diff --git a/neo4j/addressing.py b/neo4j/addressing.py index 8dca01e17..b05730c7e 100644 --- a/neo4j/addressing.py +++ b/neo4j/addressing.py @@ -25,6 +25,11 @@ from neo4j.compat import urlparse from neo4j.exceptions import AddressError +try: + from urllib.parse import parse_qs +except ImportError: + from urlparse import parse_qs + VALID_IPv4_SEGMENTS = [str(i).encode("latin1") for i in range(0x100)] VALID_IPv6_SEGMENT_CHARS = b"0123456789abcdef" @@ -79,6 +84,24 @@ def parse(cls, string, default_port=0): """ return cls.from_uri("//{}".format(string), default_port) + @classmethod + def parse_routing_context(cls, uri): + query = urlparse(uri).query + if not query: + return {} + + context = {} + parameters = parse_qs(query, True) + for key in parameters: + value_list = parameters[key] + if len(value_list) != 1: + raise ValueError("Duplicated query parameters with key '%s', value '%s' found in URL '%s'" % (key, value_list, uri)) + value = value_list[0] + if not value: + raise ValueError("Invalid parameters:'%s=%s' in URI '%s'." % (key, value, uri)) + context[key] = value + return context + def resolve(socket_address): try: diff --git a/neo4j/util.py b/neo4j/util.py index a91380777..d8b711b4f 100644 --- a/neo4j/util.py +++ b/neo4j/util.py @@ -25,6 +25,25 @@ from sys import stdout +class ServerVersion(object): + def __init__(self, product, version_tuple, tags_tuple): + self.product = product + self.version_tuple = version_tuple + self.tags_tuple = tags_tuple + + def at_least_version(self, major, minor): + return self.version_tuple >= (major, minor) + + @classmethod + def from_str(cls, full_version): + if full_version is None: + return ServerVersion("Neo4j", (3, 0), ()) + product, _, tagged_version = full_version.partition("/") + tags = tagged_version.split("-") + version = map(int, tags[0].split(".")) + return ServerVersion(product, tuple(version), tuple(tags[1:])) + + class ColourFormatter(logging.Formatter): """ Colour formatter for pretty log output. """ diff --git a/neo4j/v1/direct.py b/neo4j/v1/direct.py index 92376a72e..0278bb342 100644 --- a/neo4j/v1/direct.py +++ b/neo4j/v1/direct.py @@ -56,6 +56,8 @@ def __init__(self, uri, **config): # will carry out DNS resolution, leading to the possibility that # the connection pool may contain multiple IP address keys, one for # an old address and one for a new address. + if SocketAddress.parse_routing_context(uri): + raise ValueError("Parameters are not supported with scheme 'bolt'. Given URI: '%s'." % uri) self.address = SocketAddress.from_uri(uri, DEFAULT_PORT) self.security_plan = security_plan = SecurityPlan.build(**config) self.encrypted = security_plan.encrypted diff --git a/neo4j/v1/routing.py b/neo4j/v1/routing.py index b45b881d1..1b6c157ce 100644 --- a/neo4j/v1/routing.py +++ b/neo4j/v1/routing.py @@ -26,10 +26,11 @@ from neo4j.bolt import ConnectionPool, ServiceUnavailable, ProtocolError, DEFAULT_PORT, connect from neo4j.compat.collections import MutableSet, OrderedDict from neo4j.exceptions import CypherError -from neo4j.v1.api import Driver, READ_ACCESS, WRITE_ACCESS +from neo4j.v1.api import Driver, READ_ACCESS, WRITE_ACCESS, fix_statement, fix_parameters from neo4j.v1.exceptions import SessionExpired from neo4j.v1.security import SecurityPlan from neo4j.v1.session import BoltSession +from neo4j.util import ServerVersion class RoundRobinSet(MutableSet): @@ -131,11 +132,12 @@ def __init__(self, routers=(), readers=(), writers=(), ttl=0): self.last_updated_time = self.timer() self.ttl = ttl - def is_fresh(self): + def is_fresh(self, access_mode): """ Indicator for whether routing information is still usable. """ expired = self.last_updated_time + self.ttl <= self.timer() - return not expired and len(self.routers) > 1 and self.readers and self.writers + has_server_for_mode = (access_mode == READ_ACCESS and self.readers) or (access_mode == WRITE_ACCESS and self.writers) + return not expired and self.routers and has_server_for_mode def update(self, new_routing_table): """ Update the current routing table with new routing information @@ -148,16 +150,34 @@ def update(self, new_routing_table): self.ttl = new_routing_table.ttl +class RoutingSession(BoltSession): + + call_get_servers = "CALL dbms.cluster.routing.getServers" + get_routing_table_param = "context" + call_get_routing_table = "CALL dbms.cluster.routing.getRoutingTable({%s})" % get_routing_table_param + + def routing_info_procedure(self, routing_context): + if ServerVersion.from_str(self._connection.server.version).at_least_version(3, 2): + return self.call_get_routing_table, {self.get_routing_table_param: routing_context} + else: + return self.call_get_servers, {} + + def __run__(self, ignored, routing_context): + # the statement is ignored as it will be get routing table procedure call. + statement, parameters = self.routing_info_procedure(routing_context) + return self._run(fix_statement(statement), fix_parameters(parameters)) + + class RoutingConnectionPool(ConnectionPool): """ Connection pool with routing table. """ - routing_info_procedure = "dbms.cluster.routing.getServers" - - def __init__(self, connector, initial_address, *routers): + def __init__(self, connector, initial_address, routing_context, *routers): super(RoutingConnectionPool, self).__init__(connector) self.initial_address = initial_address + self.routing_context = routing_context self.routing_table = RoutingTable(routers) + self.missing_writer = False self.refresh_lock = Lock() def fetch_routing_info(self, address): @@ -170,8 +190,8 @@ def fetch_routing_info(self, address): if routing support is broken """ try: - with BoltSession(lambda _: self.acquire_direct(address)) as session: - return list(session.run("CALL %s" % self.routing_info_procedure)) + with RoutingSession(lambda _: self.acquire_direct(address)) as session: + return list(session.run("ignored", self.routing_context)) except CypherError as error: if error.code == "Neo.ClientError.Procedure.ProcedureNotFound": raise ServiceUnavailable("Server {!r} does not support routing".format(address)) @@ -200,6 +220,11 @@ def fetch_routing_table(self, address): num_readers = len(new_routing_table.readers) num_writers = len(new_routing_table.writers) + # No writers are available. This likely indicates a temporary state, + # such as leader switching, so we should not signal an error. + # When no writers available, then we flag we are reading in absence of writer + self.missing_writer = (num_writers == 0) + # No routers if num_routers == 0: raise ProtocolError("No routing servers returned from server %r" % (address,)) @@ -208,12 +233,6 @@ def fetch_routing_table(self, address): if num_readers == 0: raise ProtocolError("No read servers returned from server %r" % (address,)) - # No writers - if num_writers == 0: - # No writers are available. This likely indicates a temporary state, - # such as leader switching, so we should not signal an error. - return None - # At least one of each is fine, so return this table return new_routing_table @@ -234,21 +253,30 @@ def update_routing_table(self): """ # copied because it can be modified copy_of_routers = list(self.routing_table.routers) + + has_tried_initial_routers = False + if self.missing_writer: + has_tried_initial_routers = True + if self.update_routing_table_with_routers(resolve(self.initial_address)): + return + if self.update_routing_table_with_routers(copy_of_routers): return - initial_routers = resolve(self.initial_address) - for router in copy_of_routers: - if router in initial_routers: - initial_routers.remove(router) - if initial_routers: - if self.update_routing_table_with_routers(initial_routers): - return + if not has_tried_initial_routers: + initial_routers = resolve(self.initial_address) + for router in copy_of_routers: + if router in initial_routers: + initial_routers.remove(router) + if initial_routers: + if self.update_routing_table_with_routers(initial_routers): + return + # None of the routers have been successful, so just fail raise ServiceUnavailable("Unable to retrieve routing information") - def refresh_routing_table(self): + def ensure_routing_table_is_fresh(self, access_mode): """ Update the routing table if stale. This method performs two freshness checks, before and after acquiring @@ -261,10 +289,13 @@ def refresh_routing_table(self): :return: `True` if an update was required, `False` otherwise. """ - if self.routing_table.is_fresh(): + if self.routing_table.is_fresh(access_mode): return False with self.refresh_lock: - if self.routing_table.is_fresh(): + if self.routing_table.is_fresh(access_mode): + if access_mode == READ_ACCESS: + # if reader is fresh but writers is not fresh, then we are reading in absence of writer + self.missing_writer = not self.routing_table.is_fresh(WRITE_ACCESS) return False self.update_routing_table() return True @@ -278,11 +309,12 @@ def acquire(self, access_mode=None): server_list = self.routing_table.writers else: raise ValueError("Unsupported access mode {}".format(access_mode)) + + self.ensure_routing_table_is_fresh(access_mode) while True: - address = None - while address is None: - self.refresh_routing_table() - address = next(server_list) + address = next(server_list) + if address is None: + break try: connection = self.acquire_direct(address) # should always be a resolved address connection.Error = SessionExpired @@ -290,6 +322,7 @@ def acquire(self, access_mode=None): self.remove(address) else: return connection + raise SessionExpired("Failed to obtain connection towards '%s' server." % access_mode) def remove(self, address): """ Remove an address from the connection pool, if present, closing @@ -313,6 +346,7 @@ def __init__(self, uri, **config): self.initial_address = initial_address = SocketAddress.from_uri(uri, DEFAULT_PORT) self.security_plan = security_plan = SecurityPlan.build(**config) self.encrypted = security_plan.encrypted + routing_context = SocketAddress.parse_routing_context(uri) if not security_plan.routing_compatible: # this error message is case-specific as there is only one incompatible # scenario right now @@ -321,7 +355,7 @@ def __init__(self, uri, **config): def connector(a): return connect(a, security_plan.ssl_context, **config) - pool = RoutingConnectionPool(connector, initial_address, *resolve(initial_address)) + pool = RoutingConnectionPool(connector, initial_address, routing_context, *resolve(initial_address)) try: pool.update_routing_table() except: diff --git a/neo4j/v1/session.py b/neo4j/v1/session.py index 855f9860a..0cbdce4a2 100644 --- a/neo4j/v1/session.py +++ b/neo4j/v1/session.py @@ -34,7 +34,7 @@ class BoltSession(Session): :param bookmark: """ - def __run__(self, statement, parameters): + def _run(self, statement, parameters): assert isinstance(statement, unicode) assert isinstance(parameters, dict) @@ -52,6 +52,9 @@ def __run__(self, statement, parameters): return result + def __run__(self, statement, parameters): + return self._run(statement, parameters) + def __begin__(self): return self.__run__(u"BEGIN", {"bookmark": self._bookmark} if self._bookmark else {}) diff --git a/test/integration/tools.py b/test/integration/tools.py index ff261cc98..9402eee90 100644 --- a/test/integration/tools.py +++ b/test/integration/tools.py @@ -33,6 +33,7 @@ from boltkit.controller import WindowsController, UnixController from neo4j.v1 import GraphDatabase, AuthError +from neo4j.util import ServerVersion from test.env import NEO4J_SERVER_PACKAGE, NEO4J_USER, NEO4J_PASSWORD @@ -89,17 +90,11 @@ def server_version_info(cls): with GraphDatabase.driver(cls.bolt_uri, auth=cls.auth_token) as driver: with driver.session() as session: full_version = session.run("RETURN 1").summary().server.version - if full_version is None: - return "Neo4j", (3, 0), () - product, _, tagged_version = full_version.partition("/") - tags = tagged_version.split("-") - version = map(int, tags[0].split(".")) - return product, tuple(version), tuple(tags[1:]) + return ServerVersion.from_str(full_version) @classmethod def at_least_version(cls, major, minor): - _, server_version, _ = cls.server_version_info() - return server_version >= (major, minor) + return cls.server_version_info().at_least_version(major, minor); @classmethod def delete_known_hosts_file(cls): diff --git a/test/stub/scripts/get_routing_table.script b/test/stub/scripts/get_routing_table.script new file mode 100644 index 000000000..945ea64cd --- /dev/null +++ b/test/stub/scripts/get_routing_table.script @@ -0,0 +1,9 @@ +!: AUTO INIT +!: AUTO RESET + +S: SUCCESS {"server": "Neo4j/3.2.2"} +C: RUN "CALL dbms.cluster.routing.getRoutingTable({context})" {"context": {}} + PULL_ALL +S: SUCCESS {"fields": ["ttl", "servers"]} + RECORD [9223372036854775807, [{"addresses": ["127.0.0.1:9001"],"role": "WRITE"}, {"addresses": ["127.0.0.1:9002"], "role": "READ"},{"addresses": ["127.0.0.1:9001", "127.0.0.1:9002"], "role": "ROUTE"}]] + SUCCESS {} \ No newline at end of file diff --git a/test/stub/scripts/get_routing_table_with_context.script b/test/stub/scripts/get_routing_table_with_context.script new file mode 100644 index 000000000..ce4244744 --- /dev/null +++ b/test/stub/scripts/get_routing_table_with_context.script @@ -0,0 +1,9 @@ +!: AUTO INIT +!: AUTO RESET + +S: SUCCESS {"server": "Neo4j/3.2.3"} +C: RUN "CALL dbms.cluster.routing.getRoutingTable({context})" {"context": {"name": "molly", "age": "1"}} + PULL_ALL +S: SUCCESS {"fields": ["ttl", "servers"]} + RECORD [9223372036854775807, [{"addresses": ["127.0.0.1:9001"],"role": "WRITE"}, {"addresses": ["127.0.0.1:9002"], "role": "READ"},{"addresses": ["127.0.0.1:9001", "127.0.0.1:9002"], "role": "ROUTE"}]] + SUCCESS {} \ No newline at end of file diff --git a/test/stub/test_directdriver.py b/test/stub/test_directdriver.py index efb63062b..60d80d809 100644 --- a/test/stub/test_directdriver.py +++ b/test/stub/test_directdriver.py @@ -48,3 +48,9 @@ def test_direct_disconnect_on_pull_all(self): with self.assertRaises(ServiceUnavailable): with driver.session() as session: session.run("RETURN $x", {"x": 1}).consume() + + def test_direct_should_reject_routing_context(self): + uri = "bolt://127.0.0.1:9001/?name=molly&age=1" + with self.assertRaises(ValueError): + GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) + diff --git a/test/stub/test_routing.py b/test/stub/test_routing.py index 6cf0ae357..2c9a23cea 100644 --- a/test/stub/test_routing.py +++ b/test/stub/test_routing.py @@ -24,7 +24,6 @@ from test.stub.tools import StubCluster, StubTestCase - VALID_ROUTING_RECORD = { "ttl": 300, "servers": [ @@ -50,16 +49,20 @@ UNREACHABLE_ADDRESS = ("127.0.0.1", 8080) + def connector(address): return connect(address, auth=basic_auth("neotest", "neotest")) -class RoutingConnectionPoolFetchRoutingInfoTestCase(StubTestCase): +def RoutingPool(*routers): + return RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, {}, *routers) + +class RoutingConnectionPoolFetchRoutingInfoTestCase(StubTestCase): def test_should_get_info_from_router(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: + with RoutingPool() as pool: result = pool.fetch_routing_info(address) assert len(result) == 1 record = result[0] @@ -73,7 +76,7 @@ def test_should_get_info_from_router(self): def test_should_remove_router_if_cannot_connect(self): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: + with RoutingPool(address) as pool: assert address in pool.routing_table.routers _ = pool.fetch_routing_info(address) assert address not in pool.routing_table.routers @@ -81,14 +84,14 @@ def test_should_remove_router_if_cannot_connect(self): def test_should_remove_router_if_connection_drops(self): with StubCluster({9001: "rude_router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: + with RoutingPool(address) as pool: assert address in pool.routing_table.routers _ = pool.fetch_routing_info(address) assert address not in pool.routing_table.routers def test_should_not_fail_if_cannot_connect_but_router_already_removed(self): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: + with RoutingPool() as pool: assert address not in pool.routing_table.routers _ = pool.fetch_routing_info(address) assert address not in pool.routing_table.routers @@ -96,82 +99,99 @@ def test_should_not_fail_if_cannot_connect_but_router_already_removed(self): def test_should_not_fail_if_connection_drops_but_router_already_removed(self): with StubCluster({9001: "rude_router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: + with RoutingPool() as pool: assert address not in pool.routing_table.routers _ = pool.fetch_routing_info(address) assert address not in pool.routing_table.routers def test_should_return_none_if_cannot_connect(self): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: + with RoutingPool(address) as pool: result = pool.fetch_routing_info(address) assert result is None def test_should_return_none_if_connection_drops(self): with StubCluster({9001: "rude_router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: + with RoutingPool(address) as pool: result = pool.fetch_routing_info(address) assert result is None def test_should_fail_for_non_router(self): with StubCluster({9001: "non_router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: + with RoutingPool(address) as pool: with self.assertRaises(ServiceUnavailable): _ = pool.fetch_routing_info(address) def test_should_fail_if_database_error(self): with StubCluster({9001: "broken_router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: + with RoutingPool(address) as pool: with self.assertRaises(ServiceUnavailable): _ = pool.fetch_routing_info(address) + def test_should_call_get_routing_tables_with_context(self): + with StubCluster({9001: "get_routing_table_with_context.script"}): + address = ("127.0.0.1", 9001) + routing_context = {"name": "molly", "age": "1"} + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, routing_context) as pool: + pool.fetch_routing_info(address) -class RoutingConnectionPoolFetchRoutingTableTestCase(StubTestCase): + def test_should_call_get_routing_tables(self): + with StubCluster({9001: "get_routing_table.script"}): + address = ("127.0.0.1", 9001) + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, {}) as pool: + pool.fetch_routing_info(address) + +class RoutingConnectionPoolFetchRoutingTableTestCase(StubTestCase): def test_should_get_table_from_router(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: + with RoutingPool() as pool: table = pool.fetch_routing_table(address) assert table.routers == {("127.0.0.1", 9001), ("127.0.0.1", 9002), ("127.0.0.1", 9003)} assert table.readers == {("127.0.0.1", 9004), ("127.0.0.1", 9005)} assert table.writers == {("127.0.0.1", 9006)} assert table.ttl == 300 + assert not pool.missing_writer def test_null_info_should_return_null_table(self): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: + with RoutingPool() as pool: table = pool.fetch_routing_table(address) assert table is None def test_no_routers_should_raise_protocol_error(self): with StubCluster({9001: "router_no_routers.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: + with RoutingPool() as pool: with self.assertRaises(ProtocolError): _ = pool.fetch_routing_table(address) def test_no_readers_should_raise_protocol_error(self): with StubCluster({9001: "router_no_readers.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: + with RoutingPool() as pool: with self.assertRaises(ProtocolError): _ = pool.fetch_routing_table(address) - def test_no_writers_should_return_null_table(self): + def test_no_writers_should_return_table_with_no_writer(self): with StubCluster({9001: "router_no_writers.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: + with RoutingPool() as pool: table = pool.fetch_routing_table(address) - assert table is None + assert table.routers == {("127.0.0.1", 9001), ("127.0.0.1", 9002), + ("127.0.0.1", 9003)} + assert table.readers == {("127.0.0.1", 9004), ("127.0.0.1", 9005)} + assert not table.writers + assert table.ttl == 300 + assert pool.missing_writer class RoutingConnectionPoolUpdateRoutingTableTestCase(StubTestCase): - scenarios = { (None,): ServiceUnavailable, (RoutingTable,): RoutingTable, @@ -186,9 +206,9 @@ class RoutingConnectionPoolUpdateRoutingTableTestCase(StubTestCase): def test_roll_back_to_initial_server_if_failed_update_with_existing_routers(self): with StubCluster({9001: "router.script"}): - initial_address =("127.0.0.1", 9001) # roll back addresses - routers = [("127.0.0.1", 9002), ("127.0.0.1", 9003)] # not reachable servers - with RoutingConnectionPool(connector, initial_address, *routers) as pool: + initial_address = ("127.0.0.1", 9001) # roll back addresses + routers = [("127.0.0.1", 9002), ("127.0.0.1", 9003)] # not reachable servers + with RoutingConnectionPool(connector, initial_address, {}, *routers) as pool: pool.update_routing_table() table = pool.routing_table assert table.routers == {("127.0.0.1", 9001), ("127.0.0.1", 9002), @@ -197,8 +217,22 @@ def test_roll_back_to_initial_server_if_failed_update_with_existing_routers(self assert table.writers == {("127.0.0.1", 9006)} assert table.ttl == 300 + def test_try_initial_server_first_if_missing_writer(self): + with StubCluster({9001: "router.script"}): + initial_address = ("127.0.0.1", 9001) + with RoutingConnectionPool(connector, initial_address, {}) as pool: + pool.missing_writer = True + pool.update_routing_table() + table = pool.routing_table + assert table.routers == {("127.0.0.1", 9001), ("127.0.0.1", 9002), + ("127.0.0.1", 9003)} + assert table.readers == {("127.0.0.1", 9004), ("127.0.0.1", 9005)} + assert table.writers == {("127.0.0.1", 9006)} + assert table.ttl == 300 + assert not pool.missing_writer + def test_update_with_no_routers_should_signal_service_unavailable(self): - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: + with RoutingPool() as pool: with self.assertRaises(ServiceUnavailable): pool.update_routing_table() @@ -212,7 +246,7 @@ def _test_server_outcome(self, server_outcomes, overall_outcome): routers = [] for port, outcome in enumerate(server_outcomes, 9001): if outcome is None: - servers[port] = "router_no_writers.script" + servers[port] = "rude_router.script" elif outcome is RoutingTable: servers[port] = "router.script" elif outcome is ServiceUnavailable: @@ -221,7 +255,7 @@ def _test_server_outcome(self, server_outcomes, overall_outcome): assert False, "Unexpected server outcome %r" % outcome routers.append(("127.0.0.1", port)) with StubCluster(servers): - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, *routers) as pool: + with RoutingPool(*routers) as pool: if overall_outcome is RoutingTable: pool.update_routing_table() table = pool.routing_table @@ -237,34 +271,44 @@ def _test_server_outcome(self, server_outcomes, overall_outcome): assert False, "Unexpected overall outcome %r" % overall_outcome -class RoutingConnectionPoolRefreshRoutingTableTestCase(StubTestCase): - +class RoutingConnectionPoolEnsureRoutingTableTestCase(StubTestCase): def test_should_update_if_stale(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: + with RoutingPool(address) as pool: first_updated_time = pool.routing_table.last_updated_time pool.routing_table.ttl = 0 - pool.refresh_routing_table() + pool.ensure_routing_table_is_fresh(WRITE_ACCESS) second_updated_time = pool.routing_table.last_updated_time assert second_updated_time != first_updated_time + assert not pool.missing_writer def test_should_not_update_if_fresh(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: - pool.refresh_routing_table() + with RoutingPool(address) as pool: + pool.ensure_routing_table_is_fresh(WRITE_ACCESS) first_updated_time = pool.routing_table.last_updated_time - pool.refresh_routing_table() + pool.ensure_routing_table_is_fresh(WRITE_ACCESS) second_updated_time = pool.routing_table.last_updated_time assert second_updated_time == first_updated_time + assert not pool.missing_writer + + def test_should_flag_reading_without_writer(self): + with StubCluster({9001: "router_no_writers.script"}): + address = ("127.0.0.1", 9001) + with RoutingPool(address) as pool: + assert not pool.routing_table.is_fresh(READ_ACCESS) + assert not pool.routing_table.is_fresh(WRITE_ACCESS) + pool.ensure_routing_table_is_fresh(READ_ACCESS) + assert pool.missing_writer # TODO: fix flaky test # def test_concurrent_refreshes_should_not_block_if_fresh(self): # address = ("127.0.0.1", 9001) # table = RoutingTable.parse_routing_info([VALID_ROUTING_RECORD]) # - # with RoutingConnectionPool(connector, not_reachable_address, address) as pool: + # with RoutingPool(address) as pool: # semaphore = Semaphore() # # class Refresher(Thread): @@ -311,7 +355,7 @@ def test_should_not_update_if_fresh(self): # address = ("127.0.0.1", 9001) # table = RoutingTable.parse_routing_info([VALID_ROUTING_RECORD]) # - # with RoutingConnectionPool(connector, not_reachable_address, address) as pool: + # with RoutingPool(address) as pool: # semaphore = Semaphore() # # class Refresher(Thread): @@ -355,72 +399,94 @@ def test_should_not_update_if_fresh(self): class RoutingConnectionPoolAcquireForReadTestCase(StubTestCase): - def test_should_refresh(self): with StubCluster({9001: "router.script", 9004: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: - assert not pool.routing_table.is_fresh() + with RoutingPool(address) as pool: + assert not pool.routing_table.is_fresh(READ_ACCESS) _ = pool.acquire(access_mode=READ_ACCESS) - assert pool.routing_table.is_fresh() + assert pool.routing_table.is_fresh(READ_ACCESS) + assert not pool.missing_writer def test_connected_to_reader(self): with StubCluster({9001: "router.script", 9004: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: - assert not pool.routing_table.is_fresh() + with RoutingPool(address) as pool: + assert not pool.routing_table.is_fresh(READ_ACCESS) connection = pool.acquire(access_mode=READ_ACCESS) assert connection.server.address in pool.routing_table.readers + assert not pool.missing_writer def test_should_retry_if_first_reader_fails(self): with StubCluster({9001: "router.script", 9004: "fail_on_init.script", 9005: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: - assert not pool.routing_table.is_fresh() + with RoutingPool(address) as pool: + assert not pool.routing_table.is_fresh(READ_ACCESS) _ = pool.acquire(access_mode=READ_ACCESS) assert ("127.0.0.1", 9004) not in pool.routing_table.readers assert ("127.0.0.1", 9005) in pool.routing_table.readers + def test_should_connect_to_read_in_absent_of_writer(self): + with StubCluster({9001: "router_no_writers.script", 9004: "empty.script"}): + address = ("127.0.0.1", 9001) + with RoutingPool(address) as pool: + assert not pool.routing_table.is_fresh(READ_ACCESS) + connection = pool.acquire(access_mode=READ_ACCESS) + assert connection.server.address in pool.routing_table.readers + assert not pool.routing_table.is_fresh(WRITE_ACCESS) + assert pool.missing_writer -class RoutingConnectionPoolAcquireForWriteTestCase(StubTestCase): +class RoutingConnectionPoolAcquireForWriteTestCase(StubTestCase): def test_should_refresh(self): with StubCluster({9001: "router.script", 9006: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: - assert not pool.routing_table.is_fresh() + with RoutingPool(address) as pool: + assert not pool.routing_table.is_fresh(WRITE_ACCESS) _ = pool.acquire(access_mode=WRITE_ACCESS) - assert pool.routing_table.is_fresh() + assert pool.routing_table.is_fresh(WRITE_ACCESS) + assert not pool.missing_writer def test_connected_to_writer(self): with StubCluster({9001: "router.script", 9006: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: - assert not pool.routing_table.is_fresh() + with RoutingPool(address) as pool: + assert not pool.routing_table.is_fresh(WRITE_ACCESS) connection = pool.acquire(access_mode=WRITE_ACCESS) assert connection.server.address in pool.routing_table.writers + assert not pool.missing_writer def test_should_retry_if_first_writer_fails(self): with StubCluster({9001: "router_with_multiple_writers.script", 9006: "fail_on_init.script", 9007: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: - assert not pool.routing_table.is_fresh() + with RoutingPool(address) as pool: + assert not pool.routing_table.is_fresh(WRITE_ACCESS) _ = pool.acquire(access_mode=WRITE_ACCESS) assert ("127.0.0.1", 9006) not in pool.routing_table.writers assert ("127.0.0.1", 9007) in pool.routing_table.writers + def test_should_error_to_writer_in_absent_of_reader(self): + with StubCluster({9001: "router_no_readers.script"}): + address = ("127.0.0.1", 9001) + with RoutingPool(address) as pool: + assert not pool.routing_table.is_fresh(WRITE_ACCESS) + with self.assertRaises(ProtocolError): + _ = pool.acquire(access_mode=WRITE_ACCESS) + assert not pool.routing_table.is_fresh(READ_ACCESS) + assert not pool.routing_table.is_fresh(WRITE_ACCESS) + assert not pool.missing_writer -class RoutingConnectionPoolRemoveTestCase(StubTestCase): +class RoutingConnectionPoolRemoveTestCase(StubTestCase): def test_should_remove_router_from_routing_table_if_present(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: - pool.refresh_routing_table() + with RoutingPool(address) as pool: + pool.ensure_routing_table_is_fresh(WRITE_ACCESS) target = ("127.0.0.1", 9001) assert target in pool.routing_table.routers pool.remove(target) @@ -429,8 +495,8 @@ def test_should_remove_router_from_routing_table_if_present(self): def test_should_remove_reader_from_routing_table_if_present(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: - pool.refresh_routing_table() + with RoutingPool(address) as pool: + pool.ensure_routing_table_is_fresh(WRITE_ACCESS) target = ("127.0.0.1", 9004) assert target in pool.routing_table.readers pool.remove(target) @@ -439,8 +505,8 @@ def test_should_remove_reader_from_routing_table_if_present(self): def test_should_remove_writer_from_routing_table_if_present(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: - pool.refresh_routing_table() + with RoutingPool(address) as pool: + pool.ensure_routing_table_is_fresh(WRITE_ACCESS) target = ("127.0.0.1", 9006) assert target in pool.routing_table.writers pool.remove(target) @@ -449,7 +515,7 @@ def test_should_remove_writer_from_routing_table_if_present(self): def test_should_not_fail_if_absent(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: - pool.refresh_routing_table() + with RoutingPool(address) as pool: + pool.ensure_routing_table_is_fresh(WRITE_ACCESS) target = ("127.0.0.1", 9007) pool.remove(target) diff --git a/test/stub/test_routingdriver.py b/test/stub/test_routingdriver.py index 3fe952f07..debe7050a 100644 --- a/test/stub/test_routingdriver.py +++ b/test/stub/test_routingdriver.py @@ -179,3 +179,39 @@ def test_two_sessions_can_share_a_connection(self): session_2.close() session_1.close() + + def test_should_call_get_routing_table_procedure(self): + with StubCluster({9001: "get_routing_table.script", 9002: "return_1.script"}): + uri = "bolt+routing://127.0.0.1:9001" + with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) as driver: + with driver.session(READ_ACCESS) as session: + result = session.run("RETURN $x", {"x": 1}) + for record in result: + assert record["x"] == 1 + assert result.summary().server.address == ('127.0.0.1', 9002) + + def test_should_call_get_routing_table_with_context(self): + with StubCluster({9001: "get_routing_table_with_context.script", 9002: "return_1.script"}): + uri = "bolt+routing://127.0.0.1:9001/?name=molly&age=1" + with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) as driver: + with driver.session(READ_ACCESS) as session: + result = session.run("RETURN $x", {"x": 1}) + for record in result: + assert record["x"] == 1 + assert result.summary().server.address == ('127.0.0.1', 9002) + + def test_should_serve_read_when_missing_writer(self): + with StubCluster({9001: "router_no_writers.script", 9005: "return_1.script"}): + uri = "bolt+routing://127.0.0.1:9001" + with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) as driver: + with driver.session(READ_ACCESS) as session: + result = session.run("RETURN $x", {"x": 1}) + for record in result: + assert record["x"] == 1 + assert result.summary().server.address == ('127.0.0.1', 9005) + + def test_should_error_when_missing_reader(self): + with StubCluster({9001: "router_no_readers.script"}): + uri = "bolt+routing://127.0.0.1:9001" + with self.assertRaises(ProtocolError): + GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) diff --git a/test/unit/test_addressing.py b/test/unit/test_addressing.py index 33c8d8aea..4c708d83d 100644 --- a/test/unit/test_addressing.py +++ b/test/unit/test_addressing.py @@ -41,3 +41,29 @@ def test_should_parse_host_name_and_port(self): def test_should_fail_on_non_numeric_port(self): with self.assertRaises(ValueError): _ = SocketAddress.parse("127.0.0.1:X") + + def test_parse_empty_routing_context(self): + verify_routing_context({}, "bolt+routing://127.0.0.1/cat?") + verify_routing_context({}, "bolt+routing://127.0.0.1/cat") + verify_routing_context({}, "bolt+routing://127.0.0.1/?") + verify_routing_context({}, "bolt+routing://127.0.0.1/") + verify_routing_context({}, "bolt+routing://127.0.0.1?") + verify_routing_context({}, "bolt+routing://127.0.0.1") + + def test_parse_routing_context(self): + verify_routing_context({"name": "molly", "color": "white"}, "bolt+routing://127.0.0.1/cat?name=molly&color=white") + verify_routing_context({"name": "molly", "color": "white"}, "bolt+routing://127.0.0.1/?name=molly&color=white") + verify_routing_context({"name": "molly", "color": "white"}, "bolt+routing://127.0.0.1?name=molly&color=white") + + def test_should_error_when_value_missing(self): + with self.assertRaises(ValueError): + SocketAddress.parse_routing_context("bolt+routing://127.0.0.1/?name=&color=white") + + def test_should_error_when_key_duplicate(self): + with self.assertRaises(ValueError): + SocketAddress.parse_routing_context("bolt+routing://127.0.0.1/?name=molly&name=white") + + +def verify_routing_context(expected, uri): + context = SocketAddress.parse_routing_context(uri) + assert context == expected diff --git a/test/unit/test_routing.py b/test/unit/test_routing.py index 5361259e9..2b9fce5ce 100644 --- a/test/unit/test_routing.py +++ b/test/unit/test_routing.py @@ -24,6 +24,7 @@ from neo4j.bolt.connection import connect from neo4j.v1.routing import RoundRobinSet, RoutingTable, RoutingConnectionPool from neo4j.v1.security import basic_auth +from neo4j.v1.api import Driver, READ_ACCESS, WRITE_ACCESS VALID_ROUTING_RECORD = { @@ -144,7 +145,8 @@ class RoutingTableConstructionTestCase(TestCase): def test_should_be_initially_stale(self): table = RoutingTable() - assert not table.is_fresh() + assert not table.is_fresh(READ_ACCESS) + assert not table.is_fresh(WRITE_ACCESS) class RoutingTableParseRoutingInfoTestCase(TestCase): @@ -180,22 +182,26 @@ class RoutingTableFreshnessTestCase(TestCase): def test_should_be_fresh_after_update(self): table = RoutingTable.parse_routing_info([VALID_ROUTING_RECORD]) - assert table.is_fresh() + assert table.is_fresh(READ_ACCESS) + assert table.is_fresh(WRITE_ACCESS) def test_should_become_stale_on_expiry(self): table = RoutingTable.parse_routing_info([VALID_ROUTING_RECORD]) table.ttl = 0 - assert not table.is_fresh() + assert not table.is_fresh(READ_ACCESS) + assert not table.is_fresh(WRITE_ACCESS) def test_should_become_stale_if_no_readers(self): table = RoutingTable.parse_routing_info([VALID_ROUTING_RECORD]) table.readers.clear() - assert not table.is_fresh() + assert not table.is_fresh(READ_ACCESS) + assert table.is_fresh(WRITE_ACCESS) def test_should_become_stale_if_no_writers(self): table = RoutingTable.parse_routing_info([VALID_ROUTING_RECORD]) table.writers.clear() - assert not table.is_fresh() + assert table.is_fresh(READ_ACCESS) + assert not table.is_fresh(WRITE_ACCESS) class RoutingTableUpdateTestCase(TestCase): @@ -229,5 +235,5 @@ class RoutingConnectionPoolConstructionTestCase(TestCase): def test_should_populate_initial_router(self): initial_router = ("127.0.0.1", 9001) router = ("127.0.0.1", 9002) - with RoutingConnectionPool(connector, initial_router, router) as pool: + with RoutingConnectionPool(connector, initial_router, {}, router) as pool: assert pool.routing_table.routers == {("127.0.0.1", 9002)}