diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 9e0dba8e..ace6b337 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -100,7 +100,7 @@ def __init__(self, protocol, transport, loop, # # Used for `con.fetchval()`, `con.fetch()`, `con.fetchrow()`, # `con.execute()`, and `con.executemany()`. - self._stmt_exclusive_section = _Atomic() + self._stmt_exclusive_section = asyncio.Lock(loop=loop) if loop.get_debug(): self._source_traceback = _extract_stack() @@ -1409,7 +1409,7 @@ async def reload_schema_state(self): self._drop_global_statement_cache() async def _execute(self, query, args, limit, timeout, return_status=False): - with self._stmt_exclusive_section: + async with self._stmt_exclusive_section: result, _ = await self.__execute( query, args, limit, timeout, return_status=return_status) return result @@ -1425,7 +1425,7 @@ async def _executemany(self, query, args, timeout): executor = lambda stmt, timeout: self._protocol.bind_execute_many( stmt, args, '', timeout) timeout = self._protocol._get_timeout(timeout) - with self._stmt_exclusive_section: + async with self._stmt_exclusive_section: result, _ = await self._do_execute(query, executor, timeout) return result @@ -1783,22 +1783,6 @@ def _maybe_cleanup(self): self._on_remove(old_entry._statement) -class _Atomic: - __slots__ = ('_acquired',) - - def __init__(self): - self._acquired = 0 - - def __enter__(self): - if self._acquired: - raise exceptions.InterfaceError( - 'cannot perform operation: another operation is in progress') - self._acquired = 1 - - def __exit__(self, t, e, tb): - self._acquired = 0 - - class _ConnectionProxy: # Base class to enable `isinstance(Connection)` check. __slots__ = () diff --git a/tests/test_prepare.py b/tests/test_prepare.py index 8fc06e3e..b9b2ea6d 100644 --- a/tests/test_prepare.py +++ b/tests/test_prepare.py @@ -383,15 +383,22 @@ async def test_prepare_20_concurrent_calls(self): meth = getattr(self.con, methname) + footprints = [] + + async def _trace(awaitable, footprint_before, footprint_after): + footprints.append(footprint_before) + result = await awaitable + footprints.append(footprint_after) + return result + vf = self.loop.create_task( - meth('SELECT ROW(pg_sleep(0.1), 1)')) + _trace(meth('SELECT ROW(pg_sleep(0.1), 1)'), '11', '12')) await asyncio.sleep(0.01, loop=self.loop) - with self.assertRaisesRegex(asyncpg.InterfaceError, - 'another operation'): - await meth('SELECT 2') + await _trace(meth('SELECT 2'), '21', '22') + self.assertEqual(footprints, ['11', '21', '12', '22']) self.assertEqual(await vf, val) async def test_prepare_21_errors(self):