Skip to content

Fix bolt handshake not having a timeout #905

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 14 additions & 18 deletions src/neo4j/_async/io/_bolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ..._codec.hydration import v1 as hydration_v1
from ..._codec.packstream import v1 as packstream_v1
from ..._conf import PoolConfig
from ..._deadline import Deadline
from ..._exceptions import (
BoltError,
BoltHandshakeError,
Expand Down Expand Up @@ -289,17 +290,21 @@ def get_handshake(cls):
return b"".join(version.to_bytes() for version in offered_versions).ljust(16, b"\x00")

@classmethod
async def ping(cls, address, *, timeout=None, pool_config=None):
async def ping(cls, address, *, deadline=None, pool_config=None):
""" Attempt to establish a Bolt connection, returning the
agreed Bolt protocol version if successful.
"""
if pool_config is None:
pool_config = PoolConfig()
if deadline is None:
deadline = Deadline(None)

try:
s, protocol_version, handshake, data = \
await AsyncBoltSocket.connect(
address,
timeout=timeout,
tcp_timeout=pool_config.connection_timeout,
deadline=deadline,
custom_resolver=pool_config.resolver,
ssl_context=pool_config.get_ssl_context(),
keep_alive=pool_config.keep_alive,
Expand All @@ -313,14 +318,14 @@ async def ping(cls, address, *, timeout=None, pool_config=None):
# [bolt-version-bump] search tag when changing bolt version support
@classmethod
async def open(
cls, address, *, auth=None, timeout=None, routing_context=None,
cls, address, *, auth=None, deadline=None, routing_context=None,
pool_config=None
):
"""Open a new Bolt connection to a given server address.

:param address:
:param auth:
:param timeout: the connection timeout in seconds
:param deadline: how long to wait for the connection to be established
:param routing_context: dict containing routing context
:param pool_config:

Expand All @@ -330,26 +335,17 @@ async def open(
raised if the Bolt Protocol can not negotiate a protocol version.
:raise ServiceUnavailable: raised if there was a connection issue.
"""
def time_remaining():
if timeout is None:
return None
t = timeout - (perf_counter() - t0)
return t if t > 0 else 0

t0 = perf_counter()
if pool_config is None:
pool_config = PoolConfig()
if deadline is None:
deadline = Deadline(None)

socket_connection_timeout = pool_config.connection_timeout
if socket_connection_timeout is None:
socket_connection_timeout = time_remaining()
elif timeout is not None:
socket_connection_timeout = min(pool_config.connection_timeout,
time_remaining())
s, protocol_version, handshake, data = \
await AsyncBoltSocket.connect(
address,
timeout=socket_connection_timeout,
tcp_timeout=pool_config.connection_timeout,
deadline=deadline,
custom_resolver=pool_config.resolver,
ssl_context=pool_config.get_ssl_context(),
keep_alive=pool_config.keep_alive,
Expand Down Expand Up @@ -410,7 +406,7 @@ def time_remaining():
)

try:
connection.socket.set_deadline(time_remaining())
connection.socket.set_deadline(deadline)
try:
await connection.hello()
finally:
Expand Down
12 changes: 5 additions & 7 deletions src/neo4j/_async/io/_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,7 @@ async def connection_creator():
released_reservation = False
try:
try:
connection = await self.opener(
address, deadline.to_timeout()
)
connection = await self.opener(address, deadline)
except ServiceUnavailable:
await self.deactivate(address)
raise
Expand Down Expand Up @@ -382,9 +380,9 @@ def open(cls, address, *, auth, pool_config, workspace_config):
:returns: BoltPool
"""

async def opener(addr, timeout):
async def opener(addr, deadline):
return await AsyncBolt.open(
addr, auth=auth, timeout=timeout, routing_context=None,
addr, auth=auth, deadline=deadline, routing_context=None,
pool_config=pool_config
)

Expand Down Expand Up @@ -437,9 +435,9 @@ def open(cls, *addresses, auth, pool_config, workspace_config,
raise ConfigurationError("The key 'address' is reserved for routing context.")
routing_context["address"] = str(address)

async def opener(addr, timeout):
async def opener(addr, deadline):
return await AsyncBolt.open(
addr, auth=auth, timeout=timeout,
addr, auth=auth, deadline=deadline,
routing_context=routing_context, pool_config=pool_config
)

Expand Down
119 changes: 72 additions & 47 deletions src/neo4j/_async_compat/network/_bolt_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,11 +272,11 @@ async def _connect_secure(cls, resolved_address, timeout, keep_alive, ssl):
) from error
raise

async def _handshake(self, resolved_address):
async def _handshake(self, resolved_address, deadline):
"""

:param s: Socket
:param resolved_address:
:param deadline: Deadline for handshake

:returns: (socket, version, client_handshake, server_response_data)
"""
Expand All @@ -296,47 +296,52 @@ async def _handshake(self, resolved_address):
log.debug("[#%04X] C: <HANDSHAKE> %s %s %s %s", local_port,
*supported_versions)

data = self.Bolt.MAGIC_PREAMBLE + self.Bolt.get_handshake()
await self.sendall(data)
request = self.Bolt.MAGIC_PREAMBLE + self.Bolt.get_handshake()

# Handle the handshake response
original_timeout = self.gettimeout()
if original_timeout is not None:
self.settimeout(original_timeout + 1)
self.settimeout(deadline.to_timeout())
try:
data = await self.recv(4)
await self.sendall(request)
response = await self.recv(4)
except OSError as exc:
raise ServiceUnavailable(
"Failed to read any data from server {!r} "
"after connected".format(resolved_address)) from exc
f"Failed to read any data from server {resolved_address!r} "
f"after connected (deadline {deadline})"
) from exc
finally:
self.settimeout(original_timeout)
data_size = len(data)
data_size = len(response)
if data_size == 0:
# If no data is returned after a successful select
# response, the server has closed the connection
log.debug("[#%04X] S: <CLOSE>", local_port)
await self.close()
raise ServiceUnavailable(
"Connection to {address} closed without handshake response".format(
address=resolved_address))
f"Connection to {resolved_address} closed without handshake "
"response"
)
if data_size != 4:
# Some garbled data has been received
log.debug("[#%04X] S: @*#!", local_port)
await self.close()
raise BoltProtocolError(
"Expected four byte Bolt handshake response from %r, received %r instead; check for incorrect port number" % (
resolved_address, data), address=resolved_address)
elif data == b"HTTP":
"Expected four byte Bolt handshake response from "
f"{resolved_address!r}, received {response!r} instead; "
"check for incorrect port number"
, address=resolved_address
)
elif response == b"HTTP":
log.debug("[#%04X] S: <CLOSE>", local_port)
await self.close()
raise ServiceUnavailable(
"Cannot to connect to Bolt service on {!r} "
"(looks like HTTP)".format(resolved_address))
agreed_version = data[-1], data[-2]
f"Cannot to connect to Bolt service on {resolved_address!r} "
"(looks like HTTP)"
)
agreed_version = response[-1], response[-2]
log.debug("[#%04X] S: <HANDSHAKE> 0x%06X%02X", local_port,
agreed_version[1], agreed_version[0])
return self, agreed_version, handshake, data
return self, agreed_version, handshake, response

@classmethod
async def close_socket(cls, socket_):
Expand All @@ -356,8 +361,8 @@ async def close_socket(cls, socket_):
pass

@classmethod
async def connect(cls, address, *, timeout, custom_resolver, ssl_context,
keep_alive):
async def connect(cls, address, *, tcp_timeout, deadline,
custom_resolver, ssl_context, keep_alive):
""" Connect and perform a handshake and return a valid Connection object,
assuming a protocol version can be agreed.
"""
Expand All @@ -371,12 +376,18 @@ async def connect(cls, address, *, timeout, custom_resolver, ssl_context,
addressing.Address(address), resolver=custom_resolver
)
async for resolved_address in resolved_addresses:
deadline_timeout = deadline.to_timeout()
if (
deadline_timeout is not None
and deadline_timeout <= tcp_timeout
):
tcp_timeout = deadline_timeout
s = None
try:
s = await cls._connect_secure(
resolved_address, timeout, keep_alive, ssl_context
resolved_address, tcp_timeout, keep_alive, ssl_context
)
return await s._handshake(resolved_address)
return await s._handshake(resolved_address, deadline)
except (BoltError, DriverError, OSError) as error:
try:
local_port = s.getsockname()[1]
Expand Down Expand Up @@ -560,11 +571,12 @@ def _secure(cls, s, host, ssl_context):
return s

@classmethod
def _handshake(cls, s, resolved_address):
def _handshake(cls, s, resolved_address, deadline):
"""

:param s: Socket
:param resolved_address:
:param deadline:

:returns: (socket, version, client_handshake, server_response_data)
"""
Expand All @@ -584,46 +596,52 @@ def _handshake(cls, s, resolved_address):
log.debug("[#%04X] C: <HANDSHAKE> %s %s %s %s", local_port,
*supported_versions)

data = cls.Bolt.MAGIC_PREAMBLE + cls.Bolt.get_handshake()
s.sendall(data)
request = cls.Bolt.MAGIC_PREAMBLE + cls.Bolt.get_handshake()

# Handle the handshake response
ready_to_read = False
with selectors.DefaultSelector() as selector:
selector.register(s, selectors.EVENT_READ)
selector.select(1)
original_timeout = s.gettimeout()
s.settimeout(deadline.to_timeout())
try:
data = s.recv(4)
s.sendall(request)
response = s.recv(4)
except OSError as exc:
raise ServiceUnavailable(
"Failed to read any data from server {!r} "
"after connected".format(resolved_address)) from exc
data_size = len(data)
f"Failed to read any data from server {resolved_address!r} "
f"after connected (deadline {deadline})"
) from exc
finally:
s.settimeout(original_timeout)
data_size = len(response)
if data_size == 0:
# If no data is returned after a successful select
# response, the server has closed the connection
log.debug("[#%04X] S: <CLOSE>", local_port)
cls.close_socket(s)
raise ServiceUnavailable(
"Connection to {address} closed without handshake response".format(
address=resolved_address))
f"Connection to {resolved_address} closed without handshake "
"response"
)
if data_size != 4:
# Some garbled data has been received
log.debug("[#%04X] S: @*#!", local_port)
cls.close_socket(s)
raise BoltProtocolError(
"Expected four byte Bolt handshake response from %r, received %r instead; check for incorrect port number" % (
resolved_address, data), address=resolved_address)
elif data == b"HTTP":
"Expected four byte Bolt handshake response from "
f"{resolved_address!r}, received {response!r} instead; "
"check for incorrect port number"
, address=resolved_address
)
elif response == b"HTTP":
log.debug("[#%04X] S: <CLOSE>", local_port)
cls.close_socket(s)
raise ServiceUnavailable(
"Cannot to connect to Bolt service on {!r} "
"(looks like HTTP)".format(resolved_address))
agreed_version = data[-1], data[-2]
f"Cannot to connect to Bolt service on {resolved_address!r} "
"(looks like HTTP)"
)
agreed_version = response[-1], response[-2]
log.debug("[#%04X] S: <HANDSHAKE> 0x%06X%02X", local_port,
agreed_version[1], agreed_version[0])
return cls(s), agreed_version, handshake, data
return cls(s), agreed_version, handshake, response

@classmethod
def close_socket(cls, socket_):
Expand All @@ -639,8 +657,8 @@ def close_socket(cls, socket_):
pass

@classmethod
def connect(cls, address, *, timeout, custom_resolver, ssl_context,
keep_alive):
def connect(cls, address, *, tcp_timeout, deadline, custom_resolver,
ssl_context, keep_alive):
""" Connect and perform a handshake and return a valid Connection object,
assuming a protocol version can be agreed.
"""
Expand All @@ -653,12 +671,19 @@ def connect(cls, address, *, timeout, custom_resolver, ssl_context,
addressing.Address(address), resolver=custom_resolver
)
for resolved_address in resolved_addresses:
deadline_timeout = deadline.to_timeout()
if (
deadline_timeout is not None
and deadline_timeout <= tcp_timeout
):
tcp_timeout = deadline_timeout
s = None
try:
s = BoltSocket._connect(resolved_address, timeout, keep_alive)
s = BoltSocket._connect(resolved_address, tcp_timeout,
keep_alive)
s = BoltSocket._secure(s, resolved_address._host_name,
ssl_context)
return BoltSocket._handshake(s, resolved_address)
return BoltSocket._handshake(s, resolved_address, deadline)
except (BoltError, DriverError, OSError) as error:
try:
local_port = s.getsockname()[1]
Expand Down
3 changes: 3 additions & 0 deletions src/neo4j/_deadline.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def from_timeout_or_deadline(cls, timeout):
return timeout
return cls(timeout)

def __str__(self):
return f"Deadline(timeout={self._original_timeout})"


merge_deadlines = min

Expand Down
Loading