diff --git a/neo4j/addressing.py b/neo4j/addressing.py index 8dca01e17..a2016764c 100644 --- a/neo4j/addressing.py +++ b/neo4j/addressing.py @@ -79,6 +79,25 @@ def parse(cls, string, default_port=0): """ return cls.from_uri("//{}".format(string), default_port) + @classmethod + def parse_routing_context(cls, uri): + query = urlparse(uri).query + if not query: + return {} + + context = {} + parameters = [x for x in query.split('&') if x] + for keyValue in parameters: + pair = keyValue.split('=') + if len(pair) != 2 or not pair[0] or not pair[1]: + raise ValueError("Invalid parameters: '" + keyValue + "' in URI '" + uri + "'.") + key = pair[0] + value = pair[1] + if key in context: + raise ValueError("Duplicated query parameters with key '" + key + "' found in URL '" + uri + "'") + context[key] = value + return context + def resolve(socket_address): try: diff --git a/neo4j/util.py b/neo4j/util.py index a91380777..d8b711b4f 100644 --- a/neo4j/util.py +++ b/neo4j/util.py @@ -25,6 +25,25 @@ from sys import stdout +class ServerVersion(object): + def __init__(self, product, version_tuple, tags_tuple): + self.product = product + self.version_tuple = version_tuple + self.tags_tuple = tags_tuple + + def at_least_version(self, major, minor): + return self.version_tuple >= (major, minor) + + @classmethod + def from_str(cls, full_version): + if full_version is None: + return ServerVersion("Neo4j", (3, 0), ()) + product, _, tagged_version = full_version.partition("/") + tags = tagged_version.split("-") + version = map(int, tags[0].split(".")) + return ServerVersion(product, tuple(version), tuple(tags[1:])) + + class ColourFormatter(logging.Formatter): """ Colour formatter for pretty log output. """ diff --git a/neo4j/v1/direct.py b/neo4j/v1/direct.py index 92376a72e..42f0fdd4d 100644 --- a/neo4j/v1/direct.py +++ b/neo4j/v1/direct.py @@ -56,6 +56,8 @@ def __init__(self, uri, **config): # will carry out DNS resolution, leading to the possibility that # the connection pool may contain multiple IP address keys, one for # an old address and one for a new address. + if SocketAddress.parse_routing_context(uri): + raise ValueError("Routing parameters are not supported with scheme 'bolt'. Given URI: '" + uri + "'.") self.address = SocketAddress.from_uri(uri, DEFAULT_PORT) self.security_plan = security_plan = SecurityPlan.build(**config) self.encrypted = security_plan.encrypted diff --git a/neo4j/v1/routing.py b/neo4j/v1/routing.py index 73b9fcfa7..05e1dd7b0 100644 --- a/neo4j/v1/routing.py +++ b/neo4j/v1/routing.py @@ -30,6 +30,7 @@ from neo4j.v1.exceptions import SessionExpired from neo4j.v1.security import SecurityPlan from neo4j.v1.session import BoltSession +from neo4j.util import ServerVersion class RoundRobinSet(MutableSet): @@ -152,13 +153,23 @@ class RoutingConnectionPool(ConnectionPool): """ Connection pool with routing table. """ - routing_info_procedure = "dbms.cluster.routing.getServers" + call_get_servers = "CALL dbms.cluster.routing.getServers" + get_routing_table_param = "context" + call_get_routing_table = "CALL dbms.cluster.routing.getRoutingTable({" + get_routing_table_param + "})" - def __init__(self, connector, *routers): + def __init__(self, connector, initial_address, routing_context, *routers): super(RoutingConnectionPool, self).__init__(connector) + self.initial_address = initial_address + self.routing_context = routing_context self.routing_table = RoutingTable(routers) self.refresh_lock = Lock() + def routing_info_procedure(self, connection): + if ServerVersion.from_str(connection.server.version).at_least_version(3, 2): + return self.call_get_routing_table, {self.get_routing_table_param: self.routing_context} + else: + return self.call_get_servers, None + def fetch_routing_info(self, address): """ Fetch raw routing info from a given router address. @@ -169,8 +180,9 @@ def fetch_routing_info(self, address): if routing support is broken """ try: - with BoltSession(lambda _: self.acquire_direct(address)) as session: - return list(session.run("CALL %s" % self.routing_info_procedure)) + connection = self.acquire_direct(address) + with BoltSession(lambda _: connection) as session: + return list(session.run(*self.routing_info_procedure(connection))) except CypherError as error: if error.code == "Neo.ClientError.Procedure.ProcedureNotFound": raise ServiceUnavailable("Server {!r} does not support routing".format(address)) @@ -216,16 +228,32 @@ def fetch_routing_table(self, address): # At least one of each is fine, so return this table return new_routing_table + def update_routing_table_with_routers(self, routers): + """Try to update routing tables with the given routers + :return: True if the routing table is successfully updated, otherwise False + """ + for router in routers: + new_routing_table = self.fetch_routing_table(router) + if new_routing_table is not None: + self.routing_table.update(new_routing_table) + return True + return False + def update_routing_table(self): """ Update the routing table from the first router able to provide valid routing information. """ # copied because it can be modified copy_of_routers = list(self.routing_table.routers) + if self.update_routing_table_with_routers(copy_of_routers): + return + + initial_routers = resolve(self.initial_address) for router in copy_of_routers: - new_routing_table = self.fetch_routing_table(router) - if new_routing_table is not None: - self.routing_table.update(new_routing_table) + if router in initial_routers: + initial_routers.remove(router) + if len(initial_routers) != 0: + if self.update_routing_table_with_routers(initial_routers): return # None of the routers have been successful, so just fail @@ -296,6 +324,7 @@ def __init__(self, uri, **config): self.initial_address = initial_address = SocketAddress.from_uri(uri, DEFAULT_PORT) self.security_plan = security_plan = SecurityPlan.build(**config) self.encrypted = security_plan.encrypted + routing_context = SocketAddress.parse_routing_context(uri) if not security_plan.routing_compatible: # this error message is case-specific as there is only one incompatible # scenario right now @@ -304,7 +333,7 @@ def __init__(self, uri, **config): def connector(a): return connect(a, security_plan.ssl_context, **config) - pool = RoutingConnectionPool(connector, *resolve(initial_address)) + pool = RoutingConnectionPool(connector, initial_address, routing_context, *resolve(initial_address)) try: pool.update_routing_table() except: diff --git a/test/integration/tools.py b/test/integration/tools.py index ff261cc98..9402eee90 100644 --- a/test/integration/tools.py +++ b/test/integration/tools.py @@ -33,6 +33,7 @@ from boltkit.controller import WindowsController, UnixController from neo4j.v1 import GraphDatabase, AuthError +from neo4j.util import ServerVersion from test.env import NEO4J_SERVER_PACKAGE, NEO4J_USER, NEO4J_PASSWORD @@ -89,17 +90,11 @@ def server_version_info(cls): with GraphDatabase.driver(cls.bolt_uri, auth=cls.auth_token) as driver: with driver.session() as session: full_version = session.run("RETURN 1").summary().server.version - if full_version is None: - return "Neo4j", (3, 0), () - product, _, tagged_version = full_version.partition("/") - tags = tagged_version.split("-") - version = map(int, tags[0].split(".")) - return product, tuple(version), tuple(tags[1:]) + return ServerVersion.from_str(full_version) @classmethod def at_least_version(cls, major, minor): - _, server_version, _ = cls.server_version_info() - return server_version >= (major, minor) + return cls.server_version_info().at_least_version(major, minor); @classmethod def delete_known_hosts_file(cls): diff --git a/test/stub/scripts/get_routing_table.script b/test/stub/scripts/get_routing_table.script new file mode 100644 index 000000000..945ea64cd --- /dev/null +++ b/test/stub/scripts/get_routing_table.script @@ -0,0 +1,9 @@ +!: AUTO INIT +!: AUTO RESET + +S: SUCCESS {"server": "Neo4j/3.2.2"} +C: RUN "CALL dbms.cluster.routing.getRoutingTable({context})" {"context": {}} + PULL_ALL +S: SUCCESS {"fields": ["ttl", "servers"]} + RECORD [9223372036854775807, [{"addresses": ["127.0.0.1:9001"],"role": "WRITE"}, {"addresses": ["127.0.0.1:9002"], "role": "READ"},{"addresses": ["127.0.0.1:9001", "127.0.0.1:9002"], "role": "ROUTE"}]] + SUCCESS {} \ No newline at end of file diff --git a/test/stub/scripts/get_routing_table_with_context.script b/test/stub/scripts/get_routing_table_with_context.script new file mode 100644 index 000000000..ce4244744 --- /dev/null +++ b/test/stub/scripts/get_routing_table_with_context.script @@ -0,0 +1,9 @@ +!: AUTO INIT +!: AUTO RESET + +S: SUCCESS {"server": "Neo4j/3.2.3"} +C: RUN "CALL dbms.cluster.routing.getRoutingTable({context})" {"context": {"name": "molly", "age": "1"}} + PULL_ALL +S: SUCCESS {"fields": ["ttl", "servers"]} + RECORD [9223372036854775807, [{"addresses": ["127.0.0.1:9001"],"role": "WRITE"}, {"addresses": ["127.0.0.1:9002"], "role": "READ"},{"addresses": ["127.0.0.1:9001", "127.0.0.1:9002"], "role": "ROUTE"}]] + SUCCESS {} \ No newline at end of file diff --git a/test/stub/test_directdriver.py b/test/stub/test_directdriver.py index efb63062b..60d80d809 100644 --- a/test/stub/test_directdriver.py +++ b/test/stub/test_directdriver.py @@ -48,3 +48,9 @@ def test_direct_disconnect_on_pull_all(self): with self.assertRaises(ServiceUnavailable): with driver.session() as session: session.run("RETURN $x", {"x": 1}).consume() + + def test_direct_should_reject_routing_context(self): + uri = "bolt://127.0.0.1:9001/?name=molly&age=1" + with self.assertRaises(ValueError): + GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) + diff --git a/test/stub/test_routing.py b/test/stub/test_routing.py index bffcaa75d..1b2d7dc3b 100644 --- a/test/stub/test_routing.py +++ b/test/stub/test_routing.py @@ -24,7 +24,6 @@ from test.stub.tools import StubCluster, StubTestCase - VALID_ROUTING_RECORD = { "ttl": 300, "servers": [ @@ -48,17 +47,22 @@ "X": 1, } +UNREACHABLE_ADDRESS = ("127.0.0.1", 8080) + def connector(address): return connect(address, auth=basic_auth("neotest", "neotest")) -class RoutingConnectionPoolFetchRoutingInfoTestCase(StubTestCase): +def RoutingPool(*routers): + return RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, {}, *routers) + +class RoutingConnectionPoolFetchRoutingInfoTestCase(StubTestCase): def test_should_get_info_from_router(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector) as pool: + with RoutingPool() as pool: result = pool.fetch_routing_info(address) assert len(result) == 1 record = result[0] @@ -72,7 +76,7 @@ def test_should_get_info_from_router(self): def test_should_remove_router_if_cannot_connect(self): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingPool(address) as pool: assert address in pool.routing_table.routers _ = pool.fetch_routing_info(address) assert address not in pool.routing_table.routers @@ -80,14 +84,14 @@ def test_should_remove_router_if_cannot_connect(self): def test_should_remove_router_if_connection_drops(self): with StubCluster({9001: "rude_router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingPool(address) as pool: assert address in pool.routing_table.routers _ = pool.fetch_routing_info(address) assert address not in pool.routing_table.routers def test_should_not_fail_if_cannot_connect_but_router_already_removed(self): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector) as pool: + with RoutingPool() as pool: assert address not in pool.routing_table.routers _ = pool.fetch_routing_info(address) assert address not in pool.routing_table.routers @@ -95,45 +99,57 @@ def test_should_not_fail_if_cannot_connect_but_router_already_removed(self): def test_should_not_fail_if_connection_drops_but_router_already_removed(self): with StubCluster({9001: "rude_router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector) as pool: + with RoutingPool() as pool: assert address not in pool.routing_table.routers _ = pool.fetch_routing_info(address) assert address not in pool.routing_table.routers def test_should_return_none_if_cannot_connect(self): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingPool(address) as pool: result = pool.fetch_routing_info(address) assert result is None def test_should_return_none_if_connection_drops(self): with StubCluster({9001: "rude_router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingPool(address) as pool: result = pool.fetch_routing_info(address) assert result is None def test_should_fail_for_non_router(self): with StubCluster({9001: "non_router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingPool(address) as pool: with self.assertRaises(ServiceUnavailable): _ = pool.fetch_routing_info(address) def test_should_fail_if_database_error(self): with StubCluster({9001: "broken_router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingPool(address) as pool: with self.assertRaises(ServiceUnavailable): _ = pool.fetch_routing_info(address) + def test_should_call_get_routing_tables_with_context(self): + with StubCluster({9001: "get_routing_table_with_context.script"}): + address = ("127.0.0.1", 9001) + routing_context = {"name": "molly", "age": "1"} + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, routing_context) as pool: + pool.fetch_routing_info(address) + + def test_should_call_get_routing_tables(self): + with StubCluster({9001: "get_routing_table.script"}): + address = ("127.0.0.1", 9001) + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, {}) as pool: + pool.fetch_routing_info(address) -class RoutingConnectionPoolFetchRoutingTableTestCase(StubTestCase): +class RoutingConnectionPoolFetchRoutingTableTestCase(StubTestCase): def test_should_get_table_from_router(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector) as pool: + with RoutingPool() as pool: table = pool.fetch_routing_table(address) assert table.routers == {("127.0.0.1", 9001), ("127.0.0.1", 9002), ("127.0.0.1", 9003)} @@ -143,34 +159,33 @@ def test_should_get_table_from_router(self): def test_null_info_should_return_null_table(self): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector) as pool: + with RoutingPool() as pool: table = pool.fetch_routing_table(address) assert table is None def test_no_routers_should_raise_protocol_error(self): with StubCluster({9001: "router_no_routers.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector) as pool: + with RoutingPool() as pool: with self.assertRaises(ProtocolError): _ = pool.fetch_routing_table(address) def test_no_readers_should_raise_protocol_error(self): with StubCluster({9001: "router_no_readers.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector) as pool: + with RoutingPool() as pool: with self.assertRaises(ProtocolError): _ = pool.fetch_routing_table(address) def test_no_writers_should_return_null_table(self): with StubCluster({9001: "router_no_writers.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector) as pool: + with RoutingPool() as pool: table = pool.fetch_routing_table(address) assert table is None class RoutingConnectionPoolUpdateRoutingTableTestCase(StubTestCase): - scenarios = { (None,): ServiceUnavailable, (RoutingTable,): RoutingTable, @@ -183,8 +198,21 @@ class RoutingConnectionPoolUpdateRoutingTableTestCase(StubTestCase): (None, None, ServiceUnavailable): ServiceUnavailable, } + def test_roll_back_to_initial_server_if_failed_update_with_existing_routers(self): + with StubCluster({9001: "router.script"}): + initial_address = ("127.0.0.1", 9001) # roll back addresses + routers = [("127.0.0.1", 9002), ("127.0.0.1", 9003)] # not reachable servers + with RoutingConnectionPool(connector, initial_address, {}, *routers) as pool: + pool.update_routing_table() + table = pool.routing_table + assert table.routers == {("127.0.0.1", 9001), ("127.0.0.1", 9002), + ("127.0.0.1", 9003)} + assert table.readers == {("127.0.0.1", 9004), ("127.0.0.1", 9005)} + assert table.writers == {("127.0.0.1", 9006)} + assert table.ttl == 300 + def test_update_with_no_routers_should_signal_service_unavailable(self): - with RoutingConnectionPool(connector) as pool: + with RoutingPool() as pool: with self.assertRaises(ServiceUnavailable): pool.update_routing_table() @@ -207,7 +235,7 @@ def _test_server_outcome(self, server_outcomes, overall_outcome): assert False, "Unexpected server outcome %r" % outcome routers.append(("127.0.0.1", port)) with StubCluster(servers): - with RoutingConnectionPool(connector, *routers) as pool: + with RoutingPool(*routers) as pool: if overall_outcome is RoutingTable: pool.update_routing_table() table = pool.routing_table @@ -224,11 +252,10 @@ def _test_server_outcome(self, server_outcomes, overall_outcome): class RoutingConnectionPoolRefreshRoutingTableTestCase(StubTestCase): - def test_should_update_if_stale(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingPool(address) as pool: first_updated_time = pool.routing_table.last_updated_time pool.routing_table.ttl = 0 pool.refresh_routing_table() @@ -238,7 +265,7 @@ def test_should_update_if_stale(self): def test_should_not_update_if_fresh(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingPool(address) as pool: pool.refresh_routing_table() first_updated_time = pool.routing_table.last_updated_time pool.refresh_routing_table() @@ -250,7 +277,7 @@ def test_should_not_update_if_fresh(self): # address = ("127.0.0.1", 9001) # table = RoutingTable.parse_routing_info([VALID_ROUTING_RECORD]) # - # with RoutingConnectionPool(connector, address) as pool: + # with RoutingPool(address) as pool: # semaphore = Semaphore() # # class Refresher(Thread): @@ -297,7 +324,7 @@ def test_should_not_update_if_fresh(self): # address = ("127.0.0.1", 9001) # table = RoutingTable.parse_routing_info([VALID_ROUTING_RECORD]) # - # with RoutingConnectionPool(connector, address) as pool: + # with RoutingPool(address) as pool: # semaphore = Semaphore() # # class Refresher(Thread): @@ -341,11 +368,10 @@ def test_should_not_update_if_fresh(self): class RoutingConnectionPoolAcquireForReadTestCase(StubTestCase): - def test_should_refresh(self): with StubCluster({9001: "router.script", 9004: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingPool(address) as pool: assert not pool.routing_table.is_fresh() _ = pool.acquire(access_mode=READ_ACCESS) assert pool.routing_table.is_fresh() @@ -353,7 +379,7 @@ def test_should_refresh(self): def test_connected_to_reader(self): with StubCluster({9001: "router.script", 9004: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingPool(address) as pool: assert not pool.routing_table.is_fresh() connection = pool.acquire(access_mode=READ_ACCESS) assert connection.server.address in pool.routing_table.readers @@ -363,7 +389,7 @@ def test_should_retry_if_first_reader_fails(self): 9004: "fail_on_init.script", 9005: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingPool(address) as pool: assert not pool.routing_table.is_fresh() _ = pool.acquire(access_mode=READ_ACCESS) assert ("127.0.0.1", 9004) not in pool.routing_table.readers @@ -371,11 +397,10 @@ def test_should_retry_if_first_reader_fails(self): class RoutingConnectionPoolAcquireForWriteTestCase(StubTestCase): - def test_should_refresh(self): with StubCluster({9001: "router.script", 9006: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingPool(address) as pool: assert not pool.routing_table.is_fresh() _ = pool.acquire(access_mode=WRITE_ACCESS) assert pool.routing_table.is_fresh() @@ -383,7 +408,7 @@ def test_should_refresh(self): def test_connected_to_writer(self): with StubCluster({9001: "router.script", 9006: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingPool(address) as pool: assert not pool.routing_table.is_fresh() connection = pool.acquire(access_mode=WRITE_ACCESS) assert connection.server.address in pool.routing_table.writers @@ -393,7 +418,7 @@ def test_should_retry_if_first_writer_fails(self): 9006: "fail_on_init.script", 9007: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingPool(address) as pool: assert not pool.routing_table.is_fresh() _ = pool.acquire(access_mode=WRITE_ACCESS) assert ("127.0.0.1", 9006) not in pool.routing_table.writers @@ -401,11 +426,10 @@ def test_should_retry_if_first_writer_fails(self): class RoutingConnectionPoolRemoveTestCase(StubTestCase): - def test_should_remove_router_from_routing_table_if_present(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingPool(address) as pool: pool.refresh_routing_table() target = ("127.0.0.1", 9001) assert target in pool.routing_table.routers @@ -415,7 +439,7 @@ def test_should_remove_router_from_routing_table_if_present(self): def test_should_remove_reader_from_routing_table_if_present(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingPool(address) as pool: pool.refresh_routing_table() target = ("127.0.0.1", 9004) assert target in pool.routing_table.readers @@ -425,7 +449,7 @@ def test_should_remove_reader_from_routing_table_if_present(self): def test_should_remove_writer_from_routing_table_if_present(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingPool(address) as pool: pool.refresh_routing_table() target = ("127.0.0.1", 9006) assert target in pool.routing_table.writers @@ -435,7 +459,7 @@ def test_should_remove_writer_from_routing_table_if_present(self): def test_should_not_fail_if_absent(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingPool(address) as pool: pool.refresh_routing_table() target = ("127.0.0.1", 9007) pool.remove(target) diff --git a/test/stub/test_routingdriver.py b/test/stub/test_routingdriver.py index 3fe952f07..9f9ecb7ac 100644 --- a/test/stub/test_routingdriver.py +++ b/test/stub/test_routingdriver.py @@ -179,3 +179,23 @@ def test_two_sessions_can_share_a_connection(self): session_2.close() session_1.close() + + def test_should_call_get_routing_table_procedure(self): + with StubCluster({9001: "get_routing_table.script", 9002: "return_1.script"}): + uri = "bolt+routing://127.0.0.1:9001" + with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) as driver: + with driver.session(READ_ACCESS) as session: + result = session.run("RETURN $x", {"x": 1}) + for record in result: + assert record["x"] == 1 + assert result.summary().server.address == ('127.0.0.1', 9002) + + def test_should_call_get_routing_table_with_context(self): + with StubCluster({9001: "get_routing_table_with_context.script", 9002: "return_1.script"}): + uri = "bolt+routing://127.0.0.1:9001/?name=molly&age=1" + with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) as driver: + with driver.session(READ_ACCESS) as session: + result = session.run("RETURN $x", {"x": 1}) + for record in result: + assert record["x"] == 1 + assert result.summary().server.address == ('127.0.0.1', 9002) \ No newline at end of file diff --git a/test/unit/test_addressing.py b/test/unit/test_addressing.py index 33c8d8aea..4c708d83d 100644 --- a/test/unit/test_addressing.py +++ b/test/unit/test_addressing.py @@ -41,3 +41,29 @@ def test_should_parse_host_name_and_port(self): def test_should_fail_on_non_numeric_port(self): with self.assertRaises(ValueError): _ = SocketAddress.parse("127.0.0.1:X") + + def test_parse_empty_routing_context(self): + verify_routing_context({}, "bolt+routing://127.0.0.1/cat?") + verify_routing_context({}, "bolt+routing://127.0.0.1/cat") + verify_routing_context({}, "bolt+routing://127.0.0.1/?") + verify_routing_context({}, "bolt+routing://127.0.0.1/") + verify_routing_context({}, "bolt+routing://127.0.0.1?") + verify_routing_context({}, "bolt+routing://127.0.0.1") + + def test_parse_routing_context(self): + verify_routing_context({"name": "molly", "color": "white"}, "bolt+routing://127.0.0.1/cat?name=molly&color=white") + verify_routing_context({"name": "molly", "color": "white"}, "bolt+routing://127.0.0.1/?name=molly&color=white") + verify_routing_context({"name": "molly", "color": "white"}, "bolt+routing://127.0.0.1?name=molly&color=white") + + def test_should_error_when_value_missing(self): + with self.assertRaises(ValueError): + SocketAddress.parse_routing_context("bolt+routing://127.0.0.1/?name=&color=white") + + def test_should_error_when_key_duplicate(self): + with self.assertRaises(ValueError): + SocketAddress.parse_routing_context("bolt+routing://127.0.0.1/?name=molly&name=white") + + +def verify_routing_context(expected, uri): + context = SocketAddress.parse_routing_context(uri) + assert context == expected diff --git a/test/unit/test_routing.py b/test/unit/test_routing.py index 9e26689db..e6f2f80c9 100644 --- a/test/unit/test_routing.py +++ b/test/unit/test_routing.py @@ -227,5 +227,7 @@ def test_update_should_replace_ttl(self): class RoutingConnectionPoolConstructionTestCase(TestCase): def test_should_populate_initial_router(self): - with RoutingConnectionPool(connector, ("127.0.0.1", 9001)) as pool: - assert pool.routing_table.routers == {("127.0.0.1", 9001)} + initial_router = ("127.0.0.1", 9001) + router = ("127.0.0.1", 9002) + with RoutingConnectionPool(connector, initial_router, {}, router) as pool: + assert pool.routing_table.routers == {("127.0.0.1", 9002)}