Coverage for aiostream/core.py: 92%

176 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-02 23:18 +0000

1"""Core objects for stream operators.""" 

2from __future__ import annotations 

3 

4import inspect 

5import functools 

6import sys 

7import warnings 

8 

9from .aiter_utils import AsyncIteratorContext, aiter, assert_async_iterable 

10from typing import ( 

11 Any, 

12 AsyncIterator, 

13 Callable, 

14 Generator, 

15 Iterator, 

16 Protocol, 

17 Union, 

18 TypeVar, 

19 cast, 

20 AsyncIterable, 

21 Awaitable, 

22) 

23 

24from typing_extensions import ParamSpec, Concatenate 

25 

26 

27__all__ = ["Stream", "Streamer", "StreamEmpty", "operator", "streamcontext"] 

28 

29 

30# Exception 

31 

32 

33class StreamEmpty(Exception): 

34 """Exception raised when awaiting an empty stream.""" 

35 

36 pass 

37 

38 

39# Helpers 

40 

41T = TypeVar("T") 

42X = TypeVar("X") 

43A = TypeVar("A", contravariant=True) 

44P = ParamSpec("P") 

45Q = ParamSpec("Q") 

46 

47# Hack for python 3.8 compatibility 

48if sys.version_info < (3, 9): 

49 P = TypeVar("P") 

50 

51 

52async def wait_stream(aiterable: BaseStream[T]) -> T: 

53 """Wait for an asynchronous iterable to finish and return the last item. 

54 

55 The iterable is executed within a safe stream context. 

56 A StreamEmpty exception is raised if the sequence is empty. 

57 """ 

58 

59 class Unassigned: 

60 pass 

61 

62 last_item: Unassigned | T = Unassigned() 

63 

64 async with streamcontext(aiterable) as streamer: 

65 async for item in streamer: 

66 last_item = item 

67 

68 if isinstance(last_item, Unassigned): 

69 raise StreamEmpty() 

70 return last_item 

71 

72 

73# Core objects 

74 

75 

76class BaseStream(AsyncIterable[T], Awaitable[T]): 

77 """ 

78 Base class for streams. 

79 

80 See `Stream` and `Streamer` for more information. 

81 """ 

82 

83 def __init__(self, factory: Callable[[], AsyncIterable[T]]) -> None: 

84 """Initialize the stream with an asynchronous iterable factory. 

85 

86 The factory is a callable and takes no argument. 

87 The factory return value is an asynchronous iterable. 

88 """ 

89 aiter = factory() 

90 assert_async_iterable(aiter) 

91 self._generator = self._make_generator(aiter, factory) 

92 

93 def _make_generator( 

94 self, first: AsyncIterable[T], factory: Callable[[], AsyncIterable[T]] 

95 ) -> Iterator[AsyncIterable[T]]: 

96 """Generate asynchronous iterables when required. 

97 

98 The first iterable is created beforehand for extra checking. 

99 """ 

100 yield first 

101 del first 

102 while True: 

103 yield factory() 

104 

105 def __await__(self) -> Generator[Any, None, T]: 

106 """Await protocol. 

107 

108 Safely iterate and return the last element. 

109 """ 

110 return wait_stream(self).__await__() 

111 

112 def __or__(self, func: Callable[[BaseStream[T]], X]) -> X: 

113 """Pipe protocol. 

114 

115 Allow to pipe stream operators. 

116 """ 

117 return func(self) 

118 

119 def __add__(self, value: AsyncIterable[X]) -> Stream[Union[X, T]]: 

120 """Addition protocol. 

121 

122 Concatenate with a given asynchronous sequence. 

123 """ 

124 from .stream import chain 

125 

126 return chain(self, value) 

127 

128 def __getitem__(self, value: Union[int, slice]) -> Stream[T]: 

129 """Get item protocol. 

130 

131 Accept index or slice to extract the corresponding item(s) 

132 """ 

133 from .stream import getitem 

134 

135 return getitem(self, value) 

136 

137 # Disable sync iteration 

138 # This is necessary because __getitem__ is defined 

139 # which is a valid fallback for for-loops in python 

140 __iter__: None = None 

141 

142 

143class Stream(BaseStream[T]): 

