Skip to content

Commit a09e25f

Browse files
authored
Tests to make sure connections are re-auth'ed on full pool (#925)
Signed-off-by: Florent Biville <[email protected]>
1 parent d4d1c84 commit a09e25f

File tree

3 files changed

+93
-0
lines changed

3 files changed

+93
-0
lines changed

tests/unit/async_/io/test_direct.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ def defunct(self):
8383
def timedout(self):
8484
return False
8585

86+
def assert_re_auth_support(self):
87+
pass
88+
8689

8790
class AsyncFakeBoltPool(AsyncIOPool):
8891
is_direct_pool = False

tests/unit/mixed/io/test_direct.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818

1919
import asyncio
20+
import threading
21+
import time
2022
from asyncio import Event as AsyncEvent
2123
from threading import (
2224
Event,
@@ -26,10 +28,20 @@
2628

2729
import pytest
2830

31+
from neo4j._async.io._pool import AcquireAuth as AsyncAcquireAuth
2932
from neo4j._deadline import Deadline
33+
from neo4j._sync.io._pool import AcquireAuth
34+
from neo4j.auth_management import (
35+
AsyncAuthManagers,
36+
AuthManagers,
37+
)
3038

3139
from ...async_.io.test_direct import AsyncFakeBoltPool
40+
from ...async_.test_auth_manager import (
41+
static_auth_manager as static_async_auth_manager,
42+
)
3243
from ...sync.io.test_direct import FakeBoltPool
44+
from ...sync.test_auth_manager import static_auth_manager
3345
from ._common import (
3446
AsyncMultiEvent,
3547
MultiEvent,
@@ -111,6 +123,49 @@ def acquire_release_conn(pool_, address_, acquired_counter_,
111123
# The pool size is still 5, but all are free
112124
self.assert_pool_size(address, 0, 5, pool)
113125

126+
def test_full_pool_re_auth(self, mocker):
127+
address = ("127.0.0.1", 7687)
128+
acquire_auth1 = AcquireAuth(auth=static_auth_manager(
129+
("user1", "pass1"))
130+
)
131+
auth2 = ("user2", "pass2")
132+
acquire_auth2 = AcquireAuth(auth=static_auth_manager(auth2))
133+
acquire1_event = threading.Event()
134+
cx1 = None
135+
136+
def acquire1(pool_):
137+
nonlocal cx1
138+
cx = pool_._acquire(address, acquire_auth1, Deadline(0), None)
139+
acquire1_event.set()
140+
cx1 = cx
141+
while True:
142+
with pool_.cond:
143+
# _waiters is an internal attribute of threading.Condition
144+
# this might break in the future, but we couldn't come up
145+
# with a better way of waiting for the other thread to block.
146+
waiters = len(pool_.cond._waiters)
147+
if waiters:
148+
break
149+
time.sleep(0.001)
150+
cx.re_auth = mocker.Mock(spec=cx.re_auth)
151+
pool_.release(cx)
152+
153+
def acquire2(pool_):
154+
acquire1_event.wait(timeout=10)
155+
cx = pool_._acquire(address, acquire_auth2, Deadline(10), None)
156+
assert cx is cx1
157+
cx.re_auth.assert_called_once()
158+
assert auth2 in cx.re_auth.call_args.args
159+
pool_.release(cx)
160+
161+
with FakeBoltPool((), max_connection_pool_size=1) as pool:
162+
t1 = threading.Thread(target=acquire1, args=(pool,), daemon=True)
163+
t2 = threading.Thread(target=acquire2, args=(pool,), daemon=True)
164+
t1.start()
165+
t2.start()
166+
t1.join()
167+
t2.join()
168+
114169
@pytest.mark.parametrize("pre_populated", (0, 3, 5))
115170
@pytest.mark.asyncio
116171
async def test_multi_coroutine(self, pre_populated):
@@ -172,3 +227,35 @@ async def waiter(pool_, acquired_counter_, release_event_):
172227
waiter(pool, acquired_counter, release_event),
173228
*coroutines
174229
)
230+
231+
@pytest.mark.asyncio
232+
async def test_full_pool_re_auth_async(self, mocker):
233+
address = ("127.0.0.1", 7687)
234+
acquire_auth1 = AsyncAcquireAuth(auth=static_async_auth_manager(
235+
("user1", "pass1"))
236+
)
237+
auth2 = ("user2", "pass2")
238+
acquire_auth2 = AsyncAcquireAuth(auth=static_async_auth_manager(auth2))
239+
cx1 = None
240+
241+
async def acquire1(pool_):
242+
nonlocal cx1
243+
cx = await pool_._acquire(address, acquire_auth1, Deadline(0), None)
244+
cx1 = cx
245+
while len(pool_.cond._waiters) == 0:
246+
await asyncio.sleep(0)
247+
cx.re_auth = mocker.Mock(spec=cx.re_auth)
248+
await pool_.release(cx)
249+
250+
async def acquire2(pool_):
251+
while cx1 is None:
252+
await asyncio.sleep(0)
253+
cx = await pool_._acquire(address, acquire_auth2,
254+
Deadline(float("inf")), None)
255+
assert cx is cx1
256+
cx.re_auth.assert_called_once()
257+
assert auth2 in cx.re_auth.call_args.args
258+
await pool_.release(cx)
259+
260+
async with AsyncFakeBoltPool((), max_connection_pool_size=1) as pool:
261+
await asyncio.gather(acquire1(pool), acquire2(pool))

tests/unit/sync/io/test_direct.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ def defunct(self):
8383
def timedout(self):
8484
return False
8585

86+
def assert_re_auth_support(self):
87+
pass
88+
8689

8790
class FakeBoltPool(IOPool):
8891
is_direct_pool = False

0 commit comments

Comments
 (0)