Skip to content

Commit 9edcf42

Browse files
committed
Improve connection error handling infra
Connection now takes a dedicated error handler object in constructor instead of having one dynamically attached after creation.
1 parent e41b65b commit 9edcf42

File tree

6 files changed

+78
-42
lines changed

6 files changed

+78
-42
lines changed

neo4j/bolt/connection.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,26 @@ def supports_bytes(self):
119119
return self.version_info() >= (3, 2)
120120

121121

122+
class ConnectionErrorHandler(object):
123+
""" A handler for send and receive errors.
124+
"""
125+
126+
def __init__(self, handlers_by_error_class=None):
127+
if handlers_by_error_class is None:
128+
handlers_by_error_class = {}
129+
130+
self.handlers_by_error_class = handlers_by_error_class
131+
self.known_errors = tuple(handlers_by_error_class.keys())
132+
133+
def handle(self, error, address):
134+
try:
135+
error_class = error.__class__
136+
handler = self.handlers_by_error_class[error_class]
137+
handler(address)
138+
except KeyError:
139+
pass
140+
141+
122142
class Connection(object):
123143
""" Server connection for Bolt protocol v1.
124144
@@ -144,15 +164,14 @@ class Connection(object):
144164
#: Error class used for raising connection errors
145165
Error = ServiceUnavailable
146166

147-
#: The function to handle send and receive errors
148-
error_handler = None
149-
150167
_supports_statement_reuse = False
151168

152169
_last_run_statement = None
153170

154-
def __init__(self, sock, **config):
171+
def __init__(self, address, sock, error_handler, **config):
172+
self.address = address
155173
self.socket = sock
174+
self.error_handler = error_handler
156175
self.server = ServerInfo(SocketAddress.from_socket(sock))
157176
self.input_buffer = ChunkedInputBuffer()
158177
self.output_buffer = ChunkedOutputBuffer()
@@ -242,9 +261,8 @@ def reset(self):
242261
def send(self):
243262
try:
244263
self._send()
245-
except Exception as error:
246-
if self.error_handler is not None:
247-
self.error_handler(error)
264+
except self.error_handler.known_errors as error:
265+
self.error_handler.handle(error, self.address)
248266
raise error
249267

250268
def _send(self):
@@ -263,9 +281,8 @@ def _send(self):
263281
def fetch(self):
264282
try:
265283
return self._fetch()
266-
except Exception as error:
267-
if self.error_handler is not None:
268-
self.error_handler(error)
284+
except self.error_handler.known_errors as error:
285+
self.error_handler.handle(error, self.address)
269286
raise error
270287

271288
def _fetch(self):
@@ -379,8 +396,9 @@ class ConnectionPool(object):
379396

380397
_closed = False
381398

382-
def __init__(self, connector):
399+
def __init__(self, connector, connection_error_handler):
383400
self.connector = connector
401+
self.connection_error_handler = connection_error_handler
384402
self.connections = {}
385403
self.lock = RLock()
386404

@@ -414,7 +432,7 @@ def acquire_direct(self, address):
414432
connection.in_use = True
415433
return connection
416434
try:
417-
connection = self.connector(address)
435+
connection = self.connector(address, self.connection_error_handler)
418436
except ServiceUnavailable:
419437
self.remove(address)
420438
raise
@@ -476,7 +494,7 @@ def closed(self):
476494
return self._closed
477495

478496

479-
def connect(address, ssl_context=None, **config):
497+
def connect(address, ssl_context=None, error_handler=None, **config):
480498
""" Connect and perform a handshake and return a valid Connection object, assuming
481499
a protocol version can be agreed.
482500
"""
@@ -582,7 +600,8 @@ def connect(address, ssl_context=None, **config):
582600
s.shutdown(SHUT_RDWR)
583601
s.close()
584602
elif agreed_version == 1:
585-
return Connection(s, der_encoded_server_certificate=der_encoded_server_certificate, **config)
603+
return Connection(address, s, der_encoded_server_certificate=der_encoded_server_certificate,
604+
error_handler=error_handler, **config)
586605
elif agreed_version == 0x48545450:
587606
log_error("S: [CLOSE]")
588607
s.close()

neo4j/v1/direct.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,25 @@
2020

2121

2222
from neo4j.addressing import SocketAddress, resolve
23-
from neo4j.bolt import DEFAULT_PORT, ConnectionPool, connect
23+
from neo4j.bolt import DEFAULT_PORT, ConnectionPool, connect, ConnectionErrorHandler
2424
from neo4j.exceptions import ServiceUnavailable
2525
from neo4j.v1.api import Driver
2626
from neo4j.v1.security import SecurityPlan
2727
from neo4j.v1.session import BoltSession
2828

2929

30+
class DirectConnectionErrorHandler(ConnectionErrorHandler):
31+
""" Handler for errors in direct driver connections.
32+
"""
33+
34+
def __init__(self):
35+
super(DirectConnectionErrorHandler, self).__init__({}) # does not need to handle errors
36+
37+
3038
class DirectConnectionPool(ConnectionPool):
3139

3240
def __init__(self, connector, address):
33-
super(DirectConnectionPool, self).__init__(connector)
41+
super(DirectConnectionPool, self).__init__(connector, DirectConnectionErrorHandler())
3442
self.address = address
3543

3644
def acquire(self, access_mode=None):
@@ -61,7 +69,11 @@ def __init__(self, uri, **config):
6169
self.address = SocketAddress.from_uri(uri, DEFAULT_PORT)
6270
self.security_plan = security_plan = SecurityPlan.build(**config)
6371
self.encrypted = security_plan.encrypted
64-
pool = DirectConnectionPool(lambda a: connect(a, security_plan.ssl_context, **config), self.address)
72+
73+
def connector(address, error_handler):
74+
return connect(address, security_plan.ssl_context, error_handler, **config)
75+
76+
pool = DirectConnectionPool(connector, self.address)
6577
pool.release(pool.acquire())
6678
Driver.__init__(self, pool, **config)
6779

neo4j/v1/routing.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from time import clock
2424

2525
from neo4j.addressing import SocketAddress, resolve
26-
from neo4j.bolt import ConnectionPool, ServiceUnavailable, ProtocolError, DEFAULT_PORT, connect
26+
from neo4j.bolt import ConnectionPool, ServiceUnavailable, ProtocolError, DEFAULT_PORT, connect, ConnectionErrorHandler
2727
from neo4j.compat.collections import MutableSet, OrderedDict
2828
from neo4j.exceptions import CypherError, DatabaseUnavailableError, NotALeaderError, ForbiddenOnReadOnlyDatabaseError
2929
from neo4j.util import ServerVersion
@@ -32,7 +32,6 @@
3232
from neo4j.v1.security import SecurityPlan
3333
from neo4j.v1.session import BoltSession
3434

35-
3635
LOAD_BALANCING_STRATEGY_LEAST_CONNECTED = 0
3736
LOAD_BALANCING_STRATEGY_ROUND_ROBIN = 1
3837
LOAD_BALANCING_STRATEGY_DEFAULT = LOAD_BALANCING_STRATEGY_LEAST_CONNECTED
@@ -247,15 +246,26 @@ def _select(self, offset, addresses):
247246
return least_connected_address
248247

249248

249+
class RoutingConnectionErrorHandler(ConnectionErrorHandler):
250+
""" Handler for errors in routing driver connections.
251+
"""
252+
253+
def __init__(self, pool):
254+
super(RoutingConnectionErrorHandler, self).__init__({
255+
SessionExpired: lambda address: pool.remove(address),
256+
ServiceUnavailable: lambda address: pool.remove(address),
257+
DatabaseUnavailableError: lambda address: pool.remove(address),
258+
NotALeaderError: lambda address: pool.remove_writer(address),
259+
ForbiddenOnReadOnlyDatabaseError: lambda address: pool.remove_writer(address)
260+
})
261+
262+
250263
class RoutingConnectionPool(ConnectionPool):
251264
""" Connection pool with routing table.
252265
"""
253266

254-
CLUSTER_MEMBER_FAILURE_ERRORS = (ServiceUnavailable, SessionExpired, DatabaseUnavailableError)
255-
WRITE_FAILURE_ERRORS = (NotALeaderError, ForbiddenOnReadOnlyDatabaseError)
256-
257267
def __init__(self, connector, initial_address, routing_context, *routers, **config):
258-
super(RoutingConnectionPool, self).__init__(connector)
268+
super(RoutingConnectionPool, self).__init__(connector, RoutingConnectionErrorHandler(self))
259269
self.initial_address = initial_address
260270
self.routing_context = routing_context
261271
self.routing_table = RoutingTable(routers)
@@ -402,21 +412,12 @@ def acquire(self, access_mode=None):
402412
try:
403413
connection = self.acquire_direct(address) # should always be a resolved address
404414
connection.Error = SessionExpired
405-
connection.error_handler = lambda error: self._handle_connection_error(address, error)
406415
except ServiceUnavailable:
407416
self.remove(address)
408417
else:
409418
return connection
410419
raise SessionExpired("Failed to obtain connection towards '%s' server." % access_mode)
411420

412-
def _handle_connection_error(self, address, error):
413-
""" Handle routing connection send or receive error.
414-
"""
415-
if isinstance(error, self.CLUSTER_MEMBER_FAILURE_ERRORS):
416-
self.remove(address)
417-
elif isinstance(error, self.WRITE_FAILURE_ERRORS):
418-
self._remove_writer(address)
419-
420421
def remove(self, address):
421422
""" Remove an address from the connection pool, if present, closing
422423
all connections to that address. Also remove from the routing table.
@@ -428,7 +429,7 @@ def remove(self, address):
428429
self.routing_table.writers.discard(address)
429430
super(RoutingConnectionPool, self).remove(address)
430431