144 """Enhanced asynchronous iterable. 

145 

146 It provides the following features: 

147 

148 - **Operator pipe-lining** - using pipe symbol ``|`` 

149 - **Repeatability** - every iteration creates a different iterator 

150 - **Safe iteration context** - using ``async with`` and the ``stream`` 

151 method 

152 - **Simplified execution** - get the last element from a stream using 

153 ``await`` 

154 - **Slicing and indexing** - using square brackets ``[]`` 

155 - **Concatenation** - using addition symbol ``+`` 

156 

157 It is not meant to be instanciated directly. 

158 Use the stream operators instead. 

159 

160 Example:: 

161 

162 xs = stream.count() # xs is a stream object 

163 ys = xs | pipe.skip(5) # pipe xs and skip the first 5 elements 

164 zs = ys[5:10:2] # slice ys using start, stop and step 

165 

166 async with zs.stream() as streamer: # stream zs in a safe context 

167 async for z in streamer: # iterate the zs streamer 

168 print(z) # Prints 10, 12, 14 

169 

170 result = await zs # await zs and return its last element 

171 print(result) # Prints 14 

172 result = await zs # zs can be used several times 

173 print(result) # Prints 14 

174 """ 

175 

176 def stream(self) -> Streamer[T]: 

177 """Return a streamer context for safe iteration. 

178 

179 Example:: 

180 

181 xs = stream.count() 

182 async with xs.stream() as streamer: 

183 async for item in streamer: 

184 <block> 

185 

186 """ 

187 return self.__aiter__() 

188 

189 def __aiter__(self) -> Streamer[T]: 

190 """Asynchronous iteration protocol. 

191 

192 Return a streamer context for safe iteration. 

193 """ 

194 return streamcontext(next(self._generator)) 

195 

196 # Advertise the proper synthax for entering a stream context 

197 

198 __aexit__: None = None 

199 

200 async def __aenter__(self) -> None: 

201 raise TypeError( 

202 "A stream object cannot be used as a context manager. " 

203 "Use the `stream` method instead: " 

204 "`async with xs.stream() as streamer`" 

205 ) 

206 

207 

208class Streamer(AsyncIteratorContext[T], BaseStream[T]): 

209 """Enhanced asynchronous iterator context. 

210 

211 It is similar to AsyncIteratorContext but provides the stream 

212 magic methods for concatenation, indexing and awaiting. 

213 

214 It's not meant to be instanciated directly, use streamcontext instead. 

215 

216 Example:: 

217 

218 ait = some_asynchronous_iterable() 

219 async with streamcontext(ait) as streamer: 

220 async for item in streamer: 

221 await streamer[5] 

222 """ 

223 

224 pass 

225 

226 

227def streamcontext(aiterable: AsyncIterable[T]) -> Streamer[T]: 

228 """Return a stream context manager from an asynchronous iterable. 

229 

230 The context management makes sure the aclose asynchronous method 

231 of the corresponding iterator has run before it exits. It also issues 

232 warnings and RuntimeError if it is used incorrectly. 

233 

234 It is safe to use with any asynchronous iterable and prevent 

235 asynchronous iterator context to be wrapped twice. 

236 

237 Correct usage:: 

238 

239 ait = some_asynchronous_iterable() 

240 async with streamcontext(ait) as streamer: 

241 async for item in streamer: 

242 <block> 

243 

244 For streams objects, it is possible to use the stream method instead:: 

245 

246 xs = stream.count() 

247 async with xs.stream() as streamer: 

248 async for item in streamer: 

249 <block> 

250 """ 

251 aiterator = aiter(aiterable) 

252 if isinstance(aiterator, Streamer): 

253 return aiterator 

254 return Streamer(aiterator) 

255 

256 

257# Operator type protocol 

258 

259 

260class OperatorType(Protocol[P, T]): 

261 def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Stream[T]: 

262 ... 

263 

264 def raw(self, *args: P.args, **kwargs: P.kwargs) -> AsyncIterator[T]: 

265 ... 

266 

267 

268class PipableOperatorType(Protocol[A, P, T]): 

269 def __call__( 

270 self, source: AsyncIterable[A], /, *args: P.args, **kwargs: P.kwargs 

271 ) -> Stream[T]: 

272 ... 

273 

274 def raw( 

275 self, source: AsyncIterable[A], /, *args: P.args, **kwargs: P.kwargs 

276 ) -> AsyncIterator[T]: 

277 ... 

278 

279 def pipe( 

280 self, *args: P.args, **kwargs: P.kwargs 

281 ) -> Callable[[AsyncIterable[A]], Stream[T]]: 

282 ... 

283 

284 

285# Operator decorator 

286 

287 

288def operator( 

289 func: Callable[P, AsyncIterator[T]] | None = None, 

290 pipable: bool | None = None, 

291) -> OperatorType[P, T]: 

292 """Create a stream operator from an asynchronous generator 

293 (or any function returning an asynchronous iterable). 

294 

295 Decorator usage:: 

296 

297 @operator 

298 async def random(offset=0., width=1.): 

299 while True: 

300 yield offset + width * random.random() 

301 

302 The return value is a dynamically created class. 

303 It has the same name, module and doc as the original function. 

304 

305 A new stream is created by simply instanciating the operator:: 

306 

307 xs = random() 

308 

309 The original function is called at instanciation to check that 

310 signature match. Other methods are available: 

311 

312 - `original`: the original function as a static method 

313 - `raw`: same as original but add extra checking 

314 

315 The `pipable` argument is deprecated, use `pipable_operator` instead. 

316 """ 

