Skip to content

Commit 23cce6a

Browse files
authored
Harden driver against unexpected RESET responses (#1006)
The server has been observed to reply with `FAILURE` and `IGNORED` to `RESET` requests. The former is according to spec and the driver should drop the connection (which it didn't), the latter isn't. The right combination of those two unexpected responses at the right time could get the driver stuck in an infinite loop. This change makes the driver drop the connection in either case to gracefully handle the situation.
1 parent 018a49f commit 23cce6a

File tree

12 files changed

+288
-56
lines changed

12 files changed

+288
-56
lines changed

src/neo4j/_async/io/_bolt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -857,7 +857,7 @@ async def fetch_all(self):
857857
messages fetched
858858
"""
859859
detail_count = summary_count = 0
860-
while self.responses:
860+
while not self._closed and self.responses:
861861
response = self.responses[0]
862862
while not response.complete:
863863
detail_delta, summary_delta = await self.fetch_message()

src/neo4j/_async/io/_bolt3.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
check_supported_server_product,
4646
CommitResponse,
4747
InitResponse,
48+
ResetResponse,
4849
Response,
4950
)
5051

@@ -391,17 +392,14 @@ def rollback(self, dehydration_hooks=None, hydration_hooks=None,
391392
dehydration_hooks=dehydration_hooks)
392393

393394
async def reset(self, dehydration_hooks=None, hydration_hooks=None):
394-
""" Add a RESET message to the outgoing queue, send
395-
it and consume all remaining messages.
396-
"""
397-
398-
def fail(metadata):
399-
raise BoltProtocolError("RESET failed %r" % metadata, address=self.unresolved_address)
395+
"""Reset the connection.
400396
397+
Add a RESET message to the outgoing queue, send it and consume all
398+
remaining messages.
399+
"""
401400
log.debug("[#%04X] C: RESET", self.local_port)
402-
self._append(b"\x0F",
403-
response=Response(self, "reset", hydration_hooks,
404-
on_failure=fail),
401+
response = ResetResponse(self, "reset", hydration_hooks)
402+
self._append(b"\x0F", response=response,
405403
dehydration_hooks=dehydration_hooks)
406404
await self.send_all()
407405
await self.fetch_all()

src/neo4j/_async/io/_bolt4.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
check_supported_server_product,
4848
CommitResponse,
4949
InitResponse,
50+
ResetResponse,
5051
Response,
5152
)
5253

@@ -311,17 +312,14 @@ def rollback(self, dehydration_hooks=None, hydration_hooks=None,
311312
dehydration_hooks=dehydration_hooks)
312313

313314
async def reset(self, dehydration_hooks=None, hydration_hooks=None):
314-
""" Add a RESET message to the outgoing queue, send
315-
it and consume all remaining messages.
316-
"""
317-
318-
def fail(metadata):
319-
raise BoltProtocolError("RESET failed %r" % metadata, self.unresolved_address)
315+
"""Reset the connection.
320316
317+
Add a RESET message to the outgoing queue, send it and consume all
318+
remaining messages.
319+
"""
321320
log.debug("[#%04X] C: RESET", self.local_port)
322-
self._append(b"\x0F",
323-
response=Response(self, "reset", hydration_hooks,
324-
on_failure=fail),
321+
response = ResetResponse(self, "reset", hydration_hooks)
322+
self._append(b"\x0F", response=response,
325323
dehydration_hooks=dehydration_hooks)
326324
await self.send_all()
327325
await self.fetch_all()

src/neo4j/_async/io/_bolt5.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
CommitResponse,
5151
InitResponse,
5252
LogonResponse,
53+
ResetResponse,
5354
Response,
5455
)
5556

@@ -314,15 +315,9 @@ async def reset(self, dehydration_hooks=None, hydration_hooks=None):
314315
Add a RESET message to the outgoing queue, send it and consume all
315316
remaining messages.
316317
"""
317-
318-
def fail(metadata):
319-
raise BoltProtocolError("RESET failed %r" % metadata,
320-
self.unresolved_address)
321-
322318
log.debug("[#%04X] C: RESET", self.local_port)
323-
self._append(b"\x0F",
324-
response=Response(self, "reset", hydration_hooks,
325-
on_failure=fail),
319+
response = ResetResponse(self, "reset", hydration_hooks)
320+
self._append(b"\x0F", response=response,
326321
dehydration_hooks=dehydration_hooks)
327322
await self.send_all()
328323
await self.fetch_all()

src/neo4j/_async/io/_common.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,26 @@ async def on_failure(self, metadata):
281281
raise Neo4jError.hydrate(**metadata)
282282

283283

284+
class ResetResponse(Response):
285+
async def _unexpected_message(self, response):
286+
log.warning("[#%04X] _: <CONNECTION> RESET received %s "
287+
"(unexpected response) => dropping connection",
288+
self.connection.local_port, response)
289+
await self.connection.close()
290+
291+
async def on_records(self, records):
292+
await self._unexpected_message("RECORD")
293+
294+
async def on_success(self, metadata):
295+
pass
296+
297+
async def on_failure(self, metadata):
298+
await self._unexpected_message("FAILURE")
299+
300+
async def on_ignored(self, metadata=None):
301+
await self._unexpected_message("IGNORED")
302+
303+
284304
class CommitResponse(Response):
285305
pass
286306

src/neo4j/_sync/io/_bolt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -857,7 +857,7 @@ def fetch_all(self):
857857
messages fetched
858858
"""
859859
detail_count = summary_count = 0
860-
while self.responses:
860+
while not self._closed and self.responses:
861861
response = self.responses[0]
862862
while not response.complete:
863863
detail_delta, summary_delta = self.fetch_message()

src/neo4j/_sync/io/_bolt3.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
check_supported_server_product,
4646
CommitResponse,
4747
InitResponse,
48+
ResetResponse,
4849
Response,
4950
)
5051

@@ -391,17 +392,14 @@ def rollback(self, dehydration_hooks=None, hydration_hooks=None,
391392
dehydration_hooks=dehydration_hooks)
392393

393394
def reset(self, dehydration_hooks=None, hydration_hooks=None):
394-
""" Add a RESET message to the outgoing queue, send
395-
it and consume all remaining messages.
396-
"""
397-
398-
def fail(metadata):
399-
raise BoltProtocolError("RESET failed %r" % metadata, address=self.unresolved_address)
395+
"""Reset the connection.
400396
397+
Add a RESET message to the outgoing queue, send it and consume all
398+
remaining messages.
399+
"""
401400
log.debug("[#%04X] C: RESET", self.local_port)
402-
self._append(b"\x0F",
403-
response=Response(self, "reset", hydration_hooks,
404-
on_failure=fail),
401+
response = ResetResponse(self, "reset", hydration_hooks)
402+
self._append(b"\x0F", response=response,
405403
dehydration_hooks=dehydration_hooks)
406404
self.send_all()
407405
self.fetch_all()

src/neo4j/_sync/io/_bolt4.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
check_supported_server_product,
4848
CommitResponse,
4949
InitResponse,
50+
ResetResponse,
5051
Response,
5152
)
5253

@@ -311,17 +312,14 @@ def rollback(self, dehydration_hooks=None, hydration_hooks=None,
311312
dehydration_hooks=dehydration_hooks)
312313

313314
def reset(self, dehydration_hooks=None, hydration_hooks=None):
314-
""" Add a RESET message to the outgoing queue, send
315-
it and consume all remaining messages.
316-
"""
317-
318-
def fail(metadata):
319-
raise BoltProtocolError("RESET failed %r" % metadata, self.unresolved_address)
315+
"""Reset the connection.
320316
317+
Add a RESET message to the outgoing queue, send it and consume all
318+
remaining messages.
319+
"""
321320
log.debug("[#%04X] C: RESET", self.local_port)
322-
self._append(b"\x0F",
323-
response=Response(self, "reset", hydration_hooks,
324-
on_failure=fail),
321+
response = ResetResponse(self, "reset", hydration_hooks)
322+
self._append(b"\x0F", response=response,
325323
dehydration_hooks=dehydration_hooks)
326324
self.send_all()
327325
self.fetch_all()

src/neo4j/_sync/io/_bolt5.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
CommitResponse,
5151
InitResponse,
5252
LogonResponse,
53+
ResetResponse,
5354
Response,
5455
)
5556

@@ -314,15 +315,9 @@ def reset(self, dehydration_hooks=None, hydration_hooks=None):
314315
Add a RESET message to the outgoing queue, send it and consume all
315316
remaining messages.
316317
"""
317-
318-
def fail(metadata):
319-
raise BoltProtocolError("RESET failed %r" % metadata,
320-
self.unresolved_address)
321-
322318
log.debug("[#%04X] C: RESET", self.local_port)
323-
self._append(b"\x0F",
324-
response=Response(self, "reset", hydration_hooks,
325-
on_failure=fail),
319+
response = ResetResponse(self, "reset", hydration_hooks)
320+
self._append(b"\x0F", response=response,
326321
dehydration_hooks=dehydration_hooks)
327322
self.send_all()
328323
self.fetch_all()

src/neo4j/_sync/io/_common.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,26 @@ def on_failure(self, metadata):
281281
raise Neo4jError.hydrate(**metadata)
282282

283283

284+
class ResetResponse(Response):
285+
def _unexpected_message(self, response):
286+
log.warning("[#%04X] _: <CONNECTION> RESET received %s "
287+
"(unexpected response) => dropping connection",
288+
self.connection.local_port, response)
289+
self.connection.close()
290+
291+
def on_records(self, records):
292+
self._unexpected_message("RECORD")
293+
294+
def on_success(self, metadata):
295+
pass
296+
297+
def on_failure(self, metadata):
298+
self._unexpected_message("FAILURE")
299+
300+
def on_ignored(self, metadata=None):
301+
self._unexpected_message("IGNORED")
302+
303+
284304
class CommitResponse(Response):
285305
pass
286306

tests/unit/async_/io/test__common.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,14 @@
1414
# limitations under the License.
1515

1616

17+
import logging
18+
1719
import pytest
1820

19-
from neo4j._async.io._common import AsyncOutbox
21+
from neo4j._async.io._common import (
22+
AsyncOutbox,
23+
ResetResponse,
24+
)
2025
from neo4j._codec.packstream.v1 import PackableBuffer
2126

2227
from ...._async_compat import mark_async_test
@@ -56,3 +61,103 @@ async def test_async_outbox_chunking(chunk_size, data, result, mocker):
5661

5762
assert not await outbox.flush()
5863
socket_mock.sendall.assert_awaited_once()
64+
65+
66+
def get_handler_arg(response):
67+
if response == "RECORD":
68+
return []
69+
elif response == "IGNORED":
70+
return {}
71+
elif response == "FAILURE":
72+
return {}
73+
elif response == "SUCCESS":
74+
return {}
75+
else:
76+
raise ValueError(f"Unexpected response: {response}")
77+
78+
79+
def call_handler(handler, response, arg=None):
80+
if arg is None:
81+
arg = get_handler_arg(response)
82+
83+
if response == "RECORD":
84+
return handler.on_records(arg)
85+
elif response == "IGNORED":
86+
return handler.on_ignored(arg)
87+
elif response == "FAILURE":
88+
return handler.on_failure(arg)
89+
elif response == "SUCCESS":
90+
return handler.on_success(arg)
91+
else:
92+
raise ValueError(f"Unexpected response: {response}")
93+
94+
95+
@pytest.mark.parametrize(
96+
("response", "unexpected"),
97+
(
98+
("RECORD", True),
99+
("IGNORED", True),
100+
("FAILURE", True),
101+
("SUCCESS", False),
102+
)
103+
)
104+
@mark_async_test
105+
async def test_reset_response_closes_connection_on_unexpected_responses(
106+
response, unexpected, async_fake_connection
107+
):
108+
handler = ResetResponse(async_fake_connection, "reset", {})
109+
async_fake_connection.close.assert_not_called()
110+
111+
await call_handler(handler, response)
112+
113+
if unexpected:
114+
async_fake_connection.close.assert_awaited_once()
115+
else:
116+
async_fake_connection.close.assert_not_called()
117+
118+
119+
@pytest.mark.parametrize(
120+
("response", "unexpected"),
121+
(
122+
("RECORD", True),
123+
("IGNORED", True),
124+
("FAILURE", True),
125+
("SUCCESS", False),
126+
)
127+
)
128+
@mark_async_test
129+
async def test_reset_response_logs_warning_on_unexpected_responses(
130+
response, unexpected, async_fake_connection, caplog
131+
):
132+
handler = ResetResponse(async_fake_connection, "reset", {})
133+
134+
with caplog.at_level(logging.WARNING):
135+
await call_handler(handler, response)
136+
137+
log_message_found = any("RESET" in msg and "unexpected response" in msg
138+
for msg in caplog.messages)
139+
if unexpected:
140+
assert log_message_found
141+
else:
142+
assert not log_message_found
143+
144+
145+
@pytest.mark.parametrize("response",
146+
("RECORD", "IGNORED", "FAILURE", "SUCCESS"))
147+
@mark_async_test
148+
async def test_reset_response_never_calls_handlers(
149+
response, async_fake_connection, mocker
150+
):
151+
handlers = {
152+
key: mocker.AsyncMock(name=key)
153+
for key in
154+
("on_records", "on_ignored", "on_failure", "on_success", "on_summary")
155+
}
156+
157+
handler = ResetResponse(async_fake_connection, "reset", {}, **handlers)
158+
159+
arg = get_handler_arg(response)
160+
await call_handler(handler, response, arg)
161+
162+
for handler in handlers.values():
163+
handler.assert_not_called()

0 commit comments

Comments
 (0)