From 2c0b7086189b3950fa4e0a25d32be146638221bb Mon Sep 17 00:00:00 2001 From: Zhen Date: Tue, 19 Sep 2017 11:52:20 +0200 Subject: [PATCH 1/3] Add max connection lifetime on each connection By default the max connection lifetime is infinite. --- neo4j/bolt/connection.py | 15 +++++++- test/integration/test_connection.py | 6 ++- test/unit/test_connection.py | 57 +++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 4 deletions(-) create mode 100644 test/unit/test_connection.py diff --git a/neo4j/bolt/connection.py b/neo4j/bolt/connection.py index 28e4adcac..75fd4851f 100644 --- a/neo4j/bolt/connection.py +++ b/neo4j/bolt/connection.py @@ -42,11 +42,14 @@ from neo4j.meta import version from neo4j.packstream import Packer, Unpacker from neo4j.util import import_best as _import_best +from time import clock ChunkedInputBuffer = _import_best("neo4j.bolt._io", "neo4j.bolt.io").ChunkedInputBuffer ChunkedOutputBuffer = _import_best("neo4j.bolt._io", "neo4j.bolt.io").ChunkedOutputBuffer +INFINITE_CONNECTION_LIFETIME = -1 +DEFAULT_MAX_CONNECTION_LIFETIME = INFINITE_CONNECTION_LIFETIME DEFAULT_CONNECTION_TIMEOUT = 5.0 DEFAULT_PORT = 7687 DEFAULT_USER_AGENT = "neo4j-python/%s" % version @@ -178,6 +181,8 @@ def __init__(self, address, sock, error_handler, **config): self.packer = Packer(self.output_buffer) self.unpacker = Unpacker() self.responses = deque() + self._max_connection_lifetime = config.get("max_connection_lifetime", DEFAULT_MAX_CONNECTION_LIFETIME) + self._creation_timestamp = clock() # Determine the user agent and ensure it is a Unicode value user_agent = config.get("user_agent", DEFAULT_USER_AGENT) @@ -201,6 +206,7 @@ def __init__(self, address, sock, error_handler, **config): # Pick up the server certificate, if any self.der_encoded_server_certificate = config.get("der_encoded_server_certificate") + def Init(self): response = InitResponse(self) self.append(INIT, (self.user_agent, self.auth_dict), response=response) self.sync() @@ -360,6 +366,9 @@ def _unpack(self): more = False return details, summary_signature, summary_metadata + def timedout(self): + return 0 <= self._max_connection_lifetime <= clock() - self._creation_timestamp + def sync(self): """ Send and fetch all outstanding messages. @@ -425,7 +434,7 @@ def acquire_direct(self, address): except KeyError: connections = self.connections[address] = deque() for connection in list(connections): - if connection.closed() or connection.defunct(): + if connection.closed() or connection.defunct() or connection.timedout(): connections.remove(connection) continue if not connection.in_use: @@ -600,8 +609,10 @@ def connect(address, ssl_context=None, error_handler=None, **config): s.shutdown(SHUT_RDWR) s.close() elif agreed_version == 1: - return Connection(address, s, der_encoded_server_certificate=der_encoded_server_certificate, + connection = Connection(address, s, der_encoded_server_certificate=der_encoded_server_certificate, error_handler=error_handler, **config) + connection.Init() + return connection elif agreed_version == 0x48545450: log_error("S: [CLOSE]") s.close() diff --git a/test/integration/test_connection.py b/test/integration/test_connection.py index 703f97df0..fb0d0ab7c 100644 --- a/test/integration/test_connection.py +++ b/test/integration/test_connection.py @@ -22,7 +22,6 @@ from socket import create_connection from neo4j.v1 import ConnectionPool, ServiceUnavailable, DirectConnectionErrorHandler - from test.integration.tools import IntegrationTestCase @@ -44,6 +43,9 @@ def closed(self): def defunct(self): return False + def timedout(self): + return False + def connector(address, _): return QuickConnection(create_connection(address)) @@ -119,4 +121,4 @@ def test_in_use_count(self): connection = self.pool.acquire_direct(address) self.assertEqual(self.pool.in_use_connection_count(address), 1) self.pool.release(connection) - self.assertEqual(self.pool.in_use_connection_count(address), 0) + self.assertEqual(self.pool.in_use_connection_count(address), 0) \ No newline at end of file diff --git a/test/unit/test_connection.py b/test/unit/test_connection.py new file mode 100644 index 000000000..1d10a4998 --- /dev/null +++ b/test/unit/test_connection.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +# Copyright (c) 2002-2017 "Neo Technology," +# Network Engine for Objects in Lund AB [http://neotechnology.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import TestCase +from neo4j.v1 import DirectConnectionErrorHandler +from neo4j.bolt import Connection + + +class FakeSocket(object): + def __init__(self, address): + self.address = address + + def getpeername(self): + return self.address + + def sendall(self, data): + return + + def close(self): + return + + +class ConnectionTestCase(TestCase): + + def test_conn_timedout(self): + address = ("127.0.0.1", 7687) + connection = Connection(address, FakeSocket(address), DirectConnectionErrorHandler(), max_connection_lifetime=0) + self.assertEqual(connection.timedout(), True) + + def test_conn_not_timedout_if_not_enabled(self): + address = ("127.0.0.1", 7687) + connection = Connection(address, FakeSocket(address), DirectConnectionErrorHandler(), + max_connection_lifetime=-1) + self.assertEqual(connection.timedout(), False) + + def test_conn_not_timedout(self): + address = ("127.0.0.1", 7687) + connection = Connection(address, FakeSocket(address), DirectConnectionErrorHandler(), + max_connection_lifetime=999999999) + self.assertEqual(connection.timedout(), False) \ No newline at end of file From 8a00a788eae9b55f4d3960e2b2b6464469355ac2 Mon Sep 17 00:00:00 2001 From: Zhen Date: Wed, 20 Sep 2017 16:19:17 +0200 Subject: [PATCH 2/3] Adding max_connection_pool_size as well as connection_acquisition_timeout Moved connection tests from integration test to unit test as no server is needed --- neo4j/bolt/connection.py | 68 +++++++++---- neo4j/v1/direct.py | 6 +- neo4j/v1/routing.py | 6 +- test/integration/test_connection.py | 124 ----------------------- test/unit/test_connection.py | 151 +++++++++++++++++++++++++++- 5 files changed, 200 insertions(+), 155 deletions(-) delete mode 100644 test/integration/test_connection.py diff --git a/neo4j/bolt/connection.py b/neo4j/bolt/connection.py index 75fd4851f..dbfbdc916 100644 --- a/neo4j/bolt/connection.py +++ b/neo4j/bolt/connection.py @@ -32,8 +32,9 @@ from select import select from socket import socket, SOL_SOCKET, SO_KEEPALIVE, SHUT_RDWR, error as SocketError, timeout as SocketTimeout, AF_INET, AF_INET6 from struct import pack as struct_pack, unpack as struct_unpack -from threading import RLock +from threading import RLock, Condition +from neo4j.v1 import ClientError from neo4j.addressing import SocketAddress, is_ip_address from neo4j.bolt.cert import KNOWN_HOSTS from neo4j.bolt.response import InitResponse, AckFailureResponse, ResetResponse @@ -48,9 +49,11 @@ ChunkedOutputBuffer = _import_best("neo4j.bolt._io", "neo4j.bolt.io").ChunkedOutputBuffer -INFINITE_CONNECTION_LIFETIME = -1 -DEFAULT_MAX_CONNECTION_LIFETIME = INFINITE_CONNECTION_LIFETIME +INFINITE = -1 +DEFAULT_MAX_CONNECTION_LIFETIME = INFINITE +DEFAULT_MAX_CONNECTION_POOL_SIZE = INFINITE DEFAULT_CONNECTION_TIMEOUT = 5.0 +DEFAULT_CONNECTION_ACQUISITION_TIMEOUT = 60 DEFAULT_PORT = 7687 DEFAULT_USER_AGENT = "neo4j-python/%s" % version @@ -405,11 +408,14 @@ class ConnectionPool(object): _closed = False - def __init__(self, connector, connection_error_handler): + def __init__(self, connector, connection_error_handler, **config): self.connector = connector self.connection_error_handler = connection_error_handler self.connections = {} self.lock = RLock() + self.cond = Condition(self.lock) + self._max_connection_pool_size = config.get("max_connection_pool_size", DEFAULT_MAX_CONNECTION_POOL_SIZE) + self._connection_acquisition_timeout = config.get("connection_acquisition_timeout", DEFAULT_CONNECTION_ACQUISITION_TIMEOUT) def __enter__(self): return self @@ -433,23 +439,42 @@ def acquire_direct(self, address): connections = self.connections[address] except KeyError: connections = self.connections[address] = deque() - for connection in list(connections): - if connection.closed() or connection.defunct() or connection.timedout(): - connections.remove(connection) - continue - if not connection.in_use: - connection.in_use = True - return connection - try: - connection = self.connector(address, self.connection_error_handler) - except ServiceUnavailable: - self.remove(address) - raise - else: - connection.pool = self - connection.in_use = True - connections.append(connection) - return connection + + connection_acquisition_start_timestamp = clock() + while True: + # try to find a free connection in pool + for connection in list(connections): + if connection.closed() or connection.defunct() or connection.timedout(): + connections.remove(connection) + continue + if not connection.in_use: + connection.in_use = True + return connection + # all connections in pool are in-use + can_create_new_connection = self._max_connection_pool_size == INFINITE or len(connections) < self._max_connection_pool_size + if can_create_new_connection: + try: + connection = self.connector(address, self.connection_error_handler) + except ServiceUnavailable: + self.remove(address) + raise + else: + connection.pool = self + connection.in_use = True + connections.append(connection) + return connection + + # failed to obtain a connection from pool because the pool is full and no free connection in the pool + span_timeout = self._connection_acquisition_timeout - (clock() - connection_acquisition_start_timestamp) + if span_timeout > 0: + self.cond.wait(span_timeout) + # if timed out, then we throw error. This time computation is needed, as with python 2.7, we cannot + # tell if the condition is notified or timed out when we come to this line + if self._connection_acquisition_timeout <= (clock() - connection_acquisition_start_timestamp): + raise ClientError("Failed to obtain a connection from pool within {!r}s".format( + self._connection_acquisition_timeout)) + else: + raise ClientError("Failed to obtain a connection from pool within {!r}s".format(self._connection_acquisition_timeout)) def acquire(self, access_mode=None): """ Acquire a connection to a server that can satisfy a set of parameters. @@ -463,6 +488,7 @@ def release(self, connection): """ with self.lock: connection.in_use = False + self.cond.notify_all() def in_use_connection_count(self, address): """ Count the number of connections currently in use to a given diff --git a/neo4j/v1/direct.py b/neo4j/v1/direct.py index c63419ee7..869de9d2a 100644 --- a/neo4j/v1/direct.py +++ b/neo4j/v1/direct.py @@ -37,8 +37,8 @@ def __init__(self): class DirectConnectionPool(ConnectionPool): - def __init__(self, connector, address): - super(DirectConnectionPool, self).__init__(connector, DirectConnectionErrorHandler()) + def __init__(self, connector, address, **config): + super(DirectConnectionPool, self).__init__(connector, DirectConnectionErrorHandler(), **config) self.address = address def acquire(self, access_mode=None): @@ -73,7 +73,7 @@ def __init__(self, uri, **config): def connector(address, error_handler): return connect(address, security_plan.ssl_context, error_handler, **config) - pool = DirectConnectionPool(connector, self.address) + pool = DirectConnectionPool(connector, self.address, **config) pool.release(pool.acquire()) Driver.__init__(self, pool, **config) diff --git a/neo4j/v1/routing.py b/neo4j/v1/routing.py index 4bfe688c0..7eb845d6a 100644 --- a/neo4j/v1/routing.py +++ b/neo4j/v1/routing.py @@ -34,7 +34,7 @@ LOAD_BALANCING_STRATEGY_LEAST_CONNECTED = 0 LOAD_BALANCING_STRATEGY_ROUND_ROBIN = 1 -LOAD_BALANCING_STRATEGY_DEFAULT = LOAD_BALANCING_STRATEGY_LEAST_CONNECTED +DEFAULT_LOAD_BALANCING_STRATEGY = LOAD_BALANCING_STRATEGY_LEAST_CONNECTED class OrderedSet(MutableSet): @@ -166,7 +166,7 @@ class LoadBalancingStrategy(object): @classmethod def build(cls, connection_pool, **config): - load_balancing_strategy = config.get("load_balancing_strategy", LOAD_BALANCING_STRATEGY_DEFAULT) + load_balancing_strategy = config.get("load_balancing_strategy", DEFAULT_LOAD_BALANCING_STRATEGY) if load_balancing_strategy == LOAD_BALANCING_STRATEGY_LEAST_CONNECTED: return LeastConnectedLoadBalancingStrategy(connection_pool) elif load_balancing_strategy == LOAD_BALANCING_STRATEGY_ROUND_ROBIN: @@ -265,7 +265,7 @@ class RoutingConnectionPool(ConnectionPool): """ def __init__(self, connector, initial_address, routing_context, *routers, **config): - super(RoutingConnectionPool, self).__init__(connector, RoutingConnectionErrorHandler(self)) + super(RoutingConnectionPool, self).__init__(connector, RoutingConnectionErrorHandler(self), **config) self.initial_address = initial_address self.routing_context = routing_context self.routing_table = RoutingTable(routers) diff --git a/test/integration/test_connection.py b/test/integration/test_connection.py deleted file mode 100644 index fb0d0ab7c..000000000 --- a/test/integration/test_connection.py +++ /dev/null @@ -1,124 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -# Copyright (c) 2002-2017 "Neo Technology," -# Network Engine for Objects in Lund AB [http://neotechnology.com] -# -# This file is part of Neo4j. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from socket import create_connection - -from neo4j.v1 import ConnectionPool, ServiceUnavailable, DirectConnectionErrorHandler -from test.integration.tools import IntegrationTestCase - - -class QuickConnection(object): - - def __init__(self, socket): - self.socket = socket - self.address = socket.getpeername() - - def reset(self): - pass - - def close(self): - self.socket.close() - - def closed(self): - return False - - def defunct(self): - return False - - def timedout(self): - return False - - -def connector(address, _): - return QuickConnection(create_connection(address)) - - -class ConnectionPoolTestCase(IntegrationTestCase): - - def setUp(self): - self.pool = ConnectionPool(connector, DirectConnectionErrorHandler()) - - def tearDown(self): - self.pool.close() - - def assert_pool_size(self, address, expected_active, expected_inactive): - try: - connections = self.pool.connections[address] - except KeyError: - assert 0 == expected_active - assert 0 == expected_inactive - else: - assert len([c for c in connections if c.in_use]) == expected_active - assert len([c for c in connections if not c.in_use]) == expected_inactive - - def test_can_acquire(self): - address = ("127.0.0.1", 7687) - connection = self.pool.acquire_direct(address) - assert connection.address == address - self.assert_pool_size(address, 1, 0) - - def test_can_acquire_twice(self): - address = ("127.0.0.1", 7687) - connection_1 = self.pool.acquire_direct(address) - connection_2 = self.pool.acquire_direct(address) - assert connection_1.address == address - assert connection_2.address == address - assert connection_1 is not connection_2 - self.assert_pool_size(address, 2, 0) - - def test_can_acquire_two_addresses(self): - address_1 = ("127.0.0.1", 7687) - address_2 = ("127.0.0.1", 7474) - connection_1 = self.pool.acquire_direct(address_1) - connection_2 = self.pool.acquire_direct(address_2) - assert connection_1.address == address_1 - assert connection_2.address == address_2 - self.assert_pool_size(address_1, 1, 0) - self.assert_pool_size(address_2, 1, 0) - - def test_can_acquire_and_release(self): - address = ("127.0.0.1", 7687) - connection = self.pool.acquire_direct(address) - self.assert_pool_size(address, 1, 0) - self.pool.release(connection) - self.assert_pool_size(address, 0, 1) - - def test_releasing_twice(self): - address = ("127.0.0.1", 7687) - connection = self.pool.acquire_direct(address) - self.pool.release(connection) - self.assert_pool_size(address, 0, 1) - self.pool.release(connection) - self.assert_pool_size(address, 0, 1) - - def test_cannot_acquire_after_close(self): - with ConnectionPool(lambda a: QuickConnection(create_connection(a)), DirectConnectionErrorHandler()) as pool: - pool.close() - with self.assertRaises(ServiceUnavailable): - _ = pool.acquire_direct("X") - - def test_in_use_count(self): - address = ("127.0.0.1", 7687) - self.assertEqual(self.pool.in_use_connection_count(address), 0) - connection = self.pool.acquire_direct(address) - self.assertEqual(self.pool.in_use_connection_count(address), 1) - self.pool.release(connection) - self.assertEqual(self.pool.in_use_connection_count(address), 0) \ No newline at end of file diff --git a/test/unit/test_connection.py b/test/unit/test_connection.py index 1d10a4998..e9ea0ea32 100644 --- a/test/unit/test_connection.py +++ b/test/unit/test_connection.py @@ -17,10 +17,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from __future__ import print_function +import time from unittest import TestCase -from neo4j.v1 import DirectConnectionErrorHandler -from neo4j.bolt import Connection +from threading import Thread, Event +from neo4j.v1 import DirectConnectionErrorHandler, ClientError, ServiceUnavailable +from neo4j.bolt import Connection, ConnectionPool class FakeSocket(object): @@ -37,6 +39,32 @@ def close(self): return +class QuickConnection(object): + + def __init__(self, socket): + self.socket = socket + self.address = socket.getpeername() + + def reset(self): + pass + + def close(self): + self.socket.close() + + def closed(self): + return False + + def defunct(self): + return False + + def timedout(self): + return False + + +def connector(address, _): + return QuickConnection(FakeSocket(address)) + + class ConnectionTestCase(TestCase): def test_conn_timedout(self): @@ -54,4 +82,119 @@ def test_conn_not_timedout(self): address = ("127.0.0.1", 7687) connection = Connection(address, FakeSocket(address), DirectConnectionErrorHandler(), max_connection_lifetime=999999999) - self.assertEqual(connection.timedout(), False) \ No newline at end of file + self.assertEqual(connection.timedout(), False) + + +class ConnectionPoolTestCase(TestCase): + def setUp(self): + self.pool = ConnectionPool(connector, DirectConnectionErrorHandler()) + + def tearDown(self): + self.pool.close() + + def assert_pool_size(self, address, expected_active, expected_inactive, pool=None): + if pool is None: + pool = self.pool + try: + connections = pool.connections[address] + except KeyError: + assert 0 == expected_active + assert 0 == expected_inactive + else: + assert len([c for c in connections if c.in_use]) == expected_active + assert len([c for c in connections if not c.in_use]) == expected_inactive + + def test_can_acquire(self): + address = ("127.0.0.1", 7687) + connection = self.pool.acquire_direct(address) + assert connection.address == address + self.assert_pool_size(address, 1, 0) + + def test_can_acquire_twice(self): + address = ("127.0.0.1", 7687) + connection_1 = self.pool.acquire_direct(address) + connection_2 = self.pool.acquire_direct(address) + assert connection_1.address == address + assert connection_2.address == address + assert connection_1 is not connection_2 + self.assert_pool_size(address, 2, 0) + + def test_can_acquire_two_addresses(self): + address_1 = ("127.0.0.1", 7687) + address_2 = ("127.0.0.1", 7474) + connection_1 = self.pool.acquire_direct(address_1) + connection_2 = self.pool.acquire_direct(address_2) + assert connection_1.address == address_1 + assert connection_2.address == address_2 + self.assert_pool_size(address_1, 1, 0) + self.assert_pool_size(address_2, 1, 0) + + def test_can_acquire_and_release(self): + address = ("127.0.0.1", 7687) + connection = self.pool.acquire_direct(address) + self.assert_pool_size(address, 1, 0) + self.pool.release(connection) + self.assert_pool_size(address, 0, 1) + + def test_releasing_twice(self): + address = ("127.0.0.1", 7687) + connection = self.pool.acquire_direct(address) + self.pool.release(connection) + self.assert_pool_size(address, 0, 1) + self.pool.release(connection) + self.assert_pool_size(address, 0, 1) + + def test_cannot_acquire_after_close(self): + with ConnectionPool(lambda a: QuickConnection(FakeSocket(a)), DirectConnectionErrorHandler()) as pool: + pool.close() + with self.assertRaises(ServiceUnavailable): + _ = pool.acquire_direct("X") + + def test_in_use_count(self): + address = ("127.0.0.1", 7687) + self.assertEqual(self.pool.in_use_connection_count(address), 0) + connection = self.pool.acquire_direct(address) + self.assertEqual(self.pool.in_use_connection_count(address), 1) + self.pool.release(connection) + self.assertEqual(self.pool.in_use_connection_count(address), 0) + + def test_max_conn_pool_size(self): + with ConnectionPool(connector, DirectConnectionErrorHandler, + max_connection_pool_size=1, connection_acquisition_timeout=0) as pool: + address = ("127.0.0.1", 7687) + pool.acquire_direct(address) + self.assertEqual(pool.in_use_connection_count(address), 1) + with self.assertRaises(ClientError) as context: + pool.acquire_direct(address) + self.assertTrue('Failed to obtain a connection from pool within 0s' in context.exception) + self.assertEqual(pool.in_use_connection_count(address), 1) + + def test_multithread(self): + with ConnectionPool(connector, DirectConnectionErrorHandler, + max_connection_pool_size=5, connection_acquisition_timeout=10) as pool: + address = ("127.0.0.1", 7687) + releasing_event = Event() + + # We start 10 threads to compete connections from pool with size of 5 + threads = [] + for i in range(10): + t = Thread(target=acquire_release_conn, args=(pool, address, releasing_event)) + t.start() + threads.append(t) + + # The pool size should be 5, all are in-use + self.assert_pool_size(address, 5, 0, pool) + # Now we allow thread to release connections they obtained from pool + releasing_event.set() + + # wait for all threads to release connections back to pool + for t in threads: + t.join() + # The pool size is still 5, but all are free + self.assert_pool_size(address, 0, 5, pool) + + +def acquire_release_conn(pool, address, releasing_event): + conn = pool.acquire_direct(address) + releasing_event.wait() + pool.release(conn) \ No newline at end of file From acfac716f96e48b181a9562f18a2a3cde5173bcc Mon Sep 17 00:00:00 2001 From: Zhen Date: Thu, 21 Sep 2017 11:51:54 +0200 Subject: [PATCH 3/3] cleanup some imports This is to fix some error run tests on zhen's local machine --- neo4j/bolt/connection.py | 3 +-- neo4j/v1/__init__.py | 3 --- neo4j/v1/api.py | 4 ++-- neo4j/v1/direct.py | 2 +- test/integration/test_driver.py | 4 ++-- test/integration/test_security.py | 3 ++- test/integration/test_session.py | 3 ++- test/integration/tools.py | 3 ++- test/performance/tools.py | 3 ++- test/stub/test_routingdriver.py | 3 ++- test/unit/test_connection.py | 8 +++----- 11 files changed, 19 insertions(+), 20 deletions(-) diff --git a/neo4j/bolt/connection.py b/neo4j/bolt/connection.py index dbfbdc916..9c5bcb0fa 100644 --- a/neo4j/bolt/connection.py +++ b/neo4j/bolt/connection.py @@ -34,12 +34,11 @@ from struct import pack as struct_pack, unpack as struct_unpack from threading import RLock, Condition -from neo4j.v1 import ClientError from neo4j.addressing import SocketAddress, is_ip_address from neo4j.bolt.cert import KNOWN_HOSTS from neo4j.bolt.response import InitResponse, AckFailureResponse, ResetResponse from neo4j.compat.ssl import SSL_AVAILABLE, HAS_SNI, SSLError -from neo4j.exceptions import ProtocolError, SecurityError, ServiceUnavailable +from neo4j.exceptions import ClientError, ProtocolError, SecurityError, ServiceUnavailable from neo4j.meta import version from neo4j.packstream import Packer, Unpacker from neo4j.util import import_best as _import_best diff --git a/neo4j/v1/__init__.py b/neo4j/v1/__init__.py index fa13808af..a2eacc335 100644 --- a/neo4j/v1/__init__.py +++ b/neo4j/v1/__init__.py @@ -18,9 +18,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from neo4j.exceptions import * - from .api import * from .direct import * from .exceptions import * diff --git a/neo4j/v1/api.py b/neo4j/v1/api.py index 4e063da36..e6d2358f6 100644 --- a/neo4j/v1/api.py +++ b/neo4j/v1/api.py @@ -25,8 +25,8 @@ from time import time, sleep from warnings import warn -from neo4j.bolt import ProtocolError, ServiceUnavailable -from neo4j.compat import unicode, urlparse +from neo4j.exceptions import ProtocolError, ServiceUnavailable +from neo4j.compat import urlparse from neo4j.exceptions import CypherError, TransientError from .exceptions import DriverError, SessionError, SessionExpired, TransactionError diff --git a/neo4j/v1/direct.py b/neo4j/v1/direct.py index 869de9d2a..2a8f6e6f4 100644 --- a/neo4j/v1/direct.py +++ b/neo4j/v1/direct.py @@ -20,7 +20,7 @@ from neo4j.addressing import SocketAddress, resolve -from neo4j.bolt import DEFAULT_PORT, ConnectionPool, connect, ConnectionErrorHandler +from neo4j.bolt.connection import DEFAULT_PORT, ConnectionPool, connect, ConnectionErrorHandler from neo4j.exceptions import ServiceUnavailable from neo4j.v1.api import Driver from neo4j.v1.security import SecurityPlan diff --git a/test/integration/test_driver.py b/test/integration/test_driver.py index f4466d127..4afc1d916 100644 --- a/test/integration/test_driver.py +++ b/test/integration/test_driver.py @@ -19,8 +19,8 @@ # limitations under the License. -from neo4j.v1 import GraphDatabase, ProtocolError, ServiceUnavailable - +from neo4j.v1 import GraphDatabase, ServiceUnavailable +from neo4j.exceptions import ProtocolError from test.integration.tools import IntegrationTestCase diff --git a/test/integration/test_security.py b/test/integration/test_security.py index 48ed29da1..9751a1b67 100644 --- a/test/integration/test_security.py +++ b/test/integration/test_security.py @@ -23,7 +23,8 @@ from ssl import SSLSocket from unittest import skipUnless -from neo4j.v1 import GraphDatabase, SSL_AVAILABLE, TRUST_ON_FIRST_USE, TRUST_CUSTOM_CA_SIGNED_CERTIFICATES, AuthError +from neo4j.v1 import GraphDatabase, SSL_AVAILABLE, TRUST_ON_FIRST_USE, TRUST_CUSTOM_CA_SIGNED_CERTIFICATES +from neo4j.exceptions import AuthError from test.integration.tools import IntegrationTestCase diff --git a/test/integration/test_session.py b/test/integration/test_session.py index e991c2e5d..6291f611b 100644 --- a/test/integration/test_session.py +++ b/test/integration/test_session.py @@ -24,7 +24,8 @@ from neo4j.v1 import \ READ_ACCESS, WRITE_ACCESS, \ CypherError, SessionError, TransactionError, \ - Node, Relationship, Path, CypherSyntaxError + Node, Relationship, Path +from neo4j.exceptions import CypherSyntaxError from test.integration.tools import DirectIntegrationTestCase diff --git a/test/integration/tools.py b/test/integration/tools.py index cd60a1cb7..604bafd83 100644 --- a/test/integration/tools.py +++ b/test/integration/tools.py @@ -32,7 +32,8 @@ from boltkit.controller import WindowsController, UnixController -from neo4j.v1 import GraphDatabase, AuthError +from neo4j.v1 import GraphDatabase +from neo4j.exceptions import AuthError from neo4j.util import ServerVersion from test.env import NEO4J_SERVER_PACKAGE, NEO4J_USER, NEO4J_PASSWORD diff --git a/test/performance/tools.py b/test/performance/tools.py index b216ff232..fbeca9ce7 100644 --- a/test/performance/tools.py +++ b/test/performance/tools.py @@ -34,7 +34,8 @@ from boltkit.controller import WindowsController, UnixController -from neo4j.v1 import GraphDatabase, AuthError +from neo4j.v1 import GraphDatabase +from neo4j.exceptions import AuthError from neo4j.util import ServerVersion from test.env import NEO4J_SERVER_PACKAGE, NEO4J_USER, NEO4J_PASSWORD diff --git a/test/stub/test_routingdriver.py b/test/stub/test_routingdriver.py index ae12fa8ab..ab2163ccc 100644 --- a/test/stub/test_routingdriver.py +++ b/test/stub/test_routingdriver.py @@ -21,7 +21,8 @@ from neo4j.v1 import GraphDatabase, READ_ACCESS, WRITE_ACCESS, SessionExpired, \ RoutingDriver, RoutingConnectionPool, LeastConnectedLoadBalancingStrategy, LOAD_BALANCING_STRATEGY_ROUND_ROBIN, \ - RoundRobinLoadBalancingStrategy, TransientError, ClientError + RoundRobinLoadBalancingStrategy, TransientError +from neo4j.exceptions import ClientError from neo4j.bolt import ProtocolError, ServiceUnavailable from test.stub.tools import StubTestCase, StubCluster diff --git a/test/unit/test_connection.py b/test/unit/test_connection.py index e9ea0ea32..f0226e8e1 100644 --- a/test/unit/test_connection.py +++ b/test/unit/test_connection.py @@ -18,12 +18,11 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function -import time from unittest import TestCase from threading import Thread, Event -from neo4j.v1 import DirectConnectionErrorHandler, ClientError, ServiceUnavailable +from neo4j.v1 import DirectConnectionErrorHandler, ServiceUnavailable from neo4j.bolt import Connection, ConnectionPool - +from neo4j.exceptions import ClientError class FakeSocket(object): def __init__(self, address): @@ -164,9 +163,8 @@ def test_max_conn_pool_size(self): address = ("127.0.0.1", 7687) pool.acquire_direct(address) self.assertEqual(pool.in_use_connection_count(address), 1) - with self.assertRaises(ClientError) as context: + with self.assertRaises(ClientError): pool.acquire_direct(address) - self.assertTrue('Failed to obtain a connection from pool within 0s' in context.exception) self.assertEqual(pool.in_use_connection_count(address), 1) def test_multithread(self):