431-
def _remove_writer(self, address):
432+
def remove_writer(self, address):
432433
""" Remove a writer address from the routing table, if present.
433434
"""
434435
self.routing_table.writers.discard(address)
@@ -450,8 +451,8 @@ def __init__(self, uri, **config):
450451
# scenario right now
451452
raise ValueError("TRUST_ON_FIRST_USE is not compatible with routing")
452453

453-
def connector(a):
454-
return connect(a, security_plan.ssl_context, **config)
454+
def connector(address, error_handler):
455+
return connect(address, security_plan.ssl_context, error_handler, **config)
455456

456457
pool = RoutingConnectionPool(connector, initial_address, routing_context, *resolve(initial_address), **config)
457458
try:

test/integration/test_connection.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from socket import create_connection
2323

24-
from neo4j.v1 import ConnectionPool, ServiceUnavailable
24+
from neo4j.v1 import ConnectionPool, ServiceUnavailable, DirectConnectionErrorHandler
2525

2626
from test.integration.tools import IntegrationTestCase
2727

@@ -45,10 +45,14 @@ def defunct(self):
4545
return False
4646

4747

48+
def connector(address, _):
49+
return QuickConnection(create_connection(address))
50+
51+
4852
class ConnectionPoolTestCase(IntegrationTestCase):
4953

