Skip to content

Commit f81e6c4

Browse files
committed
Remove process-id checks from asyncio. Asyncio and fork() does not mix.
1 parent 19b55c6 commit f81e6c4

File tree

2 files changed

+6
-119
lines changed

2 files changed

+6
-119
lines changed

redis/asyncio/connection.py

Lines changed: 6 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,9 @@
22
import copy
33
import enum
44
import inspect
5-
import os
65
import socket
76
import ssl
87
import sys
9-
import threading
108
import weakref
119
from abc import abstractmethod
1210
from itertools import chain
@@ -97,7 +95,6 @@ class AbstractConnection:
9795
"""Manages communication to and from a Redis server"""
9896

9997
__slots__ = (
100-
"pid",
10198
"db",
10299
"username",
103100
"client_name",
@@ -158,7 +155,6 @@ def __init__(
158155
"1. 'password' and (optional) 'username'\n"
159156
"2. 'credential_provider'"
160157
)
161-
self.pid = os.getpid()
162158
self.db = db
163159
self.client_name = client_name
164160
self.lib_name = lib_name
@@ -381,12 +377,11 @@ async def disconnect(self, nowait: bool = False) -> None:
381377
if not self.is_connected:
382378
return
383379
try:
384-
if os.getpid() == self.pid:
385-
self._writer.close() # type: ignore[union-attr]
386-
# wait for close to finish, except when handling errors and
387-
# forcefully disconnecting.
388-
if not nowait:
389-
await self._writer.wait_closed() # type: ignore[union-attr]
380+
self._writer.close() # type: ignore[union-attr]
381+
# wait for close to finish, except when handling errors and
382+
# forcefully disconnecting.
383+
if not nowait:
384+
await self._writer.wait_closed() # type: ignore[union-attr]
390385
except OSError:
391386
pass
392387
finally:
@@ -1004,15 +999,6 @@ def __init__(
1004999
self.connection_kwargs = connection_kwargs
10051000
self.max_connections = max_connections
10061001

1007-
# a lock to protect the critical section in _checkpid().
1008-
# this lock is acquired when the process id changes, such as
1009-
# after a fork. during this time, multiple threads in the child
1010-
# process could attempt to acquire this lock. the first thread
1011-
# to acquire the lock will reset the data structures and lock
1012-
# object of this pool. subsequent threads acquiring this lock
1013-
# will notice the first thread already did the work and simply
1014-
# release the lock.
1015-
self._fork_lock = threading.Lock()
10161002
self._lock = asyncio.Lock()
10171003
self._created_connections: int
10181004
self._available_connections: List[AbstractConnection]
@@ -1032,67 +1018,8 @@ def reset(self):
10321018
self._available_connections = []
10331019
self._in_use_connections = set()
10341020

1035-
# this must be the last operation in this method. while reset() is
1036-
# called when holding _fork_lock, other threads in this process
1037-
# can call _checkpid() which compares self.pid and os.getpid() without
1038-
# holding any lock (for performance reasons). keeping this assignment
1039-
# as the last operation ensures that those other threads will also
1040-
# notice a pid difference and block waiting for the first thread to
1041-
# release _fork_lock. when each of these threads eventually acquire
1042-
# _fork_lock, they will notice that another thread already called
1043-
# reset() and they will immediately release _fork_lock and continue on.
1044-
self.pid = os.getpid()
1045-
1046-
def _checkpid(self):
1047-
# _checkpid() attempts to keep ConnectionPool fork-safe on modern
1048-
# systems. this is called by all ConnectionPool methods that
1049-
# manipulate the pool's state such as get_connection() and release().
1050-
#
1051-
# _checkpid() determines whether the process has forked by comparing
1052-
# the current process id to the process id saved on the ConnectionPool
1053-
# instance. if these values are the same, _checkpid() simply returns.
1054-
#
1055-
# when the process ids differ, _checkpid() assumes that the process
1056-
# has forked and that we're now running in the child process. the child
1057-
# process cannot use the parent's file descriptors (e.g., sockets).
1058-
# therefore, when _checkpid() sees the process id change, it calls
1059-
# reset() in order to reinitialize the child's ConnectionPool. this
1060-
# will cause the child to make all new connection objects.
1061-
#
1062-
# _checkpid() is protected by self._fork_lock to ensure that multiple
1063-
# threads in the child process do not call reset() multiple times.
1064-
#
1065-
# there is an extremely small chance this could fail in the following
1066-
# scenario:
1067-
# 1. process A calls _checkpid() for the first time and acquires
1068-
# self._fork_lock.
1069-
# 2. while holding self._fork_lock, process A forks (the fork()
1070-
# could happen in a different thread owned by process A)
1071-
# 3. process B (the forked child process) inherits the
1072-
# ConnectionPool's state from the parent. that state includes
1073-
# a locked _fork_lock. process B will not be notified when
1074-
# process A releases the _fork_lock and will thus never be
1075-
# able to acquire the _fork_lock.
1076-
#
1077-
# to mitigate this possible deadlock, _checkpid() will only wait 5
1078-
# seconds to acquire _fork_lock. if _fork_lock cannot be acquired in
1079-
# that time it is assumed that the child is deadlocked and a
1080-
# redis.ChildDeadlockedError error is raised.
1081-
if self.pid != os.getpid():
1082-
acquired = self._fork_lock.acquire(timeout=5)
1083-
if not acquired:
1084-
raise ChildDeadlockedError
1085-
# reset() the instance for the new process if another thread
1086-
# hasn't already done so
1087-
try:
1088-
if self.pid != os.getpid():
1089-
self.reset()
1090-
finally:
1091-
self._fork_lock.release()
1092-
10931021
async def get_connection(self, command_name, *keys, **options):
10941022
"""Get a connection from the pool"""
1095-
self._checkpid()
10961023
async with self._lock:
10971024
try:
10981025
connection = self._available_connections.pop()
@@ -1141,7 +1068,6 @@ def make_connection(self):
11411068

11421069
async def release(self, connection: AbstractConnection):
11431070
"""Releases the connection back to the pool"""
1144-
self._checkpid()
11451071
async with self._lock:
11461072
try:
11471073
self._in_use_connections.remove(connection)
@@ -1150,18 +1076,7 @@ async def release(self, connection: AbstractConnection):
11501076
# that the pool doesn't actually own
11511077
pass
11521078

1153-
if self.owns_connection(connection):
1154-
self._available_connections.append(connection)
1155-
else:
1156-
# pool doesn't own this connection. do not add it back
1157-
# to the pool and decrement the count so that another
1158-
# connection can take its place if needed
1159-
self._created_connections -= 1
1160-
await connection.disconnect()
1161-
return
1162-
1163-
def owns_connection(self, connection: AbstractConnection):
1164-
return connection.pid == self.pid
1079+
self._available_connections.append(connection)
11651080

11661081
async def disconnect(self, inuse_connections: bool = True):
11671082
"""
@@ -1171,7 +1086,6 @@ async def disconnect(self, inuse_connections: bool = True):
11711086
current in use, potentially by other tasks. Otherwise only disconnect
11721087
connections that are idle in the pool.
11731088
"""
1174-
self._checkpid()
11751089
async with self._lock:
11761090
if inuse_connections:
11771091
connections: Iterable[AbstractConnection] = chain(
@@ -1259,17 +1173,6 @@ def reset(self):
12591173
# disconnect them later.
12601174
self._connections = []
12611175

1262-
# this must be the last operation in this method. while reset() is
1263-
# called when holding _fork_lock, other threads in this process
1264-
# can call _checkpid() which compares self.pid and os.getpid() without
1265-
# holding any lock (for performance reasons). keeping this assignment
1266-
# as the last operation ensures that those other threads will also
1267-
# notice a pid difference and block waiting for the first thread to
1268-
# release _fork_lock. when each of these threads eventually acquire
1269-
# _fork_lock, they will notice that another thread already called
1270-
# reset() and they will immediately release _fork_lock and continue on.
1271-
self.pid = os.getpid()
1272-
12731176
def make_connection(self):
12741177
"""Make a fresh connection."""
12751178
connection = self.connection_class(**self.connection_kwargs)
@@ -1288,8 +1191,6 @@ async def get_connection(self, command_name, *keys, **options):
12881191
create new connections when we need to, i.e.: the actual number of
12891192
connections will only increase in response to demand.
12901193
"""
1291-
# Make sure we haven't changed process.
1292-
self._checkpid()
12931194

12941195
# Try and get a connection from the pool. If one isn't available within
12951196
# self.timeout then raise a ``ConnectionError``.
@@ -1331,17 +1232,6 @@ async def get_connection(self, command_name, *keys, **options):
13311232

13321233
async def release(self, connection: AbstractConnection):
13331234
"""Releases the connection back to the pool."""
1334-
# Make sure we haven't changed process.
1335-
self._checkpid()
1336-
if not self.owns_connection(connection):
1337-
# pool doesn't own this connection. do not add it back
1338-
# to the pool. instead add a None value which is a placeholder
1339-
# that will cause the pool to recreate the connection if
1340-
# its needed.
1341-
await connection.disconnect()
1342-
self.pool.put_nowait(None)
1343-
return
1344-
13451235
# Put the connection back into the pool.
13461236
try:
13471237
self.pool.put_nowait(connection)
@@ -1352,7 +1242,6 @@ async def release(self, connection: AbstractConnection):
13521242

13531243
async def disconnect(self, inuse_connections: bool = True):
13541244
"""Disconnects all connections in the pool."""
1355-
self._checkpid()
13561245
async with self._lock:
13571246
resp = await asyncio.gather(
13581247
*(connection.disconnect() for connection in self._connections),

tests/test_asyncio/test_connection_pool.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
import os
32
import re
43

54
import pytest
@@ -94,7 +93,6 @@ class DummyConnection(Connection):
9493

9594
def __init__(self, **kwargs):
9695
self.kwargs = kwargs
97-
self.pid = os.getpid()
9896

9997
async def connect(self):
10098
pass

0 commit comments

Comments
 (0)