21
21
# make sure TAuth is resolved in the docs, else they're pretty useless
22
22
23
23
24
- import time
25
24
import typing as t
25
+ import warnings
26
26
from logging import getLogger
27
27
28
28
from .._async_compat .concurrency import AsyncLock
31
31
expiring_auth_has_expired ,
32
32
ExpiringAuth ,
33
33
)
34
- from .._meta import preview
34
+ from .._meta import (
35
+ preview ,
36
+ PreviewWarning ,
37
+ )
35
38
36
39
# work around for https://github.com/sphinx-doc/sphinx/pull/10880
37
40
# make sure TAuth is resolved in the docs, else they're pretty useless
38
41
# if t.TYPE_CHECKING:
39
42
from ..api import _TAuth
43
+ from ..exceptions import Neo4jError
40
44
41
45
42
46
log = getLogger ("neo4j" )
@@ -51,21 +55,25 @@ def __init__(self, auth: _TAuth) -> None:
51
55
async def get_auth (self ) -> _TAuth :
52
56
return self ._auth
53
57
54
- async def on_auth_expired (self , auth : _TAuth ) -> None :
55
- pass
58
+ async def handle_security_exception (
59
+ self , auth : _TAuth , error : Neo4jError
60
+ ) -> bool :
61
+ return False
56
62
57
63
58
- class AsyncExpirationBasedAuthManager (AsyncAuthManager ):
64
+ class AsyncNeo4jAuthTokenManager (AsyncAuthManager ):
59
65
_current_auth : t .Optional [ExpiringAuth ]
60
66
_provider : t .Callable [[], t .Awaitable [ExpiringAuth ]]
67
+ _handled_codes : t .FrozenSet [str ]
61
68
_lock : AsyncLock
62
69
63
-
64
70
def __init__ (
65
71
self ,
66
- provider : t .Callable [[], t .Awaitable [ExpiringAuth ]]
72
+ provider : t .Callable [[], t .Awaitable [ExpiringAuth ]],
73
+ handled_codes : t .FrozenSet [str ]
67
74
) -> None :
68
75
self ._provider = provider
76
+ self ._handled_codes = handled_codes
69
77
self ._current_auth = None
70
78
self ._lock = AsyncLock ()
71
79
@@ -81,18 +89,25 @@ async def get_auth(self) -> _TAuth:
81
89
async with self ._lock :
82
90
auth = self ._current_auth
83
91
if auth is None or expiring_auth_has_expired (auth ):
84
- log .debug ("[ ] _: <TEMPORAL AUTH> refreshing (time out)" )
92
+ log .debug ("[ ] _: <AUTH MANAGER> refreshing (%s)" ,
93
+ "init" if auth is None else "time out" )
85
94
await self ._refresh_auth ()
86
95
auth = self ._current_auth
87
96
assert auth is not None
88
97
return auth .auth
89
98
90
- async def on_auth_expired (self , auth : _TAuth ) -> None :
99
+ async def handle_security_exception (
100
+ self , auth : _TAuth , error : Neo4jError
101
+ ) -> bool :
102
+ if error .code not in self ._handled_codes :
103
+ return False
91
104
async with self ._lock :
92
105
cur_auth = self ._current_auth
93
106
if cur_auth is not None and cur_auth .auth == auth :
94
- log .debug ("[ ] _: <TEMPORAL AUTH> refreshing (error)" )
107
+ log .debug ("[ ] _: <AUTH MANAGER> refreshing (error %s)" ,
108
+ error .code )
95
109
await self ._refresh_auth ()
110
+ return True
96
111
97
112
98
113
class AsyncAuthManagers :
@@ -103,6 +118,11 @@ class AsyncAuthManagers:
103
118
See also https://github.com/neo4j/neo4j-python-driver/wiki/preview-features
104
119
105
120
.. versionadded:: 5.8
121
+
122
+ .. versionchanged:: 5.12
123
+
124
+ * Method ``expiration_based()`` was renamed to :meth:`bearer`.
125
+ * Added :meth:`basic`.
106
126
"""
107
127
108
128
@staticmethod
@@ -139,10 +159,72 @@ def static(auth: _TAuth) -> AsyncAuthManager:
139
159
140
160
@staticmethod
141
161
@preview ("Auth managers are a preview feature." )
142
- def expiration_based (
162
+ def basic (
163
+ provider : t .Callable [[], t .Awaitable [_TAuth ]]
164
+ ) -> AsyncAuthManager :
165
+ """Create an auth manager handling basic auth password rotation.
166
+
167
+ .. warning::
168
+
169
+ The provider function **must not** interact with the driver in any
170
+ way as this can cause deadlocks and undefined behaviour.
171
+
172
+ The provider function must only ever return auth information
173
+ belonging to the same identity.
174
+ Switching identities is undefined behavior.
175
+ You may use session-level authentication for such use-cases
176
+ :ref:`session-auth-ref`.
177
+
178
+ Example::
179
+
180
+ import neo4j
181
+ from neo4j.auth_management import (
182
+ AsyncAuthManagers,
183
+ ExpiringAuth,
184
+ )
185
+
186
+
187
+ async def auth_provider():
188
+ # some way of getting a token
189
+ user, password = await get_current_auth()
190
+ return (user, password)
191
+
192
+
193
+ with neo4j.GraphDatabase.driver(
194
+ "neo4j://example.com:7687",
195
+ auth=AsyncAuthManagers.basic(auth_provider)
196
+ ) as driver:
197
+ ... # do stuff
198
+
199
+ :param provider:
200
+ A callable that provides a :class:`.ExpiringAuth` instance.
201
+
202
+ :returns:
203
+ An instance of an implementation of :class:`.AsyncAuthManager` that
204
+ returns auth info from the given provider and refreshes it, calling
205
+ the provider again, when the auth info expires (either because it's
206
+ reached its expiry time or because the server flagged it as
207
+ expired).
208
+
209
+ .. versionadded:: 5.12
210
+ """
211
+ handled_codes = frozenset (("Neo.ClientError.Security.Unauthorized" ,))
212
+
213
+ async def wrapped_provider () -> ExpiringAuth :
214
+ with warnings .catch_warnings ():
215
+ warnings .filterwarnings ("ignore" ,
216
+ message = r"^Auth managers\b.*" ,
217
+ category = PreviewWarning )
218
+ return ExpiringAuth (await provider ())
219
+
220
+ return AsyncNeo4jAuthTokenManager (wrapped_provider , handled_codes )
221
+
222
+ @staticmethod
223
+ @preview ("Auth managers are a preview feature." )
224
+ def bearer (
143
225
provider : t .Callable [[], t .Awaitable [ExpiringAuth ]]
144
226
) -> AsyncAuthManager :
145
- """Create an auth manager for potentially expiring auth info .
227
+ """Create an auth manager for potentially expiring bearer auth tokens .
146
228
147
229
.. warning::
148
230
@@ -165,7 +247,7 @@ def expiration_based(
165
247
166
248
167
249
async def auth_provider():
168
- # some way to getting a token
250
+ # some way of getting a token
169
251
sso_token = await get_sso_token()
170
252
# assume we know our tokens expire every 60 seconds
171
253
expires_in = 60
@@ -180,7 +262,7 @@ async def auth_provider():
180
262
181
263
with neo4j.GraphDatabase.driver(
182
264
"neo4j://example.com:7687",
183
- auth=AsyncAuthManagers.temporal (auth_provider)
265
+ auth=AsyncAuthManagers.bearer (auth_provider)
184
266
) as driver:
185
267
... # do stuff
186
268
@@ -194,6 +276,10 @@ async def auth_provider():
194
276
reached its expiry time or because the server flagged it as
195
277
expired).
196
278
197
-
279
+ .. versionadded:: 5.12
198
280
"""
199
- return AsyncExpirationBasedAuthManager (provider )
281
+ handled_codes = frozenset ((
282
+ "Neo.ClientError.Security.TokenExpired" ,
283
+ "Neo.ClientError.Security.Unauthorized" ,
284
+ ))
285
+ return AsyncNeo4jAuthTokenManager (provider , handled_codes )
0 commit comments