317 

318 # Handle compatibility with legacy (aiostream <= 0.4) 

319 if pipable is not None or func is None: 

320 warnings.warn( 

321 "The `pipable` argument is deprecated. Use either `@operator` or `@pipable_operator` directly.", 

322 DeprecationWarning, 

323 ) 

324 if func is None: 

325 return pipable_operator if pipable else operator # type: ignore 

326 if pipable is True: 

327 return pipable_operator(func) # type: ignore 

328 

329 # First check for classmethod instance, to avoid more confusing errors later on 

330 if isinstance(func, classmethod): 

331 raise ValueError( 

332 "An operator cannot be created from a class method, " 

333 "since the decorated function becomes an operator class" 

334 ) 

335 

336 # Gather data 

337 bases = (Stream,) 

338 name = func.__name__ 

339 module = func.__module__ 

340 extra_doc = func.__doc__ 

341 doc = extra_doc or f"Regular {name} stream operator." 

342 

343 # Extract signature 

344 signature = inspect.signature(func) 

345 parameters = list(signature.parameters.values()) 

346 if parameters and parameters[0].name in ("self", "cls"): 

347 raise ValueError( 

348 "An operator cannot be created from a method, " 

349 "since the decorated function becomes an operator class" 

350 ) 

351 

352 # Look for "more_sources" 

353 for i, p in enumerate(parameters): 

354 if p.name == "more_sources" and p.kind == inspect.Parameter.VAR_POSITIONAL: 

355 more_sources_index = i 

356 break 

357 else: 

358 more_sources_index = None 

359 

360 # Injected parameters 

361 self_parameter = inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD) 

362 inspect.Parameter("cls", inspect.Parameter.POSITIONAL_OR_KEYWORD) 

363 

364 # Wrapped static method 

365 original = func 

366 original.__qualname__ = name + ".original" 

367 

368 # Raw static method 

369 raw = func 

370 raw.__qualname__ = name + ".raw" 

371 

372 # Init method 

373 def init(self: BaseStream[T], *args: P.args, **kwargs: P.kwargs) -> None: 

374 if more_sources_index is not None: 

375 for source in args[more_sources_index:]: 

376 assert_async_iterable(source) 

377 factory = functools.partial(raw, *args, **kwargs) 

378 return BaseStream.__init__(self, factory) 

379 

380 # Customize init signature 

381 new_parameters = [self_parameter] + parameters 

382 init.__signature__ = signature.replace(parameters=new_parameters) # type: ignore[attr-defined] 

383 

384 # Customize init method 

385 init.__qualname__ = name + ".__init__" 

386 init.__name__ = "__init__" 

387 init.__module__ = module 

388 init.__doc__ = f"Initialize the {name} stream." 

389 

390 # Gather attributes 

391 attrs = { 

392 "__init__": init, 

393 "__module__": module, 

394 "__doc__": doc, 

395 "raw": staticmethod(raw), 

396 "original": staticmethod(original), 

397 } 

398 

399 # Create operator class 

400 return cast("OperatorType[P, T]", type(name, bases, attrs)) 

401 

402 

403def pipable_operator( 

404 func: Callable[Concatenate[AsyncIterable[X], P], AsyncIterator[T]], 

405) -> PipableOperatorType[X, P, T]: 

406 """Create a pipable stream operator from an asynchronous generator 

407 (or any function returning an asynchronous iterable). 

408 

409 Decorator usage:: 

410 

411 @pipable_operator 

412 async def multiply(source, factor): 

413 async with streamcontext(source) as streamer: 

414 async for item in streamer: 

415 yield factor * item 

416 

417 The first argument is expected to be the asynchronous iteratable used 

418 for piping. 

419 

420 The return value is a dynamically created class. 

421 It has the same name, module and doc as the original function. 

422 

423 A new stream is created by simply instanciating the operator:: 

424 

425 xs = random() 

426 ys = multiply(xs, 2) 

427 

428 The original function is called at instanciation to check that 

429 signature match. The source is also checked for asynchronous iteration. 

430 

431 The operator also have a pipe class method that can be used along 

432 with the piping synthax:: 

433 

434 xs = random() 

435 ys = xs | multiply.pipe(2) 

436 

437 This is strictly equivalent to the previous example. 

438 

439 Other methods are available: 

440 

441 - `original`: the original function as a static method 

442 - `raw`: same as original but add extra checking 

443 

444 The raw method is useful to create new operators from existing ones:: 

445 

446 @pipable_operator 

447 def double(source): 

448 return multiply.raw(source, 2) 

449 """ 

