From 7da31ee68c7b60b04bcc522dd84f189edf5477b9 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 11 Aug 2023 13:23:39 +0200 Subject: [PATCH 1/7] ADR 019: re-auth revamp * Adjust auth manager interface and corresponding TestKit messages * Rename `ExpirationBasedAuthTokenManager` to `BearerAuthTokenManager` * Add `BasicAuthTokenManager` * Add and adjust tests --- nutkit/frontend/__init__.py | 3 +- nutkit/frontend/auth_token_manager.py | 150 +++++++-- nutkit/frontend/driver.py | 12 +- nutkit/protocol/feature.py | 2 + nutkit/protocol/requests.py | 41 ++- nutkit/protocol/responses.py | 75 +++-- ...ed.script => reader_reauth_handled.script} | 2 +- ...t => reader_reauth_handled_minimal.script} | 2 +- .../v5x0/reader_reauth_unhandled.script | 80 +++++ ...ed.script => writer_reauth_handled.script} | 0 ...t => writer_reauth_handled_minimal.script} | 0 .../v5x0/writer_reauth_unhandled.script | 24 ++ ...ed.script => reader_reauth_handled.script} | 2 +- ...=> reader_reauth_handled_pipelined.script} | 2 +- ...r_reauth_handled_pipelined_minimal.script} | 2 +- .../v5x1/reader_reauth_unhandled.script | 81 +++++ ...ed.script => writer_reauth_handled.script} | 0 ...=> writer_reauth_handled_pipelined.script} | 0 ...r_reauth_handled_pipelined_minimal.script} | 0 .../v5x1/writer_reauth_unhandled.script | 25 ++ .../authorization/test_auth_token_manager.py | 282 ++++++++--------- .../stub/authorization/test_authorization.py | 155 ++++++++- .../authorization/test_basic_auth_manager.py | 298 ++++++++++++++++++ ...manager.py => test_bearer_auth_manager.py} | 82 +++-- .../test_token_expired_retry.py | 62 ++-- 25 files changed, 1079 insertions(+), 303 deletions(-) rename tests/stub/authorization/scripts/v5x0/{reader_reauth_token_expired.script => reader_reauth_handled.script} (95%) rename tests/stub/authorization/scripts/v5x0/{reader_reauth_token_expired_minimal.script => reader_reauth_handled_minimal.script} (95%) create mode 100644 tests/stub/authorization/scripts/v5x0/reader_reauth_unhandled.script rename tests/stub/authorization/scripts/v5x0/{writer_reauth_token_expired.script => writer_reauth_handled.script} (100%) rename tests/stub/authorization/scripts/v5x0/{writer_reauth_token_expired_minimal.script => writer_reauth_handled_minimal.script} (100%) create mode 100644 tests/stub/authorization/scripts/v5x0/writer_reauth_unhandled.script rename tests/stub/authorization/scripts/v5x1/{reader_reauth_token_expired.script => reader_reauth_handled.script} (95%) rename tests/stub/authorization/scripts/v5x1/{reader_reauth_token_expired_pipelined.script => reader_reauth_handled_pipelined.script} (96%) rename tests/stub/authorization/scripts/v5x1/{reader_reauth_token_expired_pipelined_minimal.script => reader_reauth_handled_pipelined_minimal.script} (96%) create mode 100644 tests/stub/authorization/scripts/v5x1/reader_reauth_unhandled.script rename tests/stub/authorization/scripts/v5x1/{writer_reauth_token_expired.script => writer_reauth_handled.script} (100%) rename tests/stub/authorization/scripts/v5x1/{writer_reauth_token_expired_pipelined.script => writer_reauth_handled_pipelined.script} (100%) rename tests/stub/authorization/scripts/v5x1/{writer_reauth_token_expired_pipelined_minimal.script => writer_reauth_handled_pipelined_minimal.script} (100%) create mode 100644 tests/stub/authorization/scripts/v5x1/writer_reauth_unhandled.script create mode 100644 tests/stub/authorization/test_basic_auth_manager.py rename tests/stub/authorization/{test_expiration_based_auth_manager.py => test_bearer_auth_manager.py} (80%) diff --git a/nutkit/frontend/__init__.py b/nutkit/frontend/__init__.py index 991330a17..fb9d13cc2 100644 --- a/nutkit/frontend/__init__.py +++ b/nutkit/frontend/__init__.py @@ -1,6 +1,7 @@ from .auth_token_manager import ( AuthTokenManager, - ExpirationBasedAuthTokenManager, + BasicAuthTokenManager, + BearerAuthTokenManager, ) from .bookmark_manager import ( BookmarkManager, diff --git a/nutkit/frontend/auth_token_manager.py b/nutkit/frontend/auth_token_manager.py index f183b3ea4..90e19435d 100644 --- a/nutkit/frontend/auth_token_manager.py +++ b/nutkit/frontend/auth_token_manager.py @@ -8,8 +8,38 @@ Dict, ) -from .. import protocol from ..backend import Backend +from ..protocol import ( + AuthorizationToken, + AuthTokenAndExpiration, +) +from ..protocol import AuthTokenManager as AuthTokenManagerMessage +from ..protocol import ( + AuthTokenManagerClose, + AuthTokenManagerGetAuthCompleted, + AuthTokenManagerGetAuthRequest, + AuthTokenManagerHandleSecurityExceptionCompleted, + AuthTokenManagerHandleSecurityExceptionRequest, +) +from ..protocol import BasicAuthTokenManager as BasicAuthTokenManagerMessage +from ..protocol import ( + BasicAuthTokenProviderCompleted, + BasicAuthTokenProviderRequest, +) +from ..protocol import BearerAuthTokenManager as BearerAuthTokenManagerMessage +from ..protocol import ( + BearerAuthTokenProviderCompleted, + BearerAuthTokenProviderRequest, + NewAuthTokenManager, + NewBasicAuthTokenManager, + NewBearerAuthTokenManager, +) + +__all__ = [ + "AuthTokenManager", + "BasicAuthTokenManager", + "BearerAuthTokenManager", +] @dataclass @@ -19,16 +49,16 @@ class AuthTokenManager: def __init__( self, backend: Backend, - get_auth: Callable[[], protocol.AuthorizationToken], - on_auth_expired: Callable[[protocol.AuthorizationToken], None] + get_auth: Callable[[], AuthorizationToken], + handle_security_exception: Callable[[AuthorizationToken, str], bool] ): self._backend = backend self._get_auth = get_auth - self._on_auth_expired = on_auth_expired + self._handle_security_exception = handle_security_exception - req = protocol.NewAuthTokenManager() + req = NewAuthTokenManager() res = backend.send_and_receive(req) - if not isinstance(res, protocol.AuthTokenManager): + if not isinstance(res, AuthTokenManagerMessage): raise Exception(f"Should be AuthTokenManager but was {res}") self._auth_token_manager = res @@ -40,7 +70,7 @@ def id(self): @classmethod def process_callbacks(cls, request): - if isinstance(request, protocol.AuthTokenManagerGetAuthRequest): + if isinstance(request, AuthTokenManagerGetAuthRequest): if request.auth_token_manager_id not in cls._registry: raise Exception( "Backend provided unknown Auth Token Manager " @@ -48,25 +78,27 @@ def process_callbacks(cls, request): ) manager = cls._registry[request.auth_token_manager_id] auth_token = manager._get_auth() - return protocol.AuthTokenManagerGetAuthCompleted( + return AuthTokenManagerGetAuthCompleted( request.id, auth_token ) - if isinstance(request, protocol.AuthTokenManagerOnAuthExpiredRequest): + if isinstance(request, AuthTokenManagerHandleSecurityExceptionRequest): if request.auth_token_manager_id not in cls._registry: raise Exception( "Backend provided unknown Auth Token Manager " f"id: {request.auth_token_manager_id} not found" ) manager = cls._registry[request.auth_token_manager_id] - manager._on_auth_expired(request.auth) - return protocol.AuthTokenManagerOnAuthExpiredCompleted(request.id) + handled = manager._handle_security_exception(request.auth, + request.error_code) + return AuthTokenManagerHandleSecurityExceptionCompleted(request.id, + handled) def close(self, hooks=None): res = self._backend.send_and_receive( - protocol.AuthTokenManagerClose(self.id), + AuthTokenManagerClose(self.id), hooks=hooks ) - if not isinstance(res, protocol.AuthTokenManager): + if not isinstance(res, AuthTokenManagerMessage): raise Exception( f"Should be AuthTokenManager but was {res}" ) @@ -74,58 +106,116 @@ def close(self, hooks=None): @dataclass -class ExpirationBasedAuthTokenManager: - _registry: ClassVar[Dict[Any, ExpirationBasedAuthTokenManager]] = {} +class BasicAuthTokenManager: + _registry: ClassVar[Dict[Any, BasicAuthTokenManager]] = {} + + def __init__( + self, + backend: Backend, + callback: Callable[[], AuthorizationToken] + ): + self._backend = backend + self._callback = callback + + req = NewBasicAuthTokenManager() + res = backend.send_and_receive(req) + if not isinstance(res, BasicAuthTokenManagerMessage): + raise Exception( + f"Should be BasicAuthTokenManager but was {res}" + ) + + self._basic_auth_token_manager = res + self._registry[self._basic_auth_token_manager.id] = self + + @property + def id(self): + return self._basic_auth_token_manager.id + + @classmethod + def process_callbacks(cls, request): + if isinstance(request, + BasicAuthTokenProviderRequest): + if ( + request.basic_auth_token_manager_id + not in cls._registry + ): + raise Exception( + "Backend provided unknown BasicAuthTokenManager " + f"id: {request.basic_auth_token_manager_id} " + f"not found" + ) + + manager = cls._registry[ + request.basic_auth_token_manager_id + ] + renewable_auth_token = manager._callback() + return BasicAuthTokenProviderCompleted( + request.id, renewable_auth_token + ) + + def close(self, hooks=None): + res = self._backend.send_and_receive( + AuthTokenManagerClose(self.id), + hooks=hooks + ) + if not isinstance(res, AuthTokenManagerMessage): + raise Exception(f"Should be AuthTokenManager but was {res}") + del self._registry[self.id] + + +@dataclass +class BearerAuthTokenManager: + _registry: ClassVar[Dict[Any, BearerAuthTokenManager]] = {} def __init__( self, backend: Backend, - callback: Callable[[], protocol.AuthTokenAndExpiration] + callback: Callable[[], AuthTokenAndExpiration] ): self._backend = backend self._callback = callback - req = protocol.NewExpirationBasedAuthTokenManager() + req = NewBearerAuthTokenManager() res = backend.send_and_receive(req) - if not isinstance(res, protocol.ExpirationBasedAuthTokenManager): + if not isinstance(res, BearerAuthTokenManagerMessage): raise Exception( - f"Should be TemporalAuthTokenManager but was {res}" + f"Should be BearerAuthTokenManager but was {res}" ) - self._temporal_auth_token_manager = res - self._registry[self._temporal_auth_token_manager.id] = self + self._bearer_auth_token_manager = res + self._registry[self._bearer_auth_token_manager.id] = self @property def id(self): - return self._temporal_auth_token_manager.id + return self._bearer_auth_token_manager.id @classmethod def process_callbacks(cls, request): if isinstance(request, - protocol.ExpirationBasedAuthTokenProviderRequest): + BearerAuthTokenProviderRequest): if ( - request.expiration_based_auth_token_manager_id + request.bearer_auth_token_manager_id not in cls._registry ): raise Exception( - "Backend provided unknown ExpirationBasedAuthTokenManager " - f"id: {request.expiration_based_auth_token_manager_id} " + "Backend provided unknown BearerAuthTokenManager " + f"id: {request.bearer_auth_token_manager_id} " f"not found" ) manager = cls._registry[ - request.expiration_based_auth_token_manager_id + request.bearer_auth_token_manager_id ] renewable_auth_token = manager._callback() - return protocol.ExpirationBasedAuthTokenProviderCompleted( + return BearerAuthTokenProviderCompleted( request.id, renewable_auth_token ) def close(self, hooks=None): res = self._backend.send_and_receive( - protocol.AuthTokenManagerClose(self.id), + AuthTokenManagerClose(self.id), hooks=hooks ) - if not isinstance(res, protocol.AuthTokenManager): + if not isinstance(res, AuthTokenManagerMessage): raise Exception(f"Should be AuthTokenManager but was {res}") del self._registry[self.id] diff --git a/nutkit/frontend/driver.py b/nutkit/frontend/driver.py index de949b4f6..64e868706 100644 --- a/nutkit/frontend/driver.py +++ b/nutkit/frontend/driver.py @@ -1,7 +1,8 @@ from .. import protocol from .auth_token_manager import ( AuthTokenManager, - ExpirationBasedAuthTokenManager, + BasicAuthTokenManager, + BearerAuthTokenManager, ) from .bookmark_manager import BookmarkManager from .session import Session @@ -29,7 +30,9 @@ def __init__(self, backend, uri, auth_token, user_agent=None, self._auth_token = auth_token else: assert isinstance( - auth_token, (AuthTokenManager, ExpirationBasedAuthTokenManager) + auth_token, (AuthTokenManager, + BearerAuthTokenManager, + BasicAuthTokenManager) ) self._auth_token_manager = auth_token auth_token_manager_id = auth_token.id @@ -73,9 +76,10 @@ def receive(self, timeout=None, hooks=None, *, allow_resolution): ) continue for cb_processor in ( - BookmarkManager, - ExpirationBasedAuthTokenManager, AuthTokenManager, + BasicAuthTokenManager, + BearerAuthTokenManager, + BookmarkManager, ): cb_response = cb_processor.process_callbacks(res) if cb_response is not None: diff --git a/nutkit/protocol/feature.py b/nutkit/protocol/feature.py index f8ab03dd5..fcd3adf86 100644 --- a/nutkit/protocol/feature.py +++ b/nutkit/protocol/feature.py @@ -54,6 +54,8 @@ class Feature(Enum): # If there are more than records, the driver emits a warning. # This method is supposed to always exhaust the result stream. API_RESULT_SINGLE_OPTIONAL = "Feature:API:Result.SingleOptional" + # The driver offers a way to determine if exceptions are retryable or not. + API_RETRYABLE_EXCEPTION = "Feature:API:RetryableExceptions" # The session configuration allows to switch the authentication context # by supplying new credentials. This new context is only valid for the # current session. diff --git a/nutkit/protocol/requests.py b/nutkit/protocol/requests.py index 7c039c44c..ee14c430b 100644 --- a/nutkit/protocol/requests.py +++ b/nutkit/protocol/requests.py @@ -178,15 +178,16 @@ def __init__(self, request_id, auth): self.auth = auth -class AuthTokenManagerOnAuthExpiredCompleted: +class AuthTokenManagerHandleSecurityExceptionCompleted: """ - Result of a completed auth token provider function call. + Result of a completed security exception handler call. No response is expected. """ - def __init__(self, request_id): + def __init__(self, request_id, handled): self.requestId = request_id + self.handled = bool(handled) class AuthTokenManagerClose: @@ -203,20 +204,46 @@ def __init__(self, id): self.id = id -class NewExpirationBasedAuthTokenManager: +class NewBasicAuthTokenManager: + """ + Create a new token manager for password rotation on the backend. + + The manager will wrap a plain token provider function on the backend. + + The backend should respond with `BasicAuthTokenManager`. + """ + + def __init__(self): + pass + + +class BasicAuthTokenProviderCompleted: + """ + Result of a completed auth token provider function call. + + No response is expected. + """ + + def __init__(self, request_id, auth): + self.requestId = request_id + assert isinstance(auth, AuthorizationToken) + self.auth = auth + + +class NewBearerAuthTokenManager: """ - Create a new auth temporal token manager on the backend. + Create a new manager for potentially expiring bearer tokens on the backend. The manager will wrap a temporal token provider function on the backend. - The backend should respond with `ExpirationBasedAuthTokenManager`. + The backend should respond with `BearerAuthTokenManager`. """ def __init__(self): pass -class ExpirationBasedAuthTokenProviderCompleted: +class BearerAuthTokenProviderCompleted: """ Result of a completed auth token provider function call. diff --git a/nutkit/protocol/responses.py b/nutkit/protocol/responses.py index 13219f5d9..e4f801352 100644 --- a/nutkit/protocol/responses.py +++ b/nutkit/protocol/responses.py @@ -87,63 +87,96 @@ def __init__(self, id, authTokenManagerId): self.auth_token_manager_id = authTokenManagerId -class AuthTokenManagerOnAuthExpiredRequest: +class AuthTokenManagerHandleSecurityExceptionRequest: """ - Represents the need for getting an auth token from the manager. + Represents the driver notifying the manger of a security exception. This message may be sent by the backend at any time should the driver call - OnAuthExpired() on the manager that was previously created in response to - `NewAuthTokenManager`. + HandleSecurityException() on the manager that was previously created in + esponse to `NewAuthTokenManager`. - TestKit will respond with `TemporalAuthTokenProviderCompleted` with the + TestKit will respond with + `AuthTokenManagerHandleSecurityExceptionRequestCompleted`. """ - def __init__(self, id, authTokenManagerId, auth): + def __init__(self, id, authTokenManagerId, auth, errorCode): # Id of the request. TestKit will send the same id back as `requestId` # in the `TemporalTemporalAuthTokenProviderCompleted` response. self.id = id # Id of the auth token manager that spawned this request. self.auth_token_manager_id = authTokenManagerId from .requests import AuthorizationToken - - # The expired auth data. - assert isinstance(auth, AuthorizationToken) + assert isinstance(auth, AuthorizationToken) # The expired auth data self.auth = auth + self.error_code = errorCode + + +class BasicAuthTokenManager: + """ + Represents a new auth manager to handle password rotation. + + The passed id is used when creating a new driver (`NewDriver`) to refer to + this auth token manager + """ + + def __init__(self, id): + # Id of BasicAuthTokenManager instance on backend. + # Note that the id space needs to be shared with AuthTokenManager. + self.id = id + + +class BasicAuthTokenProviderRequest: + """ + Represents the need for a fresh auth token. + + This message may be sent by the backend at any time should the driver call + a temporal auth token provider function that was previously created in + response to `BearerAuthTokenManager`. + + TestKit will respond with `BasicAuthTokenProviderCompleted`. + """ + + def __init__(self, id, basicAuthTokenManagerId): + # Id of the request. TestKit will send the same id back as `requestId` + # in the `BasicAuthTokenProviderCompleted` response. + self.id = id + # Id of the temporal auth token manager that called its provider + # function. + self.basic_auth_token_manager_id = basicAuthTokenManagerId -class ExpirationBasedAuthTokenManager: +class BearerAuthTokenManager: """ - Represents a new expiration based auth token manager. + Represents a new auth manager to handle potentially expiring bearer tokens. The passed id is used when creating a new driver (`NewDriver`) to refer to this auth token manager """ def __init__(self, id): - # Id of ExpirationBasedAuthTokenManager instance on backend. + # Id of BearerAuthTokenManager instance on backend. # Note that the id space needs to be shared with AuthTokenManager. self.id = id -class ExpirationBasedAuthTokenProviderRequest: +class BearerAuthTokenProviderRequest: """ Represents the need for a fresh auth token. This message may be sent by the backend at any time should the driver call a temporal auth token provider function that was previously created in - response to `NewExpirationBasedAuthTokenManager`. + response to `BearerAuthTokenManager`. - TestKit will respond with `ExpirationBasedAuthTokenProviderCompleted`. + TestKit will respond with `BearerAuthTokenProviderCompleted`. """ - def __init__(self, id, expirationBasedAuthTokenManagerId): + def __init__(self, id, bearerAuthTokenManagerId): # Id of the request. TestKit will send the same id back as `requestId` - # in the `ExpirationBasedAuthTokenProviderCompleted` response. + # in the `BearerAuthTokenProviderCompleted` response. self.id = id # Id of the temporal auth token manager that called its provider # function. - self.expiration_based_auth_token_manager_id = \ - expirationBasedAuthTokenManagerId + self.bearer_auth_token_manager_id = bearerAuthTokenManagerId class ResolverResolutionRequired: @@ -652,11 +685,13 @@ class DriverError(BaseError): test framework needs to check detailed error handling. """ - def __init__(self, id=None, errorType=None, msg="", code=""): + def __init__(self, id=None, errorType=None, msg="", code="", + retryable=None): self.id = id self.errorType = errorType self.msg = msg self.code = code + self.retryable = retryable def __str__(self): return f"DriverError(type={self.errorType}, msg={self.msg!r})" diff --git a/tests/stub/authorization/scripts/v5x0/reader_reauth_token_expired.script b/tests/stub/authorization/scripts/v5x0/reader_reauth_handled.script similarity index 95% rename from tests/stub/authorization/scripts/v5x0/reader_reauth_token_expired.script rename to tests/stub/authorization/scripts/v5x0/reader_reauth_handled.script index 3cf98852f..315608e8f 100644 --- a/tests/stub/authorization/scripts/v5x0/reader_reauth_token_expired.script +++ b/tests/stub/authorization/scripts/v5x0/reader_reauth_handled.script @@ -31,7 +31,7 @@ *: RESET C: RUN "RETURN 2.2 AS n" "*" "*" - S: FAILURE {"code": "Neo.ClientError.Security.TokenExpired", "message": "Token expired."} + S: FAILURE #ERROR# S: ---- C: RUN "RETURN 3.1 AS n" "*" "*" diff --git a/tests/stub/authorization/scripts/v5x0/reader_reauth_token_expired_minimal.script b/tests/stub/authorization/scripts/v5x0/reader_reauth_handled_minimal.script similarity index 95% rename from tests/stub/authorization/scripts/v5x0/reader_reauth_token_expired_minimal.script rename to tests/stub/authorization/scripts/v5x0/reader_reauth_handled_minimal.script index ed549df9c..81a95939f 100644 --- a/tests/stub/authorization/scripts/v5x0/reader_reauth_token_expired_minimal.script +++ b/tests/stub/authorization/scripts/v5x0/reader_reauth_handled_minimal.script @@ -31,7 +31,7 @@ *: RESET C: RUN "RETURN 2.2 AS n" "*" "*" - S: FAILURE {"code": "Neo.ClientError.Security.TokenExpired", "message": "Token expired."} + S: FAILURE #ERROR# S: ---- C: RUN "RETURN 3.1 AS n" "*" "*" diff --git a/tests/stub/authorization/scripts/v5x0/reader_reauth_unhandled.script b/tests/stub/authorization/scripts/v5x0/reader_reauth_unhandled.script new file mode 100644 index 000000000..6dc4669b1 --- /dev/null +++ b/tests/stub/authorization/scripts/v5x0/reader_reauth_unhandled.script @@ -0,0 +1,80 @@ +!: BOLT #VERSION# +!: ALLOW CONCURRENT + +A: HELLO {"user_agent": "*", "[routing]": "*", "scheme": "basic", "principal": "neo4j", "credentials": "pass", "[realm]": ""} + +*: RESET + +C: BEGIN {"{}": "*"} +S: SUCCESS {} +# Three concurrent connections +{{ + C: RUN "RETURN 1.1 AS n" "*" "*" + S: SUCCESS {"fields": ["n"]} + C: PULL "*" + S: RECORD [1] + SUCCESS {"type": "r"} + C: COMMIT + S: SUCCESS {} + + # now the second connection receives the error + # => this connection should not be affected + *: RESET + + C: BEGIN {"{}": "*"} + S: SUCCESS {} + C: RUN "RETURN 1.2 AS n" "*" "*" + S: SUCCESS {"fields": ["n"]} + C: PULL "*" + S: RECORD [1] + SUCCESS {"type": "r"} + C: COMMIT + S: SUCCESS {} +---- + C: RUN "RETURN 2.1 AS n" "*" "*" + S: SUCCESS {"fields": ["n"]} + C: PULL "*" + S: RECORD [1] + SUCCESS {"type": "r"} + C: COMMIT + S: SUCCESS {} + + *: RESET + + C: RUN "RETURN 2.2 AS n" "*" "*" + S: FAILURE #ERROR# + S: +---- + C: RUN "RETURN 2.3 AS n" "*" "*" + S: SUCCESS {"fields": ["n"]} + C: PULL "*" + S: RECORD [1] + SUCCESS {"type": "r"} + C: COMMIT + S: SUCCESS {} +---- + C: RUN "RETURN 3.1 AS n" "*" "*" + S: SUCCESS {"fields": ["n"]} + C: PULL "*" + S: RECORD [1] + SUCCESS {"type": "r"} + C: COMMIT + S: SUCCESS {} + + # now the second connection receives the error + # => this connection should not be affected + *: RESET + + C: BEGIN {"{}": "*"} + S: SUCCESS {} + C: RUN "RETURN 3.2 AS n" "*" "*" + S: SUCCESS {"fields": ["n"]} + C: PULL "*" + S: RECORD [1] + SUCCESS {"type": "r"} + C: COMMIT + S: SUCCESS {} +}} + +*: RESET +?: GOODBYE diff --git a/tests/stub/authorization/scripts/v5x0/writer_reauth_token_expired.script b/tests/stub/authorization/scripts/v5x0/writer_reauth_handled.script similarity index 100% rename from tests/stub/authorization/scripts/v5x0/writer_reauth_token_expired.script rename to tests/stub/authorization/scripts/v5x0/writer_reauth_handled.script diff --git a/tests/stub/authorization/scripts/v5x0/writer_reauth_token_expired_minimal.script b/tests/stub/authorization/scripts/v5x0/writer_reauth_handled_minimal.script similarity index 100% rename from tests/stub/authorization/scripts/v5x0/writer_reauth_token_expired_minimal.script rename to tests/stub/authorization/scripts/v5x0/writer_reauth_handled_minimal.script diff --git a/tests/stub/authorization/scripts/v5x0/writer_reauth_unhandled.script b/tests/stub/authorization/scripts/v5x0/writer_reauth_unhandled.script new file mode 100644 index 000000000..d9e281b36 --- /dev/null +++ b/tests/stub/authorization/scripts/v5x0/writer_reauth_unhandled.script @@ -0,0 +1,24 @@ +!: BOLT #VERSION# + +A: HELLO {"user_agent": "*", "[routing]": "*", "scheme": "basic", "principal": "neo4j", "credentials": "pass", "[realm]": ""} + +*: RESET + +C: RUN "RETURN 1 AS n" "*" "*" +S: SUCCESS {"fields": ["n"]} +C: PULL "*" +S: RECORD [1] + SUCCESS {"type": "w"} + +# reader fails now with #ERROR# +# => this connection to a different host should not be affected +*: RESET + +C: RUN "RETURN 2 AS n" "*" "*" +S: SUCCESS {"fields": ["n"]} +C: PULL "*" +S: RECORD [1] + SUCCESS {"type": "w"} + +*: RESET +?: GOODBYE diff --git a/tests/stub/authorization/scripts/v5x1/reader_reauth_token_expired.script b/tests/stub/authorization/scripts/v5x1/reader_reauth_handled.script similarity index 95% rename from tests/stub/authorization/scripts/v5x1/reader_reauth_token_expired.script rename to tests/stub/authorization/scripts/v5x1/reader_reauth_handled.script index 941e7b9d5..0a2d97906 100644 --- a/tests/stub/authorization/scripts/v5x1/reader_reauth_token_expired.script +++ b/tests/stub/authorization/scripts/v5x1/reader_reauth_handled.script @@ -48,7 +48,7 @@ A: HELLO {"user_agent": "*", "[routing]": "*"} *: RESET C: RUN "RETURN 2.2 AS n" "*" "*" - S: FAILURE {"code": "Neo.ClientError.Security.TokenExpired", "message": "Token expired."} + S: FAILURE #ERROR# S: ---- C: RUN "RETURN 3.1 AS n" "*" "*" diff --git a/tests/stub/authorization/scripts/v5x1/reader_reauth_token_expired_pipelined.script b/tests/stub/authorization/scripts/v5x1/reader_reauth_handled_pipelined.script similarity index 96% rename from tests/stub/authorization/scripts/v5x1/reader_reauth_token_expired_pipelined.script rename to tests/stub/authorization/scripts/v5x1/reader_reauth_handled_pipelined.script index bce9bbf12..f7a7d7dd2 100644 --- a/tests/stub/authorization/scripts/v5x1/reader_reauth_token_expired_pipelined.script +++ b/tests/stub/authorization/scripts/v5x1/reader_reauth_handled_pipelined.script @@ -60,7 +60,7 @@ C: HELLO {"user_agent": "*", "[routing]": "*"} *: RESET C: RUN "RETURN 2.2 AS n" "*" "*" - S: FAILURE {"code": "Neo.ClientError.Security.TokenExpired", "message": "Token expired."} + S: FAILURE #ERROR# S: ---- C: RUN "RETURN 3.1 AS n" "*" "*" diff --git a/tests/stub/authorization/scripts/v5x1/reader_reauth_token_expired_pipelined_minimal.script b/tests/stub/authorization/scripts/v5x1/reader_reauth_handled_pipelined_minimal.script similarity index 96% rename from tests/stub/authorization/scripts/v5x1/reader_reauth_token_expired_pipelined_minimal.script rename to tests/stub/authorization/scripts/v5x1/reader_reauth_handled_pipelined_minimal.script index 44a24b2fd..860270201 100644 --- a/tests/stub/authorization/scripts/v5x1/reader_reauth_token_expired_pipelined_minimal.script +++ b/tests/stub/authorization/scripts/v5x1/reader_reauth_handled_pipelined_minimal.script @@ -60,7 +60,7 @@ C: HELLO {"user_agent": "*", "[routing]": "*"} *: RESET C: RUN "RETURN 2.2 AS n" "*" "*" - S: FAILURE {"code": "Neo.ClientError.Security.TokenExpired", "message": "Token expired."} + S: FAILURE #ERROR# S: ---- C: RUN "RETURN 3.1 AS n" "*" "*" diff --git a/tests/stub/authorization/scripts/v5x1/reader_reauth_unhandled.script b/tests/stub/authorization/scripts/v5x1/reader_reauth_unhandled.script new file mode 100644 index 000000000..3110912ee --- /dev/null +++ b/tests/stub/authorization/scripts/v5x1/reader_reauth_unhandled.script @@ -0,0 +1,81 @@ +!: BOLT #VERSION# +!: ALLOW CONCURRENT + +A: HELLO {"user_agent": "*", "[routing]": "*"} +A: LOGON {"scheme": "basic", "principal": "neo4j", "credentials": "pass", "[realm]": ""} + +*: RESET + +C: BEGIN {"{}": "*"} +S: SUCCESS {} +# Three concurrent connections +{{ + C: RUN "RETURN 1.1 AS n" "*" "*" + S: SUCCESS {"fields": ["n"]} + C: PULL "*" + S: RECORD [1] + SUCCESS {"type": "r"} + C: COMMIT + S: SUCCESS {} + + # now the second connection receives the error + # => this connection should not be affected + *: RESET + + C: BEGIN {"{}": "*"} + S: SUCCESS {} + C: RUN "RETURN 1.2 AS n" "*" "*" + S: SUCCESS {"fields": ["n"]} + C: PULL "*" + S: RECORD [1] + SUCCESS {"type": "r"} + C: COMMIT + S: SUCCESS {} +---- + C: RUN "RETURN 2.1 AS n" "*" "*" + S: SUCCESS {"fields": ["n"]} + C: PULL "*" + S: RECORD [1] + SUCCESS {"type": "r"} + C: COMMIT + S: SUCCESS {} + + *: RESET + + C: RUN "RETURN 2.2 AS n" "*" "*" + S: FAILURE #ERROR# + S: +---- + C: RUN "RETURN 3.1 AS n" "*" "*" + S: SUCCESS {"fields": ["n"]} + C: PULL "*" + S: RECORD [1] + SUCCESS {"type": "r"} + C: COMMIT + S: SUCCESS {} + + # now the second connection receives the error + # => this connection should not be affected + *: RESET + + C: BEGIN {"{}": "*"} + S: SUCCESS {} + C: RUN "RETURN 3.2 AS n" "*" "*" + S: SUCCESS {"fields": ["n"]} + C: PULL "*" + S: RECORD [1] + SUCCESS {"type": "r"} + C: COMMIT + S: SUCCESS {} +---- + C: RUN "RETURN 2.3 AS n" "*" "*" + S: SUCCESS {"fields": ["n"]} + C: PULL "*" + S: RECORD [1] + SUCCESS {"type": "r"} + C: COMMIT + S: SUCCESS {} +}} + +*: RESET +?: GOODBYE diff --git a/tests/stub/authorization/scripts/v5x1/writer_reauth_token_expired.script b/tests/stub/authorization/scripts/v5x1/writer_reauth_handled.script similarity index 100% rename from tests/stub/authorization/scripts/v5x1/writer_reauth_token_expired.script rename to tests/stub/authorization/scripts/v5x1/writer_reauth_handled.script diff --git a/tests/stub/authorization/scripts/v5x1/writer_reauth_token_expired_pipelined.script b/tests/stub/authorization/scripts/v5x1/writer_reauth_handled_pipelined.script similarity index 100% rename from tests/stub/authorization/scripts/v5x1/writer_reauth_token_expired_pipelined.script rename to tests/stub/authorization/scripts/v5x1/writer_reauth_handled_pipelined.script diff --git a/tests/stub/authorization/scripts/v5x1/writer_reauth_token_expired_pipelined_minimal.script b/tests/stub/authorization/scripts/v5x1/writer_reauth_handled_pipelined_minimal.script similarity index 100% rename from tests/stub/authorization/scripts/v5x1/writer_reauth_token_expired_pipelined_minimal.script rename to tests/stub/authorization/scripts/v5x1/writer_reauth_handled_pipelined_minimal.script diff --git a/tests/stub/authorization/scripts/v5x1/writer_reauth_unhandled.script b/tests/stub/authorization/scripts/v5x1/writer_reauth_unhandled.script new file mode 100644 index 000000000..d4f24ff25 --- /dev/null +++ b/tests/stub/authorization/scripts/v5x1/writer_reauth_unhandled.script @@ -0,0 +1,25 @@ +!: BOLT #VERSION# + +A: HELLO {"user_agent": "*", "[routing]": "*"} +A: LOGON {"scheme": "basic", "principal": "neo4j", "credentials": "pass", "[realm]": ""} + +*: RESET + +C: RUN "RETURN 1 AS n" "*" "*" +S: SUCCESS {"fields": ["n"]} +C: PULL "*" +S: RECORD [1] + SUCCESS {"type": "w"} + +# reader fails now with #ERROR# +# => this connection to a different host should not be affected +*: RESET + +C: RUN "RETURN 2 AS n" "*" "*" +S: SUCCESS {"fields": ["n"]} +C: PULL "*" +S: RECORD [1] + SUCCESS {"type": "w"} + +*: RESET +?: GOODBYE diff --git a/tests/stub/authorization/test_auth_token_manager.py b/tests/stub/authorization/test_auth_token_manager.py index b291ec6c2..6ba6983d1 100644 --- a/tests/stub/authorization/test_auth_token_manager.py +++ b/tests/stub/authorization/test_auth_token_manager.py @@ -1,4 +1,6 @@ +import json from contextlib import contextmanager +from dataclasses import dataclass import nutkit.protocol as types from nutkit.frontend import ( @@ -10,13 +12,19 @@ from tests.stub.shared import StubServer +@dataclass(frozen=True) +class HandleSecurityExceptionArgs: + auth: types.AuthorizationToken + error_code: str + + class TrackingAuthTokenManager: def __init__(self, backend): self._backend = backend self._get_auth_count = 0 - self._on_auth_expired_args = [] + self._handle_security_exception_args = [] self._manager = AuthTokenManager( - backend, self.get_auth, self.on_auth_expired + backend, self.get_auth, self.handle_security_exception ) def get_auth(self): @@ -30,20 +38,27 @@ def raw_get_auth(self): credentials="pass" ) - def on_auth_expired(self, auth): - self._on_auth_expired_args.append(auth) + def handle_security_exception( + self, auth: types.AuthorizationToken, code: str + ) -> bool: + args = HandleSecurityExceptionArgs(auth, code) + self._handle_security_exception_args.append(args) + return self._handles_security_exception(code) + + def _handles_security_exception(self, code: str) -> bool: + return False @property def get_auth_count(self): return self._get_auth_count @property - def on_auth_expired_args(self): - return self._on_auth_expired_args + def handle_security_exception_args(self): + return self._handle_security_exception_args @property - def on_auth_expired_count(self): - return len(self._on_auth_expired_args) + def handle_security_exception_count(self): + return len(self._handle_security_exception_args) @property def manager(self): @@ -127,7 +142,7 @@ def _test(routing_): if routing_: self._router.done() self.assertEqual(auth_manager.get_auth_count, 2 if routing_ else 1) - self.assertEqual(auth_manager.on_auth_expired_count, 0) + self.assertEqual(auth_manager.handle_security_exception_count, 0) self.post_script_assertions(self._reader) for routing in (False, True): @@ -142,7 +157,7 @@ def test_dynamic_auth_manager(self): def _test(routing_): get_auth_count = 0 current_auth = ("neo4j", "pass") - on_auth_expired_count = 0 + handle_security_exception_count = 0 def get_auth(): nonlocal get_auth_count @@ -154,12 +169,13 @@ def get_auth(): credentials=password ) - def on_auth_expired(auth_token): - nonlocal on_auth_expired_count - on_auth_expired_count += 1 + def handle_security_exception(auth_token, code): + nonlocal handle_security_exception_count + handle_security_exception_count += 1 + return False - auth_manager = AuthTokenManager(self._backend, - get_auth, on_auth_expired) + auth_manager = AuthTokenManager(self._backend, get_auth, + handle_security_exception) self.start_server( self._reader, @@ -196,7 +212,7 @@ def on_auth_expired(auth_token): if routing_: expected_get_auth_count *= 2 self.assertEqual(get_auth_count, expected_get_auth_count) - self.assertEqual(on_auth_expired_count, 0) + self.assertEqual(handle_security_exception_count, 0) self.post_script_assertions(self._reader) logon_message = "HELLO" if self.backwards_compatible else "LOGON" if routing_: @@ -213,8 +229,30 @@ def on_auth_expired(auth_token): self._reader.reset() self._router.reset() - def _test_notify(self, error, error_assertion, script, session_cb): - def _test(routing_, session_auth_): + def _get_error_assertion(self, error, handled): + def retryable(f, can_retry): + def inner(*args, **kwargs): + kwargs["retryable"] = can_retry + f(*args, **kwargs) + return inner + + if error == self._AUTH_EXPIRED: + return retryable(self.assert_is_authorization_error, handled) + elif error == self._TOKEN_EXPIRED: + return retryable(self.assert_is_token_error, handled) + elif error == self._UNAUTHORIZED: + return retryable(self.assert_is_unauthorized_error, handled) + elif error == self._SECURITY_EXC: + return retryable(self.assert_is_security_error, handled) + elif error == self._TRANSIENT_EXC: + return self.assert_is_transient_error + elif error == self._RANDOM_EXC: + return self.assert_is_random_error + else: + raise ValueError(f"Unknown error: {error}") + + def _test_notify(self, script, session_cb): + def _test(routing_, session_auth_, should_notify_, handled_): if session_auth_: session_auth_ = types.AuthorizationToken( scheme="basic", @@ -223,7 +261,12 @@ def _test(routing_, session_auth_): ) else: session_auth_ = None - manager = TrackingAuthTokenManager(self._backend) + + class AuthManager(TrackingAuthTokenManager): + def _handles_security_exception(self, code: str) -> bool: + return handled_ + + manager = AuthManager(self._backend) if routing_: self.start_server(self._router, "router_single_reader.script") vars_ = self.get_vars() @@ -232,6 +275,9 @@ def _test(routing_, session_auth_): with self.driver(manager.manager, routing=routing_) as driver: with self.session(driver, auth_token=session_auth_) as session: exc = session_cb(session) + error_assertion = self._get_error_assertion( + error, handled_ + ) error_assertion(exc.exception) self._reader.done() if routing_: @@ -243,12 +289,15 @@ def _test(routing_, session_auth_): self.assertEqual(manager.get_auth_count, 2) else: self.assertEqual(manager.get_auth_count, 1) - if error == self._TOKEN_EXPIRED and not session_auth_: - self.assertEqual(manager.on_auth_expired_count, 1) - self.assertEqual(manager.on_auth_expired_args, - [manager.raw_get_auth()]) + if should_notify_: + self.assertEqual(manager.handle_security_exception_count, 1) + self.assertEqual( + manager.handle_security_exception_args, + [HandleSecurityExceptionArgs(manager.raw_get_auth(), + error_code)] + ) else: - self.assertEqual(manager.on_auth_expired_count, 0) + self.assertEqual(manager.handle_security_exception_count, 0) self.post_script_assertions(self._reader) @@ -258,37 +307,49 @@ def _test(routing_, session_auth_): for session_auth in session_auths: for routing in (False, True): - with self.subTest(routing=routing, session_auth=session_auth): - try: - _test(routing, session_auth) - finally: - self._reader.reset() - self._router.reset() - - def _notify_on_failed_pull_using_session_run(self, error, error_assertion): + for error in ( + self._AUTH_EXPIRED, + self._TOKEN_EXPIRED, + self._UNAUTHORIZED, + self._SECURITY_EXC, + self._TRANSIENT_EXC, + self._RANDOM_EXC, + ): + error_code = self._get_error_code(error) + should_notify = ( + error_code.startswith("Neo.ClientError.Security.") + and not session_auth + ) + handles = [False] + if should_notify: + handles.append(True) + for handled in handles: + with self.subTest( + routing=routing, session_auth=session_auth, + error=error_code, handled=handled + ): + try: + _test(routing, session_auth, should_notify, + handled) + finally: + self._reader.reset() + self._router.reset() + + @staticmethod + def _get_error_code(error): + print(error) + return json.loads(error)["code"] + + def test_error_on_pull_using_session_run(self): def session_cb(session): result = session.run("RETURN 1 AS n") with self.assertRaises(types.DriverError) as exc: result.next() return exc - self._test_notify( - error, error_assertion, - "reader_yielding_error_on_pull.script", - session_cb - ) - - def test_not_notify_on_auth_expired_pull_using_session_run(self): - self._notify_on_failed_pull_using_session_run( - self._AUTH_EXPIRED, self.assert_is_authorization_error - ) + self._test_notify("reader_yielding_error_on_pull.script", session_cb) - def test_notify_on_token_expired_pull_using_session_run(self): - self._notify_on_failed_pull_using_session_run( - self._TOKEN_EXPIRED, self.assert_is_retryable_token_error - ) - - def _notify_on_failed_begin_using_tx_run(self, error, error_assertion): + def test_error_on_begin_using_tx_run(self): def session_cb(session): if not self.driver_supports_features( types.Feature.OPT_EAGER_TX_BEGIN @@ -302,26 +363,10 @@ def session_cb(session): session.begin_transaction() return exc - self._test_notify( - error, error_assertion, - "reader_tx_yielding_error_on_begin.script", - session_cb - ) - - def test_not_notify_on_auth_expired_begin_using_tx_run(self): - if get_driver_name() in ["javascript"]: - self.skipTest("Fails on sending RESET after auth-error and " - "surfaces SessionExpired instead.") - self._notify_on_failed_begin_using_tx_run( - self._AUTH_EXPIRED, self.assert_is_authorization_error - ) - - def test_notify_on_token_expired_begin_using_tx_run(self): - self._notify_on_failed_begin_using_tx_run( - self._TOKEN_EXPIRED, self.assert_is_retryable_token_error - ) + self._test_notify("reader_tx_yielding_error_on_begin.script", + session_cb) - def _notify_on_failed_run_using_tx_run(self, error, error_assertion): + def test_error_on_run_using_tx_run(self): def session_cb(session): tx = session.begin_transaction() with self.assertRaises(types.DriverError) as exc: @@ -332,23 +377,9 @@ def session_cb(session): result.consume() return exc - self._test_notify( - error, error_assertion, - "reader_tx_yielding_error_on_run.script", - session_cb - ) - - def test_not_notify_on_auth_expired_run_using_tx_run(self): - self._notify_on_failed_run_using_tx_run( - self._AUTH_EXPIRED, self.assert_is_authorization_error - ) - - def test_notify_on_token_expired_run_using_tx_run(self): - self._notify_on_failed_run_using_tx_run( - self._TOKEN_EXPIRED, self.assert_is_retryable_token_error - ) + self._test_notify("reader_tx_yielding_error_on_run.script", session_cb) - def _notify_on_failed_pull_using_tx_run(self, error, error_assertion): + def test_error_on_pull_using_tx_run(self): def session_cb(session): tx = session.begin_transaction() with self.assertRaises(types.DriverError) as exc: @@ -356,23 +387,10 @@ def session_cb(session): result.next() return exc - self._test_notify( - error, error_assertion, - "reader_tx_yielding_error_on_pull.script", - session_cb - ) + self._test_notify("reader_tx_yielding_error_on_pull.script", + session_cb) - def test_not_notify_on_auth_expired_pull_using_tx_run(self): - self._notify_on_failed_pull_using_tx_run( - self._AUTH_EXPIRED, self.assert_is_authorization_error - ) - - def test_notify_on_token_expired_pull_using_tx_run(self): - self._notify_on_failed_pull_using_tx_run( - self._TOKEN_EXPIRED, self.assert_is_retryable_token_error - ) - - def _notify_on_failed_commit_using_tx_run(self, error, error_assertion): + def test_error_on_commit_using_tx_run(self): def session_cb(session): tx = session.begin_transaction() tx.run("RETURN 1 AS n") @@ -381,22 +399,11 @@ def session_cb(session): return exc self._test_notify( - error, error_assertion, "reader_tx_yielding_error_on_commit_with_pull_or_discard.script", session_cb ) - def test_not_notify_on_auth_expired_commit_using_tx_run(self): - self._notify_on_failed_commit_using_tx_run( - self._AUTH_EXPIRED, self.assert_is_authorization_error - ) - - def test_notify_on_token_expired_commit_using_tx_run(self): - self._notify_on_failed_commit_using_tx_run( - self._TOKEN_EXPIRED, self.assert_is_retryable_token_error - ) - - def _notify_on_failed_rollback_using_tx_run(self, error, error_assertion): + def test_error_on_rollback_using_tx_run(self): def session_cb(session): tx = session.begin_transaction() tx.run("RETURN 1 AS n") @@ -405,21 +412,10 @@ def session_cb(session): return exc self._test_notify( - error, error_assertion, "reader_tx_yielding_error_on_rollback_with_pull_or_discard.script", session_cb ) - def test_not_notify_on_auth_expired_rollback_using_tx_run(self): - self._notify_on_failed_rollback_using_tx_run( - self._AUTH_EXPIRED, self.assert_is_authorization_error - ) - - def test_notify_on_token_expired_rollback_using_tx_run(self): - self._notify_on_failed_rollback_using_tx_run( - self._TOKEN_EXPIRED, self.assert_is_retryable_token_error - ) - class TestAuthTokenManager5x0(TestAuthTokenManager5x1): @@ -436,38 +432,20 @@ def test_static_auth_manager(self): def test_dynamic_auth_manager(self): super().test_dynamic_auth_manager() - def test_not_notify_on_auth_expired_pull_using_session_run(self): - super().test_not_notify_on_auth_expired_pull_using_session_run() - - def test_notify_on_token_expired_pull_using_session_run(self): - super().test_notify_on_token_expired_pull_using_session_run() - - def test_not_notify_on_auth_expired_begin_using_tx_run(self): - super().test_not_notify_on_auth_expired_begin_using_tx_run() - - def test_notify_on_token_expired_begin_using_tx_run(self): - super().test_notify_on_token_expired_begin_using_tx_run() - - def test_not_notify_on_auth_expired_run_using_tx_run(self): - super().test_not_notify_on_auth_expired_run_using_tx_run() - - def test_notify_on_token_expired_run_using_tx_run(self): - super().test_notify_on_token_expired_run_using_tx_run() - - def test_not_notify_on_auth_expired_pull_using_tx_run(self): - super().test_not_notify_on_auth_expired_pull_using_tx_run() + def test_error_on_pull_using_session_run(self): + super().test_error_on_pull_using_session_run() - def test_notify_on_token_expired_pull_using_tx_run(self): - super().test_notify_on_token_expired_pull_using_tx_run() + def test_error_on_begin_using_tx_run(self): + super().test_error_on_begin_using_tx_run() - def test_not_notify_on_auth_expired_commit_using_tx_run(self): - super().test_not_notify_on_auth_expired_commit_using_tx_run() + def test_error_on_run_using_tx_run(self): + super().test_error_on_run_using_tx_run() - def test_notify_on_token_expired_commit_using_tx_run(self): - super().test_notify_on_token_expired_commit_using_tx_run() + def test_error_on_pull_using_tx_run(self): + super().test_error_on_pull_using_tx_run() - def test_not_notify_on_auth_expired_rollback_using_tx_run(self): - super().test_not_notify_on_auth_expired_rollback_using_tx_run() + def test_error_on_commit_using_tx_run(self): + super().test_error_on_commit_using_tx_run() - def test_notify_on_token_expired_rollback_using_tx_run(self): - super().test_notify_on_token_expired_rollback_using_tx_run() + def test_error_on_rollback_using_tx_run(self): + super().test_error_on_rollback_using_tx_run() diff --git a/tests/stub/authorization/test_authorization.py b/tests/stub/authorization/test_authorization.py index 63684bb97..946c62730 100644 --- a/tests/stub/authorization/test_authorization.py +++ b/tests/stub/authorization/test_authorization.py @@ -15,10 +15,14 @@ class AuthorizationBase(TestkitTestCase): # While there is no unified language agnostic error type mapping, a # dedicated driver mapping is required to determine if the expected # error is returned. - def assert_is_authorization_error(self, error): - driver = get_driver_name() + def assert_is_authorization_error(self, error, retryable=False): self.assertEqual("Neo.ClientError.Security.AuthorizationExpired", error.code) + + if retryable: + return self._assert_is_retryable_authorization_error(error) + + driver = get_driver_name() expected_type = None if driver in ["java"]: expected_type = \ @@ -38,12 +42,28 @@ def assert_is_authorization_error(self, error): self.fail("no error mapping is defined for %s driver" % driver) if expected_type is not None: self.assertEqual(expected_type, error.errorType) + # This exception should always be considered retryable + self._assert_retryable(error) - def assert_is_token_error(self, error): + def _assert_is_retryable_authorization_error(self, error): driver = get_driver_name() + expected_type = None + if driver in ["python"]: + expected_type = "" + else: + self.fail("no error mapping is defined for %s driver" % driver) + if expected_type is not None: + self.assertEqual(expected_type, error.errorType) + self._assert_retryable(error) + + def assert_is_token_error(self, error, retryable=False): self.assertEqual("Neo.ClientError.Security.TokenExpired", error.code) self.assertIn("Token expired", error.msg) + if retryable: + return self._assert_is_retryable_token_error(error) + + driver = get_driver_name() expected_type = None if driver in ["python"]: expected_type = "" @@ -63,29 +83,106 @@ def assert_is_token_error(self, error): self.fail("no error mapping is defined for %s driver" % driver) if expected_type is not None: self.assertEqual(expected_type, error.errorType) + self._assert_not_retryable(error) - def assert_is_retryable_token_error(self, error): + def _assert_is_retryable_token_error(self, error): driver = get_driver_name() - self.assertEqual("Neo.ClientError.Security.TokenExpired", error.code) - self.assertIn("Token expired", error.msg) + expected_type = None + if driver in ["python"]: + expected_type = "" + else: + self.fail("no error mapping is defined for %s driver" % driver) + if expected_type is not None: + self.assertEqual(expected_type, error.errorType) + self._assert_retryable(error) + + def assert_is_unauthorized_error(self, error, retryable=False): + self.assertEqual( + "Neo.ClientError.Security.Unauthorized", error.code + ) + self.assertIn("Wrong credentials. Kthxbye!", error.msg) + + if retryable: + return self._assert_is_retryable_unauthorized_error(error) + driver = get_driver_name() expected_type = None if driver in ["python"]: - expected_type = "" - elif driver in ["go", "javascript"]: - pass # code and msg check are enough - elif driver in ["dotnet"]: - expected_type = "ClientError" - elif driver == "java": - expected_type = \ - "org.neo4j.driver.exceptions.TokenExpiredRetryableException" - elif driver == "ruby": - expected_type = \ - "Neo4j::Driver::Exceptions::TokenExpiredRetryableException" + expected_type = "" + else: + self.fail("no error mapping is defined for %s driver" % driver) + if expected_type is not None: + self.assertEqual(expected_type, error.errorType) + self._assert_not_retryable(error) + + def _assert_is_retryable_unauthorized_error(self, error): + driver = get_driver_name() + expected_type = None + if driver in ["python"]: + expected_type = "" + else: + self.fail("no error mapping is defined for %s driver" % driver) + if expected_type is not None: + self.assertEqual(expected_type, error.errorType) + self._assert_retryable(error) + + def assert_is_security_error(self, error, retryable=False): + self.assertEqual("Neo.ClientError.Security.MadeUp", error.code) + self.assertIn(r"Some security issue ¯\_(ツ)_/¯", error.msg) + + if retryable: + return self._assert_is_retryable_security_error(error) + + driver = get_driver_name() + expected_type = None + if driver in ["python"]: + expected_type = "" + else: + self.fail("no error mapping is defined for %s driver" % driver) + if expected_type is not None: + self.assertEqual(expected_type, error.errorType) + self._assert_not_retryable(error) + + def _assert_is_retryable_security_error(self, error): + driver = get_driver_name() + expected_type = None + if driver in ["python"]: + expected_type = "" + else: + self.fail("no error mapping is defined for %s driver" % driver) + if expected_type is not None: + self.assertEqual(expected_type, error.errorType) + self._assert_retryable(error) + + def assert_is_transient_error(self, error): + self.assertEqual("Neo.TransientError.General.TransactionMemoryLimit", + error.code) + self.assertIn(r"RAM sticks", error.msg) + + driver = get_driver_name() + expected_type = None + if driver in ["python"]: + expected_type = "" else: self.fail("no error mapping is defined for %s driver" % driver) if expected_type is not None: self.assertEqual(expected_type, error.errorType) + self._assert_retryable(error) + + def assert_is_random_error(self, error): + self.assertEqual("Neo.ClientError.Procedure.ProcedureCallFailed", + error.code) + self.assertIn(r"A thing happened...", error.msg) + + driver = get_driver_name() + expected_type = None + if driver in ["python"]: + expected_type = "" + else: + self.fail("no error mapping is defined for %s driver" % driver) + if expected_type is not None: + self.assertEqual(expected_type, error.errorType) + self._assert_not_retryable(error) def assert_re_auth_unsupported_error(self, error): self.assertIsInstance(error, types.DriverError) @@ -128,6 +225,18 @@ def assert_re_auth_unsupported_error(self, error): else: self.fail("no error mapping is defined for %s driver" % driver) + def _assert_not_retryable(self, error): + if self.driver_supports_features( + types.Feature.API_RETRYABLE_EXCEPTION + ): + self.assertFalse(error.retryable) + + def _assert_retryable(self, error): + if self.driver_supports_features( + types.Feature.API_RETRYABLE_EXCEPTION + ): + self.assertTrue(error.retryable) + def _find_version_script(self, script_fns): if isinstance(script_fns, str): script_fns = [script_fns] @@ -206,6 +315,18 @@ def get_vars(self): ) _TOKEN_EXPIRED = ('{"code": "Neo.ClientError.Security.TokenExpired", ' '"message": "Token expired"}') + _UNAUTHORIZED = ('{"code": "Neo.ClientError.Security.Unauthorized", ' + '"message": "Wrong credentials. Kthxbye!"}') + _SECURITY_EXC = ('{"code": "Neo.ClientError.Security.MadeUp", ' + r'"message": "Some security issue ¯\\_(ツ)_/¯"}') + _TRANSIENT_EXC = ( + '{"code": "Neo.TransientError.General.TransactionMemoryLimit", ' + '"message": "These are not the RAM sticks you\'re looking for!"}' + ) + _RANDOM_EXC = ( + '{"code": "Neo.ClientError.Procedure.ProcedureCallFailed", ' + '"message": "A thing happened..."}' + ) # TODO: find a way to test that driver ditches all open connection in the pool diff --git a/tests/stub/authorization/test_basic_auth_manager.py b/tests/stub/authorization/test_basic_auth_manager.py new file mode 100644 index 000000000..3f6ad3481 --- /dev/null +++ b/tests/stub/authorization/test_basic_auth_manager.py @@ -0,0 +1,298 @@ +from contextlib import contextmanager + +import nutkit.protocol as types +from nutkit.frontend import ( + BasicAuthTokenManager, + Driver, + FakeTime, +) +from tests.shared import driver_feature +from tests.stub.authorization.test_authorization import AuthorizationBase +from tests.stub.shared import StubServer + + +class TestBearerAuthManager5x1(AuthorizationBase): + + required_features = (types.Feature.BOLT_5_1, + types.Feature.AUTH_MANAGED) + + def setUp(self): + super().setUp() + self._router = StubServer(9000) + self._reader = StubServer(9010) + self._writer = StubServer(9020) + self._uri = "bolt://%s:%d" % (self._reader.host, + self._reader.port) + self._driver = None + + def tearDown(self): + self._router.reset() + self._reader.reset() + self._writer.reset() + if self._driver: + self._driver.close() + super().tearDown() + + def get_vars(self): + host = self._router.host + return { + "#VERSION#": "5.1", + "#HOST#": host, + "#ROUTINGCTX#": '{"address": "' + host + ':9000"}' + } + + @contextmanager + def driver(self, auth, routing=False, **kwargs): + if routing: + uri = f"neo4j://{self._router.address}" + else: + uri = f"bolt://{self._reader.address}" + driver = Driver(self._backend, uri, auth, **kwargs) + try: + yield driver + finally: + driver.close() + + @contextmanager + def session(self, driver, access_mode="r"): + session = driver.session(access_mode, database="adb") + try: + yield session + finally: + session.close() + + def post_script_assertions(self, server): + # add OPT_MINIMAL_RESETS assertion (if driver claims to support it) + if self.driver_supports_features(types.Feature.OPT_MINIMAL_RESETS): + self.assertEqual(server.count_requests("RESET"), 0) + + def test_static_provider(self): + count = 0 + + def provider(): + nonlocal count + count += 1 + return types.AuthorizationToken( + scheme="basic", + principal="neo4j", + credentials="pass" + ) + + auth_manager = BasicAuthTokenManager(self._backend, provider) + + self.start_server( + self._reader, + self.script_fn_with_features("reader_no_reauth.script") + ) + with self.driver(auth_manager) as driver: + with self.session(driver) as session: + list(session.run("RETURN 1 AS n")) + + self._reader.done() + self.assertEqual(count, 1) + self.post_script_assertions(self._reader) + + @driver_feature(types.Feature.BACKEND_MOCK_TIME) + def test_static_provider_long_time(self): + count = 0 + + def provider(): + nonlocal count + count += 1 + return types.AuthorizationToken( + scheme="basic", + principal="neo4j", + credentials="pass" + ) + + auth_manager = BasicAuthTokenManager(self._backend, provider) + + self.start_server( + self._reader, + self.script_fn_with_features("reader_no_reauth.script") + ) + + with FakeTime(self._backend) as time: + with self.driver(auth_manager) as driver: + with self.session(driver) as session: + list(session.run("RETURN 1 AS n")) + # just under 1 hour to make sure to not trip over the + # connection max lifetime + time.tick(1000 * 3600 - 1) + with self.session(driver) as session: + # should still use the same token without calling the + # provider + list(session.run("RETURN 1 AS n")) + + self._reader.done() + self.assertEqual(count, 1) + self.post_script_assertions(self._reader) + + def test_renewing_on_expiration_error(self): + def _test(error_, routing_): + count = 0 + + def provider(): + nonlocal count + count += 1 + credentials = "pass++" if count > 1 else "pass" + principal = "neo5j" if count > 1 else "neo4j" + + return types.AuthorizationToken( + scheme="basic", + principal=principal, + credentials=credentials + ) + + auth_manager = BasicAuthTokenManager(self._backend, provider) + + if error_ in ("authorization_expired",): + reader_script = self.script_fn_with_features( + f"reader_reauth_{error_}.script" + ) + writer_script = self.script_fn_with_features( + f"writer_reauth_{error_}.script" + ) + vars_ = None + expected_call_count = 1 + elif error_ in ("unauthorized",): + reader_script = self.script_fn_with_features( + "reader_reauth_handled.script" + ) + writer_script = self.script_fn_with_features( + "writer_reauth_handled.script" + ) + vars_ = self.get_vars() + expected_call_count = 2 + if error_ == "token_expired": + vars_["#ERROR#"] = self._TOKEN_EXPIRED + elif error_ == "unauthorized": + vars_["#ERROR#"] = self._UNAUTHORIZED + else: + reader_script = "reader_reauth_unhandled.script" + writer_script = "writer_reauth_unhandled.script" + expected_call_count = 1 + vars_ = self.get_vars() + if error_ == "security": + vars_["#ERROR#"] = self._SECURITY_EXC + elif error_ == "token_expired": + vars_["#ERROR#"] = self._TOKEN_EXPIRED + else: + self.fail(f"Unknown error type {error_}") + self.start_server(self._reader, reader_script, vars_=vars_) + if routing_: + self.start_server(self._writer, writer_script, vars_=vars_) + self.start_server(self._router, "router_single_reader.script") + + with self.driver(auth_manager, routing=routing_, + max_connection_pool_size=3) as driver: + if routing_: + with self.session(driver, "w") as session_w: + list(session_w.run("RETURN 1 AS n")) + + with self.session(driver) as session_r1: + with self.session(driver) as session_r2: + with self.session(driver) as session_r3: + # bind connection 1 + s1_tx = session_r1.begin_transaction() + list(s1_tx.run("RETURN 1.1 AS n")) + + self.assertEqual(1, count) + + # bind connection 2 + s2_tx = session_r2.begin_transaction() + list(s2_tx.run("RETURN 2.1 AS n")) + + # bind connection 3 + s3_tx = session_r3.begin_transaction() + list(s3_tx.run("RETURN 3.1 AS n")) + + s2_tx.commit() + + self.assertEqual(1, count) + + with self.assertRaises(types.DriverError) as exc: + # connection 2 fails, gets closed + list(session_r2.run("RETURN 2.2 AS n")) + if error_ == "token_expired": + self.assert_is_token_error( + exc.exception + ) + elif error_ == "authorization_expired": + self.assert_is_authorization_error( + exc.exception + ) + elif error == "unauthorized": + self.assert_is_unauthorized_error( + exc.exception, retryable=True + ) + elif error == "security": + self.assert_is_security_error( + exc.exception + ) + else: + raise ValueError(f"Unknown error {error_}") + + # bind connection 2 + s2_tx = session_r2.begin_transaction() + self.assertEqual(expected_call_count, count) + list(s2_tx.run("RETURN 2.3 AS n")) + + # free connection 1 + s1_tx.commit() + # bind connection 1 + s1_tx = session_r1.begin_transaction() + list(s1_tx.run("RETURN 1.2 AS n")) + + # free connection 3 + s3_tx.commit() + # bind connection 3 + s3_tx = session_r3.begin_transaction() + list(s3_tx.run("RETURN 3.2 AS n")) + + # free all connections + s3_tx.commit() + s1_tx.commit() + s2_tx.commit() + + if routing_: + with self.session(driver, "w") as session_w: + list(session_w.run("RETURN 2 AS n")) + + self.assertEqual(expected_call_count, count) + self._reader.done() + self.post_script_assertions(self._reader) + if routing_: + self._writer.done() + self.post_script_assertions(self._writer) + self._router.done() + self.post_script_assertions(self._router) + + for error in ("authorization_expired", "token_expired", + "unauthorized", "security"): + for routing in (False, True): + with self.subTest(error=error, routing=routing): + try: + _test(error, routing) + finally: + self._reader.reset() + self._writer.reset() + self._router.reset() + + +class TestBearerAuthManager5x0(TestBearerAuthManager5x1): + + required_features = (types.Feature.BOLT_5_0, + types.Feature.AUTH_MANAGED) + + def get_vars(self): + return {**super().get_vars(), "#VERSION#": "5.0"} + + def test_static_provider(self): + super().test_static_provider() + + def test_static_provider_long_time(self): + super().test_static_provider_long_time() + + def test_renewing_on_expiration_error(self): + super().test_renewing_on_expiration_error() diff --git a/tests/stub/authorization/test_expiration_based_auth_manager.py b/tests/stub/authorization/test_bearer_auth_manager.py similarity index 80% rename from tests/stub/authorization/test_expiration_based_auth_manager.py rename to tests/stub/authorization/test_bearer_auth_manager.py index 5c9b43aa0..7f17a08f6 100644 --- a/tests/stub/authorization/test_expiration_based_auth_manager.py +++ b/tests/stub/authorization/test_bearer_auth_manager.py @@ -2,8 +2,8 @@ import nutkit.protocol as types from nutkit.frontend import ( + BearerAuthTokenManager, Driver, - ExpirationBasedAuthTokenManager, FakeTime, ) from tests.shared import driver_feature @@ -11,7 +11,7 @@ from tests.stub.shared import StubServer -class TestExpirationBasedAuthManager5x1(AuthorizationBase): +class TestBearerAuthManager5x1(AuthorizationBase): required_features = (types.Feature.BOLT_5_1, types.Feature.AUTH_MANAGED) @@ -80,7 +80,7 @@ def provider(): ) ) - auth_manager = ExpirationBasedAuthTokenManager(self._backend, provider) + auth_manager = BearerAuthTokenManager(self._backend, provider) self.start_server( self._reader, @@ -109,7 +109,7 @@ def provider(): ) ) - auth_manager = ExpirationBasedAuthTokenManager(self._backend, provider) + auth_manager = BearerAuthTokenManager(self._backend, provider) self.start_server( self._reader, @@ -151,7 +151,7 @@ def provider(): 10_000 ) - auth_manager = ExpirationBasedAuthTokenManager(self._backend, provider) + auth_manager = BearerAuthTokenManager(self._backend, provider) self.start_server(self._reader, self.script_fn_with_features("reader_reauth.script")) @@ -199,22 +199,42 @@ def provider(): 10_000 ) - auth_manager = ExpirationBasedAuthTokenManager(self._backend, - provider) + auth_manager = BearerAuthTokenManager(self._backend, provider) - expected_call_count = 2 if error_ == "token_expired" else 1 - - self.start_server( - self._reader, - self.script_fn_with_features(f"reader_reauth_{error_}.script") - ) - if routing_: - self.start_server( - self._writer, - self.script_fn_with_features( - f"writer_reauth_{error_}.script" - ) + if error_ in ("authorization_expired",): + reader_script = self.script_fn_with_features( + f"reader_reauth_{error_}.script" + ) + writer_script = self.script_fn_with_features( + f"writer_reauth_{error_}.script" ) + vars_ = None + expected_call_count = 1 + elif error_ in ("token_expired", "unauthorized"): + reader_script = self.script_fn_with_features( + "reader_reauth_handled.script" + ) + writer_script = self.script_fn_with_features( + "writer_reauth_handled.script" + ) + vars_ = self.get_vars() + expected_call_count = 2 + if error_ == "token_expired": + vars_["#ERROR#"] = self._TOKEN_EXPIRED + elif error_ == "unauthorized": + vars_["#ERROR#"] = self._UNAUTHORIZED + else: + reader_script = "reader_reauth_unhandled.script" + writer_script = "writer_reauth_unhandled.script" + expected_call_count = 1 + vars_ = self.get_vars() + if error_ == "security": + vars_["#ERROR#"] = self._SECURITY_EXC + else: + self.fail(f"Unknown error type {error_}") + self.start_server(self._reader, reader_script, vars_=vars_) + if routing_: + self.start_server(self._writer, writer_script, vars_=vars_) self.start_server(self._router, "router_single_reader.script") with self.driver(auth_manager, routing=routing_, @@ -248,13 +268,21 @@ def provider(): # connection 2 fails, gets closed list(session_r2.run("RETURN 2.2 AS n")) if error_ == "token_expired": - self.assert_is_retryable_token_error( - exc.exception + self.assert_is_token_error( + exc.exception, retryable=True ) elif error_ == "authorization_expired": self.assert_is_authorization_error( exc.exception ) + elif error == "unauthorized": + self.assert_is_unauthorized_error( + exc.exception, retryable=True + ) + elif error == "security": + self.assert_is_security_error( + exc.exception + ) else: raise ValueError(f"Unknown error {error_}") @@ -293,7 +321,8 @@ def provider(): self._router.done() self.post_script_assertions(self._router) - for error in ("authorization_expired", "token_expired"): + for error in ("authorization_expired", "token_expired", + "unauthorized", "security"): for routing in (False, True): with self.subTest(error=error, routing=routing): try: @@ -303,12 +332,8 @@ def provider(): self._writer.reset() self._router.reset() - @driver_feature(types.Feature.API_DRIVER_SUPPORTS_SESSION_AUTH) - def test_not_renewing_on_user_switch_expiration_error(self): - ... - -class TestExpirationBasedAuthManager5x0(TestExpirationBasedAuthManager5x1): +class TestBearerAuthManager5x0(TestBearerAuthManager5x1): required_features = (types.Feature.BOLT_5_0, types.Feature.AUTH_MANAGED) @@ -327,6 +352,3 @@ def test_expiring_token_deadline(self): def test_renewing_on_expiration_error(self): super().test_renewing_on_expiration_error() - - def test_not_renewing_on_user_switch_expiration_error(self): - super().test_not_renewing_on_user_switch_expiration_error() diff --git a/tests/stub/authorization/token_expired_retry/test_token_expired_retry.py b/tests/stub/authorization/token_expired_retry/test_token_expired_retry.py index 7a52af7c0..89478a209 100644 --- a/tests/stub/authorization/token_expired_retry/test_token_expired_retry.py +++ b/tests/stub/authorization/token_expired_retry/test_token_expired_retry.py @@ -2,9 +2,9 @@ import nutkit.protocol as types from nutkit.frontend import ( - AuthTokenManager, + BasicAuthTokenManager, + BearerAuthTokenManager, Driver, - ExpirationBasedAuthTokenManager, ) from tests.shared import driver_feature from tests.stub.authorization.test_authorization import AuthorizationBase @@ -65,7 +65,7 @@ def work(tx): with self.assertRaises(types.DriverError) as exc: res = tx.run("RETURN 1 AS n") list(res) - self.assert_is_retryable_token_error(exc.exception) + self.assert_is_token_error(exc.exception, retryable=True) raise exc.exception else: res = tx.run(f"RETURN {count} AS n") @@ -116,54 +116,42 @@ def test_no_retry_with_static_token(self): self._router.reset() @driver_feature(types.Feature.AUTH_MANAGED) - def test_retry_with_temporal_token(self): - count = 0 - + def test_no_retry_with_basic_manager(self): def provider(): nonlocal count count += 1 if count == 1: - return types.AuthTokenAndExpiration(self._auth1, None) - return types.AuthTokenAndExpiration(self._auth2, None) + return self._auth1 + return self._auth2 for routing in (True, False): with self.subTest(routing=routing): - auth = ExpirationBasedAuthTokenManager(self._backend, provider) - self._test_retry(auth, routing=routing) - self.assertEqual(count, 2) + count = 0 + auth = BasicAuthTokenManager(self._backend, provider) + self._test_no_retry(auth, routing=routing) + self.assertEqual(count, 1) self._reader.reset() self._router.reset() - count = 0 @driver_feature(types.Feature.AUTH_MANAGED) - def test_retry_with_static_token(self): - expired_count = 0 - get_count = 0 - - def get_auth(): - nonlocal get_count - nonlocal expired_count - get_count += 1 - if expired_count == 0: - return self._auth1 - return self._auth2 + def test_retry_with_bearer_manager(self): + count = 0 - def on_auth_expired(auth_): - nonlocal expired_count - expired_count += 1 - assert auth_ == self._auth1 + def provider(): + nonlocal count + count += 1 + if count == 1: + return types.AuthTokenAndExpiration(self._auth1, None) + return types.AuthTokenAndExpiration(self._auth2, None) for routing in (True, False): with self.subTest(routing=routing): - auth = AuthTokenManager(self._backend, get_auth, - on_auth_expired) + auth = BearerAuthTokenManager(self._backend, provider) self._test_retry(auth, routing=routing) - self.assertEqual(expired_count, 1) - self.assertEqual(get_count, 3 if routing else 2) + self.assertEqual(count, 2) self._reader.reset() self._router.reset() - expired_count = 0 - get_count = 0 + count = 0 class TestTokenExpiredRetryV5x0(TestTokenExpiredRetryV5x1): @@ -176,8 +164,8 @@ def get_vars(self): def test_no_retry_with_static_token(self): super().test_no_retry_with_static_token() - def test_retry_with_temporal_token(self): - super().test_retry_with_temporal_token() + def test_no_retry_with_basic_manager(self): + super().test_no_retry_with_basic_manager() - def test_retry_with_static_token(self): - super().test_retry_with_static_token() + def test_retry_with_bearer_manager(self): + super().test_retry_with_bearer_manager() From 1ed1c150745afd8caaa743dc5dc53cc7fdb42d41 Mon Sep 17 00:00:00 2001 From: Antonio Barcelos Date: Fri, 11 Aug 2023 14:01:47 +0200 Subject: [PATCH 2/7] Add error mapping for javascript --- tests/stub/authorization/test_authorization.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/stub/authorization/test_authorization.py b/tests/stub/authorization/test_authorization.py index 946c62730..caedabc1d 100644 --- a/tests/stub/authorization/test_authorization.py +++ b/tests/stub/authorization/test_authorization.py @@ -50,6 +50,8 @@ def _assert_is_retryable_authorization_error(self, error): expected_type = None if driver in ["python"]: expected_type = "" + elif driver in ["javascript"]: + pass else: self.fail("no error mapping is defined for %s driver" % driver) if expected_type is not None: @@ -90,6 +92,8 @@ def _assert_is_retryable_token_error(self, error): expected_type = None if driver in ["python"]: expected_type = "" + elif driver in ["javascript"]: + pass else: self.fail("no error mapping is defined for %s driver" % driver) if expected_type is not None: @@ -109,6 +113,8 @@ def assert_is_unauthorized_error(self, error, retryable=False): expected_type = None if driver in ["python"]: expected_type = "" + elif driver in ["javascript"]: + pass else: self.fail("no error mapping is defined for %s driver" % driver) if expected_type is not None: @@ -120,6 +126,8 @@ def _assert_is_retryable_unauthorized_error(self, error): expected_type = None if driver in ["python"]: expected_type = "" + elif driver in ["javascript"]: + pass else: self.fail("no error mapping is defined for %s driver" % driver) if expected_type is not None: @@ -137,6 +145,8 @@ def assert_is_security_error(self, error, retryable=False): expected_type = None if driver in ["python"]: expected_type = "" + elif driver in ["javascript"]: + pass else: self.fail("no error mapping is defined for %s driver" % driver) if expected_type is not None: @@ -148,6 +158,8 @@ def _assert_is_retryable_security_error(self, error): expected_type = None if driver in ["python"]: expected_type = "" + elif driver in ["javascript"]: + pass else: self.fail("no error mapping is defined for %s driver" % driver) if expected_type is not None: @@ -163,6 +175,8 @@ def assert_is_transient_error(self, error): expected_type = None if driver in ["python"]: expected_type = "" + elif driver in ["javascript"]: + pass else: self.fail("no error mapping is defined for %s driver" % driver) if expected_type is not None: @@ -178,6 +192,8 @@ def assert_is_random_error(self, error): expected_type = None if driver in ["python"]: expected_type = "" + elif driver in ["javascript"]: + pass else: self.fail("no error mapping is defined for %s driver" % driver) if expected_type is not None: From 8cc105d463ce104b2afa0a88a0f0e7d7748566c3 Mon Sep 17 00:00:00 2001 From: Antonio Barcelos Date: Fri, 11 Aug 2023 14:34:08 +0200 Subject: [PATCH 3/7] Fix basic names --- tests/stub/authorization/test_basic_auth_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/stub/authorization/test_basic_auth_manager.py b/tests/stub/authorization/test_basic_auth_manager.py index 3f6ad3481..70612985d 100644 --- a/tests/stub/authorization/test_basic_auth_manager.py +++ b/tests/stub/authorization/test_basic_auth_manager.py @@ -11,7 +11,7 @@ from tests.stub.shared import StubServer -class TestBearerAuthManager5x1(AuthorizationBase): +class TestBasicAuthManager5x1(AuthorizationBase): required_features = (types.Feature.BOLT_5_1, types.Feature.AUTH_MANAGED) @@ -280,7 +280,7 @@ def provider(): self._router.reset() -class TestBearerAuthManager5x0(TestBearerAuthManager5x1): +class TestBasicAuthManager5x0(TestBasicAuthManager5x1): required_features = (types.Feature.BOLT_5_0, types.Feature.AUTH_MANAGED) From 8b764e88accac43b7dc43930d6e8b144a8bb47cb Mon Sep 17 00:00:00 2001 From: Antonio Barcelos Date: Fri, 11 Aug 2023 17:50:48 +0200 Subject: [PATCH 4/7] Add test checking if bearer and basic managers handles unknown auth This scenario is important to make sure it retries on previous auths without have to save it on state. --- ...reader_reauth_authorization_expired.script | 5 + ...eauth_authorization_expired_minimal.script | 5 +- .../scripts/v5x0/reader_reauth_handled.script | 5 + .../v5x0/reader_reauth_handled_minimal.script | 5 + .../scripts/v5x0/reader_reauth_minimal.script | 5 + .../v5x0/reader_reauth_unhandled.script | 5 + ...reader_reauth_authorization_expired.script | 5 + ...uth_authorization_expired_pipelined.script | 5 + ...orization_expired_pipelined_minimal.script | 5 + .../scripts/v5x1/reader_reauth_handled.script | 5 + .../reader_reauth_handled_pipelined.script | 4 + ...er_reauth_handled_pipelined_minimal.script | 5 + .../v5x1/reader_reauth_unhandled.script | 5 + .../authorization/test_basic_auth_manager.py | 134 +++++++++++++++++ .../authorization/test_bearer_auth_manager.py | 137 ++++++++++++++++++ 15 files changed, 334 insertions(+), 1 deletion(-) diff --git a/tests/stub/authorization/scripts/v5x0/reader_reauth_authorization_expired.script b/tests/stub/authorization/scripts/v5x0/reader_reauth_authorization_expired.script index c0727f53d..dea68e75a 100644 --- a/tests/stub/authorization/scripts/v5x0/reader_reauth_authorization_expired.script +++ b/tests/stub/authorization/scripts/v5x0/reader_reauth_authorization_expired.script @@ -38,6 +38,11 @@ S: SUCCESS {} *: RESET + {? + C: BEGIN {"{}": "*"} + S: SUCCESS {} + ?} + C: RUN "RETURN 2.2 AS n" "*" "*" S: FAILURE {"code": "Neo.ClientError.Security.AuthorizationExpired", "message": "Authorization expired."} S: diff --git a/tests/stub/authorization/scripts/v5x0/reader_reauth_authorization_expired_minimal.script b/tests/stub/authorization/scripts/v5x0/reader_reauth_authorization_expired_minimal.script index f1012d150..e8b18c8e0 100644 --- a/tests/stub/authorization/scripts/v5x0/reader_reauth_authorization_expired_minimal.script +++ b/tests/stub/authorization/scripts/v5x0/reader_reauth_authorization_expired_minimal.script @@ -37,7 +37,10 @@ S: SUCCESS {} S: SUCCESS {} *: RESET - + {? + C: BEGIN {"{}": "*"} + S: SUCCESS {} + ?} C: RUN "RETURN 2.2 AS n" "*" "*" S: FAILURE {"code": "Neo.ClientError.Security.AuthorizationExpired", "message": "Authorization expired."} S: diff --git a/tests/stub/authorization/scripts/v5x0/reader_reauth_handled.script b/tests/stub/authorization/scripts/v5x0/reader_reauth_handled.script index 315608e8f..07210cd9c 100644 --- a/tests/stub/authorization/scripts/v5x0/reader_reauth_handled.script +++ b/tests/stub/authorization/scripts/v5x0/reader_reauth_handled.script @@ -30,6 +30,11 @@ *: RESET + {? + C: BEGIN {"{}": "*"} + S: SUCCESS {} + ?} + C: RUN "RETURN 2.2 AS n" "*" "*" S: FAILURE #ERROR# S: diff --git a/tests/stub/authorization/scripts/v5x0/reader_reauth_handled_minimal.script b/tests/stub/authorization/scripts/v5x0/reader_reauth_handled_minimal.script index 81a95939f..19a61791f 100644 --- a/tests/stub/authorization/scripts/v5x0/reader_reauth_handled_minimal.script +++ b/tests/stub/authorization/scripts/v5x0/reader_reauth_handled_minimal.script @@ -30,6 +30,11 @@ *: RESET + {? + C: BEGIN {"{}": "*"} + S: SUCCESS {} + ?} + C: RUN "RETURN 2.2 AS n" "*" "*" S: FAILURE #ERROR# S: diff --git a/tests/stub/authorization/scripts/v5x0/reader_reauth_minimal.script b/tests/stub/authorization/scripts/v5x0/reader_reauth_minimal.script index a6b946b96..775bc3e1b 100644 --- a/tests/stub/authorization/scripts/v5x0/reader_reauth_minimal.script +++ b/tests/stub/authorization/scripts/v5x0/reader_reauth_minimal.script @@ -14,6 +14,11 @@ *: RESET + {? + C: BEGIN {"{}": "*"} + S: SUCCESS {} + ?} + C: RUN "RETURN 2 AS n" "*" "*" S: SUCCESS {"fields": ["n"]} C: PULL "*" diff --git a/tests/stub/authorization/scripts/v5x0/reader_reauth_unhandled.script b/tests/stub/authorization/scripts/v5x0/reader_reauth_unhandled.script index 6dc4669b1..70b22dcf2 100644 --- a/tests/stub/authorization/scripts/v5x0/reader_reauth_unhandled.script +++ b/tests/stub/authorization/scripts/v5x0/reader_reauth_unhandled.script @@ -41,6 +41,11 @@ S: SUCCESS {} *: RESET + {? + C: BEGIN {"{}": "*"} + S: SUCCESS {} + ?} + C: RUN "RETURN 2.2 AS n" "*" "*" S: FAILURE #ERROR# S: diff --git a/tests/stub/authorization/scripts/v5x1/reader_reauth_authorization_expired.script b/tests/stub/authorization/scripts/v5x1/reader_reauth_authorization_expired.script index cc8b8d786..bc7b85385 100644 --- a/tests/stub/authorization/scripts/v5x1/reader_reauth_authorization_expired.script +++ b/tests/stub/authorization/scripts/v5x1/reader_reauth_authorization_expired.script @@ -47,6 +47,11 @@ S: SUCCESS {} *: RESET + {? + C: BEGIN {"{}": "*"} + S: SUCCESS {} + ?} + C: RUN "RETURN 2.2 AS n" "*" "*" S: FAILURE {"code": "Neo.ClientError.Security.AuthorizationExpired", "message": "Authorization expired."} S: diff --git a/tests/stub/authorization/scripts/v5x1/reader_reauth_authorization_expired_pipelined.script b/tests/stub/authorization/scripts/v5x1/reader_reauth_authorization_expired_pipelined.script index f52364e0d..8b71197c9 100644 --- a/tests/stub/authorization/scripts/v5x1/reader_reauth_authorization_expired_pipelined.script +++ b/tests/stub/authorization/scripts/v5x1/reader_reauth_authorization_expired_pipelined.script @@ -59,6 +59,11 @@ S: SUCCESS {} *: RESET + {? + C: BEGIN {"{}": "*"} + S: SUCCESS {} + ?} + C: RUN "RETURN 2.2 AS n" "*" "*" S: FAILURE {"code": "Neo.ClientError.Security.AuthorizationExpired", "message": "Authorization expired."} S: diff --git a/tests/stub/authorization/scripts/v5x1/reader_reauth_authorization_expired_pipelined_minimal.script b/tests/stub/authorization/scripts/v5x1/reader_reauth_authorization_expired_pipelined_minimal.script index 94e2f988e..36c64ccd6 100644 --- a/tests/stub/authorization/scripts/v5x1/reader_reauth_authorization_expired_pipelined_minimal.script +++ b/tests/stub/authorization/scripts/v5x1/reader_reauth_authorization_expired_pipelined_minimal.script @@ -59,6 +59,11 @@ S: SUCCESS {} *: RESET + {? + C: BEGIN {"{}": "*"} + S: SUCCESS {} + ?} + C: RUN "RETURN 2.2 AS n" "*" "*" S: FAILURE {"code": "Neo.ClientError.Security.AuthorizationExpired", "message": "Authorization expired."} S: diff --git a/tests/stub/authorization/scripts/v5x1/reader_reauth_handled.script b/tests/stub/authorization/scripts/v5x1/reader_reauth_handled.script index 0a2d97906..6d0aa5079 100644 --- a/tests/stub/authorization/scripts/v5x1/reader_reauth_handled.script +++ b/tests/stub/authorization/scripts/v5x1/reader_reauth_handled.script @@ -47,6 +47,11 @@ A: HELLO {"user_agent": "*", "[routing]": "*"} *: RESET + {? + C: BEGIN {"{}": "*"} + S: SUCCESS {} + ?} + C: RUN "RETURN 2.2 AS n" "*" "*" S: FAILURE #ERROR# S: diff --git a/tests/stub/authorization/scripts/v5x1/reader_reauth_handled_pipelined.script b/tests/stub/authorization/scripts/v5x1/reader_reauth_handled_pipelined.script index f7a7d7dd2..a5c4b5fc2 100644 --- a/tests/stub/authorization/scripts/v5x1/reader_reauth_handled_pipelined.script +++ b/tests/stub/authorization/scripts/v5x1/reader_reauth_handled_pipelined.script @@ -59,6 +59,10 @@ C: HELLO {"user_agent": "*", "[routing]": "*"} *: RESET + {? + C: BEGIN {"{}": "*"} + S: SUCCESS {} + ?} C: RUN "RETURN 2.2 AS n" "*" "*" S: FAILURE #ERROR# S: diff --git a/tests/stub/authorization/scripts/v5x1/reader_reauth_handled_pipelined_minimal.script b/tests/stub/authorization/scripts/v5x1/reader_reauth_handled_pipelined_minimal.script index 860270201..3b7b8ca85 100644 --- a/tests/stub/authorization/scripts/v5x1/reader_reauth_handled_pipelined_minimal.script +++ b/tests/stub/authorization/scripts/v5x1/reader_reauth_handled_pipelined_minimal.script @@ -59,6 +59,11 @@ C: HELLO {"user_agent": "*", "[routing]": "*"} *: RESET + {? + C: BEGIN {"{}": "*"} + S: SUCCESS {} + ?} + C: RUN "RETURN 2.2 AS n" "*" "*" S: FAILURE #ERROR# S: diff --git a/tests/stub/authorization/scripts/v5x1/reader_reauth_unhandled.script b/tests/stub/authorization/scripts/v5x1/reader_reauth_unhandled.script index 3110912ee..12e1ac1bd 100644 --- a/tests/stub/authorization/scripts/v5x1/reader_reauth_unhandled.script +++ b/tests/stub/authorization/scripts/v5x1/reader_reauth_unhandled.script @@ -42,6 +42,11 @@ S: SUCCESS {} *: RESET + {? + C: BEGIN {"{}": "*"} + S: SUCCESS {} + ?} + C: RUN "RETURN 2.2 AS n" "*" "*" S: FAILURE #ERROR# S: diff --git a/tests/stub/authorization/test_basic_auth_manager.py b/tests/stub/authorization/test_basic_auth_manager.py index 70612985d..f97cf98b8 100644 --- a/tests/stub/authorization/test_basic_auth_manager.py +++ b/tests/stub/authorization/test_basic_auth_manager.py @@ -279,6 +279,137 @@ def provider(): self._writer.reset() self._router.reset() + def test_handles_unknown_auth(self): + def _trigger_error(runner, error_): + with self.assertRaises(types.DriverError) as exc: + # connection fails, gets closed + list(runner.run("RETURN 2.2 AS n")) + if error_ == "token_expired": + self.assert_is_token_error( + exc.exception + ) + elif error == "unauthorized": + self.assert_is_unauthorized_error( + exc.exception, retryable=True + ) + elif error == "security": + self.assert_is_security_error( + exc.exception + ) + else: + raise ValueError(f"Unknown error {error_}") + + def _test(error_, routing_): + count = 0 + + def provider(): + nonlocal count + count += 1 + credentials = "pass++" if count > 1 else "pass" + principal = "neo5j" if count > 1 else "neo4j" + + return types.AuthorizationToken( + scheme="basic", + principal=principal, + credentials=credentials + ) + + auth_manager = BasicAuthTokenManager(self._backend, provider) + + if error_ in ("unauthorized",): + reader_script = self.script_fn_with_features( + "reader_reauth_handled.script" + ) + writer_script = self.script_fn_with_features( + "writer_reauth_handled.script" + ) + vars_ = self.get_vars() + expected_call_count = 2 + if error_ == "token_expired": + vars_["#ERROR#"] = self._TOKEN_EXPIRED + elif error_ == "unauthorized": + vars_["#ERROR#"] = self._UNAUTHORIZED + else: + reader_script = "reader_reauth_unhandled.script" + writer_script = "writer_reauth_unhandled.script" + expected_call_count = 1 + vars_ = self.get_vars() + if error_ == "security": + vars_["#ERROR#"] = self._SECURITY_EXC + elif error_ == "token_expired": + vars_["#ERROR#"] = self._TOKEN_EXPIRED + else: + self.fail(f"Unknown error type {error_}") + self.start_server(self._reader, reader_script, vars_=vars_) + if routing_: + self.start_server(self._writer, writer_script, vars_=vars_) + self.start_server(self._router, "router_single_reader.script") + + with self.driver(auth_manager, routing=routing_, + max_connection_pool_size=2) as driver: + if routing_: + with self.session(driver, "w") as session_w: + list(session_w.run("RETURN 1 AS n")) + + with self.session(driver) as session_r1: + with self.session(driver) as session_r2: + # bind connection 1 + s1_tx = session_r1.begin_transaction() + list(s1_tx.run("RETURN 2.1 AS n")) + + self.assertEqual(1, count) + + # bind connection 2 + s2_tx = session_r2.begin_transaction() + list(s2_tx.run("RETURN 2.1 AS n")) + + s2_tx.commit() + s1_tx.commit() + + s1_tx = session_r1.begin_transaction() + + self.assertEqual(1, count) + + _trigger_error(session_r2, error_) + _trigger_error(s1_tx, error_) + + # bind connection 2 + s2_tx = session_r2.begin_transaction() + self.assertEqual(expected_call_count, count) + list(s2_tx.run("RETURN 2.3 AS n")) + + # bind connection 1 + s1_tx = session_r1.begin_transaction() + list(s1_tx.run("RETURN 2.3 AS n")) + + # free all connections + s1_tx.commit() + s2_tx.commit() + + if routing_: + with self.session(driver, "w") as session_w: + list(session_w.run("RETURN 2 AS n")) + + self.assertEqual(expected_call_count, count) + self._reader.done() + self.post_script_assertions(self._reader) + if routing_: + self._writer.done() + self.post_script_assertions(self._writer) + self._router.done() + self.post_script_assertions(self._router) + + for error in ("token_expired", + "unauthorized", "security"): + for routing in (False, True): + with self.subTest(error=error, routing=routing): + try: + _test(error, routing) + finally: + self._reader.reset() + self._writer.reset() + self._router.reset() + class TestBasicAuthManager5x0(TestBasicAuthManager5x1): @@ -296,3 +427,6 @@ def test_static_provider_long_time(self): def test_renewing_on_expiration_error(self): super().test_renewing_on_expiration_error() + + def test_handles_unknown_auth(self): + super().test_handles_unknown_auth() diff --git a/tests/stub/authorization/test_bearer_auth_manager.py b/tests/stub/authorization/test_bearer_auth_manager.py index 7f17a08f6..c92b599b6 100644 --- a/tests/stub/authorization/test_bearer_auth_manager.py +++ b/tests/stub/authorization/test_bearer_auth_manager.py @@ -332,6 +332,140 @@ def provider(): self._writer.reset() self._router.reset() + def test_handles_unknown_auth(self): + def _trigger_error(runner, error_): + with self.assertRaises(types.DriverError) as exc: + # connection fails, gets closed + list(runner.run("RETURN 2.2 AS n")) + if error_ == "token_expired": + self.assert_is_token_error( + exc.exception, retryable=True + ) + elif error == "unauthorized": + self.assert_is_unauthorized_error( + exc.exception, retryable=True + ) + elif error == "security": + self.assert_is_security_error( + exc.exception + ) + else: + raise ValueError(f"Unknown error {error_}") + + def _test(error_, routing_): + count = 0 + + def provider(): + nonlocal count + count += 1 + credentials = "pass++" if count > 1 else "pass" + principal = "neo5j" if count > 1 else "neo4j" + + return types.AuthTokenAndExpiration( + types.AuthorizationToken( + scheme="basic", + principal=principal, + credentials=credentials + ), + 10_000 + ) + + auth_manager = BearerAuthTokenManager(self._backend, provider) + + if error_ in ("unauthorized", "token_expired"): + reader_script = self.script_fn_with_features( + "reader_reauth_handled.script" + ) + writer_script = self.script_fn_with_features( + "writer_reauth_handled.script" + ) + vars_ = self.get_vars() + expected_call_count = 2 + if error_ == "token_expired": + vars_["#ERROR#"] = self._TOKEN_EXPIRED + elif error_ == "unauthorized": + vars_["#ERROR#"] = self._UNAUTHORIZED + else: + reader_script = "reader_reauth_unhandled.script" + writer_script = "writer_reauth_unhandled.script" + expected_call_count = 1 + vars_ = self.get_vars() + if error_ == "security": + vars_["#ERROR#"] = self._SECURITY_EXC + elif error_ == "token_expired": + vars_["#ERROR#"] = self._TOKEN_EXPIRED + else: + self.fail(f"Unknown error type {error_}") + self.start_server(self._reader, reader_script, vars_=vars_) + if routing_: + self.start_server(self._writer, writer_script, vars_=vars_) + self.start_server(self._router, "router_single_reader.script") + + with self.driver(auth_manager, routing=routing_, + max_connection_pool_size=2) as driver: + if routing_: + with self.session(driver, "w") as session_w: + list(session_w.run("RETURN 1 AS n")) + + with self.session(driver) as session_r1: + with self.session(driver) as session_r2: + # bind connection 1 + s1_tx = session_r1.begin_transaction() + list(s1_tx.run("RETURN 2.1 AS n")) + + self.assertEqual(1, count) + + # bind connection 2 + s2_tx = session_r2.begin_transaction() + list(s2_tx.run("RETURN 2.1 AS n")) + + s2_tx.commit() + s1_tx.commit() + + s1_tx = session_r1.begin_transaction() + + self.assertEqual(1, count) + + _trigger_error(session_r2, error_) + _trigger_error(s1_tx, error_) + + # bind connection 2 + s2_tx = session_r2.begin_transaction() + self.assertEqual(expected_call_count, count) + list(s2_tx.run("RETURN 2.3 AS n")) + + # bind connection 1 + s1_tx = session_r1.begin_transaction() + list(s1_tx.run("RETURN 2.3 AS n")) + + # free all connections + s1_tx.commit() + s2_tx.commit() + + if routing_: + with self.session(driver, "w") as session_w: + list(session_w.run("RETURN 2 AS n")) + + self.assertEqual(expected_call_count, count) + self._reader.done() + self.post_script_assertions(self._reader) + if routing_: + self._writer.done() + self.post_script_assertions(self._writer) + self._router.done() + self.post_script_assertions(self._router) + + for error in ("token_expired", + "unauthorized", "security"): + for routing in (False, True): + with self.subTest(error=error, routing=routing): + try: + _test(error, routing) + finally: + self._reader.reset() + self._writer.reset() + self._router.reset() + class TestBearerAuthManager5x0(TestBearerAuthManager5x1): @@ -352,3 +486,6 @@ def test_expiring_token_deadline(self): def test_renewing_on_expiration_error(self): super().test_renewing_on_expiration_error() + + def test_handles_unknown_auth(self): + super().test_handles_unknown_auth() From 840c4e4e12e1825bf05c1e3ab2d8f7dbc6d397c8 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Mon, 14 Aug 2023 10:18:08 +0200 Subject: [PATCH 5/7] Code clean-up --- .../stub/authorization/test_basic_auth_manager.py | 8 ++------ .../stub/authorization/test_bearer_auth_manager.py | 14 +++++--------- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/tests/stub/authorization/test_basic_auth_manager.py b/tests/stub/authorization/test_basic_auth_manager.py index f97cf98b8..0ff20adbc 100644 --- a/tests/stub/authorization/test_basic_auth_manager.py +++ b/tests/stub/authorization/test_basic_auth_manager.py @@ -325,10 +325,7 @@ def provider(): ) vars_ = self.get_vars() expected_call_count = 2 - if error_ == "token_expired": - vars_["#ERROR#"] = self._TOKEN_EXPIRED - elif error_ == "unauthorized": - vars_["#ERROR#"] = self._UNAUTHORIZED + vars_["#ERROR#"] = self._UNAUTHORIZED else: reader_script = "reader_reauth_unhandled.script" writer_script = "writer_reauth_unhandled.script" @@ -399,8 +396,7 @@ def provider(): self._router.done() self.post_script_assertions(self._router) - for error in ("token_expired", - "unauthorized", "security"): + for error in ("token_expired", "unauthorized", "security"): for routing in (False, True): with self.subTest(error=error, routing=routing): try: diff --git a/tests/stub/authorization/test_bearer_auth_manager.py b/tests/stub/authorization/test_bearer_auth_manager.py index c92b599b6..0564aaacf 100644 --- a/tests/stub/authorization/test_bearer_auth_manager.py +++ b/tests/stub/authorization/test_bearer_auth_manager.py @@ -385,17 +385,14 @@ def provider(): vars_["#ERROR#"] = self._TOKEN_EXPIRED elif error_ == "unauthorized": vars_["#ERROR#"] = self._UNAUTHORIZED - else: + elif error_ in ("security",): reader_script = "reader_reauth_unhandled.script" writer_script = "writer_reauth_unhandled.script" expected_call_count = 1 vars_ = self.get_vars() - if error_ == "security": - vars_["#ERROR#"] = self._SECURITY_EXC - elif error_ == "token_expired": - vars_["#ERROR#"] = self._TOKEN_EXPIRED - else: - self.fail(f"Unknown error type {error_}") + vars_["#ERROR#"] = self._SECURITY_EXC + else: + self.fail(f"Unknown error type {error_}") self.start_server(self._reader, reader_script, vars_=vars_) if routing_: self.start_server(self._writer, writer_script, vars_=vars_) @@ -455,8 +452,7 @@ def provider(): self._router.done() self.post_script_assertions(self._router) - for error in ("token_expired", - "unauthorized", "security"): + for error in ("token_expired", "unauthorized", "security"): for routing in (False, True): with self.subTest(error=error, routing=routing): try: From 4840062c89f81bfd9c2cbe54e53873299120b4c5 Mon Sep 17 00:00:00 2001 From: Richard Irons Date: Mon, 14 Aug 2023 14:24:49 +0100 Subject: [PATCH 6/7] Add some .NET error mappings --- tests/stub/authorization/test_authorization.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/stub/authorization/test_authorization.py b/tests/stub/authorization/test_authorization.py index caedabc1d..27e12c2d9 100644 --- a/tests/stub/authorization/test_authorization.py +++ b/tests/stub/authorization/test_authorization.py @@ -52,6 +52,8 @@ def _assert_is_retryable_authorization_error(self, error): expected_type = "" elif driver in ["javascript"]: pass + elif driver in ["dotnet"]: + expected_type = "AuthorizationExpired" else: self.fail("no error mapping is defined for %s driver" % driver) if expected_type is not None: @@ -94,6 +96,8 @@ def _assert_is_retryable_token_error(self, error): expected_type = "" elif driver in ["javascript"]: pass + elif driver in ["dotnet"]: + expected_type = "ClientError" else: self.fail("no error mapping is defined for %s driver" % driver) if expected_type is not None: @@ -115,6 +119,8 @@ def assert_is_unauthorized_error(self, error, retryable=False): expected_type = "" elif driver in ["javascript"]: pass + elif driver in ["dotnet"]: + expected_type = "AuthenticationError" else: self.fail("no error mapping is defined for %s driver" % driver) if expected_type is not None: @@ -128,6 +134,8 @@ def _assert_is_retryable_unauthorized_error(self, error): expected_type = "" elif driver in ["javascript"]: pass + elif driver in ["dotnet"]: + expected_type = "AuthenticationError" else: self.fail("no error mapping is defined for %s driver" % driver) if expected_type is not None: @@ -147,6 +155,8 @@ def assert_is_security_error(self, error, retryable=False): expected_type = "" elif driver in ["javascript"]: pass + elif driver in ["dotnet"]: + expected_type = "OtherSecurityException" else: self.fail("no error mapping is defined for %s driver" % driver) if expected_type is not None: From dec44d25bd3905621e5b9e7002a3ba747409847b Mon Sep 17 00:00:00 2001 From: Dmitriy Tverdiakov Date: Tue, 15 Aug 2023 12:08:31 +0100 Subject: [PATCH 7/7] Add Java mappings --- .../stub/authorization/test_authorization.py | 21 +++++++++++++++++++ .../authorization/test_basic_auth_manager.py | 1 + .../authorization/test_bearer_auth_manager.py | 1 + 3 files changed, 23 insertions(+) diff --git a/tests/stub/authorization/test_authorization.py b/tests/stub/authorization/test_authorization.py index 27e12c2d9..e7e9f1a8b 100644 --- a/tests/stub/authorization/test_authorization.py +++ b/tests/stub/authorization/test_authorization.py @@ -54,6 +54,9 @@ def _assert_is_retryable_authorization_error(self, error): pass elif driver in ["dotnet"]: expected_type = "AuthorizationExpired" + elif driver in ["java"]: + expected_type = \ + "org.neo4j.driver.exceptions.SecurityRetryableException" else: self.fail("no error mapping is defined for %s driver" % driver) if expected_type is not None: @@ -98,6 +101,9 @@ def _assert_is_retryable_token_error(self, error): pass elif driver in ["dotnet"]: expected_type = "ClientError" + elif driver in ["java"]: + expected_type = \ + "org.neo4j.driver.exceptions.SecurityRetryableException" else: self.fail("no error mapping is defined for %s driver" % driver) if expected_type is not None: @@ -121,6 +127,9 @@ def assert_is_unauthorized_error(self, error, retryable=False): pass elif driver in ["dotnet"]: expected_type = "AuthenticationError" + elif driver in ["java"]: + expected_type = \ + "org.neo4j.driver.exceptions.AuthenticationException" else: self.fail("no error mapping is defined for %s driver" % driver) if expected_type is not None: @@ -136,6 +145,9 @@ def _assert_is_retryable_unauthorized_error(self, error): pass elif driver in ["dotnet"]: expected_type = "AuthenticationError" + elif driver in ["java"]: + expected_type = \ + "org.neo4j.driver.exceptions.SecurityRetryableException" else: self.fail("no error mapping is defined for %s driver" % driver) if expected_type is not None: @@ -157,6 +169,8 @@ def assert_is_security_error(self, error, retryable=False): pass elif driver in ["dotnet"]: expected_type = "OtherSecurityException" + elif driver in ["java"]: + expected_type = "org.neo4j.driver.exceptions.SecurityException" else: self.fail("no error mapping is defined for %s driver" % driver) if expected_type is not None: @@ -170,6 +184,9 @@ def _assert_is_retryable_security_error(self, error): expected_type = "" elif driver in ["javascript"]: pass + elif driver in ["java"]: + expected_type = \ + "org.neo4j.driver.exceptions.SecurityRetryableException" else: self.fail("no error mapping is defined for %s driver" % driver) if expected_type is not None: @@ -187,6 +204,8 @@ def assert_is_transient_error(self, error): expected_type = "" elif driver in ["javascript"]: pass + if driver in ["java"]: + expected_type = "org.neo4j.driver.exceptions.TransientException" else: self.fail("no error mapping is defined for %s driver" % driver) if expected_type is not None: @@ -204,6 +223,8 @@ def assert_is_random_error(self, error): expected_type = "" elif driver in ["javascript"]: pass + if driver in ["java"]: + expected_type = "org.neo4j.driver.exceptions.ClientException" else: self.fail("no error mapping is defined for %s driver" % driver) if expected_type is not None: diff --git a/tests/stub/authorization/test_basic_auth_manager.py b/tests/stub/authorization/test_basic_auth_manager.py index 0ff20adbc..b702bd087 100644 --- a/tests/stub/authorization/test_basic_auth_manager.py +++ b/tests/stub/authorization/test_basic_auth_manager.py @@ -369,6 +369,7 @@ def provider(): _trigger_error(session_r2, error_) _trigger_error(s1_tx, error_) + s1_tx.close() # bind connection 2 s2_tx = session_r2.begin_transaction() diff --git a/tests/stub/authorization/test_bearer_auth_manager.py b/tests/stub/authorization/test_bearer_auth_manager.py index 0564aaacf..5a60c7527 100644 --- a/tests/stub/authorization/test_bearer_auth_manager.py +++ b/tests/stub/authorization/test_bearer_auth_manager.py @@ -425,6 +425,7 @@ def provider(): _trigger_error(session_r2, error_) _trigger_error(s1_tx, error_) + s1_tx.close() # bind connection 2 s2_tx = session_r2.begin_transaction()