diff --git a/redis/connection.py b/redis/connection.py index 004c7a6f78..908382e3cc 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -2,6 +2,8 @@ from distutils.version import StrictVersion from itertools import chain from select import select +from copy import copy + import os import socket import sys @@ -534,6 +536,26 @@ def disconnect(self): pass self._sock = None + def shutdown_socket(self): + """ + Shutdown the socket hold by the current connection, called from + the connection pool class u other manager to singal it that has to be + disconnected in a thread safe way. Later the connection instance + will get an error and will call `disconnect` by it self. + """ + try: + self._sock.shutdown(socket.SHUT_RDWR) + except AttributeError: + # either _sock attribute does not exist or + # connection thread removed it. + pass + except OSError as e: + if e.errno == 107: + # Transport endpoint is not connected + pass + else: + raise + def send_packed_command(self, command): "Send an already packed command to the Redis server" if not self._sock: @@ -950,10 +972,10 @@ def release(self, connection): def disconnect(self): "Disconnects all connections in the pool" - all_conns = chain(self._available_connections, - self._in_use_connections) + all_conns = chain(copy(self._available_connections), + copy(self._in_use_connections)) for connection in all_conns: - connection.disconnect() + connection.shutdown_socket() class BlockingConnectionPool(ConnectionPool): @@ -1072,4 +1094,4 @@ def release(self, connection): def disconnect(self): "Disconnects all connections in the pool." for connection in self._connections: - connection.disconnect() + connection.shutdown_socket() diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 11c20080a9..6b01fdd7e7 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -1,4 +1,5 @@ from __future__ import with_statement + import os import pytest import redis @@ -69,6 +70,30 @@ def test_repr_contains_db_info_unix(self): expected = 'ConnectionPool>' assert repr(pool) == expected + def test_disconnect_active_connections(self): + + class MyConnection(redis.Connection): + + connect_calls = 0 + + def __init__(self, *args, **kwargs): + super(MyConnection, self).__init__(*args, **kwargs) + self.register_connect_callback(self.count_connect) + + def count_connect(self, connection): + MyConnection.connect_calls += 1 + + pool = self.get_pool(connection_class=MyConnection) + r = redis.StrictRedis(connection_pool=pool) + r.ping() + pool.disconnect() + r.ping() + + # If the connection is not disconnected by the pool the + # callback belonging to Connection will be called just + # one time. + assert MyConnection.connect_calls == 2 + class TestBlockingConnectionPool(object): def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20):