From ea40da8fdf32800d45bee87245d8edacbec1ef9e Mon Sep 17 00:00:00 2001 From: Zhen Date: Mon, 24 Apr 2017 14:25:20 +0200 Subject: [PATCH 1/3] Roll back to initial server if running out of routers --- neo4j/v1/routing.py | 27 +++++++++++--- test/stub/test_routing.py | 74 +++++++++++++++++++++++---------------- test/unit/test_routing.py | 6 ++-- 3 files changed, 70 insertions(+), 37 deletions(-) diff --git a/neo4j/v1/routing.py b/neo4j/v1/routing.py index 73b9fcfa7..7540d5fff 100644 --- a/neo4j/v1/routing.py +++ b/neo4j/v1/routing.py @@ -154,8 +154,9 @@ class RoutingConnectionPool(ConnectionPool): routing_info_procedure = "dbms.cluster.routing.getServers" - def __init__(self, connector, *routers): + def __init__(self, connector, initial_address, *routers): super(RoutingConnectionPool, self).__init__(connector) + self.initial_address = initial_address self.routing_table = RoutingTable(routers) self.refresh_lock = Lock() @@ -216,16 +217,32 @@ def fetch_routing_table(self, address): # At least one of each is fine, so return this table return new_routing_table + def update_routing_table_with_routers(self, routers): + """Try to update routing tables with the given routers + :return: True if the routing table is successfully updated, otherwise False + """ + for router in routers: + new_routing_table = self.fetch_routing_table(router) + if new_routing_table is not None: + self.routing_table.update(new_routing_table) + return True + return False + def update_routing_table(self): """ Update the routing table from the first router able to provide valid routing information. """ # copied because it can be modified copy_of_routers = list(self.routing_table.routers) + if self.update_routing_table_with_routers(copy_of_routers): + return + + initial_routers = resolve(self.initial_address) for router in copy_of_routers: - new_routing_table = self.fetch_routing_table(router) - if new_routing_table is not None: - self.routing_table.update(new_routing_table) + if router in initial_routers: + initial_routers.remove(router) + if len(initial_routers) != 0: + if self.update_routing_table_with_routers(initial_routers): return # None of the routers have been successful, so just fail @@ -304,7 +321,7 @@ def __init__(self, uri, **config): def connector(a): return connect(a, security_plan.ssl_context, **config) - pool = RoutingConnectionPool(connector, *resolve(initial_address)) + pool = RoutingConnectionPool(connector, initial_address, *resolve(initial_address)) try: pool.update_routing_table() except: diff --git a/test/stub/test_routing.py b/test/stub/test_routing.py index bffcaa75d..6cf0ae357 100644 --- a/test/stub/test_routing.py +++ b/test/stub/test_routing.py @@ -48,6 +48,7 @@ "X": 1, } +UNREACHABLE_ADDRESS = ("127.0.0.1", 8080) def connector(address): return connect(address, auth=basic_auth("neotest", "neotest")) @@ -58,7 +59,7 @@ class RoutingConnectionPoolFetchRoutingInfoTestCase(StubTestCase): def test_should_get_info_from_router(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: result = pool.fetch_routing_info(address) assert len(result) == 1 record = result[0] @@ -72,7 +73,7 @@ def test_should_get_info_from_router(self): def test_should_remove_router_if_cannot_connect(self): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: assert address in pool.routing_table.routers _ = pool.fetch_routing_info(address) assert address not in pool.routing_table.routers @@ -80,14 +81,14 @@ def test_should_remove_router_if_cannot_connect(self): def test_should_remove_router_if_connection_drops(self): with StubCluster({9001: "rude_router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: assert address in pool.routing_table.routers _ = pool.fetch_routing_info(address) assert address not in pool.routing_table.routers def test_should_not_fail_if_cannot_connect_but_router_already_removed(self): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: assert address not in pool.routing_table.routers _ = pool.fetch_routing_info(address) assert address not in pool.routing_table.routers @@ -95,35 +96,35 @@ def test_should_not_fail_if_cannot_connect_but_router_already_removed(self): def test_should_not_fail_if_connection_drops_but_router_already_removed(self): with StubCluster({9001: "rude_router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: assert address not in pool.routing_table.routers _ = pool.fetch_routing_info(address) assert address not in pool.routing_table.routers def test_should_return_none_if_cannot_connect(self): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: result = pool.fetch_routing_info(address) assert result is None def test_should_return_none_if_connection_drops(self): with StubCluster({9001: "rude_router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: result = pool.fetch_routing_info(address) assert result is None def test_should_fail_for_non_router(self): with StubCluster({9001: "non_router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: with self.assertRaises(ServiceUnavailable): _ = pool.fetch_routing_info(address) def test_should_fail_if_database_error(self): with StubCluster({9001: "broken_router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: with self.assertRaises(ServiceUnavailable): _ = pool.fetch_routing_info(address) @@ -133,7 +134,7 @@ class RoutingConnectionPoolFetchRoutingTableTestCase(StubTestCase): def test_should_get_table_from_router(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: table = pool.fetch_routing_table(address) assert table.routers == {("127.0.0.1", 9001), ("127.0.0.1", 9002), ("127.0.0.1", 9003)} @@ -143,28 +144,28 @@ def test_should_get_table_from_router(self): def test_null_info_should_return_null_table(self): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: table = pool.fetch_routing_table(address) assert table is None def test_no_routers_should_raise_protocol_error(self): with StubCluster({9001: "router_no_routers.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: with self.assertRaises(ProtocolError): _ = pool.fetch_routing_table(address) def test_no_readers_should_raise_protocol_error(self): with StubCluster({9001: "router_no_readers.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: with self.assertRaises(ProtocolError): _ = pool.fetch_routing_table(address) def test_no_writers_should_return_null_table(self): with StubCluster({9001: "router_no_writers.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: table = pool.fetch_routing_table(address) assert table is None @@ -183,8 +184,21 @@ class RoutingConnectionPoolUpdateRoutingTableTestCase(StubTestCase): (None, None, ServiceUnavailable): ServiceUnavailable, } + def test_roll_back_to_initial_server_if_failed_update_with_existing_routers(self): + with StubCluster({9001: "router.script"}): + initial_address =("127.0.0.1", 9001) # roll back addresses + routers = [("127.0.0.1", 9002), ("127.0.0.1", 9003)] # not reachable servers + with RoutingConnectionPool(connector, initial_address, *routers) as pool: + pool.update_routing_table() + table = pool.routing_table + assert table.routers == {("127.0.0.1", 9001), ("127.0.0.1", 9002), + ("127.0.0.1", 9003)} + assert table.readers == {("127.0.0.1", 9004), ("127.0.0.1", 9005)} + assert table.writers == {("127.0.0.1", 9006)} + assert table.ttl == 300 + def test_update_with_no_routers_should_signal_service_unavailable(self): - with RoutingConnectionPool(connector) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: with self.assertRaises(ServiceUnavailable): pool.update_routing_table() @@ -207,7 +221,7 @@ def _test_server_outcome(self, server_outcomes, overall_outcome): assert False, "Unexpected server outcome %r" % outcome routers.append(("127.0.0.1", port)) with StubCluster(servers): - with RoutingConnectionPool(connector, *routers) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, *routers) as pool: if overall_outcome is RoutingTable: pool.update_routing_table() table = pool.routing_table @@ -228,7 +242,7 @@ class RoutingConnectionPoolRefreshRoutingTableTestCase(StubTestCase): def test_should_update_if_stale(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: first_updated_time = pool.routing_table.last_updated_time pool.routing_table.ttl = 0 pool.refresh_routing_table() @@ -238,7 +252,7 @@ def test_should_update_if_stale(self): def test_should_not_update_if_fresh(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: pool.refresh_routing_table() first_updated_time = pool.routing_table.last_updated_time pool.refresh_routing_table() @@ -250,7 +264,7 @@ def test_should_not_update_if_fresh(self): # address = ("127.0.0.1", 9001) # table = RoutingTable.parse_routing_info([VALID_ROUTING_RECORD]) # - # with RoutingConnectionPool(connector, address) as pool: + # with RoutingConnectionPool(connector, not_reachable_address, address) as pool: # semaphore = Semaphore() # # class Refresher(Thread): @@ -297,7 +311,7 @@ def test_should_not_update_if_fresh(self): # address = ("127.0.0.1", 9001) # table = RoutingTable.parse_routing_info([VALID_ROUTING_RECORD]) # - # with RoutingConnectionPool(connector, address) as pool: + # with RoutingConnectionPool(connector, not_reachable_address, address) as pool: # semaphore = Semaphore() # # class Refresher(Thread): @@ -345,7 +359,7 @@ class RoutingConnectionPoolAcquireForReadTestCase(StubTestCase): def test_should_refresh(self): with StubCluster({9001: "router.script", 9004: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: assert not pool.routing_table.is_fresh() _ = pool.acquire(access_mode=READ_ACCESS) assert pool.routing_table.is_fresh() @@ -353,7 +367,7 @@ def test_should_refresh(self): def test_connected_to_reader(self): with StubCluster({9001: "router.script", 9004: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: assert not pool.routing_table.is_fresh() connection = pool.acquire(access_mode=READ_ACCESS) assert connection.server.address in pool.routing_table.readers @@ -363,7 +377,7 @@ def test_should_retry_if_first_reader_fails(self): 9004: "fail_on_init.script", 9005: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: assert not pool.routing_table.is_fresh() _ = pool.acquire(access_mode=READ_ACCESS) assert ("127.0.0.1", 9004) not in pool.routing_table.readers @@ -375,7 +389,7 @@ class RoutingConnectionPoolAcquireForWriteTestCase(StubTestCase): def test_should_refresh(self): with StubCluster({9001: "router.script", 9006: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: assert not pool.routing_table.is_fresh() _ = pool.acquire(access_mode=WRITE_ACCESS) assert pool.routing_table.is_fresh() @@ -383,7 +397,7 @@ def test_should_refresh(self): def test_connected_to_writer(self): with StubCluster({9001: "router.script", 9006: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: assert not pool.routing_table.is_fresh() connection = pool.acquire(access_mode=WRITE_ACCESS) assert connection.server.address in pool.routing_table.writers @@ -393,7 +407,7 @@ def test_should_retry_if_first_writer_fails(self): 9006: "fail_on_init.script", 9007: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: assert not pool.routing_table.is_fresh() _ = pool.acquire(access_mode=WRITE_ACCESS) assert ("127.0.0.1", 9006) not in pool.routing_table.writers @@ -405,7 +419,7 @@ class RoutingConnectionPoolRemoveTestCase(StubTestCase): def test_should_remove_router_from_routing_table_if_present(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: pool.refresh_routing_table() target = ("127.0.0.1", 9001) assert target in pool.routing_table.routers @@ -415,7 +429,7 @@ def test_should_remove_router_from_routing_table_if_present(self): def test_should_remove_reader_from_routing_table_if_present(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: pool.refresh_routing_table() target = ("127.0.0.1", 9004) assert target in pool.routing_table.readers @@ -425,7 +439,7 @@ def test_should_remove_reader_from_routing_table_if_present(self): def test_should_remove_writer_from_routing_table_if_present(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: pool.refresh_routing_table() target = ("127.0.0.1", 9006) assert target in pool.routing_table.writers @@ -435,7 +449,7 @@ def test_should_remove_writer_from_routing_table_if_present(self): def test_should_not_fail_if_absent(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: pool.refresh_routing_table() target = ("127.0.0.1", 9007) pool.remove(target) diff --git a/test/unit/test_routing.py b/test/unit/test_routing.py index 9e26689db..5361259e9 100644 --- a/test/unit/test_routing.py +++ b/test/unit/test_routing.py @@ -227,5 +227,7 @@ def test_update_should_replace_ttl(self): class RoutingConnectionPoolConstructionTestCase(TestCase): def test_should_populate_initial_router(self): - with RoutingConnectionPool(connector, ("127.0.0.1", 9001)) as pool: - assert pool.routing_table.routers == {("127.0.0.1", 9001)} + initial_router = ("127.0.0.1", 9001) + router = ("127.0.0.1", 9002) + with RoutingConnectionPool(connector, initial_router, router) as pool: + assert pool.routing_table.routers == {("127.0.0.1", 9002)} From 3c00b97c5519602315c0975c95e58c3765bc1c2f Mon Sep 17 00:00:00 2001 From: Zhen Date: Tue, 25 Apr 2017 12:29:48 +0200 Subject: [PATCH 2/3] Support routing context from bolt routing uri --- neo4j/addressing.py | 19 +++++++++ neo4j/util.py | 19 +++++++++ neo4j/v1/routing.py | 22 +++++++--- test/integration/tools.py | 11 ++--- test/stub/test_routing.py | 79 +++++++++++++++++------------------- test/unit/test_addressing.py | 26 ++++++++++++ 6 files changed, 122 insertions(+), 54 deletions(-) diff --git a/neo4j/addressing.py b/neo4j/addressing.py index 8dca01e17..a2016764c 100644 --- a/neo4j/addressing.py +++ b/neo4j/addressing.py @@ -79,6 +79,25 @@ def parse(cls, string, default_port=0): """ return cls.from_uri("//{}".format(string), default_port) + @classmethod + def parse_routing_context(cls, uri): + query = urlparse(uri).query + if not query: + return {} + + context = {} + parameters = [x for x in query.split('&') if x] + for keyValue in parameters: + pair = keyValue.split('=') + if len(pair) != 2 or not pair[0] or not pair[1]: + raise ValueError("Invalid parameters: '" + keyValue + "' in URI '" + uri + "'.") + key = pair[0] + value = pair[1] + if key in context: + raise ValueError("Duplicated query parameters with key '" + key + "' found in URL '" + uri + "'") + context[key] = value + return context + def resolve(socket_address): try: diff --git a/neo4j/util.py b/neo4j/util.py index a91380777..d8b711b4f 100644 --- a/neo4j/util.py +++ b/neo4j/util.py @@ -25,6 +25,25 @@ from sys import stdout +class ServerVersion(object): + def __init__(self, product, version_tuple, tags_tuple): + self.product = product + self.version_tuple = version_tuple + self.tags_tuple = tags_tuple + + def at_least_version(self, major, minor): + return self.version_tuple >= (major, minor) + + @classmethod + def from_str(cls, full_version): + if full_version is None: + return ServerVersion("Neo4j", (3, 0), ()) + product, _, tagged_version = full_version.partition("/") + tags = tagged_version.split("-") + version = map(int, tags[0].split(".")) + return ServerVersion(product, tuple(version), tuple(tags[1:])) + + class ColourFormatter(logging.Formatter): """ Colour formatter for pretty log output. """ diff --git a/neo4j/v1/routing.py b/neo4j/v1/routing.py index 7540d5fff..ce1292287 100644 --- a/neo4j/v1/routing.py +++ b/neo4j/v1/routing.py @@ -30,6 +30,7 @@ from neo4j.v1.exceptions import SessionExpired from neo4j.v1.security import SecurityPlan from neo4j.v1.session import BoltSession +from neo4j.util import ServerVersion class RoundRobinSet(MutableSet): @@ -152,14 +153,23 @@ class RoutingConnectionPool(ConnectionPool): """ Connection pool with routing table. """ - routing_info_procedure = "dbms.cluster.routing.getServers" + call_get_servers = "CALL dbms.cluster.routing.getServers" + get_routing_table_param = "context" + call_get_routing_table = "CALL dbms.cluster.routing.getRoutingTable({" + get_routing_table_param + "})" - def __init__(self, connector, initial_address, *routers): + def __init__(self, connector, initial_address, routing_context, *routers): super(RoutingConnectionPool, self).__init__(connector) self.initial_address = initial_address + self.routing_context = routing_context self.routing_table = RoutingTable(routers) self.refresh_lock = Lock() + def routing_info_procedure(self, connection): + if ServerVersion.from_str(connection.server.version).at_least_version(3, 2): + return self.call_get_routing_table, {self.get_routing_table_param: self.routing_context} + else: + return self.call_get_servers, None + def fetch_routing_info(self, address): """ Fetch raw routing info from a given router address. @@ -170,8 +180,9 @@ def fetch_routing_info(self, address): if routing support is broken """ try: - with BoltSession(lambda _: self.acquire_direct(address)) as session: - return list(session.run("CALL %s" % self.routing_info_procedure)) + connection = self.acquire_direct(address) + with BoltSession(lambda _: connection) as session: + return list(session.run(*self.routing_info_procedure(connection))) except CypherError as error: if error.code == "Neo.ClientError.Procedure.ProcedureNotFound": raise ServiceUnavailable("Server {!r} does not support routing".format(address)) @@ -313,6 +324,7 @@ def __init__(self, uri, **config): self.initial_address = initial_address = SocketAddress.from_uri(uri, DEFAULT_PORT) self.security_plan = security_plan = SecurityPlan.build(**config) self.encrypted = security_plan.encrypted + routing_context = SocketAddress.parse_routing_context(uri) if not security_plan.routing_compatible: # this error message is case-specific as there is only one incompatible # scenario right now @@ -321,7 +333,7 @@ def __init__(self, uri, **config): def connector(a): return connect(a, security_plan.ssl_context, **config) - pool = RoutingConnectionPool(connector, initial_address, *resolve(initial_address)) + pool = RoutingConnectionPool(connector, initial_address, routing_context *resolve(initial_address)) try: pool.update_routing_table() except: diff --git a/test/integration/tools.py b/test/integration/tools.py index ff261cc98..9402eee90 100644 --- a/test/integration/tools.py +++ b/test/integration/tools.py @@ -33,6 +33,7 @@ from boltkit.controller import WindowsController, UnixController from neo4j.v1 import GraphDatabase, AuthError +from neo4j.util import ServerVersion from test.env import NEO4J_SERVER_PACKAGE, NEO4J_USER, NEO4J_PASSWORD @@ -89,17 +90,11 @@ def server_version_info(cls): with GraphDatabase.driver(cls.bolt_uri, auth=cls.auth_token) as driver: with driver.session() as session: full_version = session.run("RETURN 1").summary().server.version - if full_version is None: - return "Neo4j", (3, 0), () - product, _, tagged_version = full_version.partition("/") - tags = tagged_version.split("-") - version = map(int, tags[0].split(".")) - return product, tuple(version), tuple(tags[1:]) + return ServerVersion.from_str(full_version) @classmethod def at_least_version(cls, major, minor): - _, server_version, _ = cls.server_version_info() - return server_version >= (major, minor) + return cls.server_version_info().at_least_version(major, minor); @classmethod def delete_known_hosts_file(cls): diff --git a/test/stub/test_routing.py b/test/stub/test_routing.py index 6cf0ae357..17c2d7561 100644 --- a/test/stub/test_routing.py +++ b/test/stub/test_routing.py @@ -24,7 +24,6 @@ from test.stub.tools import StubCluster, StubTestCase - VALID_ROUTING_RECORD = { "ttl": 300, "servers": [ @@ -50,16 +49,20 @@ UNREACHABLE_ADDRESS = ("127.0.0.1", 8080) + def connector(address): return connect(address, auth=basic_auth("neotest", "neotest")) -class RoutingConnectionPoolFetchRoutingInfoTestCase(StubTestCase): +def RoutingPool(*routers): + return RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, {}, *routers) + +class RoutingConnectionPoolFetchRoutingInfoTestCase(StubTestCase): def test_should_get_info_from_router(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: + with RoutingPool() as pool: result = pool.fetch_routing_info(address) assert len(result) == 1 record = result[0] @@ -73,7 +76,7 @@ def test_should_get_info_from_router(self): def test_should_remove_router_if_cannot_connect(self): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: + with RoutingPool(address) as pool: assert address in pool.routing_table.routers _ = pool.fetch_routing_info(address) assert address not in pool.routing_table.routers @@ -81,14 +84,14 @@ def test_should_remove_router_if_cannot_connect(self): def test_should_remove_router_if_connection_drops(self): with StubCluster({9001: "rude_router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: + with RoutingPool(address) as pool: assert address in pool.routing_table.routers _ = pool.fetch_routing_info(address) assert address not in pool.routing_table.routers def test_should_not_fail_if_cannot_connect_but_router_already_removed(self): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: + with RoutingPool() as pool: assert address not in pool.routing_table.routers _ = pool.fetch_routing_info(address) assert address not in pool.routing_table.routers @@ -96,45 +99,44 @@ def test_should_not_fail_if_cannot_connect_but_router_already_removed(self): def test_should_not_fail_if_connection_drops_but_router_already_removed(self): with StubCluster({9001: "rude_router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: + with RoutingPool() as pool: assert address not in pool.routing_table.routers _ = pool.fetch_routing_info(address) assert address not in pool.routing_table.routers def test_should_return_none_if_cannot_connect(self): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: + with RoutingPool(address) as pool: result = pool.fetch_routing_info(address) assert result is None def test_should_return_none_if_connection_drops(self): with StubCluster({9001: "rude_router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: + with RoutingPool(address) as pool: result = pool.fetch_routing_info(address) assert result is None def test_should_fail_for_non_router(self): with StubCluster({9001: "non_router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: + with RoutingPool(address) as pool: with self.assertRaises(ServiceUnavailable): _ = pool.fetch_routing_info(address) def test_should_fail_if_database_error(self): with StubCluster({9001: "broken_router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: + with RoutingPool(address) as pool: with self.assertRaises(ServiceUnavailable): _ = pool.fetch_routing_info(address) class RoutingConnectionPoolFetchRoutingTableTestCase(StubTestCase): - def test_should_get_table_from_router(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: + with RoutingPool() as pool: table = pool.fetch_routing_table(address) assert table.routers == {("127.0.0.1", 9001), ("127.0.0.1", 9002), ("127.0.0.1", 9003)} @@ -144,34 +146,33 @@ def test_should_get_table_from_router(self): def test_null_info_should_return_null_table(self): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: + with RoutingPool() as pool: table = pool.fetch_routing_table(address) assert table is None def test_no_routers_should_raise_protocol_error(self): with StubCluster({9001: "router_no_routers.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: + with RoutingPool() as pool: with self.assertRaises(ProtocolError): _ = pool.fetch_routing_table(address) def test_no_readers_should_raise_protocol_error(self): with StubCluster({9001: "router_no_readers.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: + with RoutingPool() as pool: with self.assertRaises(ProtocolError): _ = pool.fetch_routing_table(address) def test_no_writers_should_return_null_table(self): with StubCluster({9001: "router_no_writers.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: + with RoutingPool() as pool: table = pool.fetch_routing_table(address) assert table is None class RoutingConnectionPoolUpdateRoutingTableTestCase(StubTestCase): - scenarios = { (None,): ServiceUnavailable, (RoutingTable,): RoutingTable, @@ -186,9 +187,9 @@ class RoutingConnectionPoolUpdateRoutingTableTestCase(StubTestCase): def test_roll_back_to_initial_server_if_failed_update_with_existing_routers(self): with StubCluster({9001: "router.script"}): - initial_address =("127.0.0.1", 9001) # roll back addresses - routers = [("127.0.0.1", 9002), ("127.0.0.1", 9003)] # not reachable servers - with RoutingConnectionPool(connector, initial_address, *routers) as pool: + initial_address = ("127.0.0.1", 9001) # roll back addresses + routers = [("127.0.0.1", 9002), ("127.0.0.1", 9003)] # not reachable servers + with RoutingConnectionPool(connector, initial_address, {}, *routers) as pool: pool.update_routing_table() table = pool.routing_table assert table.routers == {("127.0.0.1", 9001), ("127.0.0.1", 9002), @@ -198,7 +199,7 @@ def test_roll_back_to_initial_server_if_failed_update_with_existing_routers(self assert table.ttl == 300 def test_update_with_no_routers_should_signal_service_unavailable(self): - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: + with RoutingPool() as pool: with self.assertRaises(ServiceUnavailable): pool.update_routing_table() @@ -221,7 +222,7 @@ def _test_server_outcome(self, server_outcomes, overall_outcome): assert False, "Unexpected server outcome %r" % outcome routers.append(("127.0.0.1", port)) with StubCluster(servers): - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, *routers) as pool: + with RoutingPool(*routers) as pool: if overall_outcome is RoutingTable: pool.update_routing_table() table = pool.routing_table @@ -238,11 +239,10 @@ def _test_server_outcome(self, server_outcomes, overall_outcome): class RoutingConnectionPoolRefreshRoutingTableTestCase(StubTestCase): - def test_should_update_if_stale(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: + with RoutingPool(address) as pool: first_updated_time = pool.routing_table.last_updated_time pool.routing_table.ttl = 0 pool.refresh_routing_table() @@ -252,7 +252,7 @@ def test_should_update_if_stale(self): def test_should_not_update_if_fresh(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: + with RoutingPool(address) as pool: pool.refresh_routing_table() first_updated_time = pool.routing_table.last_updated_time pool.refresh_routing_table() @@ -264,7 +264,7 @@ def test_should_not_update_if_fresh(self): # address = ("127.0.0.1", 9001) # table = RoutingTable.parse_routing_info([VALID_ROUTING_RECORD]) # - # with RoutingConnectionPool(connector, not_reachable_address, address) as pool: + # with RoutingPool(address) as pool: # semaphore = Semaphore() # # class Refresher(Thread): @@ -311,7 +311,7 @@ def test_should_not_update_if_fresh(self): # address = ("127.0.0.1", 9001) # table = RoutingTable.parse_routing_info([VALID_ROUTING_RECORD]) # - # with RoutingConnectionPool(connector, not_reachable_address, address) as pool: + # with RoutingPool(address) as pool: # semaphore = Semaphore() # # class Refresher(Thread): @@ -355,11 +355,10 @@ def test_should_not_update_if_fresh(self): class RoutingConnectionPoolAcquireForReadTestCase(StubTestCase): - def test_should_refresh(self): with StubCluster({9001: "router.script", 9004: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: + with RoutingPool(address) as pool: assert not pool.routing_table.is_fresh() _ = pool.acquire(access_mode=READ_ACCESS) assert pool.routing_table.is_fresh() @@ -367,7 +366,7 @@ def test_should_refresh(self): def test_connected_to_reader(self): with StubCluster({9001: "router.script", 9004: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: + with RoutingPool(address) as pool: assert not pool.routing_table.is_fresh() connection = pool.acquire(access_mode=READ_ACCESS) assert connection.server.address in pool.routing_table.readers @@ -377,7 +376,7 @@ def test_should_retry_if_first_reader_fails(self): 9004: "fail_on_init.script", 9005: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: + with RoutingPool(address) as pool: assert not pool.routing_table.is_fresh() _ = pool.acquire(access_mode=READ_ACCESS) assert ("127.0.0.1", 9004) not in pool.routing_table.readers @@ -385,11 +384,10 @@ def test_should_retry_if_first_reader_fails(self): class RoutingConnectionPoolAcquireForWriteTestCase(StubTestCase): - def test_should_refresh(self): with StubCluster({9001: "router.script", 9006: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: + with RoutingPool(address) as pool: assert not pool.routing_table.is_fresh() _ = pool.acquire(access_mode=WRITE_ACCESS) assert pool.routing_table.is_fresh() @@ -397,7 +395,7 @@ def test_should_refresh(self): def test_connected_to_writer(self): with StubCluster({9001: "router.script", 9006: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: + with RoutingPool(address) as pool: assert not pool.routing_table.is_fresh() connection = pool.acquire(access_mode=WRITE_ACCESS) assert connection.server.address in pool.routing_table.writers @@ -407,7 +405,7 @@ def test_should_retry_if_first_writer_fails(self): 9006: "fail_on_init.script", 9007: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: + with RoutingPool(address) as pool: assert not pool.routing_table.is_fresh() _ = pool.acquire(access_mode=WRITE_ACCESS) assert ("127.0.0.1", 9006) not in pool.routing_table.writers @@ -415,11 +413,10 @@ def test_should_retry_if_first_writer_fails(self): class RoutingConnectionPoolRemoveTestCase(StubTestCase): - def test_should_remove_router_from_routing_table_if_present(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: + with RoutingPool(address) as pool: pool.refresh_routing_table() target = ("127.0.0.1", 9001) assert target in pool.routing_table.routers @@ -429,7 +426,7 @@ def test_should_remove_router_from_routing_table_if_present(self): def test_should_remove_reader_from_routing_table_if_present(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: + with RoutingPool(address) as pool: pool.refresh_routing_table() target = ("127.0.0.1", 9004) assert target in pool.routing_table.readers @@ -439,7 +436,7 @@ def test_should_remove_reader_from_routing_table_if_present(self): def test_should_remove_writer_from_routing_table_if_present(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: + with RoutingPool(address) as pool: pool.refresh_routing_table() target = ("127.0.0.1", 9006) assert target in pool.routing_table.writers @@ -449,7 +446,7 @@ def test_should_remove_writer_from_routing_table_if_present(self): def test_should_not_fail_if_absent(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: + with RoutingPool(address) as pool: pool.refresh_routing_table() target = ("127.0.0.1", 9007) pool.remove(target) diff --git a/test/unit/test_addressing.py b/test/unit/test_addressing.py index 33c8d8aea..4c708d83d 100644 --- a/test/unit/test_addressing.py +++ b/test/unit/test_addressing.py @@ -41,3 +41,29 @@ def test_should_parse_host_name_and_port(self): def test_should_fail_on_non_numeric_port(self): with self.assertRaises(ValueError): _ = SocketAddress.parse("127.0.0.1:X") + + def test_parse_empty_routing_context(self): + verify_routing_context({}, "bolt+routing://127.0.0.1/cat?") + verify_routing_context({}, "bolt+routing://127.0.0.1/cat") + verify_routing_context({}, "bolt+routing://127.0.0.1/?") + verify_routing_context({}, "bolt+routing://127.0.0.1/") + verify_routing_context({}, "bolt+routing://127.0.0.1?") + verify_routing_context({}, "bolt+routing://127.0.0.1") + + def test_parse_routing_context(self): + verify_routing_context({"name": "molly", "color": "white"}, "bolt+routing://127.0.0.1/cat?name=molly&color=white") + verify_routing_context({"name": "molly", "color": "white"}, "bolt+routing://127.0.0.1/?name=molly&color=white") + verify_routing_context({"name": "molly", "color": "white"}, "bolt+routing://127.0.0.1?name=molly&color=white") + + def test_should_error_when_value_missing(self): + with self.assertRaises(ValueError): + SocketAddress.parse_routing_context("bolt+routing://127.0.0.1/?name=&color=white") + + def test_should_error_when_key_duplicate(self): + with self.assertRaises(ValueError): + SocketAddress.parse_routing_context("bolt+routing://127.0.0.1/?name=molly&name=white") + + +def verify_routing_context(expected, uri): + context = SocketAddress.parse_routing_context(uri) + assert context == expected From 476f7d1248e757024a724ab9fac28db7fe65fbaf Mon Sep 17 00:00:00 2001 From: Zhen Date: Tue, 25 Apr 2017 14:43:28 +0200 Subject: [PATCH 3/3] Added boltkit test to verify that the driver actually works with the routing context in uri --- neo4j/v1/direct.py | 2 ++ neo4j/v1/routing.py | 2 +- test/stub/scripts/get_routing_table.script | 9 +++++++++ .../get_routing_table_with_context.script | 9 +++++++++ test/stub/test_directdriver.py | 6 ++++++ test/stub/test_routing.py | 13 ++++++++++++ test/stub/test_routingdriver.py | 20 +++++++++++++++++++ test/unit/test_routing.py | 2 +- 8 files changed, 61 insertions(+), 2 deletions(-) create mode 100644 test/stub/scripts/get_routing_table.script create mode 100644 test/stub/scripts/get_routing_table_with_context.script diff --git a/neo4j/v1/direct.py b/neo4j/v1/direct.py index 92376a72e..42f0fdd4d 100644 --- a/neo4j/v1/direct.py +++ b/neo4j/v1/direct.py @@ -56,6 +56,8 @@ def __init__(self, uri, **config): # will carry out DNS resolution, leading to the possibility that # the connection pool may contain multiple IP address keys, one for # an old address and one for a new address. + if SocketAddress.parse_routing_context(uri): + raise ValueError("Routing parameters are not supported with scheme 'bolt'. Given URI: '" + uri + "'.") self.address = SocketAddress.from_uri(uri, DEFAULT_PORT) self.security_plan = security_plan = SecurityPlan.build(**config) self.encrypted = security_plan.encrypted diff --git a/neo4j/v1/routing.py b/neo4j/v1/routing.py index ce1292287..05e1dd7b0 100644 --- a/neo4j/v1/routing.py +++ b/neo4j/v1/routing.py @@ -333,7 +333,7 @@ def __init__(self, uri, **config): def connector(a): return connect(a, security_plan.ssl_context, **config) - pool = RoutingConnectionPool(connector, initial_address, routing_context *resolve(initial_address)) + pool = RoutingConnectionPool(connector, initial_address, routing_context, *resolve(initial_address)) try: pool.update_routing_table() except: diff --git a/test/stub/scripts/get_routing_table.script b/test/stub/scripts/get_routing_table.script new file mode 100644 index 000000000..945ea64cd --- /dev/null +++ b/test/stub/scripts/get_routing_table.script @@ -0,0 +1,9 @@ +!: AUTO INIT +!: AUTO RESET + +S: SUCCESS {"server": "Neo4j/3.2.2"} +C: RUN "CALL dbms.cluster.routing.getRoutingTable({context})" {"context": {}} + PULL_ALL +S: SUCCESS {"fields": ["ttl", "servers"]} + RECORD [9223372036854775807, [{"addresses": ["127.0.0.1:9001"],"role": "WRITE"}, {"addresses": ["127.0.0.1:9002"], "role": "READ"},{"addresses": ["127.0.0.1:9001", "127.0.0.1:9002"], "role": "ROUTE"}]] + SUCCESS {} \ No newline at end of file diff --git a/test/stub/scripts/get_routing_table_with_context.script b/test/stub/scripts/get_routing_table_with_context.script new file mode 100644 index 000000000..ce4244744 --- /dev/null +++ b/test/stub/scripts/get_routing_table_with_context.script @@ -0,0 +1,9 @@ +!: AUTO INIT +!: AUTO RESET + +S: SUCCESS {"server": "Neo4j/3.2.3"} +C: RUN "CALL dbms.cluster.routing.getRoutingTable({context})" {"context": {"name": "molly", "age": "1"}} + PULL_ALL +S: SUCCESS {"fields": ["ttl", "servers"]} + RECORD [9223372036854775807, [{"addresses": ["127.0.0.1:9001"],"role": "WRITE"}, {"addresses": ["127.0.0.1:9002"], "role": "READ"},{"addresses": ["127.0.0.1:9001", "127.0.0.1:9002"], "role": "ROUTE"}]] + SUCCESS {} \ No newline at end of file diff --git a/test/stub/test_directdriver.py b/test/stub/test_directdriver.py index efb63062b..60d80d809 100644 --- a/test/stub/test_directdriver.py +++ b/test/stub/test_directdriver.py @@ -48,3 +48,9 @@ def test_direct_disconnect_on_pull_all(self): with self.assertRaises(ServiceUnavailable): with driver.session() as session: session.run("RETURN $x", {"x": 1}).consume() + + def test_direct_should_reject_routing_context(self): + uri = "bolt://127.0.0.1:9001/?name=molly&age=1" + with self.assertRaises(ValueError): + GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) + diff --git a/test/stub/test_routing.py b/test/stub/test_routing.py index 17c2d7561..1b2d7dc3b 100644 --- a/test/stub/test_routing.py +++ b/test/stub/test_routing.py @@ -131,6 +131,19 @@ def test_should_fail_if_database_error(self): with self.assertRaises(ServiceUnavailable): _ = pool.fetch_routing_info(address) + def test_should_call_get_routing_tables_with_context(self): + with StubCluster({9001: "get_routing_table_with_context.script"}): + address = ("127.0.0.1", 9001) + routing_context = {"name": "molly", "age": "1"} + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, routing_context) as pool: + pool.fetch_routing_info(address) + + def test_should_call_get_routing_tables(self): + with StubCluster({9001: "get_routing_table.script"}): + address = ("127.0.0.1", 9001) + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, {}) as pool: + pool.fetch_routing_info(address) + class RoutingConnectionPoolFetchRoutingTableTestCase(StubTestCase): def test_should_get_table_from_router(self): diff --git a/test/stub/test_routingdriver.py b/test/stub/test_routingdriver.py index 3fe952f07..9f9ecb7ac 100644 --- a/test/stub/test_routingdriver.py +++ b/test/stub/test_routingdriver.py @@ -179,3 +179,23 @@ def test_two_sessions_can_share_a_connection(self): session_2.close() session_1.close() + + def test_should_call_get_routing_table_procedure(self): + with StubCluster({9001: "get_routing_table.script", 9002: "return_1.script"}): + uri = "bolt+routing://127.0.0.1:9001" + with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) as driver: + with driver.session(READ_ACCESS) as session: + result = session.run("RETURN $x", {"x": 1}) + for record in result: + assert record["x"] == 1 + assert result.summary().server.address == ('127.0.0.1', 9002) + + def test_should_call_get_routing_table_with_context(self): + with StubCluster({9001: "get_routing_table_with_context.script", 9002: "return_1.script"}): + uri = "bolt+routing://127.0.0.1:9001/?name=molly&age=1" + with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) as driver: + with driver.session(READ_ACCESS) as session: + result = session.run("RETURN $x", {"x": 1}) + for record in result: + assert record["x"] == 1 + assert result.summary().server.address == ('127.0.0.1', 9002) \ No newline at end of file diff --git a/test/unit/test_routing.py b/test/unit/test_routing.py index 5361259e9..e6f2f80c9 100644 --- a/test/unit/test_routing.py +++ b/test/unit/test_routing.py @@ -229,5 +229,5 @@ class RoutingConnectionPoolConstructionTestCase(TestCase): def test_should_populate_initial_router(self): initial_router = ("127.0.0.1", 9001) router = ("127.0.0.1", 9002) - with RoutingConnectionPool(connector, initial_router, router) as pool: + with RoutingConnectionPool(connector, initial_router, {}, router) as pool: assert pool.routing_table.routers == {("127.0.0.1", 9002)}