Skip to content

Commit 0e3d67e

Browse files
author
Zhen
committed
Fix after review
1 parent 9d8be07 commit 0e3d67e

File tree

4 files changed

+35
-25
lines changed

4 files changed

+35
-25
lines changed

neo4j/addressing.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@
2525
from neo4j.compat import urlparse
2626
from neo4j.exceptions import AddressError
2727

28+
try:
29+
from urllib.parse import parse_qs
30+
except ImportError:
31+
from urllib import parse_qs
32+
2833

2934
VALID_IPv4_SEGMENTS = [str(i).encode("latin1") for i in range(0x100)]
3035
VALID_IPv6_SEGMENT_CHARS = b"0123456789abcdef"
@@ -86,15 +91,14 @@ def parse_routing_context(cls, uri):
8691
return {}
8792

8893
context = {}
89-
parameters = [x for x in query.split('&') if x]
90-
for keyValue in parameters:
91-
pair = keyValue.split('=')
92-
if len(pair) != 2 or not pair[0] or not pair[1]:
93-
raise ValueError("Invalid parameters: '%s' in URI '%s'." % (keyValue, uri))
94-
key = pair[0]
95-
value = pair[1]
96-
if key in context:
97-
raise ValueError("Duplicated query parameters with key '%s' found in URL '%s'" % (key, uri))
94+
parameters = parse_qs(query, True)
95+
for key in parameters:
96+
value_list = parameters[key]
97+
if len(value_list) != 1:
98+
raise ValueError("Duplicated query parameters with key '%s', value '%s' found in URL '%s'" % (key, value_list, uri))
99+
value = value_list[0]
100+
if not value:
101+
raise ValueError("Invalid parameters: key '%s', value '%s' in URI '%s'." % (key, parameters[key], uri))
98102
context[key] = value
99103
return context
100104

neo4j/v1/direct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(self, uri, **config):
5757
# the connection pool may contain multiple IP address keys, one for
5858
# an old address and one for a new address.
5959
if SocketAddress.parse_routing_context(uri):
60-
raise ValueError("Routing parameters are not supported with scheme 'bolt'. Given URI: '%s'." % uri)
60+
raise ValueError("Parameters are not supported with scheme 'bolt'. Given URI: '%s'." % uri)
6161
self.address = SocketAddress.from_uri(uri, DEFAULT_PORT)
6262
self.security_plan = security_plan = SecurityPlan.build(**config)
6363
self.encrypted = security_plan.encrypted

neo4j/v1/routing.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def is_fresh(self, access_mode):
137137
"""
138138
expired = self.last_updated_time + self.ttl <= self.timer()
139139
has_server_for_mode = (access_mode == READ_ACCESS and self.readers) or (access_mode == WRITE_ACCESS and self.writers)
140-
return not expired and len(self.routers) >= 1 and has_server_for_mode
140+
return not expired and self.routers and has_server_for_mode
141141

142142
def update(self, new_routing_table):
143143
""" Update the current routing table with new routing information
@@ -170,7 +170,7 @@ def routing_info_procedure(self, connection):
170170
if ServerVersion.from_str(connection.server.version).at_least_version(3, 2):
171171
return self.call_get_routing_table, {self.get_routing_table_param: self.routing_context}
172172
else:
173-
return self.call_get_servers
173+
return self.call_get_servers, {}
174174

175175
def fetch_routing_info(self, address):
176176
""" Fetch raw routing info from a given router address.
@@ -182,9 +182,15 @@ def fetch_routing_info(self, address):
182182
if routing support is broken
183183
"""
184184
try:
185-
connection = self.acquire_direct(address)
186-
with BoltSession(lambda _: connection) as session:
187-
return list(session.run(*self.routing_info_procedure(connection)))
185+
connections = [None]
186+
187+
def connector(_):
188+
connection = self.acquire_direct(address)
189+
connections[0] = connection
190+
return connection
191+
192+
with BoltSession(lambda _: connector) as session:
193+
return list(session.run(*self.routing_info_procedure(connections[0])))
188194
except CypherError as error:
189195
if error.code == "Neo.ClientError.Procedure.ProcedureNotFound":
190196
raise ServiceUnavailable("Server {!r} does not support routing".format(address))
@@ -269,7 +275,7 @@ def update_routing_table(self):
269275
# None of the routers have been successful, so just fail
270276
raise ServiceUnavailable("Unable to retrieve routing information")
271277

