diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index 21f87ff74..4c9ebfbb9 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -28,6 +28,7 @@ from ..._codec.hydration import v1 as hydration_v1 from ..._codec.packstream import v1 as packstream_v1 from ..._conf import PoolConfig +from ..._deadline import Deadline from ..._exceptions import ( BoltError, BoltHandshakeError, @@ -289,17 +290,21 @@ def get_handshake(cls): return b"".join(version.to_bytes() for version in offered_versions).ljust(16, b"\x00") @classmethod - async def ping(cls, address, *, timeout=None, pool_config=None): + async def ping(cls, address, *, deadline=None, pool_config=None): """ Attempt to establish a Bolt connection, returning the agreed Bolt protocol version if successful. """ if pool_config is None: pool_config = PoolConfig() + if deadline is None: + deadline = Deadline(None) + try: s, protocol_version, handshake, data = \ await AsyncBoltSocket.connect( address, - timeout=timeout, + tcp_timeout=pool_config.connection_timeout, + deadline=deadline, custom_resolver=pool_config.resolver, ssl_context=pool_config.get_ssl_context(), keep_alive=pool_config.keep_alive, @@ -313,14 +318,14 @@ async def ping(cls, address, *, timeout=None, pool_config=None): # [bolt-version-bump] search tag when changing bolt version support @classmethod async def open( - cls, address, *, auth=None, timeout=None, routing_context=None, + cls, address, *, auth=None, deadline=None, routing_context=None, pool_config=None ): """Open a new Bolt connection to a given server address. :param address: :param auth: - :param timeout: the connection timeout in seconds + :param deadline: how long to wait for the connection to be established :param routing_context: dict containing routing context :param pool_config: @@ -330,26 +335,17 @@ async def open( raised if the Bolt Protocol can not negotiate a protocol version. :raise ServiceUnavailable: raised if there was a connection issue. """ - def time_remaining(): - if timeout is None: - return None - t = timeout - (perf_counter() - t0) - return t if t > 0 else 0 - t0 = perf_counter() if pool_config is None: pool_config = PoolConfig() + if deadline is None: + deadline = Deadline(None) - socket_connection_timeout = pool_config.connection_timeout - if socket_connection_timeout is None: - socket_connection_timeout = time_remaining() - elif timeout is not None: - socket_connection_timeout = min(pool_config.connection_timeout, - time_remaining()) s, protocol_version, handshake, data = \ await AsyncBoltSocket.connect( address, - timeout=socket_connection_timeout, + tcp_timeout=pool_config.connection_timeout, + deadline=deadline, custom_resolver=pool_config.resolver, ssl_context=pool_config.get_ssl_context(), keep_alive=pool_config.keep_alive, @@ -410,7 +406,7 @@ def time_remaining(): ) try: - connection.socket.set_deadline(time_remaining()) + connection.socket.set_deadline(deadline) try: await connection.hello() finally: diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index 8376d5610..9058bb5a6 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -135,9 +135,7 @@ async def connection_creator(): released_reservation = False try: try: - connection = await self.opener( - address, deadline.to_timeout() - ) + connection = await self.opener(address, deadline) except ServiceUnavailable: await self.deactivate(address) raise @@ -382,9 +380,9 @@ def open(cls, address, *, auth, pool_config, workspace_config): :returns: BoltPool """ - async def opener(addr, timeout): + async def opener(addr, deadline): return await AsyncBolt.open( - addr, auth=auth, timeout=timeout, routing_context=None, + addr, auth=auth, deadline=deadline, routing_context=None, pool_config=pool_config ) @@ -437,9 +435,9 @@ def open(cls, *addresses, auth, pool_config, workspace_config, raise ConfigurationError("The key 'address' is reserved for routing context.") routing_context["address"] = str(address) - async def opener(addr, timeout): + async def opener(addr, deadline): return await AsyncBolt.open( - addr, auth=auth, timeout=timeout, + addr, auth=auth, deadline=deadline, routing_context=routing_context, pool_config=pool_config ) diff --git a/src/neo4j/_async_compat/network/_bolt_socket.py b/src/neo4j/_async_compat/network/_bolt_socket.py index 18dec61fe..507d080a2 100644 --- a/src/neo4j/_async_compat/network/_bolt_socket.py +++ b/src/neo4j/_async_compat/network/_bolt_socket.py @@ -272,11 +272,11 @@ async def _connect_secure(cls, resolved_address, timeout, keep_alive, ssl): ) from error raise - async def _handshake(self, resolved_address): + async def _handshake(self, resolved_address, deadline): """ - :param s: Socket :param resolved_address: + :param deadline: Deadline for handshake :returns: (socket, version, client_handshake, server_response_data) """ @@ -296,47 +296,52 @@ async def _handshake(self, resolved_address): log.debug("[#%04X] C: %s %s %s %s", local_port, *supported_versions) - data = self.Bolt.MAGIC_PREAMBLE + self.Bolt.get_handshake() - await self.sendall(data) + request = self.Bolt.MAGIC_PREAMBLE + self.Bolt.get_handshake() # Handle the handshake response original_timeout = self.gettimeout() - if original_timeout is not None: - self.settimeout(original_timeout + 1) + self.settimeout(deadline.to_timeout()) try: - data = await self.recv(4) + await self.sendall(request) + response = await self.recv(4) except OSError as exc: raise ServiceUnavailable( - "Failed to read any data from server {!r} " - "after connected".format(resolved_address)) from exc + f"Failed to read any data from server {resolved_address!r} " + f"after connected (deadline {deadline})" + ) from exc finally: self.settimeout(original_timeout) - data_size = len(data) + data_size = len(response) if data_size == 0: # If no data is returned after a successful select # response, the server has closed the connection log.debug("[#%04X] S: ", local_port) await self.close() raise ServiceUnavailable( - "Connection to {address} closed without handshake response".format( - address=resolved_address)) + f"Connection to {resolved_address} closed without handshake " + "response" + ) if data_size != 4: # Some garbled data has been received log.debug("[#%04X] S: @*#!", local_port) await self.close() raise BoltProtocolError( - "Expected four byte Bolt handshake response from %r, received %r instead; check for incorrect port number" % ( - resolved_address, data), address=resolved_address) - elif data == b"HTTP": + "Expected four byte Bolt handshake response from " + f"{resolved_address!r}, received {response!r} instead; " + "check for incorrect port number" + , address=resolved_address + ) + elif response == b"HTTP": log.debug("[#%04X] S: ", local_port) await self.close() raise ServiceUnavailable( - "Cannot to connect to Bolt service on {!r} " - "(looks like HTTP)".format(resolved_address)) - agreed_version = data[-1], data[-2] + f"Cannot to connect to Bolt service on {resolved_address!r} " + "(looks like HTTP)" + ) + agreed_version = response[-1], response[-2] log.debug("[#%04X] S: 0x%06X%02X", local_port, agreed_version[1], agreed_version[0]) - return self, agreed_version, handshake, data + return self, agreed_version, handshake, response @classmethod async def close_socket(cls, socket_): @@ -356,8 +361,8 @@ async def close_socket(cls, socket_): pass @classmethod - async def connect(cls, address, *, timeout, custom_resolver, ssl_context, - keep_alive): + async def connect(cls, address, *, tcp_timeout, deadline, + custom_resolver, ssl_context, keep_alive): """ Connect and perform a handshake and return a valid Connection object, assuming a protocol version can be agreed. """ @@ -371,12 +376,18 @@ async def connect(cls, address, *, timeout, custom_resolver, ssl_context, addressing.Address(address), resolver=custom_resolver ) async for resolved_address in resolved_addresses: + deadline_timeout = deadline.to_timeout() + if ( + deadline_timeout is not None + and deadline_timeout <= tcp_timeout + ): + tcp_timeout = deadline_timeout s = None try: s = await cls._connect_secure( - resolved_address, timeout, keep_alive, ssl_context + resolved_address, tcp_timeout, keep_alive, ssl_context ) - return await s._handshake(resolved_address) + return await s._handshake(resolved_address, deadline) except (BoltError, DriverError, OSError) as error: try: local_port = s.getsockname()[1] @@ -560,11 +571,12 @@ def _secure(cls, s, host, ssl_context): return s @classmethod - def _handshake(cls, s, resolved_address): + def _handshake(cls, s, resolved_address, deadline): """ :param s: Socket :param resolved_address: + :param deadline: :returns: (socket, version, client_handshake, server_response_data) """ @@ -584,46 +596,52 @@ def _handshake(cls, s, resolved_address): log.debug("[#%04X] C: %s %s %s %s", local_port, *supported_versions) - data = cls.Bolt.MAGIC_PREAMBLE + cls.Bolt.get_handshake() - s.sendall(data) + request = cls.Bolt.MAGIC_PREAMBLE + cls.Bolt.get_handshake() # Handle the handshake response - ready_to_read = False - with selectors.DefaultSelector() as selector: - selector.register(s, selectors.EVENT_READ) - selector.select(1) + original_timeout = s.gettimeout() + s.settimeout(deadline.to_timeout()) try: - data = s.recv(4) + s.sendall(request) + response = s.recv(4) except OSError as exc: raise ServiceUnavailable( - "Failed to read any data from server {!r} " - "after connected".format(resolved_address)) from exc - data_size = len(data) + f"Failed to read any data from server {resolved_address!r} " + f"after connected (deadline {deadline})" + ) from exc + finally: + s.settimeout(original_timeout) + data_size = len(response) if data_size == 0: # If no data is returned after a successful select # response, the server has closed the connection log.debug("[#%04X] S: ", local_port) cls.close_socket(s) raise ServiceUnavailable( - "Connection to {address} closed without handshake response".format( - address=resolved_address)) + f"Connection to {resolved_address} closed without handshake " + "response" + ) if data_size != 4: # Some garbled data has been received log.debug("[#%04X] S: @*#!", local_port) cls.close_socket(s) raise BoltProtocolError( - "Expected four byte Bolt handshake response from %r, received %r instead; check for incorrect port number" % ( - resolved_address, data), address=resolved_address) - elif data == b"HTTP": + "Expected four byte Bolt handshake response from " + f"{resolved_address!r}, received {response!r} instead; " + "check for incorrect port number" + , address=resolved_address + ) + elif response == b"HTTP": log.debug("[#%04X] S: ", local_port) cls.close_socket(s) raise ServiceUnavailable( - "Cannot to connect to Bolt service on {!r} " - "(looks like HTTP)".format(resolved_address)) - agreed_version = data[-1], data[-2] + f"Cannot to connect to Bolt service on {resolved_address!r} " + "(looks like HTTP)" + ) + agreed_version = response[-1], response[-2] log.debug("[#%04X] S: 0x%06X%02X", local_port, agreed_version[1], agreed_version[0]) - return cls(s), agreed_version, handshake, data + return cls(s), agreed_version, handshake, response @classmethod def close_socket(cls, socket_): @@ -639,8 +657,8 @@ def close_socket(cls, socket_): pass @classmethod - def connect(cls, address, *, timeout, custom_resolver, ssl_context, - keep_alive): + def connect(cls, address, *, tcp_timeout, deadline, custom_resolver, + ssl_context, keep_alive): """ Connect and perform a handshake and return a valid Connection object, assuming a protocol version can be agreed. """ @@ -653,12 +671,19 @@ def connect(cls, address, *, timeout, custom_resolver, ssl_context, addressing.Address(address), resolver=custom_resolver ) for resolved_address in resolved_addresses: + deadline_timeout = deadline.to_timeout() + if ( + deadline_timeout is not None + and deadline_timeout <= tcp_timeout + ): + tcp_timeout = deadline_timeout s = None try: - s = BoltSocket._connect(resolved_address, timeout, keep_alive) + s = BoltSocket._connect(resolved_address, tcp_timeout, + keep_alive) s = BoltSocket._secure(s, resolved_address._host_name, ssl_context) - return BoltSocket._handshake(s, resolved_address) + return BoltSocket._handshake(s, resolved_address, deadline) except (BoltError, DriverError, OSError) as error: try: local_port = s.getsockname()[1] diff --git a/src/neo4j/_deadline.py b/src/neo4j/_deadline.py index cfcc9035d..560e55a54 100644 --- a/src/neo4j/_deadline.py +++ b/src/neo4j/_deadline.py @@ -72,6 +72,9 @@ def from_timeout_or_deadline(cls, timeout): return timeout return cls(timeout) + def __str__(self): + return f"Deadline(timeout={self._original_timeout})" + merge_deadlines = min diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index f1d7f1528..09f832e20 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -28,6 +28,7 @@ from ..._codec.hydration import v1 as hydration_v1 from ..._codec.packstream import v1 as packstream_v1 from ..._conf import PoolConfig +from ..._deadline import Deadline from ..._exceptions import ( BoltError, BoltHandshakeError, @@ -289,17 +290,21 @@ def get_handshake(cls): return b"".join(version.to_bytes() for version in offered_versions).ljust(16, b"\x00") @classmethod - def ping(cls, address, *, timeout=None, pool_config=None): + def ping(cls, address, *, deadline=None, pool_config=None): """ Attempt to establish a Bolt connection, returning the agreed Bolt protocol version if successful. """ if pool_config is None: pool_config = PoolConfig() + if deadline is None: + deadline = Deadline(None) + try: s, protocol_version, handshake, data = \ BoltSocket.connect( address, - timeout=timeout, + tcp_timeout=pool_config.connection_timeout, + deadline=deadline, custom_resolver=pool_config.resolver, ssl_context=pool_config.get_ssl_context(), keep_alive=pool_config.keep_alive, @@ -313,14 +318,14 @@ def ping(cls, address, *, timeout=None, pool_config=None): # [bolt-version-bump] search tag when changing bolt version support @classmethod def open( - cls, address, *, auth=None, timeout=None, routing_context=None, + cls, address, *, auth=None, deadline=None, routing_context=None, pool_config=None ): """Open a new Bolt connection to a given server address. :param address: :param auth: - :param timeout: the connection timeout in seconds + :param deadline: how long to wait for the connection to be established :param routing_context: dict containing routing context :param pool_config: @@ -330,26 +335,17 @@ def open( raised if the Bolt Protocol can not negotiate a protocol version. :raise ServiceUnavailable: raised if there was a connection issue. """ - def time_remaining(): - if timeout is None: - return None - t = timeout - (perf_counter() - t0) - return t if t > 0 else 0 - t0 = perf_counter() if pool_config is None: pool_config = PoolConfig() + if deadline is None: + deadline = Deadline(None) - socket_connection_timeout = pool_config.connection_timeout - if socket_connection_timeout is None: - socket_connection_timeout = time_remaining() - elif timeout is not None: - socket_connection_timeout = min(pool_config.connection_timeout, - time_remaining()) s, protocol_version, handshake, data = \ BoltSocket.connect( address, - timeout=socket_connection_timeout, + tcp_timeout=pool_config.connection_timeout, + deadline=deadline, custom_resolver=pool_config.resolver, ssl_context=pool_config.get_ssl_context(), keep_alive=pool_config.keep_alive, @@ -410,7 +406,7 @@ def time_remaining(): ) try: - connection.socket.set_deadline(time_remaining()) + connection.socket.set_deadline(deadline) try: connection.hello() finally: diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index 9fe584284..229cc8c91 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -135,9 +135,7 @@ def connection_creator(): released_reservation = False try: try: - connection = self.opener( - address, deadline.to_timeout() - ) + connection = self.opener(address, deadline) except ServiceUnavailable: self.deactivate(address) raise @@ -382,9 +380,9 @@ def open(cls, address, *, auth, pool_config, workspace_config): :returns: BoltPool """ - def opener(addr, timeout): + def opener(addr, deadline): return Bolt.open( - addr, auth=auth, timeout=timeout, routing_context=None, + addr, auth=auth, deadline=deadline, routing_context=None, pool_config=pool_config ) @@ -437,9 +435,9 @@ def open(cls, *addresses, auth, pool_config, workspace_config, raise ConfigurationError("The key 'address' is reserved for routing context.") routing_context["address"] = str(address) - def opener(addr, timeout): + def opener(addr, deadline): return Bolt.open( - addr, auth=auth, timeout=timeout, + addr, auth=auth, deadline=deadline, routing_context=routing_context, pool_config=pool_config ) diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py index 7729036ea..541fff6fb 100644 --- a/tests/unit/async_/io/test_direct.py +++ b/tests/unit/async_/io/test_direct.py @@ -102,14 +102,17 @@ async def acquire( @mark_async_test async def test_bolt_connection_open(): with pytest.raises(ServiceUnavailable): - await AsyncBolt.open(("localhost", 9999), auth=("test", "test")) + await AsyncBolt.open( + ("localhost", 9999), auth=("test", "test") + ) @mark_async_test async def test_bolt_connection_open_timeout(): with pytest.raises(ServiceUnavailable): - await AsyncBolt.open(("localhost", 9999), auth=("test", "test"), - timeout=1) + await AsyncBolt.open( + ("localhost", 9999), auth=("test", "test"), deadline=Deadline(1) + ) @mark_async_test @@ -120,7 +123,9 @@ async def test_bolt_connection_ping(): @mark_async_test async def test_bolt_connection_ping_timeout(): - protocol_version = await AsyncBolt.ping(("localhost", 9999), timeout=1) + protocol_version = await AsyncBolt.ping( + ("localhost", 9999), deadline=Deadline(1) + ) assert protocol_version is None diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py index bd5dbe70d..45a901b8e 100644 --- a/tests/unit/sync/io/test_direct.py +++ b/tests/unit/sync/io/test_direct.py @@ -102,14 +102,17 @@ def acquire( @mark_sync_test def test_bolt_connection_open(): with pytest.raises(ServiceUnavailable): - Bolt.open(("localhost", 9999), auth=("test", "test")) + Bolt.open( + ("localhost", 9999), auth=("test", "test") + ) @mark_sync_test def test_bolt_connection_open_timeout(): with pytest.raises(ServiceUnavailable): - Bolt.open(("localhost", 9999), auth=("test", "test"), - timeout=1) + Bolt.open( + ("localhost", 9999), auth=("test", "test"), deadline=Deadline(1) + ) @mark_sync_test @@ -120,7 +123,9 @@ def test_bolt_connection_ping(): @mark_sync_test def test_bolt_connection_ping_timeout(): - protocol_version = Bolt.ping(("localhost", 9999), timeout=1) + protocol_version = Bolt.ping( + ("localhost", 9999), deadline=Deadline(1) + ) assert protocol_version is None