Skip to content

Roll back to initial server if running out of routers #155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 27, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions neo4j/v1/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,9 @@ class RoutingConnectionPool(ConnectionPool):

routing_info_procedure = "dbms.cluster.routing.getServers"

def __init__(self, connector, *routers):
def __init__(self, connector, initial_address, *routers):
super(RoutingConnectionPool, self).__init__(connector)
self.initial_address = initial_address
self.routing_table = RoutingTable(routers)
self.refresh_lock = Lock()

Expand Down Expand Up @@ -216,16 +217,32 @@ def fetch_routing_table(self, address):
# At least one of each is fine, so return this table
return new_routing_table

def update_routing_table_with_routers(self, routers):
"""Try to update routing tables with the given routers
:return: True if the routing table is successfully updated, otherwise False
"""
for router in routers:
new_routing_table = self.fetch_routing_table(router)
if new_routing_table is not None:
self.routing_table.update(new_routing_table)
return True
return False

def update_routing_table(self):
""" Update the routing table from the first router able to provide
valid routing information.
"""
# copied because it can be modified
copy_of_routers = list(self.routing_table.routers)
if self.update_routing_table_with_routers(copy_of_routers):
return

initial_routers = resolve(self.initial_address)
for router in copy_of_routers:
new_routing_table = self.fetch_routing_table(router)
if new_routing_table is not None:
self.routing_table.update(new_routing_table)
if router in initial_routers:
initial_routers.remove(router)
if initial_routers:
if self.update_routing_table_with_routers(initial_routers):
return

# None of the routers have been successful, so just fail
Expand Down Expand Up @@ -304,7 +321,7 @@ def __init__(self, uri, **config):
def connector(a):
return connect(a, security_plan.ssl_context, **config)

pool = RoutingConnectionPool(connector, *resolve(initial_address))
pool = RoutingConnectionPool(connector, initial_address, *resolve(initial_address))
try:
pool.update_routing_table()
except:
Expand Down
74 changes: 44 additions & 30 deletions test/stub/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"X": 1,
}

UNREACHABLE_ADDRESS = ("127.0.0.1", 8080)

def connector(address):
return connect(address, auth=basic_auth("neotest", "neotest"))
Expand All @@ -58,7 +59,7 @@ class RoutingConnectionPoolFetchRoutingInfoTestCase(StubTestCase):
def test_should_get_info_from_router(self):
with StubCluster({9001: "router.script"}):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool:
result = pool.fetch_routing_info(address)
assert len(result) == 1
record = result[0]
Expand All @@ -72,58 +73,58 @@ def test_should_get_info_from_router(self):

def test_should_remove_router_if_cannot_connect(self):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector, address) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool:
assert address in pool.routing_table.routers
_ = pool.fetch_routing_info(address)
assert address not in pool.routing_table.routers

def test_should_remove_router_if_connection_drops(self):
with StubCluster({9001: "rude_router.script"}):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector, address) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool:
assert address in pool.routing_table.routers
_ = pool.fetch_routing_info(address)
assert address not in pool.routing_table.routers

def test_should_not_fail_if_cannot_connect_but_router_already_removed(self):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool:
assert address not in pool.routing_table.routers
_ = pool.fetch_routing_info(address)
assert address not in pool.routing_table.routers

def test_should_not_fail_if_connection_drops_but_router_already_removed(self):
with StubCluster({9001: "rude_router.script"}):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool:
assert address not in pool.routing_table.routers
_ = pool.fetch_routing_info(address)
assert address not in pool.routing_table.routers

def test_should_return_none_if_cannot_connect(self):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector, address) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool:
result = pool.fetch_routing_info(address)
assert result is None

def test_should_return_none_if_connection_drops(self):
with StubCluster({9001: "rude_router.script"}):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector, address) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool:
result = pool.fetch_routing_info(address)
assert result is None

def test_should_fail_for_non_router(self):
with StubCluster({9001: "non_router.script"}):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector, address) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool:
with self.assertRaises(ServiceUnavailable):
_ = pool.fetch_routing_info(address)

def test_should_fail_if_database_error(self):
with StubCluster({9001: "broken_router.script"}):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector, address) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool:
with self.assertRaises(ServiceUnavailable):
_ = pool.fetch_routing_info(address)

