Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,12 +717,31 @@ async def execute_command(self, *args, **options):
if self.single_connection_client:
await self._single_conn_lock.acquire()
try:
return await conn.retry.call_with_retry(
result = await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda _: self._close_connection(conn),
)

# Clean up iter_req_id for SCAN family commands when the cursor returns to 0
iter_req_id = options.get("iter_req_id")
if iter_req_id and command_name.upper() in (
"SCAN",
"SSCAN",
"HSCAN",
"ZSCAN",
):
# If the result is a tuple with cursor as the first element and cursor is 0, cleanup
if (
isinstance(result, (list, tuple))
and len(result) >= 2
and result[0] == 0
):
if hasattr(pool, "cleanup"):
Copy link
Preview

Copilot AI Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable pool is not defined in this scope. It should be referencing conn.connection_pool or self.connection_pool to access the cleanup method.

Copilot uses AI. Check for mistakes.

await pool.cleanup(iter_req_id)

return result
finally:
if self.single_connection_client:
self._single_conn_lock.release()
Expand Down
64 changes: 63 additions & 1 deletion redis/asyncio/sentinel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import asyncio
import random
import weakref
from typing import AsyncIterator, Iterable, Mapping, Optional, Sequence, Tuple, Type
from typing import (
AsyncIterator,
Dict,
Iterable,
Mapping,
Optional,
Sequence,
Tuple,
Type,
)

from redis.asyncio.client import Redis
from redis.asyncio.connection import (
Expand All @@ -17,6 +26,7 @@
ResponseError,
TimeoutError,
)
from redis.utils import deprecated_args


class MasterNotFoundError(ConnectionError):
Expand Down Expand Up @@ -121,6 +131,7 @@ def __init__(self, service_name, sentinel_manager, **kwargs):
self.sentinel_manager = sentinel_manager
self.master_address = None
self.slave_rr_counter = None
self._iter_req_connections: Dict[str, tuple] = {}

def __repr__(self):
return (
Expand Down Expand Up @@ -166,6 +177,57 @@ async def rotate_slaves(self) -> AsyncIterator:
pass
raise SlaveNotFoundError(f"No slave found for {self.service_name!r}")

async def cleanup(self, iter_req_id: str):
"""Remove tracking for a completed iteration request."""
self._iter_req_connections.pop(iter_req_id, None)

@deprecated_args(
args_to_warn=["*"],
reason="Use get_connection() without args instead",
version="5.3.0",
)
async def get_connection(self, command_name=None, *keys, **options):
"""
Get a connection from the pool, with special handling for scan commands.

For scan commands with iter_req_id, ensures the same replica is used
throughout the iteration to maintain cursor consistency.
"""
iter_req_id = options.get("iter_req_id")

# For scan commands with iter_req_id, ensure we use the same replica
if iter_req_id and not self.is_master:
# Check if we've already established a connection for this iteration
if iter_req_id in self._iter_req_connections:
target_address = self._iter_req_connections[iter_req_id]
connection = await super().get_connection()
# If the connection doesn't match our target, try to get the right one
if (connection.host, connection.port) != target_address:
# Release this connection and try to find one for the target replica
await self.release(connection)
# For now, use the connection we got and update tracking
connection = await super().get_connection()
await connection.connect_to(target_address)
return connection
else:
# First time for this iter_req_id, get a connection and track its replica
connection = await super().get_connection()
# Get the replica address this connection will use
if hasattr(connection, "connect_to"):
# Let the connection establish to its target replica
try:
replica_address = await self.rotate_slaves().__anext__()
Copy link
Preview

Copilot AI Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using __anext__() directly on an async generator is not the recommended approach. Consider using anext() builtin function instead: replica_address = await anext(self.rotate_slaves()).

Copilot uses AI. Check for mistakes.

await connection.connect_to(replica_address)
# Track this replica for future requests with this iter_req_id
self._iter_req_connections[iter_req_id] = replica_address
except (SlaveNotFoundError, StopAsyncIteration):
# Fallback to normal connection if no slaves available
pass
return connection

# For non-scan commands or master connections, use normal behavior
return await super().get_connection()


class Sentinel(AsyncSentinelCommands):
"""
Expand Down
20 changes: 19 additions & 1 deletion redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,13 +658,31 @@ def _execute_command(self, *args, **options):
if self._single_connection_client:
self.single_connection_lock.acquire()
try:
return conn.retry.call_with_retry(
result = conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda _: self._close_connection(conn),
)

# Clean up iter_req_id for SCAN family commands when the cursor returns to 0
iter_req_id = options.get("iter_req_id")
if iter_req_id and command_name.upper() in (
"SCAN",
"SSCAN",
"HSCAN",
"ZSCAN",
):
if (
isinstance(result, (list, tuple))
and len(result) >= 2
and result[0] == 0
):
if hasattr(pool, "cleanup"):
pool.cleanup(iter_req_id)
Comment on lines +681 to +682
Copy link
Preview

Copilot AI Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable pool is not defined in this scope. It should be referencing conn.connection_pool or self.connection_pool to access the cleanup method.

Suggested change
if hasattr(pool, "cleanup"):
pool.cleanup(iter_req_id)
if hasattr(conn.connection_pool, "cleanup"):
conn.connection_pool.cleanup(iter_req_id)

Copilot uses AI. Check for mistakes.


return result

finally:
if conn and conn.should_reconnect():
self._close_connection(conn)
Expand Down
Loading