Skip to content
Merged
64 changes: 41 additions & 23 deletions Lib/asyncio/locks.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,9 +346,8 @@ class Semaphore(_ContextManagerMixin, mixins._LoopBoundMixin):
def __init__(self, value=1):
if value < 0:
raise ValueError("Semaphore initial value must be >= 0")
self._waiters = None
self._value = value
self._waiters = collections.deque()
self._wakeup_scheduled = False

def __repr__(self):
res = super().__repr__()
Expand All @@ -357,40 +356,44 @@ def __repr__(self):
extra = f'{extra}, waiters:{len(self._waiters)}'
return f'<{res[1:-1]} [{extra}]>'

def _wake_up_next(self):
while self._waiters:
waiter = self._waiters.popleft()
if not waiter.done():
waiter.set_result(None)
self._wakeup_scheduled = True
return

def locked(self):
"""Returns True if semaphore can not be acquired immediately."""
return self._value == 0
return self._value <= 0

async def acquire(self):
"""Acquire a semaphore.

If the internal counter is larger than zero on entry,
decrement it by one and return True immediately. If it is
zero on entry, block, waiting until some other coroutine has
called release() to make it larger than 0, and then return
True.
"""
# _wakeup_scheduled is set if *another* task is scheduled to wakeup
# but its acquire() is not resumed yet
while self._wakeup_scheduled or self._value <= 0:
fut = self._get_loop().create_future()
self._waiters.append(fut)
if (not self.locked() and (self._waiters is None or
all(w.cancelled() for w in self._waiters))):
self._value -= 1
return True

if self._waiters is None:
self._waiters = collections.deque()
fut = self._get_loop().create_future()
self._waiters.append(fut)

# Finally block should be called before the CancelledError
# handling as we don't want CancelledError to call
# _wake_up_first() and attempt to wake up itself.
try:
try:
await fut
# reset _wakeup_scheduled *after* waiting for a future
self._wakeup_scheduled = False
except exceptions.CancelledError:
self._wake_up_next()
raise
finally:
self._waiters.remove(fut)
except exceptions.CancelledError:
if not self.locked():
self._wake_up_first()
raise

self._value -= 1
if not self.locked():
self._wake_up_first()
return True

def release(self):
Expand All @@ -399,7 +402,22 @@ def release(self):
become larger than zero again, wake up that coroutine.
"""
self._value += 1
self._wake_up_next()
self._wake_up_first()

def _wake_up_first(self):
"""Wake up the first waiter if it isn't done."""
if not self._waiters:
return
try:
fut = next(iter(self._waiters))
except StopIteration:
return

# .done() necessarily means that a waiter will wake up later on and
# either take the lock, or, if it was cancelled and lock wasn't
# taken already, will hit this again and wake up a new waiter.
if not fut.done():
fut.set_result(True)


class BoundedSemaphore(Semaphore):
Expand Down
107 changes: 102 additions & 5 deletions Lib/test/test_asyncio/test_locks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re

import asyncio
import collections

STR_RGX_REPR = (
r'^<(?P<class>.*?) object at (?P<address>.*?)'
Expand Down Expand Up @@ -774,6 +775,9 @@ async def test_repr(self):
self.assertTrue('waiters' not in repr(sem))
self.assertTrue(RGX_REPR.match(repr(sem)))

if sem._waiters is None:
sem._waiters = collections.deque()

sem._waiters.append(mock.Mock())
self.assertTrue('waiters:1' in repr(sem))
self.assertTrue(RGX_REPR.match(repr(sem)))
Expand Down Expand Up @@ -830,7 +834,7 @@ async def c4(result):
t2 = asyncio.create_task(c2(result))
t3 = asyncio.create_task(c3(result))

await asyncio.sleep(0)
await asyncio.sleep(0.01)
self.assertEqual([1], result)
self.assertTrue(sem.locked())
self.assertEqual(2, len(sem._waiters))
Expand All @@ -842,7 +846,7 @@ async def c4(result):
sem.release()
self.assertEqual(2, sem._value)

await asyncio.sleep(0)
await asyncio.sleep(0.01)
self.assertEqual(0, sem._value)
self.assertEqual(3, len(result))
self.assertTrue(sem.locked())
Expand Down Expand Up @@ -878,13 +882,13 @@ async def test_acquire_cancel_before_awoken(self):
t3 = asyncio.create_task(sem.acquire())
t4 = asyncio.create_task(sem.acquire())

await asyncio.sleep(0)
await asyncio.sleep(0.01)

t1.cancel()
t2.cancel()
sem.release()

await asyncio.sleep(0)
await asyncio.sleep(0.01)
num_done = sum(t.done() for t in [t3, t4])
self.assertEqual(num_done, 1)
self.assertTrue(t3.done())
Expand All @@ -903,10 +907,32 @@ async def test_acquire_hang(self):

t1.cancel()
sem.release()
await asyncio.sleep(0)
await asyncio.sleep(0.01)
self.assertTrue(sem.locked())
self.assertTrue(t2.done())

async def test_acquire_no_hang(self):

sem = asyncio.Semaphore(1)

async def c1():
async with sem:
await asyncio.sleep(0)
t2.cancel()

async def c2():
async with sem:
await asyncio.sleep(0)

t1 = asyncio.create_task(c1())
t2 = asyncio.create_task(c2())

result = await asyncio.gather(t1, t2, return_exceptions=True)
self.assertTrue(result[0] is None)
self.assertTrue(isinstance(result[1], asyncio.CancelledError))

await asyncio.wait_for(sem.acquire(), timeout=0.01)

def test_release_not_acquired(self):
sem = asyncio.BoundedSemaphore()

Expand Down Expand Up @@ -945,6 +971,77 @@ async def coro(tag):
result
)

async def test_acquire_fifo_order_2(self):
sem = asyncio.Semaphore(1)
result = []

async def c1(result):
await sem.acquire()
result.append(1)
return True

async def c2(result):
await sem.acquire()
result.append(2)
sem.release()
await sem.acquire()
result.append(4)
return True

async def c3(result):
await sem.acquire()
result.append(3)
return True

t1 = asyncio.create_task(c1(result))
t2 = asyncio.create_task(c2(result))
t3 = asyncio.create_task(c3(result))

await asyncio.sleep(0)

sem.release()
sem.release()

tasks = [t1, t2, t3]
await asyncio.gather(*tasks)
self.assertEqual([1, 2, 3, 4], result)

async def test_acquire_fifo_order_3(self):
sem = asyncio.Semaphore(0)
result = []

async def c1(result):
await sem.acquire()
result.append(1)
return True

async def c2(result):
await sem.acquire()
result.append(2)
return True

async def c3(result):
await sem.acquire()
result.append(3)
return True

t1 = asyncio.create_task(c1(result))
t2 = asyncio.create_task(c2(result))
t3 = asyncio.create_task(c3(result))

await asyncio.sleep(0)

t1.cancel()

await asyncio.sleep(0.01)

sem.release()
sem.release()

tasks = [t1, t2, t3]
await asyncio.gather(*tasks, return_exceptions=True)
self.assertEqual([2, 3], result)


class BarrierTests(unittest.IsolatedAsyncioTestCase):

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix broken :class:`asyncio.Semaphore` when acquire is cancelled.