@@ -801,6 +801,7 @@ class AbstractRedis:
801
801
"QUIT" : bool_ok ,
802
802
"STRALGO" : parse_stralgo ,
803
803
"PUBSUB NUMSUB" : parse_pubsub_numsub ,
804
+ "PUBSUB SHARDNUMSUB" : parse_pubsub_numsub ,
804
805
"RANDOMKEY" : lambda r : r and r or None ,
805
806
"RESET" : str_if_bytes ,
806
807
"SCAN" : parse_scan ,
@@ -1365,8 +1366,8 @@ class PubSub:
1365
1366
will be returned and it's safe to start listening again.
1366
1367
"""
1367
1368
1368
- PUBLISH_MESSAGE_TYPES = ("message" , "pmessage" )
1369
- UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe" , "punsubscribe" )
1369
+ PUBLISH_MESSAGE_TYPES = ("message" , "pmessage" , "smessage" )
1370
+ UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe" , "punsubscribe" , "sunsubscribe" )
1370
1371
HEALTH_CHECK_MESSAGE = "redis-py-health-check"
1371
1372
1372
1373
def __init__ (
@@ -1414,9 +1415,11 @@ def reset(self):
1414
1415
self .connection .clear_connect_callbacks ()
1415
1416
self .connection_pool .release (self .connection )
1416
1417
self .connection = None
1417
- self .channels = {}
1418
1418
self .health_check_response_counter = 0
1419
+ self .channels = {}
1419
1420
self .pending_unsubscribe_channels = set ()
1421
+ self .shard_channels = {}
1422
+ self .pending_unsubscribe_shard_channels = set ()
1420
1423
self .patterns = {}
1421
1424
self .pending_unsubscribe_patterns = set ()
1422
1425
self .subscribed_event .clear ()
@@ -1431,16 +1434,23 @@ def on_connect(self, connection):
1431
1434
# before passing them to [p]subscribe.
1432
1435
self .pending_unsubscribe_channels .clear ()
1433
1436
self .pending_unsubscribe_patterns .clear ()
1437
+ self .pending_unsubscribe_shard_channels .clear ()
1434
1438
if self .channels :
1435
- channels = {}
1436
- for k , v in self .channels .items ():
1437
- channels [ self . encoder . decode ( k , force = True )] = v
1439
+ channels = {
1440
+ self . encoder . decode ( k , force = True ): v for k , v in self .channels .items ()
1441
+ }
1438
1442
self .subscribe (** channels )
1439
1443
if self .patterns :
1440
- patterns = {}
1441
- for k , v in self .patterns .items ():
1442
- patterns [ self . encoder . decode ( k , force = True )] = v
1444
+ patterns = {
1445
+ self . encoder . decode ( k , force = True ): v for k , v in self .patterns .items ()
1446
+ }
1443
1447
self .psubscribe (** patterns )
1448
+ if self .shard_channels :
1449
+ shard_channels = {
1450
+ self .encoder .decode (k , force = True ): v
1451
+ for k , v in self .shard_channels .items ()
1452
+ }
1453
+ self .ssubscribe (** shard_channels )
1444
1454
1445
1455
@property
1446
1456
def subscribed (self ):
@@ -1647,6 +1657,45 @@ def unsubscribe(self, *args):
1647
1657
self .pending_unsubscribe_channels .update (channels )
1648
1658
return self .execute_command ("UNSUBSCRIBE" , * args )
1649
1659
1660
+ def ssubscribe (self , * args , target_node = None , ** kwargs ):
1661
+ """
1662
+ Subscribes the client to the specified shard channels.
1663
+ Channels supplied as keyword arguments expect a channel name as the key
1664
+ and a callable as the value. A channel's callable will be invoked automatically
1665
+ when a message is received on that channel rather than producing a message via
1666
+ ``listen()`` or ``get_sharded_message()``.
1667
+ """
1668
+ if args :
1669
+ args = list_or_args (args [0 ], args [1 :])
1670
+ new_s_channels = dict .fromkeys (args )
1671
+ new_s_channels .update (kwargs )
1672
+ ret_val = self .execute_command ("SSUBSCRIBE" , * new_s_channels .keys ())
1673
+ # update the s_channels dict AFTER we send the command. we don't want to
1674
+ # subscribe twice to these channels, once for the command and again
1675
+ # for the reconnection.
1676
+ new_s_channels = self ._normalize_keys (new_s_channels )
1677
+ self .shard_channels .update (new_s_channels )
1678
+ if not self .subscribed :
1679
+ # Set the subscribed_event flag to True
1680
+ self .subscribed_event .set ()
1681
+ # Clear the health check counter
1682
+ self .health_check_response_counter = 0
1683
+ self .pending_unsubscribe_shard_channels .difference_update (new_s_channels )
1684
+ return ret_val
1685
+
1686
+ def sunsubscribe (self , * args , target_node = None ):
1687
+ """
1688
+ Unsubscribe from the supplied shard_channels. If empty, unsubscribe from
1689
+ all shard_channels
1690
+ """
1691
+ if args :
1692
+ args = list_or_args (args [0 ], args [1 :])
1693
+ s_channels = self ._normalize_keys (dict .fromkeys (args ))
1694
+ else :
1695
+ s_channels = self .shard_channels
1696
+ self .pending_unsubscribe_shard_channels .update (s_channels )
1697
+ return self .execute_command ("SUNSUBSCRIBE" , * args )
1698
+
1650
1699
def listen (self ):
1651
1700
"Listen for messages on channels this client has been subscribed to"
1652
1701
while self .subscribed :
@@ -1681,6 +1730,8 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0.0):
1681
1730
return self .handle_message (response , ignore_subscribe_messages )
1682
1731
return None
1683
1732
1733
+ get_sharded_message = get_message
1734
+
1684
1735
def ping (self , message = None ):
1685
1736
"""
1686
1737
Ping the Redis server
@@ -1726,12 +1777,17 @@ def handle_message(self, response, ignore_subscribe_messages=False):
1726
1777
if pattern in self .pending_unsubscribe_patterns :
1727
1778
self .pending_unsubscribe_patterns .remove (pattern )
1728
1779
self .patterns .pop (pattern , None )
1780
+ elif message_type == "sunsubscribe" :
1781
+ s_channel = response [1 ]
1782
+ if s_channel in self .pending_unsubscribe_shard_channels :
1783
+ self .pending_unsubscribe_shard_channels .remove (s_channel )
1784
+ self .shard_channels .pop (s_channel , None )
1729
1785
else :
1730
1786
channel = response [1 ]
1731
1787
if channel in self .pending_unsubscribe_channels :
1732
1788
self .pending_unsubscribe_channels .remove (channel )
1733
1789
self .channels .pop (channel , None )
1734
- if not self .channels and not self .patterns :
1790
+ if not self .channels and not self .patterns and not self . shard_channels :
1735
1791
# There are no subscriptions anymore, set subscribed_event flag
1736
1792
# to false
1737
1793
self .subscribed_event .clear ()
@@ -1740,6 +1796,8 @@ def handle_message(self, response, ignore_subscribe_messages=False):
1740
1796
# if there's a message handler, invoke it
1741
1797
if message_type == "pmessage" :
1742
1798
handler = self .patterns .get (message ["pattern" ], None )
1799
+ elif message_type == "smessage" :
1800
+ handler = self .shard_channels .get (message ["channel" ], None )
1743
1801
else :
1744
1802
handler = self .channels .get (message ["channel" ], None )
1745
1803
if handler :
@@ -1760,6 +1818,11 @@ def run_in_thread(self, sleep_time=0, daemon=False, exception_handler=None):
1760
1818
for pattern , handler in self .patterns .items ():
1761
1819
if handler is None :
1762
1820
raise PubSubError (f"Pattern: '{ pattern } ' has no handler registered" )
1821
+ for s_channel , handler in self .shard_channels .items ():
1822
+ if handler is None :
1823
+ raise PubSubError (
1824
+ f"Shard Channel: '{ s_channel } ' has no handler registered"
1825
+ )
1763
1826
1764
1827
thread = PubSubWorkerThread (
1765
1828
self , sleep_time , daemon = daemon , exception_handler = exception_handler
0 commit comments