|
17 | 17 |
|
18 | 18 |
|
19 | 19 | import asyncio
|
| 20 | +import threading |
| 21 | +import time |
20 | 22 | from asyncio import Event as AsyncEvent
|
21 | 23 | from threading import (
|
22 | 24 | Event,
|
|
26 | 28 |
|
27 | 29 | import pytest
|
28 | 30 |
|
| 31 | +from neo4j._async.io._pool import AcquireAuth as AsyncAcquireAuth |
29 | 32 | from neo4j._deadline import Deadline
|
| 33 | +from neo4j._sync.io._pool import AcquireAuth |
| 34 | +from neo4j.auth_management import ( |
| 35 | + AsyncAuthManagers, |
| 36 | + AuthManagers, |
| 37 | +) |
30 | 38 |
|
31 | 39 | from ...async_.io.test_direct import AsyncFakeBoltPool
|
| 40 | +from ...async_.test_auth_manager import ( |
| 41 | + static_auth_manager as static_async_auth_manager, |
| 42 | +) |
32 | 43 | from ...sync.io.test_direct import FakeBoltPool
|
| 44 | +from ...sync.test_auth_manager import static_auth_manager |
33 | 45 | from ._common import (
|
34 | 46 | AsyncMultiEvent,
|
35 | 47 | MultiEvent,
|
@@ -111,6 +123,49 @@ def acquire_release_conn(pool_, address_, acquired_counter_,
|
111 | 123 | # The pool size is still 5, but all are free
|
112 | 124 | self.assert_pool_size(address, 0, 5, pool)
|
113 | 125 |
|
| 126 | + def test_full_pool_re_auth(self, mocker): |
| 127 | + address = ("127.0.0.1", 7687) |
| 128 | + acquire_auth1 = AcquireAuth(auth=static_auth_manager( |
| 129 | + ("user1", "pass1")) |
| 130 | + ) |
| 131 | + auth2 = ("user2", "pass2") |
| 132 | + acquire_auth2 = AcquireAuth(auth=static_auth_manager(auth2)) |
| 133 | + acquire1_event = threading.Event() |
| 134 | + cx1 = None |
| 135 | + |
| 136 | + def acquire1(pool_): |
| 137 | + nonlocal cx1 |
| 138 | + cx = pool_._acquire(address, acquire_auth1, Deadline(0), None) |
| 139 | + acquire1_event.set() |
| 140 | + cx1 = cx |
| 141 | + while True: |
| 142 | + with pool_.cond: |
| 143 | + # _waiters is an internal attribute of threading.Condition |
| 144 | + # this might break in the future, but we couldn't come up |
| 145 | + # with a better way of waiting for the other thread to block. |
| 146 | + waiters = len(pool_.cond._waiters) |
| 147 | + if waiters: |
| 148 | + break |
| 149 | + time.sleep(0.001) |
| 150 | + cx.re_auth = mocker.Mock(spec=cx.re_auth) |
| 151 | + pool_.release(cx) |
| 152 | + |
| 153 | + def acquire2(pool_): |
| 154 | + acquire1_event.wait(timeout=10) |
| 155 | + cx = pool_._acquire(address, acquire_auth2, Deadline(10), None) |
| 156 | + assert cx is cx1 |
| 157 | + cx.re_auth.assert_called_once() |
| 158 | + assert auth2 in cx.re_auth.call_args.args |
| 159 | + pool_.release(cx) |
| 160 | + |
| 161 | + with FakeBoltPool((), max_connection_pool_size=1) as pool: |
| 162 | + t1 = threading.Thread(target=acquire1, args=(pool,), daemon=True) |
| 163 | + t2 = threading.Thread(target=acquire2, args=(pool,), daemon=True) |
| 164 | + t1.start() |
| 165 | + t2.start() |
| 166 | + t1.join() |
| 167 | + t2.join() |
| 168 | + |
114 | 169 | @pytest.mark.parametrize("pre_populated", (0, 3, 5))
|
115 | 170 | @pytest.mark.asyncio
|
116 | 171 | async def test_multi_coroutine(self, pre_populated):
|
@@ -172,3 +227,35 @@ async def waiter(pool_, acquired_counter_, release_event_):
|
172 | 227 | waiter(pool, acquired_counter, release_event),
|
173 | 228 | *coroutines
|
174 | 229 | )
|
| 230 | + |
| 231 | + @pytest.mark.asyncio |
| 232 | + async def test_full_pool_re_auth_async(self, mocker): |
| 233 | + address = ("127.0.0.1", 7687) |
| 234 | + acquire_auth1 = AsyncAcquireAuth(auth=static_async_auth_manager( |
| 235 | + ("user1", "pass1")) |
| 236 | + ) |
| 237 | + auth2 = ("user2", "pass2") |
| 238 | + acquire_auth2 = AsyncAcquireAuth(auth=static_async_auth_manager(auth2)) |
| 239 | + cx1 = None |
| 240 | + |
| 241 | + async def acquire1(pool_): |
| 242 | + nonlocal cx1 |
| 243 | + cx = await pool_._acquire(address, acquire_auth1, Deadline(0), None) |
| 244 | + cx1 = cx |
| 245 | + while len(pool_.cond._waiters) == 0: |
| 246 | + await asyncio.sleep(0) |
| 247 | + cx.re_auth = mocker.Mock(spec=cx.re_auth) |
| 248 | + await pool_.release(cx) |
| 249 | + |
| 250 | + async def acquire2(pool_): |
| 251 | + while cx1 is None: |
| 252 | + await asyncio.sleep(0) |
| 253 | + cx = await pool_._acquire(address, acquire_auth2, |
| 254 | + Deadline(float("inf")), None) |
| 255 | + assert cx is cx1 |
| 256 | + cx.re_auth.assert_called_once() |
| 257 | + assert auth2 in cx.re_auth.call_args.args |
| 258 | + await pool_.release(cx) |
| 259 | + |
| 260 | + async with AsyncFakeBoltPool((), max_connection_pool_size=1) as pool: |
| 261 | + await asyncio.gather(acquire1(pool), acquire2(pool)) |
0 commit comments