From 4d1cfc54c66442f70f1ec44ac84f2b9075a410b8 Mon Sep 17 00:00:00 2001 From: Andrew-Chen-Wang Date: Mon, 24 Jan 2022 11:21:26 -0500 Subject: [PATCH 01/24] Initial aioredis import --- redis/asyncio/__init__.py | 21 + redis/asyncio/client.py | 2082 +++++++++++++++++++++++++++++++++++ redis/asyncio/connection.py | 1696 ++++++++++++++++++++++++++++ redis/asyncio/lock.py | 306 +++++ redis/asyncio/retry.py | 60 + redis/asyncio/sentinel.py | 355 ++++++ redis/compat.py | 9 + redis/typing.py | 45 + 8 files changed, 4574 insertions(+) create mode 100644 redis/asyncio/__init__.py create mode 100644 redis/asyncio/client.py create mode 100644 redis/asyncio/connection.py create mode 100644 redis/asyncio/lock.py create mode 100644 redis/asyncio/retry.py create mode 100644 redis/asyncio/sentinel.py create mode 100644 redis/compat.py create mode 100644 redis/typing.py diff --git a/redis/asyncio/__init__.py b/redis/asyncio/__init__.py new file mode 100644 index 0000000000..b762b70642 --- /dev/null +++ b/redis/asyncio/__init__.py @@ -0,0 +1,21 @@ +from redis.asyncio.client import Redis, StrictRedis +from redis.asyncio.connection import ( + BlockingConnectionPool, + Connection, + ConnectionPool, + SSLConnection, + UnixDomainSocketConnection, +) +from redis.asyncio.utils import from_url + + +__all__ = [ + "BlockingConnectionPool", + "Connection", + "ConnectionPool", + "from_url", + "Redis", + "SSLConnection", + "StrictRedis", + "UnixDomainSocketConnection", +] diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py new file mode 100644 index 0000000000..77fd3321c8 --- /dev/null +++ b/redis/asyncio/client.py @@ -0,0 +1,2082 @@ +import asyncio +import copy +import datetime +import inspect +import re +import warnings +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + Dict, + Iterable, + List, + Mapping, + MutableMapping, + NoReturn, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, + TYPE_CHECKING, +) + +from redis.asyncio.connection import ( + Connection, + ConnectionPool, + SSLConnection, + UnixDomainSocketConnection, +) +from redis.commands import ( + CoreCommands, + RedisModuleCommands, + SentinelCommands, + list_or_args, +) +from redis.compat import Protocol, TypedDict +from redis.exceptions import ( + ConnectionError, + ExecAbortError, + ModuleError, + PubSubError, + RedisError, + ResponseError, + TimeoutError, + WatchError, +) +from redis.lock import Lock +from redis.retry import Retry +from redis.typing import ChannelT, EncodableT, KeyT +from redis.utils import safe_str, str_if_bytes + +PubSubHandler = Callable[[Dict[str, str]], None] +_KeyT = TypeVar("_KeyT", bound=KeyT) +_ArgT = TypeVar("_ArgT", KeyT, EncodableT) +_RedisT = TypeVar("_RedisT", bound="Redis") +_NormalizeKeysT = TypeVar("_NormalizeKeysT", bound=Mapping[ChannelT, object]) +if TYPE_CHECKING: + from redis.commands.core import Script + +SYM_EMPTY = b"" +EMPTY_RESPONSE = "EMPTY_RESPONSE" + +# some responses (ie. dump) are binary, and just meant to never be decoded +NEVER_DECODE = "NEVER_DECODE" + + +def timestamp_to_datetime(response): + """Converts a unix timestamp to a Python datetime object""" + if not response: + return None + try: + response = int(response) + except ValueError: + return None + return datetime.datetime.fromtimestamp(response) + + +def string_keys_to_dict(key_string, callback): + return dict.fromkeys(key_string.split(), callback) + + +class CaseInsensitiveDict(dict): + """Case insensitive dict implementation. Assumes string keys only.""" + + def __init__(self, data: Mapping[str, Any]): + for k, v in data.items(): + self[k.upper()] = v + + def __contains__(self, k): + return super().__contains__(k.upper()) + + def __delitem__(self, k): + super().__delitem__(k.upper()) + + def __getitem__(self, k): + return super().__getitem__(k.upper()) + + def get(self, k, default=None): + return super().get(k.upper(), default) + + def __setitem__(self, k, v): + super().__setitem__(k.upper(), v) + + def update(self, data): + data = CaseInsensitiveDict(data) + super().update(data) + + +def parse_debug_object(response): + """Parse the results of Redis's DEBUG OBJECT command into a Python dict""" + # The 'type' of the object is the first item in the response, but isn't + # prefixed with a name + response = str_if_bytes(response) + response = "type:" + response + response = dict(kv.split(":") for kv in response.split()) + + # parse some expected int values from the string response + # note: this cmd isn't spec'd so these may not appear in all redis versions + int_fields = ("refcount", "serializedlength", "lru", "lru_seconds_idle") + for field in int_fields: + if field in response: + response[field] = int(response[field]) + + return response + + +def parse_object(response, infotype): + """Parse the results of an OBJECT command""" + if infotype in ("idletime", "refcount"): + return int_or_none(response) + return response + + +def parse_info(response): + """Parse the result of Redis's INFO command into a Python dict""" + info: Dict[str, Any] = {} + response = str_if_bytes(response) + + def get_value(value): + if "," not in value or "=" not in value: + try: + if "." in value: + return float(value) + else: + return int(value) + except ValueError: + return value + else: + sub_dict = {} + for item in value.split(","): + k, v = item.rsplit("=", 1) + sub_dict[k] = get_value(v) + return sub_dict + + for line in response.splitlines(): + if line and not line.startswith("#"): + if line.find(":") != -1: + # Split, the info fields keys and values. + # Note that the value may contain ':'. but the 'host:' + # pseudo-command is the only case where the key contains ':' + key, value = line.split(":", 1) + if key == "cmdstat_host": + key, value = line.rsplit(":", 1) + + if key == "module": + # Hardcode a list for key 'modules' since there could be + # multiple lines that started with 'module' + info.setdefault("modules", []).append(get_value(value)) + else: + info[key] = get_value(value) + else: + # if the line isn't splittable, append it to the "__raw__" key + info.setdefault("__raw__", []).append(line) + + return info + + +def parse_memory_stats(response, **kwargs): + """Parse the results of MEMORY STATS""" + stats = pairs_to_dict(response, decode_keys=True, decode_string_values=True) + for key, value in stats.items(): + if key.startswith("db."): + stats[key] = pairs_to_dict( + value, decode_keys=True, decode_string_values=True + ) + return stats + + +SENTINEL_STATE_TYPES = { + "can-failover-its-master": int, + "config-epoch": int, + "down-after-milliseconds": int, + "failover-timeout": int, + "info-refresh": int, + "last-hello-message": int, + "last-ok-ping-reply": int, + "last-ping-reply": int, + "last-ping-sent": int, + "master-link-down-time": int, + "master-port": int, + "num-other-sentinels": int, + "num-slaves": int, + "o-down-time": int, + "pending-commands": int, + "parallel-syncs": int, + "port": int, + "quorum": int, + "role-reported-time": int, + "s-down-time": int, + "slave-priority": int, + "slave-repl-offset": int, + "voted-leader-epoch": int, +} + + +def parse_sentinel_state(item): + result = pairs_to_dict_typed(item, SENTINEL_STATE_TYPES) + flags = set(result["flags"].split(",")) + for name, flag in ( + ("is_master", "master"), + ("is_slave", "slave"), + ("is_sdown", "s_down"), + ("is_odown", "o_down"), + ("is_sentinel", "sentinel"), + ("is_disconnected", "disconnected"), + ("is_master_down", "master_down"), + ): + result[name] = flag in flags + return result + + +def parse_sentinel_master(response): + return parse_sentinel_state(map(str_if_bytes, response)) + + +def parse_sentinel_masters(response): + result = {} + for item in response: + state = parse_sentinel_state(map(str_if_bytes, item)) + result[state["name"]] = state + return result + + +def parse_sentinel_slaves_and_sentinels(response): + return [parse_sentinel_state(map(str_if_bytes, item)) for item in response] + + +def parse_sentinel_get_master(response): + return response and (response[0], int(response[1])) or None + + +def pairs_to_dict(response, decode_keys=False, decode_string_values=False): + """Create a dict given a list of key/value pairs""" + if response is None: + return {} + if decode_keys or decode_string_values: + # the iter form is faster, but I don't know how to make that work + # with a str_if_bytes() map + keys = response[::2] + if decode_keys: + keys = map(str_if_bytes, keys) + values = response[1::2] + if decode_string_values: + values = map(str_if_bytes, values) + return dict(zip(keys, values)) + else: + it = iter(response) + return dict(zip(it, it)) + + +def pairs_to_dict_typed(response, type_info): + it = iter(response) + result = {} + for key, value in zip(it, it): + if key in type_info: + try: + value = type_info[key](value) + except Exception: + # if for some reason the value can't be coerced, just use + # the string value + pass + result[key] = value + return result + + +def zset_score_pairs(response, **options): + """ + If ``withscores`` is specified in the options, return the response as + a list of (value, score) pairs + """ + if not response or not options.get("withscores"): + return response + score_cast_func = options.get("score_cast_func", float) + it = iter(response) + return list(zip(it, map(score_cast_func, it))) + + +def sort_return_tuples(response, **options): + """ + If ``groups`` is specified, return the response as a list of + n-element tuples with n being the value found in options['groups'] + """ + if not response or not options.get("groups"): + return response + n = options["groups"] + return list(zip(*(response[i::n] for i in range(n)))) + + +def int_or_none(response): + if response is None: + return None + return int(response) + + +def parse_stream_list(response): + if response is None: + return None + data = [] + for r in response: + if r is not None: + data.append((r[0], pairs_to_dict(r[1]))) + else: + data.append((None, None)) + return data + + +def pairs_to_dict_with_str_keys(response): + return pairs_to_dict(response, decode_keys=True) + + +def parse_list_of_dicts(response): + return list(map(pairs_to_dict_with_str_keys, response)) + + +def parse_xclaim(response, **options): + if options.get("parse_justid", False): + return response + return parse_stream_list(response) + + +def parse_xautoclaim(response, **options): + if options.get("parse_justid", False): + return response[1] + return parse_stream_list(response[1]) + + +def parse_xinfo_stream(response, **options): + data = pairs_to_dict(response, decode_keys=True) + if not options.get("full", False): + first = data["first-entry"] + if first is not None: + data["first-entry"] = (first[0], pairs_to_dict(first[1])) + last = data["last-entry"] + if last is not None: + data["last-entry"] = (last[0], pairs_to_dict(last[1])) + else: + data["entries"] = {_id: pairs_to_dict(entry) for _id, entry in data["entries"]} + data["groups"] = [ + pairs_to_dict(group, decode_keys=True) for group in data["groups"] + ] + return data + + +def parse_xread(response): + if response is None: + return [] + return [[r[0], parse_stream_list(r[1])] for r in response] + + +def parse_xpending(response, **options): + if options.get("parse_detail", False): + return parse_xpending_range(response) + consumers = [{"name": n, "pending": int(p)} for n, p in response[3] or []] + return { + "pending": response[0], + "min": response[1], + "max": response[2], + "consumers": consumers, + } + + +def parse_xpending_range(response): + k = ("message_id", "consumer", "time_since_delivered", "times_delivered") + return [dict(zip(k, r)) for r in response] + + +def float_or_none(response): + if response is None: + return None + return float(response) + + +def bool_ok(response): + return str_if_bytes(response) == "OK" + + +def parse_zadd(response, **options): + if response is None: + return None + if options.get("as_score"): + return float(response) + return int(response) + + +def parse_client_list(response, **options): + clients = [] + for c in str_if_bytes(response).splitlines(): + # Values might contain '=' + clients.append(dict(pair.split("=", 1) for pair in c.split(" "))) + return clients + + +def parse_config_get(response, **options): + response = [str_if_bytes(i) if i is not None else None for i in response] + return response and pairs_to_dict(response) or {} + + +def parse_scan(response, **options): + cursor, r = response + return int(cursor), r + + +def parse_hscan(response, **options): + cursor, r = response + return int(cursor), r and pairs_to_dict(r) or {} + + +def parse_zscan(response, **options): + score_cast_func = options.get("score_cast_func", float) + cursor, r = response + it = iter(r) + return int(cursor), list(zip(it, map(score_cast_func, it))) + + +def parse_zmscore(response, **options): + # zmscore: list of scores (double precision floating point number) or nil + return [float(score) if score is not None else None for score in response] + + +def parse_slowlog_get(response, **options): + space: Union[str, bytes] = " " if options.get("decode_responses", False) else b" " + + def parse_item(item): + result = { + "id": item[0], + "start_time": int(item[1]), + "duration": int(item[2]), + } + # Redis Enterprise injects another entry at index [3], which has + # the complexity info (i.e. the value N in case the command has + # an O(N) complexity) instead of the command. + if isinstance(item[3], list): + result["command"] = space.join(item[3]) + else: + result["complexity"] = item[3] + result["command"] = space.join(item[4]) + return result + + return [parse_item(item) for item in response] + + +def parse_stralgo(response, **options): + """ + Parse the response from `STRALGO` command. + Without modifiers the returned value is string. + When LEN is given the command returns the length of the result + (i.e integer). + When IDX is given the command returns a dictionary with the LCS + length and all the ranges in both the strings, start and end + offset for each string, where there are matches. + When WITHMATCHLEN is given, each array representing a match will + also have the length of the match at the beginning of the array. + """ + if options.get("len", False): + return int(response) + if options.get("idx", False): + if options.get("withmatchlen", False): + matches = [ + [(int(match[-1]))] + list(map(tuple, match[:-1])) + for match in response[1] + ] + else: + matches = [list(map(tuple, match)) for match in response[1]] + return { + str_if_bytes(response[0]): matches, + str_if_bytes(response[2]): int(response[3]), + } + return str_if_bytes(response) + + +def parse_cluster_info(response, **options): + response = str_if_bytes(response) + return dict(line.split(":") for line in response.splitlines() if line) + + +def _parse_node_line(line): + line_items = line.split(" ") + node_id, addr, flags, master_id, ping, pong, epoch, connected = line.split(" ")[:8] + addr = addr.split("@")[0] + slots = [sl.split("-") for sl in line_items[8:]] + node_dict = { + "node_id": node_id, + "flags": flags, + "master_id": master_id, + "last_ping_sent": ping, + "last_pong_rcvd": pong, + "epoch": epoch, + "slots": slots, + "connected": True if connected == "connected" else False, + } + return addr, node_dict + + +def parse_cluster_nodes(response, **options): + """ + @see: https://redis.io/commands/cluster-nodes # string + @see: https://redis.io/commands/cluster-replicas # list of string + """ + if isinstance(response, str): + response = response.splitlines() + return dict(_parse_node_line(str_if_bytes(node)) for node in response) + + +def parse_geosearch_generic(response, **options): + """ + Parse the response of 'GEOSEARCH', GEORADIUS' and 'GEORADIUSBYMEMBER' + commands according to 'withdist', 'withhash' and 'withcoord' labels. + """ + if options["store"] or options["store_dist"]: + # `store` and `store_dist` cant be combined + # with other command arguments. + # relevant to 'GEORADIUS' and 'GEORADIUSBYMEMBER' + return response + + if type(response) != list: + response_list = [response] + else: + response_list = response + + if not options["withdist"] and not options["withcoord"] and not options["withhash"]: + # just a bunch of places + return response_list + + cast: Dict[str, Callable] = { + "withdist": float, + "withcoord": lambda ll: (float(ll[0]), float(ll[1])), + "withhash": int, + } + + # zip all output results with each casting function to get + # the properly native Python value. + f = [lambda x: x] + f += [cast[o] for o in ["withdist", "withhash", "withcoord"] if options[o]] + return [list(map(lambda fv: fv[0](fv[1]), zip(f, r))) for r in response_list] + + +def parse_command(response, **options): + commands = {} + for command in response: + cmd_dict = {} + cmd_name = str_if_bytes(command[0]) + cmd_dict["name"] = cmd_name + cmd_dict["arity"] = int(command[1]) + cmd_dict["flags"] = [str_if_bytes(flag) for flag in command[2]] + cmd_dict["first_key_pos"] = command[3] + cmd_dict["last_key_pos"] = command[4] + cmd_dict["step_count"] = command[5] + commands[cmd_name] = cmd_dict + return commands + + +def parse_pubsub_numsub(response, **options): + return list(zip(response[0::2], response[1::2])) + + +def parse_client_kill(response, **options): + if isinstance(response, int): + return response + return str_if_bytes(response) == "OK" + + +def parse_acl_getuser(response, **options): + if response is None: + return None + data = pairs_to_dict(response, decode_keys=True) + + # convert everything but user-defined data in 'keys' to native strings + data["flags"] = list(map(str_if_bytes, data["flags"])) + data["passwords"] = list(map(str_if_bytes, data["passwords"])) + data["commands"] = str_if_bytes(data["commands"]) + + # split 'commands' into separate 'categories' and 'commands' lists + commands, categories = [], [] + for command in data["commands"].split(" "): + if "@" in command: + categories.append(command) + else: + commands.append(command) + + data["commands"] = commands + data["categories"] = categories + data["enabled"] = "on" in data["flags"] + return data + + +def parse_acl_log(response, **options): + if response is None: + return None + if isinstance(response, list): + data = [] + for log in response: + log_data = pairs_to_dict(log, True, True) + client_info = log_data.get("client-info", "") + log_data["client-info"] = parse_client_info(client_info) + + # float() is lossy comparing to the "double" in C + log_data["age-seconds"] = float(log_data["age-seconds"]) + data.append(log_data) + else: + data = bool_ok(response) + return data + + +def parse_client_info(value): + """ + Parsing client-info in ACL Log in following format. + "key1=value1 key2=value2 key3=value3" + """ + client_info = {} + infos = str_if_bytes(value).split(" ") + for info in infos: + key, value = info.split("=") + client_info[key] = value + + # Those fields are defined as int in networking.c + for int_key in { + "id", + "age", + "idle", + "db", + "sub", + "psub", + "multi", + "qbuf", + "qbuf-free", + "obl", + "argv-mem", + "oll", + "omem", + "tot-mem", + }: + client_info[int_key] = int(client_info[int_key]) + return client_info + + +def parse_module_result(response): + if isinstance(response, ModuleError): + raise response + return True + + +def parse_set_result(response, **options): + """ + Handle SET result since GET argument is available since Redis 6.2. + Parsing SET result into: + - BOOL + - String when GET argument is used + """ + if options.get("get"): + # Redis will return a getCommand result. + # See `setGenericCommand` in t_string.c + return response + return response and str_if_bytes(response) == "OK" + + +class ResponseCallbackProtocol(Protocol): + def __call__(self, response: Any, **kwargs): + ... + + +class AsyncResponseCallbackProtocol(Protocol): + async def __call__(self, response: Any, **kwargs): + ... + + +ResponseCallbackT = Union[ResponseCallbackProtocol, AsyncResponseCallbackProtocol] + + +_R = TypeVar("_R") + + +class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): + """ + Implementation of the Redis protocol. + + This abstract class provides a Python interface to all Redis commands + and an implementation of the Redis protocol. + + Pipelines derive from this, implementing how + the commands are sent and received to the Redis server. Based on + configuration, an instance will either use a ConnectionPool, or + Connection object to talk to redis. + """ + + RESPONSE_CALLBACKS = { + **string_keys_to_dict( + "AUTH COPY EXPIRE EXPIREAT PEXPIRE PEXPIREAT HEXISTS HMSET LMOVE BLMOVE MOVE " + "MSETNX PERSIST PSETEX RENAMENX SISMEMBER SMOVE SETEX SETNX", + bool, + ), + **string_keys_to_dict( + "BITCOUNT BITPOS DECRBY DEL EXISTS GEOADD GETBIT HDEL HLEN " + "HSTRLEN INCRBY LINSERT LLEN LPUSHX PFADD PFCOUNT RPUSHX SADD " + "SCARD SDIFFSTORE SETBIT SETRANGE SINTERSTORE SREM STRLEN " + "SUNIONSTORE UNLINK XACK XDEL XLEN XTRIM ZCARD ZLEXCOUNT ZREM " + "ZREMRANGEBYLEX ZREMRANGEBYRANK ZREMRANGEBYSCORE", + int, + ), + **string_keys_to_dict("INCRBYFLOAT HINCRBYFLOAT", float), + **string_keys_to_dict( + # these return OK, or int if redis-server is >=1.3.4 + "LPUSH RPUSH", + lambda r: isinstance(r, int) and r or str_if_bytes(r) == "OK", + ), + **string_keys_to_dict("SORT", sort_return_tuples), + **string_keys_to_dict("ZSCORE ZINCRBY GEODIST", float_or_none), + **string_keys_to_dict( + "FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE READONLY READWRITE " + "RENAME SAVE SELECT SHUTDOWN SLAVEOF SWAPDB WATCH UNWATCH ", + bool_ok, + ), + **string_keys_to_dict("BLPOP BRPOP", lambda r: r and tuple(r) or None), + **string_keys_to_dict( + "SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set() + ), + **string_keys_to_dict( + "ZPOPMAX ZPOPMIN ZINTER ZDIFF ZUNION ZRANGE ZRANGEBYSCORE ZREVRANGE " + "ZREVRANGEBYSCORE", + zset_score_pairs, + ), + **string_keys_to_dict( + "BZPOPMIN BZPOPMAX", lambda r: r and (r[0], r[1], float(r[2])) or None + ), + **string_keys_to_dict("ZRANK ZREVRANK", int_or_none), + **string_keys_to_dict("XREVRANGE XRANGE", parse_stream_list), + **string_keys_to_dict("XREAD XREADGROUP", parse_xread), + **string_keys_to_dict("BGREWRITEAOF BGSAVE", lambda r: True), + "ACL CAT": lambda r: list(map(str_if_bytes, r)), + "ACL DELUSER": int, + "ACL GENPASS": str_if_bytes, + "ACL GETUSER": parse_acl_getuser, + "ACL HELP": lambda r: list(map(str_if_bytes, r)), + "ACL LIST": lambda r: list(map(str_if_bytes, r)), + "ACL LOAD": bool_ok, + "ACL LOG": parse_acl_log, + "ACL SAVE": bool_ok, + "ACL SETUSER": bool_ok, + "ACL USERS": lambda r: list(map(str_if_bytes, r)), + "ACL WHOAMI": str_if_bytes, + "CLIENT GETNAME": str_if_bytes, + "CLIENT ID": int, + "CLIENT KILL": parse_client_kill, + "CLIENT LIST": parse_client_list, + "CLIENT INFO": parse_client_info, + "CLIENT SETNAME": bool_ok, + "CLIENT UNBLOCK": lambda r: r and int(r) == 1 or False, + "CLIENT PAUSE": bool_ok, + "CLIENT GETREDIR": int, + "CLIENT TRACKINGINFO": lambda r: list(map(str_if_bytes, r)), + "CLUSTER ADDSLOTS": bool_ok, + "CLUSTER COUNT-FAILURE-REPORTS": lambda x: int(x), + "CLUSTER COUNTKEYSINSLOT": lambda x: int(x), + "CLUSTER DELSLOTS": bool_ok, + "CLUSTER FAILOVER": bool_ok, + "CLUSTER FORGET": bool_ok, + "CLUSTER INFO": parse_cluster_info, + "CLUSTER KEYSLOT": lambda x: int(x), + "CLUSTER MEET": bool_ok, + "CLUSTER NODES": parse_cluster_nodes, + "CLUSTER REPLICATE": bool_ok, + "CLUSTER RESET": bool_ok, + "CLUSTER SAVECONFIG": bool_ok, + "CLUSTER SET-CONFIG-EPOCH": bool_ok, + "CLUSTER SETSLOT": bool_ok, + "CLUSTER SLAVES": parse_cluster_nodes, + "CLUSTER REPLICAS": parse_cluster_nodes, + "COMMAND": parse_command, + "COMMAND COUNT": int, + "COMMAND GETKEYS": lambda r: list(map(str_if_bytes, r)), + "CONFIG GET": parse_config_get, + "CONFIG RESETSTAT": bool_ok, + "CONFIG SET": bool_ok, + "DEBUG OBJECT": parse_debug_object, + "GEOHASH": lambda r: list(map(str_if_bytes, r)), + "GEOPOS": lambda r: list( + map(lambda ll: (float(ll[0]), float(ll[1])) if ll is not None else None, r) + ), + "GEOSEARCH": parse_geosearch_generic, + "GEORADIUS": parse_geosearch_generic, + "GEORADIUSBYMEMBER": parse_geosearch_generic, + "HGETALL": lambda r: r and pairs_to_dict(r) or {}, + "HSCAN": parse_hscan, + "INFO": parse_info, + "LASTSAVE": timestamp_to_datetime, + "MEMORY PURGE": bool_ok, + "MEMORY STATS": parse_memory_stats, + "MEMORY USAGE": int_or_none, + "MODULE LOAD": parse_module_result, + "MODULE UNLOAD": parse_module_result, + "MODULE LIST": lambda r: [pairs_to_dict(m) for m in r], + "OBJECT": parse_object, + "PING": lambda r: str_if_bytes(r) == "PONG", + "QUIT": bool_ok, + "STRALGO": parse_stralgo, + "PUBSUB NUMSUB": parse_pubsub_numsub, + "RANDOMKEY": lambda r: r and r or None, + "SCAN": parse_scan, + "SCRIPT EXISTS": lambda r: list(map(bool, r)), + "SCRIPT FLUSH": bool_ok, + "SCRIPT KILL": bool_ok, + "SCRIPT LOAD": str_if_bytes, + "SENTINEL CKQUORUM": bool_ok, + "SENTINEL FAILOVER": bool_ok, + "SENTINEL FLUSHCONFIG": bool_ok, + "SENTINEL GET-MASTER-ADDR-BY-NAME": parse_sentinel_get_master, + "SENTINEL MASTER": parse_sentinel_master, + "SENTINEL MASTERS": parse_sentinel_masters, + "SENTINEL MONITOR": bool_ok, + "SENTINEL RESET": bool_ok, + "SENTINEL REMOVE": bool_ok, + "SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels, + "SENTINEL SET": bool_ok, + "SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels, + "SET": parse_set_result, + "SLOWLOG GET": parse_slowlog_get, + "SLOWLOG LEN": int, + "SLOWLOG RESET": bool_ok, + "SSCAN": parse_scan, + "TIME": lambda x: (int(x[0]), int(x[1])), + "XCLAIM": parse_xclaim, + "XAUTOCLAIM": parse_xautoclaim, + "XGROUP CREATE": bool_ok, + "XGROUP DELCONSUMER": int, + "XGROUP DESTROY": bool, + "XGROUP SETID": bool_ok, + "XINFO CONSUMERS": parse_list_of_dicts, + "XINFO GROUPS": parse_list_of_dicts, + "XINFO STREAM": parse_xinfo_stream, + "XPENDING": parse_xpending, + "ZADD": parse_zadd, + "ZSCAN": parse_zscan, + "ZMSCORE": parse_zmscore, + } + + response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT] + + @classmethod + def from_url(cls, url: str, **kwargs): + """ + Return a Redis client object configured from the given URL + + For example:: + + redis://[[username]:[password]]@localhost:6379/0 + rediss://[[username]:[password]]@localhost:6379/0 + unix://[[username]:[password]]@/path/to/socket.sock?db=0 + + Three URL schemes are supported: + + - `redis://` creates a TCP socket connection. See more at: + + - `rediss://` creates a SSL wrapped TCP socket connection. See more at: + + - ``unix://``: creates a Unix Domain Socket connection. + + The username, password, hostname, path and all querystring values + are passed through urllib.parse.unquote in order to replace any + percent-encoded values with their corresponding characters. + + There are several ways to specify a database number. The first value + found will be used: + 1. A ``db`` querystring option, e.g. redis://localhost?db=0 + 2. If using the redis:// or rediss:// schemes, the path argument + of the url, e.g. redis://localhost/0 + 3. A ``db`` keyword argument to this function. + + If none of these options are specified, the default db=0 is used. + + All querystring options are cast to their appropriate Python types. + Boolean arguments can be specified with string values "True"/"False" + or "Yes"/"No". Values that cannot be properly cast cause a + ``ValueError`` to be raised. Once parsed, the querystring arguments + and keyword arguments are passed to the ``ConnectionPool``'s + class initializer. In the case of conflicting arguments, querystring + arguments always win. + + """ + connection_pool = ConnectionPool.from_url(url, **kwargs) + return cls(connection_pool=connection_pool) + + def __init__( + self, + *, + host: str = "localhost", + port: int = 6379, + db: Union[str, int] = 0, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + socket_keepalive: Optional[bool] = None, + socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, + connection_pool: Optional[ConnectionPool] = None, + unix_socket_path: Optional[str] = None, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + retry_on_timeout: bool = False, + ssl: bool = False, + ssl_keyfile: Optional[str] = None, + ssl_certfile: Optional[str] = None, + ssl_cert_reqs: str = "required", + ssl_ca_certs: Optional[str] = None, + ssl_check_hostname: bool = False, + max_connections: Optional[int] = None, + single_connection_client: bool = False, + health_check_interval: int = 0, + client_name: Optional[str] = None, + username: Optional[str] = None, + retry: Optional[Retry] = None, + ): + """ + Initialize a new Redis client. + To specify a retry policy, first set `retry_on_timeout` to `True` + then set `retry` to a valid `Retry` object + """ + kwargs: Dict[str, Any] + if not connection_pool: + kwargs = { + "db": db, + "username": username, + "password": password, + "socket_timeout": socket_timeout, + "encoding": encoding, + "encoding_errors": encoding_errors, + "decode_responses": decode_responses, + "retry_on_timeout": retry_on_timeout, + "retry": copy.deepcopy(retry), + "max_connections": max_connections, + "health_check_interval": health_check_interval, + "client_name": client_name, + } + # based on input, setup appropriate connection args + if unix_socket_path is not None: + kwargs.update( + { + "path": unix_socket_path, + "connection_class": UnixDomainSocketConnection, + } + ) + else: + # TCP specific options + kwargs.update( + { + "host": host, + "port": port, + "socket_connect_timeout": socket_connect_timeout, + "socket_keepalive": socket_keepalive, + "socket_keepalive_options": socket_keepalive_options, + } + ) + + if ssl: + kwargs.update( + { + "connection_class": SSLConnection, + "ssl_keyfile": ssl_keyfile, + "ssl_certfile": ssl_certfile, + "ssl_cert_reqs": ssl_cert_reqs, + "ssl_ca_certs": ssl_ca_certs, + "ssl_check_hostname": ssl_check_hostname, + } + ) + connection_pool = ConnectionPool(**kwargs) + self.connection_pool = connection_pool + self.single_connection_client = single_connection_client + self.connection: Optional[Connection] = None + + self.response_callbacks = CaseInsensitiveDict(self.__class__.RESPONSE_CALLBACKS) + + def __repr__(self): + return f"{self.__class__.__name__}<{self.connection_pool!r}>" + + def __await__(self): + return self.initialize().__await__() + + async def initialize(self: _RedisT) -> _RedisT: + if self.single_connection_client and self.connection is None: + self.connection = await self.connection_pool.get_connection("_") + return self + + def set_response_callback(self, command: str, callback: ResponseCallbackT): + """Set a custom Response Callback""" + self.response_callbacks[command] = callback + + def load_external_module( + self, + funcname, + func, + ): + """ + This function can be used to add externally defined redis modules, + and their namespaces to the redis client. + + funcname - A string containing the name of the function to create + func - The function, being added to this class. + + ex: Assume that one has a custom redis module named foomod that + creates command named 'foo.dothing' and 'foo.anotherthing' in redis. + To load function functions into this namespace: + + from redis import Redis + from foomodule import F + r = Redis() + r.load_external_module("foo", F) + r.foo().dothing('your', 'arguments') + + For a concrete example see the reimport of the redisjson module in + tests/test_connection.py::test_loading_external_modules + """ + setattr(self, funcname, func) + + def pipeline( + self, transaction: bool = True, shard_hint: Optional[str] = None + ) -> "Pipeline": + """ + Return a new pipeline object that can queue multiple commands for + later execution. ``transaction`` indicates whether all commands + should be executed atomically. Apart from making a group of operations + atomic, pipelines are useful for reducing the back-and-forth overhead + between the client and server. + """ + return Pipeline( + self.connection_pool, self.response_callbacks, transaction, shard_hint + ) + + async def transaction( + self, + func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]], + *watches: KeyT, + shard_hint: Optional[str] = None, + value_from_callable: bool = False, + watch_delay: Optional[float] = None, + ): + """ + Convenience method for executing the callable `func` as a transaction + while watching all keys specified in `watches`. The 'func' callable + should expect a single argument which is a Pipeline object. + """ + pipe: Pipeline + async with self.pipeline(True, shard_hint) as pipe: + while True: + try: + if watches: + await pipe.watch(*watches) + func_value = func(pipe) + if inspect.isawaitable(func_value): + func_value = await func_value + exec_value = await pipe.execute() + return func_value if value_from_callable else exec_value + except WatchError: + if watch_delay is not None and watch_delay > 0: + await asyncio.sleep(watch_delay) + continue + + def lock( + self, + name: KeyT, + timeout: Optional[float] = None, + sleep: float = 0.1, + blocking_timeout: Optional[float] = None, + lock_class: Optional[Type[Lock]] = None, + thread_local: bool = True, + ) -> Lock: + """ + Return a new Lock object using key ``name`` that mimics + the behavior of threading.Lock. + + If specified, ``timeout`` indicates a maximum life for the lock. + By default, it will remain locked until release() is called. + + ``sleep`` indicates the amount of time to sleep per loop iteration + when the lock is in blocking mode and another client is currently + holding the lock. + + ``blocking_timeout`` indicates the maximum amount of time in seconds to + spend trying to acquire the lock. A value of ``None`` indicates + continue trying forever. ``blocking_timeout`` can be specified as a + float or integer, both representing the number of seconds to wait. + + ``lock_class`` forces the specified lock implementation. + + ``thread_local`` indicates whether the lock token is placed in + thread-local storage. By default, the token is placed in thread local + storage so that a thread only sees its token, not a token set by + another thread. Consider the following timeline: + + time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. + thread-1 sets the token to "abc" + time: 1, thread-2 blocks trying to acquire `my-lock` using the + Lock instance. + time: 5, thread-1 has not yet completed. redis expires the lock + key. + time: 5, thread-2 acquired `my-lock` now that it's available. + thread-2 sets the token to "xyz" + time: 6, thread-1 finishes its work and calls release(). if the + token is *not* stored in thread local storage, then + thread-1 would see the token value as "xyz" and would be + able to successfully release the thread-2's lock. + + In some use cases it's necessary to disable thread local storage. For + example, if you have code where one thread acquires a lock and passes + that lock instance to a worker thread to release later. If thread + local storage isn't disabled in this case, the worker thread won't see + the token set by the thread that acquired the lock. Our assumption + is that these cases aren't common and as such default to using + thread local storage.""" + if lock_class is None: + lock_class = Lock + return lock_class( + self, + name, + timeout=timeout, + sleep=sleep, + blocking_timeout=blocking_timeout, + thread_local=thread_local, + ) + + def pubsub(self, **kwargs) -> "PubSub": + """ + Return a Publish/Subscribe object. With this object, you can + subscribe to channels and listen for messages that get published to + them. + """ + return PubSub(self.connection_pool, **kwargs) + + def monitor(self) -> "Monitor": + return Monitor(self.connection_pool) + + def client(self) -> "Redis": + return self.__class__( + connection_pool=self.connection_pool, single_connection_client=True + ) + + async def __aenter__(self: _RedisT) -> _RedisT: + return await self.initialize() + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() + + _DEL_MESSAGE = "Unclosed Redis client" + + def __del__(self, _warnings: Any = warnings) -> None: + if self.connection is not None: + _warnings.warn( + f"Unclosed client session {self!r}", + ResourceWarning, + source=self, + ) + context = {"client": self, "message": self._DEL_MESSAGE} + asyncio.get_event_loop().call_exception_handler(context) + + async def close(self): + conn = self.connection + if conn: + self.connection = None + await self.connection_pool.release(conn) + + async def _send_command_parse_response(self, conn, command_name, *args, **options): + """ + Send a command and parse the response + """ + await conn.send_command(*args) + return await self.parse_response(conn, command_name, **options) + + async def _disconnect_raise(self, conn: Connection, error: Exception): + """ + Close the connection and raise an exception + if retry_on_timeout is not set or the error + is not a TimeoutError + """ + await conn.disconnect() + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + raise error + + # COMMAND EXECUTION AND PROTOCOL PARSING + async def execute_command(self, *args, **options): + """Execute a command and return a parsed response""" + await self.initialize() + pool = self.connection_pool + command_name = args[0] + conn = self.connection or await pool.get_connection(command_name, **options) + + try: + return await conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_raise(conn, error), + ) + finally: + if not self.connection: + await pool.release(conn) + + async def parse_response( + self, connection: Connection, command_name: Union[str, bytes], **options + ): + """Parses a response from the Redis server""" + try: + if NEVER_DECODE in options: + response = await connection.read_response(disable_encoding=True) + else: + response = await connection.read_response() + except ResponseError: + if EMPTY_RESPONSE in options: + return options[EMPTY_RESPONSE] + raise + if command_name in self.response_callbacks: + # Mypy bug: https://github.com/python/mypy/issues/10977 + command_name = cast(str, command_name) + retval = self.response_callbacks[command_name](response, **options) + return await retval if inspect.isawaitable(retval) else retval + return response + + +StrictRedis = Redis + + +class MonitorCommandInfo(TypedDict): + time: float + db: int + client_address: str + client_port: str + client_type: str + command: str + + +class Monitor: + """ + Monitor is useful for handling the MONITOR command to the redis server. + next_command() method returns one command from monitor + listen() method yields commands from monitor. + """ + + monitor_re = re.compile(r"\[(\d+) (.*)\] (.*)") + command_re = re.compile(r'"(.*?)(? MonitorCommandInfo: + """Parse the response from a monitor command""" + await self.connect() + response = await self.connection.read_response() + if isinstance(response, bytes): + response = self.connection.encoder.decode(response, force=True) + command_time, command_data = response.split(" ", 1) + m = self.monitor_re.match(command_data) + db_id, client_info, command = m.groups() + command = " ".join(self.command_re.findall(command)) + # Redis escapes double quotes because each piece of the command + # string is surrounded by double quotes. We don't have that + # requirement so remove the escaping and leave the quote. + command = command.replace('\\"', '"') + + if client_info == "lua": + client_address = "lua" + client_port = "" + client_type = "lua" + elif client_info.startswith("unix"): + client_address = "unix" + client_port = client_info[5:] + client_type = "unix" + else: + # use rsplit as ipv6 addresses contain colons + client_address, client_port = client_info.rsplit(":", 1) + client_type = "tcp" + return { + "time": float(command_time), + "db": int(db_id), + "client_address": client_address, + "client_port": client_port, + "client_type": client_type, + "command": command, + } + + async def listen(self) -> AsyncIterator[MonitorCommandInfo]: + """Listen for commands coming to the server.""" + while True: + yield await self.next_command() + + +class PubSub: + """ + PubSub provides publish, subscribe and listen support to Redis channels. + + After subscribing to one or more channels, the listen() method will block + until a message arrives on one of the subscribed channels. That message + will be returned and it's safe to start listening again. + """ + + PUBLISH_MESSAGE_TYPES = ("message", "pmessage") + UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe") + HEALTH_CHECK_MESSAGE = "redis-py-health-check" + + def __init__( + self, + connection_pool: ConnectionPool, + shard_hint: Optional[str] = None, + ignore_subscribe_messages: bool = False, + encoder=None, + ): + self.connection_pool = connection_pool + self.shard_hint = shard_hint + self.ignore_subscribe_messages = ignore_subscribe_messages + self.connection = None + # we need to know the encoding options for this connection in order + # to lookup channel and pattern names for callback handlers. + self.encoder = encoder + if self.encoder is None: + self.encoder = self.connection_pool.get_encoder() + if self.encoder.decode_responses: + self.health_check_response: Iterable[Union[str, bytes]] = [ + "pong", + self.HEALTH_CHECK_MESSAGE, + ] + else: + self.health_check_response = [ + b"pong", + self.encoder.encode(self.HEALTH_CHECK_MESSAGE), + ] + self.channels = {} + self.pending_unsubscribe_channels = set() + self.patterns = {} + self.pending_unsubscribe_patterns = set() + self._lock = asyncio.Lock() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.reset() + + def __del__(self): + if self.connection: + self.connection.clear_connect_callbacks() + + async def reset(self): + async with self._lock: + if self.connection: + await self.connection.disconnect() + self.connection.clear_connect_callbacks() + await self.connection_pool.release(self.connection) + self.connection = None + self.channels = {} + self.pending_unsubscribe_channels = set() + self.patterns = {} + self.pending_unsubscribe_patterns = set() + + def close(self) -> Awaitable[NoReturn]: + return self.reset() + + async def on_connect(self, connection: Connection): + """Re-subscribe to any channels and patterns previously subscribed to""" + # NOTE: for python3, we can't pass bytestrings as keyword arguments + # so we need to decode channel/pattern names back to unicode strings + # before passing them to [p]subscribe. + self.pending_unsubscribe_channels.clear() + self.pending_unsubscribe_patterns.clear() + if self.channels: + channels = {} + for k, v in self.channels.items(): + channels[self.encoder.decode(k, force=True)] = v + await self.subscribe(**channels) + if self.patterns: + patterns = {} + for k, v in self.patterns.items(): + patterns[self.encoder.decode(k, force=True)] = v + await self.psubscribe(**patterns) + + @property + def subscribed(self): + """Indicates if there are subscriptions to any channels or patterns""" + return bool(self.channels or self.patterns) + + async def execute_command(self, *args: EncodableT): + """Execute a publish/subscribe command""" + + # NOTE: don't parse the response in this function -- it could pull a + # legitimate message off the stack if the connection is already + # subscribed to one or more channels + + if self.connection is None: + self.connection = await self.connection_pool.get_connection( + "pubsub", self.shard_hint + ) + # register a callback that re-subscribes to any channels we + # were listening to when we were disconnected + self.connection.register_connect_callback(self.on_connect) + connection = self.connection + kwargs = {"check_health": not self.subscribed} + await self._execute(connection, connection.send_command, *args, **kwargs) + + async def _disconnect_raise_connect(self, conn, error): + """ + Close the connection and raise an exception + if retry_on_timeout is not set or the error + is not a TimeoutError. Otherwise, try to reconnect + """ + await conn.disconnect() + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + raise error + await conn.connect() + + async def _execute(self, conn, command, *args, **kwargs): + """ + Connect manually upon disconnection. If the Redis server is down, + this will fail and raise a ConnectionError as desired. + After reconnection, the ``on_connect`` callback should have been + called by the # connection to resubscribe us to any channels and + patterns we were previously listening to + """ + return await conn.retry.call_with_retry( + lambda: command(*args, **kwargs), + lambda error: self._disconnect_raise_connect(conn, error), + ) + + async def parse_response(self, block: bool = True, timeout: float = 0): + """Parse the response from a publish/subscribe command""" + conn = self.connection + if conn is None: + raise RuntimeError( + "pubsub connection not set: " + "did you forget to call subscribe() or psubscribe()?" + ) + + await self.check_health() + + if not block and not await self._execute(conn, conn.can_read, timeout=timeout): + return None + response = await self._execute(conn, conn.read_response) + + if conn.health_check_interval and response == self.health_check_response: + # ignore the health check message as user might not expect it + return None + return response + + async def check_health(self): + conn = self.connection + if conn is None: + raise RuntimeError( + "pubsub connection not set: " + "did you forget to call subscribe() or psubscribe()?" + ) + + if ( + conn.health_check_interval + and asyncio.get_event_loop().time() > conn.next_health_check + ): + await conn.send_command( + "PING", self.HEALTH_CHECK_MESSAGE, check_health=False + ) + + def _normalize_keys(self, data: _NormalizeKeysT) -> _NormalizeKeysT: + """ + normalize channel/pattern names to be either bytes or strings + based on whether responses are automatically decoded. this saves us + from coercing the value for each message coming in. + """ + encode = self.encoder.encode + decode = self.encoder.decode + return {decode(encode(k)): v for k, v in data.items()} # type: ignore[return-value] + + async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler): + """ + Subscribe to channel patterns. Patterns supplied as keyword arguments + expect a pattern name as the key and a callable as the value. A + pattern's callable will be invoked automatically when a message is + received on that pattern rather than producing a message via + ``listen()``. + """ + parsed_args = list_or_args((args[0],), args[1:]) if args else args + new_patterns: Dict[ChannelT, PubSubHandler] = dict.fromkeys(parsed_args) + # Mypy bug: https://github.com/python/mypy/issues/10970 + new_patterns.update(kwargs) # type: ignore[arg-type] + ret_val = await self.execute_command("PSUBSCRIBE", *new_patterns.keys()) + # update the patterns dict AFTER we send the command. we don't want to + # subscribe twice to these patterns, once for the command and again + # for the reconnection. + new_patterns = self._normalize_keys(new_patterns) + self.patterns.update(new_patterns) + self.pending_unsubscribe_patterns.difference_update(new_patterns) + return ret_val + + def punsubscribe(self, *args: ChannelT) -> Awaitable: + """ + Unsubscribe from the supplied patterns. If empty, unsubscribe from + all patterns. + """ + patterns: Iterable[ChannelT] + if args: + parsed_args = list_or_args((args[0],), args[1:]) + patterns = self._normalize_keys(dict.fromkeys(parsed_args)).keys() + else: + parsed_args = [] + patterns = self.patterns + self.pending_unsubscribe_patterns.update(patterns) + return self.execute_command("PUNSUBSCRIBE", *parsed_args) + + async def subscribe(self, *args: ChannelT, **kwargs: Callable): + """ + Subscribe to channels. Channels supplied as keyword arguments expect + a channel name as the key and a callable as the value. A channel's + callable will be invoked automatically when a message is received on + that channel rather than producing a message via ``listen()`` or + ``get_message()``. + """ + parsed_args = list_or_args((args[0],), args[1:]) if args else () + new_channels = dict.fromkeys(parsed_args) + # Mypy bug: https://github.com/python/mypy/issues/10970 + new_channels.update(kwargs) # type: ignore[arg-type] + ret_val = await self.execute_command("SUBSCRIBE", *new_channels.keys()) + # update the channels dict AFTER we send the command. we don't want to + # subscribe twice to these channels, once for the command and again + # for the reconnection. + new_channels = self._normalize_keys(new_channels) + self.channels.update(new_channels) + self.pending_unsubscribe_channels.difference_update(new_channels) + return ret_val + + def unsubscribe(self, *args) -> Awaitable: + """ + Unsubscribe from the supplied channels. If empty, unsubscribe from + all channels + """ + if args: + parsed_args = list_or_args(args[0], args[1:]) + channels = self._normalize_keys(dict.fromkeys(parsed_args)) + else: + parsed_args = [] + channels = self.channels + self.pending_unsubscribe_channels.update(channels) + return self.execute_command("UNSUBSCRIBE", *parsed_args) + + async def listen(self) -> AsyncIterator: + """Listen for messages on channels this client has been subscribed to""" + while self.subscribed: + response = self.handle_message(await self.parse_response(block=True)) + if response is not None: + yield response + + async def get_message( + self, ignore_subscribe_messages: bool = False, timeout: float = 0.0 + ): + """ + Get the next message if one is available, otherwise None. + + If timeout is specified, the system will wait for `timeout` seconds + before returning. Timeout should be specified as a floating point + number. + """ + response = await self.parse_response(block=False, timeout=timeout) + if response: + return self.handle_message(response, ignore_subscribe_messages) + return None + + def ping(self, message=None) -> Awaitable: + """ + Ping the Redis server + """ + message = "" if message is None else message + return self.execute_command("PING", message) + + def handle_message(self, response, ignore_subscribe_messages=False): + """ + Parses a pub/sub message. If the channel or pattern was subscribed to + with a message handler, the handler is invoked instead of a parsed + message being returned. + """ + message_type = str_if_bytes(response[0]) + if message_type == "pmessage": + message = { + "type": message_type, + "pattern": response[1], + "channel": response[2], + "data": response[3], + } + elif message_type == "pong": + message = { + "type": message_type, + "pattern": None, + "channel": None, + "data": response[1], + } + else: + message = { + "type": message_type, + "pattern": None, + "channel": response[1], + "data": response[2], + } + + # if this is an unsubscribe message, remove it from memory + if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES: + if message_type == "punsubscribe": + pattern = response[1] + if pattern in self.pending_unsubscribe_patterns: + self.pending_unsubscribe_patterns.remove(pattern) + self.patterns.pop(pattern, None) + else: + channel = response[1] + if channel in self.pending_unsubscribe_channels: + self.pending_unsubscribe_channels.remove(channel) + self.channels.pop(channel, None) + + if message_type in self.PUBLISH_MESSAGE_TYPES: + # if there's a message handler, invoke it + if message_type == "pmessage": + handler = self.patterns.get(message["pattern"], None) + else: + handler = self.channels.get(message["channel"], None) + if handler: + handler(message) + return None + elif message_type != "pong": + # this is a subscribe/unsubscribe message. ignore if we don't + # want them + if ignore_subscribe_messages or self.ignore_subscribe_messages: + return None + + return message + + async def run( + self, + *, + exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None, + poll_timeout: float = 1.0, + ) -> None: + """Process pub/sub messages using registered callbacks. + + This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in + redis-py, but it is a coroutine. To launch it as a separate task, use + ``asyncio.create_task``: + + >>> task = asyncio.create_task(pubsub.run()) + + To shut it down, use asyncio cancellation: + + >>> task.cancel() + >>> await task + """ + for channel, handler in self.channels.items(): + if handler is None: + raise PubSubError(f"Channel: '{channel}' has no handler registered") + for pattern, handler in self.patterns.items(): + if handler is None: + raise PubSubError(f"Pattern: '{pattern}' has no handler registered") + + while True: + try: + await self.get_message( + ignore_subscribe_messages=True, timeout=poll_timeout + ) + except asyncio.CancelledError: + raise + except BaseException as e: + if exception_handler is None: + raise + res = exception_handler(e, self) + if inspect.isawaitable(res): + await res + # Ensure that other tasks on the event loop get a chance to run + # if we didn't have to block for I/O anywhere. + await asyncio.sleep(0) + + +class PubsubWorkerExceptionHandler(Protocol): + def __call__(self, e: BaseException, pubsub: PubSub): + ... + + +class AsyncPubsubWorkerExceptionHandler(Protocol): + async def __call__(self, e: BaseException, pubsub: PubSub): + ... + + +PSWorkerThreadExcHandlerT = Union[ + PubsubWorkerExceptionHandler, AsyncPubsubWorkerExceptionHandler +] + + +CommandT = Tuple[Tuple[Union[str, bytes], ...], Mapping[str, Any]] +CommandStackT = List[CommandT] + + +class Pipeline(Redis): # lgtm [py/init-calls-subclass] + """ + Pipelines provide a way to transmit multiple commands to the Redis server + in one transmission. This is convenient for batch processing, such as + saving all the values in a list to Redis. + + All commands executed within a pipeline are wrapped with MULTI and EXEC + calls. This guarantees all commands executed in the pipeline will be + executed atomically. + + Any command raising an exception does *not* halt the execution of + subsequent commands in the pipeline. Instead, the exception is caught + and its instance is placed into the response list returned by execute(). + Code iterating over the response list should be able to deal with an + instance of an exception as a potential value. In general, these will be + ResponseError exceptions, such as those raised when issuing a command + on a key of a different datatype. + """ + + UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} + + def __init__( + self, + connection_pool: ConnectionPool, + response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT], + transaction: bool, + shard_hint: Optional[str], + ): + self.connection_pool = connection_pool + self.connection = None + self.response_callbacks = response_callbacks + self.is_transaction = transaction + self.shard_hint = shard_hint + self.watching = False + self.command_stack: CommandStackT = [] + self.scripts: Set["Script"] = set() + self.explicit_transaction = False + + async def __aenter__(self: _RedisT) -> _RedisT: + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.reset() + + def __await__(self): + return self._async_self().__await__() + + _DEL_MESSAGE = "Unclosed Pipeline client" + + def __len__(self): + return len(self.command_stack) + + def __bool__(self): + """Pipeline instances should always evaluate to True""" + return True + + async def _async_self(self): + return self + + async def reset(self): + self.command_stack = [] + self.scripts = set() + # make sure to reset the connection state in the event that we were + # watching something + if self.watching and self.connection: + try: + # call this manually since our unwatch or + # immediate_execute_command methods can call reset() + await self.connection.send_command("UNWATCH") + await self.connection.read_response() + except ConnectionError: + # disconnect will also remove any previous WATCHes + if self.connection: + await self.connection.disconnect() + # clean up the other instance attributes + self.watching = False + self.explicit_transaction = False + # we can safely return the connection to the pool here since we're + # sure we're no longer WATCHing anything + if self.connection: + await self.connection_pool.release(self.connection) + self.connection = None + + def multi(self): + """ + Start a transactional block of the pipeline after WATCH commands + are issued. End the transactional block with `execute`. + """ + if self.explicit_transaction: + raise RedisError("Cannot issue nested calls to MULTI") + if self.command_stack: + raise RedisError( + "Commands without an initial WATCH have already " "been issued" + ) + self.explicit_transaction = True + + def execute_command( + self, *args, **kwargs + ) -> Union["Pipeline", Awaitable["Pipeline"]]: + if (self.watching or args[0] == "WATCH") and not self.explicit_transaction: + return self.immediate_execute_command(*args, **kwargs) + return self.pipeline_execute_command(*args, **kwargs) + + async def _disconnect_reset_raise(self, conn, error): + """ + Close the connection, reset watching state and + raise an exception if we were watching, + retry_on_timeout is not set, + or the error is not a TimeoutError + """ + await conn.disconnect() + # if we were already watching a variable, the watch is no longer + # valid since this connection has died. raise a WatchError, which + # indicates the user should retry this transaction. + if self.watching: + await self.reset() + raise WatchError( + "A ConnectionError occurred on while " "watching one or more keys" + ) + # if retry_on_timeout is not set, or the error is not + # a TimeoutError, raise it + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + await self.reset() + raise + + async def immediate_execute_command(self, *args, **options): + """ + Execute a command immediately, but don't auto-retry on a + ConnectionError if we're already WATCHing a variable. Used when + issuing WATCH or subsequent commands retrieving their values but before + MULTI is called. + """ + command_name = args[0] + conn = self.connection + # if this is the first call, we need a connection + if not conn: + conn = await self.connection_pool.get_connection( + command_name, self.shard_hint + ) + self.connection = conn + + return await conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_reset_raise(conn, error), + ) + + def pipeline_execute_command(self, *args, **options): + """ + Stage a command to be executed when execute() is next called + + Returns the current Pipeline object back so commands can be + chained together, such as: + + pipe = pipe.set('foo', 'bar').incr('baz').decr('bang') + + At some other point, you can then run: pipe.execute(), + which will execute all commands queued in the pipe. + """ + self.command_stack.append((args, options)) + return self + + async def _execute_transaction( # noqa: C901 + self, connection: Connection, commands: CommandStackT, raise_on_error + ): + pre: CommandT = (("MULTI",), {}) + post: CommandT = (("EXEC",), {}) + cmds = (pre, *commands, post) + all_cmds = connection.pack_commands( + args for args, options in cmds if EMPTY_RESPONSE not in options + ) + await connection.send_packed_command(all_cmds) + errors = [] + + # parse off the response for MULTI + # NOTE: we need to handle ResponseErrors here and continue + # so that we read all the additional command messages from + # the socket + try: + await self.parse_response(connection, "_") + except ResponseError as err: + errors.append((0, err)) + + # and all the other commands + for i, command in enumerate(commands): + if EMPTY_RESPONSE in command[1]: + errors.append((i, command[1][EMPTY_RESPONSE])) + else: + try: + await self.parse_response(connection, "_") + except ResponseError as err: + self.annotate_exception(err, i + 1, command[0]) + errors.append((i, err)) + + # parse the EXEC. + try: + response = await self.parse_response(connection, "_") + except ExecAbortError as err: + if errors: + raise errors[0][1] from err + raise + + # EXEC clears any watched keys + self.watching = False + + if response is None: + raise WatchError("Watched variable changed.") from None + + # put any parse errors into the response + for i, e in errors: + response.insert(i, e) + + if len(response) != len(commands): + if self.connection: + await self.connection.disconnect() + raise ResponseError( + "Wrong number of response items from pipeline execution" + ) from None + + # find any errors in the response and raise if necessary + if raise_on_error: + self.raise_first_error(commands, response) + + # We have to run response callbacks manually + data = [] + for r, cmd in zip(response, commands): + if not isinstance(r, Exception): + args, options = cmd + command_name = args[0] + if command_name in self.response_callbacks: + r = self.response_callbacks[command_name](r, **options) + if inspect.isawaitable(r): + r = await r + data.append(r) + return data + + async def _execute_pipeline( + self, connection: Connection, commands: CommandStackT, raise_on_error: bool + ): + # build up all commands into a single request to increase network perf + all_cmds = connection.pack_commands([args for args, _ in commands]) + await connection.send_packed_command(all_cmds) + + response = [] + for args, options in commands: + try: + response.append( + await self.parse_response(connection, args[0], **options) + ) + except ResponseError as e: + response.append(e) + + if raise_on_error: + self.raise_first_error(commands, response) + return response + + def raise_first_error(self, commands: CommandStackT, response: Iterable[Any]): + for i, r in enumerate(response): + if isinstance(r, ResponseError): + self.annotate_exception(r, i + 1, commands[i][0]) + raise r + + def annotate_exception( + self, exception: Exception, number: int, command: Iterable[object] + ) -> None: + cmd = " ".join(map(safe_str, command)) + msg = f"Command # {number} ({cmd}) of pipeline caused error: {exception.args}" + exception.args = (msg,) + exception.args[1:] + + async def parse_response( + self, connection: Connection, command_name: Union[str, bytes], **options + ): + result = await super().parse_response(connection, command_name, **options) + if command_name in self.UNWATCH_COMMANDS: + self.watching = False + elif command_name == "WATCH": + self.watching = True + return result + + async def load_scripts(self): + # make sure all scripts that are about to be run on this pipeline exist + scripts = list(self.scripts) + immediate = self.immediate_execute_command + shas = [s.sha for s in scripts] + # we can't use the normal script_* methods because they would just + # get buffered in the pipeline. + exists = await immediate("SCRIPT EXISTS", *shas) + if not all(exists): + for s, exist in zip(scripts, exists): + if not exist: + s.sha = await immediate("SCRIPT LOAD", s.script) + + async def _disconnect_raise_reset(self, conn: Connection, error: Exception): + """ + Close the connection, raise an exception if we were watching, + and raise an exception if retry_on_timeout is not set, + or the error is not a TimeoutError + """ + await conn.disconnect() + # if we were watching a variable, the watch is no longer valid + # since this connection has died. raise a WatchError, which + # indicates the user should retry this transaction. + if self.watching: + raise WatchError( + "A ConnectionError occurred on while " "watching one or more keys" + ) + # if retry_on_timeout is not set, or the error is not + # a TimeoutError, raise it + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + await self.reset() + raise + + async def execute(self, raise_on_error: bool = True): + """Execute all the commands in the current pipeline""" + stack = self.command_stack + if not stack and not self.watching: + return [] + if self.scripts: + await self.load_scripts() + if self.is_transaction or self.explicit_transaction: + execute = self._execute_transaction + else: + execute = self._execute_pipeline + + conn = self.connection + if not conn: + conn = await self.connection_pool.get_connection("MULTI", self.shard_hint) + # assign to self.connection so reset() releases the connection + # back to the pool after we're done + self.connection = conn + conn = cast(Connection, conn) + + try: + return await conn.retry.call_with_retry( + lambda: execute(conn, stack, raise_on_error), + lambda error: self._disconnect_raise_reset(conn, error), + ) + finally: + await self.reset() + + async def discard(self): + """Flushes all previously queued commands + See: https://redis.io/commands/DISCARD + """ + await self.execute_command("DISCARD") + + async def watch(self, *names: KeyT): + """Watches the values at keys ``names``""" + if self.explicit_transaction: + raise RedisError("Cannot issue a WATCH after a MULTI") + return await self.execute_command("WATCH", *names) + + async def unwatch(self): + """Unwatches all previously specified keys""" + return self.watching and await self.execute_command("UNWATCH") or True diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py new file mode 100644 index 0000000000..cefb5f2cc2 --- /dev/null +++ b/redis/asyncio/connection.py @@ -0,0 +1,1696 @@ +import asyncio +import copy +import enum +import errno +import inspect +import io +import os +import socket +import ssl +import threading +import weakref +from itertools import chain +from types import MappingProxyType +from typing import ( + Any, + Callable, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) +from urllib.parse import ParseResult, parse_qs, unquote, urlparse + +import async_timeout + +from redis.asyncio.retry import Retry +from redis.backoff import NoBackoff +from redis.compat import Protocol, TypedDict +from redis.exceptions import ( + AuthenticationError, + AuthenticationWrongNumberOfArgsError, + BusyLoadingError, + ChildDeadlockedError, + ConnectionError, + DataError, + ExecAbortError, + InvalidResponse, + ModuleError, + NoPermissionError, + NoScriptError, + ReadOnlyError, + RedisError, + ResponseError, + TimeoutError, +) +from redis.typing import EncodableT, EncodedT +from redis.utils import HIREDIS_AVAILABLE, str_if_bytes + +hiredis = None +if HIREDIS_AVAILABLE: + import hiredis + +NONBLOCKING_EXCEPTION_ERROR_NUMBERS = { + BlockingIOError: errno.EWOULDBLOCK, + ssl.SSLWantReadError: 2, + ssl.SSLWantWriteError: 2, + ssl.SSLError: 2, +} + +NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys()) + + +SYM_STAR = b"*" +SYM_DOLLAR = b"$" +SYM_CRLF = b"\r\n" +SYM_LF = b"\n" +SYM_EMPTY = b"" + +SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." + + +class _Sentinel(enum.Enum): + sentinel = object() + + +SENTINEL = _Sentinel.sentinel +MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs." +NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" +MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible." +MODULE_EXPORTS_DATA_TYPES_ERROR = ( + "Error unloading module: the module " + "exports one or more module-side data " + "types, can't unload" +) + + +class _HiredisReaderArgs(TypedDict, total=False): + protocolError: Callable[[str], Exception] + replyError: Callable[[str], Exception] + encoding: Optional[str] + errors: Optional[str] + + +class Encoder: + """Encode strings to bytes-like and decode bytes-like to strings""" + + __slots__ = "encoding", "encoding_errors", "decode_responses" + + def __init__(self, encoding: str, encoding_errors: str, decode_responses: bool): + self.encoding = encoding + self.encoding_errors = encoding_errors + self.decode_responses = decode_responses + + def encode(self, value: EncodableT) -> EncodedT: + """Return a bytestring or bytes-like representation of the value""" + if isinstance(value, (bytes, memoryview)): + return value + if isinstance(value, bool): + # special case bool since it is a subclass of int + raise DataError( + "Invalid input of type: 'bool'. " + "Convert to a bytes, string, int or float first." + ) + if isinstance(value, (int, float)): + return repr(value).encode() + if not isinstance(value, str): + # a value we don't know how to deal with. throw an error + typename = value.__class__.__name__ # type: ignore[unreachable] + raise DataError( + f"Invalid input of type: {typename!r}. " + "Convert to a bytes, string, int or float first." + ) + return value.encode(self.encoding, self.encoding_errors) + + def decode(self, value: EncodableT, force=False) -> EncodableT: + """Return a unicode string from the bytes-like representation""" + if self.decode_responses or force: + if isinstance(value, memoryview): + return value.tobytes().decode(self.encoding, self.encoding_errors) + if isinstance(value, bytes): + return value.decode(self.encoding, self.encoding_errors) + return value + + +ExceptionMappingT = Mapping[str, Union[Type[Exception], Mapping[str, Type[Exception]]]] + + +class BaseParser: + """Plain Python parsing class""" + + __slots__ = "_stream", "_buffer", "_read_size" + + EXCEPTION_CLASSES: ExceptionMappingT = { + "ERR": { + "max number of clients reached": ConnectionError, + "Client sent AUTH, but no password is set": AuthenticationError, + "invalid password": AuthenticationError, + # some Redis server versions report invalid command syntax + # in lowercase + "wrong number of arguments for 'auth' command": AuthenticationWrongNumberOfArgsError, + # some Redis server versions report invalid command syntax + # in uppercase + "wrong number of arguments for 'AUTH' command": AuthenticationWrongNumberOfArgsError, + MODULE_LOAD_ERROR: ModuleError, + MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError, + NO_SUCH_MODULE_ERROR: ModuleError, + MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError, + }, + "EXECABORT": ExecAbortError, + "LOADING": BusyLoadingError, + "NOSCRIPT": NoScriptError, + "READONLY": ReadOnlyError, + "NOAUTH": AuthenticationError, + "NOPERM": NoPermissionError, + } + + def __init__(self, socket_read_size: int): + self._stream: Optional[asyncio.StreamReader] = None + self._buffer: Optional[SocketBuffer] = None + self._read_size = socket_read_size + + def __del__(self): + try: + self.on_disconnect() + except Exception: + pass + + def parse_error(self, response: str) -> ResponseError: + """Parse an error response""" + error_code = response.split(" ")[0] + if error_code in self.EXCEPTION_CLASSES: + response = response[len(error_code) + 1 :] + exception_class = self.EXCEPTION_CLASSES[error_code] + if isinstance(exception_class, dict): + exception_class = exception_class.get(response, ResponseError) + return exception_class(response) + return ResponseError(response) + + def on_disconnect(self): + raise NotImplementedError() + + def on_connect(self, connection: "Connection"): + raise NotImplementedError() + + async def can_read(self, timeout: float) -> bool: + raise NotImplementedError() + + async def read_response( + self, disable_decoding: bool = False + ) -> Union[EncodableT, ResponseError, None, List[EncodableT]]: + raise NotImplementedError() + + +class SocketBuffer: + """Async-friendly re-impl of redis-py's SocketBuffer. + + TODO: We're currently passing through two buffers, + the asyncio.StreamReader and this. I imagine we can reduce the layers here + while maintaining compliance with prior art. + """ + + def __init__( + self, + stream_reader: asyncio.StreamReader, + socket_read_size: int, + socket_timeout: Optional[float], + ): + self._stream: Optional[asyncio.StreamReader] = stream_reader + self.socket_read_size = socket_read_size + self.socket_timeout = socket_timeout + self._buffer: Optional[io.BytesIO] = io.BytesIO() + # number of bytes written to the buffer from the socket + self.bytes_written = 0 + # number of bytes read from the buffer + self.bytes_read = 0 + + @property + def length(self): + return self.bytes_written - self.bytes_read + + async def _read_from_socket( + self, + length: Optional[int] = None, + timeout: Union[float, None, _Sentinel] = SENTINEL, + raise_on_timeout: bool = True, + ) -> bool: + buf = self._buffer + if buf is None or self._stream is None: + raise RedisError("Buffer is closed.") + buf.seek(self.bytes_written) + marker = 0 + timeout = timeout if timeout is not SENTINEL else self.socket_timeout + + try: + while True: + async with async_timeout.timeout(timeout): + data = await self._stream.read(self.socket_read_size) + # an empty string indicates the server shutdown the socket + if isinstance(data, bytes) and len(data) == 0: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + buf.write(data) + data_length = len(data) + self.bytes_written += data_length + marker += data_length + + if length is not None and length > marker: + continue + return True + except (socket.timeout, asyncio.TimeoutError): + if raise_on_timeout: + raise TimeoutError("Timeout reading from socket") + return False + except NONBLOCKING_EXCEPTIONS as ex: + # if we're in nonblocking mode and the recv raises a + # blocking error, simply return False indicating that + # there's no data to be read. otherwise raise the + # original exception. + allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) + if not raise_on_timeout and ex.errno == allowed: + return False + raise ConnectionError(f"Error while reading from socket: {ex.args}") + + async def can_read(self, timeout: float) -> bool: + return bool(self.length) or await self._read_from_socket( + timeout=timeout, raise_on_timeout=False + ) + + async def read(self, length: int) -> bytes: + length = length + 2 # make sure to read the \r\n terminator + # make sure we've read enough data from the socket + if length > self.length: + await self._read_from_socket(length - self.length) + + if self._buffer is None: + raise RedisError("Buffer is closed.") + + self._buffer.seek(self.bytes_read) + data = self._buffer.read(length) + self.bytes_read += len(data) + + # purge the buffer when we've consumed it all so it doesn't + # grow forever + if self.bytes_read == self.bytes_written: + self.purge() + + return data[:-2] + + async def readline(self) -> bytes: + buf = self._buffer + if buf is None: + raise RedisError("Buffer is closed.") + + buf.seek(self.bytes_read) + data = buf.readline() + while not data.endswith(SYM_CRLF): + # there's more data in the socket that we need + await self._read_from_socket() + buf.seek(self.bytes_read) + data = buf.readline() + + self.bytes_read += len(data) + + # purge the buffer when we've consumed it all so it doesn't + # grow forever + if self.bytes_read == self.bytes_written: + self.purge() + + return data[:-2] + + def purge(self): + if self._buffer is None: + raise RedisError("Buffer is closed.") + + self._buffer.seek(0) + self._buffer.truncate() + self.bytes_written = 0 + self.bytes_read = 0 + + def close(self): + try: + self.purge() + self._buffer.close() # type: ignore[union-attr] + except Exception: + # issue #633 suggests the purge/close somehow raised a + # BadFileDescriptor error. Perhaps the client ran out of + # memory or something else? It's probably OK to ignore + # any error being raised from purge/close since we're + # removing the reference to the instance below. + pass + self._buffer = None + self._stream = None + + +class PythonParser(BaseParser): + """Plain Python parsing class""" + + __slots__ = BaseParser.__slots__ + ("encoder",) + + def __init__(self, socket_read_size: int): + super().__init__(socket_read_size) + self.encoder: Optional[Encoder] = None + + def on_connect(self, connection: "Connection"): + """Called when the stream connects""" + self._stream = connection._reader + if self._stream is None: + raise RedisError("Buffer is closed.") + + self._buffer = SocketBuffer( + self._stream, self._read_size, connection.socket_timeout + ) + self.encoder = connection.encoder + + def on_disconnect(self): + """Called when the stream disconnects""" + if self._stream is not None: + self._stream = None + if self._buffer is not None: + self._buffer.close() + self._buffer = None + self.encoder = None + + async def can_read(self, timeout: float): + return self._buffer and bool(await self._buffer.can_read(timeout)) + + async def read_response(self, disable_decoding: bool = False) -> Union[EncodableT, ResponseError, None]: + if not self._buffer or not self.encoder: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + raw = await self._buffer.readline() + if not raw: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + response: Any + byte, response = raw[:1], raw[1:] + + if byte not in (b"-", b"+", b":", b"$", b"*"): + raise InvalidResponse(f"Protocol Error: {raw!r}") + + # server returned an error + if byte == b"-": + response = response.decode("utf-8", errors="replace") + error = self.parse_error(response) + # if the error is a ConnectionError, raise immediately so the user + # is notified + if isinstance(error, ConnectionError): + raise error + # otherwise, we're dealing with a ResponseError that might belong + # inside a pipeline response. the connection's read_response() + # and/or the pipeline's execute() will raise this error if + # necessary, so just return the exception instance here. + return error + # single value + elif byte == b"+": + pass + # int value + elif byte == b":": + response = int(response) + # bulk response + elif byte == b"$": + length = int(response) + if length == -1: + return None + response = await self._buffer.read(length) + # multi-bulk response + elif byte == b"*": + length = int(response) + if length == -1: + return None + response = [(await self.read_response(disable_decoding)) for _ in range(length)] + if isinstance(response, bytes) and disable_decoding is False: + response = self.encoder.decode(response) + return response + + +class HiredisParser(BaseParser): + """Parser class for connections using Hiredis""" + + __slots__ = BaseParser.__slots__ + ("_next_response", "_reader", "_socket_timeout") + + _next_response: bool + + def __init__(self, socket_read_size: int): + if not HIREDIS_AVAILABLE: + raise RedisError("Hiredis is not available.") + super().__init__(socket_read_size=socket_read_size) + self._reader: Optional[hiredis.Reader] = None + self._socket_timeout: Optional[float] = None + + def on_connect(self, connection: "Connection"): + self._stream = connection._reader + kwargs: _HiredisReaderArgs = { + "protocolError": InvalidResponse, + "replyError": self.parse_error, + } + if connection.encoder.decode_responses: + kwargs["encoding"] = connection.encoder.encoding + kwargs["errors"] = connection.encoder.encoding_errors + + self._reader = hiredis.Reader(**kwargs) + self._next_response = False + self._socket_timeout = connection.socket_timeout + + def on_disconnect(self): + self._stream = None + self._reader = None + self._next_response = False + + async def can_read(self, timeout: float): + if not self._reader: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + + if self._next_response is False: + self._next_response = self._reader.gets() + if self._next_response is False: + return await self.read_from_socket(timeout=timeout, raise_on_timeout=False) + return True + + async def read_from_socket( + self, + timeout: Union[float, None, _Sentinel] = SENTINEL, + raise_on_timeout: bool = True, + ): + if self._stream is None or self._reader is None: + raise RedisError("Parser already closed.") + + timeout = self._socket_timeout if timeout is SENTINEL else timeout + try: + async with async_timeout.timeout(timeout): + buffer = await self._stream.read(self._read_size) + if not isinstance(buffer, bytes) or len(buffer) == 0: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None + self._reader.feed(buffer) + # data was read from the socket and added to the buffer. + # return True to indicate that data was read. + return True + except asyncio.CancelledError: + raise + except (socket.timeout, asyncio.TimeoutError): + if raise_on_timeout: + raise TimeoutError("Timeout reading from socket") from None + return False + except NONBLOCKING_EXCEPTIONS as ex: + # if we're in nonblocking mode and the recv raises a + # blocking error, simply return False indicating that + # there's no data to be read. otherwise raise the + # original exception. + allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) + if not raise_on_timeout and ex.errno == allowed: + return False + raise ConnectionError(f"Error while reading from socket: {ex.args}") + + async def read_response(self, disable_decoding: bool = False) -> Union[EncodableT, List[EncodableT]]: + if not self._stream or not self._reader: + self.on_disconnect() + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None + + response: Union[ + EncodableT, ConnectionError, List[Union[EncodableT, ConnectionError]] + ] + # _next_response might be cached from a can_read() call + if self._next_response is not False: + response = self._next_response + self._next_response = False + return response + + response = self._reader.gets() + while response is False: + await self.read_from_socket() + response = self._reader.gets() + + # if the response is a ConnectionError or the response is a list and + # the first item is a ConnectionError, raise it as something bad + # happened + if isinstance(response, ConnectionError): + raise response + elif ( + isinstance(response, list) + and response + and isinstance(response[0], ConnectionError) + ): + raise response[0] + # cast as there won't be a ConnectionError here. + return cast(Union[EncodableT, List[EncodableT]], response) + + +DefaultParser: Type[Union[PythonParser, HiredisParser]] +if HIREDIS_AVAILABLE: + DefaultParser = HiredisParser +else: + DefaultParser = PythonParser + + +class ConnectCallbackProtocol(Protocol): + def __call__(self, connection: "Connection"): + ... + + +class AsyncConnectCallbackProtocol(Protocol): + async def __call__(self, connection: "Connection"): + ... + + +ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol] + + +class Connection: + """Manages TCP communication to and from a Redis server""" + + __slots__ = ( + "pid", + "host", + "port", + "db", + "username", + "client_name", + "password", + "socket_timeout", + "socket_connect_timeout", + "socket_keepalive", + "socket_keepalive_options", + "socket_type", + "redis_connect_func", + "retry_on_timeout", + "health_check_interval", + "next_health_check", + "last_active_at", + "encoder", + "ssl_context", + "_reader", + "_writer", + "_parser", + "_connect_callbacks", + "_buffer_cutoff", + "_lock", + "_socket_read_size", + "__dict__", + ) + + def __init__( + self, + *, + host: str = "localhost", + port: Union[str, int] = 6379, + db: Union[str, int] = 0, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + socket_keepalive: bool = False, + socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, + socket_type: int = 0, + retry_on_timeout: bool = False, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + parser_class: Type[BaseParser] = DefaultParser, + socket_read_size: int = 65536, + health_check_interval: float = 0, + client_name: Optional[str] = None, + username: Optional[str] = None, + retry: Optional[Retry] = None, + redis_connect_func: Optional[ConnectCallbackT] = None, + encoder_class: Type[Encoder] = Encoder, + ): + self.pid = os.getpid() + self.host = host + self.port = int(port) + self.db = db + self.username = username + self.client_name = client_name + self.password = password + self.socket_timeout = socket_timeout + self.socket_connect_timeout = socket_connect_timeout or socket_timeout or None + self.socket_keepalive = socket_keepalive + self.socket_keepalive_options = socket_keepalive_options or {} + self.socket_type = socket_type + self.retry_on_timeout = retry_on_timeout + if retry_on_timeout: + if retry is None: + self.retry = Retry(NoBackoff(), 1) + else: + # deep-copy the Retry object as it is mutable + self.retry = copy.deepcopy(retry) + else: + self.retry = Retry(NoBackoff(), 0) + self.health_check_interval = health_check_interval + self.next_health_check: float = -1 + self.ssl_context: Optional[RedisSSLContext] = None + self.encoder = encoder_class(encoding, encoding_errors, decode_responses) + self.redis_connect_func = redis_connect_func + self._reader: Optional[asyncio.StreamReader] = None + self._writer: Optional[asyncio.StreamWriter] = None + self._socket_read_size = socket_read_size + self.set_parser(parser_class) + self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = [] + self._buffer_cutoff = 6000 + self._lock = asyncio.Lock() + + def __repr__(self): + repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces())) + return f"{self.__class__.__name__}<{repr_args}>" + + def repr_pieces(self): + pieces = [("host", self.host), ("port", self.port), ("db", self.db)] + if self.client_name: + pieces.append(("client_name", self.client_name)) + return pieces + + def __del__(self): + try: + if self.is_connected: + loop = asyncio.get_event_loop() + coro = self.disconnect() + if loop.is_running(): + loop.create_task(coro) + else: + loop.run_until_complete(coro) + except Exception: + pass + + @property + def is_connected(self): + return bool(self._reader and self._writer) + + def register_connect_callback(self, callback): + self._connect_callbacks.append(weakref.WeakMethod(callback)) + + def clear_connect_callbacks(self): + self._connect_callbacks = [] + + def set_parser(self, parser_class): + """ + Creates a new instance of parser_class with socket size: + _socket_read_size and assigns it to the parser for the connection + :param parser_class: The required parser class + """ + self._parser = parser_class(socket_read_size=self._socket_read_size) + + async def connect(self): + """Connects to the Redis server if not already connected""" + if self.is_connected: + return + try: + await self._connect() + except asyncio.CancelledError: + raise + except (socket.timeout, asyncio.TimeoutError): + raise TimeoutError("Timeout connecting to server") + except OSError as e: + raise ConnectionError(self._error_message(e)) + except Exception as exc: + raise ConnectionError(exc) from exc + + try: + if self.redis_connect_func is None: + # Use the default on_connect function + await self.on_connect() + else: + # Use the passed function redis_connect_func + await self.redis_connect_func(self) if asyncio.iscoroutinefunction( + self.redis_connect_func + ) else self.redis_connect_func(self) + except RedisError: + # clean up after any error in on_connect + await self.disconnect() + raise + + # run any user callbacks. right now the only internal callback + # is for pubsub channel/pattern resubscription + for ref in self._connect_callbacks: + callback = ref() + task = callback(self) + if task and inspect.isawaitable(task): + await task + + async def _connect(self): + """Create a TCP socket connection""" + async with async_timeout.timeout(self.socket_connect_timeout): + reader, writer = await asyncio.open_connection( + host=self.host, + port=self.port, + ssl=self.ssl_context.get() if self.ssl_context else None, + ) + self._reader = reader + self._writer = writer + sock = writer.transport.get_extra_info("socket") + if sock is not None: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + try: + # TCP_KEEPALIVE + if self.socket_keepalive: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + for k, v in self.socket_keepalive_options.items(): + sock.setsockopt(socket.SOL_TCP, k, v) + + except (OSError, TypeError): + # `socket_keepalive_options` might contain invalid options + # causing an error. Do not leave the connection open. + writer.close() + raise + + def _error_message(self, exception): + # args for socket.error can either be (errno, "message") + # or just "message" + if len(exception.args) == 1: + return f"Error connecting to {self.host}:{self.port}. {exception.args[0]}." + else: + return ( + f"Error {exception.args[0]} connecting to {self.host}:{self.port}. " + f"{exception.args[0]}." + ) + + async def on_connect(self): + """Initialize the connection, authenticate and select a database""" + self._parser.on_connect(self) + + # if username and/or password are set, authenticate + if self.username or self.password: + auth_args: Union[Tuple[str], Tuple[str, str]] + if self.username: + auth_args = (self.username, self.password or "") + else: + # Mypy bug: https://github.com/python/mypy/issues/10944 + auth_args = (self.password or "",) + # avoid checking health here -- PING will fail if we try + # to check the health prior to the AUTH + await self.send_command("AUTH", *auth_args, check_health=False) + + try: + auth_response = await self.read_response() + except AuthenticationWrongNumberOfArgsError: + # a username and password were specified but the Redis + # server seems to be < 6.0.0 which expects a single password + # arg. retry auth with just the password. + # https://github.com/andymccurdy/redis-py/issues/1274 + await self.send_command("AUTH", self.password, check_health=False) + auth_response = await self.read_response() + + if str_if_bytes(auth_response) != "OK": + raise AuthenticationError("Invalid Username or Password") + + # if a client_name is given, set it + if self.client_name: + await self.send_command("CLIENT", "SETNAME", self.client_name) + if str_if_bytes(await self.read_response()) != "OK": + raise ConnectionError("Error setting client name") + + # if a database is specified, switch to it + if self.db: + await self.send_command("SELECT", self.db) + if str_if_bytes(await self.read_response()) != "OK": + raise ConnectionError("Invalid Database") + + async def disconnect(self): + """Disconnects from the Redis server""" + try: + async with async_timeout.timeout(self.socket_connect_timeout): + self._parser.on_disconnect() + if not self.is_connected: + return + try: + if os.getpid() == self.pid: + self._writer.close() # type: ignore[union-attr] + # py3.6 doesn't have this method + if hasattr(self._writer, "wait_closed"): + await self._writer.wait_closed() # type: ignore[union-attr] + except OSError: + pass + self._reader = None + self._writer = None + except asyncio.TimeoutError: + raise TimeoutError( + f"Timed out closing connection after {self.socket_connect_timeout}" + ) from None + + async def _send_ping(self): + """Send PING, expect PONG in return""" + await self.send_command("PING", check_health=False) + if str_if_bytes(await self.read_response()) != "PONG": + raise ConnectionError("Bad response from PING health check") + + async def _ping_failed(self, error): + """Function to call when PING fails""" + await self.disconnect() + + async def check_health(self): + """Check the health of the connection with a PING/PONG""" + if ( + self.health_check_interval + and asyncio.get_running_loop().time() > self.next_health_check + ): + await self.retry.call_with_retry(self._send_ping, self._ping_failed) + + async def _send_packed_command(self, command: Iterable[bytes]) -> None: + if self._writer is None: + raise RedisError("Connection already closed.") + + self._writer.writelines(command) + await self._writer.drain() + + async def send_packed_command( + self, + command: Union[bytes, str, Iterable[bytes]], + check_health: bool = True, + ): + """Send an already packed command to the Redis server""" + if not self._writer: + await self.connect() + # guard against health check recursion + if check_health: + await self.check_health() + try: + if isinstance(command, str): + command = command.encode() + if isinstance(command, bytes): + command = [command] + await asyncio.wait_for( + self._send_packed_command(command), + self.socket_timeout, + ) + except asyncio.TimeoutError: + await self.disconnect() + raise TimeoutError("Timeout writing to socket") from None + except OSError as e: + await self.disconnect() + if len(e.args) == 1: + err_no, errmsg = "UNKNOWN", e.args[0] + else: + err_no = e.args[0] + errmsg = e.args[1] + raise ConnectionError( + f"Error {err_no} while writing to socket. {errmsg}." + ) from e + except BaseException: + await self.disconnect() + raise + + async def send_command(self, *args, **kwargs): + """Pack and send a command to the Redis server""" + if not self.is_connected: + await self.connect() + await self.send_packed_command( + self.pack_command(*args), check_health=kwargs.get("check_health", True) + ) + + async def can_read(self, timeout: float = 0): + """Poll the socket to see if there's data that can be read.""" + if not self.is_connected: + await self.connect() + return await self._parser.can_read(timeout) + + async def read_response(self, disable_decoding: bool = False): + """Read the response from a previously sent command""" + try: + async with self._lock: + async with async_timeout.timeout(self.socket_timeout): + response = await self._parser.read_response( + disable_decoding=disable_decoding + ) + except asyncio.TimeoutError: + await self.disconnect() + raise TimeoutError(f"Timeout reading from {self.host}:{self.port}") + except OSError as e: + await self.disconnect() + raise ConnectionError( + f"Error while reading from {self.host}:{self.port} : {e.args}" + ) + except BaseException: + await self.disconnect() + raise + + if self.health_check_interval: + self.next_health_check = ( + asyncio.get_running_loop().time() + self.health_check_interval + ) + + if isinstance(response, ResponseError): + raise response from None + return response + + def pack_command(self, *args: EncodableT) -> List[bytes]: + """Pack a series of arguments into the Redis protocol""" + output = [] + # the client might have included 1 or more literal arguments in + # the command name, e.g., 'CONFIG GET'. The Redis server expects these + # arguments to be sent separately, so split the first argument + # manually. These arguments should be bytestrings so that they are + # not encoded. + assert not isinstance(args[0], float) + if isinstance(args[0], str): + args = tuple(args[0].encode().split()) + args[1:] + elif b" " in args[0]: + args = tuple(args[0].split()) + args[1:] + + buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) + + buffer_cutoff = self._buffer_cutoff + for arg in map(self.encoder.encode, args): + # to avoid large string mallocs, chunk the command into the + # output list if we're sending large values or memoryviews + arg_length = len(arg) + if ( + len(buff) > buffer_cutoff + or arg_length > buffer_cutoff + or isinstance(arg, memoryview) + ): + buff = SYM_EMPTY.join( + (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) + ) + output.append(buff) + output.append(arg) + buff = SYM_CRLF + else: + buff = SYM_EMPTY.join( + ( + buff, + SYM_DOLLAR, + str(arg_length).encode(), + SYM_CRLF, + arg, + SYM_CRLF, + ) + ) + output.append(buff) + return output + + def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]: + """Pack multiple commands into the Redis protocol""" + output: List[bytes] = [] + pieces: List[bytes] = [] + buffer_length = 0 + buffer_cutoff = self._buffer_cutoff + + for cmd in commands: + for chunk in self.pack_command(*cmd): + chunklen = len(chunk) + if ( + buffer_length > buffer_cutoff + or chunklen > buffer_cutoff + or isinstance(chunk, memoryview) + ): + output.append(SYM_EMPTY.join(pieces)) + buffer_length = 0 + pieces = [] + + if chunklen > buffer_cutoff or isinstance(chunk, memoryview): + output.append(chunk) + else: + pieces.append(chunk) + buffer_length += chunklen + + if pieces: + output.append(SYM_EMPTY.join(pieces)) + return output + + +class SSLConnection(Connection): + def __init__( + self, + ssl_keyfile: Optional[str] = None, + ssl_certfile: Optional[str] = None, + ssl_cert_reqs: str = "required", + ssl_ca_certs: Optional[str] = None, + ssl_check_hostname: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self.ssl_context: RedisSSLContext = RedisSSLContext( + keyfile=ssl_keyfile, + certfile=ssl_certfile, + cert_reqs=ssl_cert_reqs, + ca_certs=ssl_ca_certs, + check_hostname=ssl_check_hostname, + ) + + @property + def keyfile(self): + return self.ssl_context.keyfile + + @property + def certfile(self): + return self.ssl_context.certfile + + @property + def cert_reqs(self): + return self.ssl_context.cert_reqs + + @property + def ca_certs(self): + return self.ssl_context.ca_certs + + @property + def check_hostname(self): + return self.ssl_context.check_hostname + + +class RedisSSLContext: + __slots__ = ( + "keyfile", + "certfile", + "cert_reqs", + "ca_certs", + "context", + "check_hostname", + ) + + def __init__( + self, + keyfile: Optional[str] = None, + certfile: Optional[str] = None, + cert_reqs: Optional[str] = None, + ca_certs: Optional[str] = None, + check_hostname: bool = False, + ): + self.keyfile = keyfile + self.certfile = certfile + if cert_reqs is None: + self.cert_reqs = ssl.CERT_NONE + elif isinstance(cert_reqs, str): + CERT_REQS = { + "none": ssl.CERT_NONE, + "optional": ssl.CERT_OPTIONAL, + "required": ssl.CERT_REQUIRED, + } + if cert_reqs not in CERT_REQS: + raise RedisError( + f"Invalid SSL Certificate Requirements Flag: {cert_reqs}" + ) + self.cert_reqs = CERT_REQS[cert_reqs] + self.ca_certs = ca_certs + self.check_hostname = check_hostname + self.context: Optional[ssl.SSLContext] = None + + def get(self) -> ssl.SSLContext: + if not self.context: + context = ssl.create_default_context() + context.check_hostname = self.check_hostname + context.verify_mode = self.cert_reqs + if self.certfile and self.keyfile: + context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile) + if self.ca_certs: + context.load_verify_locations(self.ca_certs) + self.context = context + return self.context + + +class UnixDomainSocketConnection(Connection): # lgtm [py/missing-call-to-init] + def __init__( + self, + *, + path: str = "", + db: Union[str, int] = 0, + username: Optional[str] = None, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + retry_on_timeout: bool = False, + parser_class: Type[BaseParser] = DefaultParser, + socket_read_size: int = 65536, + health_check_interval: float = 0.0, + client_name: str = None, + retry: Optional[Retry] = None, + ): + """ + Initialize a new UnixDomainSocketConnection. + To specify a retry policy, first set `retry_on_timeout` to `True` + then set `retry` to a valid `Retry` object + """ + self.pid = os.getpid() + self.path = path + self.db = db + self.username = username + self.client_name = client_name + self.password = password + self.socket_timeout = socket_timeout + self.socket_connect_timeout = socket_connect_timeout or socket_timeout or None + self.retry_on_timeout = retry_on_timeout + if retry_on_timeout: + if retry is None: + self.retry = Retry(NoBackoff(), 1) + else: + # deep-copy the Retry object as it is mutable + self.retry = copy.deepcopy(retry) + else: + self.retry = Retry(NoBackoff(), 0) + self.health_check_interval = health_check_interval + self.next_health_check = -1 + self.encoder = Encoder(encoding, encoding_errors, decode_responses) + self._sock = None + self._reader = None + self._writer = None + self._socket_read_size = socket_read_size + self.set_parser(parser_class) + self._connect_callbacks = [] + self._buffer_cutoff = 6000 + self._lock = asyncio.Lock() + + def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]: + pieces = [ + ("path", self.path), + ("db", self.db), + ] + if self.client_name: + pieces.append(("client_name", self.client_name)) + return pieces + + async def _connect(self): + async with async_timeout.timeout(self.socket_connect_timeout): + reader, writer = await asyncio.open_unix_connection(path=self.path) + self._reader = reader + self._writer = writer + await self.on_connect() + + def _error_message(self, exception): + # args for socket.error can either be (errno, "message") + # or just "message" + if len(exception.args) == 1: + return f"Error connecting to unix socket: {self.path}. {exception.args[0]}." + else: + return ( + f"Error {exception.args[0]} connecting to unix socket: " + f"{self.path}. {exception.args[1]}." + ) + + +FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") + + +def to_bool(value) -> Optional[bool]: + if value is None or value == "": + return None + if isinstance(value, str) and value.upper() in FALSE_STRINGS: + return False + return bool(value) + + +URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType( + { + "db": int, + "socket_timeout": float, + "socket_connect_timeout": float, + "socket_keepalive": to_bool, + "retry_on_timeout": to_bool, + "max_connections": int, + "health_check_interval": int, + "ssl_check_hostname": to_bool, + } +) + + +class ConnectKwargs(TypedDict, total=False): + username: str + password: str + connection_class: Type[Connection] + host: str + port: int + db: int + path: str + + +def parse_url(url: str) -> ConnectKwargs: + parsed: ParseResult = urlparse(url) + kwargs: ConnectKwargs = {} + + for name, value_list in parse_qs(parsed.query).items(): + if value_list and len(value_list) > 0: + value = unquote(value_list[0]) + parser = URL_QUERY_ARGUMENT_PARSERS.get(name) + if parser: + try: + # We can't type this. + kwargs[name] = parser(value) # type: ignore[misc] + except (TypeError, ValueError): + raise ValueError(f"Invalid value for `{name}` in connection URL.") + else: + kwargs[name] = value # type: ignore[misc] + + if parsed.username: + kwargs["username"] = unquote(parsed.username) + if parsed.password: + kwargs["password"] = unquote(parsed.password) + + # We only support redis://, rediss:// and unix:// schemes. + if parsed.scheme == "unix": + if parsed.path: + kwargs["path"] = unquote(parsed.path) + kwargs["connection_class"] = UnixDomainSocketConnection + + elif parsed.scheme in ("redis", "rediss"): + if parsed.hostname: + kwargs["host"] = unquote(parsed.hostname) + if parsed.port: + kwargs["port"] = int(parsed.port) + + # If there's a path argument, use it as the db argument if a + # querystring value wasn't specified + if parsed.path and "db" not in kwargs: + try: + kwargs["db"] = int(unquote(parsed.path).replace("/", "")) + except (AttributeError, ValueError): + pass + + if parsed.scheme == "rediss": + kwargs["connection_class"] = SSLConnection + else: + valid_schemes = "redis://, rediss://, unix://" + raise ValueError( + f"Redis URL must specify one of the following schemes ({valid_schemes})" + ) + + return kwargs + + +_CP = TypeVar("_CP", bound="ConnectionPool") + + +class ConnectionPool: + """ + Create a connection pool. ``If max_connections`` is set, then this + object raises :py:class:`~redis.ConnectionError` when the pool's + limit is reached. + + By default, TCP connections are created unless ``connection_class`` + is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for + unix sockets. + + Any additional keyword arguments are passed to the constructor of + ``connection_class``. + """ + + @classmethod + def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: + """ + Return a connection pool configured from the given URL. + + For example:: + + redis://[[username]:[password]]@localhost:6379/0 + rediss://[[username]:[password]]@localhost:6379/0 + unix://[[username]:[password]]@/path/to/socket.sock?db=0 + + Three URL schemes are supported: + + - `redis://` creates a TCP socket connection. See more at: + + - `rediss://` creates a SSL wrapped TCP socket connection. See more at: + + - ``unix://``: creates a Unix Domain Socket connection. + + The username, password, hostname, path and all querystring values + are passed through urllib.parse.unquote in order to replace any + percent-encoded values with their corresponding characters. + + There are several ways to specify a database number. The first value + found will be used: + 1. A ``db`` querystring option, e.g. redis://localhost?db=0 + 2. If using the redis:// or rediss:// schemes, the path argument + of the url, e.g. redis://localhost/0 + 3. A ``db`` keyword argument to this function. + + If none of these options are specified, the default db=0 is used. + + All querystring options are cast to their appropriate Python types. + Boolean arguments can be specified with string values "True"/"False" + or "Yes"/"No". Values that cannot be properly cast cause a + ``ValueError`` to be raised. Once parsed, the querystring arguments + and keyword arguments are passed to the ``ConnectionPool``'s + class initializer. In the case of conflicting arguments, querystring + arguments always win. + """ + url_options = parse_url(url) + kwargs.update(url_options) + return cls(**kwargs) + + def __init__( + self, + connection_class: Type[Connection] = Connection, + max_connections: Optional[int] = None, + **connection_kwargs, + ): + max_connections = max_connections or 2 ** 31 + if not isinstance(max_connections, int) or max_connections < 0: + raise ValueError('"max_connections" must be a positive integer') + + self.connection_class = connection_class + self.connection_kwargs = connection_kwargs + self.max_connections = max_connections + + # a lock to protect the critical section in _checkpid(). + # this lock is acquired when the process id changes, such as + # after a fork. during this time, multiple threads in the child + # process could attempt to acquire this lock. the first thread + # to acquire the lock will reset the data structures and lock + # object of this pool. subsequent threads acquiring this lock + # will notice the first thread already did the work and simply + # release the lock. + self._fork_lock = threading.Lock() + self._lock = asyncio.Lock() + self._created_connections: int + self._available_connections: List[Connection] + self._in_use_connections: Set[Connection] + self.reset() # lgtm [py/init-calls-subclass] + self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder) + + def __repr__(self): + return ( + f"{self.__class__.__name__}" + f"<{self.connection_class(**self.connection_kwargs)!r}>" + ) + + def reset(self): + self._lock = asyncio.Lock() + self._created_connections = 0 + self._available_connections = [] + self._in_use_connections = set() + + # this must be the last operation in this method. while reset() is + # called when holding _fork_lock, other threads in this process + # can call _checkpid() which compares self.pid and os.getpid() without + # holding any lock (for performance reasons). keeping this assignment + # as the last operation ensures that those other threads will also + # notice a pid difference and block waiting for the first thread to + # release _fork_lock. when each of these threads eventually acquire + # _fork_lock, they will notice that another thread already called + # reset() and they will immediately release _fork_lock and continue on. + self.pid = os.getpid() + + def _checkpid(self): + # _checkpid() attempts to keep ConnectionPool fork-safe on modern + # systems. this is called by all ConnectionPool methods that + # manipulate the pool's state such as get_connection() and release(). + # + # _checkpid() determines whether the process has forked by comparing + # the current process id to the process id saved on the ConnectionPool + # instance. if these values are the same, _checkpid() simply returns. + # + # when the process ids differ, _checkpid() assumes that the process + # has forked and that we're now running in the child process. the child + # process cannot use the parent's file descriptors (e.g., sockets). + # therefore, when _checkpid() sees the process id change, it calls + # reset() in order to reinitialize the child's ConnectionPool. this + # will cause the child to make all new connection objects. + # + # _checkpid() is protected by self._fork_lock to ensure that multiple + # threads in the child process do not call reset() multiple times. + # + # there is an extremely small chance this could fail in the following + # scenario: + # 1. process A calls _checkpid() for the first time and acquires + # self._fork_lock. + # 2. while holding self._fork_lock, process A forks (the fork() + # could happen in a different thread owned by process A) + # 3. process B (the forked child process) inherits the + # ConnectionPool's state from the parent. that state includes + # a locked _fork_lock. process B will not be notified when + # process A releases the _fork_lock and will thus never be + # able to acquire the _fork_lock. + # + # to mitigate this possible deadlock, _checkpid() will only wait 5 + # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in + # that time it is assumed that the child is deadlocked and a + # redis.ChildDeadlockedError error is raised. + if self.pid != os.getpid(): + acquired = self._fork_lock.acquire(timeout=5) + if not acquired: + raise ChildDeadlockedError + # reset() the instance for the new process if another thread + # hasn't already done so + try: + if self.pid != os.getpid(): + self.reset() + finally: + self._fork_lock.release() + + async def get_connection(self, command_name, *keys, **options): + """Get a connection from the pool""" + self._checkpid() + async with self._lock: + try: + connection = self._available_connections.pop() + except IndexError: + connection = self.make_connection() + self._in_use_connections.add(connection) + + try: + # ensure this connection is connected to Redis + await connection.connect() + # connections that the pool provides should be ready to send + # a command. if not, the connection was either returned to the + # pool before all data has been read or the socket has been + # closed. either way, reconnect and verify everything is good. + try: + if await connection.can_read(): + raise ConnectionError("Connection has data") from None + except ConnectionError: + await connection.disconnect() + await connection.connect() + if await connection.can_read(): + raise ConnectionError("Connection not ready") from None + except BaseException: + # release the connection back to the pool so that we don't + # leak it + await self.release(connection) + raise + + return connection + + def get_encoder(self): + """Return an encoder based on encoding settings""" + kwargs = self.connection_kwargs + return self.encoder_class( + encoding=kwargs.get("encoding", "utf-8"), + encoding_errors=kwargs.get("encoding_errors", "strict"), + decode_responses=kwargs.get("decode_responses", False), + ) + + def make_connection(self): + """Create a new connection""" + if self._created_connections >= self.max_connections: + raise ConnectionError("Too many connections") + self._created_connections += 1 + return self.connection_class(**self.connection_kwargs) + + async def release(self, connection: Connection): + """Releases the connection back to the pool""" + self._checkpid() + async with self._lock: + try: + self._in_use_connections.remove(connection) + except KeyError: + # Gracefully fail when a connection is returned to this pool + # that the pool doesn't actually own + pass + + if self.owns_connection(connection): + self._available_connections.append(connection) + else: + # pool doesn't own this connection. do not add it back + # to the pool and decrement the count so that another + # connection can take its place if needed + self._created_connections -= 1 + await connection.disconnect() + return + + def owns_connection(self, connection: Connection): + return connection.pid == self.pid + + async def disconnect(self, inuse_connections: bool = True): + """ + Disconnects connections in the pool + + If ``inuse_connections`` is True, disconnect connections that are + current in use, potentially by other tasks. Otherwise only disconnect + connections that are idle in the pool. + """ + self._checkpid() + async with self._lock: + if inuse_connections: + connections: Iterable[Connection] = chain( + self._available_connections, self._in_use_connections + ) + else: + connections = self._available_connections + resp = await asyncio.gather( + *(connection.disconnect() for connection in connections), + return_exceptions=True, + ) + exc = next((r for r in resp if isinstance(r, BaseException)), None) + if exc: + raise exc + + +class BlockingConnectionPool(ConnectionPool): + """ + Thread-safe blocking connection pool:: + + >>> from redis.client import Redis + >>> client = Redis(connection_pool=BlockingConnectionPool()) + + It performs the same function as the default + :py:class:`~redis.ConnectionPool` implementation, in that, + it maintains a pool of reusable connections that can be shared by + multiple redis clients (safely across threads if required). + + The difference is that, in the event that a client tries to get a + connection from the pool when all of connections are in use, rather than + raising a :py:class:`~redis.ConnectionError` (as the default + :py:class:`~redis.ConnectionPool` implementation does), it + makes the client wait ("blocks") for a specified number of seconds until + a connection becomes available. + + Use ``max_connections`` to increase / decrease the pool size:: + + >>> pool = BlockingConnectionPool(max_connections=10) + + Use ``timeout`` to tell it either how many seconds to wait for a connection + to become available, or to block forever: + + >>> # Block forever. + >>> pool = BlockingConnectionPool(timeout=None) + + >>> # Raise a ``ConnectionError`` after five seconds if a connection is + >>> # not available. + >>> pool = BlockingConnectionPool(timeout=5) + """ + + def __init__( + self, + max_connections: int = 50, + timeout: Optional[int] = 20, + connection_class: Type[Connection] = Connection, + queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, + **connection_kwargs, + ): + + self.queue_class = queue_class + self.timeout = timeout + self._connections: List[Connection] + super().__init__( + connection_class=connection_class, + max_connections=max_connections, + **connection_kwargs, + ) + + def reset(self): + # Create and fill up a thread safe queue with ``None`` values. + self.pool = self.queue_class(self.max_connections) + while True: + try: + self.pool.put_nowait(None) + except asyncio.QueueFull: + break + + # Keep a list of actual connection instances so that we can + # disconnect them later. + self._connections = [] + + # this must be the last operation in this method. while reset() is + # called when holding _fork_lock, other threads in this process + # can call _checkpid() which compares self.pid and os.getpid() without + # holding any lock (for performance reasons). keeping this assignment + # as the last operation ensures that those other threads will also + # notice a pid difference and block waiting for the first thread to + # release _fork_lock. when each of these threads eventually acquire + # _fork_lock, they will notice that another thread already called + # reset() and they will immediately release _fork_lock and continue on. + self.pid = os.getpid() + + def make_connection(self): + """Make a fresh connection.""" + connection = self.connection_class(**self.connection_kwargs) + self._connections.append(connection) + return connection + + async def get_connection(self, command_name, *keys, **options): + """ + Get a connection, blocking for ``self.timeout`` until a connection + is available from the pool. + + If the connection returned is ``None`` then creates a new connection. + Because we use a last-in first-out queue, the existing connections + (having been returned to the pool after the initial ``None`` values + were added) will be returned before ``None`` values. This means we only + create new connections when we need to, i.e.: the actual number of + connections will only increase in response to demand. + """ + # Make sure we haven't changed process. + self._checkpid() + + # Try and get a connection from the pool. If one isn't available within + # self.timeout then raise a ``ConnectionError``. + connection = None + try: + async with async_timeout.timeout(self.timeout): + connection = await self.pool.get() + except (asyncio.QueueEmpty, asyncio.TimeoutError): + # Note that this is not caught by the redis client and will be + # raised unless handled by application code. If you want never to + raise ConnectionError("No connection available.") + + # If the ``connection`` is actually ``None`` then that's a cue to make + # a new connection to add to the pool. + if connection is None: + connection = self.make_connection() + + try: + # ensure this connection is connected to Redis + await connection.connect() + # connections that the pool provides should be ready to send + # a command. if not, the connection was either returned to the + # pool before all data has been read or the socket has been + # closed. either way, reconnect and verify everything is good. + try: + if await connection.can_read(): + raise ConnectionError("Connection has data") from None + except ConnectionError: + await connection.disconnect() + await connection.connect() + if await connection.can_read(): + raise ConnectionError("Connection not ready") from None + except BaseException: + # release the connection back to the pool so that we don't leak it + await self.release(connection) + raise + + return connection + + async def release(self, connection: Connection): + """Releases the connection back to the pool.""" + # Make sure we haven't changed process. + self._checkpid() + if not self.owns_connection(connection): + # pool doesn't own this connection. do not add it back + # to the pool. instead add a None value which is a placeholder + # that will cause the pool to recreate the connection if + # its needed. + await connection.disconnect() + self.pool.put_nowait(None) + return + + # Put the connection back into the pool. + try: + self.pool.put_nowait(connection) + except asyncio.QueueFull: + # perhaps the pool has been reset() after a fork? regardless, + # we don't want this connection + pass + + async def disconnect(self, inuse_connections: bool = True): + """Disconnects all connections in the pool.""" + self._checkpid() + async with self._lock: + resp = await asyncio.gather( + *(connection.disconnect() for connection in self._connections), + return_exceptions=True, + ) + exc = next((r for r in resp if isinstance(r, BaseException)), None) + if exc: + raise exc diff --git a/redis/asyncio/lock.py b/redis/asyncio/lock.py new file mode 100644 index 0000000000..784594e3af --- /dev/null +++ b/redis/asyncio/lock.py @@ -0,0 +1,306 @@ +import asyncio +import threading +import uuid +from types import SimpleNamespace +from typing import TYPE_CHECKING, Awaitable, NoReturn, Optional, Union + +from redis.exceptions import LockError, LockNotOwnedError + +if TYPE_CHECKING: + from redis.asyncio import Redis + + +class Lock: + """ + A shared, distributed Lock. Using Redis for locking allows the Lock + to be shared across processes and/or machines. + + It's left to the user to resolve deadlock issues and make sure + multiple clients play nicely together. + """ + + lua_release = None + lua_extend = None + lua_reacquire = None + + # KEYS[1] - lock name + # ARGV[1] - token + # return 1 if the lock was released, otherwise 0 + LUA_RELEASE_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + redis.call('del', KEYS[1]) + return 1 + """ + + # KEYS[1] - lock name + # ARGV[1] - token + # ARGV[2] - additional milliseconds + # ARGV[3] - "0" if the additional time should be added to the lock's + # existing ttl or "1" if the existing ttl should be replaced + # return 1 if the locks time was extended, otherwise 0 + LUA_EXTEND_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + local expiration = redis.call('pttl', KEYS[1]) + if not expiration then + expiration = 0 + end + if expiration < 0 then + return 0 + end + + local newttl = ARGV[2] + if ARGV[3] == "0" then + newttl = ARGV[2] + expiration + end + redis.call('pexpire', KEYS[1], newttl) + return 1 + """ + + # KEYS[1] - lock name + # ARGV[1] - token + # ARGV[2] - milliseconds + # return 1 if the locks time was reacquired, otherwise 0 + LUA_REACQUIRE_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + redis.call('pexpire', KEYS[1], ARGV[2]) + return 1 + """ + + def __init__( + self, + redis: "Redis", + name: Union[str, bytes, memoryview], + timeout: Optional[float] = None, + sleep: float = 0.1, + blocking: bool = True, + blocking_timeout: Optional[float] = None, + thread_local: bool = True, + ): + """ + Create a new Lock instance named ``name`` using the Redis client + supplied by ``redis``. + + ``timeout`` indicates a maximum life for the lock in seconds. + By default, it will remain locked until release() is called. + ``timeout`` can be specified as a float or integer, both representing + the number of seconds to wait. + + ``sleep`` indicates the amount of time to sleep in seconds per loop + iteration when the lock is in blocking mode and another client is + currently holding the lock. + + ``blocking`` indicates whether calling ``acquire`` should block until + the lock has been acquired or to fail immediately, causing ``acquire`` + to return False and the lock not being acquired. Defaults to True. + Note this value can be overridden by passing a ``blocking`` + argument to ``acquire``. + + ``blocking_timeout`` indicates the maximum amount of time in seconds to + spend trying to acquire the lock. A value of ``None`` indicates + continue trying forever. ``blocking_timeout`` can be specified as a + float or integer, both representing the number of seconds to wait. + + ``thread_local`` indicates whether the lock token is placed in + thread-local storage. By default, the token is placed in thread local + storage so that a thread only sees its token, not a token set by + another thread. Consider the following timeline: + + time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. + thread-1 sets the token to "abc" + time: 1, thread-2 blocks trying to acquire `my-lock` using the + Lock instance. + time: 5, thread-1 has not yet completed. redis expires the lock + key. + time: 5, thread-2 acquired `my-lock` now that it's available. + thread-2 sets the token to "xyz" + time: 6, thread-1 finishes its work and calls release(). if the + token is *not* stored in thread local storage, then + thread-1 would see the token value as "xyz" and would be + able to successfully release the thread-2's lock. + + In some use cases it's necessary to disable thread local storage. For + example, if you have code where one thread acquires a lock and passes + that lock instance to a worker thread to release later. If thread + local storage isn't disabled in this case, the worker thread won't see + the token set by the thread that acquired the lock. Our assumption + is that these cases aren't common and as such default to using + thread local storage. + """ + self.redis = redis + self.name = name + self.timeout = timeout + self.sleep = sleep + self.blocking = blocking + self.blocking_timeout = blocking_timeout + self.thread_local = bool(thread_local) + self.local = threading.local() if self.thread_local else SimpleNamespace() + self.local.token = None + self.register_scripts() + + def register_scripts(self): + cls = self.__class__ + client = self.redis + if cls.lua_release is None: + cls.lua_release = client.register_script(cls.LUA_RELEASE_SCRIPT) + if cls.lua_extend is None: + cls.lua_extend = client.register_script(cls.LUA_EXTEND_SCRIPT) + if cls.lua_reacquire is None: + cls.lua_reacquire = client.register_script(cls.LUA_REACQUIRE_SCRIPT) + + async def __aenter__(self): + if await self.acquire(): + return self + raise LockError("Unable to acquire lock within the time specified") + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.release() + + async def acquire( + self, + blocking: Optional[bool] = None, + blocking_timeout: Optional[float] = None, + token: Optional[Union[str, bytes]] = None, + ): + """ + Use Redis to hold a shared, distributed lock named ``name``. + Returns True once the lock is acquired. + + If ``blocking`` is False, always return immediately. If the lock + was acquired, return True, otherwise return False. + + ``blocking_timeout`` specifies the maximum number of seconds to + wait trying to acquire the lock. + + ``token`` specifies the token value to be used. If provided, token + must be a bytes object or a string that can be encoded to a bytes + object with the default encoding. If a token isn't specified, a UUID + will be generated. + """ + loop = asyncio.get_running_loop() + sleep = self.sleep + if token is None: + token = uuid.uuid1().hex.encode() + else: + encoder = self.redis.connection_pool.get_encoder() + token = encoder.encode(token) + if blocking is None: + blocking = self.blocking + if blocking_timeout is None: + blocking_timeout = self.blocking_timeout + stop_trying_at = None + if blocking_timeout is not None: + stop_trying_at = loop.time() + blocking_timeout + while True: + if await self.do_acquire(token): + self.local.token = token + return True + if not blocking: + return False + next_try_at = loop.time() + sleep + if stop_trying_at is not None and next_try_at > stop_trying_at: + return False + await asyncio.sleep(sleep) + + async def do_acquire(self, token: Union[str, bytes]) -> bool: + if self.timeout: + # convert to milliseconds + timeout = int(self.timeout * 1000) + else: + timeout = None + if await self.redis.set(self.name, token, nx=True, px=timeout): + return True + return False + + async def locked(self) -> bool: + """ + Returns True if this key is locked by any process, otherwise False. + """ + return await self.redis.get(self.name) is not None + + async def owned(self) -> bool: + """ + Returns True if this key is locked by this lock, otherwise False. + """ + stored_token = await self.redis.get(self.name) + # need to always compare bytes to bytes + # TODO: this can be simplified when the context manager is finished + if stored_token and not isinstance(stored_token, bytes): + encoder = self.redis.connection_pool.get_encoder() + stored_token = encoder.encode(stored_token) + return self.local.token is not None and stored_token == self.local.token + + def release(self) -> Awaitable[NoReturn]: + """Releases the already acquired lock""" + expected_token = self.local.token + if expected_token is None: + raise LockError("Cannot release an unlocked lock") + self.local.token = None + return self.do_release(expected_token) + + async def do_release(self, expected_token: bytes): + if not bool( + await self.lua_release( + keys=[self.name], args=[expected_token], client=self.redis + ) + ): + raise LockNotOwnedError("Cannot release a lock" " that's no longer owned") + + def extend( + self, additional_time: float, replace_ttl: bool = False + ) -> Awaitable[bool]: + """ + Adds more time to an already acquired lock. + + ``additional_time`` can be specified as an integer or a float, both + representing the number of seconds to add. + + ``replace_ttl`` if False (the default), add `additional_time` to + the lock's existing ttl. If True, replace the lock's ttl with + `additional_time`. + """ + if self.local.token is None: + raise LockError("Cannot extend an unlocked lock") + if self.timeout is None: + raise LockError("Cannot extend a lock with no timeout") + return self.do_extend(additional_time, replace_ttl) + + async def do_extend(self, additional_time, replace_ttl) -> bool: + additional_time = int(additional_time * 1000) + if not bool( + await self.lua_extend( + keys=[self.name], + args=[self.local.token, additional_time, replace_ttl and "1" or "0"], + client=self.redis, + ) + ): + raise LockNotOwnedError("Cannot extend a lock that's" " no longer owned") + return True + + def reacquire(self) -> Awaitable[bool]: + """ + Resets a TTL of an already acquired lock back to a timeout value. + """ + if self.local.token is None: + raise LockError("Cannot reacquire an unlocked lock") + if self.timeout is None: + raise LockError("Cannot reacquire a lock with no timeout") + return self.do_reacquire() + + async def do_reacquire(self) -> bool: + timeout = int(self.timeout * 1000) + if not bool( + await self.lua_reacquire( + keys=[self.name], args=[self.local.token, timeout], client=self.redis + ) + ): + raise LockNotOwnedError("Cannot reacquire a lock that's" " no longer owned") + return True diff --git a/redis/asyncio/retry.py b/redis/asyncio/retry.py new file mode 100644 index 0000000000..a5cc9461fb --- /dev/null +++ b/redis/asyncio/retry.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from asyncio import sleep +from typing import TYPE_CHECKING, Awaitable, Callable, TypeVar + +from redis.exceptions import ConnectionError, RedisError, TimeoutError + +if TYPE_CHECKING: + from redis.backoff import AbstractBackoff + + +T = TypeVar("T") + + +class Retry: + """Retry a specific number of times after a failure""" + + __slots__ = "_backoff", "_retries", "_supported_errors" + + def __init__( + self, + backoff: AbstractBackoff, + retries: int, + supported_errors: type[tuple[RedisError, ...]] = ( + ConnectionError, + TimeoutError, + ), + ): + """ + Initialize a `Retry` object with a `Backoff` object + that retries a maximum of `retries` times. + You can specify the types of supported errors which trigger + a retry with the `supported_errors` parameter. + """ + self._backoff = backoff + self._retries = retries + self._supported_errors = supported_errors + + async def call_with_retry( + self, do: Callable[[], Awaitable[T]], fail: Callable[[RedisError], ...] + ) -> T: + """ + Execute an operation that might fail and returns its result, or + raise the exception that was thrown depending on the `Backoff` object. + `do`: the operation to call. Expects no argument. + `fail`: the failure handler, expects the last error that was thrown + """ + self._backoff.reset() + failures = 0 + while True: + try: + return await do() + except self._supported_errors as error: + failures += 1 + fail(error) + if failures > self._retries: + raise error + backoff = self._backoff.compute(failures) + if backoff > 0: + await sleep(backoff) diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py new file mode 100644 index 0000000000..824802754f --- /dev/null +++ b/redis/asyncio/sentinel.py @@ -0,0 +1,355 @@ +import asyncio +import random +import weakref +from typing import AsyncIterator, Iterable, Mapping, Sequence, Tuple, Type + +from redis.asyncio.client import Redis +from redis.asyncio.connection import ( + Connection, ConnectionPool, EncodableT, SSLConnection, +) +from redis.commands import SentinelCommands +from redis.exceptions import ( + ConnectionError, + ReadOnlyError, + ResponseError, + TimeoutError, +) +from redis.utils import str_if_bytes + + +class MasterNotFoundError(ConnectionError): + pass + + +class SlaveNotFoundError(ConnectionError): + pass + + +class SentinelManagedConnection(Connection): + def __init__(self, **kwargs): + self.connection_pool = kwargs.pop("connection_pool") + super().__init__(**kwargs) + + def __repr__(self): + pool = self.connection_pool + s = f"{self.__class__.__name__}" + + async def connect_to(self, address): + self.host, self.port = address + await super().connect() + if self.connection_pool.check_connection: + await self.send_command("PING") + if str_if_bytes(await self.read_response()) != "PONG": + raise ConnectionError("PING failed") + + async def connect(self): + if self._reader: + return # already connected + if self.connection_pool.is_master: + await self.connect_to(await self.connection_pool.get_master_address()) + else: + async for slave in self.connection_pool.rotate_slaves(): + try: + return await self.connect_to(slave) + except ConnectionError: + continue + raise SlaveNotFoundError # Never be here + + async def read_response(self, disable_decoding: bool = False): + try: + return await super().read_response(disable_decoding=disable_decoding) + except ReadOnlyError: + if self.connection_pool.is_master: + # When talking to a master, a ReadOnlyError when likely + # indicates that the previous master that we're still connected + # to has been demoted to a slave and there's a new master. + # calling disconnect will force the connection to re-query + # sentinel during the next connect() attempt. + await self.disconnect() + raise ConnectionError("The previous master is now a slave") + raise + + +class SentinelManagedSSLConnection(SentinelManagedConnection, SSLConnection): + pass + + +class SentinelConnectionPool(ConnectionPool): + """ + Sentinel backed connection pool. + + If ``check_connection`` flag is set to True, SentinelManagedConnection + sends a PING command right after establishing the connection. + """ + + def __init__(self, service_name, sentinel_manager, **kwargs): + kwargs["connection_class"] = kwargs.get( + "connection_class", + SentinelManagedSSLConnection + if kwargs.pop("ssl", False) + else SentinelManagedConnection, + ) + self.is_master = kwargs.pop("is_master", True) + self.check_connection = kwargs.pop("check_connection", False) + super().__init__(**kwargs) + self.connection_kwargs["connection_pool"] = weakref.proxy(self) + self.service_name = service_name + self.sentinel_manager = sentinel_manager + self.master_address = None + self.slave_rr_counter = None + + def __repr__(self): + return ( + f"{self.__class__.__name__}" + f"" + ) + + def reset(self): + super().reset() + self.master_address = None + self.slave_rr_counter = None + + def owns_connection(self, connection: Connection): + check = not self.is_master or ( + self.is_master and self.master_address == (connection.host, connection.port) + ) + return check and super().owns_connection(connection) + + async def get_master_address(self): + master_address = await self.sentinel_manager.discover_master(self.service_name) + if self.is_master: + if self.master_address != master_address: + self.master_address = master_address + # disconnect any idle connections so that they reconnect + # to the new master the next time that they are used. + await self.disconnect(inuse_connections=False) + return master_address + + async def rotate_slaves(self) -> AsyncIterator: + """Round-robin slave balancer""" + slaves = await self.sentinel_manager.discover_slaves(self.service_name) + if slaves: + if self.slave_rr_counter is None: + self.slave_rr_counter = random.randint(0, len(slaves) - 1) + for _ in range(len(slaves)): + self.slave_rr_counter = (self.slave_rr_counter + 1) % len(slaves) + slave = slaves[self.slave_rr_counter] + yield slave + # Fallback to the master connection + try: + yield await self.get_master_address() + except MasterNotFoundError: + pass + raise SlaveNotFoundError(f"No slave found for {self.service_name!r}") + + +class Sentinel(SentinelCommands): + """ + Redis Sentinel cluster client + + >>> from redis.sentinel import Sentinel + >>> sentinel = Sentinel([('localhost', 26379)], socket_timeout=0.1) + >>> master = sentinel.master_for('mymaster', socket_timeout=0.1) + >>> await master.set('foo', 'bar') + >>> slave = sentinel.slave_for('mymaster', socket_timeout=0.1) + >>> await slave.get('foo') + b'bar' + + ``sentinels`` is a list of sentinel nodes. Each node is represented by + a pair (hostname, port). + + ``min_other_sentinels`` defined a minimum number of peers for a sentinel. + When querying a sentinel, if it doesn't meet this threshold, responses + from that sentinel won't be considered valid. + + ``sentinel_kwargs`` is a dictionary of connection arguments used when + connecting to sentinel instances. Any argument that can be passed to + a normal Redis connection can be specified here. If ``sentinel_kwargs`` is + not specified, any socket_timeout and socket_keepalive options specified + in ``connection_kwargs`` will be used. + + ``connection_kwargs`` are keyword arguments that will be used when + establishing a connection to a Redis server. + """ + + def __init__( + self, + sentinels, + min_other_sentinels=0, + sentinel_kwargs=None, + **connection_kwargs, + ): + # if sentinel_kwargs isn't defined, use the socket_* options from + # connection_kwargs + if sentinel_kwargs is None: + sentinel_kwargs = { + k: v for k, v in connection_kwargs.items() if k.startswith("socket_") + } + self.sentinel_kwargs = sentinel_kwargs + + self.sentinels = [ + Redis(host=hostname, port=port, **self.sentinel_kwargs) + for hostname, port in sentinels + ] + self.min_other_sentinels = min_other_sentinels + self.connection_kwargs = connection_kwargs + + async def execute_command(self, *args, **kwargs): + """ + Execute Sentinel command in sentinel nodes. + once - If set to True, then execute the resulting command on a single + node at random, rather than across the entire sentinel cluster. + """ + once = bool(kwargs.get("once", False)) + if "once" in kwargs.keys(): + kwargs.pop("once") + + if once: + tasks = [ + asyncio.create_task(sentinel.execute_command(*args, **kwargs)) + for sentinel in self.sentinels + ] + await asyncio.gather(*tasks) + else: + await random.choice(self.sentinels).execute_command(*args, **kwargs) + return True + + def __repr__(self): + sentinel_addresses = [] + for sentinel in self.sentinels: + sentinel_addresses.append( + f"{sentinel.connection_pool.connection_kwargs['host']}:" + f"{sentinel.connection_pool.connection_kwargs['port']}" + ) + return f"{self.__class__.__name__}" + + def check_master_state(self, state: dict, service_name: str) -> bool: + if not state["is_master"] or state["is_sdown"] or state["is_odown"]: + return False + # Check if our sentinel doesn't see other nodes + if state["num-other-sentinels"] < self.min_other_sentinels: + return False + return True + + async def discover_master(self, service_name: str): + """ + Asks sentinel servers for the Redis master's address corresponding + to the service labeled ``service_name``. + + Returns a pair (address, port) or raises MasterNotFoundError if no + master is found. + """ + for sentinel_no, sentinel in enumerate(self.sentinels): + try: + masters = await sentinel.sentinel_masters() + except (ConnectionError, TimeoutError): + continue + state = masters.get(service_name) + if state and self.check_master_state(state, service_name): + # Put this sentinel at the top of the list + self.sentinels[0], self.sentinels[sentinel_no] = ( + sentinel, + self.sentinels[0], + ) + return state["ip"], state["port"] + raise MasterNotFoundError(f"No master found for {service_name!r}") + + def filter_slaves( + self, slaves: Iterable[Mapping] + ) -> Sequence[Tuple[EncodableT, EncodableT]]: + """Remove slaves that are in an ODOWN or SDOWN state""" + slaves_alive = [] + for slave in slaves: + if slave["is_odown"] or slave["is_sdown"]: + continue + slaves_alive.append((slave["ip"], slave["port"])) + return slaves_alive + + async def discover_slaves( + self, service_name: str + ) -> Sequence[Tuple[EncodableT, EncodableT]]: + """Returns a list of alive slaves for service ``service_name``""" + for sentinel in self.sentinels: + try: + slaves = await sentinel.sentinel_slaves(service_name) + except (ConnectionError, ResponseError, TimeoutError): + continue + slaves = self.filter_slaves(slaves) + if slaves: + return slaves + return [] + + def master_for( + self, + service_name: str, + redis_class: Type[Redis] = Redis, + connection_pool_class: Type[SentinelConnectionPool] = SentinelConnectionPool, + **kwargs, + ): + """ + Returns a redis client instance for the ``service_name`` master. + + A :py:class:`~redis.sentinel.SentinelConnectionPool` class is + used to retrieve the master's address before establishing a new + connection. + + NOTE: If the master's address has changed, any cached connections to + the old master are closed. + + By default clients will be a :py:class:`~redis.Redis` instance. + Specify a different class to the ``redis_class`` argument if you + desire something different. + + The ``connection_pool_class`` specifies the connection pool to + use. The :py:class:`~redis.sentinel.SentinelConnectionPool` + will be used by default. + + All other keyword arguments are merged with any connection_kwargs + passed to this class and passed to the connection pool as keyword + arguments to be used to initialize Redis connections. + """ + kwargs["is_master"] = True + connection_kwargs = dict(self.connection_kwargs) + connection_kwargs.update(kwargs) + return redis_class( + connection_pool=connection_pool_class( + service_name, self, **connection_kwargs + ) + ) + + def slave_for( + self, + service_name: str, + redis_class: Type[Redis] = Redis, + connection_pool_class: Type[SentinelConnectionPool] = SentinelConnectionPool, + **kwargs, + ): + """ + Returns redis client instance for the ``service_name`` slave(s). + + A SentinelConnectionPool class is used to retrieve the slave's + address before establishing a new connection. + + By default clients will be a :py:class:`~redis.Redis` instance. + Specify a different class to the ``redis_class`` argument if you + desire something different. + + The ``connection_pool_class`` specifies the connection pool to use. + The SentinelConnectionPool will be used by default. + + All other keyword arguments are merged with any connection_kwargs + passed to this class and passed to the connection pool as keyword + arguments to be used to initialize Redis connections. + """ + kwargs["is_master"] = False + connection_kwargs = dict(self.connection_kwargs) + connection_kwargs.update(kwargs) + return redis_class( + connection_pool=connection_pool_class( + service_name, self, **connection_kwargs + ) + ) diff --git a/redis/compat.py b/redis/compat.py new file mode 100644 index 0000000000..738687f645 --- /dev/null +++ b/redis/compat.py @@ -0,0 +1,9 @@ +# flake8: noqa +try: + from typing import Literal, Protocol, TypedDict # lgtm [py/unused-import] +except ImportError: + from typing_extensions import ( # lgtm [py/unused-import] + Literal, + Protocol, + TypedDict, + ) diff --git a/redis/typing.py b/redis/typing.py new file mode 100644 index 0000000000..12372a3db6 --- /dev/null +++ b/redis/typing.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Iterable, TypeVar, Union + +from redis.compat import Protocol + +if TYPE_CHECKING: + from redis.asyncio.connection import ConnectionPool as AsyncConnectionPool + from redis.connection import ConnectionPool + + +EncodedT = Union[bytes, memoryview] +DecodedT = Union[str, int, float] +EncodableT = Union[EncodedT, DecodedT] +AbsExpiryT = Union[int, datetime] +ExpiryT = Union[float, timedelta] +ZScoreBoundT = Union[float, str] # str allows for the [ or ( prefix +BitfieldOffsetT = Union[int, str] # str allows for #x syntax +_StringLikeT = Union[bytes, str, memoryview] +KeyT = _StringLikeT # Main redis key space +PatternT = _StringLikeT # Patterns matched against keys, fields etc +FieldT = EncodableT # Fields within hash tables, streams and geo commands +KeysT = Union[KeyT, Iterable[KeyT]] +ChannelT = _StringLikeT +GroupT = _StringLikeT # Consumer group +ConsumerT = _StringLikeT # Consumer name +StreamIdT = Union[int, _StringLikeT] +ScriptTextT = _StringLikeT +TimeoutSecT = Union[int, float, _StringLikeT] +# Mapping is not covariant in the key type, which prevents +# Mapping[_StringLikeT, X from accepting arguments of type Dict[str, X]. Using +# a TypeVar instead of a Union allows mappings with any of the permitted types +# to be passed. Care is needed if there is more than one such mapping in a +# type signature because they will all be required to be the same key type. +AnyKeyT = TypeVar("AnyKeyT", bytes, str, memoryview) +AnyFieldT = TypeVar("AnyFieldT", bytes, str, memoryview) +AnyChannelT = TypeVar("AnyChannelT", bytes, str, memoryview) + + +class CommandsProtocol(Protocol): + connection_pool: Union[AsyncConnectionPool, ConnectionPool] + + def execute_command(self, *args, **options): + ... From 76cccc97eece235e8bd08dc4def5fc1b1db81aea Mon Sep 17 00:00:00 2001 From: Andrew-Chen-Wang Date: Mon, 24 Jan 2022 19:12:25 -0500 Subject: [PATCH 02/24] Add asyncio test suite and remove Py3.6 support * We need Python 3.7 from async-timeout plus 3.6 EOLed * test_commands not implemented yet * Added uvloop support (hopefully) --- .github/workflows/integration.yaml | 4 +- .mypy.ini | 24 + README.md | 2 +- dev_requirements.txt | 2 + docs/index.rst | 2 +- redis/asyncio/__init__.py | 38 + redis/asyncio/client.py | 6 +- redis/asyncio/retry.py | 2 +- redis/asyncio/utils.py | 28 + redis/commands/__init__.py | 6 +- redis/commands/core.py | 1783 ++++++++++++++------ redis/commands/sentinel.py | 6 + requirements.txt | 2 + setup.py | 5 +- tests/conftest.py | 79 +- tests/test_asyncio/__init__.py | 0 tests/test_asyncio/compat.py | 6 + tests/test_asyncio/conftest.py | 200 +++ tests/test_asyncio/test_commands.py | 3 + tests/test_asyncio/test_connection.py | 61 + tests/test_asyncio/test_connection_pool.py | 802 +++++++++ tests/test_asyncio/test_encoding.py | 116 ++ tests/test_asyncio/test_lock.py | 236 +++ tests/test_asyncio/test_monitor.py | 69 + tests/test_asyncio/test_multiprocessing.py | 181 ++ tests/test_asyncio/test_pipeline.py | 408 +++++ tests/test_asyncio/test_pubsub.py | 626 +++++++ tests/test_asyncio/test_retry.py | 68 + tests/test_asyncio/test_scripting.py | 163 ++ tests/test_asyncio/test_sentinel.py | 243 +++ tox.ini | 8 +- 31 files changed, 4629 insertions(+), 550 deletions(-) create mode 100644 .mypy.ini create mode 100644 redis/asyncio/utils.py create mode 100644 tests/test_asyncio/__init__.py create mode 100644 tests/test_asyncio/compat.py create mode 100644 tests/test_asyncio/conftest.py create mode 100644 tests/test_asyncio/test_commands.py create mode 100644 tests/test_asyncio/test_connection.py create mode 100644 tests/test_asyncio/test_connection_pool.py create mode 100644 tests/test_asyncio/test_encoding.py create mode 100644 tests/test_asyncio/test_lock.py create mode 100644 tests/test_asyncio/test_monitor.py create mode 100644 tests/test_asyncio/test_multiprocessing.py create mode 100644 tests/test_asyncio/test_pipeline.py create mode 100644 tests/test_asyncio/test_pubsub.py create mode 100644 tests/test_asyncio/test_retry.py create mode 100644 tests/test_asyncio/test_scripting.py create mode 100644 tests/test_asyncio/test_sentinel.py diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index b034428bcd..4b8b5fa73e 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -34,7 +34,7 @@ jobs: strategy: max-parallel: 15 matrix: - python-version: ['3.6', '3.7', '3.8', '3.9', '3.10', 'pypy-3.7'] + python-version: ['3.7', '3.8', '3.9', '3.10', 'pypy-3.7'] test-type: ['standalone', 'cluster'] connection-type: ['hiredis', 'plain'] env: @@ -78,7 +78,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.6', '3.7', '3.8', '3.9', '3.10', 'pypy-3.7'] + python-version: ['3.7', '3.8', '3.9', '3.10', 'pypy-3.7'] steps: - uses: actions/checkout@v2 - name: install python ${{ matrix.python-version }} diff --git a/.mypy.ini b/.mypy.ini new file mode 100644 index 0000000000..942574e0f3 --- /dev/null +++ b/.mypy.ini @@ -0,0 +1,24 @@ +[mypy] +#, docs/examples, tests +files = redis +check_untyped_defs = True +follow_imports_for_stubs asyncio.= True +#disallow_any_decorated = True +disallow_subclassing_any = True +#disallow_untyped_calls = True +disallow_untyped_decorators = True +#disallow_untyped_defs = True +implicit_reexport = False +no_implicit_optional = True +show_error_codes = True +strict_equality = True +warn_incomplete_stub = True +warn_redundant_casts = True +warn_unreachable = True +warn_unused_ignores = True +disallow_any_unimported = True +#warn_return_any = True + +[mypy-redis.asyncio.lock] +# TODO: Remove once locks has been rewritten +ignore_errors = True diff --git a/README.md b/README.md index 166e80c23b..820fede0f8 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ contributing](https://github.com/redis/redis-py/blob/master/CONTRIBUTING.md). ## Getting Started -redis-py supports Python 3.6+. +redis-py supports Python 3.7+. ``` pycon >>> import redis diff --git a/dev_requirements.txt b/dev_requirements.txt index 1d33b9875b..637d93aaae 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -4,6 +4,7 @@ flynt~=0.69.0 isort==5.10.1 pytest==6.2.5 pytest-timeout==2.0.1 +pytest-asyncio==0.17.2 tox==3.24.4 tox-docker==3.1.0 tox-run-before==0.1 @@ -11,4 +12,5 @@ invoke==1.6.0 pytest-cov>=3.0.0 vulture>=2.3.0 ujson>=4.2.0 +uvloop>=0.16.0 wheel>=0.30.0 diff --git a/docs/index.rst b/docs/index.rst index 51b38a2bf7..e4ddcf4cee 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -9,7 +9,7 @@ Welcome to redis-py's documentation! Getting Started **************** -`redis-py `_ requires a running Redis server, and Python 3.6+. See the `Redis +`redis-py `_ requires a running Redis server, and Python 3.7+. See the `Redis quickstart `_ for Redis installation instructions. redis-py can be installed using pip via ``pip install redis``. diff --git a/redis/asyncio/__init__.py b/redis/asyncio/__init__.py index b762b70642..3959b9acee 100644 --- a/redis/asyncio/__init__.py +++ b/redis/asyncio/__init__.py @@ -7,15 +7,53 @@ UnixDomainSocketConnection, ) from redis.asyncio.utils import from_url +from redis.asyncio.sentinel import ( + Sentinel, + SentinelConnectionPool, + SentinelManagedConnection, + SentinelManagedSSLConnection, +) +from redis.exceptions import ( + AuthenticationError, + AuthenticationWrongNumberOfArgsError, + BusyLoadingError, + ChildDeadlockedError, + ConnectionError, + DataError, + InvalidResponse, + PubSubError, + ReadOnlyError, + RedisError, + ResponseError, + TimeoutError, + WatchError, +) __all__ = [ + "AuthenticationError", + "AuthenticationWrongNumberOfArgsError", "BlockingConnectionPool", + "BusyLoadingError", + "ChildDeadlockedError", "Connection", + "ConnectionError", "ConnectionPool", + "DataError", "from_url", + "InvalidResponse", + "PubSubError", + "ReadOnlyError", "Redis", + "RedisError", + "ResponseError", + "Sentinel", + "SentinelConnectionPool", + "SentinelManagedConnection", + "SentinelManagedSSLConnection", "SSLConnection", "StrictRedis", + "TimeoutError", "UnixDomainSocketConnection", + "WatchError", ] diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 77fd3321c8..d69279c4b7 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -32,9 +32,9 @@ UnixDomainSocketConnection, ) from redis.commands import ( - CoreCommands, + AsyncCoreCommands, RedisModuleCommands, - SentinelCommands, + AsyncSentinelCommands, list_or_args, ) from redis.compat import Protocol, TypedDict @@ -693,7 +693,7 @@ async def __call__(self, response: Any, **kwargs): _R = TypeVar("_R") -class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): +class Redis(RedisModuleCommands, AsyncCoreCommands, AsyncSentinelCommands): """ Implementation of the Redis protocol. diff --git a/redis/asyncio/retry.py b/redis/asyncio/retry.py index a5cc9461fb..d98a5fec87 100644 --- a/redis/asyncio/retry.py +++ b/redis/asyncio/retry.py @@ -52,7 +52,7 @@ async def call_with_retry( return await do() except self._supported_errors as error: failures += 1 - fail(error) + await fail(error) if failures > self._retries: raise error backoff = self._backoff.compute(failures) diff --git a/redis/asyncio/utils.py b/redis/asyncio/utils.py new file mode 100644 index 0000000000..2090e893fa --- /dev/null +++ b/redis/asyncio/utils.py @@ -0,0 +1,28 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from redis.asyncio.client import Redis, Pipeline + + +def from_url(url, **kwargs): + """ + Returns an active Redis client generated from the given database URL. + + Will attempt to extract the database id from the path url fragment, if + none is provided. + """ + from redis.asyncio.client import Redis + + return Redis.from_url(url, **kwargs) + + +class pipeline: + def __init__(self, redis_obj: "Redis"): + self.p: "Pipeline" = redis_obj.pipeline() + + async def __aenter__(self) -> "Pipeline": + return self.p + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.p.execute() + del self.p diff --git a/redis/commands/__init__.py b/redis/commands/__init__.py index 07fa7f1431..b9dd0b7210 100644 --- a/redis/commands/__init__.py +++ b/redis/commands/__init__.py @@ -1,15 +1,17 @@ from .cluster import RedisClusterCommands -from .core import CoreCommands +from .core import AsyncCoreCommands, CoreCommands from .helpers import list_or_args from .parser import CommandsParser from .redismodules import RedisModuleCommands -from .sentinel import SentinelCommands +from .sentinel import AsyncSentinelCommands, SentinelCommands __all__ = [ "RedisClusterCommands", "CommandsParser", + "AsyncCoreCommands", "CoreCommands", "list_or_args", "RedisModuleCommands", + "AsyncSentinelCommands", "SentinelCommands", ] diff --git a/redis/commands/core.py b/redis/commands/core.py index 73003e7fb6..e7e08d1a66 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -1,20 +1,60 @@ +from __future__ import annotations + import datetime import hashlib import time import warnings - +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Awaitable, + Callable, + Iterable, + Mapping, + Sequence, + Union, Iterator, +) + +from redis.compat import Literal from redis.exceptions import ConnectionError, DataError, NoScriptError, RedisError +from redis.typing import ( + AbsExpiryT, + AnyFieldT, + AnyKeyT, + BitfieldOffsetT, + ChannelT, + CommandsProtocol, + ConsumerT, + EncodableT, + ExpiryT, + FieldT, + GroupT, + KeysT, + KeyT, + PatternT, + ScriptTextT, + StreamIdT, + TimeoutSecT, + ZScoreBoundT, +) from .helpers import list_or_args +if TYPE_CHECKING: + from redis.asyncio.client import Redis as AsyncRedis + from redis.client import Redis + +ResponseT = Union[Awaitable, Any] -class ACLCommands: + +class ACLCommands(CommandsProtocol): """ Redis Access Control List (ACL) commands. see: https://redis.io/topics/acl """ - def acl_cat(self, category=None, **kwargs): + def acl_cat(self, category: str | None = None, **kwargs) -> ResponseT: """ Returns a list of categories or commands within a category. @@ -24,10 +64,10 @@ def acl_cat(self, category=None, **kwargs): For more information check https://redis.io/commands/acl-cat """ - pieces = [category] if category else [] + pieces: list[EncodableT] = [category] if category else [] return self.execute_command("ACL CAT", *pieces, **kwargs) - def acl_deluser(self, *username, **kwargs): + def acl_deluser(self, *username: str, **kwargs) -> ResponseT: """ Delete the ACL for the specified ``username``s @@ -35,7 +75,7 @@ def acl_deluser(self, *username, **kwargs): """ return self.execute_command("ACL DELUSER", *username, **kwargs) - def acl_genpass(self, bits=None, **kwargs): + def acl_genpass(self, bits: int | None = None, **kwargs) -> ResponseT: """Generate a random password value. If ``bits`` is supplied then use this number of bits, rounded to the next multiple of 4. @@ -53,7 +93,7 @@ def acl_genpass(self, bits=None, **kwargs): ) return self.execute_command("ACL GENPASS", *pieces, **kwargs) - def acl_getuser(self, username, **kwargs): + def acl_getuser(self, username: str, **kwargs) -> ResponseT: """ Get the ACL details for the specified ``username``. @@ -63,7 +103,7 @@ def acl_getuser(self, username, **kwargs): """ return self.execute_command("ACL GETUSER", username, **kwargs) - def acl_help(self, **kwargs): + def acl_help(self, **kwargs) -> ResponseT: """The ACL HELP command returns helpful text describing the different subcommands. @@ -71,7 +111,7 @@ def acl_help(self, **kwargs): """ return self.execute_command("ACL HELP", **kwargs) - def acl_list(self, **kwargs): + def acl_list(self, **kwargs) -> ResponseT: """ Return a list of all ACLs on the server @@ -79,7 +119,7 @@ def acl_list(self, **kwargs): """ return self.execute_command("ACL LIST", **kwargs) - def acl_log(self, count=None, **kwargs): + def acl_log(self, count: int | None = None, **kwargs) -> ResponseT: """ Get ACL logs as a list. :param int count: Get logs[0:count]. @@ -95,7 +135,7 @@ def acl_log(self, count=None, **kwargs): return self.execute_command("ACL LOG", *args, **kwargs) - def acl_log_reset(self, **kwargs): + def acl_log_reset(self, **kwargs) -> ResponseT: """ Reset ACL logs. :rtype: Boolean. @@ -105,7 +145,7 @@ def acl_log_reset(self, **kwargs): args = [b"RESET"] return self.execute_command("ACL LOG", *args, **kwargs) - def acl_load(self, **kwargs): + def acl_load(self, **kwargs) -> ResponseT: """ Load ACL rules from the configured ``aclfile``. @@ -116,7 +156,7 @@ def acl_load(self, **kwargs): """ return self.execute_command("ACL LOAD", **kwargs) - def acl_save(self, **kwargs): + def acl_save(self, **kwargs) -> ResponseT: """ Save ACL rules to the configured ``aclfile``. @@ -129,19 +169,19 @@ def acl_save(self, **kwargs): def acl_setuser( self, - username, - enabled=False, - nopass=False, - passwords=None, - hashed_passwords=None, - categories=None, - commands=None, - keys=None, - reset=False, - reset_keys=False, - reset_passwords=False, + username: str, + enabled: bool = False, + nopass: bool = False, + passwords: str | Iterable[str] | None = None, + hashed_passwords: str | Iterable[str] | None = None, + categories: Iterable[str] | None = None, + commands: Iterable[str] | None = None, + keys: Iterable[KeyT] | None = None, + reset: bool = False, + reset_keys: bool = False, + reset_passwords: bool = False, **kwargs, - ): + ) -> ResponseT: """ Create or update an ACL user. @@ -204,7 +244,7 @@ def acl_setuser( For more information check https://redis.io/commands/acl-setuser """ encoder = self.get_encoder() - pieces = [username] + pieces: list[str | bytes] = [username] if reset: pieces.append(b"reset") @@ -294,14 +334,14 @@ def acl_setuser( return self.execute_command("ACL SETUSER", *pieces, **kwargs) - def acl_users(self, **kwargs): + def acl_users(self, **kwargs) -> ResponseT: """Returns a list of all registered users on the server. For more information check https://redis.io/commands/acl-users """ return self.execute_command("ACL USERS", **kwargs) - def acl_whoami(self, **kwargs): + def acl_whoami(self, **kwargs) -> ResponseT: """Get the username for the current connection For more information check https://redis.io/commands/acl-whoami @@ -309,19 +349,22 @@ def acl_whoami(self, **kwargs): return self.execute_command("ACL WHOAMI", **kwargs) -class ManagementCommands: +AsyncACLCommands = ACLCommands + + +class ManagementCommands(CommandsProtocol): """ Redis management commands """ - def bgrewriteaof(self, **kwargs): + def bgrewriteaof(self, **kwargs) -> ResponseT: """Tell the Redis server to rewrite the AOF file from data in memory. For more information check https://redis.io/commands/bgrewriteaof """ return self.execute_command("BGREWRITEAOF", **kwargs) - def bgsave(self, schedule=True, **kwargs): + def bgsave(self, schedule: bool = True, **kwargs) -> ResponseT: """ Tell the Redis server to save its data to disk. Unlike save(), this method is asynchronous and returns immediately. @@ -333,7 +376,7 @@ def bgsave(self, schedule=True, **kwargs): pieces.append("SCHEDULE") return self.execute_command("BGSAVE", *pieces, **kwargs) - def role(self): + def role(self) -> ResponseT: """ Provide information on the role of a Redis instance in the context of replication, by returning if the instance @@ -343,7 +386,7 @@ def role(self): """ return self.execute_command("ROLE") - def client_kill(self, address, **kwargs): + def client_kill(self, address: str, **kwargs) -> ResponseT: """Disconnects the client at ``address`` (ip:port) For more information check https://redis.io/commands/client-kill @@ -352,18 +395,18 @@ def client_kill(self, address, **kwargs): def client_kill_filter( self, - _id=None, - _type=None, - addr=None, - skipme=None, - laddr=None, - user=None, + _id: str | None = None, + _type: str | None = None, + addr: str | None = None, + skipme: bool | None = None, + laddr: bool | None = None, + user: str = None, **kwargs, - ): + ) -> ResponseT: """ Disconnects client(s) using a variety of filter options - :param id: Kills a client by its unique ID field - :param type: Kills a client by type where type is one of 'normal', + :param _id: Kills a client by its unique ID field + :param _type: Kills a client by type where type is one of 'normal', 'master', 'slave' or 'pubsub' :param addr: Kills a client by its 'address:port' :param skipme: If True, then the client calling the command @@ -400,7 +443,7 @@ def client_kill_filter( ) return self.execute_command("CLIENT KILL", *args, **kwargs) - def client_info(self, **kwargs): + def client_info(self, **kwargs) -> ResponseT: """ Returns information and statistics about the current client connection. @@ -409,7 +452,12 @@ def client_info(self, **kwargs): """ return self.execute_command("CLIENT INFO", **kwargs) - def client_list(self, _type=None, client_id=[], **kwargs): + def client_list( + self, + _type: str | None = None, + client_id: list[EncodableT] = [], + **kwargs, + ) -> ResponseT: """ Returns a list of currently connected clients. If type of client specified, only that type will be returned. @@ -428,12 +476,12 @@ def client_list(self, _type=None, client_id=[], **kwargs): args.append(_type) if not isinstance(client_id, list): raise DataError("client_id must be a list") - if client_id != []: + if client_id: args.append(b"ID") args.append(" ".join(client_id)) return self.execute_command("CLIENT LIST", *args, **kwargs) - def client_getname(self, **kwargs): + def client_getname(self, **kwargs) -> ResponseT: """ Returns the current connection name @@ -441,7 +489,7 @@ def client_getname(self, **kwargs): """ return self.execute_command("CLIENT GETNAME", **kwargs) - def client_getredir(self, **kwargs): + def client_getredir(self, **kwargs) -> ResponseT: """ Returns the ID (an integer) of the client to whom we are redirecting tracking notifications. @@ -450,7 +498,11 @@ def client_getredir(self, **kwargs): """ return self.execute_command("CLIENT GETREDIR", **kwargs) - def client_reply(self, reply, **kwargs): + def client_reply( + self, + reply: Literal["ON"] | Literal["OFF"] | Literal["SKIP"], + **kwargs, + ) -> ResponseT: """ Enable and disable redis server replies. ``reply`` Must be ON OFF or SKIP, @@ -471,7 +523,7 @@ def client_reply(self, reply, **kwargs): raise DataError(f"CLIENT REPLY must be one of {replies!r}") return self.execute_command("CLIENT REPLY", reply, **kwargs) - def client_id(self, **kwargs): + def client_id(self, **kwargs) -> ResponseT: """ Returns the current connection id @@ -481,13 +533,13 @@ def client_id(self, **kwargs): def client_tracking_on( self, - clientid=None, - prefix=[], - bcast=False, - optin=False, - optout=False, - noloop=False, - ): + clientid: int | None = None, + prefix: Sequence[KeyT] = [], + bcast: bool = False, + optin: bool = False, + optout: bool = False, + noloop: bool = False, + ) -> ResponseT: """ Turn on the tracking mode. For more information about the options look at client_tracking func. @@ -500,13 +552,13 @@ def client_tracking_on( def client_tracking_off( self, - clientid=None, - prefix=[], - bcast=False, - optin=False, - optout=False, - noloop=False, - ): + clientid: int | None = None, + prefix: Sequence[KeyT] = [], + bcast: bool = False, + optin: bool = False, + optout: bool = False, + noloop: bool = False, + ) -> ResponseT: """ Turn off the tracking mode. For more information about the options look at client_tracking func. @@ -519,15 +571,15 @@ def client_tracking_off( def client_tracking( self, - on=True, - clientid=None, - prefix=[], - bcast=False, - optin=False, - optout=False, - noloop=False, + on: bool = True, + clientid: int | None = None, + prefix: Sequence[KeyT] = [], + bcast: bool = False, + optin: bool = False, + optout: bool = False, + noloop: bool = False, **kwargs, - ): + ) -> ResponseT: """ Enables the tracking feature of the Redis server, that is used for server assisted client side caching. @@ -577,7 +629,7 @@ def client_tracking( return self.execute_command("CLIENT TRACKING", *pieces) - def client_trackinginfo(self, **kwargs): + def client_trackinginfo(self, **kwargs) -> ResponseT: """ Returns the information about the current client connection's use of the server assisted client side cache. @@ -586,7 +638,7 @@ def client_trackinginfo(self, **kwargs): """ return self.execute_command("CLIENT TRACKINGINFO", **kwargs) - def client_setname(self, name, **kwargs): + def client_setname(self, name: str, **kwargs) -> ResponseT: """ Sets the current connection name @@ -594,7 +646,12 @@ def client_setname(self, name, **kwargs): """ return self.execute_command("CLIENT SETNAME", name, **kwargs) - def client_unblock(self, client_id, error=False, **kwargs): + def client_unblock( + self, + client_id: int, + error: bool = False, + **kwargs, + ) -> ResponseT: """ Unblocks a connection by its client id. If ``error`` is True, unblocks the client with a special error message. @@ -608,7 +665,7 @@ def client_unblock(self, client_id, error=False, **kwargs): args.append(b"ERROR") return self.execute_command(*args, **kwargs) - def client_pause(self, timeout, all=True, **kwargs): + def client_pause(self, timeout: int , all: bool = True, **kwargs) -> ResponseT: """ Suspend all the Redis clients for the specified amount of time :param timeout: milliseconds to pause clients @@ -631,7 +688,7 @@ def client_pause(self, timeout, all=True, **kwargs): args.append("WRITE") return self.execute_command(*args, **kwargs) - def client_unpause(self, **kwargs): + def client_unpause(self, **kwargs) -> ResponseT: """ Unpause all redis clients @@ -639,7 +696,7 @@ def client_unpause(self, **kwargs): """ return self.execute_command("CLIENT UNPAUSE", **kwargs) - def command(self, **kwargs): + def command(self, **kwargs) -> ResponseT: """ Returns dict reply of details about all Redis commands. @@ -647,15 +704,15 @@ def command(self, **kwargs): """ return self.execute_command("COMMAND", **kwargs) - def command_info(self, **kwargs): + def command_info(self, **kwargs) -> None: raise NotImplementedError( "COMMAND INFO is intentionally not implemented in the client." ) - def command_count(self, **kwargs): + def command_count(self, **kwargs) -> ResponseT: return self.execute_command("COMMAND COUNT", **kwargs) - def config_get(self, pattern="*", **kwargs): + def config_get(self, pattern: PatternT = "*", **kwargs) -> ResponseT: """ Return a dictionary of configuration based on the ``pattern`` @@ -663,14 +720,14 @@ def config_get(self, pattern="*", **kwargs): """ return self.execute_command("CONFIG GET", pattern, **kwargs) - def config_set(self, name, value, **kwargs): + def config_set(self, name: KeyT, value: EncodableT, **kwargs) -> ResponseT: """Set config item ``name`` with ``value`` For more information check https://redis.io/commands/config-set """ return self.execute_command("CONFIG SET", name, value, **kwargs) - def config_resetstat(self, **kwargs): + def config_resetstat(self, **kwargs) -> ResponseT: """ Reset runtime statistics @@ -678,7 +735,7 @@ def config_resetstat(self, **kwargs): """ return self.execute_command("CONFIG RESETSTAT", **kwargs) - def config_rewrite(self, **kwargs): + def config_rewrite(self, **kwargs) -> ResponseT: """ Rewrite config file with the minimal change to reflect running config. @@ -686,7 +743,7 @@ def config_rewrite(self, **kwargs): """ return self.execute_command("CONFIG REWRITE", **kwargs) - def dbsize(self, **kwargs): + def dbsize(self, **kwargs) -> ResponseT: """ Returns the number of keys in the current database @@ -694,7 +751,7 @@ def dbsize(self, **kwargs): """ return self.execute_command("DBSIZE", **kwargs) - def debug_object(self, key, **kwargs): + def debug_object(self, key: KeyT, **kwargs) -> ResponseT: """ Returns version specific meta information about a given key @@ -702,7 +759,7 @@ def debug_object(self, key, **kwargs): """ return self.execute_command("DEBUG OBJECT", key, **kwargs) - def debug_segfault(self, **kwargs): + def debug_segfault(self, **kwargs) -> None: raise NotImplementedError( """ DEBUG SEGFAULT is intentionally not implemented in the client. @@ -711,7 +768,7 @@ def debug_segfault(self, **kwargs): """ ) - def echo(self, value, **kwargs): + def echo(self, value: EncodableT, **kwargs) -> ResponseT: """ Echo the string back from the server @@ -719,7 +776,7 @@ def echo(self, value, **kwargs): """ return self.execute_command("ECHO", value, **kwargs) - def flushall(self, asynchronous=False, **kwargs): + def flushall(self, asynchronous: bool = False, **kwargs) -> ResponseT: """ Delete all keys in all databases on the current host. @@ -733,7 +790,7 @@ def flushall(self, asynchronous=False, **kwargs): args.append(b"ASYNC") return self.execute_command("FLUSHALL", *args, **kwargs) - def flushdb(self, asynchronous=False, **kwargs): + def flushdb(self, asynchronous: bool = False, **kwargs) -> ResponseT: """ Delete all keys in the current database. @@ -747,7 +804,7 @@ def flushdb(self, asynchronous=False, **kwargs): args.append(b"ASYNC") return self.execute_command("FLUSHDB", *args, **kwargs) - def sync(self): + def sync(self) -> ResponseT: """ Initiates a replication stream from the master. @@ -759,7 +816,7 @@ def sync(self): options[NEVER_DECODE] = [] return self.execute_command("SYNC", **options) - def psync(self, replicationid, offset): + def psync(self, replicationid: str, offset: int): """ Initiates a replication stream from the master. Newer version for `sync`. @@ -772,7 +829,7 @@ def psync(self, replicationid, offset): options[NEVER_DECODE] = [] return self.execute_command("PSYNC", replicationid, offset, **options) - def swapdb(self, first, second, **kwargs): + def swapdb(self, first: int, second: int, **kwargs) -> ResponseT: """ Swap two databases @@ -780,14 +837,14 @@ def swapdb(self, first, second, **kwargs): """ return self.execute_command("SWAPDB", first, second, **kwargs) - def select(self, index, **kwargs): + def select(self, index: int, **kwargs) -> ResponseT: """Select the Redis logical database at index. See: https://redis.io/commands/select """ return self.execute_command("SELECT", index, **kwargs) - def info(self, section=None, **kwargs): + def info(self, section: str | None = None, **kwargs) -> ResponseT: """ Returns a dictionary containing information about the Redis server @@ -804,7 +861,7 @@ def info(self, section=None, **kwargs): else: return self.execute_command("INFO", section, **kwargs) - def lastsave(self, **kwargs): + def lastsave(self, **kwargs) -> ResponseT: """ Return a Python datetime object representing the last time the Redis database was saved to disk @@ -813,7 +870,7 @@ def lastsave(self, **kwargs): """ return self.execute_command("LASTSAVE", **kwargs) - def lolwut(self, *version_numbers, **kwargs): + def lolwut(self, *version_numbers: str | float, **kwargs) -> ResponseT: """ Get the Redis version and a piece of generative computer art @@ -824,7 +881,7 @@ def lolwut(self, *version_numbers, **kwargs): else: return self.execute_command("LOLWUT", **kwargs) - def reset(self): + def reset(self) -> ResponseT: """Perform a full reset on the connection's server side contenxt. See: https://redis.io/commands/reset @@ -833,16 +890,16 @@ def reset(self): def migrate( self, - host, - port, - keys, - destination_db, - timeout, - copy=False, - replace=False, - auth=None, + host: str, + port: int, + keys: KeysT, + destination_db: int, + timeout: int, + copy: bool = False, + replace: bool = False, + auth: str | None = None, **kwargs, - ): + ) -> ResponseT: """ Migrate 1 or more keys from the current Redis server to a different server specified by the ``host``, ``port`` and ``destination_db``. @@ -879,7 +936,7 @@ def migrate( "MIGRATE", host, port, "", destination_db, timeout, *pieces, **kwargs ) - def object(self, infotype, key, **kwargs): + def object(self, infotype: str, key: KeyT, **kwargs) -> ResponseT: """ Return the encoding, idletime, or refcount about the key """ @@ -887,7 +944,7 @@ def object(self, infotype, key, **kwargs): "OBJECT", infotype, key, infotype=infotype, **kwargs ) - def memory_doctor(self, **kwargs): + def memory_doctor(self, **kwargs) -> None: raise NotImplementedError( """ MEMORY DOCTOR is intentionally not implemented in the client. @@ -896,7 +953,7 @@ def memory_doctor(self, **kwargs): """ ) - def memory_help(self, **kwargs): + def memory_help(self, **kwargs) -> None: raise NotImplementedError( """ MEMORY HELP is intentionally not implemented in the client. @@ -905,7 +962,7 @@ def memory_help(self, **kwargs): """ ) - def memory_stats(self, **kwargs): + def memory_stats(self, **kwargs) -> ResponseT: """ Return a dictionary of memory stats @@ -913,7 +970,7 @@ def memory_stats(self, **kwargs): """ return self.execute_command("MEMORY STATS", **kwargs) - def memory_malloc_stats(self, **kwargs): + def memory_malloc_stats(self, **kwargs) -> ResponseT: """ Return an internal statistics report from the memory allocator. @@ -921,7 +978,7 @@ def memory_malloc_stats(self, **kwargs): """ return self.execute_command("MEMORY MALLOC-STATS", **kwargs) - def memory_usage(self, key, samples=None, **kwargs): + def memory_usage(self, key: KeyT, samples: int | None = None, **kwargs) -> ResponseT: """ Return the total memory usage for key, its value and associated administrative overheads. @@ -937,7 +994,7 @@ def memory_usage(self, key, samples=None, **kwargs): args.extend([b"SAMPLES", samples]) return self.execute_command("MEMORY USAGE", key, *args, **kwargs) - def memory_purge(self, **kwargs): + def memory_purge(self, **kwargs) -> ResponseT: """ Attempts to purge dirty pages for reclamation by allocator @@ -945,7 +1002,7 @@ def memory_purge(self, **kwargs): """ return self.execute_command("MEMORY PURGE", **kwargs) - def ping(self, **kwargs): + def ping(self, **kwargs) -> ResponseT: """ Ping the Redis server @@ -953,7 +1010,7 @@ def ping(self, **kwargs): """ return self.execute_command("PING", **kwargs) - def quit(self, **kwargs): + def quit(self, **kwargs) -> ResponseT: """ Ask the server to close the connection. @@ -961,7 +1018,7 @@ def quit(self, **kwargs): """ return self.execute_command("QUIT", **kwargs) - def replicaof(self, *args, **kwargs): + def replicaof(self, *args, **kwargs) -> ResponseT: """ Update the replication settings of a redis replica, on the fly. Examples of valid arguments include: @@ -972,7 +1029,7 @@ def replicaof(self, *args, **kwargs): """ return self.execute_command("REPLICAOF", *args, **kwargs) - def save(self, **kwargs): + def save(self, **kwargs) -> ResponseT: """ Tell the Redis server to save its data to disk, blocking until the save is complete @@ -981,7 +1038,7 @@ def save(self, **kwargs): """ return self.execute_command("SAVE", **kwargs) - def shutdown(self, save=False, nosave=False, **kwargs): + def shutdown(self, save: bool = False, nosave: bool = False, **kwargs) -> None: """Shutdown the Redis server. If Redis has persistence configured, data will be flushed before shutdown. If the "save" option is set, a data flush will be attempted even if there is no persistence @@ -1004,7 +1061,7 @@ def shutdown(self, save=False, nosave=False, **kwargs): return raise RedisError("SHUTDOWN seems to have failed.") - def slaveof(self, host=None, port=None, **kwargs): + def slaveof(self, host: str | None = None, port: int | None = None, **kwargs) -> ResponseT: """ Set the server to be a replicated slave of the instance identified by the ``host`` and ``port``. If called without arguments, the @@ -1016,7 +1073,7 @@ def slaveof(self, host=None, port=None, **kwargs): return self.execute_command("SLAVEOF", b"NO", b"ONE", **kwargs) return self.execute_command("SLAVEOF", host, port, **kwargs) - def slowlog_get(self, num=None, **kwargs): + def slowlog_get(self, num: int | None = None, **kwargs) -> ResponseT: """ Get the entries from the slowlog. If ``num`` is specified, get the most recent ``num`` items. @@ -1033,7 +1090,7 @@ def slowlog_get(self, num=None, **kwargs): kwargs[NEVER_DECODE] = [] return self.execute_command(*args, **kwargs) - def slowlog_len(self, **kwargs): + def slowlog_len(self, **kwargs) -> ResponseT: """ Get the number of items in the slowlog @@ -1041,7 +1098,7 @@ def slowlog_len(self, **kwargs): """ return self.execute_command("SLOWLOG LEN", **kwargs) - def slowlog_reset(self, **kwargs): + def slowlog_reset(self, **kwargs) -> ResponseT: """ Remove all items in the slowlog @@ -1049,7 +1106,7 @@ def slowlog_reset(self, **kwargs): """ return self.execute_command("SLOWLOG RESET", **kwargs) - def time(self, **kwargs): + def time(self, **kwargs) -> ResponseT: """ Returns the server time as a 2-item tuple of ints: (seconds since epoch, microseconds into this second). @@ -1058,7 +1115,7 @@ def time(self, **kwargs): """ return self.execute_command("TIME", **kwargs) - def wait(self, num_replicas, timeout, **kwargs): + def wait(self, num_replicas: int, timeout: int, **kwargs) -> ResponseT: """ Redis synchronous replication That returns the number of replicas that processed the query when @@ -1070,12 +1127,52 @@ def wait(self, num_replicas, timeout, **kwargs): return self.execute_command("WAIT", num_replicas, timeout, **kwargs) -class BasicKeyCommands: +AsyncManagementCommands = ManagementCommands + + +class AsyncManagementCommands(ManagementCommands): + async def command_info(self, **kwargs) -> None: + return super().command_info(**kwargs) + + async def debug_segfault(self, **kwargs) -> None: + return super().debug_segfault(**kwargs) + + async def memory_doctor(self, **kwargs) -> None: + return super().memory_doctor(**kwargs) + + async def memory_help(self, **kwargs) -> None: + return super().memory_help(**kwargs) + + async def shutdown(self, save: bool = False, nosave: bool = False, **kwargs) -> None: + """Shutdown the Redis server. If Redis has persistence configured, + data will be flushed before shutdown. If the "save" option is set, + a data flush will be attempted even if there is no persistence + configured. If the "nosave" option is set, no data flush will be + attempted. The "save" and "nosave" options cannot both be set. + + For more information check https://redis.io/commands/shutdown + """ + if save and nosave: + raise DataError("SHUTDOWN save and nosave cannot both be set") + args = ["SHUTDOWN"] + if save: + args.append("SAVE") + if nosave: + args.append("NOSAVE") + try: + await self.execute_command(*args, **kwargs) + except ConnectionError: + # a ConnectionError here is expected + return + raise RedisError("SHUTDOWN seems to have failed.") + + +class BasicKeyCommands(CommandsProtocol): """ Redis basic key-based commands """ - def append(self, key, value): + def append(self, key: KeyT, value: EncodableT) -> ResponseT: """ Appends the string ``value`` to the value at ``key``. If ``key`` doesn't already exist, create it with a value of ``value``. @@ -1085,7 +1182,12 @@ def append(self, key, value): """ return self.execute_command("APPEND", key, value) - def bitcount(self, key, start=None, end=None): + def bitcount( + self, + key: KeyT, + start: int | None = None, + end: int | None = None, + ) -> ResponseT: """ Returns the count of set bits in the value of ``key``. Optional ``start`` and ``end`` parameters indicate which bytes to consider @@ -1100,7 +1202,11 @@ def bitcount(self, key, start=None, end=None): raise DataError("Both start and end must be specified") return self.execute_command("BITCOUNT", *params) - def bitfield(self, key, default_overflow=None): + def bitfield( + self: Redis | AsyncRedis, + key: KeyT, + default_overflow: str | None = None, + ) -> BitFieldOperation: """ Return a BitFieldOperation instance to conveniently construct one or more bitfield operations on ``key``. @@ -1109,7 +1215,12 @@ def bitfield(self, key, default_overflow=None): """ return BitFieldOperation(self, key, default_overflow=default_overflow) - def bitop(self, operation, dest, *keys): + def bitop( + self, + operation: str, + dest: KeyT, + *keys: KeyT, + ) -> ResponseT: """ Perform a bitwise operation using ``operation`` between ``keys`` and store the result in ``dest``. @@ -1118,7 +1229,13 @@ def bitop(self, operation, dest, *keys): """ return self.execute_command("BITOP", operation, dest, *keys) - def bitpos(self, key, bit, start=None, end=None): + def bitpos( + self, + key: KeyT, + bit: int, + start: int | None = None, + end: int | None = None, + ) -> ResponseT: """ Return the position of the first bit set to 1 or 0 in a string. ``start`` and ``end`` defines search range. The range is interpreted @@ -1139,7 +1256,13 @@ def bitpos(self, key, bit, start=None, end=None): raise DataError("start argument is not set, " "when end is specified") return self.execute_command("BITPOS", *params) - def copy(self, source, destination, destination_db=None, replace=False): + def copy( + self, + source: str, + destination: str, + destination_db: str | None = None, + replace: bool = False, + ) -> ResponseT: """ Copy the value stored in the ``source`` key to the ``destination`` key. @@ -1159,7 +1282,7 @@ def copy(self, source, destination, destination_db=None, replace=False): params.append("REPLACE") return self.execute_command("COPY", *params) - def decrby(self, name, amount=1): + def decrby(self, name: KeyT, amount: int = 1) -> ResponseT: """ Decrements the value of ``key`` by ``amount``. If no key exists, the value will be initialized as 0 - ``amount`` @@ -1170,16 +1293,16 @@ def decrby(self, name, amount=1): decr = decrby - def delete(self, *names): + def delete(self, *names: KeyT) -> ResponseT: """ Delete one or more keys specified by ``names`` """ return self.execute_command("DEL", *names) - def __delitem__(self, name): + def __delitem__(self, name: KeyT): self.delete(name) - def dump(self, name): + def dump(self, name: KeyT) -> ResponseT: """ Return a serialized version of the value stored at the specified key. If key does not exist a nil bulk reply is returned. @@ -1192,7 +1315,7 @@ def dump(self, name): options[NEVER_DECODE] = [] return self.execute_command("DUMP", name, **options) - def exists(self, *names): + def exists(self, *names: KeyT) -> ResponseT: """ Returns the number of ``names`` that exist @@ -1202,7 +1325,7 @@ def exists(self, *names): __contains__ = exists - def expire(self, name, time): + def expire(self, name: KeyT, time: ExpiryT) -> ResponseT: """ Set an expire flag on key ``name`` for ``time`` seconds. ``time`` can be represented by an integer or a Python timedelta object. @@ -1213,7 +1336,7 @@ def expire(self, name, time): time = int(time.total_seconds()) return self.execute_command("EXPIRE", name, time) - def expireat(self, name, when): + def expireat(self, name: KeyT, when: AbsExpiryT) -> ResponseT: """ Set an expire flag on key ``name``. ``when`` can be represented as an integer indicating unix time or a Python datetime object. @@ -1224,7 +1347,7 @@ def expireat(self, name, when): when = int(time.mktime(when.timetuple())) return self.execute_command("EXPIREAT", name, when) - def get(self, name): + def get(self, name: KeyT) -> ResponseT: """ Return the value at key ``name``, or None if the key doesn't exist @@ -1232,7 +1355,7 @@ def get(self, name): """ return self.execute_command("GET", name) - def getdel(self, name): + def getdel(self, name: KeyT) -> ResponseT: """ Get the value at key ``name`` and delete the key. This command is similar to GET, except for the fact that it also deletes @@ -1243,7 +1366,15 @@ def getdel(self, name): """ return self.execute_command("GETDEL", name) - def getex(self, name, ex=None, px=None, exat=None, pxat=None, persist=False): + def getex( + self, + name: KeyT, + ex: ExpiryT | None = None, + px: ExpiryT | None = None, + exat: AbsExpiryT | None = None, + pxat: AbsExpiryT | None = None, + persist: bool = False, + ) -> ResponseT: """ Get the value of key and optionally set its expiration. GETEX is similar to GET, but is a write command with @@ -1272,7 +1403,7 @@ def getex(self, name, ex=None, px=None, exat=None, pxat=None, persist=False): "and ``persist`` are mutually exclusive." ) - pieces = [] + pieces: list[EncodableT] = [] # similar to set command if ex is not None: pieces.append("EX") @@ -1302,7 +1433,7 @@ def getex(self, name, ex=None, px=None, exat=None, pxat=None, persist=False): return self.execute_command("GETEX", name, *pieces) - def __getitem__(self, name): + def __getitem__(self, name: KeyT): """ Return the value at key ``name``, raises a KeyError if the key doesn't exist. @@ -1312,7 +1443,7 @@ def __getitem__(self, name): return value raise KeyError(name) - def getbit(self, name, offset): + def getbit(self, name: KeyT, offset: int) -> ResponseT: """ Returns a boolean indicating the value of ``offset`` in ``name`` @@ -1320,7 +1451,7 @@ def getbit(self, name, offset): """ return self.execute_command("GETBIT", name, offset) - def getrange(self, key, start, end): + def getrange(self, key: KeyT, start: int, end: int) -> ResponseT: """ Returns the substring of the string value stored at ``key``, determined by the offsets ``start`` and ``end`` (both are inclusive) @@ -1329,7 +1460,7 @@ def getrange(self, key, start, end): """ return self.execute_command("GETRANGE", key, start, end) - def getset(self, name, value): + def getset(self, name: KeyT, value: EncodableT) -> ResponseT: """ Sets the value at key ``name`` to ``value`` and returns the old value at key ``name`` atomically. @@ -1341,7 +1472,7 @@ def getset(self, name, value): """ return self.execute_command("GETSET", name, value) - def incrby(self, name, amount=1): + def incrby(self, name: KeyT, amount: int = 1) -> ResponseT: """ Increments the value of ``key`` by ``amount``. If no key exists, the value will be initialized as ``amount`` @@ -1352,7 +1483,7 @@ def incrby(self, name, amount=1): incr = incrby - def incrbyfloat(self, name, amount=1.0): + def incrbyfloat(self, name: KeyT, amount: float = 1.0) -> ResponseT: """ Increments the value at key ``name`` by floating ``amount``. If no key exists, the value will be initialized as ``amount`` @@ -1361,7 +1492,7 @@ def incrbyfloat(self, name, amount=1.0): """ return self.execute_command("INCRBYFLOAT", name, amount) - def keys(self, pattern="*", **kwargs): + def keys(self, pattern: PatternT = "*", **kwargs) -> ResponseT: """ Returns a list of keys matching ``pattern`` @@ -1369,7 +1500,13 @@ def keys(self, pattern="*", **kwargs): """ return self.execute_command("KEYS", pattern, **kwargs) - def lmove(self, first_list, second_list, src="LEFT", dest="RIGHT"): + def lmove( + self, + first_list: str, + second_list: str, + src: str = "LEFT", + dest: str = "RIGHT", + ) -> ResponseT: """ Atomically returns and removes the first/last element of a list, pushing it as the first/last element on the destination list. @@ -1380,7 +1517,14 @@ def lmove(self, first_list, second_list, src="LEFT", dest="RIGHT"): params = [first_list, second_list, src, dest] return self.execute_command("LMOVE", *params) - def blmove(self, first_list, second_list, timeout, src="LEFT", dest="RIGHT"): + def blmove( + self, + first_list: str, + second_list: str, + timeout: int, + src: str = "LEFT", + dest: str = "RIGHT", + ) -> ResponseT: """ Blocking version of lmove. @@ -1389,7 +1533,7 @@ def blmove(self, first_list, second_list, timeout, src="LEFT", dest="RIGHT"): params = [first_list, second_list, src, dest, timeout] return self.execute_command("BLMOVE", *params) - def mget(self, keys, *args): + def mget(self, keys: KeysT, *args: EncodableT) -> ResponseT: """ Returns a list of values ordered identically to ``keys`` @@ -1403,7 +1547,7 @@ def mget(self, keys, *args): options[EMPTY_RESPONSE] = [] return self.execute_command("MGET", *args, **options) - def mset(self, mapping): + def mset(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT: """ Sets key/values based on a mapping. Mapping is a dictionary of key/value pairs. Both keys and values should be strings or types that @@ -1416,7 +1560,7 @@ def mset(self, mapping): items.extend(pair) return self.execute_command("MSET", *items) - def msetnx(self, mapping): + def msetnx(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT: """ Sets key/values based on a mapping if none of the keys are already set. Mapping is a dictionary of key/value pairs. Both keys and values @@ -1430,7 +1574,7 @@ def msetnx(self, mapping): items.extend(pair) return self.execute_command("MSETNX", *items) - def move(self, name, db): + def move(self, name: KeyT, db: int) -> ResponseT: """ Moves the key ``name`` to a different Redis database ``db`` @@ -1438,7 +1582,7 @@ def move(self, name, db): """ return self.execute_command("MOVE", name, db) - def persist(self, name): + def persist(self, name: KeyT) -> ResponseT: """ Removes an expiration on ``name`` @@ -1446,7 +1590,7 @@ def persist(self, name): """ return self.execute_command("PERSIST", name) - def pexpire(self, name, time): + def pexpire(self, name: KeyT, time: ExpiryT) -> ResponseT: """ Set an expire flag on key ``name`` for ``time`` milliseconds. ``time`` can be represented by an integer or a Python timedelta @@ -1458,7 +1602,7 @@ def pexpire(self, name, time): time = int(time.total_seconds() * 1000) return self.execute_command("PEXPIRE", name, time) - def pexpireat(self, name, when): + def pexpireat(self, name: KeyT, when: AbsExpiryT) -> ResponseT: """ Set an expire flag on key ``name``. ``when`` can be represented as an integer representing unix time in milliseconds (unix time * 1000) @@ -1471,7 +1615,12 @@ def pexpireat(self, name, when): when = int(time.mktime(when.timetuple())) * 1000 + ms return self.execute_command("PEXPIREAT", name, when) - def psetex(self, name, time_ms, value): + def psetex( + self, + name: KeyT, + time_ms: ExpiryT, + value: EncodableT, + ): """ Set the value of key ``name`` to ``value`` that expires in ``time_ms`` milliseconds. ``time_ms`` can be represented by an integer or a Python @@ -1483,7 +1632,7 @@ def psetex(self, name, time_ms, value): time_ms = int(time_ms.total_seconds() * 1000) return self.execute_command("PSETEX", name, time_ms, value) - def pttl(self, name): + def pttl(self, name: KeyT) -> ResponseT: """ Returns the number of milliseconds until the key ``name`` will expire @@ -1491,7 +1640,12 @@ def pttl(self, name): """ return self.execute_command("PTTL", name) - def hrandfield(self, key, count=None, withvalues=False): + def hrandfield( + self, + key: str, + count: int = None, + withvalues: bool = False, + ) -> ResponseT: """ Return a random field from the hash value stored at key. @@ -1513,7 +1667,7 @@ def hrandfield(self, key, count=None, withvalues=False): return self.execute_command("HRANDFIELD", key, *params) - def randomkey(self, **kwargs): + def randomkey(self, **kwargs) -> ResponseT: """ Returns the name of a random key @@ -1521,7 +1675,7 @@ def randomkey(self, **kwargs): """ return self.execute_command("RANDOMKEY", **kwargs) - def rename(self, src, dst): + def rename(self, src: KeyT, dst: KeyT) -> ResponseT: """ Rename key ``src`` to ``dst`` @@ -1529,7 +1683,7 @@ def rename(self, src, dst): """ return self.execute_command("RENAME", src, dst) - def renamenx(self, src, dst): + def renamenx(self, src: KeyT, dst: KeyT): """ Rename key ``src`` to ``dst`` if ``dst`` doesn't already exist @@ -1539,14 +1693,14 @@ def renamenx(self, src, dst): def restore( self, - name, - ttl, - value, - replace=False, - absttl=False, - idletime=None, - frequency=None, - ): + name: KeyT, + ttl: float, + value: EncodableT, + replace: bool = False, + absttl: bool = False, + idletime: int | None = None, + frequency: int | None = None, + ) -> ResponseT: """ Create a key using the provided serialized value, previously obtained using DUMP. @@ -1589,17 +1743,17 @@ def restore( def set( self, - name, - value, - ex=None, - px=None, - nx=False, - xx=False, - keepttl=False, - get=False, - exat=None, - pxat=None, - ): + name: KeyT, + value: EncodableT, + ex: ExpiryT | None = None, + px: ExpiryT | None = None, + nx: bool = False, + xx: bool = False, + keepttl: bool = False, + get: bool = False, + exat: AbsExpiryT | None = None, + pxat: AbsExpiryT | None = None, + ) -> ResponseT: """ Set the value at key ``name`` to ``value`` @@ -1628,7 +1782,7 @@ def set( For more information check https://redis.io/commands/set """ - pieces = [name, value] + pieces: list[EncodableT] = [name, value] options = {} if ex is not None: pieces.append("EX") @@ -1672,10 +1826,10 @@ def set( return self.execute_command("SET", *pieces, **options) - def __setitem__(self, name, value): + def __setitem__(self, name: KeyT, value: EncodableT): self.set(name, value) - def setbit(self, name, offset, value): + def setbit(self, name: KeyT, offset: int, value: int) -> ResponseT: """ Flag the ``offset`` in ``name`` as ``value``. Returns a boolean indicating the previous value of ``offset``. @@ -1685,7 +1839,7 @@ def setbit(self, name, offset, value): value = value and 1 or 0 return self.execute_command("SETBIT", name, offset, value) - def setex(self, name, time, value): + def setex(self, name: KeyT, time: ExpiryT, value: EncodableT) -> ResponseT: """ Set the value of key ``name`` to ``value`` that expires in ``time`` seconds. ``time`` can be represented by an integer or a Python @@ -1697,7 +1851,7 @@ def setex(self, name, time, value): time = int(time.total_seconds()) return self.execute_command("SETEX", name, time, value) - def setnx(self, name, value): + def setnx(self, name: KeyT, value: EncodableT) -> ResponseT: """ Set the value of key ``name`` to ``value`` if key doesn't exist @@ -1705,7 +1859,12 @@ def setnx(self, name, value): """ return self.execute_command("SETNX", name, value) - def setrange(self, name, offset, value): + def setrange( + self, + name: KeyT, + offset: int, + value: EncodableT, + ) -> ResponseT: """ Overwrite bytes in the value of ``name`` starting at ``offset`` with ``value``. If ``offset`` plus the length of ``value`` exceeds the @@ -1722,16 +1881,16 @@ def setrange(self, name, offset, value): def stralgo( self, - algo, - value1, - value2, - specific_argument="strings", - len=False, - idx=False, - minmatchlen=None, - withmatchlen=False, + algo: Literal["LCS"], + value1: KeyT, + value2: KeyT, + specific_argument: Literal["strings"] | Literal["keys"] = "strings", + len: bool = False, + idx: bool = False, + minmatchlen: int | None = None, + withmatchlen: bool = False, **kwargs, - ): + ) -> ResponseT: """ Implements complex algorithms that operate on strings. Right now the only algorithm implemented is the LCS algorithm @@ -1761,7 +1920,7 @@ def stralgo( if len and idx: raise DataError("len and idx cannot be provided together.") - pieces = [algo, specific_argument.upper(), value1, value2] + pieces: list[EncodableT] = [algo, specific_argument.upper(), value1, value2] if len: pieces.append(b"LEN") if idx: @@ -1784,7 +1943,7 @@ def stralgo( **kwargs, ) - def strlen(self, name): + def strlen(self, name: KeyT) -> ResponseT: """ Return the number of bytes stored in the value of ``name`` @@ -1792,14 +1951,14 @@ def strlen(self, name): """ return self.execute_command("STRLEN", name) - def substr(self, name, start, end=-1): + def substr(self, name: KeyT, start: int, end: int = -1) -> ResponseT: """ Return a substring of the string at key ``name``. ``start`` and ``end`` are 0-based integers specifying the portion of the string to return. """ return self.execute_command("SUBSTR", name, start, end) - def touch(self, *args): + def touch(self, *args: KeyT) -> ResponseT: """ Alters the last access time of a key(s) ``*args``. A key is ignored if it does not exist. @@ -1808,7 +1967,7 @@ def touch(self, *args): """ return self.execute_command("TOUCH", *args) - def ttl(self, name): + def ttl(self, name: KeyT) -> ResponseT: """ Returns the number of seconds until the key ``name`` will expire @@ -1816,7 +1975,7 @@ def ttl(self, name): """ return self.execute_command("TTL", name) - def type(self, name): + def type(self, name: KeyT) -> ResponseT: """ Returns the type of key ``name`` @@ -1824,7 +1983,7 @@ def type(self, name): """ return self.execute_command("TYPE", name) - def watch(self, *names): + def watch(self, *names: KeyT) -> None: """ Watches the values at keys ``names``, or None if the key doesn't exist @@ -1832,7 +1991,7 @@ def watch(self, *names): """ warnings.warn(DeprecationWarning("Call WATCH from a Pipeline object")) - def unwatch(self): + def unwatch(self) -> None: """ Unwatches the value at key ``name``, or None of the key doesn't exist @@ -1840,7 +1999,7 @@ def unwatch(self): """ warnings.warn(DeprecationWarning("Call UNWATCH from a Pipeline object")) - def unlink(self, *names): + def unlink(self, *names: KeyT) -> ResponseT: """ Unlink one or more keys specified by ``names`` @@ -1849,13 +2008,33 @@ def unlink(self, *names): return self.execute_command("UNLINK", *names) -class ListCommands: +class AsyncBasicKeyCommands(BasicKeyCommands): + def __delitem__(self, name: KeyT): + raise TypeError("Async Redis client does not support class deletion") + + def __contains__(self, name: KeyT): + raise TypeError("Async Redis client does not support class inclusion") + + def __getitem__(self, name: KeyT): + raise TypeError("Async Redis client does not support class retrieval") + + def __setitem__(self, name: KeyT, value: EncodableT): + raise TypeError("Async Redis client does not support class assignment") + + async def watch(self, *names: KeyT) -> None: + return super().watch(*names) + + async def unwatch(self) -> None: + return super().unwatch() + + +class ListCommands(CommandsProtocol): """ Redis commands for List data type. see: https://redis.io/topics/data-types#lists """ - def blpop(self, keys, timeout=0): + def blpop(self, keys: KeysT, timeout: TimeoutSecT = 0) -> ResponseT: """ LPOP a value off of the first non-empty list named in the ``keys`` list. @@ -1874,7 +2053,7 @@ def blpop(self, keys, timeout=0): keys.append(timeout) return self.execute_command("BLPOP", *keys) - def brpop(self, keys, timeout=0): + def brpop(self, keys: KeysT, timeout: TimeoutSecT = 0) -> ResponseT: """ RPOP a value off of the first non-empty list named in the ``keys`` list. @@ -1893,7 +2072,12 @@ def brpop(self, keys, timeout=0): keys.append(timeout) return self.execute_command("BRPOP", *keys) - def brpoplpush(self, src, dst, timeout=0): + def brpoplpush( + self, + src: KeyT, + dst: KeyT, + timeout: TimeoutSecT = 0, + ) -> ResponseT: """ Pop a value off the tail of ``src``, push it on the head of ``dst`` and then return it. @@ -1908,7 +2092,7 @@ def brpoplpush(self, src, dst, timeout=0): timeout = 0 return self.execute_command("BRPOPLPUSH", src, dst, timeout) - def lindex(self, name, index): + def lindex(self, name: KeyT, index: int) -> ResponseT: """ Return the item from list ``name`` at position ``index`` @@ -1919,7 +2103,13 @@ def lindex(self, name, index): """ return self.execute_command("LINDEX", name, index) - def linsert(self, name, where, refvalue, value): + def linsert( + self, + name: KeyT, + where: str, + refvalue: EncodableT, + value: EncodableT, + ) -> ResponseT: """ Insert ``value`` in list ``name`` either immediately before or after [``where``] ``refvalue`` @@ -1931,7 +2121,7 @@ def linsert(self, name, where, refvalue, value): """ return self.execute_command("LINSERT", name, where, refvalue, value) - def llen(self, name): + def llen(self, name: KeyT) -> ResponseT: """ Return the length of the list ``name`` @@ -1939,7 +2129,11 @@ def llen(self, name): """ return self.execute_command("LLEN", name) - def lpop(self, name, count=None): + def lpop( + self, + name: KeyT, + count: int | None = None, + ) -> ResponseT: """ Removes and returns the first elements of the list ``name``. @@ -1954,7 +2148,7 @@ def lpop(self, name, count=None): else: return self.execute_command("LPOP", name) - def lpush(self, name, *values): + def lpush(self, name: KeyT, *values: EncodableT) -> ResponseT: """ Push ``values`` onto the head of the list ``name`` @@ -1962,7 +2156,7 @@ def lpush(self, name, *values): """ return self.execute_command("LPUSH", name, *values) - def lpushx(self, name, *values): + def lpushx(self, name: KeyT, *values: EncodableT) -> ResponseT: """ Push ``value`` onto the head of the list ``name`` if ``name`` exists @@ -1970,7 +2164,7 @@ def lpushx(self, name, *values): """ return self.execute_command("LPUSHX", name, *values) - def lrange(self, name, start, end): + def lrange(self, name: KeyT, start: int, end: int) -> ResponseT: """ Return a slice of the list ``name`` between position ``start`` and ``end`` @@ -1982,7 +2176,12 @@ def lrange(self, name, start, end): """ return self.execute_command("LRANGE", name, start, end) - def lrem(self, name, count, value): + def lrem( + self, + name: KeyT, + count: int, + value: EncodableT, + ) -> ResponseT: """ Remove the first ``count`` occurrences of elements equal to ``value`` from the list stored at ``name``. @@ -1996,7 +2195,12 @@ def lrem(self, name, count, value): """ return self.execute_command("LREM", name, count, value) - def lset(self, name, index, value): + def lset( + self, + name: KeyT, + index: int, + value: EncodableT, + ) -> ResponseT: """ Set ``position`` of list ``name`` to ``value`` @@ -2004,7 +2208,7 @@ def lset(self, name, index, value): """ return self.execute_command("LSET", name, index, value) - def ltrim(self, name, start, end): + def ltrim(self, name: KeyT, start: int, end: int) -> ResponseT: """ Trim the list ``name``, removing all values not within the slice between ``start`` and ``end`` @@ -2016,7 +2220,11 @@ def ltrim(self, name, start, end): """ return self.execute_command("LTRIM", name, start, end) - def rpop(self, name, count=None): + def rpop( + self, + name: KeyT, + count: int | None = None, + ) -> ResponseT: """ Removes and returns the last elements of the list ``name``. @@ -2031,7 +2239,7 @@ def rpop(self, name, count=None): else: return self.execute_command("RPOP", name) - def rpoplpush(self, src, dst): + def rpoplpush(self, src: KeyT, dst: KeyT) -> ResponseT: """ RPOP a value off of the ``src`` list and atomically LPUSH it on to the ``dst`` list. Returns the value. @@ -2040,7 +2248,7 @@ def rpoplpush(self, src, dst): """ return self.execute_command("RPOPLPUSH", src, dst) - def rpush(self, name, *values): + def rpush(self, name: KeyT, *values: EncodableT) -> ResponseT: """ Push ``values`` onto the tail of the list ``name`` @@ -2048,7 +2256,7 @@ def rpush(self, name, *values): """ return self.execute_command("RPUSH", name, *values) - def rpushx(self, name, value): + def rpushx(self, name: KeyT, value: EncodableT) -> ResponseT: """ Push ``value`` onto the tail of the list ``name`` if ``name`` exists @@ -2056,7 +2264,14 @@ def rpushx(self, name, value): """ return self.execute_command("RPUSHX", name, value) - def lpos(self, name, value, rank=None, count=None, maxlen=None): + def lpos( + self, + name: KeyT, + value: EncodableT, + rank: int | None = None, + count: int | None = None, + maxlen: int | None = None, + ) -> ResponseT: """ Get position of ``value`` within the list ``name`` @@ -2082,7 +2297,7 @@ def lpos(self, name, value, rank=None, count=None, maxlen=None): For more information check https://redis.io/commands/lpos """ - pieces = [name, value] + pieces: list[EncodableT] = [name, value] if rank is not None: pieces.extend(["RANK", rank]) @@ -2096,16 +2311,16 @@ def lpos(self, name, value, rank=None, count=None, maxlen=None): def sort( self, - name, - start=None, - num=None, - by=None, - get=None, - desc=False, - alpha=False, - store=None, - groups=False, - ): + name: KeyT, + start: int | None = None, + num: int | None = None, + by: KeyT | None = None, + get: KeysT | None = None, + desc: bool = False, + alpha: bool = False, + store: KeyT | None = None, + groups: bool = False, + ) -> ResponseT: """ Sort and return the list, set or sorted set at ``name``. @@ -2134,7 +2349,7 @@ def sort( if (start is not None and num is None) or (num is not None and start is None): raise DataError("``start`` and ``num`` must both be specified") - pieces = [name] + pieces: list[EncodableT] = [name] if by is not None: pieces.extend([b"BY", by]) if start is not None and num is not None: @@ -2167,13 +2382,23 @@ def sort( return self.execute_command("SORT", *pieces, **options) -class ScanCommands: +AsyncListCommands = ListCommands + + +class ScanCommands(CommandsProtocol): """ Redis SCAN commands. see: https://redis.io/commands/scan """ - def scan(self, cursor=0, match=None, count=None, _type=None, **kwargs): + def scan( + self, + cursor: int = 0, + match: PatternT | None = None, + count: int | None = None, + _type: str | None = None, + **kwargs, + ) -> ResponseT: """ Incrementally return lists of key names. Also return a cursor indicating the scan position. @@ -2190,7 +2415,7 @@ def scan(self, cursor=0, match=None, count=None, _type=None, **kwargs): For more information check https://redis.io/commands/scan """ - pieces = [cursor] + pieces: list[EncodableT] = [cursor] if match is not None: pieces.extend([b"MATCH", match]) if count is not None: @@ -2199,7 +2424,13 @@ def scan(self, cursor=0, match=None, count=None, _type=None, **kwargs): pieces.extend([b"TYPE", _type]) return self.execute_command("SCAN", *pieces, **kwargs) - def scan_iter(self, match=None, count=None, _type=None, **kwargs): + def scan_iter( + self, + match: PatternT | None = None, + count: int | None = None, + _type: str | None = None, + **kwargs + ) -> Iterator: """ Make an iterator using the SCAN command so that the client doesn't need to remember the cursor position. @@ -2221,7 +2452,13 @@ def scan_iter(self, match=None, count=None, _type=None, **kwargs): ) yield from data - def sscan(self, name, cursor=0, match=None, count=None): + def sscan( + self, + name: KeyT, + cursor: int = 0, + match: PatternT | None = None, + count: int | None = None, + ) -> ResponseT: """ Incrementally return lists of elements in a set. Also return a cursor indicating the scan position. @@ -2232,14 +2469,19 @@ def sscan(self, name, cursor=0, match=None, count=None): For more information check https://redis.io/commands/sscan """ - pieces = [name, cursor] + pieces: list[EncodableT] = [name, cursor] if match is not None: pieces.extend([b"MATCH", match]) if count is not None: pieces.extend([b"COUNT", count]) return self.execute_command("SSCAN", *pieces) - def sscan_iter(self, name, match=None, count=None): + def sscan_iter( + self, + name: KeyT, + match: PatternT | None = None, + count: int | None = None, + ) -> Iterator: """ Make an iterator using the SSCAN command so that the client doesn't need to remember the cursor position. @@ -2253,7 +2495,13 @@ def sscan_iter(self, name, match=None, count=None): cursor, data = self.sscan(name, cursor=cursor, match=match, count=count) yield from data - def hscan(self, name, cursor=0, match=None, count=None): + def hscan( + self, + name: KeyT, + cursor: int = 0, + match: PatternT | None = None, + count: int | None = None, + ) -> ResponseT: """ Incrementally return key/value slices in a hash. Also return a cursor indicating the scan position. @@ -2264,14 +2512,19 @@ def hscan(self, name, cursor=0, match=None, count=None): For more information check https://redis.io/commands/hscan """ - pieces = [name, cursor] + pieces: list[EncodableT] = [name, cursor] if match is not None: pieces.extend([b"MATCH", match]) if count is not None: pieces.extend([b"COUNT", count]) return self.execute_command("HSCAN", *pieces) - def hscan_iter(self, name, match=None, count=None): + def hscan_iter( + self, + name: str, + match: PatternT | None = None, + count: int | None = None, + ) -> Iterator: """ Make an iterator using the HSCAN command so that the client doesn't need to remember the cursor position. @@ -2285,7 +2538,14 @@ def hscan_iter(self, name, match=None, count=None): cursor, data = self.hscan(name, cursor=cursor, match=match, count=count) yield from data.items() - def zscan(self, name, cursor=0, match=None, count=None, score_cast_func=float): + def zscan( + self, + name: KeyT, + cursor: int = 0, + match: PatternT | None = None, + count: int | None = None, + score_cast_func: type | Callable = float, + ) -> ResponseT: """ Incrementally return lists of elements in a sorted set. Also return a cursor indicating the scan position. @@ -2306,7 +2566,13 @@ def zscan(self, name, cursor=0, match=None, count=None, score_cast_func=float): options = {"score_cast_func": score_cast_func} return self.execute_command("ZSCAN", *pieces, **options) - def zscan_iter(self, name, match=None, count=None, score_cast_func=float): + def zscan_iter( + self, + name: KeyT, + match: PatternT | None = None, + count: int | None = None, + score_cast_func: type | Callable = float, + ) -> Iterator: """ Make an iterator using the ZSCAN command so that the client doesn't need to remember the cursor position. @@ -2329,13 +2595,117 @@ def zscan_iter(self, name, match=None, count=None, score_cast_func=float): yield from data -class SetCommands: +class AsyncScanCommands(ScanCommands): + async def scan_iter( + self, + match: PatternT | None = None, + count: int | None = None, + _type: str | None = None, + **kwargs + ) -> AsyncIterator: + """ + Make an iterator using the SCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` provides a hint to Redis about the number of keys to + return per batch. + + ``_type`` filters the returned values by a particular Redis type. + Stock Redis instances allow for the following types: + HASH, LIST, SET, STREAM, STRING, ZSET + Additionally, Redis modules can expose other types as well. + """ + cursor = "0" + while cursor != 0: + cursor, data = await self.scan( + cursor=cursor, match=match, count=count, _type=_type, **kwargs + ) + for d in data: + yield d + + async def sscan_iter( + self, + name: KeyT, + match: PatternT | None = None, + count: int | None = None, + ) -> AsyncIterator: + """ + Make an iterator using the SSCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + """ + cursor = "0" + while cursor != 0: + cursor, data = await self.sscan( + name, cursor=cursor, match=match, count=count + ) + for d in data: + yield d + + async def hscan_iter( + self, + name: str, + match: PatternT | None = None, + count: int | None = None, + ) -> AsyncIterator: + """ + Make an iterator using the HSCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + """ + cursor = "0" + while cursor != 0: + cursor, data = await self.hscan( + name, cursor=cursor, match=match, count=count + ) + for it in data.items(): + yield it + + async def zscan_iter( + self, + name: KeyT, + match: PatternT | None = None, + count: int | None = None, + score_cast_func: type | Callable = float, + ) -> AsyncIterator: + """ + Make an iterator using the ZSCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + + ``score_cast_func`` a callable used to cast the score return value + """ + cursor = "0" + while cursor != 0: + cursor, data = await self.zscan( + name, + cursor=cursor, + match=match, + count=count, + score_cast_func=score_cast_func, + ) + for d in data: + yield d + + +class SetCommands(CommandsProtocol): """ Redis commands for Set data type. see: https://redis.io/topics/data-types#sets """ - def sadd(self, name, *values): + def sadd(self, name: KeyT, *values: EncodableT) -> ResponseT: """ Add ``value(s)`` to set ``name`` @@ -2343,7 +2713,7 @@ def sadd(self, name, *values): """ return self.execute_command("SADD", name, *values) - def scard(self, name): + def scard(self, name: KeyT) -> ResponseT: """ Return the number of elements in set ``name`` @@ -2351,7 +2721,7 @@ def scard(self, name): """ return self.execute_command("SCARD", name) - def sdiff(self, keys, *args): + def sdiff(self, keys: KeysT, *args: EncodableT) -> ResponseT: """ Return the difference of sets specified by ``keys`` @@ -2360,7 +2730,12 @@ def sdiff(self, keys, *args): args = list_or_args(keys, args) return self.execute_command("SDIFF", *args) - def sdiffstore(self, dest, keys, *args): + def sdiffstore( + self, + dest: KeyT, + keys: KeysT, + *args: EncodableT, + ) -> ResponseT: """ Store the difference of sets specified by ``keys`` into a new set named ``dest``. Returns the number of keys in the new set. @@ -2370,7 +2745,7 @@ def sdiffstore(self, dest, keys, *args): args = list_or_args(keys, args) return self.execute_command("SDIFFSTORE", dest, *args) - def sinter(self, keys, *args): + def sinter(self, keys: KeysT, *args: EncodableT) -> ResponseT: """ Return the intersection of sets specified by ``keys`` @@ -2379,7 +2754,12 @@ def sinter(self, keys, *args): args = list_or_args(keys, args) return self.execute_command("SINTER", *args) - def sinterstore(self, dest, keys, *args): + def sinterstore( + self, + dest: KeyT, + keys: KeysT, + *args: EncodableT, + ) -> ResponseT: """ Store the intersection of sets specified by ``keys`` into a new set named ``dest``. Returns the number of keys in the new set. @@ -2389,7 +2769,7 @@ def sinterstore(self, dest, keys, *args): args = list_or_args(keys, args) return self.execute_command("SINTERSTORE", dest, *args) - def sismember(self, name, value): + def sismember(self, name: KeyT, value: EncodableT) -> ResponseT: """ Return a boolean indicating if ``value`` is a member of set ``name`` @@ -2397,7 +2777,7 @@ def sismember(self, name, value): """ return self.execute_command("SISMEMBER", name, value) - def smembers(self, name): + def smembers(self, name: KeyT) -> ResponseT: """ Return all members of the set ``name`` @@ -2405,7 +2785,7 @@ def smembers(self, name): """ return self.execute_command("SMEMBERS", name) - def smismember(self, name, values, *args): + def smismember(self, name: KeyT, values: Sequence[EncodableT], *args: EncodableT) -> ResponseT: """ Return whether each value in ``values`` is a member of the set ``name`` as a list of ``bool`` in the order of ``values`` @@ -2415,7 +2795,12 @@ def smismember(self, name, values, *args): args = list_or_args(values, args) return self.execute_command("SMISMEMBER", name, *args) - def smove(self, src, dst, value): + def smove( + self, + src: KeyT, + dst: KeyT, + value: EncodableT, + ) -> ResponseT: """ Move ``value`` from set ``src`` to set ``dst`` atomically @@ -2423,7 +2808,11 @@ def smove(self, src, dst, value): """ return self.execute_command("SMOVE", src, dst, value) - def spop(self, name, count=None): + def spop( + self, + name: KeyT, + count: int | None = None, + ) -> ResponseT: """ Remove and return a random member of set ``name`` @@ -2432,7 +2821,11 @@ def spop(self, name, count=None): args = (count is not None) and [count] or [] return self.execute_command("SPOP", name, *args) - def srandmember(self, name, number=None): + def srandmember( + self, + name: KeyT, + number: int | None = None, + ) -> ResponseT: """ If ``number`` is None, returns a random member of set ``name``. @@ -2445,7 +2838,7 @@ def srandmember(self, name, number=None): args = (number is not None) and [number] or [] return self.execute_command("SRANDMEMBER", name, *args) - def srem(self, name, *values): + def srem(self, name: KeyT, *values: EncodableT) -> ResponseT: """ Remove ``values`` from set ``name`` @@ -2453,7 +2846,7 @@ def srem(self, name, *values): """ return self.execute_command("SREM", name, *values) - def sunion(self, keys, *args): + def sunion(self, keys: KeysT, *args: EncodableT) -> ResponseT: """ Return the union of sets specified by ``keys`` @@ -2462,7 +2855,12 @@ def sunion(self, keys, *args): args = list_or_args(keys, args) return self.execute_command("SUNION", *args) - def sunionstore(self, dest, keys, *args): + def sunionstore( + self, + dest: KeyT, + keys: KeysT, + *args: EncodableT, + ) -> ResponseT: """ Store the union of sets specified by ``keys`` into a new set named ``dest``. Returns the number of keys in the new set. @@ -2473,13 +2871,21 @@ def sunionstore(self, dest, keys, *args): return self.execute_command("SUNIONSTORE", dest, *args) -class StreamCommands: +AsyncSetCommands = SetCommands + + +class StreamCommands(CommandsProtocol): """ Redis commands for Stream data type. see: https://redis.io/topics/streams-intro """ - def xack(self, name, groupname, *ids): + def xack( + self, + name: KeyT, + groupname: GroupT, + *ids: StreamIdT, + ) -> ResponseT: """ Acknowledges the successful processing of one or more messages. name: name of the stream. @@ -2492,15 +2898,15 @@ def xack(self, name, groupname, *ids): def xadd( self, - name, - fields, - id="*", - maxlen=None, - approximate=True, - nomkstream=False, - minid=None, - limit=None, - ): + name: KeyT, + fields: dict[FieldT, EncodableT], + id: StreamIdT = "*", + maxlen: int | None = None, + approximate: bool = True, + nomkstream: bool = False, + minid: StreamIdT | None = None, + limit: int | None = None, + ) -> ResponseT: """ Add to a stream. name: name of the stream @@ -2516,7 +2922,7 @@ def xadd( For more information check https://redis.io/commands/xadd """ - pieces = [] + pieces: list[EncodableT] = [] if maxlen is not None and minid is not None: raise DataError( "Only one of ```maxlen``` or ```minid``` " "may be specified" @@ -2547,14 +2953,14 @@ def xadd( def xautoclaim( self, - name, - groupname, - consumername, - min_idle_time, - start_id=0, - count=None, - justid=False, - ): + name: KeyT, + groupname: GroupT, + consumername: ConsumerT, + min_idle_time: int, + start_id: int = 0, + count: int | None = None, + justid: bool = False, + ) -> ResponseT: """ Transfers ownership of pending stream entries that match the specified criteria. Conceptually, equivalent to calling XPENDING and then XCLAIM, @@ -2598,17 +3004,17 @@ def xautoclaim( def xclaim( self, - name, - groupname, - consumername, - min_idle_time, - message_ids, - idle=None, - time=None, - retrycount=None, - force=False, - justid=False, - ): + name: KeyT, + groupname: GroupT, + consumername: ConsumerT, + min_idle_time: int, + message_ids: list[StreamIdT] | tuple[StreamIdT], + idle: int | None = None, + time: int | None = None, + retrycount: int | None = None, + force: bool = False, + justid: bool = False, + ) -> ResponseT: """ Changes the ownership of a pending message. name: name of the stream. @@ -2642,7 +3048,7 @@ def xclaim( ) kwargs = {} - pieces = [name, groupname, consumername, str(min_idle_time)] + pieces: list[EncodableT] = [name, groupname, consumername, str(min_idle_time)] pieces.extend(list(message_ids)) if idle is not None: @@ -2669,7 +3075,7 @@ def xclaim( kwargs["parse_justid"] = True return self.execute_command("XCLAIM", *pieces, **kwargs) - def xdel(self, name, *ids): + def xdel(self, name: KeyT, *ids: StreamIdT) -> ResponseT: """ Deletes one or more messages from a stream. name: name of the stream. @@ -2679,7 +3085,13 @@ def xdel(self, name, *ids): """ return self.execute_command("XDEL", name, *ids) - def xgroup_create(self, name, groupname, id="$", mkstream=False): + def xgroup_create( + self, + name: KeyT, + groupname: GroupT, + id: StreamIdT = "$", + mkstream: bool = False, + ) -> ResponseT: """ Create a new consumer group associated with a stream. name: name of the stream. @@ -2688,12 +3100,17 @@ def xgroup_create(self, name, groupname, id="$", mkstream=False): For more information check https://redis.io/commands/xgroup-create """ - pieces = ["XGROUP CREATE", name, groupname, id] + pieces: list[EncodableT] = ["XGROUP CREATE", name, groupname, id] if mkstream: pieces.append(b"MKSTREAM") return self.execute_command(*pieces) - def xgroup_delconsumer(self, name, groupname, consumername): + def xgroup_delconsumer( + self, + name: KeyT, + groupname: GroupT, + consumername: ConsumerT, + ) -> ResponseT: """ Remove a specific consumer from a consumer group. Returns the number of pending messages that the consumer had before it @@ -2706,7 +3123,7 @@ def xgroup_delconsumer(self, name, groupname, consumername): """ return self.execute_command("XGROUP DELCONSUMER", name, groupname, consumername) - def xgroup_destroy(self, name, groupname): + def xgroup_destroy(self, name: KeyT, groupname: GroupT) -> ResponseT: """ Destroy a consumer group. name: name of the stream. @@ -2716,7 +3133,9 @@ def xgroup_destroy(self, name, groupname): """ return self.execute_command("XGROUP DESTROY", name, groupname) - def xgroup_createconsumer(self, name, groupname, consumername): + def xgroup_createconsumer( + self, name: KeyT, groupname: GroupT, consumername: ConsumerT, + ) -> ResponseT: """ Consumers in a consumer group are auto-created every time a new consumer name is mentioned by some command. @@ -2731,7 +3150,12 @@ def xgroup_createconsumer(self, name, groupname, consumername): "XGROUP CREATECONSUMER", name, groupname, consumername ) - def xgroup_setid(self, name, groupname, id): + def xgroup_setid( + self, + name: KeyT, + groupname: GroupT, + id: StreamIdT, + ) -> ResponseT: """ Set the consumer group last delivered ID to something else. name: name of the stream. @@ -2742,7 +3166,7 @@ def xgroup_setid(self, name, groupname, id): """ return self.execute_command("XGROUP SETID", name, groupname, id) - def xinfo_consumers(self, name, groupname): + def xinfo_consumers(self, name: KeyT, groupname: GroupT) -> ResponseT: """ Returns general information about the consumers in the group. name: name of the stream. @@ -2752,7 +3176,7 @@ def xinfo_consumers(self, name, groupname): """ return self.execute_command("XINFO CONSUMERS", name, groupname) - def xinfo_groups(self, name): + def xinfo_groups(self, name: KeyT) -> ResponseT: """ Returns general information about the consumer groups of the stream. name: name of the stream. @@ -2761,7 +3185,7 @@ def xinfo_groups(self, name): """ return self.execute_command("XINFO GROUPS", name) - def xinfo_stream(self, name, full=False): + def xinfo_stream(self, name: KeyT, full: bool = False) -> ResponseT: """ Returns general information about the stream. name: name of the stream. @@ -2776,7 +3200,7 @@ def xinfo_stream(self, name, full=False): options = {"full": full} return self.execute_command("XINFO STREAM", *pieces, **options) - def xlen(self, name): + def xlen(self, name: KeyT) -> ResponseT: """ Returns the number of elements in a given stream. @@ -2784,7 +3208,7 @@ def xlen(self, name): """ return self.execute_command("XLEN", name) - def xpending(self, name, groupname): + def xpending(self, name: KeyT, groupname: GroupT) -> ResponseT: """ Returns information about pending messages of a group. name: name of the stream. @@ -2796,14 +3220,14 @@ def xpending(self, name, groupname): def xpending_range( self, - name, - groupname, - idle=None, - min=None, - max=None, - count=None, - consumername=None, - ): + name: KeyT, + groupname: GroupT, + min: StreamIdT, + max: StreamIdT, + count: int, + consumername: ConsumerT | None = None, + idle: int | None = None, + ) -> ResponseT: """ Returns information about pending messages, in a range. @@ -2851,7 +3275,13 @@ def xpending_range( return self.execute_command("XPENDING", *pieces, parse_detail=True) - def xrange(self, name, min="-", max="+", count=None): + def xrange( + self, + name: KeyT, + min: StreamIdT = "-", + max: StreamIdT = "+", + count: int | None = None, + ) -> ResponseT: """ Read stream values within an interval. name: name of the stream. @@ -2873,7 +3303,12 @@ def xrange(self, name, min="-", max="+", count=None): return self.execute_command("XRANGE", name, *pieces) - def xread(self, streams, count=None, block=None): + def xread( + self, + streams: dict[KeyT, StreamIdT], + count: int | None = None, + block: int | None = None, + ) -> ResponseT: """ Block and monitor multiple streams for new data. streams: a dict of stream names to stream IDs, where @@ -2904,8 +3339,14 @@ def xread(self, streams, count=None, block=None): return self.execute_command("XREAD", *pieces) def xreadgroup( - self, groupname, consumername, streams, count=None, block=None, noack=False - ): + self, + groupname: str, + consumername: str, + streams: dict[KeyT, StreamIdT], + count: int | None = None, + block: int | None = None, + noack: bool = False, + ) -> ResponseT: """ Read from a stream via a consumer group. groupname: name of the consumer group. @@ -2919,7 +3360,7 @@ def xreadgroup( For more information check https://redis.io/commands/xreadgroup """ - pieces = [b"GROUP", groupname, consumername] + pieces: list[EncodableT] = [b"GROUP", groupname, consumername] if count is not None: if not isinstance(count, int) or count < 1: raise DataError("XREADGROUP count must be a positive integer") @@ -2939,7 +3380,13 @@ def xreadgroup( pieces.extend(streams.values()) return self.execute_command("XREADGROUP", *pieces) - def xrevrange(self, name, max="+", min="-", count=None): + def xrevrange( + self, + name: KeyT, + max: StreamIdT = "+", + min: StreamIdT = "-", + count: int | None = None, + ) -> ResponseT: """ Read stream values within an interval, in reverse order. name: name of the stream @@ -2952,7 +3399,7 @@ def xrevrange(self, name, max="+", min="-", count=None): For more information check https://redis.io/commands/xrevrange """ - pieces = [max, min] + pieces: list[EncodableT] = [max, min] if count is not None: if not isinstance(count, int) or count < 1: raise DataError("XREVRANGE count must be a positive integer") @@ -2961,7 +3408,14 @@ def xrevrange(self, name, max="+", min="-", count=None): return self.execute_command("XREVRANGE", name, *pieces) - def xtrim(self, name, maxlen=None, approximate=True, minid=None, limit=None): + def xtrim( + self, + name: KeyT, + maxlen: int, + approximate: bool = True, + minid: StreamIdT | None = None, + limit: int | None = None, + ) -> ResponseT: """ Trims old messages from a stream. name: name of the stream. @@ -2974,7 +3428,7 @@ def xtrim(self, name, maxlen=None, approximate=True, minid=None, limit=None): For more information check https://redis.io/commands/xtrim """ - pieces = [] + pieces: list[EncodableT] = [] if maxlen is not None and minid is not None: raise DataError("Only one of ``maxlen`` or ``minid`` " "may be specified") @@ -2995,15 +3449,26 @@ def xtrim(self, name, maxlen=None, approximate=True, minid=None, limit=None): return self.execute_command("XTRIM", name, *pieces) -class SortedSetCommands: +AsyncStreamCommands = StreamCommands + + +class SortedSetCommands(CommandsProtocol): """ Redis commands for Sorted Sets data type. see: https://redis.io/topics/data-types-intro#redis-sorted-sets """ def zadd( - self, name, mapping, nx=False, xx=False, ch=False, incr=False, gt=None, lt=None - ): + self, + name: KeyT, + mapping: Mapping[AnyKeyT, EncodableT], + nx: bool = False, + xx: bool = False, + ch: bool = False, + incr: bool = False, + gt: bool = None, + lt: bool = None, + ) -> ResponseT: """ Set any number of element-name, score pairs to the key ``name``. Pairs are specified as a dict of element-names keys to score values. @@ -3049,7 +3514,7 @@ def zadd( if nx is True and (gt is not None or lt is not None): raise DataError("Only one of 'nx', 'lt', or 'gr' may be defined.") - pieces = [] + pieces: list[EncodableT] = [] options = {} if nx: pieces.append(b"NX") @@ -3069,7 +3534,7 @@ def zadd( pieces.append(pair[0]) return self.execute_command("ZADD", name, *pieces, **options) - def zcard(self, name): + def zcard(self, name: KeyT): """ Return the number of elements in the sorted set ``name`` @@ -3077,7 +3542,12 @@ def zcard(self, name): """ return self.execute_command("ZCARD", name) - def zcount(self, name, min, max): + def zcount( + self, + name: KeyT, + min: ZScoreBoundT, + max: ZScoreBoundT + ) -> ResponseT: """ Returns the number of elements in the sorted set at key ``name`` with a score between ``min`` and ``max``. @@ -3086,7 +3556,7 @@ def zcount(self, name, min, max): """ return self.execute_command("ZCOUNT", name, min, max) - def zdiff(self, keys, withscores=False): + def zdiff(self, keys: KeysT, withscores: bool = False) -> ResponseT: """ Returns the difference between the first and all successive input sorted sets provided in ``keys``. @@ -3098,7 +3568,7 @@ def zdiff(self, keys, withscores=False): pieces.append("WITHSCORES") return self.execute_command("ZDIFF", *pieces) - def zdiffstore(self, dest, keys): + def zdiffstore(self, dest: KeyT, keys: KeysT) -> ResponseT: """ Computes the difference between the first and all successive input sorted sets provided in ``keys`` and stores the result in ``dest``. @@ -3108,7 +3578,12 @@ def zdiffstore(self, dest, keys): pieces = [len(keys), *keys] return self.execute_command("ZDIFFSTORE", dest, *pieces) - def zincrby(self, name, amount, value): + def zincrby( + self, + name: KeyT, + amount: float, + value: EncodableT, + ) -> ResponseT: """ Increment the score of ``value`` in sorted set ``name`` by ``amount`` @@ -3116,7 +3591,12 @@ def zincrby(self, name, amount, value): """ return self.execute_command("ZINCRBY", name, amount, value) - def zinter(self, keys, aggregate=None, withscores=False): + def zinter( + self, + keys: KeysT, + aggregate: str | None = None, + withscores: bool = False, + ) -> ResponseT: """ Return the intersect of multiple sorted sets specified by ``keys``. With the ``aggregate`` option, it is possible to specify how the @@ -3130,7 +3610,12 @@ def zinter(self, keys, aggregate=None, withscores=False): """ return self._zaggregate("ZINTER", None, keys, aggregate, withscores=withscores) - def zinterstore(self, dest, keys, aggregate=None): + def zinterstore( + self, + dest: KeyT, + keys: Sequence[KeyT] | Mapping[AnyKeyT, float], + aggregate: str | None = None, + ) -> ResponseT: """ Intersect multiple sorted sets specified by ``keys`` into a new sorted set, ``dest``. Scores in the destination will be aggregated @@ -3144,7 +3629,12 @@ def zinterstore(self, dest, keys, aggregate=None): """ return self._zaggregate("ZINTERSTORE", dest, keys, aggregate) - def zlexcount(self, name, min, max): + def zlexcount( + self, + name: KeyT, + min: EncodableT, + max: EncodableT, + ) -> ResponseT: """ Return the number of items in the sorted set ``name`` between the lexicographical range ``min`` and ``max``. @@ -3153,7 +3643,11 @@ def zlexcount(self, name, min, max): """ return self.execute_command("ZLEXCOUNT", name, min, max) - def zpopmax(self, name, count=None): + def zpopmax( + self, + name: KeyT, + count: int | None = None, + ) -> ResponseT: """ Remove and return up to ``count`` members with the highest scores from the sorted set ``name``. @@ -3164,7 +3658,11 @@ def zpopmax(self, name, count=None): options = {"withscores": True} return self.execute_command("ZPOPMAX", name, *args, **options) - def zpopmin(self, name, count=None): + def zpopmin( + self, + name: KeyT, + count: int | None = None, + ) -> ResponseT: """ Remove and return up to ``count`` members with the lowest scores from the sorted set ``name``. @@ -3175,7 +3673,12 @@ def zpopmin(self, name, count=None): options = {"withscores": True} return self.execute_command("ZPOPMIN", name, *args, **options) - def zrandmember(self, key, count=None, withscores=False): + def zrandmember( + self, + key: KeyT, + count: int = None, + withscores: bool = False, + ) -> ResponseT: """ Return a random element from the sorted set value stored at key. @@ -3199,7 +3702,7 @@ def zrandmember(self, key, count=None, withscores=False): return self.execute_command("ZRANDMEMBER", key, *params) - def bzpopmax(self, keys, timeout=0): + def bzpopmax(self, keys: KeysT, timeout: TimeoutSecT = 0) -> ResponseT: """ ZPOPMAX a value off of the first non-empty sorted set named in the ``keys`` list. @@ -3218,7 +3721,7 @@ def bzpopmax(self, keys, timeout=0): keys.append(timeout) return self.execute_command("BZPOPMAX", *keys) - def bzpopmin(self, keys, timeout=0): + def bzpopmin(self, keys: KeysT, timeout: TimeoutSecT = 0) -> ResponseT: """ ZPOPMIN a value off of the first non-empty sorted set named in the ``keys`` list. @@ -3233,25 +3736,25 @@ def bzpopmin(self, keys, timeout=0): """ if timeout is None: timeout = 0 - keys = list_or_args(keys, None) + keys: list[EncodableT] = list_or_args(keys, None) keys.append(timeout) return self.execute_command("BZPOPMIN", *keys) def _zrange( self, command, - dest, - name, - start, - end, - desc=False, - byscore=False, - bylex=False, - withscores=False, - score_cast_func=float, - offset=None, - num=None, - ): + dest: KeyT | None, + name: KeyT, + start: int, + end: int, + desc: bool = False, + byscore: bool = False, + bylex: bool = False, + withscores: bool = False, + score_cast_func: type | Callable | None = float, + offset: int | None = None, + num: int | None = None, + ) -> ResponseT: if byscore and bylex: raise DataError( "``byscore`` and ``bylex`` can not be " "specified together." @@ -3281,17 +3784,17 @@ def _zrange( def zrange( self, - name, - start, - end, - desc=False, - withscores=False, - score_cast_func=float, - byscore=False, - bylex=False, - offset=None, - num=None, - ): + name: KeyT, + start: int, + end: int, + desc: bool = False, + withscores: bool = False, + score_cast_func: type | Callable = float, + byscore: bool = False, + bylex: bool = False, + offset: int = None, + num: int = None, + ) -> ResponseT: """ Return a range of values from sorted set ``name`` between ``start`` and ``end`` sorted in ascending order. @@ -3340,7 +3843,14 @@ def zrange( num, ) - def zrevrange(self, name, start, end, withscores=False, score_cast_func=float): + def zrevrange( + self, + name: KeyT, + start: int, + end: int, + withscores: bool = False, + score_cast_func: type | Callable = float + ) -> ResponseT: """ Return a range of values from sorted set ``name`` between ``start`` and ``end`` sorted in descending order. @@ -3362,16 +3872,16 @@ def zrevrange(self, name, start, end, withscores=False, score_cast_func=float): def zrangestore( self, - dest, - name, - start, - end, - byscore=False, - bylex=False, - desc=False, - offset=None, - num=None, - ): + dest: KeyT, + name: KeyT, + start: int, + end: int, + byscore: bool = False, + bylex: bool = False, + desc: bool = False, + offset: int | None = None, + num: int | None = None, + ) -> ResponseT: """ Stores in ``dest`` the result of a range of values from sorted set ``name`` between ``start`` and ``end`` sorted in ascending order. @@ -3410,7 +3920,14 @@ def zrangestore( num, ) - def zrangebylex(self, name, min, max, start=None, num=None): + def zrangebylex( + self, + name: KeyT, + min: EncodableT, + max: EncodableT, + start: int | None = None, + num: int | None = None + ) -> ResponseT: """ Return the lexicographical range of values from sorted set ``name`` between ``min`` and ``max``. @@ -3427,7 +3944,14 @@ def zrangebylex(self, name, min, max, start=None, num=None): pieces.extend([b"LIMIT", start, num]) return self.execute_command(*pieces) - def zrevrangebylex(self, name, max, min, start=None, num=None): + def zrevrangebylex( + self, + name: KeyT, + max: EncodableT, + min: EncodableT, + start: int | None = None, + num: int | None = None + ) -> ResponseT: """ Return the reversed lexicographical range of values from sorted set ``name`` between ``max`` and ``min``. @@ -3446,14 +3970,14 @@ def zrevrangebylex(self, name, max, min, start=None, num=None): def zrangebyscore( self, - name, - min, - max, - start=None, - num=None, - withscores=False, - score_cast_func=float, - ): + name: KeyT, + min: ZScoreBoundT, + max: ZScoreBoundT, + start: int | None = None, + num: int | None = None, + withscores: bool = False, + score_cast_func: type | Callable = float, + ) -> ResponseT: """ Return a range of values from the sorted set ``name`` with scores between ``min`` and ``max``. @@ -3480,13 +4004,13 @@ def zrangebyscore( def zrevrangebyscore( self, - name, - max, - min, - start=None, - num=None, - withscores=False, - score_cast_func=float, + name: KeyT, + max: ZScoreBoundT, + min: ZScoreBoundT, + start: int | None = None, + num: int | None = None, + withscores: bool = False, + score_cast_func: type | Callable = float, ): """ Return a range of values from the sorted set ``name`` with scores @@ -3512,7 +4036,7 @@ def zrevrangebyscore( options = {"withscores": withscores, "score_cast_func": score_cast_func} return self.execute_command(*pieces, **options) - def zrank(self, name, value): + def zrank(self, name: KeyT, value: EncodableT) -> ResponseT: """ Returns a 0-based value indicating the rank of ``value`` in sorted set ``name`` @@ -3521,7 +4045,7 @@ def zrank(self, name, value): """ return self.execute_command("ZRANK", name, value) - def zrem(self, name, *values): + def zrem(self, name: KeyT, *values: EncodableT) -> ResponseT: """ Remove member ``values`` from sorted set ``name`` @@ -3529,7 +4053,12 @@ def zrem(self, name, *values): """ return self.execute_command("ZREM", name, *values) - def zremrangebylex(self, name, min, max): + def zremrangebylex( + self, + name: KeyT, + min: EncodableT, + max: EncodableT + ) -> ResponseT: """ Remove all elements in the sorted set ``name`` between the lexicographical range specified by ``min`` and ``max``. @@ -3540,7 +4069,12 @@ def zremrangebylex(self, name, min, max): """ return self.execute_command("ZREMRANGEBYLEX", name, min, max) - def zremrangebyrank(self, name, min, max): + def zremrangebyrank( + self, + name: KeyT, + min: int, + max: int + ) -> ResponseT: """ Remove all elements in the sorted set ``name`` with ranks between ``min`` and ``max``. Values are 0-based, ordered from smallest score @@ -3551,7 +4085,7 @@ def zremrangebyrank(self, name, min, max): """ return self.execute_command("ZREMRANGEBYRANK", name, min, max) - def zremrangebyscore(self, name, min, max): + def zremrangebyscore(self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT) -> ResponseT: """ Remove all elements in the sorted set ``name`` with scores between ``min`` and ``max``. Returns the number of elements removed. @@ -3560,7 +4094,7 @@ def zremrangebyscore(self, name, min, max): """ return self.execute_command("ZREMRANGEBYSCORE", name, min, max) - def zrevrank(self, name, value): + def zrevrank(self, name: KeyT, value: EncodableT) -> ResponseT: """ Returns a 0-based value indicating the descending rank of ``value`` in sorted set ``name`` @@ -3569,7 +4103,7 @@ def zrevrank(self, name, value): """ return self.execute_command("ZREVRANK", name, value) - def zscore(self, name, value): + def zscore(self, name: KeyT, value: EncodableT) -> ResponseT: """ Return the score of element ``value`` in sorted set ``name`` @@ -3577,7 +4111,12 @@ def zscore(self, name, value): """ return self.execute_command("ZSCORE", name, value) - def zunion(self, keys, aggregate=None, withscores=False): + def zunion( + self, + keys: Sequence[KeyT] | Mapping[AnyKeyT, float], + aggregate: str | None = None, + withscores: bool = False + ) -> ResponseT: """ Return the union of multiple sorted sets specified by ``keys``. ``keys`` can be provided as dictionary of keys and their weights. @@ -3588,7 +4127,12 @@ def zunion(self, keys, aggregate=None, withscores=False): """ return self._zaggregate("ZUNION", None, keys, aggregate, withscores=withscores) - def zunionstore(self, dest, keys, aggregate=None): + def zunionstore( + self, + dest: KeyT, + keys: Sequence[KeyT] | Mapping[AnyKeyT, float], + aggregate: str | None = None + ) -> ResponseT: """ Union multiple sorted sets specified by ``keys`` into a new sorted set, ``dest``. Scores in the destination will be @@ -3598,7 +4142,11 @@ def zunionstore(self, dest, keys, aggregate=None): """ return self._zaggregate("ZUNIONSTORE", dest, keys, aggregate) - def zmscore(self, key, members): + def zmscore( + self, + key: KeyT, + members: list[str], + ) -> ResponseT: """ Returns the scores associated with the specified members in the sorted set stored at key. @@ -3614,8 +4162,15 @@ def zmscore(self, key, members): pieces = [key] + members return self.execute_command("ZMSCORE", *pieces) - def _zaggregate(self, command, dest, keys, aggregate=None, **options): - pieces = [command] + def _zaggregate( + self, + command: str, + dest: KeyT | None, + keys: Sequence[KeyT] | Mapping[AnyKeyT, float], + aggregate: str | None = None, + **options, + ) -> ResponseT: + pieces: list[EncodableT] = [command] if dest is not None: pieces.append(dest) pieces.append(len(keys)) @@ -3638,13 +4193,16 @@ def _zaggregate(self, command, dest, keys, aggregate=None, **options): return self.execute_command(*pieces, **options) -class HyperlogCommands: +AsyncSortedSetCommands = SortedSetCommands + + +class HyperlogCommands(CommandsProtocol): """ Redis commands of HyperLogLogs data type. see: https://redis.io/topics/data-types-intro#hyperloglogs """ - def pfadd(self, name, *values): + def pfadd(self, name: KeyT, *values: EncodableT) -> ResponseT: """ Adds the specified elements to the specified HyperLogLog. @@ -3652,7 +4210,7 @@ def pfadd(self, name, *values): """ return self.execute_command("PFADD", name, *values) - def pfcount(self, *sources): + def pfcount(self, *sources: KeyT) -> ResponseT: """ Return the approximated cardinality of the set observed by the HyperLogLog at key(s). @@ -3661,7 +4219,7 @@ def pfcount(self, *sources): """ return self.execute_command("PFCOUNT", *sources) - def pfmerge(self, dest, *sources): + def pfmerge(self, dest: KeyT, *sources: KeyT) -> ResponseT: """ Merge N different HyperLogLogs into a single one. @@ -3670,13 +4228,16 @@ def pfmerge(self, dest, *sources): return self.execute_command("PFMERGE", dest, *sources) -class HashCommands: +AsyncHyperlogCommands = HyperlogCommands + + +class HashCommands(CommandsProtocol): """ Redis commands for Hash data type. see: https://redis.io/topics/data-types-intro#redis-hashes """ - def hdel(self, name, *keys): + def hdel(self, name: KeyT, *keys: FieldT) -> ResponseT: """ Delete ``keys`` from hash ``name`` @@ -3684,7 +4245,7 @@ def hdel(self, name, *keys): """ return self.execute_command("HDEL", name, *keys) - def hexists(self, name, key): + def hexists(self, name: KeyT, key: FieldT) -> ResponseT: """ Returns a boolean indicating if ``key`` exists within hash ``name`` @@ -3692,7 +4253,7 @@ def hexists(self, name, key): """ return self.execute_command("HEXISTS", name, key) - def hget(self, name, key): + def hget(self, name: KeyT, key: FieldT) -> ResponseT: """ Return the value of ``key`` within the hash ``name`` @@ -3700,7 +4261,7 @@ def hget(self, name, key): """ return self.execute_command("HGET", name, key) - def hgetall(self, name): + def hgetall(self, name: KeyT) -> ResponseT: """ Return a Python dict of the hash's name/value pairs @@ -3708,7 +4269,7 @@ def hgetall(self, name): """ return self.execute_command("HGETALL", name) - def hincrby(self, name, key, amount=1): + def hincrby(self, name: KeyT, key: FieldT, amount: int = 1) -> ResponseT: """ Increment the value of ``key`` in hash ``name`` by ``amount`` @@ -3716,7 +4277,7 @@ def hincrby(self, name, key, amount=1): """ return self.execute_command("HINCRBY", name, key, amount) - def hincrbyfloat(self, name, key, amount=1.0): + def hincrbyfloat(self, name: KeyT, key: FieldT, amount: float = 1.0) -> ResponseT: """ Increment the value of ``key`` in hash ``name`` by floating ``amount`` @@ -3724,7 +4285,7 @@ def hincrbyfloat(self, name, key, amount=1.0): """ return self.execute_command("HINCRBYFLOAT", name, key, amount) - def hkeys(self, name): + def hkeys(self, name: KeyT) -> ResponseT: """ Return the list of keys within hash ``name`` @@ -3732,7 +4293,7 @@ def hkeys(self, name): """ return self.execute_command("HKEYS", name) - def hlen(self, name): + def hlen(self, name: KeyT) -> ResponseT: """ Return the number of elements in hash ``name`` @@ -3740,7 +4301,13 @@ def hlen(self, name): """ return self.execute_command("HLEN", name) - def hset(self, name, key=None, value=None, mapping=None): + def hset( + self, + name: KeyT, + key: FieldT = None, + value: EncodableT = None, + mapping: Mapping[AnyFieldT, EncodableT] = None + ) -> ResponseT: """ Set ``key`` to ``value`` within hash ``name``, ``mapping`` accepts a dict of key/value pairs that will be @@ -3760,7 +4327,7 @@ def hset(self, name, key=None, value=None, mapping=None): return self.execute_command("HSET", name, *items) - def hsetnx(self, name, key, value): + def hsetnx(self, name: KeyT, key: FieldT, value: EncodableT) -> ResponseT: """ Set ``key`` to ``value`` within hash ``name`` if ``key`` does not exist. Returns 1 if HSETNX created a field, otherwise 0. @@ -3769,7 +4336,7 @@ def hsetnx(self, name, key, value): """ return self.execute_command("HSETNX", name, key, value) - def hmset(self, name, mapping): + def hmset(self, name: KeyT, mapping: Mapping[AnyFieldT, EncodableT]) -> ResponseT: """ Set key to value within hash ``name`` for each corresponding key and value from the ``mapping`` dict. @@ -3789,7 +4356,7 @@ def hmset(self, name, mapping): items.extend(pair) return self.execute_command("HMSET", name, *items) - def hmget(self, name, keys, *args): + def hmget(self, name: KeyT, keys: Sequence[FieldT], *args: FieldT) -> ResponseT: """ Returns a list of values ordered identically to ``keys`` @@ -3798,7 +4365,7 @@ def hmget(self, name, keys, *args): args = list_or_args(keys, args) return self.execute_command("HMGET", name, *args) - def hvals(self, name): + def hvals(self, name: KeyT): """ Return the list of values within hash ``name`` @@ -3806,7 +4373,7 @@ def hvals(self, name): """ return self.execute_command("HVALS", name) - def hstrlen(self, name, key): + def hstrlen(self, name: KeyT, key: FieldT) -> ResponseT: """ Return the number of bytes stored in the value of ``key`` within hash ``name`` @@ -3816,13 +4383,16 @@ def hstrlen(self, name, key): return self.execute_command("HSTRLEN", name, key) -class PubSubCommands: +AsyncHashCommands = HashCommands + + +class PubSubCommands(CommandsProtocol): """ Redis PubSub commands. see https://redis.io/topics/pubsub """ - def publish(self, channel, message, **kwargs): + def publish(self, channel: ChannelT, message: EncodableT, **kwargs) -> ResponseT: """ Publish ``message`` on ``channel``. Returns the number of subscribers the message was delivered to. @@ -3831,7 +4401,7 @@ def publish(self, channel, message, **kwargs): """ return self.execute_command("PUBLISH", channel, message, **kwargs) - def pubsub_channels(self, pattern="*", **kwargs): + def pubsub_channels(self, pattern: PatternT = "*", **kwargs) -> ResponseT: """ Return a list of channels that have at least one subscriber @@ -3839,7 +4409,7 @@ def pubsub_channels(self, pattern="*", **kwargs): """ return self.execute_command("PUBSUB CHANNELS", pattern, **kwargs) - def pubsub_numpat(self, **kwargs): + def pubsub_numpat(self, **kwargs) -> ResponseT: """ Returns the number of subscriptions to patterns @@ -3847,7 +4417,7 @@ def pubsub_numpat(self, **kwargs): """ return self.execute_command("PUBSUB NUMPAT", **kwargs) - def pubsub_numsub(self, *args, **kwargs): + def pubsub_numsub(self, *args: ChannelT, **kwargs) -> ResponseT: """ Return a list of (channel, number of subscribers) tuples for each channel given in ``*args`` @@ -3857,13 +4427,21 @@ def pubsub_numsub(self, *args, **kwargs): return self.execute_command("PUBSUB NUMSUB", *args, **kwargs) -class ScriptCommands: +AsyncPubSubCommands = PubSubCommands + + +class ScriptCommands(CommandsProtocol): """ Redis Lua script commands. see: https://redis.com/ebook/part-3-next-steps/chapter-11-scripting-redis-with-lua/ """ - def eval(self, script, numkeys, *keys_and_args): + def eval( + self, + script: ScriptTextT, + numkeys: int, + *keys_and_args: EncodableT + ) -> ResponseT: """ Execute the Lua ``script``, specifying the ``numkeys`` the script will touch and the key names and argument values in ``keys_and_args``. @@ -3876,7 +4454,12 @@ def eval(self, script, numkeys, *keys_and_args): """ return self.execute_command("EVAL", script, numkeys, *keys_and_args) - def evalsha(self, sha, numkeys, *keys_and_args): + def evalsha( + self, + sha: str, + numkeys: int, + *keys_and_args: EncodableT, + ) -> ResponseT: """ Use the ``sha`` to execute a Lua script already registered via EVAL or SCRIPT LOAD. Specify the ``numkeys`` the script will touch and the @@ -3890,7 +4473,7 @@ def evalsha(self, sha, numkeys, *keys_and_args): """ return self.execute_command("EVALSHA", sha, numkeys, *keys_and_args) - def script_exists(self, *args): + def script_exists(self, *args: str): """ Check if a script exists in the script cache by specifying the SHAs of each script as ``args``. Returns a list of boolean values indicating if @@ -3900,12 +4483,15 @@ def script_exists(self, *args): """ return self.execute_command("SCRIPT EXISTS", *args) - def script_debug(self, *args): + def script_debug(self, *args) -> None: raise NotImplementedError( "SCRIPT DEBUG is intentionally not implemented in the client." ) - def script_flush(self, sync_type=None): + def script_flush( + self, + sync_type: Literal["SYNC"] | Literal["ASYNC"] = None + ) -> ResponseT: """Flush all scripts from the script cache. ``sync_type`` is by default SYNC (synchronous) but it can also be ASYNC. @@ -3925,7 +4511,7 @@ def script_flush(self, sync_type=None): pieces = [sync_type] return self.execute_command("SCRIPT FLUSH", *pieces) - def script_kill(self): + def script_kill(self) -> ResponseT: """ Kill the currently executing Lua script @@ -3933,7 +4519,7 @@ def script_kill(self): """ return self.execute_command("SCRIPT KILL") - def script_load(self, script): + def script_load(self, script: ScriptTextT) -> ResponseT: """ Load a Lua ``script`` into the script cache. Returns the SHA. @@ -3941,7 +4527,7 @@ def script_load(self, script): """ return self.execute_command("SCRIPT LOAD", script) - def register_script(self, script): + def register_script(self: Redis, script: ScriptTextT) -> Script: """ Register a Lua ``script`` specifying the ``keys`` it will touch. Returns a Script object that is callable and hides the complexity of @@ -3951,13 +4537,34 @@ def register_script(self, script): return Script(self, script) -class GeoCommands: +class AsyncScriptCommands(ScriptCommands): + async def script_debug(self, *args) -> None: + return super().script_debug() + + def register_script(self: AsyncRedis, script: ScriptTextT) -> AsyncScript: + """ + Register a Lua ``script`` specifying the ``keys`` it will touch. + Returns a Script object that is callable and hides the complexity of + deal with scripts, keys, and shas. This is the preferred way to work + with Lua scripts. + """ + return AsyncScript(self, script) + + +class GeoCommands(CommandsProtocol): """ Redis Geospatial commands. see: https://redis.com/redis-best-practices/indexing-patterns/geospatial/ """ - def geoadd(self, name, values, nx=False, xx=False, ch=False): + def geoadd( + self, + name: KeyT, + values: Sequence[EncodableT], + nx: bool = False, + xx: bool = False, + ch: bool = False, + ) -> ResponseT: """ Add the specified geospatial items to the specified key identified by the ``name`` argument. The Geospatial items are given as ordered @@ -3992,7 +4599,13 @@ def geoadd(self, name, values, nx=False, xx=False, ch=False): pieces.extend(values) return self.execute_command("GEOADD", *pieces) - def geodist(self, name, place1, place2, unit=None): + def geodist( + self, + name: KeyT, + place1: FieldT, + place2: FieldT, + unit: str | None = None, + ) -> ResponseT: """ Return the distance between ``place1`` and ``place2`` members of the ``name`` key. @@ -4001,14 +4614,14 @@ def geodist(self, name, place1, place2, unit=None): For more information check https://redis.io/commands/geodist """ - pieces = [name, place1, place2] + pieces: list[EncodableT] = [name, place1, place2] if unit and unit not in ("m", "km", "mi", "ft"): raise DataError("GEODIST invalid unit") elif unit: pieces.append(unit) return self.execute_command("GEODIST", *pieces) - def geohash(self, name, *values): + def geohash(self, name: KeyT, *values: FieldT) -> ResponseT: """ Return the geo hash string for each item of ``values`` members of the specified key identified by the ``name`` argument. @@ -4017,7 +4630,7 @@ def geohash(self, name, *values): """ return self.execute_command("GEOHASH", name, *values) - def geopos(self, name, *values): + def geopos(self, name: KeyT, *values: FieldT) -> ResponseT: """ Return the positions of each item of ``values`` as members of the specified key identified by the ``name`` argument. Each position @@ -4029,20 +4642,20 @@ def geopos(self, name, *values): def georadius( self, - name, - longitude, - latitude, - radius, - unit=None, - withdist=False, - withcoord=False, - withhash=False, - count=None, - sort=None, - store=None, - store_dist=None, - any=False, - ): + name: KeyT, + longitude: float, + latitude: float, + radius: float, + unit: str | None = None, + withdist: bool = False, + withcoord: bool = False, + withhash: bool = False, + count: int | None = None, + sort: str | None = None, + store: KeyT | None = None, + store_dist: KeyT | None = None, + any: bool = False, + ) -> ResponseT: """ Return the members of the specified key identified by the ``name`` argument which are within the borders of the area specified @@ -4092,19 +4705,19 @@ def georadius( def georadiusbymember( self, - name, - member, - radius, - unit=None, - withdist=False, - withcoord=False, - withhash=False, - count=None, - sort=None, - store=None, - store_dist=None, - any=False, - ): + name: KeyT, + member: FieldT, + radius: float, + unit: str | None = None, + withdist: bool = False, + withcoord: bool = False, + withhash: bool = False, + count: int | None = None, + sort: str | None = None, + store: KeyT | None = None, + store_dist: KeyT | None = None, + any: bool = False, + ) -> ResponseT: """ This command is exactly like ``georadius`` with the sole difference that instead of taking, as the center of the area to query, a longitude @@ -4129,7 +4742,12 @@ def georadiusbymember( any=any, ) - def _georadiusgeneric(self, command, *args, **kwargs): + def _georadiusgeneric( + self, + command: str, + *args: EncodableT, + **kwargs: EncodableT | None, + ) -> ResponseT: pieces = list(args) if kwargs["unit"] and kwargs["unit"] not in ("m", "km", "mi", "ft"): raise DataError("GEORADIUS invalid unit") @@ -4177,21 +4795,21 @@ def _georadiusgeneric(self, command, *args, **kwargs): def geosearch( self, - name, - member=None, - longitude=None, - latitude=None, - unit="m", - radius=None, - width=None, - height=None, - sort=None, - count=None, - any=False, - withcoord=False, - withdist=False, - withhash=False, - ): + name: KeyT, + member: FieldT | None = None, + longitude: float | None = None, + latitude: float | None = None, + unit: str = "m", + radius: float | None = None, + width: float | None = None, + height: float | None = None, + sort: str | None = None, + count: int | None = None, + any: bool = False, + withcoord: bool = False, + withdist: bool = False, + withhash: bool = False, + ) -> ResponseT: """ Return the members of specified key identified by the ``name`` argument, which are within the borders of the @@ -4249,20 +4867,20 @@ def geosearch( def geosearchstore( self, - dest, - name, - member=None, - longitude=None, - latitude=None, - unit="m", - radius=None, - width=None, - height=None, - sort=None, - count=None, - any=False, - storedist=False, - ): + dest: KeyT, + name: KeyT, + member: FieldT | None = None, + longitude: float | None = None, + latitude: float | None = None, + unit: str = "m", + radius: float | None = None, + width: float | None = None, + height: float | None = None, + sort: str | None = None, + count: int | None = None, + any: bool = False, + storedist: bool = False, + ) -> ResponseT: """ This command is like GEOSEARCH, but stores the result in ``dest``. By default, it stores the results in the destination @@ -4294,7 +4912,12 @@ def geosearchstore( store_dist=storedist, ) - def _geosearchgeneric(self, command, *args, **kwargs): + def _geosearchgeneric( + self, + command: str, + *args: EncodableT, + **kwargs: EncodableT | None, + ) -> ResponseT: pieces = list(args) # FROMMEMBER or FROMLONLAT @@ -4359,13 +4982,16 @@ def _geosearchgeneric(self, command, *args, **kwargs): return self.execute_command(command, *pieces, **kwargs) -class ModuleCommands: +AsyncGeoCommands = GeoCommands + + +class ModuleCommands(CommandsProtocol): """ Redis Module commands. see: https://redis.io/topics/modules-intro """ - def module_load(self, path, *args): + def module_load(self, path, *args) -> ResponseT: """ Loads the module from ``path``. Passes all ``*args`` to the module, during loading. @@ -4375,7 +5001,7 @@ def module_load(self, path, *args): """ return self.execute_command("MODULE LOAD", path, *args) - def module_unload(self, name): + def module_unload(self, name) -> ResponseT: """ Unloads the module ``name``. Raises ``ModuleError`` if ``name`` is not in loaded modules. @@ -4384,7 +5010,7 @@ def module_unload(self, name): """ return self.execute_command("MODULE UNLOAD", name) - def module_list(self): + def module_list(self) -> ResponseT: """ Returns a list of dictionaries containing the name and version of all loaded modules. @@ -4393,27 +5019,32 @@ def module_list(self): """ return self.execute_command("MODULE LIST") - def command_info(self): + def command_info(self) -> None: raise NotImplementedError( "COMMAND INFO is intentionally not implemented in the client." ) - def command_count(self): + def command_count(self) -> ResponseT: return self.execute_command("COMMAND COUNT") - def command_getkeys(self, *args): + def command_getkeys(self, *args) -> ResponseT: return self.execute_command("COMMAND GETKEYS", *args) - def command(self): + def command(self) -> ResponseT: return self.execute_command("COMMAND") +class AsyncModuleCommands(ModuleCommands): + async def command_info(self) -> None: + return super().command_info() + + class Script: """ An executable Lua script object returned by ``register_script`` """ - def __init__(self, registered_client, script): + def __init__(self, registered_client: Redis, script: ScriptTextT): self.registered_client = registered_client self.script = script # Precalculate and store the SHA1 hex digest of the script. @@ -4425,8 +5056,15 @@ def __init__(self, registered_client, script): script = encoder.encode(script) self.sha = hashlib.sha1(script).hexdigest() - def __call__(self, keys=[], args=[], client=None): - "Execute the script, passing any required ``args``" + def __call__( + self, + keys: Sequence[KeyT] | None = None, + args: Iterable[EncodableT] | None = None, + client: Redis | None = None + ): + """Execute the script, passing any required ``args``""" + keys = keys or [] + args = args or [] if client is None: client = self.registered_client args = tuple(keys) + tuple(args) @@ -4446,15 +5084,63 @@ def __call__(self, keys=[], args=[], client=None): return client.evalsha(self.sha, len(keys), *args) +class AsyncScript: + """ + An executable Lua script object returned by ``register_script`` + """ + + def __init__(self, registered_client: AsyncRedis, script: ScriptTextT): + self.registered_client = registered_client + self.script = script + # Precalculate and store the SHA1 hex digest of the script. + + if isinstance(script, str): + # We need the encoding from the client in order to generate an + # accurate byte representation of the script + encoder = registered_client.connection_pool.get_encoder() + script = encoder.encode(script) + self.sha = hashlib.sha1(script).hexdigest() + + async def __call__( + self, + keys: Sequence[KeyT] | None = None, + args: Iterable[EncodableT] | None = None, + client: AsyncRedis | None = None, + ): + """Execute the script, passing any required ``args``""" + keys = keys or [] + args = args or [] + if client is None: + client = self.registered_client + args = tuple(keys) + tuple(args) + # make sure the Redis server knows about the script + from redis.asyncio.client import Pipeline + + if isinstance(client, Pipeline): + # Make sure the pipeline can register the script before executing. + client.scripts.add(self) + try: + return await client.evalsha(self.sha, len(keys), *args) + except NoScriptError: + # Maybe the client is pointed to a different server than the client + # that created this instance? + # Overwrite the sha just in case there was a discrepancy. + self.sha = await client.script_load(self.script) + return await client.evalsha(self.sha, len(keys), *args) + + class BitFieldOperation: """ Command builder for BITFIELD commands. """ - def __init__(self, client, key, default_overflow=None): + def __init__(self, client: Redis | AsyncRedis, key: str, default_overflow: str | None = None): self.client = client self.key = key self._default_overflow = default_overflow + # for typing purposes, run the following in constructor and in reset() + self.operations: list[tuple[EncodableT, ...]] = [] + self._last_overflow = "WRAP" self.reset() def reset(self): @@ -4465,7 +5151,7 @@ def reset(self): self._last_overflow = "WRAP" self.overflow(self._default_overflow or self._last_overflow) - def overflow(self, overflow): + def overflow(self, overflow: str): """ Update the overflow algorithm of successive INCRBY operations :param overflow: Overflow algorithm, one of WRAP, SAT, FAIL. See the @@ -4478,7 +5164,13 @@ def overflow(self, overflow): self.operations.append(("OVERFLOW", overflow)) return self - def incrby(self, fmt, offset, increment, overflow=None): + def incrby( + self, + fmt: str, + offset: BitfieldOffsetT, + increment: int, + overflow: str | None = None, + ): """ Increment a bitfield by a given amount. :param fmt: format-string for the bitfield being updated, e.g. 'u8' @@ -4498,7 +5190,7 @@ def incrby(self, fmt, offset, increment, overflow=None): self.operations.append(("INCRBY", fmt, offset, increment)) return self - def get(self, fmt, offset): + def get(self, fmt: str, offset: BitfieldOffsetT): """ Get the value of a given bitfield. :param fmt: format-string for the bitfield being read, e.g. 'u8' for @@ -4511,7 +5203,7 @@ def get(self, fmt, offset): self.operations.append(("GET", fmt, offset)) return self - def set(self, fmt, offset, value): + def set(self, fmt: str, offset: BitfieldOffsetT, value: int): """ Set the value of a given bitfield. :param fmt: format-string for the bitfield being read, e.g. 'u8' for @@ -4532,7 +5224,7 @@ def command(self): cmd.extend(ops) return cmd - def execute(self): + def execute(self) -> ResponseT: """ Execute the operation(s) in a single BITFIELD command. The return value is a list of values corresponding to each operation. If the client @@ -4544,15 +5236,15 @@ def execute(self): return self.client.execute_command(*command) -class ClusterCommands: +class ClusterCommands(CommandsProtocol): """ Class for Redis Cluster commands """ - def cluster(self, cluster_arg, *args, **kwargs): + def cluster(self, cluster_arg, *args, **kwargs) -> ResponseT: return self.execute_command(f"CLUSTER {cluster_arg.upper()}", *args, **kwargs) - def readwrite(self, **kwargs): + def readwrite(self, **kwargs) -> ResponseT: """ Disables read queries for a connection to a Redis Cluster slave node. @@ -4560,7 +5252,7 @@ def readwrite(self, **kwargs): """ return self.execute_command("READWRITE", **kwargs) - def readonly(self, **kwargs): + def readonly(self, **kwargs) -> ResponseT: """ Enables read queries for a connection to a Redis Cluster replica node. @@ -4569,6 +5261,9 @@ def readonly(self, **kwargs): return self.execute_command("READONLY", **kwargs) +AsyncClusterCommands = ClusterCommands + + class DataAccessCommands( BasicKeyCommands, HyperlogCommands, @@ -4582,7 +5277,24 @@ class DataAccessCommands( ): """ A class containing all of the implemented data access redis commands. - This class is to be used as a mixin. + This class is to be used as a mixin for synchronous Redis clients. + """ + + +class AsyncDataAccessCommands( + AsyncBasicKeyCommands, + AsyncHyperlogCommands, + AsyncHashCommands, + AsyncGeoCommands, + AsyncListCommands, + AsyncScanCommands, + AsyncSetCommands, + AsyncStreamCommands, + AsyncSortedSetCommands, +): + """ + A class containing all of the implemented data access redis commands. + This class is to be used as a mixin for asynchronous Redis clients. """ @@ -4597,5 +5309,20 @@ class CoreCommands( ): """ A class containing all of the implemented redis commands. This class is - to be used as a mixin. + to be used as a mixin for synchronous Redis clients. + """ + + +class AsyncCoreCommands( + AsyncACLCommands, + AsyncClusterCommands, + AsyncDataAccessCommands, + AsyncManagementCommands, + AsyncModuleCommands, + AsyncPubSubCommands, + AsyncScriptCommands, +): + """ + A class containing all of the implemented redis commands. This class is + to be used as a mixin for asynchronous Redis clients. """ diff --git a/redis/commands/sentinel.py b/redis/commands/sentinel.py index a9b06c2f6e..bb12f14568 100644 --- a/redis/commands/sentinel.py +++ b/redis/commands/sentinel.py @@ -91,3 +91,9 @@ def sentinel_flushconfig(self): completely missing. """ return self.execute_command("SENTINEL FLUSHCONFIG") + + +class AsyncSentinelCommands(SentinelCommands): + async def sentinel(self, *args) -> None: + "Redis Sentinel's SENTINEL command." + super().sentinel(*args) diff --git a/requirements.txt b/requirements.txt index b05ff454bf..001ecb6bfd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,4 @@ +async-timeout deprecated>=1.2.3 packaging>=20.4 +typing-extensions diff --git a/setup.py b/setup.py index 8b84c2a97a..1565a9e4d2 100644 --- a/setup.py +++ b/setup.py @@ -29,11 +29,12 @@ }, author="Redis Inc.", author_email="oss@redis.com", - python_requires=">=3.6", + python_requires=">=3.7", install_requires=[ "deprecated>=1.2.3", "packaging>=20.4", 'importlib-metadata >= 1.0; python_version < "3.8"', + "typing-extensions", ], classifiers=[ "Development Status :: 5 - Production/Stable", @@ -44,7 +45,6 @@ "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", @@ -53,6 +53,7 @@ "Programming Language :: Python :: Implementation :: PyPy", ], extras_require={ + "async": ["async-timeout"], "hiredis": ["hiredis>=1.0.0"], "ocsp": ["cryptography>=36.0.1", "pyopenssl==20.0.1", "requests>=2.26.0"], }, diff --git a/tests/conftest.py b/tests/conftest.py index 9ba63d6caa..b6366e9d6b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,7 @@ +import argparse import random import time +from typing import Callable, TypeVar from unittest.mock import Mock from urllib.parse import urlparse @@ -21,6 +23,54 @@ default_redis_ssl_url = "rediss://localhost:6666" default_cluster_nodes = 6 +_DecoratedTest = TypeVar("_DecoratedTest", bound="Callable") +_TestDecorator = Callable[[_DecoratedTest], _DecoratedTest] + + +# Taken from python3.9 +class BooleanOptionalAction(argparse.Action): + def __init__( + self, + option_strings, + dest, + default=None, + type=None, + choices=None, + required=False, + help=None, + metavar=None, + ): + + _option_strings = [] + for option_string in option_strings: + _option_strings.append(option_string) + + if option_string.startswith("--"): + option_string = "--no-" + option_string[2:] + _option_strings.append(option_string) + + if help is not None and default is not None: + help += f" (default: {default})" + + super().__init__( + option_strings=_option_strings, + dest=dest, + nargs=0, + default=default, + type=type, + choices=choices, + required=required, + help=help, + metavar=metavar, + ) + + def __call__(self, parser, namespace, values, option_string=None): + if option_string in self.option_strings: + setattr(namespace, self.dest, not option_string.startswith("--no-")) + + def format_usage(self): + return " | ".join(self.option_strings) + def pytest_addoption(parser): parser.addoption( @@ -62,6 +112,9 @@ def pytest_addoption(parser): help="Redis unstable (latest version) connection string " "defaults to %(default)s`", ) + parser.addoption( + "--uvloop", action=BooleanOptionalAction, help="Run tests with uvloop" + ) def _get_info(redis_url): @@ -101,6 +154,18 @@ def pytest_sessionstart(session): cluster_nodes = session.config.getoption("--redis-cluster-nodes") wait_for_cluster_creation(redis_url, cluster_nodes) + use_uvloop = session.config.getoption("--uvloop") + + if use_uvloop: + try: + import uvloop + + uvloop.install() + except ImportError as e: + raise RuntimeError( + "Can not import uvloop, make sure it is installed" + ) from e + def wait_for_cluster_creation(redis_url, cluster_nodes, timeout=20): """ @@ -133,19 +198,19 @@ def wait_for_cluster_creation(redis_url, cluster_nodes, timeout=20): ) -def skip_if_server_version_lt(min_version): +def skip_if_server_version_lt(min_version: str) -> _TestDecorator: redis_version = REDIS_INFO["version"] check = Version(redis_version) < Version(min_version) return pytest.mark.skipif(check, reason=f"Redis version required >= {min_version}") -def skip_if_server_version_gte(min_version): +def skip_if_server_version_gte(min_version: str) -> _TestDecorator: redis_version = REDIS_INFO["version"] check = Version(redis_version) >= Version(min_version) return pytest.mark.skipif(check, reason=f"Redis version required < {min_version}") -def skip_unless_arch_bits(arch_bits): +def skip_unless_arch_bits(arch_bits: int) -> _TestDecorator: return pytest.mark.skipif( REDIS_INFO["arch_bits"] != arch_bits, reason=f"server is not {arch_bits}-bit" ) @@ -169,17 +234,17 @@ def skip_ifmodversion_lt(min_version: str, module_name: str): raise AttributeError(f"No redis module named {module_name}") -def skip_if_redis_enterprise(): +def skip_if_redis_enterprise() -> _TestDecorator: check = REDIS_INFO["enterprise"] is True return pytest.mark.skipif(check, reason="Redis enterprise") -def skip_ifnot_redis_enterprise(): +def skip_ifnot_redis_enterprise() -> _TestDecorator: check = REDIS_INFO["enterprise"] is False return pytest.mark.skipif(check, reason="Not running in redis enterprise") -def skip_if_nocryptography(): +def skip_if_nocryptography() -> _TestDecorator: try: import cryptography # noqa @@ -188,7 +253,7 @@ def skip_if_nocryptography(): return pytest.mark.skipif(True, reason="No cryptography dependency") -def skip_if_cryptography(): +def skip_if_cryptography() -> _TestDecorator: try: import cryptography # noqa diff --git a/tests/test_asyncio/__init__.py b/tests/test_asyncio/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_asyncio/compat.py b/tests/test_asyncio/compat.py new file mode 100644 index 0000000000..ced4974196 --- /dev/null +++ b/tests/test_asyncio/compat.py @@ -0,0 +1,6 @@ +from unittest import mock + +try: + mock.AsyncMock +except AttributeError: + import mock diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py new file mode 100644 index 0000000000..0657088eb5 --- /dev/null +++ b/tests/test_asyncio/conftest.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +import asyncio +import random +from packaging.version import Version +from urllib.parse import urlparse + +import pytest + +import redis.asyncio as redis +from redis.asyncio.client import Monitor +from redis.asyncio.connection import ( + HIREDIS_AVAILABLE, + HiredisParser, + PythonParser, + parse_url, +) + +from tests.conftest import REDIS_INFO +from .compat import mock + + +async def _get_info(redis_url): + client = redis.Redis.from_url(redis_url) + info = await client.info() + await client.connection_pool.disconnect() + return info + + +@pytest.fixture( + params=[ + (True, PythonParser), + (False, PythonParser), + pytest.param( + (True, HiredisParser), + marks=pytest.mark.skipif( + not HIREDIS_AVAILABLE, reason="hiredis is not installed" + ), + ), + pytest.param( + (False, HiredisParser), + marks=pytest.mark.skipif( + not HIREDIS_AVAILABLE, reason="hiredis is not installed" + ), + ), + ], + ids=[ + "single-python-parser", + "pool-python-parser", + "single-hiredis", + "pool-hiredis", + ], +) +def create_redis(request, event_loop: asyncio.BaseEventLoop): + """Wrapper around redis.create_redis.""" + single_connection, parser_cls = request.param + + async def f(url: str = request.config.getoption("--redis-url"), **kwargs): + single = kwargs.pop("single_connection_client", False) or single_connection + parser_class = kwargs.pop("parser_class", None) or parser_cls + url_options = parse_url(url) + url_options.update(kwargs) + pool = redis.ConnectionPool(parser_class=parser_class, **url_options) + client: redis.Redis = redis.Redis(connection_pool=pool) + if single: + client = client.client() + await client.initialize() + + def teardown(): + async def ateardown(): + if "username" in kwargs: + return + try: + await client.flushdb() + except redis.ConnectionError: + # handle cases where a test disconnected a client + # just manually retry the flushdb + await client.flushdb() + await client.close() + await client.connection_pool.disconnect() + + if event_loop.is_running(): + event_loop.create_task(ateardown()) + else: + event_loop.run_until_complete(ateardown()) + + request.addfinalizer(teardown) + + return client + + return f + + +@pytest.fixture() +async def r(create_redis): + yield await create_redis() + + +@pytest.fixture() +async def r2(create_redis): + """A second client for tests that need multiple""" + yield await create_redis() + + +def _gen_cluster_mock_resp(r, response): + connection = mock.AsyncMock() + connection.read_response.return_value = response + r.connection = connection + return r + + +@pytest.fixture() +async def mock_cluster_resp_ok(create_redis, **kwargs): + r = await create_redis(**kwargs) + return _gen_cluster_mock_resp(r, "OK") + + +@pytest.fixture() +async def mock_cluster_resp_int(create_redis, **kwargs): + r = await create_redis(**kwargs) + return _gen_cluster_mock_resp(r, "2") + + +@pytest.fixture() +async def mock_cluster_resp_info(create_redis, **kwargs): + r = await create_redis(**kwargs) + response = ( + "cluster_state:ok\r\ncluster_slots_assigned:16384\r\n" + "cluster_slots_ok:16384\r\ncluster_slots_pfail:0\r\n" + "cluster_slots_fail:0\r\ncluster_known_nodes:7\r\n" + "cluster_size:3\r\ncluster_current_epoch:7\r\n" + "cluster_my_epoch:2\r\ncluster_stats_messages_sent:170262\r\n" + "cluster_stats_messages_received:105653\r\n" + ) + return _gen_cluster_mock_resp(r, response) + + +@pytest.fixture() +async def mock_cluster_resp_nodes(create_redis, **kwargs): + r = await create_redis(**kwargs) + response = ( + "c8253bae761cb1ecb2b61857d85dfe455a0fec8b 172.17.0.7:7006 " + "slave aa90da731f673a99617dfe930306549a09f83a6b 0 " + "1447836263059 5 connected\n" + "9bd595fe4821a0e8d6b99d70faa660638a7612b3 172.17.0.7:7008 " + "master - 0 1447836264065 0 connected\n" + "aa90da731f673a99617dfe930306549a09f83a6b 172.17.0.7:7003 " + "myself,master - 0 0 2 connected 5461-10922\n" + "1df047e5a594f945d82fc140be97a1452bcbf93e 172.17.0.7:7007 " + "slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 " + "1447836262556 3 connected\n" + "4ad9a12e63e8f0207025eeba2354bcf4c85e5b22 172.17.0.7:7005 " + "master - 0 1447836262555 7 connected 0-5460\n" + "19efe5a631f3296fdf21a5441680f893e8cc96ec 172.17.0.7:7004 " + "master - 0 1447836263562 3 connected 10923-16383\n" + "fbb23ed8cfa23f17eaf27ff7d0c410492a1093d6 172.17.0.7:7002 " + "master,fail - 1447829446956 1447829444948 1 disconnected\n" + ) + return _gen_cluster_mock_resp(r, response) + + +@pytest.fixture() +async def mock_cluster_resp_slaves(create_redis, **kwargs): + r = await create_redis(**kwargs) + response = ( + "['1df047e5a594f945d82fc140be97a1452bcbf93e 172.17.0.7:7007 " + "slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 " + "1447836789290 3 connected']" + ) + return _gen_cluster_mock_resp(r, response) + + +@pytest.fixture(scope="session") +def master_host(request): + url = request.config.getoption("--redis-url") + parts = urlparse(url) + yield parts.hostname + + +async def wait_for_command( + client: redis.Redis, monitor: Monitor, command: str, key: str | None = None +): + # issue a command with a key name that's local to this process. + # if we find a command with our key before the command we're waiting + # for, something went wrong + if key is None: + # generate key + redis_version = REDIS_INFO["version"] + if Version(redis_version) >= Version("5.0.0"): + id_str = str(client.client_id()) + else: + id_str = f"{random.randrange(2 ** 32):08x}" + key = f"__REDIS-PY-{id_str}__" + await client.get(key) + while True: + monitor_response = await monitor.next_command() + if command in monitor_response["command"]: + return monitor_response + if key in monitor_response["command"]: + return None diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py new file mode 100644 index 0000000000..12855c9326 --- /dev/null +++ b/tests/test_asyncio/test_commands.py @@ -0,0 +1,3 @@ +""" +Tests async overrides of commands from their mixins +""" diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py new file mode 100644 index 0000000000..25ec0d8a49 --- /dev/null +++ b/tests/test_asyncio/test_connection.py @@ -0,0 +1,61 @@ +import asyncio +import types + +import pytest + +from redis.asyncio.connection import PythonParser, UnixDomainSocketConnection +from redis.exceptions import InvalidResponse +from redis.utils import HIREDIS_AVAILABLE +from tests.conftest import skip_if_server_version_lt + +from .compat import mock + +pytestmark = pytest.mark.asyncio + + +@pytest.mark.onlynoncluster +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +async def test_invalid_response(create_redis): + r = await create_redis(single_connection_client=True) + + raw = b"x" + readline_mock = mock.AsyncMock(return_value=raw) + + parser: "PythonParser" = r.connection._parser + with mock.patch.object(parser._buffer, "readline", readline_mock): + with pytest.raises(InvalidResponse) as cm: + await parser.read_response() + assert str(cm.value) == "Protocol Error: %r" % raw + + +@skip_if_server_version_lt("4.0.0") +@pytest.mark.redismod +async def test_loading_external_modules(modclient): + def inner(): + pass + + modclient.load_external_module("myfuncname", inner) + assert getattr(modclient, "myfuncname") == inner + assert isinstance(getattr(modclient, "myfuncname"), types.FunctionType) + + # and call it + from redis.commands import RedisModuleCommands + + j = RedisModuleCommands.json + modclient.load_external_module("sometestfuncname", j) + + # d = {'hello': 'world!'} + # mod = j(modclient) + # mod.set("fookey", ".", d) + # assert mod.get('fookey') == d + + +async def test_socket_param_regression(r): + """A regression test for issue #1060""" + conn = UnixDomainSocketConnection() + _ = await conn.disconnect() is True + + +async def test_can_run_concurrent_commands(r): + assert await r.ping() is True + assert all(await asyncio.gather(*(r.ping() for _ in range(10)))) diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py new file mode 100644 index 0000000000..36b5226a21 --- /dev/null +++ b/tests/test_asyncio/test_connection_pool.py @@ -0,0 +1,802 @@ +import asyncio +import os +import re + +import pytest + +import redis.asyncio as redis +from redis.asyncio.connection import Connection, to_bool +from redis import exceptions + +from tests.conftest import skip_if_redis_enterprise, skip_if_server_version_lt +from .compat import mock +from .test_pubsub import wait_for_message + +pytestmark = pytest.mark.asyncio + + +class DummyConnection(Connection): + description_format = "DummyConnection<>" + + def __init__(self, **kwargs): + self.kwargs = kwargs + self.pid = os.getpid() + + async def connect(self): + pass + + async def disconnect(self): + pass + + async def can_read(self, timeout: float = 0): + return False + + +class TestConnectionPool: + def get_pool( + self, + connection_kwargs=None, + max_connections=None, + connection_class=redis.Connection, + ): + connection_kwargs = connection_kwargs or {} + pool = redis.ConnectionPool( + connection_class=connection_class, + max_connections=max_connections, + **connection_kwargs, + ) + return pool + + async def test_connection_creation(self): + connection_kwargs = {"foo": "bar", "biz": "baz"} + pool = self.get_pool( + connection_kwargs=connection_kwargs, connection_class=DummyConnection + ) + connection = await pool.get_connection("_") + assert isinstance(connection, DummyConnection) + assert connection.kwargs == connection_kwargs + + async def test_multiple_connections(self, master_host): + connection_kwargs = {"host": master_host} + pool = self.get_pool(connection_kwargs=connection_kwargs) + c1 = await pool.get_connection("_") + c2 = await pool.get_connection("_") + assert c1 != c2 + + async def test_max_connections(self, master_host): + connection_kwargs = {"host": master_host} + pool = self.get_pool(max_connections=2, connection_kwargs=connection_kwargs) + await pool.get_connection("_") + await pool.get_connection("_") + with pytest.raises(exceptions.ConnectionError): + await pool.get_connection("_") + + async def test_reuse_previously_released_connection(self, master_host): + connection_kwargs = {"host": master_host} + pool = self.get_pool(connection_kwargs=connection_kwargs) + c1 = await pool.get_connection("_") + await pool.release(c1) + c2 = await pool.get_connection("_") + assert c1 == c2 + + def test_repr_contains_db_info_tcp(self): + connection_kwargs = { + "host": "localhost", + "port": 6379, + "db": 1, + "client_name": "test-client", + } + pool = self.get_pool( + connection_kwargs=connection_kwargs, connection_class=redis.Connection + ) + expected = ( + "ConnectionPool>" + ) + assert repr(pool) == expected + + def test_repr_contains_db_info_unix(self): + connection_kwargs = {"path": "/abc", "db": 1, "client_name": "test-client"} + pool = self.get_pool( + connection_kwargs=connection_kwargs, + connection_class=redis.UnixDomainSocketConnection, + ) + expected = ( + "ConnectionPool>" + ) + assert repr(pool) == expected + + +class TestBlockingConnectionPool: + def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20): + connection_kwargs = connection_kwargs or {} + pool = redis.BlockingConnectionPool( + connection_class=DummyConnection, + max_connections=max_connections, + timeout=timeout, + **connection_kwargs, + ) + return pool + + async def test_connection_creation(self, master_host): + connection_kwargs = { + "foo": "bar", + "biz": "baz", + "host": master_host[0], + "port": master_host[1], + } + pool = self.get_pool(connection_kwargs=connection_kwargs) + connection = await pool.get_connection("_") + assert isinstance(connection, DummyConnection) + assert connection.kwargs == connection_kwargs + + async def test_disconnect(self, master_host): + """A regression test for #1047""" + connection_kwargs = { + "foo": "bar", + "biz": "baz", + "host": master_host[0], + "port": master_host[1], + } + pool = self.get_pool(connection_kwargs=connection_kwargs) + await pool.get_connection("_") + await pool.disconnect() + + async def test_multiple_connections(self, master_host): + connection_kwargs = {"host": master_host[0], "port": master_host[1]} + pool = self.get_pool(connection_kwargs=connection_kwargs) + c1 = await pool.get_connection("_") + c2 = await pool.get_connection("_") + assert c1 != c2 + + async def test_connection_pool_blocks_until_timeout(self, master_host): + """When out of connections, block for timeout seconds, then raise""" + connection_kwargs = {"host": master_host} + pool = self.get_pool( + max_connections=1, timeout=0.1, connection_kwargs=connection_kwargs + ) + await pool.get_connection("_") + + start = asyncio.get_event_loop().time() + with pytest.raises(exceptions.ConnectionError): + await pool.get_connection("_") + # we should have waited at least 0.1 seconds + assert asyncio.get_event_loop().time() - start >= 0.1 + + async def test_connection_pool_blocks_until_conn_available(self, master_host): + """ + When out of connections, block until another connection is released + to the pool + """ + connection_kwargs = {"host": master_host[0], "port": master_host[1]} + pool = self.get_pool( + max_connections=1, timeout=2, connection_kwargs=connection_kwargs + ) + c1 = await pool.get_connection("_") + + async def target(): + await asyncio.sleep(0.1) + await pool.release(c1) + + start = asyncio.get_event_loop().time() + await asyncio.gather(target(), pool.get_connection("_")) + assert asyncio.get_event_loop().time() - start >= 0.1 + + async def test_reuse_previously_released_connection(self, master_host): + connection_kwargs = {"host": master_host} + pool = self.get_pool(connection_kwargs=connection_kwargs) + c1 = await pool.get_connection("_") + await pool.release(c1) + c2 = await pool.get_connection("_") + assert c1 == c2 + + def test_repr_contains_db_info_tcp(self): + pool = redis.ConnectionPool( + host="localhost", port=6379, client_name="test-client" + ) + expected = ( + "ConnectionPool>" + ) + assert repr(pool) == expected + + def test_repr_contains_db_info_unix(self): + pool = redis.ConnectionPool( + connection_class=redis.UnixDomainSocketConnection, + path="abc", + client_name="test-client", + ) + expected = ( + "ConnectionPool>" + ) + assert repr(pool) == expected + + +class TestConnectionPoolURLParsing: + def test_hostname(self): + pool = redis.ConnectionPool.from_url("redis://my.host") + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "my.host", + } + + def test_quoted_hostname(self): + pool = redis.ConnectionPool.from_url("redis://my %2F host %2B%3D+") + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "my / host +=+", + } + + def test_port(self): + pool = redis.ConnectionPool.from_url("redis://localhost:6380") + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "port": 6380, + } + + @skip_if_server_version_lt("6.0.0") + def test_username(self): + pool = redis.ConnectionPool.from_url("redis://myuser:@localhost") + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "username": "myuser", + } + + @skip_if_server_version_lt("6.0.0") + def test_quoted_username(self): + pool = redis.ConnectionPool.from_url( + "redis://%2Fmyuser%2F%2B name%3D%24+:@localhost" + ) + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "username": "/myuser/+ name=$+", + } + + def test_password(self): + pool = redis.ConnectionPool.from_url("redis://:mypassword@localhost") + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "password": "mypassword", + } + + def test_quoted_password(self): + pool = redis.ConnectionPool.from_url( + "redis://:%2Fmypass%2F%2B word%3D%24+@localhost" + ) + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "password": "/mypass/+ word=$+", + } + + @skip_if_server_version_lt("6.0.0") + def test_username_and_password(self): + pool = redis.ConnectionPool.from_url("redis://myuser:mypass@localhost") + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "username": "myuser", + "password": "mypass", + } + + def test_db_as_argument(self): + pool = redis.ConnectionPool.from_url("redis://localhost", db=1) + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "db": 1, + } + + def test_db_in_path(self): + pool = redis.ConnectionPool.from_url("redis://localhost/2", db=1) + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "db": 2, + } + + def test_db_in_querystring(self): + pool = redis.ConnectionPool.from_url("redis://localhost/2?db=3", db=1) + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "db": 3, + } + + def test_extra_typed_querystring_options(self): + pool = redis.ConnectionPool.from_url( + "redis://localhost/2?socket_timeout=20&socket_connect_timeout=10" + "&socket_keepalive=&retry_on_timeout=Yes&max_connections=10" + ) + + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "db": 2, + "socket_timeout": 20.0, + "socket_connect_timeout": 10.0, + "retry_on_timeout": True, + } + assert pool.max_connections == 10 + + def test_boolean_parsing(self): + for expected, value in ( + (None, None), + (None, ""), + (False, 0), + (False, "0"), + (False, "f"), + (False, "F"), + (False, "False"), + (False, "n"), + (False, "N"), + (False, "No"), + (True, 1), + (True, "1"), + (True, "y"), + (True, "Y"), + (True, "Yes"), + ): + assert expected is to_bool(value) + + def test_client_name_in_querystring(self): + pool = redis.ConnectionPool.from_url( + "redis://location?client_name=test-client" + ) + assert pool.connection_kwargs["client_name"] == "test-client" + + def test_invalid_extra_typed_querystring_options(self): + with pytest.raises(ValueError): + redis.ConnectionPool.from_url( + "redis://localhost/2?socket_timeout=_&" "socket_connect_timeout=abc" + ) + + def test_extra_querystring_options(self): + pool = redis.ConnectionPool.from_url("redis://localhost?a=1&b=2") + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == {"host": "localhost", "a": "1", "b": "2"} + + def test_calling_from_subclass_returns_correct_instance(self): + pool = redis.BlockingConnectionPool.from_url("redis://localhost") + assert isinstance(pool, redis.BlockingConnectionPool) + + def test_client_creates_connection_pool(self): + r = redis.Redis.from_url("redis://myhost") + assert r.connection_pool.connection_class == redis.Connection + assert r.connection_pool.connection_kwargs == { + "host": "myhost", + } + + def test_invalid_scheme_raises_error(self): + with pytest.raises(ValueError) as cm: + redis.ConnectionPool.from_url("localhost") + assert str(cm.value) == ( + "Redis URL must specify one of the following schemes " + "(redis://, rediss://, unix://)" + ) + + +class TestConnectionPoolUnixSocketURLParsing: + def test_defaults(self): + pool = redis.ConnectionPool.from_url("unix:///socket") + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + } + + @skip_if_server_version_lt("6.0.0") + def test_username(self): + pool = redis.ConnectionPool.from_url("unix://myuser:@/socket") + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + "username": "myuser", + } + + @skip_if_server_version_lt("6.0.0") + def test_quoted_username(self): + pool = redis.ConnectionPool.from_url( + "unix://%2Fmyuser%2F%2B name%3D%24+:@/socket" + ) + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + "username": "/myuser/+ name=$+", + } + + def test_password(self): + pool = redis.ConnectionPool.from_url("unix://:mypassword@/socket") + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + "password": "mypassword", + } + + def test_quoted_password(self): + pool = redis.ConnectionPool.from_url( + "unix://:%2Fmypass%2F%2B word%3D%24+@/socket" + ) + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + "password": "/mypass/+ word=$+", + } + + def test_quoted_path(self): + pool = redis.ConnectionPool.from_url( + "unix://:mypassword@/my%2Fpath%2Fto%2F..%2F+_%2B%3D%24ocket" + ) + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/my/path/to/../+_+=$ocket", + "password": "mypassword", + } + + def test_db_as_argument(self): + pool = redis.ConnectionPool.from_url("unix:///socket", db=1) + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + "db": 1, + } + + def test_db_in_querystring(self): + pool = redis.ConnectionPool.from_url("unix:///socket?db=2", db=1) + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + "db": 2, + } + + def test_client_name_in_querystring(self): + pool = redis.ConnectionPool.from_url( + "redis://location?client_name=test-client" + ) + assert pool.connection_kwargs["client_name"] == "test-client" + + def test_extra_querystring_options(self): + pool = redis.ConnectionPool.from_url("unix:///socket?a=1&b=2") + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == {"path": "/socket", "a": "1", "b": "2"} + + +class TestSSLConnectionURLParsing: + def test_host(self): + pool = redis.ConnectionPool.from_url("rediss://my.host") + assert pool.connection_class == redis.SSLConnection + assert pool.connection_kwargs == { + "host": "my.host", + } + + def test_cert_reqs_options(self): + import ssl + + class DummyConnectionPool(redis.ConnectionPool): + def get_connection(self, *args, **kwargs): + return self.make_connection() + + pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=none") + assert pool.get_connection("_").cert_reqs == ssl.CERT_NONE + + pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=optional") + assert pool.get_connection("_").cert_reqs == ssl.CERT_OPTIONAL + + pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=required") + assert pool.get_connection("_").cert_reqs == ssl.CERT_REQUIRED + + pool = DummyConnectionPool.from_url("rediss://?ssl_check_hostname=False") + assert pool.get_connection("_").check_hostname is False + + pool = DummyConnectionPool.from_url("rediss://?ssl_check_hostname=True") + assert pool.get_connection("_").check_hostname is True + + +class TestConnection: + async def test_on_connect_error(self): + """ + An error in Connection.on_connect should disconnect from the server + see for details: https://github.com/andymccurdy/redis-py/issues/368 + """ + # this assumes the Redis server being tested against doesn't have + # 9999 databases ;) + bad_connection = redis.Redis(db=9999) + # an error should be raised on connect + with pytest.raises(exceptions.RedisError): + await bad_connection.info() + pool = bad_connection.connection_pool + assert len(pool._available_connections) == 1 + assert not pool._available_connections[0]._reader + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("2.8.8") + @skip_if_redis_enterprise() + async def test_busy_loading_disconnects_socket(self, r): + """ + If Redis raises a LOADING error, the connection should be + disconnected and a BusyLoadingError raised + """ + with pytest.raises(exceptions.BusyLoadingError): + await r.execute_command("DEBUG", "ERROR", "LOADING fake message") + if r.connection: + assert not r.connection._reader + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("2.8.8") + @skip_if_redis_enterprise() + async def test_busy_loading_from_pipeline_immediate_command(self, r): + """ + BusyLoadingErrors should raise from Pipelines that execute a + command immediately, like WATCH does. + """ + pipe = r.pipeline() + with pytest.raises(exceptions.BusyLoadingError): + await pipe.immediate_execute_command( + "DEBUG", "ERROR", "LOADING fake message" + ) + pool = r.connection_pool + assert not pipe.connection + assert len(pool._available_connections) == 1 + assert not pool._available_connections[0]._reader + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("2.8.8") + @skip_if_redis_enterprise() + async def test_busy_loading_from_pipeline(self, r): + """ + BusyLoadingErrors should be raised from a pipeline execution + regardless of the raise_on_error flag. + """ + pipe = r.pipeline() + pipe.execute_command("DEBUG", "ERROR", "LOADING fake message") + with pytest.raises(exceptions.BusyLoadingError): + await pipe.execute() + pool = r.connection_pool + assert not pipe.connection + assert len(pool._available_connections) == 1 + assert not pool._available_connections[0]._reader + + @skip_if_server_version_lt("2.8.8") + @skip_if_redis_enterprise() + async def test_read_only_error(self, r): + """READONLY errors get turned in ReadOnlyError exceptions""" + with pytest.raises(exceptions.ReadOnlyError): + await r.execute_command("DEBUG", "ERROR", "READONLY blah blah") + + def test_connect_from_url_tcp(self): + connection = redis.Redis.from_url("redis://localhost") + pool = connection.connection_pool + + assert re.match("(.*)<(.*)<(.*)>>", repr(pool)).groups() == ( + "ConnectionPool", + "Connection", + "host=localhost,port=6379,db=0", + ) + + def test_connect_from_url_unix(self): + connection = redis.Redis.from_url("unix:///path/to/socket") + pool = connection.connection_pool + + assert re.match("(.*)<(.*)<(.*)>>", repr(pool)).groups() == ( + "ConnectionPool", + "UnixDomainSocketConnection", + "path=/path/to/socket,db=0", + ) + + @skip_if_redis_enterprise() + async def test_connect_no_auth_supplied_when_required(self, r): + """ + AuthenticationError should be raised when the server requires a + password but one isn't supplied. + """ + with pytest.raises(exceptions.AuthenticationError): + await r.execute_command( + "DEBUG", "ERROR", "ERR Client sent AUTH, but no password is set" + ) + + @skip_if_redis_enterprise() + async def test_connect_invalid_password_supplied(self, r): + """AuthenticationError should be raised when sending the wrong password""" + with pytest.raises(exceptions.AuthenticationError): + await r.execute_command("DEBUG", "ERROR", "ERR invalid password") + + +@pytest.mark.onlynoncluster +class TestMultiConnectionClient: + @pytest.fixture() + async def r(self, create_redis, server): + redis = await create_redis(single_connection_client=False) + yield redis + await redis.flushall() + + +@pytest.mark.onlynoncluster +class TestHealthCheck: + interval = 60 + + @pytest.fixture() + async def r(self, create_redis): + redis = await create_redis(health_check_interval=self.interval) + yield redis + await redis.flushall() + + def assert_interval_advanced(self, connection): + diff = connection.next_health_check - asyncio.get_event_loop().time() + assert self.interval > diff > (self.interval - 1) + + async def test_health_check_runs(self, r): + if r.connection: + r.connection.next_health_check = asyncio.get_event_loop().time() - 1 + await r.connection.check_health() + self.assert_interval_advanced(r.connection) + + async def test_arbitrary_command_invokes_health_check(self, r): + # invoke a command to make sure the connection is entirely setup + if r.connection: + await r.get("foo") + r.connection.next_health_check = asyncio.get_event_loop().time() + with mock.patch.object( + r.connection, "send_command", wraps=r.connection.send_command + ) as m: + await r.get("foo") + m.assert_called_with("PING", check_health=False) + + self.assert_interval_advanced(r.connection) + + async def test_arbitrary_command_advances_next_health_check(self, r): + if r.connection: + await r.get("foo") + next_health_check = r.connection.next_health_check + await r.get("foo") + assert next_health_check < r.connection.next_health_check + + async def test_health_check_not_invoked_within_interval(self, r): + if r.connection: + await r.get("foo") + with mock.patch.object( + r.connection, "send_command", wraps=r.connection.send_command + ) as m: + await r.get("foo") + ping_call_spec = (("PING",), {"check_health": False}) + assert ping_call_spec not in m.call_args_list + + async def test_health_check_in_pipeline(self, r): + async with r.pipeline(transaction=False) as pipe: + pipe.connection = await pipe.connection_pool.get_connection("_") + pipe.connection.next_health_check = 0 + with mock.patch.object( + pipe.connection, "send_command", wraps=pipe.connection.send_command + ) as m: + responses = await pipe.set("foo", "bar").get("foo").execute() + m.assert_any_call("PING", check_health=False) + assert responses == [True, b"bar"] + + async def test_health_check_in_transaction(self, r): + async with r.pipeline(transaction=True) as pipe: + pipe.connection = await pipe.connection_pool.get_connection("_") + pipe.connection.next_health_check = 0 + with mock.patch.object( + pipe.connection, "send_command", wraps=pipe.connection.send_command + ) as m: + responses = await pipe.set("foo", "bar").get("foo").execute() + m.assert_any_call("PING", check_health=False) + assert responses == [True, b"bar"] + + async def test_health_check_in_watched_pipeline(self, r): + await r.set("foo", "bar") + async with r.pipeline(transaction=False) as pipe: + pipe.connection = await pipe.connection_pool.get_connection("_") + pipe.connection.next_health_check = 0 + with mock.patch.object( + pipe.connection, "send_command", wraps=pipe.connection.send_command + ) as m: + await pipe.watch("foo") + # the health check should be called when watching + m.assert_called_with("PING", check_health=False) + self.assert_interval_advanced(pipe.connection) + assert await pipe.get("foo") == b"bar" + + # reset the mock to clear the call list and schedule another + # health check + m.reset_mock() + pipe.connection.next_health_check = 0 + + pipe.multi() + responses = await pipe.set("foo", "not-bar").get("foo").execute() + assert responses == [True, b"not-bar"] + m.assert_any_call("PING", check_health=False) + + async def test_health_check_in_pubsub_before_subscribe(self, r): + """A health check happens before the first [p]subscribe""" + p = r.pubsub() + p.connection = await p.connection_pool.get_connection("_") + p.connection.next_health_check = 0 + with mock.patch.object( + p.connection, "send_command", wraps=p.connection.send_command + ) as m: + assert not p.subscribed + await p.subscribe("foo") + # the connection is not yet in pubsub mode, so the normal + # ping/pong within connection.send_command should check + # the health of the connection + m.assert_any_call("PING", check_health=False) + self.assert_interval_advanced(p.connection) + + subscribe_message = await wait_for_message(p) + assert subscribe_message["type"] == "subscribe" + + async def test_health_check_in_pubsub_after_subscribed(self, r): + """ + Pubsub can handle a new subscribe when it's time to check the + connection health + """ + p = r.pubsub() + p.connection = await p.connection_pool.get_connection("_") + p.connection.next_health_check = 0 + with mock.patch.object( + p.connection, "send_command", wraps=p.connection.send_command + ) as m: + await p.subscribe("foo") + subscribe_message = await wait_for_message(p) + assert subscribe_message["type"] == "subscribe" + self.assert_interval_advanced(p.connection) + # because we weren't subscribed when sending the subscribe + # message to 'foo', the connection's standard check_health ran + # prior to subscribing. + m.assert_any_call("PING", check_health=False) + + p.connection.next_health_check = 0 + m.reset_mock() + + await p.subscribe("bar") + # the second subscribe issues exactly only command (the subscribe) + # and the health check is not invoked + m.assert_called_once_with("SUBSCRIBE", "bar", check_health=False) + + # since no message has been read since the health check was + # reset, it should still be 0 + assert p.connection.next_health_check == 0 + + subscribe_message = await wait_for_message(p) + assert subscribe_message["type"] == "subscribe" + assert await wait_for_message(p) is None + # now that the connection is subscribed, the pubsub health + # check should have taken over and include the HEALTH_CHECK_MESSAGE + m.assert_any_call("PING", p.HEALTH_CHECK_MESSAGE, check_health=False) + self.assert_interval_advanced(p.connection) + + async def test_health_check_in_pubsub_poll(self, r): + """ + Polling a pubsub connection that's subscribed will regularly + check the connection's health. + """ + p = r.pubsub() + p.connection = await p.connection_pool.get_connection("_") + with mock.patch.object( + p.connection, "send_command", wraps=p.connection.send_command + ) as m: + await p.subscribe("foo") + subscribe_message = await wait_for_message(p) + assert subscribe_message["type"] == "subscribe" + self.assert_interval_advanced(p.connection) + + # polling the connection before the health check interval + # doesn't result in another health check + m.reset_mock() + next_health_check = p.connection.next_health_check + assert await wait_for_message(p) is None + assert p.connection.next_health_check == next_health_check + m.assert_not_called() + + # reset the health check and poll again + # we should not receive a pong message, but the next_health_check + # should be advanced + p.connection.next_health_check = 0 + assert await wait_for_message(p) is None + m.assert_called_with("PING", p.HEALTH_CHECK_MESSAGE, check_health=False) + self.assert_interval_advanced(p.connection) diff --git a/tests/test_asyncio/test_encoding.py b/tests/test_asyncio/test_encoding.py new file mode 100644 index 0000000000..c3c69f5055 --- /dev/null +++ b/tests/test_asyncio/test_encoding.py @@ -0,0 +1,116 @@ +import pytest + +import redis.asyncio as redis +from redis.exceptions import DataError + +pytestmark = pytest.mark.asyncio + + +class TestEncoding: + @pytest.fixture() + async def r(self, create_redis): + redis = await create_redis(decode_responses=True) + yield redis + await redis.flushall() + + @pytest.fixture() + async def r_no_decode(self, create_redis): + redis = await create_redis(decode_responses=False) + yield redis + await redis.flushall() + + async def test_simple_encoding(self, r_no_decode: redis.Redis): + unicode_string = chr(3456) + "abcd" + chr(3421) + await r_no_decode.set("unicode-string", unicode_string.encode("utf-8")) + cached_val = await r_no_decode.get("unicode-string") + assert isinstance(cached_val, bytes) + assert unicode_string == cached_val.decode("utf-8") + + async def test_simple_encoding_and_decoding(self, r: redis.Redis): + unicode_string = chr(3456) + "abcd" + chr(3421) + await r.set("unicode-string", unicode_string) + cached_val = await r.get("unicode-string") + assert isinstance(cached_val, str) + assert unicode_string == cached_val + + async def test_memoryview_encoding(self, r_no_decode: redis.Redis): + unicode_string = chr(3456) + "abcd" + chr(3421) + unicode_string_view = memoryview(unicode_string.encode("utf-8")) + await r_no_decode.set("unicode-string-memoryview", unicode_string_view) + cached_val = await r_no_decode.get("unicode-string-memoryview") + # The cached value won't be a memoryview because it's a copy from Redis + assert isinstance(cached_val, bytes) + assert unicode_string == cached_val.decode("utf-8") + + async def test_memoryview_encoding_and_decoding(self, r: redis.Redis): + unicode_string = chr(3456) + "abcd" + chr(3421) + unicode_string_view = memoryview(unicode_string.encode("utf-8")) + await r.set("unicode-string-memoryview", unicode_string_view) + cached_val = await r.get("unicode-string-memoryview") + assert isinstance(cached_val, str) + assert unicode_string == cached_val + + async def test_list_encoding(self, r: redis.Redis): + unicode_string = chr(3456) + "abcd" + chr(3421) + result = [unicode_string, unicode_string, unicode_string] + await r.rpush("a", *result) + assert await r.lrange("a", 0, -1) == result + + +class TestEncodingErrors: + async def test_ignore(self, create_redis): + r = await create_redis( + decode_responses=True, + encoding_errors="ignore", + ) + await r.set("a", b"foo\xff") + assert await r.get("a") == "foo" + + async def test_replace(self, create_redis): + r = await create_redis( + decode_responses=True, + encoding_errors="replace", + ) + await r.set("a", b"foo\xff") + assert await r.get("a") == "foo\ufffd" + + +class TestMemoryviewsAreNotPacked: + async def test_memoryviews_are_not_packed(self, r): + arg = memoryview(b"some_arg") + arg_list = ["SOME_COMMAND", arg] + c = r.connection or await r.connection_pool.get_connection("_") + cmd = c.pack_command(*arg_list) + assert cmd[1] is arg + cmds = c.pack_commands([arg_list, arg_list]) + assert cmds[1] is arg + assert cmds[3] is arg + + +class TestCommandsAreNotEncoded: + @pytest.fixture() + async def r(self, create_redis): + redis = await create_redis(encoding="utf-16") + yield redis + await redis.flushall() + + async def test_basic_command(self, r: redis.Redis): + await r.set("hello", "world") + + +class TestInvalidUserInput: + async def test_boolean_fails(self, r: redis.Redis): + with pytest.raises(DataError): + await r.set("a", True) # type: ignore + + async def test_none_fails(self, r: redis.Redis): + with pytest.raises(DataError): + await r.set("a", None) # type: ignore + + async def test_user_type_fails(self, r: redis.Redis): + class Foo: + def __str__(self): + return "Foo" + + with pytest.raises(DataError): + await r.set("a", Foo()) # type: ignore diff --git a/tests/test_asyncio/test_lock.py b/tests/test_asyncio/test_lock.py new file mode 100644 index 0000000000..d5c2081493 --- /dev/null +++ b/tests/test_asyncio/test_lock.py @@ -0,0 +1,236 @@ +import asyncio + +import pytest + +from redis.exceptions import LockError, LockNotOwnedError +from redis.asyncio.lock import Lock + +pytestmark = pytest.mark.asyncio + + +@pytest.mark.onlynoncluster +class TestLock: + @pytest.fixture() + async def r_decoded(self, create_redis): + redis = await create_redis(decode_responses=True) + yield redis + await redis.flushall() + + def get_lock(self, redis, *args, **kwargs): + kwargs["lock_class"] = Lock + return redis.lock(*args, **kwargs) + + async def test_lock(self, r): + lock = self.get_lock(r, "foo") + assert await lock.acquire(blocking=False) + assert await r.get("foo") == lock.local.token + assert await r.ttl("foo") == -1 + await lock.release() + assert await r.get("foo") is None + + async def test_lock_token(self, r): + lock = self.get_lock(r, "foo") + await self._test_lock_token(r, lock) + + async def test_lock_token_thread_local_false(self, r): + lock = self.get_lock(r, "foo", thread_local=False) + await self._test_lock_token(r, lock) + + async def _test_lock_token(self, r, lock): + assert await lock.acquire(blocking=False, token="test") + assert await r.get("foo") == b"test" + assert lock.local.token == b"test" + assert await r.ttl("foo") == -1 + await lock.release() + assert await r.get("foo") is None + assert lock.local.token is None + + async def test_locked(self, r): + lock = self.get_lock(r, "foo") + assert await lock.locked() is False + await lock.acquire(blocking=False) + assert await lock.locked() is True + await lock.release() + assert await lock.locked() is False + + async def _test_owned(self, client): + lock = self.get_lock(client, "foo") + assert await lock.owned() is False + await lock.acquire(blocking=False) + assert await lock.owned() is True + await lock.release() + assert await lock.owned() is False + + lock2 = self.get_lock(client, "foo") + assert await lock.owned() is False + assert await lock2.owned() is False + await lock2.acquire(blocking=False) + assert await lock.owned() is False + assert await lock2.owned() is True + await lock2.release() + assert await lock.owned() is False + assert await lock2.owned() is False + + async def test_owned(self, r): + await self._test_owned(r) + + async def test_owned_with_decoded_responses(self, r_decoded): + await self._test_owned(r_decoded) + + async def test_competing_locks(self, r): + lock1 = self.get_lock(r, "foo") + lock2 = self.get_lock(r, "foo") + assert await lock1.acquire(blocking=False) + assert not await lock2.acquire(blocking=False) + await lock1.release() + assert await lock2.acquire(blocking=False) + assert not await lock1.acquire(blocking=False) + await lock2.release() + + async def test_timeout(self, r): + lock = self.get_lock(r, "foo", timeout=10) + assert await lock.acquire(blocking=False) + assert 8 < (await r.ttl("foo")) <= 10 + await lock.release() + + async def test_float_timeout(self, r): + lock = self.get_lock(r, "foo", timeout=9.5) + assert await lock.acquire(blocking=False) + assert 8 < (await r.pttl("foo")) <= 9500 + await lock.release() + + async def test_blocking_timeout(self, r, event_loop): + lock1 = self.get_lock(r, "foo") + assert await lock1.acquire(blocking=False) + bt = 0.2 + sleep = 0.05 + lock2 = self.get_lock(r, "foo", sleep=sleep, blocking_timeout=bt) + start = event_loop.time() + assert not await lock2.acquire() + # The elapsed duration should be less than the total blocking_timeout + assert bt > (event_loop.time() - start) > bt - sleep + await lock1.release() + + async def test_context_manager(self, r): + # blocking_timeout prevents a deadlock if the lock can't be acquired + # for some reason + async with self.get_lock(r, "foo", blocking_timeout=0.2) as lock: + assert await r.get("foo") == lock.local.token + assert await r.get("foo") is None + + async def test_context_manager_raises_when_locked_not_acquired(self, r): + await r.set("foo", "bar") + with pytest.raises(LockError): + async with self.get_lock(r, "foo", blocking_timeout=0.1): + pass + + async def test_high_sleep_small_blocking_timeout(self, r): + lock1 = self.get_lock(r, "foo") + assert await lock1.acquire(blocking=False) + sleep = 60 + bt = 1 + lock2 = self.get_lock(r, "foo", sleep=sleep, blocking_timeout=bt) + start = asyncio.get_event_loop().time() + assert not await lock2.acquire() + # the elapsed timed is less than the blocking_timeout as the lock is + # unattainable given the sleep/blocking_timeout configuration + assert bt > (asyncio.get_event_loop().time() - start) + await lock1.release() + + async def test_releasing_unlocked_lock_raises_error(self, r): + lock = self.get_lock(r, "foo") + with pytest.raises(LockError): + await lock.release() + + async def test_releasing_lock_no_longer_owned_raises_error(self, r): + lock = self.get_lock(r, "foo") + await lock.acquire(blocking=False) + # manually change the token + await r.set("foo", "a") + with pytest.raises(LockNotOwnedError): + await lock.release() + # even though we errored, the token is still cleared + assert lock.local.token is None + + async def test_extend_lock(self, r): + lock = self.get_lock(r, "foo", timeout=10) + assert await lock.acquire(blocking=False) + assert 8000 < (await r.pttl("foo")) <= 10000 + assert await lock.extend(10) + assert 16000 < (await r.pttl("foo")) <= 20000 + await lock.release() + + async def test_extend_lock_replace_ttl(self, r): + lock = self.get_lock(r, "foo", timeout=10) + assert await lock.acquire(blocking=False) + assert 8000 < (await r.pttl("foo")) <= 10000 + assert await lock.extend(10, replace_ttl=True) + assert 8000 < (await r.pttl("foo")) <= 10000 + await lock.release() + + async def test_extend_lock_float(self, r): + lock = self.get_lock(r, "foo", timeout=10.0) + assert await lock.acquire(blocking=False) + assert 8000 < (await r.pttl("foo")) <= 10000 + assert await lock.extend(10.0) + assert 16000 < (await r.pttl("foo")) <= 20000 + await lock.release() + + async def test_extending_unlocked_lock_raises_error(self, r): + lock = self.get_lock(r, "foo", timeout=10) + with pytest.raises(LockError): + await lock.extend(10) + + async def test_extending_lock_with_no_timeout_raises_error(self, r): + lock = self.get_lock(r, "foo") + assert await lock.acquire(blocking=False) + with pytest.raises(LockError): + await lock.extend(10) + await lock.release() + + async def test_extending_lock_no_longer_owned_raises_error(self, r): + lock = self.get_lock(r, "foo", timeout=10) + assert await lock.acquire(blocking=False) + await r.set("foo", "a") + with pytest.raises(LockNotOwnedError): + await lock.extend(10) + + async def test_reacquire_lock(self, r): + lock = self.get_lock(r, "foo", timeout=10) + assert await lock.acquire(blocking=False) + assert await r.pexpire("foo", 5000) + assert await r.pttl("foo") <= 5000 + assert await lock.reacquire() + assert 8000 < (await r.pttl("foo")) <= 10000 + await lock.release() + + async def test_reacquiring_unlocked_lock_raises_error(self, r): + lock = self.get_lock(r, "foo", timeout=10) + with pytest.raises(LockError): + await lock.reacquire() + + async def test_reacquiring_lock_with_no_timeout_raises_error(self, r): + lock = self.get_lock(r, "foo") + assert await lock.acquire(blocking=False) + with pytest.raises(LockError): + await lock.reacquire() + await lock.release() + + async def test_reacquiring_lock_no_longer_owned_raises_error(self, r): + lock = self.get_lock(r, "foo", timeout=10) + assert await lock.acquire(blocking=False) + await r.set("foo", "a") + with pytest.raises(LockNotOwnedError): + await lock.reacquire() + + +@pytest.mark.onlynoncluster +class TestLockClassSelection: + def test_lock_class_argument(self, r): + class MyLock: + def __init__(self, *args, **kwargs): + + pass + + lock = r.lock("foo", lock_class=MyLock) + assert type(lock) == MyLock diff --git a/tests/test_asyncio/test_monitor.py b/tests/test_asyncio/test_monitor.py new file mode 100644 index 0000000000..baeb9cc445 --- /dev/null +++ b/tests/test_asyncio/test_monitor.py @@ -0,0 +1,69 @@ +import pytest + +from tests.conftest import ( + skip_if_redis_enterprise, + skip_ifnot_redis_enterprise, +) +from .conftest import wait_for_command + +pytestmark = pytest.mark.asyncio + + +@pytest.mark.onlynoncluster +class TestMonitor: + async def test_wait_command_not_found(self, r): + """Make sure the wait_for_command func works when command is not found""" + async with r.monitor() as m: + response = await wait_for_command(r, m, "nothing") + assert response is None + + async def test_response_values(self, r): + db = r.connection_pool.connection_kwargs.get("db", 0) + async with r.monitor() as m: + await r.ping() + response = await wait_for_command(r, m, "PING") + assert isinstance(response["time"], float) + assert response["db"] == db + assert response["client_type"] in ("tcp", "unix") + assert isinstance(response["client_address"], str) + assert isinstance(response["client_port"], str) + assert response["command"] == "PING" + + async def test_command_with_quoted_key(self, r): + async with r.monitor() as m: + await r.get('foo"bar') + response = await wait_for_command(r, m, 'GET foo"bar') + assert response["command"] == 'GET foo"bar' + + async def test_command_with_binary_data(self, r): + async with r.monitor() as m: + byte_string = b"foo\x92" + await r.get(byte_string) + response = await wait_for_command(r, m, "GET foo\\x92") + assert response["command"] == "GET foo\\x92" + + async def test_command_with_escaped_data(self, r): + async with r.monitor() as m: + byte_string = b"foo\\x92" + await r.get(byte_string) + response = await wait_for_command(r, m, "GET foo\\\\x92") + assert response["command"] == "GET foo\\\\x92" + + @skip_if_redis_enterprise() + async def test_lua_script(self, r): + async with r.monitor() as m: + script = 'return redis.call("GET", "foo")' + assert await r.eval(script, 0) is None + response = await wait_for_command(r, m, "GET foo") + assert response["command"] == "GET foo" + assert response["client_type"] == "lua" + assert response["client_address"] == "lua" + assert response["client_port"] == "" + + @skip_ifnot_redis_enterprise() + async def test_lua_script_in_enterprise(self, r): + async with r.monitor() as m: + script = 'return redis.call("GET", "foo")' + assert await r.eval(script, 0) is None + response = await wait_for_command(r, m, "GET foo") + assert response is None diff --git a/tests/test_asyncio/test_multiprocessing.py b/tests/test_asyncio/test_multiprocessing.py new file mode 100644 index 0000000000..bd21c289be --- /dev/null +++ b/tests/test_asyncio/test_multiprocessing.py @@ -0,0 +1,181 @@ +import asyncio +import contextlib +import multiprocessing + +import pytest + +from redis.asyncio.connection import Connection, ConnectionPool +from redis.exceptions import ConnectionError + +pytestmark = pytest.mark.asyncio + + +@contextlib.contextmanager +async def exit_callback(callback, *args): + try: + yield + finally: + await callback(*args) + + +@pytest.mark.xfail() +class TestMultiprocessing: + # Test connection sharing between forks. + # See issue #1085 for details. + + # use a multi-connection client as that's the only type that is + # actually fork/process-safe + @pytest.fixture() + async def r(self, create_redis): + redis = await create_redis( + single_connection_client=False, + ) + yield redis + await redis.flushall() + + async def test_close_connection_in_child(self, master_host): + """ + A connection owned by a parent and closed by a child doesn't + destroy the file descriptors so a parent can still use it. + """ + conn = Connection(host=master_host) + await conn.send_command("ping") + assert await conn.read_response() == b"PONG" + + def target(conn): + async def atarget(conn): + await conn.send_command("ping") + assert conn.read_response() == b"PONG" + await conn.disconnect() + + asyncio.get_event_loop().run_until_complete(atarget(conn)) + + proc = multiprocessing.Process(target=target, args=(conn,)) + proc.start() + proc.join(3) + assert proc.exitcode == 0 + + # The connection was created in the parent but disconnected in the + # child. The child called socket.close() but did not call + # socket.shutdown() because it wasn't the "owning" process. + # Therefore the connection still works in the parent. + await conn.send_command("ping") + assert await conn.read_response() == b"PONG" + + async def test_close_connection_in_parent(self, master_host): + """ + A connection owned by a parent is unusable by a child if the parent + (the owning process) closes the connection. + """ + conn = Connection(host=master_host) + await conn.send_command("ping") + assert await conn.read_response() == b"PONG" + + def target(conn, ev): + ev.wait() + # the parent closed the connection. because it also created the + # connection, the connection is shutdown and the child + # cannot use it. + with pytest.raises(ConnectionError): + asyncio.get_event_loop().run_until_complete(conn.send_command("ping")) + + ev = multiprocessing.Event() + proc = multiprocessing.Process(target=target, args=(conn, ev)) + proc.start() + + await conn.disconnect() + ev.set() + + proc.join(3) + assert proc.exitcode == 0 + + @pytest.mark.parametrize("max_connections", [1, 2, None]) + async def test_pool(self, max_connections, master_host): + """ + A child will create its own connections when using a pool created + by a parent. + """ + pool = ConnectionPool.from_url( + f"redis://{master_host}", max_connections=max_connections + ) + + conn = await pool.get_connection("ping") + main_conn_pid = conn.pid + async with exit_callback(pool.release, conn): + await conn.send_command("ping") + assert await conn.read_response() == b"PONG" + + def target(pool): + async def atarget(pool): + async with exit_callback(pool.disconnect): + conn = await pool.get_connection("ping") + assert conn.pid != main_conn_pid + async with exit_callback(pool.release, conn): + assert await conn.send_command("ping") is None + assert await conn.read_response() == b"PONG" + + asyncio.get_event_loop().run_until_complete(atarget(pool)) + + proc = multiprocessing.Process(target=target, args=(pool,)) + proc.start() + proc.join(3) + assert proc.exitcode == 0 + + # Check that connection is still alive after fork process has exited + # and disconnected the connections in its pool + conn = pool.get_connection("ping") + async with exit_callback(pool.release, conn): + assert await conn.send_command("ping") is None + assert await conn.read_response() == b"PONG" + + @pytest.mark.parametrize("max_connections", [1, 2, None]) + async def test_close_pool_in_main(self, max_connections, master_host): + """ + A child process that uses the same pool as its parent isn't affected + when the parent disconnects all connections within the pool. + """ + pool = ConnectionPool.from_url( + f"redis://{master_host}", max_connections=max_connections + ) + + conn = await pool.get_connection("ping") + assert await conn.send_command("ping") is None + assert await conn.read_response() == b"PONG" + + def target(pool, disconnect_event): + async def atarget(pool, disconnect_event): + conn = await pool.get_connection("ping") + async with exit_callback(pool.release, conn): + assert await conn.send_command("ping") is None + assert await conn.read_response() == b"PONG" + disconnect_event.wait() + assert await conn.send_command("ping") is None + assert await conn.read_response() == b"PONG" + + asyncio.get_event_loop().run_until_complete(atarget(pool, disconnect_event)) + + ev = multiprocessing.Event() + + proc = multiprocessing.Process(target=target, args=(pool, ev)) + proc.start() + + await pool.disconnect() + ev.set() + proc.join(3) + assert proc.exitcode == 0 + + async def test_aioredis_client(self, r): + """A aioredis client created in a parent can also be used in a child""" + assert await r.ping() is True + + def target(client): + run = asyncio.get_event_loop().run_until_complete + assert run(client.ping()) is True + del client + + proc = multiprocessing.Process(target=target, args=(r,)) + proc.start() + proc.join(3) + assert proc.exitcode == 0 + + assert await r.ping() is True diff --git a/tests/test_asyncio/test_pipeline.py b/tests/test_asyncio/test_pipeline.py new file mode 100644 index 0000000000..8011c7258b --- /dev/null +++ b/tests/test_asyncio/test_pipeline.py @@ -0,0 +1,408 @@ +import pytest + +import redis + +from tests.conftest import skip_if_server_version_lt +from .conftest import wait_for_command + +pytestmark = pytest.mark.asyncio + + +class TestPipeline: + async def test_pipeline_is_true(self, r): + """Ensure pipeline instances are not false-y""" + async with r.pipeline() as pipe: + assert pipe + + async def test_pipeline(self, r): + async with r.pipeline() as pipe: + ( + pipe.set("a", "a1") + .get("a") + .zadd("z", {"z1": 1}) + .zadd("z", {"z2": 4}) + .zincrby("z", 1, "z1") + .zrange("z", 0, 5, withscores=True) + ) + assert await pipe.execute() == [ + True, + b"a1", + True, + True, + 2.0, + [(b"z1", 2.0), (b"z2", 4)], + ] + + async def test_pipeline_memoryview(self, r): + async with r.pipeline() as pipe: + (pipe.set("a", memoryview(b"a1")).get("a")) + assert await pipe.execute() == [ + True, + b"a1", + ] + + async def test_pipeline_length(self, r): + async with r.pipeline() as pipe: + # Initially empty. + assert len(pipe) == 0 + + # Fill 'er up! + pipe.set("a", "a1").set("b", "b1").set("c", "c1") + assert len(pipe) == 3 + + # Execute calls reset(), so empty once again. + await pipe.execute() + assert len(pipe) == 0 + + @pytest.mark.onlynoncluster + async def test_pipeline_no_transaction(self, r): + async with r.pipeline(transaction=False) as pipe: + pipe.set("a", "a1").set("b", "b1").set("c", "c1") + assert await pipe.execute() == [True, True, True] + assert await r.get("a") == b"a1" + assert await r.get("b") == b"b1" + assert await r.get("c") == b"c1" + + async def test_pipeline_no_transaction_watch(self, r): + await r.set("a", 0) + + async with r.pipeline(transaction=False) as pipe: + await pipe.watch("a") + a = await pipe.get("a") + + pipe.multi() + pipe.set("a", int(a) + 1) + assert await pipe.execute() == [True] + + async def test_pipeline_no_transaction_watch_failure(self, r): + await r.set("a", 0) + + async with r.pipeline(transaction=False) as pipe: + await pipe.watch("a") + a = await pipe.get("a") + + await r.set("a", "bad") + + pipe.multi() + pipe.set("a", int(a) + 1) + + with pytest.raises(redis.WatchError): + await pipe.execute() + + assert await r.get("a") == b"bad" + + async def test_exec_error_in_response(self, r): + """ + an invalid pipeline command at exec time adds the exception instance + to the list of returned values + """ + await r.set("c", "a") + async with r.pipeline() as pipe: + pipe.set("a", 1).set("b", 2).lpush("c", 3).set("d", 4) + result = await pipe.execute(raise_on_error=False) + + assert result[0] + assert await r.get("a") == b"1" + assert result[1] + assert await r.get("b") == b"2" + + # we can't lpush to a key that's a string value, so this should + # be a ResponseError exception + assert isinstance(result[2], redis.ResponseError) + assert await r.get("c") == b"a" + + # since this isn't a transaction, the other commands after the + # error are still executed + assert result[3] + assert await r.get("d") == b"4" + + # make sure the pipe was restored to a working state + assert await pipe.set("z", "zzz").execute() == [True] + assert await r.get("z") == b"zzz" + + async def test_exec_error_raised(self, r): + await r.set("c", "a") + async with r.pipeline() as pipe: + pipe.set("a", 1).set("b", 2).lpush("c", 3).set("d", 4) + with pytest.raises(redis.ResponseError) as ex: + await pipe.execute() + assert str(ex.value).startswith( + "Command # 3 (LPUSH c 3) of " "pipeline caused error: " + ) + + # make sure the pipe was restored to a working state + assert await pipe.set("z", "zzz").execute() == [True] + assert await r.get("z") == b"zzz" + + @pytest.mark.onlynoncluster + async def test_transaction_with_empty_error_command(self, r): + """ + Commands with custom EMPTY_ERROR functionality return their default + values in the pipeline no matter the raise_on_error preference + """ + for error_switch in (True, False): + async with r.pipeline() as pipe: + pipe.set("a", 1).mget([]).set("c", 3) + result = await pipe.execute(raise_on_error=error_switch) + + assert result[0] + assert result[1] == [] + assert result[2] + + @pytest.mark.onlynoncluster + async def test_pipeline_with_empty_error_command(self, r): + """ + Commands with custom EMPTY_ERROR functionality return their default + values in the pipeline no matter the raise_on_error preference + """ + for error_switch in (True, False): + async with r.pipeline(transaction=False) as pipe: + pipe.set("a", 1).mget([]).set("c", 3) + result = await pipe.execute(raise_on_error=error_switch) + + assert result[0] + assert result[1] == [] + assert result[2] + + async def test_parse_error_raised(self, r): + async with r.pipeline() as pipe: + # the zrem is invalid because we don't pass any keys to it + pipe.set("a", 1).zrem("b").set("b", 2) + with pytest.raises(redis.ResponseError) as ex: + await pipe.execute() + + assert str(ex.value).startswith( + "Command # 2 (ZREM b) of " "pipeline caused error: " + ) + + # make sure the pipe was restored to a working state + assert await pipe.set("z", "zzz").execute() == [True] + assert await r.get("z") == b"zzz" + + @pytest.mark.onlynoncluster + async def test_parse_error_raised_transaction(self, r): + async with r.pipeline() as pipe: + pipe.multi() + # the zrem is invalid because we don't pass any keys to it + pipe.set("a", 1).zrem("b").set("b", 2) + with pytest.raises(redis.ResponseError) as ex: + await pipe.execute() + + assert str(ex.value).startswith( + "Command # 2 (ZREM b) of " "pipeline caused error: " + ) + + # make sure the pipe was restored to a working state + assert await pipe.set("z", "zzz").execute() == [True] + assert await r.get("z") == b"zzz" + + @pytest.mark.onlynoncluster + async def test_watch_succeed(self, r): + await r.set("a", 1) + await r.set("b", 2) + + async with r.pipeline() as pipe: + await pipe.watch("a", "b") + assert pipe.watching + a_value = await pipe.get("a") + b_value = await pipe.get("b") + assert a_value == b"1" + assert b_value == b"2" + pipe.multi() + + pipe.set("c", 3) + assert await pipe.execute() == [True] + assert not pipe.watching + + @pytest.mark.onlynoncluster + async def test_watch_failure(self, r): + await r.set("a", 1) + await r.set("b", 2) + + async with r.pipeline() as pipe: + await pipe.watch("a", "b") + await r.set("b", 3) + pipe.multi() + pipe.get("a") + with pytest.raises(redis.WatchError): + await pipe.execute() + + assert not pipe.watching + + @pytest.mark.onlynoncluster + async def test_watch_failure_in_empty_transaction(self, r): + await r.set("a", 1) + await r.set("b", 2) + + async with r.pipeline() as pipe: + await pipe.watch("a", "b") + await r.set("b", 3) + pipe.multi() + with pytest.raises(redis.WatchError): + await pipe.execute() + + assert not pipe.watching + + @pytest.mark.onlynoncluster + async def test_unwatch(self, r): + await r.set("a", 1) + await r.set("b", 2) + + async with r.pipeline() as pipe: + await pipe.watch("a", "b") + await r.set("b", 3) + await pipe.unwatch() + assert not pipe.watching + pipe.get("a") + assert await pipe.execute() == [b"1"] + + @pytest.mark.onlynoncluster + async def test_watch_exec_no_unwatch(self, r): + await r.set("a", 1) + await r.set("b", 2) + + async with r.monitor() as m: + async with r.pipeline() as pipe: + await pipe.watch("a", "b") + assert pipe.watching + a_value = await pipe.get("a") + b_value = await pipe.get("b") + assert a_value == b"1" + assert b_value == b"2" + pipe.multi() + pipe.set("c", 3) + assert await pipe.execute() == [True] + assert not pipe.watching + + unwatch_command = await wait_for_command(r, m, "UNWATCH") + assert unwatch_command is None, "should not send UNWATCH" + + @pytest.mark.onlynoncluster + async def test_watch_reset_unwatch(self, r): + await r.set("a", 1) + + async with r.monitor() as m: + async with r.pipeline() as pipe: + await pipe.watch("a") + assert pipe.watching + await pipe.reset() + assert not pipe.watching + + unwatch_command = await wait_for_command(r, m, "UNWATCH") + assert unwatch_command is not None + assert unwatch_command["command"] == "UNWATCH" + + @pytest.mark.onlynoncluster + async def test_transaction_callable(self, r): + await r.set("a", 1) + await r.set("b", 2) + has_run = [] + + async def my_transaction(pipe): + a_value = await pipe.get("a") + assert a_value in (b"1", b"2") + b_value = await pipe.get("b") + assert b_value == b"2" + + # silly run-once code... incr's "a" so WatchError should be raised + # forcing this all to run again. this should incr "a" once to "2" + if not has_run: + await r.incr("a") + has_run.append("it has") + + pipe.multi() + pipe.set("c", int(a_value) + int(b_value)) + + result = await r.transaction(my_transaction, "a", "b") + assert result == [True] + assert await r.get("c") == b"4" + + @pytest.mark.onlynoncluster + async def test_transaction_callable_returns_value_from_callable(self, r): + async def callback(pipe): + # No need to do anything here since we only want the return value + return "a" + + res = await r.transaction(callback, "my-key", value_from_callable=True) + assert res == "a" + + async def test_exec_error_in_no_transaction_pipeline(self, r): + await r.set("a", 1) + async with r.pipeline(transaction=False) as pipe: + pipe.llen("a") + pipe.expire("a", 100) + + with pytest.raises(redis.ResponseError) as ex: + await pipe.execute() + + assert str(ex.value).startswith( + "Command # 1 (LLEN a) of " "pipeline caused error: " + ) + + assert await r.get("a") == b"1" + + async def test_exec_error_in_no_transaction_pipeline_unicode_command(self, r): + key = chr(3456) + "abcd" + chr(3421) + await r.set(key, 1) + async with r.pipeline(transaction=False) as pipe: + pipe.llen(key) + pipe.expire(key, 100) + + with pytest.raises(redis.ResponseError) as ex: + await pipe.execute() + + expected = "Command # 1 (LLEN %s) of pipeline caused error: " % key + assert str(ex.value).startswith(expected) + + assert await r.get(key) == b"1" + + async def test_pipeline_with_bitfield(self, r): + async with r.pipeline() as pipe: + pipe.set("a", "1") + bf = pipe.bitfield("b") + pipe2 = ( + bf.set("u8", 8, 255) + .get("u8", 0) + .get("u4", 8) # 1111 + .get("u4", 12) # 1111 + .get("u4", 13) # 1110 + .execute() + ) + pipe.get("a") + response = await pipe.execute() + + assert pipe == pipe2 + assert response == [True, [0, 0, 15, 15, 14], b"1"] + + async def test_pipeline_get(self, r): + await r.set("a", "a1") + async with r.pipeline() as pipe: + await pipe.get("a") + assert await pipe.execute() == [b"a1"] + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("2.0.0") + async def test_pipeline_discard(self, r): + + # empty pipeline should raise an error + async with r.pipeline() as pipe: + pipe.set("key", "someval") + await pipe.discard() + with pytest.raises(redis.exceptions.ResponseError): + await pipe.execute() + + # setting a pipeline and discarding should do the same + async with r.pipeline() as pipe: + pipe.set("key", "someval") + pipe.set("someotherkey", "val") + response = await pipe.execute() + pipe.set("key", "another value!") + await pipe.discard() + pipe.set("key", "another vae!") + with pytest.raises(redis.exceptions.ResponseError): + await pipe.execute() + + pipe.set("foo", "bar") + response = await pipe.execute() + assert response[0] + assert await r.get("foo") == b"bar" diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py new file mode 100644 index 0000000000..0fb8585670 --- /dev/null +++ b/tests/test_asyncio/test_pubsub.py @@ -0,0 +1,626 @@ +import asyncio +from typing import Optional + +import pytest + +import redis.asyncio as redis +from redis.exceptions import ConnectionError +from redis.typing import EncodableT + +from .compat import mock +from tests.conftest import skip_if_server_version_lt + +pytestmark = pytest.mark.asyncio(forbid_global_loop=True) + + +async def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False): + now = asyncio.get_event_loop().time() + timeout = now + timeout + while now < timeout: + message = await pubsub.get_message( + ignore_subscribe_messages=ignore_subscribe_messages + ) + if message is not None: + return message + await asyncio.sleep(0.01) + now = asyncio.get_event_loop().time() + return None + + +def make_message( + type, channel: Optional[str], data: EncodableT, pattern: Optional[str] = None +): + return { + "type": type, + "pattern": pattern and pattern.encode("utf-8") or None, + "channel": channel and channel.encode("utf-8") or None, + "data": data.encode("utf-8") if isinstance(data, str) else data, + } + + +def make_subscribe_test_data(pubsub, type): + if type == "channel": + return { + "p": pubsub, + "sub_type": "subscribe", + "unsub_type": "unsubscribe", + "sub_func": pubsub.subscribe, + "unsub_func": pubsub.unsubscribe, + "keys": ["foo", "bar", "uni" + chr(4456) + "code"], + } + elif type == "pattern": + return { + "p": pubsub, + "sub_type": "psubscribe", + "unsub_type": "punsubscribe", + "sub_func": pubsub.psubscribe, + "unsub_func": pubsub.punsubscribe, + "keys": ["f*", "b*", "uni" + chr(4456) + "*"], + } + assert False, "invalid subscribe type: %s" % type + + +class TestPubSubSubscribeUnsubscribe: + async def _test_subscribe_unsubscribe( + self, p, sub_type, unsub_type, sub_func, unsub_func, keys + ): + for key in keys: + assert await sub_func(key) is None + + # should be a message for each channel/pattern we just subscribed to + for i, key in enumerate(keys): + assert await wait_for_message(p) == make_message(sub_type, key, i + 1) + + for key in keys: + assert await unsub_func(key) is None + + # should be a message for each channel/pattern we just unsubscribed + # from + for i, key in enumerate(keys): + i = len(keys) - 1 - i + assert await wait_for_message(p) == make_message(unsub_type, key, i) + + async def test_channel_subscribe_unsubscribe(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "channel") + await self._test_subscribe_unsubscribe(**kwargs) + + async def test_pattern_subscribe_unsubscribe(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "pattern") + await self._test_subscribe_unsubscribe(**kwargs) + + @pytest.mark.onlynoncluster + async def _test_resubscribe_on_reconnection( + self, p, sub_type, unsub_type, sub_func, unsub_func, keys + ): + + for key in keys: + assert await sub_func(key) is None + + # should be a message for each channel/pattern we just subscribed to + for i, key in enumerate(keys): + assert await wait_for_message(p) == make_message(sub_type, key, i + 1) + + # manually disconnect + await p.connection.disconnect() + + # calling get_message again reconnects and resubscribes + # note, we may not re-subscribe to channels in exactly the same order + # so we have to do some extra checks to make sure we got them all + messages = [] + for i in range(len(keys)): + messages.append(await wait_for_message(p)) + + unique_channels = set() + assert len(messages) == len(keys) + for i, message in enumerate(messages): + assert message["type"] == sub_type + assert message["data"] == i + 1 + assert isinstance(message["channel"], bytes) + channel = message["channel"].decode("utf-8") + unique_channels.add(channel) + + assert len(unique_channels) == len(keys) + for channel in unique_channels: + assert channel in keys + + async def test_resubscribe_to_channels_on_reconnection(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "channel") + await self._test_resubscribe_on_reconnection(**kwargs) + + async def test_resubscribe_to_patterns_on_reconnection(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "pattern") + await self._test_resubscribe_on_reconnection(**kwargs) + + async def _test_subscribed_property( + self, p, sub_type, unsub_type, sub_func, unsub_func, keys + ): + + assert p.subscribed is False + await sub_func(keys[0]) + # we're now subscribed even though we haven't processed the + # reply from the server just yet + assert p.subscribed is True + assert await wait_for_message(p) == make_message(sub_type, keys[0], 1) + # we're still subscribed + assert p.subscribed is True + + # unsubscribe from all channels + await unsub_func() + # we're still technically subscribed until we process the + # response messages from the server + assert p.subscribed is True + assert await wait_for_message(p) == make_message(unsub_type, keys[0], 0) + # now we're no longer subscribed as no more messages can be delivered + # to any channels we were listening to + assert p.subscribed is False + + # subscribing again flips the flag back + await sub_func(keys[0]) + assert p.subscribed is True + assert await wait_for_message(p) == make_message(sub_type, keys[0], 1) + + # unsubscribe again + await unsub_func() + assert p.subscribed is True + # subscribe to another channel before reading the unsubscribe response + await sub_func(keys[1]) + assert p.subscribed is True + # read the unsubscribe for key1 + assert await wait_for_message(p) == make_message(unsub_type, keys[0], 0) + # we're still subscribed to key2, so subscribed should still be True + assert p.subscribed is True + # read the key2 subscribe message + assert await wait_for_message(p) == make_message(sub_type, keys[1], 1) + await unsub_func() + # haven't read the message yet, so we're still subscribed + assert p.subscribed is True + assert await wait_for_message(p) == make_message(unsub_type, keys[1], 0) + # now we're finally unsubscribed + assert p.subscribed is False + + async def test_subscribe_property_with_channels(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "channel") + await self._test_subscribed_property(**kwargs) + + @pytest.mark.onlynoncluster + async def test_subscribe_property_with_patterns(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "pattern") + await self._test_subscribed_property(**kwargs) + + async def test_ignore_all_subscribe_messages(self, r: redis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + + checks = ( + (p.subscribe, "foo"), + (p.unsubscribe, "foo"), + (p.psubscribe, "f*"), + (p.punsubscribe, "f*"), + ) + + assert p.subscribed is False + for func, channel in checks: + assert await func(channel) is None + assert p.subscribed is True + assert await wait_for_message(p) is None + assert p.subscribed is False + + async def test_ignore_individual_subscribe_messages(self, r: redis.Redis): + p = r.pubsub() + + checks = ( + (p.subscribe, "foo"), + (p.unsubscribe, "foo"), + (p.psubscribe, "f*"), + (p.punsubscribe, "f*"), + ) + + assert p.subscribed is False + for func, channel in checks: + assert await func(channel) is None + assert p.subscribed is True + message = await wait_for_message(p, ignore_subscribe_messages=True) + assert message is None + assert p.subscribed is False + + async def test_sub_unsub_resub_channels(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "channel") + await self._test_sub_unsub_resub(**kwargs) + + @pytest.mark.onlynoncluster + async def test_sub_unsub_resub_patterns(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "pattern") + await self._test_sub_unsub_resub(**kwargs) + + async def _test_sub_unsub_resub( + self, p, sub_type, unsub_type, sub_func, unsub_func, keys + ): + # https://github.com/andymccurdy/redis-py/issues/764 + key = keys[0] + await sub_func(key) + await unsub_func(key) + await sub_func(key) + assert p.subscribed is True + assert await wait_for_message(p) == make_message(sub_type, key, 1) + assert await wait_for_message(p) == make_message(unsub_type, key, 0) + assert await wait_for_message(p) == make_message(sub_type, key, 1) + assert p.subscribed is True + + async def test_sub_unsub_all_resub_channels(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "channel") + await self._test_sub_unsub_all_resub(**kwargs) + + async def test_sub_unsub_all_resub_patterns(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "pattern") + await self._test_sub_unsub_all_resub(**kwargs) + + async def _test_sub_unsub_all_resub( + self, p, sub_type, unsub_type, sub_func, unsub_func, keys + ): + # https://github.com/andymccurdy/redis-py/issues/764 + key = keys[0] + await sub_func(key) + await unsub_func() + await sub_func(key) + assert p.subscribed is True + assert await wait_for_message(p) == make_message(sub_type, key, 1) + assert await wait_for_message(p) == make_message(unsub_type, key, 0) + assert await wait_for_message(p) == make_message(sub_type, key, 1) + assert p.subscribed is True + + +class TestPubSubMessages: + def setup_method(self, method): + self.message = None + + def message_handler(self, message): + self.message = message + + async def test_published_message_to_channel(self, r: redis.Redis): + p = r.pubsub() + await p.subscribe("foo") + assert await wait_for_message(p) == make_message("subscribe", "foo", 1) + assert await r.publish("foo", "test message") == 1 + + message = await wait_for_message(p) + assert isinstance(message, dict) + assert message == make_message("message", "foo", "test message") + + async def test_published_message_to_pattern(self, r: redis.Redis): + p = r.pubsub() + await p.subscribe("foo") + await p.psubscribe("f*") + assert await wait_for_message(p) == make_message("subscribe", "foo", 1) + assert await wait_for_message(p) == make_message("psubscribe", "f*", 2) + # 1 to pattern, 1 to channel + assert await r.publish("foo", "test message") == 2 + + message1 = await wait_for_message(p) + message2 = await wait_for_message(p) + assert isinstance(message1, dict) + assert isinstance(message2, dict) + + expected = [ + make_message("message", "foo", "test message"), + make_message("pmessage", "foo", "test message", pattern="f*"), + ] + + assert message1 in expected + assert message2 in expected + assert message1 != message2 + + async def test_channel_message_handler(self, r: redis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + await p.subscribe(foo=self.message_handler) + assert await wait_for_message(p) is None + assert await r.publish("foo", "test message") == 1 + assert await wait_for_message(p) is None + assert self.message == make_message("message", "foo", "test message") + + @pytest.mark.onlynoncluster + async def test_pattern_message_handler(self, r: redis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + await p.psubscribe(**{"f*": self.message_handler}) + assert await wait_for_message(p) is None + assert await r.publish("foo", "test message") == 1 + assert await wait_for_message(p) is None + assert self.message == make_message( + "pmessage", "foo", "test message", pattern="f*" + ) + + async def test_unicode_channel_message_handler(self, r: redis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + channel = "uni" + chr(4456) + "code" + channels = {channel: self.message_handler} + await p.subscribe(**channels) + assert await wait_for_message(p) is None + assert await r.publish(channel, "test message") == 1 + assert await wait_for_message(p) is None + assert self.message == make_message("message", channel, "test message") + + @pytest.mark.onlynoncluster + # see: https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html + # #known-limitations-with-pubsub + async def test_unicode_pattern_message_handler(self, r: redis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + pattern = "uni" + chr(4456) + "*" + channel = "uni" + chr(4456) + "code" + await p.psubscribe(**{pattern: self.message_handler}) + assert await wait_for_message(p) is None + assert await r.publish(channel, "test message") == 1 + assert await wait_for_message(p) is None + assert self.message == make_message( + "pmessage", channel, "test message", pattern=pattern + ) + + async def test_get_message_without_subscribe(self, r: redis.Redis): + p = r.pubsub() + with pytest.raises(RuntimeError) as info: + await p.get_message() + expect = ( + "connection not set: " "did you forget to call subscribe() or psubscribe()?" + ) + assert expect in info.exconly() + + +class TestPubSubAutoDecoding: + """These tests only validate that we get unicode values back""" + + channel = "uni" + chr(4456) + "code" + pattern = "uni" + chr(4456) + "*" + data = "abc" + chr(4458) + "123" + + def make_message(self, type, channel, data, pattern=None): + return {"type": type, "channel": channel, "pattern": pattern, "data": data} + + def setup_method(self, method): + self.message = None + + def message_handler(self, message): + self.message = message + + @pytest.fixture() + async def r(self, create_redis): + return await create_redis( + decode_responses=True, + ) + + async def test_channel_subscribe_unsubscribe(self, r: redis.Redis): + p = r.pubsub() + await p.subscribe(self.channel) + assert await wait_for_message(p) == self.make_message( + "subscribe", self.channel, 1 + ) + + await p.unsubscribe(self.channel) + assert await wait_for_message(p) == self.make_message( + "unsubscribe", self.channel, 0 + ) + + async def test_pattern_subscribe_unsubscribe(self, r: redis.Redis): + p = r.pubsub() + await p.psubscribe(self.pattern) + assert await wait_for_message(p) == self.make_message( + "psubscribe", self.pattern, 1 + ) + + await p.punsubscribe(self.pattern) + assert await wait_for_message(p) == self.make_message( + "punsubscribe", self.pattern, 0 + ) + + async def test_channel_publish(self, r: redis.Redis): + p = r.pubsub() + await p.subscribe(self.channel) + assert await wait_for_message(p) == self.make_message( + "subscribe", self.channel, 1 + ) + await r.publish(self.channel, self.data) + assert await wait_for_message(p) == self.make_message( + "message", self.channel, self.data + ) + + @pytest.mark.onlynoncluster + async def test_pattern_publish(self, r: redis.Redis): + p = r.pubsub() + await p.psubscribe(self.pattern) + assert await wait_for_message(p) == self.make_message( + "psubscribe", self.pattern, 1 + ) + await r.publish(self.channel, self.data) + assert await wait_for_message(p) == self.make_message( + "pmessage", self.channel, self.data, pattern=self.pattern + ) + + async def test_channel_message_handler(self, r: redis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + await p.subscribe(**{self.channel: self.message_handler}) + assert await wait_for_message(p) is None + await r.publish(self.channel, self.data) + assert await wait_for_message(p) is None + assert self.message == self.make_message("message", self.channel, self.data) + + # test that we reconnected to the correct channel + self.message = None + await p.connection.disconnect() + assert await wait_for_message(p) is None # should reconnect + new_data = self.data + "new data" + await r.publish(self.channel, new_data) + assert await wait_for_message(p) is None + assert self.message == self.make_message("message", self.channel, new_data) + + async def test_pattern_message_handler(self, r: redis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + await p.psubscribe(**{self.pattern: self.message_handler}) + assert await wait_for_message(p) is None + await r.publish(self.channel, self.data) + assert await wait_for_message(p) is None + assert self.message == self.make_message( + "pmessage", self.channel, self.data, pattern=self.pattern + ) + + # test that we reconnected to the correct pattern + self.message = None + await p.connection.disconnect() + assert await wait_for_message(p) is None # should reconnect + new_data = self.data + "new data" + await r.publish(self.channel, new_data) + assert await wait_for_message(p) is None + assert self.message == self.make_message( + "pmessage", self.channel, new_data, pattern=self.pattern + ) + + async def test_context_manager(self, r: redis.Redis): + async with r.pubsub() as pubsub: + await pubsub.subscribe("foo") + assert pubsub.connection is not None + + assert pubsub.connection is None + assert pubsub.channels == {} + assert pubsub.patterns == {} + + +class TestPubSubRedisDown: + async def test_channel_subscribe(self, r: redis.Redis): + r = redis.Redis(host="localhost", port=6390) + p = r.pubsub() + with pytest.raises(ConnectionError): + await p.subscribe("foo") + + +class TestPubSubSubcommands: + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("2.8.0") + async def test_pubsub_channels(self, r: redis.Redis): + p = r.pubsub() + await p.subscribe("foo", "bar", "baz", "quux") + for i in range(4): + assert (await wait_for_message(p))["type"] == "subscribe" + expected = [b"bar", b"baz", b"foo", b"quux"] + assert all([channel in await r.pubsub_channels() for channel in expected]) + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("2.8.0") + async def test_pubsub_numsub(self, r: redis.Redis): + p1 = r.pubsub() + await p1.subscribe("foo", "bar", "baz") + for i in range(3): + assert (await wait_for_message(p1))["type"] == "subscribe" + p2 = r.pubsub() + await p2.subscribe("bar", "baz") + for i in range(2): + assert (await wait_for_message(p2))["type"] == "subscribe" + p3 = r.pubsub() + await p3.subscribe("baz") + assert (await wait_for_message(p3))["type"] == "subscribe" + + channels = [(b"foo", 1), (b"bar", 2), (b"baz", 3)] + assert await r.pubsub_numsub("foo", "bar", "baz") == channels + + @skip_if_server_version_lt("2.8.0") + async def test_pubsub_numpat(self, r: redis.Redis): + p = r.pubsub() + await p.psubscribe("*oo", "*ar", "b*z") + for i in range(3): + assert (await wait_for_message(p))["type"] == "psubscribe" + assert await r.pubsub_numpat() == 3 + + +class TestPubSubPings: + @skip_if_server_version_lt("3.0.0") + async def test_send_pubsub_ping(self, r: redis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + await p.subscribe("foo") + await p.ping() + assert await wait_for_message(p) == make_message( + type="pong", channel=None, data="", pattern=None + ) + + @skip_if_server_version_lt("3.0.0") + async def test_send_pubsub_ping_message(self, r: redis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + await p.subscribe("foo") + await p.ping(message="hello world") + assert await wait_for_message(p) == make_message( + type="pong", channel=None, data="hello world", pattern=None + ) + + +@pytest.mark.onlynoncluster +class TestPubSubConnectionKilled: + @skip_if_server_version_lt("3.0.0") + async def test_connection_error_raised_when_connection_dies( + self, r: redis.Redis + ): + p = r.pubsub() + await p.subscribe("foo") + assert await wait_for_message(p) == make_message("subscribe", "foo", 1) + for client in await r.client_list(): + if client["cmd"] == "subscribe": + await r.client_kill_filter(_id=client["id"]) + with pytest.raises(ConnectionError): + await wait_for_message(p) + + +class TestPubSubTimeouts: + async def test_get_message_with_timeout_returns_none(self, r: redis.Redis): + p = r.pubsub() + await p.subscribe("foo") + assert await wait_for_message(p) == make_message("subscribe", "foo", 1) + assert await p.get_message(timeout=0.01) is None + + +class TestPubSubRun: + async def _subscribe(self, p, *args, **kwargs): + await p.subscribe(*args, **kwargs) + # Wait for the server to act on the subscription, to be sure that + # a subsequent publish on another connection will reach the pubsub. + while True: + message = await p.get_message(timeout=1) + if ( + message is not None + and message["type"] == "subscribe" + and message["channel"] == b"foo" + ): + return + + async def test_callbacks(self, r: redis.Redis): + def callback(message): + messages.put_nowait(message) + + messages = asyncio.Queue() + p = r.pubsub() + await self._subscribe(p, foo=callback) + task = asyncio.get_event_loop().create_task(p.run()) + await r.publish("foo", "bar") + message = await messages.get() + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + assert message == { + "channel": b"foo", + "data": b"bar", + "pattern": None, + "type": "message", + } + + async def test_exception_handler(self, r: redis.Redis): + def exception_handler_callback(e, pubsub) -> None: + assert pubsub == p + exceptions.put_nowait(e) + + exceptions = asyncio.Queue() + p = r.pubsub() + await self._subscribe(p, foo=lambda x: None) + with mock.patch.object(p, "get_message", side_effect=Exception("error")): + task = asyncio.get_event_loop().create_task( + p.run(exception_handler=exception_handler_callback) + ) + e = await exceptions.get() + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + assert str(e) == "error" diff --git a/tests/test_asyncio/test_retry.py b/tests/test_asyncio/test_retry.py new file mode 100644 index 0000000000..6e277ae38f --- /dev/null +++ b/tests/test_asyncio/test_retry.py @@ -0,0 +1,68 @@ +import pytest + +from redis.asyncio.connection import Connection, UnixDomainSocketConnection +from redis.asyncio.retry import Retry +from redis.backoff import AbstractBackoff, NoBackoff +from redis.exceptions import ConnectionError + + +class BackoffMock(AbstractBackoff): + def __init__(self): + self.reset_calls = 0 + self.calls = 0 + + def reset(self): + self.reset_calls += 1 + + def compute(self, failures): + self.calls += 1 + return 0 + + +class TestConnectionConstructorWithRetry: + "Test that the Connection constructors properly handles Retry objects" + + @pytest.mark.parametrize("retry_on_timeout", [False, True]) + @pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection]) + def test_retry_on_timeout_boolean(self, Class, retry_on_timeout): + c = Class(retry_on_timeout=retry_on_timeout) + assert c.retry_on_timeout == retry_on_timeout + assert isinstance(c.retry, Retry) + assert c.retry._retries == (1 if retry_on_timeout else 0) + + @pytest.mark.parametrize("retries", range(10)) + @pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection]) + def test_retry_on_timeout_retry(self, Class, retries: int): + retry_on_timeout = retries > 0 + c = Class(retry_on_timeout=retry_on_timeout, retry=Retry(NoBackoff(), retries)) + assert c.retry_on_timeout == retry_on_timeout + assert isinstance(c.retry, Retry) + assert c.retry._retries == retries + + +class TestRetry: + "Test that Retry calls backoff and retries the expected number of times" + + def setup_method(self, test_method): + self.actual_attempts = 0 + self.actual_failures = 0 + + async def _do(self): + self.actual_attempts += 1 + raise ConnectionError() + + async def _fail(self, error): + self.actual_failures += 1 + + @pytest.mark.parametrize("retries", range(10)) + @pytest.mark.asyncio + async def test_retry(self, retries: int): + backoff = BackoffMock() + retry = Retry(backoff, retries) + with pytest.raises(ConnectionError): + await retry.call_with_retry(self._do, self._fail) + + assert self.actual_attempts == 1 + retries + assert self.actual_failures == 1 + retries + assert backoff.reset_calls == 1 + assert backoff.calls == retries diff --git a/tests/test_asyncio/test_scripting.py b/tests/test_asyncio/test_scripting.py new file mode 100644 index 0000000000..f23a924f66 --- /dev/null +++ b/tests/test_asyncio/test_scripting.py @@ -0,0 +1,163 @@ +import pytest + +from redis import exceptions +from tests.conftest import skip_if_server_version_lt + +multiply_script = """ +local value = redis.call('GET', KEYS[1]) +value = tonumber(value) +return value * ARGV[1]""" + +msgpack_hello_script = """ +local message = cmsgpack.unpack(ARGV[1]) +local name = message['name'] +return "hello " .. name +""" +msgpack_hello_script_broken = """ +local message = cmsgpack.unpack(ARGV[1]) +local names = message['name'] +return "hello " .. name +""" + + +@pytest.mark.onlynoncluster +class TestScripting: + @pytest.fixture + async def r(self, create_redis): + redis = await create_redis() + yield redis + await redis.script_flush() + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_eval(self, r): + await r.flushdb() + await r.set("a", 2) + # 2 * 3 == 6 + assert await r.eval(multiply_script, 1, "a", 3) == 6 + + @pytest.mark.asyncio(forbid_global_loop=True) + @skip_if_server_version_lt("6.2.0") + async def test_script_flush(self, r): + await r.set("a", 2) + await r.script_load(multiply_script) + await r.script_flush("ASYNC") + + await r.set("a", 2) + await r.script_load(multiply_script) + await r.script_flush("SYNC") + + await r.set("a", 2) + await r.script_load(multiply_script) + await r.script_flush() + + with pytest.raises(exceptions.DataError): + await r.set("a", 2) + await r.script_load(multiply_script) + await r.script_flush("NOTREAL") + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_script_flush(self, r): + await r.set("a", 2) + await r.script_load(multiply_script) + await r.script_flush(None) + + with pytest.raises(exceptions.DataError): + await r.set("a", 2) + await r.script_load(multiply_script) + await r.script_flush("NOTREAL") + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_evalsha(self, r): + await r.set("a", 2) + sha = await r.script_load(multiply_script) + # 2 * 3 == 6 + assert await r.evalsha(sha, 1, "a", 3) == 6 + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_evalsha_script_not_loaded(self, r): + await r.set("a", 2) + sha = await r.script_load(multiply_script) + # remove the script from Redis's cache + await r.script_flush() + with pytest.raises(exceptions.NoScriptError): + await r.evalsha(sha, 1, "a", 3) + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_script_loading(self, r): + # get the sha, then clear the cache + sha = await r.script_load(multiply_script) + await r.script_flush() + assert await r.script_exists(sha) == [False] + await r.script_load(multiply_script) + assert await r.script_exists(sha) == [True] + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_script_object(self, r): + await r.script_flush() + await r.set("a", 2) + multiply = r.register_script(multiply_script) + precalculated_sha = multiply.sha + assert precalculated_sha + assert await r.script_exists(multiply.sha) == [False] + # Test second evalsha block (after NoScriptError) + assert await multiply(keys=["a"], args=[3]) == 6 + # At this point, the script should be loaded + assert await r.script_exists(multiply.sha) == [True] + # Test that the precalculated sha matches the one from redis + assert multiply.sha == precalculated_sha + # Test first evalsha block + assert await multiply(keys=["a"], args=[3]) == 6 + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_script_object_in_pipeline(self, r): + await r.script_flush() + multiply = r.register_script(multiply_script) + precalculated_sha = multiply.sha + assert precalculated_sha + pipe = r.pipeline() + pipe.set("a", 2) + pipe.get("a") + await multiply(keys=["a"], args=[3], client=pipe) + assert await r.script_exists(multiply.sha) == [False] + # [SET worked, GET 'a', result of multiple script] + assert await pipe.execute() == [True, b"2", 6] + # The script should have been loaded by pipe.execute() + assert await r.script_exists(multiply.sha) == [True] + # The precalculated sha should have been the correct one + assert multiply.sha == precalculated_sha + + # purge the script from redis's cache and re-run the pipeline + # the multiply script should be reloaded by pipe.execute() + await r.script_flush() + pipe = r.pipeline() + pipe.set("a", 2) + pipe.get("a") + await multiply(keys=["a"], args=[3], client=pipe) + assert await r.script_exists(multiply.sha) == [False] + # [SET worked, GET 'a', result of multiple script] + assert await pipe.execute() == [True, b"2", 6] + assert await r.script_exists(multiply.sha) == [True] + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_eval_msgpack_pipeline_error_in_lua(self, r): + msgpack_hello = r.register_script(msgpack_hello_script) + assert msgpack_hello.sha + + pipe = r.pipeline() + + # avoiding a dependency to msgpack, this is the output of + # msgpack.dumps({"name": "joe"}) + msgpack_message_1 = b"\x81\xa4name\xa3Joe" + + await msgpack_hello(args=[msgpack_message_1], client=pipe) + + assert await r.script_exists(msgpack_hello.sha) == [False] + assert (await pipe.execute())[0] == b"hello Joe" + assert await r.script_exists(msgpack_hello.sha) == [True] + + msgpack_hello_broken = r.register_script(msgpack_hello_script_broken) + + await msgpack_hello_broken(args=[msgpack_message_1], client=pipe) + with pytest.raises(exceptions.ResponseError) as excinfo: + await pipe.execute() + assert excinfo.type == exceptions.ResponseError diff --git a/tests/test_asyncio/test_sentinel.py b/tests/test_asyncio/test_sentinel.py new file mode 100644 index 0000000000..2b22d6a339 --- /dev/null +++ b/tests/test_asyncio/test_sentinel.py @@ -0,0 +1,243 @@ +import socket + +import pytest + +import redis.asyncio.sentinel +from redis import exceptions +from redis.asyncio.sentinel import ( + MasterNotFoundError, + Sentinel, + SentinelConnectionPool, + SlaveNotFoundError, +) + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture(scope="module") +def master_ip(master_host): + yield socket.gethostbyname(master_host) + + +class SentinelTestClient: + def __init__(self, cluster, id): + self.cluster = cluster + self.id = id + + async def sentinel_masters(self): + self.cluster.connection_error_if_down(self) + self.cluster.timeout_if_down(self) + return {self.cluster.service_name: self.cluster.master} + + async def sentinel_slaves(self, master_name): + self.cluster.connection_error_if_down(self) + self.cluster.timeout_if_down(self) + if master_name != self.cluster.service_name: + return [] + return self.cluster.slaves + + async def execute_command(self, *args, **kwargs): + # wrapper purely to validate the calls don't explode + from redis.asyncio.client import bool_ok + + return bool_ok + + +class SentinelTestCluster: + def __init__(self, service_name="mymaster", ip="127.0.0.1", port=6379): + self.clients = {} + self.master = { + "ip": ip, + "port": port, + "is_master": True, + "is_sdown": False, + "is_odown": False, + "num-other-sentinels": 0, + } + self.service_name = service_name + self.slaves = [] + self.nodes_down = set() + self.nodes_timeout = set() + + def connection_error_if_down(self, node): + if node.id in self.nodes_down: + raise exceptions.ConnectionError + + def timeout_if_down(self, node): + if node.id in self.nodes_timeout: + raise exceptions.TimeoutError + + def client(self, host, port, **kwargs): + return SentinelTestClient(self, (host, port)) + + +@pytest.fixture() +async def cluster(master_ip): + + cluster = SentinelTestCluster(ip=master_ip) + saved_Redis = redis.asyncio.sentinel.Redis + redis.asyncio.sentinel.Redis = cluster.client + yield cluster + redis.asyncio.sentinel.Redis = saved_Redis + + +@pytest.fixture() +def sentinel(request, cluster): + return Sentinel([("foo", 26379), ("bar", 26379)]) + + +@pytest.mark.onlynoncluster +async def test_discover_master(sentinel, master_ip): + address = await sentinel.discover_master("mymaster") + assert address == (master_ip, 6379) + + +@pytest.mark.onlynoncluster +async def test_discover_master_error(sentinel): + with pytest.raises(MasterNotFoundError): + await sentinel.discover_master("xxx") + + +@pytest.mark.onlynoncluster +async def test_discover_master_sentinel_down(cluster, sentinel, master_ip): + # Put first sentinel 'foo' down + cluster.nodes_down.add(("foo", 26379)) + address = await sentinel.discover_master("mymaster") + assert address == (master_ip, 6379) + # 'bar' is now first sentinel + assert sentinel.sentinels[0].id == ("bar", 26379) + + +@pytest.mark.onlynoncluster +async def test_discover_master_sentinel_timeout(cluster, sentinel, master_ip): + # Put first sentinel 'foo' down + cluster.nodes_timeout.add(("foo", 26379)) + address = await sentinel.discover_master("mymaster") + assert address == (master_ip, 6379) + # 'bar' is now first sentinel + assert sentinel.sentinels[0].id == ("bar", 26379) + + +@pytest.mark.onlynoncluster +async def test_master_min_other_sentinels(cluster, master_ip): + sentinel = Sentinel([("foo", 26379)], min_other_sentinels=1) + # min_other_sentinels + with pytest.raises(MasterNotFoundError): + await sentinel.discover_master("mymaster") + cluster.master["num-other-sentinels"] = 2 + address = await sentinel.discover_master("mymaster") + assert address == (master_ip, 6379) + + +@pytest.mark.onlynoncluster +async def test_master_odown(cluster, sentinel): + cluster.master["is_odown"] = True + with pytest.raises(MasterNotFoundError): + await sentinel.discover_master("mymaster") + + +@pytest.mark.onlynoncluster +async def test_master_sdown(cluster, sentinel): + cluster.master["is_sdown"] = True + with pytest.raises(MasterNotFoundError): + await sentinel.discover_master("mymaster") + + +@pytest.mark.onlynoncluster +async def test_discover_slaves(cluster, sentinel): + assert await sentinel.discover_slaves("mymaster") == [] + + cluster.slaves = [ + {"ip": "slave0", "port": 1234, "is_odown": False, "is_sdown": False}, + {"ip": "slave1", "port": 1234, "is_odown": False, "is_sdown": False}, + ] + assert await sentinel.discover_slaves("mymaster") == [ + ("slave0", 1234), + ("slave1", 1234), + ] + + # slave0 -> ODOWN + cluster.slaves[0]["is_odown"] = True + assert await sentinel.discover_slaves("mymaster") == [("slave1", 1234)] + + # slave1 -> SDOWN + cluster.slaves[1]["is_sdown"] = True + assert await sentinel.discover_slaves("mymaster") == [] + + cluster.slaves[0]["is_odown"] = False + cluster.slaves[1]["is_sdown"] = False + + # node0 -> DOWN + cluster.nodes_down.add(("foo", 26379)) + assert await sentinel.discover_slaves("mymaster") == [ + ("slave0", 1234), + ("slave1", 1234), + ] + cluster.nodes_down.clear() + + # node0 -> TIMEOUT + cluster.nodes_timeout.add(("foo", 26379)) + assert await sentinel.discover_slaves("mymaster") == [ + ("slave0", 1234), + ("slave1", 1234), + ] + + +@pytest.mark.onlynoncluster +async def test_master_for(cluster, sentinel, master_ip): + master = sentinel.master_for("mymaster", db=9) + assert await master.ping() + assert master.connection_pool.master_address == (master_ip, 6379) + + # Use internal connection check + master = sentinel.master_for("mymaster", db=9, check_connection=True) + assert await master.ping() + + +@pytest.mark.onlynoncluster +async def test_slave_for(cluster, sentinel): + cluster.slaves = [ + {"ip": "127.0.0.1", "port": 6379, "is_odown": False, "is_sdown": False}, + ] + slave = sentinel.slave_for("mymaster", db=9) + assert await slave.ping() + + +@pytest.mark.onlynoncluster +async def test_slave_for_slave_not_found_error(cluster, sentinel): + cluster.master["is_odown"] = True + slave = sentinel.slave_for("mymaster", db=9) + with pytest.raises(SlaveNotFoundError): + await slave.ping() + + +@pytest.mark.onlynoncluster +async def test_slave_round_robin(cluster, sentinel, master_ip): + cluster.slaves = [ + {"ip": "slave0", "port": 6379, "is_odown": False, "is_sdown": False}, + {"ip": "slave1", "port": 6379, "is_odown": False, "is_sdown": False}, + ] + pool = SentinelConnectionPool("mymaster", sentinel) + rotator = pool.rotate_slaves() + assert await rotator.__anext__() in (("slave0", 6379), ("slave1", 6379)) + assert await rotator.__anext__() in (("slave0", 6379), ("slave1", 6379)) + # Fallback to master + assert await rotator.__anext__() == (master_ip, 6379) + with pytest.raises(SlaveNotFoundError): + await rotator.__anext__() + + +@pytest.mark.onlynoncluster +async def test_ckquorum(cluster, sentinel): + assert await sentinel.sentinel_ckquorum("mymaster") + + +@pytest.mark.onlynoncluster +async def test_flushconfig(cluster, sentinel): + assert await sentinel.sentinel_flushconfig() + + +@pytest.mark.onlynoncluster +async def test_reset(cluster, sentinel): + cluster.master["is_odown"] = True + assert await sentinel.sentinel_reset("mymaster") diff --git a/tox.ini b/tox.ini index abebf004ba..6ebcfc8143 100644 --- a/tox.ini +++ b/tox.ini @@ -10,7 +10,7 @@ markers = [tox] minversion = 3.2.0 requires = tox-docker -envlist = {standalone,cluster}-{plain,hiredis,ocsp}-{py36,py37,py38,py39,pypy3},linters,docs +envlist = {standalone,cluster}-{plain,hiredis,ocsp}-{uvloop,asyncio}-{py36,py37,py38,py39,pypy3},linters,docs [docker:master] name = master @@ -133,8 +133,10 @@ setenv = CLUSTER_URL = "redis://localhost:16379/0" run_before = {toxinidir}/docker/stunnel/create_certs.sh commands = - standalone: pytest --cov=./ --cov-report=xml:coverage_redis.xml -W always -m 'not onlycluster' {posargs} - cluster: pytest --cov=./ --cov-report=xml:coverage_cluster.xml -W always -m 'not onlynoncluster and not redismod' --redis-url={env:CLUSTER_URL:} {posargs} + standalone: pytest --cov=./ --cov-report=xml:coverage_redis.xml --asyncio-mode=auto -W always -m 'not onlycluster' {posargs} + standalone-uvloop: pytest --cov=./ --cov-report=xml:coverage_redis.xml --asyncio-mode=auto -W always -m 'not onlycluster' --uvloop {posargs} + cluster: pytest --cov=./ --cov-report=xml:coverage_cluster.xml --asyncio-mode=auto -W always -m 'not onlynoncluster and not redismod' --redis-url={env:CLUSTER_URL:} {posargs} + cluster-uvloop: pytest --cov=./ --cov-report=xml:coverage_redis.xml --asyncio-mode=auto -W always -m 'not onlycluster' --uvloop {posargs} [testenv:devenv] skipsdist = true From 6e629adb530684ab4c4508cd45202e2ab150f6ac Mon Sep 17 00:00:00 2001 From: Andrew-Chen-Wang Date: Mon, 24 Jan 2022 19:40:41 -0500 Subject: [PATCH 03/24] Lint --- redis/__init__.py | 10 +-- redis/asyncio/__init__.py | 3 +- redis/asyncio/client.py | 4 +- redis/asyncio/connection.py | 12 +++- redis/asyncio/sentinel.py | 12 ++-- redis/commands/core.py | 83 +++++++++++----------- tests/test_asyncio/conftest.py | 4 +- tests/test_asyncio/test_connection.py | 2 +- tests/test_asyncio/test_connection_pool.py | 29 ++++---- tests/test_asyncio/test_lock.py | 2 +- tests/test_asyncio/test_monitor.py | 6 +- tests/test_asyncio/test_pipeline.py | 4 +- tests/test_asyncio/test_pubsub.py | 8 +-- tests/test_asyncio/test_scripting.py | 11 --- tox.ini | 2 +- 15 files changed, 86 insertions(+), 106 deletions(-) diff --git a/redis/__init__.py b/redis/__init__.py index 35044be29d..b7560a6715 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -1,10 +1,5 @@ import sys -if sys.version_info >= (3, 8): - from importlib import metadata -else: - import importlib_metadata as metadata - from redis.client import Redis, StrictRedis from redis.cluster import RedisCluster from redis.connection import ( @@ -37,6 +32,11 @@ ) from redis.utils import from_url +if sys.version_info >= (3, 8): + from importlib import metadata +else: + import importlib_metadata as metadata + def int_or_str(value): try: diff --git a/redis/asyncio/__init__.py b/redis/asyncio/__init__.py index 3959b9acee..c655c7da4b 100644 --- a/redis/asyncio/__init__.py +++ b/redis/asyncio/__init__.py @@ -6,13 +6,13 @@ SSLConnection, UnixDomainSocketConnection, ) -from redis.asyncio.utils import from_url from redis.asyncio.sentinel import ( Sentinel, SentinelConnectionPool, SentinelManagedConnection, SentinelManagedSSLConnection, ) +from redis.asyncio.utils import from_url from redis.exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, @@ -29,7 +29,6 @@ WatchError, ) - __all__ = [ "AuthenticationError", "AuthenticationWrongNumberOfArgsError", diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index d69279c4b7..62d81ce49d 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -5,6 +5,7 @@ import re import warnings from typing import ( + TYPE_CHECKING, Any, AsyncIterator, Awaitable, @@ -22,7 +23,6 @@ TypeVar, Union, cast, - TYPE_CHECKING, ) from redis.asyncio.connection import ( @@ -33,8 +33,8 @@ ) from redis.commands import ( AsyncCoreCommands, - RedisModuleCommands, AsyncSentinelCommands, + RedisModuleCommands, list_or_args, ) from redis.compat import Protocol, TypedDict diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index cefb5f2cc2..b043c18c4b 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -379,7 +379,9 @@ def on_disconnect(self): async def can_read(self, timeout: float): return self._buffer and bool(await self._buffer.can_read(timeout)) - async def read_response(self, disable_decoding: bool = False) -> Union[EncodableT, ResponseError, None]: + async def read_response( + self, disable_decoding: bool = False + ) -> Union[EncodableT, ResponseError, None]: if not self._buffer or not self.encoder: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) raw = await self._buffer.readline() @@ -421,7 +423,9 @@ async def read_response(self, disable_decoding: bool = False) -> Union[Encodable length = int(response) if length == -1: return None - response = [(await self.read_response(disable_decoding)) for _ in range(length)] + response = [ + (await self.read_response(disable_decoding)) for _ in range(length) + ] if isinstance(response, bytes) and disable_decoding is False: response = self.encoder.decode(response) return response @@ -504,7 +508,9 @@ async def read_from_socket( return False raise ConnectionError(f"Error while reading from socket: {ex.args}") - async def read_response(self, disable_decoding: bool = False) -> Union[EncodableT, List[EncodableT]]: + async def read_response( + self, disable_decoding: bool = False + ) -> Union[EncodableT, List[EncodableT]]: if not self._stream or not self._reader: self.on_disconnect() raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py index 824802754f..eb707e037e 100644 --- a/redis/asyncio/sentinel.py +++ b/redis/asyncio/sentinel.py @@ -5,15 +5,13 @@ from redis.asyncio.client import Redis from redis.asyncio.connection import ( - Connection, ConnectionPool, EncodableT, SSLConnection, + Connection, + ConnectionPool, + EncodableT, + SSLConnection, ) from redis.commands import SentinelCommands -from redis.exceptions import ( - ConnectionError, - ReadOnlyError, - ResponseError, - TimeoutError, -) +from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError from redis.utils import str_if_bytes diff --git a/redis/commands/core.py b/redis/commands/core.py index e7e08d1a66..1cf7c28726 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -11,9 +11,10 @@ Awaitable, Callable, Iterable, + Iterator, Mapping, Sequence, - Union, Iterator, + Union, ) from redis.compat import Literal @@ -386,7 +387,7 @@ def role(self) -> ResponseT: """ return self.execute_command("ROLE") - def client_kill(self, address: str, **kwargs) -> ResponseT: + def client_kill(self, address: str, **kwargs) -> ResponseT: """Disconnects the client at ``address`` (ip:port) For more information check https://redis.io/commands/client-kill @@ -665,7 +666,7 @@ def client_unblock( args.append(b"ERROR") return self.execute_command(*args, **kwargs) - def client_pause(self, timeout: int , all: bool = True, **kwargs) -> ResponseT: + def client_pause(self, timeout: int, all: bool = True, **kwargs) -> ResponseT: """ Suspend all the Redis clients for the specified amount of time :param timeout: milliseconds to pause clients @@ -978,7 +979,9 @@ def memory_malloc_stats(self, **kwargs) -> ResponseT: """ return self.execute_command("MEMORY MALLOC-STATS", **kwargs) - def memory_usage(self, key: KeyT, samples: int | None = None, **kwargs) -> ResponseT: + def memory_usage( + self, key: KeyT, samples: int | None = None, **kwargs + ) -> ResponseT: """ Return the total memory usage for key, its value and associated administrative overheads. @@ -1061,7 +1064,9 @@ def shutdown(self, save: bool = False, nosave: bool = False, **kwargs) -> None: return raise RedisError("SHUTDOWN seems to have failed.") - def slaveof(self, host: str | None = None, port: int | None = None, **kwargs) -> ResponseT: + def slaveof( + self, host: str | None = None, port: int | None = None, **kwargs + ) -> ResponseT: """ Set the server to be a replicated slave of the instance identified by the ``host`` and ``port``. If called without arguments, the @@ -1143,7 +1148,9 @@ async def memory_doctor(self, **kwargs) -> None: async def memory_help(self, **kwargs) -> None: return super().memory_help(**kwargs) - async def shutdown(self, save: bool = False, nosave: bool = False, **kwargs) -> None: + async def shutdown( + self, save: bool = False, nosave: bool = False, **kwargs + ) -> None: """Shutdown the Redis server. If Redis has persistence configured, data will be flushed before shutdown. If the "save" option is set, a data flush will be attempted even if there is no persistence @@ -2429,7 +2436,7 @@ def scan_iter( match: PatternT | None = None, count: int | None = None, _type: str | None = None, - **kwargs + **kwargs, ) -> Iterator: """ Make an iterator using the SCAN command so that the client doesn't @@ -2601,7 +2608,7 @@ async def scan_iter( match: PatternT | None = None, count: int | None = None, _type: str | None = None, - **kwargs + **kwargs, ) -> AsyncIterator: """ Make an iterator using the SCAN command so that the client doesn't @@ -2785,7 +2792,9 @@ def smembers(self, name: KeyT) -> ResponseT: """ return self.execute_command("SMEMBERS", name) - def smismember(self, name: KeyT, values: Sequence[EncodableT], *args: EncodableT) -> ResponseT: + def smismember( + self, name: KeyT, values: Sequence[EncodableT], *args: EncodableT + ) -> ResponseT: """ Return whether each value in ``values`` is a member of the set ``name`` as a list of ``bool`` in the order of ``values`` @@ -3134,7 +3143,10 @@ def xgroup_destroy(self, name: KeyT, groupname: GroupT) -> ResponseT: return self.execute_command("XGROUP DESTROY", name, groupname) def xgroup_createconsumer( - self, name: KeyT, groupname: GroupT, consumername: ConsumerT, + self, + name: KeyT, + groupname: GroupT, + consumername: ConsumerT, ) -> ResponseT: """ Consumers in a consumer group are auto-created every time a new @@ -3542,12 +3554,7 @@ def zcard(self, name: KeyT): """ return self.execute_command("ZCARD", name) - def zcount( - self, - name: KeyT, - min: ZScoreBoundT, - max: ZScoreBoundT - ) -> ResponseT: + def zcount(self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT) -> ResponseT: """ Returns the number of elements in the sorted set at key ``name`` with a score between ``min`` and ``max``. @@ -3849,7 +3856,7 @@ def zrevrange( start: int, end: int, withscores: bool = False, - score_cast_func: type | Callable = float + score_cast_func: type | Callable = float, ) -> ResponseT: """ Return a range of values from sorted set ``name`` between @@ -3926,7 +3933,7 @@ def zrangebylex( min: EncodableT, max: EncodableT, start: int | None = None, - num: int | None = None + num: int | None = None, ) -> ResponseT: """ Return the lexicographical range of values from sorted set ``name`` @@ -3950,7 +3957,7 @@ def zrevrangebylex( max: EncodableT, min: EncodableT, start: int | None = None, - num: int | None = None + num: int | None = None, ) -> ResponseT: """ Return the reversed lexicographical range of values from sorted set @@ -4053,12 +4060,7 @@ def zrem(self, name: KeyT, *values: EncodableT) -> ResponseT: """ return self.execute_command("ZREM", name, *values) - def zremrangebylex( - self, - name: KeyT, - min: EncodableT, - max: EncodableT - ) -> ResponseT: + def zremrangebylex(self, name: KeyT, min: EncodableT, max: EncodableT) -> ResponseT: """ Remove all elements in the sorted set ``name`` between the lexicographical range specified by ``min`` and ``max``. @@ -4069,12 +4071,7 @@ def zremrangebylex( """ return self.execute_command("ZREMRANGEBYLEX", name, min, max) - def zremrangebyrank( - self, - name: KeyT, - min: int, - max: int - ) -> ResponseT: + def zremrangebyrank(self, name: KeyT, min: int, max: int) -> ResponseT: """ Remove all elements in the sorted set ``name`` with ranks between ``min`` and ``max``. Values are 0-based, ordered from smallest score @@ -4085,7 +4082,9 @@ def zremrangebyrank( """ return self.execute_command("ZREMRANGEBYRANK", name, min, max) - def zremrangebyscore(self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT) -> ResponseT: + def zremrangebyscore( + self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT + ) -> ResponseT: """ Remove all elements in the sorted set ``name`` with scores between ``min`` and ``max``. Returns the number of elements removed. @@ -4115,7 +4114,7 @@ def zunion( self, keys: Sequence[KeyT] | Mapping[AnyKeyT, float], aggregate: str | None = None, - withscores: bool = False + withscores: bool = False, ) -> ResponseT: """ Return the union of multiple sorted sets specified by ``keys``. @@ -4131,7 +4130,7 @@ def zunionstore( self, dest: KeyT, keys: Sequence[KeyT] | Mapping[AnyKeyT, float], - aggregate: str | None = None + aggregate: str | None = None, ) -> ResponseT: """ Union multiple sorted sets specified by ``keys`` into @@ -4306,7 +4305,7 @@ def hset( name: KeyT, key: FieldT = None, value: EncodableT = None, - mapping: Mapping[AnyFieldT, EncodableT] = None + mapping: Mapping[AnyFieldT, EncodableT] = None, ) -> ResponseT: """ Set ``key`` to ``value`` within hash ``name``, @@ -4437,10 +4436,7 @@ class ScriptCommands(CommandsProtocol): """ def eval( - self, - script: ScriptTextT, - numkeys: int, - *keys_and_args: EncodableT + self, script: ScriptTextT, numkeys: int, *keys_and_args: EncodableT ) -> ResponseT: """ Execute the Lua ``script``, specifying the ``numkeys`` the script @@ -4489,8 +4485,7 @@ def script_debug(self, *args) -> None: ) def script_flush( - self, - sync_type: Literal["SYNC"] | Literal["ASYNC"] = None + self, sync_type: Literal["SYNC"] | Literal["ASYNC"] = None ) -> ResponseT: """Flush all scripts from the script cache. ``sync_type`` is by default SYNC (synchronous) but it can also be @@ -5060,7 +5055,7 @@ def __call__( self, keys: Sequence[KeyT] | None = None, args: Iterable[EncodableT] | None = None, - client: Redis | None = None + client: Redis | None = None, ): """Execute the script, passing any required ``args``""" keys = keys or [] @@ -5134,7 +5129,9 @@ class BitFieldOperation: Command builder for BITFIELD commands. """ - def __init__(self, client: Redis | AsyncRedis, key: str, default_overflow: str | None = None): + def __init__( + self, client: Redis | AsyncRedis, key: str, default_overflow: str | None = None + ): self.client = client self.key = key self._default_overflow = default_overflow diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 0657088eb5..d9d95561d4 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -2,10 +2,10 @@ import asyncio import random -from packaging.version import Version from urllib.parse import urlparse import pytest +from packaging.version import Version import redis.asyncio as redis from redis.asyncio.client import Monitor @@ -15,8 +15,8 @@ PythonParser, parse_url, ) - from tests.conftest import REDIS_INFO + from .compat import mock diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 25ec0d8a49..78dbe4d14f 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -25,7 +25,7 @@ async def test_invalid_response(create_redis): with mock.patch.object(parser._buffer, "readline", readline_mock): with pytest.raises(InvalidResponse) as cm: await parser.read_response() - assert str(cm.value) == "Protocol Error: %r" % raw + assert str(cm.value) == f"Protocol Error: {raw!r}" @skip_if_server_version_lt("4.0.0") diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index 36b5226a21..3cf1b7c915 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -6,9 +6,8 @@ import redis.asyncio as redis from redis.asyncio.connection import Connection, to_bool -from redis import exceptions - from tests.conftest import skip_if_redis_enterprise, skip_if_server_version_lt + from .compat import mock from .test_pubsub import wait_for_message @@ -68,7 +67,7 @@ async def test_max_connections(self, master_host): pool = self.get_pool(max_connections=2, connection_kwargs=connection_kwargs) await pool.get_connection("_") await pool.get_connection("_") - with pytest.raises(exceptions.ConnectionError): + with pytest.raises(redis.ConnectionError): await pool.get_connection("_") async def test_reuse_previously_released_connection(self, master_host): @@ -159,7 +158,7 @@ async def test_connection_pool_blocks_until_timeout(self, master_host): await pool.get_connection("_") start = asyncio.get_event_loop().time() - with pytest.raises(exceptions.ConnectionError): + with pytest.raises(redis.ConnectionError): await pool.get_connection("_") # we should have waited at least 0.1 seconds assert asyncio.get_event_loop().time() - start >= 0.1 @@ -346,9 +345,7 @@ def test_boolean_parsing(self): assert expected is to_bool(value) def test_client_name_in_querystring(self): - pool = redis.ConnectionPool.from_url( - "redis://location?client_name=test-client" - ) + pool = redis.ConnectionPool.from_url("redis://location?client_name=test-client") assert pool.connection_kwargs["client_name"] == "test-client" def test_invalid_extra_typed_querystring_options(self): @@ -455,9 +452,7 @@ def test_db_in_querystring(self): } def test_client_name_in_querystring(self): - pool = redis.ConnectionPool.from_url( - "redis://location?client_name=test-client" - ) + pool = redis.ConnectionPool.from_url("redis://location?client_name=test-client") assert pool.connection_kwargs["client_name"] == "test-client" def test_extra_querystring_options(self): @@ -507,7 +502,7 @@ async def test_on_connect_error(self): # 9999 databases ;) bad_connection = redis.Redis(db=9999) # an error should be raised on connect - with pytest.raises(exceptions.RedisError): + with pytest.raises(redis.RedisError): await bad_connection.info() pool = bad_connection.connection_pool assert len(pool._available_connections) == 1 @@ -521,7 +516,7 @@ async def test_busy_loading_disconnects_socket(self, r): If Redis raises a LOADING error, the connection should be disconnected and a BusyLoadingError raised """ - with pytest.raises(exceptions.BusyLoadingError): + with pytest.raises(redis.BusyLoadingError): await r.execute_command("DEBUG", "ERROR", "LOADING fake message") if r.connection: assert not r.connection._reader @@ -535,7 +530,7 @@ async def test_busy_loading_from_pipeline_immediate_command(self, r): command immediately, like WATCH does. """ pipe = r.pipeline() - with pytest.raises(exceptions.BusyLoadingError): + with pytest.raises(redis.BusyLoadingError): await pipe.immediate_execute_command( "DEBUG", "ERROR", "LOADING fake message" ) @@ -554,7 +549,7 @@ async def test_busy_loading_from_pipeline(self, r): """ pipe = r.pipeline() pipe.execute_command("DEBUG", "ERROR", "LOADING fake message") - with pytest.raises(exceptions.BusyLoadingError): + with pytest.raises(redis.BusyLoadingError): await pipe.execute() pool = r.connection_pool assert not pipe.connection @@ -565,7 +560,7 @@ async def test_busy_loading_from_pipeline(self, r): @skip_if_redis_enterprise() async def test_read_only_error(self, r): """READONLY errors get turned in ReadOnlyError exceptions""" - with pytest.raises(exceptions.ReadOnlyError): + with pytest.raises(redis.ReadOnlyError): await r.execute_command("DEBUG", "ERROR", "READONLY blah blah") def test_connect_from_url_tcp(self): @@ -594,7 +589,7 @@ async def test_connect_no_auth_supplied_when_required(self, r): AuthenticationError should be raised when the server requires a password but one isn't supplied. """ - with pytest.raises(exceptions.AuthenticationError): + with pytest.raises(redis.AuthenticationError): await r.execute_command( "DEBUG", "ERROR", "ERR Client sent AUTH, but no password is set" ) @@ -602,7 +597,7 @@ async def test_connect_no_auth_supplied_when_required(self, r): @skip_if_redis_enterprise() async def test_connect_invalid_password_supplied(self, r): """AuthenticationError should be raised when sending the wrong password""" - with pytest.raises(exceptions.AuthenticationError): + with pytest.raises(redis.AuthenticationError): await r.execute_command("DEBUG", "ERROR", "ERR invalid password") diff --git a/tests/test_asyncio/test_lock.py b/tests/test_asyncio/test_lock.py index d5c2081493..f497fac0c0 100644 --- a/tests/test_asyncio/test_lock.py +++ b/tests/test_asyncio/test_lock.py @@ -2,8 +2,8 @@ import pytest -from redis.exceptions import LockError, LockNotOwnedError from redis.asyncio.lock import Lock +from redis.exceptions import LockError, LockNotOwnedError pytestmark = pytest.mark.asyncio diff --git a/tests/test_asyncio/test_monitor.py b/tests/test_asyncio/test_monitor.py index baeb9cc445..783ba262b0 100644 --- a/tests/test_asyncio/test_monitor.py +++ b/tests/test_asyncio/test_monitor.py @@ -1,9 +1,7 @@ import pytest -from tests.conftest import ( - skip_if_redis_enterprise, - skip_ifnot_redis_enterprise, -) +from tests.conftest import skip_if_redis_enterprise, skip_ifnot_redis_enterprise + from .conftest import wait_for_command pytestmark = pytest.mark.asyncio diff --git a/tests/test_asyncio/test_pipeline.py b/tests/test_asyncio/test_pipeline.py index 8011c7258b..4eb4daccae 100644 --- a/tests/test_asyncio/test_pipeline.py +++ b/tests/test_asyncio/test_pipeline.py @@ -1,8 +1,8 @@ import pytest import redis - from tests.conftest import skip_if_server_version_lt + from .conftest import wait_for_command pytestmark = pytest.mark.asyncio @@ -351,7 +351,7 @@ async def test_exec_error_in_no_transaction_pipeline_unicode_command(self, r): with pytest.raises(redis.ResponseError) as ex: await pipe.execute() - expected = "Command # 1 (LLEN %s) of pipeline caused error: " % key + expected = f"Command # 1 (LLEN {key}) of pipeline caused error: " assert str(ex.value).startswith(expected) assert await r.get(key) == b"1" diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index 0fb8585670..55958c18df 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -6,9 +6,9 @@ import redis.asyncio as redis from redis.exceptions import ConnectionError from redis.typing import EncodableT +from tests.conftest import skip_if_server_version_lt from .compat import mock -from tests.conftest import skip_if_server_version_lt pytestmark = pytest.mark.asyncio(forbid_global_loop=True) @@ -57,7 +57,7 @@ def make_subscribe_test_data(pubsub, type): "unsub_func": pubsub.punsubscribe, "keys": ["f*", "b*", "uni" + chr(4456) + "*"], } - assert False, "invalid subscribe type: %s" % type + assert False, f"invalid subscribe type: {type}" class TestPubSubSubscribeUnsubscribe: @@ -548,9 +548,7 @@ async def test_send_pubsub_ping_message(self, r: redis.Redis): @pytest.mark.onlynoncluster class TestPubSubConnectionKilled: @skip_if_server_version_lt("3.0.0") - async def test_connection_error_raised_when_connection_dies( - self, r: redis.Redis - ): + async def test_connection_error_raised_when_connection_dies(self, r: redis.Redis): p = r.pubsub() await p.subscribe("foo") assert await wait_for_message(p) == make_message("subscribe", "foo", 1) diff --git a/tests/test_asyncio/test_scripting.py b/tests/test_asyncio/test_scripting.py index f23a924f66..5d01f25ff5 100644 --- a/tests/test_asyncio/test_scripting.py +++ b/tests/test_asyncio/test_scripting.py @@ -55,17 +55,6 @@ async def test_script_flush(self, r): await r.script_load(multiply_script) await r.script_flush("NOTREAL") - @pytest.mark.asyncio(forbid_global_loop=True) - async def test_script_flush(self, r): - await r.set("a", 2) - await r.script_load(multiply_script) - await r.script_flush(None) - - with pytest.raises(exceptions.DataError): - await r.set("a", 2) - await r.script_load(multiply_script) - await r.script_flush("NOTREAL") - @pytest.mark.asyncio(forbid_global_loop=True) async def test_evalsha(self, r): await r.set("a", 2) diff --git a/tox.ini b/tox.ini index 6ebcfc8143..8a1822e713 100644 --- a/tox.ini +++ b/tox.ini @@ -151,7 +151,7 @@ deps_files = dev_requirements.txt docker = commands = flake8 - black --target-version py36 --check --diff . + black --target-version py37 --check --diff . isort --check-only --diff . vulture redis whitelist.py --min-confidence 80 flynt --fail-on-change --dry-run . From b90d7d9fdea39f31ece394771db815cab5108d57 Mon Sep 17 00:00:00 2001 From: Andrew-Chen-Wang Date: Mon, 24 Jan 2022 20:02:24 -0500 Subject: [PATCH 04/24] Abstract Redis and fix CI --- .github/workflows/integration.yaml | 1 + dev_requirements.txt | 1 + redis/asyncio/client.py | 783 +---------------------------- redis/asyncio/connection.py | 4 +- redis/asyncio/utils.py | 4 +- redis/client.py | 28 +- setup.py | 1 + whitelist.py | 3 + 8 files changed, 36 insertions(+), 789 deletions(-) diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index 4b8b5fa73e..8bb40a015c 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -48,6 +48,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: run tests run: | + pip install -U setuptools wheel pip install -r dev_requirements.txt bash docker/stunnel/create_certs.sh tox -e ${{matrix.test-type}}-${{matrix.connection-type}} diff --git a/dev_requirements.txt b/dev_requirements.txt index 637d93aaae..1011313a76 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -2,6 +2,7 @@ black==21.11b1 flake8==4.0.1 flynt~=0.69.0 isort==5.10.1 +mock==4.0.3 pytest==6.2.5 pytest-timeout==2.0.1 pytest-asyncio==0.17.2 diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 62d81ce49d..b738f0c463 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -1,6 +1,5 @@ import asyncio import copy -import datetime import inspect import re import warnings @@ -31,6 +30,13 @@ SSLConnection, UnixDomainSocketConnection, ) +from redis.client import ( + EMPTY_RESPONSE, + NEVER_DECODE, + AbstractRedis, + CaseInsensitiveDict, + bool_ok, +) from redis.commands import ( AsyncCoreCommands, AsyncSentinelCommands, @@ -41,7 +47,6 @@ from redis.exceptions import ( ConnectionError, ExecAbortError, - ModuleError, PubSubError, RedisError, ResponseError, @@ -61,621 +66,6 @@ if TYPE_CHECKING: from redis.commands.core import Script -SYM_EMPTY = b"" -EMPTY_RESPONSE = "EMPTY_RESPONSE" - -# some responses (ie. dump) are binary, and just meant to never be decoded -NEVER_DECODE = "NEVER_DECODE" - - -def timestamp_to_datetime(response): - """Converts a unix timestamp to a Python datetime object""" - if not response: - return None - try: - response = int(response) - except ValueError: - return None - return datetime.datetime.fromtimestamp(response) - - -def string_keys_to_dict(key_string, callback): - return dict.fromkeys(key_string.split(), callback) - - -class CaseInsensitiveDict(dict): - """Case insensitive dict implementation. Assumes string keys only.""" - - def __init__(self, data: Mapping[str, Any]): - for k, v in data.items(): - self[k.upper()] = v - - def __contains__(self, k): - return super().__contains__(k.upper()) - - def __delitem__(self, k): - super().__delitem__(k.upper()) - - def __getitem__(self, k): - return super().__getitem__(k.upper()) - - def get(self, k, default=None): - return super().get(k.upper(), default) - - def __setitem__(self, k, v): - super().__setitem__(k.upper(), v) - - def update(self, data): - data = CaseInsensitiveDict(data) - super().update(data) - - -def parse_debug_object(response): - """Parse the results of Redis's DEBUG OBJECT command into a Python dict""" - # The 'type' of the object is the first item in the response, but isn't - # prefixed with a name - response = str_if_bytes(response) - response = "type:" + response - response = dict(kv.split(":") for kv in response.split()) - - # parse some expected int values from the string response - # note: this cmd isn't spec'd so these may not appear in all redis versions - int_fields = ("refcount", "serializedlength", "lru", "lru_seconds_idle") - for field in int_fields: - if field in response: - response[field] = int(response[field]) - - return response - - -def parse_object(response, infotype): - """Parse the results of an OBJECT command""" - if infotype in ("idletime", "refcount"): - return int_or_none(response) - return response - - -def parse_info(response): - """Parse the result of Redis's INFO command into a Python dict""" - info: Dict[str, Any] = {} - response = str_if_bytes(response) - - def get_value(value): - if "," not in value or "=" not in value: - try: - if "." in value: - return float(value) - else: - return int(value) - except ValueError: - return value - else: - sub_dict = {} - for item in value.split(","): - k, v = item.rsplit("=", 1) - sub_dict[k] = get_value(v) - return sub_dict - - for line in response.splitlines(): - if line and not line.startswith("#"): - if line.find(":") != -1: - # Split, the info fields keys and values. - # Note that the value may contain ':'. but the 'host:' - # pseudo-command is the only case where the key contains ':' - key, value = line.split(":", 1) - if key == "cmdstat_host": - key, value = line.rsplit(":", 1) - - if key == "module": - # Hardcode a list for key 'modules' since there could be - # multiple lines that started with 'module' - info.setdefault("modules", []).append(get_value(value)) - else: - info[key] = get_value(value) - else: - # if the line isn't splittable, append it to the "__raw__" key - info.setdefault("__raw__", []).append(line) - - return info - - -def parse_memory_stats(response, **kwargs): - """Parse the results of MEMORY STATS""" - stats = pairs_to_dict(response, decode_keys=True, decode_string_values=True) - for key, value in stats.items(): - if key.startswith("db."): - stats[key] = pairs_to_dict( - value, decode_keys=True, decode_string_values=True - ) - return stats - - -SENTINEL_STATE_TYPES = { - "can-failover-its-master": int, - "config-epoch": int, - "down-after-milliseconds": int, - "failover-timeout": int, - "info-refresh": int, - "last-hello-message": int, - "last-ok-ping-reply": int, - "last-ping-reply": int, - "last-ping-sent": int, - "master-link-down-time": int, - "master-port": int, - "num-other-sentinels": int, - "num-slaves": int, - "o-down-time": int, - "pending-commands": int, - "parallel-syncs": int, - "port": int, - "quorum": int, - "role-reported-time": int, - "s-down-time": int, - "slave-priority": int, - "slave-repl-offset": int, - "voted-leader-epoch": int, -} - - -def parse_sentinel_state(item): - result = pairs_to_dict_typed(item, SENTINEL_STATE_TYPES) - flags = set(result["flags"].split(",")) - for name, flag in ( - ("is_master", "master"), - ("is_slave", "slave"), - ("is_sdown", "s_down"), - ("is_odown", "o_down"), - ("is_sentinel", "sentinel"), - ("is_disconnected", "disconnected"), - ("is_master_down", "master_down"), - ): - result[name] = flag in flags - return result - - -def parse_sentinel_master(response): - return parse_sentinel_state(map(str_if_bytes, response)) - - -def parse_sentinel_masters(response): - result = {} - for item in response: - state = parse_sentinel_state(map(str_if_bytes, item)) - result[state["name"]] = state - return result - - -def parse_sentinel_slaves_and_sentinels(response): - return [parse_sentinel_state(map(str_if_bytes, item)) for item in response] - - -def parse_sentinel_get_master(response): - return response and (response[0], int(response[1])) or None - - -def pairs_to_dict(response, decode_keys=False, decode_string_values=False): - """Create a dict given a list of key/value pairs""" - if response is None: - return {} - if decode_keys or decode_string_values: - # the iter form is faster, but I don't know how to make that work - # with a str_if_bytes() map - keys = response[::2] - if decode_keys: - keys = map(str_if_bytes, keys) - values = response[1::2] - if decode_string_values: - values = map(str_if_bytes, values) - return dict(zip(keys, values)) - else: - it = iter(response) - return dict(zip(it, it)) - - -def pairs_to_dict_typed(response, type_info): - it = iter(response) - result = {} - for key, value in zip(it, it): - if key in type_info: - try: - value = type_info[key](value) - except Exception: - # if for some reason the value can't be coerced, just use - # the string value - pass - result[key] = value - return result - - -def zset_score_pairs(response, **options): - """ - If ``withscores`` is specified in the options, return the response as - a list of (value, score) pairs - """ - if not response or not options.get("withscores"): - return response - score_cast_func = options.get("score_cast_func", float) - it = iter(response) - return list(zip(it, map(score_cast_func, it))) - - -def sort_return_tuples(response, **options): - """ - If ``groups`` is specified, return the response as a list of - n-element tuples with n being the value found in options['groups'] - """ - if not response or not options.get("groups"): - return response - n = options["groups"] - return list(zip(*(response[i::n] for i in range(n)))) - - -def int_or_none(response): - if response is None: - return None - return int(response) - - -def parse_stream_list(response): - if response is None: - return None - data = [] - for r in response: - if r is not None: - data.append((r[0], pairs_to_dict(r[1]))) - else: - data.append((None, None)) - return data - - -def pairs_to_dict_with_str_keys(response): - return pairs_to_dict(response, decode_keys=True) - - -def parse_list_of_dicts(response): - return list(map(pairs_to_dict_with_str_keys, response)) - - -def parse_xclaim(response, **options): - if options.get("parse_justid", False): - return response - return parse_stream_list(response) - - -def parse_xautoclaim(response, **options): - if options.get("parse_justid", False): - return response[1] - return parse_stream_list(response[1]) - - -def parse_xinfo_stream(response, **options): - data = pairs_to_dict(response, decode_keys=True) - if not options.get("full", False): - first = data["first-entry"] - if first is not None: - data["first-entry"] = (first[0], pairs_to_dict(first[1])) - last = data["last-entry"] - if last is not None: - data["last-entry"] = (last[0], pairs_to_dict(last[1])) - else: - data["entries"] = {_id: pairs_to_dict(entry) for _id, entry in data["entries"]} - data["groups"] = [ - pairs_to_dict(group, decode_keys=True) for group in data["groups"] - ] - return data - - -def parse_xread(response): - if response is None: - return [] - return [[r[0], parse_stream_list(r[1])] for r in response] - - -def parse_xpending(response, **options): - if options.get("parse_detail", False): - return parse_xpending_range(response) - consumers = [{"name": n, "pending": int(p)} for n, p in response[3] or []] - return { - "pending": response[0], - "min": response[1], - "max": response[2], - "consumers": consumers, - } - - -def parse_xpending_range(response): - k = ("message_id", "consumer", "time_since_delivered", "times_delivered") - return [dict(zip(k, r)) for r in response] - - -def float_or_none(response): - if response is None: - return None - return float(response) - - -def bool_ok(response): - return str_if_bytes(response) == "OK" - - -def parse_zadd(response, **options): - if response is None: - return None - if options.get("as_score"): - return float(response) - return int(response) - - -def parse_client_list(response, **options): - clients = [] - for c in str_if_bytes(response).splitlines(): - # Values might contain '=' - clients.append(dict(pair.split("=", 1) for pair in c.split(" "))) - return clients - - -def parse_config_get(response, **options): - response = [str_if_bytes(i) if i is not None else None for i in response] - return response and pairs_to_dict(response) or {} - - -def parse_scan(response, **options): - cursor, r = response - return int(cursor), r - - -def parse_hscan(response, **options): - cursor, r = response - return int(cursor), r and pairs_to_dict(r) or {} - - -def parse_zscan(response, **options): - score_cast_func = options.get("score_cast_func", float) - cursor, r = response - it = iter(r) - return int(cursor), list(zip(it, map(score_cast_func, it))) - - -def parse_zmscore(response, **options): - # zmscore: list of scores (double precision floating point number) or nil - return [float(score) if score is not None else None for score in response] - - -def parse_slowlog_get(response, **options): - space: Union[str, bytes] = " " if options.get("decode_responses", False) else b" " - - def parse_item(item): - result = { - "id": item[0], - "start_time": int(item[1]), - "duration": int(item[2]), - } - # Redis Enterprise injects another entry at index [3], which has - # the complexity info (i.e. the value N in case the command has - # an O(N) complexity) instead of the command. - if isinstance(item[3], list): - result["command"] = space.join(item[3]) - else: - result["complexity"] = item[3] - result["command"] = space.join(item[4]) - return result - - return [parse_item(item) for item in response] - - -def parse_stralgo(response, **options): - """ - Parse the response from `STRALGO` command. - Without modifiers the returned value is string. - When LEN is given the command returns the length of the result - (i.e integer). - When IDX is given the command returns a dictionary with the LCS - length and all the ranges in both the strings, start and end - offset for each string, where there are matches. - When WITHMATCHLEN is given, each array representing a match will - also have the length of the match at the beginning of the array. - """ - if options.get("len", False): - return int(response) - if options.get("idx", False): - if options.get("withmatchlen", False): - matches = [ - [(int(match[-1]))] + list(map(tuple, match[:-1])) - for match in response[1] - ] - else: - matches = [list(map(tuple, match)) for match in response[1]] - return { - str_if_bytes(response[0]): matches, - str_if_bytes(response[2]): int(response[3]), - } - return str_if_bytes(response) - - -def parse_cluster_info(response, **options): - response = str_if_bytes(response) - return dict(line.split(":") for line in response.splitlines() if line) - - -def _parse_node_line(line): - line_items = line.split(" ") - node_id, addr, flags, master_id, ping, pong, epoch, connected = line.split(" ")[:8] - addr = addr.split("@")[0] - slots = [sl.split("-") for sl in line_items[8:]] - node_dict = { - "node_id": node_id, - "flags": flags, - "master_id": master_id, - "last_ping_sent": ping, - "last_pong_rcvd": pong, - "epoch": epoch, - "slots": slots, - "connected": True if connected == "connected" else False, - } - return addr, node_dict - - -def parse_cluster_nodes(response, **options): - """ - @see: https://redis.io/commands/cluster-nodes # string - @see: https://redis.io/commands/cluster-replicas # list of string - """ - if isinstance(response, str): - response = response.splitlines() - return dict(_parse_node_line(str_if_bytes(node)) for node in response) - - -def parse_geosearch_generic(response, **options): - """ - Parse the response of 'GEOSEARCH', GEORADIUS' and 'GEORADIUSBYMEMBER' - commands according to 'withdist', 'withhash' and 'withcoord' labels. - """ - if options["store"] or options["store_dist"]: - # `store` and `store_dist` cant be combined - # with other command arguments. - # relevant to 'GEORADIUS' and 'GEORADIUSBYMEMBER' - return response - - if type(response) != list: - response_list = [response] - else: - response_list = response - - if not options["withdist"] and not options["withcoord"] and not options["withhash"]: - # just a bunch of places - return response_list - - cast: Dict[str, Callable] = { - "withdist": float, - "withcoord": lambda ll: (float(ll[0]), float(ll[1])), - "withhash": int, - } - - # zip all output results with each casting function to get - # the properly native Python value. - f = [lambda x: x] - f += [cast[o] for o in ["withdist", "withhash", "withcoord"] if options[o]] - return [list(map(lambda fv: fv[0](fv[1]), zip(f, r))) for r in response_list] - - -def parse_command(response, **options): - commands = {} - for command in response: - cmd_dict = {} - cmd_name = str_if_bytes(command[0]) - cmd_dict["name"] = cmd_name - cmd_dict["arity"] = int(command[1]) - cmd_dict["flags"] = [str_if_bytes(flag) for flag in command[2]] - cmd_dict["first_key_pos"] = command[3] - cmd_dict["last_key_pos"] = command[4] - cmd_dict["step_count"] = command[5] - commands[cmd_name] = cmd_dict - return commands - - -def parse_pubsub_numsub(response, **options): - return list(zip(response[0::2], response[1::2])) - - -def parse_client_kill(response, **options): - if isinstance(response, int): - return response - return str_if_bytes(response) == "OK" - - -def parse_acl_getuser(response, **options): - if response is None: - return None - data = pairs_to_dict(response, decode_keys=True) - - # convert everything but user-defined data in 'keys' to native strings - data["flags"] = list(map(str_if_bytes, data["flags"])) - data["passwords"] = list(map(str_if_bytes, data["passwords"])) - data["commands"] = str_if_bytes(data["commands"]) - - # split 'commands' into separate 'categories' and 'commands' lists - commands, categories = [], [] - for command in data["commands"].split(" "): - if "@" in command: - categories.append(command) - else: - commands.append(command) - - data["commands"] = commands - data["categories"] = categories - data["enabled"] = "on" in data["flags"] - return data - - -def parse_acl_log(response, **options): - if response is None: - return None - if isinstance(response, list): - data = [] - for log in response: - log_data = pairs_to_dict(log, True, True) - client_info = log_data.get("client-info", "") - log_data["client-info"] = parse_client_info(client_info) - - # float() is lossy comparing to the "double" in C - log_data["age-seconds"] = float(log_data["age-seconds"]) - data.append(log_data) - else: - data = bool_ok(response) - return data - - -def parse_client_info(value): - """ - Parsing client-info in ACL Log in following format. - "key1=value1 key2=value2 key3=value3" - """ - client_info = {} - infos = str_if_bytes(value).split(" ") - for info in infos: - key, value = info.split("=") - client_info[key] = value - - # Those fields are defined as int in networking.c - for int_key in { - "id", - "age", - "idle", - "db", - "sub", - "psub", - "multi", - "qbuf", - "qbuf-free", - "obl", - "argv-mem", - "oll", - "omem", - "tot-mem", - }: - client_info[int_key] = int(client_info[int_key]) - return client_info - - -def parse_module_result(response): - if isinstance(response, ModuleError): - raise response - return True - - -def parse_set_result(response, **options): - """ - Handle SET result since GET argument is available since Redis 6.2. - Parsing SET result into: - - BOOL - - String when GET argument is used - """ - if options.get("get"): - # Redis will return a getCommand result. - # See `setGenericCommand` in t_string.c - return response - return response and str_if_bytes(response) == "OK" - class ResponseCallbackProtocol(Protocol): def __call__(self, response: Any, **kwargs): @@ -690,10 +80,9 @@ async def __call__(self, response: Any, **kwargs): ResponseCallbackT = Union[ResponseCallbackProtocol, AsyncResponseCallbackProtocol] -_R = TypeVar("_R") - - -class Redis(RedisModuleCommands, AsyncCoreCommands, AsyncSentinelCommands): +class Redis( + AbstractRedis, RedisModuleCommands, AsyncCoreCommands, AsyncSentinelCommands +): """ Implementation of the Redis protocol. @@ -706,156 +95,6 @@ class Redis(RedisModuleCommands, AsyncCoreCommands, AsyncSentinelCommands): Connection object to talk to redis. """ - RESPONSE_CALLBACKS = { - **string_keys_to_dict( - "AUTH COPY EXPIRE EXPIREAT PEXPIRE PEXPIREAT HEXISTS HMSET LMOVE BLMOVE MOVE " - "MSETNX PERSIST PSETEX RENAMENX SISMEMBER SMOVE SETEX SETNX", - bool, - ), - **string_keys_to_dict( - "BITCOUNT BITPOS DECRBY DEL EXISTS GEOADD GETBIT HDEL HLEN " - "HSTRLEN INCRBY LINSERT LLEN LPUSHX PFADD PFCOUNT RPUSHX SADD " - "SCARD SDIFFSTORE SETBIT SETRANGE SINTERSTORE SREM STRLEN " - "SUNIONSTORE UNLINK XACK XDEL XLEN XTRIM ZCARD ZLEXCOUNT ZREM " - "ZREMRANGEBYLEX ZREMRANGEBYRANK ZREMRANGEBYSCORE", - int, - ), - **string_keys_to_dict("INCRBYFLOAT HINCRBYFLOAT", float), - **string_keys_to_dict( - # these return OK, or int if redis-server is >=1.3.4 - "LPUSH RPUSH", - lambda r: isinstance(r, int) and r or str_if_bytes(r) == "OK", - ), - **string_keys_to_dict("SORT", sort_return_tuples), - **string_keys_to_dict("ZSCORE ZINCRBY GEODIST", float_or_none), - **string_keys_to_dict( - "FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE READONLY READWRITE " - "RENAME SAVE SELECT SHUTDOWN SLAVEOF SWAPDB WATCH UNWATCH ", - bool_ok, - ), - **string_keys_to_dict("BLPOP BRPOP", lambda r: r and tuple(r) or None), - **string_keys_to_dict( - "SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set() - ), - **string_keys_to_dict( - "ZPOPMAX ZPOPMIN ZINTER ZDIFF ZUNION ZRANGE ZRANGEBYSCORE ZREVRANGE " - "ZREVRANGEBYSCORE", - zset_score_pairs, - ), - **string_keys_to_dict( - "BZPOPMIN BZPOPMAX", lambda r: r and (r[0], r[1], float(r[2])) or None - ), - **string_keys_to_dict("ZRANK ZREVRANK", int_or_none), - **string_keys_to_dict("XREVRANGE XRANGE", parse_stream_list), - **string_keys_to_dict("XREAD XREADGROUP", parse_xread), - **string_keys_to_dict("BGREWRITEAOF BGSAVE", lambda r: True), - "ACL CAT": lambda r: list(map(str_if_bytes, r)), - "ACL DELUSER": int, - "ACL GENPASS": str_if_bytes, - "ACL GETUSER": parse_acl_getuser, - "ACL HELP": lambda r: list(map(str_if_bytes, r)), - "ACL LIST": lambda r: list(map(str_if_bytes, r)), - "ACL LOAD": bool_ok, - "ACL LOG": parse_acl_log, - "ACL SAVE": bool_ok, - "ACL SETUSER": bool_ok, - "ACL USERS": lambda r: list(map(str_if_bytes, r)), - "ACL WHOAMI": str_if_bytes, - "CLIENT GETNAME": str_if_bytes, - "CLIENT ID": int, - "CLIENT KILL": parse_client_kill, - "CLIENT LIST": parse_client_list, - "CLIENT INFO": parse_client_info, - "CLIENT SETNAME": bool_ok, - "CLIENT UNBLOCK": lambda r: r and int(r) == 1 or False, - "CLIENT PAUSE": bool_ok, - "CLIENT GETREDIR": int, - "CLIENT TRACKINGINFO": lambda r: list(map(str_if_bytes, r)), - "CLUSTER ADDSLOTS": bool_ok, - "CLUSTER COUNT-FAILURE-REPORTS": lambda x: int(x), - "CLUSTER COUNTKEYSINSLOT": lambda x: int(x), - "CLUSTER DELSLOTS": bool_ok, - "CLUSTER FAILOVER": bool_ok, - "CLUSTER FORGET": bool_ok, - "CLUSTER INFO": parse_cluster_info, - "CLUSTER KEYSLOT": lambda x: int(x), - "CLUSTER MEET": bool_ok, - "CLUSTER NODES": parse_cluster_nodes, - "CLUSTER REPLICATE": bool_ok, - "CLUSTER RESET": bool_ok, - "CLUSTER SAVECONFIG": bool_ok, - "CLUSTER SET-CONFIG-EPOCH": bool_ok, - "CLUSTER SETSLOT": bool_ok, - "CLUSTER SLAVES": parse_cluster_nodes, - "CLUSTER REPLICAS": parse_cluster_nodes, - "COMMAND": parse_command, - "COMMAND COUNT": int, - "COMMAND GETKEYS": lambda r: list(map(str_if_bytes, r)), - "CONFIG GET": parse_config_get, - "CONFIG RESETSTAT": bool_ok, - "CONFIG SET": bool_ok, - "DEBUG OBJECT": parse_debug_object, - "GEOHASH": lambda r: list(map(str_if_bytes, r)), - "GEOPOS": lambda r: list( - map(lambda ll: (float(ll[0]), float(ll[1])) if ll is not None else None, r) - ), - "GEOSEARCH": parse_geosearch_generic, - "GEORADIUS": parse_geosearch_generic, - "GEORADIUSBYMEMBER": parse_geosearch_generic, - "HGETALL": lambda r: r and pairs_to_dict(r) or {}, - "HSCAN": parse_hscan, - "INFO": parse_info, - "LASTSAVE": timestamp_to_datetime, - "MEMORY PURGE": bool_ok, - "MEMORY STATS": parse_memory_stats, - "MEMORY USAGE": int_or_none, - "MODULE LOAD": parse_module_result, - "MODULE UNLOAD": parse_module_result, - "MODULE LIST": lambda r: [pairs_to_dict(m) for m in r], - "OBJECT": parse_object, - "PING": lambda r: str_if_bytes(r) == "PONG", - "QUIT": bool_ok, - "STRALGO": parse_stralgo, - "PUBSUB NUMSUB": parse_pubsub_numsub, - "RANDOMKEY": lambda r: r and r or None, - "SCAN": parse_scan, - "SCRIPT EXISTS": lambda r: list(map(bool, r)), - "SCRIPT FLUSH": bool_ok, - "SCRIPT KILL": bool_ok, - "SCRIPT LOAD": str_if_bytes, - "SENTINEL CKQUORUM": bool_ok, - "SENTINEL FAILOVER": bool_ok, - "SENTINEL FLUSHCONFIG": bool_ok, - "SENTINEL GET-MASTER-ADDR-BY-NAME": parse_sentinel_get_master, - "SENTINEL MASTER": parse_sentinel_master, - "SENTINEL MASTERS": parse_sentinel_masters, - "SENTINEL MONITOR": bool_ok, - "SENTINEL RESET": bool_ok, - "SENTINEL REMOVE": bool_ok, - "SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels, - "SENTINEL SET": bool_ok, - "SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels, - "SET": parse_set_result, - "SLOWLOG GET": parse_slowlog_get, - "SLOWLOG LEN": int, - "SLOWLOG RESET": bool_ok, - "SSCAN": parse_scan, - "TIME": lambda x: (int(x[0]), int(x[1])), - "XCLAIM": parse_xclaim, - "XAUTOCLAIM": parse_xautoclaim, - "XGROUP CREATE": bool_ok, - "XGROUP DELCONSUMER": int, - "XGROUP DESTROY": bool, - "XGROUP SETID": bool_ok, - "XINFO CONSUMERS": parse_list_of_dicts, - "XINFO GROUPS": parse_list_of_dicts, - "XINFO STREAM": parse_xinfo_stream, - "XPENDING": parse_xpending, - "ZADD": parse_zadd, - "ZSCAN": parse_zscan, - "ZMSCORE": parse_zmscore, - } - response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT] @classmethod @@ -1500,7 +739,7 @@ def _normalize_keys(self, data: _NormalizeKeysT) -> _NormalizeKeysT: """ encode = self.encoder.encode decode = self.encoder.decode - return {decode(encode(k)): v for k, v in data.items()} # type: ignore[return-value] + return {decode(encode(k)): v for k, v in data.items()} # type: ignore[return-value] # noqa: E501 async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler): """ diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index b043c18c4b..dc1ccee1cc 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -153,10 +153,10 @@ class BaseParser: "invalid password": AuthenticationError, # some Redis server versions report invalid command syntax # in lowercase - "wrong number of arguments for 'auth' command": AuthenticationWrongNumberOfArgsError, + "wrong number of arguments for 'auth' command": AuthenticationWrongNumberOfArgsError, # noqa: E501 # some Redis server versions report invalid command syntax # in uppercase - "wrong number of arguments for 'AUTH' command": AuthenticationWrongNumberOfArgsError, + "wrong number of arguments for 'AUTH' command": AuthenticationWrongNumberOfArgsError, # noqa: E501 MODULE_LOAD_ERROR: ModuleError, MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError, NO_SUCH_MODULE_ERROR: ModuleError, diff --git a/redis/asyncio/utils.py b/redis/asyncio/utils.py index 2090e893fa..5a55b36a33 100644 --- a/redis/asyncio/utils.py +++ b/redis/asyncio/utils.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from redis.asyncio.client import Redis, Pipeline + from redis.asyncio.client import Pipeline, Redis def from_url(url, **kwargs): @@ -23,6 +23,6 @@ def __init__(self, redis_obj: "Redis"): async def __aenter__(self) -> "Pipeline": return self.p - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__(self, exc_type, exc_value, traceback): await self.p.execute() del self.p diff --git a/redis/client.py b/redis/client.py index 612f91170a..866badf10e 100755 --- a/redis/client.py +++ b/redis/client.py @@ -642,19 +642,7 @@ def parse_set_result(response, **options): return response and str_if_bytes(response) == "OK" -class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): - """ - Implementation of the Redis protocol. - - This abstract class provides a Python interface to all Redis commands - and an implementation of the Redis protocol. - - Pipelines derive from this, implementing how - the commands are sent and received to the Redis server. Based on - configuration, an instance will either use a ConnectionPool, or - Connection object to talk to redis. - """ - +class AbstractRedis: RESPONSE_CALLBACKS = { **string_keys_to_dict( "AUTH COPY EXPIRE EXPIREAT PEXPIRE PEXPIREAT " @@ -807,6 +795,20 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): "ZMSCORE": parse_zmscore, } + +class Redis(AbstractRedis, RedisModuleCommands, CoreCommands, SentinelCommands): + """ + Implementation of the Redis protocol. + + This abstract class provides a Python interface to all Redis commands + and an implementation of the Redis protocol. + + Pipelines derive from this, implementing how + the commands are sent and received to the Redis server. Based on + configuration, an instance will either use a ConnectionPool, or + Connection object to talk to redis. + """ + @classmethod def from_url(cls, url, **kwargs): """ diff --git a/setup.py b/setup.py index 1565a9e4d2..25da1d607e 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,7 @@ packages=find_packages( include=[ "redis", + "redis.asyncio", "redis.commands", "redis.commands.bf", "redis.commands.json", diff --git a/whitelist.py b/whitelist.py index 891ccd6022..a800bcf482 100644 --- a/whitelist.py +++ b/whitelist.py @@ -10,3 +10,6 @@ exc_type # unused variable (/data/repos/redis/redis-py/redis/lock.py:156) exc_value # unused variable (/data/repos/redis/redis-py/redis/lock.py:156) traceback # unused variable (/data/repos/redis/redis-py/redis/lock.py:156) +exc_type # unused variable (/data/repos/redis/redis-py/redis/asyncio/utils.py:26) +exc_value # unused variable (/data/repos/redis/redis-py/redis/asyncio/utils.py:26) +traceback # unused variable (/data/repos/redis/redis-py/redis/asyncio/utils.py:26) From 3080960e853d03f3ee8bbed016319c565322e41a Mon Sep 17 00:00:00 2001 From: "Chayim I. Kirshen" Date: Thu, 27 Jan 2022 09:22:54 +0200 Subject: [PATCH 05/24] moved async-timeout to package requirements --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 25da1d607e..f57fe848ee 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,7 @@ "packaging>=20.4", 'importlib-metadata >= 1.0; python_version < "3.8"', "typing-extensions", + "async-timeout>=4.0.2", ], classifiers=[ "Development Status :: 5 - Production/Stable", @@ -54,7 +55,6 @@ "Programming Language :: Python :: Implementation :: PyPy", ], extras_require={ - "async": ["async-timeout"], "hiredis": ["hiredis>=1.0.0"], "ocsp": ["cryptography>=36.0.1", "pyopenssl==20.0.1", "requests>=2.26.0"], }, From 66a0efd1bd12bdd0402ff8102368b424833f8382 Mon Sep 17 00:00:00 2001 From: "Chayim I. Kirshen" Date: Thu, 10 Feb 2022 09:05:28 +0200 Subject: [PATCH 06/24] PR discussion, test removal --- tests/test_asyncio/test_multiprocessing.py | 181 --------------------- 1 file changed, 181 deletions(-) delete mode 100644 tests/test_asyncio/test_multiprocessing.py diff --git a/tests/test_asyncio/test_multiprocessing.py b/tests/test_asyncio/test_multiprocessing.py deleted file mode 100644 index bd21c289be..0000000000 --- a/tests/test_asyncio/test_multiprocessing.py +++ /dev/null @@ -1,181 +0,0 @@ -import asyncio -import contextlib -import multiprocessing - -import pytest - -from redis.asyncio.connection import Connection, ConnectionPool -from redis.exceptions import ConnectionError - -pytestmark = pytest.mark.asyncio - - -@contextlib.contextmanager -async def exit_callback(callback, *args): - try: - yield - finally: - await callback(*args) - - -@pytest.mark.xfail() -class TestMultiprocessing: - # Test connection sharing between forks. - # See issue #1085 for details. - - # use a multi-connection client as that's the only type that is - # actually fork/process-safe - @pytest.fixture() - async def r(self, create_redis): - redis = await create_redis( - single_connection_client=False, - ) - yield redis - await redis.flushall() - - async def test_close_connection_in_child(self, master_host): - """ - A connection owned by a parent and closed by a child doesn't - destroy the file descriptors so a parent can still use it. - """ - conn = Connection(host=master_host) - await conn.send_command("ping") - assert await conn.read_response() == b"PONG" - - def target(conn): - async def atarget(conn): - await conn.send_command("ping") - assert conn.read_response() == b"PONG" - await conn.disconnect() - - asyncio.get_event_loop().run_until_complete(atarget(conn)) - - proc = multiprocessing.Process(target=target, args=(conn,)) - proc.start() - proc.join(3) - assert proc.exitcode == 0 - - # The connection was created in the parent but disconnected in the - # child. The child called socket.close() but did not call - # socket.shutdown() because it wasn't the "owning" process. - # Therefore the connection still works in the parent. - await conn.send_command("ping") - assert await conn.read_response() == b"PONG" - - async def test_close_connection_in_parent(self, master_host): - """ - A connection owned by a parent is unusable by a child if the parent - (the owning process) closes the connection. - """ - conn = Connection(host=master_host) - await conn.send_command("ping") - assert await conn.read_response() == b"PONG" - - def target(conn, ev): - ev.wait() - # the parent closed the connection. because it also created the - # connection, the connection is shutdown and the child - # cannot use it. - with pytest.raises(ConnectionError): - asyncio.get_event_loop().run_until_complete(conn.send_command("ping")) - - ev = multiprocessing.Event() - proc = multiprocessing.Process(target=target, args=(conn, ev)) - proc.start() - - await conn.disconnect() - ev.set() - - proc.join(3) - assert proc.exitcode == 0 - - @pytest.mark.parametrize("max_connections", [1, 2, None]) - async def test_pool(self, max_connections, master_host): - """ - A child will create its own connections when using a pool created - by a parent. - """ - pool = ConnectionPool.from_url( - f"redis://{master_host}", max_connections=max_connections - ) - - conn = await pool.get_connection("ping") - main_conn_pid = conn.pid - async with exit_callback(pool.release, conn): - await conn.send_command("ping") - assert await conn.read_response() == b"PONG" - - def target(pool): - async def atarget(pool): - async with exit_callback(pool.disconnect): - conn = await pool.get_connection("ping") - assert conn.pid != main_conn_pid - async with exit_callback(pool.release, conn): - assert await conn.send_command("ping") is None - assert await conn.read_response() == b"PONG" - - asyncio.get_event_loop().run_until_complete(atarget(pool)) - - proc = multiprocessing.Process(target=target, args=(pool,)) - proc.start() - proc.join(3) - assert proc.exitcode == 0 - - # Check that connection is still alive after fork process has exited - # and disconnected the connections in its pool - conn = pool.get_connection("ping") - async with exit_callback(pool.release, conn): - assert await conn.send_command("ping") is None - assert await conn.read_response() == b"PONG" - - @pytest.mark.parametrize("max_connections", [1, 2, None]) - async def test_close_pool_in_main(self, max_connections, master_host): - """ - A child process that uses the same pool as its parent isn't affected - when the parent disconnects all connections within the pool. - """ - pool = ConnectionPool.from_url( - f"redis://{master_host}", max_connections=max_connections - ) - - conn = await pool.get_connection("ping") - assert await conn.send_command("ping") is None - assert await conn.read_response() == b"PONG" - - def target(pool, disconnect_event): - async def atarget(pool, disconnect_event): - conn = await pool.get_connection("ping") - async with exit_callback(pool.release, conn): - assert await conn.send_command("ping") is None - assert await conn.read_response() == b"PONG" - disconnect_event.wait() - assert await conn.send_command("ping") is None - assert await conn.read_response() == b"PONG" - - asyncio.get_event_loop().run_until_complete(atarget(pool, disconnect_event)) - - ev = multiprocessing.Event() - - proc = multiprocessing.Process(target=target, args=(pool, ev)) - proc.start() - - await pool.disconnect() - ev.set() - proc.join(3) - assert proc.exitcode == 0 - - async def test_aioredis_client(self, r): - """A aioredis client created in a parent can also be used in a child""" - assert await r.ping() is True - - def target(client): - run = asyncio.get_event_loop().run_until_complete - assert run(client.ping()) is True - del client - - proc = multiprocessing.Process(target=target, args=(r,)) - proc.start() - proc.join(3) - assert proc.exitcode == 0 - - assert await r.ping() is True From ba541e73f4a49f3a1da361514e9af5f7a4a4fe06 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Thu, 10 Feb 2022 09:10:44 +0200 Subject: [PATCH 07/24] Mark tests as onlynoncluster --- tests/test_asyncio/test_connection.py | 3 +++ tests/test_asyncio/test_connection_pool.py | 6 ++++++ tests/test_asyncio/test_encoding.py | 3 +++ tests/test_asyncio/test_pipeline.py | 1 + tests/test_asyncio/test_pubsub.py | 8 ++++++++ tests/test_asyncio/test_retry.py | 2 ++ 6 files changed, 23 insertions(+) diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 78dbe4d14f..46abec01d6 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -30,6 +30,7 @@ async def test_invalid_response(create_redis): @skip_if_server_version_lt("4.0.0") @pytest.mark.redismod +@pytest.mark.onlynoncluster async def test_loading_external_modules(modclient): def inner(): pass @@ -50,12 +51,14 @@ def inner(): # assert mod.get('fookey') == d +@pytest.mark.onlynoncluster async def test_socket_param_regression(r): """A regression test for issue #1060""" conn = UnixDomainSocketConnection() _ = await conn.disconnect() is True +@pytest.mark.onlynoncluster async def test_can_run_concurrent_commands(r): assert await r.ping() is True assert all(await asyncio.gather(*(r.ping() for _ in range(10)))) diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index 3cf1b7c915..f8900e0039 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -31,6 +31,7 @@ async def can_read(self, timeout: float = 0): return False +@pytest.mark.onlynoncluster class TestConnectionPool: def get_pool( self, @@ -107,6 +108,7 @@ def test_repr_contains_db_info_unix(self): assert repr(pool) == expected +@pytest.mark.onlynoncluster class TestBlockingConnectionPool: def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20): connection_kwargs = connection_kwargs or {} @@ -213,6 +215,7 @@ def test_repr_contains_db_info_unix(self): assert repr(pool) == expected +@pytest.mark.onlynoncluster class TestConnectionPoolURLParsing: def test_hostname(self): pool = redis.ConnectionPool.from_url("redis://my.host") @@ -379,6 +382,7 @@ def test_invalid_scheme_raises_error(self): ) +@pytest.mark.onlynoncluster class TestConnectionPoolUnixSocketURLParsing: def test_defaults(self): pool = redis.ConnectionPool.from_url("unix:///socket") @@ -461,6 +465,7 @@ def test_extra_querystring_options(self): assert pool.connection_kwargs == {"path": "/socket", "a": "1", "b": "2"} +@pytest.mark.onlynoncluster class TestSSLConnectionURLParsing: def test_host(self): pool = redis.ConnectionPool.from_url("rediss://my.host") @@ -492,6 +497,7 @@ def get_connection(self, *args, **kwargs): assert pool.get_connection("_").check_hostname is True +@pytest.mark.onlynoncluster class TestConnection: async def test_on_connect_error(self): """ diff --git a/tests/test_asyncio/test_encoding.py b/tests/test_asyncio/test_encoding.py index c3c69f5055..b68c7fc1f4 100644 --- a/tests/test_asyncio/test_encoding.py +++ b/tests/test_asyncio/test_encoding.py @@ -6,6 +6,7 @@ pytestmark = pytest.mark.asyncio +@pytest.mark.onlynoncluster class TestEncoding: @pytest.fixture() async def r(self, create_redis): @@ -57,6 +58,7 @@ async def test_list_encoding(self, r: redis.Redis): assert await r.lrange("a", 0, -1) == result +@pytest.mark.onlynoncluster class TestEncodingErrors: async def test_ignore(self, create_redis): r = await create_redis( @@ -75,6 +77,7 @@ async def test_replace(self, create_redis): assert await r.get("a") == "foo\ufffd" +@pytest.mark.onlynoncluster class TestMemoryviewsAreNotPacked: async def test_memoryviews_are_not_packed(self, r): arg = memoryview(b"some_arg") diff --git a/tests/test_asyncio/test_pipeline.py b/tests/test_asyncio/test_pipeline.py index 4eb4daccae..5bb1a8a4e0 100644 --- a/tests/test_asyncio/test_pipeline.py +++ b/tests/test_asyncio/test_pipeline.py @@ -8,6 +8,7 @@ pytestmark = pytest.mark.asyncio +@pytest.mark.onlynoncluster class TestPipeline: async def test_pipeline_is_true(self, r): """Ensure pipeline instances are not false-y""" diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index 55958c18df..de3e8e28f1 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -60,6 +60,7 @@ def make_subscribe_test_data(pubsub, type): assert False, f"invalid subscribe type: {type}" +@pytest.mark.onlynoncluster class TestPubSubSubscribeUnsubscribe: async def _test_subscribe_unsubscribe( self, p, sub_type, unsub_type, sub_func, unsub_func, keys @@ -268,6 +269,7 @@ async def _test_sub_unsub_all_resub( assert p.subscribed is True +@pytest.mark.onlynoncluster class TestPubSubMessages: def setup_method(self, method): self.message = None @@ -362,6 +364,7 @@ async def test_get_message_without_subscribe(self, r: redis.Redis): assert expect in info.exconly() +@pytest.mark.onlynoncluster class TestPubSubAutoDecoding: """These tests only validate that we get unicode values back""" @@ -479,6 +482,7 @@ async def test_context_manager(self, r: redis.Redis): assert pubsub.patterns == {} +@pytest.mark.onlynoncluster class TestPubSubRedisDown: async def test_channel_subscribe(self, r: redis.Redis): r = redis.Redis(host="localhost", port=6390) @@ -487,6 +491,7 @@ async def test_channel_subscribe(self, r: redis.Redis): await p.subscribe("foo") +@pytest.mark.onlynoncluster class TestPubSubSubcommands: @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.8.0") @@ -525,6 +530,7 @@ async def test_pubsub_numpat(self, r: redis.Redis): assert await r.pubsub_numpat() == 3 +@pytest.mark.onlynoncluster class TestPubSubPings: @skip_if_server_version_lt("3.0.0") async def test_send_pubsub_ping(self, r: redis.Redis): @@ -559,6 +565,7 @@ async def test_connection_error_raised_when_connection_dies(self, r: redis.Redis await wait_for_message(p) +@pytest.mark.onlynoncluster class TestPubSubTimeouts: async def test_get_message_with_timeout_returns_none(self, r: redis.Redis): p = r.pubsub() @@ -567,6 +574,7 @@ async def test_get_message_with_timeout_returns_none(self, r: redis.Redis): assert await p.get_message(timeout=0.01) is None +@pytest.mark.onlynoncluster class TestPubSubRun: async def _subscribe(self, p, *args, **kwargs): await p.subscribe(*args, **kwargs) diff --git a/tests/test_asyncio/test_retry.py b/tests/test_asyncio/test_retry.py index 6e277ae38f..e83e001847 100644 --- a/tests/test_asyncio/test_retry.py +++ b/tests/test_asyncio/test_retry.py @@ -19,6 +19,7 @@ def compute(self, failures): return 0 +@pytest.mark.onlynoncluster class TestConnectionConstructorWithRetry: "Test that the Connection constructors properly handles Retry objects" @@ -40,6 +41,7 @@ def test_retry_on_timeout_retry(self, Class, retries: int): assert c.retry._retries == retries +@pytest.mark.onlynoncluster class TestRetry: "Test that Retry calls backoff and retries the expected number of times" From 6690caf1de225f36a52d788901a5235970e87171 Mon Sep 17 00:00:00 2001 From: "Chayim I. Kirshen" Date: Thu, 10 Feb 2022 09:11:28 +0200 Subject: [PATCH 08/24] fixing linters --- redis/commands/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/redis/commands/core.py b/redis/commands/core.py index a7de40a1ba..a970c19abc 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -4,7 +4,6 @@ import hashlib import time import warnings -from typing import List, Optional, Union from typing import ( TYPE_CHECKING, Any, @@ -13,7 +12,9 @@ Callable, Iterable, Iterator, + List, Mapping, + Optional, Sequence, Union, ) @@ -22,7 +23,6 @@ from redis.exceptions import ConnectionError, DataError, NoScriptError, RedisError from redis.typing import ( AbsExpiryT, - AnyFieldT, AnyKeyT, BitfieldOffsetT, ChannelT, From cdf7ce88e6c040babf6b8a2ce47c9570a1567ffd Mon Sep 17 00:00:00 2001 From: Andrew-Chen-Wang Date: Sat, 12 Feb 2022 00:13:13 -0500 Subject: [PATCH 09/24] Fix asyncio SentinelCommands --- redis/asyncio/sentinel.py | 4 ++-- redis/commands/sentinel.py | 20 ++++++++++---------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py index eb707e037e..43900eb495 100644 --- a/redis/asyncio/sentinel.py +++ b/redis/asyncio/sentinel.py @@ -10,7 +10,7 @@ EncodableT, SSLConnection, ) -from redis.commands import SentinelCommands +from redis.commands import AsyncSentinelCommands from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError from redis.utils import str_if_bytes @@ -145,7 +145,7 @@ async def rotate_slaves(self) -> AsyncIterator: raise SlaveNotFoundError(f"No slave found for {self.service_name!r}") -class Sentinel(SentinelCommands): +class Sentinel(AsyncSentinelCommands): """ Redis Sentinel cluster client diff --git a/redis/commands/sentinel.py b/redis/commands/sentinel.py index bb12f14568..e054ec6a38 100644 --- a/redis/commands/sentinel.py +++ b/redis/commands/sentinel.py @@ -8,39 +8,39 @@ class SentinelCommands: """ def sentinel(self, *args): - "Redis Sentinel's SENTINEL command." + """Redis Sentinel's SENTINEL command.""" warnings.warn(DeprecationWarning("Use the individual sentinel_* methods")) def sentinel_get_master_addr_by_name(self, service_name): - "Returns a (host, port) pair for the given ``service_name``" + """Returns a (host, port) pair for the given ``service_name``""" return self.execute_command("SENTINEL GET-MASTER-ADDR-BY-NAME", service_name) def sentinel_master(self, service_name): - "Returns a dictionary containing the specified masters state." + """Returns a dictionary containing the specified masters state.""" return self.execute_command("SENTINEL MASTER", service_name) def sentinel_masters(self): - "Returns a list of dictionaries containing each master's state." + """Returns a list of dictionaries containing each master's state.""" return self.execute_command("SENTINEL MASTERS") def sentinel_monitor(self, name, ip, port, quorum): - "Add a new master to Sentinel to be monitored" + """Add a new master to Sentinel to be monitored""" return self.execute_command("SENTINEL MONITOR", name, ip, port, quorum) def sentinel_remove(self, name): - "Remove a master from Sentinel's monitoring" + """Remove a master from Sentinel's monitoring""" return self.execute_command("SENTINEL REMOVE", name) def sentinel_sentinels(self, service_name): - "Returns a list of sentinels for ``service_name``" + """Returns a list of sentinels for ``service_name``""" return self.execute_command("SENTINEL SENTINELS", service_name) def sentinel_set(self, name, option, value): - "Set Sentinel monitoring parameters for a given master" + """Set Sentinel monitoring parameters for a given master""" return self.execute_command("SENTINEL SET", name, option, value) def sentinel_slaves(self, service_name): - "Returns a list of slaves for ``service_name``" + """Returns a list of slaves for ``service_name``""" return self.execute_command("SENTINEL SLAVES", service_name) def sentinel_reset(self, pattern): @@ -95,5 +95,5 @@ def sentinel_flushconfig(self): class AsyncSentinelCommands(SentinelCommands): async def sentinel(self, *args) -> None: - "Redis Sentinel's SENTINEL command." + """Redis Sentinel's SENTINEL command.""" super().sentinel(*args) From 90e129ce431e5005f6fde7a4ec213754b504e805 Mon Sep 17 00:00:00 2001 From: Andrew-Chen-Wang Date: Sat, 12 Feb 2022 00:25:54 -0500 Subject: [PATCH 10/24] Add aio-libs/aioredis-py#1256 & aio-libs/aioredis-py#1284 --- redis/asyncio/client.py | 35 ++++++++-- tests/test_asyncio/test_connection_pool.py | 75 ++++++++++++++++++++++ tests/test_asyncio/test_pubsub.py | 22 +++++++ 3 files changed, 126 insertions(+), 6 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index b738f0c463..619592ef76 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -58,7 +58,7 @@ from redis.typing import ChannelT, EncodableT, KeyT from redis.utils import safe_str, str_if_bytes -PubSubHandler = Callable[[Dict[str, str]], None] +PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]] _KeyT = TypeVar("_KeyT", bound=KeyT) _ArgT = TypeVar("_ArgT", KeyT, EncodableT) _RedisT = TypeVar("_RedisT", bound="Redis") @@ -170,6 +170,7 @@ def __init__( client_name: Optional[str] = None, username: Optional[str] = None, retry: Optional[Retry] = None, + auto_close_connection_pool: bool = True, ): """ Initialize a new Redis client. @@ -177,6 +178,13 @@ def __init__( then set `retry` to a valid `Retry` object """ kwargs: Dict[str, Any] + # auto_close_connection_pool only has an effect if connection_pool is + # None. This is a similar feature to the missing __del__ to resolve #1103, + # but it accounts for whether a user wants to manually close the connection + # pool, as a similar feature to ConnectionPool's __del__. + self.auto_close_connection_pool = ( + auto_close_connection_pool if connection_pool is None else False + ) if not connection_pool: kwargs = { "db": db, @@ -412,11 +420,23 @@ def __del__(self, _warnings: Any = warnings) -> None: context = {"client": self, "message": self._DEL_MESSAGE} asyncio.get_event_loop().call_exception_handler(context) - async def close(self): + async def close(self, close_connection_pool: Optional[bool] = None) -> None: + """ + Closes Redis client connection + + :param close_connection_pool: decides whether to close the connection pool used + by this Redis client, overriding Redis.auto_close_connection_pool. By default, + let Redis.auto_close_connection_pool decide whether to close the connection + pool. + """ conn = self.connection if conn: self.connection = None await self.connection_pool.release(conn) + if close_connection_pool or ( + close_connection_pool is None and self.auto_close_connection_pool + ): + await self.connection_pool.disconnect() async def _send_command_parse_response(self, conn, command_name, *args, **options): """ @@ -815,7 +835,7 @@ def unsubscribe(self, *args) -> Awaitable: async def listen(self) -> AsyncIterator: """Listen for messages on channels this client has been subscribed to""" while self.subscribed: - response = self.handle_message(await self.parse_response(block=True)) + response = await self.handle_message(await self.parse_response(block=True)) if response is not None: yield response @@ -831,7 +851,7 @@ async def get_message( """ response = await self.parse_response(block=False, timeout=timeout) if response: - return self.handle_message(response, ignore_subscribe_messages) + return await self.handle_message(response, ignore_subscribe_messages) return None def ping(self, message=None) -> Awaitable: @@ -841,7 +861,7 @@ def ping(self, message=None) -> Awaitable: message = "" if message is None else message return self.execute_command("PING", message) - def handle_message(self, response, ignore_subscribe_messages=False): + async def handle_message(self, response, ignore_subscribe_messages=False): """ Parses a pub/sub message. If the channel or pattern was subscribed to with a message handler, the handler is invoked instead of a parsed @@ -890,7 +910,10 @@ def handle_message(self, response, ignore_subscribe_messages=False): else: handler = self.channels.get(message["channel"], None) if handler: - handler(message) + if inspect.iscoroutinefunction(handler): + await handler(message) + else: + handler(message) return None elif message_type != "pong": # this is a subscribe/unsubscribe message. ignore if we don't diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index f8900e0039..d3afad9800 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -14,6 +14,81 @@ pytestmark = pytest.mark.asyncio +class TestRedisAutoReleaseConnectionPool: + @pytest.fixture + async def r(self, create_redis) -> redis.Redis: + """This is necessary since r and r2 create ConnectionPools behind the scenes""" + r = await create_redis() + r.auto_close_connection_pool = True + yield r + + @staticmethod + def get_total_connected_connections(pool): + return len(pool._available_connections) + len(pool._in_use_connections) + + @staticmethod + async def create_two_conn(r: redis.Redis): + if not r.single_connection_client: # Single already initialized connection + r.connection = await r.connection_pool.get_connection("_") + return await r.connection_pool.get_connection("_") + + @staticmethod + def has_no_connected_connections(pool: redis.ConnectionPool): + return not any( + x.is_connected + for x in pool._available_connections + list(pool._in_use_connections) + ) + + async def test_auto_disconnect_redis_created_pool(self, r: redis.Redis): + new_conn = await self.create_two_conn(r) + assert new_conn != r.connection + assert self.get_total_connected_connections(r.connection_pool) == 2 + await r.close() + assert self.has_no_connected_connections(r.connection_pool) + + async def test_do_not_auto_disconnect_redis_created_pool(self, r2: redis.Redis): + assert r2.auto_close_connection_pool is False, ( + "The connection pool should not be disconnected as a manually created " + "connection pool was passed in in conftest.py" + ) + new_conn = await self.create_two_conn(r2) + assert self.get_total_connected_connections(r2.connection_pool) == 2 + await r2.close() + assert r2.connection_pool._in_use_connections == {new_conn} + assert new_conn.is_connected + assert len(r2.connection_pool._available_connections) == 1 + assert r2.connection_pool._available_connections[0].is_connected + + async def test_auto_release_override_true_manual_created_pool(self, r: redis.Redis): + assert r.auto_close_connection_pool is True, "This is from the class fixture" + await self.create_two_conn(r) + await r.close() + assert self.get_total_connected_connections(r.connection_pool) == 2, ( + "The connection pool should not be disconnected as a manually created " + "connection pool was passed in in conftest.py" + ) + assert self.has_no_connected_connections(r.connection_pool) + + @pytest.mark.parametrize("auto_close_conn_pool", [True, False]) + async def test_close_override(self, r: redis.Redis, auto_close_conn_pool): + r.auto_close_connection_pool = auto_close_conn_pool + await self.create_two_conn(r) + await r.close(close_connection_pool=True) + assert self.has_no_connected_connections(r.connection_pool) + + @pytest.mark.parametrize("auto_close_conn_pool", [True, False]) + async def test_negate_auto_close_client_pool( + self, r: redis.Redis, auto_close_conn_pool + ): + r.auto_close_connection_pool = auto_close_conn_pool + new_conn = await self.create_two_conn(r) + await r.close(close_connection_pool=False) + assert not self.has_no_connected_connections(r.connection_pool) + assert r.connection_pool._in_use_connections == {new_conn} + assert r.connection_pool._available_connections[0].is_connected + assert self.get_total_connected_connections(r.connection_pool) == 2 + + class DummyConnection(Connection): description_format = "DummyConnection<>" diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index de3e8e28f1..d889d4f5ef 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -277,6 +277,9 @@ def setup_method(self, method): def message_handler(self, message): self.message = message + async def async_message_handler(self, message): + self.async_message = message + async def test_published_message_to_channel(self, r: redis.Redis): p = r.pubsub() await p.subscribe("foo") @@ -318,6 +321,25 @@ async def test_channel_message_handler(self, r: redis.Redis): assert await wait_for_message(p) is None assert self.message == make_message("message", "foo", "test message") + async def test_channel_async_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + await p.subscribe(foo=self.async_message_handler) + assert await wait_for_message(p) is None + assert await r.publish("foo", "test message") == 1 + assert await wait_for_message(p) is None + assert self.async_message == make_message("message", "foo", "test message") + + async def test_channel_sync_async_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + await p.subscribe(foo=self.message_handler) + await p.subscribe(bar=self.async_message_handler) + assert await wait_for_message(p) is None + assert await r.publish("foo", "test message") == 1 + assert await r.publish("bar", "test message 2") == 1 + assert await wait_for_message(p) is None + assert self.message == make_message("message", "foo", "test message") + assert self.async_message == make_message("message", "bar", "test message 2") + @pytest.mark.onlynoncluster async def test_pattern_message_handler(self, r: redis.Redis): p = r.pubsub(ignore_subscribe_messages=True) From 56393637bd9e5bb1ce20e50900c0dd102f1cc462 Mon Sep 17 00:00:00 2001 From: Andrew-Chen-Wang Date: Sat, 12 Feb 2022 00:56:58 -0500 Subject: [PATCH 11/24] Add asyncio example to docs --- docs/examples.rst | 1 + docs/examples/asyncio_examples.ipynb | 301 +++++++++++++++++++++++++++ 2 files changed, 302 insertions(+) create mode 100644 docs/examples/asyncio_examples.ipynb diff --git a/docs/examples.rst b/docs/examples.rst index cf70c09bf9..a2bbb2131e 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -8,3 +8,4 @@ Examples examples/connection_examples examples/ssl_connection_examples examples/search_json_examples + examples/asyncio_examples diff --git a/docs/examples/asyncio_examples.ipynb b/docs/examples/asyncio_examples.ipynb new file mode 100644 index 0000000000..66d435835b --- /dev/null +++ b/docs/examples/asyncio_examples.ipynb @@ -0,0 +1,301 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "collapsed": true, + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "# Asyncio Examples\n", + "\n", + "All commands are coroutine functions.\n", + "\n", + "## Connecting and Disconnecting\n", + "\n", + "Utilizing asyncio Redis requires an explicit disconnect of the connection since there is no asyncio deconstructor magic method. By default, a connection pool is created on `redis.Redis()` and attached to this `Redis` instance. The connection pool closes automatically on the call to `Redis.close` which disconnects all connections." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ping successful: True\n" + ] + } + ], + "source": [ + "import redis.asyncio as redis\n", + "\n", + "connection = redis.Redis()\n", + "print(f\"Ping successful: {await connection.ping()}\")\n", + "await connection.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "If you supply a custom `ConnectionPool` that is supplied to several `Redis` instances, you may want to disconnect the connection pool explicitly. Disconnecting the connection pool simply disconnects all connections hosted in the pool." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "import redis.asyncio as redis\n", + "\n", + "connection = redis.Redis(auto_close_connection_pool=False)\n", + "await connection.close()\n", + "# Or: await connection.close(close_connection_pool=False)\n", + "await connection.connection_pool.disconnect()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "## Transactions (Multi/Exec)\n", + "\n", + "The aioredis.Redis.pipeline will return a aioredis.Pipeline object, which will buffer all commands in-memory and compile them into batches using the Redis Bulk String protocol. Additionally, each command will return the Pipeline instance, allowing you to chain your commands, i.e., p.set('foo', 1).set('bar', 2).mget('foo', 'bar').\n", + "\n", + "The commands will not be reflected in Redis until execute() is called & awaited.\n", + "\n", + "Usually, when performing a bulk operation, taking advantage of a “transaction” (e.g., Multi/Exec) is to be desired, as it will also add a layer of atomicity to your bulk operation." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "import redis.asyncio as redis\n", + "\n", + "r = await redis.from_url(\"redis://localhost\")\n", + "async with r.pipeline(transaction=True) as pipe:\n", + " ok1, ok2 = await (pipe.set(\"key1\", \"value1\").set(\"key2\", \"value2\").execute())\n", + "assert ok1\n", + "assert ok2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pub/Sub Mode\n", + "\n", + "Subscribing to specific channels:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(Reader) Message Received: {'type': 'message', 'pattern': None, 'channel': b'channel:1', 'data': b'Hello'}\n", + "(Reader) Message Received: {'type': 'message', 'pattern': None, 'channel': b'channel:2', 'data': b'World'}\n", + "(Reader) Message Received: {'type': 'message', 'pattern': None, 'channel': b'channel:1', 'data': b'STOP'}\n", + "(Reader) STOP\n" + ] + } + ], + "source": [ + "import asyncio\n", + "\n", + "import async_timeout\n", + "\n", + "import redis.asyncio as redis\n", + "\n", + "STOPWORD = \"STOP\"\n", + "\n", + "\n", + "async def reader(channel: redis.client.PubSub):\n", + " while True:\n", + " try:\n", + " async with async_timeout.timeout(1):\n", + " message = await channel.get_message(ignore_subscribe_messages=True)\n", + " if message is not None:\n", + " print(f\"(Reader) Message Received: {message}\")\n", + " if message[\"data\"].decode() == STOPWORD:\n", + " print(\"(Reader) STOP\")\n", + " break\n", + " await asyncio.sleep(0.01)\n", + " except asyncio.TimeoutError:\n", + " pass\n", + "\n", + "r = redis.from_url(\"redis://localhost\")\n", + "pubsub = r.pubsub()\n", + "await pubsub.subscribe(\"channel:1\", \"channel:2\")\n", + "\n", + "future = asyncio.create_task(reader(pubsub))\n", + "\n", + "await r.publish(\"channel:1\", \"Hello\")\n", + "await r.publish(\"channel:2\", \"World\")\n", + "await r.publish(\"channel:1\", STOPWORD)\n", + "\n", + "await future" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Subscribing to channels matching a glob-style pattern:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(Reader) Message Received: {'type': 'pmessage', 'pattern': b'channel:*', 'channel': b'channel:1', 'data': b'Hello'}\n", + "(Reader) Message Received: {'type': 'pmessage', 'pattern': b'channel:*', 'channel': b'channel:2', 'data': b'World'}\n", + "(Reader) Message Received: {'type': 'pmessage', 'pattern': b'channel:*', 'channel': b'channel:1', 'data': b'STOP'}\n", + "(Reader) STOP\n" + ] + } + ], + "source": [ + "import asyncio\n", + "\n", + "import async_timeout\n", + "\n", + "import redis.asyncio as redis\n", + "\n", + "STOPWORD = \"STOP\"\n", + "\n", + "\n", + "async def reader(channel: redis.client.PubSub):\n", + " while True:\n", + " try:\n", + " async with async_timeout.timeout(1):\n", + " message = await channel.get_message(ignore_subscribe_messages=True)\n", + " if message is not None:\n", + " print(f\"(Reader) Message Received: {message}\")\n", + " if message[\"data\"].decode() == STOPWORD:\n", + " print(\"(Reader) STOP\")\n", + " break\n", + " await asyncio.sleep(0.01)\n", + " except asyncio.TimeoutError:\n", + " pass\n", + "\n", + "\n", + "r = await redis.from_url(\"redis://localhost\")\n", + "pubsub = r.pubsub()\n", + "await pubsub.psubscribe(\"channel:*\")\n", + "\n", + "future = asyncio.create_task(reader(pubsub))\n", + "\n", + "await r.publish(\"channel:1\", \"Hello\")\n", + "await r.publish(\"channel:2\", \"World\")\n", + "await r.publish(\"channel:1\", STOPWORD)\n", + "\n", + "await future" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sentinel Client\n", + "\n", + "The Sentinel client requires a list of Redis Sentinel addresses to connect to and start discovering services.\n", + "\n", + "Calling aioredis.sentinel.Sentinel.master_for or aioredis.sentinel.Sentinel.slave_for methods will return Redis clients connected to specified services monitored by Sentinel.\n", + "\n", + "Sentinel client will detect failover and reconnect Redis clients automatically." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "import asyncio\n", + "\n", + "from redis.asyncio.sentinel import Sentinel\n", + "\n", + "\n", + "sentinel = Sentinel([(\"localhost\", 26379), (\"sentinel2\", 26379)])\n", + "r = sentinel.master_for(\"mymaster\")\n", + "\n", + "ok = await r.set(\"key\", \"value\")\n", + "assert ok\n", + "val = await r.get(\"key\")\n", + "assert val == b\"value\"" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} \ No newline at end of file From c2f0af483b61b62df449a3a5cb4b2ecc962a2c20 Mon Sep 17 00:00:00 2001 From: Andrew-Chen-Wang Date: Sun, 13 Feb 2022 11:48:22 -0500 Subject: [PATCH 12/24] Drop Python 3.6 from tox --- CONTRIBUTING.md | 6 +++--- redis/asyncio/connection.py | 4 +--- tox.ini | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ebb66bbf3e..827a25f9c4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -112,12 +112,12 @@ You can see the logging output of a containers like this: `$ docker logs -f ` The command make test runs all tests in all tested Python -environments. To run the tests in a single environment, like Python 3.6, +environments. To run the tests in a single environment, like Python 3.9, use a command like this: -`$ docker-compose run test tox -e py36 -- --redis-url=redis://master:6379/9` +`$ docker-compose run test tox -e py39 -- --redis-url=redis://master:6379/9` -Here, the flag `-e py36` runs tests against the Python 3.6 tox +Here, the flag `-e py39` runs tests against the Python 3.9 tox environment. And note from the example that whenever you run tests like this, instead of using make test, you need to pass `-- --redis-url=redis://master:6379/9`. This points the tests at the diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index dc1ccee1cc..8d2b73dfc4 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -821,9 +821,7 @@ async def disconnect(self): try: if os.getpid() == self.pid: self._writer.close() # type: ignore[union-attr] - # py3.6 doesn't have this method - if hasattr(self._writer, "wait_closed"): - await self._writer.wait_closed() # type: ignore[union-attr] + await self._writer.wait_closed() # type: ignore[union-attr] except OSError: pass self._reader = None diff --git a/tox.ini b/tox.ini index 851daf422d..9012118b33 100644 --- a/tox.ini +++ b/tox.ini @@ -10,7 +10,7 @@ markers = [tox] minversion = 3.2.0 requires = tox-docker -envlist = {standalone,cluster}-{plain,hiredis,ocsp}-{uvloop,asyncio}-{py36,py37,py38,py39,pypy3},linters,docs +envlist = {standalone,cluster}-{plain,hiredis,ocsp}-{uvloop,asyncio}-{py37,py38,py39,pypy3},linters,docs [docker:master] name = master From d1abb6c0d6ee047806a4b8cab82dcffda7b96ced Mon Sep 17 00:00:00 2001 From: "Chayim I. Kirshen" Date: Tue, 15 Feb 2022 13:46:10 +0200 Subject: [PATCH 13/24] black, to python 3.6 --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 9012118b33..42b1e037d6 100644 --- a/tox.ini +++ b/tox.ini @@ -325,7 +325,7 @@ deps_files = dev_requirements.txt docker = commands = flake8 - black --target-version py37 --check --diff . + black --target-version py36 --check --diff . isort --check-only --diff . vulture redis whitelist.py --min-confidence 80 flynt --fail-on-change --dry-run . From a4d791f3217b32597ad6bc41b10147f28ee0bd59 Mon Sep 17 00:00:00 2001 From: "Chayim I. Kirshen" Date: Tue, 15 Feb 2022 13:47:17 +0200 Subject: [PATCH 14/24] python 3.6 --- .github/workflows/integration.yaml | 2 +- tox.ini | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index 0e36012f2b..ea5c25ffad 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -34,7 +34,7 @@ jobs: strategy: max-parallel: 15 matrix: - python-version: ['3.7', '3.8', '3.9', '3.10', 'pypy-3.7'] + python-version: ['3.6','3.7', '3.8', '3.9', '3.10', 'pypy-3.7'] test-type: ['standalone', 'cluster'] connection-type: ['hiredis', 'plain'] env: diff --git a/tox.ini b/tox.ini index 42b1e037d6..ca8c9b33fd 100644 --- a/tox.ini +++ b/tox.ini @@ -10,7 +10,7 @@ markers = [tox] minversion = 3.2.0 requires = tox-docker -envlist = {standalone,cluster}-{plain,hiredis,ocsp}-{uvloop,asyncio}-{py37,py38,py39,pypy3},linters,docs +envlist = {standalone,cluster}-{plain,hiredis,ocsp}-{uvloop,asyncio}-{py36,py37,py38,py39,pypy3},linters,docs [docker:master] name = master From e49f9e7d9122a09363a2a0c3bf246cd29eff3090 Mon Sep 17 00:00:00 2001 From: "Chayim I. Kirshen" Date: Tue, 15 Feb 2022 14:02:30 +0200 Subject: [PATCH 15/24] requirements --- dev_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev_requirements.txt b/dev_requirements.txt index 4b12a647b7..ce7216bd10 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -5,7 +5,7 @@ isort==5.10.1 mock==4.0.3 pytest==6.2.5 pytest-timeout==2.0.1 -pytest-asyncio==0.17.2 +pytest-asyncio>=0.16.0 tox==3.24.4 tox-docker==3.1.0 invoke==1.6.0 From b50b265690babedda31dbdfe629529715cb1d4fc Mon Sep 17 00:00:00 2001 From: "Chayim I. Kirshen" Date: Tue, 15 Feb 2022 14:13:59 +0200 Subject: [PATCH 16/24] uvloop --- dev_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev_requirements.txt b/dev_requirements.txt index ce7216bd10..0c4bee9f35 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -12,5 +12,5 @@ invoke==1.6.0 pytest-cov>=3.0.0 vulture>=2.3.0 ujson>=4.2.0 -uvloop>=0.16.0 wheel>=0.30.0 +uvloop From 0a2713bc1e9ae35c15ea5fa016f8c9d26d998b55 Mon Sep 17 00:00:00 2001 From: "Chayim I. Kirshen" Date: Tue, 15 Feb 2022 14:28:27 +0200 Subject: [PATCH 17/24] 3.6 trove --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 8c9e9721ed..047ed902d9 100644 --- a/setup.py +++ b/setup.py @@ -47,6 +47,7 @@ "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", From ba024cb592ab90466b43a49e6355c9dc005bac2a Mon Sep 17 00:00:00 2001 From: Andrew-Chen-Wang Date: Tue, 15 Feb 2022 09:37:33 -0500 Subject: [PATCH 18/24] Add back Python 3.6 support --- redis/asyncio/connection.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 8d2b73dfc4..dc1ccee1cc 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -821,7 +821,9 @@ async def disconnect(self): try: if os.getpid() == self.pid: self._writer.close() # type: ignore[union-attr] - await self._writer.wait_closed() # type: ignore[union-attr] + # py3.6 doesn't have this method + if hasattr(self._writer, "wait_closed"): + await self._writer.wait_closed() # type: ignore[union-attr] except OSError: pass self._reader = None From 86409546030eb26ffd92d3dff7fd1dcecc106eef Mon Sep 17 00:00:00 2001 From: Andrew-Chen-Wang Date: Tue, 15 Feb 2022 11:40:01 -0500 Subject: [PATCH 19/24] Add support to Python 3.6 in setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 047ed902d9..a6d5cf0679 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ }, author="Redis Inc.", author_email="oss@redis.com", - python_requires=">=3.7", + python_requires=">=3.6", install_requires=[ "deprecated>=1.2.3", "packaging>=20.4", From f8a119cc5de7722ce75fe31a5ecc12a56c659394 Mon Sep 17 00:00:00 2001 From: "Chayim I. Kirshen" Date: Thu, 17 Feb 2022 18:22:42 +0200 Subject: [PATCH 20/24] python 3.6 fixes fixing async tests core.command typing that broke command tests --- redis/asyncio/connection.py | 19 +- redis/asyncio/lock.py | 7 +- redis/asyncio/retry.py | 10 +- redis/asyncio/sentinel.py | 2 +- redis/commands/core.py | 721 +++++++++++---------- redis/typing.py | 4 +- tests/test_asyncio/conftest.py | 24 +- tests/test_asyncio/test_connection_pool.py | 7 +- tests/test_asyncio/test_encoding.py | 7 +- tests/test_asyncio/test_lock.py | 3 +- tests/test_asyncio/test_pubsub.py | 3 +- tests/test_asyncio/test_scripting.py | 3 +- tests/test_asyncio/test_sentinel.py | 7 +- 13 files changed, 418 insertions(+), 399 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index dc1ccee1cc..ae54d8147a 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -7,6 +7,7 @@ import os import socket import ssl +import sys import threading import weakref from itertools import chain @@ -845,10 +846,12 @@ async def _ping_failed(self, error): async def check_health(self): """Check the health of the connection with a PING/PONG""" - if ( - self.health_check_interval - and asyncio.get_running_loop().time() > self.next_health_check - ): + if sys.version_info[0:2] == (3, 6): + func = asyncio.get_event_loop + else: + func = asyncio.get_running_loop + + if self.health_check_interval and func().time() > self.next_health_check: await self.retry.call_with_retry(self._send_ping, self._ping_failed) async def _send_packed_command(self, command: Iterable[bytes]) -> None: @@ -930,9 +933,11 @@ async def read_response(self, disable_decoding: bool = False): raise if self.health_check_interval: - self.next_health_check = ( - asyncio.get_running_loop().time() + self.health_check_interval - ) + if sys.version_info[0:2] == (3, 6): + func = asyncio.get_event_loop + else: + func = asyncio.get_running_loop + self.next_health_check = func().time() + self.health_check_interval if isinstance(response, ResponseError): raise response from None diff --git a/redis/asyncio/lock.py b/redis/asyncio/lock.py index 784594e3af..d4861329a0 100644 --- a/redis/asyncio/lock.py +++ b/redis/asyncio/lock.py @@ -1,4 +1,5 @@ import asyncio +import sys import threading import uuid from types import SimpleNamespace @@ -185,7 +186,11 @@ async def acquire( object with the default encoding. If a token isn't specified, a UUID will be generated. """ - loop = asyncio.get_running_loop() + if sys.version_info[0:2] != (3, 6): + loop = asyncio.get_running_loop() + else: + loop = asyncio.get_event_loop() + sleep = self.sleep if token is None: token = uuid.uuid1().hex.encode() diff --git a/redis/asyncio/retry.py b/redis/asyncio/retry.py index d98a5fec87..9b5349402c 100644 --- a/redis/asyncio/retry.py +++ b/redis/asyncio/retry.py @@ -1,7 +1,5 @@ -from __future__ import annotations - from asyncio import sleep -from typing import TYPE_CHECKING, Awaitable, Callable, TypeVar +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Tuple, TypeVar from redis.exceptions import ConnectionError, RedisError, TimeoutError @@ -19,9 +17,9 @@ class Retry: def __init__( self, - backoff: AbstractBackoff, + backoff: "AbstractBackoff", retries: int, - supported_errors: type[tuple[RedisError, ...]] = ( + supported_errors: Tuple[RedisError, ...] = ( ConnectionError, TimeoutError, ), @@ -37,7 +35,7 @@ def __init__( self._supported_errors = supported_errors async def call_with_retry( - self, do: Callable[[], Awaitable[T]], fail: Callable[[RedisError], ...] + self, do: Callable[[], Awaitable[T]], fail: Callable[[RedisError], Any] ) -> T: """ Execute an operation that might fail and returns its result, or diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py index 43900eb495..5aefd09ebd 100644 --- a/redis/asyncio/sentinel.py +++ b/redis/asyncio/sentinel.py @@ -208,7 +208,7 @@ async def execute_command(self, *args, **kwargs): if once: tasks = [ - asyncio.create_task(sentinel.execute_command(*args, **kwargs)) + asyncio.Task(sentinel.execute_command(*args, **kwargs)) for sentinel in self.sentinels ] await asyncio.gather(*tasks) diff --git a/redis/commands/core.py b/redis/commands/core.py index 3110acf3ad..56b07d8f5e 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -1,4 +1,4 @@ -from __future__ import annotations +# from __future__ import annotations import datetime import hashlib @@ -10,12 +10,14 @@ AsyncIterator, Awaitable, Callable, + Dict, Iterable, Iterator, List, Mapping, Optional, Sequence, + Tuple, Union, ) @@ -56,7 +58,7 @@ class ACLCommands(CommandsProtocol): see: https://redis.io/topics/acl """ - def acl_cat(self, category: str | None = None, **kwargs) -> ResponseT: + def acl_cat(self, category: Union[str, None] = None, **kwargs) -> ResponseT: """ Returns a list of categories or commands within a category. @@ -77,7 +79,7 @@ def acl_deluser(self, *username: str, **kwargs) -> ResponseT: """ return self.execute_command("ACL DELUSER", *username, **kwargs) - def acl_genpass(self, bits: int | None = None, **kwargs) -> ResponseT: + def acl_genpass(self, bits: Union[int, None] = None, **kwargs) -> ResponseT: """Generate a random password value. If ``bits`` is supplied then use this number of bits, rounded to the next multiple of 4. @@ -121,7 +123,7 @@ def acl_list(self, **kwargs) -> ResponseT: """ return self.execute_command("ACL LIST", **kwargs) - def acl_log(self, count: int | None = None, **kwargs) -> ResponseT: + def acl_log(self, count: Union[int, None] = None, **kwargs) -> ResponseT: """ Get ACL logs as a list. :param int count: Get logs[0:count]. @@ -174,11 +176,11 @@ def acl_setuser( username: str, enabled: bool = False, nopass: bool = False, - passwords: str | Iterable[str] | None = None, - hashed_passwords: str | Iterable[str] | None = None, - categories: Iterable[str] | None = None, - commands: Iterable[str] | None = None, - keys: Iterable[KeyT] | None = None, + passwords: Union[str, Iterable[str], None] = None, + hashed_passwords: Union[str, Iterable[str], None] = None, + categories: Union[Iterable[str], None] = None, + commands: Union[Iterable[str], None] = None, + keys: Union[Iterable[KeyT], None] = None, reset: bool = False, reset_keys: bool = False, reset_passwords: bool = False, @@ -406,11 +408,11 @@ def client_kill(self, address: str, **kwargs) -> ResponseT: def client_kill_filter( self, - _id: str | None = None, - _type: str | None = None, - addr: str | None = None, - skipme: bool | None = None, - laddr: bool | None = None, + _id: Union[str, None] = None, + _type: Union[str, None] = None, + addr: Union[str, None] = None, + skipme: Union[bool, None] = None, + laddr: Union[bool, None] = None, user: str = None, **kwargs, ) -> ResponseT: @@ -465,8 +467,8 @@ def client_info(self, **kwargs) -> ResponseT: def client_list( self, - _type: str | None = None, - client_id: list[EncodableT] = [], + _type: Union[str, None] = None, + client_id: List[EncodableT] = [], **kwargs, ) -> ResponseT: """ @@ -511,7 +513,7 @@ def client_getredir(self, **kwargs) -> ResponseT: def client_reply( self, - reply: Literal["ON"] | Literal["OFF"] | Literal["SKIP"], + reply: Union[Literal["ON"], Literal["OFF"], Literal["SKIP"]], **kwargs, ) -> ResponseT: """ @@ -544,7 +546,7 @@ def client_id(self, **kwargs) -> ResponseT: def client_tracking_on( self, - clientid: int | None = None, + clientid: Union[int, None] = None, prefix: Sequence[KeyT] = [], bcast: bool = False, optin: bool = False, @@ -563,7 +565,7 @@ def client_tracking_on( def client_tracking_off( self, - clientid: int | None = None, + clientid: Union[int, None] = None, prefix: Sequence[KeyT] = [], bcast: bool = False, optin: bool = False, @@ -583,7 +585,7 @@ def client_tracking_off( def client_tracking( self, on: bool = True, - clientid: int | None = None, + clientid: Union[int, None] = None, prefix: Sequence[KeyT] = [], bcast: bool = False, optin: bool = False, @@ -863,7 +865,7 @@ def select(self, index: int, **kwargs) -> ResponseT: """ return self.execute_command("SELECT", index, **kwargs) - def info(self, section: str | None = None, **kwargs) -> ResponseT: + def info(self, section: Union[str, None] = None, **kwargs) -> ResponseT: """ Returns a dictionary containing information about the Redis server @@ -889,7 +891,7 @@ def lastsave(self, **kwargs) -> ResponseT: """ return self.execute_command("LASTSAVE", **kwargs) - def lolwut(self, *version_numbers: str | float, **kwargs) -> ResponseT: + def lolwut(self, *version_numbers: Union[str, float], **kwargs) -> ResponseT: """ Get the Redis version and a piece of generative computer art @@ -916,7 +918,7 @@ def migrate( timeout: int, copy: bool = False, replace: bool = False, - auth: str | None = None, + auth: Union[str, None] = None, **kwargs, ) -> ResponseT: """ @@ -998,7 +1000,7 @@ def memory_malloc_stats(self, **kwargs) -> ResponseT: return self.execute_command("MEMORY MALLOC-STATS", **kwargs) def memory_usage( - self, key: KeyT, samples: int | None = None, **kwargs + self, key: KeyT, samples: Union[int, None] = None, **kwargs ) -> ResponseT: """ Return the total memory usage for key, its value and associated @@ -1083,7 +1085,7 @@ def shutdown(self, save: bool = False, nosave: bool = False, **kwargs) -> None: raise RedisError("SHUTDOWN seems to have failed.") def slaveof( - self, host: str | None = None, port: int | None = None, **kwargs + self, host: Union[str, None] = None, port: Union[int, None] = None, **kwargs ) -> ResponseT: """ Set the server to be a replicated slave of the instance identified @@ -1096,7 +1098,7 @@ def slaveof( return self.execute_command("SLAVEOF", b"NO", b"ONE", **kwargs) return self.execute_command("SLAVEOF", host, port, **kwargs) - def slowlog_get(self, num: int | None = None, **kwargs) -> ResponseT: + def slowlog_get(self, num: Union[int, None] = None, **kwargs) -> ResponseT: """ Get the entries from the slowlog. If ``num`` is specified, get the most recent ``num`` items. @@ -1210,6 +1212,118 @@ async def shutdown( raise RedisError("SHUTDOWN seems to have failed.") +class BitFieldOperation: + """ + Command builder for BITFIELD commands. + """ + + def __init__( + self, + client: Union["Redis", "AsyncRedis"], + key: str, + default_overflow: Union[str, None] = None, + ): + self.client = client + self.key = key + self._default_overflow = default_overflow + # for typing purposes, run the following in constructor and in reset() + self.operations: list[tuple[EncodableT, ...]] = [] + self._last_overflow = "WRAP" + self.reset() + + def reset(self): + """ + Reset the state of the instance to when it was constructed + """ + self.operations = [] + self._last_overflow = "WRAP" + self.overflow(self._default_overflow or self._last_overflow) + + def overflow(self, overflow: str): + """ + Update the overflow algorithm of successive INCRBY operations + :param overflow: Overflow algorithm, one of WRAP, SAT, FAIL. See the + Redis docs for descriptions of these algorithmsself. + :returns: a :py:class:`BitFieldOperation` instance. + """ + overflow = overflow.upper() + if overflow != self._last_overflow: + self._last_overflow = overflow + self.operations.append(("OVERFLOW", overflow)) + return self + + def incrby( + self, + fmt: str, + offset: BitfieldOffsetT, + increment: int, + overflow: Union[str, None] = None, + ): + """ + Increment a bitfield by a given amount. + :param fmt: format-string for the bitfield being updated, e.g. 'u8' + for an unsigned 8-bit integer. + :param offset: offset (in number of bits). If prefixed with a + '#', this is an offset multiplier, e.g. given the arguments + fmt='u8', offset='#2', the offset will be 16. + :param int increment: value to increment the bitfield by. + :param str overflow: overflow algorithm. Defaults to WRAP, but other + acceptable values are SAT and FAIL. See the Redis docs for + descriptions of these algorithms. + :returns: a :py:class:`BitFieldOperation` instance. + """ + if overflow is not None: + self.overflow(overflow) + + self.operations.append(("INCRBY", fmt, offset, increment)) + return self + + def get(self, fmt: str, offset: BitfieldOffsetT): + """ + Get the value of a given bitfield. + :param fmt: format-string for the bitfield being read, e.g. 'u8' for + an unsigned 8-bit integer. + :param offset: offset (in number of bits). If prefixed with a + '#', this is an offset multiplier, e.g. given the arguments + fmt='u8', offset='#2', the offset will be 16. + :returns: a :py:class:`BitFieldOperation` instance. + """ + self.operations.append(("GET", fmt, offset)) + return self + + def set(self, fmt: str, offset: BitfieldOffsetT, value: int): + """ + Set the value of a given bitfield. + :param fmt: format-string for the bitfield being read, e.g. 'u8' for + an unsigned 8-bit integer. + :param offset: offset (in number of bits). If prefixed with a + '#', this is an offset multiplier, e.g. given the arguments + fmt='u8', offset='#2', the offset will be 16. + :param int value: value to set at the given position. + :returns: a :py:class:`BitFieldOperation` instance. + """ + self.operations.append(("SET", fmt, offset, value)) + return self + + @property + def command(self): + cmd = ["BITFIELD", self.key] + for ops in self.operations: + cmd.extend(ops) + return cmd + + def execute(self) -> ResponseT: + """ + Execute the operation(s) in a single BITFIELD command. The return value + is a list of values corresponding to each operation. If the client + used to create this instance was a pipeline, the list of values + will be present within the pipeline's execute. + """ + command = self.command + self.reset() + return self.client.execute_command(*command) + + class BasicKeyCommands(CommandsProtocol): """ Redis basic key-based commands @@ -1228,8 +1342,8 @@ def append(self, key: KeyT, value: EncodableT) -> ResponseT: def bitcount( self, key: KeyT, - start: int | None = None, - end: int | None = None, + start: Union[int, None] = None, + end: Union[int, None] = None, ) -> ResponseT: """ Returns the count of set bits in the value of ``key``. Optional @@ -1246,9 +1360,9 @@ def bitcount( return self.execute_command("BITCOUNT", *params) def bitfield( - self: Redis | AsyncRedis, + self: Union["Redis", "AsyncRedis"], key: KeyT, - default_overflow: str | None = None, + default_overflow: Union[str, None] = None, ) -> BitFieldOperation: """ Return a BitFieldOperation instance to conveniently construct one or @@ -1276,8 +1390,8 @@ def bitpos( self, key: KeyT, bit: int, - start: int | None = None, - end: int | None = None, + start: Union[int, None] = None, + end: Union[int, None] = None, ) -> ResponseT: """ Return the position of the first bit set to 1 or 0 in a string. @@ -1303,7 +1417,7 @@ def copy( self, source: str, destination: str, - destination_db: str | None = None, + destination_db: Union[str, None] = None, replace: bool = False, ) -> ResponseT: """ @@ -1412,10 +1526,10 @@ def getdel(self, name: KeyT) -> ResponseT: def getex( self, name: KeyT, - ex: ExpiryT | None = None, - px: ExpiryT | None = None, - exat: AbsExpiryT | None = None, - pxat: AbsExpiryT | None = None, + ex: Union[ExpiryT, None] = None, + px: Union[ExpiryT, None] = None, + exat: Union[AbsExpiryT, None] = None, + pxat: Union[AbsExpiryT, None] = None, persist: bool = False, ) -> ResponseT: """ @@ -1741,8 +1855,8 @@ def restore( value: EncodableT, replace: bool = False, absttl: bool = False, - idletime: int | None = None, - frequency: int | None = None, + idletime: Union[int, None] = None, + frequency: Union[int, None] = None, ) -> ResponseT: """ Create a key using the provided serialized value, previously obtained @@ -1788,14 +1902,14 @@ def set( self, name: KeyT, value: EncodableT, - ex: ExpiryT | None = None, - px: ExpiryT | None = None, + ex: Union[ExpiryT, None] = None, + px: Union[ExpiryT, None] = None, nx: bool = False, xx: bool = False, keepttl: bool = False, get: bool = False, - exat: AbsExpiryT | None = None, - pxat: AbsExpiryT | None = None, + exat: Union[AbsExpiryT, None] = None, + pxat: Union[AbsExpiryT, None] = None, ) -> ResponseT: """ Set the value at key ``name`` to ``value`` @@ -1927,10 +2041,10 @@ def stralgo( algo: Literal["LCS"], value1: KeyT, value2: KeyT, - specific_argument: Literal["strings"] | Literal["keys"] = "strings", + specific_argument: Union[Literal["strings"], Literal["keys"]] = "strings", len: bool = False, idx: bool = False, - minmatchlen: int | None = None, + minmatchlen: Union[int, None] = None, withmatchlen: bool = False, **kwargs, ) -> ResponseT: @@ -2479,9 +2593,9 @@ class ScanCommands(CommandsProtocol): def scan( self, cursor: int = 0, - match: PatternT | None = None, - count: int | None = None, - _type: str | None = None, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + _type: Union[str, None] = None, **kwargs, ) -> ResponseT: """ @@ -2511,9 +2625,9 @@ def scan( def scan_iter( self, - match: PatternT | None = None, - count: int | None = None, - _type: str | None = None, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + _type: Union[str, None] = None, **kwargs, ) -> Iterator: """ @@ -2541,8 +2655,8 @@ def sscan( self, name: KeyT, cursor: int = 0, - match: PatternT | None = None, - count: int | None = None, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, ) -> ResponseT: """ Incrementally return lists of elements in a set. Also return a cursor @@ -2564,8 +2678,8 @@ def sscan( def sscan_iter( self, name: KeyT, - match: PatternT | None = None, - count: int | None = None, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, ) -> Iterator: """ Make an iterator using the SSCAN command so that the client doesn't @@ -2584,8 +2698,8 @@ def hscan( self, name: KeyT, cursor: int = 0, - match: PatternT | None = None, - count: int | None = None, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, ) -> ResponseT: """ Incrementally return key/value slices in a hash. Also return a cursor @@ -2607,8 +2721,8 @@ def hscan( def hscan_iter( self, name: str, - match: PatternT | None = None, - count: int | None = None, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, ) -> Iterator: """ Make an iterator using the HSCAN command so that the client doesn't @@ -2627,9 +2741,9 @@ def zscan( self, name: KeyT, cursor: int = 0, - match: PatternT | None = None, - count: int | None = None, - score_cast_func: type | Callable = float, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + score_cast_func: Union[type, Callable] = float, ) -> ResponseT: """ Incrementally return lists of elements in a sorted set. Also return a @@ -2654,9 +2768,9 @@ def zscan( def zscan_iter( self, name: KeyT, - match: PatternT | None = None, - count: int | None = None, - score_cast_func: type | Callable = float, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + score_cast_func: Union[type, Callable] = float, ) -> Iterator: """ Make an iterator using the ZSCAN command so that the client doesn't @@ -2683,9 +2797,9 @@ def zscan_iter( class AsyncScanCommands(ScanCommands): async def scan_iter( self, - match: PatternT | None = None, - count: int | None = None, - _type: str | None = None, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + _type: Union[str, None] = None, **kwargs, ) -> AsyncIterator: """ @@ -2713,8 +2827,8 @@ async def scan_iter( async def sscan_iter( self, name: KeyT, - match: PatternT | None = None, - count: int | None = None, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, ) -> AsyncIterator: """ Make an iterator using the SSCAN command so that the client doesn't @@ -2735,8 +2849,8 @@ async def sscan_iter( async def hscan_iter( self, name: str, - match: PatternT | None = None, - count: int | None = None, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, ) -> AsyncIterator: """ Make an iterator using the HSCAN command so that the client doesn't @@ -2757,9 +2871,9 @@ async def hscan_iter( async def zscan_iter( self, name: KeyT, - match: PatternT | None = None, - count: int | None = None, - score_cast_func: type | Callable = float, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + score_cast_func: Union[type, Callable] = float, ) -> AsyncIterator: """ Make an iterator using the ZSCAN command so that the client doesn't @@ -2973,13 +3087,13 @@ def xack( def xadd( self, name: KeyT, - fields: dict[FieldT, EncodableT], + fields: Dict[FieldT, EncodableT], id: StreamIdT = "*", - maxlen: int | None = None, + maxlen: Union[int, None] = None, approximate: bool = True, nomkstream: bool = False, - minid: StreamIdT | None = None, - limit: int | None = None, + minid: Union[StreamIdT, None] = None, + limit: Union[int, None] = None, ) -> ResponseT: """ Add to a stream. @@ -3032,7 +3146,7 @@ def xautoclaim( consumername: ConsumerT, min_idle_time: int, start_id: int = 0, - count: int | None = None, + count: Union[int, None] = None, justid: bool = False, ) -> ResponseT: """ @@ -3082,10 +3196,10 @@ def xclaim( groupname: GroupT, consumername: ConsumerT, min_idle_time: int, - message_ids: list[StreamIdT] | tuple[StreamIdT], - idle: int | None = None, - time: int | None = None, - retrycount: int | None = None, + message_ids: [List[StreamIdT], Tuple[StreamIdT]], + idle: Union[int, None] = None, + time: Union[int, None] = None, + retrycount: Union[int, None] = None, force: bool = False, justid: bool = False, ) -> ResponseT: @@ -3302,8 +3416,8 @@ def xpending_range( min: StreamIdT, max: StreamIdT, count: int, - consumername: ConsumerT | None = None, - idle: int | None = None, + consumername: Union[ConsumerT, None] = None, + idle: Union[int, None] = None, ) -> ResponseT: """ Returns information about pending messages, in a range. @@ -3357,7 +3471,7 @@ def xrange( name: KeyT, min: StreamIdT = "-", max: StreamIdT = "+", - count: int | None = None, + count: Union[int, None] = None, ) -> ResponseT: """ Read stream values within an interval. @@ -3382,9 +3496,9 @@ def xrange( def xread( self, - streams: dict[KeyT, StreamIdT], - count: int | None = None, - block: int | None = None, + streams: Dict[KeyT, StreamIdT], + count: Union[int, None] = None, + block: Union[int, None] = None, ) -> ResponseT: """ Block and monitor multiple streams for new data. @@ -3419,9 +3533,9 @@ def xreadgroup( self, groupname: str, consumername: str, - streams: dict[KeyT, StreamIdT], - count: int | None = None, - block: int | None = None, + streams: Dict[KeyT, StreamIdT], + count: Union[int, None] = None, + block: Union[int, None] = None, noack: bool = False, ) -> ResponseT: """ @@ -3462,7 +3576,7 @@ def xrevrange( name: KeyT, max: StreamIdT = "+", min: StreamIdT = "-", - count: int | None = None, + count: Union[int, None] = None, ) -> ResponseT: """ Read stream values within an interval, in reverse order. @@ -3490,8 +3604,8 @@ def xtrim( name: KeyT, maxlen: int, approximate: bool = True, - minid: StreamIdT | None = None, - limit: int | None = None, + minid: Union[StreamIdT, None] = None, + limit: Union[int, None] = None, ) -> ResponseT: """ Trims old messages from a stream. @@ -3666,7 +3780,7 @@ def zincrby( def zinter( self, keys: KeysT, - aggregate: str | None = None, + aggregate: Union[str, None] = None, withscores: bool = False, ) -> ResponseT: """ @@ -3685,8 +3799,8 @@ def zinter( def zinterstore( self, dest: KeyT, - keys: Sequence[KeyT] | Mapping[AnyKeyT, float], - aggregate: str | None = None, + keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], + aggregate: Union[str, None] = None, ) -> ResponseT: """ Intersect multiple sorted sets specified by ``keys`` into a new @@ -3726,7 +3840,7 @@ def zlexcount(self, name, min, max): def zpopmax( self, name: KeyT, - count: int | None = None, + count: Union[int, None] = None, ) -> ResponseT: """ Remove and return up to ``count`` members with the highest scores @@ -3741,7 +3855,7 @@ def zpopmax( def zpopmin( self, name: KeyT, - count: int | None = None, + count: Union[int, None] = None, ) -> ResponseT: """ Remove and return up to ``count`` members with the lowest scores @@ -3880,7 +3994,7 @@ def bzmpop( def _zrange( self, command, - dest: KeyT | None, + dest: Union[KeyT, None], name: KeyT, start: int, end: int, @@ -3888,9 +4002,9 @@ def _zrange( byscore: bool = False, bylex: bool = False, withscores: bool = False, - score_cast_func: type | Callable | None = float, - offset: int | None = None, - num: int | None = None, + score_cast_func: Union[type, Callable, None] = float, + offset: Union[int, None] = None, + num: Union[int, None] = None, ) -> ResponseT: if byscore and bylex: raise DataError( @@ -3926,7 +4040,7 @@ def zrange( end: int, desc: bool = False, withscores: bool = False, - score_cast_func: type | Callable = float, + score_cast_func: Union[type, Callable] = float, byscore: bool = False, bylex: bool = False, offset: int = None, @@ -3986,7 +4100,7 @@ def zrevrange( start: int, end: int, withscores: bool = False, - score_cast_func: type | Callable = float, + score_cast_func: Union[type, Callable] = float, ) -> ResponseT: """ Return a range of values from sorted set ``name`` between @@ -4016,8 +4130,8 @@ def zrangestore( byscore: bool = False, bylex: bool = False, desc: bool = False, - offset: int | None = None, - num: int | None = None, + offset: Union[int, None] = None, + num: Union[int, None] = None, ) -> ResponseT: """ Stores in ``dest`` the result of a range of values from sorted set @@ -4062,8 +4176,8 @@ def zrangebylex( name: KeyT, min: EncodableT, max: EncodableT, - start: int | None = None, - num: int | None = None, + start: Union[int, None] = None, + num: Union[int, None] = None, ) -> ResponseT: """ Return the lexicographical range of values from sorted set ``name`` @@ -4086,8 +4200,8 @@ def zrevrangebylex( name: KeyT, max: EncodableT, min: EncodableT, - start: int | None = None, - num: int | None = None, + start: Union[int, None] = None, + num: Union[int, None] = None, ) -> ResponseT: """ Return the reversed lexicographical range of values from sorted set @@ -4110,10 +4224,10 @@ def zrangebyscore( name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT, - start: int | None = None, - num: int | None = None, + start: Union[int, None] = None, + num: Union[int, None] = None, withscores: bool = False, - score_cast_func: type | Callable = float, + score_cast_func: Union[type, Callable] = float, ) -> ResponseT: """ Return a range of values from the sorted set ``name`` with scores @@ -4144,10 +4258,10 @@ def zrevrangebyscore( name: KeyT, max: ZScoreBoundT, min: ZScoreBoundT, - start: int | None = None, - num: int | None = None, + start: Union[int, None] = None, + num: Union[int, None] = None, withscores: bool = False, - score_cast_func: type | Callable = float, + score_cast_func: Union[type, Callable] = float, ): """ Return a range of values from the sorted set ``name`` with scores @@ -4242,8 +4356,8 @@ def zscore(self, name: KeyT, value: EncodableT) -> ResponseT: def zunion( self, - keys: Sequence[KeyT] | Mapping[AnyKeyT, float], - aggregate: str | None = None, + keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], + aggregate: Union[str, None] = None, withscores: bool = False, ) -> ResponseT: """ @@ -4259,8 +4373,8 @@ def zunion( def zunionstore( self, dest: KeyT, - keys: Sequence[KeyT] | Mapping[AnyKeyT, float], - aggregate: str | None = None, + keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], + aggregate: Union[str, None] = None, ) -> ResponseT: """ Union multiple sorted sets specified by ``keys`` into @@ -4274,7 +4388,7 @@ def zunionstore( def zmscore( self, key: KeyT, - members: list[str], + members: List[str], ) -> ResponseT: """ Returns the scores associated with the specified members @@ -4294,9 +4408,9 @@ def zmscore( def _zaggregate( self, command: str, - dest: KeyT | None, - keys: Sequence[KeyT] | Mapping[AnyKeyT, float], - aggregate: str | None = None, + dest: Union[KeyT, None], + keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], + aggregate: Union[str, None] = None, **options, ) -> ResponseT: pieces: list[EncodableT] = [command] @@ -4515,6 +4629,96 @@ def hstrlen(self, name: str, key: str) -> int: AsyncHashCommands = HashCommands +class Script: + """ + An executable Lua script object returned by ``register_script`` + """ + + def __init__(self, registered_client: "Redis", script: ScriptTextT): + self.registered_client = registered_client + self.script = script + # Precalculate and store the SHA1 hex digest of the script. + + if isinstance(script, str): + # We need the encoding from the client in order to generate an + # accurate byte representation of the script + encoder = registered_client.connection_pool.get_encoder() + script = encoder.encode(script) + self.sha = hashlib.sha1(script).hexdigest() + + def __call__( + self, + keys: Union[Sequence[KeyT], None] = None, + args: Union[Iterable[EncodableT], None] = None, + client: Union["Redis", None] = None, + ): + """Execute the script, passing any required ``args``""" + keys = keys or [] + args = args or [] + if client is None: + client = self.registered_client + args = tuple(keys) + tuple(args) + # make sure the Redis server knows about the script + from redis.client import Pipeline + + if isinstance(client, Pipeline): + # Make sure the pipeline can register the script before executing. + client.scripts.add(self) + try: + return client.evalsha(self.sha, len(keys), *args) + except NoScriptError: + # Maybe the client is pointed to a different server than the client + # that created this instance? + # Overwrite the sha just in case there was a discrepancy. + self.sha = client.script_load(self.script) + return client.evalsha(self.sha, len(keys), *args) + + +class AsyncScript: + """ + An executable Lua script object returned by ``register_script`` + """ + + def __init__(self, registered_client: "AsyncRedis", script: ScriptTextT): + self.registered_client = registered_client + self.script = script + # Precalculate and store the SHA1 hex digest of the script. + + if isinstance(script, str): + # We need the encoding from the client in order to generate an + # accurate byte representation of the script + encoder = registered_client.connection_pool.get_encoder() + script = encoder.encode(script) + self.sha = hashlib.sha1(script).hexdigest() + + async def __call__( + self, + keys: Union[Sequence[KeyT], None] = None, + args: Union[Iterable[EncodableT], None] = None, + client: Union["AsyncRedis", None] = None, + ): + """Execute the script, passing any required ``args``""" + keys = keys or [] + args = args or [] + if client is None: + client = self.registered_client + args = tuple(keys) + tuple(args) + # make sure the Redis server knows about the script + from redis.asyncio.client import Pipeline + + if isinstance(client, Pipeline): + # Make sure the pipeline can register the script before executing. + client.scripts.add(self) + try: + return await client.evalsha(self.sha, len(keys), *args) + except NoScriptError: + # Maybe the client is pointed to a different server than the client + # that created this instance? + # Overwrite the sha just in case there was a discrepancy. + self.sha = await client.script_load(self.script) + return await client.evalsha(self.sha, len(keys), *args) + + class PubSubCommands(CommandsProtocol): """ Redis PubSub commands. @@ -4643,7 +4847,7 @@ def script_debug(self, *args) -> None: ) def script_flush( - self, sync_type: Literal["SYNC"] | Literal["ASYNC"] = None + self, sync_type: Union[Literal["SYNC"], Literal["ASYNC"]] = None ) -> ResponseT: """Flush all scripts from the script cache. ``sync_type`` is by default SYNC (synchronous) but it can also be @@ -4680,7 +4884,7 @@ def script_load(self, script: ScriptTextT) -> ResponseT: """ return self.execute_command("SCRIPT LOAD", script) - def register_script(self: Redis, script: ScriptTextT) -> Script: + def register_script(self: "Redis", script: ScriptTextT) -> Script: """ Register a Lua ``script`` specifying the ``keys`` it will touch. Returns a Script object that is callable and hides the complexity of @@ -4694,7 +4898,7 @@ class AsyncScriptCommands(ScriptCommands): async def script_debug(self, *args) -> None: return super().script_debug() - def register_script(self: AsyncRedis, script: ScriptTextT) -> AsyncScript: + def register_script(self: "AsyncRedis", script: ScriptTextT) -> AsyncScript: """ Register a Lua ``script`` specifying the ``keys`` it will touch. Returns a Script object that is callable and hides the complexity of @@ -4757,7 +4961,7 @@ def geodist( name: KeyT, place1: FieldT, place2: FieldT, - unit: str | None = None, + unit: Union[str, None] = None, ) -> ResponseT: """ Return the distance between ``place1`` and ``place2`` members of the @@ -4799,14 +5003,14 @@ def georadius( longitude: float, latitude: float, radius: float, - unit: str | None = None, + unit: Union[str, None] = None, withdist: bool = False, withcoord: bool = False, withhash: bool = False, - count: int | None = None, - sort: str | None = None, - store: KeyT | None = None, - store_dist: KeyT | None = None, + count: Union[int, None] = None, + sort: Union[str, None] = None, + store: Union[KeyT, None] = None, + store_dist: Union[KeyT, None] = None, any: bool = False, ) -> ResponseT: """ @@ -4861,14 +5065,14 @@ def georadiusbymember( name: KeyT, member: FieldT, radius: float, - unit: str | None = None, + unit: Union[str, None] = None, withdist: bool = False, withcoord: bool = False, withhash: bool = False, - count: int | None = None, - sort: str | None = None, - store: KeyT | None = None, - store_dist: KeyT | None = None, + count: Union[int, None] = None, + sort: Union[str, None] = None, + store: Union[KeyT, None] = None, + store_dist: Union[KeyT, None] = None, any: bool = False, ) -> ResponseT: """ @@ -4899,7 +5103,7 @@ def _georadiusgeneric( self, command: str, *args: EncodableT, - **kwargs: EncodableT | None, + **kwargs: Union[EncodableT, None], ) -> ResponseT: pieces = list(args) if kwargs["unit"] and kwargs["unit"] not in ("m", "km", "mi", "ft"): @@ -4949,15 +5153,15 @@ def _georadiusgeneric( def geosearch( self, name: KeyT, - member: FieldT | None = None, - longitude: float | None = None, - latitude: float | None = None, + member: Union[FieldT, None] = None, + longitude: Union[float, None] = None, + latitude: Union[float, None] = None, unit: str = "m", - radius: float | None = None, - width: float | None = None, - height: float | None = None, - sort: str | None = None, - count: int | None = None, + radius: Union[float, None] = None, + width: Union[float, None] = None, + height: Union[float, None] = None, + sort: Union[str, None] = None, + count: Union[int, None] = None, any: bool = False, withcoord: bool = False, withdist: bool = False, @@ -5022,15 +5226,15 @@ def geosearchstore( self, dest: KeyT, name: KeyT, - member: FieldT | None = None, - longitude: float | None = None, - latitude: float | None = None, + member: Union[FieldT, None] = None, + longitude: Union[float, None] = None, + latitude: Union[float, None] = None, unit: str = "m", - radius: float | None = None, - width: float | None = None, - height: float | None = None, - sort: str | None = None, - count: int | None = None, + radius: Union[float, None] = None, + width: Union[float, None] = None, + height: Union[float, None] = None, + sort: Union[str, None] = None, + count: Union[int, None] = None, any: bool = False, storedist: bool = False, ) -> ResponseT: @@ -5069,7 +5273,7 @@ def _geosearchgeneric( self, command: str, *args: EncodableT, - **kwargs: EncodableT | None, + **kwargs: Union[EncodableT, None], ) -> ResponseT: pieces = list(args) @@ -5192,205 +5396,6 @@ async def command_info(self) -> None: return super().command_info() -class Script: - """ - An executable Lua script object returned by ``register_script`` - """ - - def __init__(self, registered_client: Redis, script: ScriptTextT): - self.registered_client = registered_client - self.script = script - # Precalculate and store the SHA1 hex digest of the script. - - if isinstance(script, str): - # We need the encoding from the client in order to generate an - # accurate byte representation of the script - encoder = registered_client.connection_pool.get_encoder() - script = encoder.encode(script) - self.sha = hashlib.sha1(script).hexdigest() - - def __call__( - self, - keys: Sequence[KeyT] | None = None, - args: Iterable[EncodableT] | None = None, - client: Redis | None = None, - ): - """Execute the script, passing any required ``args``""" - keys = keys or [] - args = args or [] - if client is None: - client = self.registered_client - args = tuple(keys) + tuple(args) - # make sure the Redis server knows about the script - from redis.client import Pipeline - - if isinstance(client, Pipeline): - # Make sure the pipeline can register the script before executing. - client.scripts.add(self) - try: - return client.evalsha(self.sha, len(keys), *args) - except NoScriptError: - # Maybe the client is pointed to a different server than the client - # that created this instance? - # Overwrite the sha just in case there was a discrepancy. - self.sha = client.script_load(self.script) - return client.evalsha(self.sha, len(keys), *args) - - -class AsyncScript: - """ - An executable Lua script object returned by ``register_script`` - """ - - def __init__(self, registered_client: AsyncRedis, script: ScriptTextT): - self.registered_client = registered_client - self.script = script - # Precalculate and store the SHA1 hex digest of the script. - - if isinstance(script, str): - # We need the encoding from the client in order to generate an - # accurate byte representation of the script - encoder = registered_client.connection_pool.get_encoder() - script = encoder.encode(script) - self.sha = hashlib.sha1(script).hexdigest() - - async def __call__( - self, - keys: Sequence[KeyT] | None = None, - args: Iterable[EncodableT] | None = None, - client: AsyncRedis | None = None, - ): - """Execute the script, passing any required ``args``""" - keys = keys or [] - args = args or [] - if client is None: - client = self.registered_client - args = tuple(keys) + tuple(args) - # make sure the Redis server knows about the script - from redis.asyncio.client import Pipeline - - if isinstance(client, Pipeline): - # Make sure the pipeline can register the script before executing. - client.scripts.add(self) - try: - return await client.evalsha(self.sha, len(keys), *args) - except NoScriptError: - # Maybe the client is pointed to a different server than the client - # that created this instance? - # Overwrite the sha just in case there was a discrepancy. - self.sha = await client.script_load(self.script) - return await client.evalsha(self.sha, len(keys), *args) - - -class BitFieldOperation: - """ - Command builder for BITFIELD commands. - """ - - def __init__( - self, client: Redis | AsyncRedis, key: str, default_overflow: str | None = None - ): - self.client = client - self.key = key - self._default_overflow = default_overflow - # for typing purposes, run the following in constructor and in reset() - self.operations: list[tuple[EncodableT, ...]] = [] - self._last_overflow = "WRAP" - self.reset() - - def reset(self): - """ - Reset the state of the instance to when it was constructed - """ - self.operations = [] - self._last_overflow = "WRAP" - self.overflow(self._default_overflow or self._last_overflow) - - def overflow(self, overflow: str): - """ - Update the overflow algorithm of successive INCRBY operations - :param overflow: Overflow algorithm, one of WRAP, SAT, FAIL. See the - Redis docs for descriptions of these algorithmsself. - :returns: a :py:class:`BitFieldOperation` instance. - """ - overflow = overflow.upper() - if overflow != self._last_overflow: - self._last_overflow = overflow - self.operations.append(("OVERFLOW", overflow)) - return self - - def incrby( - self, - fmt: str, - offset: BitfieldOffsetT, - increment: int, - overflow: str | None = None, - ): - """ - Increment a bitfield by a given amount. - :param fmt: format-string for the bitfield being updated, e.g. 'u8' - for an unsigned 8-bit integer. - :param offset: offset (in number of bits). If prefixed with a - '#', this is an offset multiplier, e.g. given the arguments - fmt='u8', offset='#2', the offset will be 16. - :param int increment: value to increment the bitfield by. - :param str overflow: overflow algorithm. Defaults to WRAP, but other - acceptable values are SAT and FAIL. See the Redis docs for - descriptions of these algorithms. - :returns: a :py:class:`BitFieldOperation` instance. - """ - if overflow is not None: - self.overflow(overflow) - - self.operations.append(("INCRBY", fmt, offset, increment)) - return self - - def get(self, fmt: str, offset: BitfieldOffsetT): - """ - Get the value of a given bitfield. - :param fmt: format-string for the bitfield being read, e.g. 'u8' for - an unsigned 8-bit integer. - :param offset: offset (in number of bits). If prefixed with a - '#', this is an offset multiplier, e.g. given the arguments - fmt='u8', offset='#2', the offset will be 16. - :returns: a :py:class:`BitFieldOperation` instance. - """ - self.operations.append(("GET", fmt, offset)) - return self - - def set(self, fmt: str, offset: BitfieldOffsetT, value: int): - """ - Set the value of a given bitfield. - :param fmt: format-string for the bitfield being read, e.g. 'u8' for - an unsigned 8-bit integer. - :param offset: offset (in number of bits). If prefixed with a - '#', this is an offset multiplier, e.g. given the arguments - fmt='u8', offset='#2', the offset will be 16. - :param int value: value to set at the given position. - :returns: a :py:class:`BitFieldOperation` instance. - """ - self.operations.append(("SET", fmt, offset, value)) - return self - - @property - def command(self): - cmd = ["BITFIELD", self.key] - for ops in self.operations: - cmd.extend(ops) - return cmd - - def execute(self) -> ResponseT: - """ - Execute the operation(s) in a single BITFIELD command. The return value - is a list of values corresponding to each operation. If the client - used to create this instance was a pipeline, the list of values - will be present within the pipeline's execute. - """ - command = self.command - self.reset() - return self.client.execute_command(*command) - - class ClusterCommands(CommandsProtocol): """ Class for Redis Cluster commands diff --git a/redis/typing.py b/redis/typing.py index 12372a3db6..d96e4e3a5d 100644 --- a/redis/typing.py +++ b/redis/typing.py @@ -1,4 +1,4 @@ -from __future__ import annotations +# from __future__ import annotations from datetime import datetime, timedelta from typing import TYPE_CHECKING, Iterable, TypeVar, Union @@ -39,7 +39,7 @@ class CommandsProtocol(Protocol): - connection_pool: Union[AsyncConnectionPool, ConnectionPool] + connection_pool: Union["AsyncConnectionPool", "ConnectionPool"] def execute_command(self, *args, **options): ... diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index d9d95561d4..e61f59c492 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -1,9 +1,9 @@ -from __future__ import annotations - import asyncio import random +from typing import Union from urllib.parse import urlparse +import pytest_asyncio import pytest from packaging.version import Version @@ -27,7 +27,7 @@ async def _get_info(redis_url): return info -@pytest.fixture( +@pytest_asyncio.fixture( params=[ (True, PythonParser), (False, PythonParser), @@ -91,12 +91,12 @@ async def ateardown(): return f -@pytest.fixture() +@pytest_asyncio.fixture() async def r(create_redis): yield await create_redis() -@pytest.fixture() +@pytest_asyncio.fixture() async def r2(create_redis): """A second client for tests that need multiple""" yield await create_redis() @@ -109,19 +109,19 @@ def _gen_cluster_mock_resp(r, response): return r -@pytest.fixture() +@pytest_asyncio.fixture() async def mock_cluster_resp_ok(create_redis, **kwargs): r = await create_redis(**kwargs) return _gen_cluster_mock_resp(r, "OK") -@pytest.fixture() +@pytest_asyncio.fixture() async def mock_cluster_resp_int(create_redis, **kwargs): r = await create_redis(**kwargs) return _gen_cluster_mock_resp(r, "2") -@pytest.fixture() +@pytest_asyncio.fixture() async def mock_cluster_resp_info(create_redis, **kwargs): r = await create_redis(**kwargs) response = ( @@ -135,7 +135,7 @@ async def mock_cluster_resp_info(create_redis, **kwargs): return _gen_cluster_mock_resp(r, response) -@pytest.fixture() +@pytest_asyncio.fixture() async def mock_cluster_resp_nodes(create_redis, **kwargs): r = await create_redis(**kwargs) response = ( @@ -159,7 +159,7 @@ async def mock_cluster_resp_nodes(create_redis, **kwargs): return _gen_cluster_mock_resp(r, response) -@pytest.fixture() +@pytest_asyncio.fixture() async def mock_cluster_resp_slaves(create_redis, **kwargs): r = await create_redis(**kwargs) response = ( @@ -170,7 +170,7 @@ async def mock_cluster_resp_slaves(create_redis, **kwargs): return _gen_cluster_mock_resp(r, response) -@pytest.fixture(scope="session") +@pytest_asyncio.fixture(scope="session") def master_host(request): url = request.config.getoption("--redis-url") parts = urlparse(url) @@ -178,7 +178,7 @@ def master_host(request): async def wait_for_command( - client: redis.Redis, monitor: Monitor, command: str, key: str | None = None + client: redis.Redis, monitor: Monitor, command: str, key: Union[str, None] = None ): # issue a command with a key name that's local to this process. # if we find a command with our key before the command we're waiting diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index d3afad9800..d479d05f00 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -3,6 +3,7 @@ import re import pytest +import pytest_asyncio import redis.asyncio as redis from redis.asyncio.connection import Connection, to_bool @@ -15,7 +16,7 @@ class TestRedisAutoReleaseConnectionPool: - @pytest.fixture + @pytest_asyncio.fixture async def r(self, create_redis) -> redis.Redis: """This is necessary since r and r2 create ConnectionPools behind the scenes""" r = await create_redis() @@ -684,7 +685,7 @@ async def test_connect_invalid_password_supplied(self, r): @pytest.mark.onlynoncluster class TestMultiConnectionClient: - @pytest.fixture() + @pytest_asyncio.fixture() async def r(self, create_redis, server): redis = await create_redis(single_connection_client=False) yield redis @@ -695,7 +696,7 @@ async def r(self, create_redis, server): class TestHealthCheck: interval = 60 - @pytest.fixture() + @pytest_asyncio.fixture() async def r(self, create_redis): redis = await create_redis(health_check_interval=self.interval) yield redis diff --git a/tests/test_asyncio/test_encoding.py b/tests/test_asyncio/test_encoding.py index b68c7fc1f4..efad80f741 100644 --- a/tests/test_asyncio/test_encoding.py +++ b/tests/test_asyncio/test_encoding.py @@ -1,4 +1,5 @@ import pytest +import pytest_asyncio import redis.asyncio as redis from redis.exceptions import DataError @@ -8,13 +9,13 @@ @pytest.mark.onlynoncluster class TestEncoding: - @pytest.fixture() + @pytest_asyncio.fixture() async def r(self, create_redis): redis = await create_redis(decode_responses=True) yield redis await redis.flushall() - @pytest.fixture() + @pytest_asyncio.fixture() async def r_no_decode(self, create_redis): redis = await create_redis(decode_responses=False) yield redis @@ -91,7 +92,7 @@ async def test_memoryviews_are_not_packed(self, r): class TestCommandsAreNotEncoded: - @pytest.fixture() + @pytest_asyncio.fixture() async def r(self, create_redis): redis = await create_redis(encoding="utf-16") yield redis diff --git a/tests/test_asyncio/test_lock.py b/tests/test_asyncio/test_lock.py index f497fac0c0..4f2a5ffba0 100644 --- a/tests/test_asyncio/test_lock.py +++ b/tests/test_asyncio/test_lock.py @@ -1,6 +1,7 @@ import asyncio import pytest +import pytest_asyncio from redis.asyncio.lock import Lock from redis.exceptions import LockError, LockNotOwnedError @@ -10,7 +11,7 @@ @pytest.mark.onlynoncluster class TestLock: - @pytest.fixture() + @pytest_asyncio.fixture() async def r_decoded(self, create_redis): redis = await create_redis(decode_responses=True) yield redis diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index d889d4f5ef..7c980c3afb 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -2,6 +2,7 @@ from typing import Optional import pytest +import pytest_asyncio import redis.asyncio as redis from redis.exceptions import ConnectionError @@ -403,7 +404,7 @@ def setup_method(self, method): def message_handler(self, message): self.message = message - @pytest.fixture() + @pytest_asyncio.fixture() async def r(self, create_redis): return await create_redis( decode_responses=True, diff --git a/tests/test_asyncio/test_scripting.py b/tests/test_asyncio/test_scripting.py index 5d01f25ff5..3776d12cb7 100644 --- a/tests/test_asyncio/test_scripting.py +++ b/tests/test_asyncio/test_scripting.py @@ -1,4 +1,5 @@ import pytest +import pytest_asyncio from redis import exceptions from tests.conftest import skip_if_server_version_lt @@ -22,7 +23,7 @@ @pytest.mark.onlynoncluster class TestScripting: - @pytest.fixture + @pytest_asyncio.fixture async def r(self, create_redis): redis = await create_redis() yield redis diff --git a/tests/test_asyncio/test_sentinel.py b/tests/test_asyncio/test_sentinel.py index 2b22d6a339..2f99537f18 100644 --- a/tests/test_asyncio/test_sentinel.py +++ b/tests/test_asyncio/test_sentinel.py @@ -1,6 +1,7 @@ import socket import pytest +import pytest_asyncio import redis.asyncio.sentinel from redis import exceptions @@ -14,7 +15,7 @@ pytestmark = pytest.mark.asyncio -@pytest.fixture(scope="module") +@pytest_asyncio.fixture(scope="module") def master_ip(master_host): yield socket.gethostbyname(master_host) @@ -71,7 +72,7 @@ def client(self, host, port, **kwargs): return SentinelTestClient(self, (host, port)) -@pytest.fixture() +@pytest_asyncio.fixture() async def cluster(master_ip): cluster = SentinelTestCluster(ip=master_ip) @@ -81,7 +82,7 @@ async def cluster(master_ip): redis.asyncio.sentinel.Redis = saved_Redis -@pytest.fixture() +@pytest_asyncio.fixture() def sentinel(request, cluster): return Sentinel([("foo", 26379), ("bar", 26379)]) From f1c60c9ce8026b50337dd1e95136981e7b8fa81c Mon Sep 17 00:00:00 2001 From: "Chayim I. Kirshen" Date: Thu, 17 Feb 2022 21:29:44 +0200 Subject: [PATCH 21/24] standalone fix --- tox.ini | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tox.ini b/tox.ini index ca8c9b33fd..d7da572ff5 100644 --- a/tox.ini +++ b/tox.ini @@ -269,10 +269,10 @@ extras = setenv = CLUSTER_URL = "redis://localhost:16379/0" commands = - standalone: pytest --cov=./ --cov-report=xml:coverage_redis.xml --asyncio-mode=auto -W always -m 'not onlycluster' {posargs} - standalone-uvloop: pytest --cov=./ --cov-report=xml:coverage_redis.xml --asyncio-mode=auto -W always -m 'not onlycluster' --uvloop {posargs} - cluster: pytest --cov=./ --cov-report=xml:coverage_cluster.xml --asyncio-mode=auto -W always -m 'not onlynoncluster and not redismod' --redis-url={env:CLUSTER_URL:} {posargs} - cluster-uvloop: pytest --cov=./ --cov-report=xml:coverage_redis.xml --asyncio-mode=auto -W always -m 'not onlycluster' --uvloop {posargs} + standalone: pytest --cov=./ --cov-report=xml:coverage_redis.xml -W always -m 'not onlycluster' {posargs} + standalone-uvloop: pytest --cov=./ --cov-report=xml:coverage_redis.xml -W always -m 'not onlycluster' --uvloop {posargs} + cluster: pytest --cov=./ --cov-report=xml:coverage_cluster.xml -W always -m 'not onlynoncluster and not redismod' --redis-url={env:CLUSTER_URL:} {posargs} + cluster-uvloop: pytest --cov=./ --cov-report=xml:coverage_redis.xml -W always -m 'not onlycluster' --uvloop {posargs} [testenv:redis5] deps = From f4c062bae490a02c3be13cae9fa552a3c972e660 Mon Sep 17 00:00:00 2001 From: "Chayim I. Kirshen" Date: Tue, 22 Feb 2022 11:12:50 +0200 Subject: [PATCH 22/24] fixing 3.6 tests, and cluster tests updating docs to point to async connections adding python deprecation notice --- .gitignore | 2 +- README.md | 5 +++++ docs/connections.rst | 12 +++++++++++- requirements.txt | 2 +- tests/test_asyncio/conftest.py | 7 ++++++- tests/test_asyncio/test_connection_pool.py | 7 ++++++- tests/test_asyncio/test_encoding.py | 8 +++++++- tests/test_asyncio/test_lock.py | 7 ++++++- tests/test_asyncio/test_pubsub.py | 7 ++++++- tests/test_asyncio/test_scripting.py | 8 +++++++- tests/test_asyncio/test_sentinel.py | 7 ++++++- tox.ini | 1 + 12 files changed, 63 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index 96fbdd5646..b392a2d748 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,7 @@ vagrant/.vagrant env venv coverage.xml -.venv +.venv* *.xml .coverage* docker/stunnel/keys diff --git a/README.md b/README.md index 006a513225..c6ee3c10eb 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,11 @@ The Python interface to the Redis key-value store. --------------------------------------------- +## Python Notice + +redis-py 4.2.x will be the last generation of redis-py to support python 3.6 as it has been [End of Life'd](https://www.python.org/dev/peps/pep-0494/#schedule-last-security-only-release). Async support was introduced in redis-py 4.2.x thanks to [aioredis](https://github.com/aio-libs/aioredis-py), which necessitates this change. We will continue to maintain 3.6 support as long as possible - but the plan is for redis-py version 5+ to offically remove 3.6. + +--------------------------- ## Installation diff --git a/docs/connections.rst b/docs/connections.rst index ba39f3341f..9804a15bf1 100644 --- a/docs/connections.rst +++ b/docs/connections.rst @@ -42,4 +42,14 @@ Connection Pools .. autoclass:: redis.connection.ConnectionPool :members: -More connection examples can be found `here `_. \ No newline at end of file +More connection examples can be found `here `_. + +Async Client +************ + +This client is used for communicating with Redis, asynchronously. + +.. autoclass:: redis.asyncio.connection.Connection + :members: + +More connection examples can be found `here `_ \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 001ecb6bfd..7f0ebf0c69 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -async-timeout +async-timeout>=4.0.2 deprecated>=1.2.3 packaging>=20.4 typing-extensions diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index e61f59c492..0e9c73ec61 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -1,9 +1,14 @@ import asyncio import random +import sys from typing import Union from urllib.parse import urlparse -import pytest_asyncio +if sys.version_info[0:2] == (3, 6): + import pytest as pytest_asyncio +else: + import pytest_asyncio + import pytest from packaging.version import Version diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index d479d05f00..f9dfefd5cc 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -1,9 +1,14 @@ import asyncio import os import re +import sys import pytest -import pytest_asyncio + +if sys.version_info[0:2] == (3, 6): + import pytest as pytest_asyncio +else: + import pytest_asyncio import redis.asyncio as redis from redis.asyncio.connection import Connection, to_bool diff --git a/tests/test_asyncio/test_encoding.py b/tests/test_asyncio/test_encoding.py index efad80f741..da2983738e 100644 --- a/tests/test_asyncio/test_encoding.py +++ b/tests/test_asyncio/test_encoding.py @@ -1,5 +1,11 @@ +import sys + import pytest -import pytest_asyncio + +if sys.version_info[0:2] == (3, 6): + import pytest as pytest_asyncio +else: + import pytest_asyncio import redis.asyncio as redis from redis.exceptions import DataError diff --git a/tests/test_asyncio/test_lock.py b/tests/test_asyncio/test_lock.py index 4f2a5ffba0..c496718a67 100644 --- a/tests/test_asyncio/test_lock.py +++ b/tests/test_asyncio/test_lock.py @@ -1,7 +1,12 @@ import asyncio +import sys import pytest -import pytest_asyncio + +if sys.version_info[0:2] == (3, 6): + import pytest as pytest_asyncio +else: + import pytest_asyncio from redis.asyncio.lock import Lock from redis.exceptions import LockError, LockNotOwnedError diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index 7c980c3afb..9efcd3cf8a 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -1,8 +1,13 @@ import asyncio +import sys from typing import Optional import pytest -import pytest_asyncio + +if sys.version_info[0:2] == (3, 6): + import pytest as pytest_asyncio +else: + import pytest_asyncio import redis.asyncio as redis from redis.exceptions import ConnectionError diff --git a/tests/test_asyncio/test_scripting.py b/tests/test_asyncio/test_scripting.py index 3776d12cb7..764525fb4a 100644 --- a/tests/test_asyncio/test_scripting.py +++ b/tests/test_asyncio/test_scripting.py @@ -1,5 +1,11 @@ +import sys + import pytest -import pytest_asyncio + +if sys.version_info[0:2] == (3, 6): + import pytest as pytest_asyncio +else: + import pytest_asyncio from redis import exceptions from tests.conftest import skip_if_server_version_lt diff --git a/tests/test_asyncio/test_sentinel.py b/tests/test_asyncio/test_sentinel.py index 2f99537f18..cd6810c1b5 100644 --- a/tests/test_asyncio/test_sentinel.py +++ b/tests/test_asyncio/test_sentinel.py @@ -1,7 +1,12 @@ import socket +import sys import pytest -import pytest_asyncio + +if sys.version_info[0:2] == (3, 6): + import pytest as pytest_asyncio +else: + import pytest_asyncio import redis.asyncio.sentinel from redis import exceptions diff --git a/tox.ini b/tox.ini index d7da572ff5..3ef01ddead 100644 --- a/tox.ini +++ b/tox.ini @@ -352,6 +352,7 @@ exclude = dist, docker, venv*, + .venv*, whitelist.py ignore = F405 From f4a6955c4fe9e80884f39f52026a856126e4ddcf Mon Sep 17 00:00:00 2001 From: "Chayim I. Kirshen" Date: Tue, 22 Feb 2022 12:05:31 +0200 Subject: [PATCH 23/24] vulture whitelist --- whitelist.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/whitelist.py b/whitelist.py index a800bcf482..52c238c926 100644 --- a/whitelist.py +++ b/whitelist.py @@ -13,3 +13,5 @@ exc_type # unused variable (/data/repos/redis/redis-py/redis/asyncio/utils.py:26) exc_value # unused variable (/data/repos/redis/redis-py/redis/asyncio/utils.py:26) traceback # unused variable (/data/repos/redis/redis-py/redis/asyncio/utils.py:26) +AsyncConnectionPool # unused import (//data/repos/redis/redis-py/redis/typing.py:9) +AsyncRedis # unused import (//data/repos/redis/redis-py/redis/commands/core.py:49) From d688f557d61960c1ede91911f4b9fabfb7d5bd0f Mon Sep 17 00:00:00 2001 From: "Chayim I. Kirshen" Date: Tue, 22 Feb 2022 12:06:07 +0200 Subject: [PATCH 24/24] linter: black --- whitelist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/whitelist.py b/whitelist.py index 52c238c926..27210284c7 100644 --- a/whitelist.py +++ b/whitelist.py @@ -14,4 +14,4 @@ exc_value # unused variable (/data/repos/redis/redis-py/redis/asyncio/utils.py:26) traceback # unused variable (/data/repos/redis/redis-py/redis/asyncio/utils.py:26) AsyncConnectionPool # unused import (//data/repos/redis/redis-py/redis/typing.py:9) -AsyncRedis # unused import (//data/repos/redis/redis-py/redis/commands/core.py:49) +AsyncRedis # unused import (//data/repos/redis/redis-py/redis/commands/core.py:49)