Skip to content

Commit 79edb1a

Browse files
authored
ADR 019: revamp auth managers (#957)
This PR updates the preview feature "re-auth" (introduced in [PR #890](#890)) significantly. The changes allow for catering to a wider range of use cases including simple password rotation. (⚠️ Breaking) changes: * Removed `TokenExpiredRetryable` exception. Even though it wasn't marked preview, it was introduced with and only used for re-auth. It now longer serves any purpose. * The `AuthManager` and `AsyncAuthManager` abstract classes were changed. The method `on_auth_expired(self, auth: _TAuth) -> None` was removed in favor of `def handle_security_exception(self, auth: _TAuth, error: Neo4jError) -> bool`. See the API docs for more details. * The factories in `AsyncAuthManagers`a nd `AuthManagers` were changed. * `expiration_based` was renamed to `bearer`. * `basic` was added to cater for password rotation.
1 parent a09e25f commit 79edb1a

16 files changed

+649
-201
lines changed

docs/source/api.rst

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1838,9 +1838,6 @@ Server-side errors
18381838
.. autoexception:: neo4j.exceptions.TokenExpired()
18391839
:show-inheritance:
18401840

1841-
.. autoexception:: neo4j.exceptions.TokenExpiredRetryable()
1842-
:show-inheritance:
1843-
18441841
.. autoexception:: neo4j.exceptions.Forbidden()
18451842
:show-inheritance:
18461843

src/neo4j/_async/auth_management.py

Lines changed: 102 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
# make sure TAuth is resolved in the docs, else they're pretty useless
2222

2323

24-
import time
2524
import typing as t
25+
import warnings
2626
from logging import getLogger
2727

2828
from .._async_compat.concurrency import AsyncLock
@@ -31,12 +31,16 @@
3131
expiring_auth_has_expired,
3232
ExpiringAuth,
3333
)
34-
from .._meta import preview
34+
from .._meta import (
35+
preview,
36+
PreviewWarning,
37+
)
3538

3639
# work around for https://github.com/sphinx-doc/sphinx/pull/10880
3740
# make sure TAuth is resolved in the docs, else they're pretty useless
3841
# if t.TYPE_CHECKING:
3942
from ..api import _TAuth
43+
from ..exceptions import Neo4jError
4044

4145

4246
log = getLogger("neo4j")
@@ -51,21 +55,25 @@ def __init__(self, auth: _TAuth) -> None:
5155
async def get_auth(self) -> _TAuth:
5256
return self._auth
5357

54-
async def on_auth_expired(self, auth: _TAuth) -> None:
55-
pass
58+
async def handle_security_exception(
59+
self, auth: _TAuth, error: Neo4jError
60+
) -> bool:
61+
return False
5662

5763

58-
class AsyncExpirationBasedAuthManager(AsyncAuthManager):
64+
class AsyncNeo4jAuthTokenManager(AsyncAuthManager):
5965
_current_auth: t.Optional[ExpiringAuth]
6066
_provider: t.Callable[[], t.Awaitable[ExpiringAuth]]
67+
_handled_codes: t.FrozenSet[str]
6168
_lock: AsyncLock
6269

63-
6470
def __init__(
6571
self,
66-
provider: t.Callable[[], t.Awaitable[ExpiringAuth]]
72+
provider: t.Callable[[], t.Awaitable[ExpiringAuth]],
73+
handled_codes: t.FrozenSet[str]
6774
) -> None:
6875
self._provider = provider
76+
self._handled_codes = handled_codes
6977
self._current_auth = None
7078
self._lock = AsyncLock()
7179

@@ -81,18 +89,25 @@ async def get_auth(self) -> _TAuth:
8189
async with self._lock:
8290
auth = self._current_auth
8391
if auth is None or expiring_auth_has_expired(auth):
84-
log.debug("[ ] _: <TEMPORAL AUTH> refreshing (time out)")
92+
log.debug("[ ] _: <AUTH MANAGER> refreshing (%s)",
93+
"init" if auth is None else "time out")
8594
await self._refresh_auth()
8695
auth = self._current_auth
8796
assert auth is not None
8897
return auth.auth
8998

90-
async def on_auth_expired(self, auth: _TAuth) -> None:
99+
async def handle_security_exception(
100+
self, auth: _TAuth, error: Neo4jError
101+
) -> bool:
102+
if error.code not in self._handled_codes:
103+
return False
91104
async with self._lock:
92105
cur_auth = self._current_auth
93106
if cur_auth is not None and cur_auth.auth == auth:
94-
log.debug("[ ] _: <TEMPORAL AUTH> refreshing (error)")
107+
log.debug("[ ] _: <AUTH MANAGER> refreshing (error %s)",
108+
error.code)
95109
await self._refresh_auth()
110+
return True
96111

97112

98113
class AsyncAuthManagers:
@@ -103,6 +118,11 @@ class AsyncAuthManagers:
103118
See also https://github.com/neo4j/neo4j-python-driver/wiki/preview-features
104119
105120
.. versionadded:: 5.8
121+
122+
.. versionchanged:: 5.12
123+
124+
* Method ``expiration_based()`` was renamed to :meth:`bearer`.
125+
* Added :meth:`basic`.
106126
"""
107127

108128
@staticmethod
@@ -139,10 +159,72 @@ def static(auth: _TAuth) -> AsyncAuthManager:
139159

140160
@staticmethod
141161
@preview("Auth managers are a preview feature.")
142-
def expiration_based(
162+
def basic(
163+
provider: t.Callable[[], t.Awaitable[_TAuth]]
164+
) -> AsyncAuthManager:
165+
"""Create an auth manager handling basic auth password rotation.
166+
167+
.. warning::
168+
169+
The provider function **must not** interact with the driver in any
170+
way as this can cause deadlocks and undefined behaviour.
171+
172+
The provider function must only ever return auth information
173+
belonging to the same identity.
174+
Switching identities is undefined behavior.
175+
You may use session-level authentication for such use-cases
176+
:ref:`session-auth-ref`.
177+
178+
Example::
179+
180+
import neo4j
181+
from neo4j.auth_management import (
182+
AsyncAuthManagers,
183+
ExpiringAuth,
184+
)
185+
186+
187+
async def auth_provider():
188+
# some way of getting a token
189+
user, password = await get_current_auth()
190+
return (user, password)
191+
192+
193+
with neo4j.GraphDatabase.driver(
194+
"neo4j://example.com:7687",
195+
auth=AsyncAuthManagers.basic(auth_provider)
196+
) as driver:
197+
... # do stuff
198+
199+
:param provider:
200+
A callable that provides a :class:`.ExpiringAuth` instance.
201+
202+
:returns:
203+
An instance of an implementation of :class:`.AsyncAuthManager` that
204+
returns auth info from the given provider and refreshes it, calling
205+
the provider again, when the auth info expires (either because it's
206+
reached its expiry time or because the server flagged it as
207+
expired).
208+
209+
.. versionadded:: 5.12
210+
"""
211+
handled_codes = frozenset(("Neo.ClientError.Security.Unauthorized",))
212+
213+
async def wrapped_provider() -> ExpiringAuth:
214+
with warnings.catch_warnings():
215+
warnings.filterwarnings("ignore",
216+
message=r"^Auth managers\b.*",
217+
category=PreviewWarning)
218+
return ExpiringAuth(await provider())
219+
220+
return AsyncNeo4jAuthTokenManager(wrapped_provider, handled_codes)
221+
222+
@staticmethod
223+
@preview("Auth managers are a preview feature.")
224+
def bearer(
143225
provider: t.Callable[[], t.Awaitable[ExpiringAuth]]
144226
) -> AsyncAuthManager:
145-
"""Create an auth manager for potentially expiring auth info.
227+
"""Create an auth manager for potentially expiring bearer auth tokens.
146228
147229
.. warning::
148230
@@ -165,7 +247,7 @@ def expiration_based(
165247
166248
167249
async def auth_provider():
168-
# some way to getting a token
250+
# some way of getting a token
169251
sso_token = await get_sso_token()
170252
# assume we know our tokens expire every 60 seconds
171253
expires_in = 60
@@ -180,7 +262,7 @@ async def auth_provider():
180262
181263
with neo4j.GraphDatabase.driver(
182264
"neo4j://example.com:7687",
183-
auth=AsyncAuthManagers.temporal(auth_provider)
265+
auth=AsyncAuthManagers.bearer(auth_provider)
184266
) as driver:
185267
... # do stuff
186268
@@ -194,6 +276,10 @@ async def auth_provider():
194276
reached its expiry time or because the server flagged it as
195277
expired).
196278
197-
279+
.. versionadded:: 5.12
198280
"""
199-
return AsyncExpirationBasedAuthManager(provider)
281+
handled_codes = frozenset((
282+
"Neo.ClientError.Security.TokenExpired",
283+
"Neo.ClientError.Security.Unauthorized",
284+
))
285+
return AsyncNeo4jAuthTokenManager(provider, handled_codes)

src/neo4j/_async/io/_pool.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
)
4949
from ..._exceptions import BoltError
5050
from ..._routing import RoutingTable
51-
from ..._sync.auth_management import StaticAuthManager
5251
from ...api import (
5352
READ_ACCESS,
5453
WRITE_ACCESS,
@@ -65,11 +64,8 @@
6564
ReadServiceUnavailable,
6665
ServiceUnavailable,
6766
SessionExpired,
68-
TokenExpired,
69-
TokenExpiredRetryable,
7067
WriteServiceUnavailable,
7168
)
72-
from ..auth_management import AsyncStaticAuthManager
7369
from ._bolt import AsyncBolt
7470

7571

@@ -467,15 +463,13 @@ async def on_neo4j_error(self, error, connection):
467463
with self.lock:
468464
for connection in self.connections.get(address, ()):
469465
connection.mark_unauthenticated()
470-
if error._requires_new_credentials():
471-
await AsyncUtil.callback(
472-
connection.auth_manager.on_auth_expired,
473-
connection.auth
466+
if error._has_security_code():
467+
handled = await AsyncUtil.callback(
468+
connection.auth_manager.handle_security_exception,
469+
connection.auth, error
474470
)
475-
if (isinstance(error, TokenExpired)
476-
and not isinstance(self.pool_config.auth, (AsyncStaticAuthManager,
477-
StaticAuthManager))):
478-
error.__class__ = TokenExpiredRetryable
471+
if handled:
472+
error._retryable = True
479473

480474
async def close(self):
481475
""" Close all connections and empty the pool.

src/neo4j/_auth_management.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
PreviewWarning,
3333
)
3434
from .api import _TAuth
35+
from .exceptions import Neo4jError
3536

3637

3738
@preview("Auth managers are a preview feature.")
@@ -128,6 +129,13 @@ class AuthManager(metaclass=abc.ABCMeta):
128129
.. seealso:: :class:`.AuthManagers`
129130
130131
.. versionadded:: 5.8
132+
133+
.. versionchanged:: 5.12
134+
``on_auth_expired`` was removed from the interface and replaced by
135+
:meth:`handle_security_exception`. The new method is called when the
136+
server returns any `Neo.ClientError.Security.*` error. It's signature
137+
differs in that it additionally received the error returned by the
138+
server and returns a boolean indicating whether the error was handled.
131139
"""
132140

133141
@abc.abstractmethod
@@ -148,15 +156,27 @@ def get_auth(self) -> _TAuth:
148156
...
149157

150158
@abc.abstractmethod
151-
def on_auth_expired(self, auth: _TAuth) -> None:
152-
"""Handle the server indicating expired authentication information.
159+
def handle_security_exception(
160+
self, auth: _TAuth, error: Neo4jError
161+
) -> bool:
162+
"""Handle the server indicating authentication failure.
153163
154-
The driver will call this method when the server indicates that the
155-
provided authentication information is no longer valid.
164+
The driver will call this method when the server returns any
165+
`Neo.ClientError.Security.*` error. The error will then be processed
166+
further as usual.
156167
157168
:param auth:
158-
The authentication information that the server flagged as no longer
159-
valid.
169+
The authentication information that was used when the server
170+
returned the error.
171+
:param error:
172+
The error returned by the server.
173+
174+
:returns:
175+
Whether the error was handled (:const:`True`), in which case the
176+
driver will mark the error as retryable
177+
(see :meth:`.Neo4jError.is_retryable`).
178+
179+
.. versionadded:: 5.12
160180
"""
161181
...
162182

@@ -171,6 +191,10 @@ class AsyncAuthManager(metaclass=abc.ABCMeta):
171191
.. seealso:: :class:`.AuthManager`
172192
173193
.. versionadded:: 5.8
194+
195+
.. versionchanged:: 5.12
196+
``on_auth_expired`` was removed from the interface and replaced by
197+
:meth:`handle_security_exception`. See :class:`.AuthManager`.
174198
"""
175199

176200
@abc.abstractmethod
@@ -182,7 +206,9 @@ async def get_auth(self) -> _TAuth:
182206
...
183207

184208
@abc.abstractmethod
185-
async def on_auth_expired(self, auth: _TAuth) -> None:
209+
async def handle_security_exception(
210+
self, auth: _TAuth, error: Neo4jError
211+
) -> bool:
186212
"""Async version of :meth:`.AuthManager.on_auth_expired`.
187213
188214
.. seealso:: :meth:`.AuthManager.on_auth_expired`

0 commit comments

Comments
 (0)