1
- from inspect import isawaitable
1
+ from asyncio import Event , ensure_future , wait
2
+ from concurrent .futures import FIRST_COMPLETED
3
+ from inspect import isasyncgen , isawaitable
2
4
from typing import AsyncIterable , Callable
3
5
4
6
__all__ = ['MapAsyncIterator' ]
@@ -19,35 +21,62 @@ def __init__(self, iterable: AsyncIterable, callback: Callable,
19
21
self .iterator = iterable .__aiter__ ()
20
22
self .callback = callback
21
23
self .reject_callback = reject_callback
22
- self .stop = False
24
+ self ._close_event = Event ()
25
+
26
+ @property
27
+ def closed (self ) -> bool :
28
+ return self ._close_event .is_set ()
29
+
30
+ @closed .setter
31
+ def closed (self , value : bool ) -> None :
32
+ if value :
33
+ self ._close_event .set ()
34
+ else :
35
+ self ._close_event .clear ()
23
36
24
37
def __aiter__ (self ):
25
38
return self
26
39
27
40
async def __anext__ (self ):
28
- if self .stop :
41
+ if self .closed :
42
+ if not isasyncgen (self .iterator ):
43
+ raise StopAsyncIteration
44
+ result = await self .iterator .__anext__ ()
45
+ return self .callback (result )
46
+
47
+ _close = ensure_future (self ._close_event .wait ())
48
+ _next = ensure_future (self .iterator .__anext__ ())
49
+ done , pending = await wait (
50
+ [_close , _next ],
51
+ return_when = FIRST_COMPLETED ,
52
+ )
53
+
54
+ for task in pending :
55
+ task .cancel ()
56
+
57
+ if _close .done ():
29
58
raise StopAsyncIteration
30
- try :
31
- value = await self . iterator . __anext__ ()
32
- except Exception as error :
33
- if not self . reject_callback or isinstance ( error , (
34
- StopAsyncIteration , GeneratorExit )):
35
- raise
36
- result = self . reject_callback ( error )
37
- else :
38
- result = self . callback ( value )
39
- if isawaitable ( result ):
40
- result = await result
41
- return result
59
+
60
+ if _next . done ():
61
+ error = _next . exception ()
62
+ if error :
63
+ if not self . reject_callback or isinstance ( error , (
64
+ StopAsyncIteration , GeneratorExit )):
65
+ raise error
66
+ result = self . reject_callback ( error )
67
+ else :
68
+ result = self . callback ( _next . result ())
69
+
70
+ return ( await result ) if isawaitable ( result ) else result
42
71
43
72
async def athrow (self , type_ , value = None , traceback = None ):
44
- if self .stop :
73
+ if self .closed :
45
74
return
46
75
athrow = getattr (self .iterator , 'athrow' , None )
47
76
if athrow :
48
77
await athrow (type_ , value , traceback )
49
78
else :
50
- self .stop = True
79
+ self .closed = True
51
80
if value is None :
52
81
if traceback is None :
53
82
raise type_
@@ -57,13 +86,12 @@ async def athrow(self, type_, value=None, traceback=None):
57
86
raise value
58
87
59
88
async def aclose (self ):
60
- if self .stop :
89
+ if self .closed :
61
90
return
62
91
aclose = getattr (self .iterator , 'aclose' , None )
63
92
if aclose :
64
93
try :
65
94
await aclose ()
66
95
except RuntimeError :
67
96
pass
68
- else :
69
- self .stop = True
97
+ self .closed = True
0 commit comments