Skip to content

Read without writer #158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 28, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions neo4j/addressing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
from neo4j.compat import urlparse
from neo4j.exceptions import AddressError

try:
from urllib.parse import parse_qs
except ImportError:
from urlparse import parse_qs


VALID_IPv4_SEGMENTS = [str(i).encode("latin1") for i in range(0x100)]
VALID_IPv6_SEGMENT_CHARS = b"0123456789abcdef"
Expand Down Expand Up @@ -79,6 +84,24 @@ def parse(cls, string, default_port=0):
"""
return cls.from_uri("//{}".format(string), default_port)

@classmethod
def parse_routing_context(cls, uri):
query = urlparse(uri).query
if not query:
return {}

context = {}
parameters = parse_qs(query, True)
for key in parameters:
value_list = parameters[key]
if len(value_list) != 1:
raise ValueError("Duplicated query parameters with key '%s', value '%s' found in URL '%s'" % (key, value_list, uri))
value = value_list[0]
if not value:
raise ValueError("Invalid parameters:'%s=%s' in URI '%s'." % (key, value, uri))
context[key] = value
return context


def resolve(socket_address):
try:
Expand Down
19 changes: 19 additions & 0 deletions neo4j/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,25 @@
from sys import stdout


class ServerVersion(object):
def __init__(self, product, version_tuple, tags_tuple):
self.product = product
self.version_tuple = version_tuple
self.tags_tuple = tags_tuple

def at_least_version(self, major, minor):
return self.version_tuple >= (major, minor)

@classmethod
def from_str(cls, full_version):
if full_version is None:
return ServerVersion("Neo4j", (3, 0), ())
product, _, tagged_version = full_version.partition("/")
tags = tagged_version.split("-")
version = map(int, tags[0].split("."))
return ServerVersion(product, tuple(version), tuple(tags[1:]))


class ColourFormatter(logging.Formatter):
""" Colour formatter for pretty log output.
"""
Expand Down
2 changes: 2 additions & 0 deletions neo4j/v1/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def __init__(self, uri, **config):
# will carry out DNS resolution, leading to the possibility that
# the connection pool may contain multiple IP address keys, one for
# an old address and one for a new address.
if SocketAddress.parse_routing_context(uri):
raise ValueError("Parameters are not supported with scheme 'bolt'. Given URI: '%s'." % uri)
self.address = SocketAddress.from_uri(uri, DEFAULT_PORT)
self.security_plan = security_plan = SecurityPlan.build(**config)
self.encrypted = security_plan.encrypted
Expand Down
92 changes: 63 additions & 29 deletions neo4j/v1/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
from neo4j.bolt import ConnectionPool, ServiceUnavailable, ProtocolError, DEFAULT_PORT, connect
from neo4j.compat.collections import MutableSet, OrderedDict
from neo4j.exceptions import CypherError
from neo4j.v1.api import Driver, READ_ACCESS, WRITE_ACCESS
from neo4j.v1.api import Driver, READ_ACCESS, WRITE_ACCESS, fix_statement, fix_parameters
from neo4j.v1.exceptions import SessionExpired
from neo4j.v1.security import SecurityPlan
from neo4j.v1.session import BoltSession
from neo4j.util import ServerVersion


class RoundRobinSet(MutableSet):
Expand Down Expand Up @@ -131,11 +132,12 @@ def __init__(self, routers=(), readers=(), writers=(), ttl=0):
self.last_updated_time = self.timer()
self.ttl = ttl

def is_fresh(self):
def is_fresh(self, access_mode):
""" Indicator for whether routing information is still usable.
"""
expired = self.last_updated_time + self.ttl <= self.timer()
return not expired and len(self.routers) > 1 and self.readers and self.writers
has_server_for_mode = (access_mode == READ_ACCESS and self.readers) or (access_mode == WRITE_ACCESS and self.writers)
return not expired and self.routers and has_server_for_mode

def update(self, new_routing_table):
""" Update the current routing table with new routing information
Expand All @@ -148,16 +150,34 @@ def update(self, new_routing_table):
self.ttl = new_routing_table.ttl


class RoutingSession(BoltSession):

call_get_servers = "CALL dbms.cluster.routing.getServers"
get_routing_table_param = "context"
call_get_routing_table = "CALL dbms.cluster.routing.getRoutingTable({%s})" % get_routing_table_param

def routing_info_procedure(self, routing_context):
if ServerVersion.from_str(self._connection.server.version).at_least_version(3, 2):
return self.call_get_routing_table, {self.get_routing_table_param: routing_context}
else:
return self.call_get_servers, {}

