Skip to content

ADR 019: revamp auth managers #957

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1838,9 +1838,6 @@ Server-side errors
.. autoexception:: neo4j.exceptions.TokenExpired()
:show-inheritance:

.. autoexception:: neo4j.exceptions.TokenExpiredRetryable()
:show-inheritance:

.. autoexception:: neo4j.exceptions.Forbidden()
:show-inheritance:

Expand Down
118 changes: 102 additions & 16 deletions src/neo4j/_async/auth_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
# make sure TAuth is resolved in the docs, else they're pretty useless


import time
import typing as t
import warnings
from logging import getLogger

from .._async_compat.concurrency import AsyncLock
Expand All @@ -31,12 +31,16 @@
expiring_auth_has_expired,
ExpiringAuth,
)
from .._meta import preview
from .._meta import (
preview,
PreviewWarning,
)

# work around for https://github.com/sphinx-doc/sphinx/pull/10880
# make sure TAuth is resolved in the docs, else they're pretty useless
# if t.TYPE_CHECKING:
from ..api import _TAuth
from ..exceptions import Neo4jError


log = getLogger("neo4j")
Expand All @@ -51,21 +55,25 @@ def __init__(self, auth: _TAuth) -> None:
async def get_auth(self) -> _TAuth:
return self._auth

async def on_auth_expired(self, auth: _TAuth) -> None:
pass
async def handle_security_exception(
self, auth: _TAuth, error: Neo4jError
) -> bool:
return False


class AsyncExpirationBasedAuthManager(AsyncAuthManager):
class AsyncNeo4jAuthTokenManager(AsyncAuthManager):
_current_auth: t.Optional[ExpiringAuth]
_provider: t.Callable[[], t.Awaitable[ExpiringAuth]]
_handled_codes: t.FrozenSet[str]
_lock: AsyncLock


def __init__(
self,
provider: t.Callable[[], t.Awaitable[ExpiringAuth]]
provider: t.Callable[[], t.Awaitable[ExpiringAuth]],
handled_codes: t.FrozenSet[str]
) -> None:
self._provider = provider
self._handled_codes = handled_codes
self._current_auth = None
self._lock = AsyncLock()

Expand All @@ -81,18 +89,25 @@ async def get_auth(self) -> _TAuth:
async with self._lock:
auth = self._current_auth
if auth is None or expiring_auth_has_expired(auth):
log.debug("[ ] _: <TEMPORAL AUTH> refreshing (time out)")
log.debug("[ ] _: <AUTH MANAGER> refreshing (%s)",
"init" if auth is None else "time out")
await self._refresh_auth()
auth = self._current_auth
assert auth is not None
return auth.auth

async def on_auth_expired(self, auth: _TAuth) -> None:
async def handle_security_exception(
self, auth: _TAuth, error: Neo4jError
) -> bool:
if error.code not in self._handled_codes:
return False
async with self._lock:
cur_auth = self._current_auth
if cur_auth is not None and cur_auth.auth == auth:
log.debug("[ ] _: <TEMPORAL AUTH> refreshing (error)")
log.debug("[ ] _: <AUTH MANAGER> refreshing (error %s)",
error.code)
await self._refresh_auth()
return True


class AsyncAuthManagers:
Expand All @@ -103,6 +118,11 @@ class AsyncAuthManagers:
See also https://github.com/neo4j/neo4j-python-driver/wiki/preview-features

.. versionadded:: 5.8

.. versionchanged:: 5.12

* Method ``expiration_based()`` was renamed to :meth:`bearer`.
* Added :meth:`basic`.
"""

@staticmethod
Expand Down Expand Up @@ -139,10 +159,72 @@ def static(auth: _TAuth) -> AsyncAuthManager:

@staticmethod
@preview("Auth managers are a preview feature.")
def expiration_based(
def basic(
provider: t.Callable[[], t.Awaitable[_TAuth]]
) -> AsyncAuthManager:
"""Create an auth manager handling basic auth password rotation.

.. warning::

The provider function **must not** interact with the driver in any
way as this can cause deadlocks and undefined behaviour.

The provider function must only ever return auth information
belonging to the same identity.
Switching identities is undefined behavior.
You may use session-level authentication for such use-cases
:ref:`session-auth-ref`.

Example::

import neo4j
from neo4j.auth_management import (
AsyncAuthManagers,
ExpiringAuth,
)


async def auth_provider():
# some way of getting a token
user, password = await get_current_auth()
return (user, password)


with neo4j.GraphDatabase.driver(
"neo4j://example.com:7687",
auth=AsyncAuthManagers.basic(auth_provider)
) as driver:
... # do stuff

:param provider:
A callable that provides a :class:`.ExpiringAuth` instance.

:returns:
An instance of an implementation of :class:`.AsyncAuthManager` that
returns auth info from the given provider and refreshes it, calling
the provider again, when the auth info expires (either because it's
reached its expiry time or because the server flagged it as
expired).

