Skip to content

Routing context #157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions neo4j/addressing.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,25 @@ def parse(cls, string, default_port=0):
"""
return cls.from_uri("//{}".format(string), default_port)

@classmethod
def parse_routing_context(cls, uri):
query = urlparse(uri).query
if not query:
return {}

context = {}
parameters = [x for x in query.split('&') if x]
for keyValue in parameters:
pair = keyValue.split('=')
if len(pair) != 2 or not pair[0] or not pair[1]:
raise ValueError("Invalid parameters: '" + keyValue + "' in URI '" + uri + "'.")
key = pair[0]
value = pair[1]
if key in context:
raise ValueError("Duplicated query parameters with key '" + key + "' found in URL '" + uri + "'")
context[key] = value
return context


def resolve(socket_address):
try:
Expand Down
19 changes: 19 additions & 0 deletions neo4j/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,25 @@
from sys import stdout


class ServerVersion(object):
def __init__(self, product, version_tuple, tags_tuple):
self.product = product
self.version_tuple = version_tuple
self.tags_tuple = tags_tuple

def at_least_version(self, major, minor):
return self.version_tuple >= (major, minor)

@classmethod
def from_str(cls, full_version):
if full_version is None:
return ServerVersion("Neo4j", (3, 0), ())
product, _, tagged_version = full_version.partition("/")
tags = tagged_version.split("-")
version = map(int, tags[0].split("."))
return ServerVersion(product, tuple(version), tuple(tags[1:]))


class ColourFormatter(logging.Formatter):
""" Colour formatter for pretty log output.
"""
Expand Down
2 changes: 2 additions & 0 deletions neo4j/v1/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def __init__(self, uri, **config):
# will carry out DNS resolution, leading to the possibility that
# the connection pool may contain multiple IP address keys, one for
# an old address and one for a new address.
if SocketAddress.parse_routing_context(uri):
raise ValueError("Routing parameters are not supported with scheme 'bolt'. Given URI: '" + uri + "'.")
self.address = SocketAddress.from_uri(uri, DEFAULT_PORT)
self.security_plan = security_plan = SecurityPlan.build(**config)
self.encrypted = security_plan.encrypted
Expand Down
45 changes: 37 additions & 8 deletions neo4j/v1/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from neo4j.v1.exceptions import SessionExpired
from neo4j.v1.security import SecurityPlan
from neo4j.v1.session import BoltSession
from neo4j.util import ServerVersion


class RoundRobinSet(MutableSet):
Expand Down Expand Up @@ -152,13 +153,23 @@ class RoutingConnectionPool(ConnectionPool):
""" Connection pool with routing table.
"""

routing_info_procedure = "dbms.cluster.routing.getServers"
call_get_servers = "CALL dbms.cluster.routing.getServers"
get_routing_table_param = "context"
call_get_routing_table = "CALL dbms.cluster.routing.getRoutingTable({" + get_routing_table_param + "})"

def __init__(self, connector, *routers):
def __init__(self, connector, initial_address, routing_context, *routers):
super(RoutingConnectionPool, self).__init__(connector)
self.initial_address = initial_address
self.routing_context = routing_context
self.routing_table = RoutingTable(routers)
self.refresh_lock = Lock()

def routing_info_procedure(self, connection):
if ServerVersion.from_str(connection.server.version).at_least_version(3, 2):
return self.call_get_routing_table, {self.get_routing_table_param: self.routing_context}
else:
return self.call_get_servers, None

def fetch_routing_info(self, address):
""" Fetch raw routing info from a given router address.

