diff --git a/.gitignore b/.gitignore index f485e71be..16e9a90ab 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,4 @@ neo4j-enterprise-* testkit/CAs testkit/CustomCAs +/.vs diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index 8865e09cf..e5c59ce9a 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -49,6 +49,7 @@ AsyncInbox, AsyncOutbox, CommitResponse, + auth_to_dict, ) @@ -142,7 +143,7 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, self.auth_dict = vars(Auth("basic", *auth)) else: try: - self.auth_dict = vars(auth) + self.auth_dict = auth_to_dict(auth) except (KeyError, TypeError): raise AuthError("Cannot determine auth details from %r" % auth) diff --git a/src/neo4j/_async/io/_common.py b/src/neo4j/_async/io/_common.py index fa21d3b60..b5ab425c1 100644 --- a/src/neo4j/_async/io/_common.py +++ b/src/neo4j/_async/io/_common.py @@ -294,3 +294,12 @@ async def receive_into_buffer(sock, buffer, n_bytes): if n == 0: raise OSError("No data") buffer.used += n + + +def auth_to_dict(auth): + auth_dict = vars(auth).copy() + if "credentials_refresher" in auth_dict: + if auth_dict["credentials_refresher"] is not None: + auth_dict["credentials"] = auth_dict["credentials_refresher"]() + auth_dict.pop("credentials_refresher") + return auth_dict diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 9980b8c71..ec7195d62 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -24,6 +24,7 @@ from logging import getLogger from time import perf_counter +from ..._async.io._common import auth_to_dict from ..._async_compat.network import BoltSocket from ..._codec.hydration import v1 as hydration_v1 from ..._codec.packstream import v1 as packstream_v1 @@ -142,7 +143,7 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, self.auth_dict = vars(Auth("basic", *auth)) else: try: - self.auth_dict = vars(auth) + self.auth_dict = auth_to_dict(auth) except (KeyError, TypeError): raise AuthError("Cannot determine auth details from %r" % auth) diff --git a/src/neo4j/api.py b/src/neo4j/api.py index e91aed6e7..535b55320 100644 --- a/src/neo4j/api.py +++ b/src/neo4j/api.py @@ -102,6 +102,10 @@ def __init__( self.realm = realm if parameters: self.parameters = parameters + self.credentials_refresher = None + + def add_credentials_refresher(self, credentials_refresher: t.Any): + self.credentials_refresher = credentials_refresher # For backwards compatibility diff --git a/tests/unit/async_/io/test__common.py b/tests/unit/async_/io/test__common.py index 1c14ea202..28e6f206e 100644 --- a/tests/unit/async_/io/test__common.py +++ b/tests/unit/async_/io/test__common.py @@ -18,7 +18,8 @@ import pytest -from neo4j._async.io._common import AsyncOutbox +from neo4j import basic_auth +from neo4j._async.io._common import AsyncOutbox, auth_to_dict from neo4j._codec.packstream.v1 import PackableBuffer from ...._async_compat import mark_async_test @@ -58,3 +59,27 @@ async def test_async_outbox_chunking(chunk_size, data, result, mocker): assert not await outbox.flush() socket_mock.sendall.assert_awaited_once() + + +def test_auth_to_dict_without_refresher(): + auth = basic_auth("some_login", "some_password") + result = auth_to_dict(auth) + assert result == {"principal": "some_login", + "credentials": "some_password", + "scheme": "basic"} + + +def test_auth_to_dict_with_refresher(): + creds = "old credentials" + auth = basic_auth("some_login", "ignored") + auth.add_credentials_refresher(lambda: creds) + result = auth_to_dict(auth) + assert result == {"principal": "some_login", + "credentials": "old credentials", + "scheme": "basic"} + # run again and should get new credentials + creds = "new credentials" + result = auth_to_dict(auth) + assert result == {"principal": "some_login", + "credentials": "new credentials", + "scheme": "basic"}