def __run__(self, ignored, routing_context):
# the statement is ignored as it will be get routing table procedure call.
statement, parameters = self.routing_info_procedure(routing_context)
return self._run(fix_statement(statement), fix_parameters(parameters))


class RoutingConnectionPool(ConnectionPool):
""" Connection pool with routing table.
"""

routing_info_procedure = "dbms.cluster.routing.getServers"

def __init__(self, connector, initial_address, *routers):
def __init__(self, connector, initial_address, routing_context, *routers):
super(RoutingConnectionPool, self).__init__(connector)
self.initial_address = initial_address
self.routing_context = routing_context
self.routing_table = RoutingTable(routers)
self.missing_writer = False
self.refresh_lock = Lock()

def fetch_routing_info(self, address):
Expand All @@ -170,8 +190,8 @@ def fetch_routing_info(self, address):
if routing support is broken
"""
try:
with BoltSession(lambda _: self.acquire_direct(address)) as session:
return list(session.run("CALL %s" % self.routing_info_procedure))
with RoutingSession(lambda _: self.acquire_direct(address)) as session:
return list(session.run("ignored", self.routing_context))
except CypherError as error:
if error.code == "Neo.ClientError.Procedure.ProcedureNotFound":
raise ServiceUnavailable("Server {!r} does not support routing".format(address))
Expand Down Expand Up @@ -200,6 +220,11 @@ def fetch_routing_table(self, address):
num_readers = len(new_routing_table.readers)
num_writers = len(new_routing_table.writers)

# No writers are available. This likely indicates a temporary state,
# such as leader switching, so we should not signal an error.
# When no writers available, then we flag we are reading in absence of writer
self.missing_writer = (num_writers == 0)

# No routers
if num_routers == 0:
raise ProtocolError("No routing servers returned from server %r" % (address,))
Expand All @@ -208,12 +233,6 @@ def fetch_routing_table(self, address):
if num_readers == 0:
raise ProtocolError("No read servers returned from server %r" % (address,))

# No writers
if num_writers == 0:
# No writers are available. This likely indicates a temporary state,
# such as leader switching, so we should not signal an error.
return None

# At least one of each is fine, so return this table
return new_routing_table

Expand All @@ -234,21 +253,30 @@ def update_routing_table(self):
"""
# copied because it can be modified
copy_of_routers = list(self.routing_table.routers)

has_tried_initial_routers = False
if self.missing_writer:
has_tried_initial_routers = True
if self.update_routing_table_with_routers(resolve(self.initial_address)):
return

if self.update_routing_table_with_routers(copy_of_routers):
return

initial_routers = resolve(self.initial_address)
for router in copy_of_routers:
if router in initial_routers:
initial_routers.remove(router)
if initial_routers:
if self.update_routing_table_with_routers(initial_routers):
return
if not has_tried_initial_routers:
initial_routers = resolve(self.initial_address)
for router in copy_of_routers:
if router in initial_routers:
initial_routers.remove(router)
if initial_routers:
if self.update_routing_table_with_routers(initial_routers):
return


# None of the routers have been successful, so just fail
raise ServiceUnavailable("Unable to retrieve routing information")

def refresh_routing_table(self):
def ensure_routing_table_is_fresh(self, access_mode):
""" Update the routing table if stale.

This method performs two freshness checks, before and after acquiring
Expand All @@ -261,10 +289,13 @@ def refresh_routing_table(self):

:return: `True` if an update was required, `False` otherwise.
"""
if self.routing_table.is_fresh():
if self.routing_table.is_fresh(access_mode):
return False
with self.refresh_lock:
if self.routing_table.is_fresh():
if self.routing_table.is_fresh(access_mode):
if access_mode == READ_ACCESS:
# if reader is fresh but writers is not fresh, then we are reading in absence of writer
self.missing_writer = not self.routing_table.is_fresh(WRITE_ACCESS)
return False
self.update_routing_table()
return True
Expand All @@ -278,18 +309,20 @@ def acquire(self, access_mode=None):
server_list = self.routing_table.writers
else:
raise ValueError("Unsupported access mode {}".format(access_mode))

self.ensure_routing_table_is_fresh(access_mode)
while True:
address = None
while address is None:
self.refresh_routing_table()
address = next(server_list)
address = next(server_list)
if address is None:
break
try:
connection = self.acquire_direct(address) # should always be a resolved address
connection.Error = SessionExpired
except ServiceUnavailable:
self.remove(address)
else:
return connection
raise SessionExpired("Failed to obtain connection towards '%s' server." % access_mode)

