Coverage for aiostream/aiter_utils.py: 96%

106 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-02 23:18 +0000

1"""Utilities for asynchronous iteration.""" 

2from __future__ import annotations 

3from types import TracebackType 

4 

5import warnings 

6import functools 

7from typing import ( 

8 TYPE_CHECKING, 

9 AsyncContextManager, 

10 AsyncGenerator, 

11 AsyncIterable, 

12 Awaitable, 

13 Callable, 

14 Type, 

15 TypeVar, 

16 AsyncIterator, 

17 Any, 

18) 

19 

20if TYPE_CHECKING: 

21 from typing_extensions import ParamSpec 

22 

23 P = ParamSpec("P") 

24 

25from contextlib import AsyncExitStack 

26 

27__all__ = [ 

28 "aiter", 

29 "anext", 

30 "await_", 

31 "async_", 

32 "is_async_iterable", 

33 "assert_async_iterable", 

34 "is_async_iterator", 

35 "assert_async_iterator", 

36 "AsyncIteratorContext", 

37 "aitercontext", 

38 "AsyncExitStack", 

39] 

40 

41 

42# Magic method shorcuts 

43 

44 

45def aiter(obj: AsyncIterable[T]) -> AsyncIterator[T]: 

46 """Access aiter magic method.""" 

47 assert_async_iterable(obj) 

48 return obj.__aiter__() 

49 

50 

51def anext(obj: AsyncIterator[T]) -> Awaitable[T]: 

52 """Access anext magic method.""" 

53 assert_async_iterator(obj) 

54 return obj.__anext__() 

55 

56 

57# Async / await helper functions 

58 

59 

60async def await_(obj: Awaitable[T]) -> T: 

61 """Identity coroutine function.""" 

62 return await obj 

63 

64 

65def async_(fn: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]: 

66 """Wrap the given function into a coroutine function.""" 

67 

68 @functools.wraps(fn) 

69 async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: 

70 return await fn(*args, **kwargs) 

71 

72 return wrapper 

73 

74 

75# Iterability helpers 

76 

77 

78def is_async_iterable(obj: object) -> bool: 

79 """Check if the given object is an asynchronous iterable.""" 

80 return hasattr(obj, "__aiter__") 

81 

82 

83def assert_async_iterable(obj: object) -> None: 

84 """Raise a TypeError if the given object is not an 

85 asynchronous iterable. 

86 """ 

87 if not is_async_iterable(obj): 

88 raise TypeError(f"{type(obj).__name__!r} object is not async iterable") 

89 

90 

91def is_async_iterator(obj: object) -> bool: 

92 """Check if the given object is an asynchronous iterator.""" 

93 return hasattr(obj, "__anext__") 

94 

95 

96def assert_async_iterator(obj: object) -> None: 

97 """Raise a TypeError if the given object is not an 

98 asynchronous iterator. 

99 """ 

100 if not is_async_iterator(obj): 

101 raise TypeError(f"{type(obj).__name__!r} object is not an async iterator") 

102 

103 

104# Async iterator context 

105 

106T = TypeVar("T") 

107Self = TypeVar("Self", bound="AsyncIteratorContext[Any]") 

108 

109 

110class AsyncIteratorContext(AsyncIterator[T], AsyncContextManager[Any]): 

111 """Asynchronous iterator with context management. 

112 

113 The context management makes sure the aclose asynchronous method 

114 of the corresponding iterator has run before it exits. It also issues 

115 warnings and RuntimeError if it is used incorrectly. 

116 

117 Correct usage:: 

118 

119 ait = some_asynchronous_iterable() 

120 async with AsyncIteratorContext(ait) as safe_ait: 

121 async for item in safe_ait: 

122 <block> 

123 

124 It is nonetheless not meant to use directly. 

125 Prefer aitercontext helper instead. 

126 """ 

127 

128 _STANDBY = "STANDBY" 

129 _RUNNING = "RUNNING" 

130 _FINISHED = "FINISHED" 

131 

132 def __init__(self, aiterator: AsyncIterator[T]): 

133 """Initialize with an asynchrnous iterator.""" 

134 assert_async_iterator(aiterator) 

135 if isinstance(aiterator, AsyncIteratorContext): 

136 raise TypeError(f"{aiterator!r} is already an AsyncIteratorContext") 

