Skip to content

Commit ed27f5e

Browse files
authored
Merge pull request #158 from zhenlineo/1.3-read-without-writer
Read without writer
2 parents b553e28 + 3cb094a commit ed27f5e

13 files changed

+338
-104
lines changed

neo4j/addressing.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@
2525
from neo4j.compat import urlparse
2626
from neo4j.exceptions import AddressError
2727

28+
try:
29+
from urllib.parse import parse_qs
30+
except ImportError:
31+
from urlparse import parse_qs
32+
2833

2934
VALID_IPv4_SEGMENTS = [str(i).encode("latin1") for i in range(0x100)]
3035
VALID_IPv6_SEGMENT_CHARS = b"0123456789abcdef"
@@ -79,6 +84,24 @@ def parse(cls, string, default_port=0):
7984
"""
8085
return cls.from_uri("//{}".format(string), default_port)
8186

87+
@classmethod
88+
def parse_routing_context(cls, uri):
89+
query = urlparse(uri).query
90+
if not query:
91+
return {}
92+
93+
context = {}
94+
parameters = parse_qs(query, True)
95+
for key in parameters:
96+
value_list = parameters[key]
97+
if len(value_list) != 1:
98+
raise ValueError("Duplicated query parameters with key '%s', value '%s' found in URL '%s'" % (key, value_list, uri))
99+
value = value_list[0]
100+
if not value:
101+
raise ValueError("Invalid parameters:'%s=%s' in URI '%s'." % (key, value, uri))
102+
context[key] = value
103+
return context
104+
82105

83106
def resolve(socket_address):
84107
try:

neo4j/util.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,25 @@
2525
from sys import stdout
2626

2727

28+
class ServerVersion(object):
29+
def __init__(self, product, version_tuple, tags_tuple):
30+
self.product = product
31+
self.version_tuple = version_tuple
32+
self.tags_tuple = tags_tuple
33+
34+
def at_least_version(self, major, minor):
35+
return self.version_tuple >= (major, minor)
36+
37+
@classmethod
38+
def from_str(cls, full_version):
39+
if full_version is None:
40+
return ServerVersion("Neo4j", (3, 0), ())
41+
product, _, tagged_version = full_version.partition("/")
42+
tags = tagged_version.split("-")
43+
version = map(int, tags[0].split("."))
44+
return ServerVersion(product, tuple(version), tuple(tags[1:]))
45+
46+
2847
class ColourFormatter(logging.Formatter):
2948
""" Colour formatter for pretty log output.
3049
"""

neo4j/v1/direct.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ def __init__(self, uri, **config):
5656
# will carry out DNS resolution, leading to the possibility that
5757
# the connection pool may contain multiple IP address keys, one for
5858
# an old address and one for a new address.
59+
if SocketAddress.parse_routing_context(uri):
60+
raise ValueError("Parameters are not supported with scheme 'bolt'. Given URI: '%s'." % uri)
5961
self.address = SocketAddress.from_uri(uri, DEFAULT_PORT)
6062
self.security_plan = security_plan = SecurityPlan.build(**config)
6163
self.encrypted = security_plan.encrypted

neo4j/v1/routing.py

Lines changed: 63 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,11 @@
2626
from neo4j.bolt import ConnectionPool, ServiceUnavailable, ProtocolError, DEFAULT_PORT, connect
2727
from neo4j.compat.collections import MutableSet, OrderedDict
2828
from neo4j.exceptions import CypherError
29-
from neo4j.v1.api import Driver, READ_ACCESS, WRITE_ACCESS
29+
from neo4j.v1.api import Driver, READ_ACCESS, WRITE_ACCESS, fix_statement, fix_parameters
3030
from neo4j.v1.exceptions import SessionExpired
3131
from neo4j.v1.security import SecurityPlan
3232
from neo4j.v1.session import BoltSession
33+
from neo4j.util import ServerVersion
3334

3435

3536
class RoundRobinSet(MutableSet):
@@ -131,11 +132,12 @@ def __init__(self, routers=(), readers=(), writers=(), ttl=0):
131132
self.last_updated_time = self.timer()
132133
self.ttl = ttl
133134

134-
def is_fresh(self):
135+
def is_fresh(self, access_mode):
135136
""" Indicator for whether routing information is still usable.
136137
"""
137138
expired = self.last_updated_time + self.ttl <= self.timer()
138-
return not expired and len(self.routers) > 1 and self.readers and self.writers
139+
has_server_for_mode = (access_mode == READ_ACCESS and self.readers) or (access_mode == WRITE_ACCESS and self.writers)
140+
return not expired and self.routers and has_server_for_mode
139141

140142
def update(self, new_routing_table):
141143
""" Update the current routing table with new routing information
@@ -148,16 +150,34 @@ def update(self, new_routing_table):
148150
self.ttl = new_routing_table.ttl
149151

