Skip to content

Commit 49e7124

Browse files
dfeeCito
authored andcommitted
Fixed the subscription race condition (#5)
1 parent 827511a commit 49e7124

File tree

3 files changed

+74
-21
lines changed

3 files changed

+74
-21
lines changed
Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from inspect import isawaitable
1+
from asyncio import Event, ensure_future, wait
2+
from concurrent.futures import FIRST_COMPLETED
3+
from inspect import isasyncgen, isawaitable
24
from typing import AsyncIterable, Callable
35

46
__all__ = ['MapAsyncIterator']
@@ -19,35 +21,62 @@ def __init__(self, iterable: AsyncIterable, callback: Callable,
1921
self.iterator = iterable.__aiter__()
2022
self.callback = callback
2123
self.reject_callback = reject_callback
22-
self.stop = False
24+
self._close_event = Event()
25+
26+
@property
27+
def closed(self) -> bool:
28+
return self._close_event.is_set()
29+
30+
@closed.setter
31+
def closed(self, value: bool) -> None:
32+
if value:
33+
self._close_event.set()
34+
else:
35+
self._close_event.clear()
2336

2437
def __aiter__(self):
2538
return self
2639

2740
async def __anext__(self):
28-
if self.stop:
41+
if self.closed:
42+
if not isasyncgen(self.iterator):
43+
raise StopAsyncIteration
44+
result = await self.iterator.__anext__()
45+
return self.callback(result)
46+
47+
_close = ensure_future(self._close_event.wait())
48+
_next = ensure_future(self.iterator.__anext__())
49+
done, pending = await wait(
50+
[_close, _next],
51+
return_when=FIRST_COMPLETED,
52+
)
53+
54+
for task in pending:
55+
task.cancel()
56+
57+
if _close.done():
2958
raise StopAsyncIteration
30-
try:
31-
value = await self.iterator.__anext__()
32-
except Exception as error:
33-
if not self.reject_callback or isinstance(error, (
34-
StopAsyncIteration, GeneratorExit)):
35-
raise
36-
result = self.reject_callback(error)
37-
else:
38-
result = self.callback(value)
39-
if isawaitable(result):
40-
result = await result
41-
return result
59+
60+
if _next.done():
61+
error = _next.exception()
62+
if error:
63+
if not self.reject_callback or isinstance(error, (
64+
StopAsyncIteration, GeneratorExit)):
65+
raise error
66+
result = self.reject_callback(error)
67+
else:
68+
result = self.callback(_next.result())
69+
70+
return (await result) if isawaitable(result) else result
4271

4372
async def athrow(self, type_, value=None, traceback=None):
44-
if self.stop:
73+
if self.closed:
4574
return
4675
athrow = getattr(self.iterator, 'athrow', None)
4776
if athrow:
4877
await athrow(type_, value, traceback)
4978
else:
50-
self.stop = True
79+
self.closed = True
5180
if value is None:
5281
if traceback is None:
5382
raise type_
@@ -57,13 +86,12 @@ async def athrow(self, type_, value=None, traceback=None):
5786
raise value
5887

5988
async def aclose(self):
60-
if self.stop:
89+
if self.closed:
6190
return
6291
aclose = getattr(self.iterator, 'aclose', None)
6392
if aclose:
6493
try:
6594
await aclose()
6695
except RuntimeError:
6796
pass
68-
else:
69-
self.stop = True
97+
self.closed = True

tests/subscription/test_map_async_iterator.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from asyncio import Event, ensure_future, sleep
12
import sys
23

34
from pytest import mark, raises
@@ -247,3 +248,27 @@ async def __anext__(self):
247248
# Throw error
248249
with raises(ValueError):
249250
await doubles.athrow(ValueError, None, tb)
251+
252+
@mark.asyncio
253+
async def stops_async_iteration_on_close():
254+
async def source():
255+
yield 1
256+
await Event().wait() # Block forever
257+
yield 2
258+
yield 3
259+
260+
doubles = MapAsyncIterator(source(), lambda x: x * 2)
261+
262+
result = await anext(doubles)
263+
assert result == 2
264+
265+
# Block at event.wait()
266+
fut = ensure_future(anext(doubles))
267+
await sleep(.01)
268+
assert not fut.done()
269+
270+
# Trigger cancellation and watch StopAsyncIteration propogate
271+
await doubles.aclose()
272+
await sleep(.01)
273+
assert fut.done()
274+
assert isinstance(fut.exception(), StopAsyncIteration)

tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,4 @@ deps =
2727
pytest-describe
2828
commands =
2929
python -m pip install -U pip
30-
pytest
30+
pytest {posargs}

0 commit comments

Comments
 (0)