From dd2ee3f3070d16ea5c4016ca57ce9b3a14792688 Mon Sep 17 00:00:00 2001 From: Jonas Dittrich <58814480+Kakadus@users.noreply.github.com> Date: Thu, 28 Nov 2024 19:48:23 +0100 Subject: [PATCH 001/113] Add dynamic_startup_nodes parameter to async RedisCluster --- CHANGES | 1 + redis/asyncio/cluster.py | 19 +++++++++++++++++-- tests/test_asyncio/test_cluster.py | 21 +++++++++++++++++++++ 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/CHANGES b/CHANGES index 8750128b05..b955681b89 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,4 @@ + * Add dynamic_startup_nodes parameter to async RedisCluster (#2472) * Move doctests (doc code examples) to main branch * Update `ResponseT` type hint * Allow to control the minimum SSL version diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 4e82e5448f..2f187ca516 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -133,6 +133,14 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand | Enable read from replicas in READONLY mode. You can read possibly stale data. When set to true, read commands will be assigned between the primary and its replications in a Round-Robin manner. + :param dynamic_startup_nodes: + | Set the RedisCluster's startup nodes to all the discovered nodes. + If true (default value), the cluster's discovered nodes will be used to + determine the cluster nodes-slots mapping in the next topology refresh. + It will remove the initial passed startup nodes if their endpoints aren't + listed in the CLUSTER SLOTS output. + If you use dynamic DNS endpoints for startup nodes but CLUSTER SLOTS lists + specific IP addresses, it is best to set it to false. :param reinitialize_steps: | Specifies the number of MOVED errors that need to occur before reinitializing the whole cluster topology. If a MOVED error occurs and the cluster does not @@ -233,6 +241,7 @@ def __init__( startup_nodes: Optional[List["ClusterNode"]] = None, require_full_coverage: bool = True, read_from_replicas: bool = False, + dynamic_startup_nodes: bool = True, reinitialize_steps: int = 5, cluster_error_retry_attempts: int = 3, connection_error_retry_attempts: int = 3, @@ -370,6 +379,7 @@ def __init__( startup_nodes, require_full_coverage, kwargs, + dynamic_startup_nodes=dynamic_startup_nodes, address_remap=address_remap, ) self.encoder = Encoder(encoding, encoding_errors, decode_responses) @@ -1093,6 +1103,7 @@ class NodesManager: "require_full_coverage", "slots_cache", "startup_nodes", + "_dynamic_startup_nodes", "address_remap", ) @@ -1101,11 +1112,13 @@ def __init__( startup_nodes: List["ClusterNode"], require_full_coverage: bool, connection_kwargs: Dict[str, Any], + dynamic_startup_nodes: bool = True, address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, ) -> None: self.startup_nodes = {node.name: node for node in startup_nodes} self.require_full_coverage = require_full_coverage self.connection_kwargs = connection_kwargs + self._dynamic_startup_nodes = dynamic_startup_nodes self.address_remap = address_remap self.default_node: "ClusterNode" = None @@ -1338,8 +1351,10 @@ async def initialize(self) -> None: # Set the tmp variables to the real variables self.slots_cache = tmp_slots self.set_nodes(self.nodes_cache, tmp_nodes_cache, remove_old=True) - # Populate the startup nodes with all discovered nodes - self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True) + + if self._dynamic_startup_nodes: + # Populate the startup nodes with all discovered nodes + self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True) # Set the default node self.default_node = self.get_nodes_by_server_type(PRIMARY)[0] diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 477397dd5f..4dfbd76176 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -2620,6 +2620,27 @@ def cmd_init_mock(self, r: ClusterNode) -> None: assert rc.get_node(host=default_host, port=7001) is not None assert rc.get_node(host=default_host, port=7002) is not None + @pytest.mark.parametrize("dynamic_startup_nodes", [True, False]) + async def test_init_slots_dynamic_startup_nodes(self, dynamic_startup_nodes): + rc = await get_mocked_redis_client( + host="my@DNS.com", + port=7000, + cluster_slots=default_cluster_slots, + dynamic_startup_nodes=dynamic_startup_nodes, + ) + # Nodes are taken from default_cluster_slots + discovered_nodes = [ + "127.0.0.1:7000", + "127.0.0.1:7001", + "127.0.0.1:7002", + "127.0.0.1:7003", + ] + startup_nodes = list(rc.nodes_manager.startup_nodes.keys()) + if dynamic_startup_nodes is True: + assert startup_nodes.sort() == discovered_nodes.sort() + else: + assert startup_nodes == ["my@DNS.com:7000"] + class TestClusterPipeline: """Tests for the ClusterPipeline class.""" From 32d45c96efa0dacea2a87b52a8bd45672c5f15a5 Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Mon, 30 Dec 2024 11:06:35 +0200 Subject: [PATCH 002/113] Fixed flacky TokenManager test (#3468) * Fixed flacky TokenManager test * Fixed additional flacky test * Removed token count assertion * Skipped test on version 3.9 --- tests/test_auth/test_token_manager.py | 36 +++++++++++---------------- tests/test_connection.py | 2 ++ 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/tests/test_auth/test_token_manager.py b/tests/test_auth/test_token_manager.py index bb396e246c..cdbf60889d 100644 --- a/tests/test_auth/test_token_manager.py +++ b/tests/test_auth/test_token_manager.py @@ -17,17 +17,17 @@ class TestTokenManager: @pytest.mark.parametrize( - "exp_refresh_ratio,tokens_refreshed", + "exp_refresh_ratio", [ - (0.9, 2), - (0.28, 4), + 0.9, + 0.28, ], ids=[ - "Refresh ratio = 0.9, 2 tokens in 0,1 second", - "Refresh ratio = 0.28, 4 tokens in 0,1 second", + "Refresh ratio = 0.9", + "Refresh ratio = 0.28", ], ) - def test_success_token_renewal(self, exp_refresh_ratio, tokens_refreshed): + def test_success_token_renewal(self, exp_refresh_ratio): tokens = [] mock_provider = Mock(spec=IdentityProviderInterface) mock_provider.request_token.side_effect = [ @@ -39,14 +39,14 @@ def test_success_token_renewal(self, exp_refresh_ratio, tokens_refreshed): ), SimpleToken( "value", - (datetime.now(timezone.utc).timestamp() * 1000) + 130, - (datetime.now(timezone.utc).timestamp() * 1000) + 30, + (datetime.now(timezone.utc).timestamp() * 1000) + 150, + (datetime.now(timezone.utc).timestamp() * 1000) + 50, {"oid": "test"}, ), SimpleToken( "value", - (datetime.now(timezone.utc).timestamp() * 1000) + 160, - (datetime.now(timezone.utc).timestamp() * 1000) + 60, + (datetime.now(timezone.utc).timestamp() * 1000) + 170, + (datetime.now(timezone.utc).timestamp() * 1000) + 70, {"oid": "test"}, ), SimpleToken( @@ -70,7 +70,7 @@ def on_next(token): mgr.start(mock_listener) sleep(0.1) - assert len(tokens) == tokens_refreshed + assert len(tokens) > 0 @pytest.mark.parametrize( "exp_refresh_ratio,tokens_refreshed", @@ -176,19 +176,13 @@ def test_token_renewal_with_skip_initial(self): mock_provider.request_token.side_effect = [ SimpleToken( "value", - (datetime.now(timezone.utc).timestamp() * 1000) + 100, + (datetime.now(timezone.utc).timestamp() * 1000) + 50, (datetime.now(timezone.utc).timestamp() * 1000), {"oid": "test"}, ), SimpleToken( "value", - (datetime.now(timezone.utc).timestamp() * 1000) + 120, - (datetime.now(timezone.utc).timestamp() * 1000), - {"oid": "test"}, - ), - SimpleToken( - "value", - (datetime.now(timezone.utc).timestamp() * 1000) + 140, + (datetime.now(timezone.utc).timestamp() * 1000) + 150, (datetime.now(timezone.utc).timestamp() * 1000), {"oid": "test"}, ), @@ -207,9 +201,9 @@ def on_next(token): mgr.start(mock_listener, skip_initial=True) # Should be less than a 0.1, or it will be flacky due to # additional token renewal. - sleep(0.2) + sleep(0.1) - assert len(tokens) == 2 + assert len(tokens) == 1 @pytest.mark.asyncio async def test_async_token_renewal_with_skip_initial(self): diff --git a/tests/test_connection.py b/tests/test_connection.py index 7683a1416d..65d80e2574 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,6 +1,7 @@ import copy import platform import socket +import sys import threading import types from typing import Any @@ -249,6 +250,7 @@ def get_redis_connection(): r1.close() +@pytest.mark.skipif(sys.version_info == (3, 9), reason="Flacky test on Python 3.9") @pytest.mark.parametrize("from_url", (True, False), ids=("from_url", "from_args")) def test_redis_connection_pool(request, from_url): """Verify that basic Redis instances using `connection_pool` From 3f4cde20cc3c98483e7ccd503e73176d95ef36b7 Mon Sep 17 00:00:00 2001 From: zs-neo <48560952+zs-neo@users.noreply.github.com> Date: Mon, 30 Dec 2024 18:26:05 +0800 Subject: [PATCH 003/113] Fix incorrect attribute reuse (#3456) add CacheEntry Co-authored-by: zhousheng06 Co-authored-by: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> --- redis/connection.py | 6 +++++- tests/test_connection.py | 5 +---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index 9d29b4aba6..ace1ed8727 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -904,9 +904,11 @@ def read_response( and self._cache.get(self._current_command_cache_key).status != CacheEntryStatus.IN_PROGRESS ): - return copy.deepcopy( + res = copy.deepcopy( self._cache.get(self._current_command_cache_key).cache_value ) + self._current_command_cache_key = None + return res response = self._conn.read_response( disable_decoding=disable_decoding, @@ -932,6 +934,8 @@ def read_response( cache_entry.cache_value = response self._cache.set(cache_entry) + self._current_command_cache_key = None + return response def pack_command(self, *args): diff --git a/tests/test_connection.py b/tests/test_connection.py index 65d80e2574..fbc23ae8c0 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -501,9 +501,9 @@ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection): ) proxy_connection.send_command(*["GET", "foo"], **{"keys": ["foo"]}) assert proxy_connection.read_response() == b"bar" + assert proxy_connection._current_command_cache_key is None assert proxy_connection.read_response() == b"bar" - mock_connection.read_response.assert_called_once() mock_cache.set.assert_has_calls( [ call( @@ -530,9 +530,6 @@ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection): call(CacheKey(command="GET", redis_keys=("foo",))), call(CacheKey(command="GET", redis_keys=("foo",))), call(CacheKey(command="GET", redis_keys=("foo",))), - call(CacheKey(command="GET", redis_keys=("foo",))), - call(CacheKey(command="GET", redis_keys=("foo",))), - call(CacheKey(command="GET", redis_keys=("foo",))), ] ) From cf181be9ff888a5406cde495fa4b8b31297d814f Mon Sep 17 00:00:00 2001 From: Shantanu <12621235+hauntsaninja@users.noreply.github.com> Date: Mon, 6 Jan 2025 04:39:41 -0800 Subject: [PATCH 004/113] Expand type for EncodedT (#3472) As of PEP 688, type checkers will no longer implicitly consider bytearray to be compatible with bytes --- redis/typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/typing.py b/redis/typing.py index b4d442c444..24ad607480 100644 --- a/redis/typing.py +++ b/redis/typing.py @@ -20,7 +20,7 @@ Number = Union[int, float] -EncodedT = Union[bytes, memoryview] +EncodedT = Union[bytes, bytearray, memoryview] DecodedT = Union[str, int, float] EncodableT = Union[EncodedT, DecodedT] AbsExpiryT = Union[int, datetime] From 0898252c76db267fb2edb5883b9475179ee66779 Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Mon, 6 Jan 2025 17:23:03 +0200 Subject: [PATCH 005/113] Moved self._lock initialisation to Pool constructor (#3473) * Moved self._lock initialisation to Pool constructor * Added test case * Codestyle fixes * Added correct annotations --- redis/connection.py | 2 +- tests/test_connection_pool.py | 28 ++++++++++++++++++++++++---- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index ace1ed8727..d905c6481b 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1378,6 +1378,7 @@ def __init__( # will notice the first thread already did the work and simply # release the lock. self._fork_lock = threading.Lock() + self._lock = threading.Lock() self.reset() def __repr__(self) -> (str, str): @@ -1395,7 +1396,6 @@ def get_protocol(self): return self.connection_kwargs.get("protocol", None) def reset(self) -> None: - self._lock = threading.Lock() self._created_connections = 0 self._available_connections = [] self._in_use_connections = set() diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index dee7c554d3..118294ee1b 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -7,10 +7,16 @@ import pytest import redis -from redis.connection import to_bool -from redis.utils import SSL_AVAILABLE - -from .conftest import _get_client, skip_if_redis_enterprise, skip_if_server_version_lt +from redis.cache import CacheConfig +from redis.connection import CacheProxyConnection, Connection, to_bool +from redis.utils import HIREDIS_AVAILABLE, SSL_AVAILABLE + +from .conftest import ( + _get_client, + skip_if_redis_enterprise, + skip_if_resp_version, + skip_if_server_version_lt, +) from .test_pubsub import wait_for_message @@ -196,6 +202,20 @@ def test_repr_contains_db_info_unix(self): expected = "path=abc,db=0,client_name=test-client" assert expected in repr(pool) + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") + @pytest.mark.onlynoncluster + @skip_if_resp_version(2) + @skip_if_server_version_lt("7.4.0") + def test_initialise_pool_with_cache(self, master_host): + pool = redis.BlockingConnectionPool( + connection_class=Connection, + host=master_host[0], + port=master_host[1], + protocol=3, + cache_config=CacheConfig(), + ) + assert isinstance(pool.get_connection("_"), CacheProxyConnection) + class TestConnectionPoolURLParsing: def test_hostname(self): From 08e9e17fad65317cf83b30f21cae44a08564dc17 Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Thu, 9 Jan 2025 14:21:33 +0200 Subject: [PATCH 006/113] Changed default dialect to 2 (#3467) * Changed default dialect to 2 * Codestyle fixes * Fixed async tests * Added handling of RESP3 responses * Fixed flacky tests * Codestyle fix * Added separate file to hold default value --- redis/commands/search/aggregation.py | 4 +- redis/commands/search/dialect.py | 3 + redis/commands/search/query.py | 4 +- tests/test_asyncio/test_search.py | 2 +- tests/test_auth/test_token_manager.py | 20 ++--- tests/test_search.py | 118 ++++++++++++++++++-------- 6 files changed, 102 insertions(+), 49 deletions(-) create mode 100644 redis/commands/search/dialect.py diff --git a/redis/commands/search/aggregation.py b/redis/commands/search/aggregation.py index 5638f1d662..13edefa081 100644 --- a/redis/commands/search/aggregation.py +++ b/redis/commands/search/aggregation.py @@ -1,5 +1,7 @@ from typing import List, Union +from redis.commands.search.dialect import DEFAULT_DIALECT + FIELDNAME = object() @@ -110,7 +112,7 @@ def __init__(self, query: str = "*") -> None: self._with_schema = False self._verbatim = False self._cursor = [] - self._dialect = None + self._dialect = DEFAULT_DIALECT self._add_scores = False self._scorer = "TFIDF" diff --git a/redis/commands/search/dialect.py b/redis/commands/search/dialect.py new file mode 100644 index 0000000000..828b3f2a43 --- /dev/null +++ b/redis/commands/search/dialect.py @@ -0,0 +1,3 @@ +# Value for the default dialect to be used as a part of +# Search or Aggregate query. +DEFAULT_DIALECT = 2 diff --git a/redis/commands/search/query.py b/redis/commands/search/query.py index 84d60a7cec..964ce6cdf4 100644 --- a/redis/commands/search/query.py +++ b/redis/commands/search/query.py @@ -1,5 +1,7 @@ from typing import List, Optional, Union +from redis.commands.search.dialect import DEFAULT_DIALECT + class Query: """ @@ -40,7 +42,7 @@ def __init__(self, query_string: str) -> None: self._highlight_fields: List = [] self._language: Optional[str] = None self._expander: Optional[str] = None - self._dialect: Optional[int] = None + self._dialect: int = DEFAULT_DIALECT def query_string(self) -> str: """Return the query string of this query only.""" diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 5260605039..cc75e4b4a4 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -1646,7 +1646,7 @@ async def test_search_commands_in_pipeline(decoded_r: redis.Redis): @pytest.mark.redismod async def test_query_timeout(decoded_r: redis.Redis): q1 = Query("foo").timeout(5000) - assert q1.get_args() == ["foo", "TIMEOUT", 5000, "LIMIT", 0, 10] + assert q1.get_args() == ["foo", "TIMEOUT", 5000, "DIALECT", 2, "LIMIT", 0, 10] q2 = Query("foo").timeout("not_a_number") with pytest.raises(redis.ResponseError): await decoded_r.ft().search(q2) diff --git a/tests/test_auth/test_token_manager.py b/tests/test_auth/test_token_manager.py index cdbf60889d..f675c125dd 100644 --- a/tests/test_auth/test_token_manager.py +++ b/tests/test_auth/test_token_manager.py @@ -73,20 +73,18 @@ def on_next(token): assert len(tokens) > 0 @pytest.mark.parametrize( - "exp_refresh_ratio,tokens_refreshed", + "exp_refresh_ratio", [ - (0.9, 2), - (0.28, 4), + (0.9), + (0.28), ], ids=[ - "Refresh ratio = 0.9, 2 tokens in 0,1 second", - "Refresh ratio = 0.28, 4 tokens in 0,1 second", + "Refresh ratio = 0.9", + "Refresh ratio = 0.28", ], ) @pytest.mark.asyncio - async def test_async_success_token_renewal( - self, exp_refresh_ratio, tokens_refreshed - ): + async def test_async_success_token_renewal(self, exp_refresh_ratio): tokens = [] mock_provider = Mock(spec=IdentityProviderInterface) mock_provider.request_token.side_effect = [ @@ -129,7 +127,7 @@ async def on_next(token): await mgr.start_async(mock_listener, block_for_initial=True) await asyncio.sleep(0.1) - assert len(tokens) == tokens_refreshed + assert len(tokens) > 0 @pytest.mark.parametrize( "block_for_initial,tokens_acquired", @@ -203,7 +201,7 @@ def on_next(token): # additional token renewal. sleep(0.1) - assert len(tokens) == 1 + assert len(tokens) > 0 @pytest.mark.asyncio async def test_async_token_renewal_with_skip_initial(self): @@ -245,7 +243,7 @@ async def on_next(token): # due to additional token renewal. await asyncio.sleep(0.2) - assert len(tokens) == 2 + assert len(tokens) > 0 def test_success_token_renewal_with_retry(self): tokens = [] diff --git a/tests/test_search.py b/tests/test_search.py index c6e9a3717f..a257484425 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -2122,7 +2122,7 @@ def test_profile_query_params(client): client.hset("b", "v", "aaaabaaa") client.hset("c", "v", "aaaaabaa") query = "*=>[KNN 2 @v $vec]" - q = Query(query).return_field("__v_score").sort_by("__v_score", True).dialect(2) + q = Query(query).return_field("__v_score").sort_by("__v_score", True) if is_resp2_connection(client): res, det = client.ft().profile(q, query_params={"vec": "aaaaaaaa"}) assert det["Iterators profile"]["Counter"] == 2.0 @@ -2155,7 +2155,7 @@ def test_vector_field(client): client.hset("c", "v", "aaaaabaa") query = "*=>[KNN 2 @v $vec]" - q = Query(query).return_field("__v_score").sort_by("__v_score", True).dialect(2) + q = Query(query).return_field("__v_score").sort_by("__v_score", True) res = client.ft().search(q, query_params={"vec": "aaaaaaaa"}) if is_resp2_connection(client): @@ -2191,7 +2191,7 @@ def test_text_params(client): client.hset("doc3", mapping={"name": "Carol"}) params_dict = {"name1": "Alice", "name2": "Bob"} - q = Query("@name:($name1 | $name2 )").dialect(2) + q = Query("@name:($name1 | $name2 )") res = client.ft().search(q, query_params=params_dict) if is_resp2_connection(client): assert 2 == res.total @@ -2214,7 +2214,7 @@ def test_numeric_params(client): client.hset("doc3", mapping={"numval": 103}) params_dict = {"min": 101, "max": 102} - q = Query("@numval:[$min $max]").dialect(2) + q = Query("@numval:[$min $max]") res = client.ft().search(q, query_params=params_dict) if is_resp2_connection(client): @@ -2236,7 +2236,7 @@ def test_geo_params(client): client.hset("doc3", mapping={"g": "29.68746, 34.94882"}) params_dict = {"lat": "34.95126", "lon": "29.69465", "radius": 1000, "units": "km"} - q = Query("@g:[$lon $lat $radius $units]").dialect(2) + q = Query("@g:[$lon $lat $radius $units]") res = client.ft().search(q, query_params=params_dict) _assert_search_result(client, res, ["doc1", "doc2", "doc3"]) @@ -2355,19 +2355,19 @@ def test_dialect(client): with pytest.raises(redis.ResponseError) as err: client.ft().explain(Query("(*)").dialect(1)) assert "Syntax error" in str(err) - assert "WILDCARD" in client.ft().explain(Query("(*)").dialect(2)) + assert "WILDCARD" in client.ft().explain(Query("(*)")) with pytest.raises(redis.ResponseError) as err: client.ft().explain(Query("$hello").dialect(1)) assert "Syntax error" in str(err) - q = Query("$hello").dialect(2) + q = Query("$hello") expected = "UNION {\n hello\n +hello(expanded)\n}\n" assert expected in client.ft().explain(q, query_params={"hello": "hello"}) expected = "NUMERIC {0.000000 <= @num <= 10.000000}\n" assert expected in client.ft().explain(Query("@title:(@num:[0 10])").dialect(1)) with pytest.raises(redis.ResponseError) as err: - client.ft().explain(Query("@title:(@num:[0 10])").dialect(2)) + client.ft().explain(Query("@title:(@num:[0 10])")) assert "Syntax error" in str(err) @@ -2438,9 +2438,9 @@ def test_withsuffixtrie(client: redis.Redis): @pytest.mark.redismod def test_query_timeout(r: redis.Redis): q1 = Query("foo").timeout(5000) - assert q1.get_args() == ["foo", "TIMEOUT", 5000, "LIMIT", 0, 10] + assert q1.get_args() == ["foo", "TIMEOUT", 5000, "DIALECT", 2, "LIMIT", 0, 10] q1 = Query("foo").timeout(0) - assert q1.get_args() == ["foo", "TIMEOUT", 0, "LIMIT", 0, 10] + assert q1.get_args() == ["foo", "TIMEOUT", 0, "DIALECT", 2, "LIMIT", 0, 10] q2 = Query("foo").timeout("not_a_number") with pytest.raises(redis.ResponseError): r.ft().search(q2) @@ -2507,28 +2507,26 @@ def test_search_missing_fields(client): ) with pytest.raises(redis.exceptions.ResponseError) as e: - client.ft().search( - Query("ismissing(@title)").dialect(2).return_field("id").no_content() - ) + client.ft().search(Query("ismissing(@title)").return_field("id").no_content()) assert "to be defined with 'INDEXMISSING'" in e.value.args[0] res = client.ft().search( - Query("ismissing(@features)").dialect(2).return_field("id").no_content() + Query("ismissing(@features)").return_field("id").no_content() ) _assert_search_result(client, res, ["property:2"]) res = client.ft().search( - Query("-ismissing(@features)").dialect(2).return_field("id").no_content() + Query("-ismissing(@features)").return_field("id").no_content() ) _assert_search_result(client, res, ["property:1", "property:3"]) res = client.ft().search( - Query("ismissing(@description)").dialect(2).return_field("id").no_content() + Query("ismissing(@description)").return_field("id").no_content() ) _assert_search_result(client, res, ["property:3"]) res = client.ft().search( - Query("-ismissing(@description)").dialect(2).return_field("id").no_content() + Query("-ismissing(@description)").return_field("id").no_content() ) _assert_search_result(client, res, ["property:1", "property:2"]) @@ -2578,31 +2576,25 @@ def test_search_empty_fields(client): ) with pytest.raises(redis.exceptions.ResponseError) as e: - client.ft().search( - Query("@title:''").dialect(2).return_field("id").no_content() - ) + client.ft().search(Query("@title:''").return_field("id").no_content()) assert "Use `INDEXEMPTY` in field creation" in e.value.args[0] res = client.ft().search( - Query("@features:{$empty}").dialect(2).return_field("id").no_content(), + Query("@features:{$empty}").return_field("id").no_content(), query_params={"empty": ""}, ) _assert_search_result(client, res, ["property:2"]) res = client.ft().search( - Query("-@features:{$empty}").dialect(2).return_field("id").no_content(), + Query("-@features:{$empty}").return_field("id").no_content(), query_params={"empty": ""}, ) _assert_search_result(client, res, ["property:1", "property:3"]) - res = client.ft().search( - Query("@description:''").dialect(2).return_field("id").no_content() - ) + res = client.ft().search(Query("@description:''").return_field("id").no_content()) _assert_search_result(client, res, ["property:3"]) - res = client.ft().search( - Query("-@description:''").dialect(2).return_field("id").no_content() - ) + res = client.ft().search(Query("-@description:''").return_field("id").no_content()) _assert_search_result(client, res, ["property:1", "property:2"]) @@ -2643,29 +2635,85 @@ def test_special_characters_in_fields(client): # no need to escape - when using params res = client.ft().search( - Query("@uuid:{$uuid}").dialect(2), + Query("@uuid:{$uuid}"), query_params={"uuid": "123e4567-e89b-12d3-a456-426614174000"}, ) _assert_search_result(client, res, ["resource:1"]) # with double quotes exact match no need to escape the - even without params - res = client.ft().search( - Query('@uuid:{"123e4567-e89b-12d3-a456-426614174000"}').dialect(2) - ) + res = client.ft().search(Query('@uuid:{"123e4567-e89b-12d3-a456-426614174000"}')) _assert_search_result(client, res, ["resource:1"]) - res = client.ft().search(Query('@tags:{"new-year\'s-resolutions"}').dialect(2)) + res = client.ft().search(Query('@tags:{"new-year\'s-resolutions"}')) _assert_search_result(client, res, ["resource:2"]) # possible to search numeric fields by single value - res = client.ft().search(Query("@rating:[4]").dialect(2)) + res = client.ft().search(Query("@rating:[4]")) _assert_search_result(client, res, ["resource:2"]) # some chars still need escaping - res = client.ft().search(Query(r"@tags:{\$btc}").dialect(2)) + res = client.ft().search(Query(r"@tags:{\$btc}")) _assert_search_result(client, res, ["resource:1"]) +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +def test_vector_search_with_default_dialect(client): + client.ft().create_index( + ( + VectorField( + "v", "HNSW", {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "L2"} + ), + ) + ) + + client.hset("a", "v", "aaaaaaaa") + client.hset("b", "v", "aaaabaaa") + client.hset("c", "v", "aaaaabaa") + + query = "*=>[KNN 2 @v $vec]" + q = Query(query) + + assert "DIALECT" in q.get_args() + assert 2 in q.get_args() + + res = client.ft().search(q, query_params={"vec": "aaaaaaaa"}) + if is_resp2_connection(client): + assert res.total == 2 + else: + assert res["total_results"] == 2 + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.4.3", "search") +def test_search_query_with_different_dialects(client): + client.ft().create_index( + (TextField("name"), TextField("lastname")), + definition=IndexDefinition(prefix=["test:"]), + ) + + client.hset("test:1", "name", "James") + client.hset("test:1", "lastname", "Brown") + + # Query with default DIALECT 2 + query = "@name: James Brown" + q = Query(query) + res = client.ft().search(q) + if is_resp2_connection(client): + assert res.total == 1 + else: + assert res["total_results"] == 1 + + # Query with explicit DIALECT 1 + query = "@name: James Brown" + q = Query(query).dialect(1) + res = client.ft().search(q) + if is_resp2_connection(client): + assert res.total == 0 + else: + assert res["total_results"] == 0 + + def _assert_search_result(client, result, expected_doc_ids): """ Make sure the result of a geo search is as expected, taking into account the RESP From 58c1604126f6e092f37e47514eb240ec5a9ebdca Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Wed, 15 Jan 2025 12:00:51 +0200 Subject: [PATCH 007/113] Moved ClusterParser exceptions to BaseParser class (#3475) * Moved ClusterParser exceptions to BaseParser class * Codestyle fixes * Removed ubused imports * Sorted imports --- docs/clustering.rst | 4 ++-- redis/_parsers/base.py | 12 ++++++++++++ redis/asyncio/cluster.py | 27 ++------------------------- redis/cluster.py | 19 +------------------ 4 files changed, 17 insertions(+), 45 deletions(-) diff --git a/docs/clustering.rst b/docs/clustering.rst index f8320e4e59..cf257d8ad5 100644 --- a/docs/clustering.rst +++ b/docs/clustering.rst @@ -17,8 +17,8 @@ Nodes <#specifying-target-nodes>`__ \| `Multi-key Commands <#multi-key-commands>`__ \| `Known PubSub Limitations <#known-pubsub-limitations>`__ -Creating clusters ------------------ +Connecting to cluster +--------------------- Connecting redis-py to a Redis Cluster instance(s) requires at a minimum a single node for cluster discovery. There are multiple ways in which a diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index 0137539d66..91a4f74199 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -9,18 +9,24 @@ from async_timeout import timeout as async_timeout from ..exceptions import ( + AskError, AuthenticationError, AuthenticationWrongNumberOfArgsError, BusyLoadingError, + ClusterCrossSlotError, + ClusterDownError, ConnectionError, ExecAbortError, + MasterDownError, ModuleError, + MovedError, NoPermissionError, NoScriptError, OutOfMemoryError, ReadOnlyError, RedisError, ResponseError, + TryAgainError, ) from ..typing import EncodableT from .encoders import Encoder @@ -72,6 +78,12 @@ class BaseParser(ABC): "READONLY": ReadOnlyError, "NOAUTH": AuthenticationError, "NOPERM": NoPermissionError, + "ASK": AskError, + "TRYAGAIN": TryAgainError, + "MOVED": MovedError, + "CLUSTERDOWN": ClusterDownError, + "CROSSSLOT": ClusterCrossSlotError, + "MASTERDOWN": MasterDownError, } @classmethod diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index dbc32047aa..0d6d130dcf 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -26,7 +26,7 @@ _RedisCallbacksRESP3, ) from redis.asyncio.client import ResponseCallbackT -from redis.asyncio.connection import Connection, DefaultParser, SSLConnection, parse_url +from redis.asyncio.connection import Connection, SSLConnection, parse_url from redis.asyncio.lock import Lock from redis.asyncio.retry import Retry from redis.auth.token import TokenInterface @@ -50,12 +50,10 @@ from redis.exceptions import ( AskError, BusyLoadingError, - ClusterCrossSlotError, ClusterDownError, ClusterError, ConnectionError, DataError, - MasterDownError, MaxConnectionsError, MovedError, RedisClusterException, @@ -66,33 +64,13 @@ TryAgainError, ) from redis.typing import AnyKeyT, EncodableT, KeyT -from redis.utils import ( - deprecated_function, - dict_merge, - get_lib_version, - safe_str, - str_if_bytes, -) +from redis.utils import deprecated_function, get_lib_version, safe_str, str_if_bytes TargetNodesT = TypeVar( "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] ) -class ClusterParser(DefaultParser): - EXCEPTION_CLASSES = dict_merge( - DefaultParser.EXCEPTION_CLASSES, - { - "ASK": AskError, - "CLUSTERDOWN": ClusterDownError, - "CROSSSLOT": ClusterCrossSlotError, - "MASTERDOWN": MasterDownError, - "MOVED": MovedError, - "TRYAGAIN": TryAgainError, - }, - ) - - class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): """ Create a new RedisCluster client. @@ -306,7 +284,6 @@ def __init__( kwargs: Dict[str, Any] = { "max_connections": max_connections, "connection_class": Connection, - "parser_class": ClusterParser, # Client related kwargs "credential_provider": credential_provider, "username": username, diff --git a/redis/cluster.py b/redis/cluster.py index 38bd5dde1a..8718493759 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -13,7 +13,7 @@ from redis.client import CaseInsensitiveDict, PubSub, Redis from redis.commands import READ_COMMANDS, RedisClusterCommands from redis.commands.helpers import list_or_args -from redis.connection import ConnectionPool, DefaultParser, parse_url +from redis.connection import ConnectionPool, parse_url from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot from redis.event import ( AfterPooledConnectionsInstantiationEvent, @@ -24,12 +24,10 @@ from redis.exceptions import ( AskError, AuthenticationError, - ClusterCrossSlotError, ClusterDownError, ClusterError, ConnectionError, DataError, - MasterDownError, MovedError, RedisClusterException, RedisError, @@ -193,20 +191,6 @@ def cleanup_kwargs(**kwargs): return connection_kwargs -class ClusterParser(DefaultParser): - EXCEPTION_CLASSES = dict_merge( - DefaultParser.EXCEPTION_CLASSES, - { - "ASK": AskError, - "TRYAGAIN": TryAgainError, - "MOVED": MovedError, - "CLUSTERDOWN": ClusterDownError, - "CROSSSLOT": ClusterCrossSlotError, - "MASTERDOWN": MasterDownError, - }, - ) - - class AbstractRedisCluster: RedisClusterRequestTTL = 16 @@ -692,7 +676,6 @@ def on_connect(self, connection): Initialize the connection, authenticate and select a database and send READONLY if it is set during object initialization. """ - connection.set_parser(ClusterParser) connection.on_connect() if self.read_from_replicas: From 570ac929f9a2b2cfe0d708f297340724f7b07633 Mon Sep 17 00:00:00 2001 From: David Dougherty Date: Fri, 17 Jan 2025 00:28:35 -0800 Subject: [PATCH 008/113] DOC-4423: add TCEs for various command pages (#3476) Co-authored-by: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> --- doctests/cmds_cnxmgmt.py | 36 +++++++++++ doctests/cmds_hash.py | 24 +++++++ doctests/cmds_list.py | 123 ++++++++++++++++++++++++++++++++++++ doctests/cmds_servermgmt.py | 30 +++++++++ doctests/cmds_set.py | 35 ++++++++++ 5 files changed, 248 insertions(+) create mode 100644 doctests/cmds_cnxmgmt.py create mode 100644 doctests/cmds_list.py create mode 100644 doctests/cmds_servermgmt.py create mode 100644 doctests/cmds_set.py diff --git a/doctests/cmds_cnxmgmt.py b/doctests/cmds_cnxmgmt.py new file mode 100644 index 0000000000..c691f723cf --- /dev/null +++ b/doctests/cmds_cnxmgmt.py @@ -0,0 +1,36 @@ +# EXAMPLE: cmds_cnxmgmt +# HIDE_START +import redis + +r = redis.Redis(decode_responses=True) +# HIDE_END + +# STEP_START auth1 +# REMOVE_START +r.config_set("requirepass", "temp_pass") +# REMOVE_END +res1 = r.auth(password="temp_pass") +print(res1) # >>> True + +res2 = r.auth(password="temp_pass", username="default") +print(res2) # >>> True + +# REMOVE_START +assert res1 == True +assert res2 == True +r.config_set("requirepass", "") +# REMOVE_END +# STEP_END + +# STEP_START auth2 +# REMOVE_START +r.acl_setuser("test-user", enabled=True, passwords=["+strong_password"], commands=["+acl"]) +# REMOVE_END +res = r.auth(username="test-user", password="strong_password") +print(res) # >>> True + +# REMOVE_START +assert res == True +r.acl_deluser("test-user") +# REMOVE_END +# STEP_END diff --git a/doctests/cmds_hash.py b/doctests/cmds_hash.py index 0bc1cb8038..65bbd52d60 100644 --- a/doctests/cmds_hash.py +++ b/doctests/cmds_hash.py @@ -61,3 +61,27 @@ r.delete("myhash") # REMOVE_END # STEP_END + +# STEP_START hgetall +res10 = r.hset("myhash", mapping={"field1": "Hello", "field2": "World"}) + +res11 = r.hgetall("myhash") +print(res11) # >>> { "field1": "Hello", "field2": "World" } + +# REMOVE_START +assert res11 == { "field1": "Hello", "field2": "World" } +r.delete("myhash") +# REMOVE_END +# STEP_END + +# STEP_START hvals +res10 = r.hset("myhash", mapping={"field1": "Hello", "field2": "World"}) + +res11 = r.hvals("myhash") +print(res11) # >>> [ "Hello", "World" ] + +# REMOVE_START +assert res11 == [ "Hello", "World" ] +r.delete("myhash") +# REMOVE_END +# STEP_END \ No newline at end of file diff --git a/doctests/cmds_list.py b/doctests/cmds_list.py new file mode 100644 index 0000000000..cce2d540a8 --- /dev/null +++ b/doctests/cmds_list.py @@ -0,0 +1,123 @@ +# EXAMPLE: cmds_list +# HIDE_START +import redis + +r = redis.Redis(decode_responses=True) +# HIDE_END + +# STEP_START lpush +res1 = r.lpush("mylist", "world") +print(res1) # >>> 1 + +res2 = r.lpush("mylist", "hello") +print(res2) # >>> 2 + +res3 = r.lrange("mylist", 0, -1) +print(res3) # >>> [ "hello", "world" ] + +# REMOVE_START +assert res3 == [ "hello", "world" ] +r.delete("mylist") +# REMOVE_END +# STEP_END + +# STEP_START lrange +res4 = r.rpush("mylist", "one"); +print(res4) # >>> 1 + +res5 = r.rpush("mylist", "two") +print(res5) # >>> 2 + +res6 = r.rpush("mylist", "three") +print(res6) # >>> 3 + +res7 = r.lrange('mylist', 0, 0) +print(res7) # >>> [ 'one' ] + +res8 = r.lrange('mylist', -3, 2) +print(res8) # >>> [ 'one', 'two', 'three' ] + +res9 = r.lrange('mylist', -100, 100) +print(res9) # >>> [ 'one', 'two', 'three' ] + +res10 = r.lrange('mylist', 5, 10) +print(res10) # >>> [] + +# REMOVE_START +assert res7 == [ 'one' ] +assert res8 == [ 'one', 'two', 'three' ] +assert res9 == [ 'one', 'two', 'three' ] +assert res10 == [] +r.delete('mylist') +# REMOVE_END +# STEP_END + +# STEP_START llen +res11 = r.lpush("mylist", "World") +print(res11) # >>> 1 + +res12 = r.lpush("mylist", "Hello") +print(res12) # >>> 2 + +res13 = r.llen("mylist") +print(res13) # >>> 2 + +# REMOVE_START +assert res13 == 2 +r.delete("mylist") +# REMOVE_END +# STEP_END + +# STEP_START rpush +res14 = r.rpush("mylist", "hello") +print(res14) # >>> 1 + +res15 = r.rpush("mylist", "world") +print(res15) # >>> 2 + +res16 = r.lrange("mylist", 0, -1) +print(res16) # >>> [ "hello", "world" ] + +# REMOVE_START +assert res16 == [ "hello", "world" ] +r.delete("mylist") +# REMOVE_END +# STEP_END + +# STEP_START lpop +res17 = r.rpush("mylist", *["one", "two", "three", "four", "five"]) +print(res17) # >>> 5 + +res18 = r.lpop("mylist") +print(res18) # >>> "one" + +res19 = r.lpop("mylist", 2) +print(res19) # >>> ['two', 'three'] + +res17 = r.lrange("mylist", 0, -1) +print(res17) # >>> [ "four", "five" ] + +# REMOVE_START +assert res17 == [ "four", "five" ] +r.delete("mylist") +# REMOVE_END +# STEP_END + +# STEP_START rpop +res18 = r.rpush("mylist", *["one", "two", "three", "four", "five"]) +print(res18) # >>> 5 + +res19 = r.rpop("mylist") +print(res19) # >>> "five" + +res20 = r.rpop("mylist", 2) +print(res20) # >>> ['four', 'three'] + +res21 = r.lrange("mylist", 0, -1) +print(res21) # >>> [ "one", "two" ] + +# REMOVE_START +assert res21 == [ "one", "two" ] +r.delete("mylist") +# REMOVE_END +# STEP_END \ No newline at end of file diff --git a/doctests/cmds_servermgmt.py b/doctests/cmds_servermgmt.py new file mode 100644 index 0000000000..6ad2b6acb2 --- /dev/null +++ b/doctests/cmds_servermgmt.py @@ -0,0 +1,30 @@ +# EXAMPLE: cmds_servermgmt +# HIDE_START +import redis + +r = redis.Redis(decode_responses=True) +# HIDE_END + +# STEP_START flushall +# REMOVE_START +r.set("foo", "1") +r.set("bar", "2") +r.set("baz", "3") +# REMOVE_END +res1 = r.flushall(asynchronous=False) +print(res1) # >>> True + +res2 = r.keys() +print(res2) # >>> [] + +# REMOVE_START +assert res1 == True +assert res2 == [] +# REMOVE_END +# STEP_END + +# STEP_START info +res3 = r.info() +print(res3) +# >>> {'redis_version': '7.4.0', 'redis_git_sha1': 'c9d29f6a',...} +# STEP_END \ No newline at end of file diff --git a/doctests/cmds_set.py b/doctests/cmds_set.py new file mode 100644 index 0000000000..ece74e8cf0 --- /dev/null +++ b/doctests/cmds_set.py @@ -0,0 +1,35 @@ +# EXAMPLE: cmds_set +# HIDE_START +import redis + +r = redis.Redis(decode_responses=True) +# HIDE_END + +# STEP_START sadd +res1 = r.sadd("myset", "Hello", "World") +print(res1) # >>> 2 + +res2 = r.sadd("myset", "World") +print(res2) # >>> 0 + +res3 = r.smembers("myset") +print(res3) # >>> {'Hello', 'World'} + +# REMOVE_START +assert res3 == {'Hello', 'World'} +r.delete('myset') +# REMOVE_END +# STEP_END + +# STEP_START smembers +res4 = r.sadd("myset", "Hello", "World") +print(res4) # >>> 2 + +res5 = r.smembers("myset") +print(res5) # >>> {'Hello', 'World'} + +# REMOVE_START +assert res5 == {'Hello', 'World'} +r.delete('myset') +# REMOVE_END +# STEP_END \ No newline at end of file From 1c8af3a4c16919f32fcf63058b412a5e7f18d741 Mon Sep 17 00:00:00 2001 From: andy-stark-redis <164213578+andy-stark-redis@users.noreply.github.com> Date: Tue, 21 Jan 2025 09:06:48 +0000 Subject: [PATCH 009/113] DOC-4345 added testable JSON search examples for home page (#3407) * DOC-4345 added testable JSON search examples for home page * DOC-4345 avoid possible non-deterministic results in tests * DOC-4345 close connection at end of example * DOC-4345 remove unnecessary blank lines --- doctests/home_json.py | 137 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 doctests/home_json.py diff --git a/doctests/home_json.py b/doctests/home_json.py new file mode 100644 index 0000000000..922c83d2fe --- /dev/null +++ b/doctests/home_json.py @@ -0,0 +1,137 @@ +# EXAMPLE: py_home_json +""" +JSON examples from redis-py "home" page" + https://redis.io/docs/latest/develop/connect/clients/python/redis-py/#example-indexing-and-querying-json-documents +""" + +# STEP_START import +import redis +from redis.commands.json.path import Path +import redis.commands.search.aggregation as aggregations +import redis.commands.search.reducers as reducers +from redis.commands.search.field import TextField, NumericField, TagField +from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.query import Query +import redis.exceptions +# STEP_END + +# STEP_START connect +r = redis.Redis(decode_responses=True) +# STEP_END + +# REMOVE_START +try: + r.ft("idx:users").dropindex(True) +except redis.exceptions.ResponseError: + pass + +r.delete("user:1", "user:2", "user:3") +# REMOVE_END +# STEP_START create_data +user1 = { + "name": "Paul John", + "email": "paul.john@example.com", + "age": 42, + "city": "London" +} + +user2 = { + "name": "Eden Zamir", + "email": "eden.zamir@example.com", + "age": 29, + "city": "Tel Aviv" +} + +user3 = { + "name": "Paul Zamir", + "email": "paul.zamir@example.com", + "age": 35, + "city": "Tel Aviv" +} +# STEP_END + +# STEP_START make_index +schema = ( + TextField("$.name", as_name="name"), + TagField("$.city", as_name="city"), + NumericField("$.age", as_name="age") +) + +indexCreated = r.ft("idx:users").create_index( + schema, + definition=IndexDefinition( + prefix=["user:"], index_type=IndexType.JSON + ) +) +# STEP_END +# Tests for 'make_index' step. +# REMOVE_START +assert indexCreated +# REMOVE_END + +# STEP_START add_data +user1Set = r.json().set("user:1", Path.root_path(), user1) +user2Set = r.json().set("user:2", Path.root_path(), user2) +user3Set = r.json().set("user:3", Path.root_path(), user3) +# STEP_END +# Tests for 'add_data' step. +# REMOVE_START +assert user1Set +assert user2Set +assert user3Set +# REMOVE_END + +# STEP_START query1 +findPaulResult = r.ft("idx:users").search( + Query("Paul @age:[30 40]") +) + +print(findPaulResult) +# >>> Result{1 total, docs: [Document {'id': 'user:3', ... +# STEP_END +# Tests for 'query1' step. +# REMOVE_START +assert str(findPaulResult) == ( + "Result{1 total, docs: [Document {'id': 'user:3', 'payload': None, " + + "'json': '{\"name\":\"Paul Zamir\",\"email\":" + + "\"paul.zamir@example.com\",\"age\":35,\"city\":\"Tel Aviv\"}'}]}" +) +# REMOVE_END + +# STEP_START query2 +citiesResult = r.ft("idx:users").search( + Query("Paul").return_field("$.city", as_field="city") +).docs + +print(citiesResult) +# >>> [Document {'id': 'user:1', 'payload': None, ... +# STEP_END +# Tests for 'query2' step. +# REMOVE_START +citiesResult.sort(key=lambda doc: doc['id']) + +assert str(citiesResult) == ( + "[Document {'id': 'user:1', 'payload': None, 'city': 'London'}, " + + "Document {'id': 'user:3', 'payload': None, 'city': 'Tel Aviv'}]" +) +# REMOVE_END + +# STEP_START query3 +req = aggregations.AggregateRequest("*").group_by( + '@city', reducers.count().alias('count') +) + +aggResult = r.ft("idx:users").aggregate(req).rows +print(aggResult) +# >>> [['city', 'London', 'count', '1'], ['city', 'Tel Aviv', 'count', '2']] +# STEP_END +# Tests for 'query3' step. +# REMOVE_START +aggResult.sort(key=lambda row: row[1]) + +assert str(aggResult) == ( + "[['city', 'London', 'count', '1'], ['city', 'Tel Aviv', 'count', '2']]" +) +# REMOVE_END + +r.close() From 0d790132db41eff3e2fc750149448bb464805d8b Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Tue, 21 Jan 2025 11:24:42 +0200 Subject: [PATCH 010/113] Added Redis 8.0 to test matrix (#3469) * Added Redis 8.0 to test matrix * Fixed test cases * Added version annotation * Changed FT.PROFILE response type * Added version restrictions * Updated file names, fixed tests assertions * Removed unused API --- .github/workflows/integration.yaml | 2 +- doctests/query_agg.py | 2 +- doctests/query_combined.py | 2 +- doctests/query_em.py | 2 +- doctests/query_ft.py | 2 +- doctests/query_geo.py | 2 +- doctests/query_range.py | 2 +- doctests/search_quickstart.py | 2 +- doctests/search_vss.py | 2 +- redis/commands/helpers.py | 23 --- redis/commands/search/commands.py | 9 +- ...indexDefinition.py => index_definition.py} | 0 redis/commands/search/profile_information.py | 14 ++ tests/test_asyncio/test_search.py | 18 ++- tests/test_commands.py | 26 +++ tests/test_helpers.py | 35 ---- tests/test_search.py | 149 +++++++++++++++--- 17 files changed, 195 insertions(+), 97 deletions(-) rename redis/commands/search/{indexDefinition.py => index_definition.py} (100%) create mode 100644 redis/commands/search/profile_information.py diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index 7c74de5290..c32029e6f9 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -74,7 +74,7 @@ jobs: max-parallel: 15 fail-fast: false matrix: - redis-version: [ '${{ needs.redis_version.outputs.CURRENT }}', '7.2.6', '6.2.16'] + redis-version: ['8.0-M02', '${{ needs.redis_version.outputs.CURRENT }}', '7.2.6', '6.2.16'] python-version: ['3.8', '3.12'] parser-backend: ['plain'] event-loop: ['asyncio'] diff --git a/doctests/query_agg.py b/doctests/query_agg.py index 4fa8f14b84..4d81ddbcda 100644 --- a/doctests/query_agg.py +++ b/doctests/query_agg.py @@ -6,7 +6,7 @@ from redis.commands.search import Search from redis.commands.search.aggregation import AggregateRequest from redis.commands.search.field import NumericField, TagField -from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.index_definition import IndexDefinition, IndexType import redis.commands.search.reducers as reducers r = redis.Redis(decode_responses=True) diff --git a/doctests/query_combined.py b/doctests/query_combined.py index a17f19417c..e6dd5a2cb5 100644 --- a/doctests/query_combined.py +++ b/doctests/query_combined.py @@ -6,7 +6,7 @@ import warnings from redis.commands.json.path import Path from redis.commands.search.field import NumericField, TagField, TextField, VectorField -from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.index_definition import IndexDefinition, IndexType from redis.commands.search.query import Query from sentence_transformers import SentenceTransformer diff --git a/doctests/query_em.py b/doctests/query_em.py index a00ff11150..91cc5ae940 100644 --- a/doctests/query_em.py +++ b/doctests/query_em.py @@ -4,7 +4,7 @@ import redis from redis.commands.json.path import Path from redis.commands.search.field import TextField, NumericField, TagField -from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.index_definition import IndexDefinition, IndexType from redis.commands.search.query import NumericFilter, Query r = redis.Redis(decode_responses=True) diff --git a/doctests/query_ft.py b/doctests/query_ft.py index 182a5b2bd3..6272cdab25 100644 --- a/doctests/query_ft.py +++ b/doctests/query_ft.py @@ -5,7 +5,7 @@ import redis from redis.commands.json.path import Path from redis.commands.search.field import TextField, NumericField, TagField -from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.index_definition import IndexDefinition, IndexType from redis.commands.search.query import NumericFilter, Query r = redis.Redis(decode_responses=True) diff --git a/doctests/query_geo.py b/doctests/query_geo.py index dcb7db6ee7..ed8c9a5f99 100644 --- a/doctests/query_geo.py +++ b/doctests/query_geo.py @@ -5,7 +5,7 @@ import redis from redis.commands.json.path import Path from redis.commands.search.field import GeoField, GeoShapeField -from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.index_definition import IndexDefinition, IndexType from redis.commands.search.query import Query r = redis.Redis(decode_responses=True) diff --git a/doctests/query_range.py b/doctests/query_range.py index 4ef957acfb..674afc492a 100644 --- a/doctests/query_range.py +++ b/doctests/query_range.py @@ -5,7 +5,7 @@ import redis from redis.commands.json.path import Path from redis.commands.search.field import TextField, NumericField, TagField -from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.index_definition import IndexDefinition, IndexType from redis.commands.search.query import NumericFilter, Query r = redis.Redis(decode_responses=True) diff --git a/doctests/search_quickstart.py b/doctests/search_quickstart.py index e190393b16..cde4caa84a 100644 --- a/doctests/search_quickstart.py +++ b/doctests/search_quickstart.py @@ -10,7 +10,7 @@ import redis.commands.search.reducers as reducers from redis.commands.json.path import Path from redis.commands.search.field import NumericField, TagField, TextField -from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.index_definition import IndexDefinition, IndexType from redis.commands.search.query import Query # HIDE_END diff --git a/doctests/search_vss.py b/doctests/search_vss.py index 8b4884727a..a1132971db 100644 --- a/doctests/search_vss.py +++ b/doctests/search_vss.py @@ -20,7 +20,7 @@ TextField, VectorField, ) -from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.index_definition import IndexDefinition, IndexType from redis.commands.search.query import Query from sentence_transformers import SentenceTransformer diff --git a/redis/commands/helpers.py b/redis/commands/helpers.py index 1ea02a60cf..7d9095ea41 100644 --- a/redis/commands/helpers.py +++ b/redis/commands/helpers.py @@ -79,29 +79,6 @@ def parse_list_to_dict(response): return res -def parse_to_dict(response): - if response is None: - return {} - - res = {} - for det in response: - if not isinstance(det, list) or not det: - continue - if len(det) == 1: - res[det[0]] = True - elif isinstance(det[1], list): - res[det[0]] = parse_list_to_dict(det[1]) - else: - try: # try to set the attribute. may be provided without value - try: # try to convert the value to float - res[det[0]] = float(det[1]) - except (TypeError, ValueError): - res[det[0]] = det[1] - except IndexError: - pass - return res - - def random_string(length=10): """ Returns a random N character long string. diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index da79016ad4..2447959922 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -5,12 +5,13 @@ from redis.client import NEVER_DECODE, Pipeline from redis.utils import deprecated_function -from ..helpers import get_protocol_version, parse_to_dict +from ..helpers import get_protocol_version from ._util import to_string from .aggregation import AggregateRequest, AggregateResult, Cursor from .document import Document from .field import Field -from .indexDefinition import IndexDefinition +from .index_definition import IndexDefinition +from .profile_information import ProfileInformation from .query import Query from .result import Result from .suggestion import SuggestionParser @@ -67,7 +68,7 @@ class SearchCommands: def _parse_results(self, cmd, res, **kwargs): if get_protocol_version(self.client) in ["3", 3]: - return res + return ProfileInformation(res) if cmd == "FT.PROFILE" else res else: return self._RESP2_MODULE_CALLBACKS[cmd](res, **kwargs) @@ -101,7 +102,7 @@ def _parse_profile(self, res, **kwargs): with_scores=query._with_scores, ) - return result, parse_to_dict(res[1]) + return result, ProfileInformation(res[1]) def _parse_spellcheck(self, res, **kwargs): corrections = {} diff --git a/redis/commands/search/indexDefinition.py b/redis/commands/search/index_definition.py similarity index 100% rename from redis/commands/search/indexDefinition.py rename to redis/commands/search/index_definition.py diff --git a/redis/commands/search/profile_information.py b/redis/commands/search/profile_information.py new file mode 100644 index 0000000000..23551be27f --- /dev/null +++ b/redis/commands/search/profile_information.py @@ -0,0 +1,14 @@ +from typing import Any + + +class ProfileInformation: + """ + Wrapper around FT.PROFILE response + """ + + def __init__(self, info: Any) -> None: + self._info: Any = info + + @property + def info(self) -> Any: + return self._info diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index cc75e4b4a4..4f5a4c2f04 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -19,7 +19,7 @@ TextField, VectorField, ) -from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.index_definition import IndexDefinition, IndexType from redis.commands.search.query import GeoFilter, NumericFilter, Query from redis.commands.search.result import Result from redis.commands.search.suggestion import Suggestion @@ -27,6 +27,8 @@ is_resp2_connection, skip_if_redis_enterprise, skip_if_resp_version, + skip_if_server_version_gte, + skip_if_server_version_lt, skip_ifmodversion_lt, ) @@ -1111,6 +1113,7 @@ async def test_get(decoded_r: redis.Redis): @pytest.mark.redismod @pytest.mark.onlynoncluster @skip_ifmodversion_lt("2.2.0", "search") +@skip_if_server_version_gte("7.9.0") async def test_config(decoded_r: redis.Redis): assert await decoded_r.ft().config_set("TIMEOUT", "100") with pytest.raises(redis.ResponseError): @@ -1121,6 +1124,19 @@ async def test_config(decoded_r: redis.Redis): assert "100" == res["TIMEOUT"] +@pytest.mark.redismod +@pytest.mark.onlynoncluster +@skip_if_server_version_lt("7.9.0") +async def test_config_with_removed_ftconfig(decoded_r: redis.Redis): + assert await decoded_r.config_set("timeout", "100") + with pytest.raises(redis.ResponseError): + await decoded_r.config_set("timeout", "null") + res = await decoded_r.config_get("*") + assert "100" == res["timeout"] + res = await decoded_r.config_get("timeout") + assert "100" == res["timeout"] + + @pytest.mark.redismod @pytest.mark.onlynoncluster async def test_aggregations_groupby(decoded_r: redis.Redis): diff --git a/tests/test_commands.py b/tests/test_commands.py index 4cad4c14b6..2681b8eaf0 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -1823,6 +1823,7 @@ def try_delete_libs(self, r, *lib_names): @pytest.mark.onlynoncluster @skip_if_server_version_lt("7.1.140") + @skip_if_server_version_gte("7.9.0") def test_tfunction_load_delete(self, stack_r): self.try_delete_libs(stack_r, "lib1") lib_code = self.generate_lib_code("lib1") @@ -1831,6 +1832,7 @@ def test_tfunction_load_delete(self, stack_r): @pytest.mark.onlynoncluster @skip_if_server_version_lt("7.1.140") + @skip_if_server_version_gte("7.9.0") def test_tfunction_list(self, stack_r): self.try_delete_libs(stack_r, "lib1", "lib2", "lib3") assert stack_r.tfunction_load(self.generate_lib_code("lib1")) @@ -1861,6 +1863,7 @@ def test_tfunction_list(self, stack_r): @pytest.mark.onlynoncluster @skip_if_server_version_lt("7.1.140") + @skip_if_server_version_gte("7.9.0") def test_tfcall(self, stack_r): self.try_delete_libs(stack_r, "lib1") assert stack_r.tfunction_load(self.generate_lib_code("lib1")) @@ -4329,6 +4332,7 @@ def test_xgroup_create_mkstream(self, r): assert r.xinfo_groups(stream) == expected @skip_if_server_version_lt("7.0.0") + @skip_if_server_version_gte("7.9.0") def test_xgroup_create_entriesread(self, r: redis.Redis): stream = "stream" group = "group" @@ -4350,6 +4354,28 @@ def test_xgroup_create_entriesread(self, r: redis.Redis): ] assert r.xinfo_groups(stream) == expected + @skip_if_server_version_lt("7.9.0") + def test_xgroup_create_entriesread_with_fixed_lag_field(self, r: redis.Redis): + stream = "stream" + group = "group" + r.xadd(stream, {"foo": "bar"}) + + # no group is setup yet, no info to obtain + assert r.xinfo_groups(stream) == [] + + assert r.xgroup_create(stream, group, 0, entries_read=7) + expected = [ + { + "name": group.encode(), + "consumers": 0, + "pending": 0, + "last-delivered-id": b"0-0", + "entries-read": 7, + "lag": 1, + } + ] + assert r.xinfo_groups(stream) == expected + @skip_if_server_version_lt("5.0.0") def test_xgroup_delconsumer(self, r): stream = "stream" diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 66ee1c5390..06265d382e 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -4,7 +4,6 @@ delist, list_or_args, nativestr, - parse_to_dict, parse_to_list, quote_string, random_string, @@ -26,40 +25,6 @@ def test_parse_to_list(): assert parse_to_list(r) == ["hello", "my name", 45, 555.55, "is simon!", None] -def test_parse_to_dict(): - assert parse_to_dict(None) == {} - r = [ - ["Some number", "1.0345"], - ["Some string", "hello"], - [ - "Child iterators", - [ - "Time", - "0.2089", - "Counter", - 3, - "Child iterators", - ["Type", "bar", "Time", "0.0729", "Counter", 3], - ["Type", "barbar", "Time", "0.058", "Counter", 3], - ["Type", "barbarbar", "Time", "0.0234", "Counter", 3], - ], - ], - ] - assert parse_to_dict(r) == { - "Child iterators": { - "Child iterators": [ - {"Counter": 3.0, "Time": 0.0729, "Type": "bar"}, - {"Counter": 3.0, "Time": 0.058, "Type": "barbar"}, - {"Counter": 3.0, "Time": 0.0234, "Type": "barbarbar"}, - ], - "Counter": 3.0, - "Time": 0.2089, - }, - "Some number": 1.0345, - "Some string": "hello", - } - - def test_nativestr(): assert nativestr("teststr") == "teststr" assert nativestr(b"teststr") == "teststr" diff --git a/tests/test_search.py b/tests/test_search.py index a257484425..ee1ba66434 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -20,7 +20,7 @@ TextField, VectorField, ) -from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.index_definition import IndexDefinition, IndexType from redis.commands.search.query import GeoFilter, NumericFilter, Query from redis.commands.search.result import Result from redis.commands.search.suggestion import Suggestion @@ -30,6 +30,7 @@ is_resp2_connection, skip_if_redis_enterprise, skip_if_resp_version, + skip_if_server_version_gte, skip_if_server_version_lt, skip_ifmodversion_lt, ) @@ -1007,6 +1008,7 @@ def test_get(client): @pytest.mark.redismod @pytest.mark.onlynoncluster @skip_ifmodversion_lt("2.2.0", "search") +@skip_if_server_version_gte("7.9.0") def test_config(client): assert client.ft().config_set("TIMEOUT", "100") with pytest.raises(redis.ResponseError): @@ -1017,6 +1019,19 @@ def test_config(client): assert "100" == res["TIMEOUT"] +@pytest.mark.redismod +@pytest.mark.onlynoncluster +@skip_if_server_version_lt("7.9.0") +def test_config_with_removed_ftconfig(client): + assert client.config_set("timeout", "100") + with pytest.raises(redis.ResponseError): + client.config_set("timeout", "null") + res = client.config_get("*") + assert "100" == res["timeout"] + res = client.config_get("timeout") + assert "100" == res["timeout"] + + @pytest.mark.redismod @pytest.mark.onlynoncluster def test_aggregations_groupby(client): @@ -1571,6 +1586,7 @@ def test_index_definition(client): @pytest.mark.redismod @pytest.mark.onlynoncluster @skip_if_redis_enterprise() +@skip_if_server_version_gte("7.9.0") def test_expire(client): client.ft().create_index((TextField("txt", sortable=True),), temporary=4) ttl = client.execute_command("ft.debug", "TTL", "idx") @@ -2025,6 +2041,8 @@ def test_json_with_jsonpath(client): @pytest.mark.redismod @pytest.mark.onlynoncluster @skip_if_redis_enterprise() +@skip_if_server_version_gte("7.9.0") +@skip_if_server_version_lt("6.3.0") def test_profile(client): client.ft().create_index((TextField("t"),)) client.ft().client.hset("1", "t", "hello") @@ -2034,10 +2052,9 @@ def test_profile(client): q = Query("hello|world").no_content() if is_resp2_connection(client): res, det = client.ft().profile(q) - assert det["Iterators profile"]["Counter"] == 2.0 - assert len(det["Iterators profile"]["Child iterators"]) == 2 - assert det["Iterators profile"]["Type"] == "UNION" - assert det["Parsing time"] < 0.5 + det = det.info + + assert isinstance(det, list) assert len(res.docs) == 2 # check also the search result # check using AggregateRequest @@ -2047,15 +2064,14 @@ def test_profile(client): .apply(prefix="startswith(@t, 'hel')") ) res, det = client.ft().profile(req) - assert det["Iterators profile"]["Counter"] == 2 - assert det["Iterators profile"]["Type"] == "WILDCARD" - assert isinstance(det["Parsing time"], float) + det = det.info + assert isinstance(det, list) assert len(res.rows) == 2 # check also the search result else: res = client.ft().profile(q) - assert res["profile"]["Iterators profile"][0]["Counter"] == 2.0 - assert res["profile"]["Iterators profile"][0]["Type"] == "UNION" - assert res["profile"]["Parsing time"] < 0.5 + res = res.info + + assert isinstance(res, dict) assert len(res["results"]) == 2 # check also the search result # check using AggregateRequest @@ -2065,14 +2081,97 @@ def test_profile(client): .apply(prefix="startswith(@t, 'hel')") ) res = client.ft().profile(req) - assert res["profile"]["Iterators profile"][0]["Counter"] == 2 - assert res["profile"]["Iterators profile"][0]["Type"] == "WILDCARD" - assert isinstance(res["profile"]["Parsing time"], float) + res = res.info + + assert isinstance(res, dict) assert len(res["results"]) == 2 # check also the search result @pytest.mark.redismod @pytest.mark.onlynoncluster +@skip_if_redis_enterprise() +@skip_if_server_version_lt("7.9.0") +def test_profile_with_coordinator(client): + client.ft().create_index((TextField("t"),)) + client.ft().client.hset("1", "t", "hello") + client.ft().client.hset("2", "t", "world") + + # check using Query + q = Query("hello|world").no_content() + if is_resp2_connection(client): + res, det = client.ft().profile(q) + det = det.info + + assert isinstance(det, list) + assert len(res.docs) == 2 # check also the search result + + # check using AggregateRequest + req = ( + aggregations.AggregateRequest("*") + .load("t") + .apply(prefix="startswith(@t, 'hel')") + ) + res, det = client.ft().profile(req) + det = det.info + + assert isinstance(det, list) + assert det[0] == "Shards" + assert det[2] == "Coordinator" + assert len(res.rows) == 2 # check also the search result + else: + res = client.ft().profile(q) + res = res.info + + assert isinstance(res, dict) + assert len(res["Results"]["results"]) == 2 # check also the search result + + # check using AggregateRequest + req = ( + aggregations.AggregateRequest("*") + .load("t") + .apply(prefix="startswith(@t, 'hel')") + ) + res = client.ft().profile(req) + res = res.info + + assert isinstance(res, dict) + assert len(res["Results"]["results"]) == 2 # check also the search result + + +@pytest.mark.redismod +@pytest.mark.onlynoncluster +@skip_if_redis_enterprise() +@skip_if_server_version_gte("6.3.0") +def test_profile_with_no_warnings(client): + client.ft().create_index((TextField("t"),)) + client.ft().client.hset("1", "t", "hello") + client.ft().client.hset("2", "t", "world") + + # check using Query + q = Query("hello|world").no_content() + res, det = client.ft().profile(q) + det = det.info + + assert isinstance(det, list) + assert len(res.docs) == 2 # check also the search result + + # check using AggregateRequest + req = ( + aggregations.AggregateRequest("*") + .load("t") + .apply(prefix="startswith(@t, 'hel')") + ) + res, det = client.ft().profile(req) + det = det.info + + assert isinstance(det, list) + assert len(res.rows) == 2 # check also the search result + + +@pytest.mark.redismod +@pytest.mark.onlynoncluster +@skip_if_server_version_gte("7.9.0") +@skip_if_server_version_lt("6.3.0") def test_profile_limited(client): client.ft().create_index((TextField("t"),)) client.ft().client.hset("1", "t", "hello") @@ -2083,18 +2182,14 @@ def test_profile_limited(client): q = Query("%hell% hel*") if is_resp2_connection(client): res, det = client.ft().profile(q, limited=True) - assert ( - det["Iterators profile"]["Child iterators"][0]["Child iterators"] - == "The number of iterators in the union is 3" - ) - assert ( - det["Iterators profile"]["Child iterators"][1]["Child iterators"] - == "The number of iterators in the union is 4" - ) - assert det["Iterators profile"]["Type"] == "INTERSECT" + det = det.info + assert det[4][1][7][9] == "The number of iterators in the union is 3" + assert det[4][1][8][9] == "The number of iterators in the union is 4" + assert det[4][1][1] == "INTERSECT" assert len(res.docs) == 3 # check also the search result else: res = client.ft().profile(q, limited=True) + res = res.info iterators_profile = res["profile"]["Iterators profile"] assert ( iterators_profile[0]["Child iterators"][0]["Child iterators"] @@ -2110,6 +2205,8 @@ def test_profile_limited(client): @pytest.mark.redismod @skip_ifmodversion_lt("2.4.3", "search") +@skip_if_server_version_gte("7.9.0") +@skip_if_server_version_lt("6.3.0") def test_profile_query_params(client): client.ft().create_index( ( @@ -2125,13 +2222,15 @@ def test_profile_query_params(client): q = Query(query).return_field("__v_score").sort_by("__v_score", True) if is_resp2_connection(client): res, det = client.ft().profile(q, query_params={"vec": "aaaaaaaa"}) - assert det["Iterators profile"]["Counter"] == 2.0 - assert det["Iterators profile"]["Type"] == "VECTOR" + det = det.info + assert det[4][1][5] == 2.0 + assert det[4][1][1] == "VECTOR" assert res.total == 2 assert "a" == res.docs[0].id assert "0" == res.docs[0].__getattribute__("__v_score") else: res = client.ft().profile(q, query_params={"vec": "aaaaaaaa"}) + res = res.info assert res["profile"]["Iterators profile"][0]["Counter"] == 2 assert res["profile"]["Iterators profile"][0]["Type"] == "VECTOR" assert res["total_results"] == 2 From c60c41cf7bbd5a3b51c9e9f7380dd88d0489ef18 Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Tue, 28 Jan 2025 10:50:49 +0200 Subject: [PATCH 011/113] Updated test infrastructure with latest versions (#3484) * Updated latest image to 8.0-M03-pre * Changed redis image for testing purposes * Updated image version * Updated redis server versions * Updated test case * Revert version restriction * Updated redis versions * Added tests for new default scorer * Skipped test on 8.0 * Fixed query to match exact-match syntax * Codestyle fixes * Added condition for 8.0-M04-pre image * Added test for INFO section --- .github/actions/run-tests/action.yml | 10 ++- .github/workflows/integration.yaml | 4 +- redis/commands/search/query.py | 2 + tests/test_asyncio/test_search.py | 94 +++++++++++++++++++++++++++- tests/test_commands.py | 23 ------- tests/test_search.py | 89 ++++++++++++++++++++++++-- 6 files changed, 186 insertions(+), 36 deletions(-) diff --git a/.github/actions/run-tests/action.yml b/.github/actions/run-tests/action.yml index 5ca6bf5a09..1f9332fb86 100644 --- a/.github/actions/run-tests/action.yml +++ b/.github/actions/run-tests/action.yml @@ -35,6 +35,10 @@ runs: CLIENT_LIBS_TEST_IMAGE: "redislabs/client-libs-test:${{ inputs.redis-version }}" run: | set -e + + if [ "${{inputs.redis-version}}" == "8.0-M04-pre" ]; then + export REDIS_IMAGE=redis:8.0-M03 + fi echo "::group::Installing dependencies" pip install -U setuptools wheel @@ -56,9 +60,9 @@ runs: # Mapping of redis version to stack version declare -A redis_stack_version_mapping=( - ["7.4.1"]="7.4.0-v1" - ["7.2.6"]="7.2.0-v13" - ["6.2.16"]="6.2.6-v17" + ["7.4.2"]="7.4.0-v2" + ["7.2.7"]="7.2.0-v14" + ["6.2.17"]="6.2.6-v18" ) if [[ -v redis_stack_version_mapping[$REDIS_VERSION] ]]; then diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index c32029e6f9..c4548c21ef 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -29,7 +29,7 @@ env: COVERAGE_CORE: sysmon REDIS_IMAGE: redis:latest REDIS_STACK_IMAGE: redis/redis-stack-server:latest - CURRENT_REDIS_VERSION: '7.4.1' + CURRENT_REDIS_VERSION: '7.4.2' jobs: dependency-audit: @@ -74,7 +74,7 @@ jobs: max-parallel: 15 fail-fast: false matrix: - redis-version: ['8.0-M02', '${{ needs.redis_version.outputs.CURRENT }}', '7.2.6', '6.2.16'] + redis-version: ['8.0-M04-pre', '${{ needs.redis_version.outputs.CURRENT }}', '7.2.7', '6.2.17'] python-version: ['3.8', '3.12'] parser-backend: ['plain'] event-loop: ['asyncio'] diff --git a/redis/commands/search/query.py b/redis/commands/search/query.py index 964ce6cdf4..a8312a2ad2 100644 --- a/redis/commands/search/query.py +++ b/redis/commands/search/query.py @@ -179,6 +179,8 @@ def scorer(self, scorer: str) -> "Query": Use a different scoring function to evaluate document relevance. Default is `TFIDF`. + Since Redis 8.0 default was changed to BM25STD. + :param scorer: The scoring function to use (e.g. `TFIDF.DOCNORM` or `BM25`) """ diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 4f5a4c2f04..c0efcce882 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -341,6 +341,7 @@ async def test_client(decoded_r: redis.Redis): @pytest.mark.redismod @pytest.mark.onlynoncluster +@skip_if_server_version_gte("7.9.0") async def test_scores(decoded_r: redis.Redis): await decoded_r.ft().create_index((TextField("txt"),)) @@ -361,6 +362,29 @@ async def test_scores(decoded_r: redis.Redis): assert "doc1" == res["results"][1]["id"] +@pytest.mark.redismod +@pytest.mark.onlynoncluster +@skip_if_server_version_lt("7.9.0") +async def test_scores_with_new_default_scorer(decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("txt"),)) + + await decoded_r.hset("doc1", mapping={"txt": "foo baz"}) + await decoded_r.hset("doc2", mapping={"txt": "foo bar"}) + + q = Query("foo ~bar").with_scores() + res = await decoded_r.ft().search(q) + if is_resp2_connection(decoded_r): + assert 2 == res.total + assert "doc2" == res.docs[0].id + assert 0.87 == pytest.approx(res.docs[0].score, 0.01) + assert "doc1" == res.docs[1].id + else: + assert 2 == res["total_results"] + assert "doc2" == res["results"][0]["id"] + assert 0.87 == pytest.approx(res["results"][0]["score"], 0.01) + assert "doc1" == res["results"][1]["id"] + + @pytest.mark.redismod async def test_stopwords(decoded_r: redis.Redis): stopwords = ["foo", "bar", "baz"] @@ -663,7 +687,7 @@ async def test_summarize(decoded_r: redis.Redis): await createIndex(decoded_r.ft()) await waitForIndex(decoded_r, "idx") - q = Query("king henry").paging(0, 1) + q = Query('"king henry"').paging(0, 1) q.highlight(fields=("play", "txt"), tags=("", "")) q.summarize("txt") @@ -675,7 +699,7 @@ async def test_summarize(decoded_r: redis.Redis): == doc.txt ) - q = Query("king henry").paging(0, 1).summarize().highlight() + q = Query('"king henry"').paging(0, 1).summarize().highlight() doc = sorted((await decoded_r.ft().search(q)).docs)[0] assert "Henry ... " == doc.play @@ -691,7 +715,7 @@ async def test_summarize(decoded_r: redis.Redis): == doc["extra_attributes"]["txt"] ) - q = Query("king henry").paging(0, 1).summarize().highlight() + q = Query('"king henry"').paging(0, 1).summarize().highlight() doc = sorted((await decoded_r.ft().search(q))["results"])[0] assert "Henry ... " == doc["extra_attributes"]["play"] @@ -1029,6 +1053,7 @@ async def test_phonetic_matcher(decoded_r: redis.Redis): @pytest.mark.onlynoncluster # NOTE(imalinovskyi): This test contains hardcoded scores valid only for RediSearch 2.8+ @skip_ifmodversion_lt("2.8.0", "search") +@skip_if_server_version_gte("7.9.0") async def test_scorer(decoded_r: redis.Redis): await decoded_r.ft().create_index((TextField("description"),)) @@ -1087,6 +1112,69 @@ async def test_scorer(decoded_r: redis.Redis): assert 0.0 == res["results"][0]["score"] +@pytest.mark.redismod +@pytest.mark.onlynoncluster +# NOTE(imalinovskyi): This test contains hardcoded scores valid only for RediSearch 2.8+ +@skip_ifmodversion_lt("2.8.0", "search") +@skip_if_server_version_lt("7.9.0") +async def test_scorer_with_new_default_scorer(decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("description"),)) + + await decoded_r.hset( + "doc1", mapping={"description": "The quick brown fox jumps over the lazy dog"} + ) + await decoded_r.hset( + "doc2", + mapping={ + "description": "Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do." # noqa + }, + ) + + if is_resp2_connection(decoded_r): + # default scorer is BM25STD + res = await decoded_r.ft().search(Query("quick").with_scores()) + assert 0.23 == pytest.approx(res.docs[0].score, 0.05) + res = await decoded_r.ft().search(Query("quick").scorer("TFIDF").with_scores()) + assert 1.0 == res.docs[0].score + res = await decoded_r.ft().search( + Query("quick").scorer("TFIDF.DOCNORM").with_scores() + ) + assert 0.14285714285714285 == res.docs[0].score + res = await decoded_r.ft().search(Query("quick").scorer("BM25").with_scores()) + assert 0.22471909420069797 == res.docs[0].score + res = await decoded_r.ft().search(Query("quick").scorer("DISMAX").with_scores()) + assert 2.0 == res.docs[0].score + res = await decoded_r.ft().search( + Query("quick").scorer("DOCSCORE").with_scores() + ) + assert 1.0 == res.docs[0].score + res = await decoded_r.ft().search( + Query("quick").scorer("HAMMING").with_scores() + ) + assert 0.0 == res.docs[0].score + else: + res = await decoded_r.ft().search(Query("quick").with_scores()) + assert 0.23 == pytest.approx(res["results"][0]["score"], 0.05) + res = await decoded_r.ft().search(Query("quick").scorer("TFIDF").with_scores()) + assert 1.0 == res["results"][0]["score"] + res = await decoded_r.ft().search( + Query("quick").scorer("TFIDF.DOCNORM").with_scores() + ) + assert 0.14285714285714285 == res["results"][0]["score"] + res = await decoded_r.ft().search(Query("quick").scorer("BM25").with_scores()) + assert 0.22471909420069797 == res["results"][0]["score"] + res = await decoded_r.ft().search(Query("quick").scorer("DISMAX").with_scores()) + assert 2.0 == res["results"][0]["score"] + res = await decoded_r.ft().search( + Query("quick").scorer("DOCSCORE").with_scores() + ) + assert 1.0 == res["results"][0]["score"] + res = await decoded_r.ft().search( + Query("quick").scorer("HAMMING").with_scores() + ) + assert 0.0 == res["results"][0]["score"] + + @pytest.mark.redismod async def test_get(decoded_r: redis.Redis): await decoded_r.ft().create_index((TextField("f1"), TextField("f2"))) diff --git a/tests/test_commands.py b/tests/test_commands.py index 2681b8eaf0..f83fe76aa9 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -4332,7 +4332,6 @@ def test_xgroup_create_mkstream(self, r): assert r.xinfo_groups(stream) == expected @skip_if_server_version_lt("7.0.0") - @skip_if_server_version_gte("7.9.0") def test_xgroup_create_entriesread(self, r: redis.Redis): stream = "stream" group = "group" @@ -4341,28 +4340,6 @@ def test_xgroup_create_entriesread(self, r: redis.Redis): # no group is setup yet, no info to obtain assert r.xinfo_groups(stream) == [] - assert r.xgroup_create(stream, group, 0, entries_read=7) - expected = [ - { - "name": group.encode(), - "consumers": 0, - "pending": 0, - "last-delivered-id": b"0-0", - "entries-read": 7, - "lag": -6, - } - ] - assert r.xinfo_groups(stream) == expected - - @skip_if_server_version_lt("7.9.0") - def test_xgroup_create_entriesread_with_fixed_lag_field(self, r: redis.Redis): - stream = "stream" - group = "group" - r.xadd(stream, {"foo": "bar"}) - - # no group is setup yet, no info to obtain - assert r.xinfo_groups(stream) == [] - assert r.xgroup_create(stream, group, 0, entries_read=7) expected = [ { diff --git a/tests/test_search.py b/tests/test_search.py index ee1ba66434..5b45cfc0a3 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -314,6 +314,7 @@ def test_client(client): @pytest.mark.redismod @pytest.mark.onlynoncluster +@skip_if_server_version_gte("7.9.0") def test_scores(client): client.ft().create_index((TextField("txt"),)) @@ -334,6 +335,29 @@ def test_scores(client): assert "doc1" == res["results"][1]["id"] +@pytest.mark.redismod +@pytest.mark.onlynoncluster +@skip_if_server_version_lt("7.9.0") +def test_scores_with_new_default_scorer(client): + client.ft().create_index((TextField("txt"),)) + + client.hset("doc1", mapping={"txt": "foo baz"}) + client.hset("doc2", mapping={"txt": "foo bar"}) + + q = Query("foo ~bar").with_scores() + res = client.ft().search(q) + if is_resp2_connection(client): + assert 2 == res.total + assert "doc2" == res.docs[0].id + assert 0.87 == pytest.approx(res.docs[0].score, 0.01) + assert "doc1" == res.docs[1].id + else: + assert 2 == res["total_results"] + assert "doc2" == res["results"][0]["id"] + assert 0.87 == pytest.approx(res["results"][0]["score"], 0.01) + assert "doc1" == res["results"][1]["id"] + + @pytest.mark.redismod def test_stopwords(client): client.ft().create_index((TextField("txt"),), stopwords=["foo", "bar", "baz"]) @@ -623,7 +647,7 @@ def test_summarize(client): createIndex(client.ft()) waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - q = Query("king henry").paging(0, 1) + q = Query('"king henry"').paging(0, 1) q.highlight(fields=("play", "txt"), tags=("", "")) q.summarize("txt") @@ -635,7 +659,7 @@ def test_summarize(client): == doc.txt ) - q = Query("king henry").paging(0, 1).summarize().highlight() + q = Query('"king henry"').paging(0, 1).summarize().highlight() doc = sorted(client.ft().search(q).docs)[0] assert "Henry ... " == doc.play @@ -651,7 +675,7 @@ def test_summarize(client): == doc["extra_attributes"]["txt"] ) - q = Query("king henry").paging(0, 1).summarize().highlight() + q = Query('"king henry"').paging(0, 1).summarize().highlight() doc = sorted(client.ft().search(q)["results"])[0] assert "Henry ... " == doc["extra_attributes"]["play"] @@ -936,6 +960,7 @@ def test_phonetic_matcher(client): @pytest.mark.onlynoncluster # NOTE(imalinovskyi): This test contains hardcoded scores valid only for RediSearch 2.8+ @skip_ifmodversion_lt("2.8.0", "search") +@skip_if_server_version_gte("7.9.0") def test_scorer(client): client.ft().create_index((TextField("description"),)) @@ -982,6 +1007,55 @@ def test_scorer(client): assert 0.0 == res["results"][0]["score"] +@pytest.mark.redismod +@pytest.mark.onlynoncluster +@skip_if_server_version_lt("7.9.0") +def test_scorer_with_new_default_scorer(client): + client.ft().create_index((TextField("description"),)) + + client.hset( + "doc1", mapping={"description": "The quick brown fox jumps over the lazy dog"} + ) + client.hset( + "doc2", + mapping={ + "description": "Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do." # noqa + }, + ) + + # default scorer is BM25STD + if is_resp2_connection(client): + res = client.ft().search(Query("quick").with_scores()) + assert 0.23 == pytest.approx(res.docs[0].score, 0.05) + res = client.ft().search(Query("quick").scorer("TFIDF").with_scores()) + assert 1.0 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores()) + assert 0.14285714285714285 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("BM25").with_scores()) + assert 0.22471909420069797 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("DISMAX").with_scores()) + assert 2.0 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("DOCSCORE").with_scores()) + assert 1.0 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("HAMMING").with_scores()) + assert 0.0 == res.docs[0].score + else: + res = client.ft().search(Query("quick").with_scores()) + assert 0.23 == pytest.approx(res["results"][0]["score"], 0.05) + res = client.ft().search(Query("quick").scorer("TFIDF").with_scores()) + assert 1.0 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores()) + assert 0.14285714285714285 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("BM25").with_scores()) + assert 0.22471909420069797 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("DISMAX").with_scores()) + assert 2.0 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("DOCSCORE").with_scores()) + assert 1.0 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("HAMMING").with_scores()) + assert 0.0 == res["results"][0]["score"] + + @pytest.mark.redismod def test_get(client): client.ft().create_index((TextField("f1"), TextField("f2"))) @@ -2605,9 +2679,8 @@ def test_search_missing_fields(client): }, ) - with pytest.raises(redis.exceptions.ResponseError) as e: + with pytest.raises(redis.exceptions.ResponseError): client.ft().search(Query("ismissing(@title)").return_field("id").no_content()) - assert "to be defined with 'INDEXMISSING'" in e.value.args[0] res = client.ft().search( Query("ismissing(@features)").return_field("id").no_content() @@ -2813,6 +2886,12 @@ def test_search_query_with_different_dialects(client): assert res["total_results"] == 0 +@pytest.mark.redismod +@skip_if_server_version_lt("7.9.0") +def test_info_exposes_search_info(client): + assert len(client.info("search")) > 0 + + def _assert_search_result(client, result, expected_doc_ids): """ Make sure the result of a geo search is as expected, taking into account the RESP From 10f3a4b56b13b7487c95a84785931c756383ef69 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Tue, 28 Jan 2025 14:33:39 +0200 Subject: [PATCH 012/113] Adding unit text fixes to improve compatibility with MacOS. (#3486) * Adding unit text fixes to improve compatibility with MacOS. * Applying review comments * Unifying the exception msg validation pattern for both test_connection.py files --------- Co-authored-by: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> --- docker-compose.yml | 3 ++- tests/test_asyncio/test_connection.py | 30 +++++++++++++-------------- tests/test_connection.py | 18 +++++++--------- tests/test_multiprocessing.py | 4 ++++ 4 files changed, 27 insertions(+), 28 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index 7804f09c8a..60657d5653 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -103,7 +103,7 @@ services: - all redis-stack: - image: ${REDIS_STACK_IMAGE:-redis/redis-stack-server:edge} + image: ${REDIS_STACK_IMAGE:-redis/redis-stack-server:latest} container_name: redis-stack ports: - 6479:6379 @@ -112,6 +112,7 @@ services: profiles: - standalone - all-stack + - all redis-stack-graph: image: redis/redis-stack-server:6.2.6-v15 diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index e584fc6999..d4956f16e9 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -1,6 +1,7 @@ import asyncio import socket import types +from errno import ECONNREFUSED from unittest.mock import patch import pytest @@ -36,15 +37,16 @@ async def test_invalid_response(create_redis): fake_stream = MockStream(raw + b"\r\n") parser: _AsyncRESPBase = r.connection._parser - with mock.patch.object(parser, "_stream", fake_stream): - with pytest.raises(InvalidResponse) as cm: - await parser.read_response() + if isinstance(parser, _AsyncRESPBase): - assert str(cm.value) == f"Protocol Error: {raw!r}" + exp_err = f"Protocol Error: {raw!r}" else: - assert ( - str(cm.value) == f'Protocol error, got "{raw.decode()}" as reply type byte' - ) + exp_err = f'Protocol error, got "{raw.decode()}" as reply type byte' + + with mock.patch.object(parser, "_stream", fake_stream): + with pytest.raises(InvalidResponse, match=exp_err): + await parser.read_response() + await r.connection.disconnect() @@ -170,10 +172,9 @@ async def test_connect_timeout_error_without_retry(): conn._connect = mock.AsyncMock() conn._connect.side_effect = socket.timeout - with pytest.raises(TimeoutError) as e: + with pytest.raises(TimeoutError, match="Timeout connecting to server"): await conn.connect() assert conn._connect.call_count == 1 - assert str(e.value) == "Timeout connecting to server" @pytest.mark.onlynoncluster @@ -531,17 +532,14 @@ async def test_format_error_message(conn, error, expected_message): async def test_network_connection_failure(): - with pytest.raises(ConnectionError) as e: + exp_err = rf"^Error {ECONNREFUSED} connecting to 127.0.0.1:9999.(.+)$" + with pytest.raises(ConnectionError, match=exp_err): redis = Redis(host="127.0.0.1", port=9999) await redis.set("a", "b") - assert str(e.value).startswith("Error 111 connecting to 127.0.0.1:9999. Connect") async def test_unix_socket_connection_failure(): - with pytest.raises(ConnectionError) as e: + exp_err = "Error 2 connecting to unix:///tmp/a.sock. No such file or directory." + with pytest.raises(ConnectionError, match=exp_err): redis = Redis(unix_socket_path="unix:///tmp/a.sock") await redis.set("a", "b") - assert ( - str(e.value) - == "Error 2 connecting to unix:///tmp/a.sock. No such file or directory." - ) diff --git a/tests/test_connection.py b/tests/test_connection.py index fbc23ae8c0..6c1498a329 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -4,6 +4,7 @@ import sys import threading import types +from errno import ECONNREFUSED from typing import Any from unittest import mock from unittest.mock import call, patch @@ -44,9 +45,8 @@ def test_invalid_response(r): raw = b"x" parser = r.connection._parser with mock.patch.object(parser._buffer, "readline", return_value=raw): - with pytest.raises(InvalidResponse) as cm: + with pytest.raises(InvalidResponse, match=f"Protocol Error: {raw!r}"): parser.read_response() - assert str(cm.value) == f"Protocol Error: {raw!r}" @skip_if_server_version_lt("4.0.0") @@ -141,10 +141,9 @@ def test_connect_timeout_error_without_retry(self): conn._connect = mock.Mock() conn._connect.side_effect = socket.timeout - with pytest.raises(TimeoutError) as e: + with pytest.raises(TimeoutError, match="Timeout connecting to server"): conn.connect() assert conn._connect.call_count == 1 - assert str(e.value) == "Timeout connecting to server" self.clear(conn) @@ -349,20 +348,17 @@ def test_format_error_message(conn, error, expected_message): def test_network_connection_failure(): - with pytest.raises(ConnectionError) as e: + exp_err = f"Error {ECONNREFUSED} connecting to localhost:9999. Connection refused." + with pytest.raises(ConnectionError, match=exp_err): redis = Redis(port=9999) redis.set("a", "b") - assert str(e.value) == "Error 111 connecting to localhost:9999. Connection refused." def test_unix_socket_connection_failure(): - with pytest.raises(ConnectionError) as e: + exp_err = "Error 2 connecting to unix:///tmp/a.sock. No such file or directory." + with pytest.raises(ConnectionError, match=exp_err): redis = Redis(unix_socket_path="unix:///tmp/a.sock") redis.set("a", "b") - assert ( - str(e.value) - == "Error 2 connecting to unix:///tmp/a.sock. No such file or directory." - ) class TestUnitConnectionPool: diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index 5cda3190a6..116d20dab0 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -1,5 +1,6 @@ import contextlib import multiprocessing +import sys import pytest import redis @@ -8,6 +9,9 @@ from .conftest import _get_client +if sys.platform == "darwin": + multiprocessing.set_start_method("fork", force=True) + @contextlib.contextmanager def exit_callback(callback, *args): From c98c6eb3ea0f34d05373f0c825fcb63ec5b895bd Mon Sep 17 00:00:00 2001 From: Niklas Becker <48069565+niklasbec@users.noreply.github.com> Date: Tue, 28 Jan 2025 16:36:14 +0100 Subject: [PATCH 013/113] fix: update redis university url, the old link doesn't work (#3481) * fix: update redis university url * fix: add comment to changes --------- Co-authored-by: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> --- CHANGES | 2 +- README.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGES b/CHANGES index b955681b89..bd96846b6d 100644 --- a/CHANGES +++ b/CHANGES @@ -1,4 +1,4 @@ - * Add dynamic_startup_nodes parameter to async RedisCluster (#2472) + * Update URL in the readme linking to Redis University * Move doctests (doc code examples) to main branch * Update `ResponseT` type hint * Allow to control the minimum SSL version diff --git a/README.md b/README.md index 08eff587be..98ddee5b52 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ The Python interface to the Redis key-value store. ## How do I Redis? -[Learn for free at Redis University](https://redis.io/university/) +[Learn for free at Redis University](https://redis.io/learn/university) [Try the Redis Cloud](https://redis.io/try-free/) From 4a8da2aa89be36260235d346fadc041b213ebace Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Thu, 30 Jan 2025 10:25:33 +0200 Subject: [PATCH 014/113] Adding tests for modules ACL and modules config changes in 8.0 (#3489) * Adding tests for modules ACL and modules config changes in 8.0 * Applying review comments * Adding deprecation annotations for tf config commands --- redis/commands/search/commands.py | 8 ++ tests/test_asyncio/test_commands.py | 210 ++++++++++++++++++++++++++++ tests/test_commands.py | 209 +++++++++++++++++++++++++++ 3 files changed, 427 insertions(+) diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index 2447959922..9d9ef42415 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -692,6 +692,10 @@ def dict_dump(self, name: str): cmd = [DICT_DUMP_CMD, name] return self.execute_command(*cmd) + @deprecated_function( + version="8.0.0", + reason="deprecated since Redis 8.0, call config_set from core module instead", + ) def config_set(self, option: str, value: str) -> bool: """Set runtime configuration option. @@ -706,6 +710,10 @@ def config_set(self, option: str, value: str) -> bool: raw = self.execute_command(*cmd) return raw == "OK" + @deprecated_function( + version="8.0.0", + reason="deprecated since Redis 8.0, call config_get from core module instead", + ) def config_get(self, option: str) -> str: """Get runtime configuration option value. diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index f6ed07fab5..9f154cb273 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -20,6 +20,9 @@ parse_info, ) from redis.client import EMPTY_RESPONSE, NEVER_DECODE +from redis.commands.json.path import Path +from redis.commands.search.field import TextField +from redis.commands.search.query import Query from tests.conftest import ( assert_resp_response, assert_resp_response_in, @@ -49,6 +52,12 @@ def factory(username): return r yield factory + try: + client_info = await r.client_info() + except exceptions.NoPermissionError: + client_info = {} + if "default" != client_info.get("user", ""): + await r.auth("", "default") for username in usernames: await r.acl_deluser(username) @@ -115,12 +124,65 @@ async def test_acl_cat_no_category(self, r: redis.Redis): assert isinstance(categories, list) assert "read" in categories or b"read" in categories + @pytest.mark.redismod + @skip_if_server_version_lt("7.9.0") + async def test_acl_cat_contain_modules_no_category(self, r: redis.Redis): + modules_list = [ + "search", + "bloom", + "json", + "cuckoo", + "timeseries", + "cms", + "topk", + "tdigest", + ] + categories = await r.acl_cat() + assert isinstance(categories, list) + for module_cat in modules_list: + assert module_cat in categories or module_cat.encode() in categories + @skip_if_server_version_lt(REDIS_6_VERSION) async def test_acl_cat_with_category(self, r: redis.Redis): commands = await r.acl_cat("read") assert isinstance(commands, list) assert "get" in commands or b"get" in commands + @pytest.mark.redismod + @skip_if_server_version_lt("7.9.0") + async def test_acl_modules_cat_with_category(self, r: redis.Redis): + search_commands = await r.acl_cat("search") + assert isinstance(search_commands, list) + assert "FT.SEARCH" in search_commands or b"FT.SEARCH" in search_commands + + bloom_commands = await r.acl_cat("bloom") + assert isinstance(bloom_commands, list) + assert "bf.add" in bloom_commands or b"bf.add" in bloom_commands + + json_commands = await r.acl_cat("json") + assert isinstance(json_commands, list) + assert "json.get" in json_commands or b"json.get" in json_commands + + cuckoo_commands = await r.acl_cat("cuckoo") + assert isinstance(cuckoo_commands, list) + assert "cf.insert" in cuckoo_commands or b"cf.insert" in cuckoo_commands + + cms_commands = await r.acl_cat("cms") + assert isinstance(cms_commands, list) + assert "cms.query" in cms_commands or b"cms.query" in cms_commands + + topk_commands = await r.acl_cat("topk") + assert isinstance(topk_commands, list) + assert "topk.list" in topk_commands or b"topk.list" in topk_commands + + tdigest_commands = await r.acl_cat("tdigest") + assert isinstance(tdigest_commands, list) + assert "tdigest.rank" in tdigest_commands or b"tdigest.rank" in tdigest_commands + + timeseries_commands = await r.acl_cat("timeseries") + assert isinstance(timeseries_commands, list) + assert "ts.range" in timeseries_commands or b"ts.range" in timeseries_commands + @skip_if_server_version_lt(REDIS_6_VERSION) async def test_acl_deluser(self, r_teardown): username = "redis-py-user" @@ -316,6 +378,116 @@ async def test_acl_whoami(self, r: redis.Redis): username = await r.acl_whoami() assert isinstance(username, (str, bytes)) + @pytest.mark.redismod + @skip_if_server_version_lt("7.9.0") + async def test_acl_modules_commands(self, r_teardown): + username = "redis-py-user" + password = "pass-for-test-user" + + r = r_teardown(username) + await r.flushdb() + + await r.ft().create_index((TextField("txt"),)) + await r.hset("doc1", mapping={"txt": "foo baz"}) + await r.hset("doc2", mapping={"txt": "foo bar"}) + + await r.acl_setuser( + username, + enabled=True, + reset=True, + passwords=[f"+{password}"], + categories=["-all"], + commands=[ + "+FT.SEARCH", + "-FT.DROPINDEX", + "+json.set", + "+json.get", + "-json.clear", + "+bf.reserve", + "-bf.info", + "+cf.reserve", + "+cms.initbydim", + "+topk.reserve", + "+tdigest.create", + "+ts.create", + "-ts.info", + ], + keys=["*"], + ) + + await r.auth(password, username) + + assert await r.ft().search(Query("foo ~bar")) + with pytest.raises(exceptions.NoPermissionError): + await r.ft().dropindex() + + await r.json().set("foo", Path.root_path(), "bar") + assert await r.json().get("foo") == "bar" + with pytest.raises(exceptions.NoPermissionError): + await r.json().clear("foo") + + assert await r.bf().create("bloom", 0.01, 1000) + assert await r.cf().create("cuckoo", 1000) + assert await r.cms().initbydim("cmsDim", 100, 5) + assert await r.topk().reserve("topk", 5, 100, 5, 0.9) + assert await r.tdigest().create("to-tDigest", 10) + with pytest.raises(exceptions.NoPermissionError): + await r.bf().info("bloom") + + assert await r.ts().create(1, labels={"Redis": "Labs"}) + with pytest.raises(exceptions.NoPermissionError): + await r.ts().info(1) + + @pytest.mark.redismod + @skip_if_server_version_lt("7.9.0") + async def test_acl_modules_category_commands(self, r_teardown): + username = "redis-py-user" + password = "pass-for-test-user" + + r = r_teardown(username) + await r.flushdb() + + # validate modules categories acl config + await r.acl_setuser( + username, + enabled=True, + reset=True, + passwords=[f"+{password}"], + categories=[ + "-all", + "+@search", + "+@json", + "+@bloom", + "+@cuckoo", + "+@topk", + "+@cms", + "+@timeseries", + "+@tdigest", + ], + keys=["*"], + ) + await r.ft().create_index((TextField("txt"),)) + await r.hset("doc1", mapping={"txt": "foo baz"}) + await r.hset("doc2", mapping={"txt": "foo bar"}) + + await r.auth(password, username) + + assert await r.ft().search(Query("foo ~bar")) + assert await r.ft().dropindex() + + assert await r.json().set("foo", Path.root_path(), "bar") + assert await r.json().get("foo") == "bar" + + assert await r.bf().create("bloom", 0.01, 1000) + assert await r.bf().info("bloom") + assert await r.cf().create("cuckoo", 1000) + assert await r.cms().initbydim("cmsDim", 100, 5) + assert await r.topk().reserve("topk", 5, 100, 5, 0.9) + assert await r.tdigest().create("to-tDigest", 10) + + assert await r.ts().create(1, labels={"Redis": "Labs"}) + assert await r.ts().info(1) + @pytest.mark.onlynoncluster async def test_client_list(self, r: redis.Redis): clients = await r.client_list() @@ -512,6 +684,44 @@ async def test_config_set(self, r: redis.Redis): assert await r.config_set("timeout", 0) assert (await r.config_get())["timeout"] == "0" + @pytest.mark.redismod + @skip_if_server_version_lt("7.9.0") + async def test_config_get_for_modules(self, r: redis.Redis): + search_module_configs = await r.config_get("search-*") + assert "search-timeout" in search_module_configs + + ts_module_configs = await r.config_get("ts-*") + assert "ts-retention-policy" in ts_module_configs + + bf_module_configs = await r.config_get("bf-*") + assert "bf-error-rate" in bf_module_configs + + cf_module_configs = await r.config_get("cf-*") + assert "cf-initial-size" in cf_module_configs + + @pytest.mark.redismod + @skip_if_server_version_lt("7.9.0") + async def test_config_set_for_search_module(self, r: redis.Redis): + config = await r.config_get("*") + initial_default_search_dialect = config["search-default-dialect"] + try: + default_dialect_new = "3" + assert await r.config_set("search-default-dialect", default_dialect_new) + assert (await r.config_get("*"))[ + "search-default-dialect" + ] == default_dialect_new + assert ( + (await r.ft().config_get("*"))[b"DEFAULT_DIALECT"] + ).decode() == default_dialect_new + except AssertionError as ex: + raise ex + finally: + assert await r.config_set( + "search-default-dialect", initial_default_search_dialect + ) + with pytest.raises(exceptions.ResponseError): + await r.config_set("search-max-doctablesize", 2000000) + @pytest.mark.onlynoncluster async def test_dbsize(self, r: redis.Redis): await r.set("a", "foo") diff --git a/tests/test_commands.py b/tests/test_commands.py index f83fe76aa9..24c320a3f3 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -18,6 +18,9 @@ parse_info, ) from redis.client import EMPTY_RESPONSE, NEVER_DECODE +from redis.commands.json.path import Path +from redis.commands.search.field import TextField +from redis.commands.search.query import Query from .conftest import ( _get_client, @@ -144,12 +147,65 @@ def test_acl_cat_no_category(self, r): assert isinstance(categories, list) assert "read" in categories or b"read" in categories + @pytest.mark.redismod + @skip_if_server_version_lt("7.9.0") + def test_acl_cat_contain_modules_no_category(self, r): + modules_list = [ + "search", + "bloom", + "json", + "cuckoo", + "timeseries", + "cms", + "topk", + "tdigest", + ] + categories = r.acl_cat() + assert isinstance(categories, list) + for module_cat in modules_list: + assert module_cat in categories or module_cat.encode() in categories + @skip_if_server_version_lt("6.0.0") def test_acl_cat_with_category(self, r): commands = r.acl_cat("read") assert isinstance(commands, list) assert "get" in commands or b"get" in commands + @pytest.mark.redismod + @skip_if_server_version_lt("7.9.0") + def test_acl_modules_cat_with_category(self, r): + search_commands = r.acl_cat("search") + assert isinstance(search_commands, list) + assert "FT.SEARCH" in search_commands or b"FT.SEARCH" in search_commands + + bloom_commands = r.acl_cat("bloom") + assert isinstance(bloom_commands, list) + assert "bf.add" in bloom_commands or b"bf.add" in bloom_commands + + json_commands = r.acl_cat("json") + assert isinstance(json_commands, list) + assert "json.get" in json_commands or b"json.get" in json_commands + + cuckoo_commands = r.acl_cat("cuckoo") + assert isinstance(cuckoo_commands, list) + assert "cf.insert" in cuckoo_commands or b"cf.insert" in cuckoo_commands + + cms_commands = r.acl_cat("cms") + assert isinstance(cms_commands, list) + assert "cms.query" in cms_commands or b"cms.query" in cms_commands + + topk_commands = r.acl_cat("topk") + assert isinstance(topk_commands, list) + assert "topk.list" in topk_commands or b"topk.list" in topk_commands + + tdigest_commands = r.acl_cat("tdigest") + assert isinstance(tdigest_commands, list) + assert "tdigest.rank" in tdigest_commands or b"tdigest.rank" in tdigest_commands + + timeseries_commands = r.acl_cat("timeseries") + assert isinstance(timeseries_commands, list) + assert "ts.range" in timeseries_commands or b"ts.range" in timeseries_commands + @skip_if_server_version_lt("7.0.0") @skip_if_redis_enterprise() def test_acl_dryrun(self, r, request): @@ -458,6 +514,123 @@ def test_acl_whoami(self, r): username = r.acl_whoami() assert isinstance(username, (str, bytes)) + @pytest.mark.redismod + @skip_if_server_version_lt("7.9.0") + def test_acl_modules_commands(self, r, request): + default_username = "default" + username = "redis-py-user" + password = "pass-for-test-user" + + def teardown(): + r.auth("", default_username) + r.acl_deluser(username) + + request.addfinalizer(teardown) + + r.ft().create_index((TextField("txt"),)) + r.hset("doc1", mapping={"txt": "foo baz"}) + r.hset("doc2", mapping={"txt": "foo bar"}) + + r.acl_setuser( + username, + enabled=True, + reset=True, + passwords=[f"+{password}"], + categories=["-all"], + commands=[ + "+FT.SEARCH", + "-FT.DROPINDEX", + "+json.set", + "+json.get", + "-json.clear", + "+bf.reserve", + "-bf.info", + "+cf.reserve", + "+cms.initbydim", + "+topk.reserve", + "+tdigest.create", + "+ts.create", + "-ts.info", + ], + keys=["*"], + ) + r.auth(password, username) + + assert r.ft().search(Query("foo ~bar")) + with pytest.raises(exceptions.NoPermissionError): + r.ft().dropindex() + + r.json().set("foo", Path.root_path(), "bar") + assert r.json().get("foo") == "bar" + with pytest.raises(exceptions.NoPermissionError): + r.json().clear("foo") + + assert r.bf().create("bloom", 0.01, 1000) + assert r.cf().create("cuckoo", 1000) + assert r.cms().initbydim("cmsDim", 100, 5) + assert r.topk().reserve("topk", 5, 100, 5, 0.9) + assert r.tdigest().create("to-tDigest", 10) + with pytest.raises(exceptions.NoPermissionError): + r.bf().info("bloom") + + assert r.ts().create(1, labels={"Redis": "Labs"}) + with pytest.raises(exceptions.NoPermissionError): + r.ts().info(1) + + @pytest.mark.redismod + @skip_if_server_version_lt("7.9.0") + def test_acl_modules_category_commands(self, r, request): + default_username = "default" + username = "redis-py-user" + password = "pass-for-test-user" + + def teardown(): + r.auth("", default_username) + r.acl_deluser(username) + + request.addfinalizer(teardown) + + # validate modules categories acl config + r.acl_setuser( + username, + enabled=True, + reset=True, + passwords=[f"+{password}"], + categories=[ + "-all", + "+@search", + "+@json", + "+@bloom", + "+@cuckoo", + "+@topk", + "+@cms", + "+@timeseries", + "+@tdigest", + ], + keys=["*"], + ) + r.ft().create_index((TextField("txt"),)) + r.hset("doc1", mapping={"txt": "foo baz"}) + r.hset("doc2", mapping={"txt": "foo bar"}) + + r.auth(password, username) + + assert r.ft().search(Query("foo ~bar")) + assert r.ft().dropindex() + + assert r.json().set("foo", Path.root_path(), "bar") + assert r.json().get("foo") == "bar" + + assert r.bf().create("bloom", 0.01, 1000) + assert r.bf().info("bloom") + assert r.cf().create("cuckoo", 1000) + assert r.cms().initbydim("cmsDim", 100, 5) + assert r.topk().reserve("topk", 5, 100, 5, 0.9) + assert r.tdigest().create("to-tDigest", 10) + + assert r.ts().create(1, labels={"Redis": "Labs"}) + assert r.ts().info(1) + @pytest.mark.onlynoncluster def test_client_list(self, r): clients = r.client_list() @@ -824,6 +997,42 @@ def test_config_set_multi_params(self, r: redis.Redis): assert r.config_get()["timeout"] == "0" assert r.config_get()["maxmemory"] == "0" + @pytest.mark.redismod + @skip_if_server_version_lt("7.9.0") + def test_config_get_for_modules(self, r: redis.Redis): + search_module_configs = r.config_get("search-*") + assert "search-timeout" in search_module_configs + + ts_module_configs = r.config_get("ts-*") + assert "ts-retention-policy" in ts_module_configs + + bf_module_configs = r.config_get("bf-*") + assert "bf-error-rate" in bf_module_configs + + cf_module_configs = r.config_get("cf-*") + assert "cf-initial-size" in cf_module_configs + + @pytest.mark.redismod + @skip_if_server_version_lt("7.9.0") + def test_config_set_for_search_module(self, r: redis.Redis): + initial_default_search_dialect = r.config_get("*")["search-default-dialect"] + try: + default_dialect_new = "3" + assert r.config_set("search-default-dialect", default_dialect_new) + assert r.config_get("*")["search-default-dialect"] == default_dialect_new + assert ( + r.ft().config_get("*")[b"DEFAULT_DIALECT"] + ).decode() == default_dialect_new + except AssertionError as ex: + raise ex + finally: + assert r.config_set( + "search-default-dialect", initial_default_search_dialect + ) + + with pytest.raises(exceptions.ResponseError): + r.config_set("search-max-doctablesize", 2000000) + @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise() def test_failover(self, r): From f46355713038c93096aa64f3ddff825785feaf8c Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Tue, 4 Feb 2025 08:10:06 +0000 Subject: [PATCH 015/113] Add return type to `close` functions (#3496) --- redis/client.py | 3 +-- redis/cluster.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/redis/client.py b/redis/client.py index a7c1364a10..4fa410c65e 100755 --- a/redis/client.py +++ b/redis/client.py @@ -550,7 +550,7 @@ def __exit__(self, exc_type, exc_value, traceback): def __del__(self): self.close() - def close(self): + def close(self) -> None: # In case a connection property does not yet exist # (due to a crash earlier in the Redis() constructor), return # immediately as there is nothing to clean-up. @@ -1551,7 +1551,6 @@ def _disconnect_raise_reset( conn.retry_on_error is None or isinstance(error, tuple(conn.retry_on_error)) is False ): - self.reset() raise error diff --git a/redis/cluster.py b/redis/cluster.py index 8718493759..2fff761f95 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1227,7 +1227,7 @@ def _execute_command(self, target_node, *args, **kwargs): raise ClusterError("TTL exhausted.") - def close(self): + def close(self) -> None: try: with self._lock: if self.nodes_manager: @@ -1669,7 +1669,7 @@ def initialize(self): # If initialize was called after a MovedError, clear it self._moved_exception = None - def close(self): + def close(self) -> None: self.default_node = None for node in self.nodes_cache.values(): if node.redis_connection: From 2f0fb9ab02c4c393dcbebfb4538e5ec7866cf21a Mon Sep 17 00:00:00 2001 From: David Dougherty Date: Tue, 4 Feb 2025 02:41:16 -0800 Subject: [PATCH 016/113] Update Python imports in doctests (index_definition => indexDefinition) (#3490) --- doctests/query_agg.py | 2 +- doctests/query_combined.py | 2 +- doctests/query_em.py | 2 +- doctests/query_ft.py | 2 +- doctests/query_geo.py | 2 +- doctests/query_range.py | 2 +- doctests/search_quickstart.py | 2 +- doctests/search_vss.py | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/doctests/query_agg.py b/doctests/query_agg.py index 4d81ddbcda..4fa8f14b84 100644 --- a/doctests/query_agg.py +++ b/doctests/query_agg.py @@ -6,7 +6,7 @@ from redis.commands.search import Search from redis.commands.search.aggregation import AggregateRequest from redis.commands.search.field import NumericField, TagField -from redis.commands.search.index_definition import IndexDefinition, IndexType +from redis.commands.search.indexDefinition import IndexDefinition, IndexType import redis.commands.search.reducers as reducers r = redis.Redis(decode_responses=True) diff --git a/doctests/query_combined.py b/doctests/query_combined.py index e6dd5a2cb5..a17f19417c 100644 --- a/doctests/query_combined.py +++ b/doctests/query_combined.py @@ -6,7 +6,7 @@ import warnings from redis.commands.json.path import Path from redis.commands.search.field import NumericField, TagField, TextField, VectorField -from redis.commands.search.index_definition import IndexDefinition, IndexType +from redis.commands.search.indexDefinition import IndexDefinition, IndexType from redis.commands.search.query import Query from sentence_transformers import SentenceTransformer diff --git a/doctests/query_em.py b/doctests/query_em.py index 91cc5ae940..a00ff11150 100644 --- a/doctests/query_em.py +++ b/doctests/query_em.py @@ -4,7 +4,7 @@ import redis from redis.commands.json.path import Path from redis.commands.search.field import TextField, NumericField, TagField -from redis.commands.search.index_definition import IndexDefinition, IndexType +from redis.commands.search.indexDefinition import IndexDefinition, IndexType from redis.commands.search.query import NumericFilter, Query r = redis.Redis(decode_responses=True) diff --git a/doctests/query_ft.py b/doctests/query_ft.py index 6272cdab25..182a5b2bd3 100644 --- a/doctests/query_ft.py +++ b/doctests/query_ft.py @@ -5,7 +5,7 @@ import redis from redis.commands.json.path import Path from redis.commands.search.field import TextField, NumericField, TagField -from redis.commands.search.index_definition import IndexDefinition, IndexType +from redis.commands.search.indexDefinition import IndexDefinition, IndexType from redis.commands.search.query import NumericFilter, Query r = redis.Redis(decode_responses=True) diff --git a/doctests/query_geo.py b/doctests/query_geo.py index ed8c9a5f99..dcb7db6ee7 100644 --- a/doctests/query_geo.py +++ b/doctests/query_geo.py @@ -5,7 +5,7 @@ import redis from redis.commands.json.path import Path from redis.commands.search.field import GeoField, GeoShapeField -from redis.commands.search.index_definition import IndexDefinition, IndexType +from redis.commands.search.indexDefinition import IndexDefinition, IndexType from redis.commands.search.query import Query r = redis.Redis(decode_responses=True) diff --git a/doctests/query_range.py b/doctests/query_range.py index 674afc492a..4ef957acfb 100644 --- a/doctests/query_range.py +++ b/doctests/query_range.py @@ -5,7 +5,7 @@ import redis from redis.commands.json.path import Path from redis.commands.search.field import TextField, NumericField, TagField -from redis.commands.search.index_definition import IndexDefinition, IndexType +from redis.commands.search.indexDefinition import IndexDefinition, IndexType from redis.commands.search.query import NumericFilter, Query r = redis.Redis(decode_responses=True) diff --git a/doctests/search_quickstart.py b/doctests/search_quickstart.py index cde4caa84a..e190393b16 100644 --- a/doctests/search_quickstart.py +++ b/doctests/search_quickstart.py @@ -10,7 +10,7 @@ import redis.commands.search.reducers as reducers from redis.commands.json.path import Path from redis.commands.search.field import NumericField, TagField, TextField -from redis.commands.search.index_definition import IndexDefinition, IndexType +from redis.commands.search.indexDefinition import IndexDefinition, IndexType from redis.commands.search.query import Query # HIDE_END diff --git a/doctests/search_vss.py b/doctests/search_vss.py index a1132971db..8b4884727a 100644 --- a/doctests/search_vss.py +++ b/doctests/search_vss.py @@ -20,7 +20,7 @@ TextField, VectorField, ) -from redis.commands.search.index_definition import IndexDefinition, IndexType +from redis.commands.search.indexDefinition import IndexDefinition, IndexType from redis.commands.search.query import Query from sentence_transformers import SentenceTransformer From 9cadd6dd654b45cb49c3e4405e6a054885a8d9ee Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Tue, 4 Feb 2025 14:19:46 +0000 Subject: [PATCH 017/113] Add types to ConnectionPool.from_url (#3495) Co-authored-by: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> --- redis/connection.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index d905c6481b..d47f46590b 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -9,7 +9,7 @@ from itertools import chain from queue import Empty, Full, LifoQueue from time import time -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union from urllib.parse import parse_qs, unquote, urlparse from redis.cache import ( @@ -1263,6 +1263,9 @@ def parse_url(url): return kwargs +_CP = TypeVar("_CP", bound="ConnectionPool") + + class ConnectionPool: """ Create a connection pool. ``If max_connections`` is set, then this @@ -1278,7 +1281,7 @@ class ConnectionPool: """ @classmethod - def from_url(cls, url, **kwargs): + def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: """ Return a connection pool configured from the given URL. From c07b599bf175bda7f81bbfc027039ad85f833d41 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Thu, 6 Feb 2025 08:18:17 +0000 Subject: [PATCH 018/113] Add types to execute method of pipelines (#3494) Co-authored-by: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> --- redis/asyncio/client.py | 2 +- redis/client.py | 2 +- redis/cluster.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 9478d539d7..7c17938714 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -1554,7 +1554,7 @@ async def _disconnect_raise_reset(self, conn: Connection, error: Exception): await self.reset() raise - async def execute(self, raise_on_error: bool = True): + async def execute(self, raise_on_error: bool = True) -> List[Any]: """Execute all the commands in the current pipeline""" stack = self.command_stack if not stack and not self.watching: diff --git a/redis/client.py b/redis/client.py index 4fa410c65e..88bc6bf475 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1554,7 +1554,7 @@ def _disconnect_raise_reset( self.reset() raise error - def execute(self, raise_on_error=True): + def execute(self, raise_on_error: bool = True) -> List[Any]: """Execute all the commands in the current pipeline""" stack = self.command_stack if not stack and not self.watching: diff --git a/redis/cluster.py b/redis/cluster.py index 2fff761f95..6c6cfbf114 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -2050,7 +2050,7 @@ def annotate_exception(self, exception, number, command): ) exception.args = (msg,) + exception.args[1:] - def execute(self, raise_on_error=True): + def execute(self, raise_on_error: bool = True) -> List[Any]: """ Execute all the commands in the current pipeline """ From 8091bdbb5b232a277ed7f18d3b331697c5ed2c5f Mon Sep 17 00:00:00 2001 From: andy-stark-redis <164213578+andy-stark-redis@users.noreply.github.com> Date: Thu, 6 Feb 2025 12:31:00 +0000 Subject: [PATCH 019/113] DOC-4796 fixed capped lists example (#3493) Co-authored-by: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> --- doctests/dt_list.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doctests/dt_list.py b/doctests/dt_list.py index be8a4b8562..111da8eb08 100644 --- a/doctests/dt_list.py +++ b/doctests/dt_list.py @@ -165,20 +165,20 @@ # REMOVE_END # STEP_START ltrim -res27 = r.lpush("bikes:repairs", "bike:1", "bike:2", "bike:3", "bike:4", "bike:5") +res27 = r.rpush("bikes:repairs", "bike:1", "bike:2", "bike:3", "bike:4", "bike:5") print(res27) # >>> 5 res28 = r.ltrim("bikes:repairs", 0, 2) print(res28) # >>> True res29 = r.lrange("bikes:repairs", 0, -1) -print(res29) # >>> ['bike:5', 'bike:4', 'bike:3'] +print(res29) # >>> ['bike:1', 'bike:2', 'bike:3'] # STEP_END # REMOVE_START assert res27 == 5 assert res28 is True -assert res29 == ["bike:5", "bike:4", "bike:3"] +assert res29 == ["bike:1", "bike:2", "bike:3"] r.delete("bikes:repairs") # REMOVE_END From 8604a505ef2b6f2ec0c653aec293639b57124e13 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Fri, 7 Feb 2025 13:12:24 +0200 Subject: [PATCH 020/113] Adding deprecation messages for the exposed in search module commands: FT.INFO, just for async client: FT.CONFIG GET and FT.CONFIG SET (#3499) --- redis/commands/search/commands.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index 9d9ef42415..2158c01ba9 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -442,6 +442,10 @@ def get(self, *ids): return self.execute_command(MGET_CMD, self.index_name, *ids) + @deprecated_function( + version="8.0.0", + reason="deprecated since Redis 8.0, call info from core module instead", + ) def info(self): """ Get info an stats about the the current index, including the number of @@ -912,6 +916,10 @@ def syndump(self): class AsyncSearchCommands(SearchCommands): + @deprecated_function( + version="8.0.0", + reason="deprecated since Redis 8.0, call info from core module instead", + ) async def info(self): """ Get info an stats about the the current index, including the number of @@ -1015,6 +1023,10 @@ async def spellcheck(self, query, distance=None, include=None, exclude=None): return self._parse_results(SPELLCHECK_CMD, res) + @deprecated_function( + version="8.0.0", + reason="deprecated since Redis 8.0, call config_set from core module instead", + ) async def config_set(self, option: str, value: str) -> bool: """Set runtime configuration option. @@ -1029,6 +1041,10 @@ async def config_set(self, option: str, value: str) -> bool: raw = await self.execute_command(*cmd) return raw == "OK" + @deprecated_function( + version="8.0.0", + reason="deprecated since Redis 8.0, call config_get from core module instead", + ) async def config_get(self, option: str) -> str: """Get runtime configuration option value. From 996c48bef1e3e6012f220f1d11d24892a3a168f5 Mon Sep 17 00:00:00 2001 From: Artur Mostowski Date: Mon, 10 Feb 2025 15:56:11 +0100 Subject: [PATCH 021/113] typing for client __init__ (#3357) * typing for client __init__ * typing with string literals * retry_on_error more specific typing * retry typing * fix lint --------- Co-authored-by: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> --- redis/client.py | 99 ++++++++++++++++++++++++++++--------------------- 1 file changed, 57 insertions(+), 42 deletions(-) diff --git a/redis/client.py b/redis/client.py index 88bc6bf475..5a9f4fafb5 100755 --- a/redis/client.py +++ b/redis/client.py @@ -4,7 +4,17 @@ import time import warnings from itertools import chain -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Mapping, + Optional, + Type, + Union, +) from redis._parsers.encoders import Encoder from redis._parsers.helpers import ( @@ -53,6 +63,11 @@ str_if_bytes, ) +if TYPE_CHECKING: + import ssl + + import OpenSSL + SYM_EMPTY = b"" EMPTY_RESPONSE = "EMPTY_RESPONSE" @@ -175,47 +190,47 @@ def from_pool( def __init__( self, - host="localhost", - port=6379, - db=0, - password=None, - socket_timeout=None, - socket_connect_timeout=None, - socket_keepalive=None, - socket_keepalive_options=None, - connection_pool=None, - unix_socket_path=None, - encoding="utf-8", - encoding_errors="strict", - charset=None, - errors=None, - decode_responses=False, - retry_on_timeout=False, - retry_on_error=None, - ssl=False, - ssl_keyfile=None, - ssl_certfile=None, - ssl_cert_reqs="required", - ssl_ca_certs=None, - ssl_ca_path=None, - ssl_ca_data=None, - ssl_check_hostname=False, - ssl_password=None, - ssl_validate_ocsp=False, - ssl_validate_ocsp_stapled=False, - ssl_ocsp_context=None, - ssl_ocsp_expected_cert=None, - ssl_min_version=None, - ssl_ciphers=None, - max_connections=None, - single_connection_client=False, - health_check_interval=0, - client_name=None, - lib_name="redis-py", - lib_version=get_lib_version(), - username=None, - retry=None, - redis_connect_func=None, + host: str = "localhost", + port: int = 6379, + db: int = 0, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + socket_keepalive: Optional[bool] = None, + socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, + connection_pool: Optional[ConnectionPool] = None, + unix_socket_path: Optional[str] = None, + encoding: str = "utf-8", + encoding_errors: str = "strict", + charset: Optional[str] = None, + errors: Optional[str] = None, + decode_responses: bool = False, + retry_on_timeout: bool = False, + retry_on_error: Optional[List[Type[Exception]]] = None, + ssl: bool = False, + ssl_keyfile: Optional[str] = None, + ssl_certfile: Optional[str] = None, + ssl_cert_reqs: str = "required", + ssl_ca_certs: Optional[str] = None, + ssl_ca_path: Optional[str] = None, + ssl_ca_data: Optional[str] = None, + ssl_check_hostname: bool = False, + ssl_password: Optional[str] = None, + ssl_validate_ocsp: bool = False, + ssl_validate_ocsp_stapled: bool = False, + ssl_ocsp_context: Optional["OpenSSL.SSL.Context"] = None, + ssl_ocsp_expected_cert: Optional[str] = None, + ssl_min_version: Optional["ssl.TLSVersion"] = None, + ssl_ciphers: Optional[str] = None, + max_connections: Optional[int] = None, + single_connection_client: bool = False, + health_check_interval: int = 0, + client_name: Optional[str] = None, + lib_name: Optional[str] = "redis-py", + lib_version: Optional[str] = get_lib_version(), + username: Optional[str] = None, + retry: Optional[Retry] = None, + redis_connect_func: Optional[Callable[[], None]] = None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, cache: Optional[CacheInterface] = None, From fd1e205bbe0cf07ee469a48501148967dd5d3e23 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Mon, 10 Feb 2025 19:16:19 +0200 Subject: [PATCH 022/113] Since commands info and ft.info do not return redundant information - ft.info will not be deprecated in current release. (#3500) --- redis/commands/search/commands.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index 2158c01ba9..e5e78578be 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -442,10 +442,6 @@ def get(self, *ids): return self.execute_command(MGET_CMD, self.index_name, *ids) - @deprecated_function( - version="8.0.0", - reason="deprecated since Redis 8.0, call info from core module instead", - ) def info(self): """ Get info an stats about the the current index, including the number of @@ -916,10 +912,6 @@ def syndump(self): class AsyncSearchCommands(SearchCommands): - @deprecated_function( - version="8.0.0", - reason="deprecated since Redis 8.0, call info from core module instead", - ) async def info(self): """ Get info an stats about the the current index, including the number of From f614abf2148e1b0c730fe6d3c9fbbcd9105eea91 Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Tue, 11 Feb 2025 13:15:14 +0200 Subject: [PATCH 023/113] test: Updated CredentialProvider test infrastructure (#3502) * test: Updated CredentialProvider test infrastructure * Added linter exclusion * Updated dev dependency * Codestyle fixes * Updated async test infra * Added missing constant --- dev_requirements.txt | 2 +- tests/conftest.py | 99 +++++++++++++++++++++------------- tests/test_asyncio/conftest.py | 97 ++++++++++++++++++++------------- 3 files changed, 124 insertions(+), 74 deletions(-) diff --git a/dev_requirements.txt b/dev_requirements.txt index be74470ec2..728536d6fb 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -16,4 +16,4 @@ uvloop vulture>=2.3.0 wheel>=0.30.0 numpy>=1.24.0 -redis-entraid==0.1.0b1 +redis-entraid==0.3.0b1 diff --git a/tests/conftest.py b/tests/conftest.py index a900cea8bf..fc732c0d72 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ import time from datetime import datetime, timezone from enum import Enum -from typing import Callable, TypeVar +from typing import Callable, TypeVar, Union from unittest import mock from unittest.mock import Mock from urllib.parse import urlparse @@ -17,6 +17,7 @@ from redis import Sentinel from redis.auth.idp import IdentityProviderInterface from redis.auth.token import JWToken +from redis.auth.token_manager import RetryPolicy, TokenManagerConfig from redis.backoff import NoBackoff from redis.cache import ( CacheConfig, @@ -29,12 +30,21 @@ from redis.credentials import CredentialProvider from redis.exceptions import RedisClusterException from redis.retry import Retry -from redis_entraid.cred_provider import EntraIdCredentialsProvider, TokenAuthConfig +from redis_entraid.cred_provider import ( + DEFAULT_DELAY_IN_MS, + DEFAULT_EXPIRATION_REFRESH_RATIO, + DEFAULT_LOWER_REFRESH_BOUND_MILLIS, + DEFAULT_MAX_ATTEMPTS, + DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, + EntraIdCredentialsProvider, +) from redis_entraid.identity_provider import ( ManagedIdentityIdType, + ManagedIdentityProviderConfig, ManagedIdentityType, - create_provider_from_managed_identity, - create_provider_from_service_principal, + ServicePrincipalIdentityProviderConfig, + _create_provider_from_managed_identity, + _create_provider_from_service_principal, ) from tests.ssl_utils import get_tls_certificates @@ -623,17 +633,33 @@ def identity_provider(request) -> IdentityProviderInterface: return mock_identity_provider() auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) + config = get_identity_provider_config(request=request) if auth_type == "MANAGED_IDENTITY": - return _get_managed_identity_provider(request) + return _create_provider_from_managed_identity(config) + + return _create_provider_from_service_principal(config) - return _get_service_principal_provider(request) +def get_identity_provider_config( + request, +) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]: + if hasattr(request, "param"): + kwargs = request.param.get("idp_kwargs", {}) + else: + kwargs = {} + + auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) -def _get_managed_identity_provider(request): - authority = os.getenv("AZURE_AUTHORITY") + if auth_type == AuthType.MANAGED_IDENTITY: + return _get_managed_identity_provider_config(request) + + return _get_service_principal_provider_config(request) + + +def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig: resource = os.getenv("AZURE_RESOURCE") - id_value = os.getenv("AZURE_ID_VALUE", None) + id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None) if hasattr(request, "param"): kwargs = request.param.get("idp_kwargs", {}) @@ -641,23 +667,24 @@ def _get_managed_identity_provider(request): kwargs = {} identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED) - id_type = kwargs.pop("id_type", ManagedIdentityIdType.CLIENT_ID) + id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID) - return create_provider_from_managed_identity( + return ManagedIdentityProviderConfig( identity_type=identity_type, resource=resource, id_type=id_type, id_value=id_value, - authority=authority, - **kwargs, + kwargs=kwargs, ) -def _get_service_principal_provider(request): +def _get_service_principal_provider_config( + request, +) -> ServicePrincipalIdentityProviderConfig: client_id = os.getenv("AZURE_CLIENT_ID") client_credential = os.getenv("AZURE_CLIENT_SECRET") - authority = os.getenv("AZURE_AUTHORITY") - scopes = os.getenv("AZURE_REDIS_SCOPES", []) + tenant_id = os.getenv("AZURE_TENANT_ID") + scopes = os.getenv("AZURE_REDIS_SCOPES", None) if hasattr(request, "param"): kwargs = request.param.get("idp_kwargs", {}) @@ -671,14 +698,14 @@ def _get_service_principal_provider(request): if isinstance(scopes, str): scopes = scopes.split(",") - return create_provider_from_service_principal( + return ServicePrincipalIdentityProviderConfig( client_id=client_id, client_credential=client_credential, scopes=scopes, timeout=timeout, token_kwargs=token_kwargs, - authority=authority, - **kwargs, + tenant_id=tenant_id, + app_kwargs=kwargs, ) @@ -690,31 +717,29 @@ def get_credential_provider(request) -> CredentialProvider: return cred_provider_class(**cred_provider_kwargs) idp = identity_provider(request) - initial_delay_in_ms = cred_provider_kwargs.get("initial_delay_in_ms", 0) - block_for_initial = cred_provider_kwargs.get("block_for_initial", False) expiration_refresh_ratio = cred_provider_kwargs.get( - "expiration_refresh_ratio", TokenAuthConfig.DEFAULT_EXPIRATION_REFRESH_RATIO + "expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO ) lower_refresh_bound_millis = cred_provider_kwargs.get( - "lower_refresh_bound_millis", TokenAuthConfig.DEFAULT_LOWER_REFRESH_BOUND_MILLIS - ) - max_attempts = cred_provider_kwargs.get( - "max_attempts", TokenAuthConfig.DEFAULT_MAX_ATTEMPTS + "lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS ) - delay_in_ms = cred_provider_kwargs.get( - "delay_in_ms", TokenAuthConfig.DEFAULT_DELAY_IN_MS + max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS) + delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS) + + token_mgr_config = TokenManagerConfig( + expiration_refresh_ratio=expiration_refresh_ratio, + lower_refresh_bound_millis=lower_refresh_bound_millis, + token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, # noqa + retry_policy=RetryPolicy( + max_attempts=max_attempts, + delay_in_ms=delay_in_ms, + ), ) - auth_config = TokenAuthConfig(idp) - auth_config.expiration_refresh_ratio = expiration_refresh_ratio - auth_config.lower_refresh_bound_millis = lower_refresh_bound_millis - auth_config.max_attempts = max_attempts - auth_config.delay_in_ms = delay_in_ms - return EntraIdCredentialsProvider( - config=auth_config, - initial_delay_in_ms=initial_delay_in_ms, - block_for_initial=block_for_initial, + identity_provider=idp, + token_manager_config=token_mgr_config, + initial_delay_in_ms=delay_in_ms, ) diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 8833426af1..fb6c51140e 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -17,14 +17,24 @@ from redis.asyncio.retry import Retry from redis.auth.idp import IdentityProviderInterface from redis.auth.token import JWToken +from redis.auth.token_manager import RetryPolicy, TokenManagerConfig from redis.backoff import NoBackoff from redis.credentials import CredentialProvider -from redis_entraid.cred_provider import EntraIdCredentialsProvider, TokenAuthConfig +from redis_entraid.cred_provider import ( + DEFAULT_DELAY_IN_MS, + DEFAULT_EXPIRATION_REFRESH_RATIO, + DEFAULT_LOWER_REFRESH_BOUND_MILLIS, + DEFAULT_MAX_ATTEMPTS, + DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, + EntraIdCredentialsProvider, +) from redis_entraid.identity_provider import ( ManagedIdentityIdType, + ManagedIdentityProviderConfig, ManagedIdentityType, - create_provider_from_managed_identity, - create_provider_from_service_principal, + ServicePrincipalIdentityProviderConfig, + _create_provider_from_managed_identity, + _create_provider_from_service_principal, ) from tests.conftest import REDIS_INFO @@ -255,17 +265,33 @@ def identity_provider(request) -> IdentityProviderInterface: return mock_identity_provider() auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) + config = get_identity_provider_config(request=request) if auth_type == "MANAGED_IDENTITY": - return _get_managed_identity_provider(request) + return _create_provider_from_managed_identity(config) + + return _create_provider_from_service_principal(config) + + +def get_identity_provider_config( + request, +) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]: + if hasattr(request, "param"): + kwargs = request.param.get("idp_kwargs", {}) + else: + kwargs = {} - return _get_service_principal_provider(request) + auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) + + if auth_type == AuthType.MANAGED_IDENTITY: + return _get_managed_identity_provider_config(request) + return _get_service_principal_provider_config(request) -def _get_managed_identity_provider(request): - authority = os.getenv("AZURE_AUTHORITY") + +def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig: resource = os.getenv("AZURE_RESOURCE") - id_value = os.getenv("AZURE_ID_VALUE", None) + id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None) if hasattr(request, "param"): kwargs = request.param.get("idp_kwargs", {}) @@ -273,23 +299,24 @@ def _get_managed_identity_provider(request): kwargs = {} identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED) - id_type = kwargs.pop("id_type", ManagedIdentityIdType.CLIENT_ID) + id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID) - return create_provider_from_managed_identity( + return ManagedIdentityProviderConfig( identity_type=identity_type, resource=resource, id_type=id_type, id_value=id_value, - authority=authority, - **kwargs, + kwargs=kwargs, ) -def _get_service_principal_provider(request): +def _get_service_principal_provider_config( + request, +) -> ServicePrincipalIdentityProviderConfig: client_id = os.getenv("AZURE_CLIENT_ID") client_credential = os.getenv("AZURE_CLIENT_SECRET") - authority = os.getenv("AZURE_AUTHORITY") - scopes = os.getenv("AZURE_REDIS_SCOPES", []) + tenant_id = os.getenv("AZURE_TENANT_ID") + scopes = os.getenv("AZURE_REDIS_SCOPES", None) if hasattr(request, "param"): kwargs = request.param.get("idp_kwargs", {}) @@ -303,14 +330,14 @@ def _get_service_principal_provider(request): if isinstance(scopes, str): scopes = scopes.split(",") - return create_provider_from_service_principal( + return ServicePrincipalIdentityProviderConfig( client_id=client_id, client_credential=client_credential, scopes=scopes, timeout=timeout, token_kwargs=token_kwargs, - authority=authority, - **kwargs, + tenant_id=tenant_id, + app_kwargs=kwargs, ) @@ -322,31 +349,29 @@ def get_credential_provider(request) -> CredentialProvider: return cred_provider_class(**cred_provider_kwargs) idp = identity_provider(request) - initial_delay_in_ms = cred_provider_kwargs.get("initial_delay_in_ms", 0) - block_for_initial = cred_provider_kwargs.get("block_for_initial", False) expiration_refresh_ratio = cred_provider_kwargs.get( - "expiration_refresh_ratio", TokenAuthConfig.DEFAULT_EXPIRATION_REFRESH_RATIO + "expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO ) lower_refresh_bound_millis = cred_provider_kwargs.get( - "lower_refresh_bound_millis", TokenAuthConfig.DEFAULT_LOWER_REFRESH_BOUND_MILLIS - ) - max_attempts = cred_provider_kwargs.get( - "max_attempts", TokenAuthConfig.DEFAULT_MAX_ATTEMPTS + "lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS ) - delay_in_ms = cred_provider_kwargs.get( - "delay_in_ms", TokenAuthConfig.DEFAULT_DELAY_IN_MS + max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS) + delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS) + + token_mgr_config = TokenManagerConfig( + expiration_refresh_ratio=expiration_refresh_ratio, + lower_refresh_bound_millis=lower_refresh_bound_millis, + token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, # noqa + retry_policy=RetryPolicy( + max_attempts=max_attempts, + delay_in_ms=delay_in_ms, + ), ) - auth_config = TokenAuthConfig(idp) - auth_config.expiration_refresh_ratio = expiration_refresh_ratio - auth_config.lower_refresh_bound_millis = lower_refresh_bound_millis - auth_config.max_attempts = max_attempts - auth_config.delay_in_ms = delay_in_ms - return EntraIdCredentialsProvider( - config=auth_config, - initial_delay_in_ms=initial_delay_in_ms, - block_for_initial=block_for_initial, + identity_provider=idp, + token_manager_config=token_mgr_config, + initial_delay_in_ms=delay_in_ms, ) From c5a42fc4c9d60d39220d2d3f47ea09f667b6bb1e Mon Sep 17 00:00:00 2001 From: Kevin Johnson Date: Tue, 11 Feb 2025 10:16:03 -0500 Subject: [PATCH 024/113] Fixes minor grammar nit in documentation. (#3354) --- docs/opentelemetry.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/opentelemetry.rst b/docs/opentelemetry.rst index c0285e6761..edf6a42071 100644 --- a/docs/opentelemetry.rst +++ b/docs/opentelemetry.rst @@ -46,7 +46,7 @@ You can then use it to instrument code like this: RedisInstrumentor().instrument() -Once the code is patched, you can use redis-py as usually: +Once the code is patched, you can use redis-py as usual: .. code-block:: python From a8ff646f15e1111eabf56121cfecc8597e11df08 Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Thu, 13 Feb 2025 10:20:36 +0200 Subject: [PATCH 025/113] maintenance: Adding Python 3.13 compatibility (#3510) --- .github/workflows/integration.yaml | 10 +++++----- dev_requirements.txt | 2 +- setup.py | 1 + 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index c4548c21ef..4467de4f62 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -75,7 +75,7 @@ jobs: fail-fast: false matrix: redis-version: ['8.0-M04-pre', '${{ needs.redis_version.outputs.CURRENT }}', '7.2.7', '6.2.17'] - python-version: ['3.8', '3.12'] + python-version: ['3.8', '3.13'] parser-backend: ['plain'] event-loop: ['asyncio'] env: @@ -99,7 +99,7 @@ jobs: fail-fast: false matrix: redis-version: [ '${{ needs.redis_version.outputs.CURRENT }}' ] - python-version: ['3.9', '3.10', '3.11', 'pypy-3.9', 'pypy-3.10'] + python-version: ['3.9', '3.10', '3.11', '3.12', 'pypy-3.9', 'pypy-3.10'] parser-backend: [ 'plain' ] event-loop: [ 'asyncio' ] env: @@ -123,7 +123,7 @@ jobs: fail-fast: false matrix: redis-version: [ '${{ needs.redis_version.outputs.CURRENT }}' ] - python-version: [ '3.8', '3.12'] + python-version: [ '3.8', '3.13'] parser-backend: [ 'hiredis' ] hiredis-version: [ '>=3.0.0', '<3.0.0' ] event-loop: [ 'asyncio' ] @@ -149,7 +149,7 @@ jobs: fail-fast: false matrix: redis-version: [ '${{ needs.redis_version.outputs.CURRENT }}' ] - python-version: [ '3.8', '3.12' ] + python-version: [ '3.8', '3.13' ] parser-backend: [ 'plain' ] event-loop: [ 'uvloop' ] env: @@ -192,7 +192,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', 'pypy-3.9', 'pypy-3.10'] + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13', 'pypy-3.9', 'pypy-3.10'] steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 diff --git a/dev_requirements.txt b/dev_requirements.txt index 728536d6fb..619fbf479c 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -9,7 +9,7 @@ packaging>=20.4 pytest pytest-asyncio>=0.23.0,<0.24.0 pytest-cov -pytest-profiling==1.7.0 +pytest-profiling==1.8.1 pytest-timeout ujson>=4.2.0 uvloop diff --git a/setup.py b/setup.py index 02853251b2..74f6fdafb7 100644 --- a/setup.py +++ b/setup.py @@ -54,6 +54,7 @@ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ], From 6129912ae4ed00aba8141a902c888a268872f6ab Mon Sep 17 00:00:00 2001 From: Paolo Date: Thu, 13 Feb 2025 10:37:27 +0100 Subject: [PATCH 026/113] Fix Incorrect markdown formatting for example in connection_examples.ipynb (#3504) Co-authored-by: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Co-authored-by: petyaslavova --- docs/examples/connection_examples.ipynb | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/docs/examples/connection_examples.ipynb b/docs/examples/connection_examples.ipynb index fd60e2a495..05c2c82081 100644 --- a/docs/examples/connection_examples.ipynb +++ b/docs/examples/connection_examples.ipynb @@ -75,13 +75,26 @@ ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 4, "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "import redis\n", "\n", "r = redis.Redis(protocol=3)\n", - "rcon.ping()" + "r.ping()" ] }, { @@ -93,7 +106,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -102,7 +115,7 @@ "True" ] }, - "execution_count": 4, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } From 2065ea73a94cbcf3f51c2c7738c289116aef3df9 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Thu, 13 Feb 2025 12:40:53 +0200 Subject: [PATCH 027/113] Adding unit test for core info command related to modules info (#3507) --- tests/test_commands.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_commands.py b/tests/test_commands.py index 24c320a3f3..f89c5f3365 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -1066,6 +1066,21 @@ def test_info_multi_sections(self, r): assert "redis_version" in res assert "connected_clients" in res + @pytest.mark.redismod + @skip_if_server_version_lt("7.9.0") + def test_info_with_modules(self, r: redis.Redis): + res = r.info(section="everything") + assert "modules" in res + assert "search_number_of_indexes" in res + + res = r.info(section="modules") + assert "modules" in res + assert "search_number_of_indexes" in res + + res = r.info(section="search") + assert "modules" not in res + assert "search_number_of_indexes" in res + @pytest.mark.onlynoncluster @skip_if_redis_enterprise() def test_lastsave(self, r): From 83162c21a4c89bcbf3c235bedc12ffb1727b07f2 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Thu, 13 Feb 2025 14:45:40 +0200 Subject: [PATCH 028/113] Replacing the redis and redis-stack-server images with redis-libs-tests image in test infrastructure (#3505) * Replacing the redis image with redis-libs-tests image in test infrastructure * Replacing redis-stack-server image usage with client-libs-test. Fixing lib version in setup.py * Defining stack tag variable for the build and test github action * Removing unused env var from build and test github actions --- .github/actions/run-tests/action.yml | 15 ++---- .github/workflows/integration.yaml | 8 ++-- .gitignore | 3 ++ docker-compose.yml | 70 ++++++++++++---------------- setup.py | 2 +- 5 files changed, 43 insertions(+), 55 deletions(-) diff --git a/.github/actions/run-tests/action.yml b/.github/actions/run-tests/action.yml index 1f9332fb86..ca775f5a5b 100644 --- a/.github/actions/run-tests/action.yml +++ b/.github/actions/run-tests/action.yml @@ -31,14 +31,9 @@ runs: - name: Setup Test environment env: REDIS_VERSION: ${{ inputs.redis-version }} - REDIS_IMAGE: "redis:${{ inputs.redis-version }}" - CLIENT_LIBS_TEST_IMAGE: "redislabs/client-libs-test:${{ inputs.redis-version }}" + CLIENT_LIBS_TEST_IMAGE_TAG: ${{ inputs.redis-version }} run: | set -e - - if [ "${{inputs.redis-version}}" == "8.0-M04-pre" ]; then - export REDIS_IMAGE=redis:8.0-M03 - fi echo "::group::Installing dependencies" pip install -U setuptools wheel @@ -60,13 +55,13 @@ runs: # Mapping of redis version to stack version declare -A redis_stack_version_mapping=( - ["7.4.2"]="7.4.0-v2" - ["7.2.7"]="7.2.0-v14" - ["6.2.17"]="6.2.6-v18" + ["7.4.2"]="rs-7.4.0-v2" + ["7.2.7"]="rs-7.2.0-v14" + ["6.2.17"]="rs-6.2.6-v18" ) if [[ -v redis_stack_version_mapping[$REDIS_VERSION] ]]; then - export REDIS_STACK_IMAGE="redis/redis-stack-server:${redis_stack_version_mapping[$REDIS_VERSION]}" + export CLIENT_LIBS_TEST_STACK_IMAGE_TAG=${redis_stack_version_mapping[$REDIS_VERSION]} echo "REDIS_MOD_URL=redis://127.0.0.1:6479/0" >> $GITHUB_ENV else echo "Version not found in the mapping." diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index 4467de4f62..45e0d5bf8e 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -27,8 +27,7 @@ env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} # this speeds up coverage with Python 3.12: https://github.com/nedbat/coveragepy/issues/1665 COVERAGE_CORE: sysmon - REDIS_IMAGE: redis:latest - REDIS_STACK_IMAGE: redis/redis-stack-server:latest + CURRENT_CLIENT_LIBS_TEST_STACK_IMAGE_TAG: 'rs-7.4.0-v2' CURRENT_REDIS_VERSION: '7.4.2' jobs: @@ -180,9 +179,8 @@ jobs: python-version: 3.9 - name: Run installed unit tests env: - REDIS_VERSION: ${{ env.CURRENT_REDIS_VERSION }} - REDIS_IMAGE: "redis:${{ env.CURRENT_REDIS_VERSION }}" - CLIENT_LIBS_TEST_IMAGE: "redislabs/client-libs-test:${{ env.CURRENT_REDIS_VERSION }}" + CLIENT_LIBS_TEST_IMAGE_TAG: ${{ env.CURRENT_REDIS_VERSION }} + CLIENT_LIBS_TEST_STACK_IMAGE_TAG: ${{ env.CURRENT_CLIENT_LIBS_TEST_STACK_IMAGE_TAG }} run: | bash .github/workflows/install_and_test.sh ${{ matrix.extension }} diff --git a/.gitignore b/.gitignore index ee1bda0fa5..5f77dcfde4 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,6 @@ docker/stunnel/keys /dockers/*/tls/* /dockers/standalone/ /dockers/cluster/ +/dockers/replica/ +/dockers/sentinel/ +/dockers/redis-stack/ diff --git a/docker-compose.yml b/docker-compose.yml index 60657d5653..8ca3471311 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,9 +1,14 @@ --- +x-client-libs-stack-image: &client-libs-stack-image + image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_STACK_IMAGE_TAG:-rs-7.4.0-v2}" + +x-client-libs-image: &client-libs-image + image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_IMAGE_TAG:-7.4.2}" services: redis: - image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:7.4.1} + <<: *client-libs-image container_name: redis-standalone environment: - TLS_ENABLED=yes @@ -24,20 +29,26 @@ services: - all replica: - image: ${REDIS_IMAGE:-redis:7.4.1} + <<: *client-libs-image container_name: redis-replica depends_on: - redis - command: redis-server --replicaof redis 6379 --protected-mode no --save "" + environment: + - TLS_ENABLED=no + - REDIS_CLUSTER=no + - PORT=6380 + command: ${REDIS_EXTRA_ARGS:---enable-debug-command yes --replicaof redis 6379 --protected-mode no --save ""} ports: - - 6380:6379 + - 6380:6380 + volumes: + - "./dockers/replica:/redis/work" profiles: - replica - all-stack - all cluster: - image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:7.4.1} + <<: *client-libs-image container_name: redis-cluster environment: - REDIS_CLUSTER=yes @@ -58,57 +69,38 @@ services: - all sentinel: - image: ${REDIS_IMAGE:-redis:7.4.1} + <<: *client-libs-image container_name: redis-sentinel depends_on: - redis - entrypoint: "redis-sentinel /redis.conf --port 26379" + environment: + - REDIS_CLUSTER=no + - NODES=3 + - PORT=26379 + command: ${REDIS_EXTRA_ARGS:---sentinel} ports: - 26379:26379 - volumes: - - "./dockers/sentinel.conf:/redis.conf" - profiles: - - sentinel - - all-stack - - all - - sentinel2: - image: ${REDIS_IMAGE:-redis:7.4.1} - container_name: redis-sentinel2 - depends_on: - - redis - entrypoint: "redis-sentinel /redis.conf --port 26380" - ports: - 26380:26380 - volumes: - - "./dockers/sentinel.conf:/redis.conf" - profiles: - - sentinel - - all-stack - - all - - sentinel3: - image: ${REDIS_IMAGE:-redis:7.4.1} - container_name: redis-sentinel3 - depends_on: - - redis - entrypoint: "redis-sentinel /redis.conf --port 26381" - ports: - 26381:26381 volumes: - - "./dockers/sentinel.conf:/redis.conf" + - "./dockers/sentinel.conf:/redis/config-default/redis.conf" + - "./dockers/sentinel:/redis/work" profiles: - sentinel - all-stack - all redis-stack: - image: ${REDIS_STACK_IMAGE:-redis/redis-stack-server:latest} + <<: *client-libs-stack-image container_name: redis-stack + environment: + - REDIS_CLUSTER=no + - PORT=6379 + command: ${REDIS_EXTRA_ARGS:---enable-debug-command yes --enable-module-command yes --save ""} ports: - 6479:6379 - environment: - - "REDIS_ARGS=${REDIS_STACK_EXTRA_ARGS:---enable-debug-command yes --enable-module-command yes --save ''}" + volumes: + - "./dockers/redis-stack:/redis/work" profiles: - standalone - all-stack diff --git a/setup.py b/setup.py index 74f6fdafb7..2cde3fb51b 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ long_description_content_type="text/markdown", keywords=["Redis", "key-value store", "database"], license="MIT", - version="5.1.1", + version="5.2.1", packages=find_packages( include=[ "redis", From 2b3420ce8b8e776cac3cb1bd4cdae50d3cd6eb4c Mon Sep 17 00:00:00 2001 From: Max Base Date: Thu, 13 Feb 2025 15:12:22 +0100 Subject: [PATCH 029/113] Fix formatting in README.md - 'Note' - bold formatting (#3413) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 98ddee5b52..8b4d4b6875 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ The Python interface to the Redis key-value store. --------------------------------------------- -**Note: ** redis-py 5.0 will be the last version of redis-py to support Python 3.7, as it has reached [end of life](https://devguide.python.org/versions/). redis-py 5.1 will support Python 3.8+. +**Note:** redis-py 5.0 will be the last version of redis-py to support Python 3.7, as it has reached [end of life](https://devguide.python.org/versions/). redis-py 5.1 will support Python 3.8+. --------------------------------------------- From d912801cbe12a69050059d436c72bb6afcec1166 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Thu, 13 Feb 2025 17:06:18 +0200 Subject: [PATCH 030/113] Adding dev_requirements.txt and pytest.ini resources into sdist. Fix for issue #3057 (#3511) --- MANIFEST.in | 2 ++ 1 file changed, 2 insertions(+) diff --git a/MANIFEST.in b/MANIFEST.in index 97fa305889..80afb437ff 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,6 +1,8 @@ include INSTALL include LICENSE include README.md +include dev_requirements.txt +include pytest.ini exclude __pycache__ recursive-include tests * recursive-exclude tests *.pyc From 27c83de9817558a8297a0e5769931155cf022dd6 Mon Sep 17 00:00:00 2001 From: Mateusz Bilski Date: Mon, 17 Feb 2025 14:40:33 +0100 Subject: [PATCH 031/113] Reinitialize the cluster in case of TimeoutError inside a pipeline (#3513) --- redis/cluster.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/cluster.py b/redis/cluster.py index 6c6cfbf114..db866ce2bf 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -2185,7 +2185,7 @@ def _send_cluster_commands( redis_node = self.get_redis_connection(node) try: connection = get_connection(redis_node, c.args) - except ConnectionError: + except (ConnectionError, TimeoutError): for n in nodes.values(): n.connection_pool.release(n.connection) # Connection retries are being handled in the node's From 9ea0e25824c815fe67a657a79da8a9c9af1b2236 Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 18 Feb 2025 02:54:20 -0800 Subject: [PATCH 032/113] Fix inaccurate docstring for unwatch() (#3424) The 'unwatch()' method of the Redis client, as currently documented, says that it unwatches the value at key "name", but it does not actually take any arguments ("name" or otherwise). According to the latest Redis documentation at the given URL for the UNWATCH command, this command unwatches all previously watched keys for the current transaction. Modified docstring to reflect that this method does not take any arguments and instead (presumably) unwatches all previously watched keys. --- redis/commands/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/commands/core.py b/redis/commands/core.py index 8986a48de2..1d19e33f2c 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -2496,7 +2496,7 @@ def watch(self, *names: KeyT) -> None: def unwatch(self) -> None: """ - Unwatches the value at key ``name``, or None of the key doesn't exist + Unwatches all previously watched keys for a transaction For more information see https://redis.io/commands/unwatch """ From 9834cfa9c4b6acee75c8867c650bcda6f159ec63 Mon Sep 17 00:00:00 2001 From: birthdaysgift <35929293+birthdaysgift@users.noreply.github.com> Date: Wed, 19 Feb 2025 11:01:09 +0300 Subject: [PATCH 033/113] Fix invalid return type annotation (#3480) --- redis/connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/connection.py b/redis/connection.py index d47f46590b..c4ff9b7b17 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -672,7 +672,7 @@ def pack_commands(self, commands): output.append(SYM_EMPTY.join(pieces)) return output - def get_protocol(self) -> int or str: + def get_protocol(self) -> Union[int, str]: return self.protocol @property From e77940fe4af21db174aaa6564fb72f2aff7e88cc Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Thu, 20 Feb 2025 09:54:28 +0200 Subject: [PATCH 034/113] Deprecating unused arguments in connection pools's get_connection functions (#3517) --- redis/asyncio/client.py | 16 ++---- redis/asyncio/connection.py | 16 +++++- redis/client.py | 14 ++--- redis/cluster.py | 18 +++--- redis/connection.py | 16 +++++- redis/utils.py | 65 ++++++++++++++++++++++ tests/test_asyncio/test_connection.py | 2 +- tests/test_asyncio/test_connection_pool.py | 64 ++++++++++----------- tests/test_asyncio/test_credentials.py | 2 +- tests/test_asyncio/test_encoding.py | 2 +- tests/test_asyncio/test_retry.py | 4 +- tests/test_asyncio/test_sentinel.py | 2 +- tests/test_cache.py | 6 +- tests/test_cluster.py | 4 +- tests/test_connection_pool.py | 60 ++++++++++---------- tests/test_credentials.py | 2 +- tests/test_multiprocessing.py | 10 ++-- tests/test_retry.py | 4 +- tests/test_sentinel.py | 2 +- 19 files changed, 196 insertions(+), 113 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 7c17938714..4254441073 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -375,7 +375,7 @@ async def initialize(self: _RedisT) -> _RedisT: if self.single_connection_client: async with self._single_conn_lock: if self.connection is None: - self.connection = await self.connection_pool.get_connection("_") + self.connection = await self.connection_pool.get_connection() self._event_dispatcher.dispatch( AfterSingleConnectionInstantiationEvent( @@ -638,7 +638,7 @@ async def execute_command(self, *args, **options): await self.initialize() pool = self.connection_pool command_name = args[0] - conn = self.connection or await pool.get_connection(command_name, **options) + conn = self.connection or await pool.get_connection() if self.single_connection_client: await self._single_conn_lock.acquire() @@ -712,7 +712,7 @@ def __init__(self, connection_pool: ConnectionPool): async def connect(self): if self.connection is None: - self.connection = await self.connection_pool.get_connection("MONITOR") + self.connection = await self.connection_pool.get_connection() async def __aenter__(self): await self.connect() @@ -900,9 +900,7 @@ async def connect(self): Ensure that the PubSub is connected """ if self.connection is None: - self.connection = await self.connection_pool.get_connection( - "pubsub", self.shard_hint - ) + self.connection = await self.connection_pool.get_connection() # register a callback that re-subscribes to any channels we # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) @@ -1370,9 +1368,7 @@ async def immediate_execute_command(self, *args, **options): conn = self.connection # if this is the first call, we need a connection if not conn: - conn = await self.connection_pool.get_connection( - command_name, self.shard_hint - ) + conn = await self.connection_pool.get_connection() self.connection = conn return await conn.retry.call_with_retry( @@ -1568,7 +1564,7 @@ async def execute(self, raise_on_error: bool = True) -> List[Any]: conn = self.connection if not conn: - conn = await self.connection_pool.get_connection("MULTI", self.shard_hint) + conn = await self.connection_pool.get_connection() # assign to self.connection so reset() releases the connection # back to the pool after we're done self.connection = conn diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 4a743ff374..e67dc5b207 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -29,7 +29,7 @@ from ..auth.token import TokenInterface from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher -from ..utils import format_error_message +from ..utils import deprecated_args, format_error_message # the functionality is available in 3.11.x but has a major issue before # 3.11.3. See https://github.com/redis/redis-py/issues/2633 @@ -1087,7 +1087,12 @@ def can_get_connection(self) -> bool: or len(self._in_use_connections) < self.max_connections ) - async def get_connection(self, command_name, *keys, **options): + @deprecated_args( + args_to_warn=["*"], + reason="Use get_connection() without args instead", + version="5.0.3", + ) + async def get_connection(self, command_name=None, *keys, **options): async with self._lock: """Get a connected connection from the pool""" connection = self.get_available_connection() @@ -1255,7 +1260,12 @@ def __init__( self._condition = asyncio.Condition() self.timeout = timeout - async def get_connection(self, command_name, *keys, **options): + @deprecated_args( + args_to_warn=["*"], + reason="Use get_connection() without args instead", + version="5.0.3", + ) + async def get_connection(self, command_name=None, *keys, **options): """Gets a connection from the pool, blocking until one is available""" try: async with self._condition: diff --git a/redis/client.py b/redis/client.py index 5a9f4fafb5..fc535c8ca0 100755 --- a/redis/client.py +++ b/redis/client.py @@ -366,7 +366,7 @@ def __init__( self.connection = None self._single_connection_client = single_connection_client if self._single_connection_client: - self.connection = self.connection_pool.get_connection("_") + self.connection = self.connection_pool.get_connection() self._event_dispatcher.dispatch( AfterSingleConnectionInstantiationEvent( self.connection, ClientType.SYNC, self.single_connection_lock @@ -608,7 +608,7 @@ def _execute_command(self, *args, **options): """Execute a command and return a parsed response""" pool = self.connection_pool command_name = args[0] - conn = self.connection or pool.get_connection(command_name, **options) + conn = self.connection or pool.get_connection() if self._single_connection_client: self.single_connection_lock.acquire() @@ -667,7 +667,7 @@ class Monitor: def __init__(self, connection_pool): self.connection_pool = connection_pool - self.connection = self.connection_pool.get_connection("MONITOR") + self.connection = self.connection_pool.get_connection() def __enter__(self): self.connection.send_command("MONITOR") @@ -840,9 +840,7 @@ def execute_command(self, *args): # subscribed to one or more channels if self.connection is None: - self.connection = self.connection_pool.get_connection( - "pubsub", self.shard_hint - ) + self.connection = self.connection_pool.get_connection() # register a callback that re-subscribes to any channels we # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) @@ -1397,7 +1395,7 @@ def immediate_execute_command(self, *args, **options): conn = self.connection # if this is the first call, we need a connection if not conn: - conn = self.connection_pool.get_connection(command_name, self.shard_hint) + conn = self.connection_pool.get_connection() self.connection = conn return conn.retry.call_with_retry( @@ -1583,7 +1581,7 @@ def execute(self, raise_on_error: bool = True) -> List[Any]: conn = self.connection if not conn: - conn = self.connection_pool.get_connection("MULTI", self.shard_hint) + conn = self.connection_pool.get_connection() # assign to self.connection so reset() releases the connection # back to the pool after we're done self.connection = conn diff --git a/redis/cluster.py b/redis/cluster.py index db866ce2bf..c184838a9b 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -40,6 +40,7 @@ from redis.retry import Retry from redis.utils import ( HIREDIS_AVAILABLE, + deprecated_args, dict_merge, list_keys_to_dict, merge_result, @@ -52,10 +53,13 @@ def get_node_name(host: str, port: Union[str, int]) -> str: return f"{host}:{port}" +@deprecated_args( + allowed_args=["redis_node"], + reason="Use get_connection(redis_node) instead", + version="5.0.3", +) def get_connection(redis_node, *args, **options): - return redis_node.connection or redis_node.connection_pool.get_connection( - args[0], **options - ) + return redis_node.connection or redis_node.connection_pool.get_connection() def parse_scan_result(command, res, **options): @@ -1151,7 +1155,7 @@ def _execute_command(self, target_node, *args, **kwargs): moved = False redis_node = self.get_redis_connection(target_node) - connection = get_connection(redis_node, *args, **kwargs) + connection = get_connection(redis_node) if asking: connection.send_command("ASKING") redis_node.parse_response(connection, "ASKING", **kwargs) @@ -1822,9 +1826,7 @@ def execute_command(self, *args): self.node = node redis_connection = self.cluster.get_redis_connection(node) self.connection_pool = redis_connection.connection_pool - self.connection = self.connection_pool.get_connection( - "pubsub", self.shard_hint - ) + self.connection = self.connection_pool.get_connection() # register a callback that re-subscribes to any channels we # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) @@ -2184,7 +2186,7 @@ def _send_cluster_commands( if node_name not in nodes: redis_node = self.get_redis_connection(node) try: - connection = get_connection(redis_node, c.args) + connection = get_connection(redis_node) except (ConnectionError, TimeoutError): for n in nodes.values(): n.connection_pool.release(n.connection) diff --git a/redis/connection.py b/redis/connection.py index c4ff9b7b17..d59a9b069b 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -42,6 +42,7 @@ HIREDIS_AVAILABLE, SSL_AVAILABLE, compare_versions, + deprecated_args, ensure_string, format_error_message, get_lib_version, @@ -1461,8 +1462,14 @@ def _checkpid(self) -> None: finally: self._fork_lock.release() - def get_connection(self, command_name: str, *keys, **options) -> "Connection": + @deprecated_args( + args_to_warn=["*"], + reason="Use get_connection() without args instead", + version="5.0.3", + ) + def get_connection(self, command_name=None, *keys, **options) -> "Connection": "Get a connection from the pool" + self._checkpid() with self._lock: try: @@ -1683,7 +1690,12 @@ def make_connection(self): self._connections.append(connection) return connection - def get_connection(self, command_name, *keys, **options): + @deprecated_args( + args_to_warn=["*"], + reason="Use get_connection() without args instead", + version="5.0.3", + ) + def get_connection(self, command_name=None, *keys, **options): """ Get a connection, blocking for ``self.timeout`` until a connection is available from the pool. diff --git a/redis/utils.py b/redis/utils.py index 8693fb3c8f..66465636a1 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -122,6 +122,71 @@ def wrapper(*args, **kwargs): return decorator +def warn_deprecated_arg_usage( + arg_name: Union[list, str], + function_name: str, + reason: str = "", + version: str = "", + stacklevel: int = 2, +): + import warnings + + msg = ( + f"Call to '{function_name}' function with deprecated" + f" usage of input argument/s '{arg_name}'." + ) + if reason: + msg += f" ({reason})" + if version: + msg += f" -- Deprecated since version {version}." + warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel) + + +def deprecated_args( + args_to_warn: list = ["*"], + allowed_args: list = [], + reason: str = "", + version: str = "", +): + """ + Decorator to mark specified args of a function as deprecated. + If '*' is in args_to_warn, all arguments will be marked as deprecated. + """ + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + # Get function argument names + arg_names = func.__code__.co_varnames[: func.__code__.co_argcount] + + provided_args = dict(zip(arg_names, args)) + provided_args.update(kwargs) + + provided_args.pop("self", None) + for allowed_arg in allowed_args: + provided_args.pop(allowed_arg, None) + + for arg in args_to_warn: + if arg == "*" and len(provided_args) > 0: + warn_deprecated_arg_usage( + list(provided_args.keys()), + func.__name__, + reason, + version, + stacklevel=3, + ) + elif arg in provided_args: + warn_deprecated_arg_usage( + arg, func.__name__, reason, version, stacklevel=3 + ) + + return func(*args, **kwargs) + + return wrapper + + return decorator + + def _set_info_logger(): """ Set up a logger that log info logs to stdout. diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index d4956f16e9..38764d30cd 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -78,7 +78,7 @@ async def call_with_retry(self, _, __): mock_conn = mock.AsyncMock(spec=Connection) mock_conn.retry = Retry_() - async def get_conn(_): + async def get_conn(): # Validate only one client is created in single-client mode when # concurrent requests are made nonlocal init_call_count diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index 83545b4ede..3d120e4ca7 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -29,8 +29,8 @@ def get_total_connected_connections(pool): @staticmethod async def create_two_conn(r: redis.Redis): if not r.single_connection_client: # Single already initialized connection - r.connection = await r.connection_pool.get_connection("_") - return await r.connection_pool.get_connection("_") + r.connection = await r.connection_pool.get_connection() + return await r.connection_pool.get_connection() @staticmethod def has_no_connected_connections(pool: redis.ConnectionPool): @@ -138,7 +138,7 @@ async def test_connection_creation(self): async with self.get_pool( connection_kwargs=connection_kwargs, connection_class=DummyConnection ) as pool: - connection = await pool.get_connection("_") + connection = await pool.get_connection() assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs @@ -155,8 +155,8 @@ async def test_aclosing(self): async def test_multiple_connections(self, master_host): connection_kwargs = {"host": master_host[0]} async with self.get_pool(connection_kwargs=connection_kwargs) as pool: - c1 = await pool.get_connection("_") - c2 = await pool.get_connection("_") + c1 = await pool.get_connection() + c2 = await pool.get_connection() assert c1 != c2 async def test_max_connections(self, master_host): @@ -164,17 +164,17 @@ async def test_max_connections(self, master_host): async with self.get_pool( max_connections=2, connection_kwargs=connection_kwargs ) as pool: - await pool.get_connection("_") - await pool.get_connection("_") + await pool.get_connection() + await pool.get_connection() with pytest.raises(redis.ConnectionError): - await pool.get_connection("_") + await pool.get_connection() async def test_reuse_previously_released_connection(self, master_host): connection_kwargs = {"host": master_host[0]} async with self.get_pool(connection_kwargs=connection_kwargs) as pool: - c1 = await pool.get_connection("_") + c1 = await pool.get_connection() await pool.release(c1) - c2 = await pool.get_connection("_") + c2 = await pool.get_connection() assert c1 == c2 async def test_repr_contains_db_info_tcp(self): @@ -223,7 +223,7 @@ async def test_connection_creation(self, master_host): "port": master_host[1], } async with self.get_pool(connection_kwargs=connection_kwargs) as pool: - connection = await pool.get_connection("_") + connection = await pool.get_connection() assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs @@ -236,14 +236,14 @@ async def test_disconnect(self, master_host): "port": master_host[1], } async with self.get_pool(connection_kwargs=connection_kwargs) as pool: - await pool.get_connection("_") + await pool.get_connection() await pool.disconnect() async def test_multiple_connections(self, master_host): connection_kwargs = {"host": master_host[0], "port": master_host[1]} async with self.get_pool(connection_kwargs=connection_kwargs) as pool: - c1 = await pool.get_connection("_") - c2 = await pool.get_connection("_") + c1 = await pool.get_connection() + c2 = await pool.get_connection() assert c1 != c2 async def test_connection_pool_blocks_until_timeout(self, master_host): @@ -252,11 +252,11 @@ async def test_connection_pool_blocks_until_timeout(self, master_host): async with self.get_pool( max_connections=1, timeout=0.1, connection_kwargs=connection_kwargs ) as pool: - c1 = await pool.get_connection("_") + c1 = await pool.get_connection() start = asyncio.get_running_loop().time() with pytest.raises(redis.ConnectionError): - await pool.get_connection("_") + await pool.get_connection() # we should have waited at least some period of time assert asyncio.get_running_loop().time() - start >= 0.05 @@ -271,23 +271,23 @@ async def test_connection_pool_blocks_until_conn_available(self, master_host): async with self.get_pool( max_connections=1, timeout=2, connection_kwargs=connection_kwargs ) as pool: - c1 = await pool.get_connection("_") + c1 = await pool.get_connection() async def target(): await asyncio.sleep(0.1) await pool.release(c1) start = asyncio.get_running_loop().time() - await asyncio.gather(target(), pool.get_connection("_")) + await asyncio.gather(target(), pool.get_connection()) stop = asyncio.get_running_loop().time() assert (stop - start) <= 0.2 async def test_reuse_previously_released_connection(self, master_host): connection_kwargs = {"host": master_host[0]} async with self.get_pool(connection_kwargs=connection_kwargs) as pool: - c1 = await pool.get_connection("_") + c1 = await pool.get_connection() await pool.release(c1) - c2 = await pool.get_connection("_") + c2 = await pool.get_connection() assert c1 == c2 def test_repr_contains_db_info_tcp(self): @@ -552,23 +552,23 @@ def test_cert_reqs_options(self): import ssl class DummyConnectionPool(redis.ConnectionPool): - def get_connection(self, *args, **kwargs): + def get_connection(self): return self.make_connection() pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=none") - assert pool.get_connection("_").cert_reqs == ssl.CERT_NONE + assert pool.get_connection().cert_reqs == ssl.CERT_NONE pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=optional") - assert pool.get_connection("_").cert_reqs == ssl.CERT_OPTIONAL + assert pool.get_connection().cert_reqs == ssl.CERT_OPTIONAL pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=required") - assert pool.get_connection("_").cert_reqs == ssl.CERT_REQUIRED + assert pool.get_connection().cert_reqs == ssl.CERT_REQUIRED pool = DummyConnectionPool.from_url("rediss://?ssl_check_hostname=False") - assert pool.get_connection("_").check_hostname is False + assert pool.get_connection().check_hostname is False pool = DummyConnectionPool.from_url("rediss://?ssl_check_hostname=True") - assert pool.get_connection("_").check_hostname is True + assert pool.get_connection().check_hostname is True class TestConnection: @@ -756,7 +756,7 @@ async def test_health_check_not_invoked_within_interval(self, r): async def test_health_check_in_pipeline(self, r): async with r.pipeline(transaction=False) as pipe: - pipe.connection = await pipe.connection_pool.get_connection("_") + pipe.connection = await pipe.connection_pool.get_connection() pipe.connection.next_health_check = 0 with mock.patch.object( pipe.connection, "send_command", wraps=pipe.connection.send_command @@ -767,7 +767,7 @@ async def test_health_check_in_pipeline(self, r): async def test_health_check_in_transaction(self, r): async with r.pipeline(transaction=True) as pipe: - pipe.connection = await pipe.connection_pool.get_connection("_") + pipe.connection = await pipe.connection_pool.get_connection() pipe.connection.next_health_check = 0 with mock.patch.object( pipe.connection, "send_command", wraps=pipe.connection.send_command @@ -779,7 +779,7 @@ async def test_health_check_in_transaction(self, r): async def test_health_check_in_watched_pipeline(self, r): await r.set("foo", "bar") async with r.pipeline(transaction=False) as pipe: - pipe.connection = await pipe.connection_pool.get_connection("_") + pipe.connection = await pipe.connection_pool.get_connection() pipe.connection.next_health_check = 0 with mock.patch.object( pipe.connection, "send_command", wraps=pipe.connection.send_command @@ -803,7 +803,7 @@ async def test_health_check_in_watched_pipeline(self, r): async def test_health_check_in_pubsub_before_subscribe(self, r): """A health check happens before the first [p]subscribe""" p = r.pubsub() - p.connection = await p.connection_pool.get_connection("_") + p.connection = await p.connection_pool.get_connection() p.connection.next_health_check = 0 with mock.patch.object( p.connection, "send_command", wraps=p.connection.send_command @@ -825,7 +825,7 @@ async def test_health_check_in_pubsub_after_subscribed(self, r): connection health """ p = r.pubsub() - p.connection = await p.connection_pool.get_connection("_") + p.connection = await p.connection_pool.get_connection() p.connection.next_health_check = 0 with mock.patch.object( p.connection, "send_command", wraps=p.connection.send_command @@ -865,7 +865,7 @@ async def test_health_check_in_pubsub_poll(self, r): check the connection's health. """ p = r.pubsub() - p.connection = await p.connection_pool.get_connection("_") + p.connection = await p.connection_pool.get_connection() with mock.patch.object( p.connection, "send_command", wraps=p.connection.send_command ) as m: diff --git a/tests/test_asyncio/test_credentials.py b/tests/test_asyncio/test_credentials.py index ca42d19090..1eb988ce71 100644 --- a/tests/test_asyncio/test_credentials.py +++ b/tests/test_asyncio/test_credentials.py @@ -274,7 +274,7 @@ async def test_change_username_password_on_existing_connection( await init_acl_user(r, username, password) r2 = await create_redis(flushdb=False, username=username, password=password) assert await r2.ping() is True - conn = await r2.connection_pool.get_connection("_") + conn = await r2.connection_pool.get_connection() await conn.send_command("PING") assert str_if_bytes(await conn.read_response()) == "PONG" assert conn.username == username diff --git a/tests/test_asyncio/test_encoding.py b/tests/test_asyncio/test_encoding.py index 162ccb367d..74a9f28b2d 100644 --- a/tests/test_asyncio/test_encoding.py +++ b/tests/test_asyncio/test_encoding.py @@ -74,7 +74,7 @@ class TestMemoryviewsAreNotPacked: async def test_memoryviews_are_not_packed(self, r): arg = memoryview(b"some_arg") arg_list = ["SOME_COMMAND", arg] - c = r.connection or await r.connection_pool.get_connection("_") + c = r.connection or await r.connection_pool.get_connection() cmd = c.pack_command(*arg_list) assert cmd[1] is arg cmds = c.pack_commands([arg_list, arg_list]) diff --git a/tests/test_asyncio/test_retry.py b/tests/test_asyncio/test_retry.py index 8bc71c1479..cd251a986f 100644 --- a/tests/test_asyncio/test_retry.py +++ b/tests/test_asyncio/test_retry.py @@ -126,13 +126,13 @@ async def test_get_set_retry_object(self, request): assert r.get_retry()._retries == retry._retries assert isinstance(r.get_retry()._backoff, NoBackoff) new_retry_policy = Retry(ExponentialBackoff(), 3) - exiting_conn = await r.connection_pool.get_connection("_") + exiting_conn = await r.connection_pool.get_connection() r.set_retry(new_retry_policy) assert r.get_retry()._retries == new_retry_policy._retries assert isinstance(r.get_retry()._backoff, ExponentialBackoff) assert exiting_conn.retry._retries == new_retry_policy._retries await r.connection_pool.release(exiting_conn) - new_conn = await r.connection_pool.get_connection("_") + new_conn = await r.connection_pool.get_connection() assert new_conn.retry._retries == new_retry_policy._retries await r.connection_pool.release(new_conn) await r.aclose() diff --git a/tests/test_asyncio/test_sentinel.py b/tests/test_asyncio/test_sentinel.py index e553fdb00b..a27ba92bb8 100644 --- a/tests/test_asyncio/test_sentinel.py +++ b/tests/test_asyncio/test_sentinel.py @@ -269,7 +269,7 @@ async def mock_disconnect(): @pytest.mark.onlynoncluster async def test_repr_correctly_represents_connection_object(sentinel): pool = SentinelConnectionPool("mymaster", sentinel) - connection = await pool.get_connection("PING") + connection = await pool.get_connection() assert ( str(connection) diff --git a/tests/test_cache.py b/tests/test_cache.py index 67733dc9af..7010baff5f 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -159,7 +159,7 @@ def test_cache_clears_on_disconnect(self, r, cache): == b"bar" ) # Force disconnection - r.connection_pool.get_connection("_").disconnect() + r.connection_pool.get_connection().disconnect() # Make sure cache is empty assert cache.size == 0 @@ -429,7 +429,7 @@ def test_cache_clears_on_disconnect(self, r, r2): # Force disconnection r.nodes_manager.get_node_from_slot( 12000 - ).redis_connection.connection_pool.get_connection("_").disconnect() + ).redis_connection.connection_pool.get_connection().disconnect() # Make sure cache is empty assert cache.size == 0 @@ -667,7 +667,7 @@ def test_cache_clears_on_disconnect(self, master, cache): == b"bar" ) # Force disconnection - master.connection_pool.get_connection("_").disconnect() + master.connection_pool.get_connection().disconnect() # Make sure cache_data is empty assert cache.size == 0 diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 1b9b9969c5..908ac26211 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -845,7 +845,7 @@ def test_cluster_get_set_retry_object(self, request): assert node.redis_connection.get_retry()._retries == retry._retries assert isinstance(node.redis_connection.get_retry()._backoff, NoBackoff) rand_node = r.get_random_node() - existing_conn = rand_node.redis_connection.connection_pool.get_connection("_") + existing_conn = rand_node.redis_connection.connection_pool.get_connection() # Change retry policy new_retry = Retry(ExponentialBackoff(), 3) r.set_retry(new_retry) @@ -857,7 +857,7 @@ def test_cluster_get_set_retry_object(self, request): node.redis_connection.get_retry()._backoff, ExponentialBackoff ) assert existing_conn.retry._retries == new_retry._retries - new_conn = rand_node.redis_connection.connection_pool.get_connection("_") + new_conn = rand_node.redis_connection.connection_pool.get_connection() assert new_conn.retry._retries == new_retry._retries def test_cluster_retry_object(self, r) -> None: diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 118294ee1b..387a0f4565 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -54,7 +54,7 @@ def test_connection_creation(self): pool = self.get_pool( connection_kwargs=connection_kwargs, connection_class=DummyConnection ) - connection = pool.get_connection("_") + connection = pool.get_connection() assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs @@ -71,24 +71,24 @@ def test_closing(self): def test_multiple_connections(self, master_host): connection_kwargs = {"host": master_host[0], "port": master_host[1]} pool = self.get_pool(connection_kwargs=connection_kwargs) - c1 = pool.get_connection("_") - c2 = pool.get_connection("_") + c1 = pool.get_connection() + c2 = pool.get_connection() assert c1 != c2 def test_max_connections(self, master_host): connection_kwargs = {"host": master_host[0], "port": master_host[1]} pool = self.get_pool(max_connections=2, connection_kwargs=connection_kwargs) - pool.get_connection("_") - pool.get_connection("_") + pool.get_connection() + pool.get_connection() with pytest.raises(redis.ConnectionError): - pool.get_connection("_") + pool.get_connection() def test_reuse_previously_released_connection(self, master_host): connection_kwargs = {"host": master_host[0], "port": master_host[1]} pool = self.get_pool(connection_kwargs=connection_kwargs) - c1 = pool.get_connection("_") + c1 = pool.get_connection() pool.release(c1) - c2 = pool.get_connection("_") + c2 = pool.get_connection() assert c1 == c2 def test_repr_contains_db_info_tcp(self): @@ -133,15 +133,15 @@ def test_connection_creation(self, master_host): "port": master_host[1], } pool = self.get_pool(connection_kwargs=connection_kwargs) - connection = pool.get_connection("_") + connection = pool.get_connection() assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs def test_multiple_connections(self, master_host): connection_kwargs = {"host": master_host[0], "port": master_host[1]} pool = self.get_pool(connection_kwargs=connection_kwargs) - c1 = pool.get_connection("_") - c2 = pool.get_connection("_") + c1 = pool.get_connection() + c2 = pool.get_connection() assert c1 != c2 def test_connection_pool_blocks_until_timeout(self, master_host): @@ -150,11 +150,11 @@ def test_connection_pool_blocks_until_timeout(self, master_host): pool = self.get_pool( max_connections=1, timeout=0.1, connection_kwargs=connection_kwargs ) - pool.get_connection("_") + pool.get_connection() start = time.time() with pytest.raises(redis.ConnectionError): - pool.get_connection("_") + pool.get_connection() # we should have waited at least 0.1 seconds assert time.time() - start >= 0.1 @@ -167,7 +167,7 @@ def test_connection_pool_blocks_until_conn_available(self, master_host): pool = self.get_pool( max_connections=1, timeout=2, connection_kwargs=connection_kwargs ) - c1 = pool.get_connection("_") + c1 = pool.get_connection() def target(): time.sleep(0.1) @@ -175,15 +175,15 @@ def target(): start = time.time() Thread(target=target).start() - pool.get_connection("_") + pool.get_connection() assert time.time() - start >= 0.1 def test_reuse_previously_released_connection(self, master_host): connection_kwargs = {"host": master_host[0], "port": master_host[1]} pool = self.get_pool(connection_kwargs=connection_kwargs) - c1 = pool.get_connection("_") + c1 = pool.get_connection() pool.release(c1) - c2 = pool.get_connection("_") + c2 = pool.get_connection() assert c1 == c2 def test_repr_contains_db_info_tcp(self): @@ -214,7 +214,7 @@ def test_initialise_pool_with_cache(self, master_host): protocol=3, cache_config=CacheConfig(), ) - assert isinstance(pool.get_connection("_"), CacheProxyConnection) + assert isinstance(pool.get_connection(), CacheProxyConnection) class TestConnectionPoolURLParsing: @@ -489,23 +489,23 @@ def test_cert_reqs_options(self): import ssl class DummyConnectionPool(redis.ConnectionPool): - def get_connection(self, *args, **kwargs): + def get_connection(self): return self.make_connection() pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=none") - assert pool.get_connection("_").cert_reqs == ssl.CERT_NONE + assert pool.get_connection().cert_reqs == ssl.CERT_NONE pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=optional") - assert pool.get_connection("_").cert_reqs == ssl.CERT_OPTIONAL + assert pool.get_connection().cert_reqs == ssl.CERT_OPTIONAL pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=required") - assert pool.get_connection("_").cert_reqs == ssl.CERT_REQUIRED + assert pool.get_connection().cert_reqs == ssl.CERT_REQUIRED pool = DummyConnectionPool.from_url("rediss://?ssl_check_hostname=False") - assert pool.get_connection("_").check_hostname is False + assert pool.get_connection().check_hostname is False pool = DummyConnectionPool.from_url("rediss://?ssl_check_hostname=True") - assert pool.get_connection("_").check_hostname is True + assert pool.get_connection().check_hostname is True class TestConnection: @@ -701,7 +701,7 @@ def test_health_check_not_invoked_within_interval(self, r): def test_health_check_in_pipeline(self, r): with r.pipeline(transaction=False) as pipe: - pipe.connection = pipe.connection_pool.get_connection("_") + pipe.connection = pipe.connection_pool.get_connection() pipe.connection.next_health_check = 0 with mock.patch.object( pipe.connection, "send_command", wraps=pipe.connection.send_command @@ -712,7 +712,7 @@ def test_health_check_in_pipeline(self, r): def test_health_check_in_transaction(self, r): with r.pipeline(transaction=True) as pipe: - pipe.connection = pipe.connection_pool.get_connection("_") + pipe.connection = pipe.connection_pool.get_connection() pipe.connection.next_health_check = 0 with mock.patch.object( pipe.connection, "send_command", wraps=pipe.connection.send_command @@ -724,7 +724,7 @@ def test_health_check_in_transaction(self, r): def test_health_check_in_watched_pipeline(self, r): r.set("foo", "bar") with r.pipeline(transaction=False) as pipe: - pipe.connection = pipe.connection_pool.get_connection("_") + pipe.connection = pipe.connection_pool.get_connection() pipe.connection.next_health_check = 0 with mock.patch.object( pipe.connection, "send_command", wraps=pipe.connection.send_command @@ -748,7 +748,7 @@ def test_health_check_in_watched_pipeline(self, r): def test_health_check_in_pubsub_before_subscribe(self, r): "A health check happens before the first [p]subscribe" p = r.pubsub() - p.connection = p.connection_pool.get_connection("_") + p.connection = p.connection_pool.get_connection() p.connection.next_health_check = 0 with mock.patch.object( p.connection, "send_command", wraps=p.connection.send_command @@ -770,7 +770,7 @@ def test_health_check_in_pubsub_after_subscribed(self, r): connection health """ p = r.pubsub() - p.connection = p.connection_pool.get_connection("_") + p.connection = p.connection_pool.get_connection() p.connection.next_health_check = 0 with mock.patch.object( p.connection, "send_command", wraps=p.connection.send_command @@ -810,7 +810,7 @@ def test_health_check_in_pubsub_poll(self, r): check the connection's health. """ p = r.pubsub() - p.connection = p.connection_pool.get_connection("_") + p.connection = p.connection_pool.get_connection() with mock.patch.object( p.connection, "send_command", wraps=p.connection.send_command ) as m: diff --git a/tests/test_credentials.py b/tests/test_credentials.py index b0b79d305f..95ec5577cc 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -252,7 +252,7 @@ def teardown(): redis.Redis, request, flushdb=False, username=username, password=password ) assert r2.ping() is True - conn = r2.connection_pool.get_connection("_") + conn = r2.connection_pool.get_connection() conn.send_command("PING") assert str_if_bytes(conn.read_response()) == "PONG" assert conn.username == username diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index 116d20dab0..0e8e8958c5 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -95,7 +95,7 @@ def test_pool(self, max_connections, master_host): max_connections=max_connections, ) - conn = pool.get_connection("ping") + conn = pool.get_connection() main_conn_pid = conn.pid with exit_callback(pool.release, conn): conn.send_command("ping") @@ -103,7 +103,7 @@ def test_pool(self, max_connections, master_host): def target(pool): with exit_callback(pool.disconnect): - conn = pool.get_connection("ping") + conn = pool.get_connection() assert conn.pid != main_conn_pid with exit_callback(pool.release, conn): assert conn.send_command("ping") is None @@ -116,7 +116,7 @@ def target(pool): # Check that connection is still alive after fork process has exited # and disconnected the connections in its pool - conn = pool.get_connection("ping") + conn = pool.get_connection() with exit_callback(pool.release, conn): assert conn.send_command("ping") is None assert conn.read_response() == b"PONG" @@ -132,12 +132,12 @@ def test_close_pool_in_main(self, max_connections, master_host): max_connections=max_connections, ) - conn = pool.get_connection("ping") + conn = pool.get_connection() assert conn.send_command("ping") is None assert conn.read_response() == b"PONG" def target(pool, disconnect_event): - conn = pool.get_connection("ping") + conn = pool.get_connection() with exit_callback(pool.release, conn): assert conn.send_command("ping") is None assert conn.read_response() == b"PONG" diff --git a/tests/test_retry.py b/tests/test_retry.py index 183807386d..e1e4c414a4 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -206,7 +206,7 @@ def test_client_retry_on_timeout(self, request): def test_get_set_retry_object(self, request): retry = Retry(NoBackoff(), 2) r = _get_client(Redis, request, retry_on_timeout=True, retry=retry) - exist_conn = r.connection_pool.get_connection("_") + exist_conn = r.connection_pool.get_connection() assert r.get_retry()._retries == retry._retries assert isinstance(r.get_retry()._backoff, NoBackoff) new_retry_policy = Retry(ExponentialBackoff(), 3) @@ -214,5 +214,5 @@ def test_get_set_retry_object(self, request): assert r.get_retry()._retries == new_retry_policy._retries assert isinstance(r.get_retry()._backoff, ExponentialBackoff) assert exist_conn.retry._retries == new_retry_policy._retries - new_conn = r.connection_pool.get_connection("_") + new_conn = r.connection_pool.get_connection() assert new_conn.retry._retries == new_retry_policy._retries diff --git a/tests/test_sentinel.py b/tests/test_sentinel.py index 54b9647098..93455f3290 100644 --- a/tests/test_sentinel.py +++ b/tests/test_sentinel.py @@ -101,7 +101,7 @@ def test_discover_master_error(sentinel): @pytest.mark.onlynoncluster def test_dead_pool(sentinel): master = sentinel.master_for("mymaster", db=9) - conn = master.connection_pool.get_connection("_") + conn = master.connection_pool.get_connection() conn.disconnect() del master conn.connect() From c788cd4141fa745d321039533584176510a91559 Mon Sep 17 00:00:00 2001 From: Sviatoslav Abakumov Date: Fri, 21 Feb 2025 21:56:45 +0400 Subject: [PATCH 035/113] Type hint Lock.extend's additional_time as a Number (#3522) It must be possible for the function to receive floats, as the docstring and the code suggest. --- redis/lock.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/redis/lock.py b/redis/lock.py index cae7f27ea1..f44ed629da 100644 --- a/redis/lock.py +++ b/redis/lock.py @@ -264,7 +264,7 @@ def do_release(self, expected_token: str) -> None: lock_name=self.name, ) - def extend(self, additional_time: int, replace_ttl: bool = False) -> bool: + def extend(self, additional_time: Number, replace_ttl: bool = False) -> bool: """ Adds more time to an already acquired lock. @@ -281,7 +281,7 @@ def extend(self, additional_time: int, replace_ttl: bool = False) -> bool: raise LockError("Cannot extend a lock with no timeout", lock_name=self.name) return self.do_extend(additional_time, replace_ttl) - def do_extend(self, additional_time: int, replace_ttl: bool) -> bool: + def do_extend(self, additional_time: Number, replace_ttl: bool) -> bool: additional_time = int(additional_time * 1000) if not bool( self.lua_extend( From 70142d803011f81b41073ccea40771f17469f0f7 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Mon, 24 Feb 2025 17:33:56 +0200 Subject: [PATCH 036/113] Remove decreasing of created connections count when releasing not owned by connection pool connection(fixes issue #2832). (#3514) * Removing decreasing of created connections count when releasing not owned by connection pool connection(#2832). * Fixed another issue that was allowing adding connections to a pool owned by other pools. Adding unit tests. * Fixing a typo in a comment --- redis/connection.py | 10 +++++----- tests/test_connection_pool.py | 15 +++++++++++++++ tests/test_multiprocessing.py | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 5 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index d59a9b069b..ece17d752b 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1532,7 +1532,7 @@ def release(self, connection: "Connection") -> None: except KeyError: # Gracefully fail when a connection is returned to this pool # that the pool doesn't actually own - pass + return if self.owns_connection(connection): self._available_connections.append(connection) @@ -1540,10 +1540,10 @@ def release(self, connection: "Connection") -> None: AfterConnectionReleasedEvent(connection) ) else: - # pool doesn't own this connection. do not add it back - # to the pool and decrement the count so that another - # connection can take its place if needed - self._created_connections -= 1 + # Pool doesn't own this connection, do not add it back + # to the pool. + # The created connections count should not be changed, + # because the connection was not created by the pool. connection.disconnect() return diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 387a0f4565..65f42923fe 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -91,6 +91,21 @@ def test_reuse_previously_released_connection(self, master_host): c2 = pool.get_connection() assert c1 == c2 + def test_release_not_owned_connection(self, master_host): + connection_kwargs = {"host": master_host[0], "port": master_host[1]} + pool1 = self.get_pool(connection_kwargs=connection_kwargs) + c1 = pool1.get_connection("_") + pool2 = self.get_pool( + connection_kwargs={"host": master_host[0], "port": master_host[1]} + ) + c2 = pool2.get_connection("_") + pool2.release(c2) + + assert len(pool2._available_connections) == 1 + + pool2.release(c1) + assert len(pool2._available_connections) == 1 + def test_repr_contains_db_info_tcp(self): connection_kwargs = { "host": "localhost", diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index 0e8e8958c5..8b9e9fb90b 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -84,6 +84,40 @@ def target(conn, ev): proc.join(3) assert proc.exitcode == 0 + @pytest.mark.parametrize("max_connections", [2, None]) + def test_release_parent_connection_from_pool_in_child_process( + self, max_connections, master_host + ): + """ + A connection owned by a parent should not decrease the _created_connections + counter in child when released - when the child process starts to use the + pool it resets all the counters that have been set in the parent process. + """ + + pool = ConnectionPool.from_url( + f"redis://{master_host[0]}:{master_host[1]}", + max_connections=max_connections, + ) + + parent_conn = pool.get_connection("ping") + + def target(pool, parent_conn): + with exit_callback(pool.disconnect): + child_conn = pool.get_connection("ping") + assert child_conn.pid != parent_conn.pid + pool.release(child_conn) + assert pool._created_connections == 1 + assert child_conn in pool._available_connections + pool.release(parent_conn) + assert pool._created_connections == 1 + assert child_conn in pool._available_connections + assert parent_conn not in pool._available_connections + + proc = multiprocessing.Process(target=target, args=(pool, parent_conn)) + proc.start() + proc.join(3) + assert proc.exitcode == 0 + @pytest.mark.parametrize("max_connections", [1, 2, None]) def test_pool(self, max_connections, master_host): """ From 5cbc526026b98baea57b1bcc9eb6963095c7e318 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Tue, 25 Feb 2025 12:58:12 +0200 Subject: [PATCH 037/113] Switch to PEP 517 packaging using hatchling (#2930) * Switch to PEP 517 packaging using hatchling Refs #1316 Refs #1649 Remake of #2388 Co-authored-by: Ofek Lev * Use a single source of truth for version info * Uninstall redis wheel installed as redis-entraid dep * Add build as dev_requirement * Get rid of requirements.txt * Get rid of setuptools and wheel deps * Move pytest configuration to pyproject.toml * Retain tests and dev_requirements.txt in sdist --------- Co-authored-by: Ofek Lev --- .github/actions/run-tests/action.yml | 4 +- .github/workflows/docs.yaml | 2 +- .github/workflows/install_and_test.sh | 1 + .github/workflows/integration.yaml | 3 +- .github/workflows/pypi-publish.yaml | 9 +-- CONTRIBUTING.md | 2 +- INSTALL | 6 -- MANIFEST.in | 8 --- dev_requirements.txt | 2 +- docs/examples/opentelemetry/README.md | 2 +- pyproject.toml | 96 +++++++++++++++++++++++++++ pytest.ini | 20 ------ redis/__init__.py | 13 +--- requirements.txt | 2 - setup.py | 65 ------------------ tasks.py | 2 +- 16 files changed, 110 insertions(+), 127 deletions(-) delete mode 100644 INSTALL delete mode 100644 MANIFEST.in create mode 100644 pyproject.toml delete mode 100644 pytest.ini delete mode 100644 requirements.txt delete mode 100644 setup.py diff --git a/.github/actions/run-tests/action.yml b/.github/actions/run-tests/action.yml index ca775f5a5b..d822e1d499 100644 --- a/.github/actions/run-tests/action.yml +++ b/.github/actions/run-tests/action.yml @@ -36,9 +36,9 @@ runs: set -e echo "::group::Installing dependencies" - pip install -U setuptools wheel - pip install -r requirements.txt pip install -r dev_requirements.txt + pip uninstall -y redis # uninstall Redis package installed via redis-entraid + pip install -e . # install the working copy if [ "${{inputs.parser-backend}}" == "hiredis" ]; then pip install "hiredis${{inputs.hiredis-version}}" echo "PARSER_BACKEND=$(echo "${{inputs.parser-backend}}_${{inputs.hiredis-version}}" | sed 's/[^a-zA-Z0-9]/_/g')" >> $GITHUB_ENV diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index a3512b46dc..747b3b6d76 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -36,7 +36,7 @@ jobs: sudo apt-get install -yqq pandoc make - name: run code linters run: | - pip install -r requirements.txt -r dev_requirements.txt -r docs/requirements.txt + pip install -r dev_requirements.txt -r docs/requirements.txt invoke build-docs - name: upload docs diff --git a/.github/workflows/install_and_test.sh b/.github/workflows/install_and_test.sh index 5c879c1b3a..778dbe0b20 100755 --- a/.github/workflows/install_and_test.sh +++ b/.github/workflows/install_and_test.sh @@ -21,6 +21,7 @@ python -m venv ${DESTENV} source ${DESTENV}/bin/activate pip install --upgrade --quiet pip pip install --quiet -r dev_requirements.txt +pip uninstall -y redis # uninstall Redis package installed via redis-entraid invoke devenv --endpoints=all-stack invoke package diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index 45e0d5bf8e..bb56e8a024 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -38,7 +38,7 @@ jobs: - uses: actions/checkout@v4 - uses: pypa/gh-action-pip-audit@v1.0.8 with: - inputs: requirements.txt dev_requirements.txt + inputs: dev_requirements.txt ignore-vulns: | GHSA-w596-4wvx-j9j6 # subversion related git pull, dependency for pytest. There is no impact here. @@ -54,6 +54,7 @@ jobs: - name: run code linters run: | pip install -r dev_requirements.txt + pip uninstall -y redis # uninstall Redis package installed via redis-entraid invoke linters redis_version: diff --git a/.github/workflows/pypi-publish.yaml b/.github/workflows/pypi-publish.yaml index e4815aa1b5..048d06c53c 100644 --- a/.github/workflows/pypi-publish.yaml +++ b/.github/workflows/pypi-publish.yaml @@ -18,15 +18,10 @@ jobs: uses: actions/setup-python@v5 with: python-version: 3.9 - - name: Install dev tools - run: | - pip install -r dev_requirements.txt - pip install twine wheel + - run: pip install build twine - name: Build package - run: | - python setup.py build - python setup.py sdist bdist_wheel + run: python -m build . - name: Basic package test prior to upload run: | diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d87e6ba1c3..79983e4cb5 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -38,7 +38,7 @@ Here's how to get started with your code contribution: a. python -m venv .venv b. source .venv/bin/activate c. pip install -r dev_requirements.txt - c. pip install -r requirements.txt + c. pip install -e . 4. If you need a development environment, run `invoke devenv`. Note: this relies on docker-compose to build environments, and assumes that you have a version supporting [docker profiles](https://docs.docker.com/compose/profiles/). 5. While developing, make sure the tests pass by running `invoke tests` diff --git a/INSTALL b/INSTALL deleted file mode 100644 index 951f7dea8a..0000000000 --- a/INSTALL +++ /dev/null @@ -1,6 +0,0 @@ - -Please use - python setup.py install - -and report errors to Andy McCurdy (sedrik@gmail.com) - diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 80afb437ff..0000000000 --- a/MANIFEST.in +++ /dev/null @@ -1,8 +0,0 @@ -include INSTALL -include LICENSE -include README.md -include dev_requirements.txt -include pytest.ini -exclude __pycache__ -recursive-include tests * -recursive-exclude tests *.pyc diff --git a/dev_requirements.txt b/dev_requirements.txt index 619fbf479c..7ee7ac2b75 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,3 +1,4 @@ +build black==24.3.0 click==8.0.4 flake8-isort @@ -14,6 +15,5 @@ pytest-timeout ujson>=4.2.0 uvloop vulture>=2.3.0 -wheel>=0.30.0 numpy>=1.24.0 redis-entraid==0.3.0b1 diff --git a/docs/examples/opentelemetry/README.md b/docs/examples/opentelemetry/README.md index a1d1c04eda..58085c9637 100644 --- a/docs/examples/opentelemetry/README.md +++ b/docs/examples/opentelemetry/README.md @@ -24,7 +24,7 @@ source .venv/bin/active **Step 3**. Install dependencies: ```shell -pip install -r requirements.txt +pip install -e . ``` **Step 4**. Start the services using Docker and make sure Uptrace is running: diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..7becde948e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,96 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "redis" +dynamic = ["version"] +description = "Python client for Redis database and key-value store" +readme = "README.md" +license = "MIT" +requires-python = ">=3.8" +authors = [ + { name = "Redis Inc.", email = "oss@redis.com" }, +] +keywords = [ + "Redis", + "database", + "key-value-store", +] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Environment :: Console", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = [ + 'async-timeout>=4.0.3; python_full_version<"3.11.3"', + "PyJWT~=2.9.0", +] + +[project.optional-dependencies] +hiredis = [ + "hiredis>=3.0.0", +] +ocsp = [ + "cryptography>=36.0.1", + "pyopenssl==20.0.1", + "requests>=2.31.0", +] + +[project.urls] +Changes = "https://github.com/redis/redis-py/releases" +Code = "https://github.com/redis/redis-py" +Documentation = "https://redis.readthedocs.io/en/latest/" +Homepage = "https://github.com/redis/redis-py" +"Issue tracker" = "https://github.com/redis/redis-py/issues" + +[tool.hatch.version] +path = "redis/__init__.py" + +[tool.hatch.build.targets.sdist] +include = [ + "/redis", + "/tests", + "dev_requirements.txt", +] + +[tool.hatch.build.targets.wheel] +include = [ + "/redis", +] + +[tool.pytest.ini_options] +addopts = "-s" +markers = [ + "redismod: run only the redis module tests", + "graph: run only the redisgraph tests", + "pipeline: pipeline tests", + "onlycluster: marks tests to be run only with cluster mode redis", + "onlynoncluster: marks tests to be run only with standalone redis", + "ssl: marker for only the ssl tests", + "asyncio: marker for async tests", + "replica: replica tests", + "experimental: run only experimental tests", + "cp_integration: credential provider integration tests", +] +asyncio_mode = "auto" +timeout = 30 +filterwarnings = [ + "always", + "ignore:RedisGraph support is deprecated as of Redis Stack 7.2:DeprecationWarning", + # Ignore a coverage warning when COVERAGE_CORE=sysmon for Pythons < 3.12. + "ignore:sys.monitoring isn't available:coverage.exceptions.CoverageWarning", +] diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index 68fee2b603..0000000000 --- a/pytest.ini +++ /dev/null @@ -1,20 +0,0 @@ -[pytest] -addopts = -s -markers = - redismod: run only the redis module tests - graph: run only the redisgraph tests - pipeline: pipeline tests - onlycluster: marks tests to be run only with cluster mode redis - onlynoncluster: marks tests to be run only with standalone redis - ssl: marker for only the ssl tests - asyncio: marker for async tests - replica: replica tests - experimental: run only experimental tests - cp_integration: credential provider integration tests -asyncio_mode = auto -timeout = 30 -filterwarnings = - always - ignore:RedisGraph support is deprecated as of Redis Stack 7.2:DeprecationWarning - # Ignore a coverage warning when COVERAGE_CORE=sysmon for Pythons < 3.12. - ignore:sys.monitoring isn't available:coverage.exceptions.CoverageWarning diff --git a/redis/__init__.py b/redis/__init__.py index 94324a0de8..f82a876b2d 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -1,5 +1,3 @@ -from importlib import metadata - from redis import asyncio # noqa from redis.backoff import default_backoff from redis.client import Redis, StrictRedis @@ -44,16 +42,9 @@ def int_or_str(value): return value -try: - __version__ = metadata.version("redis") -except metadata.PackageNotFoundError: - __version__ = "99.99.99" - +__version__ = "5.2.1" +VERSION = tuple(map(int_or_str, __version__.split("."))) -try: - VERSION = tuple(map(int_or_str, __version__.split("."))) -except AttributeError: - VERSION = tuple([99, 99, 99]) __all__ = [ "AuthenticationError", diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 9760e5bb13..0000000000 --- a/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -async-timeout>=4.0.3 -PyJWT~=2.9.0 \ No newline at end of file diff --git a/setup.py b/setup.py deleted file mode 100644 index 2cde3fb51b..0000000000 --- a/setup.py +++ /dev/null @@ -1,65 +0,0 @@ -#!/usr/bin/env python -from setuptools import find_packages, setup - -setup( - name="redis", - description="Python client for Redis database and key-value store", - long_description=open("README.md").read().strip(), - long_description_content_type="text/markdown", - keywords=["Redis", "key-value store", "database"], - license="MIT", - version="5.2.1", - packages=find_packages( - include=[ - "redis", - "redis._parsers", - "redis.asyncio", - "redis.auth", - "redis.commands", - "redis.commands.bf", - "redis.commands.json", - "redis.commands.search", - "redis.commands.timeseries", - "redis.commands.graph", - "redis.parsers", - ] - ), - package_data={"redis": ["py.typed"]}, - include_package_data=True, - url="https://github.com/redis/redis-py", - project_urls={ - "Documentation": "https://redis.readthedocs.io/en/latest/", - "Changes": "https://github.com/redis/redis-py/releases", - "Code": "https://github.com/redis/redis-py", - "Issue tracker": "https://github.com/redis/redis-py/issues", - }, - author="Redis Inc.", - author_email="oss@redis.com", - python_requires=">=3.8", - install_requires=[ - 'async-timeout>=4.0.3; python_full_version<"3.11.3"', - "PyJWT~=2.9.0", - ], - classifiers=[ - "Development Status :: 5 - Production/Stable", - "Environment :: Console", - "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - "Programming Language :: Python", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", - "Programming Language :: Python :: Implementation :: CPython", - "Programming Language :: Python :: Implementation :: PyPy", - ], - extras_require={ - "hiredis": ["hiredis>=3.0.0"], - "ocsp": ["cryptography>=36.0.1", "pyopenssl==23.2.1", "requests>=2.31.0"], - }, -) diff --git a/tasks.py b/tasks.py index f7b728aed4..8a5cae97b2 100644 --- a/tasks.py +++ b/tasks.py @@ -97,4 +97,4 @@ def clean(c): @task def package(c): """Create the python packages""" - run("python setup.py sdist bdist_wheel") + run("python -m build .") From 2324ab24f9700408c42dc65d96bab1f86e680f12 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Tue, 25 Feb 2025 14:55:05 +0200 Subject: [PATCH 038/113] Removing the requirements.txt from docs building dependencies (#3527) --- .readthedocs.yml | 1 - doctests/README.md | 4 ++-- doctests/requirements.txt | 1 + 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.readthedocs.yml b/.readthedocs.yml index 800cb14816..80b9738d82 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -3,7 +3,6 @@ version: 2 python: install: - requirements: ./docs/requirements.txt - - requirements: requirements.txt build: os: ubuntu-20.04 diff --git a/doctests/README.md b/doctests/README.md index 9dd6eaeb5d..15664f1bcd 100644 --- a/doctests/README.md +++ b/doctests/README.md @@ -13,11 +13,11 @@ See https://github.com/redis-stack/redis-stack-website#readme for more details. ## How to test examples Examples are standalone python scripts, committed to the *doctests* directory. These scripts assume that the -```requirements.txt``` and ```dev_requirements.txt``` from this repository have been installed, as per below. +```doctests/requirements.txt``` and ```dev_requirements.txt``` from this repository have been installed, as per below. ```bash -pip install -r requirements.txt pip install -r dev_requirements.txt +pip uninstall -y redis # uninstall Redis package installed via redis-entraid pip install -r doctests/requirements.txt ``` diff --git a/doctests/requirements.txt b/doctests/requirements.txt index 209d87b9c8..1f239546c1 100644 --- a/doctests/requirements.txt +++ b/doctests/requirements.txt @@ -3,3 +3,4 @@ pandas requests sentence_transformers tabulate +redis #install latest stable version From 005367b7d362c3239eb4585205cfc62f72247a3d Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Tue, 25 Feb 2025 17:06:09 +0200 Subject: [PATCH 039/113] Installing redis and its mandatory dependencies when building readthedocs (#3528) --- .readthedocs.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.readthedocs.yml b/.readthedocs.yml index 80b9738d82..4b22490a49 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -2,7 +2,10 @@ version: 2 python: install: - - requirements: ./docs/requirements.txt + - method: pip + path: . + - method: pip + requirements: ./docs/requirements.txt build: os: ubuntu-20.04 From c58b590241d5fec828f291d2424f1a9dcee1fe54 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Tue, 25 Feb 2025 19:01:03 +0200 Subject: [PATCH 040/113] Adding vector search tests for types int8/uint8 (#3525) --- tests/test_search.py | 58 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/tests/test_search.py b/tests/test_search.py index 5b45cfc0a3..c4598f3773 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -2856,6 +2856,64 @@ def test_vector_search_with_default_dialect(client): assert res["total_results"] == 2 +@pytest.mark.redismod +@skip_if_server_version_lt("7.9.0") +def test_vector_search_with_int8_type(client): + client.ft().create_index( + (VectorField("v", "FLAT", {"TYPE": "INT8", "DIM": 2, "DISTANCE_METRIC": "L2"}),) + ) + + a = [1.5, 10] + b = [123, 100] + c = [1, 1] + + client.hset("a", "v", np.array(a, dtype=np.int8).tobytes()) + client.hset("b", "v", np.array(b, dtype=np.int8).tobytes()) + client.hset("c", "v", np.array(c, dtype=np.int8).tobytes()) + + query = Query("*=>[KNN 2 @v $vec as score]") + query_params = {"vec": np.array(a, dtype=np.int8).tobytes()} + + assert 2 in query.get_args() + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 2 + else: + assert res["total_results"] == 2 + + +@pytest.mark.redismod +@skip_if_server_version_lt("7.9.0") +def test_vector_search_with_uint8_type(client): + client.ft().create_index( + ( + VectorField( + "v", "FLAT", {"TYPE": "UINT8", "DIM": 2, "DISTANCE_METRIC": "L2"} + ), + ) + ) + + a = [1.5, 10] + b = [123, 100] + c = [1, 1] + + client.hset("a", "v", np.array(a, dtype=np.uint8).tobytes()) + client.hset("b", "v", np.array(b, dtype=np.uint8).tobytes()) + client.hset("c", "v", np.array(c, dtype=np.uint8).tobytes()) + + query = Query("*=>[KNN 2 @v $vec as score]") + query_params = {"vec": np.array(a, dtype=np.uint8).tobytes()} + + assert 2 in query.get_args() + + res = client.ft().search(query, query_params=query_params) + if is_resp2_connection(client): + assert res.total == 2 + else: + assert res["total_results"] == 2 + + @pytest.mark.redismod @skip_ifmodversion_lt("2.4.3", "search") def test_search_query_with_different_dialects(client): From 126f28af3b3b0d6dee554a190799633b72008d5c Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Tue, 25 Feb 2025 19:54:21 +0200 Subject: [PATCH 041/113] Fixing wrong type hints (#3526) --- redis/commands/core.py | 7 ++++--- redis/commands/search/commands.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/redis/commands/core.py b/redis/commands/core.py index 1d19e33f2c..6ab0602698 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -35,6 +35,7 @@ GroupT, KeysT, KeyT, + Number, PatternT, ResponseT, ScriptTextT, @@ -2567,7 +2568,7 @@ class ListCommands(CommandsProtocol): """ def blpop( - self, keys: List, timeout: Optional[int] = 0 + self, keys: List, timeout: Optional[Number] = 0 ) -> Union[Awaitable[list], list]: """ LPOP a value off of the first non-empty list @@ -2588,7 +2589,7 @@ def blpop( return self.execute_command("BLPOP", *keys) def brpop( - self, keys: List, timeout: Optional[int] = 0 + self, keys: List, timeout: Optional[Number] = 0 ) -> Union[Awaitable[list], list]: """ RPOP a value off of the first non-empty list @@ -2609,7 +2610,7 @@ def brpop( return self.execute_command("BRPOP", *keys) def brpoplpush( - self, src: str, dst: str, timeout: Optional[int] = 0 + self, src: str, dst: str, timeout: Optional[Number] = 0 ) -> Union[Awaitable[Optional[str]], Optional[str]]: """ Pop a value off the tail of ``src``, push it on the head of ``dst`` diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index e5e78578be..1db57c23a5 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -586,7 +586,7 @@ def _get_aggregate_result( def profile( self, - query: Union[str, Query, AggregateRequest], + query: Union[Query, AggregateRequest], limited: bool = False, query_params: Optional[Dict[str, Union[str, int, float]]] = None, ): @@ -596,7 +596,7 @@ def profile( ### Parameters - **query**: This can be either an `AggregateRequest`, `Query` or string. + **query**: This can be either an `AggregateRequest` or `Query`. **limited**: If set to True, removes details of reader iterator. **query_params**: Define one or more value parameters. Each parameter has a name and a value. From c7c896e95b76eb8ad416b9797801515e1438fbaa Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Wed, 26 Feb 2025 10:25:23 +0200 Subject: [PATCH 042/113] Fix readthedocs.yml format for python install configuration (#3529) --- .readthedocs.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.readthedocs.yml b/.readthedocs.yml index 4b22490a49..17bfa9d47c 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -2,10 +2,9 @@ version: 2 python: install: + - requirements: docs/requirements.txt - method: pip path: . - - method: pip - requirements: ./docs/requirements.txt build: os: ubuntu-20.04 From dec26ff850cfb557af90ccd4c14da59fd1749f25 Mon Sep 17 00:00:00 2001 From: bssyousefi <44493177+bssyousefi@users.noreply.github.com> Date: Wed, 26 Feb 2025 08:14:59 -0500 Subject: [PATCH 043/113] Add valid Exception type to Except in ClusterPipeline (#3516) --- redis/cluster.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/cluster.py b/redis/cluster.py index c184838a9b..f518a1f184 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -2122,7 +2122,7 @@ def send_cluster_commands( raise_on_error=raise_on_error, allow_redirections=allow_redirections, ) - except (ClusterDownError, ConnectionError) as e: + except RedisCluster.ERRORS_ALLOW_RETRY as e: if retry_attempts > 0: # Try again with the new cluster setup. All other errors # should be raised. From 8339b166b908aca8fce4bba8e0cdca0bc0557932 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 27 Feb 2025 10:26:47 +0200 Subject: [PATCH 044/113] Make PyJWT an optional dependency (#3518) --- .github/actions/run-tests/action.yml | 2 +- CONTRIBUTING.md | 14 +++++++------- pyproject.toml | 4 +++- redis/auth/token.py | 7 ++++++- tests/conftest.py | 2 +- tests/test_asyncio/conftest.py | 2 +- tests/test_auth/test_token.py | 3 ++- 7 files changed, 21 insertions(+), 13 deletions(-) diff --git a/.github/actions/run-tests/action.yml b/.github/actions/run-tests/action.yml index d822e1d499..aa958a9236 100644 --- a/.github/actions/run-tests/action.yml +++ b/.github/actions/run-tests/action.yml @@ -38,7 +38,7 @@ runs: echo "::group::Installing dependencies" pip install -r dev_requirements.txt pip uninstall -y redis # uninstall Redis package installed via redis-entraid - pip install -e . # install the working copy + pip install -e .[jwt] # install the working copy if [ "${{inputs.parser-backend}}" == "hiredis" ]; then pip install "hiredis${{inputs.hiredis-version}}" echo "PARSER_BACKEND=$(echo "${{inputs.parser-backend}}_${{inputs.hiredis-version}}" | sed 's/[^a-zA-Z0-9]/_/g')" >> $GITHUB_ENV diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 79983e4cb5..eb333f644f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -32,13 +32,13 @@ Here's how to get started with your code contribution: 1. Create your own fork of redis-py 2. Do the changes in your fork -3. - *Create a virtualenv and install the development dependencies from the dev_requirements.txt file:* - - a. python -m venv .venv - b. source .venv/bin/activate - c. pip install -r dev_requirements.txt - c. pip install -e . +3. Create a virtualenv and install the development dependencies from the dev_requirements.txt file: + ``` + python -m venv .venv + source .venv/bin/activate + pip install -r dev_requirements.txt + pip install -e .[jwt] + ``` 4. If you need a development environment, run `invoke devenv`. Note: this relies on docker-compose to build environments, and assumes that you have a version supporting [docker profiles](https://docs.docker.com/compose/profiles/). 5. While developing, make sure the tests pass by running `invoke tests` diff --git a/pyproject.toml b/pyproject.toml index 7becde948e..0ec38ed3b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,6 @@ classifiers = [ ] dependencies = [ 'async-timeout>=4.0.3; python_full_version<"3.11.3"', - "PyJWT~=2.9.0", ] [project.optional-dependencies] @@ -49,6 +48,9 @@ ocsp = [ "pyopenssl==20.0.1", "requests>=2.31.0", ] +jwt = [ + "PyJWT~=2.9.0", +] [project.urls] Changes = "https://github.com/redis/redis-py/releases" diff --git a/redis/auth/token.py b/redis/auth/token.py index 876e95c4fa..1c5246469b 100644 --- a/redis/auth/token.py +++ b/redis/auth/token.py @@ -1,7 +1,6 @@ from abc import ABC, abstractmethod from datetime import datetime, timezone -import jwt from redis.auth.err import InvalidTokenSchemaErr @@ -81,6 +80,12 @@ class JWToken(TokenInterface): REQUIRED_FIELDS = {"exp"} def __init__(self, token: str): + try: + import jwt + except ImportError as ie: + raise ImportError( + f"The PyJWT library is required for {self.__class__.__name__}.", + ) from ie self._value = token self._decoded = jwt.decode( self._value, diff --git a/tests/conftest.py b/tests/conftest.py index fc732c0d72..8795c9f022 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,7 +10,6 @@ from unittest.mock import Mock from urllib.parse import urlparse -import jwt import pytest import redis from packaging.version import Version @@ -615,6 +614,7 @@ def cache_key(request) -> CacheKey: def mock_identity_provider() -> IdentityProviderInterface: + jwt = pytest.importorskip("jwt") mock_provider = Mock(spec=IdentityProviderInterface) token = {"exp": datetime.now(timezone.utc).timestamp() + 3600, "oid": "username"} encoded = jwt.encode(token, "secret", algorithm="HS256") diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index fb6c51140e..99ad155d0a 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -5,7 +5,6 @@ from enum import Enum from typing import Union -import jwt import pytest import pytest_asyncio import redis.asyncio as redis @@ -247,6 +246,7 @@ async def mock_cluster_resp_slaves(create_redis, **kwargs): def mock_identity_provider() -> IdentityProviderInterface: + jwt = pytest.importorskip("jwt") mock_provider = Mock(spec=IdentityProviderInterface) token = {"exp": datetime.now(timezone.utc).timestamp() + 3600, "oid": "username"} encoded = jwt.encode(token, "secret", algorithm="HS256") diff --git a/tests/test_auth/test_token.py b/tests/test_auth/test_token.py index 978cc2ca8c..2d72e08895 100644 --- a/tests/test_auth/test_token.py +++ b/tests/test_auth/test_token.py @@ -1,6 +1,5 @@ from datetime import datetime, timezone -import jwt import pytest from redis.auth.err import InvalidTokenSchemaErr from redis.auth.token import JWToken, SimpleToken @@ -39,6 +38,8 @@ def test_simple_token(self): assert token.get_expires_at_ms() == -1 def test_jwt_token(self): + jwt = pytest.importorskip("jwt") + token = { "exp": datetime.now(timezone.utc).timestamp() + 100, "iat": datetime.now(timezone.utc).timestamp(), From 797c59f227e1d00e0db910bbf92e6ae5bae13be9 Mon Sep 17 00:00:00 2001 From: alposomn <45829982+666romeo@users.noreply.github.com> Date: Thu, 27 Feb 2025 14:58:20 +0200 Subject: [PATCH 045/113] Add force_master_ip support to async Sentinel client (#3524) --- redis/asyncio/sentinel.py | 10 +++++++++- tests/test_asyncio/conftest.py | 2 ++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py index f1d2cab3f1..0389539fcf 100644 --- a/redis/asyncio/sentinel.py +++ b/redis/asyncio/sentinel.py @@ -198,6 +198,7 @@ def __init__( sentinels, min_other_sentinels=0, sentinel_kwargs=None, + force_master_ip=None, **connection_kwargs, ): # if sentinel_kwargs isn't defined, use the socket_* options from @@ -214,6 +215,7 @@ def __init__( ] self.min_other_sentinels = min_other_sentinels self.connection_kwargs = connection_kwargs + self._force_master_ip = force_master_ip async def execute_command(self, *args, **kwargs): """ @@ -277,7 +279,13 @@ async def discover_master(self, service_name: str): sentinel, self.sentinels[0], ) - return state["ip"], state["port"] + + ip = ( + self._force_master_ip + if self._force_master_ip is not None + else state["ip"] + ) + return ip, state["port"] error_info = "" if len(collected_errors) > 0: diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 99ad155d0a..d9cccf1b92 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -151,8 +151,10 @@ async def sentinel_setup(local_cache, request): for ip, port in (endpoint.split(":") for endpoint in sentinel_ips.split(",")) ] kwargs = request.param.get("kwargs", {}) if hasattr(request, "param") else {} + force_master_ip = request.param.get("force_master_ip", None) sentinel = Sentinel( sentinel_endpoints, + force_master_ip=force_master_ip, socket_timeout=0.1, client_cache=local_cache, protocol=3, From 77193ceef72584eac0fb0c19896264ee7ceff947 Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Fri, 28 Feb 2025 17:16:14 +0200 Subject: [PATCH 046/113] Fixing typing for FCALL commands to match PEP 484 (#3537) * Fixing typing for FCALL commands to match PEP 484 * Codestyle fixes * Fixes issue #3536 --- redis/commands/core.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/redis/commands/core.py b/redis/commands/core.py index 6ab0602698..880d0fda41 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -5,6 +5,7 @@ import warnings from typing import ( TYPE_CHECKING, + Any, AsyncIterator, Awaitable, Callable, @@ -6397,12 +6398,12 @@ def function_list( return self.execute_command("FUNCTION LIST", *args) def _fcall( - self, command: str, function, numkeys: int, *keys_and_args: Optional[List] + self, command: str, function, numkeys: int, *keys_and_args: Any ) -> Union[Awaitable[str], str]: return self.execute_command(command, function, numkeys, *keys_and_args) def fcall( - self, function, numkeys: int, *keys_and_args: Optional[List] + self, function, numkeys: int, *keys_and_args: Any ) -> Union[Awaitable[str], str]: """ Invoke a function. @@ -6412,7 +6413,7 @@ def fcall( return self._fcall("FCALL", function, numkeys, *keys_and_args) def fcall_ro( - self, function, numkeys: int, *keys_and_args: Optional[List] + self, function, numkeys: int, *keys_and_args: Any ) -> Union[Awaitable[str], str]: """ This is a read-only variant of the FCALL command that cannot From 6c81598eaa9e368edeab761fb86e1bc9c53d7a6f Mon Sep 17 00:00:00 2001 From: Mathew Shen Date: Tue, 4 Mar 2025 17:34:19 +0800 Subject: [PATCH 047/113] fix(lock): Fix LockError message when releasing a lock. (#3534) * fix(lock): raise LockNotOwnedError when release a lock from non-owned thread * change: error raise * fix: linter * fix(lock): async release --------- Co-authored-by: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> --- redis/asyncio/lock.py | 5 ++++- redis/lock.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/redis/asyncio/lock.py b/redis/asyncio/lock.py index e1d11a882d..bb2cccab52 100644 --- a/redis/asyncio/lock.py +++ b/redis/asyncio/lock.py @@ -249,7 +249,10 @@ def release(self) -> Awaitable[None]: """Releases the already acquired lock""" expected_token = self.local.token if expected_token is None: - raise LockError("Cannot release an unlocked lock") + raise LockError( + "Cannot release a lock that's not owned or is already unlocked.", + lock_name=self.name, + ) self.local.token = None return self.do_release(expected_token) diff --git a/redis/lock.py b/redis/lock.py index f44ed629da..7a1becb30a 100644 --- a/redis/lock.py +++ b/redis/lock.py @@ -251,7 +251,10 @@ def release(self) -> None: """ expected_token = self.local.token if expected_token is None: - raise LockError("Cannot release an unlocked lock", lock_name=self.name) + raise LockError( + "Cannot release a lock that's not owned or is already unlocked.", + lock_name=self.name, + ) self.local.token = None self.do_release(expected_token) From 04eafb82267f9686087877b75466f5aaf839085c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 4 Mar 2025 13:05:56 +0200 Subject: [PATCH 048/113] Bump rojopolis/spellcheck-github-actions from 0.38.0 to 0.47.0 (#3538) Bumps [rojopolis/spellcheck-github-actions](https://github.com/rojopolis/spellcheck-github-actions) from 0.38.0 to 0.47.0. - [Release notes](https://github.com/rojopolis/spellcheck-github-actions/releases) - [Changelog](https://github.com/rojopolis/spellcheck-github-actions/blob/master/CHANGELOG.md) - [Commits](https://github.com/rojopolis/spellcheck-github-actions/compare/0.38.0...0.47.0) --- updated-dependencies: - dependency-name: rojopolis/spellcheck-github-actions dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: petyaslavova --- .github/workflows/spellcheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/spellcheck.yml b/.github/workflows/spellcheck.yml index 62e38997e4..beefa6164f 100644 --- a/.github/workflows/spellcheck.yml +++ b/.github/workflows/spellcheck.yml @@ -8,7 +8,7 @@ jobs: - name: Checkout uses: actions/checkout@v4 - name: Check Spelling - uses: rojopolis/spellcheck-github-actions@0.38.0 + uses: rojopolis/spellcheck-github-actions@0.47.0 with: config_path: .github/spellcheck-settings.yml task_name: Markdown From d540d56c1c71c296b973e85c6d771e25330d749a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=89=E6=A0=B7?= <83902645+three-kinds@users.noreply.github.com> Date: Wed, 5 Mar 2025 02:00:17 +0800 Subject: [PATCH 049/113] Fix client_list with multiple client ids (#3539) --- redis/commands/core.py | 2 +- tests/test_commands.py | 14 +++++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/redis/commands/core.py b/redis/commands/core.py index 880d0fda41..c3ffb955c4 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -530,7 +530,7 @@ def client_list( raise DataError("client_id must be a list") if client_id: args.append(b"ID") - args.append(" ".join(client_id)) + args += client_id return self.execute_command("CLIENT LIST", *args, **kwargs) def client_getname(self, **kwargs) -> ResponseT: diff --git a/tests/test_commands.py b/tests/test_commands.py index f89c5f3365..b6f13f6aa8 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -667,11 +667,15 @@ def test_client_list_client_id(self, r, request): assert "addr" in clients[0] # testing multiple client ids - _get_client(redis.Redis, request, flushdb=False) - _get_client(redis.Redis, request, flushdb=False) - _get_client(redis.Redis, request, flushdb=False) - clients_listed = r.client_list(client_id=clients[:-1]) - assert len(clients_listed) > 1 + client_list = list() + client_count = 3 + for i in range(client_count): + client = _get_client(redis.Redis, request, flushdb=False) + client_list.append(client) + + multiple_client_ids = [str(client.client_id()) for client in client_list] + clients_listed = r.client_list(client_id=multiple_client_ids) + assert len(clients_listed) == len(multiple_client_ids) @pytest.mark.onlynoncluster @skip_if_server_version_lt("5.0.0") From 5b5340fe176ff9da1d0e7e781aeaec90b3b31a24 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 4 Mar 2025 23:45:28 -0800 Subject: [PATCH 050/113] Fix connection health check for protocol != 2 when auth credentials are provided and health check interval is configured (#3477) --- redis/asyncio/connection.py | 6 +++++- redis/connection.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index e67dc5b207..2e2a2502c3 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -363,7 +363,11 @@ async def on_connect(self) -> None: self._parser.on_connect(self) if len(auth_args) == 1: auth_args = ["default", auth_args[0]] - await self.send_command("HELLO", self.protocol, "AUTH", *auth_args) + # avoid checking health here -- PING will fail if we try + # to check the health prior to the AUTH + await self.send_command( + "HELLO", self.protocol, "AUTH", *auth_args, check_health=False + ) response = await self.read_response() if response.get(b"proto") != int(self.protocol) and response.get( "proto" diff --git a/redis/connection.py b/redis/connection.py index ece17d752b..2391e74d2c 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -440,7 +440,11 @@ def on_connect(self): self._parser.on_connect(self) if len(auth_args) == 1: auth_args = ["default", auth_args[0]] - self.send_command("HELLO", self.protocol, "AUTH", *auth_args) + # avoid checking health here -- PING will fail if we try + # to check the health prior to the AUTH + self.send_command( + "HELLO", self.protocol, "AUTH", *auth_args, check_health=False + ) self.handshake_metadata = self.read_response() # if response.get(b"proto") != self.protocol and response.get( # "proto" From 5eb9939ddc69ecccd4a89d9fa7ee50b59270e9bb Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 5 Mar 2025 14:02:19 +0200 Subject: [PATCH 051/113] Isolate redis-entraid dependency for tests (#3521) --- tests/conftest.py | 141 +--------------------- tests/entraid_utils.py | 140 ++++++++++++++++++++++ tests/test_asyncio/conftest.py | 154 +------------------------ tests/test_asyncio/test_credentials.py | 9 +- tests/test_credentials.py | 9 +- 5 files changed, 163 insertions(+), 290 deletions(-) create mode 100644 tests/entraid_utils.py diff --git a/tests/conftest.py b/tests/conftest.py index 8795c9f022..e5eea4d582 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,8 +4,7 @@ import random import time from datetime import datetime, timezone -from enum import Enum -from typing import Callable, TypeVar, Union +from typing import Callable, TypeVar from unittest import mock from unittest.mock import Mock from urllib.parse import urlparse @@ -16,7 +15,6 @@ from redis import Sentinel from redis.auth.idp import IdentityProviderInterface from redis.auth.token import JWToken -from redis.auth.token_manager import RetryPolicy, TokenManagerConfig from redis.backoff import NoBackoff from redis.cache import ( CacheConfig, @@ -29,22 +27,6 @@ from redis.credentials import CredentialProvider from redis.exceptions import RedisClusterException from redis.retry import Retry -from redis_entraid.cred_provider import ( - DEFAULT_DELAY_IN_MS, - DEFAULT_EXPIRATION_REFRESH_RATIO, - DEFAULT_LOWER_REFRESH_BOUND_MILLIS, - DEFAULT_MAX_ATTEMPTS, - DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, - EntraIdCredentialsProvider, -) -from redis_entraid.identity_provider import ( - ManagedIdentityIdType, - ManagedIdentityProviderConfig, - ManagedIdentityType, - ServicePrincipalIdentityProviderConfig, - _create_provider_from_managed_identity, - _create_provider_from_service_principal, -) from tests.ssl_utils import get_tls_certificates REDIS_INFO = {} @@ -60,11 +42,6 @@ _TestDecorator = Callable[[_DecoratedTest], _DecoratedTest] -class AuthType(Enum): - MANAGED_IDENTITY = "managed_identity" - SERVICE_PRINCIPAL = "service_principal" - - # Taken from python3.9 class BooleanOptionalAction(argparse.Action): def __init__( @@ -623,124 +600,18 @@ def mock_identity_provider() -> IdentityProviderInterface: return mock_provider -def identity_provider(request) -> IdentityProviderInterface: - if hasattr(request, "param"): - kwargs = request.param.get("idp_kwargs", {}) - else: - kwargs = {} - - if request.param.get("mock_idp", None) is not None: - return mock_identity_provider() - - auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) - config = get_identity_provider_config(request=request) - - if auth_type == "MANAGED_IDENTITY": - return _create_provider_from_managed_identity(config) - - return _create_provider_from_service_principal(config) - - -def get_identity_provider_config( - request, -) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]: - if hasattr(request, "param"): - kwargs = request.param.get("idp_kwargs", {}) - else: - kwargs = {} - - auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) - - if auth_type == AuthType.MANAGED_IDENTITY: - return _get_managed_identity_provider_config(request) - - return _get_service_principal_provider_config(request) - - -def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig: - resource = os.getenv("AZURE_RESOURCE") - id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None) - - if hasattr(request, "param"): - kwargs = request.param.get("idp_kwargs", {}) - else: - kwargs = {} - - identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED) - id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID) - - return ManagedIdentityProviderConfig( - identity_type=identity_type, - resource=resource, - id_type=id_type, - id_value=id_value, - kwargs=kwargs, - ) - - -def _get_service_principal_provider_config( - request, -) -> ServicePrincipalIdentityProviderConfig: - client_id = os.getenv("AZURE_CLIENT_ID") - client_credential = os.getenv("AZURE_CLIENT_SECRET") - tenant_id = os.getenv("AZURE_TENANT_ID") - scopes = os.getenv("AZURE_REDIS_SCOPES", None) - - if hasattr(request, "param"): - kwargs = request.param.get("idp_kwargs", {}) - token_kwargs = request.param.get("token_kwargs", {}) - timeout = request.param.get("timeout", None) - else: - kwargs = {} - token_kwargs = {} - timeout = None - - if isinstance(scopes, str): - scopes = scopes.split(",") - - return ServicePrincipalIdentityProviderConfig( - client_id=client_id, - client_credential=client_credential, - scopes=scopes, - timeout=timeout, - token_kwargs=token_kwargs, - tenant_id=tenant_id, - app_kwargs=kwargs, - ) - - def get_credential_provider(request) -> CredentialProvider: cred_provider_class = request.param.get("cred_provider_class") cred_provider_kwargs = request.param.get("cred_provider_kwargs", {}) - if cred_provider_class != EntraIdCredentialsProvider: + # Since we can't import EntraIdCredentialsProvider in this module, + # we'll just check the class name. + if cred_provider_class.__name__ != "EntraIdCredentialsProvider": return cred_provider_class(**cred_provider_kwargs) - idp = identity_provider(request) - expiration_refresh_ratio = cred_provider_kwargs.get( - "expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO - ) - lower_refresh_bound_millis = cred_provider_kwargs.get( - "lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS - ) - max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS) - delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS) - - token_mgr_config = TokenManagerConfig( - expiration_refresh_ratio=expiration_refresh_ratio, - lower_refresh_bound_millis=lower_refresh_bound_millis, - token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, # noqa - retry_policy=RetryPolicy( - max_attempts=max_attempts, - delay_in_ms=delay_in_ms, - ), - ) + from tests.entraid_utils import get_entra_id_credentials_provider - return EntraIdCredentialsProvider( - identity_provider=idp, - token_manager_config=token_mgr_config, - initial_delay_in_ms=delay_in_ms, - ) + return get_entra_id_credentials_provider(request, cred_provider_kwargs) @pytest.fixture() diff --git a/tests/entraid_utils.py b/tests/entraid_utils.py new file mode 100644 index 0000000000..daefbd3956 --- /dev/null +++ b/tests/entraid_utils.py @@ -0,0 +1,140 @@ +import os +from enum import Enum +from typing import Union + +from redis.auth.idp import IdentityProviderInterface +from redis.auth.token_manager import RetryPolicy, TokenManagerConfig +from redis_entraid.cred_provider import ( + DEFAULT_DELAY_IN_MS, + DEFAULT_EXPIRATION_REFRESH_RATIO, + DEFAULT_LOWER_REFRESH_BOUND_MILLIS, + DEFAULT_MAX_ATTEMPTS, + DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, + EntraIdCredentialsProvider, +) +from redis_entraid.identity_provider import ( + ManagedIdentityIdType, + ManagedIdentityProviderConfig, + ManagedIdentityType, + ServicePrincipalIdentityProviderConfig, + _create_provider_from_managed_identity, + _create_provider_from_service_principal, +) +from tests.conftest import mock_identity_provider + + +class AuthType(Enum): + MANAGED_IDENTITY = "managed_identity" + SERVICE_PRINCIPAL = "service_principal" + + +def identity_provider(request) -> IdentityProviderInterface: + if hasattr(request, "param"): + kwargs = request.param.get("idp_kwargs", {}) + else: + kwargs = {} + + if request.param.get("mock_idp", None) is not None: + return mock_identity_provider() + + auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) + config = get_identity_provider_config(request=request) + + if auth_type == "MANAGED_IDENTITY": + return _create_provider_from_managed_identity(config) + + return _create_provider_from_service_principal(config) + + +def get_identity_provider_config( + request, +) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]: + if hasattr(request, "param"): + kwargs = request.param.get("idp_kwargs", {}) + else: + kwargs = {} + + auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) + + if auth_type == AuthType.MANAGED_IDENTITY: + return _get_managed_identity_provider_config(request) + + return _get_service_principal_provider_config(request) + + +def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig: + resource = os.getenv("AZURE_RESOURCE") + id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None) + + if hasattr(request, "param"): + kwargs = request.param.get("idp_kwargs", {}) + else: + kwargs = {} + + identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED) + id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID) + + return ManagedIdentityProviderConfig( + identity_type=identity_type, + resource=resource, + id_type=id_type, + id_value=id_value, + kwargs=kwargs, + ) + + +def _get_service_principal_provider_config( + request, +) -> ServicePrincipalIdentityProviderConfig: + client_id = os.getenv("AZURE_CLIENT_ID") + client_credential = os.getenv("AZURE_CLIENT_SECRET") + tenant_id = os.getenv("AZURE_TENANT_ID") + scopes = os.getenv("AZURE_REDIS_SCOPES", None) + + if hasattr(request, "param"): + kwargs = request.param.get("idp_kwargs", {}) + token_kwargs = request.param.get("token_kwargs", {}) + timeout = request.param.get("timeout", None) + else: + kwargs = {} + token_kwargs = {} + timeout = None + + if isinstance(scopes, str): + scopes = scopes.split(",") + + return ServicePrincipalIdentityProviderConfig( + client_id=client_id, + client_credential=client_credential, + scopes=scopes, + timeout=timeout, + token_kwargs=token_kwargs, + tenant_id=tenant_id, + app_kwargs=kwargs, + ) + + +def get_entra_id_credentials_provider(request, cred_provider_kwargs): + idp = identity_provider(request) + expiration_refresh_ratio = cred_provider_kwargs.get( + "expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO + ) + lower_refresh_bound_millis = cred_provider_kwargs.get( + "lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS + ) + max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS) + delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS) + token_mgr_config = TokenManagerConfig( + expiration_refresh_ratio=expiration_refresh_ratio, + lower_refresh_bound_millis=lower_refresh_bound_millis, + token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, # noqa + retry_policy=RetryPolicy( + max_attempts=max_attempts, + delay_in_ms=delay_in_ms, + ), + ) + return EntraIdCredentialsProvider( + identity_provider=idp, + token_manager_config=token_mgr_config, + initial_delay_in_ms=delay_in_ms, + ) diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index d9cccf1b92..226b00aa45 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -1,41 +1,19 @@ -import os import random from contextlib import asynccontextmanager as _asynccontextmanager -from datetime import datetime, timezone from enum import Enum from typing import Union import pytest import pytest_asyncio import redis.asyncio as redis -from mock.mock import Mock from packaging.version import Version from redis.asyncio import Sentinel from redis.asyncio.client import Monitor from redis.asyncio.connection import Connection, parse_url from redis.asyncio.retry import Retry -from redis.auth.idp import IdentityProviderInterface -from redis.auth.token import JWToken -from redis.auth.token_manager import RetryPolicy, TokenManagerConfig from redis.backoff import NoBackoff from redis.credentials import CredentialProvider -from redis_entraid.cred_provider import ( - DEFAULT_DELAY_IN_MS, - DEFAULT_EXPIRATION_REFRESH_RATIO, - DEFAULT_LOWER_REFRESH_BOUND_MILLIS, - DEFAULT_MAX_ATTEMPTS, - DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, - EntraIdCredentialsProvider, -) -from redis_entraid.identity_provider import ( - ManagedIdentityIdType, - ManagedIdentityProviderConfig, - ManagedIdentityType, - ServicePrincipalIdentityProviderConfig, - _create_provider_from_managed_identity, - _create_provider_from_service_principal, -) -from tests.conftest import REDIS_INFO +from tests.conftest import REDIS_INFO, get_credential_provider from .compat import mock @@ -247,136 +225,6 @@ async def mock_cluster_resp_slaves(create_redis, **kwargs): yield mocked -def mock_identity_provider() -> IdentityProviderInterface: - jwt = pytest.importorskip("jwt") - mock_provider = Mock(spec=IdentityProviderInterface) - token = {"exp": datetime.now(timezone.utc).timestamp() + 3600, "oid": "username"} - encoded = jwt.encode(token, "secret", algorithm="HS256") - jwt_token = JWToken(encoded) - mock_provider.request_token.return_value = jwt_token - return mock_provider - - -def identity_provider(request) -> IdentityProviderInterface: - if hasattr(request, "param"): - kwargs = request.param.get("idp_kwargs", {}) - else: - kwargs = {} - - if request.param.get("mock_idp", None) is not None: - return mock_identity_provider() - - auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) - config = get_identity_provider_config(request=request) - - if auth_type == "MANAGED_IDENTITY": - return _create_provider_from_managed_identity(config) - - return _create_provider_from_service_principal(config) - - -def get_identity_provider_config( - request, -) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]: - if hasattr(request, "param"): - kwargs = request.param.get("idp_kwargs", {}) - else: - kwargs = {} - - auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) - - if auth_type == AuthType.MANAGED_IDENTITY: - return _get_managed_identity_provider_config(request) - - return _get_service_principal_provider_config(request) - - -def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig: - resource = os.getenv("AZURE_RESOURCE") - id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None) - - if hasattr(request, "param"): - kwargs = request.param.get("idp_kwargs", {}) - else: - kwargs = {} - - identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED) - id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID) - - return ManagedIdentityProviderConfig( - identity_type=identity_type, - resource=resource, - id_type=id_type, - id_value=id_value, - kwargs=kwargs, - ) - - -def _get_service_principal_provider_config( - request, -) -> ServicePrincipalIdentityProviderConfig: - client_id = os.getenv("AZURE_CLIENT_ID") - client_credential = os.getenv("AZURE_CLIENT_SECRET") - tenant_id = os.getenv("AZURE_TENANT_ID") - scopes = os.getenv("AZURE_REDIS_SCOPES", None) - - if hasattr(request, "param"): - kwargs = request.param.get("idp_kwargs", {}) - token_kwargs = request.param.get("token_kwargs", {}) - timeout = request.param.get("timeout", None) - else: - kwargs = {} - token_kwargs = {} - timeout = None - - if isinstance(scopes, str): - scopes = scopes.split(",") - - return ServicePrincipalIdentityProviderConfig( - client_id=client_id, - client_credential=client_credential, - scopes=scopes, - timeout=timeout, - token_kwargs=token_kwargs, - tenant_id=tenant_id, - app_kwargs=kwargs, - ) - - -def get_credential_provider(request) -> CredentialProvider: - cred_provider_class = request.param.get("cred_provider_class") - cred_provider_kwargs = request.param.get("cred_provider_kwargs", {}) - - if cred_provider_class != EntraIdCredentialsProvider: - return cred_provider_class(**cred_provider_kwargs) - - idp = identity_provider(request) - expiration_refresh_ratio = cred_provider_kwargs.get( - "expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO - ) - lower_refresh_bound_millis = cred_provider_kwargs.get( - "lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS - ) - max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS) - delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS) - - token_mgr_config = TokenManagerConfig( - expiration_refresh_ratio=expiration_refresh_ratio, - lower_refresh_bound_millis=lower_refresh_bound_millis, - token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, # noqa - retry_policy=RetryPolicy( - max_attempts=max_attempts, - delay_in_ms=delay_in_ms, - ), - ) - - return EntraIdCredentialsProvider( - identity_provider=idp, - token_manager_config=token_mgr_config, - initial_delay_in_ms=delay_in_ms, - ) - - @pytest_asyncio.fixture() async def credential_provider(request) -> CredentialProvider: return get_credential_provider(request) diff --git a/tests/test_asyncio/test_credentials.py b/tests/test_asyncio/test_credentials.py index 1eb988ce71..ce8d76ea45 100644 --- a/tests/test_asyncio/test_credentials.py +++ b/tests/test_asyncio/test_credentials.py @@ -17,10 +17,14 @@ from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider from redis.exceptions import ConnectionError from redis.utils import str_if_bytes -from redis_entraid.cred_provider import EntraIdCredentialsProvider from tests.conftest import get_endpoint, skip_if_redis_enterprise from tests.test_asyncio.conftest import get_credential_provider +try: + from redis_entraid.cred_provider import EntraIdCredentialsProvider +except ImportError: + EntraIdCredentialsProvider = None + @pytest.fixture() def endpoint(request): @@ -321,6 +325,7 @@ async def test_user_pass_provider_only_password( @pytest.mark.asyncio @pytest.mark.onlynoncluster +@pytest.mark.skipif(not EntraIdCredentialsProvider, reason="requires redis-entraid") class TestStreamingCredentialProvider: @pytest.mark.parametrize( "credential_provider", @@ -599,6 +604,7 @@ async def test_fails_on_token_renewal(self, credential_provider): @pytest.mark.asyncio @pytest.mark.onlynoncluster @pytest.mark.cp_integration +@pytest.mark.skipif(not EntraIdCredentialsProvider, reason="requires redis-entraid") class TestEntraIdCredentialsProvider: @pytest.mark.parametrize( "r_credential", @@ -674,6 +680,7 @@ async def test_async_auth_pubsub_with_credential_provider( @pytest.mark.asyncio @pytest.mark.onlycluster @pytest.mark.cp_integration +@pytest.mark.skipif(not EntraIdCredentialsProvider, reason="requires redis-entraid") class TestClusterEntraIdCredentialsProvider: @pytest.mark.parametrize( "r_credential", diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 95ec5577cc..1f98c5208d 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -16,7 +16,6 @@ from redis.exceptions import ConnectionError, RedisError from redis.retry import Retry from redis.utils import str_if_bytes -from redis_entraid.cred_provider import EntraIdCredentialsProvider from tests.conftest import ( _get_client, get_credential_provider, @@ -24,6 +23,11 @@ skip_if_redis_enterprise, ) +try: + from redis_entraid.cred_provider import EntraIdCredentialsProvider +except ImportError: + EntraIdCredentialsProvider = None + @pytest.fixture() def endpoint(request): @@ -295,6 +299,7 @@ def test_user_pass_provider_only_password(self, r, request): @pytest.mark.onlynoncluster +@pytest.mark.skipif(not EntraIdCredentialsProvider, reason="requires redis-entraid") class TestStreamingCredentialProvider: @pytest.mark.parametrize( "credential_provider", @@ -567,6 +572,7 @@ def test_fails_on_token_renewal(self, credential_provider): @pytest.mark.onlynoncluster @pytest.mark.cp_integration +@pytest.mark.skipif(not EntraIdCredentialsProvider, reason="requires redis-entraid") class TestEntraIdCredentialsProvider: @pytest.mark.parametrize( "r_entra", @@ -637,6 +643,7 @@ def test_auth_pubsub_with_credential_provider(self, r_entra: redis.Redis): @pytest.mark.onlycluster @pytest.mark.cp_integration +@pytest.mark.skipif(not EntraIdCredentialsProvider, reason="requires redis-entraid") class TestClusterEntraIdCredentialsProvider: @pytest.mark.parametrize( "r_entra", From dea3d9a1cf01d183fb5857a6dc269e274303c8a7 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 5 Mar 2025 20:04:54 +0200 Subject: [PATCH 052/113] Replace flake8+isort+black with `ruff` (#3147) * Replace flake8 + isort + flynt with ruff * Replace black with `ruff format`; run it --- .flake8 | 28 --------------- .isort.cfg | 5 --- benchmarks/basic_operations.py | 2 +- benchmarks/command_packer_benchmark.py | 1 - benchmarks/socket_read_size.py | 1 - dev_requirements.txt | 5 +-- doctests/README.md | 4 +-- pyproject.toml | 50 ++++++++++++++++++++++++++ redis/_parsers/base.py | 4 +-- redis/asyncio/cluster.py | 6 ++-- redis/asyncio/connection.py | 2 +- redis/asyncio/utils.py | 2 +- redis/auth/token.py | 1 - redis/client.py | 3 +- redis/cluster.py | 5 ++- redis/commands/cluster.py | 2 +- redis/commands/core.py | 7 ++-- redis/commands/graph/commands.py | 4 +-- redis/commands/helpers.py | 4 +-- redis/connection.py | 6 ++-- redis/exceptions.py | 3 ++ redis/ocsp.py | 3 +- redis/sentinel.py | 2 +- tasks.py | 6 ++-- tests/conftest.py | 2 +- tests/test_asyncio/conftest.py | 2 +- tests/test_asyncio/test_cluster.py | 11 +++--- tests/test_asyncio/test_commands.py | 15 ++++---- tests/test_asyncio/test_graph.py | 3 +- tests/test_asyncio/test_search.py | 12 +++---- tests/test_auth/test_token.py | 1 - tests/test_cluster.py | 10 +++--- tests/test_commands.py | 15 ++++---- tests/test_connection.py | 1 - tests/test_graph.py | 3 +- tests/test_search.py | 12 +++---- 36 files changed, 123 insertions(+), 120 deletions(-) delete mode 100644 .flake8 delete mode 100644 .isort.cfg diff --git a/.flake8 b/.flake8 deleted file mode 100644 index b1bd1d0b75..0000000000 --- a/.flake8 +++ /dev/null @@ -1,28 +0,0 @@ -[flake8] -max-line-length = 88 -exclude = - *.egg-info, - *.pyc, - .git, - .venv*, - build, - docs/*, - dist, - docker, - venv*, - .venv*, - whitelist.py, - tasks.py -ignore = - E126 - E203 - E231 - E701 - E704 - F405 - N801 - N802 - N803 - N806 - N815 - W503 diff --git a/.isort.cfg b/.isort.cfg deleted file mode 100644 index 039f0337a2..0000000000 --- a/.isort.cfg +++ /dev/null @@ -1,5 +0,0 @@ -[settings] -profile=black -multi_line_output=3 -src_paths = ["redis", "tests"] -skip_glob=benchmarks/* \ No newline at end of file diff --git a/benchmarks/basic_operations.py b/benchmarks/basic_operations.py index c9f5853652..66cd6b320d 100644 --- a/benchmarks/basic_operations.py +++ b/benchmarks/basic_operations.py @@ -54,7 +54,7 @@ def wrapper(*args, **kwargs): count = args[1] print(f"{func.__name__} - {count} Requests") print(f"Duration = {duration}") - print(f"Rate = {count/duration}") + print(f"Rate = {count / duration}") print() return ret diff --git a/benchmarks/command_packer_benchmark.py b/benchmarks/command_packer_benchmark.py index e66dbbcbf9..4fb7196422 100644 --- a/benchmarks/command_packer_benchmark.py +++ b/benchmarks/command_packer_benchmark.py @@ -78,7 +78,6 @@ def pack_command(self, *args): class CommandPackerBenchmark(Benchmark): - ARGUMENTS = ( { "name": "connection_class", diff --git a/benchmarks/socket_read_size.py b/benchmarks/socket_read_size.py index 544c733178..37ffa97812 100644 --- a/benchmarks/socket_read_size.py +++ b/benchmarks/socket_read_size.py @@ -4,7 +4,6 @@ class SocketReadBenchmark(Benchmark): - ARGUMENTS = ( {"name": "parser", "values": [PythonParser, _HiredisParser]}, { diff --git a/dev_requirements.txt b/dev_requirements.txt index 7ee7ac2b75..11797c98e5 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,9 +1,5 @@ build -black==24.3.0 click==8.0.4 -flake8-isort -flake8 -flynt~=0.69.0 invoke==2.2.0 mock packaging>=20.4 @@ -12,6 +8,7 @@ pytest-asyncio>=0.23.0,<0.24.0 pytest-cov pytest-profiling==1.8.1 pytest-timeout +ruff==0.9.6 ujson>=4.2.0 uvloop vulture>=2.3.0 diff --git a/doctests/README.md b/doctests/README.md index 15664f1bcd..b5deff7ff3 100644 --- a/doctests/README.md +++ b/doctests/README.md @@ -21,8 +21,8 @@ pip uninstall -y redis # uninstall Redis package installed via redis-entraid pip install -r doctests/requirements.txt ``` -Note - the CI process, runs the basic ```black``` and ```isort``` linters against the examples. Assuming -the requirements above have been installed you can run ```black yourfile.py``` and ```isort yourfile.py``` +Note - the CI process, runs linters against the examples. Assuming +the requirements above have been installed you can run ```ruff check yourfile.py``` and ```ruff format yourfile.py``` locally to validate the linting, prior to CI. Just include necessary assertions in the example file and run diff --git a/pyproject.toml b/pyproject.toml index 0ec38ed3b1..0ab1b61bb8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,3 +96,53 @@ filterwarnings = [ # Ignore a coverage warning when COVERAGE_CORE=sysmon for Pythons < 3.12. "ignore:sys.monitoring isn't available:coverage.exceptions.CoverageWarning", ] + +[tool.ruff] +target-version = "py38" +line-length = 88 +exclude = [ + "*.egg-info", + "*.pyc", + ".git", + ".venv*", + "build", + "dist", + "docker", + "docs/*", + "doctests/*", + "tasks.py", + "venv*", + "whitelist.py", +] + +[tool.ruff.lint] +ignore = [ + "E501", # line too long (taken care of with ruff format) + "E741", # ambiguous variable name + "N818", # Errors should have Error suffix +] + +select = [ + "E", + "F", + "FLY", + "I", + "N", + "W", +] + +[tool.ruff.lint.per-file-ignores] +"redis/commands/bf/*" = [ + # the `bf` module uses star imports, so this is required there. + "F405", # name may be undefined, or defined from star imports +] +"redis/commands/{bf,timeseries,json,search}/*" = [ + "N", +] +"tests/*" = [ + "I", # TODO: could be enabled, plenty of changes + "N801", # class name should use CapWords convention + "N803", # argument name should be lowercase + "N802", # function name should be lowercase + "N806", # variable name should be lowercase +] diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index 91a4f74199..ebc8313ce7 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -32,9 +32,9 @@ from .encoders import Encoder from .socket import SERVER_CLOSED_CONNECTION_ERROR, SocketBuffer -MODULE_LOAD_ERROR = "Error loading the extension. " "Please check the server logs." +MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs." NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" -MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not " "possible." +MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible." MODULE_EXPORTS_DATA_TYPES_ERROR = ( "Error unloading module: the module " "exports one or more module-side data " diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 0d6d130dcf..d080943182 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -1167,9 +1167,7 @@ def get_node( return self.nodes_cache.get(node_name) else: raise DataError( - "get_node requires one of the following: " - "1. node name " - "2. host and port" + "get_node requires one of the following: 1. node name 2. host and port" ) def set_nodes( @@ -1351,7 +1349,7 @@ async def initialize(self) -> None: if len(disagreements) > 5: raise RedisClusterException( f"startup_nodes could not agree on a valid " - f'slots cache: {", ".join(disagreements)}' + f"slots cache: {', '.join(disagreements)}" ) # Validate if all slots are covered or if we should try next startup node diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 2e2a2502c3..9b5d0d8eb9 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -842,7 +842,7 @@ def __init__( if cert_reqs is None: self.cert_reqs = ssl.CERT_NONE elif isinstance(cert_reqs, str): - CERT_REQS = { + CERT_REQS = { # noqa: N806 "none": ssl.CERT_NONE, "optional": ssl.CERT_OPTIONAL, "required": ssl.CERT_REQUIRED, diff --git a/redis/asyncio/utils.py b/redis/asyncio/utils.py index 5a55b36a33..fa014514ec 100644 --- a/redis/asyncio/utils.py +++ b/redis/asyncio/utils.py @@ -16,7 +16,7 @@ def from_url(url, **kwargs): return Redis.from_url(url, **kwargs) -class pipeline: +class pipeline: # noqa: N801 def __init__(self, redis_obj: "Redis"): self.p: "Pipeline" = redis_obj.pipeline() diff --git a/redis/auth/token.py b/redis/auth/token.py index 1c5246469b..1f613aff5f 100644 --- a/redis/auth/token.py +++ b/redis/auth/token.py @@ -76,7 +76,6 @@ def get_received_at_ms(self) -> float: class JWToken(TokenInterface): - REQUIRED_FIELDS = {"exp"} def __init__(self, token: str): diff --git a/redis/client.py b/redis/client.py index fc535c8ca0..2bacbe14ac 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1514,8 +1514,7 @@ def raise_first_error(self, commands, response): def annotate_exception(self, exception, number, command): cmd = " ".join(map(safe_str, command)) msg = ( - f"Command # {number} ({cmd}) of pipeline " - f"caused error: {exception.args[0]}" + f"Command # {number} ({cmd}) of pipeline caused error: {exception.args[0]}" ) exception.args = (msg,) + exception.args[1:] diff --git a/redis/cluster.py b/redis/cluster.py index f518a1f184..ef4500f895 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1628,7 +1628,7 @@ def initialize(self): if len(disagreements) > 5: raise RedisClusterException( f"startup_nodes could not agree on a valid " - f'slots cache: {", ".join(disagreements)}' + f"slots cache: {', '.join(disagreements)}" ) fully_covered = self.check_slots_coverage(tmp_slots) @@ -2047,8 +2047,7 @@ def annotate_exception(self, exception, number, command): """ cmd = " ".join(map(safe_str, command)) msg = ( - f"Command # {number} ({cmd}) of pipeline " - f"caused error: {exception.args[0]}" + f"Command # {number} ({cmd}) of pipeline caused error: {exception.args[0]}" ) exception.args = (msg,) + exception.args[1:] diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py index f31b88bc4e..f0b65612e0 100644 --- a/redis/commands/cluster.py +++ b/redis/commands/cluster.py @@ -595,7 +595,7 @@ def cluster_setslot( "CLUSTER SETSLOT", slot_id, state, node_id, target_nodes=target_node ) elif state.upper() == "STABLE": - raise RedisError('For "stable" state please use ' "cluster_setslot_stable") + raise RedisError('For "stable" state please use cluster_setslot_stable') else: raise RedisError(f"Invalid slot state: {state}") diff --git a/redis/commands/core.py b/redis/commands/core.py index c3ffb955c4..b0e5dc6794 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -3415,7 +3415,9 @@ def smembers(self, name: str) -> Union[Awaitable[Set], Set]: """ return self.execute_command("SMEMBERS", name, keys=[name]) - def smismember(self, name: str, values: List, *args: List) -> Union[ + def smismember( + self, name: str, values: List, *args: List + ) -> Union[ Awaitable[List[Union[Literal[0], Literal[1]]]], List[Union[Literal[0], Literal[1]]], ]: @@ -4162,8 +4164,7 @@ def zadd( raise DataError("ZADD allows either 'gt' or 'lt', not both") if incr and len(mapping) != 1: raise DataError( - "ZADD option 'incr' only works when passing a " - "single element/score pair" + "ZADD option 'incr' only works when passing a single element/score pair" ) if nx and (gt or lt): raise DataError("Only one of 'nx', 'lt', or 'gr' may be defined.") diff --git a/redis/commands/graph/commands.py b/redis/commands/graph/commands.py index d92018f731..1e41a5fb1f 100644 --- a/redis/commands/graph/commands.py +++ b/redis/commands/graph/commands.py @@ -171,9 +171,7 @@ def config(self, name, value=None, set=False): if set: params.append(value) else: - raise DataError( - "``value`` can be provided only when ``set`` is True" - ) # noqa + raise DataError("``value`` can be provided only when ``set`` is True") # noqa return self.execute_command(CONFIG_CMD, *params) def list_keys(self): diff --git a/redis/commands/helpers.py b/redis/commands/helpers.py index 7d9095ea41..e11d34fb71 100644 --- a/redis/commands/helpers.py +++ b/redis/commands/helpers.py @@ -138,9 +138,9 @@ def stringify_param_value(value): elif value is None: return "null" elif isinstance(value, (list, tuple)): - return f'[{",".join(map(stringify_param_value, value))}]' + return f"[{','.join(map(stringify_param_value, value))}]" elif isinstance(value, dict): - return f'{{{",".join(f"{k}:{stringify_param_value(v)}" for k, v in value.items())}}}' # noqa + return f"{{{','.join(f'{k}:{stringify_param_value(v)}' for k, v in value.items())}}}" # noqa else: return str(value) diff --git a/redis/connection.py b/redis/connection.py index 2391e74d2c..43501800c8 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -622,9 +622,7 @@ def read_response( except OSError as e: if disconnect_on_error: self.disconnect() - raise ConnectionError( - f"Error while reading from {host_error}" f" : {e.args}" - ) + raise ConnectionError(f"Error while reading from {host_error} : {e.args}") except BaseException: # Also by default close in case of BaseException. A lot of code # relies on this behaviour when doing Command/Response pairs. @@ -1040,7 +1038,7 @@ def __init__( if ssl_cert_reqs is None: ssl_cert_reqs = ssl.CERT_NONE elif isinstance(ssl_cert_reqs, str): - CERT_REQS = { + CERT_REQS = { # noqa: N806 "none": ssl.CERT_NONE, "optional": ssl.CERT_OPTIONAL, "required": ssl.CERT_REQUIRED, diff --git a/redis/exceptions.py b/redis/exceptions.py index dcc06774b0..82f62730ab 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -79,6 +79,7 @@ class ModuleError(ResponseError): class LockError(RedisError, ValueError): "Errors acquiring or releasing a lock" + # NOTE: For backwards compatibility, this class derives from ValueError. # This was originally chosen to behave like threading.Lock. @@ -89,11 +90,13 @@ def __init__(self, message=None, lock_name=None): class LockNotOwnedError(LockError): "Error trying to extend or release a lock that is (no longer) owned" + pass class ChildDeadlockedError(Exception): "Error indicating that a child process is deadlocked after a fork()" + pass diff --git a/redis/ocsp.py b/redis/ocsp.py index 8819848fa9..d69c914dee 100644 --- a/redis/ocsp.py +++ b/redis/ocsp.py @@ -15,6 +15,7 @@ from cryptography.hazmat.primitives.hashes import SHA1, Hash from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat from cryptography.x509 import ocsp + from redis.exceptions import AuthorizationError, ConnectionError @@ -56,7 +57,7 @@ def _check_certificate(issuer_cert, ocsp_bytes, validate=True): if ocsp_response.response_status == ocsp.OCSPResponseStatus.SUCCESSFUL: if ocsp_response.certificate_status != ocsp.OCSPCertStatus.GOOD: raise ConnectionError( - f'Received an {str(ocsp_response.certificate_status).split(".")[1]} ' + f"Received an {str(ocsp_response.certificate_status).split('.')[1]} " "ocsp certificate status" ) else: diff --git a/redis/sentinel.py b/redis/sentinel.py index 01e210794c..521ac24142 100644 --- a/redis/sentinel.py +++ b/redis/sentinel.py @@ -273,7 +273,7 @@ def __repr__(self): ) return ( f"<{type(self).__module__}.{type(self).__name__}" - f'(sentinels=[{",".join(sentinel_addresses)}])>' + f"(sentinels=[{','.join(sentinel_addresses)}])>" ) def check_master_state(self, state, service_name): diff --git a/tasks.py b/tasks.py index 8a5cae97b2..2d1a073437 100644 --- a/tasks.py +++ b/tasks.py @@ -27,11 +27,9 @@ def build_docs(c): @task def linters(c): """Run code linters""" - run("flake8 tests redis") - run("black --target-version py37 --check --diff tests redis") - run("isort --check-only --diff tests redis") + run("ruff check tests redis") + run("ruff format --check --diff tests redis") run("vulture redis whitelist.py --min-confidence 80") - run("flynt --fail-on-change --dry-run tests redis") @task diff --git a/tests/conftest.py b/tests/conftest.py index e5eea4d582..c485d626ca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -646,7 +646,7 @@ def wait_for_command(client, monitor, command, key=None): if Version(redis_version) >= Version("5.0.0"): id_str = str(client.client_id()) else: - id_str = f"{random.randrange(2 ** 32):08x}" + id_str = f"{random.randrange(2**32):08x}" key = f"__REDIS-PY-{id_str}__" client.get(key) while True: diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 226b00aa45..60e447e6fd 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -242,7 +242,7 @@ async def wait_for_command( if Version(redis_version) >= Version("5.0.0"): id_str = str(await client.client_id()) else: - id_str = f"{random.randrange(2 ** 32):08x}" + id_str = f"{random.randrange(2**32):08x}" key = f"__REDIS-PY-{id_str}__" await client.get(key) while True: diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 48ddd5e4f3..735b116c5d 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -146,7 +146,6 @@ async def get_mocked_redis_client( with mock.patch.object(ClusterNode, "execute_command") as execute_command_mock: async def execute_command(*_args, **_kwargs): - if _args[0] == "CLUSTER SLOTS": if cluster_slots_raise_error: raise ResponseError() @@ -1577,7 +1576,7 @@ async def test_cluster_bitop_not_empty_string(self, r: RedisCluster) -> None: @skip_if_server_version_lt("2.6.0") async def test_cluster_bitop_not(self, r: RedisCluster) -> None: - test_str = b"\xAA\x00\xFF\x55" + test_str = b"\xaa\x00\xff\x55" correct = ~0xAA00FF55 & 0xFFFFFFFF await r.set("{foo}a", test_str) await r.bitop("not", "{foo}r", "{foo}a") @@ -1585,7 +1584,7 @@ async def test_cluster_bitop_not(self, r: RedisCluster) -> None: @skip_if_server_version_lt("2.6.0") async def test_cluster_bitop_not_in_place(self, r: RedisCluster) -> None: - test_str = b"\xAA\x00\xFF\x55" + test_str = b"\xaa\x00\xff\x55" correct = ~0xAA00FF55 & 0xFFFFFFFF await r.set("{foo}a", test_str) await r.bitop("not", "{foo}a", "{foo}a") @@ -1593,7 +1592,7 @@ async def test_cluster_bitop_not_in_place(self, r: RedisCluster) -> None: @skip_if_server_version_lt("2.6.0") async def test_cluster_bitop_single_string(self, r: RedisCluster) -> None: - test_str = b"\x01\x02\xFF" + test_str = b"\x01\x02\xff" await r.set("{foo}a", test_str) await r.bitop("and", "{foo}res1", "{foo}a") await r.bitop("or", "{foo}res2", "{foo}a") @@ -1604,8 +1603,8 @@ async def test_cluster_bitop_single_string(self, r: RedisCluster) -> None: @skip_if_server_version_lt("2.6.0") async def test_cluster_bitop_string_operands(self, r: RedisCluster) -> None: - await r.set("{foo}a", b"\x01\x02\xFF\xFF") - await r.set("{foo}b", b"\x01\x02\xFF") + await r.set("{foo}a", b"\x01\x02\xff\xff") + await r.set("{foo}b", b"\x01\x02\xff") await r.bitop("and", "{foo}res1", "{foo}a", "{foo}b") await r.bitop("or", "{foo}res2", "{foo}a", "{foo}b") await r.bitop("xor", "{foo}res3", "{foo}a", "{foo}b") diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index 9f154cb273..08bd5810f4 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -711,8 +711,9 @@ async def test_config_set_for_search_module(self, r: redis.Redis): "search-default-dialect" ] == default_dialect_new assert ( - (await r.ft().config_get("*"))[b"DEFAULT_DIALECT"] - ).decode() == default_dialect_new + ((await r.ft().config_get("*"))[b"DEFAULT_DIALECT"]).decode() + == default_dialect_new + ) except AssertionError as ex: raise ex finally: @@ -844,7 +845,7 @@ async def test_bitop_not_empty_string(self, r: redis.Redis): @skip_if_server_version_lt("2.6.0") @pytest.mark.onlynoncluster async def test_bitop_not(self, r: redis.Redis): - test_str = b"\xAA\x00\xFF\x55" + test_str = b"\xaa\x00\xff\x55" correct = ~0xAA00FF55 & 0xFFFFFFFF await r.set("a", test_str) await r.bitop("not", "r", "a") @@ -853,7 +854,7 @@ async def test_bitop_not(self, r: redis.Redis): @skip_if_server_version_lt("2.6.0") @pytest.mark.onlynoncluster async def test_bitop_not_in_place(self, r: redis.Redis): - test_str = b"\xAA\x00\xFF\x55" + test_str = b"\xaa\x00\xff\x55" correct = ~0xAA00FF55 & 0xFFFFFFFF await r.set("a", test_str) await r.bitop("not", "a", "a") @@ -862,7 +863,7 @@ async def test_bitop_not_in_place(self, r: redis.Redis): @skip_if_server_version_lt("2.6.0") @pytest.mark.onlynoncluster async def test_bitop_single_string(self, r: redis.Redis): - test_str = b"\x01\x02\xFF" + test_str = b"\x01\x02\xff" await r.set("a", test_str) await r.bitop("and", "res1", "a") await r.bitop("or", "res2", "a") @@ -874,8 +875,8 @@ async def test_bitop_single_string(self, r: redis.Redis): @skip_if_server_version_lt("2.6.0") @pytest.mark.onlynoncluster async def test_bitop_string_operands(self, r: redis.Redis): - await r.set("a", b"\x01\x02\xFF\xFF") - await r.set("b", b"\x01\x02\xFF") + await r.set("a", b"\x01\x02\xff\xff") + await r.set("b", b"\x01\x02\xff") await r.bitop("and", "res1", "a", "b") await r.bitop("or", "res2", "a", "b") await r.bitop("xor", "res3", "a", "b") diff --git a/tests/test_asyncio/test_graph.py b/tests/test_asyncio/test_graph.py index 2a506d5e22..7b823265c3 100644 --- a/tests/test_asyncio/test_graph.py +++ b/tests/test_asyncio/test_graph.py @@ -44,8 +44,7 @@ async def test_graph_creation(decoded_r: redis.Redis): await graph.commit() query = ( - 'MATCH (p:person)-[v:visited {purpose:"pleasure"}]->(c:country) ' - "RETURN p, v, c" + 'MATCH (p:person)-[v:visited {purpose:"pleasure"}]->(c:country) RETURN p, v, c' ) result = await graph.query(query) diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index c0efcce882..c55d57f3b2 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -1797,10 +1797,10 @@ async def test_binary_and_text_fields(decoded_r: redis.Redis): docs[0]["vector_emb"], dtype=np.float32 ) - assert np.array_equal( - decoded_vec_from_search_results, fake_vec - ), "The vectors are not equal" + assert np.array_equal(decoded_vec_from_search_results, fake_vec), ( + "The vectors are not equal" + ) - assert ( - docs[0]["first_name"] == mixed_data["first_name"] - ), "The text field is not decoded correctly" + assert docs[0]["first_name"] == mixed_data["first_name"], ( + "The text field is not decoded correctly" + ) diff --git a/tests/test_auth/test_token.py b/tests/test_auth/test_token.py index 2d72e08895..97633e38e7 100644 --- a/tests/test_auth/test_token.py +++ b/tests/test_auth/test_token.py @@ -6,7 +6,6 @@ class TestToken: - def test_simple_token(self): token = SimpleToken( "value", diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 908ac26211..bec9a8ecb0 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -1692,7 +1692,7 @@ def test_cluster_bitop_not_empty_string(self, r): @skip_if_server_version_lt("2.6.0") def test_cluster_bitop_not(self, r): - test_str = b"\xAA\x00\xFF\x55" + test_str = b"\xaa\x00\xff\x55" correct = ~0xAA00FF55 & 0xFFFFFFFF r["{foo}a"] = test_str r.bitop("not", "{foo}r", "{foo}a") @@ -1700,7 +1700,7 @@ def test_cluster_bitop_not(self, r): @skip_if_server_version_lt("2.6.0") def test_cluster_bitop_not_in_place(self, r): - test_str = b"\xAA\x00\xFF\x55" + test_str = b"\xaa\x00\xff\x55" correct = ~0xAA00FF55 & 0xFFFFFFFF r["{foo}a"] = test_str r.bitop("not", "{foo}a", "{foo}a") @@ -1708,7 +1708,7 @@ def test_cluster_bitop_not_in_place(self, r): @skip_if_server_version_lt("2.6.0") def test_cluster_bitop_single_string(self, r): - test_str = b"\x01\x02\xFF" + test_str = b"\x01\x02\xff" r["{foo}a"] = test_str r.bitop("and", "{foo}res1", "{foo}a") r.bitop("or", "{foo}res2", "{foo}a") @@ -1719,8 +1719,8 @@ def test_cluster_bitop_single_string(self, r): @skip_if_server_version_lt("2.6.0") def test_cluster_bitop_string_operands(self, r): - r["{foo}a"] = b"\x01\x02\xFF\xFF" - r["{foo}b"] = b"\x01\x02\xFF" + r["{foo}a"] = b"\x01\x02\xff\xff" + r["{foo}b"] = b"\x01\x02\xff" r.bitop("and", "{foo}res1", "{foo}a", "{foo}b") r.bitop("or", "{foo}res2", "{foo}a", "{foo}b") r.bitop("xor", "{foo}res3", "{foo}a", "{foo}b") diff --git a/tests/test_commands.py b/tests/test_commands.py index b6f13f6aa8..c6e39f565d 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -1025,8 +1025,9 @@ def test_config_set_for_search_module(self, r: redis.Redis): assert r.config_set("search-default-dialect", default_dialect_new) assert r.config_get("*")["search-default-dialect"] == default_dialect_new assert ( - r.ft().config_get("*")[b"DEFAULT_DIALECT"] - ).decode() == default_dialect_new + (r.ft().config_get("*")[b"DEFAULT_DIALECT"]).decode() + == default_dialect_new + ) except AssertionError as ex: raise ex finally: @@ -1268,7 +1269,7 @@ def test_bitop_not_empty_string(self, r): @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.6.0") def test_bitop_not(self, r): - test_str = b"\xAA\x00\xFF\x55" + test_str = b"\xaa\x00\xff\x55" correct = ~0xAA00FF55 & 0xFFFFFFFF r["a"] = test_str r.bitop("not", "r", "a") @@ -1277,7 +1278,7 @@ def test_bitop_not(self, r): @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.6.0") def test_bitop_not_in_place(self, r): - test_str = b"\xAA\x00\xFF\x55" + test_str = b"\xaa\x00\xff\x55" correct = ~0xAA00FF55 & 0xFFFFFFFF r["a"] = test_str r.bitop("not", "a", "a") @@ -1286,7 +1287,7 @@ def test_bitop_not_in_place(self, r): @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.6.0") def test_bitop_single_string(self, r): - test_str = b"\x01\x02\xFF" + test_str = b"\x01\x02\xff" r["a"] = test_str r.bitop("and", "res1", "a") r.bitop("or", "res2", "a") @@ -1298,8 +1299,8 @@ def test_bitop_single_string(self, r): @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.6.0") def test_bitop_string_operands(self, r): - r["a"] = b"\x01\x02\xFF\xFF" - r["b"] = b"\x01\x02\xFF" + r["a"] = b"\x01\x02\xff\xff" + r["b"] = b"\x01\x02\xff" r.bitop("and", "res1", "a", "b") r.bitop("or", "res2", "a", "b") r.bitop("xor", "res3", "a", "b") diff --git a/tests/test_connection.py b/tests/test_connection.py index 6c1498a329..9664146ce5 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -362,7 +362,6 @@ def test_unix_socket_connection_failure(): class TestUnitConnectionPool: - @pytest.mark.parametrize( "max_conn", (-1, "str"), ids=("non-positive", "wrong type") ) diff --git a/tests/test_graph.py b/tests/test_graph.py index efb10dada7..fd08385667 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -73,8 +73,7 @@ def test_graph_creation(client): graph.commit() query = ( - 'MATCH (p:person)-[v:visited {purpose:"pleasure"}]->(c:country) ' - "RETURN p, v, c" + 'MATCH (p:person)-[v:visited {purpose:"pleasure"}]->(c:country) RETURN p, v, c' ) result = graph.query(query) diff --git a/tests/test_search.py b/tests/test_search.py index c4598f3773..11f22ac805 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1919,13 +1919,13 @@ def test_binary_and_text_fields(client): docs[0]["vector_emb"], dtype=np.float32 ) - assert np.array_equal( - decoded_vec_from_search_results, fake_vec - ), "The vectors are not equal" + assert np.array_equal(decoded_vec_from_search_results, fake_vec), ( + "The vectors are not equal" + ) - assert ( - docs[0]["first_name"] == mixed_data["first_name"] - ), "The text field is not decoded correctly" + assert docs[0]["first_name"] == mixed_data["first_name"], ( + "The text field is not decoded correctly" + ) @pytest.mark.redismod From 1de69f590b7c0b50cbffceca35134bbabda56e44 Mon Sep 17 00:00:00 2001 From: Neil Bertram Date: Thu, 6 Mar 2025 20:52:15 +1300 Subject: [PATCH 053/113] Correct the typedef of lock.extend() to accept floats, and test that float TTLs are honoured precisely (#3420) --- CHANGES | 1 + redis/asyncio/lock.py | 7 ++++--- tests/test_asyncio/test_lock.py | 9 +++++---- tests/test_lock.py | 9 +++++---- 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/CHANGES b/CHANGES index bd96846b6d..031d909f23 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,4 @@ + * Fix lock.extend() typedef to accept float TTL extension * Update URL in the readme linking to Redis University * Move doctests (doc code examples) to main branch * Update `ResponseT` type hint diff --git a/redis/asyncio/lock.py b/redis/asyncio/lock.py index bb2cccab52..f70a8d09ab 100644 --- a/redis/asyncio/lock.py +++ b/redis/asyncio/lock.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Awaitable, Optional, Union from redis.exceptions import LockError, LockNotOwnedError +from redis.typing import Number if TYPE_CHECKING: from redis.asyncio import Redis, RedisCluster @@ -82,7 +83,7 @@ def __init__( timeout: Optional[float] = None, sleep: float = 0.1, blocking: bool = True, - blocking_timeout: Optional[float] = None, + blocking_timeout: Optional[Number] = None, thread_local: bool = True, ): """ @@ -167,7 +168,7 @@ async def __aexit__(self, exc_type, exc_value, traceback): async def acquire( self, blocking: Optional[bool] = None, - blocking_timeout: Optional[float] = None, + blocking_timeout: Optional[Number] = None, token: Optional[Union[str, bytes]] = None, ): """ @@ -265,7 +266,7 @@ async def do_release(self, expected_token: bytes) -> None: raise LockNotOwnedError("Cannot release a lock that's no longer owned") def extend( - self, additional_time: float, replace_ttl: bool = False + self, additional_time: Number, replace_ttl: bool = False ) -> Awaitable[bool]: """ Adds more time to an already acquired lock. diff --git a/tests/test_asyncio/test_lock.py b/tests/test_asyncio/test_lock.py index 033a8b7467..9973ef701f 100644 --- a/tests/test_asyncio/test_lock.py +++ b/tests/test_asyncio/test_lock.py @@ -174,11 +174,12 @@ async def test_extend_lock_replace_ttl(self, r): await lock.release() async def test_extend_lock_float(self, r): - lock = self.get_lock(r, "foo", timeout=10.0) + lock = self.get_lock(r, "foo", timeout=10.5) assert await lock.acquire(blocking=False) - assert 8000 < (await r.pttl("foo")) <= 10000 - assert await lock.extend(10.0) - assert 16000 < (await r.pttl("foo")) <= 20000 + assert 10400 < (await r.pttl("foo")) <= 10500 + old_ttl = await r.pttl("foo") + assert await lock.extend(10.5) + assert old_ttl + 10400 < (await r.pttl("foo")) <= old_ttl + 10500 await lock.release() async def test_extending_unlocked_lock_raises_error(self, r): diff --git a/tests/test_lock.py b/tests/test_lock.py index d77ff9717a..136c86e459 100644 --- a/tests/test_lock.py +++ b/tests/test_lock.py @@ -178,11 +178,12 @@ def test_extend_lock_replace_ttl(self, r): lock.release() def test_extend_lock_float(self, r): - lock = self.get_lock(r, "foo", timeout=10.0) + lock = self.get_lock(r, "foo", timeout=10.5) assert lock.acquire(blocking=False) - assert 8000 < r.pttl("foo") <= 10000 - assert lock.extend(10.0) - assert 16000 < r.pttl("foo") <= 20000 + assert 10400 < r.pttl("foo") <= 10500 + old_ttl = r.pttl("foo") + assert lock.extend(10.5) + assert old_ttl + 10400 < r.pttl("foo") <= old_ttl + 10500 lock.release() def test_extending_unlocked_lock_raises_error(self, r): From 9be806ad27e7cbbd93d7fd05f839ec890e390092 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Thu, 6 Mar 2025 11:03:52 +0200 Subject: [PATCH 054/113] Removing deprecated usage of forbid_global_loop=True in pytest.mark.asyncio decorator (#3542) --- dev_requirements.txt | 2 +- pyproject.toml | 1 + tests/test_asyncio/test_scripting.py | 16 ++++++++-------- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/dev_requirements.txt b/dev_requirements.txt index 11797c98e5..2a0938bec3 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -4,7 +4,7 @@ invoke==2.2.0 mock packaging>=20.4 pytest -pytest-asyncio>=0.23.0,<0.24.0 +pytest-asyncio>=0.23.0 pytest-cov pytest-profiling==1.8.1 pytest-timeout diff --git a/pyproject.toml b/pyproject.toml index 0ab1b61bb8..22f18bb1ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,6 +88,7 @@ markers = [ "experimental: run only experimental tests", "cp_integration: credential provider integration tests", ] +asyncio_default_fixture_loop_scope = "function" asyncio_mode = "auto" timeout = 30 filterwarnings = [ diff --git a/tests/test_asyncio/test_scripting.py b/tests/test_asyncio/test_scripting.py index 8375ecd787..b8e100c04a 100644 --- a/tests/test_asyncio/test_scripting.py +++ b/tests/test_asyncio/test_scripting.py @@ -28,14 +28,14 @@ async def r(self, create_redis): yield redis await redis.script_flush() - @pytest.mark.asyncio(forbid_global_loop=True) + @pytest.mark.asyncio() async def test_eval(self, r): await r.flushdb() await r.set("a", 2) # 2 * 3 == 6 assert await r.eval(multiply_script, 1, "a", 3) == 6 - @pytest.mark.asyncio(forbid_global_loop=True) + @pytest.mark.asyncio() @skip_if_server_version_lt("6.2.0") async def test_script_flush(self, r): await r.set("a", 2) @@ -55,14 +55,14 @@ async def test_script_flush(self, r): await r.script_load(multiply_script) await r.script_flush("NOTREAL") - @pytest.mark.asyncio(forbid_global_loop=True) + @pytest.mark.asyncio() async def test_evalsha(self, r): await r.set("a", 2) sha = await r.script_load(multiply_script) # 2 * 3 == 6 assert await r.evalsha(sha, 1, "a", 3) == 6 - @pytest.mark.asyncio(forbid_global_loop=True) + @pytest.mark.asyncio() async def test_evalsha_script_not_loaded(self, r): await r.set("a", 2) sha = await r.script_load(multiply_script) @@ -71,7 +71,7 @@ async def test_evalsha_script_not_loaded(self, r): with pytest.raises(exceptions.NoScriptError): await r.evalsha(sha, 1, "a", 3) - @pytest.mark.asyncio(forbid_global_loop=True) + @pytest.mark.asyncio() async def test_script_loading(self, r): # get the sha, then clear the cache sha = await r.script_load(multiply_script) @@ -80,7 +80,7 @@ async def test_script_loading(self, r): await r.script_load(multiply_script) assert await r.script_exists(sha) == [True] - @pytest.mark.asyncio(forbid_global_loop=True) + @pytest.mark.asyncio() async def test_script_object(self, r): await r.script_flush() await r.set("a", 2) @@ -97,7 +97,7 @@ async def test_script_object(self, r): # Test first evalsha block assert await multiply(keys=["a"], args=[3]) == 6 - @pytest.mark.asyncio(forbid_global_loop=True) + @pytest.mark.asyncio() async def test_script_object_in_pipeline(self, r): await r.script_flush() multiply = r.register_script(multiply_script) @@ -127,7 +127,7 @@ async def test_script_object_in_pipeline(self, r): assert await pipe.execute() == [True, b"2", 6] assert await r.script_exists(multiply.sha) == [True] - @pytest.mark.asyncio(forbid_global_loop=True) + @pytest.mark.asyncio() async def test_eval_msgpack_pipeline_error_in_lua(self, r): msgpack_hello = r.register_script(msgpack_hello_script) assert msgpack_hello.sha From 8a28b960f08c6a205215a42080a0ff1026bbbd57 Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Thu, 6 Mar 2025 13:23:58 +0200 Subject: [PATCH 055/113] (tests): Added testing for auth via DefaultAzureCredential (#3544) * (tests): Added testing for auth via DefaultAzureCredential * Added testing for async * Remove unused import --- dev_requirements.txt | 2 +- tests/entraid_utils.py | 39 ++++++++++++++++++++++++-- tests/test_asyncio/conftest.py | 6 ---- tests/test_asyncio/test_credentials.py | 13 +++++++-- tests/test_credentials.py | 13 +++++++-- 5 files changed, 59 insertions(+), 14 deletions(-) diff --git a/dev_requirements.txt b/dev_requirements.txt index 2a0938bec3..ad7330598d 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -13,4 +13,4 @@ ujson>=4.2.0 uvloop vulture>=2.3.0 numpy>=1.24.0 -redis-entraid==0.3.0b1 +redis-entraid==0.4.0b2 diff --git a/tests/entraid_utils.py b/tests/entraid_utils.py index daefbd3956..529c3ccdee 100644 --- a/tests/entraid_utils.py +++ b/tests/entraid_utils.py @@ -19,6 +19,8 @@ ServicePrincipalIdentityProviderConfig, _create_provider_from_managed_identity, _create_provider_from_service_principal, + DefaultAzureCredentialIdentityProviderConfig, + _create_provider_from_default_azure_credential, ) from tests.conftest import mock_identity_provider @@ -26,6 +28,7 @@ class AuthType(Enum): MANAGED_IDENTITY = "managed_identity" SERVICE_PRINCIPAL = "service_principal" + DEFAULT_AZURE_CREDENTIAL = "default_azure_credential" def identity_provider(request) -> IdentityProviderInterface: @@ -37,18 +40,25 @@ def identity_provider(request) -> IdentityProviderInterface: if request.param.get("mock_idp", None) is not None: return mock_identity_provider() - auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) + auth_type = kwargs.get("auth_type", AuthType.SERVICE_PRINCIPAL) config = get_identity_provider_config(request=request) - if auth_type == "MANAGED_IDENTITY": + if auth_type == AuthType.MANAGED_IDENTITY: return _create_provider_from_managed_identity(config) + if auth_type == AuthType.DEFAULT_AZURE_CREDENTIAL: + return _create_provider_from_default_azure_credential(config) + return _create_provider_from_service_principal(config) def get_identity_provider_config( request, -) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]: +) -> Union[ + ManagedIdentityProviderConfig, + ServicePrincipalIdentityProviderConfig, + DefaultAzureCredentialIdentityProviderConfig, +]: if hasattr(request, "param"): kwargs = request.param.get("idp_kwargs", {}) else: @@ -59,6 +69,9 @@ def get_identity_provider_config( if auth_type == AuthType.MANAGED_IDENTITY: return _get_managed_identity_provider_config(request) + if auth_type == AuthType.DEFAULT_AZURE_CREDENTIAL: + return _get_default_azure_credential_provider_config(request) + return _get_service_principal_provider_config(request) @@ -114,6 +127,26 @@ def _get_service_principal_provider_config( ) +def _get_default_azure_credential_provider_config( + request, +) -> DefaultAzureCredentialIdentityProviderConfig: + scopes = os.getenv("AZURE_REDIS_SCOPES", ()) + + if hasattr(request, "param"): + kwargs = request.param.get("idp_kwargs", {}) + token_kwargs = request.param.get("token_kwargs", {}) + else: + kwargs = {} + token_kwargs = {} + + if isinstance(scopes, str): + scopes = scopes.split(",") + + return DefaultAzureCredentialIdentityProviderConfig( + scopes=scopes, app_kwargs=kwargs, token_kwargs=token_kwargs + ) + + def get_entra_id_credentials_provider(request, cred_provider_kwargs): idp = identity_provider(request) expiration_refresh_ratio = cred_provider_kwargs.get( diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 60e447e6fd..340d146ea3 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -1,6 +1,5 @@ import random from contextlib import asynccontextmanager as _asynccontextmanager -from enum import Enum from typing import Union import pytest @@ -18,11 +17,6 @@ from .compat import mock -class AuthType(Enum): - MANAGED_IDENTITY = "managed_identity" - SERVICE_PRINCIPAL = "service_principal" - - async def _get_info(redis_url): client = redis.Redis.from_url(redis_url) info = await client.info() diff --git a/tests/test_asyncio/test_credentials.py b/tests/test_asyncio/test_credentials.py index ce8d76ea45..b4824be469 100644 --- a/tests/test_asyncio/test_credentials.py +++ b/tests/test_asyncio/test_credentials.py @@ -18,6 +18,7 @@ from redis.exceptions import ConnectionError from redis.utils import str_if_bytes from tests.conftest import get_endpoint, skip_if_redis_enterprise +from tests.entraid_utils import AuthType from tests.test_asyncio.conftest import get_credential_provider try: @@ -616,8 +617,12 @@ class TestEntraIdCredentialsProvider: "cred_provider_class": EntraIdCredentialsProvider, "cred_provider_kwargs": {"block_for_initial": True}, }, + { + "cred_provider_class": EntraIdCredentialsProvider, + "idp_kwargs": {"auth_type": AuthType.DEFAULT_AZURE_CREDENTIAL}, + }, ], - ids=["blocked", "non-blocked"], + ids=["blocked", "non-blocked", "DefaultAzureCredential"], indirect=True, ) @pytest.mark.asyncio @@ -692,8 +697,12 @@ class TestClusterEntraIdCredentialsProvider: "cred_provider_class": EntraIdCredentialsProvider, "cred_provider_kwargs": {"block_for_initial": True}, }, + { + "cred_provider_class": EntraIdCredentialsProvider, + "idp_kwargs": {"auth_type": AuthType.DEFAULT_AZURE_CREDENTIAL}, + }, ], - ids=["blocked", "non-blocked"], + ids=["blocked", "non-blocked", "DefaultAzureCredential"], indirect=True, ) @pytest.mark.asyncio diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 1f98c5208d..58bbd01f28 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -22,6 +22,7 @@ get_endpoint, skip_if_redis_enterprise, ) +from tests.entraid_utils import AuthType try: from redis_entraid.cred_provider import EntraIdCredentialsProvider @@ -585,8 +586,12 @@ class TestEntraIdCredentialsProvider: "cred_provider_class": EntraIdCredentialsProvider, "single_connection_client": True, }, + { + "cred_provider_class": EntraIdCredentialsProvider, + "idp_kwargs": {"auth_type": AuthType.DEFAULT_AZURE_CREDENTIAL}, + }, ], - ids=["pool", "single"], + ids=["pool", "single", "DefaultAzureCredential"], indirect=True, ) @pytest.mark.onlynoncluster @@ -656,8 +661,12 @@ class TestClusterEntraIdCredentialsProvider: "cred_provider_class": EntraIdCredentialsProvider, "single_connection_client": True, }, + { + "cred_provider_class": EntraIdCredentialsProvider, + "idp_kwargs": {"auth_type": AuthType.DEFAULT_AZURE_CREDENTIAL}, + }, ], - ids=["pool", "single"], + ids=["pool", "single", "DefaultAzureCredential"], indirect=True, ) @pytest.mark.onlycluster From 74977eb9a55fca95b495d06d85a39c67b81d9053 Mon Sep 17 00:00:00 2001 From: "Edmund L. Wong" Date: Mon, 10 Mar 2025 01:11:11 -0700 Subject: [PATCH 056/113] Fix AttributeError when client.get_default_node() returns None (#3458) --- redis/asyncio/cluster.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index d080943182..398311ebf1 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -1612,18 +1612,24 @@ async def _execute( result.args = (msg,) + result.args[1:] raise result - default_node = nodes.get(client.get_default_node().name) - if default_node is not None: - # This pipeline execution used the default node, check if we need - # to replace it. - # Note: when the error is raised we'll reset the default node in the - # caller function. - for cmd in default_node[1]: - # Check if it has a command that failed with a relevant - # exception - if type(cmd.result) in self.__class__.ERRORS_ALLOW_RETRY: - client.replace_default_node() - break + default_cluster_node = client.get_default_node() + + # Check whether the default node was used. In some cases, + # 'client.get_default_node()' may return None. The check below + # prevents a potential AttributeError. + if default_cluster_node is not None: + default_node = nodes.get(default_cluster_node.name) + if default_node is not None: + # This pipeline execution used the default node, check if we need + # to replace it. + # Note: when the error is raised we'll reset the default node in the + # caller function. + for cmd in default_node[1]: + # Check if it has a command that failed with a relevant + # exception + if type(cmd.result) in self.__class__.ERRORS_ALLOW_RETRY: + client.replace_default_node() + break return [cmd.result for cmd in stack] From 4f0ee91ceaa82d49147c9ab13a05f81312699529 Mon Sep 17 00:00:00 2001 From: Juliano Amadeu <65794514+julianolm@users.noreply.github.com> Date: Mon, 10 Mar 2025 06:55:57 -0300 Subject: [PATCH 057/113] feat: adds option not to raise exception when leaving context manager after lock expiration (#3531) * adds option not to raise when leaving context manager after lock expiration * keep oroginal traceback Co-authored-by: Aarni Koskela * improves error traceback * adds missing modifications * sort imports * run linter * adds catch for other possible exception * Update redis/lock.py to catch Both LockNotOwnedError and LockError in one except statement as LockError. Co-authored-by: Juliano Amadeu <65794514+julianolm@users.noreply.github.com> * Update redis/asyncio/lock.py Co-authored-by: Juliano Amadeu <65794514+julianolm@users.noreply.github.com> * fix linter errors --------- Co-authored-by: Aarni Koskela Co-authored-by: petyaslavova --- redis/asyncio/client.py | 7 +++++++ redis/asyncio/cluster.py | 7 +++++++ redis/asyncio/lock.py | 19 ++++++++++++++++++- redis/client.py | 7 +++++++ redis/cluster.py | 7 +++++++ redis/exceptions.py | 2 +- redis/lock.py | 19 ++++++++++++++++++- tests/test_asyncio/test_lock.py | 30 ++++++++++++++++++++++++++++++ tests/test_lock.py | 26 ++++++++++++++++++++++++++ 9 files changed, 121 insertions(+), 3 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 4254441073..412d5a24b3 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -478,6 +478,7 @@ def lock( blocking_timeout: Optional[float] = None, lock_class: Optional[Type[Lock]] = None, thread_local: bool = True, + raise_on_release_error: bool = True, ) -> Lock: """ Return a new Lock object using key ``name`` that mimics @@ -524,6 +525,11 @@ def lock( thread-1 would see the token value as "xyz" and would be able to successfully release the thread-2's lock. + ``raise_on_release_error`` indicates whether to raise an exception when + the lock is no longer owned when exiting the context manager. By default, + this is True, meaning an exception will be raised. If False, the warning + will be logged and the exception will be suppressed. + In some use cases it's necessary to disable thread local storage. For example, if you have code where one thread acquires a lock and passes that lock instance to a worker thread to release later. If thread @@ -541,6 +547,7 @@ def lock( blocking=blocking, blocking_timeout=blocking_timeout, thread_local=thread_local, + raise_on_release_error=raise_on_release_error, ) def pubsub(self, **kwargs) -> "PubSub": diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 398311ebf1..51328ad95a 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -849,6 +849,7 @@ def lock( blocking_timeout: Optional[float] = None, lock_class: Optional[Type[Lock]] = None, thread_local: bool = True, + raise_on_release_error: bool = True, ) -> Lock: """ Return a new Lock object using key ``name`` that mimics @@ -895,6 +896,11 @@ def lock( thread-1 would see the token value as "xyz" and would be able to successfully release the thread-2's lock. + ``raise_on_release_error`` indicates whether to raise an exception when + the lock is no longer owned when exiting the context manager. By default, + this is True, meaning an exception will be raised. If False, the warning + will be logged and the exception will be suppressed. + In some use cases it's necessary to disable thread local storage. For example, if you have code where one thread acquires a lock and passes that lock instance to a worker thread to release later. If thread @@ -912,6 +918,7 @@ def lock( blocking=blocking, blocking_timeout=blocking_timeout, thread_local=thread_local, + raise_on_release_error=raise_on_release_error, ) diff --git a/redis/asyncio/lock.py b/redis/asyncio/lock.py index f70a8d09ab..16d7fb6957 100644 --- a/redis/asyncio/lock.py +++ b/redis/asyncio/lock.py @@ -1,4 +1,5 @@ import asyncio +import logging import threading import uuid from types import SimpleNamespace @@ -10,6 +11,8 @@ if TYPE_CHECKING: from redis.asyncio import Redis, RedisCluster +logger = logging.getLogger(__name__) + class Lock: """ @@ -85,6 +88,7 @@ def __init__( blocking: bool = True, blocking_timeout: Optional[Number] = None, thread_local: bool = True, + raise_on_release_error: bool = True, ): """ Create a new Lock instance named ``name`` using the Redis client @@ -128,6 +132,11 @@ def __init__( thread-1 would see the token value as "xyz" and would be able to successfully release the thread-2's lock. + ``raise_on_release_error`` indicates whether to raise an exception when + the lock is no longer owned when exiting the context manager. By default, + this is True, meaning an exception will be raised. If False, the warning + will be logged and the exception will be suppressed. + In some use cases it's necessary to disable thread local storage. For example, if you have code where one thread acquires a lock and passes that lock instance to a worker thread to release later. If thread @@ -144,6 +153,7 @@ def __init__( self.blocking_timeout = blocking_timeout self.thread_local = bool(thread_local) self.local = threading.local() if self.thread_local else SimpleNamespace() + self.raise_on_release_error = raise_on_release_error self.local.token = None self.register_scripts() @@ -163,7 +173,14 @@ async def __aenter__(self): raise LockError("Unable to acquire lock within the time specified") async def __aexit__(self, exc_type, exc_value, traceback): - await self.release() + try: + await self.release() + except LockError: + if self.raise_on_release_error: + raise + logger.warning( + "Lock was unlocked or no longer owned when exiting context manager." + ) async def acquire( self, diff --git a/redis/client.py b/redis/client.py index 2bacbe14ac..2ba96bd6f9 100755 --- a/redis/client.py +++ b/redis/client.py @@ -473,6 +473,7 @@ def lock( blocking_timeout: Optional[float] = None, lock_class: Union[None, Any] = None, thread_local: bool = True, + raise_on_release_error: bool = True, ): """ Return a new Lock object using key ``name`` that mimics @@ -519,6 +520,11 @@ def lock( thread-1 would see the token value as "xyz" and would be able to successfully release the thread-2's lock. + ``raise_on_release_error`` indicates whether to raise an exception when + the lock is no longer owned when exiting the context manager. By default, + this is True, meaning an exception will be raised. If False, the warning + will be logged and the exception will be suppressed. + In some use cases it's necessary to disable thread local storage. For example, if you have code where one thread acquires a lock and passes that lock instance to a worker thread to release later. If thread @@ -536,6 +542,7 @@ def lock( blocking=blocking, blocking_timeout=blocking_timeout, thread_local=thread_local, + raise_on_release_error=raise_on_release_error, ) def pubsub(self, **kwargs): diff --git a/redis/cluster.py b/redis/cluster.py index ef4500f895..c9523e2a76 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -822,6 +822,7 @@ def lock( blocking_timeout=None, lock_class=None, thread_local=True, + raise_on_release_error: bool = True, ): """ Return a new Lock object using key ``name`` that mimics @@ -868,6 +869,11 @@ def lock( thread-1 would see the token value as "xyz" and would be able to successfully release the thread-2's lock. + ``raise_on_release_error`` indicates whether to raise an exception when + the lock is no longer owned when exiting the context manager. By default, + this is True, meaning an exception will be raised. If False, the warning + will be logged and the exception will be suppressed. + In some use cases it's necessary to disable thread local storage. For example, if you have code where one thread acquires a lock and passes that lock instance to a worker thread to release later. If thread @@ -885,6 +891,7 @@ def lock( blocking=blocking, blocking_timeout=blocking_timeout, thread_local=thread_local, + raise_on_release_error=raise_on_release_error, ) def set_response_callback(self, command, callback): diff --git a/redis/exceptions.py b/redis/exceptions.py index 82f62730ab..bad447a086 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -89,7 +89,7 @@ def __init__(self, message=None, lock_name=None): class LockNotOwnedError(LockError): - "Error trying to extend or release a lock that is (no longer) owned" + "Error trying to extend or release a lock that is not owned (anymore)" pass diff --git a/redis/lock.py b/redis/lock.py index 7a1becb30a..0288496e6f 100644 --- a/redis/lock.py +++ b/redis/lock.py @@ -1,3 +1,4 @@ +import logging import threading import time as mod_time import uuid @@ -7,6 +8,8 @@ from redis.exceptions import LockError, LockNotOwnedError from redis.typing import Number +logger = logging.getLogger(__name__) + class Lock: """ @@ -82,6 +85,7 @@ def __init__( blocking: bool = True, blocking_timeout: Optional[Number] = None, thread_local: bool = True, + raise_on_release_error: bool = True, ): """ Create a new Lock instance named ``name`` using the Redis client @@ -125,6 +129,11 @@ def __init__( thread-1 would see the token value as "xyz" and would be able to successfully release the thread-2's lock. + ``raise_on_release_error`` indicates whether to raise an exception when + the lock is no longer owned when exiting the context manager. By default, + this is True, meaning an exception will be raised. If False, the warning + will be logged and the exception will be suppressed. + In some use cases it's necessary to disable thread local storage. For example, if you have code where one thread acquires a lock and passes that lock instance to a worker thread to release later. If thread @@ -140,6 +149,7 @@ def __init__( self.blocking = blocking self.blocking_timeout = blocking_timeout self.thread_local = bool(thread_local) + self.raise_on_release_error = raise_on_release_error self.local = threading.local() if self.thread_local else SimpleNamespace() self.local.token = None self.register_scripts() @@ -168,7 +178,14 @@ def __exit__( exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: - self.release() + try: + self.release() + except LockError: + if self.raise_on_release_error: + raise + logger.warning( + "Lock was unlocked or no longer owned when exiting context manager." + ) def acquire( self, diff --git a/tests/test_asyncio/test_lock.py b/tests/test_asyncio/test_lock.py index 9973ef701f..be4270acdf 100644 --- a/tests/test_asyncio/test_lock.py +++ b/tests/test_asyncio/test_lock.py @@ -129,6 +129,36 @@ async def test_context_manager_raises_when_locked_not_acquired(self, r): async with self.get_lock(r, "foo", blocking_timeout=0.1): pass + async def test_context_manager_not_raise_on_release_lock_not_owned_error(self, r): + try: + async with self.get_lock( + r, "foo", timeout=0.1, raise_on_release_error=False + ): + await asyncio.sleep(0.15) + except LockNotOwnedError: + pytest.fail("LockNotOwnedError should not have been raised") + + with pytest.raises(LockNotOwnedError): + async with self.get_lock( + r, "foo", timeout=0.1, raise_on_release_error=True + ): + await asyncio.sleep(0.15) + + async def test_context_manager_not_raise_on_release_lock_error(self, r): + try: + async with self.get_lock( + r, "foo", timeout=0.1, raise_on_release_error=False + ) as lock: + lock.release() + except LockError: + pytest.fail("LockError should not have been raised") + + with pytest.raises(LockError): + async with self.get_lock( + r, "foo", timeout=0.1, raise_on_release_error=True + ) as lock: + lock.release() + async def test_high_sleep_small_blocking_timeout(self, r): lock1 = self.get_lock(r, "foo") assert await lock1.acquire(blocking=False) diff --git a/tests/test_lock.py b/tests/test_lock.py index 136c86e459..3d6d81465e 100644 --- a/tests/test_lock.py +++ b/tests/test_lock.py @@ -133,6 +133,32 @@ def test_context_manager_raises_when_locked_not_acquired(self, r): with self.get_lock(r, "foo", blocking_timeout=0.1): pass + def test_context_manager_not_raise_on_release_lock_not_owned_error(self, r): + try: + with self.get_lock(r, "foo", timeout=0.1, raise_on_release_error=False): + time.sleep(0.15) + except LockNotOwnedError: + pytest.fail("LockNotOwnedError should not have been raised") + + with pytest.raises(LockNotOwnedError): + with self.get_lock(r, "foo", timeout=0.1, raise_on_release_error=True): + time.sleep(0.15) + + def test_context_manager_not_raise_on_release_lock_error(self, r): + try: + with self.get_lock( + r, "foo", timeout=0.1, raise_on_release_error=False + ) as lock: + lock.release() + except LockError: + pytest.fail("LockError should not have been raised") + + with pytest.raises(LockError): + with self.get_lock( + r, "foo", timeout=0.1, raise_on_release_error=True + ) as lock: + lock.release() + def test_high_sleep_small_blocking_timeout(self, r): lock1 = self.get_lock(r, "foo") assert lock1.acquire(blocking=False) From 333fd8fb7feaf8b50ecbe5ff058127cd89617150 Mon Sep 17 00:00:00 2001 From: Vladimir Chebotarev Date: Mon, 10 Mar 2025 13:23:13 +0300 Subject: [PATCH 058/113] Got rid of `time.time()`- replacing it with time.monotonic(). (#3551) --- redis/client.py | 6 +++--- redis/commands/search/commands.py | 12 ++++++------ redis/connection.py | 6 +++--- tests/conftest.py | 4 ++-- tests/test_connection_pool.py | 14 +++++++------- tests/test_pubsub.py | 4 ++-- 6 files changed, 23 insertions(+), 23 deletions(-) diff --git a/redis/client.py b/redis/client.py index 2ba96bd6f9..ea29a864ce 100755 --- a/redis/client.py +++ b/redis/client.py @@ -957,7 +957,7 @@ def check_health(self) -> None: "did you forget to call subscribe() or psubscribe()?" ) - if conn.health_check_interval and time.time() > conn.next_health_check: + if conn.health_check_interval and time.monotonic() > conn.next_health_check: conn.send_command("PING", self.HEALTH_CHECK_MESSAGE, check_health=False) self.health_check_response_counter += 1 @@ -1107,12 +1107,12 @@ def get_message( """ if not self.subscribed: # Wait for subscription - start_time = time.time() + start_time = time.monotonic() if self.subscribed_event.wait(timeout) is True: # The connection was subscribed during the timeout time frame. # The timeout should be adjusted based on the time spent # waiting for the subscription - time_spent = time.time() - start_time + time_spent = time.monotonic() - start_time timeout = max(0.0, timeout - time_spent) else: # The connection isn't subscribed to any channels or patterns, diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index 1db57c23a5..96c6d9c2af 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -500,7 +500,7 @@ def search( For more information see `FT.SEARCH `_. """ # noqa args, query = self._mk_query_args(query, query_params=query_params) - st = time.time() + st = time.monotonic() options = {} if get_protocol_version(self.client) not in ["3", 3]: @@ -512,7 +512,7 @@ def search( return res return self._parse_results( - SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0 + SEARCH_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0 ) def explain( @@ -602,7 +602,7 @@ def profile( Each parameter has a name and a value. """ - st = time.time() + st = time.monotonic() cmd = [PROFILE_CMD, self.index_name, ""] if limited: cmd.append("LIMITED") @@ -621,7 +621,7 @@ def profile( res = self.execute_command(*cmd) return self._parse_results( - PROFILE_CMD, res, query=query, duration=(time.time() - st) * 1000.0 + PROFILE_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0 ) def spellcheck(self, query, distance=None, include=None, exclude=None): @@ -940,7 +940,7 @@ async def search( For more information see `FT.SEARCH `_. """ # noqa args, query = self._mk_query_args(query, query_params=query_params) - st = time.time() + st = time.monotonic() options = {} if get_protocol_version(self.client) not in ["3", 3]: @@ -952,7 +952,7 @@ async def search( return res return self._parse_results( - SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0 + SEARCH_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0 ) async def aggregate( diff --git a/redis/connection.py b/redis/connection.py index 43501800c8..a298542c03 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -4,11 +4,11 @@ import ssl import sys import threading +import time import weakref from abc import abstractmethod from itertools import chain from queue import Empty, Full, LifoQueue -from time import time from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union from urllib.parse import parse_qs, unquote, urlparse @@ -542,7 +542,7 @@ def _ping_failed(self, error): def check_health(self): """Check the health of the connection with a PING/PONG""" - if self.health_check_interval and time() > self.next_health_check: + if self.health_check_interval and time.monotonic() > self.next_health_check: self.retry.call_with_retry(self._send_ping, self._ping_failed) def send_packed_command(self, command, check_health=True): @@ -632,7 +632,7 @@ def read_response( raise if self.health_check_interval: - self.next_health_check = time() + self.health_check_interval + self.next_health_check = time.monotonic() + self.health_check_interval if isinstance(response, ResponseError): try: diff --git a/tests/conftest.py b/tests/conftest.py index c485d626ca..7eaccb1acb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -237,7 +237,7 @@ def wait_for_cluster_creation(redis_url, cluster_nodes, timeout=60): :param cluster_nodes: The number of nodes in the cluster :param timeout: the amount of time to wait (in seconds) """ - now = time.time() + now = time.monotonic() end_time = now + timeout client = None print(f"Waiting for {cluster_nodes} cluster nodes to become available") @@ -250,7 +250,7 @@ def wait_for_cluster_creation(redis_url, cluster_nodes, timeout=60): except RedisClusterException: pass time.sleep(1) - now = time.time() + now = time.monotonic() if now >= end_time: available_nodes = 0 if client is None else len(client.get_nodes()) raise RedisClusterException( diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 65f42923fe..c92d84c226 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -167,11 +167,11 @@ def test_connection_pool_blocks_until_timeout(self, master_host): ) pool.get_connection() - start = time.time() + start = time.monotonic() with pytest.raises(redis.ConnectionError): pool.get_connection() # we should have waited at least 0.1 seconds - assert time.time() - start >= 0.1 + assert time.monotonic() - start >= 0.1 def test_connection_pool_blocks_until_conn_available(self, master_host): """ @@ -188,10 +188,10 @@ def target(): time.sleep(0.1) pool.release(c1) - start = time.time() + start = time.monotonic() Thread(target=target).start() pool.get_connection() - assert time.time() - start >= 0.1 + assert time.monotonic() - start >= 0.1 def test_reuse_previously_released_connection(self, master_host): connection_kwargs = {"host": master_host[0], "port": master_host[1]} @@ -679,18 +679,18 @@ def r(self, request): return _get_client(redis.Redis, request, health_check_interval=self.interval) def assert_interval_advanced(self, connection): - diff = connection.next_health_check - time.time() + diff = connection.next_health_check - time.monotonic() assert self.interval > diff > (self.interval - 1) def test_health_check_runs(self, r): - r.connection.next_health_check = time.time() - 1 + r.connection.next_health_check = time.monotonic() - 1 r.connection.check_health() self.assert_interval_advanced(r.connection) def test_arbitrary_command_invokes_health_check(self, r): # invoke a command to make sure the connection is entirely setup r.get("foo") - r.connection.next_health_check = time.time() + r.connection.next_health_check = time.monotonic() with mock.patch.object( r.connection, "send_command", wraps=r.connection.send_command ) as m: diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index fb46772af3..9ead455af3 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -23,7 +23,7 @@ def wait_for_message( pubsub, timeout=0.5, ignore_subscribe_messages=False, node=None, func=None ): - now = time.time() + now = time.monotonic() timeout = now + timeout while now < timeout: if node: @@ -39,7 +39,7 @@ def wait_for_message( if message is not None: return message time.sleep(0.01) - now = time.time() + now = time.monotonic() return None From 83db949719aa4c4e624c2cd79c79454a1c94d940 Mon Sep 17 00:00:00 2001 From: David Hotham Date: Mon, 10 Mar 2025 11:48:06 +0000 Subject: [PATCH 059/113] allow more recent pyopenssl (#3541) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 22f18bb1ae..9c868be4b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ hiredis = [ ] ocsp = [ "cryptography>=36.0.1", - "pyopenssl==20.0.1", + "pyopenssl>=20.0.1", "requests>=2.31.0", ] jwt = [ From d30ebd1ef6dd4a021f8b1c420de2f39b1d99ea27 Mon Sep 17 00:00:00 2001 From: byeongjulee222 <52685247+byeongjulee222@users.noreply.github.com> Date: Mon, 10 Mar 2025 23:48:58 +0900 Subject: [PATCH 060/113] Fix #3464: Correct misleading exception_handler example in docs (#3474) --- docs/advanced_features.rst | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/advanced_features.rst b/docs/advanced_features.rst index de645bd764..0ed3e1ff34 100644 --- a/docs/advanced_features.rst +++ b/docs/advanced_features.rst @@ -380,8 +380,6 @@ run_in_thread. >>> def exception_handler(ex, pubsub, thread): >>> print(ex) >>> thread.stop() - >>> thread.join(timeout=1.0) - >>> pubsub.close() >>> thread = p.run_in_thread(exception_handler=exception_handler) A PubSub object adheres to the same encoding semantics as the client From dc8359f7e4cd8a674d8acb5d7e174e4ec3ebbbda Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Tue, 11 Mar 2025 12:50:45 +0200 Subject: [PATCH 061/113] Removing support for RedisGraph module. (#3548) --- .github/wordlist.txt | 2 - .github/workflows/install_and_test.sh | 4 +- CHANGES | 1 + docker-compose.yml | 8 - docs/redismodules.rst | 33 -- pyproject.toml | 57 +-- redis/cluster.py | 1 - redis/commands/graph/__init__.py | 263 ---------- redis/commands/graph/commands.py | 311 ------------ redis/commands/graph/edge.py | 91 ---- redis/commands/graph/exceptions.py | 3 - redis/commands/graph/execution_plan.py | 211 -------- redis/commands/graph/node.py | 88 ---- redis/commands/graph/path.py | 78 --- redis/commands/graph/query_result.py | 588 ---------------------- redis/commands/helpers.py | 47 -- redis/commands/redismodules.py | 20 - tasks.py | 8 +- tests/test_asyncio/test_graph.py | 526 -------------------- tests/test_graph.py | 656 ------------------------- tests/test_graph_utils/__init__.py | 0 tests/test_graph_utils/test_edge.py | 75 --- tests/test_graph_utils/test_node.py | 51 -- tests/test_graph_utils/test_path.py | 90 ---- tests/test_helpers.py | 13 - 25 files changed, 23 insertions(+), 3202 deletions(-) delete mode 100644 redis/commands/graph/__init__.py delete mode 100644 redis/commands/graph/commands.py delete mode 100644 redis/commands/graph/edge.py delete mode 100644 redis/commands/graph/exceptions.py delete mode 100644 redis/commands/graph/execution_plan.py delete mode 100644 redis/commands/graph/node.py delete mode 100644 redis/commands/graph/path.py delete mode 100644 redis/commands/graph/query_result.py delete mode 100644 tests/test_asyncio/test_graph.py delete mode 100644 tests/test_graph.py delete mode 100644 tests/test_graph_utils/__init__.py delete mode 100644 tests/test_graph_utils/test_edge.py delete mode 100644 tests/test_graph_utils/test_node.py delete mode 100644 tests/test_graph_utils/test_path.py diff --git a/.github/wordlist.txt b/.github/wordlist.txt index 3ea543748e..29bcaa9d77 100644 --- a/.github/wordlist.txt +++ b/.github/wordlist.txt @@ -12,7 +12,6 @@ ConnectionPool CoreCommands EVAL EVALSHA -GraphCommands Grokzen's INCR IOError @@ -39,7 +38,6 @@ RedisCluster RedisClusterCommands RedisClusterException RedisClusters -RedisGraph RedisInstrumentor RedisJSON RedisTimeSeries diff --git a/.github/workflows/install_and_test.sh b/.github/workflows/install_and_test.sh index 778dbe0b20..e647126539 100755 --- a/.github/workflows/install_and_test.sh +++ b/.github/workflows/install_and_test.sh @@ -40,9 +40,9 @@ cd ${TESTDIR} # install, run tests pip install ${PKG} # Redis tests -pytest -m 'not onlycluster and not graph' +pytest -m 'not onlycluster' # RedisCluster tests CLUSTER_URL="redis://localhost:16379/0" CLUSTER_SSL_URL="rediss://localhost:27379/0" -pytest -m 'not onlynoncluster and not redismod and not ssl and not graph' \ +pytest -m 'not onlynoncluster and not redismod and not ssl' \ --redis-url="${CLUSTER_URL}" --redis-ssl-url="${CLUSTER_SSL_URL}" diff --git a/CHANGES b/CHANGES index 031d909f23..24b52c54db 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,4 @@ + * Removing support for RedisGraph module. RedisGraph support is deprecated since Redis Stack 7.2 (https://redis.com/blog/redisgraph-eol/) * Fix lock.extend() typedef to accept float TTL extension * Update URL in the readme linking to Redis University * Move doctests (doc code examples) to main branch diff --git a/docker-compose.yml b/docker-compose.yml index 8ca3471311..76a60398f3 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -105,11 +105,3 @@ services: - standalone - all-stack - all - - redis-stack-graph: - image: redis/redis-stack-server:6.2.6-v15 - container_name: redis-stack-graph - ports: - - 6480:6379 - profiles: - - graph diff --git a/docs/redismodules.rst b/docs/redismodules.rst index 27757cb692..07914fff12 100644 --- a/docs/redismodules.rst +++ b/docs/redismodules.rst @@ -51,39 +51,6 @@ These are the commands for interacting with the `RedisBloom module `_. Below is a brief example, as well as documentation on the commands themselves. - -**Create a graph, adding two nodes** - -.. code-block:: python - - import redis - from redis.graph.node import Node - - john = Node(label="person", properties={"name": "John Doe", "age": 33} - jane = Node(label="person", properties={"name": "Jane Doe", "age": 34} - - r = redis.Redis() - graph = r.graph() - graph.add_node(john) - graph.add_node(jane) - graph.add_node(pat) - graph.commit() - -.. automodule:: redis.commands.graph.node - :members: Node - -.. automodule:: redis.commands.graph.edge - :members: Edge - -.. automodule:: redis.commands.graph.commands - :members: GraphCommands - ------- - RedisJSON Commands ****************** diff --git a/pyproject.toml b/pyproject.toml index 9c868be4b7..ab3e4cd77e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,14 +9,8 @@ description = "Python client for Redis database and key-value store" readme = "README.md" license = "MIT" requires-python = ">=3.8" -authors = [ - { name = "Redis Inc.", email = "oss@redis.com" }, -] -keywords = [ - "Redis", - "database", - "key-value-store", -] +authors = [{ name = "Redis Inc.", email = "oss@redis.com" }] +keywords = ["Redis", "database", "key-value-store"] classifiers = [ "Development Status :: 5 - Production/Stable", "Environment :: Console", @@ -35,9 +29,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = [ - 'async-timeout>=4.0.3; python_full_version<"3.11.3"', -] +dependencies = ['async-timeout>=4.0.3; python_full_version<"3.11.3"'] [project.optional-dependencies] hiredis = [ @@ -63,22 +55,15 @@ Homepage = "https://github.com/redis/redis-py" path = "redis/__init__.py" [tool.hatch.build.targets.sdist] -include = [ - "/redis", - "/tests", - "dev_requirements.txt", -] +include = ["/redis", "/tests", "dev_requirements.txt"] [tool.hatch.build.targets.wheel] -include = [ - "/redis", -] +include = ["/redis"] [tool.pytest.ini_options] addopts = "-s" markers = [ "redismod: run only the redis module tests", - "graph: run only the redisgraph tests", "pipeline: pipeline tests", "onlycluster: marks tests to be run only with cluster mode redis", "onlynoncluster: marks tests to be run only with standalone redis", @@ -93,7 +78,6 @@ asyncio_mode = "auto" timeout = 30 filterwarnings = [ "always", - "ignore:RedisGraph support is deprecated as of Redis Stack 7.2:DeprecationWarning", # Ignore a coverage warning when COVERAGE_CORE=sysmon for Pythons < 3.12. "ignore:sys.monitoring isn't available:coverage.exceptions.CoverageWarning", ] @@ -118,32 +102,23 @@ exclude = [ [tool.ruff.lint] ignore = [ - "E501", # line too long (taken care of with ruff format) - "E741", # ambiguous variable name - "N818", # Errors should have Error suffix + "E501", # line too long (taken care of with ruff format) + "E741", # ambiguous variable name + "N818", # Errors should have Error suffix ] -select = [ - "E", - "F", - "FLY", - "I", - "N", - "W", -] +select = ["E", "F", "FLY", "I", "N", "W"] [tool.ruff.lint.per-file-ignores] "redis/commands/bf/*" = [ # the `bf` module uses star imports, so this is required there. - "F405", # name may be undefined, or defined from star imports -] -"redis/commands/{bf,timeseries,json,search}/*" = [ - "N", + "F405", # name may be undefined, or defined from star imports ] +"redis/commands/{bf,timeseries,json,search}/*" = ["N"] "tests/*" = [ - "I", # TODO: could be enabled, plenty of changes - "N801", # class name should use CapWords convention - "N803", # argument name should be lowercase - "N802", # function name should be lowercase - "N806", # variable name should be lowercase + "I", # TODO: could be enabled, plenty of changes + "N801", # class name should use CapWords convention + "N803", # argument name should be lowercase + "N802", # function name should be lowercase + "N806", # variable name should be lowercase ] diff --git a/redis/cluster.py b/redis/cluster.py index c9523e2a76..13253ec896 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -288,7 +288,6 @@ class AbstractRedisCluster: "TFUNCTION LIST", "TFCALL", "TFCALLASYNC", - "GRAPH.CONFIG", "LATENCY HISTORY", "LATENCY LATEST", "LATENCY RESET", diff --git a/redis/commands/graph/__init__.py b/redis/commands/graph/__init__.py deleted file mode 100644 index ddc0e34f4c..0000000000 --- a/redis/commands/graph/__init__.py +++ /dev/null @@ -1,263 +0,0 @@ -import warnings - -from ..helpers import quote_string, random_string, stringify_param_value -from .commands import AsyncGraphCommands, GraphCommands -from .edge import Edge # noqa -from .node import Node # noqa -from .path import Path # noqa - -DB_LABELS = "DB.LABELS" -DB_RAELATIONSHIPTYPES = "DB.RELATIONSHIPTYPES" -DB_PROPERTYKEYS = "DB.PROPERTYKEYS" - - -class Graph(GraphCommands): - """ - Graph, collection of nodes and edges. - """ - - def __init__(self, client, name=random_string()): - """ - Create a new graph. - """ - warnings.warn( - DeprecationWarning( - "RedisGraph support is deprecated as of Redis Stack 7.2 \ - (https://redis.com/blog/redisgraph-eol/)" - ) - ) - self.NAME = name # Graph key - self.client = client - self.execute_command = client.execute_command - - self.nodes = {} - self.edges = [] - self._labels = [] # List of node labels. - self._properties = [] # List of properties. - self._relationship_types = [] # List of relation types. - self.version = 0 # Graph version - - @property - def name(self): - return self.NAME - - def _clear_schema(self): - self._labels = [] - self._properties = [] - self._relationship_types = [] - - def _refresh_schema(self): - self._clear_schema() - self._refresh_labels() - self._refresh_relations() - self._refresh_attributes() - - def _refresh_labels(self): - lbls = self.labels() - - # Unpack data. - self._labels = [l[0] for _, l in enumerate(lbls)] - - def _refresh_relations(self): - rels = self.relationship_types() - - # Unpack data. - self._relationship_types = [r[0] for _, r in enumerate(rels)] - - def _refresh_attributes(self): - props = self.property_keys() - - # Unpack data. - self._properties = [p[0] for _, p in enumerate(props)] - - def get_label(self, idx): - """ - Returns a label by it's index - - Args: - - idx: - The index of the label - """ - try: - label = self._labels[idx] - except IndexError: - # Refresh labels. - self._refresh_labels() - label = self._labels[idx] - return label - - def get_relation(self, idx): - """ - Returns a relationship type by it's index - - Args: - - idx: - The index of the relation - """ - try: - relationship_type = self._relationship_types[idx] - except IndexError: - # Refresh relationship types. - self._refresh_relations() - relationship_type = self._relationship_types[idx] - return relationship_type - - def get_property(self, idx): - """ - Returns a property by it's index - - Args: - - idx: - The index of the property - """ - try: - p = self._properties[idx] - except IndexError: - # Refresh properties. - self._refresh_attributes() - p = self._properties[idx] - return p - - def add_node(self, node): - """ - Adds a node to the graph. - """ - if node.alias is None: - node.alias = random_string() - self.nodes[node.alias] = node - - def add_edge(self, edge): - """ - Adds an edge to the graph. - """ - if not (self.nodes[edge.src_node.alias] and self.nodes[edge.dest_node.alias]): - raise AssertionError("Both edge's end must be in the graph") - - self.edges.append(edge) - - def _build_params_header(self, params): - if params is None: - return "" - if not isinstance(params, dict): - raise TypeError("'params' must be a dict") - # Header starts with "CYPHER" - params_header = "CYPHER " - for key, value in params.items(): - params_header += str(key) + "=" + stringify_param_value(value) + " " - return params_header - - # Procedures. - def call_procedure(self, procedure, *args, read_only=False, **kwagrs): - args = [quote_string(arg) for arg in args] - q = f"CALL {procedure}({','.join(args)})" - - y = kwagrs.get("y", None) - if y is not None: - q += f"YIELD {','.join(y)}" - - return self.query(q, read_only=read_only) - - def labels(self): - return self.call_procedure(DB_LABELS, read_only=True).result_set - - def relationship_types(self): - return self.call_procedure(DB_RAELATIONSHIPTYPES, read_only=True).result_set - - def property_keys(self): - return self.call_procedure(DB_PROPERTYKEYS, read_only=True).result_set - - -class AsyncGraph(Graph, AsyncGraphCommands): - """Async version for Graph""" - - async def _refresh_labels(self): - lbls = await self.labels() - - # Unpack data. - self._labels = [l[0] for _, l in enumerate(lbls)] - - async def _refresh_attributes(self): - props = await self.property_keys() - - # Unpack data. - self._properties = [p[0] for _, p in enumerate(props)] - - async def _refresh_relations(self): - rels = await self.relationship_types() - - # Unpack data. - self._relationship_types = [r[0] for _, r in enumerate(rels)] - - async def get_label(self, idx): - """ - Returns a label by it's index - - Args: - - idx: - The index of the label - """ - try: - label = self._labels[idx] - except IndexError: - # Refresh labels. - await self._refresh_labels() - label = self._labels[idx] - return label - - async def get_property(self, idx): - """ - Returns a property by it's index - - Args: - - idx: - The index of the property - """ - try: - p = self._properties[idx] - except IndexError: - # Refresh properties. - await self._refresh_attributes() - p = self._properties[idx] - return p - - async def get_relation(self, idx): - """ - Returns a relationship type by it's index - - Args: - - idx: - The index of the relation - """ - try: - relationship_type = self._relationship_types[idx] - except IndexError: - # Refresh relationship types. - await self._refresh_relations() - relationship_type = self._relationship_types[idx] - return relationship_type - - async def call_procedure(self, procedure, *args, read_only=False, **kwagrs): - args = [quote_string(arg) for arg in args] - q = f"CALL {procedure}({','.join(args)})" - - y = kwagrs.get("y", None) - if y is not None: - f"YIELD {','.join(y)}" - return await self.query(q, read_only=read_only) - - async def labels(self): - return (await self.call_procedure(DB_LABELS, read_only=True)).result_set - - async def property_keys(self): - return (await self.call_procedure(DB_PROPERTYKEYS, read_only=True)).result_set - - async def relationship_types(self): - return ( - await self.call_procedure(DB_RAELATIONSHIPTYPES, read_only=True) - ).result_set diff --git a/redis/commands/graph/commands.py b/redis/commands/graph/commands.py deleted file mode 100644 index 1e41a5fb1f..0000000000 --- a/redis/commands/graph/commands.py +++ /dev/null @@ -1,311 +0,0 @@ -from redis import DataError -from redis.exceptions import ResponseError - -from .exceptions import VersionMismatchException -from .execution_plan import ExecutionPlan -from .query_result import AsyncQueryResult, QueryResult - -PROFILE_CMD = "GRAPH.PROFILE" -RO_QUERY_CMD = "GRAPH.RO_QUERY" -QUERY_CMD = "GRAPH.QUERY" -DELETE_CMD = "GRAPH.DELETE" -SLOWLOG_CMD = "GRAPH.SLOWLOG" -CONFIG_CMD = "GRAPH.CONFIG" -LIST_CMD = "GRAPH.LIST" -EXPLAIN_CMD = "GRAPH.EXPLAIN" - - -class GraphCommands: - """RedisGraph Commands""" - - def commit(self): - """ - Create entire graph. - """ - if len(self.nodes) == 0 and len(self.edges) == 0: - return None - - query = "CREATE " - for _, node in self.nodes.items(): - query += str(node) + "," - - query += ",".join([str(edge) for edge in self.edges]) - - # Discard leading comma. - if query[-1] == ",": - query = query[:-1] - - return self.query(query) - - def query(self, q, params=None, timeout=None, read_only=False, profile=False): - """ - Executes a query against the graph. - For more information see `GRAPH.QUERY `_. # noqa - - Args: - - q : str - The query. - params : dict - Query parameters. - timeout : int - Maximum runtime for read queries in milliseconds. - read_only : bool - Executes a readonly query if set to True. - profile : bool - Return details on results produced by and time - spent in each operation. - """ - - # maintain original 'q' - query = q - - # handle query parameters - query = self._build_params_header(params) + query - - # construct query command - # ask for compact result-set format - # specify known graph version - if profile: - cmd = PROFILE_CMD - else: - cmd = RO_QUERY_CMD if read_only else QUERY_CMD - command = [cmd, self.name, query, "--compact"] - - # include timeout is specified - if isinstance(timeout, int): - command.extend(["timeout", timeout]) - elif timeout is not None: - raise Exception("Timeout argument must be a positive integer") - - # issue query - try: - response = self.execute_command(*command) - return QueryResult(self, response, profile) - except ResponseError as e: - if "unknown command" in str(e) and read_only: - # `GRAPH.RO_QUERY` is unavailable in older versions. - return self.query(q, params, timeout, read_only=False) - raise e - except VersionMismatchException as e: - # client view over the graph schema is out of sync - # set client version and refresh local schema - self.version = e.version - self._refresh_schema() - # re-issue query - return self.query(q, params, timeout, read_only) - - def merge(self, pattern): - """ - Merge pattern. - """ - query = "MERGE " - query += str(pattern) - - return self.query(query) - - def delete(self): - """ - Deletes graph. - For more information see `DELETE `_. # noqa - """ - self._clear_schema() - return self.execute_command(DELETE_CMD, self.name) - - # declared here, to override the built in redis.db.flush() - def flush(self): - """ - Commit the graph and reset the edges and the nodes to zero length. - """ - self.commit() - self.nodes = {} - self.edges = [] - - def bulk(self, **kwargs): - """Internal only. Not supported.""" - raise NotImplementedError( - "GRAPH.BULK is internal only. " - "Use https://github.com/redisgraph/redisgraph-bulk-loader." - ) - - def profile(self, query): - """ - Execute a query and produce an execution plan augmented with metrics - for each operation's execution. Return a string representation of a - query execution plan, with details on results produced by and time - spent in each operation. - For more information see `GRAPH.PROFILE `_. # noqa - """ - return self.query(query, profile=True) - - def slowlog(self): - """ - Get a list containing up to 10 of the slowest queries issued - against the given graph ID. - For more information see `GRAPH.SLOWLOG `_. # noqa - - Each item in the list has the following structure: - 1. A unix timestamp at which the log entry was processed. - 2. The issued command. - 3. The issued query. - 4. The amount of time needed for its execution, in milliseconds. - """ - return self.execute_command(SLOWLOG_CMD, self.name) - - def config(self, name, value=None, set=False): - """ - Retrieve or update a RedisGraph configuration. - For more information see ``__. - - Args: - - name : str - The name of the configuration - value : - The value we want to set (can be used only when `set` is on) - set : bool - Turn on to set a configuration. Default behavior is get. - """ - params = ["SET" if set else "GET", name] - if value is not None: - if set: - params.append(value) - else: - raise DataError("``value`` can be provided only when ``set`` is True") # noqa - return self.execute_command(CONFIG_CMD, *params) - - def list_keys(self): - """ - Lists all graph keys in the keyspace. - For more information see `GRAPH.LIST `_. # noqa - """ - return self.execute_command(LIST_CMD) - - def execution_plan(self, query, params=None): - """ - Get the execution plan for given query, - GRAPH.EXPLAIN returns an array of operations. - - Args: - query: the query that will be executed - params: query parameters - """ - query = self._build_params_header(params) + query - - plan = self.execute_command(EXPLAIN_CMD, self.name, query) - if isinstance(plan[0], bytes): - plan = [b.decode() for b in plan] - return "\n".join(plan) - - def explain(self, query, params=None): - """ - Get the execution plan for given query, - GRAPH.EXPLAIN returns ExecutionPlan object. - For more information see `GRAPH.EXPLAIN `_. # noqa - - Args: - query: the query that will be executed - params: query parameters - """ - query = self._build_params_header(params) + query - - plan = self.execute_command(EXPLAIN_CMD, self.name, query) - return ExecutionPlan(plan) - - -class AsyncGraphCommands(GraphCommands): - async def query(self, q, params=None, timeout=None, read_only=False, profile=False): - """ - Executes a query against the graph. - For more information see `GRAPH.QUERY `_. # noqa - - Args: - - q : str - The query. - params : dict - Query parameters. - timeout : int - Maximum runtime for read queries in milliseconds. - read_only : bool - Executes a readonly query if set to True. - profile : bool - Return details on results produced by and time - spent in each operation. - """ - - # maintain original 'q' - query = q - - # handle query parameters - query = self._build_params_header(params) + query - - # construct query command - # ask for compact result-set format - # specify known graph version - if profile: - cmd = PROFILE_CMD - else: - cmd = RO_QUERY_CMD if read_only else QUERY_CMD - command = [cmd, self.name, query, "--compact"] - - # include timeout is specified - if isinstance(timeout, int): - command.extend(["timeout", timeout]) - elif timeout is not None: - raise Exception("Timeout argument must be a positive integer") - - # issue query - try: - response = await self.execute_command(*command) - return await AsyncQueryResult().initialize(self, response, profile) - except ResponseError as e: - if "unknown command" in str(e) and read_only: - # `GRAPH.RO_QUERY` is unavailable in older versions. - return await self.query(q, params, timeout, read_only=False) - raise e - except VersionMismatchException as e: - # client view over the graph schema is out of sync - # set client version and refresh local schema - self.version = e.version - self._refresh_schema() - # re-issue query - return await self.query(q, params, timeout, read_only) - - async def execution_plan(self, query, params=None): - """ - Get the execution plan for given query, - GRAPH.EXPLAIN returns an array of operations. - - Args: - query: the query that will be executed - params: query parameters - """ - query = self._build_params_header(params) + query - - plan = await self.execute_command(EXPLAIN_CMD, self.name, query) - if isinstance(plan[0], bytes): - plan = [b.decode() for b in plan] - return "\n".join(plan) - - async def explain(self, query, params=None): - """ - Get the execution plan for given query, - GRAPH.EXPLAIN returns ExecutionPlan object. - - Args: - query: the query that will be executed - params: query parameters - """ - query = self._build_params_header(params) + query - - plan = await self.execute_command(EXPLAIN_CMD, self.name, query) - return ExecutionPlan(plan) - - async def flush(self): - """ - Commit the graph and reset the edges and the nodes to zero length. - """ - await self.commit() - self.nodes = {} - self.edges = [] diff --git a/redis/commands/graph/edge.py b/redis/commands/graph/edge.py deleted file mode 100644 index 6ee195f1f5..0000000000 --- a/redis/commands/graph/edge.py +++ /dev/null @@ -1,91 +0,0 @@ -from ..helpers import quote_string -from .node import Node - - -class Edge: - """ - An edge connecting two nodes. - """ - - def __init__(self, src_node, relation, dest_node, edge_id=None, properties=None): - """ - Create a new edge. - """ - if src_node is None or dest_node is None: - # NOTE(bors-42): It makes sense to change AssertionError to - # ValueError here - raise AssertionError("Both src_node & dest_node must be provided") - - self.id = edge_id - self.relation = relation or "" - self.properties = properties or {} - self.src_node = src_node - self.dest_node = dest_node - - def to_string(self): - res = "" - if self.properties: - props = ",".join( - key + ":" + str(quote_string(val)) - for key, val in sorted(self.properties.items()) - ) - res += "{" + props + "}" - - return res - - def __str__(self): - # Source node. - if isinstance(self.src_node, Node): - res = str(self.src_node) - else: - res = "()" - - # Edge - res += "-[" - if self.relation: - res += ":" + self.relation - if self.properties: - props = ",".join( - key + ":" + str(quote_string(val)) - for key, val in sorted(self.properties.items()) - ) - res += "{" + props + "}" - res += "]->" - - # Dest node. - if isinstance(self.dest_node, Node): - res += str(self.dest_node) - else: - res += "()" - - return res - - def __eq__(self, rhs): - # Type checking - if not isinstance(rhs, Edge): - return False - - # Quick positive check, if both IDs are set. - if self.id is not None and rhs.id is not None and self.id == rhs.id: - return True - - # Source and destination nodes should match. - if self.src_node != rhs.src_node: - return False - - if self.dest_node != rhs.dest_node: - return False - - # Relation should match. - if self.relation != rhs.relation: - return False - - # Quick check for number of properties. - if len(self.properties) != len(rhs.properties): - return False - - # Compare properties. - if self.properties != rhs.properties: - return False - - return True diff --git a/redis/commands/graph/exceptions.py b/redis/commands/graph/exceptions.py deleted file mode 100644 index 4bbac1008e..0000000000 --- a/redis/commands/graph/exceptions.py +++ /dev/null @@ -1,3 +0,0 @@ -class VersionMismatchException(Exception): - def __init__(self, version): - self.version = version diff --git a/redis/commands/graph/execution_plan.py b/redis/commands/graph/execution_plan.py deleted file mode 100644 index 179a80cca0..0000000000 --- a/redis/commands/graph/execution_plan.py +++ /dev/null @@ -1,211 +0,0 @@ -import re - - -class ProfileStats: - """ - ProfileStats, runtime execution statistics of operation. - """ - - def __init__(self, records_produced, execution_time): - self.records_produced = records_produced - self.execution_time = execution_time - - -class Operation: - """ - Operation, single operation within execution plan. - """ - - def __init__(self, name, args=None, profile_stats=None): - """ - Create a new operation. - - Args: - name: string that represents the name of the operation - args: operation arguments - profile_stats: profile statistics - """ - self.name = name - self.args = args - self.profile_stats = profile_stats - self.children = [] - - def append_child(self, child): - if not isinstance(child, Operation) or self is child: - raise Exception("child must be Operation") - - self.children.append(child) - return self - - def child_count(self): - return len(self.children) - - def __eq__(self, o: object) -> bool: - if not isinstance(o, Operation): - return False - - return self.name == o.name and self.args == o.args - - def __str__(self) -> str: - args_str = "" if self.args is None else " | " + self.args - return f"{self.name}{args_str}" - - -class ExecutionPlan: - """ - ExecutionPlan, collection of operations. - """ - - def __init__(self, plan): - """ - Create a new execution plan. - - Args: - plan: array of strings that represents the collection operations - the output from GRAPH.EXPLAIN - """ - if not isinstance(plan, list): - raise Exception("plan must be an array") - - if isinstance(plan[0], bytes): - plan = [b.decode() for b in plan] - - self.plan = plan - self.structured_plan = self._operation_tree() - - def _compare_operations(self, root_a, root_b): - """ - Compare execution plan operation tree - - Return: True if operation trees are equal, False otherwise - """ - - # compare current root - if root_a != root_b: - return False - - # make sure root have the same number of children - if root_a.child_count() != root_b.child_count(): - return False - - # recursively compare children - for i in range(root_a.child_count()): - if not self._compare_operations(root_a.children[i], root_b.children[i]): - return False - - return True - - def __str__(self) -> str: - def aggraget_str(str_children): - return "\n".join( - [ - " " + line - for str_child in str_children - for line in str_child.splitlines() - ] - ) - - def combine_str(x, y): - return f"{x}\n{y}" - - return self._operation_traverse( - self.structured_plan, str, aggraget_str, combine_str - ) - - def __eq__(self, o: object) -> bool: - """Compares two execution plans - - Return: True if the two plans are equal False otherwise - """ - # make sure 'o' is an execution-plan - if not isinstance(o, ExecutionPlan): - return False - - # get root for both plans - root_a = self.structured_plan - root_b = o.structured_plan - - # compare execution trees - return self._compare_operations(root_a, root_b) - - def _operation_traverse(self, op, op_f, aggregate_f, combine_f): - """ - Traverse operation tree recursively applying functions - - Args: - op: operation to traverse - op_f: function applied for each operation - aggregate_f: aggregation function applied for all children of a single operation - combine_f: combine function applied for the operation result and the children result - """ # noqa - # apply op_f for each operation - op_res = op_f(op) - if len(op.children) == 0: - return op_res # no children return - else: - # apply _operation_traverse recursively - children = [ - self._operation_traverse(child, op_f, aggregate_f, combine_f) - for child in op.children - ] - # combine the operation result with the children aggregated result - return combine_f(op_res, aggregate_f(children)) - - def _operation_tree(self): - """Build the operation tree from the string representation""" - - # initial state - i = 0 - level = 0 - stack = [] - current = None - - def _create_operation(args): - profile_stats = None - name = args[0].strip() - args.pop(0) - if len(args) > 0 and "Records produced" in args[-1]: - records_produced = int( - re.search("Records produced: (\\d+)", args[-1]).group(1) - ) - execution_time = float( - re.search("Execution time: (\\d+.\\d+) ms", args[-1]).group(1) - ) - profile_stats = ProfileStats(records_produced, execution_time) - args.pop(-1) - return Operation( - name, None if len(args) == 0 else args[0].strip(), profile_stats - ) - - # iterate plan operations - while i < len(self.plan): - current_op = self.plan[i] - op_level = current_op.count(" ") - if op_level == level: - # if the operation level equal to the current level - # set the current operation and move next - child = _create_operation(current_op.split("|")) - if current: - current = stack.pop() - current.append_child(child) - current = child - i += 1 - elif op_level == level + 1: - # if the operation is child of the current operation - # add it as child and set as current operation - child = _create_operation(current_op.split("|")) - current.append_child(child) - stack.append(current) - current = child - level += 1 - i += 1 - elif op_level < level: - # if the operation is not child of current operation - # go back to it's parent operation - levels_back = level - op_level + 1 - for _ in range(levels_back): - current = stack.pop() - level -= levels_back - else: - raise Exception("corrupted plan") - return stack[0] diff --git a/redis/commands/graph/node.py b/redis/commands/graph/node.py deleted file mode 100644 index 4546a393b1..0000000000 --- a/redis/commands/graph/node.py +++ /dev/null @@ -1,88 +0,0 @@ -from ..helpers import quote_string - - -class Node: - """ - A node within the graph. - """ - - def __init__(self, node_id=None, alias=None, label=None, properties=None): - """ - Create a new node. - """ - self.id = node_id - self.alias = alias - if isinstance(label, list): - label = [inner_label for inner_label in label if inner_label != ""] - - if ( - label is None - or label == "" - or (isinstance(label, list) and len(label) == 0) - ): - self.label = None - self.labels = None - elif isinstance(label, str): - self.label = label - self.labels = [label] - elif isinstance(label, list) and all( - [isinstance(inner_label, str) for inner_label in label] - ): - self.label = label[0] - self.labels = label - else: - raise AssertionError( - "label should be either None, string or a list of strings" - ) - - self.properties = properties or {} - - def to_string(self): - res = "" - if self.properties: - props = ",".join( - key + ":" + str(quote_string(val)) - for key, val in sorted(self.properties.items()) - ) - res += "{" + props + "}" - - return res - - def __str__(self): - res = "(" - if self.alias: - res += self.alias - if self.labels: - res += ":" + ":".join(self.labels) - if self.properties: - props = ",".join( - key + ":" + str(quote_string(val)) - for key, val in sorted(self.properties.items()) - ) - res += "{" + props + "}" - res += ")" - - return res - - def __eq__(self, rhs): - # Type checking - if not isinstance(rhs, Node): - return False - - # Quick positive check, if both IDs are set. - if self.id is not None and rhs.id is not None and self.id != rhs.id: - return False - - # Label should match. - if self.label != rhs.label: - return False - - # Quick check for number of properties. - if len(self.properties) != len(rhs.properties): - return False - - # Compare properties. - if self.properties != rhs.properties: - return False - - return True diff --git a/redis/commands/graph/path.py b/redis/commands/graph/path.py deleted file mode 100644 index ee22dc8c6b..0000000000 --- a/redis/commands/graph/path.py +++ /dev/null @@ -1,78 +0,0 @@ -from .edge import Edge -from .node import Node - - -class Path: - def __init__(self, nodes, edges): - if not (isinstance(nodes, list) and isinstance(edges, list)): - raise TypeError("nodes and edges must be list") - - self._nodes = nodes - self._edges = edges - self.append_type = Node - - @classmethod - def new_empty_path(cls): - return cls([], []) - - def nodes(self): - return self._nodes - - def edges(self): - return self._edges - - def get_node(self, index): - return self._nodes[index] - - def get_relationship(self, index): - return self._edges[index] - - def first_node(self): - return self._nodes[0] - - def last_node(self): - return self._nodes[-1] - - def edge_count(self): - return len(self._edges) - - def nodes_count(self): - return len(self._nodes) - - def add_node(self, node): - if not isinstance(node, self.append_type): - raise AssertionError("Add Edge before adding Node") - self._nodes.append(node) - self.append_type = Edge - return self - - def add_edge(self, edge): - if not isinstance(edge, self.append_type): - raise AssertionError("Add Node before adding Edge") - self._edges.append(edge) - self.append_type = Node - return self - - def __eq__(self, other): - # Type checking - if not isinstance(other, Path): - return False - - return self.nodes() == other.nodes() and self.edges() == other.edges() - - def __str__(self): - res = "<" - edge_count = self.edge_count() - for i in range(0, edge_count): - node_id = self.get_node(i).id - res += "(" + str(node_id) + ")" - edge = self.get_relationship(i) - res += ( - "-[" + str(int(edge.id)) + "]->" - if edge.src_node == node_id - else "<-[" + str(int(edge.id)) + "]-" - ) - node_id = self.get_node(edge_count).id - res += "(" + str(node_id) + ")" - res += ">" - return res diff --git a/redis/commands/graph/query_result.py b/redis/commands/graph/query_result.py deleted file mode 100644 index 7709081bcf..0000000000 --- a/redis/commands/graph/query_result.py +++ /dev/null @@ -1,588 +0,0 @@ -import sys -from collections import OrderedDict - -# from prettytable import PrettyTable -from redis import ResponseError - -from .edge import Edge -from .exceptions import VersionMismatchException -from .node import Node -from .path import Path - -LABELS_ADDED = "Labels added" -LABELS_REMOVED = "Labels removed" -NODES_CREATED = "Nodes created" -NODES_DELETED = "Nodes deleted" -RELATIONSHIPS_DELETED = "Relationships deleted" -PROPERTIES_SET = "Properties set" -PROPERTIES_REMOVED = "Properties removed" -RELATIONSHIPS_CREATED = "Relationships created" -INDICES_CREATED = "Indices created" -INDICES_DELETED = "Indices deleted" -CACHED_EXECUTION = "Cached execution" -INTERNAL_EXECUTION_TIME = "internal execution time" - -STATS = [ - LABELS_ADDED, - LABELS_REMOVED, - NODES_CREATED, - PROPERTIES_SET, - PROPERTIES_REMOVED, - RELATIONSHIPS_CREATED, - NODES_DELETED, - RELATIONSHIPS_DELETED, - INDICES_CREATED, - INDICES_DELETED, - CACHED_EXECUTION, - INTERNAL_EXECUTION_TIME, -] - - -class ResultSetColumnTypes: - COLUMN_UNKNOWN = 0 - COLUMN_SCALAR = 1 - COLUMN_NODE = 2 # Unused as of RedisGraph v2.1.0, retained for backwards compatibility. # noqa - COLUMN_RELATION = 3 # Unused as of RedisGraph v2.1.0, retained for backwards compatibility. # noqa - - -class ResultSetScalarTypes: - VALUE_UNKNOWN = 0 - VALUE_NULL = 1 - VALUE_STRING = 2 - VALUE_INTEGER = 3 - VALUE_BOOLEAN = 4 - VALUE_DOUBLE = 5 - VALUE_ARRAY = 6 - VALUE_EDGE = 7 - VALUE_NODE = 8 - VALUE_PATH = 9 - VALUE_MAP = 10 - VALUE_POINT = 11 - - -class QueryResult: - def __init__(self, graph, response, profile=False): - """ - A class that represents a result of the query operation. - - Args: - - graph: - The graph on which the query was executed. - response: - The response from the server. - profile: - A boolean indicating if the query command was "GRAPH.PROFILE" - """ - self.graph = graph - self.header = [] - self.result_set = [] - - # in case of an error an exception will be raised - self._check_for_errors(response) - - if len(response) == 1: - self.parse_statistics(response[0]) - elif profile: - self.parse_profile(response) - else: - # start by parsing statistics, matches the one we have - self.parse_statistics(response[-1]) # Last element. - self.parse_results(response) - - def _check_for_errors(self, response): - """ - Check if the response contains an error. - """ - if isinstance(response[0], ResponseError): - error = response[0] - if str(error) == "version mismatch": - version = response[1] - error = VersionMismatchException(version) - raise error - - # If we encountered a run-time error, the last response - # element will be an exception - if isinstance(response[-1], ResponseError): - raise response[-1] - - def parse_results(self, raw_result_set): - """ - Parse the query execution result returned from the server. - """ - self.header = self.parse_header(raw_result_set) - - # Empty header. - if len(self.header) == 0: - return - - self.result_set = self.parse_records(raw_result_set) - - def parse_statistics(self, raw_statistics): - """ - Parse the statistics returned in the response. - """ - self.statistics = {} - - # decode statistics - for idx, stat in enumerate(raw_statistics): - if isinstance(stat, bytes): - raw_statistics[idx] = stat.decode() - - for s in STATS: - v = self._get_value(s, raw_statistics) - if v is not None: - self.statistics[s] = v - - def parse_header(self, raw_result_set): - """ - Parse the header of the result. - """ - # An array of column name/column type pairs. - header = raw_result_set[0] - return header - - def parse_records(self, raw_result_set): - """ - Parses the result set and returns a list of records. - """ - records = [ - [ - self.parse_record_types[self.header[idx][0]](cell) - for idx, cell in enumerate(row) - ] - for row in raw_result_set[1] - ] - - return records - - def parse_entity_properties(self, props): - """ - Parse node / edge properties. - """ - # [[name, value type, value] X N] - properties = {} - for prop in props: - prop_name = self.graph.get_property(prop[0]) - prop_value = self.parse_scalar(prop[1:]) - properties[prop_name] = prop_value - - return properties - - def parse_string(self, cell): - """ - Parse the cell as a string. - """ - if isinstance(cell, bytes): - return cell.decode() - elif not isinstance(cell, str): - return str(cell) - else: - return cell - - def parse_node(self, cell): - """ - Parse the cell to a node. - """ - # Node ID (integer), - # [label string offset (integer)], - # [[name, value type, value] X N] - - node_id = int(cell[0]) - labels = None - if len(cell[1]) > 0: - labels = [] - for inner_label in cell[1]: - labels.append(self.graph.get_label(inner_label)) - properties = self.parse_entity_properties(cell[2]) - return Node(node_id=node_id, label=labels, properties=properties) - - def parse_edge(self, cell): - """ - Parse the cell to an edge. - """ - # Edge ID (integer), - # reltype string offset (integer), - # src node ID offset (integer), - # dest node ID offset (integer), - # [[name, value, value type] X N] - - edge_id = int(cell[0]) - relation = self.graph.get_relation(cell[1]) - src_node_id = int(cell[2]) - dest_node_id = int(cell[3]) - properties = self.parse_entity_properties(cell[4]) - return Edge( - src_node_id, relation, dest_node_id, edge_id=edge_id, properties=properties - ) - - def parse_path(self, cell): - """ - Parse the cell to a path. - """ - nodes = self.parse_scalar(cell[0]) - edges = self.parse_scalar(cell[1]) - return Path(nodes, edges) - - def parse_map(self, cell): - """ - Parse the cell as a map. - """ - m = OrderedDict() - n_entries = len(cell) - - # A map is an array of key value pairs. - # 1. key (string) - # 2. array: (value type, value) - for i in range(0, n_entries, 2): - key = self.parse_string(cell[i]) - m[key] = self.parse_scalar(cell[i + 1]) - - return m - - def parse_point(self, cell): - """ - Parse the cell to point. - """ - p = {} - # A point is received an array of the form: [latitude, longitude] - # It is returned as a map of the form: {"latitude": latitude, "longitude": longitude} # noqa - p["latitude"] = float(cell[0]) - p["longitude"] = float(cell[1]) - return p - - def parse_null(self, cell): - """ - Parse a null value. - """ - return None - - def parse_integer(self, cell): - """ - Parse the integer value from the cell. - """ - return int(cell) - - def parse_boolean(self, value): - """ - Parse the cell value as a boolean. - """ - value = value.decode() if isinstance(value, bytes) else value - try: - scalar = True if strtobool(value) else False - except ValueError: - sys.stderr.write("unknown boolean type\n") - scalar = None - return scalar - - def parse_double(self, cell): - """ - Parse the cell as a double. - """ - return float(cell) - - def parse_array(self, value): - """ - Parse an array of values. - """ - scalar = [self.parse_scalar(value[i]) for i in range(len(value))] - return scalar - - def parse_unknown(self, cell): - """ - Parse a cell of unknown type. - """ - sys.stderr.write("Unknown type\n") - return None - - def parse_scalar(self, cell): - """ - Parse a scalar value from a cell in the result set. - """ - scalar_type = int(cell[0]) - value = cell[1] - scalar = self.parse_scalar_types[scalar_type](value) - - return scalar - - def parse_profile(self, response): - self.result_set = [x[0 : x.index(",")].strip() for x in response] - - def is_empty(self): - return len(self.result_set) == 0 - - @staticmethod - def _get_value(prop, statistics): - for stat in statistics: - if prop in stat: - return float(stat.split(": ")[1].split(" ")[0]) - - return None - - def _get_stat(self, stat): - return self.statistics[stat] if stat in self.statistics else 0 - - @property - def labels_added(self): - """Returns the number of labels added in the query""" - return self._get_stat(LABELS_ADDED) - - @property - def labels_removed(self): - """Returns the number of labels removed in the query""" - return self._get_stat(LABELS_REMOVED) - - @property - def nodes_created(self): - """Returns the number of nodes created in the query""" - return self._get_stat(NODES_CREATED) - - @property - def nodes_deleted(self): - """Returns the number of nodes deleted in the query""" - return self._get_stat(NODES_DELETED) - - @property - def properties_set(self): - """Returns the number of properties set in the query""" - return self._get_stat(PROPERTIES_SET) - - @property - def properties_removed(self): - """Returns the number of properties removed in the query""" - return self._get_stat(PROPERTIES_REMOVED) - - @property - def relationships_created(self): - """Returns the number of relationships created in the query""" - return self._get_stat(RELATIONSHIPS_CREATED) - - @property - def relationships_deleted(self): - """Returns the number of relationships deleted in the query""" - return self._get_stat(RELATIONSHIPS_DELETED) - - @property - def indices_created(self): - """Returns the number of indices created in the query""" - return self._get_stat(INDICES_CREATED) - - @property - def indices_deleted(self): - """Returns the number of indices deleted in the query""" - return self._get_stat(INDICES_DELETED) - - @property - def cached_execution(self): - """Returns whether or not the query execution plan was cached""" - return self._get_stat(CACHED_EXECUTION) == 1 - - @property - def run_time_ms(self): - """Returns the server execution time of the query""" - return self._get_stat(INTERNAL_EXECUTION_TIME) - - @property - def parse_scalar_types(self): - return { - ResultSetScalarTypes.VALUE_NULL: self.parse_null, - ResultSetScalarTypes.VALUE_STRING: self.parse_string, - ResultSetScalarTypes.VALUE_INTEGER: self.parse_integer, - ResultSetScalarTypes.VALUE_BOOLEAN: self.parse_boolean, - ResultSetScalarTypes.VALUE_DOUBLE: self.parse_double, - ResultSetScalarTypes.VALUE_ARRAY: self.parse_array, - ResultSetScalarTypes.VALUE_NODE: self.parse_node, - ResultSetScalarTypes.VALUE_EDGE: self.parse_edge, - ResultSetScalarTypes.VALUE_PATH: self.parse_path, - ResultSetScalarTypes.VALUE_MAP: self.parse_map, - ResultSetScalarTypes.VALUE_POINT: self.parse_point, - ResultSetScalarTypes.VALUE_UNKNOWN: self.parse_unknown, - } - - @property - def parse_record_types(self): - return { - ResultSetColumnTypes.COLUMN_SCALAR: self.parse_scalar, - ResultSetColumnTypes.COLUMN_NODE: self.parse_node, - ResultSetColumnTypes.COLUMN_RELATION: self.parse_edge, - ResultSetColumnTypes.COLUMN_UNKNOWN: self.parse_unknown, - } - - -class AsyncQueryResult(QueryResult): - """ - Async version for the QueryResult class - a class that - represents a result of the query operation. - """ - - def __init__(self): - """ - To init the class you must call self.initialize() - """ - pass - - async def initialize(self, graph, response, profile=False): - """ - Initializes the class. - Args: - - graph: - The graph on which the query was executed. - response: - The response from the server. - profile: - A boolean indicating if the query command was "GRAPH.PROFILE" - """ - self.graph = graph - self.header = [] - self.result_set = [] - - # in case of an error an exception will be raised - self._check_for_errors(response) - - if len(response) == 1: - self.parse_statistics(response[0]) - elif profile: - self.parse_profile(response) - else: - # start by parsing statistics, matches the one we have - self.parse_statistics(response[-1]) # Last element. - await self.parse_results(response) - - return self - - async def parse_node(self, cell): - """ - Parses a node from the cell. - """ - # Node ID (integer), - # [label string offset (integer)], - # [[name, value type, value] X N] - - labels = None - if len(cell[1]) > 0: - labels = [] - for inner_label in cell[1]: - labels.append(await self.graph.get_label(inner_label)) - properties = await self.parse_entity_properties(cell[2]) - node_id = int(cell[0]) - return Node(node_id=node_id, label=labels, properties=properties) - - async def parse_scalar(self, cell): - """ - Parses a scalar value from the server response. - """ - scalar_type = int(cell[0]) - value = cell[1] - try: - scalar = await self.parse_scalar_types[scalar_type](value) - except TypeError: - # Not all of the functions are async - scalar = self.parse_scalar_types[scalar_type](value) - - return scalar - - async def parse_records(self, raw_result_set): - """ - Parses the result set and returns a list of records. - """ - records = [] - for row in raw_result_set[1]: - record = [ - await self.parse_record_types[self.header[idx][0]](cell) - for idx, cell in enumerate(row) - ] - records.append(record) - - return records - - async def parse_results(self, raw_result_set): - """ - Parse the query execution result returned from the server. - """ - self.header = self.parse_header(raw_result_set) - - # Empty header. - if len(self.header) == 0: - return - - self.result_set = await self.parse_records(raw_result_set) - - async def parse_entity_properties(self, props): - """ - Parse node / edge properties. - """ - # [[name, value type, value] X N] - properties = {} - for prop in props: - prop_name = await self.graph.get_property(prop[0]) - prop_value = await self.parse_scalar(prop[1:]) - properties[prop_name] = prop_value - - return properties - - async def parse_edge(self, cell): - """ - Parse the cell to an edge. - """ - # Edge ID (integer), - # reltype string offset (integer), - # src node ID offset (integer), - # dest node ID offset (integer), - # [[name, value, value type] X N] - - edge_id = int(cell[0]) - relation = await self.graph.get_relation(cell[1]) - src_node_id = int(cell[2]) - dest_node_id = int(cell[3]) - properties = await self.parse_entity_properties(cell[4]) - return Edge( - src_node_id, relation, dest_node_id, edge_id=edge_id, properties=properties - ) - - async def parse_path(self, cell): - """ - Parse the cell to a path. - """ - nodes = await self.parse_scalar(cell[0]) - edges = await self.parse_scalar(cell[1]) - return Path(nodes, edges) - - async def parse_map(self, cell): - """ - Parse the cell to a map. - """ - m = OrderedDict() - n_entries = len(cell) - - # A map is an array of key value pairs. - # 1. key (string) - # 2. array: (value type, value) - for i in range(0, n_entries, 2): - key = self.parse_string(cell[i]) - m[key] = await self.parse_scalar(cell[i + 1]) - - return m - - async def parse_array(self, value): - """ - Parse array value. - """ - scalar = [await self.parse_scalar(value[i]) for i in range(len(value))] - return scalar - - -def strtobool(val): - """ - Convert a string representation of truth to true (1) or false (0). - True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values - are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if - 'val' is anything else. - """ - val = val.lower() - if val in ("y", "yes", "t", "true", "on", "1"): - return True - elif val in ("n", "no", "f", "false", "off", "0"): - return False - else: - raise ValueError(f"invalid truth value {val!r}") diff --git a/redis/commands/helpers.py b/redis/commands/helpers.py index e11d34fb71..f6121b6c3b 100644 --- a/redis/commands/helpers.py +++ b/redis/commands/helpers.py @@ -88,26 +88,6 @@ def random_string(length=10): ) -def quote_string(v): - """ - RedisGraph strings must be quoted, - quote_string wraps given v with quotes incase - v is a string. - """ - - if isinstance(v, bytes): - v = v.decode() - elif not isinstance(v, str): - return v - if len(v) == 0: - return '""' - - v = v.replace("\\", "\\\\") - v = v.replace('"', '\\"') - - return f'"{v}"' - - def decode_dict_keys(obj): """Decode the keys of the given dictionary with utf-8.""" newobj = copy.copy(obj) @@ -118,33 +98,6 @@ def decode_dict_keys(obj): return newobj -def stringify_param_value(value): - """ - Turn a parameter value into a string suitable for the params header of - a Cypher command. - You may pass any value that would be accepted by `json.dumps()`. - - Ways in which output differs from that of `str()`: - * Strings are quoted. - * None --> "null". - * In dictionaries, keys are _not_ quoted. - - :param value: The parameter value to be turned into a string. - :return: string - """ - - if isinstance(value, str): - return quote_string(value) - elif value is None: - return "null" - elif isinstance(value, (list, tuple)): - return f"[{','.join(map(stringify_param_value, value))}]" - elif isinstance(value, dict): - return f"{{{','.join(f'{k}:{stringify_param_value(v)}' for k, v in value.items())}}}" # noqa - else: - return str(value) - - def get_protocol_version(client): if isinstance(client, redis.Redis) or isinstance(client, redis.asyncio.Redis): return client.connection_pool.connection_kwargs.get("protocol") diff --git a/redis/commands/redismodules.py b/redis/commands/redismodules.py index 7e2045a722..7ba40dd845 100644 --- a/redis/commands/redismodules.py +++ b/redis/commands/redismodules.py @@ -72,16 +72,6 @@ def tdigest(self): tdigest = TDigestBloom(client=self) return tdigest - def graph(self, index_name="idx"): - """Access the graph namespace, providing support for - redis graph data. - """ - - from .graph import Graph - - g = Graph(client=self, name=index_name) - return g - class AsyncRedisModuleCommands(RedisModuleCommands): def ft(self, index_name="idx"): @@ -91,13 +81,3 @@ def ft(self, index_name="idx"): s = AsyncSearch(client=self, index_name=index_name) return s - - def graph(self, index_name="idx"): - """Access the graph namespace, providing support for - redis graph data. - """ - - from .graph import AsyncGraph - - g = AsyncGraph(client=self, name=index_name) - return g diff --git a/tasks.py b/tasks.py index 2d1a073437..52decf08e7 100644 --- a/tasks.py +++ b/tasks.py @@ -58,11 +58,11 @@ def standalone_tests( if uvloop: run( - f"pytest {profile_arg} --protocol={protocol} {redis_mod_url} --cov=./ --cov-report=xml:coverage_resp{protocol}_uvloop.xml -m 'not onlycluster and not graph{extra_markers}' --uvloop --junit-xml=standalone-resp{protocol}-uvloop-results.xml" + f"pytest {profile_arg} --protocol={protocol} {redis_mod_url} --cov=./ --cov-report=xml:coverage_resp{protocol}_uvloop.xml -m 'not onlycluster{extra_markers}' --uvloop --junit-xml=standalone-resp{protocol}-uvloop-results.xml" ) else: run( - f"pytest {profile_arg} --protocol={protocol} {redis_mod_url} --cov=./ --cov-report=xml:coverage_resp{protocol}.xml -m 'not onlycluster and not graph{extra_markers}' --junit-xml=standalone-resp{protocol}-results.xml" + f"pytest {profile_arg} --protocol={protocol} {redis_mod_url} --cov=./ --cov-report=xml:coverage_resp{protocol}.xml -m 'not onlycluster{extra_markers}' --junit-xml=standalone-resp{protocol}-results.xml" ) @@ -74,11 +74,11 @@ def cluster_tests(c, uvloop=False, protocol=2, profile=False): cluster_tls_url = "rediss://localhost:27379/0" if uvloop: run( - f"pytest {profile_arg} --protocol={protocol} --cov=./ --cov-report=xml:coverage_cluster_resp{protocol}_uvloop.xml -m 'not onlynoncluster and not redismod and not graph' --redis-url={cluster_url} --redis-ssl-url={cluster_tls_url} --junit-xml=cluster-resp{protocol}-uvloop-results.xml --uvloop" + f"pytest {profile_arg} --protocol={protocol} --cov=./ --cov-report=xml:coverage_cluster_resp{protocol}_uvloop.xml -m 'not onlynoncluster and not redismod' --redis-url={cluster_url} --redis-ssl-url={cluster_tls_url} --junit-xml=cluster-resp{protocol}-uvloop-results.xml --uvloop" ) else: run( - f"pytest {profile_arg} --protocol={protocol} --cov=./ --cov-report=xml:coverage_cluster_resp{protocol}.xml -m 'not onlynoncluster and not redismod and not graph' --redis-url={cluster_url} --redis-ssl-url={cluster_tls_url} --junit-xml=cluster-resp{protocol}-results.xml" + f"pytest {profile_arg} --protocol={protocol} --cov=./ --cov-report=xml:coverage_cluster_resp{protocol}.xml -m 'not onlynoncluster and not redismod' --redis-url={cluster_url} --redis-ssl-url={cluster_tls_url} --junit-xml=cluster-resp{protocol}-results.xml" ) diff --git a/tests/test_asyncio/test_graph.py b/tests/test_asyncio/test_graph.py deleted file mode 100644 index 7b823265c3..0000000000 --- a/tests/test_asyncio/test_graph.py +++ /dev/null @@ -1,526 +0,0 @@ -import pytest -import pytest_asyncio -import redis.asyncio as redis -from redis.commands.graph import Edge, Node, Path -from redis.commands.graph.execution_plan import Operation -from redis.exceptions import ResponseError -from tests.conftest import skip_if_redis_enterprise, skip_if_resp_version - - -@pytest_asyncio.fixture() -async def decoded_r(create_redis, stack_url): - return await create_redis(decode_responses=True, url="redis://localhost:6480") - - -@pytest.mark.graph -@skip_if_resp_version(3) -async def test_bulk(decoded_r): - with pytest.raises(NotImplementedError): - await decoded_r.graph().bulk() - await decoded_r.graph().bulk(foo="bar!") - - -@pytest.mark.graph -@skip_if_resp_version(3) -async def test_graph_creation(decoded_r: redis.Redis): - graph = decoded_r.graph() - - john = Node( - label="person", - properties={ - "name": "John Doe", - "age": 33, - "gender": "male", - "status": "single", - }, - ) - graph.add_node(john) - japan = Node(label="country", properties={"name": "Japan"}) - - graph.add_node(japan) - edge = Edge(john, "visited", japan, properties={"purpose": "pleasure"}) - graph.add_edge(edge) - - await graph.commit() - - query = ( - 'MATCH (p:person)-[v:visited {purpose:"pleasure"}]->(c:country) RETURN p, v, c' - ) - - result = await graph.query(query) - - person = result.result_set[0][0] - visit = result.result_set[0][1] - country = result.result_set[0][2] - - assert person == john - assert visit.properties == edge.properties - assert country == japan - - query = """RETURN [1, 2.3, "4", true, false, null]""" - result = await graph.query(query) - assert [1, 2.3, "4", True, False, None] == result.result_set[0][0] - - # All done, remove graph. - await graph.delete() - - -@pytest.mark.graph -@skip_if_resp_version(3) -async def test_array_functions(decoded_r: redis.Redis): - graph = decoded_r.graph() - - query = """CREATE (p:person{name:'a',age:32, array:[0,1,2]})""" - await graph.query(query) - - query = """WITH [0,1,2] as x return x""" - result = await graph.query(query) - assert [0, 1, 2] == result.result_set[0][0] - - query = """MATCH(n) return collect(n)""" - result = await graph.query(query) - - a = Node( - node_id=0, - label="person", - properties={"name": "a", "age": 32, "array": [0, 1, 2]}, - ) - - assert [a] == result.result_set[0][0] - - -@pytest.mark.graph -@skip_if_resp_version(3) -async def test_path(decoded_r: redis.Redis): - node0 = Node(node_id=0, label="L1") - node1 = Node(node_id=1, label="L1") - edge01 = Edge(node0, "R1", node1, edge_id=0, properties={"value": 1}) - - graph = decoded_r.graph() - graph.add_node(node0) - graph.add_node(node1) - graph.add_edge(edge01) - await graph.flush() - - path01 = Path.new_empty_path().add_node(node0).add_edge(edge01).add_node(node1) - expected_results = [[path01]] - - query = "MATCH p=(:L1)-[:R1]->(:L1) RETURN p ORDER BY p" - result = await graph.query(query) - assert expected_results == result.result_set - - -@pytest.mark.graph -@skip_if_resp_version(3) -async def test_param(decoded_r: redis.Redis): - params = [1, 2.3, "str", True, False, None, [0, 1, 2]] - query = "RETURN $param" - for param in params: - result = await decoded_r.graph().query(query, {"param": param}) - expected_results = [[param]] - assert expected_results == result.result_set - - -@pytest.mark.graph -@skip_if_resp_version(3) -async def test_map(decoded_r: redis.Redis): - query = "RETURN {a:1, b:'str', c:NULL, d:[1,2,3], e:True, f:{x:1, y:2}}" - - actual = (await decoded_r.graph().query(query)).result_set[0][0] - expected = { - "a": 1, - "b": "str", - "c": None, - "d": [1, 2, 3], - "e": True, - "f": {"x": 1, "y": 2}, - } - - assert actual == expected - - -@pytest.mark.graph -@skip_if_resp_version(3) -async def test_point(decoded_r: redis.Redis): - query = "RETURN point({latitude: 32.070794860, longitude: 34.820751118})" - expected_lat = 32.070794860 - expected_lon = 34.820751118 - actual = (await decoded_r.graph().query(query)).result_set[0][0] - assert abs(actual["latitude"] - expected_lat) < 0.001 - assert abs(actual["longitude"] - expected_lon) < 0.001 - - query = "RETURN point({latitude: 32, longitude: 34.0})" - expected_lat = 32 - expected_lon = 34 - actual = (await decoded_r.graph().query(query)).result_set[0][0] - assert abs(actual["latitude"] - expected_lat) < 0.001 - assert abs(actual["longitude"] - expected_lon) < 0.001 - - -@pytest.mark.graph -@skip_if_resp_version(3) -async def test_index_response(decoded_r: redis.Redis): - result_set = await decoded_r.graph().query("CREATE INDEX ON :person(age)") - assert 1 == result_set.indices_created - - result_set = await decoded_r.graph().query("CREATE INDEX ON :person(age)") - assert 0 == result_set.indices_created - - result_set = await decoded_r.graph().query("DROP INDEX ON :person(age)") - assert 1 == result_set.indices_deleted - - with pytest.raises(ResponseError): - await decoded_r.graph().query("DROP INDEX ON :person(age)") - - -@pytest.mark.graph -@skip_if_resp_version(3) -async def test_stringify_query_result(decoded_r: redis.Redis): - graph = decoded_r.graph() - - john = Node( - alias="a", - label="person", - properties={ - "name": "John Doe", - "age": 33, - "gender": "male", - "status": "single", - }, - ) - graph.add_node(john) - - japan = Node(alias="b", label="country", properties={"name": "Japan"}) - graph.add_node(japan) - - edge = Edge(john, "visited", japan, properties={"purpose": "pleasure"}) - graph.add_edge(edge) - - assert ( - str(john) - == """(a:person{age:33,gender:"male",name:"John Doe",status:"single"})""" # noqa - ) - assert ( - str(edge) - == """(a:person{age:33,gender:"male",name:"John Doe",status:"single"})""" # noqa - + """-[:visited{purpose:"pleasure"}]->""" - + """(b:country{name:"Japan"})""" - ) - assert str(japan) == """(b:country{name:"Japan"})""" - - await graph.commit() - - query = """MATCH (p:person)-[v:visited {purpose:"pleasure"}]->(c:country) - RETURN p, v, c""" - - result = await graph.query(query) - person = result.result_set[0][0] - visit = result.result_set[0][1] - country = result.result_set[0][2] - - assert ( - str(person) - == """(:person{age:33,gender:"male",name:"John Doe",status:"single"})""" # noqa - ) - assert str(visit) == """()-[:visited{purpose:"pleasure"}]->()""" - assert str(country) == """(:country{name:"Japan"})""" - - await graph.delete() - - -@pytest.mark.graph -@skip_if_resp_version(3) -async def test_optional_match(decoded_r: redis.Redis): - # Build a graph of form (a)-[R]->(b) - node0 = Node(node_id=0, label="L1", properties={"value": "a"}) - node1 = Node(node_id=1, label="L1", properties={"value": "b"}) - - edge01 = Edge(node0, "R", node1, edge_id=0) - - graph = decoded_r.graph() - graph.add_node(node0) - graph.add_node(node1) - graph.add_edge(edge01) - await graph.flush() - - # Issue a query that collects all outgoing edges from both nodes - # (the second has none) - query = """MATCH (a) OPTIONAL MATCH (a)-[e]->(b) RETURN a, e, b ORDER BY a.value""" # noqa - expected_results = [[node0, edge01, node1], [node1, None, None]] - - result = await graph.query(query) - assert expected_results == result.result_set - - await graph.delete() - - -@pytest.mark.graph -@skip_if_resp_version(3) -async def test_cached_execution(decoded_r: redis.Redis): - await decoded_r.graph().query("CREATE ()") - - uncached_result = await decoded_r.graph().query( - "MATCH (n) RETURN n, $param", {"param": [0]} - ) - assert uncached_result.cached_execution is False - - # loop to make sure the query is cached on each thread on server - for x in range(0, 64): - cached_result = await decoded_r.graph().query( - "MATCH (n) RETURN n, $param", {"param": [0]} - ) - assert uncached_result.result_set == cached_result.result_set - - # should be cached on all threads by now - assert cached_result.cached_execution - - -@pytest.mark.graph -@skip_if_resp_version(3) -async def test_slowlog(decoded_r: redis.Redis): - create_query = """CREATE - (:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}), - (:Rider {name:'Dani Pedrosa'})-[:rides]->(:Team {name:'Honda'}), - (:Rider {name:'Andrea Dovizioso'})-[:rides]->(:Team {name:'Ducati'})""" - await decoded_r.graph().query(create_query) - - results = await decoded_r.graph().slowlog() - assert results[0][1] == "GRAPH.QUERY" - assert results[0][2] == create_query - - -@pytest.mark.xfail(strict=False) -@pytest.mark.graph -@skip_if_resp_version(3) -async def test_query_timeout(decoded_r: redis.Redis): - # Build a sample graph with 1000 nodes. - await decoded_r.graph().query("UNWIND range(0,1000) as val CREATE ({v: val})") - # Issue a long-running query with a 1-millisecond timeout. - with pytest.raises(ResponseError): - await decoded_r.graph().query("MATCH (a), (b), (c), (d) RETURN *", timeout=1) - assert False is False - - with pytest.raises(Exception): - await decoded_r.graph().query("RETURN 1", timeout="str") - assert False is False - - -@pytest.mark.graph -@skip_if_resp_version(3) -async def test_read_only_query(decoded_r: redis.Redis): - with pytest.raises(Exception): - # Issue a write query, specifying read-only true, - # this call should fail. - await decoded_r.graph().query("CREATE (p:person {name:'a'})", read_only=True) - assert False is False - - -@pytest.mark.graph -@skip_if_resp_version(3) -async def test_profile(decoded_r: redis.Redis): - q = """UNWIND range(1, 3) AS x CREATE (p:Person {v:x})""" - profile = (await decoded_r.graph().profile(q)).result_set - assert "Create | Records produced: 3" in profile - assert "Unwind | Records produced: 3" in profile - - q = "MATCH (p:Person) WHERE p.v > 1 RETURN p" - profile = (await decoded_r.graph().profile(q)).result_set - assert "Results | Records produced: 2" in profile - assert "Project | Records produced: 2" in profile - assert "Filter | Records produced: 2" in profile - assert "Node By Label Scan | (p:Person) | Records produced: 3" in profile - - -@skip_if_redis_enterprise() -@pytest.mark.graph -@skip_if_resp_version(3) -async def test_config(decoded_r: redis.Redis): - config_name = "RESULTSET_SIZE" - config_value = 3 - - # Set configuration - response = await decoded_r.graph().config(config_name, config_value, set=True) - assert response == "OK" - - # Make sure config been updated. - response = await decoded_r.graph().config(config_name, set=False) - expected_response = [config_name, config_value] - assert response == expected_response - - config_name = "QUERY_MEM_CAPACITY" - config_value = 1 << 20 # 1MB - - # Set configuration - response = await decoded_r.graph().config(config_name, config_value, set=True) - assert response == "OK" - - # Make sure config been updated. - response = await decoded_r.graph().config(config_name, set=False) - expected_response = [config_name, config_value] - assert response == expected_response - - # reset to default - await decoded_r.graph().config("QUERY_MEM_CAPACITY", 0, set=True) - await decoded_r.graph().config("RESULTSET_SIZE", -100, set=True) - - -@pytest.mark.onlynoncluster -@pytest.mark.graph -@skip_if_resp_version(3) -async def test_list_keys(decoded_r: redis.Redis): - result = await decoded_r.graph().list_keys() - assert result == [] - - await decoded_r.graph("G").query("CREATE (n)") - result = await decoded_r.graph().list_keys() - assert result == ["G"] - - await decoded_r.graph("X").query("CREATE (m)") - result = await decoded_r.graph().list_keys() - assert result == ["G", "X"] - - await decoded_r.delete("G") - await decoded_r.rename("X", "Z") - result = await decoded_r.graph().list_keys() - assert result == ["Z"] - - await decoded_r.delete("Z") - result = await decoded_r.graph().list_keys() - assert result == [] - - -@pytest.mark.graph -@skip_if_resp_version(3) -async def test_multi_label(decoded_r: redis.Redis): - redis_graph = decoded_r.graph("g") - - node = Node(label=["l", "ll"]) - redis_graph.add_node(node) - await redis_graph.commit() - - query = "MATCH (n) RETURN n" - result = await redis_graph.query(query) - result_node = result.result_set[0][0] - assert result_node == node - - try: - Node(label=1) - assert False - except AssertionError: - assert True - - try: - Node(label=["l", 1]) - assert False - except AssertionError: - assert True - - -@pytest.mark.graph -@skip_if_resp_version(3) -async def test_execution_plan(decoded_r: redis.Redis): - redis_graph = decoded_r.graph("execution_plan") - create_query = """CREATE - (:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}), - (:Rider {name:'Dani Pedrosa'})-[:rides]->(:Team {name:'Honda'}), - (:Rider {name:'Andrea Dovizioso'})-[:rides]->(:Team {name:'Ducati'})""" - await redis_graph.query(create_query) - - result = await redis_graph.execution_plan( - "MATCH (r:Rider)-[:rides]->(t:Team) WHERE t.name = $name RETURN r.name, t.name, $params", # noqa - {"name": "Yehuda"}, - ) - expected = "Results\n Project\n Conditional Traverse | (t)->(r:Rider)\n Filter\n Node By Label Scan | (t:Team)" # noqa - assert result == expected - - await redis_graph.delete() - - -@pytest.mark.graph -@skip_if_resp_version(3) -async def test_explain(decoded_r: redis.Redis): - redis_graph = decoded_r.graph("execution_plan") - # graph creation / population - create_query = """CREATE -(:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}), -(:Rider {name:'Dani Pedrosa'})-[:rides]->(:Team {name:'Honda'}), -(:Rider {name:'Andrea Dovizioso'})-[:rides]->(:Team {name:'Ducati'})""" - await redis_graph.query(create_query) - - result = await redis_graph.explain( - """MATCH (r:Rider)-[:rides]->(t:Team) -WHERE t.name = $name -RETURN r.name, t.name -UNION -MATCH (r:Rider)-[:rides]->(t:Team) -WHERE t.name = $name -RETURN r.name, t.name""", - {"name": "Yamaha"}, - ) - expected = """\ -Results -Distinct - Join - Project - Conditional Traverse | (t)->(r:Rider) - Filter - Node By Label Scan | (t:Team) - Project - Conditional Traverse | (t)->(r:Rider) - Filter - Node By Label Scan | (t:Team)""" - assert str(result).replace(" ", "").replace("\n", "") == expected.replace( - " ", "" - ).replace("\n", "") - - expected = Operation("Results").append_child( - Operation("Distinct").append_child( - Operation("Join") - .append_child( - Operation("Project").append_child( - Operation("Conditional Traverse", "(t)->(r:Rider)").append_child( - Operation("Filter").append_child( - Operation("Node By Label Scan", "(t:Team)") - ) - ) - ) - ) - .append_child( - Operation("Project").append_child( - Operation("Conditional Traverse", "(t)->(r:Rider)").append_child( - Operation("Filter").append_child( - Operation("Node By Label Scan", "(t:Team)") - ) - ) - ) - ) - ) - ) - - assert result.structured_plan == expected - - result = await redis_graph.explain( - """MATCH (r:Rider), (t:Team) - RETURN r.name, t.name""" - ) - expected = """\ -Results -Project - Cartesian Product - Node By Label Scan | (r:Rider) - Node By Label Scan | (t:Team)""" - assert str(result).replace(" ", "").replace("\n", "") == expected.replace( - " ", "" - ).replace("\n", "") - - expected = Operation("Results").append_child( - Operation("Project").append_child( - Operation("Cartesian Product") - .append_child(Operation("Node By Label Scan")) - .append_child(Operation("Node By Label Scan")) - ) - ) - - assert result.structured_plan == expected - - await redis_graph.delete() diff --git a/tests/test_graph.py b/tests/test_graph.py deleted file mode 100644 index fd08385667..0000000000 --- a/tests/test_graph.py +++ /dev/null @@ -1,656 +0,0 @@ -from unittest.mock import patch - -import pytest -from redis import Redis -from redis.commands.graph import Edge, Node, Path -from redis.commands.graph.execution_plan import Operation -from redis.commands.graph.query_result import ( - CACHED_EXECUTION, - INDICES_CREATED, - INDICES_DELETED, - INTERNAL_EXECUTION_TIME, - LABELS_ADDED, - LABELS_REMOVED, - NODES_CREATED, - NODES_DELETED, - PROPERTIES_REMOVED, - PROPERTIES_SET, - RELATIONSHIPS_CREATED, - RELATIONSHIPS_DELETED, - QueryResult, -) -from redis.exceptions import ResponseError -from tests.conftest import _get_client, skip_if_redis_enterprise, skip_if_resp_version - - -@pytest.fixture -def client(request, stack_url): - r = _get_client( - Redis, request, decode_responses=True, from_url="redis://localhost:6480" - ) - r.flushdb() - return r - - -@pytest.mark.graph -@skip_if_resp_version(3) -def test_bulk(client): - with pytest.raises(NotImplementedError): - client.graph().bulk() - client.graph().bulk(foo="bar!") - - -@pytest.mark.graph -def test_graph_creation_throws_deprecation_warning(client): - """Verify that a DeprecationWarning is raised when creating a Graph instance.""" - - match = "RedisGraph support is deprecated as of Redis Stack 7.2" - with pytest.warns(DeprecationWarning, match=match): - client.graph() - - -@pytest.mark.graph -@skip_if_resp_version(3) -def test_graph_creation(client): - graph = client.graph() - - john = Node( - label="person", - properties={ - "name": "John Doe", - "age": 33, - "gender": "male", - "status": "single", - }, - ) - graph.add_node(john) - japan = Node(label="country", properties={"name": "Japan"}) - - graph.add_node(japan) - edge = Edge(john, "visited", japan, properties={"purpose": "pleasure"}) - graph.add_edge(edge) - - graph.commit() - - query = ( - 'MATCH (p:person)-[v:visited {purpose:"pleasure"}]->(c:country) RETURN p, v, c' - ) - - result = graph.query(query) - - person = result.result_set[0][0] - visit = result.result_set[0][1] - country = result.result_set[0][2] - - assert person == john - assert visit.properties == edge.properties - assert country == japan - - query = """RETURN [1, 2.3, "4", true, false, null]""" - result = graph.query(query) - assert [1, 2.3, "4", True, False, None] == result.result_set[0][0] - - # All done, remove graph. - graph.delete() - - -@pytest.mark.graph -@skip_if_resp_version(3) -def test_array_functions(client): - query = """CREATE (p:person{name:'a',age:32, array:[0,1,2]})""" - client.graph().query(query) - - query = """WITH [0,1,2] as x return x""" - result = client.graph().query(query) - assert [0, 1, 2] == result.result_set[0][0] - - query = """MATCH(n) return collect(n)""" - result = client.graph().query(query) - - a = Node( - node_id=0, - label="person", - properties={"name": "a", "age": 32, "array": [0, 1, 2]}, - ) - - assert [a] == result.result_set[0][0] - - -@pytest.mark.graph -@skip_if_resp_version(3) -def test_path(client): - node0 = Node(node_id=0, label="L1") - node1 = Node(node_id=1, label="L1") - edge01 = Edge(node0, "R1", node1, edge_id=0, properties={"value": 1}) - - graph = client.graph() - graph.add_node(node0) - graph.add_node(node1) - graph.add_edge(edge01) - graph.flush() - - path01 = Path.new_empty_path().add_node(node0).add_edge(edge01).add_node(node1) - expected_results = [[path01]] - - query = "MATCH p=(:L1)-[:R1]->(:L1) RETURN p ORDER BY p" - result = graph.query(query) - assert expected_results == result.result_set - - -@pytest.mark.graph -@skip_if_resp_version(3) -def test_param(client): - params = [1, 2.3, "str", True, False, None, [0, 1, 2], r"\" RETURN 1337 //"] - query = "RETURN $param" - for param in params: - result = client.graph().query(query, {"param": param}) - expected_results = [[param]] - assert expected_results == result.result_set - - -@pytest.mark.graph -@skip_if_resp_version(3) -def test_map(client): - query = "RETURN {a:1, b:'str', c:NULL, d:[1,2,3], e:True, f:{x:1, y:2}}" - - actual = client.graph().query(query).result_set[0][0] - expected = { - "a": 1, - "b": "str", - "c": None, - "d": [1, 2, 3], - "e": True, - "f": {"x": 1, "y": 2}, - } - - assert actual == expected - - -@pytest.mark.graph -@skip_if_resp_version(3) -def test_point(client): - query = "RETURN point({latitude: 32.070794860, longitude: 34.820751118})" - expected_lat = 32.070794860 - expected_lon = 34.820751118 - actual = client.graph().query(query).result_set[0][0] - assert abs(actual["latitude"] - expected_lat) < 0.001 - assert abs(actual["longitude"] - expected_lon) < 0.001 - - query = "RETURN point({latitude: 32, longitude: 34.0})" - expected_lat = 32 - expected_lon = 34 - actual = client.graph().query(query).result_set[0][0] - assert abs(actual["latitude"] - expected_lat) < 0.001 - assert abs(actual["longitude"] - expected_lon) < 0.001 - - -@pytest.mark.graph -@skip_if_resp_version(3) -def test_index_response(client): - result_set = client.graph().query("CREATE INDEX ON :person(age)") - assert 1 == result_set.indices_created - - result_set = client.graph().query("CREATE INDEX ON :person(age)") - assert 0 == result_set.indices_created - - result_set = client.graph().query("DROP INDEX ON :person(age)") - assert 1 == result_set.indices_deleted - - with pytest.raises(ResponseError): - client.graph().query("DROP INDEX ON :person(age)") - - -@pytest.mark.graph -@skip_if_resp_version(3) -def test_stringify_query_result(client): - graph = client.graph() - - john = Node( - alias="a", - label="person", - properties={ - "name": "John Doe", - "age": 33, - "gender": "male", - "status": "single", - }, - ) - graph.add_node(john) - - japan = Node(alias="b", label="country", properties={"name": "Japan"}) - graph.add_node(japan) - - edge = Edge(john, "visited", japan, properties={"purpose": "pleasure"}) - graph.add_edge(edge) - - assert ( - str(john) - == """(a:person{age:33,gender:"male",name:"John Doe",status:"single"})""" # noqa - ) - assert ( - str(edge) - == """(a:person{age:33,gender:"male",name:"John Doe",status:"single"})""" # noqa - + """-[:visited{purpose:"pleasure"}]->""" - + """(b:country{name:"Japan"})""" - ) - assert str(japan) == """(b:country{name:"Japan"})""" - - graph.commit() - - query = """MATCH (p:person)-[v:visited {purpose:"pleasure"}]->(c:country) - RETURN p, v, c""" - - result = client.graph().query(query) - person = result.result_set[0][0] - visit = result.result_set[0][1] - country = result.result_set[0][2] - - assert ( - str(person) - == """(:person{age:33,gender:"male",name:"John Doe",status:"single"})""" # noqa - ) - assert str(visit) == """()-[:visited{purpose:"pleasure"}]->()""" - assert str(country) == """(:country{name:"Japan"})""" - - graph.delete() - - -@pytest.mark.graph -@skip_if_resp_version(3) -def test_optional_match(client): - # Build a graph of form (a)-[R]->(b) - node0 = Node(node_id=0, label="L1", properties={"value": "a"}) - node1 = Node(node_id=1, label="L1", properties={"value": "b"}) - - edge01 = Edge(node0, "R", node1, edge_id=0) - - graph = client.graph() - graph.add_node(node0) - graph.add_node(node1) - graph.add_edge(edge01) - graph.flush() - - # Issue a query that collects all outgoing edges from both nodes - # (the second has none) - query = """MATCH (a) OPTIONAL MATCH (a)-[e]->(b) RETURN a, e, b ORDER BY a.value""" # noqa - expected_results = [[node0, edge01, node1], [node1, None, None]] - - result = client.graph().query(query) - assert expected_results == result.result_set - - graph.delete() - - -@pytest.mark.graph -@skip_if_resp_version(3) -def test_cached_execution(client): - client.graph().query("CREATE ()") - - uncached_result = client.graph().query("MATCH (n) RETURN n, $param", {"param": [0]}) - assert uncached_result.cached_execution is False - - # loop to make sure the query is cached on each thread on server - for x in range(0, 64): - cached_result = client.graph().query( - "MATCH (n) RETURN n, $param", {"param": [0]} - ) - assert uncached_result.result_set == cached_result.result_set - - # should be cached on all threads by now - assert cached_result.cached_execution - - -@pytest.mark.graph -@skip_if_resp_version(3) -def test_slowlog(client): - create_query = """CREATE (:Rider - {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}), - (:Rider {name:'Dani Pedrosa'})-[:rides]->(:Team {name:'Honda'}), - (:Rider {name:'Andrea Dovizioso'})-[:rides]->(:Team {name:'Ducati'})""" - client.graph().query(create_query) - - results = client.graph().slowlog() - assert results[0][1] == "GRAPH.QUERY" - assert results[0][2] == create_query - - -@pytest.mark.graph -@skip_if_resp_version(3) -@pytest.mark.xfail(strict=False) -def test_query_timeout(client): - # Build a sample graph with 1000 nodes. - client.graph().query("UNWIND range(0,1000) as val CREATE ({v: val})") - # Issue a long-running query with a 1-millisecond timeout. - with pytest.raises(ResponseError): - client.graph().query("MATCH (a), (b), (c), (d) RETURN *", timeout=1) - assert False is False - - with pytest.raises(Exception): - client.graph().query("RETURN 1", timeout="str") - assert False is False - - -@pytest.mark.graph -@skip_if_resp_version(3) -def test_read_only_query(client): - with pytest.raises(Exception): - # Issue a write query, specifying read-only true, - # this call should fail. - client.graph().query("CREATE (p:person {name:'a'})", read_only=True) - assert False is False - - -@pytest.mark.graph -@skip_if_resp_version(3) -def test_profile(client): - q = """UNWIND range(1, 3) AS x CREATE (p:Person {v:x})""" - profile = client.graph().profile(q).result_set - assert "Create | Records produced: 3" in profile - assert "Unwind | Records produced: 3" in profile - - q = "MATCH (p:Person) WHERE p.v > 1 RETURN p" - profile = client.graph().profile(q).result_set - assert "Results | Records produced: 2" in profile - assert "Project | Records produced: 2" in profile - assert "Filter | Records produced: 2" in profile - assert "Node By Label Scan | (p:Person) | Records produced: 3" in profile - - -@pytest.mark.graph -@skip_if_resp_version(3) -@skip_if_redis_enterprise() -def test_config(client): - config_name = "RESULTSET_SIZE" - config_value = 3 - - # Set configuration - response = client.graph().config(config_name, config_value, set=True) - assert response == "OK" - - # Make sure config been updated. - response = client.graph().config(config_name, set=False) - expected_response = [config_name, config_value] - assert response == expected_response - - config_name = "QUERY_MEM_CAPACITY" - config_value = 1 << 20 # 1MB - - # Set configuration - response = client.graph().config(config_name, config_value, set=True) - assert response == "OK" - - # Make sure config been updated. - response = client.graph().config(config_name, set=False) - expected_response = [config_name, config_value] - assert response == expected_response - - # reset to default - client.graph().config("QUERY_MEM_CAPACITY", 0, set=True) - client.graph().config("RESULTSET_SIZE", -100, set=True) - - -@pytest.mark.onlynoncluster -@pytest.mark.graph -@skip_if_resp_version(3) -def test_list_keys(client): - result = client.graph().list_keys() - assert result == [] - - client.graph("G").query("CREATE (n)") - result = client.graph().list_keys() - assert result == ["G"] - - client.graph("X").query("CREATE (m)") - result = client.graph().list_keys() - assert result == ["G", "X"] - - client.delete("G") - client.rename("X", "Z") - result = client.graph().list_keys() - assert result == ["Z"] - - client.delete("Z") - result = client.graph().list_keys() - assert result == [] - - -@pytest.mark.graph -@skip_if_resp_version(3) -def test_multi_label(client): - redis_graph = client.graph("g") - - node = Node(label=["l", "ll"]) - redis_graph.add_node(node) - redis_graph.commit() - - query = "MATCH (n) RETURN n" - result = redis_graph.query(query) - result_node = result.result_set[0][0] - assert result_node == node - - try: - Node(label=1) - assert False - except AssertionError: - assert True - - try: - Node(label=["l", 1]) - assert False - except AssertionError: - assert True - - -@pytest.mark.graph -@skip_if_resp_version(3) -def test_cache_sync(client): - pass - return - # This test verifies that client internal graph schema cache stays - # in sync with the graph schema - # - # Client B will try to get Client A out of sync by: - # 1. deleting the graph - # 2. reconstructing the graph in a different order, this will casuse - # a difference in the current mapping between string IDs and the - # mapping Client A is aware of - # - # Client A should pick up on the changes by comparing graph versions - # and resyncing its cache. - - A = client.graph("cache-sync") - B = client.graph("cache-sync") - - # Build order: - # 1. introduce label 'L' and 'K' - # 2. introduce attribute 'x' and 'q' - # 3. introduce relationship-type 'R' and 'S' - - A.query("CREATE (:L)") - B.query("CREATE (:K)") - A.query("MATCH (n) SET n.x = 1") - B.query("MATCH (n) SET n.q = 1") - A.query("MATCH (n) CREATE (n)-[:R]->()") - B.query("MATCH (n) CREATE (n)-[:S]->()") - - # Cause client A to populate its cache - A.query("MATCH (n)-[e]->() RETURN n, e") - - assert len(A._labels) == 2 - assert len(A._properties) == 2 - assert len(A._relationship_types) == 2 - assert A._labels[0] == "L" - assert A._labels[1] == "K" - assert A._properties[0] == "x" - assert A._properties[1] == "q" - assert A._relationship_types[0] == "R" - assert A._relationship_types[1] == "S" - - # Have client B reconstruct the graph in a different order. - B.delete() - - # Build order: - # 1. introduce relationship-type 'R' - # 2. introduce label 'L' - # 3. introduce attribute 'x' - B.query("CREATE ()-[:S]->()") - B.query("CREATE ()-[:R]->()") - B.query("CREATE (:K)") - B.query("CREATE (:L)") - B.query("MATCH (n) SET n.q = 1") - B.query("MATCH (n) SET n.x = 1") - - # A's internal cached mapping is now out of sync - # issue a query and make sure A's cache is synced. - A.query("MATCH (n)-[e]->() RETURN n, e") - - assert len(A._labels) == 2 - assert len(A._properties) == 2 - assert len(A._relationship_types) == 2 - assert A._labels[0] == "K" - assert A._labels[1] == "L" - assert A._properties[0] == "q" - assert A._properties[1] == "x" - assert A._relationship_types[0] == "S" - assert A._relationship_types[1] == "R" - - -@pytest.mark.graph -@skip_if_resp_version(3) -def test_execution_plan(client): - redis_graph = client.graph("execution_plan") - create_query = """CREATE - (:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}), - (:Rider {name:'Dani Pedrosa'})-[:rides]->(:Team {name:'Honda'}), - (:Rider {name:'Andrea Dovizioso'})-[:rides]->(:Team {name:'Ducati'})""" - redis_graph.query(create_query) - - result = redis_graph.execution_plan( - "MATCH (r:Rider)-[:rides]->(t:Team) WHERE t.name = $name RETURN r.name, t.name, $params", # noqa - {"name": "Yehuda"}, - ) - expected = "Results\n Project\n Conditional Traverse | (t)->(r:Rider)\n Filter\n Node By Label Scan | (t:Team)" # noqa - assert result == expected - - redis_graph.delete() - - -@pytest.mark.graph -@skip_if_resp_version(3) -def test_explain(client): - redis_graph = client.graph("execution_plan") - # graph creation / population - create_query = """CREATE -(:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}), -(:Rider {name:'Dani Pedrosa'})-[:rides]->(:Team {name:'Honda'}), -(:Rider {name:'Andrea Dovizioso'})-[:rides]->(:Team {name:'Ducati'})""" - redis_graph.query(create_query) - - result = redis_graph.explain( - """MATCH (r:Rider)-[:rides]->(t:Team) -WHERE t.name = $name -RETURN r.name, t.name -UNION -MATCH (r:Rider)-[:rides]->(t:Team) -WHERE t.name = $name -RETURN r.name, t.name""", - {"name": "Yamaha"}, - ) - expected = """\ -Results -Distinct - Join - Project - Conditional Traverse | (t)->(r:Rider) - Filter - Node By Label Scan | (t:Team) - Project - Conditional Traverse | (t)->(r:Rider) - Filter - Node By Label Scan | (t:Team)""" - assert str(result).replace(" ", "").replace("\n", "") == expected.replace( - " ", "" - ).replace("\n", "") - - expected = Operation("Results").append_child( - Operation("Distinct").append_child( - Operation("Join") - .append_child( - Operation("Project").append_child( - Operation("Conditional Traverse", "(t)->(r:Rider)").append_child( - Operation("Filter").append_child( - Operation("Node By Label Scan", "(t:Team)") - ) - ) - ) - ) - .append_child( - Operation("Project").append_child( - Operation("Conditional Traverse", "(t)->(r:Rider)").append_child( - Operation("Filter").append_child( - Operation("Node By Label Scan", "(t:Team)") - ) - ) - ) - ) - ) - ) - - assert result.structured_plan == expected - - result = redis_graph.explain( - """MATCH (r:Rider), (t:Team) - RETURN r.name, t.name""" - ) - expected = """\ -Results -Project - Cartesian Product - Node By Label Scan | (r:Rider) - Node By Label Scan | (t:Team)""" - assert str(result).replace(" ", "").replace("\n", "") == expected.replace( - " ", "" - ).replace("\n", "") - - expected = Operation("Results").append_child( - Operation("Project").append_child( - Operation("Cartesian Product") - .append_child(Operation("Node By Label Scan")) - .append_child(Operation("Node By Label Scan")) - ) - ) - - assert result.structured_plan == expected - - redis_graph.delete() - - -@pytest.mark.graph -@skip_if_resp_version(3) -def test_resultset_statistics(client): - with patch.object(target=QueryResult, attribute="_get_stat") as mock_get_stats: - result = client.graph().query("RETURN 1") - result.labels_added - mock_get_stats.assert_called_with(LABELS_ADDED) - result.labels_removed - mock_get_stats.assert_called_with(LABELS_REMOVED) - result.nodes_created - mock_get_stats.assert_called_with(NODES_CREATED) - result.nodes_deleted - mock_get_stats.assert_called_with(NODES_DELETED) - result.properties_set - mock_get_stats.assert_called_with(PROPERTIES_SET) - result.properties_removed - mock_get_stats.assert_called_with(PROPERTIES_REMOVED) - result.relationships_created - mock_get_stats.assert_called_with(RELATIONSHIPS_CREATED) - result.relationships_deleted - mock_get_stats.assert_called_with(RELATIONSHIPS_DELETED) - result.indices_created - mock_get_stats.assert_called_with(INDICES_CREATED) - result.indices_deleted - mock_get_stats.assert_called_with(INDICES_DELETED) - result.cached_execution - mock_get_stats.assert_called_with(CACHED_EXECUTION) - result.run_time_ms - mock_get_stats.assert_called_with(INTERNAL_EXECUTION_TIME) diff --git a/tests/test_graph_utils/__init__.py b/tests/test_graph_utils/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/test_graph_utils/test_edge.py b/tests/test_graph_utils/test_edge.py deleted file mode 100644 index 09e6fa08ed..0000000000 --- a/tests/test_graph_utils/test_edge.py +++ /dev/null @@ -1,75 +0,0 @@ -import pytest -from redis.commands.graph import edge, node - - -@pytest.mark.graph -def test_init(): - with pytest.raises(AssertionError): - edge.Edge(None, None, None) - edge.Edge(node.Node(), None, None) - edge.Edge(None, None, node.Node()) - - assert isinstance( - edge.Edge(node.Node(node_id=1), None, node.Node(node_id=2)), edge.Edge - ) - - -@pytest.mark.graph -def test_to_string(): - props_result = edge.Edge( - node.Node(), None, node.Node(), properties={"a": "a", "b": 10} - ).to_string() - assert props_result == '{a:"a",b:10}' - - no_props_result = edge.Edge( - node.Node(), None, node.Node(), properties={} - ).to_string() - assert no_props_result == "" - - -@pytest.mark.graph -def test_stringify(): - john = node.Node( - alias="a", - label="person", - properties={"name": "John Doe", "age": 33, "someArray": [1, 2, 3]}, - ) - japan = node.Node(alias="b", label="country", properties={"name": "Japan"}) - edge_with_relation = edge.Edge( - john, "visited", japan, properties={"purpose": "pleasure"} - ) - assert ( - '(a:person{age:33,name:"John Doe",someArray:[1, 2, 3]})' - '-[:visited{purpose:"pleasure"}]->' - '(b:country{name:"Japan"})' == str(edge_with_relation) - ) - - edge_no_relation_no_props = edge.Edge(japan, "", john) - assert ( - '(b:country{name:"Japan"})' - "-[]->" - '(a:person{age:33,name:"John Doe",someArray:[1, 2, 3]})' - == str(edge_no_relation_no_props) - ) - - edge_only_props = edge.Edge(john, "", japan, properties={"a": "b", "c": 3}) - assert ( - '(a:person{age:33,name:"John Doe",someArray:[1, 2, 3]})' - '-[{a:"b",c:3}]->' - '(b:country{name:"Japan"})' == str(edge_only_props) - ) - - -@pytest.mark.graph -def test_comparison(): - node1 = node.Node(node_id=1) - node2 = node.Node(node_id=2) - node3 = node.Node(node_id=3) - - edge1 = edge.Edge(node1, None, node2) - assert edge1 == edge.Edge(node1, None, node2) - assert edge1 != edge.Edge(node1, "bla", node2) - assert edge1 != edge.Edge(node1, None, node3) - assert edge1 != edge.Edge(node3, None, node2) - assert edge1 != edge.Edge(node2, None, node1) - assert edge1 != edge.Edge(node1, None, node2, properties={"a": 10}) diff --git a/tests/test_graph_utils/test_node.py b/tests/test_graph_utils/test_node.py deleted file mode 100644 index e9b8a54f43..0000000000 --- a/tests/test_graph_utils/test_node.py +++ /dev/null @@ -1,51 +0,0 @@ -import pytest -from redis.commands.graph import node - - -@pytest.fixture -def fixture(): - no_args = node.Node() - no_props = node.Node(node_id=1, alias="alias", label="l") - props_only = node.Node(properties={"a": "a", "b": 10}) - no_label = node.Node(node_id=1, alias="alias", properties={"a": "a"}) - multi_label = node.Node(node_id=1, alias="alias", label=["l", "ll"]) - return no_args, no_props, props_only, no_label, multi_label - - -@pytest.mark.graph -def test_to_string(fixture): - no_args, no_props, props_only, no_label, multi_label = fixture - assert no_args.to_string() == "" - assert no_props.to_string() == "" - assert props_only.to_string() == '{a:"a",b:10}' - assert no_label.to_string() == '{a:"a"}' - assert multi_label.to_string() == "" - - -@pytest.mark.graph -def test_stringify(fixture): - no_args, no_props, props_only, no_label, multi_label = fixture - assert str(no_args) == "()" - assert str(no_props) == "(alias:l)" - assert str(props_only) == '({a:"a",b:10})' - assert str(no_label) == '(alias{a:"a"})' - assert str(multi_label) == "(alias:l:ll)" - - -@pytest.mark.graph -def test_comparison(fixture): - no_args, no_props, props_only, no_label, multi_label = fixture - - assert node.Node() == node.Node() - assert node.Node(node_id=1) == node.Node(node_id=1) - assert node.Node(node_id=1) != node.Node(node_id=2) - assert node.Node(node_id=1, alias="a") == node.Node(node_id=1, alias="b") - assert node.Node(node_id=1, alias="a") == node.Node(node_id=1, alias="a") - assert node.Node(node_id=1, label="a") == node.Node(node_id=1, label="a") - assert node.Node(node_id=1, label="a") != node.Node(node_id=1, label="b") - assert node.Node(node_id=1, alias="a", label="l") == node.Node( - node_id=1, alias="a", label="l" - ) - assert node.Node(alias="a", label="l") != node.Node(alias="a", label="l1") - assert node.Node(properties={"a": 10}) == node.Node(properties={"a": 10}) - assert node.Node() != node.Node(properties={"a": 10}) diff --git a/tests/test_graph_utils/test_path.py b/tests/test_graph_utils/test_path.py deleted file mode 100644 index 33ca041cfa..0000000000 --- a/tests/test_graph_utils/test_path.py +++ /dev/null @@ -1,90 +0,0 @@ -import pytest -from redis.commands.graph import edge, node, path - - -@pytest.mark.graph -def test_init(): - with pytest.raises(TypeError): - path.Path(None, None) - path.Path([], None) - path.Path(None, []) - - assert isinstance(path.Path([], []), path.Path) - - -@pytest.mark.graph -def test_new_empty_path(): - new_empty_path = path.Path.new_empty_path() - assert isinstance(new_empty_path, path.Path) - assert new_empty_path._nodes == [] - assert new_empty_path._edges == [] - - -@pytest.mark.graph -def test_wrong_flows(): - node_1 = node.Node(node_id=1) - node_2 = node.Node(node_id=2) - node_3 = node.Node(node_id=3) - - edge_1 = edge.Edge(node_1, None, node_2) - edge_2 = edge.Edge(node_1, None, node_3) - - p = path.Path.new_empty_path() - with pytest.raises(AssertionError): - p.add_edge(edge_1) - - p.add_node(node_1) - with pytest.raises(AssertionError): - p.add_node(node_2) - - p.add_edge(edge_1) - with pytest.raises(AssertionError): - p.add_edge(edge_2) - - -@pytest.mark.graph -def test_nodes_and_edges(): - node_1 = node.Node(node_id=1) - node_2 = node.Node(node_id=2) - edge_1 = edge.Edge(node_1, None, node_2) - - p = path.Path.new_empty_path() - assert p.nodes() == [] - p.add_node(node_1) - assert [] == p.edges() - assert 0 == p.edge_count() - assert [node_1] == p.nodes() - assert node_1 == p.get_node(0) - assert node_1 == p.first_node() - assert node_1 == p.last_node() - assert 1 == p.nodes_count() - p.add_edge(edge_1) - assert [edge_1] == p.edges() - assert 1 == p.edge_count() - assert edge_1 == p.get_relationship(0) - p.add_node(node_2) - assert [node_1, node_2] == p.nodes() - assert node_1 == p.first_node() - assert node_2 == p.last_node() - assert 2 == p.nodes_count() - - -@pytest.mark.graph -def test_compare(): - node_1 = node.Node(node_id=1) - node_2 = node.Node(node_id=2) - edge_1 = edge.Edge(node_1, None, node_2) - - assert path.Path.new_empty_path() == path.Path.new_empty_path() - assert path.Path(nodes=[node_1, node_2], edges=[edge_1]) == path.Path( - nodes=[node_1, node_2], edges=[edge_1] - ) - assert path.Path(nodes=[node_1], edges=[]) != path.Path(nodes=[], edges=[]) - assert path.Path(nodes=[node_1], edges=[]) != path.Path(nodes=[], edges=[]) - assert path.Path(nodes=[node_1], edges=[]) != path.Path(nodes=[node_2], edges=[]) - assert path.Path(nodes=[node_1], edges=[edge_1]) != path.Path( - nodes=[node_1], edges=[] - ) - assert path.Path(nodes=[node_1], edges=[edge_1]) != path.Path( - nodes=[node_2], edges=[edge_1] - ) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 06265d382e..367700547f 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -5,7 +5,6 @@ list_or_args, nativestr, parse_to_list, - quote_string, random_string, ) @@ -41,15 +40,3 @@ def test_random_string(): assert len(random_string(15)) == 15 for a in random_string(): assert a in string.ascii_lowercase - - -def test_quote_string(): - assert quote_string("hello world!") == '"hello world!"' - assert quote_string("") == '""' - assert quote_string("hello world!") == '"hello world!"' - assert quote_string("abc") == '"abc"' - assert quote_string("") == '""' - assert quote_string('"') == r'"\""' - assert quote_string(r"foo \ bar") == r'"foo \\ bar"' - assert quote_string(r"foo \" bar") == r'"foo \\\" bar"' - assert quote_string('a"a') == r'"a\"a"' From 6c7acbdb0726ebfc932695ccb7d72f8d49e64254 Mon Sep 17 00:00:00 2001 From: andy-stark-redis <164213578+andy-stark-redis@users.noreply.github.com> Date: Tue, 11 Mar 2025 14:28:36 +0000 Subject: [PATCH 062/113] DOC-4736 added geo indexing examples (#3485) --- doctests/geo_index.py | 153 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 doctests/geo_index.py diff --git a/doctests/geo_index.py b/doctests/geo_index.py new file mode 100644 index 0000000000..5e34ec8866 --- /dev/null +++ b/doctests/geo_index.py @@ -0,0 +1,153 @@ +# EXAMPLE: geoindex +import redis +from redis.commands.json.path import Path +from redis.commands.search.field import TextField, GeoField, GeoShapeField +from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.query import Query + +r = redis.Redis() +# REMOVE_START +try: + r.ft("productidx").dropindex(True) +except redis.exceptions.ResponseError: + pass + +try: + r.ft("geomidx").dropindex(True) +except redis.exceptions.ResponseError: + pass + +r.delete("product:46885", "product:46886", "shape:1", "shape:2", "shape:3", "shape:4") +# REMOVE_END + +# STEP_START create_geo_idx +geo_schema = ( + GeoField("$.location", as_name="location") +) + +geo_index_create_result = r.ft("productidx").create_index( + geo_schema, + definition=IndexDefinition( + prefix=["product:"], index_type=IndexType.JSON + ) +) +print(geo_index_create_result) # >>> True +# STEP_END +# REMOVE_START +assert geo_index_create_result +# REMOVE_END + +# STEP_START add_geo_json +prd46885 = { + "description": "Navy Blue Slippers", + "price": 45.99, + "city": "Denver", + "location": "-104.991531, 39.742043" +} + +json_add_result_1 = r.json().set("product:46885", Path.root_path(), prd46885) +print(json_add_result_1) # >>> True + +prd46886 = { + "description": "Bright Green Socks", + "price": 25.50, + "city": "Fort Collins", + "location": "-105.0618814,40.5150098" +} + +json_add_result_2 = r.json().set("product:46886", Path.root_path(), prd46886) +print(json_add_result_2) # >>> True +# STEP_END +# REMOVE_START +assert json_add_result_1 +assert json_add_result_2 +# REMOVE_END + +# STEP_START geo_query +geo_result = r.ft("productidx").search( + "@location:[-104.800644 38.846127 100 mi]" +) +print(geo_result) +# >>> Result{1 total, docs: [Document {'id': 'product:46885'... +# STEP_END +# REMOVE_START +assert len(geo_result.docs) == 1 +assert geo_result.docs[0]["id"] == "product:46885" +# REMOVE_END + +# STEP_START create_gshape_idx +geom_schema = ( + TextField("$.name", as_name="name"), + GeoShapeField( + "$.geom", as_name="geom", coord_system=GeoShapeField.FLAT + ) +) + +geom_index_create_result = r.ft("geomidx").create_index( + geom_schema, + definition=IndexDefinition( + prefix=["shape:"], index_type=IndexType.JSON + ) +) +print(geom_index_create_result) # True +# STEP_END +# REMOVE_START +assert geom_index_create_result +# REMOVE_END + +# STEP_START add_gshape_json +shape1 = { + "name": "Green Square", + "geom": "POLYGON ((1 1, 1 3, 3 3, 3 1, 1 1))" +} + +gm_json_res_1 = r.json().set("shape:1", Path.root_path(), shape1) +print(gm_json_res_1) # >>> True + +shape2 = { + "name": "Red Rectangle", + "geom": "POLYGON ((2 2.5, 2 3.5, 3.5 3.5, 3.5 2.5, 2 2.5))" +} + +gm_json_res_2 = r.json().set("shape:2", Path.root_path(), shape2) +print(gm_json_res_2) # >>> True + +shape3 = { + "name": "Blue Triangle", + "geom": "POLYGON ((3.5 1, 3.75 2, 4 1, 3.5 1))" +} + +gm_json_res_3 = r.json().set("shape:3", Path.root_path(), shape3) +print(gm_json_res_3) # >>> True + +shape4 = { + "name": "Purple Point", + "geom": "POINT (2 2)" +} + +gm_json_res_4 = r.json().set("shape:4", Path.root_path(), shape4) +print(gm_json_res_4) # >>> True +# STEP_END +# REMOVE_START +assert gm_json_res_1 +assert gm_json_res_2 +assert gm_json_res_3 +assert gm_json_res_4 +# REMOVE_END + +# STEP_START gshape_query +geom_result = r.ft("geomidx").search( + Query( + "(-@name:(Green Square) @geom:[WITHIN $qshape])" + ).dialect(4).paging(0, 1), + query_params={ + "qshape": "POLYGON ((1 1, 1 3, 3 3, 3 1, 1 1))" + } +) +print(geom_result) +# >>> Result{1 total, docs: [Document {'id': 'shape:4'... +# STEP_END +# REMOVE_START +assert len(geom_result.docs) == 1 +assert geom_result.docs[0]["id"] == "shape:4" +# REMOVE_END From 4bd3656e512ab9dc4a4405de0226a689c320ae96 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Wed, 12 Mar 2025 09:26:51 +0200 Subject: [PATCH 063/113] Fixing async cluster pipeline execution when client is created with cluster_error_retry_attempts=0 (#3545) --- redis/asyncio/cluster.py | 23 +++++++++++------------ tests/test_asyncio/test_cluster.py | 13 ++++++++++++- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 51328ad95a..7b9b609417 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -1531,29 +1531,28 @@ async def execute( return [] try: - for _ in range(self._client.cluster_error_retry_attempts): - if self._client._initialize: - await self._client.initialize() - + retry_attempts = self._client.cluster_error_retry_attempts + while True: try: + if self._client._initialize: + await self._client.initialize() return await self._execute( self._client, self._command_stack, raise_on_error=raise_on_error, allow_redirections=allow_redirections, ) - except BaseException as e: - if type(e) in self.__class__.ERRORS_ALLOW_RETRY: - # Try again with the new cluster setup. - exception = e + + except self.__class__.ERRORS_ALLOW_RETRY as e: + if retry_attempts > 0: + # Try again with the new cluster setup. All other errors + # should be raised. + retry_attempts -= 1 await self._client.aclose() await asyncio.sleep(0.25) else: # All other errors should be raised. - raise - - # If it fails the configured number of times then raise an exception - raise exception + raise e finally: self._command_stack = [] diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 735b116c5d..f298503799 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -2694,6 +2694,17 @@ async def test_redis_cluster_pipeline(self, r: RedisCluster) -> None: ) assert result == [True, b"1", 1, {b"F": b"V"}, True, True, b"2", b"3", 1, 1, 1] + async def test_cluster_pipeline_execution_zero_cluster_err_retries( + self, r: RedisCluster + ) -> None: + """ + Test that we can run successfully cluster pipeline execute at least once when + cluster_error_retry_attempts is set to 0 + """ + r.cluster_error_retry_attempts = 0 + result = await r.pipeline().set("A", 1).get("A").delete("A").execute() + assert result == [True, b"1", 1] + async def test_multi_key_operation_with_a_single_slot( self, r: RedisCluster ) -> None: @@ -2754,7 +2765,7 @@ async def parse_response( await pipe.get(key).execute() assert ( node.parse_response.await_count - == 3 * r.cluster_error_retry_attempts - 2 + == 3 * r.cluster_error_retry_attempts + 1 ) async def test_connection_error_not_raised(self, r: RedisCluster) -> None: From d18922f8173a4084f73c91f3104ad1929cac644b Mon Sep 17 00:00:00 2001 From: Karolina Surma <33810531+befeleme@users.noreply.github.com> Date: Wed, 12 Mar 2025 12:23:02 +0100 Subject: [PATCH 064/113] Avoid the multiprocessing forkserver method (#3442) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Miro Hrončok --- tests/test_multiprocessing.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index 8b9e9fb90b..f60898ab86 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -1,6 +1,5 @@ import contextlib import multiprocessing -import sys import pytest import redis @@ -9,9 +8,6 @@ from .conftest import _get_client -if sys.platform == "darwin": - multiprocessing.set_start_method("fork", force=True) - @contextlib.contextmanager def exit_callback(callback, *args): @@ -22,6 +18,16 @@ def exit_callback(callback, *args): class TestMultiprocessing: + # On macOS and newly non-macOS POSIX systems (since Python 3.14), + # the default method has been changed to forkserver. + # The code in this module does not work with it, + # hence the explicit change to 'fork' + # See https://github.com/python/cpython/issues/125714 + if multiprocessing.get_start_method() == "forkserver": + _mp_context = multiprocessing.get_context(method="fork") + else: + _mp_context = multiprocessing.get_context() + # Test connection sharing between forks. # See issue #1085 for details. @@ -45,7 +51,7 @@ def target(conn): assert conn.read_response() == b"PONG" conn.disconnect() - proc = multiprocessing.Process(target=target, args=(conn,)) + proc = self._mp_context.Process(target=target, args=(conn,)) proc.start() proc.join(3) assert proc.exitcode == 0 @@ -75,7 +81,7 @@ def target(conn, ev): conn.send_command("ping") ev = multiprocessing.Event() - proc = multiprocessing.Process(target=target, args=(conn, ev)) + proc = self._mp_context.Process(target=target, args=(conn, ev)) proc.start() conn.disconnect() @@ -143,7 +149,7 @@ def target(pool): assert conn.send_command("ping") is None assert conn.read_response() == b"PONG" - proc = multiprocessing.Process(target=target, args=(pool,)) + proc = self._mp_context.Process(target=target, args=(pool,)) proc.start() proc.join(3) assert proc.exitcode == 0 @@ -181,7 +187,7 @@ def target(pool, disconnect_event): ev = multiprocessing.Event() - proc = multiprocessing.Process(target=target, args=(pool, ev)) + proc = self._mp_context.Process(target=target, args=(pool, ev)) proc.start() pool.disconnect() @@ -197,7 +203,7 @@ def target(client): assert client.ping() is True del client - proc = multiprocessing.Process(target=target, args=(r,)) + proc = self._mp_context.Process(target=target, args=(r,)) proc.start() proc.join(3) assert proc.exitcode == 0 From c39c48d76b03385ba427fb21248c8c7a7a54f321 Mon Sep 17 00:00:00 2001 From: Nicolas Noirbent Date: Wed, 12 Mar 2025 18:48:10 +0100 Subject: [PATCH 065/113] Avoid stacktrace on process exit in Client.__del__() (#3397) Client.close() may call ConnectionPool.release() or ConnectionPool.disconnect(); both methods may end up calling os.getpid() (through ConnectionPool._checkpid() or threading.Lock() (through ConnectionPool.reset()). As mentioned in the Python documentation [1], at interpreter shutdown, module globals (in this case, the os and threading module references) may be deleted or set to None before __del__() methods are called. This causes an AttributeError to be raised when trying to run e.g. os.getpid(); while the error is ignored by the interpreter, the traceback is still printed out to stderr. Closes #3014 [1] https://docs.python.org/3/reference/datamodel.html#object.__del__ Co-authored-by: petyaslavova --- redis/client.py | 5 ++++- redis/cluster.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/redis/client.py b/redis/client.py index ea29a864ce..2c4a1fadff 100755 --- a/redis/client.py +++ b/redis/client.py @@ -570,7 +570,10 @@ def __exit__(self, exc_type, exc_value, traceback): self.close() def __del__(self): - self.close() + try: + self.close() + except Exception: + pass def close(self) -> None: # In case a connection property does not yet exist diff --git a/redis/cluster.py b/redis/cluster.py index 13253ec896..118e67382c 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -663,7 +663,10 @@ def __exit__(self, exc_type, exc_value, traceback): self.close() def __del__(self): - self.close() + try: + self.close() + except Exception: + pass def disconnect_connection_pools(self): for node in self.get_nodes(): From 527a98fb4bc56147f1fb2260f4e7baac39c485d9 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Thu, 13 Mar 2025 10:22:00 +0200 Subject: [PATCH 066/113] Increasing the operations-per-run for stale issues GH action (#3556) --- .github/workflows/stale-issues.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/stale-issues.yml b/.github/workflows/stale-issues.yml index 445af1c818..f3bc21bbf3 100644 --- a/.github/workflows/stale-issues.yml +++ b/.github/workflows/stale-issues.yml @@ -21,5 +21,5 @@ jobs: days-before-close: 30 stale-issue-label: "Stale" stale-pr-label: "Stale" - operations-per-run: 10 + operations-per-run: 20 remove-stale-when-updated: true From 24bb8f4898968bb550429aa62308188dbe13c85c Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Thu, 13 Mar 2025 13:19:21 +0200 Subject: [PATCH 067/113] Removing support for RedisGears module. (#3553) --- redis/cluster.py | 1 - redis/commands/cluster.py | 10 --- redis/commands/core.py | 127 ---------------------------------- tests/test_cluster.py | 44 ------------ tests/test_commands.py | 51 -------------- tests/test_connection_pool.py | 4 +- tests/test_multiprocessing.py | 4 +- 7 files changed, 4 insertions(+), 237 deletions(-) diff --git a/redis/cluster.py b/redis/cluster.py index 118e67382c..8edf82e413 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -307,7 +307,6 @@ class AbstractRedisCluster: "FUNCTION LIST", "FUNCTION LOAD", "FUNCTION RESTORE", - "REDISGEARS_2.REFRESHCLUSTER", "SCAN", "SCRIPT EXISTS", "SCRIPT FLUSH", diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py index f0b65612e0..13f2035265 100644 --- a/redis/commands/cluster.py +++ b/redis/commands/cluster.py @@ -31,13 +31,11 @@ AsyncACLCommands, AsyncDataAccessCommands, AsyncFunctionCommands, - AsyncGearsCommands, AsyncManagementCommands, AsyncModuleCommands, AsyncScriptCommands, DataAccessCommands, FunctionCommands, - GearsCommands, ManagementCommands, ModuleCommands, PubSubCommands, @@ -693,12 +691,6 @@ def readwrite(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT: self.read_from_replicas = False return self.execute_command("READWRITE", target_nodes=target_nodes) - def gears_refresh_cluster(self, **kwargs) -> ResponseT: - """ - On an OSS cluster, before executing any gears function, you must call this command. # noqa - """ - return self.execute_command("REDISGEARS_2.REFRESHCLUSTER", **kwargs) - class AsyncClusterManagementCommands( ClusterManagementCommands, AsyncManagementCommands @@ -874,7 +866,6 @@ class RedisClusterCommands( ClusterDataAccessCommands, ScriptCommands, FunctionCommands, - GearsCommands, ModuleCommands, RedisModuleCommands, ): @@ -905,7 +896,6 @@ class AsyncRedisClusterCommands( AsyncClusterDataAccessCommands, AsyncScriptCommands, AsyncFunctionCommands, - AsyncGearsCommands, AsyncModuleCommands, AsyncRedisModuleCommands, ): diff --git a/redis/commands/core.py b/redis/commands/core.py index b0e5dc6794..df76eafed0 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -6470,131 +6470,6 @@ def function_stats(self) -> Union[Awaitable[List], List]: AsyncFunctionCommands = FunctionCommands -class GearsCommands: - def tfunction_load( - self, lib_code: str, replace: bool = False, config: Union[str, None] = None - ) -> ResponseT: - """ - Load a new library to RedisGears. - - ``lib_code`` - the library code. - ``config`` - a string representation of a JSON object - that will be provided to the library on load time, - for more information refer to - https://github.com/RedisGears/RedisGears/blob/master/docs/function_advance_topics.md#library-configuration - ``replace`` - an optional argument, instructs RedisGears to replace the - function if its already exists - - For more information see https://redis.io/commands/tfunction-load/ - """ - pieces = [] - if replace: - pieces.append("REPLACE") - if config is not None: - pieces.extend(["CONFIG", config]) - pieces.append(lib_code) - return self.execute_command("TFUNCTION LOAD", *pieces) - - def tfunction_delete(self, lib_name: str) -> ResponseT: - """ - Delete a library from RedisGears. - - ``lib_name`` the library name to delete. - - For more information see https://redis.io/commands/tfunction-delete/ - """ - return self.execute_command("TFUNCTION DELETE", lib_name) - - def tfunction_list( - self, - with_code: bool = False, - verbose: int = 0, - lib_name: Union[str, None] = None, - ) -> ResponseT: - """ - List the functions with additional information about each function. - - ``with_code`` Show libraries code. - ``verbose`` output verbosity level, higher number will increase verbosity level - ``lib_name`` specifying a library name (can be used multiple times to show multiple libraries in a single command) # noqa - - For more information see https://redis.io/commands/tfunction-list/ - """ - pieces = [] - if with_code: - pieces.append("WITHCODE") - if verbose >= 1 and verbose <= 3: - pieces.append("v" * verbose) - else: - raise DataError("verbose can be 1, 2 or 3") - if lib_name is not None: - pieces.append("LIBRARY") - pieces.append(lib_name) - - return self.execute_command("TFUNCTION LIST", *pieces) - - def _tfcall( - self, - lib_name: str, - func_name: str, - keys: KeysT = None, - _async: bool = False, - *args: List, - ) -> ResponseT: - pieces = [f"{lib_name}.{func_name}"] - if keys is not None: - pieces.append(len(keys)) - pieces.extend(keys) - else: - pieces.append(0) - if args is not None: - pieces.extend(args) - if _async: - return self.execute_command("TFCALLASYNC", *pieces) - return self.execute_command("TFCALL", *pieces) - - def tfcall( - self, - lib_name: str, - func_name: str, - keys: KeysT = None, - *args: List, - ) -> ResponseT: - """ - Invoke a function. - - ``lib_name`` - the library name contains the function. - ``func_name`` - the function name to run. - ``keys`` - the keys that will be touched by the function. - ``args`` - Additional argument to pass to the function. - - For more information see https://redis.io/commands/tfcall/ - """ - return self._tfcall(lib_name, func_name, keys, False, *args) - - def tfcall_async( - self, - lib_name: str, - func_name: str, - keys: KeysT = None, - *args: List, - ) -> ResponseT: - """ - Invoke an async function (coroutine). - - ``lib_name`` - the library name contains the function. - ``func_name`` - the function name to run. - ``keys`` - the keys that will be touched by the function. - ``args`` - Additional argument to pass to the function. - - For more information see https://redis.io/commands/tfcall/ - """ - return self._tfcall(lib_name, func_name, keys, True, *args) - - -AsyncGearsCommands = GearsCommands - - class DataAccessCommands( BasicKeyCommands, HyperlogCommands, @@ -6638,7 +6513,6 @@ class CoreCommands( PubSubCommands, ScriptCommands, FunctionCommands, - GearsCommands, ): """ A class containing all of the implemented redis commands. This class is @@ -6655,7 +6529,6 @@ class AsyncCoreCommands( AsyncPubSubCommands, AsyncScriptCommands, AsyncFunctionCommands, - AsyncGearsCommands, ): """ A class containing all of the implemented redis commands. This class is diff --git a/tests/test_cluster.py b/tests/test_cluster.py index bec9a8ecb0..e64db3690b 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -2448,50 +2448,6 @@ def try_delete_libs(self, r, *lib_names): except Exception: pass - @pytest.mark.redismod - @skip_if_server_version_lt("7.1.140") - def test_tfunction_load_delete(self, r): - r.gears_refresh_cluster() - self.try_delete_libs(r, "lib1") - lib_code = self.generate_lib_code("lib1") - assert r.tfunction_load(lib_code) - assert r.tfunction_delete("lib1") - - @pytest.mark.redismod - @skip_if_server_version_lt("7.1.140") - def test_tfunction_list(self, r): - r.gears_refresh_cluster() - self.try_delete_libs(r, "lib1", "lib2", "lib3") - assert r.tfunction_load(self.generate_lib_code("lib1")) - assert r.tfunction_load(self.generate_lib_code("lib2")) - assert r.tfunction_load(self.generate_lib_code("lib3")) - - # test error thrown when verbose > 4 - with pytest.raises(DataError): - assert r.tfunction_list(verbose=8) - - functions = r.tfunction_list(verbose=1) - assert len(functions) == 3 - - expected_names = [b"lib1", b"lib2", b"lib3"] - actual_names = [functions[0][13], functions[1][13], functions[2][13]] - - assert sorted(expected_names) == sorted(actual_names) - assert r.tfunction_delete("lib1") - assert r.tfunction_delete("lib2") - assert r.tfunction_delete("lib3") - - @pytest.mark.redismod - @skip_if_server_version_lt("7.1.140") - def test_tfcall(self, r): - r.gears_refresh_cluster() - self.try_delete_libs(r, "lib1") - assert r.tfunction_load(self.generate_lib_code("lib1")) - assert r.tfcall("lib1", "foo") == b"bar" - assert r.tfcall_async("lib1", "foo") == b"bar" - - assert r.tfunction_delete("lib1") - @pytest.mark.onlycluster class TestNodesManager: diff --git a/tests/test_commands.py b/tests/test_commands.py index c6e39f565d..5c72a019ba 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -2050,57 +2050,6 @@ def try_delete_libs(self, r, *lib_names): except Exception: pass - @pytest.mark.onlynoncluster - @skip_if_server_version_lt("7.1.140") - @skip_if_server_version_gte("7.9.0") - def test_tfunction_load_delete(self, stack_r): - self.try_delete_libs(stack_r, "lib1") - lib_code = self.generate_lib_code("lib1") - assert stack_r.tfunction_load(lib_code) - assert stack_r.tfunction_delete("lib1") - - @pytest.mark.onlynoncluster - @skip_if_server_version_lt("7.1.140") - @skip_if_server_version_gte("7.9.0") - def test_tfunction_list(self, stack_r): - self.try_delete_libs(stack_r, "lib1", "lib2", "lib3") - assert stack_r.tfunction_load(self.generate_lib_code("lib1")) - assert stack_r.tfunction_load(self.generate_lib_code("lib2")) - assert stack_r.tfunction_load(self.generate_lib_code("lib3")) - - # test error thrown when verbose > 4 - with pytest.raises(redis.exceptions.DataError): - assert stack_r.tfunction_list(verbose=8) - - functions = stack_r.tfunction_list(verbose=1) - assert len(functions) == 3 - - expected_names = [b"lib1", b"lib2", b"lib3"] - if is_resp2_connection(stack_r): - actual_names = [functions[0][13], functions[1][13], functions[2][13]] - else: - actual_names = [ - functions[0][b"name"], - functions[1][b"name"], - functions[2][b"name"], - ] - - assert sorted(expected_names) == sorted(actual_names) - assert stack_r.tfunction_delete("lib1") - assert stack_r.tfunction_delete("lib2") - assert stack_r.tfunction_delete("lib3") - - @pytest.mark.onlynoncluster - @skip_if_server_version_lt("7.1.140") - @skip_if_server_version_gte("7.9.0") - def test_tfcall(self, stack_r): - self.try_delete_libs(stack_r, "lib1") - assert stack_r.tfunction_load(self.generate_lib_code("lib1")) - assert stack_r.tfcall("lib1", "foo") == b"bar" - assert stack_r.tfcall_async("lib1", "foo") == b"bar" - - assert stack_r.tfunction_delete("lib1") - def test_ttl(self, r): r["a"] = "1" assert r.expire("a", 10) diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index c92d84c226..0ec77a4fff 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -94,11 +94,11 @@ def test_reuse_previously_released_connection(self, master_host): def test_release_not_owned_connection(self, master_host): connection_kwargs = {"host": master_host[0], "port": master_host[1]} pool1 = self.get_pool(connection_kwargs=connection_kwargs) - c1 = pool1.get_connection("_") + c1 = pool1.get_connection() pool2 = self.get_pool( connection_kwargs={"host": master_host[0], "port": master_host[1]} ) - c2 = pool2.get_connection("_") + c2 = pool2.get_connection() pool2.release(c2) assert len(pool2._available_connections) == 1 diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index f60898ab86..79301b93f1 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -105,11 +105,11 @@ def test_release_parent_connection_from_pool_in_child_process( max_connections=max_connections, ) - parent_conn = pool.get_connection("ping") + parent_conn = pool.get_connection() def target(pool, parent_conn): with exit_callback(pool.disconnect): - child_conn = pool.get_connection("ping") + child_conn = pool.get_connection() assert child_conn.pid != parent_conn.pid pool.release(child_conn) assert pool._created_connections == 1 From 944f010b64465c2cb0661fe00b7f91298d0a91ee Mon Sep 17 00:00:00 2001 From: Joel Dice Date: Fri, 14 Mar 2025 02:16:56 -0600 Subject: [PATCH 068/113] skip `ssl` import if not available (#3078) * skip `ssl` import if not available Signed-off-by: Joel Dice * address review feedback and fix lint errors Signed-off-by: Joel Dice * remove TYPE_CHECKING clause from ssl conditional Signed-off-by: Joel Dice --------- Signed-off-by: Joel Dice Co-authored-by: petyaslavova --- redis/asyncio/client.py | 9 +++++++-- redis/asyncio/cluster.py | 16 +++++++++++++--- redis/asyncio/connection.py | 25 ++++++++++++++++++++----- redis/connection.py | 6 +++++- 4 files changed, 45 insertions(+), 11 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 412d5a24b3..0039cea540 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -2,7 +2,6 @@ import copy import inspect import re -import ssl import warnings from typing import ( TYPE_CHECKING, @@ -72,6 +71,7 @@ from redis.typing import ChannelT, EncodableT, KeyT from redis.utils import ( HIREDIS_AVAILABLE, + SSL_AVAILABLE, _set_info_logger, deprecated_function, get_lib_version, @@ -79,6 +79,11 @@ str_if_bytes, ) +if TYPE_CHECKING and SSL_AVAILABLE: + from ssl import TLSVersion +else: + TLSVersion = None + PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]] _KeyT = TypeVar("_KeyT", bound=KeyT) _ArgT = TypeVar("_ArgT", KeyT, EncodableT) @@ -226,7 +231,7 @@ def __init__( ssl_ca_certs: Optional[str] = None, ssl_ca_data: Optional[str] = None, ssl_check_hostname: bool = False, - ssl_min_version: Optional[ssl.TLSVersion] = None, + ssl_min_version: Optional[TLSVersion] = None, ssl_ciphers: Optional[str] = None, max_connections: Optional[int] = None, single_connection_client: bool = False, diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 7b9b609417..b3358c2817 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -2,7 +2,6 @@ import collections import random import socket -import ssl import warnings from typing import ( Any, @@ -64,7 +63,18 @@ TryAgainError, ) from redis.typing import AnyKeyT, EncodableT, KeyT -from redis.utils import deprecated_function, get_lib_version, safe_str, str_if_bytes +from redis.utils import ( + SSL_AVAILABLE, + deprecated_function, + get_lib_version, + safe_str, + str_if_bytes, +) + +if SSL_AVAILABLE: + from ssl import TLSVersion +else: + TLSVersion = None TargetNodesT = TypeVar( "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] @@ -256,7 +266,7 @@ def __init__( ssl_certfile: Optional[str] = None, ssl_check_hostname: bool = False, ssl_keyfile: Optional[str] = None, - ssl_min_version: Optional[ssl.TLSVersion] = None, + ssl_min_version: Optional[TLSVersion] = None, ssl_ciphers: Optional[str] = None, protocol: Optional[int] = 2, address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 9b5d0d8eb9..15b9219aaa 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -3,7 +3,6 @@ import enum import inspect import socket -import ssl import sys import warnings import weakref @@ -27,6 +26,16 @@ ) from urllib.parse import ParseResult, parse_qs, unquote, urlparse +from ..utils import SSL_AVAILABLE + +if SSL_AVAILABLE: + import ssl + from ssl import SSLContext, TLSVersion +else: + ssl = None + TLSVersion = None + SSLContext = None + from ..auth.token import TokenInterface from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher from ..utils import deprecated_args, format_error_message @@ -763,10 +772,13 @@ def __init__( ssl_ca_certs: Optional[str] = None, ssl_ca_data: Optional[str] = None, ssl_check_hostname: bool = False, - ssl_min_version: Optional[ssl.TLSVersion] = None, + ssl_min_version: Optional[TLSVersion] = None, ssl_ciphers: Optional[str] = None, **kwargs, ): + if not SSL_AVAILABLE: + raise RedisError("Python wasn't built with SSL support") + self.ssl_context: RedisSSLContext = RedisSSLContext( keyfile=ssl_keyfile, certfile=ssl_certfile, @@ -834,9 +846,12 @@ def __init__( ca_certs: Optional[str] = None, ca_data: Optional[str] = None, check_hostname: bool = False, - min_version: Optional[ssl.TLSVersion] = None, + min_version: Optional[TLSVersion] = None, ciphers: Optional[str] = None, ): + if not SSL_AVAILABLE: + raise RedisError("Python wasn't built with SSL support") + self.keyfile = keyfile self.certfile = certfile if cert_reqs is None: @@ -857,9 +872,9 @@ def __init__( self.check_hostname = check_hostname self.min_version = min_version self.ciphers = ciphers - self.context: Optional[ssl.SSLContext] = None + self.context: Optional[SSLContext] = None - def get(self) -> ssl.SSLContext: + def get(self) -> SSLContext: if not self.context: context = ssl.create_default_context() context.check_hostname = self.check_hostname diff --git a/redis/connection.py b/redis/connection.py index a298542c03..b6dee40d75 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1,7 +1,6 @@ import copy import os import socket -import ssl import sys import threading import time @@ -49,6 +48,11 @@ str_if_bytes, ) +if SSL_AVAILABLE: + import ssl +else: + ssl = None + if HIREDIS_AVAILABLE: import hiredis From 540c3e819696ff17ee52e30b5cde4528612412d8 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Tue, 18 Mar 2025 10:46:41 +0200 Subject: [PATCH 069/113] Fixing search module dropindex function not to send invalid third parameter. Updating pipeline infra (#3564) * Fixing search module dropindex function not to send invalid third parameter. Updating pipeline infra * Fixing linters --- .github/workflows/integration.yaml | 2 +- redis/commands/search/commands.py | 14 ++++++++++++-- tests/test_asyncio/test_search.py | 8 ++++---- tests/test_search.py | 10 +++++----- 4 files changed, 22 insertions(+), 12 deletions(-) diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index bb56e8a024..514a88a796 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -74,7 +74,7 @@ jobs: max-parallel: 15 fail-fast: false matrix: - redis-version: ['8.0-M04-pre', '${{ needs.redis_version.outputs.CURRENT }}', '7.2.7', '6.2.17'] + redis-version: ['8.0-M05-pre', '${{ needs.redis_version.outputs.CURRENT }}', '7.2.7', '6.2.17'] python-version: ['3.8', '3.13'] parser-backend: ['plain'] event-loop: ['asyncio'] diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index 96c6d9c2af..42866f5ec1 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -255,8 +255,18 @@ def dropindex(self, delete_documents: bool = False): For more information see `FT.DROPINDEX `_. """ # noqa - delete_str = "DD" if delete_documents else "" - return self.execute_command(DROPINDEX_CMD, self.index_name, delete_str) + args = [DROPINDEX_CMD, self.index_name] + + delete_str = ( + "DD" + if isinstance(delete_documents, bool) and delete_documents is True + else "" + ) + + if delete_str: + args.append(delete_str) + + return self.execute_command(*args) def _add_document( self, diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index c55d57f3b2..9a318796bf 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -1603,14 +1603,14 @@ async def test_withsuffixtrie(decoded_r: redis.Redis): if is_resp2_connection(decoded_r): info = await decoded_r.ft().info() assert "WITHSUFFIXTRIE" not in info["attributes"][0] - assert await decoded_r.ft().dropindex("idx") + assert await decoded_r.ft().dropindex() # create withsuffixtrie index (text field) assert await decoded_r.ft().create_index(TextField("t", withsuffixtrie=True)) await waitForIndex(decoded_r, getattr(decoded_r.ft(), "index_name", "idx")) info = await decoded_r.ft().info() assert "WITHSUFFIXTRIE" in info["attributes"][0] - assert await decoded_r.ft().dropindex("idx") + assert await decoded_r.ft().dropindex() # create withsuffixtrie index (tag field) assert await decoded_r.ft().create_index(TagField("t", withsuffixtrie=True)) @@ -1620,14 +1620,14 @@ async def test_withsuffixtrie(decoded_r: redis.Redis): else: info = await decoded_r.ft().info() assert "WITHSUFFIXTRIE" not in info["attributes"][0]["flags"] - assert await decoded_r.ft().dropindex("idx") + assert await decoded_r.ft().dropindex() # create withsuffixtrie index (text fields) assert await decoded_r.ft().create_index(TextField("t", withsuffixtrie=True)) await waitForIndex(decoded_r, getattr(decoded_r.ft(), "index_name", "idx")) info = await decoded_r.ft().info() assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"] - assert await decoded_r.ft().dropindex("idx") + assert await decoded_r.ft().dropindex() # create withsuffixtrie index (tag field) assert await decoded_r.ft().create_index(TagField("t", withsuffixtrie=True)) diff --git a/tests/test_search.py b/tests/test_search.py index 11f22ac805..7e4f59eb79 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1711,7 +1711,7 @@ def test_max_text_fields(client): with pytest.raises(redis.ResponseError): client.ft().alter_schema_add((TextField(f"f{x}"),)) - client.ft().dropindex("idx") + client.ft().dropindex() # Creating the index definition client.ft().create_index((TextField("f0"),), max_text_fields=True) # Fill the index with fields @@ -2575,14 +2575,14 @@ def test_withsuffixtrie(client: redis.Redis): if is_resp2_connection(client): info = client.ft().info() assert "WITHSUFFIXTRIE" not in info["attributes"][0] - assert client.ft().dropindex("idx") + assert client.ft().dropindex() # create withsuffixtrie index (text fields) assert client.ft().create_index(TextField("t", withsuffixtrie=True)) waitForIndex(client, getattr(client.ft(), "index_name", "idx")) info = client.ft().info() assert "WITHSUFFIXTRIE" in info["attributes"][0] - assert client.ft().dropindex("idx") + assert client.ft().dropindex() # create withsuffixtrie index (tag field) assert client.ft().create_index(TagField("t", withsuffixtrie=True)) @@ -2592,14 +2592,14 @@ def test_withsuffixtrie(client: redis.Redis): else: info = client.ft().info() assert "WITHSUFFIXTRIE" not in info["attributes"][0]["flags"] - assert client.ft().dropindex("idx") + assert client.ft().dropindex() # create withsuffixtrie index (text fields) assert client.ft().create_index(TextField("t", withsuffixtrie=True)) waitForIndex(client, getattr(client.ft(), "index_name", "idx")) info = client.ft().info() assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"] - assert client.ft().dropindex("idx") + assert client.ft().dropindex() # create withsuffixtrie index (tag field) assert client.ft().create_index(TagField("t", withsuffixtrie=True)) From decec9a933b0617b79bfb1e8721058cb624b5a73 Mon Sep 17 00:00:00 2001 From: Don Bowman Date: Tue, 18 Mar 2025 11:33:44 -0400 Subject: [PATCH 070/113] fix: add TimeoutError handling in get_connection() (#1485) * fix: add TimeoutError handling in get_connection() In get_connection() we can implicitly call read on a connection. This can timeout of the underlying TCP session is gone. With this change we remove it from the connection pool and get a new connection. * fix: add TimeoutError handling in get_connection() sync/async In get_connection() we can implicitly call read on a connection. This can timeout of the underlying TCP session is gone. With this change we remove it from the connection pool and get a new connection. * fix: update version number to match test expectations * fix: revert version number to 5.2.1, manually uninstall entraid --- redis/asyncio/connection.py | 2 +- redis/connection.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 15b9219aaa..66dbd09b61 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -1157,7 +1157,7 @@ async def ensure_connection(self, connection: AbstractConnection): try: if await connection.can_read_destructive(): raise ConnectionError("Connection has data") from None - except (ConnectionError, OSError): + except (ConnectionError, TimeoutError, OSError): await connection.disconnect() await connection.connect() if await connection.can_read_destructive(): diff --git a/redis/connection.py b/redis/connection.py index b6dee40d75..f754a5165a 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1494,7 +1494,7 @@ def get_connection(self, command_name=None, *keys, **options) -> "Connection": try: if connection.can_read() and self.cache is None: raise ConnectionError("Connection has data") - except (ConnectionError, OSError): + except (ConnectionError, TimeoutError, OSError): connection.disconnect() connection.connect() if connection.can_read(): @@ -1741,7 +1741,7 @@ def get_connection(self, command_name=None, *keys, **options): try: if connection.can_read(): raise ConnectionError("Connection has data") - except (ConnectionError, OSError): + except (ConnectionError, TimeoutError, OSError): connection.disconnect() connection.connect() if connection.can_read(): From c7443664b68259bba096c543e23b6ef29513a81c Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Wed, 19 Mar 2025 09:58:11 +0200 Subject: [PATCH 071/113] Adding load balancing strategy configuration to cluster clients(replacement for 'read_from_replicas' config) (#3563) * Adding laod balancing strategy configuration to cluster clients(replacement for 'read_from_replicas' config) * Fixing linter errors * Changing the LoadBalancingStrategy type hints to be defined as optional. Fixed wording in pydocs * Adding integration tests with the different load balancing strategies for read operation * Fixing linters --- redis/asyncio/cluster.py | 40 ++++++-- redis/cluster.py | 100 +++++++++++++++--- tests/test_asyncio/test_cluster.py | 139 +++++++++++++++++++++++-- tests/test_cluster.py | 160 +++++++++++++++++++++++++++-- tests/test_multiprocessing.py | 4 +- 5 files changed, 400 insertions(+), 43 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index b3358c2817..7a29550a35 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -38,6 +38,7 @@ SLOT_ID, AbstractRedisCluster, LoadBalancer, + LoadBalancingStrategy, block_pipeline_command, get_node_name, parse_cluster_slots, @@ -65,6 +66,7 @@ from redis.typing import AnyKeyT, EncodableT, KeyT from redis.utils import ( SSL_AVAILABLE, + deprecated_args, deprecated_function, get_lib_version, safe_str, @@ -121,9 +123,14 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand | See: https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters :param read_from_replicas: - | Enable read from replicas in READONLY mode. You can read possibly stale data. + | @deprecated - please use load_balancing_strategy instead + | Enable read from replicas in READONLY mode. When set to true, read commands will be assigned between the primary and its replications in a Round-Robin manner. + :param load_balancing_strategy: + | Enable read from replicas in READONLY mode and defines the load balancing + strategy that will be used for cluster node selection. + The data read from replicas is eventually consistent with the data in primary nodes. :param dynamic_startup_nodes: | Set the RedisCluster's startup nodes to all the discovered nodes. If true (default value), the cluster's discovered nodes will be used to @@ -132,6 +139,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand listed in the CLUSTER SLOTS output. If you use dynamic DNS endpoints for startup nodes but CLUSTER SLOTS lists specific IP addresses, it is best to set it to false. + The data read from replicas is eventually consistent with the data in primary nodes. :param reinitialize_steps: | Specifies the number of MOVED errors that need to occur before reinitializing the whole cluster topology. If a MOVED error occurs and the cluster does not @@ -224,6 +232,11 @@ def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster": "result_callbacks", ) + @deprecated_args( + args_to_warn=["read_from_replicas"], + reason="Please configure the 'load_balancing_strategy' instead", + version="5.0.3", + ) def __init__( self, host: Optional[str] = None, @@ -232,6 +245,7 @@ def __init__( startup_nodes: Optional[List["ClusterNode"]] = None, require_full_coverage: bool = True, read_from_replicas: bool = False, + load_balancing_strategy: Optional[LoadBalancingStrategy] = None, dynamic_startup_nodes: bool = True, reinitialize_steps: int = 5, cluster_error_retry_attempts: int = 3, @@ -331,7 +345,7 @@ def __init__( } ) - if read_from_replicas: + if read_from_replicas or load_balancing_strategy: # Call our on_connect function to configure READONLY mode kwargs["redis_connect_func"] = self.on_connect @@ -381,6 +395,7 @@ def __init__( ) self.encoder = Encoder(encoding, encoding_errors, decode_responses) self.read_from_replicas = read_from_replicas + self.load_balancing_strategy = load_balancing_strategy self.reinitialize_steps = reinitialize_steps self.cluster_error_retry_attempts = cluster_error_retry_attempts self.connection_error_retry_attempts = connection_error_retry_attempts @@ -599,6 +614,7 @@ async def _determine_nodes( self.nodes_manager.get_node_from_slot( await self._determine_slot(command, *args), self.read_from_replicas and command in READ_COMMANDS, + self.load_balancing_strategy if command in READ_COMMANDS else None, ) ] @@ -779,7 +795,11 @@ async def _execute_command( # refresh the target node slot = await self._determine_slot(*args) target_node = self.nodes_manager.get_node_from_slot( - slot, self.read_from_replicas and args[0] in READ_COMMANDS + slot, + self.read_from_replicas and args[0] in READ_COMMANDS, + self.load_balancing_strategy + if args[0] in READ_COMMANDS + else None, ) moved = False @@ -1244,17 +1264,23 @@ def _update_moved_slots(self) -> None: self._moved_exception = None def get_node_from_slot( - self, slot: int, read_from_replicas: bool = False + self, + slot: int, + read_from_replicas: bool = False, + load_balancing_strategy=None, ) -> "ClusterNode": if self._moved_exception: self._update_moved_slots() + if read_from_replicas is True and load_balancing_strategy is None: + load_balancing_strategy = LoadBalancingStrategy.ROUND_ROBIN + try: - if read_from_replicas: - # get the server index in a Round-Robin manner + if len(self.slots_cache[slot]) > 1 and load_balancing_strategy: + # get the server index using the strategy defined in load_balancing_strategy primary_name = self.slots_cache[slot][0].name node_idx = self.read_load_balancer.get_server_index( - primary_name, len(self.slots_cache[slot]) + primary_name, len(self.slots_cache[slot]), load_balancing_strategy ) return self.slots_cache[slot][node_idx] return self.slots_cache[slot][0] diff --git a/redis/cluster.py b/redis/cluster.py index 8edf82e413..0488608a60 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -4,6 +4,7 @@ import threading import time from collections import OrderedDict +from enum import Enum from typing import Any, Callable, Dict, List, Optional, Tuple, Union from redis._parsers import CommandsParser, Encoder @@ -482,6 +483,11 @@ class initializer. In the case of conflicting arguments, querystring """ return cls(url=url, **kwargs) + @deprecated_args( + args_to_warn=["read_from_replicas"], + reason="Please configure the 'load_balancing_strategy' instead", + version="5.0.3", + ) def __init__( self, host: Optional[str] = None, @@ -492,6 +498,7 @@ def __init__( require_full_coverage: bool = False, reinitialize_steps: int = 5, read_from_replicas: bool = False, + load_balancing_strategy: Optional["LoadBalancingStrategy"] = None, dynamic_startup_nodes: bool = True, url: Optional[str] = None, address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, @@ -520,11 +527,16 @@ def __init__( cluster client. If not all slots are covered, RedisClusterException will be thrown. :param read_from_replicas: + @deprecated - please use load_balancing_strategy instead Enable read from replicas in READONLY mode. You can read possibly stale data. When set to true, read commands will be assigned between the primary and its replications in a Round-Robin manner. - :param dynamic_startup_nodes: + :param load_balancing_strategy: + Enable read from replicas in READONLY mode and defines the load balancing + strategy that will be used for cluster node selection. + The data read from replicas is eventually consistent with the data in primary nodes. + :param dynamic_startup_nodes: Set the RedisCluster's startup nodes to all of the discovered nodes. If true (default value), the cluster's discovered nodes will be used to determine the cluster nodes-slots mapping in the next topology refresh. @@ -629,6 +641,7 @@ def __init__( self.command_flags = self.__class__.COMMAND_FLAGS.copy() self.node_flags = self.__class__.NODE_FLAGS.copy() self.read_from_replicas = read_from_replicas + self.load_balancing_strategy = load_balancing_strategy self.reinitialize_counter = 0 self.reinitialize_steps = reinitialize_steps if event_dispatcher is None: @@ -683,7 +696,7 @@ def on_connect(self, connection): """ connection.on_connect() - if self.read_from_replicas: + if self.read_from_replicas or self.load_balancing_strategy: # Sending READONLY command to server to configure connection as # readonly. Since each cluster node may change its server type due # to a failover, we should establish a READONLY connection @@ -810,6 +823,7 @@ def pipeline(self, transaction=None, shard_hint=None): cluster_response_callbacks=self.cluster_response_callbacks, cluster_error_retry_attempts=self.cluster_error_retry_attempts, read_from_replicas=self.read_from_replicas, + load_balancing_strategy=self.load_balancing_strategy, reinitialize_steps=self.reinitialize_steps, lock=self._lock, ) @@ -934,7 +948,9 @@ def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]: # get the node that holds the key's slot slot = self.determine_slot(*args) node = self.nodes_manager.get_node_from_slot( - slot, self.read_from_replicas and command in READ_COMMANDS + slot, + self.read_from_replicas and command in READ_COMMANDS, + self.load_balancing_strategy if command in READ_COMMANDS else None, ) return [node] @@ -1158,7 +1174,11 @@ def _execute_command(self, target_node, *args, **kwargs): # refresh the target node slot = self.determine_slot(*args) target_node = self.nodes_manager.get_node_from_slot( - slot, self.read_from_replicas and command in READ_COMMANDS + slot, + self.read_from_replicas and command in READ_COMMANDS, + self.load_balancing_strategy + if command in READ_COMMANDS + else None, ) moved = False @@ -1307,6 +1327,12 @@ def __del__(self): self.redis_connection.close() +class LoadBalancingStrategy(Enum): + ROUND_ROBIN = "round_robin" + ROUND_ROBIN_REPLICAS = "round_robin_replicas" + RANDOM_REPLICA = "random_replica" + + class LoadBalancer: """ Round-Robin Load Balancing @@ -1316,15 +1342,38 @@ def __init__(self, start_index: int = 0) -> None: self.primary_to_idx = {} self.start_index = start_index - def get_server_index(self, primary: str, list_size: int) -> int: - server_index = self.primary_to_idx.setdefault(primary, self.start_index) - # Update the index - self.primary_to_idx[primary] = (server_index + 1) % list_size - return server_index + def get_server_index( + self, + primary: str, + list_size: int, + load_balancing_strategy: LoadBalancingStrategy = LoadBalancingStrategy.ROUND_ROBIN, + ) -> int: + if load_balancing_strategy == LoadBalancingStrategy.RANDOM_REPLICA: + return self._get_random_replica_index(list_size) + else: + return self._get_round_robin_index( + primary, + list_size, + load_balancing_strategy == LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, + ) def reset(self) -> None: self.primary_to_idx.clear() + def _get_random_replica_index(self, list_size: int) -> int: + return random.randint(1, list_size - 1) + + def _get_round_robin_index( + self, primary: str, list_size: int, replicas_only: bool + ) -> int: + server_index = self.primary_to_idx.setdefault(primary, self.start_index) + if replicas_only and server_index == 0: + # skip the primary node index + server_index = 1 + # Update the index for the next round + self.primary_to_idx[primary] = (server_index + 1) % list_size + return server_index + class NodesManager: def __init__( @@ -1428,7 +1477,21 @@ def _update_moved_slots(self): # Reset moved_exception self._moved_exception = None - def get_node_from_slot(self, slot, read_from_replicas=False, server_type=None): + @deprecated_args( + args_to_warn=["server_type"], + reason=( + "In case you need select some load balancing strategy " + "that will use replicas, please set it through 'load_balancing_strategy'" + ), + version="5.0.3", + ) + def get_node_from_slot( + self, + slot, + read_from_replicas=False, + load_balancing_strategy=None, + server_type=None, + ): """ Gets a node that servers this hash slot """ @@ -1443,11 +1506,14 @@ def get_node_from_slot(self, slot, read_from_replicas=False, server_type=None): f'"require_full_coverage={self._require_full_coverage}"' ) - if read_from_replicas is True: - # get the server index in a Round-Robin manner + if read_from_replicas is True and load_balancing_strategy is None: + load_balancing_strategy = LoadBalancingStrategy.ROUND_ROBIN + + if len(self.slots_cache[slot]) > 1 and load_balancing_strategy: + # get the server index using the strategy defined in load_balancing_strategy primary_name = self.slots_cache[slot][0].name node_idx = self.read_load_balancer.get_server_index( - primary_name, len(self.slots_cache[slot]) + primary_name, len(self.slots_cache[slot]), load_balancing_strategy ) elif ( server_type is None @@ -1730,7 +1796,7 @@ def __init__( first command execution. The node will be determined by: 1. Hashing the channel name in the request to find its keyslot 2. Selecting a node that handles the keyslot: If read_from_replicas is - set to true, a replica can be selected. + set to true or load_balancing_strategy is set, a replica can be selected. :type redis_cluster: RedisCluster :type node: ClusterNode @@ -1826,7 +1892,9 @@ def execute_command(self, *args): channel = args[1] slot = self.cluster.keyslot(channel) node = self.cluster.nodes_manager.get_node_from_slot( - slot, self.cluster.read_from_replicas + slot, + self.cluster.read_from_replicas, + self.cluster.load_balancing_strategy, ) else: # Get a random node @@ -1969,6 +2037,7 @@ def __init__( cluster_response_callbacks: Optional[Dict[str, Callable]] = None, startup_nodes: Optional[List["ClusterNode"]] = None, read_from_replicas: bool = False, + load_balancing_strategy: Optional[LoadBalancingStrategy] = None, cluster_error_retry_attempts: int = 3, reinitialize_steps: int = 5, lock=None, @@ -1984,6 +2053,7 @@ def __init__( ) self.startup_nodes = startup_nodes if startup_nodes else [] self.read_from_replicas = read_from_replicas + self.load_balancing_strategy = load_balancing_strategy self.command_flags = self.__class__.COMMAND_FLAGS.copy() self.cluster_response_callbacks = cluster_response_callbacks self.cluster_error_retry_attempts = cluster_error_retry_attempts diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index f298503799..a4f0636299 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -14,7 +14,13 @@ from redis.asyncio.connection import Connection, SSLConnection, async_timeout from redis.asyncio.retry import Retry from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff -from redis.cluster import PIPELINE_BLOCKED_COMMANDS, PRIMARY, REPLICA, get_node_name +from redis.cluster import ( + PIPELINE_BLOCKED_COMMANDS, + PRIMARY, + REPLICA, + LoadBalancingStrategy, + get_node_name, +) from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot from redis.exceptions import ( AskError, @@ -181,7 +187,18 @@ def cmd_init_mock(self, r: ClusterNode) -> None: cmd_parser_initialize.side_effect = cmd_init_mock - return await RedisCluster(*args, **kwargs) + # Create a subclass of RedisCluster that overrides __del__ + class MockedRedisCluster(RedisCluster): + def __del__(self): + # Override to prevent connection cleanup attempts + pass + + @property + def connection_pool(self): + # Required abstract property implementation + return self.nodes_manager.get_default_node().redis_connection.connection_pool + + return await MockedRedisCluster(*args, **kwargs) def mock_node_resp(node: ClusterNode, response: Any) -> ClusterNode: @@ -677,7 +694,24 @@ def cmd_init_mock(self, r: ClusterNode) -> None: assert execute_command.failed_calls == 1 assert execute_command.successful_calls == 1 - async def test_reading_from_replicas_in_round_robin(self) -> None: + @pytest.mark.parametrize( + "read_from_replicas,load_balancing_strategy,mocks_srv_ports", + [ + (True, None, [7001, 7002, 7001]), + (True, LoadBalancingStrategy.ROUND_ROBIN, [7001, 7002, 7001]), + (True, LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, [7002, 7002, 7002]), + (True, LoadBalancingStrategy.RANDOM_REPLICA, [7002, 7002, 7002]), + (False, LoadBalancingStrategy.ROUND_ROBIN, [7001, 7002, 7001]), + (False, LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, [7002, 7002, 7002]), + (False, LoadBalancingStrategy.RANDOM_REPLICA, [7002, 7002, 7002]), + ], + ) + async def test_reading_with_load_balancing_strategies( + self, + read_from_replicas: bool, + load_balancing_strategy: LoadBalancingStrategy, + mocks_srv_ports: List[int], + ) -> None: with mock.patch.multiple( Connection, send_command=mock.DEFAULT, @@ -693,19 +727,19 @@ async def test_reading_from_replicas_in_round_robin(self) -> None: async def execute_command_mock_first(self, *args, **options): await self.connection_class(**self.connection_kwargs).connect() # Primary - assert self.port == 7001 + assert self.port == mocks_srv_ports[0] execute_command.side_effect = execute_command_mock_second return "MOCK_OK" def execute_command_mock_second(self, *args, **options): # Replica - assert self.port == 7002 + assert self.port == mocks_srv_ports[1] execute_command.side_effect = execute_command_mock_third return "MOCK_OK" def execute_command_mock_third(self, *args, **options): # Primary - assert self.port == 7001 + assert self.port == mocks_srv_ports[2] return "MOCK_OK" # We don't need to create a real cluster connection but we @@ -720,9 +754,13 @@ def execute_command_mock_third(self, *args, **options): # Create a cluster with reading from replications read_cluster = await get_mocked_redis_client( - host=default_host, port=default_port, read_from_replicas=True + host=default_host, + port=default_port, + read_from_replicas=read_from_replicas, + load_balancing_strategy=load_balancing_strategy, ) - assert read_cluster.read_from_replicas is True + assert read_cluster.read_from_replicas is read_from_replicas + assert read_cluster.load_balancing_strategy is load_balancing_strategy # Check that we read from the slot's nodes in a round robin # matter. # 'foo' belongs to slot 12182 and the slot's nodes are: @@ -970,6 +1008,34 @@ async def test_get_and_set(self, r: RedisCluster) -> None: assert await r.get("integer") == str(integer).encode() assert (await r.get("unicode_string")).decode("utf-8") == unicode_string + @pytest.mark.parametrize( + "load_balancing_strategy", + [ + LoadBalancingStrategy.ROUND_ROBIN, + LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, + LoadBalancingStrategy.RANDOM_REPLICA, + ], + ) + async def test_get_and_set_with_load_balanced_client( + self, create_redis, load_balancing_strategy: LoadBalancingStrategy + ) -> None: + r = await create_redis( + cls=RedisCluster, + load_balancing_strategy=load_balancing_strategy, + ) + + # get and set can't be tested independently of each other + assert await r.get("a") is None + + byte_string = b"value" + assert await r.set("byte_string", byte_string) + + # run the get command for the same key several times + # to iterate over the read nodes + assert await r.get("byte_string") == byte_string + assert await r.get("byte_string") == byte_string + assert await r.get("byte_string") == byte_string + async def test_mget_nonatomic(self, r: RedisCluster) -> None: assert await r.mget_nonatomic([]) == [] assert await r.mget_nonatomic(["a", "b"]) == [None, None] @@ -2370,11 +2436,14 @@ async def test_load_balancer(self, r: RedisCluster) -> None: primary2_name = n_manager.slots_cache[slot_2][0].name list1_size = len(n_manager.slots_cache[slot_1]) list2_size = len(n_manager.slots_cache[slot_2]) + + # default load balancer strategy: LoadBalancerStrategy.ROUND_ROBIN # slot 1 assert lb.get_server_index(primary1_name, list1_size) == 0 assert lb.get_server_index(primary1_name, list1_size) == 1 assert lb.get_server_index(primary1_name, list1_size) == 2 assert lb.get_server_index(primary1_name, list1_size) == 0 + # slot 2 assert lb.get_server_index(primary2_name, list2_size) == 0 assert lb.get_server_index(primary2_name, list2_size) == 1 @@ -2384,6 +2453,29 @@ async def test_load_balancer(self, r: RedisCluster) -> None: assert lb.get_server_index(primary1_name, list1_size) == 0 assert lb.get_server_index(primary2_name, list2_size) == 0 + # reset the indexes before load balancing strategy test + lb.reset() + # load balancer strategy: LoadBalancerStrategy.ROUND_ROBIN_REPLICAS + for i in [1, 2, 1]: + srv_index = lb.get_server_index( + primary1_name, + list1_size, + load_balancing_strategy=LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, + ) + assert srv_index == i + + # reset the indexes before load balancing strategy test + lb.reset() + # load balancer strategy: LoadBalancerStrategy.RANDOM_REPLICA + for i in range(5): + srv_index = lb.get_server_index( + primary1_name, + list1_size, + load_balancing_strategy=LoadBalancingStrategy.RANDOM_REPLICA, + ) + + assert srv_index > 0 and srv_index <= 2 + async def test_init_slots_cache_not_all_slots_covered(self) -> None: """ Test that if not all slots are covered it should raise an exception @@ -2887,6 +2979,37 @@ async def test_readonly_pipeline_from_readonly_client( break assert executed_on_replica + @pytest.mark.parametrize( + "load_balancing_strategy", + [ + LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, + LoadBalancingStrategy.RANDOM_REPLICA, + ], + ) + async def test_readonly_pipeline_with_reading_from_replicas_strategies( + self, r: RedisCluster, load_balancing_strategy: LoadBalancingStrategy + ) -> None: + """ + Test that the pipeline uses replicas for different replica-based + load balancing strategies. + """ + # Set the load balancing strategy + r.load_balancing_strategy = load_balancing_strategy + key = "bar" + await r.set(key, "foo") + + async with r.pipeline() as pipe: + mock_all_nodes_resp(r, "MOCK_OK") + assert await pipe.get(key).get(key).execute() == ["MOCK_OK", "MOCK_OK"] + slot_nodes = r.nodes_manager.slots_cache[r.keyslot(key)] + executed_on_replicas_only = True + for node in slot_nodes: + if node.server_type == PRIMARY: + if node._free.pop().read_response.await_count > 0: + executed_on_replicas_only = False + break + assert executed_on_replicas_only + async def test_can_run_concurrent_pipelines(self, r: RedisCluster) -> None: """Test that the pipeline can be used concurrently.""" await asyncio.gather( diff --git a/tests/test_cluster.py b/tests/test_cluster.py index e64db3690b..b71908d396 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -4,6 +4,7 @@ import socket import socketserver import threading +from typing import List import warnings from queue import LifoQueue, Queue from time import sleep @@ -19,6 +20,7 @@ REDIS_CLUSTER_HASH_SLOTS, REPLICA, ClusterNode, + LoadBalancingStrategy, NodesManager, RedisCluster, get_node_name, @@ -202,7 +204,18 @@ def cmd_init_mock(self, r): cmd_parser_initialize.side_effect = cmd_init_mock - return RedisCluster(*args, **kwargs) + # Create a subclass of RedisCluster that overrides __del__ + class MockedRedisCluster(RedisCluster): + def __del__(self): + # Override to prevent connection cleanup attempts + pass + + @property + def connection_pool(self): + # Required abstract property implementation + return self.nodes_manager.get_default_node().redis_connection.connection_pool + + return MockedRedisCluster(*args, **kwargs) def mock_node_resp(node, response): @@ -590,7 +603,24 @@ def cmd_init_mock(self, r): assert parse_response.failed_calls == 1 assert parse_response.successful_calls == 1 - def test_reading_from_replicas_in_round_robin(self): + @pytest.mark.parametrize( + "read_from_replicas,load_balancing_strategy,mocks_srv_ports", + [ + (True, None, [7001, 7002, 7001]), + (True, LoadBalancingStrategy.ROUND_ROBIN, [7001, 7002, 7001]), + (True, LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, [7002, 7002, 7002]), + (True, LoadBalancingStrategy.RANDOM_REPLICA, [7002, 7002, 7002]), + (False, LoadBalancingStrategy.ROUND_ROBIN, [7001, 7002, 7001]), + (False, LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, [7002, 7002, 7002]), + (False, LoadBalancingStrategy.RANDOM_REPLICA, [7002, 7002, 7002]), + ], + ) + def test_reading_with_load_balancing_strategies( + self, + read_from_replicas: bool, + load_balancing_strategy: LoadBalancingStrategy, + mocks_srv_ports: List[int], + ): with patch.multiple( Connection, send_command=DEFAULT, @@ -603,19 +633,19 @@ def test_reading_from_replicas_in_round_robin(self): def parse_response_mock_first(connection, *args, **options): # Primary - assert connection.port == 7001 + assert connection.port == mocks_srv_ports[0] parse_response.side_effect = parse_response_mock_second return "MOCK_OK" def parse_response_mock_second(connection, *args, **options): # Replica - assert connection.port == 7002 + assert connection.port == mocks_srv_ports[1] parse_response.side_effect = parse_response_mock_third return "MOCK_OK" def parse_response_mock_third(connection, *args, **options): # Primary - assert connection.port == 7001 + assert connection.port == mocks_srv_ports[2] return "MOCK_OK" # We don't need to create a real cluster connection but we @@ -630,9 +660,13 @@ def parse_response_mock_third(connection, *args, **options): # Create a cluster with reading from replications read_cluster = get_mocked_redis_client( - host=default_host, port=default_port, read_from_replicas=True + host=default_host, + port=default_port, + read_from_replicas=read_from_replicas, + load_balancing_strategy=load_balancing_strategy, ) - assert read_cluster.read_from_replicas is True + assert read_cluster.read_from_replicas is read_from_replicas + assert read_cluster.load_balancing_strategy is load_balancing_strategy # Check that we read from the slot's nodes in a round robin # matter. # 'foo' belongs to slot 12182 and the slot's nodes are: @@ -640,16 +674,27 @@ def parse_response_mock_third(connection, *args, **options): read_cluster.get("foo") read_cluster.get("foo") read_cluster.get("foo") - mocks["send_command"].assert_has_calls( + expected_calls_list = [] + expected_calls_list.append(call("READONLY")) + expected_calls_list.append(call("GET", "foo", keys=["foo"])) + + if ( + load_balancing_strategy is None + or load_balancing_strategy == LoadBalancingStrategy.ROUND_ROBIN + ): + # in the round robin strategy the primary node can also receive read + # requests and this means that there will be second node connected + expected_calls_list.append(call("READONLY")) + + expected_calls_list.extend( [ - call("READONLY"), - call("GET", "foo", keys=["foo"]), - call("READONLY"), call("GET", "foo", keys=["foo"]), call("GET", "foo", keys=["foo"]), ] ) + mocks["send_command"].assert_has_calls(expected_calls_list) + def test_keyslot(self, r): """ Test that method will compute correct key in all supported cases @@ -975,6 +1020,35 @@ def test_get_and_set(self, r): assert r.get("integer") == str(integer).encode() assert r.get("unicode_string").decode("utf-8") == unicode_string + @pytest.mark.parametrize( + "load_balancing_strategy", + [ + LoadBalancingStrategy.ROUND_ROBIN, + LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, + LoadBalancingStrategy.RANDOM_REPLICA, + ], + ) + def test_get_and_set_with_load_balanced_client( + self, request, load_balancing_strategy: LoadBalancingStrategy + ) -> None: + r = _get_client( + cls=RedisCluster, + request=request, + load_balancing_strategy=load_balancing_strategy, + ) + + # get and set can't be tested independently of each other + assert r.get("a") is None + + byte_string = b"value" + assert r.set("byte_string", byte_string) + + # run the get command for the same key several times + # to iterate over the read nodes + assert r.get("byte_string") == byte_string + assert r.get("byte_string") == byte_string + assert r.get("byte_string") == byte_string + def test_mget_nonatomic(self, r): assert r.mget_nonatomic([]) == [] assert r.mget_nonatomic(["a", "b"]) == [None, None] @@ -2473,6 +2547,8 @@ def test_load_balancer(self, r): primary2_name = n_manager.slots_cache[slot_2][0].name list1_size = len(n_manager.slots_cache[slot_1]) list2_size = len(n_manager.slots_cache[slot_2]) + + # default load balancer strategy: LoadBalancerStrategy.ROUND_ROBIN # slot 1 assert lb.get_server_index(primary1_name, list1_size) == 0 assert lb.get_server_index(primary1_name, list1_size) == 1 @@ -2487,6 +2563,29 @@ def test_load_balancer(self, r): assert lb.get_server_index(primary1_name, list1_size) == 0 assert lb.get_server_index(primary2_name, list2_size) == 0 + # reset the indexes before load balancing strategy test + lb.reset() + # load balancer strategy: LoadBalancerStrategy.ROUND_ROBIN_REPLICAS + for i in [1, 2, 1]: + srv_index = lb.get_server_index( + primary1_name, + list1_size, + load_balancing_strategy=LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, + ) + assert srv_index == i + + # reset the indexes before load balancing strategy test + lb.reset() # reset the indexes + # load balancer strategy: LoadBalancerStrategy.RANDOM_REPLICA + for i in range(5): + srv_index = lb.get_server_index( + primary1_name, + list1_size, + load_balancing_strategy=LoadBalancingStrategy.RANDOM_REPLICA, + ) + + assert srv_index > 0 and srv_index <= 2 + def test_init_slots_cache_not_all_slots_covered(self): """ Test that if not all slots are covered it should raise an exception @@ -3333,6 +3432,45 @@ def test_readonly_pipeline_from_readonly_client(self, request): break assert executed_on_replica is True + @pytest.mark.parametrize( + "load_balancing_strategy", + [ + LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, + LoadBalancingStrategy.RANDOM_REPLICA, + ], + ) + def test_readonly_pipeline_with_reading_from_replicas_strategies( + self, request, load_balancing_strategy: LoadBalancingStrategy + ) -> None: + """ + Test that the pipeline uses replicas for different replica-based + load balancing strategies. + """ + ro = _get_client( + RedisCluster, + request, + load_balancing_strategy=load_balancing_strategy, + ) + key = "bar" + ro.set(key, "foo") + import time + + time.sleep(0.2) + + with ro.pipeline() as readonly_pipe: + mock_all_nodes_resp(ro, "MOCK_OK") + assert readonly_pipe.load_balancing_strategy == load_balancing_strategy + assert readonly_pipe.get(key).get(key).execute() == ["MOCK_OK", "MOCK_OK"] + slot_nodes = ro.nodes_manager.slots_cache[ro.keyslot(key)] + executed_on_replicas_only = True + for node in slot_nodes: + if node.server_type == PRIMARY: + conn = node.redis_connection.connection + if conn.read_response.called: + executed_on_replicas_only = False + break + assert executed_on_replicas_only + @pytest.mark.onlycluster class TestClusterMonitor: diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index 79301b93f1..549eeb49a2 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -23,7 +23,7 @@ class TestMultiprocessing: # The code in this module does not work with it, # hence the explicit change to 'fork' # See https://github.com/python/cpython/issues/125714 - if multiprocessing.get_start_method() == "forkserver": + if multiprocessing.get_start_method() in ["forkserver", "spawn"]: _mp_context = multiprocessing.get_context(method="fork") else: _mp_context = multiprocessing.get_context() @@ -119,7 +119,7 @@ def target(pool, parent_conn): assert child_conn in pool._available_connections assert parent_conn not in pool._available_connections - proc = multiprocessing.Process(target=target, args=(pool, parent_conn)) + proc = self._mp_context.Process(target=target, args=(pool, parent_conn)) proc.start() proc.join(3) assert proc.exitcode == 0 From 13e68afc3f360d7a93b5feb96c41a946fa6db591 Mon Sep 17 00:00:00 2001 From: Jim Cameron-Burn Date: Mon, 24 Mar 2025 05:31:04 +0000 Subject: [PATCH 072/113] Exponential with jitter backoff (#3550) --- redis/backoff.py | 15 +++++++++++++++ tests/test_backoff.py | 18 ++++++++++++++++++ 2 files changed, 33 insertions(+) create mode 100644 tests/test_backoff.py diff --git a/redis/backoff.py b/redis/backoff.py index f612d60704..e236764d71 100644 --- a/redis/backoff.py +++ b/redis/backoff.py @@ -110,5 +110,20 @@ def compute(self, failures: int) -> float: return self._previous_backoff +class ExponentialWithJitterBackoff(AbstractBackoff): + """Exponential backoff upon failure, with jitter""" + + def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None: + """ + `cap`: maximum backoff time in seconds + `base`: base backoff time in seconds + """ + self._cap = cap + self._base = base + + def compute(self, failures: int) -> float: + return min(self._cap, random.random() * self._base * 2**failures) + + def default_backoff(): return EqualJitterBackoff() diff --git a/tests/test_backoff.py b/tests/test_backoff.py new file mode 100644 index 0000000000..0a491276ff --- /dev/null +++ b/tests/test_backoff.py @@ -0,0 +1,18 @@ +from unittest.mock import Mock + +import pytest + +from redis.backoff import ExponentialWithJitterBackoff + + +def test_exponential_with_jitter_backoff(monkeypatch: pytest.MonkeyPatch) -> None: + mock_random = Mock(side_effect=[0.25, 0.5, 0.75, 1.0, 0.9]) + monkeypatch.setattr("random.random", mock_random) + + bo = ExponentialWithJitterBackoff(cap=5, base=1) + + assert bo.compute(0) == 0.25 # min(5, 0.25*2^0) + assert bo.compute(1) == 1.0 # min(5, 0.5*2^1) + assert bo.compute(2) == 3.0 # min(5, 0.75*2^2) + assert bo.compute(3) == 5.0 # min(5, 1*2^3) + assert bo.compute(4) == 5.0 # min(5, 0.9*2^4) From 09c6ff995b758523d6b1f1e7e9f65f1803c55c70 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Mon, 24 Mar 2025 14:28:21 +0200 Subject: [PATCH 073/113] Adding new hash commands with expiration options - HGETDEL, HGETEX, HSETEX (#3570) --- redis/commands/core.py | 272 ++++++++++++++++++++------- redis/utils.py | 43 ++++- tests/test_asyncio/test_commands.py | 13 +- tests/test_asyncio/test_hash.py | 276 ++++++++++++++++++++++++++++ tests/test_asyncio/test_utils.py | 8 + tests/test_commands.py | 14 +- tests/test_hash.py | 247 +++++++++++++++++++++++++ tests/test_utils.py | 7 + 8 files changed, 788 insertions(+), 92 deletions(-) create mode 100644 tests/test_asyncio/test_utils.py diff --git a/redis/commands/core.py b/redis/commands/core.py index df76eafed0..271f640dec 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -3,6 +3,7 @@ import datetime import hashlib import warnings +from enum import Enum from typing import ( TYPE_CHECKING, Any, @@ -44,6 +45,10 @@ TimeoutSecT, ZScoreBoundT, ) +from redis.utils import ( + deprecated_function, + extract_expire_flags, +) from .helpers import list_or_args @@ -1837,10 +1842,10 @@ def getdel(self, name: KeyT) -> ResponseT: def getex( self, name: KeyT, - ex: Union[ExpiryT, None] = None, - px: Union[ExpiryT, None] = None, - exat: Union[AbsExpiryT, None] = None, - pxat: Union[AbsExpiryT, None] = None, + ex: Optional[ExpiryT] = None, + px: Optional[ExpiryT] = None, + exat: Optional[AbsExpiryT] = None, + pxat: Optional[AbsExpiryT] = None, persist: bool = False, ) -> ResponseT: """ @@ -1863,7 +1868,6 @@ def getex( For more information see https://redis.io/commands/getex """ - opset = {ex, px, exat, pxat} if len(opset) > 2 or len(opset) > 1 and persist: raise DataError( @@ -1871,33 +1875,12 @@ def getex( "and ``persist`` are mutually exclusive." ) - pieces: list[EncodableT] = [] - # similar to set command - if ex is not None: - pieces.append("EX") - if isinstance(ex, datetime.timedelta): - ex = int(ex.total_seconds()) - pieces.append(ex) - if px is not None: - pieces.append("PX") - if isinstance(px, datetime.timedelta): - px = int(px.total_seconds() * 1000) - pieces.append(px) - # similar to pexpireat command - if exat is not None: - pieces.append("EXAT") - if isinstance(exat, datetime.datetime): - exat = int(exat.timestamp()) - pieces.append(exat) - if pxat is not None: - pieces.append("PXAT") - if isinstance(pxat, datetime.datetime): - pxat = int(pxat.timestamp() * 1000) - pieces.append(pxat) + exp_options: list[EncodableT] = extract_expire_flags(ex, px, exat, pxat) + if persist: - pieces.append("PERSIST") + exp_options.append("PERSIST") - return self.execute_command("GETEX", name, *pieces) + return self.execute_command("GETEX", name, *exp_options) def __getitem__(self, name: KeyT): """ @@ -2255,14 +2238,14 @@ def set( self, name: KeyT, value: EncodableT, - ex: Union[ExpiryT, None] = None, - px: Union[ExpiryT, None] = None, + ex: Optional[ExpiryT] = None, + px: Optional[ExpiryT] = None, nx: bool = False, xx: bool = False, keepttl: bool = False, get: bool = False, - exat: Union[AbsExpiryT, None] = None, - pxat: Union[AbsExpiryT, None] = None, + exat: Optional[AbsExpiryT] = None, + pxat: Optional[AbsExpiryT] = None, ) -> ResponseT: """ Set the value at key ``name`` to ``value`` @@ -2292,36 +2275,21 @@ def set( For more information see https://redis.io/commands/set """ + opset = {ex, px, exat, pxat} + if len(opset) > 2 or len(opset) > 1 and keepttl: + raise DataError( + "``ex``, ``px``, ``exat``, ``pxat``, " + "and ``keepttl`` are mutually exclusive." + ) + + if nx and xx: + raise DataError("``nx`` and ``xx`` are mutually exclusive.") + pieces: list[EncodableT] = [name, value] options = {} - if ex is not None: - pieces.append("EX") - if isinstance(ex, datetime.timedelta): - pieces.append(int(ex.total_seconds())) - elif isinstance(ex, int): - pieces.append(ex) - elif isinstance(ex, str) and ex.isdigit(): - pieces.append(int(ex)) - else: - raise DataError("ex must be datetime.timedelta or int") - if px is not None: - pieces.append("PX") - if isinstance(px, datetime.timedelta): - pieces.append(int(px.total_seconds() * 1000)) - elif isinstance(px, int): - pieces.append(px) - else: - raise DataError("px must be datetime.timedelta or int") - if exat is not None: - pieces.append("EXAT") - if isinstance(exat, datetime.datetime): - exat = int(exat.timestamp()) - pieces.append(exat) - if pxat is not None: - pieces.append("PXAT") - if isinstance(pxat, datetime.datetime): - pxat = int(pxat.timestamp() * 1000) - pieces.append(pxat) + + pieces.extend(extract_expire_flags(ex, px, exat, pxat)) + if keepttl: pieces.append("KEEPTTL") @@ -4940,6 +4908,16 @@ def pfmerge(self, dest: KeyT, *sources: KeyT) -> ResponseT: AsyncHyperlogCommands = HyperlogCommands +class HashDataPersistOptions(Enum): + # set the value for each provided key to each + # provided value only if all do not already exist. + FNX = "FNX" + + # set the value for each provided key to each + # provided value only if all already exist. + FXX = "FXX" + + class HashCommands(CommandsProtocol): """ Redis commands for Hash data type. @@ -4980,6 +4958,80 @@ def hgetall(self, name: str) -> Union[Awaitable[dict], dict]: """ return self.execute_command("HGETALL", name, keys=[name]) + def hgetdel( + self, name: str, *keys: str + ) -> Union[ + Awaitable[Optional[List[Union[str, bytes]]]], Optional[List[Union[str, bytes]]] + ]: + """ + Return the value of ``key`` within the hash ``name`` and + delete the field in the hash. + This command is similar to HGET, except for the fact that it also deletes + the key on success from the hash with the provided ```name```. + + Available since Redis 8.0 + For more information see https://redis.io/commands/hgetdel + """ + if len(keys) == 0: + raise DataError("'hgetdel' should have at least one key provided") + + return self.execute_command("HGETDEL", name, "FIELDS", len(keys), *keys) + + def hgetex( + self, + name: KeyT, + *keys: str, + ex: Optional[ExpiryT] = None, + px: Optional[ExpiryT] = None, + exat: Optional[AbsExpiryT] = None, + pxat: Optional[AbsExpiryT] = None, + persist: bool = False, + ) -> Union[ + Awaitable[Optional[List[Union[str, bytes]]]], Optional[List[Union[str, bytes]]] + ]: + """ + Return the values of ``key`` and ``keys`` within the hash ``name`` + and optionally set their expiration. + + ``ex`` sets an expire flag on ``kyes`` for ``ex`` seconds. + + ``px`` sets an expire flag on ``keys`` for ``px`` milliseconds. + + ``exat`` sets an expire flag on ``keys`` for ``ex`` seconds, + specified in unix time. + + ``pxat`` sets an expire flag on ``keys`` for ``ex`` milliseconds, + specified in unix time. + + ``persist`` remove the time to live associated with the ``keys``. + + Available since Redis 8.0 + For more information see https://redis.io/commands/hgetex + """ + if not keys: + raise DataError("'hgetex' should have at least one key provided") + + opset = {ex, px, exat, pxat} + if len(opset) > 2 or len(opset) > 1 and persist: + raise DataError( + "``ex``, ``px``, ``exat``, ``pxat``, " + "and ``persist`` are mutually exclusive." + ) + + exp_options: list[EncodableT] = extract_expire_flags(ex, px, exat, pxat) + + if persist: + exp_options.append("PERSIST") + + return self.execute_command( + "HGETEX", + name, + *exp_options, + "FIELDS", + len(keys), + *keys, + ) + def hincrby( self, name: str, key: str, amount: int = 1 ) -> Union[Awaitable[int], int]: @@ -5034,8 +5086,10 @@ def hset( For more information see https://redis.io/commands/hset """ + if key is None and not mapping and not items: raise DataError("'hset' with no key value pairs") + pieces = [] if items: pieces.extend(items) @@ -5047,6 +5101,89 @@ def hset( return self.execute_command("HSET", name, *pieces) + def hsetex( + self, + name: str, + key: Optional[str] = None, + value: Optional[str] = None, + mapping: Optional[dict] = None, + items: Optional[list] = None, + ex: Optional[ExpiryT] = None, + px: Optional[ExpiryT] = None, + exat: Optional[AbsExpiryT] = None, + pxat: Optional[AbsExpiryT] = None, + data_persist_option: Optional[HashDataPersistOptions] = None, + keepttl: bool = False, + ) -> Union[Awaitable[int], int]: + """ + Set ``key`` to ``value`` within hash ``name`` + + ``mapping`` accepts a dict of key/value pairs that will be + added to hash ``name``. + + ``items`` accepts a list of key/value pairs that will be + added to hash ``name``. + + ``ex`` sets an expire flag on ``keys`` for ``ex`` seconds. + + ``px`` sets an expire flag on ``keys`` for ``px`` milliseconds. + + ``exat`` sets an expire flag on ``keys`` for ``ex`` seconds, + specified in unix time. + + ``pxat`` sets an expire flag on ``keys`` for ``ex`` milliseconds, + specified in unix time. + + ``data_persist_option`` can be set to ``FNX`` or ``FXX`` to control the + behavior of the command. + ``FNX`` will set the value for each provided key to each + provided value only if all do not already exist. + ``FXX`` will set the value for each provided key to each + provided value only if all already exist. + + ``keepttl`` if True, retain the time to live associated with the keys. + + Returns the number of fields that were added. + + Available since Redis 8.0 + For more information see https://redis.io/commands/hsetex + """ + if key is None and not mapping and not items: + raise DataError("'hsetex' with no key value pairs") + + if items and len(items) % 2 != 0: + raise DataError( + "'hsetex' with odd number of items. " + "'items' must contain a list of key/value pairs." + ) + + opset = {ex, px, exat, pxat} + if len(opset) > 2 or len(opset) > 1 and keepttl: + raise DataError( + "``ex``, ``px``, ``exat``, ``pxat``, " + "and ``keepttl`` are mutually exclusive." + ) + + exp_options: list[EncodableT] = extract_expire_flags(ex, px, exat, pxat) + if data_persist_option: + exp_options.append(data_persist_option.value) + + if keepttl: + exp_options.append("KEEPTTL") + + pieces = [] + if items: + pieces.extend(items) + if key is not None: + pieces.extend((key, value)) + if mapping: + for pair in mapping.items(): + pieces.extend(pair) + + return self.execute_command( + "HSETEX", name, *exp_options, "FIELDS", int(len(pieces) / 2), *pieces + ) + def hsetnx(self, name: str, key: str, value: str) -> Union[Awaitable[bool], bool]: """ Set ``key`` to ``value`` within hash ``name`` if ``key`` does not @@ -5056,6 +5193,11 @@ def hsetnx(self, name: str, key: str, value: str) -> Union[Awaitable[bool], bool """ return self.execute_command("HSETNX", name, key, value) + @deprecated_function( + version="4.0.0", + reason="Use 'hset' instead.", + name="hmset", + ) def hmset(self, name: str, mapping: dict) -> Union[Awaitable[str], str]: """ Set key to value within hash ``name`` for each corresponding @@ -5063,12 +5205,6 @@ def hmset(self, name: str, mapping: dict) -> Union[Awaitable[str], str]: For more information see https://redis.io/commands/hmset """ - warnings.warn( - f"{self.__class__.__name__}.hmset() is deprecated. " - f"Use {self.__class__.__name__}.hset() instead.", - DeprecationWarning, - stacklevel=2, - ) if not mapping: raise DataError("'hmset' with 'mapping' of length 0") items = [] diff --git a/redis/utils.py b/redis/utils.py index 66465636a1..9d9b4a9580 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -1,7 +1,11 @@ +import datetime import logging from contextlib import contextmanager from functools import wraps -from typing import Any, Dict, Mapping, Union +from typing import Any, Dict, List, Mapping, Optional, Union + +from redis.exceptions import DataError +from redis.typing import AbsExpiryT, EncodableT, ExpiryT try: import hiredis # noqa @@ -257,3 +261,40 @@ def ensure_string(key): return key else: raise TypeError("Key must be either a string or bytes") + + +def extract_expire_flags( + ex: Optional[ExpiryT] = None, + px: Optional[ExpiryT] = None, + exat: Optional[AbsExpiryT] = None, + pxat: Optional[AbsExpiryT] = None, +) -> List[EncodableT]: + exp_options: list[EncodableT] = [] + if ex is not None: + exp_options.append("EX") + if isinstance(ex, datetime.timedelta): + exp_options.append(int(ex.total_seconds())) + elif isinstance(ex, int): + exp_options.append(ex) + elif isinstance(ex, str) and ex.isdigit(): + exp_options.append(int(ex)) + else: + raise DataError("ex must be datetime.timedelta or int") + elif px is not None: + exp_options.append("PX") + if isinstance(px, datetime.timedelta): + exp_options.append(int(px.total_seconds() * 1000)) + elif isinstance(px, int): + exp_options.append(px) + else: + raise DataError("px must be datetime.timedelta or int") + elif exat is not None: + if isinstance(exat, datetime.datetime): + exat = int(exat.timestamp()) + exp_options.extend(["EXAT", exat]) + elif pxat is not None: + if isinstance(pxat, datetime.datetime): + pxat = int(pxat.timestamp() * 1000) + exp_options.extend(["PXAT", pxat]) + + return exp_options diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index 08bd5810f4..bfb6855a0f 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -31,6 +31,7 @@ skip_if_server_version_lt, skip_unless_arch_bits, ) +from tests.test_asyncio.test_utils import redis_server_time if sys.version_info >= (3, 11, 3): from asyncio import timeout as async_timeout @@ -77,12 +78,6 @@ async def slowlog(r: redis.Redis): await r.config_set("slowlog-max-len", old_max_legnth_value) -async def redis_server_time(client: redis.Redis): - seconds, milliseconds = await client.time() - timestamp = float(f"{seconds}.{milliseconds}") - return datetime.datetime.fromtimestamp(timestamp) - - async def get_stream_message(client: redis.Redis, stream: str, message_id: str): """Fetch a stream message and format it as a (message_id, fields) pair""" response = await client.xrange(stream, min=message_id, max=message_id) @@ -2328,12 +2323,8 @@ async def test_hmget(self, r: redis.Redis): assert await r.hmget("a", "a", "b", "c") == [b"1", b"2", b"3"] async def test_hmset(self, r: redis.Redis): - warning_message = ( - r"^Redis(?:Cluster)*\.hmset\(\) is deprecated\. " - r"Use Redis(?:Cluster)*\.hset\(\) instead\.$" - ) h = {b"a": b"1", b"b": b"2", b"c": b"3"} - with pytest.warns(DeprecationWarning, match=warning_message): + with pytest.warns(DeprecationWarning): assert await r.hmset("a", h) assert await r.hgetall("a") == h diff --git a/tests/test_asyncio/test_hash.py b/tests/test_asyncio/test_hash.py index 15e426673b..4fbc02c5fe 100644 --- a/tests/test_asyncio/test_hash.py +++ b/tests/test_asyncio/test_hash.py @@ -2,7 +2,12 @@ import math from datetime import datetime, timedelta +import pytest + +from redis import exceptions +from redis.commands.core import HashDataPersistOptions from tests.conftest import skip_if_server_version_lt +from tests.test_asyncio.test_utils import redis_server_time @skip_if_server_version_lt("7.3.240") @@ -299,3 +304,274 @@ async def test_pttl_multiple_fields_mixed_conditions(r): result = await r.hpttl("test:hash", "field1", "field2", "field3") assert 30 * 60000 - 10000 < result[0] <= 30 * 60000 assert result[1:] == [-1, -2] + + +@skip_if_server_version_lt("7.9.0") +async def test_hgetdel(r): + await r.delete("test:hash") + await r.hset("test:hash", "foo", "bar", mapping={"1": 1, "2": 2}) + assert await r.hgetdel("test:hash", "foo", "1") == [b"bar", b"1"] + assert await r.hget("test:hash", "foo") is None + assert await r.hget("test:hash", "1") is None + assert await r.hget("test:hash", "2") == b"2" + assert await r.hgetdel("test:hash", "foo", "1") == [None, None] + assert await r.hget("test:hash", "2") == b"2" + + with pytest.raises(exceptions.DataError): + await r.hgetdel("test:hash") + + +@skip_if_server_version_lt("7.9.0") +async def test_hgetex_no_expiration(r): + await r.delete("test:hash") + await r.hset( + "b", "foo", "bar", mapping={"1": 1, "2": 2, "3": "three", "4": b"four"} + ) + + assert await r.hgetex("b", "foo", "1", "4") == [b"bar", b"1", b"four"] + assert await r.hgetex("b", "foo") == [b"bar"] + assert await r.httl("b", "foo", "1", "4") == [-1, -1, -1] + + +@skip_if_server_version_lt("7.9.0") +async def test_hgetex_expiration_configs(r): + await r.delete("test:hash") + await r.hset( + "test:hash", "foo", "bar", mapping={"1": 1, "3": "three", "4": b"four"} + ) + + test_keys = ["foo", "1", "4"] + # test get with multiple fields with expiration set through 'ex' + assert await r.hgetex("test:hash", *test_keys, ex=10) == [ + b"bar", + b"1", + b"four", + ] + ttls = await r.httl("test:hash", *test_keys) + for ttl in ttls: + assert pytest.approx(ttl) == 10 + + # test get with multiple fields removing expiration settings with 'persist' + assert await r.hgetex("test:hash", *test_keys, persist=True) == [ + b"bar", + b"1", + b"four", + ] + assert await r.httl("test:hash", *test_keys) == [-1, -1, -1] + + # test get with multiple fields with expiration set through 'px' + assert await r.hgetex("test:hash", *test_keys, px=6000) == [ + b"bar", + b"1", + b"four", + ] + ttls = await r.httl("test:hash", *test_keys) + for ttl in ttls: + assert pytest.approx(ttl) == 6 + + # test get single field with expiration set through 'pxat' + expire_at = await redis_server_time(r) + timedelta(minutes=1) + assert await r.hgetex("test:hash", "foo", pxat=expire_at) == [b"bar"] + assert (await r.httl("test:hash", "foo"))[0] <= 61 + + # test get single field with expiration set through 'exat' + expire_at = await redis_server_time(r) + timedelta(seconds=10) + assert await r.hgetex("test:hash", "foo", exat=expire_at) == [b"bar"] + assert (await r.httl("test:hash", "foo"))[0] <= 10 + + +@skip_if_server_version_lt("7.9.0") +async def test_hgetex_validate_expired_fields_removed(r): + await r.delete("test:hash") + await r.hset( + "test:hash", "foo", "bar", mapping={"1": 1, "3": "three", "4": b"four"} + ) + + # test get multiple fields with expiration set + # validate that expired fields are removed + assert await r.hgetex("test:hash", "foo", "1", "3", ex=1) == [ + b"bar", + b"1", + b"three", + ] + await asyncio.sleep(1.1) + assert await r.hgetex("test:hash", "foo", "1", "3") == [None, None, None] + assert await r.httl("test:hash", "foo", "1", "3") == [-2, -2, -2] + assert await r.hgetex("test:hash", "4") == [b"four"] + + +@skip_if_server_version_lt("7.9.0") +async def test_hgetex_invalid_inputs(r): + with pytest.raises(exceptions.DataError): + await r.hgetex("b", "foo", ex=10, persist=True) + + with pytest.raises(exceptions.DataError): + await r.hgetex("b", "foo", ex=10.0, persist=True) + + with pytest.raises(exceptions.DataError): + await r.hgetex("b", "foo", ex=10, px=6000) + + with pytest.raises(exceptions.DataError): + await r.hgetex("b", ex=10) + + +@skip_if_server_version_lt("7.9.0") +async def test_hsetex_no_expiration(r): + await r.delete("test:hash") + + # # set items from mapping without expiration + assert await r.hsetex("test:hash", None, None, mapping={"1": 1, "4": b"four"}) == 1 + assert await r.httl("test:hash", "foo", "1", "4") == [-2, -1, -1] + assert await r.hgetex("test:hash", "foo", "1") == [None, b"1"] + + +@skip_if_server_version_lt("7.9.0") +async def test_hsetex_expiration_ex_and_keepttl(r): + await r.delete("test:hash") + + # set items from key/value provided + # combined with mapping and items with expiration - testing ex field + assert ( + await r.hsetex( + "test:hash", + "foo", + "bar", + mapping={"1": 1, "2": "2"}, + items=["i1", 11, "i2", 22], + ex=10, + ) + == 1 + ) + test_keys = ["foo", "1", "2", "i1", "i2"] + ttls = await r.httl("test:hash", *test_keys) + for ttl in ttls: + assert pytest.approx(ttl) == 10 + + assert await r.hgetex("test:hash", *test_keys) == [ + b"bar", + b"1", + b"2", + b"11", + b"22", + ] + await asyncio.sleep(1.1) + # validate keepttl + assert await r.hsetex("test:hash", "foo", "bar1", keepttl=True) == 1 + assert 0 < (await r.httl("test:hash", "foo"))[0] < 10 + + +@skip_if_server_version_lt("7.9.0") +async def test_hsetex_expiration_px(r): + await r.delete("test:hash") + # set items from key/value provided and mapping + # with expiration - testing px field + assert ( + await r.hsetex("test:hash", "foo", "bar", mapping={"1": 1, "2": "2"}, px=60000) + == 1 + ) + test_keys = ["foo", "1", "2"] + ttls = await r.httl("test:hash", *test_keys) + for ttl in ttls: + assert pytest.approx(ttl) == 60 + + assert await r.hgetex("test:hash", *test_keys) == [b"bar", b"1", b"2"] + + +@skip_if_server_version_lt("7.9.0") +async def test_hsetex_expiration_pxat_and_fnx(r): + await r.delete("test:hash") + assert ( + await r.hsetex("test:hash", "foo", "bar", mapping={"1": 1, "2": "2"}, ex=30) + == 1 + ) + + expire_at = await redis_server_time(r) + timedelta(minutes=1) + assert ( + await r.hsetex( + "test:hash", + "foo", + "bar1", + mapping={"new": "ok"}, + pxat=expire_at, + data_persist_option=HashDataPersistOptions.FNX, + ) + == 0 + ) + ttls = await r.httl("test:hash", "foo", "new") + assert ttls[0] <= 30 + assert ttls[1] == -2 + + assert await r.hgetex("test:hash", "foo", "1", "new") == [b"bar", b"1", None] + assert ( + await r.hsetex( + "test:hash", + "foo_new", + "bar1", + mapping={"new": "ok"}, + pxat=expire_at, + data_persist_option=HashDataPersistOptions.FNX, + ) + == 1 + ) + ttls = await r.httl("test:hash", "foo", "new") + for ttl in ttls: + assert ttl <= 61 + assert await r.hgetex("test:hash", "foo", "foo_new", "new") == [ + b"bar", + b"bar1", + b"ok", + ] + + +@skip_if_server_version_lt("7.9.0") +async def test_hsetex_expiration_exat_and_fxx(r): + await r.delete("test:hash") + assert ( + await r.hsetex("test:hash", "foo", "bar", mapping={"1": 1, "2": "2"}, ex=30) + == 1 + ) + + expire_at = await redis_server_time(r) + timedelta(seconds=10) + assert ( + await r.hsetex( + "test:hash", + "foo", + "bar1", + mapping={"new": "ok"}, + exat=expire_at, + data_persist_option=HashDataPersistOptions.FXX, + ) + == 0 + ) + ttls = await r.httl("test:hash", "foo", "new") + assert 10 < ttls[0] <= 30 + assert ttls[1] == -2 + + assert await r.hgetex("test:hash", "foo", "1", "new") == [b"bar", b"1", None] + assert ( + await r.hsetex( + "test:hash", + "foo", + "bar1", + mapping={"1": "new_value"}, + exat=expire_at, + data_persist_option=HashDataPersistOptions.FXX, + ) + == 1 + ) + assert await r.hgetex("test:hash", "foo", "1") == [b"bar1", b"new_value"] + + +@skip_if_server_version_lt("7.9.0") +async def test_hsetex_invalid_inputs(r): + with pytest.raises(exceptions.DataError): + await r.hsetex("b", "foo", "bar", ex=10.0) + + with pytest.raises(exceptions.DataError): + await r.hsetex("b", None, None) + + with pytest.raises(exceptions.DataError): + await r.hsetex("b", "foo", "bar", items=["i1", 11, "i2"], px=6000) + + with pytest.raises(exceptions.DataError): + await r.hsetex("b", "foo", "bar", ex=10, keepttl=True) diff --git a/tests/test_asyncio/test_utils.py b/tests/test_asyncio/test_utils.py new file mode 100644 index 0000000000..05cad1bfaf --- /dev/null +++ b/tests/test_asyncio/test_utils.py @@ -0,0 +1,8 @@ +from datetime import datetime +import redis + + +async def redis_server_time(client: redis.Redis): + seconds, milliseconds = await client.time() + timestamp = float(f"{seconds}.{milliseconds}") + return datetime.fromtimestamp(timestamp) diff --git a/tests/test_commands.py b/tests/test_commands.py index 5c72a019ba..8758efa771 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -21,6 +21,7 @@ from redis.commands.json.path import Path from redis.commands.search.field import TextField from redis.commands.search.query import Query +from tests.test_utils import redis_server_time from .conftest import ( _get_client, @@ -50,12 +51,6 @@ def cleanup(): r.config_set("slowlog-max-len", 128) -def redis_server_time(client): - seconds, milliseconds = client.time() - timestamp = float(f"{seconds}.{milliseconds}") - return datetime.datetime.fromtimestamp(timestamp) - - def get_stream_message(client, stream, message_id): "Fetch a stream message and format it as a (message_id, fields) pair" response = client.xrange(stream, min=message_id, max=message_id) @@ -3393,13 +3388,8 @@ def test_hmget(self, r): assert r.hmget("a", "a", "b", "c") == [b"1", b"2", b"3"] def test_hmset(self, r): - redis_class = type(r).__name__ - warning_message = ( - r"^{0}\.hmset\(\) is deprecated\. " - r"Use {0}\.hset\(\) instead\.$".format(redis_class) - ) h = {b"a": b"1", b"b": b"2", b"c": b"3"} - with pytest.warns(DeprecationWarning, match=warning_message): + with pytest.warns(DeprecationWarning): assert r.hmset("a", h) assert r.hgetall("a") == h diff --git a/tests/test_hash.py b/tests/test_hash.py index 0422185865..c2a92fb852 100644 --- a/tests/test_hash.py +++ b/tests/test_hash.py @@ -3,7 +3,10 @@ from datetime import datetime, timedelta import pytest +from redis import exceptions +from redis.commands.core import HashDataPersistOptions from tests.conftest import skip_if_server_version_lt +from tests.test_utils import redis_server_time @skip_if_server_version_lt("7.3.240") @@ -368,3 +371,247 @@ def test_hpttl_multiple_fields_mixed_conditions(r): def test_hpttl_nonexistent_key(r): r.delete("test:hash") assert r.hpttl("test:hash", "field1", "field2", "field3") == [-2, -2, -2] + + +@skip_if_server_version_lt("7.9.0") +def test_hgetdel(r): + r.delete("test:hash") + r.hset("test:hash", "foo", "bar", mapping={"1": 1, "2": 2}) + assert r.hgetdel("test:hash", "foo", "1") == [b"bar", b"1"] + assert r.hget("test:hash", "foo") is None + assert r.hget("test:hash", "1") is None + assert r.hget("test:hash", "2") == b"2" + assert r.hgetdel("test:hash", "foo", "1") == [None, None] + assert r.hget("test:hash", "2") == b"2" + + with pytest.raises(exceptions.DataError): + r.hgetdel("test:hash") + + +@skip_if_server_version_lt("7.9.0") +def test_hgetex_no_expiration(r): + r.delete("test:hash") + r.hset("b", "foo", "bar", mapping={"1": 1, "2": 2, "3": "three", "4": b"four"}) + + assert r.hgetex("b", "foo", "1", "4") == [b"bar", b"1", b"four"] + assert r.httl("b", "foo", "1", "4") == [-1, -1, -1] + + +@skip_if_server_version_lt("7.9.0") +def test_hgetex_expiration_configs(r): + r.delete("test:hash") + r.hset("test:hash", "foo", "bar", mapping={"1": 1, "3": "three", "4": b"four"}) + test_keys = ["foo", "1", "4"] + + # test get with multiple fields with expiration set through 'ex' + assert r.hgetex("test:hash", *test_keys, ex=10) == [b"bar", b"1", b"four"] + ttls = r.httl("test:hash", *test_keys) + for ttl in ttls: + assert pytest.approx(ttl) == 10 + + # test get with multiple fields removing expiration settings with 'persist' + assert r.hgetex("test:hash", *test_keys, persist=True) == [ + b"bar", + b"1", + b"four", + ] + assert r.httl("test:hash", *test_keys) == [-1, -1, -1] + + # test get with multiple fields with expiration set through 'px' + assert r.hgetex("test:hash", *test_keys, px=6000) == [b"bar", b"1", b"four"] + ttls = r.httl("test:hash", *test_keys) + for ttl in ttls: + assert pytest.approx(ttl) == 6 + + # test get single field with expiration set through 'pxat' + expire_at = redis_server_time(r) + timedelta(minutes=1) + assert r.hgetex("test:hash", "foo", pxat=expire_at) == [b"bar"] + assert r.httl("test:hash", "foo")[0] <= 61 + + # test get single field with expiration set through 'exat' + expire_at = redis_server_time(r) + timedelta(seconds=10) + assert r.hgetex("test:hash", "foo", exat=expire_at) == [b"bar"] + assert r.httl("test:hash", "foo")[0] <= 10 + + +@skip_if_server_version_lt("7.9.0") +def test_hgetex_validate_expired_fields_removed(r): + r.delete("test:hash") + r.hset("test:hash", "foo", "bar", mapping={"1": 1, "3": "three", "4": b"four"}) + + test_keys = ["foo", "1", "3"] + # test get multiple fields with expiration set + # validate that expired fields are removed + assert r.hgetex("test:hash", *test_keys, ex=1) == [b"bar", b"1", b"three"] + time.sleep(1.1) + assert r.hgetex("test:hash", *test_keys) == [None, None, None] + assert r.httl("test:hash", *test_keys) == [-2, -2, -2] + assert r.hgetex("test:hash", "4") == [b"four"] + + +@skip_if_server_version_lt("7.9.0") +def test_hgetex_invalid_inputs(r): + with pytest.raises(exceptions.DataError): + r.hgetex("b", "foo", "1", "3", ex=10, persist=True) + + with pytest.raises(exceptions.DataError): + r.hgetex("b", "foo", ex=10.0, persist=True) + + with pytest.raises(exceptions.DataError): + r.hgetex("b", "foo", ex=10, px=6000) + + with pytest.raises(exceptions.DataError): + r.hgetex("b", ex=10) + + +@skip_if_server_version_lt("7.9.0") +def test_hsetex_no_expiration(r): + r.delete("test:hash") + + # # set items from mapping without expiration + assert r.hsetex("test:hash", None, None, mapping={"1": 1, "4": b"four"}) == 1 + assert r.httl("test:hash", "foo", "1", "4") == [-2, -1, -1] + assert r.hgetex("test:hash", "foo", "1") == [None, b"1"] + + +@skip_if_server_version_lt("7.9.0") +def test_hsetex_expiration_ex_and_keepttl(r): + r.delete("test:hash") + + # set items from key/value provided + # combined with mapping and items with expiration - testing ex field + assert ( + r.hsetex( + "test:hash", + "foo", + "bar", + mapping={"1": 1, "2": "2"}, + items=["i1", 11, "i2", 22], + ex=10, + ) + == 1 + ) + ttls = r.httl("test:hash", "foo", "1", "2", "i1", "i2") + for ttl in ttls: + assert pytest.approx(ttl) == 10 + + assert r.hgetex("test:hash", "foo", "1", "2", "i1", "i2") == [ + b"bar", + b"1", + b"2", + b"11", + b"22", + ] + time.sleep(1.1) + # validate keepttl + assert r.hsetex("test:hash", "foo", "bar1", keepttl=True) == 1 + assert r.httl("test:hash", "foo")[0] < 10 + + +@skip_if_server_version_lt("7.9.0") +def test_hsetex_expiration_px(r): + r.delete("test:hash") + # set items from key/value provided and mapping + # with expiration - testing px field + assert ( + r.hsetex("test:hash", "foo", "bar", mapping={"1": 1, "2": "2"}, px=60000) == 1 + ) + test_keys = ["foo", "1", "2"] + ttls = r.httl("test:hash", *test_keys) + for ttl in ttls: + assert pytest.approx(ttl) == 60 + assert r.hgetex("test:hash", *test_keys) == [b"bar", b"1", b"2"] + + +@skip_if_server_version_lt("7.9.0") +def test_hsetex_expiration_pxat_and_fnx(r): + r.delete("test:hash") + assert r.hsetex("test:hash", "foo", "bar", mapping={"1": 1, "2": "2"}, ex=30) == 1 + + expire_at = redis_server_time(r) + timedelta(minutes=1) + assert ( + r.hsetex( + "test:hash", + "foo", + "bar1", + mapping={"new": "ok"}, + pxat=expire_at, + data_persist_option=HashDataPersistOptions.FNX, + ) + == 0 + ) + ttls = r.httl("test:hash", "foo", "new") + assert ttls[0] <= 30 + assert ttls[1] == -2 + + assert r.hgetex("test:hash", "foo", "1", "new") == [b"bar", b"1", None] + assert ( + r.hsetex( + "test:hash", + "foo_new", + "bar1", + mapping={"new": "ok"}, + pxat=expire_at, + data_persist_option=HashDataPersistOptions.FNX, + ) + == 1 + ) + ttls = r.httl("test:hash", "foo", "new") + for ttl in ttls: + assert ttl <= 61 + assert r.hgetex("test:hash", "foo", "foo_new", "new") == [ + b"bar", + b"bar1", + b"ok", + ] + + +@skip_if_server_version_lt("7.9.0") +def test_hsetex_expiration_exat_and_fxx(r): + r.delete("test:hash") + assert r.hsetex("test:hash", "foo", "bar", mapping={"1": 1, "2": "2"}, ex=30) == 1 + + expire_at = redis_server_time(r) + timedelta(seconds=10) + assert ( + r.hsetex( + "test:hash", + "foo", + "bar1", + mapping={"new": "ok"}, + exat=expire_at, + data_persist_option=HashDataPersistOptions.FXX, + ) + == 0 + ) + ttls = r.httl("test:hash", "foo", "new") + assert 10 < ttls[0] <= 30 + assert ttls[1] == -2 + + assert r.hgetex("test:hash", "foo", "1", "new") == [b"bar", b"1", None] + assert ( + r.hsetex( + "test:hash", + "foo", + "bar1", + mapping={"1": "new_value"}, + exat=expire_at, + data_persist_option=HashDataPersistOptions.FXX, + ) + == 1 + ) + assert r.hgetex("test:hash", "foo", "1") == [b"bar1", b"new_value"] + + +@skip_if_server_version_lt("7.9.0") +def test_hsetex_invalid_inputs(r): + with pytest.raises(exceptions.DataError): + r.hsetex("b", "foo", "bar", ex=10.0) + + with pytest.raises(exceptions.DataError): + r.hsetex("b", None, None) + + with pytest.raises(exceptions.DataError): + r.hsetex("b", "foo", "bar", items=["i1", 11, "i2"], px=6000) + + with pytest.raises(exceptions.DataError): + r.hsetex("b", "foo", "bar", ex=10, keepttl=True) diff --git a/tests/test_utils.py b/tests/test_utils.py index 764ef5d0a9..75de8dbb9f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,3 +1,4 @@ +from datetime import datetime import pytest from redis.utils import compare_versions @@ -25,3 +26,9 @@ ) def test_compare_versions(version1, version2, expected_res): assert compare_versions(version1, version2) == expected_res + + +def redis_server_time(client): + seconds, milliseconds = client.time() + timestamp = float(f"{seconds}.{milliseconds}") + return datetime.fromtimestamp(timestamp) From a9d02605c3c862bdeb6c6f6f7f302ae7030d33b5 Mon Sep 17 00:00:00 2001 From: Rohan Singh Date: Tue, 25 Mar 2025 17:16:27 +0100 Subject: [PATCH 074/113] Truncate pipeline exception message to a sane size (#3530) Fixes #20234. --- redis/asyncio/client.py | 6 +++++- redis/asyncio/cluster.py | 6 ++++-- redis/client.py | 4 +++- redis/cluster.py | 4 +++- redis/utils.py | 7 +++++++ tests/test_asyncio/test_cluster.py | 19 +++++++++++++++++++ tests/test_asyncio/test_pipeline.py | 16 ++++++++++++++++ tests/test_cluster.py | 19 +++++++++++++++++++ tests/test_pipeline.py | 16 ++++++++++++++++ 9 files changed, 92 insertions(+), 5 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 0039cea540..a35a5f1f8c 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -77,6 +77,7 @@ get_lib_version, safe_str, str_if_bytes, + truncate_text, ) if TYPE_CHECKING and SSL_AVAILABLE: @@ -1513,7 +1514,10 @@ def annotate_exception( self, exception: Exception, number: int, command: Iterable[object] ) -> None: cmd = " ".join(map(safe_str, command)) - msg = f"Command # {number} ({cmd}) of pipeline caused error: {exception.args}" + msg = ( + f"Command # {number} ({truncate_text(cmd)}) " + "of pipeline caused error: {exception.args}" + ) exception.args = (msg,) + exception.args[1:] async def parse_response( diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 7a29550a35..e679a377d7 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -71,6 +71,7 @@ get_lib_version, safe_str, str_if_bytes, + truncate_text, ) if SSL_AVAILABLE: @@ -1648,8 +1649,9 @@ async def _execute( if isinstance(result, Exception): command = " ".join(map(safe_str, cmd.args)) msg = ( - f"Command # {cmd.position + 1} ({command}) of pipeline " - f"caused error: {result.args}" + f"Command # {cmd.position + 1} " + f"({truncate_text(command)}) " + f"of pipeline caused error: {result.args}" ) result.args = (msg,) + result.args[1:] raise result diff --git a/redis/client.py b/redis/client.py index 2c4a1fadff..9fb89ec5cd 100755 --- a/redis/client.py +++ b/redis/client.py @@ -61,6 +61,7 @@ get_lib_version, safe_str, str_if_bytes, + truncate_text, ) if TYPE_CHECKING: @@ -1524,7 +1525,8 @@ def raise_first_error(self, commands, response): def annotate_exception(self, exception, number, command): cmd = " ".join(map(safe_str, command)) msg = ( - f"Command # {number} ({cmd}) of pipeline caused error: {exception.args[0]}" + f"Command # {number} ({truncate_text(cmd)}) of pipeline " + f"caused error: {exception.args[0]}" ) exception.args = (msg,) + exception.args[1:] diff --git a/redis/cluster.py b/redis/cluster.py index 0488608a60..4ec03ac98f 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -47,6 +47,7 @@ merge_result, safe_str, str_if_bytes, + truncate_text, ) @@ -2125,7 +2126,8 @@ def annotate_exception(self, exception, number, command): """ cmd = " ".join(map(safe_str, command)) msg = ( - f"Command # {number} ({cmd}) of pipeline caused error: {exception.args[0]}" + f"Command # {number} ({truncate_text(cmd)}) of pipeline " + f"caused error: {exception.args[0]}" ) exception.args = (msg,) + exception.args[1:] diff --git a/redis/utils.py b/redis/utils.py index 9d9b4a9580..1f0b24d768 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -1,5 +1,6 @@ import datetime import logging +import textwrap from contextlib import contextmanager from functools import wraps from typing import Any, Dict, List, Mapping, Optional, Union @@ -298,3 +299,9 @@ def extract_expire_flags( exp_options.extend(["PXAT", pxat]) return exp_options + + +def truncate_text(txt, max_length=100): + return textwrap.shorten( + text=txt, width=max_length, placeholder="...", break_long_words=True + ) diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index a4f0636299..5a52da3d80 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -2926,6 +2926,25 @@ async def test_asking_error(self, r: RedisCluster) -> None: assert ask_node._free.pop().read_response.await_count assert res == ["MOCK_OK"] + async def test_error_is_truncated(self, r) -> None: + """ + Test that an error from the pipeline is truncated correctly. + """ + key = "a" * 50 + a_value = "a" * 20 + b_value = "b" * 20 + + async with r.pipeline() as pipe: + pipe.set(key, 1) + pipe.hset(key, mapping={"field_a": a_value, "field_b": b_value}) + pipe.expire(key, 100) + + with pytest.raises(Exception) as ex: + await pipe.execute() + + expected = f"Command # 2 (HSET {key} field_a {a_value} field_b...) of pipeline caused error: " + assert str(ex.value).startswith(expected) + async def test_moved_redirection_on_slave_with_default( self, r: RedisCluster ) -> None: diff --git a/tests/test_asyncio/test_pipeline.py b/tests/test_asyncio/test_pipeline.py index 31759d84a3..19e11dc792 100644 --- a/tests/test_asyncio/test_pipeline.py +++ b/tests/test_asyncio/test_pipeline.py @@ -368,6 +368,22 @@ async def test_exec_error_in_no_transaction_pipeline_unicode_command(self, r): assert await r.get(key) == b"1" + async def test_exec_error_in_pipeline_truncated(self, r): + key = "a" * 50 + a_value = "a" * 20 + b_value = "b" * 20 + + await r.set(key, 1) + async with r.pipeline(transaction=False) as pipe: + pipe.hset(key, mapping={"field_a": a_value, "field_b": b_value}) + pipe.expire(key, 100) + + with pytest.raises(redis.ResponseError) as ex: + await pipe.execute() + + expected = f"Command # 1 (HSET {key} field_a {a_value} field_b...) of pipeline caused error: " + assert str(ex.value).startswith(expected) + async def test_pipeline_with_bitfield(self, r): async with r.pipeline() as pipe: pipe.set("a", "1") diff --git a/tests/test_cluster.py b/tests/test_cluster.py index b71908d396..d96342f87a 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -3315,6 +3315,25 @@ def raise_ask_error(): assert ask_node.redis_connection.connection.read_response.called assert res == ["MOCK_OK"] + def test_error_is_truncated(self, r): + """ + Test that an error from the pipeline is truncated correctly. + """ + key = "a" * 50 + a_value = "a" * 20 + b_value = "b" * 20 + + with r.pipeline() as pipe: + pipe.set(key, 1) + pipe.hset(key, mapping={"field_a": a_value, "field_b": b_value}) + pipe.expire(key, 100) + + with pytest.raises(Exception) as ex: + pipe.execute() + + expected = f"Command # 2 (HSET {key} field_a {a_value} field_b...) of pipeline caused error: " + assert str(ex.value).startswith(expected) + def test_return_previously_acquired_connections(self, r): # in order to ensure that a pipeline will make use of connections # from different nodes diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index be7784ad0b..bbf1ec9eb5 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -369,6 +369,22 @@ def test_exec_error_in_no_transaction_pipeline_unicode_command(self, r): assert r[key] == b"1" + def test_exec_error_in_pipeline_truncated(self, r): + key = "a" * 50 + a_value = "a" * 20 + b_value = "b" * 20 + + r[key] = 1 + with r.pipeline(transaction=False) as pipe: + pipe.hset(key, mapping={"field_a": a_value, "field_b": b_value}) + pipe.expire(key, 100) + + with pytest.raises(redis.ResponseError) as ex: + pipe.execute() + + expected = f"Command # 1 (HSET {key} field_a {a_value} field_b...) of pipeline caused error: " + assert str(ex.value).startswith(expected) + def test_pipeline_with_bitfield(self, r): with r.pipeline() as pipe: pipe.set("a", "1") From 8df6a6c00192067b275e7fe3eedb02f227ff35e4 Mon Sep 17 00:00:00 2001 From: Logan Attwood Date: Tue, 25 Mar 2025 14:10:03 -0300 Subject: [PATCH 075/113] Support using ssl.VerifyMode enum for ssl_cert_reqs (#3346) --- redis/asyncio/client.py | 5 +++-- redis/asyncio/cluster.py | 5 +++-- redis/asyncio/connection.py | 9 +++++---- redis/client.py | 2 +- redis/connection.py | 2 +- 5 files changed, 13 insertions(+), 10 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index a35a5f1f8c..3f35fdd59e 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -81,9 +81,10 @@ ) if TYPE_CHECKING and SSL_AVAILABLE: - from ssl import TLSVersion + from ssl import TLSVersion, VerifyMode else: TLSVersion = None + VerifyMode = None PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]] _KeyT = TypeVar("_KeyT", bound=KeyT) @@ -228,7 +229,7 @@ def __init__( ssl: bool = False, ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None, - ssl_cert_reqs: str = "required", + ssl_cert_reqs: Union[str, VerifyMode] = "required", ssl_ca_certs: Optional[str] = None, ssl_ca_data: Optional[str] = None, ssl_check_hostname: bool = False, diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index e679a377d7..f58ae50a40 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -75,9 +75,10 @@ ) if SSL_AVAILABLE: - from ssl import TLSVersion + from ssl import TLSVersion, VerifyMode else: TLSVersion = None + VerifyMode = None TargetNodesT = TypeVar( "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] @@ -277,7 +278,7 @@ def __init__( ssl: bool = False, ssl_ca_certs: Optional[str] = None, ssl_ca_data: Optional[str] = None, - ssl_cert_reqs: str = "required", + ssl_cert_reqs: Union[str, VerifyMode] = "required", ssl_certfile: Optional[str] = None, ssl_check_hostname: bool = False, ssl_keyfile: Optional[str] = None, diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 66dbd09b61..ddf58cb1c6 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -768,7 +768,7 @@ def __init__( self, ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None, - ssl_cert_reqs: str = "required", + ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required", ssl_ca_certs: Optional[str] = None, ssl_ca_data: Optional[str] = None, ssl_check_hostname: bool = False, @@ -842,7 +842,7 @@ def __init__( self, keyfile: Optional[str] = None, certfile: Optional[str] = None, - cert_reqs: Optional[str] = None, + cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None, ca_certs: Optional[str] = None, ca_data: Optional[str] = None, check_hostname: bool = False, @@ -855,7 +855,7 @@ def __init__( self.keyfile = keyfile self.certfile = certfile if cert_reqs is None: - self.cert_reqs = ssl.CERT_NONE + cert_reqs = ssl.CERT_NONE elif isinstance(cert_reqs, str): CERT_REQS = { # noqa: N806 "none": ssl.CERT_NONE, @@ -866,7 +866,8 @@ def __init__( raise RedisError( f"Invalid SSL Certificate Requirements Flag: {cert_reqs}" ) - self.cert_reqs = CERT_REQS[cert_reqs] + cert_reqs = CERT_REQS[cert_reqs] + self.cert_reqs = cert_reqs self.ca_certs = ca_certs self.ca_data = ca_data self.check_hostname = check_hostname diff --git a/redis/client.py b/redis/client.py index 9fb89ec5cd..e9435d33ef 100755 --- a/redis/client.py +++ b/redis/client.py @@ -211,7 +211,7 @@ def __init__( ssl: bool = False, ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None, - ssl_cert_reqs: str = "required", + ssl_cert_reqs: Union[str, "ssl.VerifyMode"] = "required", ssl_ca_certs: Optional[str] = None, ssl_ca_path: Optional[str] = None, ssl_ca_data: Optional[str] = None, diff --git a/redis/connection.py b/redis/connection.py index f754a5165a..87aa986d17 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1017,7 +1017,7 @@ def __init__( Args: ssl_keyfile: Path to an ssl private key. Defaults to None. ssl_certfile: Path to an ssl certificate. Defaults to None. - ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required). Defaults to "required". + ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required), or an ssl.VerifyMode. Defaults to "required". ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None. ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates. ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to False. From 9f90baf992c6f35e79403177eb060d0cd71f089d Mon Sep 17 00:00:00 2001 From: Robin <167366979+allrob23@users.noreply.github.com> Date: Wed, 26 Mar 2025 06:09:55 -0300 Subject: [PATCH 076/113] Improvement: Use `shutdown()` Before `close()` in connection.py (#3567) --- redis/connection.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/redis/connection.py b/redis/connection.py index 87aa986d17..b323be058b 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -764,6 +764,10 @@ def _connect(self): except OSError as _: err = _ if sock is not None: + try: + sock.shutdown(socket.SHUT_RDWR) # ensure a clean close + except OSError: + pass sock.close() if err is not None: @@ -1179,6 +1183,10 @@ def _connect(self): sock.connect(self.path) except OSError: # Prevent ResourceWarnings for unclosed sockets. + try: + sock.shutdown(socket.SHUT_RDWR) # ensure a clean close + except OSError: + pass sock.close() raise sock.settimeout(self.socket_timeout) From 4e2da482d27e448faf7ffc79eda5fd0f1015427b Mon Sep 17 00:00:00 2001 From: Vladimir Chebotarev Date: Wed, 26 Mar 2025 14:34:22 +0300 Subject: [PATCH 077/113] Fixed infinitely recursive health checks. (#3557) --- redis/asyncio/connection.py | 39 +++++++++++++++++++++++++++++-------- redis/connection.py | 37 ++++++++++++++++++++++++++++------- 2 files changed, 61 insertions(+), 15 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index ddf58cb1c6..7404f3d6f8 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -293,6 +293,9 @@ def set_parser(self, parser_class: Type[BaseParser]) -> None: async def connect(self): """Connects to the Redis server if not already connected""" + await self.connect_check_health(check_health=True) + + async def connect_check_health(self, check_health: bool = True): if self.is_connected: return try: @@ -311,7 +314,7 @@ async def connect(self): try: if not self.redis_connect_func: # Use the default on_connect function - await self.on_connect() + await self.on_connect_check_health(check_health=check_health) else: # Use the passed function redis_connect_func ( @@ -350,6 +353,9 @@ def get_protocol(self): async def on_connect(self) -> None: """Initialize the connection, authenticate and select a database""" + await self.on_connect_check_health(check_health=True) + + async def on_connect_check_health(self, check_health: bool = True) -> None: self._parser.on_connect(self) parser = self._parser @@ -407,7 +413,7 @@ async def on_connect(self) -> None: # update cluster exception classes self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES self._parser.on_connect(self) - await self.send_command("HELLO", self.protocol) + await self.send_command("HELLO", self.protocol, check_health=check_health) response = await self.read_response() # if response.get(b"proto") != self.protocol and response.get( # "proto" @@ -416,18 +422,35 @@ async def on_connect(self) -> None: # if a client_name is given, set it if self.client_name: - await self.send_command("CLIENT", "SETNAME", self.client_name) + await self.send_command( + "CLIENT", + "SETNAME", + self.client_name, + check_health=check_health, + ) if str_if_bytes(await self.read_response()) != "OK": raise ConnectionError("Error setting client name") # set the library name and version, pipeline for lower startup latency if self.lib_name: - await self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name) + await self.send_command( + "CLIENT", + "SETINFO", + "LIB-NAME", + self.lib_name, + check_health=check_health, + ) if self.lib_version: - await self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version) + await self.send_command( + "CLIENT", + "SETINFO", + "LIB-VER", + self.lib_version, + check_health=check_health, + ) # if a database is specified, switch to it. Also pipeline this if self.db: - await self.send_command("SELECT", self.db) + await self.send_command("SELECT", self.db, check_health=check_health) # read responses from pipeline for _ in (sent for sent in (self.lib_name, self.lib_version) if sent): @@ -489,8 +512,8 @@ async def send_packed_command( self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True ) -> None: if not self.is_connected: - await self.connect() - elif check_health: + await self.connect_check_health(check_health=False) + if check_health: await self.check_health() try: diff --git a/redis/connection.py b/redis/connection.py index b323be058b..08e980e866 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -376,6 +376,9 @@ def set_parser(self, parser_class): def connect(self): "Connects to the Redis server if not already connected" + self.connect_check_health(check_health=True) + + def connect_check_health(self, check_health: bool = True): if self._sock: return try: @@ -391,7 +394,7 @@ def connect(self): try: if self.redis_connect_func is None: # Use the default on_connect function - self.on_connect() + self.on_connect_check_health(check_health=check_health) else: # Use the passed function redis_connect_func self.redis_connect_func(self) @@ -421,6 +424,9 @@ def _error_message(self, exception): return format_error_message(self._host_error(), exception) def on_connect(self): + self.on_connect_check_health(check_health=True) + + def on_connect_check_health(self, check_health: bool = True): "Initialize the connection, authenticate and select a database" self._parser.on_connect(self) parser = self._parser @@ -479,7 +485,7 @@ def on_connect(self): # update cluster exception classes self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES self._parser.on_connect(self) - self.send_command("HELLO", self.protocol) + self.send_command("HELLO", self.protocol, check_health=check_health) self.handshake_metadata = self.read_response() if ( self.handshake_metadata.get(b"proto") != self.protocol @@ -489,28 +495,45 @@ def on_connect(self): # if a client_name is given, set it if self.client_name: - self.send_command("CLIENT", "SETNAME", self.client_name) + self.send_command( + "CLIENT", + "SETNAME", + self.client_name, + check_health=check_health, + ) if str_if_bytes(self.read_response()) != "OK": raise ConnectionError("Error setting client name") try: # set the library name and version if self.lib_name: - self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name) + self.send_command( + "CLIENT", + "SETINFO", + "LIB-NAME", + self.lib_name, + check_health=check_health, + ) self.read_response() except ResponseError: pass try: if self.lib_version: - self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version) + self.send_command( + "CLIENT", + "SETINFO", + "LIB-VER", + self.lib_version, + check_health=check_health, + ) self.read_response() except ResponseError: pass # if a database is specified, switch to it if self.db: - self.send_command("SELECT", self.db) + self.send_command("SELECT", self.db, check_health=check_health) if str_if_bytes(self.read_response()) != "OK": raise ConnectionError("Invalid Database") @@ -552,7 +575,7 @@ def check_health(self): def send_packed_command(self, command, check_health=True): """Send an already packed command to the Redis server""" if not self._sock: - self.connect() + self.connect_check_health(check_health=False) # guard against health check recursion if check_health: self.check_health() From 4525c2d8a490a67935dd9b5364a7fea63d5a1b41 Mon Sep 17 00:00:00 2001 From: Paolo Date: Thu, 27 Mar 2025 12:20:13 +0100 Subject: [PATCH 078/113] Fix incorrect link to docs for fcall_ro command (#3576) --- redis/commands/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/commands/core.py b/redis/commands/core.py index 271f640dec..a8c327f08f 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -6556,7 +6556,7 @@ def fcall_ro( This is a read-only variant of the FCALL command that cannot execute commands that modify data. - For more information see https://redis.io/commands/fcal_ro + For more information see https://redis.io/commands/fcall_ro """ return self._fcall("FCALL_RO", function, numkeys, *keys_and_args) From 56e61f8aec6f3aed44507f4a1022f3be783be2bd Mon Sep 17 00:00:00 2001 From: Robert Shelton Date: Wed, 2 Apr 2025 09:03:02 -0400 Subject: [PATCH 079/113] Docs/raae 724/remove redis ventures (#3579) --- docs/examples/search_vector_similarity_examples.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/examples/search_vector_similarity_examples.ipynb b/docs/examples/search_vector_similarity_examples.ipynb index 809dbda4ea..af6d825129 100644 --- a/docs/examples/search_vector_similarity_examples.ipynb +++ b/docs/examples/search_vector_similarity_examples.ipynb @@ -638,7 +638,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Find more example apps, tutorials, and projects using Redis Vector Similarity Search [in this GitHub organization](https://github.com/RedisVentures)." + "Find more example apps, tutorials, and projects using Redis Vector Similarity Search check out the [Redis AI resources repo](https://github.com/redis-developer/redis-ai-resources/tree/main)." ] } ], From 71916ee25c66dcfc0d95a4e8d1eff7ed11b054e9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 3 Apr 2025 16:28:19 +0300 Subject: [PATCH 080/113] Bump rojopolis/spellcheck-github-actions from 0.47.0 to 0.48.0 (#3580) --- .github/workflows/spellcheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/spellcheck.yml b/.github/workflows/spellcheck.yml index beefa6164f..4d0fc338d6 100644 --- a/.github/workflows/spellcheck.yml +++ b/.github/workflows/spellcheck.yml @@ -8,7 +8,7 @@ jobs: - name: Checkout uses: actions/checkout@v4 - name: Check Spelling - uses: rojopolis/spellcheck-github-actions@0.47.0 + uses: rojopolis/spellcheck-github-actions@0.48.0 with: config_path: .github/spellcheck-settings.yml task_name: Markdown From 80936748700e3dfcd0607cc409f332b42954635b Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Fri, 4 Apr 2025 12:06:32 +0300 Subject: [PATCH 081/113] Run pipeline tests against latest 8.0 RC1 image. (#3585) --- .github/workflows/integration.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index 514a88a796..f8aa5c8932 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -74,7 +74,7 @@ jobs: max-parallel: 15 fail-fast: false matrix: - redis-version: ['8.0-M05-pre', '${{ needs.redis_version.outputs.CURRENT }}', '7.2.7', '6.2.17'] + redis-version: ['8.0-RC1-pre', '${{ needs.redis_version.outputs.CURRENT }}', '7.2.7', '6.2.17'] python-version: ['3.8', '3.13'] parser-backend: ['plain'] event-loop: ['asyncio'] From e66e35c1402a882b86b109ad09353df0137570b6 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Fri, 4 Apr 2025 13:11:40 +0300 Subject: [PATCH 082/113] Adding info for sentinel handling failover when Redis client is acquired with master_for() method. (#3578) Co-authored-by: Elena Kolevska --- redis/asyncio/sentinel.py | 2 ++ redis/sentinel.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py index 0389539fcf..fae6875d82 100644 --- a/redis/asyncio/sentinel.py +++ b/redis/asyncio/sentinel.py @@ -326,6 +326,8 @@ def master_for( ): """ Returns a redis client instance for the ``service_name`` master. + Sentinel client will detect failover and reconnect Redis clients + automatically. A :py:class:`~redis.sentinel.SentinelConnectionPool` class is used to retrieve the master's address before establishing a new diff --git a/redis/sentinel.py b/redis/sentinel.py index 521ac24142..02aa244ede 100644 --- a/redis/sentinel.py +++ b/redis/sentinel.py @@ -349,6 +349,8 @@ def master_for( ): """ Returns a redis client instance for the ``service_name`` master. + Sentinel client will detect failover and reconnect Redis clients + automatically. A :py:class:`~redis.sentinel.SentinelConnectionPool` class is used to retrieve the master's address before establishing a new From 6c0747370cf55bd525fc5447e63a081207591eff Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Mon, 7 Apr 2025 15:09:29 +0300 Subject: [PATCH 083/113] Adding VectorSet commands support. (#3584) --- docker-compose.yml | 4 +- redis/commands/redismodules.py | 8 + redis/commands/vectorset/__init__.py | 46 ++ redis/commands/vectorset/commands.py | 367 ++++++++++++ redis/commands/vectorset/utils.py | 94 +++ tests/test_asyncio/test_vsets.py | 858 +++++++++++++++++++++++++++ tests/test_vsets.py | 856 ++++++++++++++++++++++++++ 7 files changed, 2231 insertions(+), 2 deletions(-) create mode 100644 redis/commands/vectorset/__init__.py create mode 100644 redis/commands/vectorset/commands.py create mode 100644 redis/commands/vectorset/utils.py create mode 100644 tests/test_asyncio/test_vsets.py create mode 100644 tests/test_vsets.py diff --git a/docker-compose.yml b/docker-compose.yml index 76a60398f3..75292bbd03 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,9 +1,9 @@ --- x-client-libs-stack-image: &client-libs-stack-image - image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_STACK_IMAGE_TAG:-rs-7.4.0-v2}" + image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_STACK_IMAGE_TAG:-8.0-M06-pre}" x-client-libs-image: &client-libs-image - image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_IMAGE_TAG:-7.4.2}" + image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_IMAGE_TAG:-8.0-M06-pre}" services: diff --git a/redis/commands/redismodules.py b/redis/commands/redismodules.py index 7ba40dd845..6e253b1597 100644 --- a/redis/commands/redismodules.py +++ b/redis/commands/redismodules.py @@ -72,6 +72,14 @@ def tdigest(self): tdigest = TDigestBloom(client=self) return tdigest + def vset(self): + """Access the VectorSet commands namespace.""" + + from .vectorset import VectorSet + + vset = VectorSet(client=self) + return vset + class AsyncRedisModuleCommands(RedisModuleCommands): def ft(self, index_name="idx"): diff --git a/redis/commands/vectorset/__init__.py b/redis/commands/vectorset/__init__.py new file mode 100644 index 0000000000..d78580a73b --- /dev/null +++ b/redis/commands/vectorset/__init__.py @@ -0,0 +1,46 @@ +import json + +from redis._parsers.helpers import pairs_to_dict +from redis.commands.vectorset.utils import ( + parse_vemb_result, + parse_vlinks_result, + parse_vsim_result, +) + +from ..helpers import get_protocol_version +from .commands import ( + VEMB_CMD, + VGETATTR_CMD, + VINFO_CMD, + VLINKS_CMD, + VSIM_CMD, + VectorSetCommands, +) + + +class VectorSet(VectorSetCommands): + def __init__(self, client, **kwargs): + """Create a new VectorSet client.""" + # Set the module commands' callbacks + self._MODULE_CALLBACKS = { + VEMB_CMD: parse_vemb_result, + VGETATTR_CMD: lambda r: r and json.loads(r) or None, + } + + self._RESP2_MODULE_CALLBACKS = { + VINFO_CMD: lambda r: r and pairs_to_dict(r) or None, + VSIM_CMD: parse_vsim_result, + VLINKS_CMD: parse_vlinks_result, + } + self._RESP3_MODULE_CALLBACKS = {} + + self.client = client + self.execute_command = client.execute_command + + if get_protocol_version(self.client) in ["3", 3]: + self._MODULE_CALLBACKS.update(self._RESP3_MODULE_CALLBACKS) + else: + self._MODULE_CALLBACKS.update(self._RESP2_MODULE_CALLBACKS) + + for k, v in self._MODULE_CALLBACKS.items(): + self.client.set_response_callback(k, v) diff --git a/redis/commands/vectorset/commands.py b/redis/commands/vectorset/commands.py new file mode 100644 index 0000000000..c24bd200ce --- /dev/null +++ b/redis/commands/vectorset/commands.py @@ -0,0 +1,367 @@ +import json +from enum import Enum +from typing import Awaitable, Dict, List, Optional, Union + +from redis.client import NEVER_DECODE +from redis.commands.helpers import get_protocol_version +from redis.exceptions import DataError +from redis.typing import CommandsProtocol, EncodableT, KeyT, Number + +VADD_CMD = "VADD" +VSIM_CMD = "VSIM" +VREM_CMD = "VREM" +VDIM_CMD = "VDIM" +VCARD_CMD = "VCARD" +VEMB_CMD = "VEMB" +VLINKS_CMD = "VLINKS" +VINFO_CMD = "VINFO" +VSETATTR_CMD = "VSETATTR" +VGETATTR_CMD = "VGETATTR" +VRANDMEMBER_CMD = "VRANDMEMBER" + + +class QuantizationOptions(Enum): + """Quantization options for the VADD command.""" + + NOQUANT = "NOQUANT" + BIN = "BIN" + Q8 = "Q8" + + +class CallbacksOptions(Enum): + """Options that can be set for the commands callbacks""" + + RAW = "RAW" + WITHSCORES = "WITHSCORES" + ALLOW_DECODING = "ALLOW_DECODING" + RESP3 = "RESP3" + + +class VectorSetCommands(CommandsProtocol): + """Redis VectorSet commands""" + + def vadd( + self, + key: KeyT, + vector: Union[List[float], bytes], + element: str, + reduce_dim: Optional[int] = None, + cas: Optional[bool] = False, + quantization: Optional[QuantizationOptions] = None, + ef: Optional[Number] = None, + attributes: Optional[Union[dict, str]] = None, + numlinks: Optional[int] = None, + ) -> Union[Awaitable[int], int]: + """ + Add vector ``vector`` for element ``element`` to a vector set ``key``. + + ``reduce_dim`` sets the dimensions to reduce the vector to. + If not provided, the vector is not reduced. + + ``cas`` is a boolean flag that indicates whether to use CAS (check-and-set style) + when adding the vector. If not provided, CAS is not used. + + ``quantization`` sets the quantization type to use. + If not provided, int8 quantization is used. + The options are: + - NOQUANT: No quantization + - BIN: Binary quantization + - Q8: Signed 8-bit quantization + + ``ef`` sets the exploration factor to use. + If not provided, the default exploration factor is used. + + ``attributes`` is a dictionary or json string that contains the attributes to set for the vector. + If not provided, no attributes are set. + + ``numlinks`` sets the number of links to create for the vector. + If not provided, the default number of links is used. + + For more information see https://redis.io/commands/vadd + """ + if not vector or not element: + raise DataError("Both vector and element must be provided") + + pieces = [] + if reduce_dim: + pieces.extend(["REDUCE", reduce_dim]) + + values_pieces = [] + if isinstance(vector, bytes): + values_pieces.extend(["FP32", vector]) + else: + values_pieces.extend(["VALUES", len(vector)]) + values_pieces.extend(vector) + pieces.extend(values_pieces) + + pieces.append(element) + + if cas: + pieces.append("CAS") + + if quantization: + pieces.append(quantization.value) + + if ef: + pieces.extend(["EF", ef]) + + if attributes: + if isinstance(attributes, dict): + # transform attributes to json string + attributes_json = json.dumps(attributes) + else: + attributes_json = attributes + pieces.extend(["SETATTR", attributes_json]) + + if numlinks: + pieces.extend(["M", numlinks]) + + return self.execute_command(VADD_CMD, key, *pieces) + + def vsim( + self, + key: KeyT, + input: Union[List[float], bytes, str], + with_scores: Optional[bool] = False, + count: Optional[int] = None, + ef: Optional[Number] = None, + filter: Optional[str] = None, + filter_ef: Optional[str] = None, + truth: Optional[bool] = False, + no_thread: Optional[bool] = False, + ) -> Union[ + Awaitable[Optional[List[Union[List[EncodableT], Dict[EncodableT, Number]]]]], + Optional[List[Union[List[EncodableT], Dict[EncodableT, Number]]]], + ]: + """ + Compare a vector or element ``input`` with the other vectors in a vector set ``key``. + + ``with_scores`` sets if the results should be returned with the + similarity scores of the elements in the result. + + ``count`` sets the number of results to return. + + ``ef`` sets the exploration factor. + + ``filter`` sets filter that should be applied for the search. + + ``filter_ef`` sets the max filtering effort. + + ``truth`` when enabled forces the command to perform linear scan. + + ``no_thread`` when enabled forces the command to execute the search + on the data structure in the main thread. + + For more information see https://redis.io/commands/vsim + """ + + if not input: + raise DataError("'input' should be provided") + + pieces = [] + options = {} + + if isinstance(input, bytes): + pieces.extend(["FP32", input]) + elif isinstance(input, list): + pieces.extend(["VALUES", len(input)]) + pieces.extend(input) + else: + pieces.extend(["ELE", input]) + + if with_scores: + pieces.append("WITHSCORES") + options[CallbacksOptions.WITHSCORES.value] = True + + if count: + pieces.extend(["COUNT", count]) + + if ef: + pieces.extend(["EF", ef]) + + if filter: + pieces.extend(["FILTER", filter]) + + if filter_ef: + pieces.extend(["FILTER-EF", filter_ef]) + + if truth: + pieces.append("TRUTH") + + if no_thread: + pieces.append("NOTHREAD") + + return self.execute_command(VSIM_CMD, key, *pieces, **options) + + def vdim(self, key: KeyT) -> Union[Awaitable[int], int]: + """ + Get the dimension of a vector set. + + In the case of vectors that were populated using the `REDUCE` + option, for random projection, the vector set will report the size of + the projected (reduced) dimension. + + Raises `redis.exceptions.ResponseError` if the vector set doesn't exist. + + For more information see https://redis.io/commands/vdim + """ + return self.execute_command(VDIM_CMD, key) + + def vcard(self, key: KeyT) -> Union[Awaitable[int], int]: + """ + Get the cardinality(the number of elements) of a vector set with key ``key``. + + Raises `redis.exceptions.ResponseError` if the vector set doesn't exist. + + For more information see https://redis.io/commands/vcard + """ + return self.execute_command(VCARD_CMD, key) + + def vrem(self, key: KeyT, element: str) -> Union[Awaitable[int], int]: + """ + Remove an element from a vector set. + + For more information see https://redis.io/commands/vrem + """ + return self.execute_command(VREM_CMD, key, element) + + def vemb( + self, key: KeyT, element: str, raw: Optional[bool] = False + ) -> Union[ + Awaitable[Optional[Union[List[EncodableT], Dict[str, EncodableT]]]], + Optional[Union[List[EncodableT], Dict[str, EncodableT]]], + ]: + """ + Get the approximated vector of an element ``element`` from vector set ``key``. + + ``raw`` is a boolean flag that indicates whether to return the + interal representation used by the vector. + + + For more information see https://redis.io/commands/vembed + """ + options = {} + pieces = [] + pieces.extend([key, element]) + + if get_protocol_version(self.client) in ["3", 3]: + options[CallbacksOptions.RESP3.value] = True + + if raw: + pieces.append("RAW") + + options[NEVER_DECODE] = True + if ( + hasattr(self.client, "connection_pool") + and self.client.connection_pool.connection_kwargs["decode_responses"] + ) or ( + hasattr(self.client, "nodes_manager") + and self.client.nodes_manager.connection_kwargs["decode_responses"] + ): + # allow decoding in the postprocessing callback + # if the user set decode_responses=True + # in the connection pool + options[CallbacksOptions.ALLOW_DECODING.value] = True + + options[CallbacksOptions.RAW.value] = True + + return self.execute_command(VEMB_CMD, *pieces, **options) + + def vlinks( + self, key: KeyT, element: str, with_scores: Optional[bool] = False + ) -> Union[ + Awaitable[ + Optional[ + List[Union[List[Union[str, bytes]], Dict[Union[str, bytes], Number]]] + ] + ], + Optional[List[Union[List[Union[str, bytes]], Dict[Union[str, bytes], Number]]]], + ]: + """ + Returns the neighbors for each level the element ``element`` exists in the vector set ``key``. + + The result is a list of lists, where each list contains the neighbors for one level. + If the element does not exist, or if the vector set does not exist, None is returned. + + If the ``WITHSCORES`` option is provided, the result is a list of dicts, + where each dict contains the neighbors for one level, with the scores as values. + + For more information see https://redis.io/commands/vlinks + """ + options = {} + pieces = [] + pieces.extend([key, element]) + + if with_scores: + pieces.append("WITHSCORES") + options[CallbacksOptions.WITHSCORES.value] = True + + return self.execute_command(VLINKS_CMD, *pieces, **options) + + def vinfo(self, key: KeyT) -> Union[Awaitable[dict], dict]: + """ + Get information about a vector set. + + For more information see https://redis.io/commands/vinfo + """ + return self.execute_command(VINFO_CMD, key) + + def vsetattr( + self, key: KeyT, element: str, attributes: Optional[Union[dict, str]] = None + ) -> Union[Awaitable[int], int]: + """ + Associate or remove JSON attributes ``attributes`` of element ``element`` + for vector set ``key``. + + For more information see https://redis.io/commands/vsetattr + """ + if attributes is None: + attributes_json = "{}" + elif isinstance(attributes, dict): + # transform attributes to json string + attributes_json = json.dumps(attributes) + else: + attributes_json = attributes + + return self.execute_command(VSETATTR_CMD, key, element, attributes_json) + + def vgetattr( + self, key: KeyT, element: str + ) -> Union[Optional[Awaitable[dict]], Optional[dict]]: + """ + Retrieve the JSON attributes of an element ``elemet`` for vector set ``key``. + + If the element does not exist, or if the vector set does not exist, None is + returned. + + For more information see https://redis.io/commands/vgetattr + """ + return self.execute_command(VGETATTR_CMD, key, element) + + def vrandmember( + self, key: KeyT, count: Optional[int] = None + ) -> Union[ + Awaitable[Optional[Union[List[str], str]]], Optional[Union[List[str], str]] + ]: + """ + Returns random elements from a vector set ``key``. + + ``count`` is the number of elements to return. + If ``count`` is not provided, a single element is returned as a single string. + If ``count`` is positive(smaller than the number of elements + in the vector set), the command returns a list with up to ``count`` + distinct elements from the vector set + If ``count`` is negative, the command returns a list with ``count`` random elements, + potentially with duplicates. + If ``count`` is greater than the number of elements in the vector set, + only the entire set is returned as a list. + + If the vector set does not exist, ``None`` is returned. + + For more information see https://redis.io/commands/vrandmember + """ + pieces = [] + pieces.append(key) + if count is not None: + pieces.append(count) + return self.execute_command(VRANDMEMBER_CMD, *pieces) diff --git a/redis/commands/vectorset/utils.py b/redis/commands/vectorset/utils.py new file mode 100644 index 0000000000..ed6d194ae0 --- /dev/null +++ b/redis/commands/vectorset/utils.py @@ -0,0 +1,94 @@ +from redis._parsers.helpers import pairs_to_dict +from redis.commands.vectorset.commands import CallbacksOptions + + +def parse_vemb_result(response, **options): + """ + Handle VEMB result since the command can returning different result + structures depending on input options and on quantization type of the vector set. + + Parsing VEMB result into: + - List[Union[bytes, Union[int, float]]] + - Dict[str, Union[bytes, str, float]] + """ + if response is None: + return response + + if options.get(CallbacksOptions.RAW.value): + result = {} + result["quantization"] = ( + response[0].decode("utf-8") + if options.get(CallbacksOptions.ALLOW_DECODING.value) + else response[0] + ) + result["raw"] = response[1] + result["l2"] = float(response[2]) + if len(response) > 3: + result["range"] = float(response[3]) + return result + else: + if options.get(CallbacksOptions.RESP3.value): + return response + + result = [] + for i in range(len(response)): + try: + result.append(int(response[i])) + except ValueError: + # if the value is not an integer, it should be a float + result.append(float(response[i])) + + return result + + +def parse_vlinks_result(response, **options): + """ + Handle VLINKS result since the command can be returning different result + structures depending on input options. + Parsing VLINKS result into: + - List[List[str]] + - List[Dict[str, Number]] + """ + if response is None: + return response + + if options.get(CallbacksOptions.WITHSCORES.value): + result = [] + # Redis will return a list of list of strings. + # This list have to be transformed to list of dicts + for level_item in response: + level_data_dict = {} + for key, value in pairs_to_dict(level_item).items(): + value = float(value) + level_data_dict[key] = value + result.append(level_data_dict) + return result + else: + # return the list of elements for each level + # list of lists + return response + + +def parse_vsim_result(response, **options): + """ + Handle VSIM result since the command can be returning different result + structures depending on input options. + Parsing VSIM result into: + - List[List[str]] + - List[Dict[str, Number]] + """ + if response is None: + return response + + if options.get(CallbacksOptions.WITHSCORES.value): + # Redis will return a list of list of pairs. + # This list have to be transformed to dict + result_dict = {} + for key, value in pairs_to_dict(response).items(): + value = float(value) + result_dict[key] = value + return result_dict + else: + # return the list of elements for each level + # list of lists + return response diff --git a/tests/test_asyncio/test_vsets.py b/tests/test_asyncio/test_vsets.py new file mode 100644 index 0000000000..9abc899066 --- /dev/null +++ b/tests/test_asyncio/test_vsets.py @@ -0,0 +1,858 @@ +import json +import random +import numpy as np +import pytest +import pytest_asyncio +import redis +from redis.commands.vectorset.commands import QuantizationOptions + +from tests.conftest import ( + skip_if_server_version_lt, +) + + +@pytest_asyncio.fixture() +async def d_client(create_redis, redis_url): + return await create_redis(url=redis_url, decode_responses=True) + + +@pytest_asyncio.fixture() +async def client(create_redis, redis_url): + return await create_redis(url=redis_url, decode_responses=False) + + +@skip_if_server_version_lt("7.9.0") +async def test_add_elem_with_values(d_client): + float_array = [1, 4.32, 0.11] + resp = await d_client.vset().vadd("myset", float_array, "elem1") + assert resp == 1 + + emb = await d_client.vset().vemb("myset", "elem1") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + with pytest.raises(redis.DataError): + await d_client.vset().vadd("myset_invalid_data", None, "elem1") + + with pytest.raises(redis.DataError): + await d_client.vset().vadd("myset_invalid_data", [12, 45], None, reduce_dim=3) + + +@skip_if_server_version_lt("7.9.0") +async def test_add_elem_with_vector(d_client): + float_array = [1, 4.32, 0.11] + # Convert the list of floats to a byte array in fp32 format + byte_array = _to_fp32_blob_array(float_array) + resp = await d_client.vset().vadd("myset", byte_array, "elem1") + assert resp == 1 + + emb = await d_client.vset().vemb("myset", "elem1") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + +@skip_if_server_version_lt("7.9.0") +async def test_add_elem_reduced_dim(d_client): + float_array = [1, 4.32, 0.11, 0.5, 0.9] + resp = await d_client.vset().vadd("myset", float_array, "elem1", reduce_dim=3) + assert resp == 1 + + dim = await d_client.vset().vdim("myset") + assert dim == 3 + + +@skip_if_server_version_lt("7.9.0") +async def test_add_elem_cas(d_client): + float_array = [1, 4.32, 0.11, 0.5, 0.9] + resp = await d_client.vset().vadd( + "myset", vector=float_array, element="elem1", cas=True + ) + assert resp == 1 + + emb = await d_client.vset().vemb("myset", "elem1") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + +@skip_if_server_version_lt("7.9.0") +async def test_add_elem_no_quant(d_client): + float_array = [1, 4.32, 0.11, 0.5, 0.9] + resp = await d_client.vset().vadd( + "myset", + vector=float_array, + element="elem1", + quantization=QuantizationOptions.NOQUANT, + ) + assert resp == 1 + + emb = await d_client.vset().vemb("myset", "elem1") + assert _validate_quantization(float_array, emb, tolerance=0.0) + + +@skip_if_server_version_lt("7.9.0") +async def test_add_elem_bin_quant(d_client): + float_array = [1, 4.32, 0.0, 0.05, -2.9] + resp = await d_client.vset().vadd( + "myset", + vector=float_array, + element="elem1", + quantization=QuantizationOptions.BIN, + ) + assert resp == 1 + + emb = await d_client.vset().vemb("myset", "elem1") + expected_array = [1, 1, -1, 1, -1] + assert _validate_quantization(expected_array, emb, tolerance=0.0) + + +@skip_if_server_version_lt("7.9.0") +async def test_add_elem_q8_quant(d_client): + float_array = [1, 4.32, 10.0, -21, -2.9] + resp = await d_client.vset().vadd( + "myset", + vector=float_array, + element="elem1", + quantization=QuantizationOptions.BIN, + ) + assert resp == 1 + + emb = await d_client.vset().vemb("myset", "elem1") + expected_array = [1, 1, 1, -1, -1] + assert _validate_quantization(expected_array, emb, tolerance=0.0) + + +@skip_if_server_version_lt("7.9.0") +async def test_add_elem_ef(d_client): + await d_client.vset().vadd("myset", vector=[5, 55, 65, -20, 30], element="elem1") + await d_client.vset().vadd( + "myset", vector=[-40, -40.32, 10.0, -4, 2.9], element="elem2" + ) + + float_array = [1, 4.32, 10.0, -21, -2.9] + resp = await d_client.vset().vadd("myset", float_array, "elem3", ef=1) + assert resp == 1 + + emb = await d_client.vset().vemb("myset", "elem3") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + sim = await d_client.vset().vsim("myset", input="elem3", with_scores=True) + assert len(sim) == 3 + + +@skip_if_server_version_lt("7.9.0") +async def test_add_elem_with_attr(d_client): + float_array = [1, 4.32, 10.0, -21, -2.9] + attrs_dict = {"key1": "value1", "key2": "value2"} + resp = await d_client.vset().vadd( + "myset", + vector=float_array, + element="elem3", + attributes=attrs_dict, + ) + assert resp == 1 + + emb = await d_client.vset().vemb("myset", "elem3") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + attr_saved = await d_client.vset().vgetattr("myset", "elem3") + assert attr_saved == attrs_dict + + resp = await d_client.vset().vadd( + "myset", + vector=float_array, + element="elem4", + attributes={}, + ) + assert resp == 1 + + emb = await d_client.vset().vemb("myset", "elem4") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + attr_saved = await d_client.vset().vgetattr("myset", "elem4") + assert attr_saved is None + + resp = await d_client.vset().vadd( + "myset", + vector=float_array, + element="elem5", + attributes=json.dumps(attrs_dict), + ) + assert resp == 1 + + emb = await d_client.vset().vemb("myset", "elem5") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + attr_saved = await d_client.vset().vgetattr("myset", "elem5") + assert attr_saved == attrs_dict + + +@skip_if_server_version_lt("7.9.0") +async def test_add_elem_with_numlinks(d_client): + elements_count = 100 + vector_dim = 10 + for i in range(elements_count): + float_array = [random.randint(0, 10) for x in range(vector_dim)] + await d_client.vset().vadd( + "myset", + float_array, + f"elem{i}", + numlinks=8, + ) + + float_array = [1, 4.32, 0.11, 0.5, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5] + resp = await d_client.vset().vadd("myset", float_array, "elem_numlinks", numlinks=8) + assert resp == 1 + + emb = await d_client.vset().vemb("myset", "elem_numlinks") + assert _validate_quantization(float_array, emb, tolerance=0.5) + + numlinks_all_layers = await d_client.vset().vlinks("myset", "elem_numlinks") + for neighbours_list_for_layer in numlinks_all_layers: + assert len(neighbours_list_for_layer) <= 8 + + +@skip_if_server_version_lt("7.9.0") +async def test_vsim_count(d_client): + elements_count = 30 + vector_dim = 800 + for i in range(elements_count): + float_array = [random.uniform(0, 10) for x in range(vector_dim)] + await d_client.vset().vadd( + "myset", + float_array, + f"elem{i}", + numlinks=64, + ) + + vsim = await d_client.vset().vsim("myset", input="elem1") + assert len(vsim) == 10 + assert isinstance(vsim, list) + assert isinstance(vsim[0], str) + + vsim = await d_client.vset().vsim("myset", input="elem1", count=5) + assert len(vsim) == 5 + assert isinstance(vsim, list) + assert isinstance(vsim[0], str) + + vsim = await d_client.vset().vsim("myset", input="elem1", count=50) + assert len(vsim) == 30 + assert isinstance(vsim, list) + assert isinstance(vsim[0], str) + + vsim = await d_client.vset().vsim("myset", input="elem1", count=15) + assert len(vsim) == 15 + assert isinstance(vsim, list) + assert isinstance(vsim[0], str) + + +@skip_if_server_version_lt("7.9.0") +async def test_vsim_with_scores(d_client): + elements_count = 20 + vector_dim = 50 + for i in range(elements_count): + float_array = [random.uniform(0, 10) for x in range(vector_dim)] + await d_client.vset().vadd( + "myset", + float_array, + f"elem{i}", + numlinks=64, + ) + + vsim = await d_client.vset().vsim("myset", input="elem1", with_scores=True) + assert len(vsim) == 10 + assert isinstance(vsim, dict) + assert isinstance(vsim["elem1"], float) + assert 0 <= vsim["elem1"] <= 1 + + +@skip_if_server_version_lt("7.9.0") +async def test_vsim_with_different_vector_input_types(d_client): + elements_count = 10 + vector_dim = 5 + for i in range(elements_count): + float_array = [random.uniform(0, 10) for x in range(vector_dim)] + attributes = {"index": i, "elem_name": f"elem_{i}"} + await d_client.vset().vadd( + "myset", + float_array, + f"elem_{i}", + numlinks=4, + attributes=attributes, + ) + sim = await d_client.vset().vsim("myset", input="elem_1") + assert len(sim) == 10 + assert isinstance(sim, list) + + float_array = [1, 4.32, 0.0, 0.05, -2.9] + sim_to_float_array = await d_client.vset().vsim("myset", input=float_array) + assert len(sim_to_float_array) == 10 + assert isinstance(sim_to_float_array, list) + + fp32_vector = _to_fp32_blob_array(float_array) + sim_to_fp32_vector = await d_client.vset().vsim("myset", input=fp32_vector) + assert len(sim_to_fp32_vector) == 10 + assert isinstance(sim_to_fp32_vector, list) + assert sim_to_float_array == sim_to_fp32_vector + + with pytest.raises(redis.DataError): + await d_client.vset().vsim("myset", input=None) + + +@skip_if_server_version_lt("7.9.0") +async def test_vsim_unexisting(d_client): + float_array = [1, 4.32, 0.11, 0.5, 0.9] + await d_client.vset().vadd("myset", vector=float_array, element="elem1", cas=True) + + with pytest.raises(redis.ResponseError): + await d_client.vset().vsim("myset", input="elem_not_existing") + + sim = await d_client.vset().vsim("myset_not_existing", input="elem1") + assert sim == [] + + +@skip_if_server_version_lt("7.9.0") +async def test_vsim_with_filter(d_client): + elements_count = 30 + vector_dim = 800 + for i in range(elements_count): + float_array = [random.uniform(0, 10) for x in range(vector_dim)] + attributes = {"index": i, "elem_name": f"elem_{i}"} + await d_client.vset().vadd( + "myset", + float_array, + f"elem_{i}", + numlinks=4, + attributes=attributes, + ) + sim = await d_client.vset().vsim("myset", input="elem_1", filter=".index > 10") + assert len(sim) == 10 + assert isinstance(sim, list) + for elem in sim: + assert int(elem.split("_")[1]) > 10 + + sim = await d_client.vset().vsim( + "myset", + input="elem_1", + filter=".index > 10 and .index < 15 and .elem_name in ['elem_12', 'elem_17']", + ) + assert len(sim) == 1 + assert isinstance(sim, list) + assert sim[0] == "elem_12" + + sim = await d_client.vset().vsim( + "myset", + input="elem_1", + filter=".index > 25 and .elem_name in ['elem_12', 'elem_17', 'elem_19']", + ef=100, + ) + assert len(sim) == 0 + assert isinstance(sim, list) + + sim = await d_client.vset().vsim( + "myset", + input="elem_1", + filter=".index > 28 and .elem_name in ['elem_12', 'elem_17', 'elem_29']", + filter_ef=1, + ) + assert len(sim) == 0 + assert isinstance(sim, list) + + sim = await d_client.vset().vsim( + "myset", + input="elem_1", + filter=".index > 28 and .elem_name in ['elem_12', 'elem_17', 'elem_29']", + filter_ef=20, + ) + assert len(sim) == 1 + assert isinstance(sim, list) + + +@skip_if_server_version_lt("7.9.0") +async def test_vsim_truth_no_thread_enabled(d_client): + elements_count = 5000 + vector_dim = 30 + for i in range(1, elements_count + 1): + float_array = [random.uniform(10 * i, 1000 * i) for x in range(vector_dim)] + await d_client.vset().vadd("myset", float_array, f"elem_{i}") + + await d_client.vset().vadd("myset", [-22 for _ in range(vector_dim)], "elem_man_2") + + sim_without_truth = await d_client.vset().vsim( + "myset", input="elem_man_2", with_scores=True + ) + sim_truth = await d_client.vset().vsim( + "myset", input="elem_man_2", with_scores=True, truth=True + ) + + assert len(sim_without_truth) == 10 + assert len(sim_truth) == 10 + + assert isinstance(sim_without_truth, dict) + assert isinstance(sim_truth, dict) + + results_scores = list( + zip( + [v for _, v in sim_truth.items()], [v for _, v in sim_without_truth.items()] + ) + ) + + found_better_match = False + for index, (score_with_truth, score_without_truth) in enumerate(results_scores): + if score_with_truth < score_without_truth: + assert False, ( + "Score with truth [{score_with_truth}] < score without truth [{score_without_truth}]" + ) + elif score_with_truth > score_without_truth: + found_better_match = True + + assert found_better_match + + sim_no_thread = await d_client.vset().vsim( + "myset", input="elem_man_2", with_scores=True, no_thread=True + ) + + assert len(sim_no_thread) == 10 + assert isinstance(sim_no_thread, dict) + + +@skip_if_server_version_lt("7.9.0") +async def test_vdim(d_client): + float_array = [1, 4.32, 0.11, 0.5, 0.9, 0.1, 0.2] + await d_client.vset().vadd("myset", float_array, "elem1") + + dim = await d_client.vset().vdim("myset") + assert dim == len(float_array) + + await d_client.vset().vadd("myset_reduced", float_array, "elem1", reduce_dim=4) + reduced_dim = await d_client.vset().vdim("myset_reduced") + assert reduced_dim == 4 + + with pytest.raises(redis.ResponseError): + await d_client.vset().vdim("myset_unexisting") + + +@skip_if_server_version_lt("7.9.0") +async def test_vcard(d_client): + n = 20 + for i in range(n): + float_array = [random.uniform(0, 10) for x in range(1, 8)] + await d_client.vset().vadd("myset", float_array, f"elem{i}") + + card = await d_client.vset().vcard("myset") + assert card == n + + with pytest.raises(redis.ResponseError): + await d_client.vset().vdim("myset_unexisting") + + +@skip_if_server_version_lt("7.9.0") +async def test_vrem(d_client): + n = 3 + for i in range(n): + float_array = [random.uniform(0, 10) for x in range(1, 8)] + await d_client.vset().vadd("myset", float_array, f"elem{i}") + + resp = await d_client.vset().vrem("myset", "elem2") + assert resp == 1 + + card = await d_client.vset().vcard("myset") + assert card == n - 1 + + resp = await d_client.vset().vrem("myset", "elem2") + assert resp == 0 + + card = await d_client.vset().vcard("myset") + assert card == n - 1 + + resp = await d_client.vset().vrem("myset_unexisting", "elem1") + assert resp == 0 + + +@skip_if_server_version_lt("7.9.0") +async def test_vemb_bin_quantization(d_client): + e = [1, 4.32, 0.0, 0.05, -2.9] + await d_client.vset().vadd( + "myset", + e, + "elem", + quantization=QuantizationOptions.BIN, + ) + emb_no_quant = await d_client.vset().vemb("myset", "elem") + assert emb_no_quant == [1, 1, -1, 1, -1] + + emb_no_quant_raw = await d_client.vset().vemb("myset", "elem", raw=True) + assert emb_no_quant_raw["quantization"] == "bin" + assert isinstance(emb_no_quant_raw["raw"], bytes) + assert isinstance(emb_no_quant_raw["l2"], float) + assert "range" not in emb_no_quant_raw + + +@skip_if_server_version_lt("7.9.0") +async def test_vemb_q8_quantization(d_client): + e = [1, 10.32, 0.0, 2.05, -12.5] + await d_client.vset().vadd("myset", e, "elem", quantization=QuantizationOptions.Q8) + + emb_q8_quant = await d_client.vset().vemb("myset", "elem") + assert _validate_quantization(e, emb_q8_quant, tolerance=0.1) + + emb_q8_quant_raw = await d_client.vset().vemb("myset", "elem", raw=True) + assert emb_q8_quant_raw["quantization"] == "int8" + assert isinstance(emb_q8_quant_raw["raw"], bytes) + assert isinstance(emb_q8_quant_raw["l2"], float) + assert isinstance(emb_q8_quant_raw["range"], float) + + +@skip_if_server_version_lt("7.9.0") +async def test_vemb_no_quantization(d_client): + e = [1, 10.32, 0.0, 2.05, -12.5] + await d_client.vset().vadd( + "myset", e, "elem", quantization=QuantizationOptions.NOQUANT + ) + + emb_no_quant = await d_client.vset().vemb("myset", "elem") + assert _validate_quantization(e, emb_no_quant, tolerance=0.1) + + emb_no_quant_raw = await d_client.vset().vemb("myset", "elem", raw=True) + assert emb_no_quant_raw["quantization"] == "f32" + assert isinstance(emb_no_quant_raw["raw"], bytes) + assert isinstance(emb_no_quant_raw["l2"], float) + assert "range" not in emb_no_quant_raw + + +@skip_if_server_version_lt("7.9.0") +async def test_vemb_default_quantization(d_client): + e = [1, 5.32, 0.0, 0.25, -5] + await d_client.vset().vadd("myset", vector=e, element="elem") + + emb_default_quant = await d_client.vset().vemb("myset", "elem") + assert _validate_quantization(e, emb_default_quant, tolerance=0.1) + + emb_default_quant_raw = await d_client.vset().vemb("myset", "elem", raw=True) + assert emb_default_quant_raw["quantization"] == "int8" + assert isinstance(emb_default_quant_raw["raw"], bytes) + assert isinstance(emb_default_quant_raw["l2"], float) + assert isinstance(emb_default_quant_raw["range"], float) + + +@skip_if_server_version_lt("7.9.0") +async def test_vemb_fp32_quantization(d_client): + float_array_fp32 = [1, 4.32, 0.11] + # Convert the list of floats to a byte array in fp32 format + byte_array = _to_fp32_blob_array(float_array_fp32) + await d_client.vset().vadd("myset", byte_array, "elem") + + emb_fp32_quant = await d_client.vset().vemb("myset", "elem") + assert _validate_quantization(float_array_fp32, emb_fp32_quant, tolerance=0.1) + + emb_fp32_quant_raw = await d_client.vset().vemb("myset", "elem", raw=True) + assert emb_fp32_quant_raw["quantization"] == "int8" + assert isinstance(emb_fp32_quant_raw["raw"], bytes) + assert isinstance(emb_fp32_quant_raw["l2"], float) + assert isinstance(emb_fp32_quant_raw["range"], float) + + +@skip_if_server_version_lt("7.9.0") +async def test_vemb_unexisting(d_client): + emb_not_existing = await d_client.vset().vemb("not_existing", "elem") + assert emb_not_existing is None + + e = [1, 5.32, 0.0, 0.25, -5] + await d_client.vset().vadd("myset", vector=e, element="elem") + emb_elem_not_existing = await d_client.vset().vemb("myset", "not_existing") + assert emb_elem_not_existing is None + + +@skip_if_server_version_lt("7.9.0") +async def test_vlinks(d_client): + elements_count = 100 + vector_dim = 800 + for i in range(elements_count): + float_array = [random.uniform(0, 10) for x in range(vector_dim)] + await d_client.vset().vadd( + "myset", + float_array, + f"elem{i}", + numlinks=8, + ) + + element_links_all_layers = await d_client.vset().vlinks("myset", "elem1") + assert len(element_links_all_layers) >= 1 + for neighbours_list_for_layer in element_links_all_layers: + assert isinstance(neighbours_list_for_layer, list) + for neighbour in neighbours_list_for_layer: + assert isinstance(neighbour, str) + + elem_links_all_layers_with_scores = await d_client.vset().vlinks( + "myset", "elem1", with_scores=True + ) + assert len(elem_links_all_layers_with_scores) >= 1 + for neighbours_dict_for_layer in elem_links_all_layers_with_scores: + assert isinstance(neighbours_dict_for_layer, dict) + for neighbour_key, score_value in neighbours_dict_for_layer.items(): + assert isinstance(neighbour_key, str) + assert isinstance(score_value, float) + + float_array = [0.75, 0.25, 0.5, 0.1, 0.9] + await d_client.vset().vadd("myset_one_elem_only", float_array, "elem1") + elem_no_neighbours_with_scores = await d_client.vset().vlinks( + "myset_one_elem_only", "elem1", with_scores=True + ) + assert len(elem_no_neighbours_with_scores) >= 1 + for neighbours_dict_for_layer in elem_no_neighbours_with_scores: + assert isinstance(neighbours_dict_for_layer, dict) + assert len(neighbours_dict_for_layer) == 0 + + elem_no_neighbours_no_scores = await d_client.vset().vlinks( + "myset_one_elem_only", "elem1" + ) + assert len(elem_no_neighbours_no_scores) >= 1 + for neighbours_list_for_layer in elem_no_neighbours_no_scores: + assert isinstance(neighbours_list_for_layer, list) + assert len(neighbours_list_for_layer) == 0 + + unexisting_element_links = await d_client.vset().vlinks("myset", "unexisting_elem") + assert unexisting_element_links is None + + unexisting_vset_links = await d_client.vset().vlinks("myset_unexisting", "elem1") + assert unexisting_vset_links is None + + unexisting_element_links = await d_client.vset().vlinks( + "myset", "unexisting_elem", with_scores=True + ) + assert unexisting_element_links is None + + unexisting_vset_links = await d_client.vset().vlinks( + "myset_unexisting", "elem1", with_scores=True + ) + assert unexisting_vset_links is None + + +@skip_if_server_version_lt("7.9.0") +async def test_vinfo(d_client): + elements_count = 100 + vector_dim = 800 + for i in range(elements_count): + float_array = [random.uniform(0, 10) for x in range(vector_dim)] + await d_client.vset().vadd( + "myset", + float_array, + f"elem{i}", + numlinks=8, + quantization=QuantizationOptions.BIN, + ) + + vset_info = await d_client.vset().vinfo("myset") + assert vset_info["quant-type"] == "bin" + assert vset_info["vector-dim"] == vector_dim + assert vset_info["size"] == elements_count + assert vset_info["max-level"] > 0 + assert vset_info["hnsw-max-node-uid"] == elements_count + + unexisting_vset_info = await d_client.vset().vinfo("myset_unexisting") + assert unexisting_vset_info is None + + +@skip_if_server_version_lt("7.9.0") +async def test_vset_vget_attributes(d_client): + float_array = [1, 4.32, 0.11] + attributes = {"key1": "value1", "key2": "value2"} + + # validate vgetattrs when no attributes are set with vadd + resp = await d_client.vset().vadd("myset", float_array, "elem1") + assert resp == 1 + + attrs = await d_client.vset().vgetattr("myset", "elem1") + assert attrs is None + + # validate vgetattrs when attributes are set with vadd + resp = await d_client.vset().vadd( + "myset_with_attrs", float_array, "elem1", attributes=attributes + ) + assert resp == 1 + + attrs = await d_client.vset().vgetattr("myset_with_attrs", "elem1") + assert attrs == attributes + + # Set attributes and get attributes + resp = await d_client.vset().vsetattr("myset", "elem1", attributes) + assert resp == 1 + attr_saved = await d_client.vset().vgetattr("myset", "elem1") + assert attr_saved == attributes + + # Set attributes to None + resp = await d_client.vset().vsetattr("myset", "elem1", None) + assert resp == 1 + attr_saved = await d_client.vset().vgetattr("myset", "elem1") + assert attr_saved is None + + # Set attributes to empty dict + resp = await d_client.vset().vsetattr("myset", "elem1", {}) + assert resp == 1 + attr_saved = await d_client.vset().vgetattr("myset", "elem1") + assert attr_saved is None + + # Set attributes provided as string + resp = await d_client.vset().vsetattr("myset", "elem1", json.dumps(attributes)) + assert resp == 1 + attr_saved = await d_client.vset().vgetattr("myset", "elem1") + assert attr_saved == attributes + + # Set attributes to unexisting element + resp = await d_client.vset().vsetattr("myset", "elem2", attributes) + assert resp == 0 + attr_saved = await d_client.vset().vgetattr("myset", "elem2") + assert attr_saved is None + + # Set attributes to unexisting vset + resp = await d_client.vset().vsetattr("myset_unexisting", "elem1", attributes) + assert resp == 0 + attr_saved = await d_client.vset().vgetattr("myset_unexisting", "elem1") + assert attr_saved is None + + +@skip_if_server_version_lt("7.9.0") +async def test_vrandmember(d_client): + elements = ["elem1", "elem2", "elem3"] + for elem in elements: + float_array = [random.uniform(0, 10) for x in range(1, 8)] + await d_client.vset().vadd("myset", float_array, element=elem) + + random_member = await d_client.vset().vrandmember("myset") + assert random_member in elements + + members_list = await d_client.vset().vrandmember("myset", count=2) + assert len(members_list) == 2 + assert all(member in elements for member in members_list) + + # Test with count greater than the number of elements + members_list = await d_client.vset().vrandmember("myset", count=10) + assert len(members_list) == len(elements) + assert all(member in elements for member in members_list) + + # Test with negative count + members_list = await d_client.vset().vrandmember("myset", count=-2) + assert len(members_list) == 2 + assert all(member in elements for member in members_list) + + # Test with count equal to the number of elements + members_list = await d_client.vset().vrandmember("myset", count=len(elements)) + assert len(members_list) == len(elements) + assert all(member in elements for member in members_list) + + # Test with count equal to 0 + members_list = await d_client.vset().vrandmember("myset", count=0) + assert members_list == [] + + # Test with count equal to 1 + members_list = await d_client.vset().vrandmember("myset", count=1) + assert len(members_list) == 1 + assert members_list[0] in elements + + # Test with count equal to -1 + members_list = await d_client.vset().vrandmember("myset", count=-1) + assert len(members_list) == 1 + assert members_list[0] in elements + + # Test with unexisting vset & without count + members_list = await d_client.vset().vrandmember("myset_unexisting") + assert members_list is None + + # Test with unexisting vset & count + members_list = await d_client.vset().vrandmember("myset_unexisting", count=5) + assert members_list == [] + + +@skip_if_server_version_lt("7.9.0") +async def test_vset_commands_without_decoding_responces(client): + # test vadd + elements = ["elem1", "elem2", "elem3"] + for elem in elements: + float_array = [random.uniform(0, 10) for x in range(0, 8)] + resp = await client.vset().vadd("myset", float_array, element=elem) + assert resp == 1 + + # test vemb + emb = await client.vset().vemb("myset", "elem1") + assert len(emb) == 8 + assert isinstance(emb, list) + assert all(isinstance(x, float) for x in emb) + + emb_raw = await client.vset().vemb("myset", "elem1", raw=True) + assert emb_raw["quantization"] == b"int8" + assert isinstance(emb_raw["raw"], bytes) + assert isinstance(emb_raw["l2"], float) + assert isinstance(emb_raw["range"], float) + + # test vsim + vsim = await client.vset().vsim("myset", input="elem1") + assert len(vsim) == 3 + assert isinstance(vsim, list) + assert isinstance(vsim[0], bytes) + + # test vsim with scores + vsim_with_scores = await client.vset().vsim( + "myset", input="elem1", with_scores=True + ) + assert len(vsim_with_scores) == 3 + assert isinstance(vsim_with_scores, dict) + assert isinstance(vsim_with_scores[b"elem1"], float) + + # test vlinks - no scores + element_links_all_layers = await client.vset().vlinks("myset", "elem1") + assert len(element_links_all_layers) >= 1 + for neighbours_list_for_layer in element_links_all_layers: + assert isinstance(neighbours_list_for_layer, list) + for neighbour in neighbours_list_for_layer: + assert isinstance(neighbour, bytes) + # test vlinks with scores + elem_links_all_layers_with_scores = await client.vset().vlinks( + "myset", "elem1", with_scores=True + ) + assert len(elem_links_all_layers_with_scores) >= 1 + for neighbours_dict_for_layer in elem_links_all_layers_with_scores: + assert isinstance(neighbours_dict_for_layer, dict) + for neighbour_key, score_value in neighbours_dict_for_layer.items(): + assert isinstance(neighbour_key, bytes) + assert isinstance(score_value, float) + + # test vinfo + vset_info = await client.vset().vinfo("myset") + assert vset_info[b"quant-type"] == b"int8" + assert vset_info[b"vector-dim"] == 8 + assert vset_info[b"size"] == len(elements) + assert vset_info[b"max-level"] >= 0 + assert vset_info[b"hnsw-max-node-uid"] == len(elements) + + # test vgetattr + attributes = {"key1": "value1", "key2": "value2"} + await client.vset().vsetattr("myset", "elem1", attributes) + attrs = await client.vset().vgetattr("myset", "elem1") + assert attrs == attributes + + # test vrandmember + random_member = await client.vset().vrandmember("myset") + assert isinstance(random_member, bytes) + assert random_member.decode("utf-8") in elements + + members_list = await client.vset().vrandmember("myset", count=2) + assert len(members_list) == 2 + assert all(member.decode("utf-8") in elements for member in members_list) + + +def _to_fp32_blob_array(float_array): + """ + Convert a list of floats to a byte array in fp32 format. + """ + # Convert the list of floats to a NumPy array with dtype np.float32 + arr = np.array(float_array, dtype=np.float32) + # Convert the NumPy array to a byte array + byte_array = arr.tobytes() + return byte_array + + +def _validate_quantization(original, quantized, tolerance=0.1): + original = np.array(original, dtype=np.float32) + quantized = np.array(quantized, dtype=np.float32) + + max_diff = np.max(np.abs(original - quantized)) + if max_diff > tolerance: + return False + else: + return True diff --git a/tests/test_vsets.py b/tests/test_vsets.py new file mode 100644 index 0000000000..ab4194657b --- /dev/null +++ b/tests/test_vsets.py @@ -0,0 +1,856 @@ +import json +import random +import numpy as np +import pytest +import redis +from redis.commands.vectorset.commands import QuantizationOptions + +from .conftest import ( + _get_client, + skip_if_server_version_lt, +) + + +@pytest.fixture +def d_client(request): + r = _get_client(redis.Redis, request, decode_responses=True) + + r.flushdb() + return r + + +@pytest.fixture +def client(request): + r = _get_client(redis.Redis, request, decode_responses=False) + + r.flushdb() + return r + + +@skip_if_server_version_lt("7.9.0") +def test_add_elem_with_values(d_client): + float_array = [1, 4.32, 0.11] + resp = d_client.vset().vadd("myset", float_array, "elem1") + assert resp == 1 + + emb = d_client.vset().vemb("myset", "elem1") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + with pytest.raises(redis.DataError): + d_client.vset().vadd("myset_invalid_data", None, "elem1") + + with pytest.raises(redis.DataError): + d_client.vset().vadd("myset_invalid_data", [12, 45], None, reduce_dim=3) + + +@skip_if_server_version_lt("7.9.0") +def test_add_elem_with_vector(d_client): + float_array = [1, 4.32, 0.11] + # Convert the list of floats to a byte array in fp32 format + byte_array = _to_fp32_blob_array(float_array) + resp = d_client.vset().vadd("myset", byte_array, "elem1") + assert resp == 1 + + emb = d_client.vset().vemb("myset", "elem1") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + +@skip_if_server_version_lt("7.9.0") +def test_add_elem_reduced_dim(d_client): + float_array = [1, 4.32, 0.11, 0.5, 0.9] + resp = d_client.vset().vadd("myset", float_array, "elem1", reduce_dim=3) + assert resp == 1 + + dim = d_client.vset().vdim("myset") + assert dim == 3 + + +@skip_if_server_version_lt("7.9.0") +def test_add_elem_cas(d_client): + float_array = [1, 4.32, 0.11, 0.5, 0.9] + resp = d_client.vset().vadd("myset", vector=float_array, element="elem1", cas=True) + assert resp == 1 + + emb = d_client.vset().vemb("myset", "elem1") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + +@skip_if_server_version_lt("7.9.0") +def test_add_elem_no_quant(d_client): + float_array = [1, 4.32, 0.11, 0.5, 0.9] + resp = d_client.vset().vadd( + "myset", + vector=float_array, + element="elem1", + quantization=QuantizationOptions.NOQUANT, + ) + assert resp == 1 + + emb = d_client.vset().vemb("myset", "elem1") + assert _validate_quantization(float_array, emb, tolerance=0.0) + + +@skip_if_server_version_lt("7.9.0") +def test_add_elem_bin_quant(d_client): + float_array = [1, 4.32, 0.0, 0.05, -2.9] + resp = d_client.vset().vadd( + "myset", + vector=float_array, + element="elem1", + quantization=QuantizationOptions.BIN, + ) + assert resp == 1 + + emb = d_client.vset().vemb("myset", "elem1") + expected_array = [1, 1, -1, 1, -1] + assert _validate_quantization(expected_array, emb, tolerance=0.0) + + +@skip_if_server_version_lt("7.9.0") +def test_add_elem_q8_quant(d_client): + float_array = [1, 4.32, 10.0, -21, -2.9] + resp = d_client.vset().vadd( + "myset", + vector=float_array, + element="elem1", + quantization=QuantizationOptions.BIN, + ) + assert resp == 1 + + emb = d_client.vset().vemb("myset", "elem1") + expected_array = [1, 1, 1, -1, -1] + assert _validate_quantization(expected_array, emb, tolerance=0.0) + + +@skip_if_server_version_lt("7.9.0") +def test_add_elem_ef(d_client): + d_client.vset().vadd("myset", vector=[5, 55, 65, -20, 30], element="elem1") + d_client.vset().vadd("myset", vector=[-40, -40.32, 10.0, -4, 2.9], element="elem2") + + float_array = [1, 4.32, 10.0, -21, -2.9] + resp = d_client.vset().vadd("myset", float_array, "elem3", ef=1) + assert resp == 1 + + emb = d_client.vset().vemb("myset", "elem3") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + sim = d_client.vset().vsim("myset", input="elem3", with_scores=True) + assert len(sim) == 3 + + +@skip_if_server_version_lt("7.9.0") +def test_add_elem_with_attr(d_client): + float_array = [1, 4.32, 10.0, -21, -2.9] + attrs_dict = {"key1": "value1", "key2": "value2"} + resp = d_client.vset().vadd( + "myset", + vector=float_array, + element="elem3", + attributes=attrs_dict, + ) + assert resp == 1 + + emb = d_client.vset().vemb("myset", "elem3") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + attr_saved = d_client.vset().vgetattr("myset", "elem3") + assert attr_saved == attrs_dict + + resp = d_client.vset().vadd( + "myset", + vector=float_array, + element="elem4", + attributes={}, + ) + assert resp == 1 + + emb = d_client.vset().vemb("myset", "elem4") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + attr_saved = d_client.vset().vgetattr("myset", "elem4") + assert attr_saved is None + + resp = d_client.vset().vadd( + "myset", + vector=float_array, + element="elem5", + attributes=json.dumps(attrs_dict), + ) + assert resp == 1 + + emb = d_client.vset().vemb("myset", "elem5") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + attr_saved = d_client.vset().vgetattr("myset", "elem5") + assert attr_saved == attrs_dict + + +@skip_if_server_version_lt("7.9.0") +def test_add_elem_with_numlinks(d_client): + elements_count = 100 + vector_dim = 10 + for i in range(elements_count): + float_array = [random.randint(0, 10) for x in range(vector_dim)] + d_client.vset().vadd( + "myset", + float_array, + f"elem{i}", + numlinks=8, + ) + + float_array = [1, 4.32, 0.11, 0.5, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5] + resp = d_client.vset().vadd("myset", float_array, "elem_numlinks", numlinks=8) + assert resp == 1 + + emb = d_client.vset().vemb("myset", "elem_numlinks") + assert _validate_quantization(float_array, emb, tolerance=0.5) + + numlinks_all_layers = d_client.vset().vlinks("myset", "elem_numlinks") + for neighbours_list_for_layer in numlinks_all_layers: + assert len(neighbours_list_for_layer) <= 8 + + +@skip_if_server_version_lt("7.9.0") +def test_vsim_count(d_client): + elements_count = 30 + vector_dim = 800 + for i in range(elements_count): + float_array = [random.uniform(0, 10) for x in range(vector_dim)] + d_client.vset().vadd( + "myset", + float_array, + f"elem{i}", + numlinks=64, + ) + + vsim = d_client.vset().vsim("myset", input="elem1") + assert len(vsim) == 10 + assert isinstance(vsim, list) + assert isinstance(vsim[0], str) + + vsim = d_client.vset().vsim("myset", input="elem1", count=5) + assert len(vsim) == 5 + assert isinstance(vsim, list) + assert isinstance(vsim[0], str) + + vsim = d_client.vset().vsim("myset", input="elem1", count=50) + assert len(vsim) == 30 + assert isinstance(vsim, list) + assert isinstance(vsim[0], str) + + vsim = d_client.vset().vsim("myset", input="elem1", count=15) + assert len(vsim) == 15 + assert isinstance(vsim, list) + assert isinstance(vsim[0], str) + + +@skip_if_server_version_lt("7.9.0") +def test_vsim_with_scores(d_client): + elements_count = 20 + vector_dim = 50 + for i in range(elements_count): + float_array = [random.uniform(0, 10) for x in range(vector_dim)] + d_client.vset().vadd( + "myset", + float_array, + f"elem{i}", + numlinks=64, + ) + + vsim = d_client.vset().vsim("myset", input="elem1", with_scores=True) + assert len(vsim) == 10 + assert isinstance(vsim, dict) + assert isinstance(vsim["elem1"], float) + assert 0 <= vsim["elem1"] <= 1 + + +@skip_if_server_version_lt("7.9.0") +def test_vsim_with_different_vector_input_types(d_client): + elements_count = 10 + vector_dim = 5 + for i in range(elements_count): + float_array = [random.uniform(0, 10) for x in range(vector_dim)] + attributes = {"index": i, "elem_name": f"elem_{i}"} + d_client.vset().vadd( + "myset", + float_array, + f"elem_{i}", + numlinks=4, + attributes=attributes, + ) + sim = d_client.vset().vsim("myset", input="elem_1") + assert len(sim) == 10 + assert isinstance(sim, list) + + float_array = [1, 4.32, 0.0, 0.05, -2.9] + sim_to_float_array = d_client.vset().vsim("myset", input=float_array) + assert len(sim_to_float_array) == 10 + assert isinstance(sim_to_float_array, list) + + fp32_vector = _to_fp32_blob_array(float_array) + sim_to_fp32_vector = d_client.vset().vsim("myset", input=fp32_vector) + assert len(sim_to_fp32_vector) == 10 + assert isinstance(sim_to_fp32_vector, list) + assert sim_to_float_array == sim_to_fp32_vector + + with pytest.raises(redis.DataError): + d_client.vset().vsim("myset", input=None) + + +@skip_if_server_version_lt("7.9.0") +def test_vsim_unexisting(d_client): + float_array = [1, 4.32, 0.11, 0.5, 0.9] + d_client.vset().vadd("myset", vector=float_array, element="elem1", cas=True) + + with pytest.raises(redis.ResponseError): + d_client.vset().vsim("myset", input="elem_not_existing") + + sim = d_client.vset().vsim("myset_not_existing", input="elem1") + assert sim == [] + + +@skip_if_server_version_lt("7.9.0") +def test_vsim_with_filter(d_client): + elements_count = 30 + vector_dim = 800 + for i in range(elements_count): + float_array = [random.uniform(0, 10) for x in range(vector_dim)] + attributes = {"index": i, "elem_name": f"elem_{i}"} + d_client.vset().vadd( + "myset", + float_array, + f"elem_{i}", + numlinks=4, + attributes=attributes, + ) + sim = d_client.vset().vsim("myset", input="elem_1", filter=".index > 10") + assert len(sim) == 10 + assert isinstance(sim, list) + for elem in sim: + assert int(elem.split("_")[1]) > 10 + + sim = d_client.vset().vsim( + "myset", + input="elem_1", + filter=".index > 10 and .index < 15 and .elem_name in ['elem_12', 'elem_17']", + ) + assert len(sim) == 1 + assert isinstance(sim, list) + assert sim[0] == "elem_12" + + sim = d_client.vset().vsim( + "myset", + input="elem_1", + filter=".index > 25 and .elem_name in ['elem_12', 'elem_17', 'elem_19']", + ef=100, + ) + assert len(sim) == 0 + assert isinstance(sim, list) + + sim = d_client.vset().vsim( + "myset", + input="elem_1", + filter=".index > 28 and .elem_name in ['elem_12', 'elem_17', 'elem_29']", + filter_ef=1, + ) + assert len(sim) == 0 + assert isinstance(sim, list) + + sim = d_client.vset().vsim( + "myset", + input="elem_1", + filter=".index > 28 and .elem_name in ['elem_12', 'elem_17', 'elem_29']", + filter_ef=20, + ) + assert len(sim) == 1 + assert isinstance(sim, list) + + +@skip_if_server_version_lt("7.9.0") +def test_vsim_truth_no_thread_enabled(d_client): + elements_count = 5000 + vector_dim = 30 + for i in range(1, elements_count + 1): + float_array = [random.uniform(10 * i, 1000 * i) for x in range(vector_dim)] + d_client.vset().vadd("myset", float_array, f"elem_{i}") + + d_client.vset().vadd("myset", [-22 for _ in range(vector_dim)], "elem_man_2") + + sim_without_truth = d_client.vset().vsim( + "myset", input="elem_man_2", with_scores=True + ) + sim_truth = d_client.vset().vsim( + "myset", input="elem_man_2", with_scores=True, truth=True + ) + + assert len(sim_without_truth) == 10 + assert len(sim_truth) == 10 + + assert isinstance(sim_without_truth, dict) + assert isinstance(sim_truth, dict) + + results_scores = list( + zip( + [v for _, v in sim_truth.items()], [v for _, v in sim_without_truth.items()] + ) + ) + + found_better_match = False + for index, (score_with_truth, score_without_truth) in enumerate(results_scores): + if score_with_truth < score_without_truth: + assert False, ( + "Score with truth [{score_with_truth}] < score without truth [{score_without_truth}]" + ) + elif score_with_truth > score_without_truth: + found_better_match = True + + assert found_better_match + + sim_no_thread = d_client.vset().vsim( + "myset", input="elem_man_2", with_scores=True, no_thread=True + ) + + assert len(sim_no_thread) == 10 + assert isinstance(sim_no_thread, dict) + + +@skip_if_server_version_lt("7.9.0") +def test_vdim(d_client): + float_array = [1, 4.32, 0.11, 0.5, 0.9, 0.1, 0.2] + d_client.vset().vadd("myset", float_array, "elem1") + + dim = d_client.vset().vdim("myset") + assert dim == len(float_array) + + d_client.vset().vadd("myset_reduced", float_array, "elem1", reduce_dim=4) + reduced_dim = d_client.vset().vdim("myset_reduced") + assert reduced_dim == 4 + + with pytest.raises(redis.ResponseError): + d_client.vset().vdim("myset_unexisting") + + +@skip_if_server_version_lt("7.9.0") +def test_vcard(d_client): + n = 20 + for i in range(n): + float_array = [random.uniform(0, 10) for x in range(1, 8)] + d_client.vset().vadd("myset", float_array, f"elem{i}") + + card = d_client.vset().vcard("myset") + assert card == n + + with pytest.raises(redis.ResponseError): + d_client.vset().vdim("myset_unexisting") + + +@skip_if_server_version_lt("7.9.0") +def test_vrem(d_client): + n = 3 + for i in range(n): + float_array = [random.uniform(0, 10) for x in range(1, 8)] + d_client.vset().vadd("myset", float_array, f"elem{i}") + + resp = d_client.vset().vrem("myset", "elem2") + assert resp == 1 + + card = d_client.vset().vcard("myset") + assert card == n - 1 + + resp = d_client.vset().vrem("myset", "elem2") + assert resp == 0 + + card = d_client.vset().vcard("myset") + assert card == n - 1 + + resp = d_client.vset().vrem("myset_unexisting", "elem1") + assert resp == 0 + + +@skip_if_server_version_lt("7.9.0") +def test_vemb_bin_quantization(d_client): + e = [1, 4.32, 0.0, 0.05, -2.9] + d_client.vset().vadd( + "myset", + e, + "elem", + quantization=QuantizationOptions.BIN, + ) + emb_no_quant = d_client.vset().vemb("myset", "elem") + assert emb_no_quant == [1, 1, -1, 1, -1] + + emb_no_quant_raw = d_client.vset().vemb("myset", "elem", raw=True) + assert emb_no_quant_raw["quantization"] == "bin" + assert isinstance(emb_no_quant_raw["raw"], bytes) + assert isinstance(emb_no_quant_raw["l2"], float) + assert "range" not in emb_no_quant_raw + + +@skip_if_server_version_lt("7.9.0") +def test_vemb_q8_quantization(d_client): + e = [1, 10.32, 0.0, 2.05, -12.5] + d_client.vset().vadd("myset", e, "elem", quantization=QuantizationOptions.Q8) + + emb_q8_quant = d_client.vset().vemb("myset", "elem") + assert _validate_quantization(e, emb_q8_quant, tolerance=0.1) + + emb_q8_quant_raw = d_client.vset().vemb("myset", "elem", raw=True) + assert emb_q8_quant_raw["quantization"] == "int8" + assert isinstance(emb_q8_quant_raw["raw"], bytes) + assert isinstance(emb_q8_quant_raw["l2"], float) + assert isinstance(emb_q8_quant_raw["range"], float) + + +@skip_if_server_version_lt("7.9.0") +def test_vemb_no_quantization(d_client): + e = [1, 10.32, 0.0, 2.05, -12.5] + d_client.vset().vadd("myset", e, "elem", quantization=QuantizationOptions.NOQUANT) + + emb_no_quant = d_client.vset().vemb("myset", "elem") + assert _validate_quantization(e, emb_no_quant, tolerance=0.1) + + emb_no_quant_raw = d_client.vset().vemb("myset", "elem", raw=True) + assert emb_no_quant_raw["quantization"] == "f32" + assert isinstance(emb_no_quant_raw["raw"], bytes) + assert isinstance(emb_no_quant_raw["l2"], float) + assert "range" not in emb_no_quant_raw + + +@skip_if_server_version_lt("7.9.0") +def test_vemb_default_quantization(d_client): + e = [1, 5.32, 0.0, 0.25, -5] + d_client.vset().vadd("myset", vector=e, element="elem") + + emb_default_quant = d_client.vset().vemb("myset", "elem") + assert _validate_quantization(e, emb_default_quant, tolerance=0.1) + + emb_default_quant_raw = d_client.vset().vemb("myset", "elem", raw=True) + assert emb_default_quant_raw["quantization"] == "int8" + assert isinstance(emb_default_quant_raw["raw"], bytes) + assert isinstance(emb_default_quant_raw["l2"], float) + assert isinstance(emb_default_quant_raw["range"], float) + + +@skip_if_server_version_lt("7.9.0") +def test_vemb_fp32_quantization(d_client): + float_array_fp32 = [1, 4.32, 0.11] + # Convert the list of floats to a byte array in fp32 format + byte_array = _to_fp32_blob_array(float_array_fp32) + d_client.vset().vadd("myset", byte_array, "elem") + + emb_fp32_quant = d_client.vset().vemb("myset", "elem") + assert _validate_quantization(float_array_fp32, emb_fp32_quant, tolerance=0.1) + + emb_fp32_quant_raw = d_client.vset().vemb("myset", "elem", raw=True) + assert emb_fp32_quant_raw["quantization"] == "int8" + assert isinstance(emb_fp32_quant_raw["raw"], bytes) + assert isinstance(emb_fp32_quant_raw["l2"], float) + assert isinstance(emb_fp32_quant_raw["range"], float) + + +@skip_if_server_version_lt("7.9.0") +def test_vemb_unexisting(d_client): + emb_not_existing = d_client.vset().vemb("not_existing", "elem") + assert emb_not_existing is None + + e = [1, 5.32, 0.0, 0.25, -5] + d_client.vset().vadd("myset", vector=e, element="elem") + emb_elem_not_existing = d_client.vset().vemb("myset", "not_existing") + assert emb_elem_not_existing is None + + +@skip_if_server_version_lt("7.9.0") +def test_vlinks(d_client): + elements_count = 100 + vector_dim = 800 + for i in range(elements_count): + float_array = [random.uniform(0, 10) for x in range(vector_dim)] + d_client.vset().vadd( + "myset", + float_array, + f"elem{i}", + numlinks=8, + ) + + element_links_all_layers = d_client.vset().vlinks("myset", "elem1") + assert len(element_links_all_layers) >= 1 + for neighbours_list_for_layer in element_links_all_layers: + assert isinstance(neighbours_list_for_layer, list) + for neighbour in neighbours_list_for_layer: + assert isinstance(neighbour, str) + + elem_links_all_layers_with_scores = d_client.vset().vlinks( + "myset", "elem1", with_scores=True + ) + assert len(elem_links_all_layers_with_scores) >= 1 + for neighbours_dict_for_layer in elem_links_all_layers_with_scores: + assert isinstance(neighbours_dict_for_layer, dict) + for neighbour_key, score_value in neighbours_dict_for_layer.items(): + assert isinstance(neighbour_key, str) + assert isinstance(score_value, float) + + float_array = [0.75, 0.25, 0.5, 0.1, 0.9] + d_client.vset().vadd("myset_one_elem_only", float_array, "elem1") + elem_no_neighbours_with_scores = d_client.vset().vlinks( + "myset_one_elem_only", "elem1", with_scores=True + ) + assert len(elem_no_neighbours_with_scores) >= 1 + for neighbours_dict_for_layer in elem_no_neighbours_with_scores: + assert isinstance(neighbours_dict_for_layer, dict) + assert len(neighbours_dict_for_layer) == 0 + + elem_no_neighbours_no_scores = d_client.vset().vlinks( + "myset_one_elem_only", "elem1" + ) + assert len(elem_no_neighbours_no_scores) >= 1 + for neighbours_list_for_layer in elem_no_neighbours_no_scores: + assert isinstance(neighbours_list_for_layer, list) + assert len(neighbours_list_for_layer) == 0 + + unexisting_element_links = d_client.vset().vlinks("myset", "unexisting_elem") + assert unexisting_element_links is None + + unexisting_vset_links = d_client.vset().vlinks("myset_unexisting", "elem1") + assert unexisting_vset_links is None + + unexisting_element_links = d_client.vset().vlinks( + "myset", "unexisting_elem", with_scores=True + ) + assert unexisting_element_links is None + + unexisting_vset_links = d_client.vset().vlinks( + "myset_unexisting", "elem1", with_scores=True + ) + assert unexisting_vset_links is None + + +@skip_if_server_version_lt("7.9.0") +def test_vinfo(d_client): + elements_count = 100 + vector_dim = 800 + for i in range(elements_count): + float_array = [random.uniform(0, 10) for x in range(vector_dim)] + d_client.vset().vadd( + "myset", + float_array, + f"elem{i}", + numlinks=8, + quantization=QuantizationOptions.BIN, + ) + + vset_info = d_client.vset().vinfo("myset") + assert vset_info["quant-type"] == "bin" + assert vset_info["vector-dim"] == vector_dim + assert vset_info["size"] == elements_count + assert vset_info["max-level"] > 0 + assert vset_info["hnsw-max-node-uid"] == elements_count + + unexisting_vset_info = d_client.vset().vinfo("myset_unexisting") + assert unexisting_vset_info is None + + +@skip_if_server_version_lt("7.9.0") +def test_vset_vget_attributes(d_client): + float_array = [1, 4.32, 0.11] + attributes = {"key1": "value1", "key2": "value2"} + + # validate vgetattrs when no attributes are set with vadd + resp = d_client.vset().vadd("myset", float_array, "elem1") + assert resp == 1 + + attrs = d_client.vset().vgetattr("myset", "elem1") + assert attrs is None + + # validate vgetattrs when attributes are set with vadd + resp = d_client.vset().vadd( + "myset_with_attrs", float_array, "elem1", attributes=attributes + ) + assert resp == 1 + + attrs = d_client.vset().vgetattr("myset_with_attrs", "elem1") + assert attrs == attributes + + # Set attributes and get attributes + resp = d_client.vset().vsetattr("myset", "elem1", attributes) + assert resp == 1 + attr_saved = d_client.vset().vgetattr("myset", "elem1") + assert attr_saved == attributes + + # Set attributes to None + resp = d_client.vset().vsetattr("myset", "elem1", None) + assert resp == 1 + attr_saved = d_client.vset().vgetattr("myset", "elem1") + assert attr_saved is None + + # Set attributes to empty dict + resp = d_client.vset().vsetattr("myset", "elem1", {}) + assert resp == 1 + attr_saved = d_client.vset().vgetattr("myset", "elem1") + assert attr_saved is None + + # Set attributes provided as string + resp = d_client.vset().vsetattr("myset", "elem1", json.dumps(attributes)) + assert resp == 1 + attr_saved = d_client.vset().vgetattr("myset", "elem1") + assert attr_saved == attributes + + # Set attributes to unexisting element + resp = d_client.vset().vsetattr("myset", "elem2", attributes) + assert resp == 0 + attr_saved = d_client.vset().vgetattr("myset", "elem2") + assert attr_saved is None + + # Set attributes to unexisting vset + resp = d_client.vset().vsetattr("myset_unexisting", "elem1", attributes) + assert resp == 0 + attr_saved = d_client.vset().vgetattr("myset_unexisting", "elem1") + assert attr_saved is None + + +@skip_if_server_version_lt("7.9.0") +def test_vrandmember(d_client): + elements = ["elem1", "elem2", "elem3"] + for elem in elements: + float_array = [random.uniform(0, 10) for x in range(1, 8)] + d_client.vset().vadd("myset", float_array, element=elem) + + random_member = d_client.vset().vrandmember("myset") + assert random_member in elements + + members_list = d_client.vset().vrandmember("myset", count=2) + assert len(members_list) == 2 + assert all(member in elements for member in members_list) + + # Test with count greater than the number of elements + members_list = d_client.vset().vrandmember("myset", count=10) + assert len(members_list) == len(elements) + assert all(member in elements for member in members_list) + + # Test with negative count + members_list = d_client.vset().vrandmember("myset", count=-2) + assert len(members_list) == 2 + assert all(member in elements for member in members_list) + + # Test with count equal to the number of elements + members_list = d_client.vset().vrandmember("myset", count=len(elements)) + assert len(members_list) == len(elements) + assert all(member in elements for member in members_list) + + # Test with count equal to 0 + members_list = d_client.vset().vrandmember("myset", count=0) + assert members_list == [] + + # Test with count equal to 1 + members_list = d_client.vset().vrandmember("myset", count=1) + assert len(members_list) == 1 + assert members_list[0] in elements + + # Test with count equal to -1 + members_list = d_client.vset().vrandmember("myset", count=-1) + assert len(members_list) == 1 + assert members_list[0] in elements + + # Test with unexisting vset & without count + members_list = d_client.vset().vrandmember("myset_unexisting") + assert members_list is None + + # Test with unexisting vset & count + members_list = d_client.vset().vrandmember("myset_unexisting", count=5) + assert members_list == [] + + +@skip_if_server_version_lt("7.9.0") +def test_vset_commands_without_decoding_responces(client): + # test vadd + elements = ["elem1", "elem2", "elem3"] + for elem in elements: + float_array = [random.uniform(0, 10) for x in range(0, 8)] + resp = client.vset().vadd("myset", float_array, element=elem) + assert resp == 1 + + # test vemb + emb = client.vset().vemb("myset", "elem1") + assert len(emb) == 8 + assert isinstance(emb, list) + assert all(isinstance(x, float) for x in emb) + + emb_raw = client.vset().vemb("myset", "elem1", raw=True) + assert emb_raw["quantization"] == b"int8" + assert isinstance(emb_raw["raw"], bytes) + assert isinstance(emb_raw["l2"], float) + assert isinstance(emb_raw["range"], float) + + # test vsim + vsim = client.vset().vsim("myset", input="elem1") + assert len(vsim) == 3 + assert isinstance(vsim, list) + assert isinstance(vsim[0], bytes) + + # test vsim with scores + vsim_with_scores = client.vset().vsim("myset", input="elem1", with_scores=True) + assert len(vsim_with_scores) == 3 + assert isinstance(vsim_with_scores, dict) + assert isinstance(vsim_with_scores[b"elem1"], float) + + # test vlinks - no scores + element_links_all_layers = client.vset().vlinks("myset", "elem1") + assert len(element_links_all_layers) >= 1 + for neighbours_list_for_layer in element_links_all_layers: + assert isinstance(neighbours_list_for_layer, list) + for neighbour in neighbours_list_for_layer: + assert isinstance(neighbour, bytes) + # test vlinks with scores + elem_links_all_layers_with_scores = client.vset().vlinks( + "myset", "elem1", with_scores=True + ) + assert len(elem_links_all_layers_with_scores) >= 1 + for neighbours_dict_for_layer in elem_links_all_layers_with_scores: + assert isinstance(neighbours_dict_for_layer, dict) + for neighbour_key, score_value in neighbours_dict_for_layer.items(): + assert isinstance(neighbour_key, bytes) + assert isinstance(score_value, float) + + # test vinfo + vset_info = client.vset().vinfo("myset") + assert vset_info[b"quant-type"] == b"int8" + assert vset_info[b"vector-dim"] == 8 + assert vset_info[b"size"] == len(elements) + assert vset_info[b"max-level"] >= 0 + assert vset_info[b"hnsw-max-node-uid"] == len(elements) + + # test vgetattr + attributes = {"key1": "value1", "key2": "value2"} + client.vset().vsetattr("myset", "elem1", attributes) + attrs = client.vset().vgetattr("myset", "elem1") + assert attrs == attributes + + # test vrandmember + random_member = client.vset().vrandmember("myset") + assert isinstance(random_member, bytes) + assert random_member.decode("utf-8") in elements + + members_list = client.vset().vrandmember("myset", count=2) + assert len(members_list) == 2 + assert all(member.decode("utf-8") in elements for member in members_list) + + +def _to_fp32_blob_array(float_array): + """ + Convert a list of floats to a byte array in fp32 format. + """ + # Convert the list of floats to a NumPy array with dtype np.float32 + arr = np.array(float_array, dtype=np.float32) + # Convert the NumPy array to a byte array + byte_array = arr.tobytes() + return byte_array + + +def _validate_quantization(original, quantized, tolerance=0.1): + original = np.array(original, dtype=np.float32) + quantized = np.array(quantized, dtype=np.float32) + + max_diff = np.max(np.abs(original - quantized)) + if max_diff > tolerance: + return False + else: + return True From 266c59ce8dfc9044a4e1b3d3f16d09effa80920b Mon Sep 17 00:00:00 2001 From: andy-stark-redis <164213578+andy-stark-redis@users.noreply.github.com> Date: Wed, 9 Apr 2025 14:12:54 +0100 Subject: [PATCH 084/113] DOC-5073 added examples for vector sets intro page (#3590) --- doctests/dt_vec_set.py | 210 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 210 insertions(+) create mode 100644 doctests/dt_vec_set.py diff --git a/doctests/dt_vec_set.py b/doctests/dt_vec_set.py new file mode 100644 index 0000000000..398171a04c --- /dev/null +++ b/doctests/dt_vec_set.py @@ -0,0 +1,210 @@ +# EXAMPLE: vecset_tutorial +# HIDE_START +""" +Code samples for Vector set doc pages: + https://redis.io/docs/latest/develop/data-types/vector-sets/ +""" + +import redis + +r = redis.Redis(decode_responses=True) +# HIDE_END + +# REMOVE_START +r.delete("points") +# REMOVE_END + +# STEP_START vadd +res1 = r.vset().vadd("points", [1.0, 1.0], "pt:A") +print(res1) # >>> 1 + +res2 = r.vset().vadd("points", [-1.0, -1.0], "pt:B") +print(res2) # >>> 1 + +res3 = r.vset().vadd("points", [-1.0, 1.0], "pt:C") +print(res3) # >>> 1 + +res4 = r.vset().vadd("points", [1.0, -1.0], "pt:D") +print(res4) # >>> 1 + +res5 = r.vset().vadd("points", [1.0, 0], "pt:E") +print(res5) # >>> 1 + +res6 = r.type("points") +print(res6) # >>> vectorset +# STEP_END +# REMOVE_START +assert res1 == 1 +assert res2 == 1 +assert res3 == 1 +assert res4 == 1 +assert res5 == 1 + +assert res6 == "vectorset" +# REMOVE_END + +# STEP_START vcardvdim +res7 = r.vset().vcard("points") +print(res7) # >>> 5 + +res8 = r.vset().vdim("points") +print(res8) # >>> 2 +# STEP_END +# REMOVE_START +assert res7 == 5 +assert res8 == 2 +# REMOVE_END + +# STEP_START vemb +res9 = r.vset().vemb("points", "pt:A") +print(res9) # >>> [0.9999999403953552, 0.9999999403953552] + +res10 = r.vset().vemb("points", "pt:B") +print(res10) # >>> [-0.9999999403953552, -0.9999999403953552] + +res11 = r.vset().vemb("points", "pt:C") +print(res11) # >>> [-0.9999999403953552, 0.9999999403953552] + +res12 = r.vset().vemb("points", "pt:D") +print(res12) # >>> [0.9999999403953552, -0.9999999403953552] + +res13 = r.vset().vemb("points", "pt:E") +print(res13) # >>> [1, 0] +# STEP_END +# REMOVE_START +assert 1 - res9[0] < 0.001 +assert 1 - res9[1] < 0.001 +assert 1 + res10[0] < 0.001 +assert 1 + res10[1] < 0.001 +assert 1 + res11[0] < 0.001 +assert 1 - res11[1] < 0.001 +assert 1 - res12[0] < 0.001 +assert 1 + res12[1] < 0.001 +assert res13 == [1, 0] +# REMOVE_END + +# STEP_START attr +res14 = r.vset().vsetattr("points", "pt:A", { + "name": "Point A", + "description": "First point added" +}) +print(res14) # >>> 1 + +res15 = r.vset().vgetattr("points", "pt:A") +print(res15) +# >>> {'name': 'Point A', 'description': 'First point added'} + +res16 = r.vset().vsetattr("points", "pt:A", "") +print(res16) # >>> 1 + +res17 = r.vset().vgetattr("points", "pt:A") +print(res17) # >>> None +# STEP_END +# REMOVE_START +assert res14 == 1 +assert res15 == {"name": "Point A", "description": "First point added"} +assert res16 == 1 +assert res17 is None +# REMOVE_END + +# STEP_START vrem +res18 = r.vset().vadd("points", [0, 0], "pt:F") +print(res18) # >>> 1 + +res19 = r.vset().vcard("points") +print(res19) # >>> 6 + +res20 = r.vset().vrem("points", "pt:F") +print(res20) # >>> 1 + +res21 = r.vset().vcard("points") +print(res21) # >>> 5 +# STEP_END +# REMOVE_START +assert res18 == 1 +assert res19 == 6 +assert res20 == 1 +assert res21 == 5 +# REMOVE_END + +# STEP_START vsim_basic +res22 = r.vset().vsim("points", [0.9, 0.1]) +print(res22) +# >>> ['pt:E', 'pt:A', 'pt:D', 'pt:C', 'pt:B'] +# STEP_END +# REMOVE_START +assert res22 == ["pt:E", "pt:A", "pt:D", "pt:C", "pt:B"] +# REMOVE_END + +# STEP_START vsim_options +res23 = r.vset().vsim( + "points", "pt:A", + with_scores=True, + count=4 +) +print(res23) +# >>> {'pt:A': 1.0, 'pt:E': 0.8535534143447876, 'pt:D': 0.5, 'pt:C': 0.5} +# STEP_END +# REMOVE_START +assert res23["pt:A"] == 1.0 +assert res23["pt:C"] == 0.5 +assert res23["pt:D"] == 0.5 +assert res23["pt:E"] - 0.85 < 0.005 +# REMOVE_END + +# STEP_START vsim_filter +res24 = r.vset().vsetattr("points", "pt:A", { + "size": "large", + "price": 18.99 +}) +print(res24) # >>> 1 + +res25 = r.vset().vsetattr("points", "pt:B", { + "size": "large", + "price": 35.99 +}) +print(res25) # >>> 1 + +res26 = r.vset().vsetattr("points", "pt:C", { + "size": "large", + "price": 25.99 +}) +print(res26) # >>> 1 + +res27 = r.vset().vsetattr("points", "pt:D", { + "size": "small", + "price": 21.00 +}) +print(res27) # >>> 1 + +res28 = r.vset().vsetattr("points", "pt:E", { + "size": "small", + "price": 17.75 +}) +print(res28) # >>> 1 + +# Return elements in order of distance from point A whose +# `size` attribute is `large`. +res29 = r.vset().vsim( + "points", "pt:A", + filter='.size == "large"' +) +print(res29) # >>> ['pt:A', 'pt:C', 'pt:B'] + +# Return elements in order of distance from point A whose size is +# `large` and whose price is greater than 20.00. +res30 = r.vset().vsim( + "points", "pt:A", + filter='.size == "large" && .price > 20.00' +) +print(res30) # >>> ['pt:C', 'pt:B'] +# STEP_END +# REMOVE_START +assert res24 == 1 +assert res25 == 1 +assert res26 == 1 +assert res27 == 1 +assert res28 == 1 + +assert res30 == ['pt:C', 'pt:B'] +# REMOVE_END From 9a6479734167a8a31e8e419e985f30e0894a3e93 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Wed, 9 Apr 2025 17:22:44 +0300 Subject: [PATCH 085/113] Fixing some sporadically failing tests - part 1 (#3589) --- docker-compose.yml | 4 ++-- redis/_parsers/resp3.py | 4 ++-- tests/test_asyncio/test_lock.py | 4 ++-- tests/test_asyncio/test_vsets.py | 27 +++++++++++++++++++-------- tests/test_vsets.py | 31 +++++++++++++++++++++---------- 5 files changed, 46 insertions(+), 24 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index 75292bbd03..6b544553cb 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,9 +1,9 @@ --- x-client-libs-stack-image: &client-libs-stack-image - image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_STACK_IMAGE_TAG:-8.0-M06-pre}" + image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_STACK_IMAGE_TAG:-7.4.2}" x-client-libs-image: &client-libs-image - image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_IMAGE_TAG:-8.0-M06-pre}" + image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_IMAGE_TAG:-7.4.2}" services: diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py index 281546430b..ce4c59fb5b 100644 --- a/redis/_parsers/resp3.py +++ b/redis/_parsers/resp3.py @@ -19,7 +19,7 @@ def __init__(self, socket_read_size): def handle_pubsub_push_response(self, response): logger = getLogger("push_response") - logger.info("Push response: " + str(response)) + logger.debug("Push response: " + str(response)) return response def read_response(self, disable_decoding=False, push_request=False): @@ -150,7 +150,7 @@ def __init__(self, socket_read_size): async def handle_pubsub_push_response(self, response): logger = getLogger("push_response") - logger.info("Push response: " + str(response)) + logger.debug("Push response: " + str(response)) return response async def read_response( diff --git a/tests/test_asyncio/test_lock.py b/tests/test_asyncio/test_lock.py index be4270acdf..fff045a7f4 100644 --- a/tests/test_asyncio/test_lock.py +++ b/tests/test_asyncio/test_lock.py @@ -149,7 +149,7 @@ async def test_context_manager_not_raise_on_release_lock_error(self, r): async with self.get_lock( r, "foo", timeout=0.1, raise_on_release_error=False ) as lock: - lock.release() + await lock.release() except LockError: pytest.fail("LockError should not have been raised") @@ -157,7 +157,7 @@ async def test_context_manager_not_raise_on_release_lock_error(self, r): async with self.get_lock( r, "foo", timeout=0.1, raise_on_release_error=True ) as lock: - lock.release() + await lock.release() async def test_high_sleep_small_blocking_timeout(self, r): lock1 = self.get_lock(r, "foo") diff --git a/tests/test_asyncio/test_vsets.py b/tests/test_asyncio/test_vsets.py index 9abc899066..4ae336acf8 100644 --- a/tests/test_asyncio/test_vsets.py +++ b/tests/test_asyncio/test_vsets.py @@ -309,7 +309,7 @@ async def test_vsim_unexisting(d_client): @skip_if_server_version_lt("7.9.0") async def test_vsim_with_filter(d_client): - elements_count = 30 + elements_count = 50 vector_dim = 800 for i in range(elements_count): float_array = [random.uniform(0, 10) for x in range(vector_dim)] @@ -321,6 +321,15 @@ async def test_vsim_with_filter(d_client): numlinks=4, attributes=attributes, ) + float_array = [-random.uniform(10, 20) for x in range(vector_dim)] + attributes = {"index": elements_count, "elem_name": "elem_special"} + await d_client.vset().vadd( + "myset", + float_array, + "elem_special", + numlinks=4, + attributes=attributes, + ) sim = await d_client.vset().vsim("myset", input="elem_1", filter=".index > 10") assert len(sim) == 10 assert isinstance(sim, list) @@ -348,17 +357,19 @@ async def test_vsim_with_filter(d_client): sim = await d_client.vset().vsim( "myset", input="elem_1", - filter=".index > 28 and .elem_name in ['elem_12', 'elem_17', 'elem_29']", + filter=".index > 28 and .elem_name in ['elem_12', 'elem_17', 'elem_special']", filter_ef=1, ) - assert len(sim) == 0 + assert len(sim) == 0, ( + f"Expected 0 results, but got {len(sim)} with filter_ef=1, sim: {sim}" + ) assert isinstance(sim, list) sim = await d_client.vset().vsim( "myset", input="elem_1", - filter=".index > 28 and .elem_name in ['elem_12', 'elem_17', 'elem_29']", - filter_ef=20, + filter=".index > 28 and .elem_name in ['elem_12', 'elem_17', 'elem_special']", + filter_ef=500, ) assert len(sim) == 1 assert isinstance(sim, list) @@ -367,7 +378,7 @@ async def test_vsim_with_filter(d_client): @skip_if_server_version_lt("7.9.0") async def test_vsim_truth_no_thread_enabled(d_client): elements_count = 5000 - vector_dim = 30 + vector_dim = 50 for i in range(1, elements_count + 1): float_array = [random.uniform(10 * i, 1000 * i) for x in range(vector_dim)] await d_client.vset().vadd("myset", float_array, f"elem_{i}") @@ -394,7 +405,7 @@ async def test_vsim_truth_no_thread_enabled(d_client): ) found_better_match = False - for index, (score_with_truth, score_without_truth) in enumerate(results_scores): + for score_with_truth, score_without_truth in results_scores: if score_with_truth < score_without_truth: assert False, ( "Score with truth [{score_with_truth}] < score without truth [{score_without_truth}]" @@ -764,7 +775,7 @@ async def test_vset_commands_without_decoding_responces(client): # test vadd elements = ["elem1", "elem2", "elem3"] for elem in elements: - float_array = [random.uniform(0, 10) for x in range(0, 8)] + float_array = [random.uniform(0.5, 10) for x in range(0, 8)] resp = await client.vset().vadd("myset", float_array, element=elem) assert resp == 1 diff --git a/tests/test_vsets.py b/tests/test_vsets.py index ab4194657b..4a9d95bc1f 100644 --- a/tests/test_vsets.py +++ b/tests/test_vsets.py @@ -311,10 +311,10 @@ def test_vsim_unexisting(d_client): @skip_if_server_version_lt("7.9.0") def test_vsim_with_filter(d_client): - elements_count = 30 + elements_count = 50 vector_dim = 800 for i in range(elements_count): - float_array = [random.uniform(0, 10) for x in range(vector_dim)] + float_array = [random.uniform(10, 20) for x in range(vector_dim)] attributes = {"index": i, "elem_name": f"elem_{i}"} d_client.vset().vadd( "myset", @@ -323,6 +323,15 @@ def test_vsim_with_filter(d_client): numlinks=4, attributes=attributes, ) + float_array = [-random.uniform(10, 20) for x in range(vector_dim)] + attributes = {"index": elements_count, "elem_name": "elem_special"} + d_client.vset().vadd( + "myset", + float_array, + "elem_special", + numlinks=4, + attributes=attributes, + ) sim = d_client.vset().vsim("myset", input="elem_1", filter=".index > 10") assert len(sim) == 10 assert isinstance(sim, list) @@ -350,17 +359,19 @@ def test_vsim_with_filter(d_client): sim = d_client.vset().vsim( "myset", input="elem_1", - filter=".index > 28 and .elem_name in ['elem_12', 'elem_17', 'elem_29']", + filter=".index > 28 and .elem_name in ['elem_12', 'elem_17', 'elem_special']", filter_ef=1, ) - assert len(sim) == 0 + assert len(sim) == 0, ( + f"Expected 0 results, but got {len(sim)} with filter_ef=1, sim: {sim}" + ) assert isinstance(sim, list) sim = d_client.vset().vsim( "myset", input="elem_1", - filter=".index > 28 and .elem_name in ['elem_12', 'elem_17', 'elem_29']", - filter_ef=20, + filter=".index > 28 and .elem_name in ['elem_12', 'elem_17', 'elem_special']", + filter_ef=500, ) assert len(sim) == 1 assert isinstance(sim, list) @@ -369,7 +380,7 @@ def test_vsim_with_filter(d_client): @skip_if_server_version_lt("7.9.0") def test_vsim_truth_no_thread_enabled(d_client): elements_count = 5000 - vector_dim = 30 + vector_dim = 50 for i in range(1, elements_count + 1): float_array = [random.uniform(10 * i, 1000 * i) for x in range(vector_dim)] d_client.vset().vadd("myset", float_array, f"elem_{i}") @@ -396,7 +407,7 @@ def test_vsim_truth_no_thread_enabled(d_client): ) found_better_match = False - for index, (score_with_truth, score_without_truth) in enumerate(results_scores): + for score_with_truth, score_without_truth in results_scores: if score_with_truth < score_without_truth: assert False, ( "Score with truth [{score_with_truth}] < score without truth [{score_without_truth}]" @@ -764,7 +775,7 @@ def test_vset_commands_without_decoding_responces(client): # test vadd elements = ["elem1", "elem2", "elem3"] for elem in elements: - float_array = [random.uniform(0, 10) for x in range(0, 8)] + float_array = [random.uniform(0.5, 10) for x in range(0, 8)] resp = client.vset().vadd("myset", float_array, element=elem) assert resp == 1 @@ -772,7 +783,7 @@ def test_vset_commands_without_decoding_responces(client): emb = client.vset().vemb("myset", "elem1") assert len(emb) == 8 assert isinstance(emb, list) - assert all(isinstance(x, float) for x in emb) + assert all(isinstance(x, float) for x in emb), f"Expected float values, got {emb}" emb_raw = client.vset().vemb("myset", "elem1", raw=True) assert emb_raw["quantization"] == b"int8" From 19b9b720c181ee9efec635da250c767aa0d4c4b3 Mon Sep 17 00:00:00 2001 From: andy-stark-redis <164213578+andy-stark-redis@users.noreply.github.com> Date: Mon, 14 Apr 2025 08:58:17 +0100 Subject: [PATCH 086/113] DOC-5073 vector set quantization and dimension reduction examples (#3597) * DOC-5073 added example of quantization * DOC-5073 added example of reduce option for vadd() --- doctests/dt_vec_set.py | 78 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 77 insertions(+), 1 deletion(-) diff --git a/doctests/dt_vec_set.py b/doctests/dt_vec_set.py index 398171a04c..687ad90593 100644 --- a/doctests/dt_vec_set.py +++ b/doctests/dt_vec_set.py @@ -7,11 +7,18 @@ import redis +from redis.commands.vectorset.commands import ( + QuantizationOptions +) + r = redis.Redis(decode_responses=True) # HIDE_END # REMOVE_START -r.delete("points") +r.delete( + "points", "quantSetQ8", "quantSetNoQ", + "quantSetBin", "setNotReduced", "setReduced" +) # REMOVE_END # STEP_START vadd @@ -208,3 +215,72 @@ assert res30 == ['pt:C', 'pt:B'] # REMOVE_END + +# STEP_START add_quant +# Import `QuantizationOptions` enum using: +# +# from redis.commands.vectorset.commands import ( +# QuantizationOptions +# ) +res31 = r.vset().vadd( + "quantSetQ8", [1.262185, 1.958231], + "quantElement", + quantization=QuantizationOptions.Q8 +) +print(res31) # >>> 1 + +res32 = r.vset().vemb("quantSetQ8", "quantElement") +print(f"Q8: {res32}") +# >>> Q8: [1.2643694877624512, 1.958230972290039] + +res33 = r.vset().vadd( + "quantSetNoQ", [1.262185, 1.958231], + "quantElement", + quantization=QuantizationOptions.NOQUANT +) +print(res33) # >>> 1 + +res34 = r.vset().vemb("quantSetNoQ", "quantElement") +print(f"NOQUANT: {res34}") +# >>> NOQUANT: [1.262184977531433, 1.958230972290039] + +res35 = r.vset().vadd( + "quantSetBin", [1.262185, 1.958231], + "quantElement", + quantization=QuantizationOptions.BIN +) +print(res35) # >>> 1 + +res36 = r.vset().vemb("quantSetBin", "quantElement") +print(f"BIN: {res36}") +# >>> BIN: [1, 1] +# STEP_END +# REMOVE_START +assert res31 == 1 +# REMOVE_END + +# STEP_START add_reduce +# Create a list of 300 arbitrary values. +values = [x / 299 for x in range(300)] + +res37 = r.vset().vadd( + "setNotReduced", + values, + "element" +) +print(res37) # >>> 1 + +res38 = r.vset().vdim("setNotReduced") +print(res38) # >>> 300 + +res39 = r.vset().vadd( + "setReduced", + values, + "element", + reduce_dim=100 +) +print(res39) # >>> 1 + +res40 = r.vset().vdim("setReduced") # >>> 100 +print(res40) +# STEP_END From 4f1774578cab6c13e055a4c7385b678ddb47ac07 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Thu, 17 Apr 2025 14:25:35 +0300 Subject: [PATCH 087/113] Removing 'charset' and 'errors' inputs from the Redis initialization arguments - deprecated 3 years ago. (#3608) --- docs/advanced_features.rst | 4 ++-- redis/client.py | 17 ----------------- redis/cluster.py | 2 -- 3 files changed, 2 insertions(+), 21 deletions(-) diff --git a/docs/advanced_features.rst b/docs/advanced_features.rst index 0ed3e1ff34..603e728e84 100644 --- a/docs/advanced_features.rst +++ b/docs/advanced_features.rst @@ -384,13 +384,13 @@ run_in_thread. A PubSub object adheres to the same encoding semantics as the client instance it was created from. Any channel or pattern that's unicode will -be encoded using the charset specified on the client before being sent +be encoded using the encoding specified on the client before being sent to Redis. If the client's decode_responses flag is set the False (the default), the 'channel', 'pattern' and 'data' values in message dictionaries will be byte strings (str on Python 2, bytes on Python 3). If the client's decode_responses is True, then the 'channel', 'pattern' and 'data' values will be automatically decoded to unicode strings using -the client's charset. +the client's encoding. PubSub objects remember what channels and patterns they are subscribed to. In the event of a disconnection such as a network error or timeout, diff --git a/redis/client.py b/redis/client.py index e9435d33ef..fda927507a 100755 --- a/redis/client.py +++ b/redis/client.py @@ -2,7 +2,6 @@ import re import threading import time -import warnings from itertools import chain from typing import ( TYPE_CHECKING, @@ -203,8 +202,6 @@ def __init__( unix_socket_path: Optional[str] = None, encoding: str = "utf-8", encoding_errors: str = "strict", - charset: Optional[str] = None, - errors: Optional[str] = None, decode_responses: bool = False, retry_on_timeout: bool = False, retry_on_error: Optional[List[Type[Exception]]] = None, @@ -256,20 +253,6 @@ def __init__( else: self._event_dispatcher = event_dispatcher if not connection_pool: - if charset is not None: - warnings.warn( - DeprecationWarning( - '"charset" is deprecated. Use "encoding" instead' - ) - ) - encoding = charset - if errors is not None: - warnings.warn( - DeprecationWarning( - '"errors" is deprecated. Use "encoding_errors" instead' - ) - ) - encoding_errors = errors if not retry_on_error: retry_on_error = [] if retry_on_timeout is True: diff --git a/redis/cluster.py b/redis/cluster.py index 4ec03ac98f..39b454babe 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -142,7 +142,6 @@ def parse_cluster_myshardid(resp, **options): SLOT_ID = "slot-id" REDIS_ALLOWED_KEYS = ( - "charset", "connection_class", "connection_pool", "connection_pool_class", @@ -152,7 +151,6 @@ def parse_cluster_myshardid(resp, **options): "decode_responses", "encoding", "encoding_errors", - "errors", "host", "lib_name", "lib_version", From 8bcccd43b405f13b02ce8a96e25e2ba42778cb99 Mon Sep 17 00:00:00 2001 From: andy-stark-redis <164213578+andy-stark-redis@users.noreply.github.com> Date: Wed, 23 Apr 2025 08:15:51 +0100 Subject: [PATCH 088/113] DOC-5107 added hash examples for index/query intro page (#3609) * DOC-5107 added hash examples for index/query intro page * DOC-5107 restored old index_definition import --- doctests/home_json.py | 55 +++++++++++++++++++++++++++++++++-- doctests/query_agg.py | 2 +- doctests/query_combined.py | 2 +- doctests/query_em.py | 2 +- doctests/query_ft.py | 2 +- doctests/query_geo.py | 2 +- doctests/query_range.py | 2 +- doctests/search_quickstart.py | 2 +- doctests/search_vss.py | 2 +- 9 files changed, 61 insertions(+), 10 deletions(-) diff --git a/doctests/home_json.py b/doctests/home_json.py index 922c83d2fe..794f844034 100644 --- a/doctests/home_json.py +++ b/doctests/home_json.py @@ -10,7 +10,7 @@ import redis.commands.search.aggregation as aggregations import redis.commands.search.reducers as reducers from redis.commands.search.field import TextField, NumericField, TagField -from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.index_definition import IndexDefinition, IndexType from redis.commands.search.query import Query import redis.exceptions # STEP_END @@ -25,7 +25,12 @@ except redis.exceptions.ResponseError: pass -r.delete("user:1", "user:2", "user:3") +try: + r.ft("hash-idx:users").dropindex(True) +except redis.exceptions.ResponseError: + pass + +r.delete("user:1", "user:2", "user:3", "huser:1", "huser:2", "huser:3") # REMOVE_END # STEP_START create_data user1 = { @@ -134,4 +139,50 @@ ) # REMOVE_END +# STEP_START make_hash_index +hashSchema = ( + TextField("name"), + TagField("city"), + NumericField("age") +) + +hashIndexCreated = r.ft("hash-idx:users").create_index( + hashSchema, + definition=IndexDefinition( + prefix=["huser:"], index_type=IndexType.HASH + ) +) +# STEP_END +# REMOVE_START +assert hashIndexCreated +# REMOVE_END + +# STEP_START add_hash_data +huser1Set = r.hset("huser:1", mapping=user1) +huser2Set = r.hset("huser:2", mapping=user2) +huser3Set = r.hset("huser:3", mapping=user3) +# STEP_END +# REMOVE_START +assert huser1Set +assert huser2Set +assert huser3Set +# REMOVE_END + +# STEP_START query1_hash +findPaulHashResult = r.ft("hash-idx:users").search( + Query("Paul @age:[30 40]") +) + +print(findPaulHashResult) +# >>> Result{1 total, docs: [Document {'id': 'huser:3', +# >>> 'payload': None, 'name': 'Paul Zamir', ... +# STEP_END +# REMOVE_START +assert str(findPaulHashResult) == ( + "Result{1 total, docs: [Document " + + "{'id': 'huser:3', 'payload': None, 'name': 'Paul Zamir', " + + "'email': 'paul.zamir@example.com', 'age': '35', 'city': 'Tel Aviv'}]}" +) +# REMOVE_END + r.close() diff --git a/doctests/query_agg.py b/doctests/query_agg.py index 4fa8f14b84..4d81ddbcda 100644 --- a/doctests/query_agg.py +++ b/doctests/query_agg.py @@ -6,7 +6,7 @@ from redis.commands.search import Search from redis.commands.search.aggregation import AggregateRequest from redis.commands.search.field import NumericField, TagField -from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.index_definition import IndexDefinition, IndexType import redis.commands.search.reducers as reducers r = redis.Redis(decode_responses=True) diff --git a/doctests/query_combined.py b/doctests/query_combined.py index a17f19417c..e6dd5a2cb5 100644 --- a/doctests/query_combined.py +++ b/doctests/query_combined.py @@ -6,7 +6,7 @@ import warnings from redis.commands.json.path import Path from redis.commands.search.field import NumericField, TagField, TextField, VectorField -from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.index_definition import IndexDefinition, IndexType from redis.commands.search.query import Query from sentence_transformers import SentenceTransformer diff --git a/doctests/query_em.py b/doctests/query_em.py index a00ff11150..91cc5ae940 100644 --- a/doctests/query_em.py +++ b/doctests/query_em.py @@ -4,7 +4,7 @@ import redis from redis.commands.json.path import Path from redis.commands.search.field import TextField, NumericField, TagField -from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.index_definition import IndexDefinition, IndexType from redis.commands.search.query import NumericFilter, Query r = redis.Redis(decode_responses=True) diff --git a/doctests/query_ft.py b/doctests/query_ft.py index 182a5b2bd3..6272cdab25 100644 --- a/doctests/query_ft.py +++ b/doctests/query_ft.py @@ -5,7 +5,7 @@ import redis from redis.commands.json.path import Path from redis.commands.search.field import TextField, NumericField, TagField -from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.index_definition import IndexDefinition, IndexType from redis.commands.search.query import NumericFilter, Query r = redis.Redis(decode_responses=True) diff --git a/doctests/query_geo.py b/doctests/query_geo.py index dcb7db6ee7..ed8c9a5f99 100644 --- a/doctests/query_geo.py +++ b/doctests/query_geo.py @@ -5,7 +5,7 @@ import redis from redis.commands.json.path import Path from redis.commands.search.field import GeoField, GeoShapeField -from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.index_definition import IndexDefinition, IndexType from redis.commands.search.query import Query r = redis.Redis(decode_responses=True) diff --git a/doctests/query_range.py b/doctests/query_range.py index 4ef957acfb..674afc492a 100644 --- a/doctests/query_range.py +++ b/doctests/query_range.py @@ -5,7 +5,7 @@ import redis from redis.commands.json.path import Path from redis.commands.search.field import TextField, NumericField, TagField -from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.index_definition import IndexDefinition, IndexType from redis.commands.search.query import NumericFilter, Query r = redis.Redis(decode_responses=True) diff --git a/doctests/search_quickstart.py b/doctests/search_quickstart.py index e190393b16..cde4caa84a 100644 --- a/doctests/search_quickstart.py +++ b/doctests/search_quickstart.py @@ -10,7 +10,7 @@ import redis.commands.search.reducers as reducers from redis.commands.json.path import Path from redis.commands.search.field import NumericField, TagField, TextField -from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.index_definition import IndexDefinition, IndexType from redis.commands.search.query import Query # HIDE_END diff --git a/doctests/search_vss.py b/doctests/search_vss.py index 8b4884727a..a1132971db 100644 --- a/doctests/search_vss.py +++ b/doctests/search_vss.py @@ -20,7 +20,7 @@ TextField, VectorField, ) -from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.index_definition import IndexDefinition, IndexType from redis.commands.search.query import Query from sentence_transformers import SentenceTransformer From f5d5ff3444b7e484516fbdd525f0f5b5ab18f147 Mon Sep 17 00:00:00 2001 From: Elena Kolevska Date: Thu, 24 Apr 2025 14:48:37 +0100 Subject: [PATCH 089/113] Marks old RediSearch 1.0 commands as deprecated (#3606) * Marks old RediSearch 1.0 commands as deprecated * linters --- redis/commands/search/commands.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index 42866f5ec1..bc48fa9aa8 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -23,7 +23,6 @@ SEARCH_CMD = "FT.SEARCH" ADD_CMD = "FT.ADD" ADDHASH_CMD = "FT.ADDHASH" -DROP_CMD = "FT.DROP" DROPINDEX_CMD = "FT.DROPINDEX" EXPLAIN_CMD = "FT.EXPLAIN" EXPLAINCLI_CMD = "FT.EXPLAINCLI" @@ -35,7 +34,6 @@ DICT_ADD_CMD = "FT.DICTADD" DICT_DEL_CMD = "FT.DICTDEL" DICT_DUMP_CMD = "FT.DICTDUMP" -GET_CMD = "FT.GET" MGET_CMD = "FT.MGET" CONFIG_CMD = "FT.CONFIG" TAGVALS_CMD = "FT.TAGVALS" @@ -406,6 +404,7 @@ def add_document_hash(self, doc_id, score=1.0, language=None, replace=False): doc_id, conn=None, score=score, language=language, replace=replace ) + @deprecated_function(version="2.0.0", reason="deprecated since redisearch 2.0") def delete_document(self, doc_id, conn=None, delete_actual_document=False): """ Delete a document from index @@ -440,6 +439,7 @@ def load_document(self, id): return Document(id=id, **fields) + @deprecated_function(version="2.0.0", reason="deprecated since redisearch 2.0") def get(self, *ids): """ Returns the full contents of multiple documents. From d02fbbb9ea2a0ae41f3a7fe731af9d01072ce813 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Mon, 28 Apr 2025 15:43:51 +0300 Subject: [PATCH 090/113] Fixing flaky tests - part 2 (#3592) --- tests/test_asyncio/test_search.py | 18 ++++++++++++------ tests/test_cache.py | 2 ++ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 9a318796bf..2ee74f710f 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -1,14 +1,13 @@ import bz2 import csv import os -import time +import asyncio from io import TextIOWrapper import numpy as np import pytest import pytest_asyncio import redis.asyncio as redis -import redis.commands.search import redis.commands.search.aggregation as aggregations import redis.commands.search.reducers as reducers from redis.commands.search import AsyncSearch @@ -49,8 +48,8 @@ async def decoded_r(create_redis, stack_url): async def waitForIndex(env, idx, timeout=None): delay = 0.1 while True: - res = await env.execute_command("FT.INFO", idx) try: + res = await env.execute_command("FT.INFO", idx) if int(res[res.index("indexing") + 1]) == 0: break except ValueError: @@ -62,7 +61,7 @@ async def waitForIndex(env, idx, timeout=None): except ValueError: break - time.sleep(delay) + await asyncio.sleep(delay) if timeout is not None: timeout -= delay if timeout <= 0: @@ -1765,7 +1764,7 @@ async def test_binary_and_text_fields(decoded_r: redis.Redis): mixed_data = {"first_name": "🐍python", "vector_emb": fake_vec.tobytes()} await decoded_r.hset(f"{index_name}:1", mapping=mixed_data) - schema = ( + schema = [ TagField("first_name"), VectorField( "embeddings_bio", @@ -1776,7 +1775,7 @@ async def test_binary_and_text_fields(decoded_r: redis.Redis): "DISTANCE_METRIC": "COSINE", }, ), - ) + ] await decoded_r.ft(index_name).create_index( fields=schema, @@ -1784,6 +1783,7 @@ async def test_binary_and_text_fields(decoded_r: redis.Redis): prefix=[f"{index_name}:"], index_type=IndexType.HASH ), ) + await waitForIndex(decoded_r, index_name) query = ( Query("*") @@ -1793,6 +1793,12 @@ async def test_binary_and_text_fields(decoded_r: redis.Redis): result = await decoded_r.ft(index_name).search(query=query, query_params={}) docs = result.docs + if len(docs) == 0: + hash_content = await decoded_r.hget(f"{index_name}:1", "first_name") + assert len(docs) > 0, ( + f"Returned search results are empty. Result: {result}; Hash: {hash_content}" + ) + decoded_vec_from_search_results = np.frombuffer( docs[0]["vector_emb"], dtype=np.float32 ) diff --git a/tests/test_cache.py b/tests/test_cache.py index 7010baff5f..a305d2de7b 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -2,6 +2,7 @@ import pytest import redis + from redis.cache import ( CacheConfig, CacheEntry, @@ -636,6 +637,7 @@ def test_get_from_default_cache(self, r, r2): ] # change key in redis (cause invalidation) r2.set("foo", "barbar") + time.sleep(0.1) # Retrieves a new value from server and cache_data it assert r.get("foo") in [b"barbar", "barbar"] # Make sure that new value was cached From 6573d38636de8e3263c403be53b12853389bde3f Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Mon, 28 Apr 2025 18:22:55 +0300 Subject: [PATCH 091/113] Updating default retry strategy for standalone clients. 3 retries with ExponentialWithJitterBackoff become the default config. (#3614) * Changing the default retry configuration for Redis standalone clients. * Updating default retry strategy for standalone clients. 3 retries with ExponentialWithJitterBackoff become the default config. * Applying review comments - removing unused methods from retry objects, updating pydocs of error handler method --- redis/asyncio/client.py | 141 ++++++++++---------- redis/client.py | 148 ++++++++++++--------- redis/connection.py | 2 +- tests/test_asyncio/test_connection_pool.py | 6 +- tests/test_connection_pool.py | 6 +- tests/test_retry.py | 4 +- 6 files changed, 165 insertions(+), 142 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 3f35fdd59e..ac907b0c10 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -39,6 +39,7 @@ ) from redis.asyncio.lock import Lock from redis.asyncio.retry import Retry +from redis.backoff import ExponentialWithJitterBackoff from redis.client import ( EMPTY_RESPONSE, NEVER_DECODE, @@ -65,7 +66,6 @@ PubSubError, RedisError, ResponseError, - TimeoutError, WatchError, ) from redis.typing import ChannelT, EncodableT, KeyT @@ -73,6 +73,7 @@ HIREDIS_AVAILABLE, SSL_AVAILABLE, _set_info_logger, + deprecated_args, deprecated_function, get_lib_version, safe_str, @@ -208,6 +209,11 @@ def from_pool( client.auto_close_connection_pool = True return client + @deprecated_args( + args_to_warn=["retry_on_timeout"], + reason="TimeoutError is included by default.", + version="6.0.0", + ) def __init__( self, *, @@ -225,6 +231,9 @@ def __init__( encoding_errors: str = "strict", decode_responses: bool = False, retry_on_timeout: bool = False, + retry: Retry = Retry( + backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 + ), retry_on_error: Optional[list] = None, ssl: bool = False, ssl_keyfile: Optional[str] = None, @@ -242,7 +251,6 @@ def __init__( lib_name: Optional[str] = "redis-py", lib_version: Optional[str] = get_lib_version(), username: Optional[str] = None, - retry: Optional[Retry] = None, auto_close_connection_pool: Optional[bool] = None, redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, @@ -251,10 +259,24 @@ def __init__( ): """ Initialize a new Redis client. - To specify a retry policy for specific errors, first set - `retry_on_error` to a list of the error/s to retry on, then set - `retry` to a valid `Retry` object. - To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. + + To specify a retry policy for specific errors, you have two options: + + 1. Set the `retry_on_error` to a list of the error/s to retry on, and + you can also set `retry` to a valid `Retry` object(in case the default + one is not appropriate) - with this approach the retries will be triggered + on the default errors specified in the Retry object enriched with the + errors specified in `retry_on_error`. + + 2. Define a `Retry` object with configured 'supported_errors' and set + it to the `retry` parameter - with this approach you completely redefine + the errors on which retries will happen. + + `retry_on_timeout` is deprecated - please include the TimeoutError + either in the Retry object or in the `retry_on_error` list. + + When 'connection_pool' is provided - the retry configuration of the + provided pool will be used. """ kwargs: Dict[str, Any] if event_dispatcher is None: @@ -280,8 +302,6 @@ def __init__( # Create internal connection pool, expected to be closed by Redis instance if not retry_on_error: retry_on_error = [] - if retry_on_timeout is True: - retry_on_error.append(TimeoutError) kwargs = { "db": db, "username": username, @@ -291,7 +311,6 @@ def __init__( "encoding": encoding, "encoding_errors": encoding_errors, "decode_responses": decode_responses, - "retry_on_timeout": retry_on_timeout, "retry_on_error": retry_on_error, "retry": copy.deepcopy(retry), "max_connections": max_connections, @@ -403,10 +422,10 @@ def get_connection_kwargs(self): """Get the connection's key-word arguments""" return self.connection_pool.connection_kwargs - def get_retry(self) -> Optional["Retry"]: + def get_retry(self) -> Optional[Retry]: return self.get_connection_kwargs().get("retry") - def set_retry(self, retry: "Retry") -> None: + def set_retry(self, retry: Retry) -> None: self.get_connection_kwargs().update({"retry": retry}) self.connection_pool.set_retry(retry) @@ -633,18 +652,17 @@ async def _send_command_parse_response(self, conn, command_name, *args, **option await conn.send_command(*args) return await self.parse_response(conn, command_name, **options) - async def _disconnect_raise(self, conn: Connection, error: Exception): + async def _close_connection(self, conn: Connection): """ - Close the connection and raise an exception - if retry_on_error is not set or the error - is not one of the specified error types + Close the connection before retrying. + + The supported exceptions are already checked in the + retry object so we don't need to do it here. + + After we disconnect the connection, it will try to reconnect and + do a health check as part of the send_command logic(on connection level). """ await conn.disconnect() - if ( - conn.retry_on_error is None - or isinstance(error, tuple(conn.retry_on_error)) is False - ): - raise error # COMMAND EXECUTION AND PROTOCOL PARSING async def execute_command(self, *args, **options): @@ -661,7 +679,7 @@ async def execute_command(self, *args, **options): lambda: self._send_command_parse_response( conn, command_name, *args, **options ), - lambda error: self._disconnect_raise(conn, error), + lambda _: self._close_connection(conn), ) finally: if self.single_connection_client: @@ -929,19 +947,11 @@ async def connect(self): ) ) - async def _disconnect_raise_connect(self, conn, error): + async def _reconnect(self, conn): """ - Close the connection and raise an exception - if retry_on_error is not set or the error is not one - of the specified error types. Otherwise, try to - reconnect + Try to reconnect """ await conn.disconnect() - if ( - conn.retry_on_error is None - or isinstance(error, tuple(conn.retry_on_error)) is False - ): - raise error await conn.connect() async def _execute(self, conn, command, *args, **kwargs): @@ -954,7 +964,7 @@ async def _execute(self, conn, command, *args, **kwargs): """ return await conn.retry.call_with_retry( lambda: command(*args, **kwargs), - lambda error: self._disconnect_raise_connect(conn, error), + lambda _: self._reconnect(conn), ) async def parse_response(self, block: bool = True, timeout: float = 0): @@ -1245,7 +1255,8 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass] in one transmission. This is convenient for batch processing, such as saving all the values in a list to Redis. - All commands executed within a pipeline are wrapped with MULTI and EXEC + All commands executed within a pipeline(when running in transactional mode, + which is the default behavior) are wrapped with MULTI and EXEC calls. This guarantees all commands executed in the pipeline will be executed atomically. @@ -1274,7 +1285,7 @@ def __init__( self.shard_hint = shard_hint self.watching = False self.command_stack: CommandStackT = [] - self.scripts: Set["Script"] = set() + self.scripts: Set[Script] = set() self.explicit_transaction = False async def __aenter__(self: _RedisT) -> _RedisT: @@ -1346,36 +1357,36 @@ def execute_command( return self.immediate_execute_command(*args, **kwargs) return self.pipeline_execute_command(*args, **kwargs) - async def _disconnect_reset_raise(self, conn, error): + async def _disconnect_reset_raise_on_watching( + self, + conn: Connection, + error: Exception, + ): """ - Close the connection, reset watching state and - raise an exception if we were watching, - if retry_on_error is not set or the error is not one - of the specified error types. + Close the connection reset watching state and + raise an exception if we were watching. + + The supported exceptions are already checked in the + retry object so we don't need to do it here. + + After we disconnect the connection, it will try to reconnect and + do a health check as part of the send_command logic(on connection level). """ await conn.disconnect() # if we were already watching a variable, the watch is no longer # valid since this connection has died. raise a WatchError, which # indicates the user should retry this transaction. if self.watching: - await self.aclose() + await self.reset() raise WatchError( - "A ConnectionError occurred on while watching one or more keys" + f"A {type(error).__name__} occurred while watching one or more keys" ) - # if retry_on_error is not set or the error is not one - # of the specified error types, raise it - if ( - conn.retry_on_error is None - or isinstance(error, tuple(conn.retry_on_error)) is False - ): - await self.aclose() - raise async def immediate_execute_command(self, *args, **options): """ - Execute a command immediately, but don't auto-retry on a - ConnectionError if we're already WATCHing a variable. Used when - issuing WATCH or subsequent commands retrieving their values but before + Execute a command immediately, but don't auto-retry on the supported + errors for retry if we're already WATCHing a variable. + Used when issuing WATCH or subsequent commands retrieving their values but before MULTI is called. """ command_name = args[0] @@ -1389,7 +1400,7 @@ async def immediate_execute_command(self, *args, **options): lambda: self._send_command_parse_response( conn, command_name, *args, **options ), - lambda error: self._disconnect_reset_raise(conn, error), + lambda error: self._disconnect_reset_raise_on_watching(conn, error), ) def pipeline_execute_command(self, *args, **options): @@ -1544,11 +1555,15 @@ async def load_scripts(self): if not exist: s.sha = await immediate("SCRIPT LOAD", s.script) - async def _disconnect_raise_reset(self, conn: Connection, error: Exception): + async def _disconnect_raise_on_watching(self, conn: Connection, error: Exception): """ - Close the connection, raise an exception if we were watching, - and raise an exception if retry_on_error is not set or the - error is not one of the specified error types. + Close the connection, raise an exception if we were watching. + + The supported exceptions are already checked in the + retry object so we don't need to do it here. + + After we disconnect the connection, it will try to reconnect and + do a health check as part of the send_command logic(on connection level). """ await conn.disconnect() # if we were watching a variable, the watch is no longer valid @@ -1556,16 +1571,8 @@ async def _disconnect_raise_reset(self, conn: Connection, error: Exception): # indicates the user should retry this transaction. if self.watching: raise WatchError( - "A ConnectionError occurred on while watching one or more keys" + f"A {type(error).__name__} occurred while watching one or more keys" ) - # if retry_on_error is not set or the error is not one - # of the specified error types, raise it - if ( - conn.retry_on_error is None - or isinstance(error, tuple(conn.retry_on_error)) is False - ): - await self.reset() - raise async def execute(self, raise_on_error: bool = True) -> List[Any]: """Execute all the commands in the current pipeline""" @@ -1590,7 +1597,7 @@ async def execute(self, raise_on_error: bool = True) -> List[Any]: try: return await conn.retry.call_with_retry( lambda: execute(conn, stack, raise_on_error), - lambda error: self._disconnect_raise_reset(conn, error), + lambda error: self._disconnect_raise_on_watching(conn, error), ) finally: await self.reset() diff --git a/redis/client.py b/redis/client.py index fda927507a..138f561974 100755 --- a/redis/client.py +++ b/redis/client.py @@ -11,6 +11,7 @@ List, Mapping, Optional, + Set, Type, Union, ) @@ -22,6 +23,7 @@ _RedisCallbacksRESP3, bool_ok, ) +from redis.backoff import ExponentialWithJitterBackoff from redis.cache import CacheConfig, CacheInterface from redis.commands import ( CoreCommands, @@ -29,6 +31,7 @@ SentinelCommands, list_or_args, ) +from redis.commands.core import Script from redis.connection import ( AbstractConnection, ConnectionPool, @@ -49,7 +52,6 @@ PubSubError, RedisError, ResponseError, - TimeoutError, WatchError, ) from redis.lock import Lock @@ -57,6 +59,7 @@ from redis.utils import ( HIREDIS_AVAILABLE, _set_info_logger, + deprecated_args, get_lib_version, safe_str, str_if_bytes, @@ -188,6 +191,11 @@ def from_pool( client.auto_close_connection_pool = True return client + @deprecated_args( + args_to_warn=["retry_on_timeout"], + reason="TimeoutError is included by default.", + version="6.0.0", + ) def __init__( self, host: str = "localhost", @@ -204,6 +212,9 @@ def __init__( encoding_errors: str = "strict", decode_responses: bool = False, retry_on_timeout: bool = False, + retry: Retry = Retry( + backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 + ), retry_on_error: Optional[List[Type[Exception]]] = None, ssl: bool = False, ssl_keyfile: Optional[str] = None, @@ -227,7 +238,6 @@ def __init__( lib_name: Optional[str] = "redis-py", lib_version: Optional[str] = get_lib_version(), username: Optional[str] = None, - retry: Optional[Retry] = None, redis_connect_func: Optional[Callable[[], None]] = None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, @@ -237,10 +247,24 @@ def __init__( ) -> None: """ Initialize a new Redis client. - To specify a retry policy for specific errors, first set - `retry_on_error` to a list of the error/s to retry on, then set - `retry` to a valid `Retry` object. - To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. + + To specify a retry policy for specific errors, you have two options: + + 1. Set the `retry_on_error` to a list of the error/s to retry on, and + you can also set `retry` to a valid `Retry` object(in case the default + one is not appropriate) - with this approach the retries will be triggered + on the default errors specified in the Retry object enriched with the + errors specified in `retry_on_error`. + + 2. Define a `Retry` object with configured 'supported_errors' and set + it to the `retry` parameter - with this approach you completely redefine + the errors on which retries will happen. + + `retry_on_timeout` is deprecated - please include the TimeoutError + either in the Retry object or in the `retry_on_error` list. + + When 'connection_pool' is provided - the retry configuration of the + provided pool will be used. Args: @@ -255,8 +279,6 @@ def __init__( if not connection_pool: if not retry_on_error: retry_on_error = [] - if retry_on_timeout is True: - retry_on_error.append(TimeoutError) kwargs = { "db": db, "username": username, @@ -378,10 +400,10 @@ def get_connection_kwargs(self) -> Dict: """Get the connection's key-word arguments""" return self.connection_pool.connection_kwargs - def get_retry(self) -> Optional["Retry"]: + def get_retry(self) -> Optional[Retry]: return self.get_connection_kwargs().get("retry") - def set_retry(self, retry: "Retry") -> None: + def set_retry(self, retry: Retry) -> None: self.get_connection_kwargs().update({"retry": retry}) self.connection_pool.set_retry(retry) @@ -581,18 +603,18 @@ def _send_command_parse_response(self, conn, command_name, *args, **options): conn.send_command(*args, **options) return self.parse_response(conn, command_name, **options) - def _disconnect_raise(self, conn, error): + def _close_connection(self, conn) -> None: """ - Close the connection and raise an exception - if retry_on_error is not set or the error - is not one of the specified error types + Close the connection before retrying. + + The supported exceptions are already checked in the + retry object so we don't need to do it here. + + After we disconnect the connection, it will try to reconnect and + do a health check as part of the send_command logic(on connection level). """ + conn.disconnect() - if ( - conn.retry_on_error is None - or isinstance(error, tuple(conn.retry_on_error)) is False - ): - raise error # COMMAND EXECUTION AND PROTOCOL PARSING def execute_command(self, *args, **options): @@ -611,7 +633,7 @@ def _execute_command(self, *args, **options): lambda: self._send_command_parse_response( conn, command_name, *args, **options ), - lambda error: self._disconnect_raise(conn, error), + lambda _: self._close_connection(conn), ) finally: if self._single_connection_client: @@ -870,19 +892,14 @@ def clean_health_check_responses(self) -> None: ) ttl -= 1 - def _disconnect_raise_connect(self, conn, error) -> None: + def _reconnect(self, conn) -> None: """ - Close the connection and raise an exception - if retry_on_error is not set or the error is not one - of the specified error types. Otherwise, try to - reconnect + The supported exceptions are already checked in the + retry object so we don't need to do it here. + + In this error handler we are trying to reconnect to the server. """ conn.disconnect() - if ( - conn.retry_on_error is None - or isinstance(error, tuple(conn.retry_on_error)) is False - ): - raise error conn.connect() def _execute(self, conn, command, *args, **kwargs): @@ -895,7 +912,7 @@ def _execute(self, conn, command, *args, **kwargs): """ return conn.retry.call_with_retry( lambda: command(*args, **kwargs), - lambda error: self._disconnect_raise_connect(conn, error), + lambda _: self._reconnect(conn), ) def parse_response(self, block=True, timeout=0): @@ -1264,7 +1281,8 @@ class Pipeline(Redis): in one transmission. This is convenient for batch processing, such as saving all the values in a list to Redis. - All commands executed within a pipeline are wrapped with MULTI and EXEC + All commands executed within a pipeline(when running in transactional mode, + which is the default behavior) are wrapped with MULTI and EXEC calls. This guarantees all commands executed in the pipeline will be executed atomically. @@ -1285,9 +1303,10 @@ def __init__(self, connection_pool, response_callbacks, transaction, shard_hint) self.response_callbacks = response_callbacks self.transaction = transaction self.shard_hint = shard_hint - self.watching = False - self.reset() + self.command_stack = [] + self.scripts: Set[Script] = set() + self.explicit_transaction = False def __enter__(self) -> "Pipeline": return self @@ -1353,36 +1372,37 @@ def execute_command(self, *args, **kwargs): return self.immediate_execute_command(*args, **kwargs) return self.pipeline_execute_command(*args, **kwargs) - def _disconnect_reset_raise(self, conn, error) -> None: + def _disconnect_reset_raise_on_watching( + self, + conn: AbstractConnection, + error: Exception, + ) -> None: """ - Close the connection, reset watching state and - raise an exception if we were watching, - if retry_on_error is not set or the error is not one - of the specified error types. + Close the connection reset watching state and + raise an exception if we were watching. + + The supported exceptions are already checked in the + retry object so we don't need to do it here. + + After we disconnect the connection, it will try to reconnect and + do a health check as part of the send_command logic(on connection level). """ conn.disconnect() + # if we were already watching a variable, the watch is no longer # valid since this connection has died. raise a WatchError, which # indicates the user should retry this transaction. if self.watching: self.reset() raise WatchError( - "A ConnectionError occurred on while watching one or more keys" + f"A {type(error).__name__} occurred while watching one or more keys" ) - # if retry_on_error is not set or the error is not one - # of the specified error types, raise it - if ( - conn.retry_on_error is None - or isinstance(error, tuple(conn.retry_on_error)) is False - ): - self.reset() - raise def immediate_execute_command(self, *args, **options): """ - Execute a command immediately, but don't auto-retry on a - ConnectionError if we're already WATCHing a variable. Used when - issuing WATCH or subsequent commands retrieving their values but before + Execute a command immediately, but don't auto-retry on the supported + errors for retry if we're already WATCHing a variable. + Used when issuing WATCH or subsequent commands retrieving their values but before MULTI is called. """ command_name = args[0] @@ -1396,7 +1416,7 @@ def immediate_execute_command(self, *args, **options): lambda: self._send_command_parse_response( conn, command_name, *args, **options ), - lambda error: self._disconnect_reset_raise(conn, error), + lambda error: self._disconnect_reset_raise_on_watching(conn, error), ) def pipeline_execute_command(self, *args, **options) -> "Pipeline": @@ -1534,15 +1554,19 @@ def load_scripts(self): if not exist: s.sha = immediate("SCRIPT LOAD", s.script) - def _disconnect_raise_reset( + def _disconnect_raise_on_watching( self, conn: AbstractConnection, error: Exception, ) -> None: """ - Close the connection, raise an exception if we were watching, - and raise an exception if retry_on_error is not set or the - error is not one of the specified error types. + Close the connection, raise an exception if we were watching. + + The supported exceptions are already checked in the + retry object so we don't need to do it here. + + After we disconnect the connection, it will try to reconnect and + do a health check as part of the send_command logic(on connection level). """ conn.disconnect() # if we were watching a variable, the watch is no longer valid @@ -1550,16 +1574,8 @@ def _disconnect_raise_reset( # indicates the user should retry this transaction. if self.watching: raise WatchError( - "A ConnectionError occurred on while watching one or more keys" + f"A {type(error).__name__} occurred while watching one or more keys" ) - # if retry_on_error is not set or the error is not one - # of the specified error types, raise it - if ( - conn.retry_on_error is None - or isinstance(error, tuple(conn.retry_on_error)) is False - ): - self.reset() - raise error def execute(self, raise_on_error: bool = True) -> List[Any]: """Execute all the commands in the current pipeline""" @@ -1583,7 +1599,7 @@ def execute(self, raise_on_error: bool = True) -> List[Any]: try: return conn.retry.call_with_retry( lambda: execute(conn, stack, raise_on_error), - lambda error: self._disconnect_raise_reset(conn, error), + lambda error: self._disconnect_raise_on_watching(conn, error), ) finally: self.reset() diff --git a/redis/connection.py b/redis/connection.py index 08e980e866..ffb1e37ba3 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1611,7 +1611,7 @@ def close(self) -> None: """Close the pool, disconnecting all connections""" self.disconnect() - def set_retry(self, retry: "Retry") -> None: + def set_retry(self, retry: Retry) -> None: self.connection_kwargs.update({"retry": retry}) for conn in self._available_connections: conn.retry = retry diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index 3d120e4ca7..09409e04a8 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -614,9 +614,9 @@ async def test_busy_loading_from_pipeline_immediate_command(self, r): "DEBUG", "ERROR", "LOADING fake message" ) pool = r.connection_pool - assert not pipe.connection - assert len(pool._available_connections) == 1 - assert not pool._available_connections[0]._reader + assert pipe.connection + assert pipe.connection in pool._in_use_connections + assert not pipe.connection._reader @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.8.8") diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 0ec77a4fff..d97c9063ac 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -563,9 +563,9 @@ def test_busy_loading_from_pipeline_immediate_command(self, r): with pytest.raises(redis.BusyLoadingError): pipe.immediate_execute_command("DEBUG", "ERROR", "LOADING fake message") pool = r.connection_pool - assert not pipe.connection - assert len(pool._available_connections) == 1 - assert not pool._available_connections[0]._sock + assert pipe.connection + assert pipe.connection in pool._in_use_connections + assert not pipe.connection._sock @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.8.8") diff --git a/tests/test_retry.py b/tests/test_retry.py index e1e4c414a4..cb001fbbd5 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -159,7 +159,7 @@ def test_client_retry_on_error_raise(self, request): def test_client_retry_on_error_different_error_raised(self, request): with patch.object(Redis, "parse_response") as parse_response: - parse_response.side_effect = TimeoutError() + parse_response.side_effect = OSError() retries = 3 r = _get_client( Redis, @@ -167,7 +167,7 @@ def test_client_retry_on_error_different_error_raised(self, request): retry_on_error=[ReadOnlyError], retry=Retry(NoBackoff(), retries), ) - with pytest.raises(TimeoutError): + with pytest.raises(OSError): try: r.get("foo") finally: From 8dadea2504f83c866c98035172bb89c908904c38 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Tue, 29 Apr 2025 14:40:17 +0300 Subject: [PATCH 092/113] Handling some special values when transforming responce data into list(issue #3573) (#3586) --- redis/commands/helpers.py | 25 +++++++++++++++++++------ tests/test_asyncio/test_bloom.py | 28 ++++++++++++++++++++++++++++ tests/test_bloom.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 6 deletions(-) diff --git a/redis/commands/helpers.py b/redis/commands/helpers.py index f6121b6c3b..859a43aea9 100644 --- a/redis/commands/helpers.py +++ b/redis/commands/helpers.py @@ -43,19 +43,32 @@ def parse_to_list(response): """Optimistically parse the response to a list.""" res = [] + special_values = {"infinity", "nan", "-infinity"} + if response is None: return res for item in response: + if item is None: + res.append(None) + continue try: - res.append(int(item)) - except ValueError: - try: - res.append(float(item)) - except ValueError: - res.append(nativestr(item)) + item_str = nativestr(item) except TypeError: res.append(None) + continue + + if isinstance(item_str, str) and item_str.lower() in special_values: + res.append(item_str) # Keep as string + else: + try: + res.append(int(item)) + except ValueError: + try: + res.append(float(item)) + except ValueError: + res.append(item_str) + return res diff --git a/tests/test_asyncio/test_bloom.py b/tests/test_asyncio/test_bloom.py index 031e8364d7..d67858570f 100644 --- a/tests/test_asyncio/test_bloom.py +++ b/tests/test_asyncio/test_bloom.py @@ -333,6 +333,34 @@ async def test_topk(decoded_r: redis.Redis): assert 0.9 == round(float(info["decay"]), 1) +@pytest.mark.redismod +async def test_topk_list_with_special_words(decoded_r: redis.Redis): + # test list with empty buckets + assert await decoded_r.topk().reserve("topklist:specialwords", 5, 20, 4, 0.9) + assert await decoded_r.topk().add( + "topklist:specialwords", + "infinity", + "B", + "nan", + "D", + "-infinity", + "infinity", + "infinity", + "B", + "nan", + "G", + "D", + "B", + "D", + "infinity", + "-infinity", + "-infinity", + ) + assert ["infinity", "B", "D", "-infinity", "nan"] == await decoded_r.topk().list( + "topklist:specialwords" + ) + + @pytest.mark.redismod async def test_topk_incrby(decoded_r: redis.Redis): await decoded_r.flushdb() diff --git a/tests/test_bloom.py b/tests/test_bloom.py index e44c421634..a8d6390048 100644 --- a/tests/test_bloom.py +++ b/tests/test_bloom.py @@ -364,6 +364,34 @@ def test_topk(client): assert 0.9 == round(float(info["decay"]), 1) +@pytest.mark.redismod +def test_topk_list_with_special_words(client): + # test list with empty buckets + assert client.topk().reserve("topklist:specialwords", 5, 20, 4, 0.9) + assert client.topk().add( + "topklist:specialwords", + "infinity", + "B", + "nan", + "D", + "-infinity", + "infinity", + "infinity", + "B", + "nan", + "G", + "D", + "B", + "D", + "infinity", + "-infinity", + "-infinity", + ) + assert ["infinity", "B", "D", "-infinity", "nan"] == client.topk().list( + "topklist:specialwords" + ) + + @pytest.mark.redismod def test_topk_incrby(client): client.flushdb() From fb547af9c7b29749578fb3bb716fd369f543afb9 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Tue, 29 Apr 2025 16:48:48 +0300 Subject: [PATCH 093/113] When SlotNotCoveredError is raised, the cluster topology should be reinitialized as part of error handling and retrying of the commands. (#3621) --- redis/asyncio/cluster.py | 8 +++++++- redis/cluster.py | 17 ++++++++++++++--- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index f58ae50a40..9749ba7b6f 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -818,10 +818,16 @@ async def _execute_command( # and try again with the new setup await self.aclose() raise - except ClusterDownError: + except (ClusterDownError, SlotNotCoveredError): # ClusterDownError can occur during a failover and to get # self-healed, we will try to reinitialize the cluster layout # and retry executing the command + + # SlotNotCoveredError can occur when the cluster is not fully + # initialized or can be temporary issue. + # We will try to reinitialize the cluster topology + # and retry executing the command + await self.aclose() await asyncio.sleep(0.25) raise diff --git a/redis/cluster.py b/redis/cluster.py index 39b454babe..fc5ffab892 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -410,7 +410,12 @@ class AbstractRedisCluster: list_keys_to_dict(["SCRIPT FLUSH"], lambda command, res: all(res.values())), ) - ERRORS_ALLOW_RETRY = (ConnectionError, TimeoutError, ClusterDownError) + ERRORS_ALLOW_RETRY = ( + ConnectionError, + TimeoutError, + ClusterDownError, + SlotNotCoveredError, + ) def replace_default_node(self, target_node: "ClusterNode" = None) -> None: """Replace the default cluster node. @@ -1239,13 +1244,19 @@ def _execute_command(self, target_node, *args, **kwargs): except AskError as e: redirect_addr = get_node_name(host=e.host, port=e.port) asking = True - except ClusterDownError as e: + except (ClusterDownError, SlotNotCoveredError): # ClusterDownError can occur during a failover and to get # self-healed, we will try to reinitialize the cluster layout # and retry executing the command + + # SlotNotCoveredError can occur when the cluster is not fully + # initialized or can be temporary issue. + # We will try to reinitialize the cluster topology + # and retry executing the command + time.sleep(0.25) self.nodes_manager.initialize() - raise e + raise except ResponseError: raise except Exception as e: From c918139d2a1695dd749c98a6659c2ec3094b1634 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Wed, 30 Apr 2025 13:02:33 +0300 Subject: [PATCH 094/113] Fixing the versions of some deprecations that wrongly added as 5.0.3 - the correct version is 5.3.0 (#3625) --- redis/asyncio/cluster.py | 2 +- redis/asyncio/connection.py | 4 ++-- redis/cluster.py | 6 +++--- redis/connection.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 9749ba7b6f..e6fd52d790 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -237,7 +237,7 @@ def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster": @deprecated_args( args_to_warn=["read_from_replicas"], reason="Please configure the 'load_balancing_strategy' instead", - version="5.0.3", + version="5.3.0", ) def __init__( self, diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 7404f3d6f8..70d7d91898 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -1133,7 +1133,7 @@ def can_get_connection(self) -> bool: @deprecated_args( args_to_warn=["*"], reason="Use get_connection() without args instead", - version="5.0.3", + version="5.3.0", ) async def get_connection(self, command_name=None, *keys, **options): async with self._lock: @@ -1306,7 +1306,7 @@ def __init__( @deprecated_args( args_to_warn=["*"], reason="Use get_connection() without args instead", - version="5.0.3", + version="5.3.0", ) async def get_connection(self, command_name=None, *keys, **options): """Gets a connection from the pool, blocking until one is available""" diff --git a/redis/cluster.py b/redis/cluster.py index fc5ffab892..010a3d94e6 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -58,7 +58,7 @@ def get_node_name(host: str, port: Union[str, int]) -> str: @deprecated_args( allowed_args=["redis_node"], reason="Use get_connection(redis_node) instead", - version="5.0.3", + version="5.3.0", ) def get_connection(redis_node, *args, **options): return redis_node.connection or redis_node.connection_pool.get_connection() @@ -490,7 +490,7 @@ class initializer. In the case of conflicting arguments, querystring @deprecated_args( args_to_warn=["read_from_replicas"], reason="Please configure the 'load_balancing_strategy' instead", - version="5.0.3", + version="5.3.0", ) def __init__( self, @@ -1493,7 +1493,7 @@ def _update_moved_slots(self): "In case you need select some load balancing strategy " "that will use replicas, please set it through 'load_balancing_strategy'" ), - version="5.0.3", + version="5.3.0", ) def get_node_from_slot( self, diff --git a/redis/connection.py b/redis/connection.py index ffb1e37ba3..ddc6991c5c 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1502,7 +1502,7 @@ def _checkpid(self) -> None: @deprecated_args( args_to_warn=["*"], reason="Use get_connection() without args instead", - version="5.0.3", + version="5.3.0", ) def get_connection(self, command_name=None, *keys, **options) -> "Connection": "Get a connection from the pool" @@ -1730,7 +1730,7 @@ def make_connection(self): @deprecated_args( args_to_warn=["*"], reason="Use get_connection() without args instead", - version="5.0.3", + version="5.3.0", ) def get_connection(self, command_name=None, *keys, **options): """ From cf5c755ee62fba9e4fba15f7ebfb00174a7e22f1 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Wed, 30 Apr 2025 14:38:07 +0300 Subject: [PATCH 095/113] Updating pipeline tests to use test libs image with RC2. Updating timeseries tests. (#3623) --- .github/workflows/integration.yaml | 2 +- docker-compose.yml | 2 +- tests/test_asyncio/test_timeseries.py | 39 +++++++++++++++++++++++++++ tests/test_timeseries.py | 39 ++++++++++++++++++++++++++- 4 files changed, 79 insertions(+), 3 deletions(-) diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index f8aa5c8932..fbfcecdf68 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -74,7 +74,7 @@ jobs: max-parallel: 15 fail-fast: false matrix: - redis-version: ['8.0-RC1-pre', '${{ needs.redis_version.outputs.CURRENT }}', '7.2.7', '6.2.17'] + redis-version: ['8.0-RC2-pre', '${{ needs.redis_version.outputs.CURRENT }}', '7.2.7', '6.2.17'] python-version: ['3.8', '3.13'] parser-backend: ['plain'] event-loop: ['asyncio'] diff --git a/docker-compose.yml b/docker-compose.yml index 6b544553cb..76a60398f3 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,6 +1,6 @@ --- x-client-libs-stack-image: &client-libs-stack-image - image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_STACK_IMAGE_TAG:-7.4.2}" + image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_STACK_IMAGE_TAG:-rs-7.4.0-v2}" x-client-libs-image: &client-libs-image image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_IMAGE_TAG:-7.4.2}" diff --git a/tests/test_asyncio/test_timeseries.py b/tests/test_asyncio/test_timeseries.py index b21f7d0ac8..24c36c9ca2 100644 --- a/tests/test_asyncio/test_timeseries.py +++ b/tests/test_asyncio/test_timeseries.py @@ -7,6 +7,8 @@ from tests.conftest import ( assert_resp_response, is_resp2_connection, + skip_if_server_version_gte, + skip_if_server_version_lt, skip_ifmodversion_lt, ) @@ -75,7 +77,24 @@ async def test_alter(decoded_r: redis.Redis): @pytest.mark.redismod @skip_ifmodversion_lt("1.4.0", "timeseries") +@skip_if_server_version_lt("7.9.0") async def test_alter_duplicate_policy(decoded_r: redis.Redis): + assert await decoded_r.ts().create(1) + info = await decoded_r.ts().info(1) + assert_resp_response( + decoded_r, "block", info.get("duplicate_policy"), info.get("duplicatePolicy") + ) + assert await decoded_r.ts().alter(1, duplicate_policy="min") + info = await decoded_r.ts().info(1) + assert_resp_response( + decoded_r, "min", info.get("duplicate_policy"), info.get("duplicatePolicy") + ) + + +@pytest.mark.redismod +@skip_ifmodversion_lt("1.4.0", "timeseries") +@skip_if_server_version_gte("7.9.0") +async def test_alter_duplicate_policy_prior_redis_8(decoded_r: redis.Redis): assert await decoded_r.ts().create(1) info = await decoded_r.ts().info(1) assert_resp_response( @@ -722,7 +741,27 @@ async def test_info(decoded_r: redis.Redis): @pytest.mark.redismod @skip_ifmodversion_lt("1.4.0", "timeseries") +@skip_if_server_version_lt("7.9.0") async def test_info_duplicate_policy(decoded_r: redis.Redis): + await decoded_r.ts().create( + 1, retention_msecs=5, labels={"currentLabel": "currentData"} + ) + info = await decoded_r.ts().info(1) + assert_resp_response( + decoded_r, "block", info.get("duplicate_policy"), info.get("duplicatePolicy") + ) + + await decoded_r.ts().create("time-serie-2", duplicate_policy="min") + info = await decoded_r.ts().info("time-serie-2") + assert_resp_response( + decoded_r, "min", info.get("duplicate_policy"), info.get("duplicatePolicy") + ) + + +@pytest.mark.redismod +@skip_ifmodversion_lt("1.4.0", "timeseries") +@skip_if_server_version_gte("7.9.0") +async def test_info_duplicate_policy_prior_redis_8(decoded_r: redis.Redis): await decoded_r.ts().create( 1, retention_msecs=5, labels={"currentLabel": "currentData"} ) diff --git a/tests/test_timeseries.py b/tests/test_timeseries.py index fb604d0329..ad98b1ca2f 100644 --- a/tests/test_timeseries.py +++ b/tests/test_timeseries.py @@ -9,6 +9,8 @@ _get_client, assert_resp_response, is_resp2_connection, + skip_if_server_version_gte, + skip_if_server_version_lt, skip_ifmodversion_lt, ) @@ -84,7 +86,8 @@ def test_alter(client): @pytest.mark.redismod @skip_ifmodversion_lt("1.4.0", "timeseries") -def test_alter_duplicate_policy(client): +@skip_if_server_version_gte("7.9.0") +def test_alter_duplicate_policy_prior_redis_8(client): assert client.ts().create(1) info = client.ts().info(1) assert_resp_response( @@ -97,6 +100,22 @@ def test_alter_duplicate_policy(client): ) +@pytest.mark.redismod +@skip_ifmodversion_lt("1.4.0", "timeseries") +@skip_if_server_version_lt("7.9.0") +def test_alter_duplicate_policy(client): + assert client.ts().create(1) + info = client.ts().info(1) + assert_resp_response( + client, "block", info.get("duplicate_policy"), info.get("duplicatePolicy") + ) + assert client.ts().alter(1, duplicate_policy="min") + info = client.ts().info(1) + assert_resp_response( + client, "min", info.get("duplicate_policy"), info.get("duplicatePolicy") + ) + + @pytest.mark.redismod def test_add(client): assert 1 == client.ts().add(1, 1, 1) @@ -967,7 +986,25 @@ def test_info(client): @pytest.mark.redismod @skip_ifmodversion_lt("1.4.0", "timeseries") +@skip_if_server_version_lt("7.9.0") def test_info_duplicate_policy(client): + client.ts().create(1, retention_msecs=5, labels={"currentLabel": "currentData"}) + info = client.ts().info(1) + assert_resp_response( + client, "block", info.get("duplicate_policy"), info.get("duplicatePolicy") + ) + + client.ts().create("time-serie-2", duplicate_policy="min") + info = client.ts().info("time-serie-2") + assert_resp_response( + client, "min", info.get("duplicate_policy"), info.get("duplicatePolicy") + ) + + +@pytest.mark.redismod +@skip_ifmodversion_lt("1.4.0", "timeseries") +@skip_if_server_version_gte("7.9.0") +def test_info_duplicate_policy_prior_redis_8(client): client.ts().create(1, retention_msecs=5, labels={"currentLabel": "currentData"}) info = client.ts().info(1) assert_resp_response( From 5fe120d5976b46ead99ea62127bb8fc15990e136 Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Wed, 30 Apr 2025 15:56:11 +0300 Subject: [PATCH 096/113] Updated default value of 'require_full_coverage' argument to true for sync Cluster client to match sync/async cluster APIs (#3434) * Updated default value to much sync cluster API * Updated default value to TRUE --------- Co-authored-by: petyaslavova --- redis/cluster.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/cluster.py b/redis/cluster.py index 010a3d94e6..ae9720652a 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -499,7 +499,7 @@ def __init__( startup_nodes: Optional[List["ClusterNode"]] = None, cluster_error_retry_attempts: int = 3, retry: Optional["Retry"] = None, - require_full_coverage: bool = False, + require_full_coverage: bool = True, reinitialize_steps: int = 5, read_from_replicas: bool = False, load_balancing_strategy: Optional["LoadBalancingStrategy"] = None, From 41fdadb21b64c2b242b69d7d15c7eace15193575 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Wed, 30 Apr 2025 17:03:26 +0300 Subject: [PATCH 097/113] Adding default retry configuration changes for cluster clients (#3622) * Adding default retry configuration changes for sync cluster client * Adding default retry configuration changes for sync cluster client * Adding default retry configuration changes for async cluster client * Updating docs related to retries and read_from_replicas. * Applying review comments. * Removing retry checks when using set_retry for cluster clients. --- docs/clustering.rst | 11 ++- docs/retry.rst | 43 ++++++----- redis/asyncio/cluster.py | 97 +++++++++++++---------- redis/asyncio/retry.py | 12 +++ redis/cluster.py | 106 ++++++++++++++++++++------ redis/commands/json/__init__.py | 2 +- redis/commands/timeseries/__init__.py | 2 +- redis/retry.py | 12 +++ tests/test_asyncio/test_cluster.py | 85 ++++++++++++--------- tests/test_cluster.py | 50 ++++++------ tests/test_retry.py | 16 ++++ 11 files changed, 283 insertions(+), 153 deletions(-) diff --git a/docs/clustering.rst b/docs/clustering.rst index cf257d8ad5..3c28b9ee16 100644 --- a/docs/clustering.rst +++ b/docs/clustering.rst @@ -187,8 +187,8 @@ When a ClusterPubSub instance is created without specifying a node, a single node will be transparently chosen for the pubsub connection on the first command execution. The node will be determined by: 1. Hashing the channel name in the request to find its keyslot 2. Selecting a node -that handles the keyslot: If read_from_replicas is set to true, a -replica can be selected. +that handles the keyslot: If read_from_replicas is set to true or +load_balancing_strategy is provided, a replica can be selected. Known PubSub Limitations ------------------------ @@ -216,9 +216,12 @@ By default, Redis Cluster always returns MOVE redirection response on accessing a replica node. You can overcome this limitation and scale read commands by triggering READONLY mode. -To enable READONLY mode pass read_from_replicas=True to RedisCluster -constructor. When set to true, read commands will be assigned between +To enable READONLY mode pass read_from_replicas=True or define +a load_balancing_strategy to RedisCluster constructor. +When read_from_replicas is set to true read commands will be assigned between the primary and its replications in a Round-Robin manner. +With load_balancing_strategy you can define a custom strategy for +assigning read commands to the replicas and primary nodes. READONLY mode can be set at runtime by calling the readonly() method with target_nodes=‘replicas’, and read-write access can be restored by diff --git a/docs/retry.rst b/docs/retry.rst index acf198ec94..0f2e318022 100644 --- a/docs/retry.rst +++ b/docs/retry.rst @@ -13,25 +13,25 @@ Retry in Redis Standalone >>> from redis.client import Redis >>> from redis.exceptions import ( >>> BusyLoadingError, ->>> ConnectionError, ->>> TimeoutError +>>> RedisError, >>> ) >>> >>> # Run 3 retries with exponential backoff strategy >>> retry = Retry(ExponentialBackoff(), 3) ->>> # Redis client with retries on custom errors ->>> r = Redis(host='localhost', port=6379, retry=retry, retry_on_error=[BusyLoadingError, ConnectionError, TimeoutError]) ->>> # Redis client with retries on TimeoutError only ->>> r_only_timeout = Redis(host='localhost', port=6379, retry=retry, retry_on_timeout=True) +>>> # Redis client with retries on custom errors in addition to the errors +>>> # that are already retried by default +>>> r = Redis(host='localhost', port=6379, retry=retry, retry_on_error=[BusyLoadingError, RedisError]) -As you can see from the example above, Redis client supports 3 parameters to configure the retry behaviour: +As you can see from the example above, Redis client supports 2 parameters to configure the retry behaviour: * ``retry``: :class:`~.Retry` instance with a :ref:`backoff-label` strategy and the max number of retries -* ``retry_on_error``: list of :ref:`exceptions-label` to retry on -* ``retry_on_timeout``: if ``True``, retry on :class:`~.TimeoutError` only + * The :class:`~.Retry` instance has default set of :ref:`exceptions-label` to retry on, + which can be overridden by passing a tuple with :ref:`exceptions-label` to the ``supported_errors`` parameter. +* ``retry_on_error``: list of additional :ref:`exceptions-label` to retry on -If either ``retry_on_error`` or ``retry_on_timeout`` are passed and no ``retry`` is given, -by default it uses a ``Retry(NoBackoff(), 1)`` (meaning 1 retry right after the first failure). + +If no ``retry`` is provided, a default one is created with :class:`~.ExponentialWithJitterBackoff` as backoff strategy +and 3 retries. Retry in Redis Cluster @@ -44,12 +44,18 @@ Retry in Redis Cluster >>> # Run 3 retries with exponential backoff strategy >>> retry = Retry(ExponentialBackoff(), 3) >>> # Redis Cluster client with retries ->>> rc = RedisCluster(host='localhost', port=6379, retry=retry, cluster_error_retry_attempts=2) +>>> rc = RedisCluster(host='localhost', port=6379, retry=retry) Retry behaviour in Redis Cluster is a little bit different from Standalone: -* ``retry``: :class:`~.Retry` instance with a :ref:`backoff-label` strategy and the max number of retries, default value is ``Retry(NoBackoff(), 0)`` -* ``cluster_error_retry_attempts``: number of times to retry before raising an error when :class:`~.TimeoutError` or :class:`~.ConnectionError` or :class:`~.ClusterDownError` are encountered, default value is ``3`` +* ``retry``: :class:`~.Retry` instance with a :ref:`backoff-label` strategy and the max number of retries, default value is ``Retry(ExponentialWithJitterBackoff(base=1, cap=10), cluster_error_retry_attempts)`` +* ``cluster_error_retry_attempts``: number of times to retry before raising an error when :class:`~.TimeoutError`, :class:`~.ConnectionError`, :class:`~.ClusterDownError` or :class:`~.SlotNotCoveredError` are encountered, default value is ``3`` + * This argument is deprecated - it is used to initialize the number of retries for the retry object, + only in the case when the ``retry`` object is not provided. + When the ``retry`` argument is provided, the ``cluster_error_retry_attempts`` argument is ignored! + +* The retry object is not yet fully utilized in the cluster client. + The retry object is used only to determine the number of retries for the cluster level calls. Let's consider the following example: @@ -57,14 +63,11 @@ Let's consider the following example: >>> from redis.retry import Retry >>> from redis.cluster import RedisCluster >>> ->>> rc = RedisCluster(host='localhost', port=6379, retry=Retry(ExponentialBackoff(), 6), cluster_error_retry_attempts=1) +>>> rc = RedisCluster(host='localhost', port=6379, retry=Retry(ExponentialBackoff(), 6)) >>> rc.set('foo', 'bar') #. the client library calculates the hash slot for key 'foo'. #. given the hash slot, it then determines which node to connect to, in order to execute the command. #. during the connection, a :class:`~.ConnectionError` is raised. -#. because we set ``retry=Retry(ExponentialBackoff(), 6)``, the client tries to reconnect to the node up to 6 times, with an exponential backoff between each attempt. -#. even after 6 retries, the client is still unable to connect. -#. because we set ``cluster_error_retry_attempts=1``, before giving up, the client starts a cluster update, removes the failed node from the startup nodes, and re-initializes the cluster. -#. after the cluster has been re-initialized, it starts a new cycle of retries, up to 6 retries, with an exponential backoff. -#. if the client can connect, we're good. Otherwise, the exception is finally raised to the caller, because we've run out of attempts. \ No newline at end of file +#. because we set ``retry=Retry(ExponentialBackoff(), 6)``, the cluster client starts a cluster update, removes the failed node from the startup nodes, and re-initializes the cluster. +#. the cluster client retries the command until it either succeeds or the max number of retries is reached. \ No newline at end of file diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index e6fd52d790..6a0d4414fd 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -29,7 +29,7 @@ from redis.asyncio.lock import Lock from redis.asyncio.retry import Retry from redis.auth.token import TokenInterface -from redis.backoff import default_backoff +from redis.backoff import ExponentialWithJitterBackoff, NoBackoff from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis from redis.cluster import ( PIPELINE_BLOCKED_COMMANDS, @@ -151,19 +151,23 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand To avoid reinitializing the cluster on moved errors, set reinitialize_steps to 0. :param cluster_error_retry_attempts: - | Number of times to retry before raising an error when :class:`~.TimeoutError` - or :class:`~.ConnectionError` or :class:`~.ClusterDownError` are encountered - :param connection_error_retry_attempts: - | Number of times to retry before reinitializing when :class:`~.TimeoutError` - or :class:`~.ConnectionError` are encountered. - The default backoff strategy will be set if Retry object is not passed (see - default_backoff in backoff.py). To change it, pass a custom Retry object - using the "retry" keyword. + | @deprecated - Please configure the 'retry' object instead + In case 'retry' object is set - this argument is ignored! + + Number of times to retry before raising an error when :class:`~.TimeoutError`, + :class:`~.ConnectionError`, :class:`~.SlotNotCoveredError` + or :class:`~.ClusterDownError` are encountered + :param retry: + | A retry object that defines the retry strategy and the number of + retries for the cluster client. + In current implementation for the cluster client (starting form redis-py version 6.0.0) + the retry object is not yet fully utilized, instead it is used just to determine + the number of retries for the cluster client. + In the future releases the retry object will be used to handle the cluster client retries! :param max_connections: | Maximum number of connections per node. If there are no free connections & the maximum number of connections are already created, a - :class:`~.MaxConnectionsError` is raised. This error may be retried as defined - by :attr:`connection_error_retry_attempts` + :class:`~.MaxConnectionsError` is raised. :param address_remap: | An optional callable which, when provided with an internal network address of a node, e.g. a `(host, port)` tuple, will return the address @@ -219,10 +223,9 @@ def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster": __slots__ = ( "_initialize", "_lock", - "cluster_error_retry_attempts", + "retry", "command_flags", "commands_parser", - "connection_error_retry_attempts", "connection_kwargs", "encoder", "node_flags", @@ -239,6 +242,13 @@ def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster": reason="Please configure the 'load_balancing_strategy' instead", version="5.3.0", ) + @deprecated_args( + args_to_warn=[ + "cluster_error_retry_attempts", + ], + reason="Please configure the 'retry' object instead", + version="6.0.0", + ) def __init__( self, host: Optional[str] = None, @@ -251,8 +261,9 @@ def __init__( dynamic_startup_nodes: bool = True, reinitialize_steps: int = 5, cluster_error_retry_attempts: int = 3, - connection_error_retry_attempts: int = 3, max_connections: int = 2**31, + retry: Optional["Retry"] = None, + retry_on_error: Optional[List[Type[Exception]]] = None, # Client related kwargs db: Union[str, int] = 0, path: Optional[str] = None, @@ -272,8 +283,6 @@ def __init__( socket_keepalive: bool = False, socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, socket_timeout: Optional[float] = None, - retry: Optional["Retry"] = None, - retry_on_error: Optional[List[Type[Exception]]] = None, # SSL related kwargs ssl: bool = False, ssl_ca_certs: Optional[str] = None, @@ -327,7 +336,6 @@ def __init__( "socket_keepalive": socket_keepalive, "socket_keepalive_options": socket_keepalive_options, "socket_timeout": socket_timeout, - "retry": retry, "protocol": protocol, } @@ -351,17 +359,15 @@ def __init__( # Call our on_connect function to configure READONLY mode kwargs["redis_connect_func"] = self.on_connect - self.retry = retry - if retry or retry_on_error or connection_error_retry_attempts > 0: - # Set a retry object for all cluster nodes - self.retry = retry or Retry( - default_backoff(), connection_error_retry_attempts + if retry: + self.retry = retry + else: + self.retry = Retry( + backoff=ExponentialWithJitterBackoff(base=1, cap=10), + retries=cluster_error_retry_attempts, ) - if not retry_on_error: - # Default errors for retrying - retry_on_error = [ConnectionError, TimeoutError] + if retry_on_error: self.retry.update_supported_errors(retry_on_error) - kwargs.update({"retry": self.retry}) kwargs["response_callbacks"] = _RedisCallbacks.copy() if kwargs.get("protocol") in ["3", 3]: @@ -399,8 +405,6 @@ def __init__( self.read_from_replicas = read_from_replicas self.load_balancing_strategy = load_balancing_strategy self.reinitialize_steps = reinitialize_steps - self.cluster_error_retry_attempts = cluster_error_retry_attempts - self.connection_error_retry_attempts = connection_error_retry_attempts self.reinitialize_counter = 0 self.commands_parser = AsyncCommandsParser() self.node_flags = self.__class__.NODE_FLAGS.copy() @@ -571,15 +575,8 @@ def get_connection_kwargs(self) -> Dict[str, Optional[Any]]: """Get the kwargs passed to :class:`~redis.asyncio.connection.Connection`.""" return self.connection_kwargs - def get_retry(self) -> Optional["Retry"]: - return self.retry - - def set_retry(self, retry: "Retry") -> None: + def set_retry(self, retry: Retry) -> None: self.retry = retry - for node in self.get_nodes(): - node.connection_kwargs.update({"retry": retry}) - for conn in node._connections: - conn.retry = retry def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None: """Set a custom response callback.""" @@ -698,8 +695,8 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: """ Execute a raw command on the appropriate cluster node or target_nodes. - It will retry the command as specified by :attr:`cluster_error_retry_attempts` & - then raise an exception. + It will retry the command as specified by the retries property of + the :attr:`retry` & then raise an exception. :param args: | Raw command args @@ -715,7 +712,7 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: command = args[0] target_nodes = [] target_nodes_specified = False - retry_attempts = self.cluster_error_retry_attempts + retry_attempts = self.retry.get_retries() passed_targets = kwargs.pop("target_nodes", None) if passed_targets and not self._is_node_flag(passed_targets): @@ -1058,7 +1055,23 @@ def acquire_connection(self) -> Connection: return self._free.popleft() except IndexError: if len(self._connections) < self.max_connections: - connection = self.connection_class(**self.connection_kwargs) + # We are configuring the connection pool not to retry + # connections on lower level clients to avoid retrying + # connections to nodes that are not reachable + # and to avoid blocking the connection pool. + # The only error that will have some handling in the lower + # level clients is ConnectionError which will trigger disconnection + # of the socket. + # The retries will be handled on cluster client level + # where we will have proper handling of the cluster topology + retry = Retry( + backoff=NoBackoff(), + retries=0, + supported_errors=(ConnectionError,), + ) + connection_kwargs = self.connection_kwargs.copy() + connection_kwargs["retry"] = retry + connection = self.connection_class(**connection_kwargs) self._connections.append(connection) return connection @@ -1559,7 +1572,7 @@ async def execute( """ Execute the pipeline. - It will retry the commands as specified by :attr:`cluster_error_retry_attempts` + It will retry the commands as specified by retries specified in :attr:`retry` & then raise an exception. :param raise_on_error: @@ -1575,7 +1588,7 @@ async def execute( return [] try: - retry_attempts = self._client.cluster_error_retry_attempts + retry_attempts = self._client.retry.get_retries() while True: try: if self._client._initialize: diff --git a/redis/asyncio/retry.py b/redis/asyncio/retry.py index 7c5e3b0e7d..a20f8b4849 100644 --- a/redis/asyncio/retry.py +++ b/redis/asyncio/retry.py @@ -43,6 +43,18 @@ def update_supported_errors(self, specified_errors: list): set(self._supported_errors + tuple(specified_errors)) ) + def get_retries(self) -> int: + """ + Get the number of retries. + """ + return self._retries + + def update_retries(self, value: int) -> None: + """ + Set the number of retries. + """ + self._retries = value + async def call_with_retry( self, do: Callable[[], Awaitable[T]], fail: Callable[[RedisError], Any] ) -> T: diff --git a/redis/cluster.py b/redis/cluster.py index ae9720652a..c79f8e429d 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -9,7 +9,7 @@ from redis._parsers import CommandsParser, Encoder from redis._parsers.helpers import parse_scan -from redis.backoff import default_backoff +from redis.backoff import ExponentialWithJitterBackoff, NoBackoff from redis.cache import CacheConfig, CacheFactory, CacheFactoryInterface, CacheInterface from redis.client import CaseInsensitiveDict, PubSub, Redis from redis.commands import READ_COMMANDS, RedisClusterCommands @@ -179,7 +179,7 @@ def parse_cluster_myshardid(resp, **options): "cache", "cache_config", ) -KWARGS_DISABLED_KEYS = ("host", "port") +KWARGS_DISABLED_KEYS = ("host", "port", "retry") def cleanup_kwargs(**kwargs): @@ -436,7 +436,7 @@ def replace_default_node(self, target_node: "ClusterNode" = None) -> None: # Choose a primary if the cluster contains different primaries self.nodes_manager.default_node = random.choice(primaries) else: - # Otherwise, hoose a primary if the cluster contains different primaries + # Otherwise, choose a primary if the cluster contains different primaries replicas = [node for node in self.get_replicas() if node != curr_node] if replicas: self.nodes_manager.default_node = random.choice(replicas) @@ -492,6 +492,13 @@ class initializer. In the case of conflicting arguments, querystring reason="Please configure the 'load_balancing_strategy' instead", version="5.3.0", ) + @deprecated_args( + args_to_warn=[ + "cluster_error_retry_attempts", + ], + reason="Please configure the 'retry' object instead", + version="6.0.0", + ) def __init__( self, host: Optional[str] = None, @@ -549,9 +556,19 @@ def __init__( If you use dynamic DNS endpoints for startup nodes but CLUSTER SLOTS lists specific IP addresses, it is best to set it to false. :param cluster_error_retry_attempts: + @deprecated - Please configure the 'retry' object instead + In case 'retry' object is set - this argument is ignored! + Number of times to retry before raising an error when - :class:`~.TimeoutError` or :class:`~.ConnectionError` or + :class:`~.TimeoutError` or :class:`~.ConnectionError`, :class:`~.SlotNotCoveredError` or :class:`~.ClusterDownError` are encountered + :param retry: + A retry object that defines the retry strategy and the number of + retries for the cluster client. + In current implementation for the cluster client (starting form redis-py version 6.0.0) + the retry object is not yet fully utilized, instead it is used just to determine + the number of retries for the cluster client. + In the future releases the retry object will be used to handle the cluster client retries! :param reinitialize_steps: Specifies the number of MOVED errors that need to occur before reinitializing the whole cluster topology. If a MOVED error occurs @@ -571,7 +588,8 @@ def __init__( :**kwargs: Extra arguments that will be sent into Redis instance when created - (See Official redis-py doc for supported kwargs + (See Official redis-py doc for supported kwargs - the only limitation + is that you can't provide 'retry' object as part of kwargs. [https://github.com/andymccurdy/redis-py/blob/master/redis/client.py]) Some kwargs are not supported and will raise a RedisClusterException: @@ -586,6 +604,15 @@ def __init__( "Argument 'db' is not possible to use in cluster mode" ) + if "retry" in kwargs: + # Argument 'retry' is not possible to be used in kwargs when in cluster mode + # the kwargs are set to the lower level connections to the cluster nodes + # and there we provide retry configuration without retries allowed. + # The retries should be handled on cluster client level. + raise RedisClusterException( + "The 'retry' argument cannot be used in kwargs when running in cluster mode." + ) + # Get the startup node/s from_url = False if url is not None: @@ -628,9 +655,11 @@ def __init__( kwargs = cleanup_kwargs(**kwargs) if retry: self.retry = retry - kwargs.update({"retry": self.retry}) else: - kwargs.update({"retry": Retry(default_backoff(), 0)}) + self.retry = Retry( + backoff=ExponentialWithJitterBackoff(base=1, cap=10), + retries=cluster_error_retry_attempts, + ) self.encoder = Encoder( kwargs.get("encoding", "utf-8"), @@ -641,7 +670,6 @@ def __init__( if (cache_config or cache) and protocol not in [3, "3"]: raise RedisError("Client caching is only supported with RESP version 3") - self.cluster_error_retry_attempts = cluster_error_retry_attempts self.command_flags = self.__class__.COMMAND_FLAGS.copy() self.node_flags = self.__class__.NODE_FLAGS.copy() self.read_from_replicas = read_from_replicas @@ -772,13 +800,8 @@ def set_default_node(self, node): self.nodes_manager.default_node = node return True - def get_retry(self) -> Optional["Retry"]: - return self.retry - - def set_retry(self, retry: "Retry") -> None: + def set_retry(self, retry: Retry) -> None: self.retry = retry - for node in self.get_nodes(): - node.redis_connection.set_retry(retry) def monitor(self, target_node=None): """ @@ -825,10 +848,11 @@ def pipeline(self, transaction=None, shard_hint=None): startup_nodes=self.nodes_manager.startup_nodes, result_callbacks=self.result_callbacks, cluster_response_callbacks=self.cluster_response_callbacks, - cluster_error_retry_attempts=self.cluster_error_retry_attempts, + cluster_error_retry_attempts=self.retry.get_retries(), read_from_replicas=self.read_from_replicas, load_balancing_strategy=self.load_balancing_strategy, reinitialize_steps=self.reinitialize_steps, + retry=self.retry, lock=self._lock, ) @@ -1090,8 +1114,8 @@ def _internal_execute_command(self, *args, **kwargs): """ Wrapper for ERRORS_ALLOW_RETRY error handling. - It will try the number of times specified by the config option - "self.cluster_error_retry_attempts" which defaults to 3 unless manually + It will try the number of times specified by the retries property from + config option "self.retry" which defaults to 3 unless manually configured. If it reaches the number of times, the command will raise the exception @@ -1117,9 +1141,7 @@ def _internal_execute_command(self, *args, **kwargs): # execution since the nodes may not be valid anymore after the tables # were reinitialized. So in case of passed target nodes, # retry_attempts will be set to 0. - retry_attempts = ( - 0 if target_nodes_specified else self.cluster_error_retry_attempts - ) + retry_attempts = 0 if target_nodes_specified else self.retry.get_retries() # Add one for the first execution execute_attempts = 1 + retry_attempts for _ in range(execute_attempts): @@ -1333,8 +1355,12 @@ def __eq__(self, obj): return isinstance(obj, ClusterNode) and obj.name == self.name def __del__(self): - if self.redis_connection is not None: - self.redis_connection.close() + try: + if self.redis_connection is not None: + self.redis_connection.close() + except Exception: + # Ignore errors when closing the connection + pass class LoadBalancingStrategy(Enum): @@ -1585,17 +1611,32 @@ def create_redis_connections(self, nodes): ) def create_redis_node(self, host, port, **kwargs): + # We are configuring the connection pool not to retry + # connections on lower level clients to avoid retrying + # connections to nodes that are not reachable + # and to avoid blocking the connection pool. + # The only error that will have some handling in the lower + # level clients is ConnectionError which will trigger disconnection + # of the socket. + # The retries will be handled on cluster client level + # where we will have proper handling of the cluster topology + node_retry_config = Retry( + backoff=NoBackoff(), retries=0, supported_errors=(ConnectionError,) + ) + if self.from_url: # Create a redis node with a costumed connection pool kwargs.update({"host": host}) kwargs.update({"port": port}) kwargs.update({"cache": self._cache}) + kwargs.update({"retry": node_retry_config}) r = Redis(connection_pool=self.connection_pool_class(**kwargs)) else: r = Redis( host=host, port=port, cache=self._cache, + retry=node_retry_config, **kwargs, ) return r @@ -2039,6 +2080,13 @@ class ClusterPipeline(RedisCluster): TryAgainError, ) + @deprecated_args( + args_to_warn=[ + "cluster_error_retry_attempts", + ], + reason="Please configure the 'retry' object instead", + version="6.0.0", + ) def __init__( self, nodes_manager: "NodesManager", @@ -2050,6 +2098,7 @@ def __init__( load_balancing_strategy: Optional[LoadBalancingStrategy] = None, cluster_error_retry_attempts: int = 3, reinitialize_steps: int = 5, + retry: Optional[Retry] = None, lock=None, **kwargs, ): @@ -2066,9 +2115,16 @@ def __init__( self.load_balancing_strategy = load_balancing_strategy self.command_flags = self.__class__.COMMAND_FLAGS.copy() self.cluster_response_callbacks = cluster_response_callbacks - self.cluster_error_retry_attempts = cluster_error_retry_attempts self.reinitialize_counter = 0 self.reinitialize_steps = reinitialize_steps + if retry is not None: + self.retry = retry + else: + self.retry = Retry( + backoff=ExponentialWithJitterBackoff(base=1, cap=10), + retries=self.cluster_error_retry_attempts, + ) + self.encoder = Encoder( kwargs.get("encoding", "utf-8"), kwargs.get("encoding_errors", "strict"), @@ -2194,7 +2250,7 @@ def send_cluster_commands( - refereh_table_asap set to True It will try the number of times specified by - the config option "self.cluster_error_retry_attempts" + the retries in config option "self.retry" which defaults to 3 unless manually configured. If it reaches the number of times, the command will @@ -2202,7 +2258,7 @@ def send_cluster_commands( """ if not stack: return [] - retry_attempts = self.cluster_error_retry_attempts + retry_attempts = self.retry.get_retries() while True: try: return self._send_cluster_commands( diff --git a/redis/commands/json/__init__.py b/redis/commands/json/__init__.py index 01077e6b88..0e717b31d6 100644 --- a/redis/commands/json/__init__.py +++ b/redis/commands/json/__init__.py @@ -120,7 +120,7 @@ def pipeline(self, transaction=True, shard_hint=None): startup_nodes=self.client.nodes_manager.startup_nodes, result_callbacks=self.client.result_callbacks, cluster_response_callbacks=self.client.cluster_response_callbacks, - cluster_error_retry_attempts=self.client.cluster_error_retry_attempts, + cluster_error_retry_attempts=self.client.retry.get_retries(), read_from_replicas=self.client.read_from_replicas, reinitialize_steps=self.client.reinitialize_steps, lock=self.client._lock, diff --git a/redis/commands/timeseries/__init__.py b/redis/commands/timeseries/__init__.py index 4188b93d70..3fbf821172 100644 --- a/redis/commands/timeseries/__init__.py +++ b/redis/commands/timeseries/__init__.py @@ -84,7 +84,7 @@ def pipeline(self, transaction=True, shard_hint=None): startup_nodes=self.client.nodes_manager.startup_nodes, result_callbacks=self.client.result_callbacks, cluster_response_callbacks=self.client.cluster_response_callbacks, - cluster_error_retry_attempts=self.client.cluster_error_retry_attempts, + cluster_error_retry_attempts=self.client.retry.get_retries(), read_from_replicas=self.client.read_from_replicas, reinitialize_steps=self.client.reinitialize_steps, lock=self.client._lock, diff --git a/redis/retry.py b/redis/retry.py index 03fd973c4c..ca9ea76f24 100644 --- a/redis/retry.py +++ b/redis/retry.py @@ -44,6 +44,18 @@ def update_supported_errors( set(self._supported_errors + tuple(specified_errors)) ) + def get_retries(self) -> int: + """ + Get the number of retries. + """ + return self._retries + + def update_retries(self, value: int) -> None: + """ + Set the number of retries. + """ + self._retries = value + def call_with_retry( self, do: Callable[[], T], diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 5a52da3d80..3897492ea6 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -13,7 +13,11 @@ from redis.asyncio.cluster import ClusterNode, NodesManager, RedisCluster from redis.asyncio.connection import Connection, SSLConnection, async_timeout from redis.asyncio.retry import Retry -from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff +from redis.backoff import ( + ExponentialBackoff, + ExponentialWithJitterBackoff, + NoBackoff, +) from redis.cluster import ( PIPELINE_BLOCKED_COMMANDS, PRIMARY, @@ -367,71 +371,79 @@ async def test_cluster_set_get_retry_object(self, request: FixtureRequest): retry = Retry(NoBackoff(), 2) url = request.config.getoption("--redis-url") async with RedisCluster.from_url(url, retry=retry) as r: - assert r.get_retry()._retries == retry._retries - assert isinstance(r.get_retry()._backoff, NoBackoff) + assert r.retry.get_retries() == retry.get_retries() + assert isinstance(r.retry._backoff, NoBackoff) for node in r.get_nodes(): - n_retry = node.connection_kwargs.get("retry") + # validate nodes lower level connections default + # retry policy is applied + n_retry = node.acquire_connection().retry assert n_retry is not None - assert n_retry._retries == retry._retries + assert n_retry._retries == 0 assert isinstance(n_retry._backoff, NoBackoff) rand_cluster_node = r.get_random_node() existing_conn = rand_cluster_node.acquire_connection() # Change retry policy new_retry = Retry(ExponentialBackoff(), 3) r.set_retry(new_retry) - assert r.get_retry()._retries == new_retry._retries - assert isinstance(r.get_retry()._backoff, ExponentialBackoff) + assert r.retry.get_retries() == new_retry.get_retries() + assert isinstance(r.retry._backoff, ExponentialBackoff) for node in r.get_nodes(): - n_retry = node.connection_kwargs.get("retry") + # validate nodes lower level connections are not affected + n_retry = node.acquire_connection().retry assert n_retry is not None - assert n_retry._retries == new_retry._retries - assert isinstance(n_retry._backoff, ExponentialBackoff) - assert existing_conn.retry._retries == new_retry._retries + assert n_retry._retries == 0 + assert isinstance(n_retry._backoff, NoBackoff) + assert existing_conn.retry.get_retries() == 0 new_conn = rand_cluster_node.acquire_connection() - assert new_conn.retry._retries == new_retry._retries + assert new_conn.retry._retries == 0 async def test_cluster_retry_object(self, request: FixtureRequest) -> None: url = request.config.getoption("--redis-url") async with RedisCluster.from_url(url) as rc_default: # Test default retry - retry = rc_default.connection_kwargs.get("retry") + retry = rc_default.retry # FIXME: Workaround for https://github.com/redis/redis-py/issues/3030 host = rc_default.get_default_node().host assert isinstance(retry, Retry) assert retry._retries == 3 - assert isinstance(retry._backoff, type(default_backoff())) - assert rc_default.get_node(host, 16379).connection_kwargs.get( - "retry" - ) == rc_default.get_node(host, 16380).connection_kwargs.get("retry") + assert isinstance(retry._backoff, type(ExponentialWithJitterBackoff())) + + # validate nodes connections are using the default retry for + # lower level connections when client is created through 'from_url' method + # without specified retry object + node1_retry = rc_default.get_node(host, 16379).acquire_connection().retry + node2_retry = rc_default.get_node(host, 16380).acquire_connection().retry + for node_retry in (node1_retry, node2_retry): + assert node_retry.get_retries() == 0 + assert isinstance(node_retry._backoff, NoBackoff) + assert node_retry._supported_errors == (ConnectionError,) retry = Retry(ExponentialBackoff(10, 5), 5) async with RedisCluster.from_url(url, retry=retry) as rc_custom_retry: # Test custom retry - assert ( - rc_custom_retry.get_node(host, 16379).connection_kwargs.get("retry") - == retry - ) + assert rc_custom_retry.retry == retry + # validate nodes connections are using the default retry for + # lower level connections when client is created through 'from_url' method + # with specified retry object + node1_retry = rc_default.get_node(host, 16379).acquire_connection().retry + node2_retry = rc_default.get_node(host, 16380).acquire_connection().retry + for node_retry in (node1_retry, node2_retry): + assert node_retry.get_retries() == 0 + assert isinstance(node_retry._backoff, NoBackoff) + assert node_retry._supported_errors == (ConnectionError,) async with RedisCluster.from_url( - url, connection_error_retry_attempts=0 + url, cluster_error_retry_attempts=0 ) as rc_no_retries: - # Test no connection retries - assert ( - rc_no_retries.get_node(host, 16379).connection_kwargs.get("retry") - is None - ) + # Test no cluster retries + assert rc_no_retries.retry.get_retries() == 0 async with RedisCluster.from_url( url, retry=Retry(NoBackoff(), 0) ) as rc_no_retries: - assert ( - rc_no_retries.get_node(host, 16379) - .connection_kwargs.get("retry") - ._retries - == 0 - ) + assert rc_no_retries.retry.get_retries() == 0 async def test_empty_startup_nodes(self) -> None: """ @@ -2830,7 +2842,7 @@ async def test_multi_key_operation_with_multi_slots(self, r: RedisCluster) -> No async def test_cluster_down_error(self, r: RedisCluster) -> None: """ - Test that the pipeline retries cluster_error_retry_attempts times before raising + Test that the pipeline retries the specified in retry object times before raising an error. """ key = "foo" @@ -2855,10 +2867,7 @@ async def parse_response( async with r.pipeline() as pipe: with pytest.raises(ClusterDownError): await pipe.get(key).execute() - assert ( - node.parse_response.await_count - == 3 * r.cluster_error_retry_attempts + 1 - ) + assert node.parse_response.await_count == 3 * r.retry.get_retries() + 1 async def test_connection_error_not_raised(self, r: RedisCluster) -> None: """Test ConnectionError handling with raise_on_error=False.""" diff --git a/tests/test_cluster.py b/tests/test_cluster.py index d96342f87a..d4e48e199b 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -14,7 +14,11 @@ import redis from redis import Redis from redis._parsers import CommandsParser -from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff +from redis.backoff import ( + ExponentialBackoff, + ExponentialWithJitterBackoff, + NoBackoff, +) from redis.cluster import ( PRIMARY, REDIS_CLUSTER_HASH_SLOTS, @@ -884,46 +888,48 @@ def moved_redirect_effect(connection, *args, **options): def test_cluster_get_set_retry_object(self, request): retry = Retry(NoBackoff(), 2) r = _get_client(RedisCluster, request, retry=retry) - assert r.get_retry()._retries == retry._retries - assert isinstance(r.get_retry()._backoff, NoBackoff) + assert r.retry.get_retries() == retry.get_retries() + assert isinstance(r.retry._backoff, NoBackoff) for node in r.get_nodes(): - assert node.redis_connection.get_retry()._retries == retry._retries + assert node.redis_connection.get_retry().get_retries() == 0 assert isinstance(node.redis_connection.get_retry()._backoff, NoBackoff) rand_node = r.get_random_node() existing_conn = rand_node.redis_connection.connection_pool.get_connection() # Change retry policy new_retry = Retry(ExponentialBackoff(), 3) r.set_retry(new_retry) - assert r.get_retry()._retries == new_retry._retries - assert isinstance(r.get_retry()._backoff, ExponentialBackoff) + assert r.retry.get_retries() == new_retry.get_retries() + assert isinstance(r.retry._backoff, ExponentialBackoff) for node in r.get_nodes(): - assert node.redis_connection.get_retry()._retries == new_retry._retries - assert isinstance( - node.redis_connection.get_retry()._backoff, ExponentialBackoff - ) - assert existing_conn.retry._retries == new_retry._retries + assert node.redis_connection.get_retry()._retries == 0 + assert isinstance(node.redis_connection.get_retry()._backoff, NoBackoff) + assert existing_conn.retry._retries == 0 new_conn = rand_node.redis_connection.connection_pool.get_connection() - assert new_conn.retry._retries == new_retry._retries + assert new_conn.retry._retries == 0 def test_cluster_retry_object(self, r) -> None: # Test default retry # FIXME: Workaround for https://github.com/redis/redis-py/issues/3030 host = r.get_default_node().host - retry = r.get_connection_kwargs().get("retry") + # test default retry config + retry = r.retry assert isinstance(retry, Retry) - assert retry._retries == 0 - assert isinstance(retry._backoff, type(default_backoff())) - node1 = r.get_node(host, 16379).redis_connection - node2 = r.get_node(host, 16380).redis_connection - assert node1.get_retry()._retries == node2.get_retry()._retries - - # Test custom retry + assert retry.get_retries() == 3 + assert isinstance(retry._backoff, type(ExponentialWithJitterBackoff())) + node1_connection = r.get_node(host, 16379).redis_connection + node2_connection = r.get_node(host, 16380).redis_connection + assert node1_connection.get_retry()._retries == 0 + assert node2_connection.get_retry()._retries == 0 + + # Test custom retry is not applied to nodes retry = Retry(ExponentialBackoff(10, 5), 5) rc_custom_retry = RedisCluster(host, 16379, retry=retry) assert ( - rc_custom_retry.get_node(host, 16379).redis_connection.get_retry()._retries - == retry._retries + rc_custom_retry.get_node(host, 16379) + .redis_connection.get_retry() + .get_retries() + == 0 ) def test_replace_cluster_node(self, r) -> None: diff --git a/tests/test_retry.py b/tests/test_retry.py index cb001fbbd5..926fe28313 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -203,6 +203,22 @@ def test_client_retry_on_timeout(self, request): finally: assert parse_response.call_count == retries + 1 + @pytest.mark.onlycluster + def test_get_set_retry_object_for_cluster_client(self, request): + retry = Retry(NoBackoff(), 2) + r = _get_client(Redis, request, retry_on_timeout=True, retry=retry) + exist_conn = r.connection_pool.get_connection() + assert r.retry._retries == retry._retries + assert isinstance(r.retry._backoff, NoBackoff) + new_retry_policy = Retry(ExponentialBackoff(), 3) + r.set_retry(new_retry_policy) + assert r.retry._retries == new_retry_policy._retries + assert isinstance(r.retry._backoff, ExponentialBackoff) + assert exist_conn.retry._retries == new_retry_policy._retries + new_conn = r.connection_pool.get_connection() + assert new_conn.retry._retries == new_retry_policy._retries + + @pytest.mark.onlynoncluster def test_get_set_retry_object(self, request): retry = Retry(NoBackoff(), 2) r = _get_client(Redis, request, retry_on_timeout=True, retry=retry) From c980e9589b791234c99d7a783cd6ec957a74851e Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Wed, 30 Apr 2025 21:10:54 +0300 Subject: [PATCH 098/113] Changing the default value for ssl_check_hostname to True, to ensure security validations are not skipped by default (#3626) * Changing the default value for ssl_check_hostname to True, to ensure security validations are not skipped by default * Applying review comments * Removing unused operation in tests. * Removing unneeded comment from tests. --- docs/examples/ssl_connection_examples.ipynb | 19 ++++++++++------- redis/asyncio/client.py | 2 +- redis/asyncio/cluster.py | 2 +- redis/asyncio/connection.py | 2 +- redis/client.py | 2 +- redis/connection.py | 2 +- tests/test_asyncio/test_cluster.py | 23 ++++++++++++++++++++- tests/test_asyncio/test_connect.py | 13 +++++++++++- tests/test_connect.py | 10 +++++++++ tests/test_ssl.py | 12 ++++++++++- 10 files changed, 71 insertions(+), 16 deletions(-) diff --git a/docs/examples/ssl_connection_examples.ipynb b/docs/examples/ssl_connection_examples.ipynb index c94c4e0191..a09b87ec1f 100644 --- a/docs/examples/ssl_connection_examples.ipynb +++ b/docs/examples/ssl_connection_examples.ipynb @@ -34,9 +34,10 @@ "import redis\n", "\n", "r = redis.Redis(\n", - " host='localhost', \n", - " port=6666, \n", - " ssl=True, \n", + " host='localhost',\n", + " port=6666,\n", + " ssl=True,\n", + " ssl_check_hostname=False,\n", " ssl_cert_reqs=\"none\",\n", ")\n", "r.ping()" @@ -68,7 +69,7 @@ "source": [ "import redis\n", "\n", - "r = redis.from_url(\"rediss://localhost:6666?ssl_cert_reqs=none&decode_responses=True&health_check_interval=2\")\n", + "r = redis.from_url(\"rediss://localhost:6666?ssl_cert_reqs=none&ssl_check_hostname=False&decode_responses=True&health_check_interval=2\")\n", "r.ping()" ] }, @@ -99,13 +100,14 @@ "import redis\n", "\n", "redis_pool = redis.ConnectionPool(\n", - " host=\"localhost\", \n", - " port=6666, \n", - " connection_class=redis.SSLConnection, \n", + " host=\"localhost\",\n", + " port=6666,\n", + " connection_class=redis.SSLConnection,\n", + " ssl_check_hostname=False,\n", " ssl_cert_reqs=\"none\",\n", ")\n", "\n", - "r = redis.StrictRedis(connection_pool=redis_pool) \n", + "r = redis.StrictRedis(connection_pool=redis_pool)\n", "r.ping()" ] }, @@ -141,6 +143,7 @@ " port=6666,\n", " ssl=True,\n", " ssl_min_version=ssl.TLSVersion.TLSv1_3,\n", + " ssl_check_hostname=False,\n", " ssl_cert_reqs=\"none\",\n", ")\n", "r.ping()" diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index ac907b0c10..1cb28e725e 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -241,7 +241,7 @@ def __init__( ssl_cert_reqs: Union[str, VerifyMode] = "required", ssl_ca_certs: Optional[str] = None, ssl_ca_data: Optional[str] = None, - ssl_check_hostname: bool = False, + ssl_check_hostname: bool = True, ssl_min_version: Optional[TLSVersion] = None, ssl_ciphers: Optional[str] = None, max_connections: Optional[int] = None, diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 6a0d4414fd..23e039c62f 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -289,7 +289,7 @@ def __init__( ssl_ca_data: Optional[str] = None, ssl_cert_reqs: Union[str, VerifyMode] = "required", ssl_certfile: Optional[str] = None, - ssl_check_hostname: bool = False, + ssl_check_hostname: bool = True, ssl_keyfile: Optional[str] = None, ssl_min_version: Optional[TLSVersion] = None, ssl_ciphers: Optional[str] = None, diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 70d7d91898..77131ab951 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -794,7 +794,7 @@ def __init__( ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required", ssl_ca_certs: Optional[str] = None, ssl_ca_data: Optional[str] = None, - ssl_check_hostname: bool = False, + ssl_check_hostname: bool = True, ssl_min_version: Optional[TLSVersion] = None, ssl_ciphers: Optional[str] = None, **kwargs, diff --git a/redis/client.py b/redis/client.py index 138f561974..2ef95600c2 100755 --- a/redis/client.py +++ b/redis/client.py @@ -223,7 +223,7 @@ def __init__( ssl_ca_certs: Optional[str] = None, ssl_ca_path: Optional[str] = None, ssl_ca_data: Optional[str] = None, - ssl_check_hostname: bool = False, + ssl_check_hostname: bool = True, ssl_password: Optional[str] = None, ssl_validate_ocsp: bool = False, ssl_validate_ocsp_stapled: bool = False, diff --git a/redis/connection.py b/redis/connection.py index ddc6991c5c..dab45906d2 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1028,7 +1028,7 @@ def __init__( ssl_cert_reqs="required", ssl_ca_certs=None, ssl_ca_data=None, - ssl_check_hostname=False, + ssl_check_hostname=True, ssl_ca_path=None, ssl_password=None, ssl_validate_ocsp=False, diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 3897492ea6..5a8b6dfee7 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -3139,7 +3139,9 @@ async def test_ssl_with_invalid_cert( async def test_ssl_connection( self, create_client: Callable[..., Awaitable[RedisCluster]] ) -> None: - async with await create_client(ssl=True, ssl_cert_reqs="none") as rc: + async with await create_client( + ssl=True, ssl_check_hostname=False, ssl_cert_reqs="none" + ) as rc: assert await rc.ping() @pytest.mark.parametrize( @@ -3155,6 +3157,7 @@ async def test_ssl_connection_tls12_custom_ciphers( ) -> None: async with await create_client( ssl=True, + ssl_check_hostname=False, ssl_cert_reqs="none", ssl_min_version=ssl.TLSVersion.TLSv1_2, ssl_ciphers=ssl_ciphers, @@ -3166,6 +3169,7 @@ async def test_ssl_connection_tls12_custom_ciphers_invalid( ) -> None: async with await create_client( ssl=True, + ssl_check_hostname=False, ssl_cert_reqs="none", ssl_min_version=ssl.TLSVersion.TLSv1_2, ssl_ciphers="foo:bar", @@ -3187,6 +3191,7 @@ async def test_ssl_connection_tls13_custom_ciphers( # TLSv1.3 does not support changing the ciphers async with await create_client( ssl=True, + ssl_check_hostname=False, ssl_cert_reqs="none", ssl_min_version=ssl.TLSVersion.TLSv1_2, ssl_ciphers=ssl_ciphers, @@ -3198,12 +3203,20 @@ async def test_ssl_connection_tls13_custom_ciphers( async def test_validating_self_signed_certificate( self, create_client: Callable[..., Awaitable[RedisCluster]] ) -> None: + # ssl_check_hostname=False is used to avoid hostname verification + # in the test environment, where the server certificate is self-signed + # and does not match the hostname that is extracted for the cluster. + # Cert hostname is 'localhost' in the cluster initialization when using + # 'localhost' it gets transformed into 127.0.0.1 + # In production code, ssl_check_hostname should be set to True + # to ensure proper hostname verification. async with await create_client( ssl=True, ssl_ca_certs=self.ca_cert, ssl_cert_reqs="required", ssl_certfile=self.client_cert, ssl_keyfile=self.client_key, + ssl_check_hostname=False, ) as rc: assert await rc.ping() @@ -3213,10 +3226,18 @@ async def test_validating_self_signed_string_certificate( with open(self.ca_cert) as f: cert_data = f.read() + # ssl_check_hostname=False is used to avoid hostname verification + # in the test environment, where the server certificate is self-signed + # and does not match the hostname that is extracted for the cluster. + # Cert hostname is 'localhost' in the cluster initialization when using + # 'localhost' it gets transformed into 127.0.0.1 + # In production code, ssl_check_hostname should be set to True + # to ensure proper hostname verification. async with await create_client( ssl=True, ssl_ca_data=cert_data, ssl_cert_reqs="required", + ssl_check_hostname=False, ssl_certfile=self.client_cert, ssl_keyfile=self.client_key, ) as rc: diff --git a/tests/test_asyncio/test_connect.py b/tests/test_asyncio/test_connect.py index 6c4b3c33d7..62e8665d1f 100644 --- a/tests/test_asyncio/test_connect.py +++ b/tests/test_asyncio/test_connect.py @@ -58,6 +58,10 @@ async def test_uds_connect(uds_address): async def test_tcp_ssl_tls12_custom_ciphers(tcp_address, ssl_ciphers): host, port = tcp_address + # in order to have working hostname verification, we need to use "localhost" + # as redis host as the server certificate is self-signed and only valid for "localhost" + host = "localhost" + server_certs = get_tls_certificates(cert_type=CertificateType.server) conn = SSLConnection( @@ -89,6 +93,10 @@ async def test_tcp_ssl_tls12_custom_ciphers(tcp_address, ssl_ciphers): async def test_tcp_ssl_connect(tcp_address, ssl_min_version): host, port = tcp_address + # in order to have working hostname verification, we need to use "localhost" + # as redis host as the server certificate is self-signed and only valid for "localhost" + host = "localhost" + server_certs = get_tls_certificates(cert_type=CertificateType.server) conn = SSLConnection( @@ -100,7 +108,10 @@ async def test_tcp_ssl_connect(tcp_address, ssl_min_version): ssl_min_version=ssl_min_version, ) await _assert_connect( - conn, tcp_address, certfile=server_certs.certfile, keyfile=server_certs.keyfile + conn, + tcp_address, + certfile=server_certs.certfile, + keyfile=server_certs.keyfile, ) await conn.disconnect() diff --git a/tests/test_connect.py b/tests/test_connect.py index f3c02b330f..1e1c23c87e 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -54,10 +54,16 @@ def test_uds_connect(uds_address): ) def test_tcp_ssl_connect(tcp_address, ssl_min_version): host, port = tcp_address + + # in order to have working hostname verification, we need to use "localhost" + # as redis host as the server certificate is self-signed and only valid for "localhost" + host = "localhost" server_certs = get_tls_certificates(cert_type=CertificateType.server) + conn = SSLConnection( host=host, port=port, + ssl_check_hostname=True, client_name=_CLIENT_NAME, ssl_ca_certs=server_certs.ca_certfile, socket_timeout=10, @@ -80,6 +86,10 @@ def test_tcp_ssl_connect(tcp_address, ssl_min_version): def test_tcp_ssl_tls12_custom_ciphers(tcp_address, ssl_ciphers): host, port = tcp_address + # in order to have working hostname verification, we need to use "localhost" + # as redis host as the server certificate is self-signed and only valid for "localhost" + host = "localhost" + server_certs = get_tls_certificates(cert_type=CertificateType.server) conn = SSLConnection( diff --git a/tests/test_ssl.py b/tests/test_ssl.py index 2a945ac287..5aa33353a8 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -37,7 +37,14 @@ def test_ssl_with_invalid_cert(self, request): def test_ssl_connection(self, request): ssl_url = request.config.option.redis_ssl_url p = urlparse(ssl_url)[1].split(":") - r = redis.Redis(host=p[0], port=p[1], ssl=True, ssl_cert_reqs="none") + + r = redis.Redis( + host=p[0], + port=p[1], + ssl=True, + ssl_check_hostname=False, + ssl_cert_reqs="none", + ) assert r.ping() r.close() @@ -98,6 +105,7 @@ def test_ssl_connection_tls12_custom_ciphers(self, request, ssl_ciphers): host=p[0], port=p[1], ssl=True, + ssl_check_hostname=False, ssl_cert_reqs="none", ssl_min_version=ssl.TLSVersion.TLSv1_3, ssl_ciphers=ssl_ciphers, @@ -112,6 +120,7 @@ def test_ssl_connection_tls12_custom_ciphers_invalid(self, request): host=p[0], port=p[1], ssl=True, + ssl_check_hostname=False, ssl_cert_reqs="none", ssl_min_version=ssl.TLSVersion.TLSv1_2, ssl_ciphers="foo:bar", @@ -136,6 +145,7 @@ def test_ssl_connection_tls13_custom_ciphers(self, request, ssl_ciphers): host=p[0], port=p[1], ssl=True, + ssl_check_hostname=False, ssl_cert_reqs="none", ssl_min_version=ssl.TLSVersion.TLSv1_2, ssl_ciphers=ssl_ciphers, From c40b98c2e69e75b82a55dcb9995fe713a9fccffc Mon Sep 17 00:00:00 2001 From: Ivana Kellyer Date: Mon, 5 May 2025 16:18:09 +0200 Subject: [PATCH 099/113] Fix AttributeError on ClusterPipeline (#3634) --- redis/cluster.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/cluster.py b/redis/cluster.py index c79f8e429d..99e5c37f9d 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -2122,7 +2122,7 @@ def __init__( else: self.retry = Retry( backoff=ExponentialWithJitterBackoff(base=1, cap=10), - retries=self.cluster_error_retry_attempts, + retries=cluster_error_retry_attempts, ) self.encoder = Encoder( From 18d3b321d4441f83e3f75b96aa60adf5fa6d614d Mon Sep 17 00:00:00 2001 From: Terence Honles Date: Wed, 7 May 2025 09:51:09 +0200 Subject: [PATCH 100/113] add equality and hashability to ``Retry`` and backoff classes (#3628) --- redis/backoff.py | 54 +++++++++++++++++++++++++++++++++++++++++++++ redis/retry.py | 13 +++++++++++ tests/test_retry.py | 45 ++++++++++++++++++++++++++++++++++++- 3 files changed, 111 insertions(+), 1 deletion(-) diff --git a/redis/backoff.py b/redis/backoff.py index e236764d71..22a3ed0abb 100644 --- a/redis/backoff.py +++ b/redis/backoff.py @@ -31,6 +31,15 @@ def __init__(self, backoff: float) -> None: """`backoff`: backoff time in seconds""" self._backoff = backoff + def __hash__(self) -> int: + return hash((self._backoff,)) + + def __eq__(self, other) -> bool: + if not isinstance(other, ConstantBackoff): + return NotImplemented + + return self._backoff == other._backoff + def compute(self, failures: int) -> float: return self._backoff @@ -53,6 +62,15 @@ def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE): self._cap = cap self._base = base + def __hash__(self) -> int: + return hash((self._base, self._cap)) + + def __eq__(self, other) -> bool: + if not isinstance(other, ExponentialBackoff): + return NotImplemented + + return self._base == other._base and self._cap == other._cap + def compute(self, failures: int) -> float: return min(self._cap, self._base * 2**failures) @@ -68,6 +86,15 @@ def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None self._cap = cap self._base = base + def __hash__(self) -> int: + return hash((self._base, self._cap)) + + def __eq__(self, other) -> bool: + if not isinstance(other, FullJitterBackoff): + return NotImplemented + + return self._base == other._base and self._cap == other._cap + def compute(self, failures: int) -> float: return random.uniform(0, min(self._cap, self._base * 2**failures)) @@ -83,6 +110,15 @@ def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None self._cap = cap self._base = base + def __hash__(self) -> int: + return hash((self._base, self._cap)) + + def __eq__(self, other) -> bool: + if not isinstance(other, EqualJitterBackoff): + return NotImplemented + + return self._base == other._base and self._cap == other._cap + def compute(self, failures: int) -> float: temp = min(self._cap, self._base * 2**failures) / 2 return temp + random.uniform(0, temp) @@ -100,6 +136,15 @@ def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None self._base = base self._previous_backoff = 0 + def __hash__(self) -> int: + return hash((self._base, self._cap)) + + def __eq__(self, other) -> bool: + if not isinstance(other, DecorrelatedJitterBackoff): + return NotImplemented + + return self._base == other._base and self._cap == other._cap + def reset(self) -> None: self._previous_backoff = 0 @@ -121,6 +166,15 @@ def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None self._cap = cap self._base = base + def __hash__(self) -> int: + return hash((self._base, self._cap)) + + def __eq__(self, other) -> bool: + if not isinstance(other, EqualJitterBackoff): + return NotImplemented + + return self._base == other._base and self._cap == other._cap + def compute(self, failures: int) -> float: return min(self._cap, random.random() * self._base * 2**failures) diff --git a/redis/retry.py b/redis/retry.py index ca9ea76f24..c93f34e65f 100644 --- a/redis/retry.py +++ b/redis/retry.py @@ -34,6 +34,19 @@ def __init__( self._retries = retries self._supported_errors = supported_errors + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Retry): + return NotImplemented + + return ( + self._backoff == other._backoff + and self._retries == other._retries + and set(self._supported_errors) == set(other._supported_errors) + ) + + def __hash__(self) -> int: + return hash((self._backoff, self._retries, frozenset(self._supported_errors))) + def update_supported_errors( self, specified_errors: Iterable[Type[Exception]] ) -> None: diff --git a/tests/test_retry.py b/tests/test_retry.py index 926fe28313..4f4f04caca 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -1,7 +1,16 @@ from unittest.mock import patch import pytest -from redis.backoff import AbstractBackoff, ExponentialBackoff, NoBackoff +from redis.backoff import ( + AbstractBackoff, + ConstantBackoff, + DecorrelatedJitterBackoff, + EqualJitterBackoff, + ExponentialBackoff, + ExponentialWithJitterBackoff, + FullJitterBackoff, + NoBackoff, +) from redis.client import Redis from redis.connection import Connection, UnixDomainSocketConnection from redis.exceptions import ( @@ -80,6 +89,40 @@ def test_retry_on_error_retry(self, Class, retries): assert c.retry._retries == retries +@pytest.mark.parametrize( + "args", + [ + (ConstantBackoff(0), 0), + (ConstantBackoff(10), 5), + (NoBackoff(), 0), + ] + + [ + backoff + for Backoff in ( + DecorrelatedJitterBackoff, + EqualJitterBackoff, + ExponentialBackoff, + ExponentialWithJitterBackoff, + FullJitterBackoff, + ) + for backoff in ((Backoff(), 2), (Backoff(25), 5), (Backoff(25, 5), 5)) + ], +) +def test_retry_eq_and_hashable(args): + assert Retry(*args) == Retry(*args) + + # create another retry object with different parameters + copy = list(args) + if isinstance(copy[0], ConstantBackoff): + copy[1] = 9000 + else: + copy[0] = ConstantBackoff(9000) + + assert Retry(*args) != Retry(*copy) + assert Retry(*copy) != Retry(*args) + assert len({Retry(*args), Retry(*args), Retry(*copy), Retry(*copy)}) == 2 + + class TestRetry: "Test that Retry calls backoff and retries the expected number of times" From 436ad609466b75c7b64c94f1f82889b9fe2ac87f Mon Sep 17 00:00:00 2001 From: kesha1225 <48860626+kesha1225@users.noreply.github.com> Date: Thu, 8 May 2025 10:35:48 +0300 Subject: [PATCH 101/113] Change type hints with possible None args or return types to be annotated with Optional - includes commands in core.py and json commands (#3610) * fix(redis-client): change `zrange` num parameter type to `Optional[int]` * fix(redis-client): normalize optional parameter annotations Replace all occurrences of Union[T, None] = None and bare T = None with Optional[T] = None in zrange, _zrange, arrtrim, and other methods so that type checkers no longer report errors. * commit message: fix(redis-client): normalize optional parameter annotations and correct arrtrim return type body: replaced all Union[T, None] = None and bare T = None with Optional[T] = None; changed arrtrim return annotation to Optional[int] * fix(redis-client): replace Optional[None] with Optional[int] for numeric parameters --- CHANGES | 1 + redis/commands/core.py | 208 ++++++++++++++++---------------- redis/commands/json/commands.py | 16 +-- 3 files changed, 113 insertions(+), 112 deletions(-) diff --git a/CHANGES b/CHANGES index 24b52c54db..7ce774621d 100644 --- a/CHANGES +++ b/CHANGES @@ -1146,3 +1146,4 @@ incompatible in code using*SCAN commands loops such as * Implemented STRLEN * Implemented PERSIST * Implemented SETRANGE + * Changed type annotation of the `num` parameter in `zrange` from `int` to `Optional[int] \ No newline at end of file diff --git a/redis/commands/core.py b/redis/commands/core.py index a8c327f08f..378898272f 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -63,7 +63,7 @@ class ACLCommands(CommandsProtocol): see: https://redis.io/topics/acl """ - def acl_cat(self, category: Union[str, None] = None, **kwargs) -> ResponseT: + def acl_cat(self, category: Optional[str] = None, **kwargs) -> ResponseT: """ Returns a list of categories or commands within a category. @@ -92,7 +92,7 @@ def acl_deluser(self, *username: str, **kwargs) -> ResponseT: """ return self.execute_command("ACL DELUSER", *username, **kwargs) - def acl_genpass(self, bits: Union[int, None] = None, **kwargs) -> ResponseT: + def acl_genpass(self, bits: Optional[int] = None, **kwargs) -> ResponseT: """Generate a random password value. If ``bits`` is supplied then use this number of bits, rounded to the next multiple of 4. @@ -137,7 +137,7 @@ def acl_list(self, **kwargs) -> ResponseT: """ return self.execute_command("ACL LIST", **kwargs) - def acl_log(self, count: Union[int, None] = None, **kwargs) -> ResponseT: + def acl_log(self, count: Optional[int] = None, **kwargs) -> ResponseT: """ Get ACL logs as a list. :param int count: Get logs[0:count]. @@ -190,8 +190,8 @@ def acl_setuser( username: str, enabled: bool = False, nopass: bool = False, - passwords: Union[str, Iterable[str], None] = None, - hashed_passwords: Union[str, Iterable[str], None] = None, + passwords: Optional[Union[str, Iterable[str]]] = None, + hashed_passwords: Optional[Union[str, Iterable[str]]] = None, categories: Optional[Iterable[str]] = None, commands: Optional[Iterable[str]] = None, keys: Optional[Iterable[KeyT]] = None, @@ -450,13 +450,13 @@ def client_kill(self, address: str, **kwargs) -> ResponseT: def client_kill_filter( self, - _id: Union[str, None] = None, - _type: Union[str, None] = None, - addr: Union[str, None] = None, - skipme: Union[bool, None] = None, - laddr: Union[bool, None] = None, - user: str = None, - maxage: Union[int, None] = None, + _id: Optional[str] = None, + _type: Optional[str] = None, + addr: Optional[str] = None, + skipme: Optional[bool] = None, + laddr: Optional[bool] = None, + user: Optional[str] = None, + maxage: Optional[int] = None, **kwargs, ) -> ResponseT: """ @@ -512,7 +512,7 @@ def client_info(self, **kwargs) -> ResponseT: return self.execute_command("CLIENT INFO", **kwargs) def client_list( - self, _type: Union[str, None] = None, client_id: List[EncodableT] = [], **kwargs + self, _type: Optional[str] = None, client_id: List[EncodableT] = [], **kwargs ) -> ResponseT: """ Returns a list of currently connected clients. @@ -589,7 +589,7 @@ def client_id(self, **kwargs) -> ResponseT: def client_tracking_on( self, - clientid: Union[int, None] = None, + clientid: Optional[int] = None, prefix: Sequence[KeyT] = [], bcast: bool = False, optin: bool = False, @@ -608,7 +608,7 @@ def client_tracking_on( def client_tracking_off( self, - clientid: Union[int, None] = None, + clientid: Optional[int] = None, prefix: Sequence[KeyT] = [], bcast: bool = False, optin: bool = False, @@ -628,7 +628,7 @@ def client_tracking_off( def client_tracking( self, on: bool = True, - clientid: Union[int, None] = None, + clientid: Optional[int] = None, prefix: Sequence[KeyT] = [], bcast: bool = False, optin: bool = False, @@ -988,7 +988,7 @@ def select(self, index: int, **kwargs) -> ResponseT: return self.execute_command("SELECT", index, **kwargs) def info( - self, section: Union[str, None] = None, *args: List[str], **kwargs + self, section: Optional[str] = None, *args: List[str], **kwargs ) -> ResponseT: """ Returns a dictionary containing information about the Redis server @@ -1070,7 +1070,7 @@ def migrate( timeout: int, copy: bool = False, replace: bool = False, - auth: Union[str, None] = None, + auth: Optional[str] = None, **kwargs, ) -> ResponseT: """ @@ -1152,7 +1152,7 @@ def memory_malloc_stats(self, **kwargs) -> ResponseT: return self.execute_command("MEMORY MALLOC-STATS", **kwargs) def memory_usage( - self, key: KeyT, samples: Union[int, None] = None, **kwargs + self, key: KeyT, samples: Optional[int] = None, **kwargs ) -> ResponseT: """ Return the total memory usage for key, its value and associated @@ -1291,7 +1291,7 @@ def shutdown( raise RedisError("SHUTDOWN seems to have failed.") def slaveof( - self, host: Union[str, None] = None, port: Union[int, None] = None, **kwargs + self, host: Optional[str] = None, port: Optional[int] = None, **kwargs ) -> ResponseT: """ Set the server to be a replicated slave of the instance identified @@ -1304,7 +1304,7 @@ def slaveof( return self.execute_command("SLAVEOF", b"NO", b"ONE", **kwargs) return self.execute_command("SLAVEOF", host, port, **kwargs) - def slowlog_get(self, num: Union[int, None] = None, **kwargs) -> ResponseT: + def slowlog_get(self, num: Optional[int] = None, **kwargs) -> ResponseT: """ Get the entries from the slowlog. If ``num`` is specified, get the most recent ``num`` items. @@ -1451,7 +1451,7 @@ def __init__( self, client: Union["redis.client.Redis", "redis.asyncio.client.Redis"], key: str, - default_overflow: Union[str, None] = None, + default_overflow: Optional[str] = None, ): self.client = client self.key = key @@ -1487,7 +1487,7 @@ def incrby( fmt: str, offset: BitfieldOffsetT, increment: int, - overflow: Union[str, None] = None, + overflow: Optional[str] = None, ): """ Increment a bitfield by a given amount. @@ -1572,8 +1572,8 @@ def append(self, key: KeyT, value: EncodableT) -> ResponseT: def bitcount( self, key: KeyT, - start: Union[int, None] = None, - end: Union[int, None] = None, + start: Optional[int] = None, + end: Optional[int] = None, mode: Optional[str] = None, ) -> ResponseT: """ @@ -1595,7 +1595,7 @@ def bitcount( def bitfield( self: Union["redis.client.Redis", "redis.asyncio.client.Redis"], key: KeyT, - default_overflow: Union[str, None] = None, + default_overflow: Optional[str] = None, ) -> BitFieldOperation: """ Return a BitFieldOperation instance to conveniently construct one or @@ -1641,8 +1641,8 @@ def bitpos( self, key: KeyT, bit: int, - start: Union[int, None] = None, - end: Union[int, None] = None, + start: Optional[int] = None, + end: Optional[int] = None, mode: Optional[str] = None, ) -> ResponseT: """ @@ -1672,7 +1672,7 @@ def copy( self, source: str, destination: str, - destination_db: Union[str, None] = None, + destination_db: Optional[str] = None, replace: bool = False, ) -> ResponseT: """ @@ -2137,7 +2137,7 @@ def pttl(self, name: KeyT) -> ResponseT: return self.execute_command("PTTL", name) def hrandfield( - self, key: str, count: int = None, withvalues: bool = False + self, key: str, count: Optional[int] = None, withvalues: bool = False ) -> ResponseT: """ Return a random field from the hash value stored at key. @@ -2191,8 +2191,8 @@ def restore( value: EncodableT, replace: bool = False, absttl: bool = False, - idletime: Union[int, None] = None, - frequency: Union[int, None] = None, + idletime: Optional[int] = None, + frequency: Optional[int] = None, ) -> ResponseT: """ Create a key using the provided serialized value, previously obtained @@ -2360,7 +2360,7 @@ def stralgo( specific_argument: Union[Literal["strings"], Literal["keys"]] = "strings", len: bool = False, idx: bool = False, - minmatchlen: Union[int, None] = None, + minmatchlen: Optional[int] = None, withmatchlen: bool = False, **kwargs, ) -> ResponseT: @@ -2960,8 +2960,8 @@ def scan( self, cursor: int = 0, match: Union[PatternT, None] = None, - count: Union[int, None] = None, - _type: Union[str, None] = None, + count: Optional[int] = None, + _type: Optional[str] = None, **kwargs, ) -> ResponseT: """ @@ -2992,8 +2992,8 @@ def scan( def scan_iter( self, match: Union[PatternT, None] = None, - count: Union[int, None] = None, - _type: Union[str, None] = None, + count: Optional[int] = None, + _type: Optional[str] = None, **kwargs, ) -> Iterator: """ @@ -3022,7 +3022,7 @@ def sscan( name: KeyT, cursor: int = 0, match: Union[PatternT, None] = None, - count: Union[int, None] = None, + count: Optional[int] = None, ) -> ResponseT: """ Incrementally return lists of elements in a set. Also return a cursor @@ -3045,7 +3045,7 @@ def sscan_iter( self, name: KeyT, match: Union[PatternT, None] = None, - count: Union[int, None] = None, + count: Optional[int] = None, ) -> Iterator: """ Make an iterator using the SSCAN command so that the client doesn't @@ -3065,7 +3065,7 @@ def hscan( name: KeyT, cursor: int = 0, match: Union[PatternT, None] = None, - count: Union[int, None] = None, + count: Optional[int] = None, no_values: Union[bool, None] = None, ) -> ResponseT: """ @@ -3093,7 +3093,7 @@ def hscan_iter( self, name: str, match: Union[PatternT, None] = None, - count: Union[int, None] = None, + count: Optional[int] = None, no_values: Union[bool, None] = None, ) -> Iterator: """ @@ -3121,7 +3121,7 @@ def zscan( name: KeyT, cursor: int = 0, match: Union[PatternT, None] = None, - count: Union[int, None] = None, + count: Optional[int] = None, score_cast_func: Union[type, Callable] = float, ) -> ResponseT: """ @@ -3148,7 +3148,7 @@ def zscan_iter( self, name: KeyT, match: Union[PatternT, None] = None, - count: Union[int, None] = None, + count: Optional[int] = None, score_cast_func: Union[type, Callable] = float, ) -> Iterator: """ @@ -3177,8 +3177,8 @@ class AsyncScanCommands(ScanCommands): async def scan_iter( self, match: Union[PatternT, None] = None, - count: Union[int, None] = None, - _type: Union[str, None] = None, + count: Optional[int] = None, + _type: Optional[str] = None, **kwargs, ) -> AsyncIterator: """ @@ -3207,7 +3207,7 @@ async def sscan_iter( self, name: KeyT, match: Union[PatternT, None] = None, - count: Union[int, None] = None, + count: Optional[int] = None, ) -> AsyncIterator: """ Make an iterator using the SSCAN command so that the client doesn't @@ -3229,7 +3229,7 @@ async def hscan_iter( self, name: str, match: Union[PatternT, None] = None, - count: Union[int, None] = None, + count: Optional[int] = None, no_values: Union[bool, None] = None, ) -> AsyncIterator: """ @@ -3258,7 +3258,7 @@ async def zscan_iter( self, name: KeyT, match: Union[PatternT, None] = None, - count: Union[int, None] = None, + count: Optional[int] = None, score_cast_func: Union[type, Callable] = float, ) -> AsyncIterator: """ @@ -3489,11 +3489,11 @@ def xadd( name: KeyT, fields: Dict[FieldT, EncodableT], id: StreamIdT = "*", - maxlen: Union[int, None] = None, + maxlen: Optional[int] = None, approximate: bool = True, nomkstream: bool = False, minid: Union[StreamIdT, None] = None, - limit: Union[int, None] = None, + limit: Optional[int] = None, ) -> ResponseT: """ Add to a stream. @@ -3544,7 +3544,7 @@ def xautoclaim( consumername: ConsumerT, min_idle_time: int, start_id: StreamIdT = "0-0", - count: Union[int, None] = None, + count: Optional[int] = None, justid: bool = False, ) -> ResponseT: """ @@ -3595,9 +3595,9 @@ def xclaim( consumername: ConsumerT, min_idle_time: int, message_ids: Union[List[StreamIdT], Tuple[StreamIdT]], - idle: Union[int, None] = None, - time: Union[int, None] = None, - retrycount: Union[int, None] = None, + idle: Optional[int] = None, + time: Optional[int] = None, + retrycount: Optional[int] = None, force: bool = False, justid: bool = False, ) -> ResponseT: @@ -3829,7 +3829,7 @@ def xpending_range( max: StreamIdT, count: int, consumername: Union[ConsumerT, None] = None, - idle: Union[int, None] = None, + idle: Optional[int] = None, ) -> ResponseT: """ Returns information about pending messages, in a range. @@ -3883,7 +3883,7 @@ def xrange( name: KeyT, min: StreamIdT = "-", max: StreamIdT = "+", - count: Union[int, None] = None, + count: Optional[int] = None, ) -> ResponseT: """ Read stream values within an interval. @@ -3913,8 +3913,8 @@ def xrange( def xread( self, streams: Dict[KeyT, StreamIdT], - count: Union[int, None] = None, - block: Union[int, None] = None, + count: Optional[int] = None, + block: Optional[int] = None, ) -> ResponseT: """ Block and monitor multiple streams for new data. @@ -3953,8 +3953,8 @@ def xreadgroup( groupname: str, consumername: str, streams: Dict[KeyT, StreamIdT], - count: Union[int, None] = None, - block: Union[int, None] = None, + count: Optional[int] = None, + block: Optional[int] = None, noack: bool = False, ) -> ResponseT: """ @@ -4000,7 +4000,7 @@ def xrevrange( name: KeyT, max: StreamIdT = "+", min: StreamIdT = "-", - count: Union[int, None] = None, + count: Optional[int] = None, ) -> ResponseT: """ Read stream values within an interval, in reverse order. @@ -4030,10 +4030,10 @@ def xrevrange( def xtrim( self, name: KeyT, - maxlen: Union[int, None] = None, + maxlen: Optional[int] = None, approximate: bool = True, minid: Union[StreamIdT, None] = None, - limit: Union[int, None] = None, + limit: Optional[int] = None, ) -> ResponseT: """ Trims old messages from a stream. @@ -4205,7 +4205,7 @@ def zincrby(self, name: KeyT, amount: float, value: EncodableT) -> ResponseT: return self.execute_command("ZINCRBY", name, amount, value) def zinter( - self, keys: KeysT, aggregate: Union[str, None] = None, withscores: bool = False + self, keys: KeysT, aggregate: Optional[str] = None, withscores: bool = False ) -> ResponseT: """ Return the intersect of multiple sorted sets specified by ``keys``. @@ -4224,7 +4224,7 @@ def zinterstore( self, dest: KeyT, keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], - aggregate: Union[str, None] = None, + aggregate: Optional[str] = None, ) -> ResponseT: """ Intersect multiple sorted sets specified by ``keys`` into a new @@ -4263,7 +4263,7 @@ def zlexcount(self, name, min, max): """ return self.execute_command("ZLEXCOUNT", name, min, max, keys=[name]) - def zpopmax(self, name: KeyT, count: Union[int, None] = None) -> ResponseT: + def zpopmax(self, name: KeyT, count: Optional[int] = None) -> ResponseT: """ Remove and return up to ``count`` members with the highest scores from the sorted set ``name``. @@ -4274,7 +4274,7 @@ def zpopmax(self, name: KeyT, count: Union[int, None] = None) -> ResponseT: options = {"withscores": True} return self.execute_command("ZPOPMAX", name, *args, **options) - def zpopmin(self, name: KeyT, count: Union[int, None] = None) -> ResponseT: + def zpopmin(self, name: KeyT, count: Optional[int] = None) -> ResponseT: """ Remove and return up to ``count`` members with the lowest scores from the sorted set ``name``. @@ -4286,7 +4286,7 @@ def zpopmin(self, name: KeyT, count: Union[int, None] = None) -> ResponseT: return self.execute_command("ZPOPMIN", name, *args, **options) def zrandmember( - self, key: KeyT, count: int = None, withscores: bool = False + self, key: KeyT, count: Optional[int] = None, withscores: bool = False ) -> ResponseT: """ Return a random element from the sorted set value stored at key. @@ -4418,8 +4418,8 @@ def _zrange( bylex: bool = False, withscores: bool = False, score_cast_func: Union[type, Callable, None] = float, - offset: Union[int, None] = None, - num: Union[int, None] = None, + offset: Optional[int] = None, + num: Optional[int] = None, ) -> ResponseT: if byscore and bylex: raise DataError("``byscore`` and ``bylex`` can not be specified together.") @@ -4457,8 +4457,8 @@ def zrange( score_cast_func: Union[type, Callable] = float, byscore: bool = False, bylex: bool = False, - offset: int = None, - num: int = None, + offset: Optional[int] = None, + num: Optional[int] = None, ) -> ResponseT: """ Return a range of values from sorted set ``name`` between @@ -4545,8 +4545,8 @@ def zrangestore( byscore: bool = False, bylex: bool = False, desc: bool = False, - offset: Union[int, None] = None, - num: Union[int, None] = None, + offset: Optional[int] = None, + num: Optional[int] = None, ) -> ResponseT: """ Stores in ``dest`` the result of a range of values from sorted set @@ -4591,8 +4591,8 @@ def zrangebylex( name: KeyT, min: EncodableT, max: EncodableT, - start: Union[int, None] = None, - num: Union[int, None] = None, + start: Optional[int] = None, + num: Optional[int] = None, ) -> ResponseT: """ Return the lexicographical range of values from sorted set ``name`` @@ -4615,8 +4615,8 @@ def zrevrangebylex( name: KeyT, max: EncodableT, min: EncodableT, - start: Union[int, None] = None, - num: Union[int, None] = None, + start: Optional[int] = None, + num: Optional[int] = None, ) -> ResponseT: """ Return the reversed lexicographical range of values from sorted set @@ -4639,8 +4639,8 @@ def zrangebyscore( name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT, - start: Union[int, None] = None, - num: Union[int, None] = None, + start: Optional[int] = None, + num: Optional[int] = None, withscores: bool = False, score_cast_func: Union[type, Callable] = float, ) -> ResponseT: @@ -4674,8 +4674,8 @@ def zrevrangebyscore( name: KeyT, max: ZScoreBoundT, min: ZScoreBoundT, - start: Union[int, None] = None, - num: Union[int, None] = None, + start: Optional[int] = None, + num: Optional[int] = None, withscores: bool = False, score_cast_func: Union[type, Callable] = float, ): @@ -4794,7 +4794,7 @@ def zscore(self, name: KeyT, value: EncodableT) -> ResponseT: def zunion( self, keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], - aggregate: Union[str, None] = None, + aggregate: Optional[str] = None, withscores: bool = False, ) -> ResponseT: """ @@ -4811,7 +4811,7 @@ def zunionstore( self, dest: KeyT, keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], - aggregate: Union[str, None] = None, + aggregate: Optional[str] = None, ) -> ResponseT: """ Union multiple sorted sets specified by ``keys`` into @@ -4843,7 +4843,7 @@ def _zaggregate( command: str, dest: Union[KeyT, None], keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], - aggregate: Union[str, None] = None, + aggregate: Optional[str] = None, **options, ) -> ResponseT: pieces: list[EncodableT] = [command] @@ -5994,7 +5994,7 @@ def geoadd( return self.execute_command("GEOADD", *pieces) def geodist( - self, name: KeyT, place1: FieldT, place2: FieldT, unit: Union[str, None] = None + self, name: KeyT, place1: FieldT, place2: FieldT, unit: Optional[str] = None ) -> ResponseT: """ Return the distance between ``place1`` and ``place2`` members of the @@ -6036,14 +6036,14 @@ def georadius( longitude: float, latitude: float, radius: float, - unit: Union[str, None] = None, + unit: Optional[str] = None, withdist: bool = False, withcoord: bool = False, withhash: bool = False, - count: Union[int, None] = None, - sort: Union[str, None] = None, - store: Union[KeyT, None] = None, - store_dist: Union[KeyT, None] = None, + count: Optional[int] = None, + sort: Optional[str] = None, + store: Optional[KeyT] = None, + store_dist: Optional[KeyT] = None, any: bool = False, ) -> ResponseT: """ @@ -6098,12 +6098,12 @@ def georadiusbymember( name: KeyT, member: FieldT, radius: float, - unit: Union[str, None] = None, + unit: Optional[str] = None, withdist: bool = False, withcoord: bool = False, withhash: bool = False, - count: Union[int, None] = None, - sort: Union[str, None] = None, + count: Optional[int] = None, + sort: Optional[str] = None, store: Union[KeyT, None] = None, store_dist: Union[KeyT, None] = None, any: bool = False, @@ -6188,8 +6188,8 @@ def geosearch( radius: Union[float, None] = None, width: Union[float, None] = None, height: Union[float, None] = None, - sort: Union[str, None] = None, - count: Union[int, None] = None, + sort: Optional[str] = None, + count: Optional[int] = None, any: bool = False, withcoord: bool = False, withdist: bool = False, @@ -6263,15 +6263,15 @@ def geosearchstore( self, dest: KeyT, name: KeyT, - member: Union[FieldT, None] = None, - longitude: Union[float, None] = None, - latitude: Union[float, None] = None, + member: Optional[FieldT] = None, + longitude: Optional[float] = None, + latitude: Optional[float] = None, unit: str = "m", - radius: Union[float, None] = None, - width: Union[float, None] = None, - height: Union[float, None] = None, - sort: Union[str, None] = None, - count: Union[int, None] = None, + radius: Optional[float] = None, + width: Optional[float] = None, + height: Optional[float] = None, + sort: Optional[str] = None, + count: Optional[int] = None, any: bool = False, storedist: bool = False, ) -> ResponseT: diff --git a/redis/commands/json/commands.py b/redis/commands/json/commands.py index da879df611..48849e1888 100644 --- a/redis/commands/json/commands.py +++ b/redis/commands/json/commands.py @@ -15,7 +15,7 @@ class JSONCommands: def arrappend( self, name: str, path: Optional[str] = Path.root_path(), *args: List[JsonType] - ) -> List[Union[int, None]]: + ) -> List[Optional[int]]: """Append the objects ``args`` to the array under the ``path` in key ``name``. @@ -33,7 +33,7 @@ def arrindex( scalar: int, start: Optional[int] = None, stop: Optional[int] = None, - ) -> List[Union[int, None]]: + ) -> List[Optional[int]]: """ Return the index of ``scalar`` in the JSON array under ``path`` at key ``name``. @@ -53,7 +53,7 @@ def arrindex( def arrinsert( self, name: str, path: str, index: int, *args: List[JsonType] - ) -> List[Union[int, None]]: + ) -> List[Optional[int]]: """Insert the objects ``args`` to the array at index ``index`` under the ``path` in key ``name``. @@ -66,7 +66,7 @@ def arrinsert( def arrlen( self, name: str, path: Optional[str] = Path.root_path() - ) -> List[Union[int, None]]: + ) -> List[Optional[int]]: """Return the length of the array JSON value under ``path`` at key``name``. @@ -79,7 +79,7 @@ def arrpop( name: str, path: Optional[str] = Path.root_path(), index: Optional[int] = -1, - ) -> List[Union[str, None]]: + ) -> List[Optional[str]]: """Pop the element at ``index`` in the array JSON value under ``path`` at key ``name``. @@ -89,7 +89,7 @@ def arrpop( def arrtrim( self, name: str, path: str, start: int, stop: int - ) -> List[Union[int, None]]: + ) -> List[Optional[int]]: """Trim the array JSON value under ``path`` at key ``name`` to the inclusive range given by ``start`` and ``stop``. @@ -113,7 +113,7 @@ def resp(self, name: str, path: Optional[str] = Path.root_path()) -> List: def objkeys( self, name: str, path: Optional[str] = Path.root_path() - ) -> List[Union[List[str], None]]: + ) -> List[Optional[List[str]]]: """Return the key names in the dictionary JSON value under ``path`` at key ``name``. @@ -357,7 +357,7 @@ def set_path( return set_files_result - def strlen(self, name: str, path: Optional[str] = None) -> List[Union[int, None]]: + def strlen(self, name: str, path: Optional[str] = None) -> List[Optional[int]]: """Return the length of the string JSON value under ``path`` at key ``name``. From 07a4bbbda13779191fe3bea0ddc67a0cfff4e611 Mon Sep 17 00:00:00 2001 From: Armin Berres <20811121+aberres@users.noreply.github.com> Date: Thu, 8 May 2025 14:55:33 +0200 Subject: [PATCH 102/113] Allow newer PyJWT versions (#3636) Close #3630 Co-authored-by: Armin Berres --- CHANGES | 1 + pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGES b/CHANGES index 7ce774621d..50126a86d7 100644 --- a/CHANGES +++ b/CHANGES @@ -70,6 +70,7 @@ * Close Unix sockets if the connection attempt fails. This prevents `ResourceWarning`s. (#3314) * Close SSL sockets if the connection attempt fails, or if validations fail. (#3317) * Eliminate mutable default arguments in the `redis.commands.core.Script` class. (#3332) + * Allow newer versions of PyJWT as dependency. (#3630) * 4.1.3 (Feb 8, 2022) * Fix flushdb and flushall (#1926) diff --git a/pyproject.toml b/pyproject.toml index ab3e4cd77e..5cd40c0212 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ ocsp = [ "requests>=2.31.0", ] jwt = [ - "PyJWT~=2.9.0", + "PyJWT>=2.9.0", ] [project.urls] From fa3067d9901095c0bbe3ec04c2b905fd81e09e0a Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Thu, 8 May 2025 16:59:46 +0300 Subject: [PATCH 103/113] Updating Redis 8 test image for GH pipeline (#3639) --- .github/workflows/integration.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index fbfcecdf68..dcc21b5a9c 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -74,7 +74,7 @@ jobs: max-parallel: 15 fail-fast: false matrix: - redis-version: ['8.0-RC2-pre', '${{ needs.redis_version.outputs.CURRENT }}', '7.2.7', '6.2.17'] + redis-version: ['8.0.1-pre', '${{ needs.redis_version.outputs.CURRENT }}', '7.2.7', '6.2.17'] python-version: ['3.8', '3.13'] parser-backend: ['plain'] event-loop: ['asyncio'] From f3dfbd4646cfa76609275a812ed4de30e8dcef40 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Thu, 8 May 2025 18:09:26 +0300 Subject: [PATCH 104/113] Prevent RuntimeError while reinitializing clusters - sync and async (#3633) * Prevent RuntimeError while reinitializing clusters - sync and async * Applying copilot's review comments --- redis/asyncio/cluster.py | 4 +++- redis/cluster.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 23e039c62f..9faf5b891d 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -1326,7 +1326,9 @@ async def initialize(self) -> None: startup_nodes_reachable = False fully_covered = False exception = None - for startup_node in self.startup_nodes.values(): + # Convert to tuple to prevent RuntimeError if self.startup_nodes + # is modified during iteration + for startup_node in tuple(self.startup_nodes.values()): try: # Make sure cluster mode is enabled on this node try: diff --git a/redis/cluster.py b/redis/cluster.py index 99e5c37f9d..6e3505404a 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1674,7 +1674,9 @@ def initialize(self): fully_covered = False kwargs = self.connection_kwargs exception = None - for startup_node in self.startup_nodes.values(): + # Convert to tuple to prevent RuntimeError if self.startup_nodes + # is modified during iteration + for startup_node in tuple(self.startup_nodes.values()): try: if startup_node.redis_connection: r = startup_node.redis_connection From 36fec153ed99963432f623275847c16694ff83d2 Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Thu, 8 May 2025 19:55:26 +0300 Subject: [PATCH 105/113] Adding return types for the RedisModuleCommands class (#3632) --- redis/commands/redismodules.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/redis/commands/redismodules.py b/redis/commands/redismodules.py index 6e253b1597..078844f7aa 100644 --- a/redis/commands/redismodules.py +++ b/redis/commands/redismodules.py @@ -1,4 +1,14 @@ +from __future__ import annotations + from json import JSONDecoder, JSONEncoder +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .bf import BFBloom, CFBloom, CMSBloom, TDigestBloom, TOPKBloom + from .json import JSON + from .search import AsyncSearch, Search + from .timeseries import TimeSeries + from .vectorset import VectorSet class RedisModuleCommands: @@ -6,7 +16,7 @@ class RedisModuleCommands: modules into the command namespace. """ - def json(self, encoder=JSONEncoder(), decoder=JSONDecoder()): + def json(self, encoder=JSONEncoder(), decoder=JSONDecoder()) -> JSON: """Access the json namespace, providing support for redis json.""" from .json import JSON @@ -14,7 +24,7 @@ def json(self, encoder=JSONEncoder(), decoder=JSONDecoder()): jj = JSON(client=self, encoder=encoder, decoder=decoder) return jj - def ft(self, index_name="idx"): + def ft(self, index_name="idx") -> Search: """Access the search namespace, providing support for redis search.""" from .search import Search @@ -22,7 +32,7 @@ def ft(self, index_name="idx"): s = Search(client=self, index_name=index_name) return s - def ts(self): + def ts(self) -> TimeSeries: """Access the timeseries namespace, providing support for redis timeseries data. """ @@ -32,7 +42,7 @@ def ts(self): s = TimeSeries(client=self) return s - def bf(self): + def bf(self) -> BFBloom: """Access the bloom namespace.""" from .bf import BFBloom @@ -40,7 +50,7 @@ def bf(self): bf = BFBloom(client=self) return bf - def cf(self): + def cf(self) -> CFBloom: """Access the bloom namespace.""" from .bf import CFBloom @@ -48,7 +58,7 @@ def cf(self): cf = CFBloom(client=self) return cf - def cms(self): + def cms(self) -> CMSBloom: """Access the bloom namespace.""" from .bf import CMSBloom @@ -56,7 +66,7 @@ def cms(self): cms = CMSBloom(client=self) return cms - def topk(self): + def topk(self) -> TOPKBloom: """Access the bloom namespace.""" from .bf import TOPKBloom @@ -64,7 +74,7 @@ def topk(self): topk = TOPKBloom(client=self) return topk - def tdigest(self): + def tdigest(self) -> TDigestBloom: """Access the bloom namespace.""" from .bf import TDigestBloom @@ -72,7 +82,7 @@ def tdigest(self): tdigest = TDigestBloom(client=self) return tdigest - def vset(self): + def vset(self) -> VectorSet: """Access the VectorSet commands namespace.""" from .vectorset import VectorSet @@ -82,7 +92,7 @@ def vset(self): class AsyncRedisModuleCommands(RedisModuleCommands): - def ft(self, index_name="idx"): + def ft(self, index_name="idx") -> AsyncSearch: """Access the search namespace, providing support for redis search.""" from .search import AsyncSearch From f69192acc6201a6d3273d50e7787e72c89d8a168 Mon Sep 17 00:00:00 2001 From: Igor Malinovskiy Date: Fri, 9 May 2025 13:02:14 +0200 Subject: [PATCH 106/113] Test against unstable hiredis-py (#3617) * Test against unstable hiredis-py * Create a separate workflow instead * Remove outdated guard for hiredis-py The guard was required to prevent cluster tests on RESP3 with hiredis-py before 3.1.0 --------- Co-authored-by: petyaslavova --- .github/actions/run-tests/action.yml | 31 ++++++--- .github/workflows/hiredis-py-integration.yaml | 66 +++++++++++++++++++ 2 files changed, 89 insertions(+), 8 deletions(-) create mode 100644 .github/workflows/hiredis-py-integration.yaml diff --git a/.github/actions/run-tests/action.yml b/.github/actions/run-tests/action.yml index aa958a9236..ae9575e055 100644 --- a/.github/actions/run-tests/action.yml +++ b/.github/actions/run-tests/action.yml @@ -14,6 +14,10 @@ inputs: description: 'hiredis version to test against' required: false default: '>3.0.0' + hiredis-branch: + description: 'hiredis branch to test against' + required: false + default: 'master' event-loop: description: 'Event loop to use' required: false @@ -28,6 +32,14 @@ runs: python-version: ${{ inputs.python-version }} cache: 'pip' + - uses: actions/checkout@v4 + if: ${{ inputs.parser-backend == 'hiredis' && inputs.hiredis-version == 'unstable' }} + with: + repository: redis/hiredis-py + submodules: true + path: hiredis-py + ref: ${{ inputs.hiredis-branch }} + - name: Setup Test environment env: REDIS_VERSION: ${{ inputs.redis-version }} @@ -40,8 +52,13 @@ runs: pip uninstall -y redis # uninstall Redis package installed via redis-entraid pip install -e .[jwt] # install the working copy if [ "${{inputs.parser-backend}}" == "hiredis" ]; then - pip install "hiredis${{inputs.hiredis-version}}" - echo "PARSER_BACKEND=$(echo "${{inputs.parser-backend}}_${{inputs.hiredis-version}}" | sed 's/[^a-zA-Z0-9]/_/g')" >> $GITHUB_ENV + if [[ "${{inputs.hiredis-version}}" == "unstable" ]]; then + echo "Installing unstable version of hiredis from local directory" + pip install -e ./hiredis-py + else + pip install "hiredis${{inputs.hiredis-version}}" + fi + echo "PARSER_BACKEND=$(echo "${{inputs.parser-backend}}_${{inputs.hiredis-version}}" | sed 's/[^a-zA-Z0-9]/_/g')" >> $GITHUB_ENV else echo "PARSER_BACKEND=${{inputs.parser-backend}}" >> $GITHUB_ENV fi @@ -108,12 +125,10 @@ runs: fi echo "::endgroup::" - - if [ "$protocol" == "2" ] || [ "${{inputs.parser-backend}}" != 'hiredis' ]; then - echo "::group::RESP${protocol} cluster tests" - invoke cluster-tests $eventloop --protocol=${protocol} - echo "::endgroup::" - fi + + echo "::group::RESP${protocol} cluster tests" + invoke cluster-tests $eventloop --protocol=${protocol} + echo "::endgroup::" } run_tests 2 "${{inputs.event-loop}}" diff --git a/.github/workflows/hiredis-py-integration.yaml b/.github/workflows/hiredis-py-integration.yaml new file mode 100644 index 0000000000..816a143fba --- /dev/null +++ b/.github/workflows/hiredis-py-integration.yaml @@ -0,0 +1,66 @@ +name: Hiredis-py integration tests + +on: + workflow_dispatch: + inputs: + redis-py-branch: + description: 'redis-py branch to run tests on' + required: true + default: 'master' + hiredis-branch: + description: 'hiredis-py branch to run tests on' + required: true + default: 'master' + +concurrency: + group: ${{ github.event.pull_request.number || github.ref }}-hiredis-integration + cancel-in-progress: true + +permissions: + contents: read # to fetch code (actions/checkout) + +env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + # this speeds up coverage with Python 3.12: https://github.com/nedbat/coveragepy/issues/1665 + COVERAGE_CORE: sysmon + CURRENT_CLIENT_LIBS_TEST_STACK_IMAGE_TAG: 'rs-7.4.0-v2' + CURRENT_REDIS_VERSION: '7.4.2' + +jobs: + redis_version: + runs-on: ubuntu-latest + outputs: + CURRENT: ${{ env.CURRENT_REDIS_VERSION }} + steps: + - name: Compute outputs + run: | + echo "CURRENT=${{ env.CURRENT_REDIS_VERSION }}" >> $GITHUB_OUTPUT + + hiredis-tests: + runs-on: ubuntu-latest + needs: [redis_version] + timeout-minutes: 60 + strategy: + max-parallel: 15 + fail-fast: false + matrix: + redis-version: [ '${{ needs.redis_version.outputs.CURRENT }}' ] + python-version: [ '3.8', '3.13'] + parser-backend: [ 'hiredis' ] + hiredis-version: [ 'unstable' ] + hiredis-branch: ${{ inputs.hiredis-branch }} + event-loop: [ 'asyncio' ] + env: + ACTIONS_ALLOW_UNSECURE_COMMANDS: true + name: Redis ${{ matrix.redis-version }}; Python ${{ matrix.python-version }}; RESP Parser:${{matrix.parser-backend}} (${{ matrix.hiredis-version }}); EL:${{matrix.event-loop}} + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ inputs.redis-py-branch }} + - name: Run tests + uses: ./.github/actions/run-tests + with: + python-version: ${{ matrix.python-version }} + parser-backend: ${{ matrix.parser-backend }} + redis-version: ${{ matrix.redis-version }} + hiredis-version: ${{ matrix.hiredis-version }} \ No newline at end of file From cda3fbd75f2eb195a01aa49637570bc15d695102 Mon Sep 17 00:00:00 2001 From: Igor Malinovskiy Date: Fri, 9 May 2025 17:10:42 +0200 Subject: [PATCH 107/113] Fix matrix in hiredis-py-integration.yaml (#3641) --- .github/workflows/hiredis-py-integration.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/hiredis-py-integration.yaml b/.github/workflows/hiredis-py-integration.yaml index 816a143fba..947a0df05b 100644 --- a/.github/workflows/hiredis-py-integration.yaml +++ b/.github/workflows/hiredis-py-integration.yaml @@ -48,7 +48,6 @@ jobs: python-version: [ '3.8', '3.13'] parser-backend: [ 'hiredis' ] hiredis-version: [ 'unstable' ] - hiredis-branch: ${{ inputs.hiredis-branch }} event-loop: [ 'asyncio' ] env: ACTIONS_ALLOW_UNSECURE_COMMANDS: true @@ -63,4 +62,5 @@ jobs: python-version: ${{ matrix.python-version }} parser-backend: ${{ matrix.parser-backend }} redis-version: ${{ matrix.redis-version }} - hiredis-version: ${{ matrix.hiredis-version }} \ No newline at end of file + hiredis-version: ${{ matrix.hiredis-version }} + hiredis-branch: ${{ inputs.hiredis-branch }} \ No newline at end of file From 6ef5d7167ccdff5e4fa84ec2ff68349aeccf8b0c Mon Sep 17 00:00:00 2001 From: Igor Malinovskiy Date: Mon, 12 May 2025 08:50:11 +0200 Subject: [PATCH 108/113] Export REDIS_MAJOR_VERSION correctly in run-tests (#3642) --- .github/actions/run-tests/action.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/actions/run-tests/action.yml b/.github/actions/run-tests/action.yml index ae9575e055..20cceb922b 100644 --- a/.github/actions/run-tests/action.yml +++ b/.github/actions/run-tests/action.yml @@ -66,6 +66,7 @@ runs: echo "::group::Starting Redis servers" redis_major_version=$(echo "$REDIS_VERSION" | grep -oP '^\d+') + echo "REDIS_MAJOR_VERSION=${redis_major_version}" >> $GITHUB_ENV if (( redis_major_version < 8 )); then echo "Using redis-stack for module tests" @@ -87,8 +88,7 @@ runs: if (( redis_major_version < 7 )); then export REDIS_STACK_EXTRA_ARGS="--tls-auth-clients optional --save ''" - export REDIS_EXTRA_ARGS="--tls-auth-clients optional --save ''" - echo "REDIS_MAJOR_VERSION=${redis_major_version}" >> $GITHUB_ENV + export REDIS_EXTRA_ARGS="--tls-auth-clients optional --save ''" fi invoke devenv --endpoints=all-stack From a2f7e4b4d91ef8e3f9f05fea4e06df50c88efdea Mon Sep 17 00:00:00 2001 From: robertosantamaria-scopely <136002179+robertosantamaria-scopely@users.noreply.github.com> Date: Mon, 12 May 2025 12:54:29 +0200 Subject: [PATCH 109/113] Multi exec on cluster (#3611) * feat(cluster): support for transactions on cluster-aware client Adds support for transactions based on multi/watch/exec on clusters. Transactions in this mode are limited to a single hash slot. Contributed-by: Scopely * fix: remove deprecated argument * remove attributions from code * Refactor ClusterPipeline to use execution strategies * Refactored strategy to use composition * Added test cases * Sync with master * Filter tests, ensure that tests are working after refactor * Added test case * Revert port changes * Improved exception handling * Change visibility of variable to public * Changed variable ref * Changed ref type * Added documentation * Refactored retries, fixed comments, fixed linters * Added word to a wordlist * Revert port changes * Added quotes * Fixed docs * Updated CONNECTION_ERRORS * Codestyle fixes * Updated docs * Revert import --------- Co-authored-by: vladvildanov Co-authored-by: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> --- .github/wordlist.txt | 1 + .gitignore | 1 + CHANGES | 1 + docs/advanced_features.rst | 119 ++- redis/__init__.py | 6 + redis/client.py | 15 +- redis/cluster.py | 1290 ++++++++++++++++++++++------- redis/exceptions.py | 18 + tests/test_cluster.py | 115 ++- tests/test_cluster_transaction.py | 392 +++++++++ 10 files changed, 1601 insertions(+), 357 deletions(-) create mode 100644 tests/test_cluster_transaction.py diff --git a/.github/wordlist.txt b/.github/wordlist.txt index 29bcaa9d77..150f96a624 100644 --- a/.github/wordlist.txt +++ b/.github/wordlist.txt @@ -2,6 +2,7 @@ APM ARGV BFCommands CacheImpl +CAS CFCommands CMSCommands ClusterNode diff --git a/.gitignore b/.gitignore index 5f77dcfde4..7184ad4e20 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ vagrant/.vagrant .cache .eggs .idea +.vscode .coverage env venv diff --git a/CHANGES b/CHANGES index 50126a86d7..1a1f4eca11 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,4 @@ + * Support transactions in ClusterPipeline * Removing support for RedisGraph module. RedisGraph support is deprecated since Redis Stack 7.2 (https://redis.com/blog/redisgraph-eol/) * Fix lock.extend() typedef to accept float TTL extension * Update URL in the readme linking to Redis University diff --git a/docs/advanced_features.rst b/docs/advanced_features.rst index 603e728e84..11ab8af716 100644 --- a/docs/advanced_features.rst +++ b/docs/advanced_features.rst @@ -167,6 +167,7 @@ the server. .. code:: python + >>> rc = RedisCluster() >>> with rc.pipeline() as pipe: ... pipe.set('foo', 'value1') ... pipe.set('bar', 'value2') @@ -177,20 +178,110 @@ the server. ... pipe.set('foo1', 'bar1').get('foo1').execute() [True, b'bar1'] -Please note: - RedisCluster pipelines currently only support key-based -commands. - The pipeline gets its ‘read_from_replicas’ value from the -cluster’s parameter. Thus, if read from replications is enabled in the -cluster instance, the pipeline will also direct read commands to -replicas. - The ‘transaction’ option is NOT supported in cluster-mode. -In non-cluster mode, the ‘transaction’ option is available when -executing pipelines. This wraps the pipeline commands with MULTI/EXEC -commands, and effectively turns the pipeline commands into a single -transaction block. This means that all commands are executed -sequentially without any interruptions from other clients. However, in -cluster-mode this is not possible, because commands are partitioned -according to their respective destination nodes. This means that we can -not turn the pipeline commands into one transaction block, because in -most cases they are split up into several smaller pipelines. +Please note: + +- RedisCluster pipelines currently only support key-based commands. +- The pipeline gets its ‘load_balancing_strategy’ value from the + cluster’s parameter. Thus, if read from replications is enabled in + the cluster instance, the pipeline will also direct read commands to + replicas. + + +Transactions in clusters +~~~~~~~~~~~~~~~~~~~~~~~~ + +Transactions are supported in cluster-mode with one caveat: all keys of +all commands issued on a transaction pipeline must reside on the +same slot. This is similar to the limitation of multikey commands in +cluster. The reason behind this is that the Redis engine does not offer +a mechanism to block or exchange key data across nodes on the fly. A +client may add some logic to abstract engine limitations when running +on a cluster, such as the pipeline behavior explained on the previous +block, but there is no simple way that a client can enforce atomicity +across nodes on a distributed system. + +The compromise of limiting the transaction pipeline to same-slot keys +is exactly that: a compromise. While this behavior is different from +non-transactional cluster pipelines, it simplifies migration of clients +from standalone to cluster under some circumstances. Note that application +code that issues multi/exec commands on a standalone client without +embedding them within a pipeline would eventually get ‘AttributeError’. +With this approach, if the application uses ‘client.pipeline(transaction=True)’, +then switching the client with a cluster-aware instance would simplify +code changes (to some extent). This may be true for application code that +makes use of hash keys, since its transactions may already be +mapping all commands to the same slot. + +An alternative is some kind of two-step commit solution, where a slot +validation is run before the actual commands are run. This could work +with controlled node maintenance but does not cover single node failures. + +Given the cluster limitations for transactions, by default pipeline isn't in +transactional mode. To enable transactional context set: + +.. code:: python + + >>> p = rc.pipeline(transaction=True) + +After entering the transactional context you can add commands to a transactional +context, by one of the following ways: + +.. code:: python + + >>> p = rc.pipeline(transaction=True) # Chaining commands + >>> p.set("key", "value") + >>> p.get("key") + >>> response = p.execute() + +Or + +.. code:: python + + >>> with rc.pipeline(transaction=True) as pipe: # Using context manager + ... pipe.set("key", "value") + ... pipe.get("key") + ... response = pipe.execute() + +As you see there's no need to explicitly send `MULTI/EXEC` commands to control context start/end +`ClusterPipeline` will take care of it. + +To ensure that different keys will be mapped to a same hash slot on the server side +prepend your keys with the same hash tag, the technique that allows you to control +keys distribution. +More information `here `_ + +.. code:: python + + >>> with rc.pipeline(transaction=True) as pipe: + ... pipe.set("{tag}foo", "bar") + ... pipe.set("{tag}bar", "foo") + ... pipe.get("{tag}foo") + ... pipe.get("{tag}bar") + ... response = pipe.execute() + +CAS Transactions +~~~~~~~~~~~~~~~~~~~~~~~~ + +If you want to apply optimistic locking for certain keys, you have to execute +`WATCH` command in transactional context. `WATCH` command follows the same limitations +as any other multi key command - all keys should be mapped to the same hash slot. + +However, the difference between CAS transaction and normal one is that you have to +explicitly call MULTI command to indicate the start of transactional context, WATCH +command itself and any subsequent commands before MULTI will be immediately executed +on the server side so you can apply optimistic locking and get necessary data before +transaction execution. + +.. code:: python + + >>> with rc.pipeline(transaction=True) as pipe: + ... pipe.watch("mykey") # Apply locking by immediately executing command + ... val = pipe.get("mykey") # Immediately retrieves value + ... val = val + 1 # Increment value + ... pipe.multi() # Starting transaction context + ... pipe.set("mykey", val) # Command will be pipelined + ... response = pipe.execute() # Returns OK or None if key was modified in the meantime + Publish / Subscribe ------------------- diff --git a/redis/__init__.py b/redis/__init__.py index f82a876b2d..14030205e3 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -16,11 +16,14 @@ BusyLoadingError, ChildDeadlockedError, ConnectionError, + CrossSlotTransactionError, DataError, + InvalidPipelineStack, InvalidResponse, OutOfMemoryError, PubSubError, ReadOnlyError, + RedisClusterException, RedisError, ResponseError, TimeoutError, @@ -56,15 +59,18 @@ def int_or_str(value): "ConnectionError", "ConnectionPool", "CredentialProvider", + "CrossSlotTransactionError", "DataError", "from_url", "default_backoff", + "InvalidPipelineStack", "InvalidResponse", "OutOfMemoryError", "PubSubError", "ReadOnlyError", "Redis", "RedisCluster", + "RedisClusterException", "RedisError", "ResponseError", "Sentinel", diff --git a/redis/client.py b/redis/client.py index 2ef95600c2..dc4f0f9d0c 100755 --- a/redis/client.py +++ b/redis/client.py @@ -34,6 +34,7 @@ from redis.commands.core import Script from redis.connection import ( AbstractConnection, + Connection, ConnectionPool, SSLConnection, UnixDomainSocketConnection, @@ -1297,9 +1298,15 @@ class Pipeline(Redis): UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} - def __init__(self, connection_pool, response_callbacks, transaction, shard_hint): + def __init__( + self, + connection_pool: ConnectionPool, + response_callbacks, + transaction, + shard_hint, + ): self.connection_pool = connection_pool - self.connection = None + self.connection: Optional[Connection] = None self.response_callbacks = response_callbacks self.transaction = transaction self.shard_hint = shard_hint @@ -1434,7 +1441,9 @@ def pipeline_execute_command(self, *args, **options) -> "Pipeline": self.command_stack.append((args, options)) return self - def _execute_transaction(self, connection, commands, raise_on_error) -> List: + def _execute_transaction( + self, connection: Connection, commands, raise_on_error + ) -> List: cmds = chain([(("MULTI",), {})], commands, [(("EXEC",), {})]) all_cmds = connection.pack_commands( [args for args, options in cmds if EMPTY_RESPONSE not in options] diff --git a/redis/cluster.py b/redis/cluster.py index 6e3505404a..b614c598f9 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -3,18 +3,25 @@ import sys import threading import time +from abc import ABC, abstractmethod from collections import OrderedDict +from copy import copy from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from itertools import chain +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from redis._parsers import CommandsParser, Encoder from redis._parsers.helpers import parse_scan from redis.backoff import ExponentialWithJitterBackoff, NoBackoff from redis.cache import CacheConfig, CacheFactory, CacheFactoryInterface, CacheInterface -from redis.client import CaseInsensitiveDict, PubSub, Redis +from redis.client import EMPTY_RESPONSE, CaseInsensitiveDict, PubSub, Redis from redis.commands import READ_COMMANDS, RedisClusterCommands from redis.commands.helpers import list_or_args -from redis.connection import ConnectionPool, parse_url +from redis.connection import ( + Connection, + ConnectionPool, + parse_url, +) from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot from redis.event import ( AfterPooledConnectionsInstantiationEvent, @@ -28,7 +35,10 @@ ClusterDownError, ClusterError, ConnectionError, + CrossSlotTransactionError, DataError, + ExecAbortError, + InvalidPipelineStack, MovedError, RedisClusterException, RedisError, @@ -36,6 +46,7 @@ SlotNotCoveredError, TimeoutError, TryAgainError, + WatchError, ) from redis.lock import Lock from redis.retry import Retry @@ -60,7 +71,7 @@ def get_node_name(host: str, port: Union[str, int]) -> str: reason="Use get_connection(redis_node) instead", version="5.3.0", ) -def get_connection(redis_node, *args, **options): +def get_connection(redis_node: Redis, *args, **options) -> Connection: return redis_node.connection or redis_node.connection_pool.get_connection() @@ -741,7 +752,7 @@ def on_connect(self, connection): if self.user_on_connect_func is not None: self.user_on_connect_func(connection) - def get_redis_connection(self, node): + def get_redis_connection(self, node: "ClusterNode") -> Redis: if not node.redis_connection: with self._lock: if not node.redis_connection: @@ -839,9 +850,6 @@ def pipeline(self, transaction=None, shard_hint=None): if shard_hint: raise RedisClusterException("shard_hint is deprecated in cluster mode") - if transaction: - raise RedisClusterException("transaction is deprecated in cluster mode") - return ClusterPipeline( nodes_manager=self.nodes_manager, commands_parser=self.commands_parser, @@ -854,6 +862,7 @@ def pipeline(self, transaction=None, shard_hint=None): reinitialize_steps=self.reinitialize_steps, retry=self.retry, lock=self._lock, + transaction=transaction, ) def lock( @@ -1015,7 +1024,7 @@ def _get_command_keys(self, *args): redis_conn = self.get_default_node().redis_connection return self.commands_parser.get_keys(redis_conn, *args) - def determine_slot(self, *args): + def determine_slot(self, *args) -> int: """ Figure out what slot to use based on args. @@ -1228,8 +1237,6 @@ def _execute_command(self, target_node, *args, **kwargs): except AuthenticationError: raise except (ConnectionError, TimeoutError) as e: - # Connection retries are being handled in the node's - # Retry object. # ConnectionError can also be raised if we couldn't get a # connection from the pool before timing out, so check that # this is an actual connection before attempting to disconnect. @@ -1330,6 +1337,28 @@ def load_external_module(self, funcname, func): """ setattr(self, funcname, func) + def transaction(self, func, *watches, **kwargs): + """ + Convenience method for executing the callable `func` as a transaction + while watching all keys specified in `watches`. The 'func' callable + should expect a single argument which is a Pipeline object. + """ + shard_hint = kwargs.pop("shard_hint", None) + value_from_callable = kwargs.pop("value_from_callable", False) + watch_delay = kwargs.pop("watch_delay", None) + with self.pipeline(True, shard_hint) as pipe: + while True: + try: + if watches: + pipe.watch(*watches) + func_value = func(pipe) + exec_value = pipe.execute() + return func_value if value_from_callable else exec_value + except WatchError: + if watch_delay is not None and watch_delay > 0: + time.sleep(watch_delay) + continue + class ClusterNode: def __init__(self, host, port, server_type=None, redis_connection=None): @@ -1427,7 +1456,7 @@ def __init__( event_dispatcher: Optional[EventDispatcher] = None, **kwargs, ): - self.nodes_cache = {} + self.nodes_cache: Dict[str, Redis] = {} self.slots_cache = {} self.startup_nodes = {} self.default_node = None @@ -1527,7 +1556,7 @@ def get_node_from_slot( read_from_replicas=False, load_balancing_strategy=None, server_type=None, - ): + ) -> ClusterNode: """ Gets a node that servers this hash slot """ @@ -1823,6 +1852,16 @@ def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: return self.address_remap((host, port)) return host, port + def find_connection_owner(self, connection: Connection) -> Optional[Redis]: + node_name = get_node_name(connection.host, connection.port) + for node in tuple(self.nodes_cache.values()): + if node.redis_connection: + conn_args = node.redis_connection.connection_pool.connection_kwargs + if node_name == get_node_name( + conn_args.get("host"), conn_args.get("port") + ): + return node + class ClusterPubSub(PubSub): """ @@ -2082,6 +2121,10 @@ class ClusterPipeline(RedisCluster): TryAgainError, ) + NO_SLOTS_COMMANDS = {"UNWATCH"} + IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"} + UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} + @deprecated_args( args_to_warn=[ "cluster_error_retry_attempts", @@ -2102,6 +2145,7 @@ def __init__( reinitialize_steps: int = 5, retry: Optional[Retry] = None, lock=None, + transaction=False, **kwargs, ): """ """ @@ -2135,6 +2179,10 @@ def __init__( if lock is None: lock = threading.Lock() self._lock = lock + self.parent_execute_command = super().execute_command + self._execution_strategy: ExecutionStrategy = ( + PipelineStrategy(self) if not transaction else TransactionStrategy(self) + ) def __repr__(self): """ """ @@ -2156,7 +2204,7 @@ def __del__(self): def __len__(self): """ """ - return len(self.command_stack) + return len(self._execution_strategy.command_queue) def __bool__(self): "Pipeline instances should always evaluate to True on Python 3+" @@ -2166,45 +2214,35 @@ def execute_command(self, *args, **kwargs): """ Wrapper function for pipeline_execute_command """ - return self.pipeline_execute_command(*args, **kwargs) + return self._execution_strategy.execute_command(*args, **kwargs) def pipeline_execute_command(self, *args, **options): """ - Appends the executed command to the pipeline's command stack - """ - self.command_stack.append( - PipelineCommand(args, options, len(self.command_stack)) - ) - return self + Stage a command to be executed when execute() is next called - def raise_first_error(self, stack): - """ - Raise the first exception on the stack + Returns the current Pipeline object back so commands can be + chained together, such as: + + pipe = pipe.set('foo', 'bar').incr('baz').decr('bang') + + At some other point, you can then run: pipe.execute(), + which will execute all commands queued in the pipe. """ - for c in stack: - r = c.result - if isinstance(r, Exception): - self.annotate_exception(r, c.position + 1, c.args) - raise r + return self._execution_strategy.execute_command(*args, **options) def annotate_exception(self, exception, number, command): """ Provides extra context to the exception prior to it being handled """ - cmd = " ".join(map(safe_str, command)) - msg = ( - f"Command # {number} ({truncate_text(cmd)}) of pipeline " - f"caused error: {exception.args[0]}" - ) - exception.args = (msg,) + exception.args[1:] + self._execution_strategy.annotate_exception(exception, number, command) def execute(self, raise_on_error: bool = True) -> List[Any]: """ Execute all the commands in the current pipeline """ - stack = self.command_stack + try: - return self.send_cluster_commands(stack, raise_on_error) + return self._execution_strategy.execute(raise_on_error) finally: self.reset() @@ -2212,312 +2250,53 @@ def reset(self): """ Reset back to empty pipeline. """ - self.command_stack = [] - - self.scripts = set() - - # TODO: Implement - # make sure to reset the connection state in the event that we were - # watching something - # if self.watching and self.connection: - # try: - # # call this manually since our unwatch or - # # immediate_execute_command methods can call reset() - # self.connection.send_command('UNWATCH') - # self.connection.read_response() - # except ConnectionError: - # # disconnect will also remove any previous WATCHes - # self.connection.disconnect() - - # clean up the other instance attributes - self.watching = False - self.explicit_transaction = False - - # TODO: Implement - # we can safely return the connection to the pool here since we're - # sure we're no longer WATCHing anything - # if self.connection: - # self.connection_pool.release(self.connection) - # self.connection = None + self._execution_strategy.reset() def send_cluster_commands( self, stack, raise_on_error=True, allow_redirections=True ): - """ - Wrapper for CLUSTERDOWN error handling. - - If the cluster reports it is down it is assumed that: - - connection_pool was disconnected - - connection_pool was reseted - - refereh_table_asap set to True - - It will try the number of times specified by - the retries in config option "self.retry" - which defaults to 3 unless manually configured. - - If it reaches the number of times, the command will - raises ClusterDownException. - """ - if not stack: - return [] - retry_attempts = self.retry.get_retries() - while True: - try: - return self._send_cluster_commands( - stack, - raise_on_error=raise_on_error, - allow_redirections=allow_redirections, - ) - except RedisCluster.ERRORS_ALLOW_RETRY as e: - if retry_attempts > 0: - # Try again with the new cluster setup. All other errors - # should be raised. - retry_attempts -= 1 - pass - else: - raise e - - def _send_cluster_commands( - self, stack, raise_on_error=True, allow_redirections=True - ): - """ - Send a bunch of cluster commands to the redis cluster. - - `allow_redirections` If the pipeline should follow - `ASK` & `MOVED` responses automatically. If set - to false it will raise RedisClusterException. - """ - # the first time sending the commands we send all of - # the commands that were queued up. - # if we have to run through it again, we only retry - # the commands that failed. - attempt = sorted(stack, key=lambda x: x.position) - is_default_node = False - # build a list of node objects based on node names we need to - nodes = {} - - # as we move through each command that still needs to be processed, - # we figure out the slot number that command maps to, then from - # the slot determine the node. - for c in attempt: - while True: - # refer to our internal node -> slot table that - # tells us where a given command should route to. - # (it might be possible we have a cached node that no longer - # exists in the cluster, which is why we do this in a loop) - passed_targets = c.options.pop("target_nodes", None) - if passed_targets and not self._is_nodes_flag(passed_targets): - target_nodes = self._parse_target_nodes(passed_targets) - else: - target_nodes = self._determine_nodes( - *c.args, node_flag=passed_targets - ) - if not target_nodes: - raise RedisClusterException( - f"No targets were found to execute {c.args} command on" - ) - if len(target_nodes) > 1: - raise RedisClusterException( - f"Too many targets for command {c.args}" - ) - - node = target_nodes[0] - if node == self.get_default_node(): - is_default_node = True - - # now that we know the name of the node - # ( it's just a string in the form of host:port ) - # we can build a list of commands for each node. - node_name = node.name - if node_name not in nodes: - redis_node = self.get_redis_connection(node) - try: - connection = get_connection(redis_node) - except (ConnectionError, TimeoutError): - for n in nodes.values(): - n.connection_pool.release(n.connection) - # Connection retries are being handled in the node's - # Retry object. Reinitialize the node -> slot table. - self.nodes_manager.initialize() - if is_default_node: - self.replace_default_node() - raise - nodes[node_name] = NodeCommands( - redis_node.parse_response, - redis_node.connection_pool, - connection, - ) - nodes[node_name].append(c) - break - - # send the commands in sequence. - # we write to all the open sockets for each node first, - # before reading anything - # this allows us to flush all the requests out across the - # network essentially in parallel - # so that we can read them all in parallel as they come back. - # we dont' multiplex on the sockets as they come available, - # but that shouldn't make too much difference. - node_commands = nodes.values() - try: - node_commands = nodes.values() - for n in node_commands: - n.write() - - for n in node_commands: - n.read() - finally: - # release all of the redis connections we allocated earlier - # back into the connection pool. - # we used to do this step as part of a try/finally block, - # but it is really dangerous to - # release connections back into the pool if for some - # reason the socket has data still left in it - # from a previous operation. The write and - # read operations already have try/catch around them for - # all known types of errors including connection - # and socket level errors. - # So if we hit an exception, something really bad - # happened and putting any oF - # these connections back into the pool is a very bad idea. - # the socket might have unread buffer still sitting in it, - # and then the next time we read from it we pass the - # buffered result back from a previous command and - # every single request after to that connection will always get - # a mismatched result. - for n in nodes.values(): - n.connection_pool.release(n.connection) - - # if the response isn't an exception it is a - # valid response from the node - # we're all done with that command, YAY! - # if we have more commands to attempt, we've run into problems. - # collect all the commands we are allowed to retry. - # (MOVED, ASK, or connection errors or timeout errors) - attempt = sorted( - ( - c - for c in attempt - if isinstance(c.result, ClusterPipeline.ERRORS_ALLOW_RETRY) - ), - key=lambda x: x.position, + return self._execution_strategy.send_cluster_commands( + stack, raise_on_error=raise_on_error, allow_redirections=allow_redirections ) - if attempt and allow_redirections: - # RETRY MAGIC HAPPENS HERE! - # send these remaining commands one at a time using `execute_command` - # in the main client. This keeps our retry logic - # in one place mostly, - # and allows us to be more confident in correctness of behavior. - # at this point any speed gains from pipelining have been lost - # anyway, so we might as well make the best - # attempt to get the correct behavior. - # - # The client command will handle retries for each - # individual command sequentially as we pass each - # one into `execute_command`. Any exceptions - # that bubble out should only appear once all - # retries have been exhausted. - # - # If a lot of commands have failed, we'll be setting the - # flag to rebuild the slots table from scratch. - # So MOVED errors should correct themselves fairly quickly. - self.reinitialize_counter += 1 - if self._should_reinitialized(): - self.nodes_manager.initialize() - if is_default_node: - self.replace_default_node() - for c in attempt: - try: - # send each command individually like we - # do in the main client. - c.result = super().execute_command(*c.args, **c.options) - except RedisError as e: - c.result = e - - # turn the response back into a simple flat array that corresponds - # to the sequence of commands issued in the stack in pipeline.execute() - response = [] - for c in sorted(stack, key=lambda x: x.position): - if c.args[0] in self.cluster_response_callbacks: - # Remove keys entry, it needs only for cache. - c.options.pop("keys", None) - c.result = self.cluster_response_callbacks[c.args[0]]( - c.result, **c.options - ) - response.append(c.result) - - if raise_on_error: - self.raise_first_error(stack) - - return response - - def _fail_on_redirect(self, allow_redirections): - """ """ - if not allow_redirections: - raise RedisClusterException( - "ASK & MOVED redirection not allowed in this pipeline" - ) def exists(self, *keys): - return self.execute_command("EXISTS", *keys) + return self._execution_strategy.exists(*keys) def eval(self): """ """ - raise RedisClusterException("method eval() is not implemented") + return self._execution_strategy.eval() def multi(self): - """ """ - raise RedisClusterException("method multi() is not implemented") - - def immediate_execute_command(self, *args, **options): - """ """ - raise RedisClusterException( - "method immediate_execute_command() is not implemented" - ) + """ + Start a transactional block of the pipeline after WATCH commands + are issued. End the transactional block with `execute`. + """ + self._execution_strategy.multi() - def _execute_transaction(self, *args, **kwargs): + def load_scripts(self): """ """ - raise RedisClusterException("method _execute_transaction() is not implemented") + self._execution_strategy.load_scripts() - def load_scripts(self): + def discard(self): """ """ - raise RedisClusterException("method load_scripts() is not implemented") + self._execution_strategy.discard() def watch(self, *names): - """ """ - raise RedisClusterException("method watch() is not implemented") + """Watches the values at keys ``names``""" + self._execution_strategy.watch(*names) def unwatch(self): - """ """ - raise RedisClusterException("method unwatch() is not implemented") + """Unwatches all previously specified keys""" + self._execution_strategy.unwatch() def script_load_for_pipeline(self, *args, **kwargs): - """ """ - raise RedisClusterException( - "method script_load_for_pipeline() is not implemented" - ) + self._execution_strategy.script_load_for_pipeline(*args, **kwargs) def delete(self, *names): - """ - "Delete a key specified by ``names``" - """ - if len(names) != 1: - raise RedisClusterException( - "deleting multiple keys is not implemented in pipeline command" - ) - - return self.execute_command("DEL", names[0]) + self._execution_strategy.delete(*names) def unlink(self, *names): - """ - "Unlink a key specified by ``names``" - """ - if len(names) != 1: - raise RedisClusterException( - "unlinking multiple keys is not implemented in pipeline command" - ) - - return self.execute_command("UNLINK", names[0]) + self._execution_strategy.unlink(*names) def block_pipeline_command(name: str) -> Callable[..., Any]: @@ -2694,3 +2473,880 @@ def read(self): return except RedisError: c.result = sys.exc_info()[1] + + +class ExecutionStrategy(ABC): + @property + @abstractmethod + def command_queue(self): + pass + + @abstractmethod + def execute_command(self, *args, **kwargs): + """ + Execution flow for current execution strategy. + + See: ClusterPipeline.execute_command() + """ + pass + + @abstractmethod + def annotate_exception(self, exception, number, command): + """ + Annotate exception according to current execution strategy. + + See: ClusterPipeline.annotate_exception() + """ + pass + + @abstractmethod + def pipeline_execute_command(self, *args, **options): + """ + Pipeline execution flow for current execution strategy. + + See: ClusterPipeline.pipeline_execute_command() + """ + pass + + @abstractmethod + def execute(self, raise_on_error: bool = True) -> List[Any]: + """ + Executes current execution strategy. + + See: ClusterPipeline.execute() + """ + pass + + @abstractmethod + def send_cluster_commands( + self, stack, raise_on_error=True, allow_redirections=True + ): + """ + Sends commands according to current execution strategy. + + See: ClusterPipeline.send_cluster_commands() + """ + pass + + @abstractmethod + def reset(self): + """ + Resets current execution strategy. + + See: ClusterPipeline.reset() + """ + pass + + @abstractmethod + def exists(self, *keys): + pass + + @abstractmethod + def eval(self): + pass + + @abstractmethod + def multi(self): + """ + Starts transactional context. + + See: ClusterPipeline.multi() + """ + pass + + @abstractmethod + def load_scripts(self): + pass + + @abstractmethod + def watch(self, *names): + pass + + @abstractmethod + def unwatch(self): + """ + Unwatches all previously specified keys + + See: ClusterPipeline.unwatch() + """ + pass + + @abstractmethod + def script_load_for_pipeline(self, *args, **kwargs): + pass + + @abstractmethod + def delete(self, *names): + """ + "Delete a key specified by ``names``" + + See: ClusterPipeline.delete() + """ + pass + + @abstractmethod + def unlink(self, *names): + """ + "Unlink a key specified by ``names``" + + See: ClusterPipeline.unlink() + """ + pass + + @abstractmethod + def discard(self): + pass + + +class AbstractStrategy(ExecutionStrategy): + def __init__( + self, + pipe: ClusterPipeline, + ): + self._command_queue: List[PipelineCommand] = [] + self._pipe = pipe + self._nodes_manager = self._pipe.nodes_manager + + @property + def command_queue(self): + return self._command_queue + + @command_queue.setter + def command_queue(self, queue: List[PipelineCommand]): + self._command_queue = queue + + @abstractmethod + def execute_command(self, *args, **kwargs): + pass + + def pipeline_execute_command(self, *args, **options): + self._command_queue.append( + PipelineCommand(args, options, len(self._command_queue)) + ) + return self._pipe + + @abstractmethod + def execute(self, raise_on_error: bool = True) -> List[Any]: + pass + + @abstractmethod + def send_cluster_commands( + self, stack, raise_on_error=True, allow_redirections=True + ): + pass + + @abstractmethod + def reset(self): + pass + + def exists(self, *keys): + return self.execute_command("EXISTS", *keys) + + def eval(self): + """ """ + raise RedisClusterException("method eval() is not implemented") + + def load_scripts(self): + """ """ + raise RedisClusterException("method load_scripts() is not implemented") + + def script_load_for_pipeline(self, *args, **kwargs): + """ """ + raise RedisClusterException( + "method script_load_for_pipeline() is not implemented" + ) + + def annotate_exception(self, exception, number, command): + """ + Provides extra context to the exception prior to it being handled + """ + cmd = " ".join(map(safe_str, command)) + msg = ( + f"Command # {number} ({truncate_text(cmd)}) of pipeline " + f"caused error: {exception.args[0]}" + ) + exception.args = (msg,) + exception.args[1:] + + +class PipelineStrategy(AbstractStrategy): + def __init__(self, pipe: ClusterPipeline): + super().__init__(pipe) + self.command_flags = pipe.command_flags + + def execute_command(self, *args, **kwargs): + return self.pipeline_execute_command(*args, **kwargs) + + def _raise_first_error(self, stack): + """ + Raise the first exception on the stack + """ + for c in stack: + r = c.result + if isinstance(r, Exception): + self.annotate_exception(r, c.position + 1, c.args) + raise r + + def execute(self, raise_on_error: bool = True) -> List[Any]: + stack = self._command_queue + if not stack: + return [] + + try: + return self.send_cluster_commands(stack, raise_on_error) + finally: + self.reset() + + def reset(self): + """ + Reset back to empty pipeline. + """ + self._command_queue = [] + + def send_cluster_commands( + self, stack, raise_on_error=True, allow_redirections=True + ): + """ + Wrapper for RedisCluster.ERRORS_ALLOW_RETRY errors handling. + + If one of the retryable exceptions has been thrown we assume that: + - connection_pool was disconnected + - connection_pool was reseted + - refereh_table_asap set to True + + It will try the number of times specified by + the retries in config option "self.retry" + which defaults to 3 unless manually configured. + + If it reaches the number of times, the command will + raises ClusterDownException. + """ + if not stack: + return [] + retry_attempts = self._pipe.retry.get_retries() + while True: + try: + return self._send_cluster_commands( + stack, + raise_on_error=raise_on_error, + allow_redirections=allow_redirections, + ) + except RedisCluster.ERRORS_ALLOW_RETRY as e: + if retry_attempts > 0: + # Try again with the new cluster setup. All other errors + # should be raised. + retry_attempts -= 1 + pass + else: + raise e + + def _send_cluster_commands( + self, stack, raise_on_error=True, allow_redirections=True + ): + """ + Send a bunch of cluster commands to the redis cluster. + + `allow_redirections` If the pipeline should follow + `ASK` & `MOVED` responses automatically. If set + to false it will raise RedisClusterException. + """ + # the first time sending the commands we send all of + # the commands that were queued up. + # if we have to run through it again, we only retry + # the commands that failed. + attempt = sorted(stack, key=lambda x: x.position) + is_default_node = False + # build a list of node objects based on node names we need to + nodes = {} + + # as we move through each command that still needs to be processed, + # we figure out the slot number that command maps to, then from + # the slot determine the node. + for c in attempt: + while True: + # refer to our internal node -> slot table that + # tells us where a given command should route to. + # (it might be possible we have a cached node that no longer + # exists in the cluster, which is why we do this in a loop) + passed_targets = c.options.pop("target_nodes", None) + if passed_targets and not self._is_nodes_flag(passed_targets): + target_nodes = self._parse_target_nodes(passed_targets) + else: + target_nodes = self._determine_nodes( + *c.args, node_flag=passed_targets + ) + if not target_nodes: + raise RedisClusterException( + f"No targets were found to execute {c.args} command on" + ) + if len(target_nodes) > 1: + raise RedisClusterException( + f"Too many targets for command {c.args}" + ) + + node = target_nodes[0] + if node == self._pipe.get_default_node(): + is_default_node = True + + # now that we know the name of the node + # ( it's just a string in the form of host:port ) + # we can build a list of commands for each node. + node_name = node.name + if node_name not in nodes: + redis_node = self._pipe.get_redis_connection(node) + try: + connection = get_connection(redis_node) + except (ConnectionError, TimeoutError): + for n in nodes.values(): + n.connection_pool.release(n.connection) + # Connection retries are being handled in the node's + # Retry object. Reinitialize the node -> slot table. + self._nodes_manager.initialize() + if is_default_node: + self._pipe.replace_default_node() + raise + nodes[node_name] = NodeCommands( + redis_node.parse_response, + redis_node.connection_pool, + connection, + ) + nodes[node_name].append(c) + break + + # send the commands in sequence. + # we write to all the open sockets for each node first, + # before reading anything + # this allows us to flush all the requests out across the + # network + # so that we can read them from different sockets as they come back. + # we dont' multiplex on the sockets as they come available, + # but that shouldn't make too much difference. + try: + node_commands = nodes.values() + for n in node_commands: + n.write() + + for n in node_commands: + n.read() + finally: + # release all of the redis connections we allocated earlier + # back into the connection pool. + # we used to do this step as part of a try/finally block, + # but it is really dangerous to + # release connections back into the pool if for some + # reason the socket has data still left in it + # from a previous operation. The write and + # read operations already have try/catch around them for + # all known types of errors including connection + # and socket level errors. + # So if we hit an exception, something really bad + # happened and putting any oF + # these connections back into the pool is a very bad idea. + # the socket might have unread buffer still sitting in it, + # and then the next time we read from it we pass the + # buffered result back from a previous command and + # every single request after to that connection will always get + # a mismatched result. + for n in nodes.values(): + n.connection_pool.release(n.connection) + + # if the response isn't an exception it is a + # valid response from the node + # we're all done with that command, YAY! + # if we have more commands to attempt, we've run into problems. + # collect all the commands we are allowed to retry. + # (MOVED, ASK, or connection errors or timeout errors) + attempt = sorted( + ( + c + for c in attempt + if isinstance(c.result, ClusterPipeline.ERRORS_ALLOW_RETRY) + ), + key=lambda x: x.position, + ) + if attempt and allow_redirections: + # RETRY MAGIC HAPPENS HERE! + # send these remaining commands one at a time using `execute_command` + # in the main client. This keeps our retry logic + # in one place mostly, + # and allows us to be more confident in correctness of behavior. + # at this point any speed gains from pipelining have been lost + # anyway, so we might as well make the best + # attempt to get the correct behavior. + # + # The client command will handle retries for each + # individual command sequentially as we pass each + # one into `execute_command`. Any exceptions + # that bubble out should only appear once all + # retries have been exhausted. + # + # If a lot of commands have failed, we'll be setting the + # flag to rebuild the slots table from scratch. + # So MOVED errors should correct themselves fairly quickly. + self._pipe.reinitialize_counter += 1 + if self._pipe._should_reinitialized(): + self._nodes_manager.initialize() + if is_default_node: + self._pipe.replace_default_node() + for c in attempt: + try: + # send each command individually like we + # do in the main client. + c.result = self._pipe.parent_execute_command(*c.args, **c.options) + except RedisError as e: + c.result = e + + # turn the response back into a simple flat array that corresponds + # to the sequence of commands issued in the stack in pipeline.execute() + response = [] + for c in sorted(stack, key=lambda x: x.position): + if c.args[0] in self._pipe.cluster_response_callbacks: + # Remove keys entry, it needs only for cache. + c.options.pop("keys", None) + c.result = self._pipe.cluster_response_callbacks[c.args[0]]( + c.result, **c.options + ) + response.append(c.result) + + if raise_on_error: + self._raise_first_error(stack) + + return response + + def _is_nodes_flag(self, target_nodes): + return isinstance(target_nodes, str) and target_nodes in self._pipe.node_flags + + def _parse_target_nodes(self, target_nodes): + if isinstance(target_nodes, list): + nodes = target_nodes + elif isinstance(target_nodes, ClusterNode): + # Supports passing a single ClusterNode as a variable + nodes = [target_nodes] + elif isinstance(target_nodes, dict): + # Supports dictionaries of the format {node_name: node}. + # It enables to execute commands with multi nodes as follows: + # rc.cluster_save_config(rc.get_primaries()) + nodes = target_nodes.values() + else: + raise TypeError( + "target_nodes type can be one of the following: " + "node_flag (PRIMARIES, REPLICAS, RANDOM, ALL_NODES)," + "ClusterNode, list, or dict. " + f"The passed type is {type(target_nodes)}" + ) + return nodes + + def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]: + # Determine which nodes should be executed the command on. + # Returns a list of target nodes. + command = args[0].upper() + if ( + len(args) >= 2 + and f"{args[0]} {args[1]}".upper() in self._pipe.command_flags + ): + command = f"{args[0]} {args[1]}".upper() + + nodes_flag = kwargs.pop("nodes_flag", None) + if nodes_flag is not None: + # nodes flag passed by the user + command_flag = nodes_flag + else: + # get the nodes group for this command if it was predefined + command_flag = self._pipe.command_flags.get(command) + if command_flag == self._pipe.RANDOM: + # return a random node + return [self._pipe.get_random_node()] + elif command_flag == self._pipe.PRIMARIES: + # return all primaries + return self._pipe.get_primaries() + elif command_flag == self._pipe.REPLICAS: + # return all replicas + return self._pipe.get_replicas() + elif command_flag == self._pipe.ALL_NODES: + # return all nodes + return self._pipe.get_nodes() + elif command_flag == self._pipe.DEFAULT_NODE: + # return the cluster's default node + return [self._nodes_manager.default_node] + elif command in self._pipe.SEARCH_COMMANDS[0]: + return [self._nodes_manager.default_node] + else: + # get the node that holds the key's slot + slot = self._pipe.determine_slot(*args) + node = self._nodes_manager.get_node_from_slot( + slot, + self._pipe.read_from_replicas and command in READ_COMMANDS, + self._pipe.load_balancing_strategy + if command in READ_COMMANDS + else None, + ) + return [node] + + def multi(self): + raise RedisClusterException( + "method multi() is not supported outside of transactional context" + ) + + def discard(self): + raise RedisClusterException( + "method discard() is not supported outside of transactional context" + ) + + def watch(self, *names): + raise RedisClusterException( + "method watch() is not supported outside of transactional context" + ) + + def unwatch(self, *names): + raise RedisClusterException( + "method unwatch() is not supported outside of transactional context" + ) + + def delete(self, *names): + if len(names) != 1: + raise RedisClusterException( + "deleting multiple keys is not implemented in pipeline command" + ) + + return self.execute_command("DEL", names[0]) + + def unlink(self, *names): + if len(names) != 1: + raise RedisClusterException( + "unlinking multiple keys is not implemented in pipeline command" + ) + + return self.execute_command("UNLINK", names[0]) + + +class TransactionStrategy(AbstractStrategy): + NO_SLOTS_COMMANDS = {"UNWATCH"} + IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"} + UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} + SLOT_REDIRECT_ERRORS = (AskError, MovedError) + CONNECTION_ERRORS = ( + ConnectionError, + OSError, + ClusterDownError, + SlotNotCoveredError, + ) + + def __init__(self, pipe: ClusterPipeline): + super().__init__(pipe) + self._explicit_transaction = False + self._watching = False + self._pipeline_slots: Set[int] = set() + self._transaction_connection: Optional[Connection] = None + self._executing = False + self._retry = copy(self._pipe.retry) + self._retry.update_supported_errors( + RedisCluster.ERRORS_ALLOW_RETRY + self.SLOT_REDIRECT_ERRORS + ) + + def _get_client_and_connection_for_transaction(self) -> Tuple[Redis, Connection]: + """ + Find a connection for a pipeline transaction. + + For running an atomic transaction, watch keys ensure that contents have not been + altered as long as the watch commands for those keys were sent over the same + connection. So once we start watching a key, we fetch a connection to the + node that owns that slot and reuse it. + """ + if not self._pipeline_slots: + raise RedisClusterException( + "At least a command with a key is needed to identify a node" + ) + + node: ClusterNode = self._nodes_manager.get_node_from_slot( + list(self._pipeline_slots)[0], False + ) + redis_node: Redis = self._pipe.get_redis_connection(node) + if self._transaction_connection: + if not redis_node.connection_pool.owns_connection( + self._transaction_connection + ): + previous_node = self._nodes_manager.find_connection_owner( + self._transaction_connection + ) + previous_node.connection_pool.release(self._transaction_connection) + self._transaction_connection = None + + if not self._transaction_connection: + self._transaction_connection = get_connection(redis_node) + + return redis_node, self._transaction_connection + + def execute_command(self, *args, **kwargs): + slot_number: Optional[int] = None + if args[0] not in ClusterPipeline.NO_SLOTS_COMMANDS: + slot_number = self._pipe.determine_slot(*args) + + if ( + self._watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS + ) and not self._explicit_transaction: + if args[0] == "WATCH": + self._validate_watch() + + if slot_number is not None: + if self._pipeline_slots and slot_number not in self._pipeline_slots: + raise CrossSlotTransactionError( + "Cannot watch or send commands on different slots" + ) + + self._pipeline_slots.add(slot_number) + elif args[0] not in self.NO_SLOTS_COMMANDS: + raise RedisClusterException( + f"Cannot identify slot number for command: {args[0]}," + "it cannot be triggered in a transaction" + ) + + return self._immediate_execute_command(*args, **kwargs) + else: + if slot_number is not None: + self._pipeline_slots.add(slot_number) + + return self.pipeline_execute_command(*args, **kwargs) + + def _validate_watch(self): + if self._explicit_transaction: + raise RedisError("Cannot issue a WATCH after a MULTI") + + self._watching = True + + def _immediate_execute_command(self, *args, **options): + return self._retry.call_with_retry( + lambda: self._get_connection_and_send_command(*args, **options), + self._reinitialize_on_error, + ) + + def _get_connection_and_send_command(self, *args, **options): + redis_node, connection = self._get_client_and_connection_for_transaction() + return self._send_command_parse_response( + connection, redis_node, args[0], *args, **options + ) + + def _send_command_parse_response( + self, conn, redis_node: Redis, command_name, *args, **options + ): + """ + Send a command and parse the response + """ + + conn.send_command(*args) + output = redis_node.parse_response(conn, command_name, **options) + + if command_name in self.UNWATCH_COMMANDS: + self._watching = False + return output + + def _reinitialize_on_error(self, error): + if self._watching: + if type(error) in self.SLOT_REDIRECT_ERRORS and self._executing: + raise WatchError("Slot rebalancing occurred while watching keys") + + if ( + type(error) in self.SLOT_REDIRECT_ERRORS + or type(error) in self.CONNECTION_ERRORS + ): + if self._transaction_connection: + self._transaction_connection = None + + self._pipe.reinitialize_counter += 1 + if self._pipe._should_reinitialized(): + self._nodes_manager.initialize() + self.reinitialize_counter = 0 + else: + self._nodes_manager.update_moved_exception(error) + + self._executing = False + + def _raise_first_error(self, responses, stack): + """ + Raise the first exception on the stack + """ + for r, cmd in zip(responses, stack): + if isinstance(r, Exception): + self.annotate_exception(r, cmd.position + 1, cmd.args) + raise r + + def execute(self, raise_on_error: bool = True) -> List[Any]: + stack = self._command_queue + if not stack and (not self._watching or not self._pipeline_slots): + return [] + + return self._execute_transaction_with_retries(stack, raise_on_error) + + def _execute_transaction_with_retries( + self, stack: List["PipelineCommand"], raise_on_error: bool + ): + return self._retry.call_with_retry( + lambda: self._execute_transaction(stack, raise_on_error), + self._reinitialize_on_error, + ) + + def _execute_transaction( + self, stack: List["PipelineCommand"], raise_on_error: bool + ): + if len(self._pipeline_slots) > 1: + raise CrossSlotTransactionError( + "All keys involved in a cluster transaction must map to the same slot" + ) + + self._executing = True + + redis_node, connection = self._get_client_and_connection_for_transaction() + + stack = chain( + [PipelineCommand(("MULTI",))], + stack, + [PipelineCommand(("EXEC",))], + ) + commands = [c.args for c in stack if EMPTY_RESPONSE not in c.options] + packed_commands = connection.pack_commands(commands) + connection.send_packed_command(packed_commands) + errors = [] + + # parse off the response for MULTI + # NOTE: we need to handle ResponseErrors here and continue + # so that we read all the additional command messages from + # the socket + try: + redis_node.parse_response(connection, "MULTI") + except ResponseError as e: + self.annotate_exception(e, 0, "MULTI") + errors.append(e) + except self.CONNECTION_ERRORS as cluster_error: + self.annotate_exception(cluster_error, 0, "MULTI") + raise + + # and all the other commands + for i, command in enumerate(self._command_queue): + if EMPTY_RESPONSE in command.options: + errors.append((i, command.options[EMPTY_RESPONSE])) + else: + try: + _ = redis_node.parse_response(connection, "_") + except self.SLOT_REDIRECT_ERRORS as slot_error: + self.annotate_exception(slot_error, i + 1, command.args) + errors.append(slot_error) + except self.CONNECTION_ERRORS as cluster_error: + self.annotate_exception(cluster_error, i + 1, command.args) + raise + except ResponseError as e: + self.annotate_exception(e, i + 1, command.args) + errors.append(e) + + response = None + # parse the EXEC. + try: + response = redis_node.parse_response(connection, "EXEC") + except ExecAbortError: + if errors: + raise errors[0] + raise + + self._executing = False + + # EXEC clears any watched keys + self._watching = False + + if response is None: + raise WatchError("Watched variable changed.") + + # put any parse errors into the response + for i, e in errors: + response.insert(i, e) + + if len(response) != len(self._command_queue): + raise InvalidPipelineStack( + "Unexpected response length for cluster pipeline EXEC." + " Command stack was {} but response had length {}".format( + [c.args[0] for c in self._command_queue], len(response) + ) + ) + + # find any errors in the response and raise if necessary + if raise_on_error or len(errors) > 0: + self._raise_first_error( + response, + self._command_queue, + ) + + # We have to run response callbacks manually + data = [] + for r, cmd in zip(response, self._command_queue): + if not isinstance(r, Exception): + command_name = cmd.args[0] + if command_name in self._pipe.cluster_response_callbacks: + r = self._pipe.cluster_response_callbacks[command_name]( + r, **cmd.options + ) + data.append(r) + return data + + def reset(self): + self._command_queue = [] + + # make sure to reset the connection state in the event that we were + # watching something + if self._transaction_connection: + try: + # call this manually since our unwatch or + # immediate_execute_command methods can call reset() + self._transaction_connection.send_command("UNWATCH") + self._transaction_connection.read_response() + # we can safely return the connection to the pool here since we're + # sure we're no longer WATCHing anything + node = self._nodes_manager.find_connection_owner( + self._transaction_connection + ) + node.redis_connection.connection_pool.release( + self._transaction_connection + ) + self._transaction_connection = None + except self.CONNECTION_ERRORS: + # disconnect will also remove any previous WATCHes + if self._transaction_connection: + self._transaction_connection.disconnect() + + # clean up the other instance attributes + self._watching = False + self._explicit_transaction = False + self._pipeline_slots = set() + self._executing = False + + def send_cluster_commands( + self, stack, raise_on_error=True, allow_redirections=True + ): + raise NotImplementedError( + "send_cluster_commands cannot be executed in transactional context." + ) + + def multi(self): + if self._explicit_transaction: + raise RedisError("Cannot issue nested calls to MULTI") + if self._command_queue: + raise RedisError( + "Commands without an initial WATCH have already been issued" + ) + self._explicit_transaction = True + + def watch(self, *names): + if self._explicit_transaction: + raise RedisError("Cannot issue a WATCH after a MULTI") + + return self.execute_command("WATCH", *names) + + def unwatch(self): + if self._watching: + return self.execute_command("UNWATCH") + + return True + + def discard(self): + self.reset() + + def delete(self, *names): + return self.execute_command("DEL", *names) + + def unlink(self, *names): + return self.execute_command("UNLINK", *names) diff --git a/redis/exceptions.py b/redis/exceptions.py index bad447a086..a00ac65ac1 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -221,3 +221,21 @@ class SlotNotCoveredError(RedisClusterException): class MaxConnectionsError(ConnectionError): ... + + +class CrossSlotTransactionError(RedisClusterException): + """ + Raised when a transaction or watch is triggered in a pipeline + and not all keys or all commands belong to the same slot. + """ + + pass + + +class InvalidPipelineStack(RedisClusterException): + """ + Raised on unexpected response length on pipelines. This is + most likely a handling error on the stack. + """ + + pass diff --git a/tests/test_cluster.py b/tests/test_cluster.py index d4e48e199b..d360ab07f7 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -3015,24 +3015,10 @@ def test_blocked_methods(self, r): They maybe implemented in the future. """ pipe = r.pipeline() - with pytest.raises(RedisClusterException): - pipe.multi() - - with pytest.raises(RedisClusterException): - pipe.immediate_execute_command() - - with pytest.raises(RedisClusterException): - pipe._execute_transaction(None, None, None) with pytest.raises(RedisClusterException): pipe.load_scripts() - with pytest.raises(RedisClusterException): - pipe.watch() - - with pytest.raises(RedisClusterException): - pipe.unwatch() - with pytest.raises(RedisClusterException): pipe.script_load_for_pipeline(None) @@ -3044,14 +3030,6 @@ def test_blocked_arguments(self, r): Currently some arguments is blocked when using in cluster mode. They maybe implemented in the future. """ - with pytest.raises(RedisClusterException) as ex: - r.pipeline(transaction=True) - - assert ( - str(ex.value).startswith("transaction is deprecated in cluster mode") - is True - ) - with pytest.raises(RedisClusterException) as ex: r.pipeline(shard_hint=True) @@ -3109,7 +3087,7 @@ def test_delete_single(self, r): pipe.delete("a") assert pipe.execute() == [1] - def test_multi_delete_unsupported(self, r): + def test_multi_delete_unsupported_cross_slot(self, r): """ Test that multi delete operation is unsupported """ @@ -3119,6 +3097,16 @@ def test_multi_delete_unsupported(self, r): with pytest.raises(RedisClusterException): pipe.delete("a", "b") + def test_multi_delete_supported_single_slot(self, r): + """ + Test that multi delete operation is supported when all keys are in the same hash slot + """ + with r.pipeline(transaction=True) as pipe: + r["{key}:a"] = 1 + r["{key}:b"] = 2 + pipe.delete("{key}:a", "{key}:b") + assert pipe.execute() + def test_unlink_single(self, r): """ Test a single unlink operation @@ -3374,6 +3362,87 @@ def test_empty_stack(self, r): result = p.execute() assert result == [] + @pytest.mark.onlycluster + def test_exec_error_in_response(self, r): + """ + an invalid pipeline command at exec time adds the exception instance + to the list of returned values + """ + hashkey = "{key}" + r[f"{hashkey}:c"] = "a" + with r.pipeline() as pipe: + pipe.set(f"{hashkey}:a", 1).set(f"{hashkey}:b", 2) + pipe.lpush(f"{hashkey}:c", 3).set(f"{hashkey}:d", 4) + result = pipe.execute(raise_on_error=False) + + assert result[0] + assert r[f"{hashkey}:a"] == b"1" + assert result[1] + assert r[f"{hashkey}:b"] == b"2" + + # we can't lpush to a key that's a string value, so this should + # be a ResponseError exception + assert isinstance(result[2], redis.ResponseError) + assert r[f"{hashkey}:c"] == b"a" + + # since this isn't a transaction, the other commands after the + # error are still executed + assert result[3] + assert r[f"{hashkey}:d"] == b"4" + + # make sure the pipe was restored to a working state + assert pipe.set(f"{hashkey}:z", "zzz").execute() == [True] + assert r[f"{hashkey}:z"] == b"zzz" + + def test_exec_error_in_no_transaction_pipeline(self, r): + r["a"] = 1 + with r.pipeline(transaction=False) as pipe: + pipe.llen("a") + pipe.expire("a", 100) + + with pytest.raises(redis.ResponseError) as ex: + pipe.execute() + + assert str(ex.value).startswith( + "Command # 1 (LLEN a) of pipeline caused error: " + ) + + assert r["a"] == b"1" + + @pytest.mark.onlycluster + @skip_if_server_version_lt("2.0.0") + def test_pipeline_discard(self, r): + hashkey = "{key}" + + # empty pipeline should raise an error + with r.pipeline() as pipe: + pipe.set(f"{hashkey}:key", "someval") + with pytest.raises(redis.exceptions.RedisClusterException) as ex: + pipe.discard() + + assert str(ex.value).startswith( + "method discard() is not supported outside of transactional context" + ) + + # setting a pipeline and discarding should do the same + with r.pipeline() as pipe: + pipe.set(f"{hashkey}:key", "someval") + pipe.set(f"{hashkey}:someotherkey", "val") + response = pipe.execute() + pipe.set(f"{hashkey}:key", "another value!") + with pytest.raises(redis.exceptions.RedisClusterException) as ex: + pipe.discard() + + assert str(ex.value).startswith( + "method discard() is not supported outside of transactional context" + ) + + pipe.set(f"{hashkey}:foo", "bar") + response = pipe.execute() + + assert response[0] + assert r.get(f"{hashkey}:foo") == b"bar" + @pytest.mark.onlycluster class TestReadOnlyPipeline: diff --git a/tests/test_cluster_transaction.py b/tests/test_cluster_transaction.py new file mode 100644 index 0000000000..0eb7a4f256 --- /dev/null +++ b/tests/test_cluster_transaction.py @@ -0,0 +1,392 @@ +import threading +from typing import Tuple +from unittest.mock import patch, Mock + +import pytest + +import redis +from redis import CrossSlotTransactionError, ConnectionPool, RedisClusterException +from redis.backoff import NoBackoff +from redis.client import Redis +from redis.cluster import PRIMARY, ClusterNode, NodesManager, RedisCluster +from redis.retry import Retry + +from .conftest import skip_if_server_version_lt + + +def _find_source_and_target_node_for_slot( + r: RedisCluster, slot: int +) -> Tuple[ClusterNode, ClusterNode]: + """Returns a pair of ClusterNodes, where the first node is the + one that owns the slot and the second is a possible target + for that slot, i.e. a primary node different from the first + one. + """ + node_migrating = r.nodes_manager.get_node_from_slot(slot) + assert node_migrating, f"No node could be found that owns slot #{slot}" + + available_targets = [ + n + for n in r.nodes_manager.startup_nodes.values() + if node_migrating.name != n.name and n.server_type == PRIMARY + ] + + assert available_targets, f"No possible target nodes for slot #{slot}" + return node_migrating, available_targets[0] + + +class TestClusterTransaction: + @pytest.mark.onlycluster + def test_pipeline_is_true(self, r): + "Ensure pipeline instances are not false-y" + with r.pipeline(transaction=True) as pipe: + assert pipe + + @pytest.mark.onlycluster + def test_pipeline_empty_transaction(self, r): + r["a"] = 0 + + with r.pipeline(transaction=True) as pipe: + assert pipe.execute() == [] + + @pytest.mark.onlycluster + def test_executes_transaction_against_cluster(self, r): + with r.pipeline(transaction=True) as tx: + tx.set("{foo}bar", "value1") + tx.set("{foo}baz", "value2") + tx.set("{foo}bad", "value3") + tx.get("{foo}bar") + tx.get("{foo}baz") + tx.get("{foo}bad") + assert tx.execute() == [ + b"OK", + b"OK", + b"OK", + b"value1", + b"value2", + b"value3", + ] + + r.flushall() + + tx = r.pipeline(transaction=True) + tx.set("{foo}bar", "value1") + tx.set("{foo}baz", "value2") + tx.set("{foo}bad", "value3") + tx.get("{foo}bar") + tx.get("{foo}baz") + tx.get("{foo}bad") + assert tx.execute() == [b"OK", b"OK", b"OK", b"value1", b"value2", b"value3"] + + @pytest.mark.onlycluster + def test_throws_exception_on_different_hash_slots(self, r): + with r.pipeline(transaction=True) as tx: + tx.set("{foo}bar", "value1") + tx.set("{foobar}baz", "value2") + + with pytest.raises( + CrossSlotTransactionError, + match="All keys involved in a cluster transaction must map to the same slot", + ): + tx.execute() + + @pytest.mark.onlycluster + def test_throws_exception_with_watch_on_different_hash_slots(self, r): + with r.pipeline(transaction=True) as tx: + with pytest.raises( + RedisClusterException, + match="WATCH - all keys must map to the same key slot", + ): + tx.watch("key1", "key2") + + @pytest.mark.onlycluster + def test_transaction_with_watched_keys(self, r): + r["a"] = 0 + + with r.pipeline(transaction=True) as pipe: + pipe.watch("a") + a = pipe.get("a") + pipe.multi() + pipe.set("a", int(a) + 1) + assert pipe.execute() == [b"OK"] + + @pytest.mark.onlycluster + def test_retry_transaction_during_unfinished_slot_migration(self, r): + """ + When a transaction is triggered during a migration, MovedError + or AskError may appear (depends on the key being already migrated + or the key not existing already). The patch on parse_response + simulates such an error, but the slot cache is not updated + (meaning the migration is still ongogin) so the pipeline eventually + fails as if it was retried but the migration is not yet complete. + """ + key = "book" + slot = r.keyslot(key) + node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) + + with patch.object(Redis, "parse_response") as parse_response, patch.object( + NodesManager, "_update_moved_slots" + ) as manager_update_moved_slots: + + def ask_redirect_effect(connection, *args, **options): + if "MULTI" in args: + return + elif "EXEC" in args: + raise redis.exceptions.ExecAbortError() + + raise redis.exceptions.AskError(f"{slot} {node_importing.name}") + + parse_response.side_effect = ask_redirect_effect + + with r.pipeline(transaction=True) as pipe: + pipe.set(key, "val") + with pytest.raises(redis.exceptions.AskError) as ex: + pipe.execute() + + assert str(ex.value).startswith( + "Command # 1 (SET book val) of pipeline caused error:" + f" {slot} {node_importing.name}" + ) + + manager_update_moved_slots.assert_called() + + @pytest.mark.onlycluster + def test_retry_transaction_during_slot_migration_successful(self, r): + """ + If a MovedError or AskError appears when calling EXEC and no key is watched, + the pipeline is retried after updating the node manager slot table. If the + migration was completed, the transaction may then complete successfully. + """ + key = "book" + slot = r.keyslot(key) + node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) + + with patch.object(Redis, "parse_response") as parse_response, patch.object( + NodesManager, "_update_moved_slots" + ) as manager_update_moved_slots: + + def ask_redirect_effect(conn, *args, **options): + # first call should go here, we trigger an AskError + if f"{conn.host}:{conn.port}" == node_migrating.name: + if "MULTI" in args: + return + elif "EXEC" in args: + raise redis.exceptions.ExecAbortError() + + raise redis.exceptions.AskError(f"{slot} {node_importing.name}") + # if the slot table is updated, the next call will go here + elif f"{conn.host}:{conn.port}" == node_importing.name: + if "EXEC" in args: + return [ + "MOCK_OK" + ] # mock value to validate this section was called + return + else: + assert False, f"unexpected node {conn.host}:{conn.port} was called" + + def update_moved_slot(): # simulate slot table update + ask_error = r.nodes_manager._moved_exception + assert ask_error is not None, "No AskError was previously triggered" + assert f"{ask_error.host}:{ask_error.port}" == node_importing.name + r.nodes_manager._moved_exception = None + r.nodes_manager.slots_cache[slot] = [node_importing] + + parse_response.side_effect = ask_redirect_effect + manager_update_moved_slots.side_effect = update_moved_slot + + result = None + with r.pipeline(transaction=True) as pipe: + pipe.multi() + pipe.set(key, "val") + result = pipe.execute() + + assert result and "MOCK_OK" in result, "Target node was not called" + + @pytest.mark.onlycluster + def test_retry_transaction_with_watch_after_slot_migration(self, r): + """ + If a MovedError or AskError appears when calling WATCH, the client + must attempt to recover itself before proceeding and no WatchError + should appear. + """ + key = "book" + slot = r.keyslot(key) + r.reinitialize_steps = 1 + + # force a MovedError on the first call to pipe.watch() + # by switching the node that owns the slot to another one + _node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) + r.nodes_manager.slots_cache[slot] = [node_importing] + + with r.pipeline(transaction=True) as pipe: + pipe.watch(key) + pipe.multi() + pipe.set(key, "val") + assert pipe.execute() == [b"OK"] + + @pytest.mark.onlycluster + def test_retry_transaction_with_watch_during_slot_migration(self, r): + """ + If a MovedError or AskError appears when calling EXEC and keys were + being watched before the migration started, a WatchError should appear. + These errors imply resetting the connection and connecting to a new node, + so watches are lost anyway and the client code must be notified. + """ + key = "book" + slot = r.keyslot(key) + node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) + + with patch.object(Redis, "parse_response") as parse_response: + + def ask_redirect_effect(conn, *args, **options): + if f"{conn.host}:{conn.port}" == node_migrating.name: + # we simulate the watch was sent before the migration started + if "WATCH" in args: + return b"OK" + # but the pipeline was triggered after the migration started + elif "MULTI" in args: + return + elif "EXEC" in args: + raise redis.exceptions.ExecAbortError() + + raise redis.exceptions.AskError(f"{slot} {node_importing.name}") + # we should not try to connect to any other node + else: + assert False, f"unexpected node {conn.host}:{conn.port} was called" + + parse_response.side_effect = ask_redirect_effect + + with r.pipeline(transaction=True) as pipe: + pipe.watch(key) + pipe.multi() + pipe.set(key, "val") + with pytest.raises(redis.exceptions.WatchError) as ex: + pipe.execute() + + assert str(ex.value).startswith( + "Slot rebalancing occurred while watching keys" + ) + + @pytest.mark.onlycluster + def test_retry_transaction_on_connection_error(self, r, mock_connection): + key = "book" + slot = r.keyslot(key) + + mock_connection.read_response.side_effect = redis.exceptions.ConnectionError( + "Conn error" + ) + mock_connection.retry = Retry(NoBackoff(), 0) + mock_pool = Mock(spec=ConnectionPool) + mock_pool.get_connection.return_value = mock_connection + mock_pool._available_connections = [mock_connection] + mock_pool._lock = threading.Lock() + + _node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) + node_importing.redis_connection.connection_pool = mock_pool + r.nodes_manager.slots_cache[slot] = [node_importing] + r.reinitialize_steps = 1 + + with r.pipeline(transaction=True) as pipe: + pipe.set(key, "val") + assert pipe.execute() == [b"OK"] + + @pytest.mark.onlycluster + def test_retry_transaction_on_connection_error_with_watched_keys( + self, r, mock_connection + ): + key = "book" + slot = r.keyslot(key) + + mock_connection.read_response.side_effect = redis.exceptions.ConnectionError( + "Conn error" + ) + mock_connection.retry = Retry(NoBackoff(), 0) + mock_pool = Mock(spec=ConnectionPool) + mock_pool.get_connection.return_value = mock_connection + mock_pool._available_connections = [mock_connection] + mock_pool._lock = threading.Lock() + + _node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) + node_importing.redis_connection.connection_pool = mock_pool + r.nodes_manager.slots_cache[slot] = [node_importing] + r.reinitialize_steps = 1 + + with r.pipeline(transaction=True) as pipe: + pipe.watch(key) + pipe.multi() + pipe.set(key, "val") + assert pipe.execute() == [b"OK"] + + @pytest.mark.onlycluster + def test_exec_error_raised(self, r): + hashkey = "{key}" + r[f"{hashkey}:c"] = "a" + with r.pipeline(transaction=True) as pipe: + pipe.set(f"{hashkey}:a", 1).set(f"{hashkey}:b", 2) + pipe.lpush(f"{hashkey}:c", 3).set(f"{hashkey}:d", 4) + with pytest.raises(redis.ResponseError) as ex: + pipe.execute() + assert str(ex.value).startswith( + "Command # 3 (LPUSH {key}:c 3) of pipeline caused error: " + ) + + # make sure the pipe was restored to a working state + assert pipe.set(f"{hashkey}:z", "zzz").execute() == [b"OK"] + assert r[f"{hashkey}:z"] == b"zzz" + + @pytest.mark.onlycluster + def test_parse_error_raised(self, r): + hashkey = "{key}" + with r.pipeline(transaction=True) as pipe: + # the zrem is invalid because we don't pass any keys to it + pipe.set(f"{hashkey}:a", 1).zrem(f"{hashkey}:b").set(f"{hashkey}:b", 2) + with pytest.raises(redis.ResponseError) as ex: + pipe.execute() + + assert str(ex.value).startswith( + "Command # 2 (ZREM {key}:b) of pipeline caused error: wrong number" + ) + + # make sure the pipe was restored to a working state + assert pipe.set(f"{hashkey}:z", "zzz").execute() == [b"OK"] + assert r[f"{hashkey}:z"] == b"zzz" + + @pytest.mark.onlycluster + def test_transaction_callable(self, r): + hashkey = "{key}" + r[f"{hashkey}:a"] = 1 + r[f"{hashkey}:b"] = 2 + has_run = [] + + def my_transaction(pipe): + a_value = pipe.get(f"{hashkey}:a") + assert a_value in (b"1", b"2") + b_value = pipe.get(f"{hashkey}:b") + assert b_value == b"2" + + # silly run-once code... incr's "a" so WatchError should be raised + # forcing this all to run again. this should incr "a" once to "2" + if not has_run: + r.incr(f"{hashkey}:a") + has_run.append("it has") + + pipe.multi() + pipe.set(f"{hashkey}:c", int(a_value) + int(b_value)) + + result = r.transaction(my_transaction, f"{hashkey}:a", f"{hashkey}:b") + assert result == [b"OK"] + assert r[f"{hashkey}:c"] == b"4" + + @pytest.mark.onlycluster + @skip_if_server_version_lt("2.0.0") + def test_transaction_discard(self, r): + hashkey = "{key}" + + # pipelines enabled as transactions can be discarded at any point + with r.pipeline(transaction=True) as pipe: + pipe.watch(f"{hashkey}:key") + pipe.set(f"{hashkey}:key", "someval") + pipe.discard() + + assert not pipe._execution_strategy._watching + assert not pipe.command_stack From ed35c58b2682c5dd0f64cdbd2bec9797a70ab758 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=96mer=20Feyzioglu?= <95534250+omerfeyzioglu@users.noreply.github.com> Date: Mon, 12 May 2025 18:55:11 +0300 Subject: [PATCH 110/113] Fix RedisCluster ssl_check_hostname not set to connections. For SSL verification with ssl_cert_reqs="none", check_hostname is set to False (#3637) * Fix SSL verification with ssl_cert_reqs=none and ssl_check_hostname=True * Add ssl_check_hostname to REDIS_ALLOWED_KEYS and fix default value in RedisSSLContext --- CHANGES | 1 + docs/examples/ssl_connection_examples.ipynb | 5 +- redis/asyncio/connection.py | 6 ++- redis/cluster.py | 1 + redis/connection.py | 4 +- tests/test_asyncio/test_cluster.py | 7 +-- tests/test_asyncio/test_ssl.py | 56 +++++++++++++++++++++ tests/test_ssl.py | 27 ++++++++-- 8 files changed, 90 insertions(+), 17 deletions(-) create mode 100644 tests/test_asyncio/test_ssl.py diff --git a/CHANGES b/CHANGES index 1a1f4eca11..dbc27dbacc 100644 --- a/CHANGES +++ b/CHANGES @@ -71,6 +71,7 @@ * Close Unix sockets if the connection attempt fails. This prevents `ResourceWarning`s. (#3314) * Close SSL sockets if the connection attempt fails, or if validations fail. (#3317) * Eliminate mutable default arguments in the `redis.commands.core.Script` class. (#3332) + * Fix SSL verification with `ssl_cert_reqs="none"` and `ssl_check_hostname=True` by automatically setting `check_hostname=False` when `verify_mode=ssl.CERT_NONE` (#3635) * Allow newer versions of PyJWT as dependency. (#3630) * 4.1.3 (Feb 8, 2022) diff --git a/docs/examples/ssl_connection_examples.ipynb b/docs/examples/ssl_connection_examples.ipynb index a09b87ec1f..3fcc7bc3cc 100644 --- a/docs/examples/ssl_connection_examples.ipynb +++ b/docs/examples/ssl_connection_examples.ipynb @@ -37,7 +37,6 @@ " host='localhost',\n", " port=6666,\n", " ssl=True,\n", - " ssl_check_hostname=False,\n", " ssl_cert_reqs=\"none\",\n", ")\n", "r.ping()" @@ -69,7 +68,7 @@ "source": [ "import redis\n", "\n", - "r = redis.from_url(\"rediss://localhost:6666?ssl_cert_reqs=none&ssl_check_hostname=False&decode_responses=True&health_check_interval=2\")\n", + "r = redis.from_url(\"rediss://localhost:6666?ssl_cert_reqs=none&decode_responses=True&health_check_interval=2\")\n", "r.ping()" ] }, @@ -103,7 +102,6 @@ " host=\"localhost\",\n", " port=6666,\n", " connection_class=redis.SSLConnection,\n", - " ssl_check_hostname=False,\n", " ssl_cert_reqs=\"none\",\n", ")\n", "\n", @@ -143,7 +141,6 @@ " port=6666,\n", " ssl=True,\n", " ssl_min_version=ssl.TLSVersion.TLSv1_3,\n", - " ssl_check_hostname=False,\n", " ssl_cert_reqs=\"none\",\n", ")\n", "r.ping()" diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 77131ab951..d1ae81d269 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -868,7 +868,7 @@ def __init__( cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None, ca_certs: Optional[str] = None, ca_data: Optional[str] = None, - check_hostname: bool = False, + check_hostname: bool = True, min_version: Optional[TLSVersion] = None, ciphers: Optional[str] = None, ): @@ -893,7 +893,9 @@ def __init__( self.cert_reqs = cert_reqs self.ca_certs = ca_certs self.ca_data = ca_data - self.check_hostname = check_hostname + self.check_hostname = ( + check_hostname if self.cert_reqs != ssl.CERT_NONE else False + ) self.min_version = min_version self.ciphers = ciphers self.context: Optional[SSLContext] = None diff --git a/redis/cluster.py b/redis/cluster.py index b614c598f9..af60e1c76c 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -185,6 +185,7 @@ def parse_cluster_myshardid(resp, **options): "ssl_cert_reqs", "ssl_keyfile", "ssl_password", + "ssl_check_hostname", "unix_socket_path", "username", "cache", diff --git a/redis/connection.py b/redis/connection.py index dab45906d2..cc805e442f 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1083,7 +1083,9 @@ def __init__( self.ca_certs = ssl_ca_certs self.ca_data = ssl_ca_data self.ca_path = ssl_ca_path - self.check_hostname = ssl_check_hostname + self.check_hostname = ( + ssl_check_hostname if self.cert_reqs != ssl.CERT_NONE else False + ) self.certificate_password = ssl_password self.ssl_validate_ocsp = ssl_validate_ocsp self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 5a8b6dfee7..b56ad6dbd1 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -3139,9 +3139,7 @@ async def test_ssl_with_invalid_cert( async def test_ssl_connection( self, create_client: Callable[..., Awaitable[RedisCluster]] ) -> None: - async with await create_client( - ssl=True, ssl_check_hostname=False, ssl_cert_reqs="none" - ) as rc: + async with await create_client(ssl=True, ssl_cert_reqs="none") as rc: assert await rc.ping() @pytest.mark.parametrize( @@ -3157,7 +3155,6 @@ async def test_ssl_connection_tls12_custom_ciphers( ) -> None: async with await create_client( ssl=True, - ssl_check_hostname=False, ssl_cert_reqs="none", ssl_min_version=ssl.TLSVersion.TLSv1_2, ssl_ciphers=ssl_ciphers, @@ -3169,7 +3166,6 @@ async def test_ssl_connection_tls12_custom_ciphers_invalid( ) -> None: async with await create_client( ssl=True, - ssl_check_hostname=False, ssl_cert_reqs="none", ssl_min_version=ssl.TLSVersion.TLSv1_2, ssl_ciphers="foo:bar", @@ -3191,7 +3187,6 @@ async def test_ssl_connection_tls13_custom_ciphers( # TLSv1.3 does not support changing the ciphers async with await create_client( ssl=True, - ssl_check_hostname=False, ssl_cert_reqs="none", ssl_min_version=ssl.TLSVersion.TLSv1_2, ssl_ciphers=ssl_ciphers, diff --git a/tests/test_asyncio/test_ssl.py b/tests/test_asyncio/test_ssl.py new file mode 100644 index 0000000000..75800f22de --- /dev/null +++ b/tests/test_asyncio/test_ssl.py @@ -0,0 +1,56 @@ +from urllib.parse import urlparse +import pytest +import pytest_asyncio +import redis.asyncio as redis + +# Skip test or not based on cryptography installation +try: + import cryptography # noqa + + skip_if_cryptography = pytest.mark.skipif(False, reason="") + skip_if_nocryptography = pytest.mark.skipif(False, reason="") +except ImportError: + skip_if_cryptography = pytest.mark.skipif(True, reason="cryptography not installed") + skip_if_nocryptography = pytest.mark.skipif( + True, reason="cryptography not installed" + ) + + +@pytest.mark.ssl +class TestSSL: + """Tests for SSL connections in asyncio.""" + + @pytest_asyncio.fixture() + async def _get_client(self, request): + ssl_url = request.config.option.redis_ssl_url + p = urlparse(ssl_url)[1].split(":") + client = redis.Redis(host=p[0], port=p[1], ssl=True) + yield client + await client.aclose() + + async def test_ssl_with_invalid_cert(self, _get_client): + """Test SSL connection with invalid certificate.""" + pass + + async def test_cert_reqs_none_with_check_hostname(self, request): + """Test that when ssl_cert_reqs=none is used with ssl_check_hostname=True, + the connection is created successfully with check_hostname internally set to False""" + ssl_url = request.config.option.redis_ssl_url + parsed_url = urlparse(ssl_url) + r = redis.Redis( + host=parsed_url.hostname, + port=parsed_url.port, + ssl=True, + ssl_cert_reqs="none", + # Check that ssl_check_hostname is ignored, when ssl_cert_reqs=none + ssl_check_hostname=True, + ) + try: + # Connection should be successful + assert await r.ping() + # check_hostname should have been automatically set to False + assert r.connection_pool.connection_class == redis.SSLConnection + conn = r.connection_pool.make_connection() + assert conn.check_hostname is False + finally: + await r.aclose() diff --git a/tests/test_ssl.py b/tests/test_ssl.py index 5aa33353a8..cb3f227629 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -42,7 +42,6 @@ def test_ssl_connection(self, request): host=p[0], port=p[1], ssl=True, - ssl_check_hostname=False, ssl_cert_reqs="none", ) assert r.ping() @@ -105,7 +104,6 @@ def test_ssl_connection_tls12_custom_ciphers(self, request, ssl_ciphers): host=p[0], port=p[1], ssl=True, - ssl_check_hostname=False, ssl_cert_reqs="none", ssl_min_version=ssl.TLSVersion.TLSv1_3, ssl_ciphers=ssl_ciphers, @@ -120,7 +118,6 @@ def test_ssl_connection_tls12_custom_ciphers_invalid(self, request): host=p[0], port=p[1], ssl=True, - ssl_check_hostname=False, ssl_cert_reqs="none", ssl_min_version=ssl.TLSVersion.TLSv1_2, ssl_ciphers="foo:bar", @@ -145,7 +142,6 @@ def test_ssl_connection_tls13_custom_ciphers(self, request, ssl_ciphers): host=p[0], port=p[1], ssl=True, - ssl_check_hostname=False, ssl_cert_reqs="none", ssl_min_version=ssl.TLSVersion.TLSv1_2, ssl_ciphers=ssl_ciphers, @@ -309,3 +305,26 @@ def test_mock_ocsp_staple(self, request): r.ping() assert "no ocsp response present" in str(e) r.close() + + def test_cert_reqs_none_with_check_hostname(self, request): + """Test that when ssl_cert_reqs=none is used with ssl_check_hostname=True, + the connection is created successfully with check_hostname internally set to False""" + ssl_url = request.config.option.redis_ssl_url + parsed_url = urlparse(ssl_url) + r = redis.Redis( + host=parsed_url.hostname, + port=parsed_url.port, + ssl=True, + ssl_cert_reqs="none", + # Check that ssl_check_hostname is ignored, when ssl_cert_reqs=none + ssl_check_hostname=True, + ) + try: + # Connection should be successful + assert r.ping() + # check_hostname should have been automatically set to False + assert r.connection_pool.connection_class == redis.SSLConnection + conn = r.connection_pool.make_connection() + assert conn.check_hostname is False + finally: + r.close() From 66d4a02654904dd0a730fc799603ae999577202b Mon Sep 17 00:00:00 2001 From: petyaslavova Date: Tue, 13 May 2025 12:16:13 +0300 Subject: [PATCH 111/113] Updating the readme and lib version to contain the changes from the latest stable release (#3644) --- README.md | 52 ++++++++++++++++++++++++++++++++++++++++------ docker-compose.yml | 5 +++-- redis/__init__.py | 2 +- 3 files changed, 50 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 8b4d4b6875..177f78feeb 100644 --- a/README.md +++ b/README.md @@ -31,12 +31,17 @@ The Python interface to the Redis key-value store. ## Installation -Start a redis via docker: +Start a redis via docker (for Redis versions >= 8.0): ``` bash -docker run -p 6379:6379 -it redis/redis-stack:latest +docker run -p 6379:6379 -it redis:latest ``` +Start a redis via docker (for Redis versions < 8.0): + +``` bash +docker run -p 6379:6379 -it redis/redis-stack:latest + To install redis-py, simply: ``` bash @@ -54,7 +59,7 @@ Looking for a high-level library to handle object mapping? See [redis-om-python] ## Supported Redis Versions -The most recent version of this library supports redis version [5.0](https://github.com/redis/redis/blob/5.0/00-RELEASENOTES), [6.0](https://github.com/redis/redis/blob/6.0/00-RELEASENOTES), [6.2](https://github.com/redis/redis/blob/6.2/00-RELEASENOTES), [7.0](https://github.com/redis/redis/blob/7.0/00-RELEASENOTES), [7.2](https://github.com/redis/redis/blob/7.2/00-RELEASENOTES) and [7.4](https://github.com/redis/redis/blob/7.4/00-RELEASENOTES). +The most recent version of this library supports Redis version [7.2](https://github.com/redis/redis/blob/7.2/00-RELEASENOTES), [7.4](https://github.com/redis/redis/blob/7.4/00-RELEASENOTES) and [8.0](https://github.com/redis/redis/blob/8.0/00-RELEASENOTES). The table below highlights version compatibility of the most-recent library versions and redis versions. @@ -62,7 +67,8 @@ The table below highlights version compatibility of the most-recent library vers |-----------------|-------------------| | 3.5.3 | <= 6.2 Family of releases | | >= 4.5.0 | Version 5.0 to 7.0 | -| >= 5.0.0 | Version 5.0 to current | +| >= 5.0.0 | Version 5.0 to 7.4 | +| >= 6.0.0 | Version 7.2 to current | ## Usage @@ -152,8 +158,42 @@ The following example shows how to utilize [Redis Pub/Sub](https://redis.io/docs {'pattern': None, 'type': 'subscribe', 'channel': b'my-second-channel', 'data': 1} ``` +### Redis’ search and query capabilities default dialect + +Release 6.0.0 introduces a client-side default dialect for Redis’ search and query capabilities. +By default, the client now overrides the server-side dialect with version 2, automatically appending *DIALECT 2* to commands like *FT.AGGREGATE* and *FT.SEARCH*. --------------------------- +**Important**: Be aware that the query dialect may impact the results returned. If needed, you can revert to a different dialect version by configuring the client accordingly. + +``` python +>>> from redis.commands.search.field import TextField +>>> from redis.commands.search.query import Query +>>> from redis.commands.search.index_definition import IndexDefinition +>>> import redis + +>>> r = redis.Redis(host='localhost', port=6379, db=0) +>>> r.ft().create_index( +>>> (TextField("name"), TextField("lastname")), +>>> definition=IndexDefinition(prefix=["test:"]), +>>> ) + +>>> r.hset("test:1", "name", "James") +>>> r.hset("test:1", "lastname", "Brown") + +>>> # Query with default DIALECT 2 +>>> query = "@name: James Brown" +>>> q = Query(query) +>>> res = r.ft().search(q) + +>>> # Query with explicit DIALECT 1 +>>> query = "@name: James Brown" +>>> q = Query(query).dialect(1) +>>> res = r.ft().search(q) +``` + +You can find further details in the [query dialect documentation](https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/dialects/). + +--------------------------------------------- ### Author @@ -169,4 +209,4 @@ Special thanks to: system. - Paul Hubbard for initial packaging support. -[![Redis](./docs/_static/logo-redis.svg)](https://redis.io) +[![Redis](./docs/_static/logo-redis.svg)](https://redis.io) \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 76a60398f3..bcf85df1a7 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,9 +1,10 @@ --- +# image tag 8.0-RC2-pre is the one matching the 8.0 GA release x-client-libs-stack-image: &client-libs-stack-image - image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_STACK_IMAGE_TAG:-rs-7.4.0-v2}" + image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_STACK_IMAGE_TAG:-8.0-RC2-pre}" x-client-libs-image: &client-libs-image - image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_IMAGE_TAG:-7.4.2}" + image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_IMAGE_TAG:-8.0-RC2-pre}" services: diff --git a/redis/__init__.py b/redis/__init__.py index 14030205e3..cd3ee12adb 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -45,7 +45,7 @@ def int_or_str(value): return value -__version__ = "5.2.1" +__version__ = "6.1.0" VERSION = tuple(map(int_or_str, __version__.split("."))) From 7c600dcf5da3d89428282b729698905d40c8fecb Mon Sep 17 00:00:00 2001 From: Jonas Dittrich <58814480+Kakadus@users.noreply.github.com> Date: Thu, 28 Nov 2024 19:48:23 +0100 Subject: [PATCH 112/113] Add dynamic_startup_nodes parameter to async RedisCluster --- CHANGES | 342 --------------------------------------- redis/asyncio/cluster.py | 1 - 2 files changed, 343 deletions(-) diff --git a/CHANGES b/CHANGES index dbc27dbacc..5df6f81c84 100644 --- a/CHANGES +++ b/CHANGES @@ -1,346 +1,4 @@ - * Support transactions in ClusterPipeline - * Removing support for RedisGraph module. RedisGraph support is deprecated since Redis Stack 7.2 (https://redis.com/blog/redisgraph-eol/) - * Fix lock.extend() typedef to accept float TTL extension - * Update URL in the readme linking to Redis University - * Move doctests (doc code examples) to main branch - * Update `ResponseT` type hint - * Allow to control the minimum SSL version - * Add an optional lock_name attribute to LockError. - * Fix return types for `get`, `set_path` and `strappend` in JSONCommands - * Connection.register_connect_callback() is made public. - * Fix async `read_response` to use `disable_decoding`. - * Add 'aclose()' methods to async classes, deprecate async close(). - * Fix #2831, add auto_close_connection_pool=True arg to asyncio.Redis.from_url() - * Fix incorrect redis.asyncio.Cluster type hint for `retry_on_error` - * Fix dead weakref in sentinel connection causing ReferenceError (#2767) - * Fix #2768, Fix KeyError: 'first-entry' in parse_xinfo_stream. - * Fix #2749, remove unnecessary __del__ logic to close connections. - * Fix #2754, adding a missing argument to SentinelManagedConnection - * Fix `xadd` command to accept non-negative `maxlen` including 0 - * Revert #2104, #2673, add `disconnect_on_error` option to `read_response()` (issues #2506, #2624) - * Add `address_remap` parameter to `RedisCluster` - * Fix incorrect usage of once flag in async Sentinel - * asyncio: Fix memory leak caused by hiredis (#2693) - * Allow data to drain from async PythonParser when reading during a disconnect() - * Use asyncio.timeout() instead of async_timeout.timeout() for python >= 3.11 (#2602) - * Add a Dependabot configuration to auto-update GitHub action versions. - * Add test and fix async HiredisParser when reading during a disconnect() (#2349) - * Use hiredis-py pack_command if available. - * Support `.unlink()` in ClusterPipeline - * Simplify synchronous SocketBuffer state management - * Fix string cleanse in Redis Graph - * Make PythonParser resumable in case of error (#2510) - * Add `timeout=None` in `SentinelConnectionManager.read_response` - * Documentation fix: password protected socket connection (#2374) - * Allow `timeout=None` in `PubSub.get_message()` to wait forever - * add `nowait` flag to `asyncio.Connection.disconnect()` - * Update README.md links - * Fix timezone handling for datetime to unixtime conversions - * Fix start_id type for XAUTOCLAIM - * Remove verbose logging from cluster.py - * Add retry mechanism to async version of Connection - * Compare commands case-insensitively in the asyncio command parser - * Allow negative `retries` for `Retry` class to retry forever - * Add `items` parameter to `hset` signature - * Create codeql-analysis.yml (#1988). Thanks @chayim - * Add limited support for Lua scripting with RedisCluster - * Implement `.lock()` method on RedisCluster - * Fix cursor returned by SCAN for RedisCluster & change default target to PRIMARIES - * Fix scan_iter for RedisCluster - * Remove verbose logging when initializing ClusterPubSub, ClusterPipeline or RedisCluster - * Fix broken connection writer lock-up for asyncio (#2065) - * Fix auth bug when provided with no username (#2086) - * Fix missing ClusterPipeline._lock (#2189) - * Added dynaminc_startup_nodes configuration to RedisCluster - * Fix reusing the old nodes' connections when cluster topology refresh is being done - * Fix RedisCluster to immediately raise AuthenticationError without a retry - * ClusterPipeline Doesn't Handle ConnectionError for Dead Hosts (#2225) - * Remove compatibility code for old versions of Hiredis, drop Packaging dependency - * The `deprecated` library is no longer a dependency - * Failover handling improvements for RedisCluster and Async RedisCluster (#2377) - * Fixed "cannot pickle '_thread.lock' object" bug (#2354, #2297) - * Added CredentialsProvider class to support password rotation - * Enable Lock for asyncio cluster mode - * Fix Sentinel.execute_command doesn't execute across the entire sentinel cluster bug (#2458) - * Added a replacement for the default cluster node in the event of failure (#2463) - * Fix for Unhandled exception related to self.host with unix socket (#2496) - * Improve error output for master discovery - * Make `ClusterCommandsProtocol` an actual Protocol - * Add `sum` to DUPLICATE_POLICY documentation of `TS.CREATE`, `TS.ADD` and `TS.ALTER` - * Prevent async ClusterPipeline instances from becoming "false-y" in case of empty command stack (#3061) - * Close Unix sockets if the connection attempt fails. This prevents `ResourceWarning`s. (#3314) - * Close SSL sockets if the connection attempt fails, or if validations fail. (#3317) - * Eliminate mutable default arguments in the `redis.commands.core.Script` class. (#3332) - * Fix SSL verification with `ssl_cert_reqs="none"` and `ssl_check_hostname=True` by automatically setting `check_hostname=False` when `verify_mode=ssl.CERT_NONE` (#3635) - * Allow newer versions of PyJWT as dependency. (#3630) -* 4.1.3 (Feb 8, 2022) - * Fix flushdb and flushall (#1926) - * Add redis5 and redis4 dockers (#1871) - * Change json.clear test multi to be up to date with redisjson (#1922) - * Fixing volume for unstable_cluster docker (#1914) - * Update changes file with changes since 4.0.0-beta2 (#1915) -* 4.1.2 (Jan 27, 2022) - * Invalid OCSP certificates should raise ConnectionError on failed validation (#1907) - * Added retry mechanism on socket timeouts when connecting to the server (#1895) - * LMOVE, BLMOVE return incorrect responses (#1906) - * Fixing AttributeError in UnixDomainSocketConnection (#1903) - * Fixing TypeError in GraphCommands.explain (#1901) - * For tests, increasing wait time for the cluster (#1908) - * Increased pubsub's wait_for_messages timeout to prevent flaky tests (#1893) - * README code snippets formatted to highlight properly (#1888) - * Fix link in the main page (#1897) - * Documentation fixes: JSON Example, SSL Connection Examples, RTD version (#1887) - * Direct link to readthedocs (#1885) -* 4.1.1 (Jan 17, 2022) - * Add retries to connections in Sentinel Pools (#1879) - * OCSP Stapling Support (#1873) - * Define incr/decr as aliases of incrby/decrby (#1874) - * FT.CREATE - support MAXTEXTFIELDS, TEMPORARY, NOHL, NOFREQS, SKIPINITIALSCAN (#1847) - * Timeseries docs fix (#1877) - * get_connection: catch OSError too (#1832) - * Set keys var otherwise variable not created (#1853) - * Clusters should optionally require full slot coverage (#1845) - * Triple quote docstrings in client.py PEP 257 (#1876) - * syncing requirements (#1870) - * Typo and typing in GraphCommands documentation (#1855) - * Allowing poetry and redis-py to install together (#1854) - * setup.py: Add project_urls for PyPI (#1867) - * Support test with redis unstable docker (#1850) - * Connection examples (#1835) - * Documentation cleanup (#1841) -* 4.1.0 (Dec 26, 2021) - * OCSP stapling support (#1820) - * Support for SELECT (#1825) - * Support for specifying error types with retry (#1817) - * Support for RESET command since Redis 6.2.0 (#1824) - * Support CLIENT TRACKING (#1612) - * Support WRITE in CLIENT PAUSE (#1549) - * JSON set_file and set_path support (#1818) - * Allow ssl_ca_path with rediss:// urls (#1814) - * Support for password-encrypted SSL private keys (#1782) - * Support SYNC and PSYNC (#1741) - * Retry on error exception and timeout fixes (#1821) - * Fixing read race condition during pubsub (#1737) - * Fixing exception in listen (#1823) - * Fixed MovedError, and stopped iterating through startup nodes when slots are fully covered (#1819) - * Socket not closing after server disconnect (#1797) - * Single sourcing the package version (#1791) - * Ensure redis_connect_func is set on uds connection (#1794) - * SRTALGO - Skip for redis versions greater than 7.0.0 (#1831) - * Documentation updates (#1822) - * Add CI action to install package from repository commit hash (#1781) (#1790) - * Fix link in lmove docstring (#1793) - * Disabling JSON.DEBUG tests (#1787) - * Migrated targeted nodes to kwargs in Cluster Mode (#1762) - * Added support for MONITOR in clusters (#1756) - * Adding ROLE Command (#1610) - * Integrate RedisBloom support (#1683) - * Adding RedisGraph support (#1556) - * Allow overriding connection class via keyword arguments (#1752) - * Aggregation LOAD * support for RediSearch (#1735) - * Adding cluster, bloom, and graph docs (#1779) - * Add packaging to setup_requires, and use >= to play nice to setup.py (fixes #1625) (#1780) - * Fixing the license link in the readme (#1778) - * Removing distutils from tests (#1773) - * Fix cluster ACL tests (#1774) - * Improved RedisCluster's reinitialize_steps and documentation (#1765) - * Added black and isort (#1734) - * Link Documents for all module commands (#1711) - * Pyupgrade + flynt + f-strings (#1759) - * Remove unused aggregation subclasses in RediSearch (#1754) - * Adding RedisCluster client to support Redis Cluster Mode (#1660) - * Support RediSearch FT.PROFILE command (#1727) - * Adding support for non-decodable commands (#1731) - * COMMAND GETKEYS support (#1738) - * RedisJSON 2.0.4 behaviour support (#1747) - * Removing deprecating distutils (PEP 632) (#1730) - * Updating PR template (#1745) - * Removing duplication of Script class (#1751) - * Splitting documentation for read the docs (#1743) - * Improve code coverage for aggregation tests (#1713) - * Fixing COMMAND GETKEYS tests (#1750) - * GitHub release improvements (#1684) -* 4.0.2 (Nov 22, 2021) - * Restoring Sentinel commands to redis client (#1723) - * Better removal of hiredis warning (#1726) - * Adding links to redis documents in function calls (#1719) -* 4.0.1 (Nov 17, 2021) - * Removing command on initial connections (#1722) - * Removing hiredis warning when not installed (#1721) -* 4.0.0 (Nov 15, 2021) - * FT.EXPLAINCLI intentionally raising NotImplementedError - * Restoring ZRANGE desc for Redis < 6.2.0 (#1697) - * Response parsing occasionally fails to parse floats (#1692) - * Re-enabling read-the-docs (#1707) - * Call HSET after FT.CREATE to avoid keyspace scan (#1706) - * Unit tests fixes for compatibility (#1703) - * Improve documentation about Locks (#1701) - * Fixes to allow --redis-url to pass through all tests (#1700) - * Fix unit tests running against Redis 4.0.0 (#1699) - * Search alias test fix (#1695) - * Adding RediSearch/RedisJSON tests (#1691) - * Updating codecov rules (#1689) - * Tests to validate custom JSON decoders (#1681) - * Added breaking icon to release drafter (#1702) - * Removing dependency on six (#1676) - * Re-enable pipeline support for JSON and TimeSeries (#1674) - * Export Sentinel, and SSL like other classes (#1671) - * Restore zrange functionality for older versions of Redis (#1670) - * Fixed garbage collection deadlock (#1578) - * Tests to validate built python packages (#1678) - * Sleep for flaky search test (#1680) - * Test function renames, to match standards (#1679) - * Docstring improvements for Redis class (#1675) - * Fix georadius tests (#1672) - * Improvements to JSON coverage (#1666) - * Add python_requires setuptools check for python > 3.6 (#1656) - * SMISMEMBER support (#1667) - * Exposing the module version in loaded_modules (#1648) - * RedisTimeSeries support (#1652) - * Support for json multipath ($) (#1663) - * Added boolean parsing to PEXPIRE and PEXPIREAT (#1665) - * Add python_requires setuptools check for python > 3.6 (#1656) - * Adding vulture for static analysis (#1655) - * Starting to clean the docs (#1657) - * Update README.md (#1654) - * Adding description format for package (#1651) - * Publish to pypi as releases are generated with the release drafter (#1647) - * Restore actions to prs (#1653) - * Fixing the package to include commands (#1649) - * Re-enabling codecov as part of CI process (#1646) - * Adding support for redisearch (#1640) Thanks @chayim - * redisjson support (#1636) Thanks @chayim - * Sentinel: Add SentinelManagedSSLConnection (#1419) Thanks @AbdealiJK - * Enable floating parameters in SET (ex and px) (#1635) Thanks @AvitalFineRedis - * Add warning when hiredis not installed. Recommend installation. (#1621) Thanks @adiamzn - * Raising NotImplementedError for SCRIPT DEBUG and DEBUG SEGFAULT (#1624) Thanks @chayim - * CLIENT REDIR command support (#1623) Thanks @chayim - * REPLICAOF command implementation (#1622) Thanks @chayim - * Add support to NX XX and CH to GEOADD (#1605) Thanks @AvitalFineRedis - * Add support to ZRANGE and ZRANGESTORE parameters (#1603) Thanks @AvitalFineRedis - * Pre 6.2 redis should default to None for script flush (#1641) Thanks @chayim - * Add FULL option to XINFO SUMMARY (#1638) Thanks @agusdmb - * Geosearch test should use any=True (#1594) Thanks @Andrew-Chen-Wang - * Removing packaging dependency (#1626) Thanks @chayim - * Fix client_kill_filter docs for skimpy (#1596) Thanks @Andrew-Chen-Wang - * Normalize minid and maxlen docs (#1593) Thanks @Andrew-Chen-Wang - * Update docs for multiple usernames for ACL DELUSER (#1595) Thanks @Andrew-Chen-Wang - * Fix grammar of get param in set command (#1588) Thanks @Andrew-Chen-Wang - * Fix docs for client_kill_filter (#1584) Thanks @Andrew-Chen-Wang - * Convert README & CONTRIBUTING from rst to md (#1633) Thanks @davidylee - * Test BYLEX param in zrangestore (#1634) Thanks @AvitalFineRedis - * Tox integrations with invoke and docker (#1632) Thanks @chayim - * Adding the release drafter to help simplify release notes (#1618). Thanks @chayim - * BACKWARDS INCOMPATIBLE: Removed support for end of life Python 2.7. #1318 - * BACKWARDS INCOMPATIBLE: All values within Redis URLs are unquoted via - urllib.parse.unquote. Prior versions of redis-py supported this by - specifying the ``decode_components`` flag to the ``from_url`` functions. - This is now done by default and cannot be disabled. #589 - * POTENTIALLY INCOMPATIBLE: Redis commands were moved into a mixin - (see commands.py). Anyone importing ``redis.client`` to access commands - directly should import ``redis.commands``. #1534, #1550 - * Removed technical debt on REDIS_6_VERSION placeholder. Thanks @chayim #1582. - * Various docus fixes. Thanks @Andrew-Chen-Wang #1585, #1586. - * Support for LOLWUT command, available since Redis 5.0.0. - Thanks @brainix #1568. - * Added support for CLIENT REPLY, available in Redis 3.2.0. - Thanks @chayim #1581. - * Support for Auto-reconnect PubSub on get_message. Thanks @luhn #1574. - * Fix RST syntax error in README/ Thanks @JanCBrammer #1451. - * IDLETIME and FREQ support for RESTORE. Thanks @chayim #1580. - * Supporting args with MODULE LOAD. Thanks @chayim #1579. - * Updating RedisLabs with Redis. Thanks @gkorland #1575. - * Added support for ASYNC to SCRIPT FLUSH available in Redis 6.2.0. - Thanks @chayim. #1567 - * Added CLIENT LIST fix to support multiple client ids available in - Redis 2.8.12. Thanks @chayim #1563. - * Added DISCARD support for pipelines available in Redis 2.0.0. - Thanks @chayim #1565. - * Added ACL DELUSER support for deleting lists of users available in - Redis 6.2.0. Thanks @chayim. #1562 - * Added CLIENT TRACKINFO support available in Redis 6.2.0. - Thanks @chayim. #1560 - * Added GEOSEARCH and GEOSEARCHSTORE support available in Redis 6.2.0. - Thanks @AvitalFine Redis. #1526 - * Added LPUSHX support for lists available in Redis 4.0.0. - Thanks @chayim. #1559 - * Added support for QUIT available in Redis 1.0.0. - Thanks @chayim. #1558 - * Added support for COMMAND COUNT available in Redis 2.8.13. - Thanks @chayim. #1554. - * Added CREATECONSUMER support for XGROUP available in Redis 6.2.0. - Thanks @AvitalFineRedis. #1553 - * Including slowly complexity in INFO if available. - Thanks @ian28223 #1489. - * Added support for STRALGO available in Redis 6.0.0. - Thanks @AvitalFineRedis. #1528 - * Addes support for ZMSCORE available in Redis 6.2.0. - Thanks @2014BDuck and @jiekun.zhu. #1437 - * Support MINID and LIMIT on XADD available in Redis 6.2.0. - Thanks @AvitalFineRedis. #1548 - * Added sentinel commands FLUSHCONFIG, CKQUORUM, FAILOVER, and RESET - available in Redis 2.8.12. - Thanks @otherpirate. #834 - * Migrated Version instead of StrictVersion for Python 3.10. - Thanks @tirkarthi. #1552 - * Added retry mechanism with backoff. Thanks @nbraun-amazon. #1494 - * Migrated commands to a mixin. Thanks @chayim. #1534 - * Added support for ZUNION, available in Redis 6.2.0. Thanks - @AvitalFineRedis. #1522 - * Added support for CLIENT LIST with ID, available in Redis 6.2.0. - Thanks @chayim. #1505 - * Added support for MINID and LIMIT with xtrim, available in Reds 6.2.0. - Thanks @chayim. #1508 - * Implemented LMOVE and BLMOVE commands, available in Redis 6.2.0. - Thanks @chayim. #1504 - * Added GET argument to SET command, available in Redis 6.2.0. - Thanks @2014BDuck. #1412 - * Documentation fixes. Thanks @enjoy-binbin @jonher937. #1496 #1532 - * Added support for XAUTOCLAIM, available in Redis 6.2.0. - Thanks @AvitalFineRedis. #1529 - * Added IDLE support for XPENDING, available in Redis 6.2.0. - Thanks @AvitalFineRedis. #1523 - * Add a count parameter to lpop/rpop, available in Redis 6.2.0. - Thanks @wavenator. #1487 - * Added a (pypy) trove classifier for Python 3.9. - Thanks @D3X. #1535 - * Added ZINTER support, available in Redis 6.2.0. - Thanks @AvitalFineRedis. #1520 - * Added ZINTER support, available in Redis 6.2.0. - Thanks @AvitalFineRedis. #1520 - * Added ZDIFF and ZDIFFSTORE support, available in Redis 6.2.0. - Thanks @AvitalFineRedis. #1518 - * Added ZRANGESTORE support, available in Redis 6.2.0. - Thanks @AvitalFineRedis. #1521 - * Added LT and GT support for ZADD, available in Redis 6.2.0. - Thanks @chayim. #1509 - * Added ZRANDMEMBER support, available in Redis 6.2.0. - Thanks @AvitalFineRedis. #1519 - * Added GETDEL support, available in Redis 6.2.0. - Thanks @AvitalFineRedis. #1514 - * Added CLIENT KILL laddr filter, available in Redis 6.2.0. - Thanks @chayim. #1506 - * Added CLIENT UNPAUSE, available in Redis 6.2.0. - Thanks @chayim. #1512 - * Added NOMKSTREAM support for XADD, available in Redis 6.2.0. - Thanks @chayim. #1507 - * Added HRANDFIELD support, available in Redis 6.2.0. - Thanks @AvitalFineRedis. #1513 - * Added CLIENT INFO support, available in Redis 6.2.0. - Thanks @AvitalFineRedis. #1517 - * Added GETEX support, available in Redis 6.2.0. - Thanks @AvitalFineRedis. #1515 - * Added support for COPY command, available in Redis 6.2.0. - Thanks @malinaa96. #1492 - * Provide a development and testing environment via docker. Thanks - @abrookins. #1365 - * Added support for the LPOS command available in Redis 6.0.6. Thanks - @aparcar #1353/#1354 - * Added support for the ACL LOG command available in Redis 6. Thanks - @2014BDuck. #1307 - * Added support for ABSTTL option of the RESTORE command available in - Redis 5.0. Thanks @charettes. #1423 * 3.5.3 (June 1, 2020) * Restore try/except clauses to __del__ methods. These will be removed in 4.0 when more explicit resource management if enforced. #1339 diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 9faf5b891d..9652def198 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -141,7 +141,6 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand listed in the CLUSTER SLOTS output. If you use dynamic DNS endpoints for startup nodes but CLUSTER SLOTS lists specific IP addresses, it is best to set it to false. - The data read from replicas is eventually consistent with the data in primary nodes. :param reinitialize_steps: | Specifies the number of MOVED errors that need to occur before reinitializing the whole cluster topology. If a MOVED error occurs and the cluster does not From 6f53c029775feaebdc63b86fbc4b63dbcef34d46 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Tue, 13 May 2025 20:06:17 +0300 Subject: [PATCH 113/113] Applying review comments --- redis/asyncio/cluster.py | 5 +++-- tests/test_asyncio/test_cluster.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 9652def198..13a2606cc3 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -1196,13 +1196,14 @@ def __init__( self.startup_nodes = {node.name: node for node in startup_nodes} self.require_full_coverage = require_full_coverage self.connection_kwargs = connection_kwargs - self._dynamic_startup_nodes = dynamic_startup_nodes self.address_remap = address_remap self.default_node: "ClusterNode" = None self.nodes_cache: Dict[str, "ClusterNode"] = {} self.slots_cache: Dict[int, List["ClusterNode"]] = {} - self.read_load_balancer = LoadBalancer() + self.read_load_balancer: LoadBalancer = LoadBalancer() + + self._dynamic_startup_nodes: bool = dynamic_startup_nodes self._moved_exception: MovedError = None if event_dispatcher is None: self._event_dispatcher = EventDispatcher() diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index b56ad6dbd1..1b3fbd5526 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -2740,7 +2740,7 @@ async def test_init_slots_dynamic_startup_nodes(self, dynamic_startup_nodes): ] startup_nodes = list(rc.nodes_manager.startup_nodes.keys()) if dynamic_startup_nodes is True: - assert startup_nodes.sort() == discovered_nodes.sort() + assert sorted(startup_nodes) == sorted(discovered_nodes) else: assert startup_nodes == ["my@DNS.com:7000"]