150152

153+
class RoutingSession(BoltSession):
154+
155+
call_get_servers = "CALL dbms.cluster.routing.getServers"
156+
get_routing_table_param = "context"
157+
call_get_routing_table = "CALL dbms.cluster.routing.getRoutingTable({%s})" % get_routing_table_param
158+
159+
def routing_info_procedure(self, routing_context):
160+
if ServerVersion.from_str(self._connection.server.version).at_least_version(3, 2):
161+
return self.call_get_routing_table, {self.get_routing_table_param: routing_context}
162+
else:
163+
return self.call_get_servers, {}
164+
165+
def __run__(self, ignored, routing_context):
166+
# the statement is ignored as it will be get routing table procedure call.
167+
statement, parameters = self.routing_info_procedure(routing_context)
168+
return self._run(fix_statement(statement), fix_parameters(parameters))
169+
170+
151171
class RoutingConnectionPool(ConnectionPool):
152172
""" Connection pool with routing table.
153173
"""
154174

155-
routing_info_procedure = "dbms.cluster.routing.getServers"
156-
157-
def __init__(self, connector, initial_address, *routers):
175+
def __init__(self, connector, initial_address, routing_context, *routers):
158176
super(RoutingConnectionPool, self).__init__(connector)
159177
self.initial_address = initial_address
178+
self.routing_context = routing_context
160179
self.routing_table = RoutingTable(routers)
180+
self.missing_writer = False
161181
self.refresh_lock = Lock()
162182

163183
def fetch_routing_info(self, address):
@@ -170,8 +190,8 @@ def fetch_routing_info(self, address):
170190
if routing support is broken
171191
"""
172192
try:
173-
with BoltSession(lambda _: self.acquire_direct(address)) as session:
174-
return list(session.run("CALL %s" % self.routing_info_procedure))
193+
with RoutingSession(lambda _: self.acquire_direct(address)) as session:
194+
return list(session.run("ignored", self.routing_context))
175195
except CypherError as error:
176196
if error.code == "Neo.ClientError.Procedure.ProcedureNotFound":
177197
raise ServiceUnavailable("Server {!r} does not support routing".format(address))
@@ -200,6 +220,11 @@ def fetch_routing_table(self, address):
200220
num_readers = len(new_routing_table.readers)
201221
num_writers = len(new_routing_table.writers)
202222

223+
# No writers are available. This likely indicates a temporary state,
224+
# such as leader switching, so we should not signal an error.
225+
# When no writers available, then we flag we are reading in absence of writer
226+
self.missing_writer = (num_writers == 0)
227+
203228
# No routers
204229
if num_routers == 0:
205230
raise ProtocolError("No routing servers returned from server %r" % (address,))
@@ -208,12 +233,6 @@ def fetch_routing_table(self, address):
208233
if num_readers == 0:
209234
raise ProtocolError("No read servers returned from server %r" % (address,))
210235

211-
# No writers
212-
if num_writers == 0:
213-
# No writers are available. This likely indicates a temporary state,
214-
# such as leader switching, so we should not signal an error.
215-
return None
216-
217236
# At least one of each is fine, so return this table
218237
return new_routing_table
219238

@@ -234,21 +253,30 @@ def update_routing_table(self):
234253
"""
235254
# copied because it can be modified
236255
copy_of_routers = list(self.routing_table.routers)
256+
257+
has_tried_initial_routers = False
258+
if self.missing_writer:
259+
has_tried_initial_routers = True
260+
if self.update_routing_table_with_routers(resolve(self.initial_address)):
261+
return
262+
237263
if self.update_routing_table_with_routers(copy_of_routers):
238264
return
239265

