From 2e556ed9a848c67d8739d9040126f29be25c4035 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Mon, 22 Nov 2021 10:46:49 +0200 Subject: [PATCH 01/11] Changed get_message logic to wait for subscription --- redis/client.py | 19 +++++++++++++++++++ tests/test_pubsub.py | 43 ++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/redis/client.py b/redis/client.py index c02bc3a4d5..2bf8fd5595 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1508,11 +1508,30 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0): before returning. Timeout should be specified as a floating point number. """ + if not self.subscribed and \ + self.wait_for_subscription(timeout) is False: + # The connection isn't subscribed to any channels or patterns, so + # no messages are available + return None + response = self.parse_response(block=False, timeout=timeout) if response: return self.handle_message(response, ignore_subscribe_messages) return None + def wait_for_subscription(self, timeout, period=0.25): + """ + Wait until this pubsub connection has been subscribed. + Return True if the connection was subscribed during the timeout + frametime. Otherwise, return False. + """ + mustend = time.time() + timeout + while time.time() < mustend: + if self.subscribed: + return True + time.sleep(period) + return False + def ping(self, message=None): """ Ping the Redis server diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 20ae0a05c1..b12a07b258 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -6,9 +6,14 @@ import pytest import redis +from redis.client import PubSub from redis.exceptions import ConnectionError -from .conftest import _get_client, skip_if_redis_enterprise, skip_if_server_version_lt +from .conftest import ( + _get_client, + skip_if_redis_enterprise, + skip_if_server_version_lt +) def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False): @@ -16,8 +21,7 @@ def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False): timeout = now + timeout while now < timeout: message = pubsub.get_message( - ignore_subscribe_messages=ignore_subscribe_messages - ) + ignore_subscribe_messages=ignore_subscribe_messages) if message is not None: return message time.sleep(0.01) @@ -549,6 +553,39 @@ def test_get_message_with_timeout_returns_none(self, r): assert wait_for_message(p) == make_message("subscribe", "foo", 1) assert p.get_message(timeout=0.01) is None + def test_get_message_not_subscribed_return_none(self, r): + p = r.pubsub() + assert p.subscribed is False + assert p.get_message() is None + assert p.get_message(timeout=0.1) is None + with patch.object(PubSub, 'wait_for_subscription') as mock: + mock.return_value = False + assert p.get_message(timeout=0.01) is None + assert mock.called + + def test_get_message_subscribe_during_waiting(self, r): + p = r.pubsub() + + def poll(ps, expected_res): + assert ps.get_message() is None + message = ps.get_message(timeout=1) + assert message == expected_res + + subscribe_response = make_message('subscribe', 'foo', 1) + poller = threading.Thread(target=poll, args=(p, subscribe_response)) + poller.start() + time.sleep(0.2) + p.subscribe('foo') + poller.join() + + def test_get_message_wait_for_subscription_not_being_called(self, r): + p = r.pubsub() + p.subscribe('foo') + with patch.object(PubSub, 'wait_for_subscription') as mock: + assert p.subscribed is True + assert wait_for_message(p) == make_message('subscribe', 'foo', 1) + assert mock.called is False + class TestPubSubWorkerThread: @pytest.mark.skipif( From deee567bfafda69411ecafc84965174a7d8d61f1 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Tue, 23 Nov 2021 13:46:11 +0200 Subject: [PATCH 02/11] Added an threading Event 'subscribed_event' to set on subscription and clear on unsubscription --- redis/client.py | 45 ++++++++++++++++++++++++++------------------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/redis/client.py b/redis/client.py index 2bf8fd5595..a86f1f812f 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1276,6 +1276,7 @@ def __init__( self.shard_hint = shard_hint self.ignore_subscribe_messages = ignore_subscribe_messages self.connection = None + self.subscribed_event = threading.Event() # we need to know the encoding options for this connection in order # to lookup channel and pattern names for callback handlers. self.encoder = encoder @@ -1315,6 +1316,7 @@ def reset(self): self.pending_unsubscribe_channels = set() self.patterns = {} self.pending_unsubscribe_patterns = set() + self.subscribed_event.clear() def close(self): self.reset() @@ -1340,7 +1342,7 @@ def on_connect(self, connection): @property def subscribed(self): "Indicates if there are subscriptions to any channels or patterns" - return bool(self.channels or self.patterns) + return self.subscribed_event.is_set() def execute_command(self, *args): "Execute a publish/subscribe command" @@ -1443,6 +1445,9 @@ def psubscribe(self, *args, **kwargs): # for the reconnection. new_patterns = self._normalize_keys(new_patterns) self.patterns.update(new_patterns) + if not self.subscribed: + # Set the subscribed_event flag to True + self.subscribed_event.set() self.pending_unsubscribe_patterns.difference_update(new_patterns) return ret_val @@ -1477,6 +1482,9 @@ def subscribe(self, *args, **kwargs): # for the reconnection. new_channels = self._normalize_keys(new_channels) self.channels.update(new_channels) + if not self.subscribed: + # Set the subscribed_event flag to True + self.subscribed_event.set() self.pending_unsubscribe_channels.difference_update(new_channels) return ret_val @@ -1508,30 +1516,25 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0): before returning. Timeout should be specified as a floating point number. """ - if not self.subscribed and \ - self.wait_for_subscription(timeout) is False: - # The connection isn't subscribed to any channels or patterns, so - # no messages are available - return None + if not self.subscribed: + # Wait for subscription + deadline = time.time() + timeout + if self.subscribed_event.wait(timeout) is True: + # The connection was subscribed during the timeout frametime. + # The timeout should be adjusted for the time spent waiting + # for subscription + time_spent = deadline - time.time() + timeout = timeout - time_spent + else: + # The connection isn't subscribed to any channels or patterns, + # so no messages are available + return None response = self.parse_response(block=False, timeout=timeout) if response: return self.handle_message(response, ignore_subscribe_messages) return None - def wait_for_subscription(self, timeout, period=0.25): - """ - Wait until this pubsub connection has been subscribed. - Return True if the connection was subscribed during the timeout - frametime. Otherwise, return False. - """ - mustend = time.time() + timeout - while time.time() < mustend: - if self.subscribed: - return True - time.sleep(period) - return False - def ping(self, message=None): """ Ping the Redis server @@ -1580,6 +1583,10 @@ def handle_message(self, response, ignore_subscribe_messages=False): if channel in self.pending_unsubscribe_channels: self.pending_unsubscribe_channels.remove(channel) self.channels.pop(channel, None) + if not self.channels and not self.patterns: + # There are no subscriptions anymore, set subscribed_event flag + # to false + self.subscribed_event.clear() if message_type in self.PUBLISH_MESSAGE_TYPES: # if there's a message handler, invoke it From ea64f5b7d479469847d83751f77a90561680cccc Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Tue, 23 Nov 2021 13:54:46 +0200 Subject: [PATCH 03/11] Fixed pubsub tests --- tests/test_pubsub.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index b12a07b258..6a5a173f0b 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -558,7 +558,7 @@ def test_get_message_not_subscribed_return_none(self, r): assert p.subscribed is False assert p.get_message() is None assert p.get_message(timeout=0.1) is None - with patch.object(PubSub, 'wait_for_subscription') as mock: + with patch.object(threading.Event, 'wait') as mock: mock.return_value = False assert p.get_message(timeout=0.01) is None assert mock.called @@ -581,7 +581,7 @@ def poll(ps, expected_res): def test_get_message_wait_for_subscription_not_being_called(self, r): p = r.pubsub() p.subscribe('foo') - with patch.object(PubSub, 'wait_for_subscription') as mock: + with patch.object(threading.Event, 'wait') as mock: assert p.subscribed is True assert wait_for_message(p) == make_message('subscribe', 'foo', 1) assert mock.called is False From 50c0ef8e655a908e5d374db57d9323ff11adf68c Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Tue, 23 Nov 2021 14:18:47 +0200 Subject: [PATCH 04/11] removed unused import --- tests/test_pubsub.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 6a5a173f0b..a20f80b393 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -6,7 +6,6 @@ import pytest import redis -from redis.client import PubSub from redis.exceptions import ConnectionError from .conftest import ( From c2ebd78ad5f2763b1eac5cf1ef77f224f6c737c7 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Tue, 23 Nov 2021 14:57:25 +0200 Subject: [PATCH 05/11] fixed the new timeout --- redis/client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/redis/client.py b/redis/client.py index a86f1f812f..d7a924f321 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1518,12 +1518,12 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0): """ if not self.subscribed: # Wait for subscription - deadline = time.time() + timeout + start_time = time.time() if self.subscribed_event.wait(timeout) is True: # The connection was subscribed during the timeout frametime. - # The timeout should be adjusted for the time spent waiting - # for subscription - time_spent = deadline - time.time() + # The timeout should be adjusted based on the time spent + # waiting for the subscription + time_spent = time.time() - start_time timeout = timeout - time_spent else: # The connection isn't subscribed to any channels or patterns, From 4b404120f488fcdd335c60df844fb4acfc5c89d1 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Tue, 23 Nov 2021 16:48:59 +0200 Subject: [PATCH 06/11] 1) Fixes issue #1740. 2) Changed health check to be executed from the execute_command method only in the first command execution. --- redis/client.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/redis/client.py b/redis/client.py index d7a924f321..a7511fa0d9 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1282,12 +1282,14 @@ def __init__( self.encoder = encoder if self.encoder is None: self.encoder = self.connection_pool.get_encoder() + self.health_check_response_b = self.encoder.encode( + self.HEALTH_CHECK_MESSAGE) if self.encoder.decode_responses: self.health_check_response = ["pong", self.HEALTH_CHECK_MESSAGE] else: self.health_check_response = [ b"pong", - self.encoder.encode(self.HEALTH_CHECK_MESSAGE), + self.health_check_response_b ] self.reset() @@ -1401,7 +1403,14 @@ def parse_response(self, block=True, timeout=0): return None response = self._execute(conn, conn.read_response) - if conn.health_check_interval and response == self.health_check_response: + if conn.health_check_interval and \ + response in [ + self.health_check_response, # If there was a subscription + self.health_check_response_b # If there wasn't + ]: + # If there are no subscriptions redis responds to PING command with + # a bulk response, instead of a multi-bulk with "pong" and the + # response. # ignore the health check message as user might not expect it return None return response From b26d48f30f82d273f9eeca59d914b7d3cf689a9b Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Tue, 30 Nov 2021 12:04:37 +0200 Subject: [PATCH 07/11] Added a function to clean the socket from health check responses if not the connection isn't subscribed --- redis/client.py | 38 ++++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/redis/client.py b/redis/client.py index a7511fa0d9..7b666ef48f 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1362,8 +1362,22 @@ def execute_command(self, *args): self.connection.register_connect_callback(self.on_connect) connection = self.connection kwargs = {"check_health": not self.subscribed} + if not self.subscribed: + self.clean_health_check_responses() self._execute(connection, connection.send_command, *args, **kwargs) + + def clean_health_check_responses(self): + """ + If any health check responses are present, clean them + """ + conn = self.connection + while self._execute(conn, conn.can_read, timeout=0): + response = self._execute(conn, conn.read_response) + if not self.is_health_check_response(response): + raise PubSubError('A non health check response was cleaned by ' + 'execute_command: {0}'.format(response)) + def _disconnect_raise_connect(self, conn, error): """ Close the connection and raise an exception @@ -1403,18 +1417,22 @@ def parse_response(self, block=True, timeout=0): return None response = self._execute(conn, conn.read_response) - if conn.health_check_interval and \ - response in [ - self.health_check_response, # If there was a subscription - self.health_check_response_b # If there wasn't - ]: - # If there are no subscriptions redis responds to PING command with - # a bulk response, instead of a multi-bulk with "pong" and the - # response. + if self.is_health_check_response(response): # ignore the health check message as user might not expect it return None return response + def is_health_check_response(self, response): + """ + Check if the response is a health check response. + If there are no subscriptions redis responds to PING command with a + bulk response, instead of a multi-bulk with "pong" and the response. + """ + return response in [ + self.health_check_response, # If there was a subscription + self.health_check_response_b # If there wasn't + ] + def check_health(self): conn = self.connection if conn is None: @@ -1529,11 +1547,11 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0): # Wait for subscription start_time = time.time() if self.subscribed_event.wait(timeout) is True: - # The connection was subscribed during the timeout frametime. + # The connection was subscribed during the timeout time frame. # The timeout should be adjusted based on the time spent # waiting for the subscription time_spent = time.time() - start_time - timeout = timeout - time_spent + timeout = max(0.0, timeout - time_spent) else: # The connection isn't subscribed to any channels or patterns, # so no messages are available From 27ed485162686985e73dcf10d45839cf049fe5c6 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Tue, 30 Nov 2021 12:14:37 +0200 Subject: [PATCH 08/11] Fixed linters --- redis/client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/redis/client.py b/redis/client.py index 7b666ef48f..9c506493e6 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1366,7 +1366,6 @@ def execute_command(self, *args): self.clean_health_check_responses() self._execute(connection, connection.send_command, *args, **kwargs) - def clean_health_check_responses(self): """ If any health check responses are present, clean them From d0a95fa8870787351edacf6fba23249ba9e25f18 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Wed, 8 Dec 2021 13:18:54 +0200 Subject: [PATCH 09/11] Added a health check response counter --- redis/client.py | 34 ++++++++++++++++++++-------------- tests/test_pubsub.py | 31 ++++++++++--------------------- 2 files changed, 30 insertions(+), 35 deletions(-) diff --git a/redis/client.py b/redis/client.py index 9c506493e6..a68c9d13c9 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1282,15 +1282,11 @@ def __init__( self.encoder = encoder if self.encoder is None: self.encoder = self.connection_pool.get_encoder() - self.health_check_response_b = self.encoder.encode( - self.HEALTH_CHECK_MESSAGE) + self.health_check_response_b = self.encoder.encode(self.HEALTH_CHECK_MESSAGE) if self.encoder.decode_responses: self.health_check_response = ["pong", self.HEALTH_CHECK_MESSAGE] else: - self.health_check_response = [ - b"pong", - self.health_check_response_b - ] + self.health_check_response = [b"pong", self.health_check_response_b] self.reset() def __enter__(self): @@ -1315,6 +1311,7 @@ def reset(self): self.connection_pool.release(self.connection) self.connection = None self.channels = {} + self.health_check_response_counter = 0 self.pending_unsubscribe_channels = set() self.patterns = {} self.pending_unsubscribe_patterns = set() @@ -1370,12 +1367,19 @@ def clean_health_check_responses(self): """ If any health check responses are present, clean them """ + ttl = 10 conn = self.connection - while self._execute(conn, conn.can_read, timeout=0): - response = self._execute(conn, conn.read_response) - if not self.is_health_check_response(response): - raise PubSubError('A non health check response was cleaned by ' - 'execute_command: {0}'.format(response)) + while self.health_check_response_counter > 0 and ttl > 0: + if self._execute(conn, conn.can_read, timeout=1): + response = self._execute(conn, conn.read_response) + if self.is_health_check_response(response): + self.health_check_response_counter -= 1 + else: + raise PubSubError( + "A non health check response was cleaned by " + "execute_command: {0}".format(response) + ) + ttl -= 1 def _disconnect_raise_connect(self, conn, error): """ @@ -1418,6 +1422,7 @@ def parse_response(self, block=True, timeout=0): if self.is_health_check_response(response): # ignore the health check message as user might not expect it + self.health_check_response_counter -= 1 return None return response @@ -1428,9 +1433,9 @@ def is_health_check_response(self, response): bulk response, instead of a multi-bulk with "pong" and the response. """ return response in [ - self.health_check_response, # If there was a subscription - self.health_check_response_b # If there wasn't - ] + self.health_check_response, # If there was a subscription + self.health_check_response_b, # If there wasn't + ] def check_health(self): conn = self.connection @@ -1442,6 +1447,7 @@ def check_health(self): if conn.health_check_interval and time.time() > conn.next_health_check: conn.send_command("PING", self.HEALTH_CHECK_MESSAGE, check_health=False) + self.health_check_response_counter += 1 def _normalize_keys(self, data): """ diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index a20f80b393..23af46153f 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -2,17 +2,14 @@ import threading import time from unittest import mock +from unittest.mock import patch import pytest import redis from redis.exceptions import ConnectionError -from .conftest import ( - _get_client, - skip_if_redis_enterprise, - skip_if_server_version_lt -) +from .conftest import _get_client, skip_if_redis_enterprise, skip_if_server_version_lt def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False): @@ -20,7 +17,8 @@ def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False): timeout = now + timeout while now < timeout: message = pubsub.get_message( - ignore_subscribe_messages=ignore_subscribe_messages) + ignore_subscribe_messages=ignore_subscribe_messages + ) if message is not None: return message time.sleep(0.01) @@ -351,15 +349,6 @@ def test_unicode_pattern_message_handler(self, r): "pmessage", channel, "test message", pattern=pattern ) - def test_get_message_without_subscribe(self, r): - p = r.pubsub() - with pytest.raises(RuntimeError) as info: - p.get_message() - expect = ( - "connection not set: " "did you forget to call subscribe() or psubscribe()?" - ) - assert expect in info.exconly() - class TestPubSubAutoDecoding: "These tests only validate that we get unicode values back" @@ -557,7 +546,7 @@ def test_get_message_not_subscribed_return_none(self, r): assert p.subscribed is False assert p.get_message() is None assert p.get_message(timeout=0.1) is None - with patch.object(threading.Event, 'wait') as mock: + with patch.object(threading.Event, "wait") as mock: mock.return_value = False assert p.get_message(timeout=0.01) is None assert mock.called @@ -570,19 +559,19 @@ def poll(ps, expected_res): message = ps.get_message(timeout=1) assert message == expected_res - subscribe_response = make_message('subscribe', 'foo', 1) + subscribe_response = make_message("subscribe", "foo", 1) poller = threading.Thread(target=poll, args=(p, subscribe_response)) poller.start() time.sleep(0.2) - p.subscribe('foo') + p.subscribe("foo") poller.join() def test_get_message_wait_for_subscription_not_being_called(self, r): p = r.pubsub() - p.subscribe('foo') - with patch.object(threading.Event, 'wait') as mock: + p.subscribe("foo") + with patch.object(threading.Event, "wait") as mock: assert p.subscribed is True - assert wait_for_message(p) == make_message('subscribe', 'foo', 1) + assert wait_for_message(p) == make_message("subscribe", "foo", 1) assert mock.called is False From c03ef638655667464c41d1451a9e544b85d6146a Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Wed, 8 Dec 2021 13:31:37 +0200 Subject: [PATCH 10/11] Clear the health check counter the first subscription --- redis/client.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/redis/client.py b/redis/client.py index a68c9d13c9..7bedcdd6fe 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1480,6 +1480,8 @@ def psubscribe(self, *args, **kwargs): if not self.subscribed: # Set the subscribed_event flag to True self.subscribed_event.set() + # Clear the health check counter + self.health_check_response_counter = 0 self.pending_unsubscribe_patterns.difference_update(new_patterns) return ret_val @@ -1517,6 +1519,8 @@ def subscribe(self, *args, **kwargs): if not self.subscribed: # Set the subscribed_event flag to True self.subscribed_event.set() + # Clear the health check counter + self.health_check_response_counter = 0 self.pending_unsubscribe_channels.difference_update(new_channels) return ret_val From 3cf820f0112147ab8e889d1a50c020947f8ab17f Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Wed, 22 Dec 2021 14:46:41 +0200 Subject: [PATCH 11/11] Changed clean_health_check_response timeout to the connection's socket_timeout --- redis/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/client.py b/redis/client.py index 7bedcdd6fe..3d6255c20a 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1370,7 +1370,7 @@ def clean_health_check_responses(self): ttl = 10 conn = self.connection while self.health_check_response_counter > 0 and ttl > 0: - if self._execute(conn, conn.can_read, timeout=1): + if self._execute(conn, conn.can_read, timeout=conn.socket_timeout): response = self._execute(conn, conn.read_response) if self.is_health_check_response(response): self.health_check_response_counter -= 1