Skip to content

Commit cf77739

Browse files
miss-islingtonserhiy-storchakaGobot1234
authored
[3.11] gh-111085: Fix invalid state handling in TaskGroup and Timeout (GH-111111) (GH-111172)
asyncio.TaskGroup and asyncio.Timeout classes now raise proper RuntimeError if they are improperly used. * When they are used without entering the context manager. * When they are used after finishing. * When the context manager is entered more than once (simultaneously or sequentially). * If there is no current task when entering the context manager. They now remain in a consistent state after an exception is thrown, so subsequent operations can be performed correctly (if they are allowed). (cherry picked from commit 6c23635) Co-authored-by: Serhiy Storchaka <[email protected]> Co-authored-by: James Hilton-Balfe <[email protected]>
1 parent cf28c61 commit cf77739

File tree

6 files changed

+120
-9
lines changed

6 files changed

+120
-9
lines changed

Lib/asyncio/taskgroups.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,14 @@ def __repr__(self):
5454
async def __aenter__(self):
5555
if self._entered:
5656
raise RuntimeError(
57-
f"TaskGroup {self!r} has been already entered")
58-
self._entered = True
59-
57+
f"TaskGroup {self!r} has already been entered")
6058
if self._loop is None:
6159
self._loop = events.get_running_loop()
62-
6360
self._parent_task = tasks.current_task(self._loop)
6461
if self._parent_task is None:
6562
raise RuntimeError(
6663
f'TaskGroup {self!r} cannot determine the parent task')
64+
self._entered = True
6765

6866
return self
6967

Lib/asyncio/timeouts.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ def when(self) -> Optional[float]:
4949

5050
def reschedule(self, when: Optional[float]) -> None:
5151
"""Reschedule the timeout."""
52-
assert self._state is not _State.CREATED
5352
if self._state is not _State.ENTERED:
53+
if self._state is _State.CREATED:
54+
raise RuntimeError("Timeout has not been entered")
5455
raise RuntimeError(
5556
f"Cannot change state of {self._state.value} Timeout",
5657
)
@@ -82,11 +83,14 @@ def __repr__(self) -> str:
8283
return f"<Timeout [{self._state.value}]{info_str}>"
8384

8485
async def __aenter__(self) -> "Timeout":
86+
if self._state is not _State.CREATED:
87+
raise RuntimeError("Timeout has already been entered")
88+
task = tasks.current_task()
89+
if task is None:
90+
raise RuntimeError("Timeout should be used inside a task")
8591
self._state = _State.ENTERED
86-
self._task = tasks.current_task()
92+
self._task = task
8793
self._cancelling = self._task.cancelling()
88-
if self._task is None:
89-
raise RuntimeError("Timeout should be used inside a task")
9094
self.reschedule(self._when)
9195
return self
9296

Lib/test/test_asyncio/test_taskgroups.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from asyncio import taskgroups
99
import unittest
1010

11+
from test.test_asyncio.utils import await_without_task
12+
1113

1214
# To prevent a warning "test altered the execution environment"
1315
def tearDownModule():
@@ -779,6 +781,49 @@ async def main():
779781

780782
await asyncio.create_task(main())
781783

784+
async def test_taskgroup_already_entered(self):
785+
tg = taskgroups.TaskGroup()
786+
async with tg:
787+
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
788+
async with tg:
789+
pass
790+
791+
async def test_taskgroup_double_enter(self):
792+
tg = taskgroups.TaskGroup()
793+
async with tg:
794+
pass
795+
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
796+
async with tg:
797+
pass
798+
799+
async def test_taskgroup_finished(self):
800+
tg = taskgroups.TaskGroup()
801+
async with tg:
802+
pass
803+
coro = asyncio.sleep(0)
804+
with self.assertRaisesRegex(RuntimeError, "is finished"):
805+
tg.create_task(coro)
806+
# We still have to await coro to avoid a warning
807+
await coro
808+
809+
async def test_taskgroup_not_entered(self):
810+
tg = taskgroups.TaskGroup()
811+
coro = asyncio.sleep(0)
812+
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
813+
tg.create_task(coro)
814+
# We still have to await coro to avoid a warning
815+
await coro
816+
817+
async def test_taskgroup_without_parent_task(self):
818+
tg = taskgroups.TaskGroup()
819+
with self.assertRaisesRegex(RuntimeError, "parent task"):
820+
await await_without_task(tg.__aenter__())
821+
coro = asyncio.sleep(0)
822+
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
823+
tg.create_task(coro)
824+
# We still have to await coro to avoid a warning
825+
await coro
826+
782827

