Skip to content

Fix bolt handshake not having a timeout #915

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 17 additions & 24 deletions neo4j/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,16 +279,19 @@ def get_handshake(cls):
return b"".join(version.to_bytes() for version in offered_versions).ljust(16, b"\x00")

@classmethod
def ping(cls, address, *, timeout=None, pool_config=None):
def ping(cls, address, *, deadline=None, pool_config=None):
""" Attempt to establish a Bolt connection, returning the
agreed Bolt protocol version if successful.
"""
if pool_config is None:
pool_config = PoolConfig()
if deadline is None:
deadline = Deadline(None)
try:
s, protocol_version, handshake, data = BoltSocket.connect(
address,
timeout=timeout,
tcp_timeout=pool_config.connection_timeout,
deadline=deadline,
custom_resolver=pool_config.resolver,
ssl_context=pool_config.get_ssl_context(),
keep_alive=pool_config.keep_alive,
Expand All @@ -300,38 +303,30 @@ def ping(cls, address, *, timeout=None, pool_config=None):
return protocol_version

@classmethod
def open(cls, address, *, auth=None, timeout=None, routing_context=None,
def open(cls, address, *, auth=None, deadline=None, routing_context=None,
pool_config=None):
""" Open a new Bolt connection to a given server address.

:param address:
:param auth:
:param timeout: the connection timeout in seconds
:param deadline: how long to wait for the connection to be established
:param routing_context: dict containing routing context
:param pool_config:
:return:
:raise BoltHandshakeError: raised if the Bolt Protocol can not negotiate a protocol version.
:raise ServiceUnavailable: raised if there was a connection issue.
"""
def time_remaining():
if timeout is None:
return None
t = timeout - (perf_counter() - t0)
return t if t > 0 else 0

t0 = perf_counter()
if pool_config is None:
pool_config = PoolConfig()
if deadline is None:
deadline = Deadline(None)

socket_connection_timeout = pool_config.connection_timeout
if socket_connection_timeout is None:
socket_connection_timeout = time_remaining()
elif timeout is not None:
socket_connection_timeout = min(pool_config.connection_timeout,
time_remaining())
s, pool_config.protocol_version, handshake, data = BoltSocket.connect(
address,
timeout=socket_connection_timeout,
tcp_timeout=pool_config.connection_timeout,
deadline=deadline,
custom_resolver=pool_config.resolver,
ssl_context=pool_config.get_ssl_context(),
keep_alive=pool_config.keep_alive,
Expand Down Expand Up @@ -370,7 +365,7 @@ def time_remaining():
)

try:
connection.socket.set_deadline(time_remaining())
connection.socket.set_deadline(deadline)
try:
connection.hello()
finally:
Expand Down Expand Up @@ -732,9 +727,7 @@ def connection_creator():
released_reservation = False
try:
try:
connection = self.opener(
address, deadline.to_timeout()
)
connection = self.opener(address, deadline)
except ServiceUnavailable:
self.deactivate(address)
raise
Expand Down Expand Up @@ -909,9 +902,9 @@ def open(cls, address, *, auth, pool_config, workspace_config):
:return: BoltPool
"""

def opener(addr, timeout):
def opener(addr, deadline):
return Bolt.open(
addr, auth=auth, timeout=timeout, routing_context=None,
addr, auth=auth, deadline=deadline, routing_context=None,
pool_config=pool_config
)

Expand Down Expand Up @@ -955,8 +948,8 @@ def open(cls, *addresses, auth, pool_config, workspace_config, routing_context=N
raise ConfigurationError("The key 'address' is reserved for routing context.")
routing_context["address"] = str(address)

def opener(addr, timeout):
return Bolt.open(addr, auth=auth, timeout=timeout,
def opener(addr, deadline):
return Bolt.open(addr, auth=auth, deadline=deadline,
routing_context=routing_context,
pool_config=pool_config)

Expand Down
64 changes: 39 additions & 25 deletions neo4j/io/_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,12 @@ def _secure(cls, s, host, ssl_context):
return s

@classmethod
def _handshake(cls, s, resolved_address):
def _handshake(cls, s, resolved_address, deadline):
"""

:param s: Socket
:param resolved_address:
:param deadline:

:return: (socket, version, client_handshake, server_response_data)
"""
Expand All @@ -214,46 +215,52 @@ def _handshake(cls, s, resolved_address):
log.debug("[#%04X] C: <HANDSHAKE> %s %s %s %s", local_port,
*supported_versions)

data = cls.Bolt.MAGIC_PREAMBLE + cls.Bolt.get_handshake()
s.sendall(data)
request = cls.Bolt.MAGIC_PREAMBLE + cls.Bolt.get_handshake()

# Handle the handshake response
ready_to_read = False
with selectors.DefaultSelector() as selector:
selector.register(s, selectors.EVENT_READ)
selector.select(1)
original_timeout = s.gettimeout()
s.settimeout(deadline.to_timeout())
try:
data = s.recv(4)
except OSError:
s.sendall(request)
response = s.recv(4)
except OSError as exc:
raise ServiceUnavailable(
"Failed to read any data from server {!r} "
"after connected".format(resolved_address))
data_size = len(data)
f"Failed to read any data from server {resolved_address!r} "
f"after connected (deadline {deadline})"
) from exc
finally:
s.settimeout(original_timeout)
data_size = len(response)
if data_size == 0:
# If no data is returned after a successful select
# response, the server has closed the connection
log.debug("[#%04X] S: <CLOSE>", local_port)
cls.close_socket(s)
raise ServiceUnavailable(
"Connection to {address} closed without handshake response".format(
address=resolved_address))
f"Connection to {resolved_address} closed without handshake "
"response"
)
if data_size != 4:
# Some garbled data has been received
log.debug("[#%04X] S: @*#!", local_port)
cls.close_socket(s)
raise BoltProtocolError(
"Expected four byte Bolt handshake response from %r, received %r instead; check for incorrect port number" % (
resolved_address, data), address=resolved_address)
elif data == b"HTTP":
"Expected four byte Bolt handshake response from "
f"{resolved_address!r}, received {response!r} instead; "
"check for incorrect port number"
, address=resolved_address
)
elif response == b"HTTP":
log.debug("[#%04X] S: <CLOSE>", local_port)
cls.close_socket(s)
raise ServiceUnavailable(
"Cannot to connect to Bolt service on {!r} "
"(looks like HTTP)".format(resolved_address))
agreed_version = data[-1], data[-2]
f"Cannot to connect to Bolt service on {resolved_address!r} "
"(looks like HTTP)"
)
agreed_version = response[-1], response[-2]
log.debug("[#%04X] S: <HANDSHAKE> 0x%06X%02X", local_port,
agreed_version[1], agreed_version[0])
return cls(s), agreed_version, handshake, data
return cls(s), agreed_version, handshake, response

@classmethod
def close_socket(cls, socket_):
Expand All @@ -269,8 +276,8 @@ def close_socket(cls, socket_):
pass

@classmethod
def connect(cls, address, *, timeout, custom_resolver, ssl_context,
keep_alive):
def connect(cls, address, *, tcp_timeout, deadline, custom_resolver,
ssl_context, keep_alive):
""" Connect and perform a handshake and return a valid Connection object,
assuming a protocol version can be agreed.
"""
Expand All @@ -281,12 +288,19 @@ def connect(cls, address, *, timeout, custom_resolver, ssl_context,

resolved_addresses = Address(address).resolve(resolver=custom_resolver)
for resolved_address in resolved_addresses:
deadline_timeout = deadline.to_timeout()
if (
deadline_timeout is not None
and deadline_timeout <= tcp_timeout
):
tcp_timeout = deadline_timeout
s = None
try:
s = BoltSocket._connect(resolved_address, timeout, keep_alive)
s = BoltSocket._connect(resolved_address, tcp_timeout,
keep_alive)
s = BoltSocket._secure(s, resolved_address.host_name,
ssl_context)
return BoltSocket._handshake(s, resolved_address)
return BoltSocket._handshake(s, resolved_address, deadline)
except (BoltError, DriverError, OSError) as error:
try:
local_port = s.getsockname()[1]
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/io/test_direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,19 +113,19 @@ class BoltTestCase(TestCase):

def test_open(self):
with pytest.raises(ServiceUnavailable):
connection = Bolt.open(("localhost", 9999), auth=("test", "test"))
Bolt.open(("localhost", 9999), auth=("test", "test"))

def test_open_timeout(self):
conf = PoolConfig()
with pytest.raises(ServiceUnavailable):
connection = Bolt.open(("localhost", 9999), auth=("test", "test"), timeout=1)
Bolt.open(("localhost", 9999), auth=("test", "test"),
deadline=Deadline(1))

def test_ping(self):
protocol_version = Bolt.ping(("localhost", 9999))
assert protocol_version is None

def test_ping_timeout(self):
protocol_version = Bolt.ping(("localhost", 9999), timeout=1)
protocol_version = Bolt.ping(("localhost", 9999), deadline=Deadline(1))
assert protocol_version is None


Expand Down