Coverage for aiostream/aiter_utils.py: 96%
106 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-02 23:18 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-02 23:18 +0000
1"""Utilities for asynchronous iteration."""
2from __future__ import annotations
3from types import TracebackType
5import warnings
6import functools
7from typing import (
8 TYPE_CHECKING,
9 AsyncContextManager,
10 AsyncGenerator,
11 AsyncIterable,
12 Awaitable,
13 Callable,
14 Type,
15 TypeVar,
16 AsyncIterator,
17 Any,
18)
20if TYPE_CHECKING:
21 from typing_extensions import ParamSpec
23 P = ParamSpec("P")
25from contextlib import AsyncExitStack
27__all__ = [
28 "aiter",
29 "anext",
30 "await_",
31 "async_",
32 "is_async_iterable",
33 "assert_async_iterable",
34 "is_async_iterator",
35 "assert_async_iterator",
36 "AsyncIteratorContext",
37 "aitercontext",
38 "AsyncExitStack",
39]
42# Magic method shorcuts
45def aiter(obj: AsyncIterable[T]) -> AsyncIterator[T]:
46 """Access aiter magic method."""
47 assert_async_iterable(obj)
48 return obj.__aiter__()
51def anext(obj: AsyncIterator[T]) -> Awaitable[T]:
52 """Access anext magic method."""
53 assert_async_iterator(obj)
54 return obj.__anext__()
57# Async / await helper functions
60async def await_(obj: Awaitable[T]) -> T:
61 """Identity coroutine function."""
62 return await obj
65def async_(fn: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
66 """Wrap the given function into a coroutine function."""
68 @functools.wraps(fn)
69 async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
70 return await fn(*args, **kwargs)
72 return wrapper
75# Iterability helpers
78def is_async_iterable(obj: object) -> bool:
79 """Check if the given object is an asynchronous iterable."""
80 return hasattr(obj, "__aiter__")
83def assert_async_iterable(obj: object) -> None:
84 """Raise a TypeError if the given object is not an
85 asynchronous iterable.
86 """
87 if not is_async_iterable(obj):
88 raise TypeError(f"{type(obj).__name__!r} object is not async iterable")
91def is_async_iterator(obj: object) -> bool:
92 """Check if the given object is an asynchronous iterator."""
93 return hasattr(obj, "__anext__")
96def assert_async_iterator(obj: object) -> None:
97 """Raise a TypeError if the given object is not an
98 asynchronous iterator.
99 """
100 if not is_async_iterator(obj):
101 raise TypeError(f"{type(obj).__name__!r} object is not an async iterator")
104# Async iterator context
106T = TypeVar("T")
107Self = TypeVar("Self", bound="AsyncIteratorContext[Any]")
110class AsyncIteratorContext(AsyncIterator[T], AsyncContextManager[Any]):
111 """Asynchronous iterator with context management.
113 The context management makes sure the aclose asynchronous method
114 of the corresponding iterator has run before it exits. It also issues
115 warnings and RuntimeError if it is used incorrectly.
117 Correct usage::
119 ait = some_asynchronous_iterable()
120 async with AsyncIteratorContext(ait) as safe_ait:
121 async for item in safe_ait:
122 <block>
124 It is nonetheless not meant to use directly.
125 Prefer aitercontext helper instead.
126 """
128 _STANDBY = "STANDBY"
129 _RUNNING = "RUNNING"
130 _FINISHED = "FINISHED"
132 def __init__(self, aiterator: AsyncIterator[T]):
133 """Initialize with an asynchrnous iterator."""
134 assert_async_iterator(aiterator)
135 if isinstance(aiterator, AsyncIteratorContext):
136 raise TypeError(f"{aiterator!r} is already an AsyncIteratorContext")
137 self._state = self._STANDBY
138 self._aiterator = aiterator
140 def __aiter__(self: Self) -> Self:
141 return self
143 def __anext__(self) -> Awaitable[T]:
144 if self._state == self._FINISHED:
145 raise RuntimeError(
146 f"{type(self).__name__} is closed and cannot be iterated"
147 )
148 if self._state == self._STANDBY:
149 warnings.warn(
150 f"{type(self).__name__} is iterated outside of its context",
151 stacklevel=2,
152 )
153 return anext(self._aiterator)
155 async def __aenter__(self: Self) -> Self:
156 if self._state == self._RUNNING:
157 raise RuntimeError(f"{type(self).__name__} has already been entered")
158 if self._state == self._FINISHED:
159 raise RuntimeError(
160 f"{type(self).__name__} is closed and cannot be iterated"
161 )
162 self._state = self._RUNNING
163 return self
165 async def __aexit__(
166 self,
167 typ: Type[BaseException] | None,
168 value: BaseException | None,
169 traceback: TracebackType | None,
170 ) -> bool:
171 try:
172 if self._state == self._FINISHED:
173 return False
174 try:
175 # No exception to throw
176 if typ is None:
177 return False
179 # Prevent GeneratorExit from being silenced
180 if typ is GeneratorExit:
181 return False
183 # No method to throw
184 if not hasattr(self._aiterator, "athrow"):
185 return False
187 # No frame to throw
188 if not getattr(self._aiterator, "ag_frame", True):
189 return False
191 # Cannot throw at the moment
192 if getattr(self._aiterator, "ag_running", False):
193 return False
195 # Throw
196 try:
197 assert isinstance(self._aiterator, AsyncGenerator)
198 await self._aiterator.athrow(typ, value, traceback)
199 raise RuntimeError("Async iterator didn't stop after athrow()")
201 # Exception has been (most probably) silenced
202 except StopAsyncIteration as exc:
203 return exc is not value
205 # A (possibly new) exception has been raised
206 except BaseException as exc:
207 if exc is value:
208 return False
209 raise
210 finally:
211 # Look for an aclose method
212 aclose = getattr(self._aiterator, "aclose", None)
214 # The ag_running attribute has been introduced with python 3.8
215 running = getattr(self._aiterator, "ag_running", False)
216 closed = not getattr(self._aiterator, "ag_frame", True)
218 # A RuntimeError is raised if aiterator is running or closed
219 if aclose and not running and not closed:
220 try:
221 await aclose()
223 # Work around bpo-35409
224 except GeneratorExit:
225 pass # pragma: no cover
226 finally:
227 self._state = self._FINISHED
229 async def aclose(self) -> None:
230 await self.__aexit__(None, None, None)
232 async def athrow(self, exc: Exception) -> T:
233 if self._state == self._FINISHED:
234 raise RuntimeError(f"{type(self).__name__} is closed and cannot be used")
235 assert isinstance(self._aiterator, AsyncGenerator)
236 item: T = await self._aiterator.athrow(exc)
237 return item
240def aitercontext(
241 aiterable: AsyncIterable[T],
242) -> AsyncIteratorContext[T]:
243 """Return an asynchronous context manager from an asynchronous iterable.
245 The context management makes sure the aclose asynchronous method
246 has run before it exits. It also issues warnings and RuntimeError
247 if it is used incorrectly.
249 It is safe to use with any asynchronous iterable and prevent
250 asynchronous iterator context to be wrapped twice.
252 Correct usage::
254 ait = some_asynchronous_iterable()
255 async with aitercontext(ait) as safe_ait:
256 async for item in safe_ait:
257 <block>
258 """
259 aiterator = aiter(aiterable)
260 if isinstance(aiterator, AsyncIteratorContext):
261 return aiterator
262 return AsyncIteratorContext(aiterator)