137 self._state = self._STANDBY 

138 self._aiterator = aiterator 

139 

140 def __aiter__(self: Self) -> Self: 

141 return self 

142 

143 def __anext__(self) -> Awaitable[T]: 

144 if self._state == self._FINISHED: 

145 raise RuntimeError( 

146 f"{type(self).__name__} is closed and cannot be iterated" 

147 ) 

148 if self._state == self._STANDBY: 

149 warnings.warn( 

150 f"{type(self).__name__} is iterated outside of its context", 

151 stacklevel=2, 

152 ) 

153 return anext(self._aiterator) 

154 

155 async def __aenter__(self: Self) -> Self: 

156 if self._state == self._RUNNING: 

157 raise RuntimeError(f"{type(self).__name__} has already been entered") 

158 if self._state == self._FINISHED: 

159 raise RuntimeError( 

160 f"{type(self).__name__} is closed and cannot be iterated" 

161 ) 

162 self._state = self._RUNNING 

163 return self 

164 

165 async def __aexit__( 

166 self, 

167 typ: Type[BaseException] | None, 

168 value: BaseException | None, 

169 traceback: TracebackType | None, 

170 ) -> bool: 

171 try: 

172 if self._state == self._FINISHED: 

173 return False 

174 try: 

175 # No exception to throw 

176 if typ is None: 

177 return False 

178 

179 # Prevent GeneratorExit from being silenced 

180 if typ is GeneratorExit: 

181 return False 

182 

183 # No method to throw 

184 if not hasattr(self._aiterator, "athrow"): 

185 return False 

186 

187 # No frame to throw 

188 if not getattr(self._aiterator, "ag_frame", True): 

189 return False 

190 

191 # Cannot throw at the moment 

192 if getattr(self._aiterator, "ag_running", False): 

193 return False 

194 

195 # Throw 

196 try: 

197 assert isinstance(self._aiterator, AsyncGenerator) 

198 await self._aiterator.athrow(typ, value, traceback) 

199 raise RuntimeError("Async iterator didn't stop after athrow()") 

200 

201 # Exception has been (most probably) silenced 

202 except StopAsyncIteration as exc: 

203 return exc is not value 

204 

205 # A (possibly new) exception has been raised 

206 except BaseException as exc: 

207 if exc is value: 

208 return False 

209 raise 

210 finally: 

211 # Look for an aclose method 

212 aclose = getattr(self._aiterator, "aclose", None) 

213 

214 # The ag_running attribute has been introduced with python 3.8 

215 running = getattr(self._aiterator, "ag_running", False) 

216 closed = not getattr(self._aiterator, "ag_frame", True) 

217 

218 # A RuntimeError is raised if aiterator is running or closed 

219 if aclose and not running and not closed: 

220 try: 

221 await aclose() 

222 

223 # Work around bpo-35409 

224 except GeneratorExit: 

225 pass # pragma: no cover 

226 finally: 

227 self._state = self._FINISHED 

228 

229 async def aclose(self) -> None: 

230 await self.__aexit__(None, None, None) 

231 

232 async def athrow(self, exc: Exception) -> T: 

233 if self._state == self._FINISHED: 

234 raise RuntimeError(f"{type(self).__name__} is closed and cannot be used") 

235 assert isinstance(self._aiterator, AsyncGenerator) 

236 item: T = await self._aiterator.athrow(exc) 

237 return item 

238 

239 

240def aitercontext( 

241 aiterable: AsyncIterable[T], 

242) -> AsyncIteratorContext[T]: 

243 """Return an asynchronous context manager from an asynchronous iterable. 

244 

245 The context management makes sure the aclose asynchronous method 

246 has run before it exits. It also issues warnings and RuntimeError 

247 if it is used incorrectly. 

248 

249 It is safe to use with any asynchronous iterable and prevent 

250 asynchronous iterator context to be wrapped twice. 

251 

252 Correct usage:: 

253 

254 ait = some_asynchronous_iterable() 

255 async with aitercontext(ait) as safe_ait: 

256 async for item in safe_ait: 

257 <block> 

258 """ 

259 aiterator = aiter(aiterable) 

260 if isinstance(aiterator, AsyncIteratorContext): 

261 return aiterator 

262 return AsyncIteratorContext(aiterator)