450 

451 # First check for classmethod instance, to avoid more confusing errors later on 

452 if isinstance(func, classmethod): 

453 raise ValueError( 

454 "An operator cannot be created from a class method, " 

455 "since the decorated function becomes an operator class" 

456 ) 

457 

458 # Gather data 

459 bases = (Stream,) 

460 name = func.__name__ 

461 module = func.__module__ 

462 extra_doc = func.__doc__ 

463 doc = extra_doc or f"Regular {name} stream operator." 

464 

465 # Extract signature 

466 signature = inspect.signature(func) 

467 parameters = list(signature.parameters.values()) 

468 if parameters and parameters[0].name in ("self", "cls"): 

469 raise ValueError( 

470 "An operator cannot be created from a method, " 

471 "since the decorated function becomes an operator class" 

472 ) 

473 

474 # Look for "more_sources" 

475 for i, p in enumerate(parameters): 

476 if p.name == "more_sources" and p.kind == inspect.Parameter.VAR_POSITIONAL: 

477 more_sources_index = i 

478 break 

479 else: 

480 more_sources_index = None 

481 

482 # Injected parameters 

483 self_parameter = inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD) 

484 cls_parameter = inspect.Parameter("cls", inspect.Parameter.POSITIONAL_OR_KEYWORD) 

485 

486 # Wrapped static method 

487 original = func 

488 original.__qualname__ = name + ".original" 

489 

490 # Raw static method 

491 def raw( 

492 arg: AsyncIterable[X], *args: P.args, **kwargs: P.kwargs 

493 ) -> AsyncIterator[T]: 

494 assert_async_iterable(arg) 

495 if more_sources_index is not None: 

496 for source in args[more_sources_index - 1 :]: 

497 assert_async_iterable(source) 

498 return func(arg, *args, **kwargs) 

499 

500 # Custonize raw method 

501 raw.__signature__ = signature # type: ignore[attr-defined] 

502 raw.__qualname__ = name + ".raw" 

503 raw.__module__ = module 

504 raw.__doc__ = doc 

505 

506 # Init method 

507 def init( 

508 self: BaseStream[T], arg: AsyncIterable[X], *args: P.args, **kwargs: P.kwargs 

509 ) -> None: 

510 assert_async_iterable(arg) 

511 if more_sources_index is not None: 

512 for source in args[more_sources_index - 1 :]: 

513 assert_async_iterable(source) 

514 factory = functools.partial(raw, arg, *args, **kwargs) 

515 return BaseStream.__init__(self, factory) 

516 

517 # Customize init signature 

518 new_parameters = [self_parameter] + parameters 

519 init.__signature__ = signature.replace(parameters=new_parameters) # type: ignore[attr-defined] 

520 

521 # Customize init method 

522 init.__qualname__ = name + ".__init__" 

523 init.__name__ = "__init__" 

524 init.__module__ = module 

525 init.__doc__ = f"Initialize the {name} stream." 

526 

527 # Pipe class method 

528 def pipe( 

529 cls: PipableOperatorType[X, P, T], 

530 /, 

531 *args: P.args, 

532 **kwargs: P.kwargs, 

533 ) -> Callable[[AsyncIterable[X]], Stream[T]]: 

534 return lambda source: cls(source, *args, **kwargs) 

535 

536 # Customize pipe signature 

537 if parameters and parameters[0].kind in ( 

538 inspect.Parameter.POSITIONAL_ONLY, 

539 inspect.Parameter.POSITIONAL_OR_KEYWORD, 

540 ): 

541 new_parameters = [cls_parameter] + parameters[1:] 

542 else: 

543 new_parameters = [cls_parameter] + parameters 

544 pipe.__signature__ = signature.replace(parameters=new_parameters) # type: ignore[attr-defined] 

545 

546 # Customize pipe method 

547 pipe.__qualname__ = name + ".pipe" 

548 pipe.__module__ = module 

549 pipe.__doc__ = f'Pipable "{name}" stream operator.' 

550 if extra_doc: 

551 pipe.__doc__ += "\n\n " + extra_doc 

552 

553 # Gather attributes 

554 attrs = { 

555 "__init__": init, 

556 "__module__": module, 

557 "__doc__": doc, 

558 "raw": staticmethod(raw), 

559 "original": staticmethod(original), 

560 "pipe": classmethod(pipe), # type: ignore[arg-type] 

561 } 

562 

563 # Create operator class 

564 return cast( 

565 "PipableOperatorType[X, P, T]", 

566 type(name, bases, attrs), 

567 )