240-
initial_routers = resolve(self.initial_address)
241-
for router in copy_of_routers:
242-
if router in initial_routers:
243-
initial_routers.remove(router)
244-
if initial_routers:
245-
if self.update_routing_table_with_routers(initial_routers):
246-
return
266+
if not has_tried_initial_routers:
267+
initial_routers = resolve(self.initial_address)
268+
for router in copy_of_routers:
269+
if router in initial_routers:
270+
initial_routers.remove(router)
271+
if initial_routers:
272+
if self.update_routing_table_with_routers(initial_routers):
273+
return
274+
247275

248276
# None of the routers have been successful, so just fail
249277
raise ServiceUnavailable("Unable to retrieve routing information")
250278

251-
def refresh_routing_table(self):
279+
def ensure_routing_table_is_fresh(self, access_mode):
252280
""" Update the routing table if stale.
253281
254282
This method performs two freshness checks, before and after acquiring
@@ -261,10 +289,13 @@ def refresh_routing_table(self):
261289
262290
:return: `True` if an update was required, `False` otherwise.
263291
"""
264-
if self.routing_table.is_fresh():
292+
if self.routing_table.is_fresh(access_mode):
265293
return False
266294
with self.refresh_lock:
267-
if self.routing_table.is_fresh():
295+
if self.routing_table.is_fresh(access_mode):
296+
if access_mode == READ_ACCESS:
297+
# if reader is fresh but writers is not fresh, then we are reading in absence of writer
298+
self.missing_writer = not self.routing_table.is_fresh(WRITE_ACCESS)
268299
return False
269300
self.update_routing_table()
270301
return True
@@ -278,18 +309,20 @@ def acquire(self, access_mode=None):
278309
server_list = self.routing_table.writers
279310
else:
280311
raise ValueError("Unsupported access mode {}".format(access_mode))
312+
313+
self.ensure_routing_table_is_fresh(access_mode)
281314
while True:
282-
address = None
283-
while address is None:
284-
self.refresh_routing_table()
285-
address = next(server_list)
315+
address = next(server_list)
316+
if address is None:
317+
break
286318
try:
287319
connection = self.acquire_direct(address) # should always be a resolved address
288320
connection.Error = SessionExpired
289321
except ServiceUnavailable:
290322
self.remove(address)
291323
else:
292324
return connection
325+
raise SessionExpired("Failed to obtain connection towards '%s' server." % access_mode)
293326

294327
def remove(self, address):
295328
""" Remove an address from the connection pool, if present, closing
@@ -313,6 +346,7 @@ def __init__(self, uri, **config):
313346
self.initial_address = initial_address = SocketAddress.from_uri(uri, DEFAULT_PORT)
314347
self.security_plan = security_plan = SecurityPlan.build(**config)
315348
self.encrypted = security_plan.encrypted
349+
routing_context = SocketAddress.parse_routing_context(uri)
316350
if not security_plan.routing_compatible:
317351
# this error message is case-specific as there is only one incompatible
318352
# scenario right now
@@ -321,7 +355,7 @@ def __init__(self, uri, **config):
321355
def connector(a):
322356
return connect(a, security_plan.ssl_context, **config)
323357

324-
pool = RoutingConnectionPool(connector, initial_address, *resolve(initial_address))
358+
pool = RoutingConnectionPool(connector, initial_address, routing_context, *resolve(initial_address))
325359
try:
326360
pool.update_routing_table()
327361
except:

neo4j/v1/session.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class BoltSession(Session):
3434
:param bookmark:
3535
"""
3636