.. versionadded:: 5.12
"""
handled_codes = frozenset(("Neo.ClientError.Security.Unauthorized",))

async def wrapped_provider() -> ExpiringAuth:
with warnings.catch_warnings():
warnings.filterwarnings("ignore",
message=r"^Auth managers\b.*",
category=PreviewWarning)
return ExpiringAuth(await provider())

return AsyncNeo4jAuthTokenManager(wrapped_provider, handled_codes)

@staticmethod
@preview("Auth managers are a preview feature.")
def bearer(
provider: t.Callable[[], t.Awaitable[ExpiringAuth]]
) -> AsyncAuthManager:
"""Create an auth manager for potentially expiring auth info.
"""Create an auth manager for potentially expiring bearer auth tokens.

.. warning::

Expand All @@ -165,7 +247,7 @@ def expiration_based(


async def auth_provider():
# some way to getting a token
# some way of getting a token
sso_token = await get_sso_token()
# assume we know our tokens expire every 60 seconds
expires_in = 60
Expand All @@ -180,7 +262,7 @@ async def auth_provider():

with neo4j.GraphDatabase.driver(
"neo4j://example.com:7687",
auth=AsyncAuthManagers.temporal(auth_provider)
auth=AsyncAuthManagers.bearer(auth_provider)
) as driver:
... # do stuff

Expand All @@ -194,6 +276,10 @@ async def auth_provider():
reached its expiry time or because the server flagged it as
expired).


.. versionadded:: 5.12
"""
return AsyncExpirationBasedAuthManager(provider)
handled_codes = frozenset((
"Neo.ClientError.Security.TokenExpired",
"Neo.ClientError.Security.Unauthorized",
))
return AsyncNeo4jAuthTokenManager(provider, handled_codes)
18 changes: 6 additions & 12 deletions src/neo4j/_async/io/_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
)
from ..._exceptions import BoltError
from ..._routing import RoutingTable
from ..._sync.auth_management import StaticAuthManager
from ...api import (
READ_ACCESS,
WRITE_ACCESS,
Expand All @@ -65,11 +64,8 @@
ReadServiceUnavailable,
ServiceUnavailable,
SessionExpired,
TokenExpired,
TokenExpiredRetryable,
WriteServiceUnavailable,
)
from ..auth_management import AsyncStaticAuthManager
from ._bolt import AsyncBolt


Expand Down Expand Up @@ -467,15 +463,13 @@ async def on_neo4j_error(self, error, connection):
with self.lock:
for connection in self.connections.get(address, ()):
connection.mark_unauthenticated()
if error._requires_new_credentials():
await AsyncUtil.callback(
connection.auth_manager.on_auth_expired,
connection.auth
if error._has_security_code():
handled = await AsyncUtil.callback(
connection.auth_manager.handle_security_exception,
connection.auth, error
)
if (isinstance(error, TokenExpired)
and not isinstance(self.pool_config.auth, (AsyncStaticAuthManager,
StaticAuthManager))):
error.__class__ = TokenExpiredRetryable
if handled:
error._retryable = True

async def close(self):
""" Close all connections and empty the pool.
Expand Down
40 changes: 33 additions & 7 deletions src/neo4j/_auth_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
PreviewWarning,
)
from .api import _TAuth
from .exceptions import Neo4jError


@preview("Auth managers are a preview feature.")
Expand Down Expand Up @@ -128,6 +129,13 @@ class AuthManager(metaclass=abc.ABCMeta):
.. seealso:: :class:`.AuthManagers`

.. versionadded:: 5.8

.. versionchanged:: 5.12
``on_auth_expired`` was removed from the interface and replaced by
:meth:`handle_security_exception`. The new method is called when the
server returns any `Neo.ClientError.Security.*` error. It's signature
differs in that it additionally received the error returned by the
server and returns a boolean indicating whether the error was handled.
"""

@abc.abstractmethod
Expand All @@ -148,15 +156,27 @@ def get_auth(self) -> _TAuth:
...

@abc.abstractmethod
def on_auth_expired(self, auth: _TAuth) -> None:
"""Handle the server indicating expired authentication information.
def handle_security_exception(
self, auth: _TAuth, error: Neo4jError
) -> bool:
"""Handle the server indicating authentication failure.

The driver will call this method when the server indicates that the
provided authentication information is no longer valid.
The driver will call this method when the server returns any
`Neo.ClientError.Security.*` error. The error will then be processed
further as usual.

:param auth:
The authentication information that the server flagged as no longer
valid.
The authentication information that was used when the server
returned the error.
:param error:
The error returned by the server.

:returns:
Whether the error was handled (:const:`True`), in which case the
driver will mark the error as retryable
(see :meth:`.Neo4jError.is_retryable`).

.. versionadded:: 5.12
"""
...

Expand All @@ -171,6 +191,10 @@ class AsyncAuthManager(metaclass=abc.ABCMeta):
.. seealso:: :class:`.AuthManager`

.. versionadded:: 5.8

.. versionchanged:: 5.12
``on_auth_expired`` was removed from the interface and replaced by
:meth:`handle_security_exception`. See :class:`.AuthManager`.
"""

@abc.abstractmethod
Expand All @@ -182,7 +206,9 @@ async def get_auth(self) -> _TAuth:
...

@abc.abstractmethod
async def on_auth_expired(self, auth: _TAuth) -> None:
async def handle_security_exception(
self, auth: _TAuth, error: Neo4jError
) -> bool:
"""Async version of :meth:`.AuthManager.on_auth_expired`.

.. seealso:: :meth:`.AuthManager.on_auth_expired`
Expand Down
Loading