Expand All @@ -169,8 +180,9 @@ def fetch_routing_info(self, address):
if routing support is broken
"""
try:
with BoltSession(lambda _: self.acquire_direct(address)) as session:
return list(session.run("CALL %s" % self.routing_info_procedure))
connection = self.acquire_direct(address)
with BoltSession(lambda _: connection) as session:
return list(session.run(*self.routing_info_procedure(connection)))
except CypherError as error:
if error.code == "Neo.ClientError.Procedure.ProcedureNotFound":
raise ServiceUnavailable("Server {!r} does not support routing".format(address))
Expand Down Expand Up @@ -216,16 +228,32 @@ def fetch_routing_table(self, address):
# At least one of each is fine, so return this table
return new_routing_table

def update_routing_table_with_routers(self, routers):
"""Try to update routing tables with the given routers
:return: True if the routing table is successfully updated, otherwise False
"""
for router in routers:
new_routing_table = self.fetch_routing_table(router)
if new_routing_table is not None:
self.routing_table.update(new_routing_table)
return True
return False

def update_routing_table(self):
""" Update the routing table from the first router able to provide
valid routing information.
"""
# copied because it can be modified
copy_of_routers = list(self.routing_table.routers)
if self.update_routing_table_with_routers(copy_of_routers):
return

initial_routers = resolve(self.initial_address)
for router in copy_of_routers:
new_routing_table = self.fetch_routing_table(router)
if new_routing_table is not None:
self.routing_table.update(new_routing_table)
if router in initial_routers:
initial_routers.remove(router)
if len(initial_routers) != 0:
if self.update_routing_table_with_routers(initial_routers):
return

# None of the routers have been successful, so just fail
Expand Down Expand Up @@ -296,6 +324,7 @@ def __init__(self, uri, **config):
self.initial_address = initial_address = SocketAddress.from_uri(uri, DEFAULT_PORT)
self.security_plan = security_plan = SecurityPlan.build(**config)
self.encrypted = security_plan.encrypted
routing_context = SocketAddress.parse_routing_context(uri)
if not security_plan.routing_compatible:
# this error message is case-specific as there is only one incompatible
# scenario right now
Expand All @@ -304,7 +333,7 @@ def __init__(self, uri, **config):
def connector(a):
return connect(a, security_plan.ssl_context, **config)

pool = RoutingConnectionPool(connector, *resolve(initial_address))
pool = RoutingConnectionPool(connector, initial_address, routing_context, *resolve(initial_address))
try:
pool.update_routing_table()
except:
Expand Down
11 changes: 3 additions & 8 deletions test/integration/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from boltkit.controller import WindowsController, UnixController

from neo4j.v1 import GraphDatabase, AuthError
from neo4j.util import ServerVersion

from test.env import NEO4J_SERVER_PACKAGE, NEO4J_USER, NEO4J_PASSWORD

Expand Down Expand Up @@ -89,17 +90,11 @@ def server_version_info(cls):
with GraphDatabase.driver(cls.bolt_uri, auth=cls.auth_token) as driver:
with driver.session() as session:
full_version = session.run("RETURN 1").summary().server.version
if full_version is None:
return "Neo4j", (3, 0), ()
product, _, tagged_version = full_version.partition("/")
tags = tagged_version.split("-")
version = map(int, tags[0].split("."))
return product, tuple(version), tuple(tags[1:])
return ServerVersion.from_str(full_version)

@classmethod
def at_least_version(cls, major, minor):
_, server_version, _ = cls.server_version_info()
return server_version >= (major, minor)
return cls.server_version_info().at_least_version(major, minor);

@classmethod
def delete_known_hosts_file(cls):
Expand Down
9 changes: 9 additions & 0 deletions test/stub/scripts/get_routing_table.script
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
!: AUTO INIT
!: AUTO RESET

S: SUCCESS {"server": "Neo4j/3.2.2"}
C: RUN "CALL dbms.cluster.routing.getRoutingTable({context})" {"context": {}}
PULL_ALL
S: SUCCESS {"fields": ["ttl", "servers"]}
RECORD [9223372036854775807, [{"addresses": ["127.0.0.1:9001"],"role": "WRITE"}, {"addresses": ["127.0.0.1:9002"], "role": "READ"},{"addresses": ["127.0.0.1:9001", "127.0.0.1:9002"], "role": "ROUTE"}]]
SUCCESS {}
9 changes: 9 additions & 0 deletions test/stub/scripts/get_routing_table_with_context.script
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
!: AUTO INIT
!: AUTO RESET

S: SUCCESS {"server": "Neo4j/3.2.3"}
C: RUN "CALL dbms.cluster.routing.getRoutingTable({context})" {"context": {"name": "molly", "age": "1"}}
PULL_ALL
S: SUCCESS {"fields": ["ttl", "servers"]}
RECORD [9223372036854775807, [{"addresses": ["127.0.0.1:9001"],"role": "WRITE"}, {"addresses": ["127.0.0.1:9002"], "role": "READ"},{"addresses": ["127.0.0.1:9001", "127.0.0.1:9002"], "role": "ROUTE"}]]
SUCCESS {}
6 changes: 6 additions & 0 deletions test/stub/test_directdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,9 @@ def test_direct_disconnect_on_pull_all(self):
with self.assertRaises(ServiceUnavailable):
with driver.session() as session:
session.run("RETURN $x", {"x": 1}).consume()

def test_direct_should_reject_routing_context(self):
uri = "bolt://127.0.0.1:9001/?name=molly&age=1"
with self.assertRaises(ValueError):
GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False)

Loading