37-
def __run__(self, statement, parameters):
37+
def _run(self, statement, parameters):
3838
assert isinstance(statement, unicode)
3939
assert isinstance(parameters, dict)
4040

@@ -52,6 +52,9 @@ def __run__(self, statement, parameters):
5252

5353
return result
5454

55+
def __run__(self, statement, parameters):
56+
return self._run(statement, parameters)
57+
5558
def __begin__(self):
5659
return self.__run__(u"BEGIN", {"bookmark": self._bookmark} if self._bookmark else {})
5760

test/integration/tools.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from boltkit.controller import WindowsController, UnixController
3434

3535
from neo4j.v1 import GraphDatabase, AuthError
36+
from neo4j.util import ServerVersion
3637

3738
from test.env import NEO4J_SERVER_PACKAGE, NEO4J_USER, NEO4J_PASSWORD
3839

@@ -89,17 +90,11 @@ def server_version_info(cls):
8990
with GraphDatabase.driver(cls.bolt_uri, auth=cls.auth_token) as driver:
9091
with driver.session() as session:
9192
full_version = session.run("RETURN 1").summary().server.version
92-
if full_version is None:
93-
return "Neo4j", (3, 0), ()
94-
product, _, tagged_version = full_version.partition("/")
95-
tags = tagged_version.split("-")
96-
version = map(int, tags[0].split("."))
97-
return product, tuple(version), tuple(tags[1:])
93+
return ServerVersion.from_str(full_version)
9894

9995
@classmethod
10096
def at_least_version(cls, major, minor):
101-
_, server_version, _ = cls.server_version_info()
102-
return server_version >= (major, minor)
97+
return cls.server_version_info().at_least_version(major, minor);
10398

10499
@classmethod
105100
def delete_known_hosts_file(cls):
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
!: AUTO INIT
2+
!: AUTO RESET
3+
4+
S: SUCCESS {"server": "Neo4j/3.2.2"}
5+
C: RUN "CALL dbms.cluster.routing.getRoutingTable({context})" {"context": {}}
6+
PULL_ALL
7+
S: SUCCESS {"fields": ["ttl", "servers"]}
8+
RECORD [9223372036854775807, [{"addresses": ["127.0.0.1:9001"],"role": "WRITE"}, {"addresses": ["127.0.0.1:9002"], "role": "READ"},{"addresses": ["127.0.0.1:9001", "127.0.0.1:9002"], "role": "ROUTE"}]]
9+
SUCCESS {}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
!: AUTO INIT
2+
!: AUTO RESET
3+
4+
S: SUCCESS {"server": "Neo4j/3.2.3"}
5+
C: RUN "CALL dbms.cluster.routing.getRoutingTable({context})" {"context": {"name": "molly", "age": "1"}}
6+
PULL_ALL
7+
S: SUCCESS {"fields": ["ttl", "servers"]}
8+
RECORD [9223372036854775807, [{"addresses": ["127.0.0.1:9001"],"role": "WRITE"}, {"addresses": ["127.0.0.1:9002"], "role": "READ"},{"addresses": ["127.0.0.1:9001", "127.0.0.1:9002"], "role": "ROUTE"}]]
9+
SUCCESS {}

test/stub/test_directdriver.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,9 @@ def test_direct_disconnect_on_pull_all(self):
4848
with self.assertRaises(ServiceUnavailable):
4949
with driver.session() as session:
5050
session.run("RETURN $x", {"x": 1}).consume()
51+
52+
def test_direct_should_reject_routing_context(self):
53+
uri = "bolt://127.0.0.1:9001/?name=molly&age=1"
54+
with self.assertRaises(ValueError):
55+
GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False)
56+

0 commit comments

Comments
 (0)