5054
def setUp(self):
51-
self.pool = ConnectionPool(lambda a: QuickConnection(create_connection(a)))
55+
self.pool = ConnectionPool(connector, DirectConnectionErrorHandler())
5256

5357
def tearDown(self):
5458
self.pool.close()
@@ -104,7 +108,7 @@ def test_releasing_twice(self):
104108
self.assert_pool_size(address, 0, 1)
105109

106110
def test_cannot_acquire_after_close(self):
107-
with ConnectionPool(lambda a: QuickConnection(create_connection(a))) as pool:
111+
with ConnectionPool(lambda a: QuickConnection(create_connection(a)), DirectConnectionErrorHandler()) as pool:
108112
pool.close()
109113
with self.assertRaises(ServiceUnavailable):
110114
_ = pool.acquire_direct("X")

test/stub/test_routing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@
5050
UNREACHABLE_ADDRESS = ("127.0.0.1", 8080)
5151

5252

53-
def connector(address):
54-
return connect(address, auth=basic_auth("neotest", "neotest"))
53+
def connector(address, error_handler):
54+
return connect(address, error_handler=error_handler, auth=basic_auth("neotest", "neotest"))
5555

5656

5757
def RoutingPool(*routers):

test/unit/test_routing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@
5252
}
5353

5454

55-
def connector(address):
56-
return connect(address, auth=basic_auth("neotest", "neotest"))
55+
def connector(address, error_handler):
56+
return connect(address, error_handler=error_handler, auth=basic_auth("neotest", "neotest"))
5757

5858

5959
class RoundRobinSetTestCase(TestCase):

0 commit comments

Comments
 (0)