diff --git a/neo4j/v1/bolt.py b/neo4j/v1/bolt.py index 417aba0fc..f85e9e19d 100644 --- a/neo4j/v1/bolt.py +++ b/neo4j/v1/bolt.py @@ -29,7 +29,7 @@ from __future__ import division from base64 import b64encode -from collections import deque +from collections import deque, namedtuple from io import BytesIO import logging from os import makedirs, open as os_open, write as os_write, close as os_close, O_CREAT, O_APPEND, O_WRONLY @@ -81,12 +81,16 @@ log_error = log.error +Address = namedtuple("Address", ["host", "port"]) +ServerInfo = namedtuple("ServerInfo", ["address", "version"]) + + class BufferingSocket(object): def __init__(self, connection): self.connection = connection self.socket = connection.socket - self.address = self.socket.getpeername() + self.address = Address(*self.socket.getpeername()) self.buffer = bytearray() def fill(self): @@ -132,7 +136,7 @@ class ChunkChannel(object): def __init__(self, sock): self.socket = sock - self.address = sock.getpeername() + self.address = Address(*sock.getpeername()) self.raw = BytesIO() self.output_buffer = [] self.output_size = 0 @@ -206,6 +210,22 @@ def on_ignored(self, metadata=None): pass +class InitResponse(Response): + + def on_success(self, metadata): + super(InitResponse, self).on_success(metadata) + connection = self.connection + address = Address(*connection.socket.getpeername()) + version = metadata.get("server") + connection.server = ServerInfo(address, version) + + def on_failure(self, metadata): + code = metadata.get("code") + error = (Unauthorized if code == "Neo.ClientError.Security.Unauthorized" else + ServiceUnavailable) + raise error(metadata.get("message", "INIT failed")) + + class Connection(object): """ Server connection for Bolt protocol v1. @@ -222,15 +242,15 @@ class Connection(object): defunct = False - server_version = None # TODO: remove this when PR#108 is merged - #: The pool of which this connection is a member pool = None + #: Server version details + server = None + def __init__(self, sock, **config): self.socket = sock self.buffering_socket = BufferingSocket(self) - self.address = sock.getpeername() self.channel = ChunkChannel(sock) self.packer = Packer(self.channel) self.unpacker = Unpacker() @@ -251,19 +271,7 @@ def __init__(self, sock, **config): # Pick up the server certificate, if any self.der_encoded_server_certificate = config.get("der_encoded_server_certificate") - def on_success(metadata): - self.server_version = metadata.get("server") - - def on_failure(metadata): - code = metadata.get("code") - error = (Unauthorized if code == "Neo.ClientError.Security.Unauthorized" else - ServiceUnavailable) - raise error(metadata.get("message", "INIT failed")) - - response = Response(self) - response.on_success = on_success - response.on_failure = on_failure - + response = InitResponse(self) self.append(INIT, (self.user_agent, self.auth_dict), response=response) self.sync() @@ -323,18 +331,18 @@ def send(self): """ Send all queued messages to the server. """ if self.closed: - raise ServiceUnavailable("Failed to write to closed connection %r" % (self.address,)) + raise ServiceUnavailable("Failed to write to closed connection %r" % (self.server.address,)) if self.defunct: - raise ServiceUnavailable("Failed to write to defunct connection %r" % (self.address,)) + raise ServiceUnavailable("Failed to write to defunct connection %r" % (self.server.address,)) self.channel.send() def fetch(self): """ Receive exactly one message from the server. """ if self.closed: - raise ServiceUnavailable("Failed to read from closed connection %r" % (self.address,)) + raise ServiceUnavailable("Failed to read from closed connection %r" % (self.server.address,)) if self.defunct: - raise ServiceUnavailable("Failed to read from defunct connection %r" % (self.address,)) + raise ServiceUnavailable("Failed to read from defunct connection %r" % (self.server.address,)) try: message_data = self.buffering_socket.read_message() except ProtocolError: diff --git a/neo4j/v1/routing.py b/neo4j/v1/routing.py index a408fdfe1..a58c758b2 100644 --- a/neo4j/v1/routing.py +++ b/neo4j/v1/routing.py @@ -22,7 +22,7 @@ from threading import Lock from time import clock -from .bolt import ConnectionPool +from .bolt import Address, ConnectionPool from .compat.collections import MutableSet, OrderedDict from .exceptions import CypherError, ProtocolError, ServiceUnavailable @@ -94,7 +94,7 @@ def parse_address(cls, address): """ Convert an address string to a tuple. """ host, _, port = address.partition(":") - return host, int(port) + return Address(host, int(port)) @classmethod def parse_routing_info(cls, records): diff --git a/test/test_driver.py b/test/test_driver.py index 55580f6fb..76de32c48 100644 --- a/test/test_driver.py +++ b/test/test_driver.py @@ -159,7 +159,7 @@ def test_should_be_able_to_read(self): result = session.run("RETURN $x", {"x": 1}) for record in result: assert record["x"] == 1 - assert session.connection.address == ('127.0.0.1', 9004) + assert session.connection.server.address == ('127.0.0.1', 9004) def test_should_be_able_to_write(self): with StubCluster({9001: "router.script", 9006: "create_a.script"}): @@ -168,7 +168,7 @@ def test_should_be_able_to_write(self): with driver.session(WRITE_ACCESS) as session: result = session.run("CREATE (a $x)", {"x": {"name": "Alice"}}) assert not list(result) - assert session.connection.address == ('127.0.0.1', 9006) + assert session.connection.server.address == ('127.0.0.1', 9006) def test_should_be_able_to_write_as_default(self): with StubCluster({9001: "router.script", 9006: "create_a.script"}): @@ -177,7 +177,7 @@ def test_should_be_able_to_write_as_default(self): with driver.session() as session: result = session.run("CREATE (a $x)", {"x": {"name": "Alice"}}) assert not list(result) - assert session.connection.address == ('127.0.0.1', 9006) + assert session.connection.server.address == ('127.0.0.1', 9006) def test_routing_disconnect_on_run(self): with StubCluster({9001: "router.script", 9004: "disconnect_on_run.script"}): diff --git a/test/test_routing.py b/test/test_routing.py index 8a30aa3e8..b4c60b12f 100644 --- a/test/test_routing.py +++ b/test/test_routing.py @@ -575,7 +575,7 @@ def test_connected_to_reader(self): with RoutingConnectionPool(connector, address) as pool: assert not pool.routing_table.is_fresh() connection = pool.acquire_for_read() - assert connection.address in pool.routing_table.readers + assert connection.server.address in pool.routing_table.readers def test_should_retry_if_first_reader_fails(self): with StubCluster({9001: "router.script", @@ -605,7 +605,7 @@ def test_connected_to_writer(self): with RoutingConnectionPool(connector, address) as pool: assert not pool.routing_table.is_fresh() connection = pool.acquire_for_write() - assert connection.address in pool.routing_table.writers + assert connection.server.address in pool.routing_table.writers def test_should_retry_if_first_writer_fails(self): with StubCluster({9001: "router_with_multiple_writers.script", diff --git a/test/test_session.py b/test/test_session.py index 6fc4fd3fd..e3b9024ce 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -37,7 +37,7 @@ def get_server_version(): driver = GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN, encrypted=False) with driver.session() as session: - full_version = session.connection.server_version + full_version = session.connection.server.version if full_version is None: return "Neo4j", (3, 0), () product, _, tagged_version = full_version.partition("/")