Coverage for aiostream/core.py: 92%
176 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-02 23:18 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-02 23:18 +0000
1"""Core objects for stream operators."""
2from __future__ import annotations
4import inspect
5import functools
6import sys
7import warnings
9from .aiter_utils import AsyncIteratorContext, aiter, assert_async_iterable
10from typing import (
11 Any,
12 AsyncIterator,
13 Callable,
14 Generator,
15 Iterator,
16 Protocol,
17 Union,
18 TypeVar,
19 cast,
20 AsyncIterable,
21 Awaitable,
22)
24from typing_extensions import ParamSpec, Concatenate
27__all__ = ["Stream", "Streamer", "StreamEmpty", "operator", "streamcontext"]
30# Exception
33class StreamEmpty(Exception):
34 """Exception raised when awaiting an empty stream."""
36 pass
39# Helpers
41T = TypeVar("T")
42X = TypeVar("X")
43A = TypeVar("A", contravariant=True)
44P = ParamSpec("P")
45Q = ParamSpec("Q")
47# Hack for python 3.8 compatibility
48if sys.version_info < (3, 9):
49 P = TypeVar("P")
52async def wait_stream(aiterable: BaseStream[T]) -> T:
53 """Wait for an asynchronous iterable to finish and return the last item.
55 The iterable is executed within a safe stream context.
56 A StreamEmpty exception is raised if the sequence is empty.
57 """
59 class Unassigned:
60 pass
62 last_item: Unassigned | T = Unassigned()
64 async with streamcontext(aiterable) as streamer:
65 async for item in streamer:
66 last_item = item
68 if isinstance(last_item, Unassigned):
69 raise StreamEmpty()
70 return last_item
73# Core objects
76class BaseStream(AsyncIterable[T], Awaitable[T]):
77 """
78 Base class for streams.
80 See `Stream` and `Streamer` for more information.
81 """
83 def __init__(self, factory: Callable[[], AsyncIterable[T]]) -> None:
84 """Initialize the stream with an asynchronous iterable factory.
86 The factory is a callable and takes no argument.
87 The factory return value is an asynchronous iterable.
88 """
89 aiter = factory()
90 assert_async_iterable(aiter)
91 self._generator = self._make_generator(aiter, factory)
93 def _make_generator(
94 self, first: AsyncIterable[T], factory: Callable[[], AsyncIterable[T]]
95 ) -> Iterator[AsyncIterable[T]]:
96 """Generate asynchronous iterables when required.
98 The first iterable is created beforehand for extra checking.
99 """
100 yield first
101 del first
102 while True:
103 yield factory()
105 def __await__(self) -> Generator[Any, None, T]:
106 """Await protocol.
108 Safely iterate and return the last element.
109 """
110 return wait_stream(self).__await__()
112 def __or__(self, func: Callable[[BaseStream[T]], X]) -> X:
113 """Pipe protocol.
115 Allow to pipe stream operators.
116 """
117 return func(self)
119 def __add__(self, value: AsyncIterable[X]) -> Stream[Union[X, T]]:
120 """Addition protocol.
122 Concatenate with a given asynchronous sequence.
123 """
124 from .stream import chain
126 return chain(self, value)
128 def __getitem__(self, value: Union[int, slice]) -> Stream[T]:
129 """Get item protocol.
131 Accept index or slice to extract the corresponding item(s)
132 """
133 from .stream import getitem
135 return getitem(self, value)
137 # Disable sync iteration
138 # This is necessary because __getitem__ is defined
139 # which is a valid fallback for for-loops in python
140 __iter__: None = None
143class Stream(BaseStream[T]):
144 """Enhanced asynchronous iterable.
146 It provides the following features:
148 - **Operator pipe-lining** - using pipe symbol ``|``
149 - **Repeatability** - every iteration creates a different iterator
150 - **Safe iteration context** - using ``async with`` and the ``stream``
151 method
152 - **Simplified execution** - get the last element from a stream using
153 ``await``
154 - **Slicing and indexing** - using square brackets ``[]``
155 - **Concatenation** - using addition symbol ``+``
157 It is not meant to be instanciated directly.
158 Use the stream operators instead.
160 Example::
162 xs = stream.count() # xs is a stream object
163 ys = xs | pipe.skip(5) # pipe xs and skip the first 5 elements
164 zs = ys[5:10:2] # slice ys using start, stop and step
166 async with zs.stream() as streamer: # stream zs in a safe context
167 async for z in streamer: # iterate the zs streamer
168 print(z) # Prints 10, 12, 14
170 result = await zs # await zs and return its last element
171 print(result) # Prints 14
172 result = await zs # zs can be used several times
173 print(result) # Prints 14
174 """
176 def stream(self) -> Streamer[T]:
177 """Return a streamer context for safe iteration.
179 Example::
181 xs = stream.count()
182 async with xs.stream() as streamer:
183 async for item in streamer:
184 <block>
186 """
187 return self.__aiter__()
189 def __aiter__(self) -> Streamer[T]:
190 """Asynchronous iteration protocol.
192 Return a streamer context for safe iteration.
193 """
194 return streamcontext(next(self._generator))
196 # Advertise the proper synthax for entering a stream context
198 __aexit__: None = None
200 async def __aenter__(self) -> None:
201 raise TypeError(
202 "A stream object cannot be used as a context manager. "
203 "Use the `stream` method instead: "
204 "`async with xs.stream() as streamer`"
205 )
208class Streamer(AsyncIteratorContext[T], BaseStream[T]):
209 """Enhanced asynchronous iterator context.
211 It is similar to AsyncIteratorContext but provides the stream
212 magic methods for concatenation, indexing and awaiting.
214 It's not meant to be instanciated directly, use streamcontext instead.
216 Example::
218 ait = some_asynchronous_iterable()
219 async with streamcontext(ait) as streamer:
220 async for item in streamer:
221 await streamer[5]
222 """
224 pass
227def streamcontext(aiterable: AsyncIterable[T]) -> Streamer[T]:
228 """Return a stream context manager from an asynchronous iterable.
230 The context management makes sure the aclose asynchronous method
231 of the corresponding iterator has run before it exits. It also issues
232 warnings and RuntimeError if it is used incorrectly.
234 It is safe to use with any asynchronous iterable and prevent
235 asynchronous iterator context to be wrapped twice.
237 Correct usage::
239 ait = some_asynchronous_iterable()
240 async with streamcontext(ait) as streamer:
241 async for item in streamer:
242 <block>
244 For streams objects, it is possible to use the stream method instead::
246 xs = stream.count()
247 async with xs.stream() as streamer:
248 async for item in streamer:
249 <block>
250 """
251 aiterator = aiter(aiterable)
252 if isinstance(aiterator, Streamer):
253 return aiterator
254 return Streamer(aiterator)
257# Operator type protocol
260class OperatorType(Protocol[P, T]):
261 def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Stream[T]:
262 ...
264 def raw(self, *args: P.args, **kwargs: P.kwargs) -> AsyncIterator[T]:
265 ...
268class PipableOperatorType(Protocol[A, P, T]):
269 def __call__(
270 self, source: AsyncIterable[A], /, *args: P.args, **kwargs: P.kwargs
271 ) -> Stream[T]:
272 ...
274 def raw(
275 self, source: AsyncIterable[A], /, *args: P.args, **kwargs: P.kwargs
276 ) -> AsyncIterator[T]:
277 ...
279 def pipe(
280 self, *args: P.args, **kwargs: P.kwargs
281 ) -> Callable[[AsyncIterable[A]], Stream[T]]:
282 ...
285# Operator decorator
288def operator(
289 func: Callable[P, AsyncIterator[T]] | None = None,
290 pipable: bool | None = None,
291) -> OperatorType[P, T]:
292 """Create a stream operator from an asynchronous generator
293 (or any function returning an asynchronous iterable).
295 Decorator usage::
297 @operator
298 async def random(offset=0., width=1.):
299 while True:
300 yield offset + width * random.random()
302 The return value is a dynamically created class.
303 It has the same name, module and doc as the original function.
305 A new stream is created by simply instanciating the operator::
307 xs = random()
309 The original function is called at instanciation to check that
310 signature match. Other methods are available:
312 - `original`: the original function as a static method
313 - `raw`: same as original but add extra checking
315 The `pipable` argument is deprecated, use `pipable_operator` instead.
316 """
318 # Handle compatibility with legacy (aiostream <= 0.4)
319 if pipable is not None or func is None:
320 warnings.warn(
321 "The `pipable` argument is deprecated. Use either `@operator` or `@pipable_operator` directly.",
322 DeprecationWarning,
323 )
324 if func is None:
325 return pipable_operator if pipable else operator # type: ignore
326 if pipable is True:
327 return pipable_operator(func) # type: ignore
329 # First check for classmethod instance, to avoid more confusing errors later on
330 if isinstance(func, classmethod):
331 raise ValueError(
332 "An operator cannot be created from a class method, "
333 "since the decorated function becomes an operator class"
334 )
336 # Gather data
337 bases = (Stream,)
338 name = func.__name__
339 module = func.__module__
340 extra_doc = func.__doc__
341 doc = extra_doc or f"Regular {name} stream operator."
343 # Extract signature
344 signature = inspect.signature(func)
345 parameters = list(signature.parameters.values())
346 if parameters and parameters[0].name in ("self", "cls"):
347 raise ValueError(
348 "An operator cannot be created from a method, "
349 "since the decorated function becomes an operator class"
350 )
352 # Look for "more_sources"
353 for i, p in enumerate(parameters):
354 if p.name == "more_sources" and p.kind == inspect.Parameter.VAR_POSITIONAL:
355 more_sources_index = i
356 break
357 else:
358 more_sources_index = None
360 # Injected parameters
361 self_parameter = inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)
362 inspect.Parameter("cls", inspect.Parameter.POSITIONAL_OR_KEYWORD)
364 # Wrapped static method
365 original = func
366 original.__qualname__ = name + ".original"
368 # Raw static method
369 raw = func
370 raw.__qualname__ = name + ".raw"
372 # Init method
373 def init(self: BaseStream[T], *args: P.args, **kwargs: P.kwargs) -> None:
374 if more_sources_index is not None:
375 for source in args[more_sources_index:]:
376 assert_async_iterable(source)
377 factory = functools.partial(raw, *args, **kwargs)
378 return BaseStream.__init__(self, factory)
380 # Customize init signature
381 new_parameters = [self_parameter] + parameters
382 init.__signature__ = signature.replace(parameters=new_parameters) # type: ignore[attr-defined]
384 # Customize init method
385 init.__qualname__ = name + ".__init__"
386 init.__name__ = "__init__"
387 init.__module__ = module
388 init.__doc__ = f"Initialize the {name} stream."
390 # Gather attributes
391 attrs = {
392 "__init__": init,
393 "__module__": module,
394 "__doc__": doc,
395 "raw": staticmethod(raw),
396 "original": staticmethod(original),
397 }
399 # Create operator class
400 return cast("OperatorType[P, T]", type(name, bases, attrs))
403def pipable_operator(
404 func: Callable[Concatenate[AsyncIterable[X], P], AsyncIterator[T]],
405) -> PipableOperatorType[X, P, T]:
406 """Create a pipable stream operator from an asynchronous generator
407 (or any function returning an asynchronous iterable).
409 Decorator usage::
411 @pipable_operator
412 async def multiply(source, factor):
413 async with streamcontext(source) as streamer:
414 async for item in streamer:
415 yield factor * item
417 The first argument is expected to be the asynchronous iteratable used
418 for piping.
420 The return value is a dynamically created class.
421 It has the same name, module and doc as the original function.
423 A new stream is created by simply instanciating the operator::
425 xs = random()
426 ys = multiply(xs, 2)
428 The original function is called at instanciation to check that
429 signature match. The source is also checked for asynchronous iteration.
431 The operator also have a pipe class method that can be used along
432 with the piping synthax::
434 xs = random()
435 ys = xs | multiply.pipe(2)
437 This is strictly equivalent to the previous example.
439 Other methods are available:
441 - `original`: the original function as a static method
442 - `raw`: same as original but add extra checking
444 The raw method is useful to create new operators from existing ones::
446 @pipable_operator
447 def double(source):
448 return multiply.raw(source, 2)
449 """
451 # First check for classmethod instance, to avoid more confusing errors later on
452 if isinstance(func, classmethod):
453 raise ValueError(
454 "An operator cannot be created from a class method, "
455 "since the decorated function becomes an operator class"
456 )
458 # Gather data
459 bases = (Stream,)
460 name = func.__name__
461 module = func.__module__
462 extra_doc = func.__doc__
463 doc = extra_doc or f"Regular {name} stream operator."
465 # Extract signature
466 signature = inspect.signature(func)
467 parameters = list(signature.parameters.values())
468 if parameters and parameters[0].name in ("self", "cls"):
469 raise ValueError(
470 "An operator cannot be created from a method, "
471 "since the decorated function becomes an operator class"
472 )
474 # Look for "more_sources"
475 for i, p in enumerate(parameters):
476 if p.name == "more_sources" and p.kind == inspect.Parameter.VAR_POSITIONAL:
477 more_sources_index = i
478 break
479 else:
480 more_sources_index = None
482 # Injected parameters
483 self_parameter = inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)
484 cls_parameter = inspect.Parameter("cls", inspect.Parameter.POSITIONAL_OR_KEYWORD)
486 # Wrapped static method
487 original = func
488 original.__qualname__ = name + ".original"
490 # Raw static method
491 def raw(
492 arg: AsyncIterable[X], *args: P.args, **kwargs: P.kwargs
493 ) -> AsyncIterator[T]:
494 assert_async_iterable(arg)
495 if more_sources_index is not None:
496 for source in args[more_sources_index - 1 :]:
497 assert_async_iterable(source)
498 return func(arg, *args, **kwargs)
500 # Custonize raw method
501 raw.__signature__ = signature # type: ignore[attr-defined]
502 raw.__qualname__ = name + ".raw"
503 raw.__module__ = module
504 raw.__doc__ = doc
506 # Init method
507 def init(
508 self: BaseStream[T], arg: AsyncIterable[X], *args: P.args, **kwargs: P.kwargs
509 ) -> None:
510 assert_async_iterable(arg)
511 if more_sources_index is not None:
512 for source in args[more_sources_index - 1 :]:
513 assert_async_iterable(source)
514 factory = functools.partial(raw, arg, *args, **kwargs)
515 return BaseStream.__init__(self, factory)
517 # Customize init signature
518 new_parameters = [self_parameter] + parameters
519 init.__signature__ = signature.replace(parameters=new_parameters) # type: ignore[attr-defined]
521 # Customize init method
522 init.__qualname__ = name + ".__init__"
523 init.__name__ = "__init__"
524 init.__module__ = module
525 init.__doc__ = f"Initialize the {name} stream."
527 # Pipe class method
528 def pipe(
529 cls: PipableOperatorType[X, P, T],
530 /,
531 *args: P.args,
532 **kwargs: P.kwargs,
533 ) -> Callable[[AsyncIterable[X]], Stream[T]]:
534 return lambda source: cls(source, *args, **kwargs)
536 # Customize pipe signature
537 if parameters and parameters[0].kind in (
538 inspect.Parameter.POSITIONAL_ONLY,
539 inspect.Parameter.POSITIONAL_OR_KEYWORD,
540 ):
541 new_parameters = [cls_parameter] + parameters[1:]
542 else:
543 new_parameters = [cls_parameter] + parameters
544 pipe.__signature__ = signature.replace(parameters=new_parameters) # type: ignore[attr-defined]
546 # Customize pipe method
547 pipe.__qualname__ = name + ".pipe"
548 pipe.__module__ = module
549 pipe.__doc__ = f'Pipable "{name}" stream operator.'
550 if extra_doc:
551 pipe.__doc__ += "\n\n " + extra_doc
553 # Gather attributes
554 attrs = {
555 "__init__": init,
556 "__module__": module,
557 "__doc__": doc,
558 "raw": staticmethod(raw),
559 "original": staticmethod(original),
560 "pipe": classmethod(pipe), # type: ignore[arg-type]
561 }
563 # Create operator class
564 return cast(
565 "PipableOperatorType[X, P, T]",
566 type(name, bases, attrs),
567 )