Skip to content

Commit 496480e

Browse files
authored
Update AyncCondition to be based on Python 3.11 (#879)
Most notably, this changes how the async driver behaves when created while no event loop is running. Previously, this would lead to hard to predict errors only occurring once the connection pool ran full. Now, the async synchronization primitives will defer binding to an event loop until they're used in an async context for the first time. This simply allows users to safely create a driver without a running event loop (in a sync context) and later use it in an async context. Important note: this will likely only work when the user is on Python 3.10+ because the driver also relies on synchronization primitives that come with `asyncio`. So their behavior depends on the used Python version. + Add test for creating async driver in sync environment
1 parent 9cfa4c8 commit 496480e

File tree

2 files changed

+129
-14
lines changed

2 files changed

+129
-14
lines changed

src/neo4j/_async_compat/concurrency.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -305,26 +305,17 @@ class AsyncCondition:
305305
A new Lock object is created and used as the underlying lock.
306306
"""
307307

308-
# copied and modified from Python 3.7's asyncio.locks module
309-
# to add support for `.wait(timeout)`
308+
# copied and modified from Python 3.11's asyncio package
309+
# to add support for `.wait(timeout)` and cooperative locks
310310

311311
# Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010,
312312
# 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022
313313
# Python Software Foundation;
314314
# All Rights Reserved
315315

316-
def __init__(self, lock=None, *, loop=None):
317-
if loop is not None:
318-
self._loop = loop
319-
else:
320-
self._loop = asyncio.get_event_loop()
321-
316+
def __init__(self, lock=None):
322317
if lock is None:
323-
lock = asyncio.Lock(loop=self._loop)
324-
elif (hasattr(lock, "_loop")
325-
and lock._loop is not None
326-
and lock._loop is not self._loop):
327-
raise ValueError("loop argument must agree with lock")
318+
lock = AsyncLock()
328319

329320
self._lock = lock
330321
# Export the lock's locked(), acquire() and release() methods.
@@ -334,6 +325,23 @@ def __init__(self, lock=None, *, loop=None):
334325

335326
self._waiters = collections.deque()
336327

328+
_loop = None
329+
_loop_lock = threading.Lock()
330+
331+
def _get_loop(self):
332+
try:
333+
loop = asyncio.get_running_loop()
334+
except RuntimeError:
335+
loop = None
336+
337+
if self._loop is None:
338+
with self._loop_lock:
339+
if self._loop is None:
340+
self._loop = loop
341+
if loop is not self._loop:
342+
raise RuntimeError(f'{self!r} is bound to a different event loop')
343+
return loop
344+
337345
async def __aenter__(self):
338346
if isinstance(self._lock, (AsyncCooperativeLock,
339347
AsyncCooperativeRLock)):
@@ -374,7 +382,7 @@ async def _wait(self, timeout=None, me=None):
374382
else:
375383
self._lock.release()
376384
try:
377-
fut = self._loop.create_future()
385+
fut = self._get_loop().create_future()
378386
self._waiters.append(fut)
379387
try:
380388
await wait_for(fut, timeout)
@@ -409,6 +417,19 @@ async def wait(self, timeout=None):
409417
me = asyncio.current_task()
410418
return await self._wait(timeout=timeout, me=me)
411419

420+
async def wait_for(self, predicate):
421+
"""Wait until a predicate becomes true.
422+
423+
The predicate should be a callable which result will be
424+
interpreted as a boolean value. The final predicate value is
425+
the return value.
426+
"""
427+
result = predicate()
428+
while not result:
429+
await self.wait()
430+
result = predicate()
431+
return result
432+
412433
def notify(self, n=1):
413434
"""By default, wake up one coroutine waiting on this condition, if any.
414435
If the calling coroutine has not acquired the lock when this method
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
#
4+
# This file is part of Neo4j.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# https://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
19+
import asyncio
20+
import sys
21+
22+
import pytest
23+
24+
import neo4j
25+
26+
from ... import env
27+
28+
29+
# TODO: Python 3.9: when support gets dropped, remove this mark
30+
@pytest.mark.xfail(
31+
# direct driver is not making use of `asyncio.Lock`.
32+
sys.version_info < (3, 10) and env.NEO4J_SCHEME == "neo4j",
33+
reason="asyncio's synchronization primitives can create a new event loop "
34+
"if instantiated while there is no running event loop. This "
35+
"changed with Python 3.10.",
36+
raises=RuntimeError,
37+
strict=True,
38+
)
39+
def test_can_create_async_driver_outside_of_loop(uri, auth):
40+
pool_size = 2
41+
# used to make sure the pool was full at least at some point
42+
counter = 0
43+
was_full = False
44+
45+
async def return_1(tx):
46+
nonlocal counter, was_full
47+
res = await tx.run("RETURN 1")
48+
49+
counter += 1
50+
while not was_full and counter < pool_size:
51+
await asyncio.sleep(0.001)
52+
if not was_full:
53+
# a little extra time to make sure a connection too many was
54+
# tried to be acquired from the pool
55+
was_full = True
56+
await asyncio.sleep(0.5)
57+
58+
await res.consume()
59+
counter -= 1
60+
61+
async def run(driver: neo4j.AsyncDriver):
62+
async with driver:
63+
sessions = []
64+
try:
65+
for i in range(pool_size * 4):
66+
sessions.append(driver.session())
67+
work_loads = (session.execute_read(return_1)
68+
for session in sessions)
69+
await asyncio.gather(*work_loads)
70+
finally:
71+
cancelled = None
72+
for session in sessions:
73+
if not cancelled:
74+
try:
75+
await session.close()
76+
except asyncio.CancelledError as e:
77+
session.cancel()
78+
cancelled = e
79+
else:
80+
session.cancel()
81+
await driver.close()
82+
if cancelled:
83+
raise cancelled
84+
85+
86+
driver = neo4j.AsyncGraphDatabase.driver(
87+
uri, auth=auth, max_connection_pool_size=pool_size
88+
)
89+
coro = run(driver)
90+
loop = asyncio.new_event_loop()
91+
try:
92+
loop.run_until_complete(coro)
93+
finally:
94+
loop.close()

0 commit comments

Comments
 (0)