Skip to content

Commit bf594bd

Browse files
committed
handle asyncio.wait_for(coro, None) and timeouts.timeout(0)
1 parent 54f4895 commit bf594bd

File tree

3 files changed

+30
-7
lines changed

3 files changed

+30
-7
lines changed

Lib/asyncio/tasks.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -446,10 +446,12 @@ async def wait_for(fut, timeout):
446446
447447
This function is a coroutine.
448448
"""
449+
if timeout is None:
450+
return await fut
449451

450452
async def inner():
451453
async with timeouts.timeout(timeout):
452-
return await fut
454+
return await ensure_future(fut)
453455

454456
return await create_task(inner())
455457

Lib/asyncio/timeouts.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,13 @@ def reschedule(self, when: Optional[float]) -> None:
5252
self._timeout_handler = None
5353
else:
5454
loop = events.get_running_loop()
55-
self._timeout_handler = loop.call_at(
56-
when,
57-
self._on_timeout,
58-
)
55+
if when <= 0:
56+
self._timeout_handler = loop.call_soon(self._on_timeout)
57+
else:
58+
self._timeout_handler = loop.call_at(
59+
when,
60+
self._on_timeout,
61+
)
5962

6063
def expired(self) -> bool:
6164
"""Is timeout expired during execution?"""
@@ -126,7 +129,13 @@ def timeout(delay: Optional[float]) -> Timeout:
126129
into TimeoutError.
127130
"""
128131
loop = events.get_running_loop()
129-
return Timeout(loop.time() + delay if delay is not None else None)
132+
if delay is None:
133+
return Timeout(None)
134+
135+
if delay <= 0:
136+
return Timeout(0)
137+
138+
return Timeout(loop.time() + delay)
130139

131140

132141
def timeout_at(when: Optional[float]) -> Timeout:

Lib/test/test_asyncio/test_timeouts.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,19 @@ async def test_timeout_zero(self):
103103
self.assertTrue(cm.expired())
104104
# 2 sec for slow CI boxes
105105
self.assertLess(t1-t0, 2)
106-
self.assertTrue(t0 <= cm.when() <= t1)
106+
self.assertEqual(cm.when(), 0)
107+
108+
async def test_timeout_zero_zero(self):
109+
loop = asyncio.get_running_loop()
110+
t0 = loop.time()
111+
with self.assertRaises(TimeoutError):
112+
async with asyncio.timeout(0) as cm:
113+
await asyncio.sleep(0)
114+
t1 = loop.time()
115+
self.assertTrue(cm.expired())
116+
# 2 sec for slow CI boxes
117+
self.assertLess(t1-t0, 2)
118+
self.assertEqual(cm.when(), 0)
107119

108120
async def test_foreign_exception_passed(self):
109121
with self.assertRaises(KeyError):

0 commit comments

Comments
 (0)