def remove(self, address):
""" Remove an address from the connection pool, if present, closing
Expand All @@ -313,6 +346,7 @@ def __init__(self, uri, **config):
self.initial_address = initial_address = SocketAddress.from_uri(uri, DEFAULT_PORT)
self.security_plan = security_plan = SecurityPlan.build(**config)
self.encrypted = security_plan.encrypted
routing_context = SocketAddress.parse_routing_context(uri)
if not security_plan.routing_compatible:
# this error message is case-specific as there is only one incompatible
# scenario right now
Expand All @@ -321,7 +355,7 @@ def __init__(self, uri, **config):
def connector(a):
return connect(a, security_plan.ssl_context, **config)

pool = RoutingConnectionPool(connector, initial_address, *resolve(initial_address))
pool = RoutingConnectionPool(connector, initial_address, routing_context, *resolve(initial_address))
try:
pool.update_routing_table()
except:
Expand Down
5 changes: 4 additions & 1 deletion neo4j/v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class BoltSession(Session):
:param bookmark:
"""

def __run__(self, statement, parameters):
def _run(self, statement, parameters):
assert isinstance(statement, unicode)
assert isinstance(parameters, dict)

Expand All @@ -52,6 +52,9 @@ def __run__(self, statement, parameters):

return result

def __run__(self, statement, parameters):
return self._run(statement, parameters)

def __begin__(self):
return self.__run__(u"BEGIN", {"bookmark": self._bookmark} if self._bookmark else {})

Expand Down
11 changes: 3 additions & 8 deletions test/integration/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from boltkit.controller import WindowsController, UnixController

from neo4j.v1 import GraphDatabase, AuthError
from neo4j.util import ServerVersion

from test.env import NEO4J_SERVER_PACKAGE, NEO4J_USER, NEO4J_PASSWORD

Expand Down Expand Up @@ -89,17 +90,11 @@ def server_version_info(cls):
with GraphDatabase.driver(cls.bolt_uri, auth=cls.auth_token) as driver:
with driver.session() as session:
full_version = session.run("RETURN 1").summary().server.version
if full_version is None:
return "Neo4j", (3, 0), ()
product, _, tagged_version = full_version.partition("/")
tags = tagged_version.split("-")
version = map(int, tags[0].split("."))
return product, tuple(version), tuple(tags[1:])
return ServerVersion.from_str(full_version)

@classmethod
def at_least_version(cls, major, minor):
_, server_version, _ = cls.server_version_info()
return server_version >= (major, minor)
return cls.server_version_info().at_least_version(major, minor);

@classmethod
def delete_known_hosts_file(cls):
Expand Down
9 changes: 9 additions & 0 deletions test/stub/scripts/get_routing_table.script
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
!: AUTO INIT
!: AUTO RESET

S: SUCCESS {"server": "Neo4j/3.2.2"}
C: RUN "CALL dbms.cluster.routing.getRoutingTable({context})" {"context": {}}
PULL_ALL
S: SUCCESS {"fields": ["ttl", "servers"]}
RECORD [9223372036854775807, [{"addresses": ["127.0.0.1:9001"],"role": "WRITE"}, {"addresses": ["127.0.0.1:9002"], "role": "READ"},{"addresses": ["127.0.0.1:9001", "127.0.0.1:9002"], "role": "ROUTE"}]]
SUCCESS {}
9 changes: 9 additions & 0 deletions test/stub/scripts/get_routing_table_with_context.script
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
!: AUTO INIT
!: AUTO RESET

S: SUCCESS {"server": "Neo4j/3.2.3"}
C: RUN "CALL dbms.cluster.routing.getRoutingTable({context})" {"context": {"name": "molly", "age": "1"}}
PULL_ALL
S: SUCCESS {"fields": ["ttl", "servers"]}
RECORD [9223372036854775807, [{"addresses": ["127.0.0.1:9001"],"role": "WRITE"}, {"addresses": ["127.0.0.1:9002"], "role": "READ"},{"addresses": ["127.0.0.1:9001", "127.0.0.1:9002"], "role": "ROUTE"}]]
SUCCESS {}
6 changes: 6 additions & 0 deletions test/stub/test_directdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,9 @@ def test_direct_disconnect_on_pull_all(self):
with self.assertRaises(ServiceUnavailable):
with driver.session() as session:
session.run("RETURN $x", {"x": 1}).consume()

def test_direct_should_reject_routing_context(self):
uri = "bolt://127.0.0.1:9001/?name=molly&age=1"
with self.assertRaises(ValueError):
GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False)

Loading