26
26
from neo4j .bolt import ConnectionPool , ServiceUnavailable , ProtocolError , DEFAULT_PORT , connect
27
27
from neo4j .compat .collections import MutableSet , OrderedDict
28
28
from neo4j .exceptions import CypherError
29
- from neo4j .v1 .api import Driver , READ_ACCESS , WRITE_ACCESS
29
+ from neo4j .v1 .api import Driver , READ_ACCESS , WRITE_ACCESS , fix_statement , fix_parameters
30
30
from neo4j .v1 .exceptions import SessionExpired
31
31
from neo4j .v1 .security import SecurityPlan
32
32
from neo4j .v1 .session import BoltSession
33
+ from neo4j .util import ServerVersion
33
34
34
35
35
36
class RoundRobinSet (MutableSet ):
@@ -131,11 +132,12 @@ def __init__(self, routers=(), readers=(), writers=(), ttl=0):
131
132
self .last_updated_time = self .timer ()
132
133
self .ttl = ttl
133
134
134
- def is_fresh (self ):
135
+ def is_fresh (self , access_mode ):
135
136
""" Indicator for whether routing information is still usable.
136
137
"""
137
138
expired = self .last_updated_time + self .ttl <= self .timer ()
138
- return not expired and len (self .routers ) > 1 and self .readers and self .writers
139
+ has_server_for_mode = (access_mode == READ_ACCESS and self .readers ) or (access_mode == WRITE_ACCESS and self .writers )
140
+ return not expired and self .routers and has_server_for_mode
139
141
140
142
def update (self , new_routing_table ):
141
143
""" Update the current routing table with new routing information
@@ -148,16 +150,34 @@ def update(self, new_routing_table):
148
150
self .ttl = new_routing_table .ttl
149
151
150
152
153
+ class RoutingSession (BoltSession ):
154
+
155
+ call_get_servers = "CALL dbms.cluster.routing.getServers"
156
+ get_routing_table_param = "context"
157
+ call_get_routing_table = "CALL dbms.cluster.routing.getRoutingTable({%s})" % get_routing_table_param
158
+
159
+ def routing_info_procedure (self , routing_context ):
160
+ if ServerVersion .from_str (self ._connection .server .version ).at_least_version (3 , 2 ):
161
+ return self .call_get_routing_table , {self .get_routing_table_param : routing_context }
162
+ else :
163
+ return self .call_get_servers , {}
164
+
165
+ def __run__ (self , ignored , routing_context ):
166
+ # the statement is ignored as it will be get routing table procedure call.
167
+ statement , parameters = self .routing_info_procedure (routing_context )
168
+ return self ._run (fix_statement (statement ), fix_parameters (parameters ))
169
+
170
+
151
171
class RoutingConnectionPool (ConnectionPool ):
152
172
""" Connection pool with routing table.
153
173
"""
154
174
155
- routing_info_procedure = "dbms.cluster.routing.getServers"
156
-
157
- def __init__ (self , connector , initial_address , * routers ):
175
+ def __init__ (self , connector , initial_address , routing_context , * routers ):
158
176
super (RoutingConnectionPool , self ).__init__ (connector )
159
177
self .initial_address = initial_address
178
+ self .routing_context = routing_context
160
179
self .routing_table = RoutingTable (routers )
180
+ self .missing_writer = False
161
181
self .refresh_lock = Lock ()
162
182
163
183
def fetch_routing_info (self , address ):
@@ -170,8 +190,8 @@ def fetch_routing_info(self, address):
170
190
if routing support is broken
171
191
"""
172
192
try :
173
- with BoltSession (lambda _ : self .acquire_direct (address )) as session :
174
- return list (session .run ("CALL %s" % self .routing_info_procedure ))
193
+ with RoutingSession (lambda _ : self .acquire_direct (address )) as session :
194
+ return list (session .run ("ignored" , self .routing_context ))
175
195
except CypherError as error :
176
196
if error .code == "Neo.ClientError.Procedure.ProcedureNotFound" :
177
197
raise ServiceUnavailable ("Server {!r} does not support routing" .format (address ))
@@ -200,6 +220,11 @@ def fetch_routing_table(self, address):
200
220
num_readers = len (new_routing_table .readers )
201
221
num_writers = len (new_routing_table .writers )
202
222
223
+ # No writers are available. This likely indicates a temporary state,
224
+ # such as leader switching, so we should not signal an error.
225
+ # When no writers available, then we flag we are reading in absence of writer
226
+ self .missing_writer = (num_writers == 0 )
227
+
203
228
# No routers
204
229
if num_routers == 0 :
205
230
raise ProtocolError ("No routing servers returned from server %r" % (address ,))
@@ -208,12 +233,6 @@ def fetch_routing_table(self, address):
208
233
if num_readers == 0 :
209
234
raise ProtocolError ("No read servers returned from server %r" % (address ,))
210
235
211
- # No writers
212
- if num_writers == 0 :
213
- # No writers are available. This likely indicates a temporary state,
214
- # such as leader switching, so we should not signal an error.
215
- return None
216
-
217
236
# At least one of each is fine, so return this table
218
237
return new_routing_table
219
238
@@ -234,21 +253,30 @@ def update_routing_table(self):
234
253
"""
235
254
# copied because it can be modified
236
255
copy_of_routers = list (self .routing_table .routers )
256
+
257
+ has_tried_initial_routers = False
258
+ if self .missing_writer :
259
+ has_tried_initial_routers = True
260
+ if self .update_routing_table_with_routers (resolve (self .initial_address )):
261
+ return
262
+
237
263
if self .update_routing_table_with_routers (copy_of_routers ):
238
264
return
239
265
240
- initial_routers = resolve (self .initial_address )
241
- for router in copy_of_routers :
242
- if router in initial_routers :
243
- initial_routers .remove (router )
244
- if initial_routers :
245
- if self .update_routing_table_with_routers (initial_routers ):
246
- return
266
+ if not has_tried_initial_routers :
267
+ initial_routers = resolve (self .initial_address )
268
+ for router in copy_of_routers :
269
+ if router in initial_routers :
270
+ initial_routers .remove (router )
271
+ if initial_routers :
272
+ if self .update_routing_table_with_routers (initial_routers ):
273
+ return
274
+
247
275
248
276
# None of the routers have been successful, so just fail
249
277
raise ServiceUnavailable ("Unable to retrieve routing information" )
250
278
251
- def refresh_routing_table (self ):
279
+ def ensure_routing_table_is_fresh (self , access_mode ):
252
280
""" Update the routing table if stale.
253
281
254
282
This method performs two freshness checks, before and after acquiring
@@ -261,10 +289,13 @@ def refresh_routing_table(self):
261
289
262
290
:return: `True` if an update was required, `False` otherwise.
263
291
"""
264
- if self .routing_table .is_fresh ():
292
+ if self .routing_table .is_fresh (access_mode ):
265
293
return False
266
294
with self .refresh_lock :
267
- if self .routing_table .is_fresh ():
295
+ if self .routing_table .is_fresh (access_mode ):
296
+ if access_mode == READ_ACCESS :
297
+ # if reader is fresh but writers is not fresh, then we are reading in absence of writer
298
+ self .missing_writer = not self .routing_table .is_fresh (WRITE_ACCESS )
268
299
return False
269
300
self .update_routing_table ()
270
301
return True
@@ -278,18 +309,20 @@ def acquire(self, access_mode=None):
278
309
server_list = self .routing_table .writers
279
310
else :
280
311
raise ValueError ("Unsupported access mode {}" .format (access_mode ))
312
+
313
+ self .ensure_routing_table_is_fresh (access_mode )
281
314
while True :
282
- address = None
283
- while address is None :
284
- self .refresh_routing_table ()
285
- address = next (server_list )
315
+ address = next (server_list )
316
+ if address is None :
317
+ break
286
318
try :
287
319
connection = self .acquire_direct (address ) # should always be a resolved address
288
320
connection .Error = SessionExpired
289
321
except ServiceUnavailable :
290
322
self .remove (address )
291
323
else :
292
324
return connection
325
+ raise SessionExpired ("Failed to obtain connection towards '%s' server." % access_mode )
293
326
294
327
def remove (self , address ):
295
328
""" Remove an address from the connection pool, if present, closing
@@ -313,6 +346,7 @@ def __init__(self, uri, **config):
313
346
self .initial_address = initial_address = SocketAddress .from_uri (uri , DEFAULT_PORT )
314
347
self .security_plan = security_plan = SecurityPlan .build (** config )
315
348
self .encrypted = security_plan .encrypted
349
+ routing_context = SocketAddress .parse_routing_context (uri )
316
350
if not security_plan .routing_compatible :
317
351
# this error message is case-specific as there is only one incompatible
318
352
# scenario right now
@@ -321,7 +355,7 @@ def __init__(self, uri, **config):
321
355
def connector (a ):
322
356
return connect (a , security_plan .ssl_context , ** config )
323
357
324
- pool = RoutingConnectionPool (connector , initial_address , * resolve (initial_address ))
358
+ pool = RoutingConnectionPool (connector , initial_address , routing_context , * resolve (initial_address ))
325
359
try :
326
360
pool .update_routing_table ()
327
361
except :
0 commit comments