272-
def ensure_routing_table(self, access_mode):
278+
def ensure_routing_table_is_fresh(self, access_mode):
273279
""" Update the routing table if stale.
274280
275281
This method performs two freshness checks, before and after acquiring
@@ -303,7 +309,7 @@ def acquire(self, access_mode=None):
303309
else:
304310
raise ValueError("Unsupported access mode {}".format(access_mode))
305311

306-
self.ensure_routing_table(access_mode)
312+
self.ensure_routing_table_is_fresh(access_mode)
307313
while True:
308314
address = next(server_list)
309315
if address is None:

test/stub/test_routing.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def test_should_update_if_stale(self):
278278
with RoutingPool(address) as pool:
279279
first_updated_time = pool.routing_table.last_updated_time
280280
pool.routing_table.ttl = 0
281-
pool.ensure_routing_table(WRITE_ACCESS)
281+
pool.ensure_routing_table_is_fresh(WRITE_ACCESS)
282282
second_updated_time = pool.routing_table.last_updated_time
283283
assert second_updated_time != first_updated_time
284284
assert not pool.missing_writer
@@ -287,9 +287,9 @@ def test_should_not_update_if_fresh(self):
287287
with StubCluster({9001: "router.script"}):
288288
address = ("127.0.0.1", 9001)
289289
with RoutingPool(address) as pool:
290-
pool.ensure_routing_table(WRITE_ACCESS)
290+
pool.ensure_routing_table_is_fresh(WRITE_ACCESS)
291291
first_updated_time = pool.routing_table.last_updated_time
292-
pool.ensure_routing_table(WRITE_ACCESS)
292+
pool.ensure_routing_table_is_fresh(WRITE_ACCESS)
293293
second_updated_time = pool.routing_table.last_updated_time
294294
assert second_updated_time == first_updated_time
295295
assert not pool.missing_writer
@@ -300,7 +300,7 @@ def test_should_flag_reading_without_writer(self):
300300
with RoutingPool(address) as pool:
301301
assert not pool.routing_table.is_fresh(READ_ACCESS)
302302
assert not pool.routing_table.is_fresh(WRITE_ACCESS)
303-
pool.ensure_routing_table(READ_ACCESS)
303+
pool.ensure_routing_table_is_fresh(READ_ACCESS)
304304
assert pool.missing_writer
305305

306306
# TODO: fix flaky test
@@ -486,7 +486,7 @@ def test_should_remove_router_from_routing_table_if_present(self):
486486
with StubCluster({9001: "router.script"}):
487487
address = ("127.0.0.1", 9001)
488488
with RoutingPool(address) as pool:
489-
pool.ensure_routing_table(WRITE_ACCESS)
489+
pool.ensure_routing_table_is_fresh(WRITE_ACCESS)
490490
target = ("127.0.0.1", 9001)
491491
assert target in pool.routing_table.routers
492492
pool.remove(target)
@@ -496,7 +496,7 @@ def test_should_remove_reader_from_routing_table_if_present(self):
496496
with StubCluster({9001: "router.script"}):
497497
address = ("127.0.0.1", 9001)
498498
with RoutingPool(address) as pool:
499-
pool.ensure_routing_table(WRITE_ACCESS)
499+
pool.ensure_routing_table_is_fresh(WRITE_ACCESS)
500500
target = ("127.0.0.1", 9004)
501501
assert target in pool.routing_table.readers
502502
pool.remove(target)
@@ -506,7 +506,7 @@ def test_should_remove_writer_from_routing_table_if_present(self):
506506
with StubCluster({9001: "router.script"}):
507507
address = ("127.0.0.1", 9001)
508508
with RoutingPool(address) as pool:
509-
pool.ensure_routing_table(WRITE_ACCESS)
509+
pool.ensure_routing_table_is_fresh(WRITE_ACCESS)
510510
target = ("127.0.0.1", 9006)
511511
assert target in pool.routing_table.writers
512512
pool.remove(target)
@@ -516,6 +516,6 @@ def test_should_not_fail_if_absent(self):
516516
with StubCluster({9001: "router.script"}):
517517
address = ("127.0.0.1", 9001)
518518
with RoutingPool(address) as pool:
519-
pool.ensure_routing_table(WRITE_ACCESS)
519+
pool.ensure_routing_table_is_fresh(WRITE_ACCESS)
520520
target = ("127.0.0.1", 9007)
521521
pool.remove(target)

0 commit comments

Comments
 (0)