diff --git a/neo4j/v1/routing.py b/neo4j/v1/routing.py index 73b9fcfa7..b45b881d1 100644 --- a/neo4j/v1/routing.py +++ b/neo4j/v1/routing.py @@ -154,8 +154,9 @@ class RoutingConnectionPool(ConnectionPool): routing_info_procedure = "dbms.cluster.routing.getServers" - def __init__(self, connector, *routers): + def __init__(self, connector, initial_address, *routers): super(RoutingConnectionPool, self).__init__(connector) + self.initial_address = initial_address self.routing_table = RoutingTable(routers) self.refresh_lock = Lock() @@ -216,16 +217,32 @@ def fetch_routing_table(self, address): # At least one of each is fine, so return this table return new_routing_table + def update_routing_table_with_routers(self, routers): + """Try to update routing tables with the given routers + :return: True if the routing table is successfully updated, otherwise False + """ + for router in routers: + new_routing_table = self.fetch_routing_table(router) + if new_routing_table is not None: + self.routing_table.update(new_routing_table) + return True + return False + def update_routing_table(self): """ Update the routing table from the first router able to provide valid routing information. """ # copied because it can be modified copy_of_routers = list(self.routing_table.routers) + if self.update_routing_table_with_routers(copy_of_routers): + return + + initial_routers = resolve(self.initial_address) for router in copy_of_routers: - new_routing_table = self.fetch_routing_table(router) - if new_routing_table is not None: - self.routing_table.update(new_routing_table) + if router in initial_routers: + initial_routers.remove(router) + if initial_routers: + if self.update_routing_table_with_routers(initial_routers): return # None of the routers have been successful, so just fail @@ -304,7 +321,7 @@ def __init__(self, uri, **config): def connector(a): return connect(a, security_plan.ssl_context, **config) - pool = RoutingConnectionPool(connector, *resolve(initial_address)) + pool = RoutingConnectionPool(connector, initial_address, *resolve(initial_address)) try: pool.update_routing_table() except: diff --git a/test/stub/test_routing.py b/test/stub/test_routing.py index bffcaa75d..6cf0ae357 100644 --- a/test/stub/test_routing.py +++ b/test/stub/test_routing.py @@ -48,6 +48,7 @@ "X": 1, } +UNREACHABLE_ADDRESS = ("127.0.0.1", 8080) def connector(address): return connect(address, auth=basic_auth("neotest", "neotest")) @@ -58,7 +59,7 @@ class RoutingConnectionPoolFetchRoutingInfoTestCase(StubTestCase): def test_should_get_info_from_router(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: result = pool.fetch_routing_info(address) assert len(result) == 1 record = result[0] @@ -72,7 +73,7 @@ def test_should_get_info_from_router(self): def test_should_remove_router_if_cannot_connect(self): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: assert address in pool.routing_table.routers _ = pool.fetch_routing_info(address) assert address not in pool.routing_table.routers @@ -80,14 +81,14 @@ def test_should_remove_router_if_cannot_connect(self): def test_should_remove_router_if_connection_drops(self): with StubCluster({9001: "rude_router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: assert address in pool.routing_table.routers _ = pool.fetch_routing_info(address) assert address not in pool.routing_table.routers def test_should_not_fail_if_cannot_connect_but_router_already_removed(self): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: assert address not in pool.routing_table.routers _ = pool.fetch_routing_info(address) assert address not in pool.routing_table.routers @@ -95,35 +96,35 @@ def test_should_not_fail_if_cannot_connect_but_router_already_removed(self): def test_should_not_fail_if_connection_drops_but_router_already_removed(self): with StubCluster({9001: "rude_router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: assert address not in pool.routing_table.routers _ = pool.fetch_routing_info(address) assert address not in pool.routing_table.routers def test_should_return_none_if_cannot_connect(self): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: result = pool.fetch_routing_info(address) assert result is None def test_should_return_none_if_connection_drops(self): with StubCluster({9001: "rude_router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: result = pool.fetch_routing_info(address) assert result is None def test_should_fail_for_non_router(self): with StubCluster({9001: "non_router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: with self.assertRaises(ServiceUnavailable): _ = pool.fetch_routing_info(address) def test_should_fail_if_database_error(self): with StubCluster({9001: "broken_router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: with self.assertRaises(ServiceUnavailable): _ = pool.fetch_routing_info(address) @@ -133,7 +134,7 @@ class RoutingConnectionPoolFetchRoutingTableTestCase(StubTestCase): def test_should_get_table_from_router(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: table = pool.fetch_routing_table(address) assert table.routers == {("127.0.0.1", 9001), ("127.0.0.1", 9002), ("127.0.0.1", 9003)} @@ -143,28 +144,28 @@ def test_should_get_table_from_router(self): def test_null_info_should_return_null_table(self): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: table = pool.fetch_routing_table(address) assert table is None def test_no_routers_should_raise_protocol_error(self): with StubCluster({9001: "router_no_routers.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: with self.assertRaises(ProtocolError): _ = pool.fetch_routing_table(address) def test_no_readers_should_raise_protocol_error(self): with StubCluster({9001: "router_no_readers.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: with self.assertRaises(ProtocolError): _ = pool.fetch_routing_table(address) def test_no_writers_should_return_null_table(self): with StubCluster({9001: "router_no_writers.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: table = pool.fetch_routing_table(address) assert table is None @@ -183,8 +184,21 @@ class RoutingConnectionPoolUpdateRoutingTableTestCase(StubTestCase): (None, None, ServiceUnavailable): ServiceUnavailable, } + def test_roll_back_to_initial_server_if_failed_update_with_existing_routers(self): + with StubCluster({9001: "router.script"}): + initial_address =("127.0.0.1", 9001) # roll back addresses + routers = [("127.0.0.1", 9002), ("127.0.0.1", 9003)] # not reachable servers + with RoutingConnectionPool(connector, initial_address, *routers) as pool: + pool.update_routing_table() + table = pool.routing_table + assert table.routers == {("127.0.0.1", 9001), ("127.0.0.1", 9002), + ("127.0.0.1", 9003)} + assert table.readers == {("127.0.0.1", 9004), ("127.0.0.1", 9005)} + assert table.writers == {("127.0.0.1", 9006)} + assert table.ttl == 300 + def test_update_with_no_routers_should_signal_service_unavailable(self): - with RoutingConnectionPool(connector) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool: with self.assertRaises(ServiceUnavailable): pool.update_routing_table() @@ -207,7 +221,7 @@ def _test_server_outcome(self, server_outcomes, overall_outcome): assert False, "Unexpected server outcome %r" % outcome routers.append(("127.0.0.1", port)) with StubCluster(servers): - with RoutingConnectionPool(connector, *routers) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, *routers) as pool: if overall_outcome is RoutingTable: pool.update_routing_table() table = pool.routing_table @@ -228,7 +242,7 @@ class RoutingConnectionPoolRefreshRoutingTableTestCase(StubTestCase): def test_should_update_if_stale(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: first_updated_time = pool.routing_table.last_updated_time pool.routing_table.ttl = 0 pool.refresh_routing_table() @@ -238,7 +252,7 @@ def test_should_update_if_stale(self): def test_should_not_update_if_fresh(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: pool.refresh_routing_table() first_updated_time = pool.routing_table.last_updated_time pool.refresh_routing_table() @@ -250,7 +264,7 @@ def test_should_not_update_if_fresh(self): # address = ("127.0.0.1", 9001) # table = RoutingTable.parse_routing_info([VALID_ROUTING_RECORD]) # - # with RoutingConnectionPool(connector, address) as pool: + # with RoutingConnectionPool(connector, not_reachable_address, address) as pool: # semaphore = Semaphore() # # class Refresher(Thread): @@ -297,7 +311,7 @@ def test_should_not_update_if_fresh(self): # address = ("127.0.0.1", 9001) # table = RoutingTable.parse_routing_info([VALID_ROUTING_RECORD]) # - # with RoutingConnectionPool(connector, address) as pool: + # with RoutingConnectionPool(connector, not_reachable_address, address) as pool: # semaphore = Semaphore() # # class Refresher(Thread): @@ -345,7 +359,7 @@ class RoutingConnectionPoolAcquireForReadTestCase(StubTestCase): def test_should_refresh(self): with StubCluster({9001: "router.script", 9004: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: assert not pool.routing_table.is_fresh() _ = pool.acquire(access_mode=READ_ACCESS) assert pool.routing_table.is_fresh() @@ -353,7 +367,7 @@ def test_should_refresh(self): def test_connected_to_reader(self): with StubCluster({9001: "router.script", 9004: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: assert not pool.routing_table.is_fresh() connection = pool.acquire(access_mode=READ_ACCESS) assert connection.server.address in pool.routing_table.readers @@ -363,7 +377,7 @@ def test_should_retry_if_first_reader_fails(self): 9004: "fail_on_init.script", 9005: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: assert not pool.routing_table.is_fresh() _ = pool.acquire(access_mode=READ_ACCESS) assert ("127.0.0.1", 9004) not in pool.routing_table.readers @@ -375,7 +389,7 @@ class RoutingConnectionPoolAcquireForWriteTestCase(StubTestCase): def test_should_refresh(self): with StubCluster({9001: "router.script", 9006: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: assert not pool.routing_table.is_fresh() _ = pool.acquire(access_mode=WRITE_ACCESS) assert pool.routing_table.is_fresh() @@ -383,7 +397,7 @@ def test_should_refresh(self): def test_connected_to_writer(self): with StubCluster({9001: "router.script", 9006: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: assert not pool.routing_table.is_fresh() connection = pool.acquire(access_mode=WRITE_ACCESS) assert connection.server.address in pool.routing_table.writers @@ -393,7 +407,7 @@ def test_should_retry_if_first_writer_fails(self): 9006: "fail_on_init.script", 9007: "empty.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: assert not pool.routing_table.is_fresh() _ = pool.acquire(access_mode=WRITE_ACCESS) assert ("127.0.0.1", 9006) not in pool.routing_table.writers @@ -405,7 +419,7 @@ class RoutingConnectionPoolRemoveTestCase(StubTestCase): def test_should_remove_router_from_routing_table_if_present(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: pool.refresh_routing_table() target = ("127.0.0.1", 9001) assert target in pool.routing_table.routers @@ -415,7 +429,7 @@ def test_should_remove_router_from_routing_table_if_present(self): def test_should_remove_reader_from_routing_table_if_present(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: pool.refresh_routing_table() target = ("127.0.0.1", 9004) assert target in pool.routing_table.readers @@ -425,7 +439,7 @@ def test_should_remove_reader_from_routing_table_if_present(self): def test_should_remove_writer_from_routing_table_if_present(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: pool.refresh_routing_table() target = ("127.0.0.1", 9006) assert target in pool.routing_table.writers @@ -435,7 +449,7 @@ def test_should_remove_writer_from_routing_table_if_present(self): def test_should_not_fail_if_absent(self): with StubCluster({9001: "router.script"}): address = ("127.0.0.1", 9001) - with RoutingConnectionPool(connector, address) as pool: + with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool: pool.refresh_routing_table() target = ("127.0.0.1", 9007) pool.remove(target) diff --git a/test/unit/test_routing.py b/test/unit/test_routing.py index 9e26689db..5361259e9 100644 --- a/test/unit/test_routing.py +++ b/test/unit/test_routing.py @@ -227,5 +227,7 @@ def test_update_should_replace_ttl(self): class RoutingConnectionPoolConstructionTestCase(TestCase): def test_should_populate_initial_router(self): - with RoutingConnectionPool(connector, ("127.0.0.1", 9001)) as pool: - assert pool.routing_table.routers == {("127.0.0.1", 9001)} + initial_router = ("127.0.0.1", 9001) + router = ("127.0.0.1", 9002) + with RoutingConnectionPool(connector, initial_router, router) as pool: + assert pool.routing_table.routers == {("127.0.0.1", 9002)}