Expand All @@ -133,7 +134,7 @@ class RoutingConnectionPoolFetchRoutingTableTestCase(StubTestCase):
def test_should_get_table_from_router(self):
with StubCluster({9001: "router.script"}):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool:
table = pool.fetch_routing_table(address)
assert table.routers == {("127.0.0.1", 9001), ("127.0.0.1", 9002),
("127.0.0.1", 9003)}
Expand All @@ -143,28 +144,28 @@ def test_should_get_table_from_router(self):

def test_null_info_should_return_null_table(self):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool:
table = pool.fetch_routing_table(address)
assert table is None

def test_no_routers_should_raise_protocol_error(self):
with StubCluster({9001: "router_no_routers.script"}):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool:
with self.assertRaises(ProtocolError):
_ = pool.fetch_routing_table(address)

def test_no_readers_should_raise_protocol_error(self):
with StubCluster({9001: "router_no_readers.script"}):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool:
with self.assertRaises(ProtocolError):
_ = pool.fetch_routing_table(address)

def test_no_writers_should_return_null_table(self):
with StubCluster({9001: "router_no_writers.script"}):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool:
table = pool.fetch_routing_table(address)
assert table is None

Expand All @@ -183,8 +184,21 @@ class RoutingConnectionPoolUpdateRoutingTableTestCase(StubTestCase):
(None, None, ServiceUnavailable): ServiceUnavailable,
}

def test_roll_back_to_initial_server_if_failed_update_with_existing_routers(self):
with StubCluster({9001: "router.script"}):
initial_address =("127.0.0.1", 9001) # roll back addresses
routers = [("127.0.0.1", 9002), ("127.0.0.1", 9003)] # not reachable servers
with RoutingConnectionPool(connector, initial_address, *routers) as pool:
pool.update_routing_table()
table = pool.routing_table
assert table.routers == {("127.0.0.1", 9001), ("127.0.0.1", 9002),
("127.0.0.1", 9003)}
assert table.readers == {("127.0.0.1", 9004), ("127.0.0.1", 9005)}
assert table.writers == {("127.0.0.1", 9006)}
assert table.ttl == 300

def test_update_with_no_routers_should_signal_service_unavailable(self):
with RoutingConnectionPool(connector) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS) as pool:
with self.assertRaises(ServiceUnavailable):
pool.update_routing_table()

Expand All @@ -207,7 +221,7 @@ def _test_server_outcome(self, server_outcomes, overall_outcome):
assert False, "Unexpected server outcome %r" % outcome
routers.append(("127.0.0.1", port))
with StubCluster(servers):
with RoutingConnectionPool(connector, *routers) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, *routers) as pool:
if overall_outcome is RoutingTable:
pool.update_routing_table()
table = pool.routing_table
Expand All @@ -228,7 +242,7 @@ class RoutingConnectionPoolRefreshRoutingTableTestCase(StubTestCase):
def test_should_update_if_stale(self):
with StubCluster({9001: "router.script"}):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector, address) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool:
first_updated_time = pool.routing_table.last_updated_time
pool.routing_table.ttl = 0
pool.refresh_routing_table()
Expand All @@ -238,7 +252,7 @@ def test_should_update_if_stale(self):
def test_should_not_update_if_fresh(self):
with StubCluster({9001: "router.script"}):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector, address) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool:
pool.refresh_routing_table()
first_updated_time = pool.routing_table.last_updated_time
pool.refresh_routing_table()
Expand All @@ -250,7 +264,7 @@ def test_should_not_update_if_fresh(self):
# address = ("127.0.0.1", 9001)
# table = RoutingTable.parse_routing_info([VALID_ROUTING_RECORD])
#
# with RoutingConnectionPool(connector, address) as pool:
# with RoutingConnectionPool(connector, not_reachable_address, address) as pool:
# semaphore = Semaphore()
#
# class Refresher(Thread):
Expand Down Expand Up @@ -297,7 +311,7 @@ def test_should_not_update_if_fresh(self):
# address = ("127.0.0.1", 9001)
# table = RoutingTable.parse_routing_info([VALID_ROUTING_RECORD])
#
# with RoutingConnectionPool(connector, address) as pool:
# with RoutingConnectionPool(connector, not_reachable_address, address) as pool:
# semaphore = Semaphore()
#
# class Refresher(Thread):
Expand Down Expand Up @@ -345,15 +359,15 @@ class RoutingConnectionPoolAcquireForReadTestCase(StubTestCase):
def test_should_refresh(self):
with StubCluster({9001: "router.script", 9004: "empty.script"}):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector, address) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool:
assert not pool.routing_table.is_fresh()
_ = pool.acquire(access_mode=READ_ACCESS)
assert pool.routing_table.is_fresh()