783828
if __name__ == "__main__":
784829
unittest.main()

Lib/test/test_asyncio/test_timeouts.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
import asyncio
77
from asyncio import tasks
88

9+
from test.test_asyncio.utils import await_without_task
10+
911

1012
def tearDownModule():
1113
asyncio.set_event_loop_policy(None)
1214

13-
1415
class TimeoutTests(unittest.IsolatedAsyncioTestCase):
1516

1617
async def test_timeout_basic(self):
@@ -258,6 +259,51 @@ async def test_timeout_exception_cause (self):
258259
cause = exc.exception.__cause__
259260
assert isinstance(cause, asyncio.CancelledError)
260261

262+
async def test_timeout_already_entered(self):
263+
async with asyncio.timeout(0.01) as cm:
264+
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
265+
async with cm:
266+
pass
267+
268+
async def test_timeout_double_enter(self):
269+
async with asyncio.timeout(0.01) as cm:
270+
pass
271+
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
272+
async with cm:
273+
pass
274+
275+
async def test_timeout_finished(self):
276+
async with asyncio.timeout(0.01) as cm:
277+
pass
278+
with self.assertRaisesRegex(RuntimeError, "finished"):
279+
cm.reschedule(0.02)
280+
281+
async def test_timeout_expired(self):
282+
with self.assertRaises(TimeoutError):
283+
async with asyncio.timeout(0.01) as cm:
284+
await asyncio.sleep(1)
285+
with self.assertRaisesRegex(RuntimeError, "expired"):
286+
cm.reschedule(0.02)
287+
288+
async def test_timeout_expiring(self):
289+
async with asyncio.timeout(0.01) as cm:
290+
with self.assertRaises(asyncio.CancelledError):
291+
await asyncio.sleep(1)
292+
with self.assertRaisesRegex(RuntimeError, "expiring"):
293+
cm.reschedule(0.02)
294+
295+
async def test_timeout_not_entered(self):
296+
cm = asyncio.timeout(0.01)
297+
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
298+
cm.reschedule(0.02)
299+
300+
async def test_timeout_without_task(self):
301+
cm = asyncio.timeout(0.01)
302+
with self.assertRaisesRegex(RuntimeError, "task"):
303+
await await_without_task(cm.__aenter__())
304+
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
305+
cm.reschedule(0.02)
306+
261307

262308
if __name__ == '__main__':
263309
unittest.main()

Lib/test/test_asyncio/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,3 +612,18 @@ def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
612612
sock.family = family
613613
sock.gettimeout.return_value = 0.0
614614
return sock
615+
616+
617+
async def await_without_task(coro):
618+
exc = None
619+
def func():
620+
try:
621+
for _ in coro.__await__():
622+
pass
623+
except BaseException as err:
624+
nonlocal exc
625+
exc = err
626+
asyncio.get_running_loop().call_soon(func)
627+
await asyncio.sleep(0)
628+
if exc is not None:
629+
raise exc
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Fix invalid state handling in :class:`asyncio.TaskGroup` and
2+
:class:`asyncio.Timeout`. They now raise proper RuntimeError if they are
3+
improperly used and are left in consistent state after this.

0 commit comments

Comments
 (0)