def test_connected_to_reader(self):
with StubCluster({9001: "router.script", 9004: "empty.script"}):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector, address) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool:
assert not pool.routing_table.is_fresh()
connection = pool.acquire(access_mode=READ_ACCESS)
assert connection.server.address in pool.routing_table.readers
Expand All @@ -363,7 +377,7 @@ def test_should_retry_if_first_reader_fails(self):
9004: "fail_on_init.script",
9005: "empty.script"}):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector, address) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool:
assert not pool.routing_table.is_fresh()
_ = pool.acquire(access_mode=READ_ACCESS)
assert ("127.0.0.1", 9004) not in pool.routing_table.readers
Expand All @@ -375,15 +389,15 @@ class RoutingConnectionPoolAcquireForWriteTestCase(StubTestCase):
def test_should_refresh(self):
with StubCluster({9001: "router.script", 9006: "empty.script"}):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector, address) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool:
assert not pool.routing_table.is_fresh()
_ = pool.acquire(access_mode=WRITE_ACCESS)
assert pool.routing_table.is_fresh()

def test_connected_to_writer(self):
with StubCluster({9001: "router.script", 9006: "empty.script"}):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector, address) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool:
assert not pool.routing_table.is_fresh()
connection = pool.acquire(access_mode=WRITE_ACCESS)
assert connection.server.address in pool.routing_table.writers
Expand All @@ -393,7 +407,7 @@ def test_should_retry_if_first_writer_fails(self):
9006: "fail_on_init.script",
9007: "empty.script"}):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector, address) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool:
assert not pool.routing_table.is_fresh()
_ = pool.acquire(access_mode=WRITE_ACCESS)
assert ("127.0.0.1", 9006) not in pool.routing_table.writers
Expand All @@ -405,7 +419,7 @@ class RoutingConnectionPoolRemoveTestCase(StubTestCase):
def test_should_remove_router_from_routing_table_if_present(self):
with StubCluster({9001: "router.script"}):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector, address) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool:
pool.refresh_routing_table()
target = ("127.0.0.1", 9001)
assert target in pool.routing_table.routers
Expand All @@ -415,7 +429,7 @@ def test_should_remove_router_from_routing_table_if_present(self):
def test_should_remove_reader_from_routing_table_if_present(self):
with StubCluster({9001: "router.script"}):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector, address) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool:
pool.refresh_routing_table()
target = ("127.0.0.1", 9004)
assert target in pool.routing_table.readers
Expand All @@ -425,7 +439,7 @@ def test_should_remove_reader_from_routing_table_if_present(self):
def test_should_remove_writer_from_routing_table_if_present(self):
with StubCluster({9001: "router.script"}):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector, address) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool:
pool.refresh_routing_table()
target = ("127.0.0.1", 9006)
assert target in pool.routing_table.writers
Expand All @@ -435,7 +449,7 @@ def test_should_remove_writer_from_routing_table_if_present(self):
def test_should_not_fail_if_absent(self):
with StubCluster({9001: "router.script"}):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector, address) as pool:
with RoutingConnectionPool(connector, UNREACHABLE_ADDRESS, address) as pool:
pool.refresh_routing_table()
target = ("127.0.0.1", 9007)
pool.remove(target)
6 changes: 4 additions & 2 deletions test/unit/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,5 +227,7 @@ def test_update_should_replace_ttl(self):
class RoutingConnectionPoolConstructionTestCase(TestCase):

def test_should_populate_initial_router(self):
with RoutingConnectionPool(connector, ("127.0.0.1", 9001)) as pool:
assert pool.routing_table.routers == {("127.0.0.1", 9001)}
initial_router = ("127.0.0.1", 9001)
router = ("127.0.0.1", 9002)
with RoutingConnectionPool(connector, initial_router, router) as pool:
assert pool.routing_table.routers == {("127.0.0.1", 9002)}