From 7a6d141a863c590ed06ec74e97b5ed167457e31f Mon Sep 17 00:00:00 2001 From: "Chayim I. Kirshen" Date: Wed, 20 Sep 2023 15:25:45 +0300 Subject: [PATCH 01/19] Some type hints --- dev_requirements.txt | 2 + redis/_parsers/base.py | 1 - redis/_parsers/resp3.py | 6 +- redis/asyncio/connection.py | 1 - redis/client.py | 111 ++++++++++++----------- redis/cluster.py | 1 - redis/commands/core.py | 2 +- redis/commands/json/commands.py | 1 - redis/commands/search/__init__.py | 1 - redis/commands/search/aggregation.py | 44 ++++----- redis/commands/search/commands.py | 74 ++++++++------- redis/commands/search/field.py | 1 - redis/commands/search/query.py | 124 +++++++++++++++----------- redis/commands/search/reducers.py | 30 ++++--- redis/commands/search/result.py | 2 +- redis/commands/search/suggestion.py | 10 ++- redis/connection.py | 62 ++++++------- tests/test_asyncio/test_cwe_404.py | 4 - tests/test_asyncio/test_json.py | 6 -- tests/test_asyncio/test_lock.py | 1 - tests/test_asyncio/test_pipeline.py | 1 - tests/test_asyncio/test_pubsub.py | 2 - tests/test_asyncio/test_search.py | 51 +++++------ tests/test_asyncio/test_sentinel.py | 1 - tests/test_asyncio/test_timeseries.py | 1 - tests/test_commands.py | 2 - tests/test_graph_utils/test_edge.py | 1 - tests/test_json.py | 7 -- tests/test_lock.py | 1 - tests/test_pipeline.py | 1 - tests/test_pubsub.py | 2 - tests/test_search.py | 3 - tests/test_timeseries.py | 1 - 33 files changed, 282 insertions(+), 276 deletions(-) diff --git a/dev_requirements.txt b/dev_requirements.txt index 3715599af0..6bd418bf46 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -15,3 +15,5 @@ ujson>=4.2.0 wheel>=0.30.0 urllib3<2 uvloop +types-requests +types-pyOpenSSL \ No newline at end of file diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index f77296df6a..4a005669af 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -46,7 +46,6 @@ class BaseParser(ABC): - EXCEPTION_CLASSES = { "ERR": { "max number of clients reached": ConnectionError, diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py index 1275686710..ad766a8f95 100644 --- a/redis/_parsers/resp3.py +++ b/redis/_parsers/resp3.py @@ -243,10 +243,8 @@ async def _read_response( ] res = self.push_handler_func(response) if not push_request: - return await ( - self._read_response( - disable_decoding=disable_decoding, push_request=push_request - ) + return await self._read_response( + disable_decoding=disable_decoding, push_request=push_request ) else: return res diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 71d0e92002..cce0576bf0 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -1148,7 +1148,6 @@ def __init__( queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated **connection_kwargs, ): - super().__init__( connection_class=connection_class, max_connections=max_connections, diff --git a/redis/client.py b/redis/client.py index 1e1ff57605..6ef71f083c 100755 --- a/redis/client.py +++ b/redis/client.py @@ -4,8 +4,9 @@ import time import warnings from itertools import chain -from typing import Optional, Type +from typing import Any, Callable, Dict, List, Optional, Type, Union +from redis._parsers.encoders import Encoder from redis._parsers.helpers import ( _RedisCallbacks, _RedisCallbacksRESP2, @@ -49,7 +50,7 @@ class CaseInsensitiveDict(dict): "Case insensitive dict implementation. Assumes string keys only." - def __init__(self, data): + def __init__(self, data) -> None: for k, v in data.items(): self[k.upper()] = v @@ -93,7 +94,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): """ @classmethod - def from_url(cls, url, **kwargs): + def from_url(cls, url, **kwargs) -> None: """ Return a Redis client object configured from the given URL @@ -202,7 +203,7 @@ def __init__( redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, - ): + ) -> None: """ Initialize a new Redis client. To specify a retry policy for specific errors, first set @@ -309,14 +310,14 @@ def __init__( else: self.response_callbacks.update(_RedisCallbacksRESP2) - def __repr__(self): + def __repr__(self) -> str: return f"{type(self).__name__}<{repr(self.connection_pool)}>" - def get_encoder(self): + def get_encoder(self) -> "Encoder": """Get the connection pool's encoder""" return self.connection_pool.get_encoder() - def get_connection_kwargs(self): + def get_connection_kwargs(self) -> Dict: """Get the connection's key-word arguments""" return self.connection_pool.connection_kwargs @@ -327,11 +328,11 @@ def set_retry(self, retry: "Retry") -> None: self.get_connection_kwargs().update({"retry": retry}) self.connection_pool.set_retry(retry) - def set_response_callback(self, command, callback): + def set_response_callback(self, command, callback) -> None: """Set a custom Response Callback""" self.response_callbacks[command] = callback - def load_external_module(self, funcname, func): + def load_external_module(self, funcname, func) -> None: """ This function can be used to add externally defined redis modules, and their namespaces to the redis client. @@ -354,7 +355,7 @@ def load_external_module(self, funcname, func): """ setattr(self, funcname, func) - def pipeline(self, transaction=True, shard_hint=None): + def pipeline(self, transaction=True, shard_hint=None) -> "Pipeline": """ Return a new pipeline object that can queue multiple commands for later execution. ``transaction`` indicates whether all commands @@ -366,7 +367,7 @@ def pipeline(self, transaction=True, shard_hint=None): self.connection_pool, self.response_callbacks, transaction, shard_hint ) - def transaction(self, func, *watches, **kwargs): + def transaction(self, func: Callable["..."], *watches, **kwargs) -> None: """ Convenience method for executing the callable `func` as a transaction while watching all keys specified in `watches`. The 'func' callable @@ -390,13 +391,13 @@ def transaction(self, func, *watches, **kwargs): def lock( self, - name, - timeout=None, - sleep=0.1, - blocking=True, - blocking_timeout=None, - lock_class=None, - thread_local=True, + name: str, + timeout: Union[None, float] = None, + sleep: float = 0.1, + blocking: bool = True, + blocking_timeout: Union[None, float] = None, + lock_class: Union[None, Any] = None, + thread_local: bool = True, ): """ Return a new Lock object using key ``name`` that mimics @@ -648,9 +649,9 @@ def __init__( self, connection_pool, shard_hint=None, - ignore_subscribe_messages=False, - encoder=None, - push_handler_func=None, + ignore_subscribe_messages: bool = False, + encoder: Union[None, "Encoder"] = None, + push_handler_func: Union[None, Callable["..."]] = None, ): self.connection_pool = connection_pool self.shard_hint = shard_hint @@ -672,13 +673,13 @@ def __init__( _set_info_logger() self.reset() - def __enter__(self): + def __enter__(self) -> "PubSub": return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type, exc_value, traceback) -> None: self.reset() - def __del__(self): + def __del__(self) -> None: try: # if this object went out of scope prior to shutting down # subscriptions, close the connection manually before @@ -687,7 +688,7 @@ def __del__(self): except Exception: pass - def reset(self): + def reset(self) -> None: if self.connection: self.connection.disconnect() self.connection.clear_connect_callbacks() @@ -702,10 +703,10 @@ def reset(self): self.pending_unsubscribe_patterns = set() self.subscribed_event.clear() - def close(self): + def close(self) -> None: self.reset() - def on_connect(self, connection): + def on_connect(self, connection) -> None: "Re-subscribe to any channels and patterns previously subscribed to" # NOTE: for python3, we can't pass bytestrings as keyword arguments # so we need to decode channel/pattern names back to unicode strings @@ -731,7 +732,7 @@ def on_connect(self, connection): self.ssubscribe(**shard_channels) @property - def subscribed(self): + def subscribed(self) -> bool: """Indicates if there are subscriptions to any channels or patterns""" return self.subscribed_event.is_set() @@ -757,7 +758,7 @@ def execute_command(self, *args): self.clean_health_check_responses() self._execute(connection, connection.send_command, *args, **kwargs) - def clean_health_check_responses(self): + def clean_health_check_responses(self) -> None: """ If any health check responses are present, clean them """ @@ -775,7 +776,7 @@ def clean_health_check_responses(self): ) ttl -= 1 - def _disconnect_raise_connect(self, conn, error): + def _disconnect_raise_connect(self, conn, error) -> None: """ Close the connection and raise an exception if retry_on_timeout is not set or the error @@ -826,7 +827,7 @@ def try_read(): return None return response - def is_health_check_response(self, response): + def is_health_check_response(self, response) -> bool: """ Check if the response is a health check response. If there are no subscriptions redis responds to PING command with a @@ -837,7 +838,7 @@ def is_health_check_response(self, response): self.health_check_response_b, # If there wasn't ] - def check_health(self): + def check_health(self) -> None: conn = self.connection if conn is None: raise RuntimeError( @@ -849,7 +850,7 @@ def check_health(self): conn.send_command("PING", self.HEALTH_CHECK_MESSAGE, check_health=False) self.health_check_response_counter += 1 - def _normalize_keys(self, data): + def _normalize_keys(self, data) -> Dict: """ normalize channel/pattern names to be either bytes or strings based on whether responses are automatically decoded. this saves us @@ -983,7 +984,9 @@ def listen(self): if response is not None: yield response - def get_message(self, ignore_subscribe_messages=False, timeout=0.0): + def get_message( + self, ignore_subscribe_messages: bool = False, timeout: float = 0.0 + ): """ Get the next message if one is available, otherwise None. @@ -1012,7 +1015,7 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0.0): get_sharded_message = get_message - def ping(self, message=None): + def ping(self, message: Union[str, None] = None) -> bool: """ Ping the Redis server """ @@ -1093,7 +1096,9 @@ def handle_message(self, response, ignore_subscribe_messages=False): return message - def run_in_thread(self, sleep_time=0, daemon=False, exception_handler=None): + def run_in_thread( + self, sleep_time=0, daemon=False, exception_handler=None + ) -> "PubSubWorkerThread": for channel, handler in self.channels.items(): if handler is None: raise PubSubError(f"Channel: '{channel}' has no handler registered") @@ -1114,7 +1119,13 @@ def run_in_thread(self, sleep_time=0, daemon=False, exception_handler=None): class PubSubWorkerThread(threading.Thread): - def __init__(self, pubsub, sleep_time, daemon=False, exception_handler=None): + def __init__( + self, + pubsub, + sleep_time: float, + daemon: bool = False, + exception_handler: Union[Callable["..."], None] = None, + ): super().__init__() self.daemon = daemon self.pubsub = pubsub @@ -1122,7 +1133,7 @@ def __init__(self, pubsub, sleep_time, daemon=False, exception_handler=None): self.exception_handler = exception_handler self._running = threading.Event() - def run(self): + def run(self) -> None: if self._running.is_set(): return self._running.set() @@ -1137,7 +1148,7 @@ def run(self): self.exception_handler(e, pubsub, self) pubsub.close() - def stop(self): + def stop(self) -> None: # trip the flag so the run loop exits. the run loop will # close the pubsub connection, which disconnects the socket # and returns the connection to the pool. @@ -1175,7 +1186,7 @@ def __init__(self, connection_pool, response_callbacks, transaction, shard_hint) self.watching = False self.reset() - def __enter__(self): + def __enter__(self) -> "Pipeline": return self def __exit__(self, exc_type, exc_value, traceback): @@ -1187,14 +1198,14 @@ def __del__(self): except Exception: pass - def __len__(self): + def __len__(self) -> int: return len(self.command_stack) - def __bool__(self): + def __bool__(self) -> bool: """Pipeline instances should always evaluate to True""" return True - def reset(self): + def reset(self) -> None: self.command_stack = [] self.scripts = set() # make sure to reset the connection state in the event that we were @@ -1217,11 +1228,11 @@ def reset(self): self.connection_pool.release(self.connection) self.connection = None - def close(self): + def close(self) -> None: """Close the pipeline""" self.reset() - def multi(self): + def multi(self) -> None: """ Start a transactional block of the pipeline after WATCH commands are issued. End the transactional block with `execute`. @@ -1239,7 +1250,7 @@ def execute_command(self, *args, **kwargs): return self.immediate_execute_command(*args, **kwargs) return self.pipeline_execute_command(*args, **kwargs) - def _disconnect_reset_raise(self, conn, error): + def _disconnect_reset_raise(self, conn, error) -> None: """ Close the connection, reset watching state and raise an exception if we were watching, @@ -1282,7 +1293,7 @@ def immediate_execute_command(self, *args, **options): lambda error: self._disconnect_reset_raise(conn, error), ) - def pipeline_execute_command(self, *args, **options): + def pipeline_execute_command(self, *args, **options) -> "Pipeline": """ Stage a command to be executed when execute() is next called @@ -1297,7 +1308,7 @@ def pipeline_execute_command(self, *args, **options): self.command_stack.append((args, options)) return self - def _execute_transaction(self, connection, commands, raise_on_error): + def _execute_transaction(self, connection, commands, raise_on_error) -> List: cmds = chain([(("MULTI",), {})], commands, [(("EXEC",), {})]) all_cmds = connection.pack_commands( [args for args, options in cmds if EMPTY_RESPONSE not in options] @@ -1415,7 +1426,7 @@ def load_scripts(self): if not exist: s.sha = immediate("SCRIPT LOAD", s.script) - def _disconnect_raise_reset(self, conn, error): + def _disconnect_raise_reset(self, conn: Redis, error: Exception) -> None: """ Close the connection, raise an exception if we were watching, and raise an exception if TimeoutError is not part of retry_on_error, @@ -1477,6 +1488,6 @@ def watch(self, *names): raise RedisError("Cannot issue a WATCH after a MULTI") return self.execute_command("WATCH", *names) - def unwatch(self): + def unwatch(self) -> bool: """Unwatches all previously specified keys""" return self.watching and self.execute_command("UNWATCH") or True diff --git a/redis/cluster.py b/redis/cluster.py index 2ce9c54f85..620cba322e 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -2457,7 +2457,6 @@ def read(self): """ """ connection = self.connection for c in self.commands: - # if there is a result on this command, # it means we ran into an exception # like a connection error. Trying to parse diff --git a/redis/commands/core.py b/redis/commands/core.py index 9d81e9772c..49802576ec 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -403,7 +403,7 @@ class ManagementCommands(CommandsProtocol): Redis management commands """ - def auth(self, password, username=None, **kwargs): + def auth(self, password: str, username=None, **kwargs): """ Authenticates the user. If you do not pass username, Redis will try to authenticate for the "default" user. If you do pass username, it will diff --git a/redis/commands/json/commands.py b/redis/commands/json/commands.py index 3abe155796..0f92e0d6c9 100644 --- a/redis/commands/json/commands.py +++ b/redis/commands/json/commands.py @@ -80,7 +80,6 @@ def arrpop( path: Optional[str] = Path.root_path(), index: Optional[int] = -1, ) -> List[Union[str, None]]: - """Pop the element at ``index`` in the array JSON value under ``path`` at key ``name``. diff --git a/redis/commands/search/__init__.py b/redis/commands/search/__init__.py index e635f91e99..a2bb23b76d 100644 --- a/redis/commands/search/__init__.py +++ b/redis/commands/search/__init__.py @@ -27,7 +27,6 @@ class BatchIndexer: """ def __init__(self, client, chunk_size=1000): - self.client = client self.execute_command = client.execute_command self._pipeline = client.pipeline(transaction=False, shard_hint=None) diff --git a/redis/commands/search/aggregation.py b/redis/commands/search/aggregation.py index 93a3d9273b..aa42148470 100644 --- a/redis/commands/search/aggregation.py +++ b/redis/commands/search/aggregation.py @@ -1,8 +1,10 @@ +from typing import List, Union + FIELDNAME = object() class Limit: - def __init__(self, offset=0, count=0): + def __init__(self, offset: int = 0, count: int = 0) -> None: self.offset = offset self.count = count @@ -22,12 +24,12 @@ class Reducer: NAME = None - def __init__(self, *args): + def __init__(self, *args: List[str]) -> None: self._args = args self._field = None self._alias = None - def alias(self, alias): + def alias(self, alias: str) -> "Reducer": """ Set the alias for this reducer. @@ -51,7 +53,7 @@ def alias(self, alias): return self @property - def args(self): + def args(self) -> List[str]: return self._args @@ -62,7 +64,7 @@ class SortDirection: DIRSTRING = None - def __init__(self, field): + def __init__(self, field: str) -> None: self.field = field @@ -87,7 +89,7 @@ class AggregateRequest: Aggregation request which can be passed to `Client.aggregate`. """ - def __init__(self, query="*"): + def __init__(self, query: str = "*") -> None: """ Create an aggregation request. This request may then be passed to `client.aggregate()`. @@ -110,7 +112,7 @@ def __init__(self, query="*"): self._cursor = [] self._dialect = None - def load(self, *fields): + def load(self, *fields: List[str]) -> "AggregateRequest": """ Indicate the fields to be returned in the response. These fields are returned in addition to any others implicitly specified. @@ -126,7 +128,9 @@ def load(self, *fields): self._loadall = True return self - def group_by(self, fields, *reducers): + def group_by( + self, fields: List[str], *reducers: Union[Reducer, List[Reducer]] + ) -> "AggregateRequest": """ Specify by which fields to group the aggregation. @@ -151,7 +155,7 @@ def group_by(self, fields, *reducers): self._aggregateplan.extend(ret) return self - def apply(self, **kwexpr): + def apply(self, **kwexpr) -> "AggregateRequest": """ Specify one or more projection expressions to add to each result @@ -169,7 +173,7 @@ def apply(self, **kwexpr): return self - def limit(self, offset, num): + def limit(self, offset: int, num: int) -> "AggregateRequest": """ Sets the limit for the most recent group or query. @@ -215,7 +219,7 @@ def limit(self, offset, num): self._aggregateplan.extend(_limit.build_args()) return self - def sort_by(self, *fields, **kwargs): + def sort_by(self, *fields: List[str], **kwargs) -> "AggregateRequest": """ Indicate how the results should be sorted. This can also be used for *top-N* style queries @@ -262,7 +266,7 @@ def sort_by(self, *fields, **kwargs): self._aggregateplan.extend(ret) return self - def filter(self, expressions): + def filter(self, expressions: Union(str, List[str])) -> "AggregateRequest": """ Specify filter for post-query results using predicates relating to values in the result set. @@ -280,7 +284,7 @@ def filter(self, expressions): return self - def with_schema(self): + def with_schema(self) -> "AggregateRequest": """ If set, the `schema` property will contain a list of `[field, type]` entries in the result object. @@ -288,11 +292,11 @@ def with_schema(self): self._with_schema = True return self - def verbatim(self): + def verbatim(self) -> "AggregateRequest": self._verbatim = True return self - def cursor(self, count=0, max_idle=0.0): + def cursor(self, count: int = 0, max_idle: float = 0.0) -> "AggregateRequest": args = ["WITHCURSOR"] if count: args += ["COUNT", str(count)] @@ -301,7 +305,7 @@ def cursor(self, count=0, max_idle=0.0): self._cursor = args return self - def build_args(self): + def build_args(self) -> List[str]: # @foo:bar ... ret = [self._query] @@ -329,7 +333,7 @@ def build_args(self): return ret - def dialect(self, dialect): + def dialect(self, dialect: int) -> "AggregateRequest": """ Add a dialect field to the aggregate command. @@ -340,7 +344,7 @@ def dialect(self, dialect): class Cursor: - def __init__(self, cid): + def __init__(self, cid: int) -> None: self.cid = cid self.max_idle = 0 self.count = 0 @@ -355,12 +359,12 @@ def build_args(self): class AggregateResult: - def __init__(self, rows, cursor, schema): + def __init__(self, rows, cursor: Cursor, schema) -> None: self.rows = rows self.cursor = cursor self.schema = schema - def __repr__(self): + def __repr__(self) -> (str, str): cid = self.cursor.cid if self.cursor else -1 return ( f"<{self.__class__.__name__} at 0x{id(self):x} " diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index 83dea106d2..c00abdde98 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -1,6 +1,6 @@ import itertools import time -from typing import Dict, Optional, Union +from typing import Dict, List, Optional, Union from redis.client import Pipeline from redis.utils import deprecated_function @@ -220,7 +220,7 @@ def create_index( return self.execute_command(*args) - def alter_schema_add(self, fields): + def alter_schema_add(self, fields: List[str]): """ Alter the existing search index by adding new fields. The index must already exist. @@ -240,7 +240,7 @@ def alter_schema_add(self, fields): return self.execute_command(*args) - def dropindex(self, delete_documents=False): + def dropindex(self, delete_documents: bool = False): """ Drop the index if it exists. Replaced `drop_index` in RediSearch 2.0. @@ -322,15 +322,15 @@ def _add_document_hash( ) def add_document( self, - doc_id, - nosave=False, - score=1.0, - payload=None, - replace=False, - partial=False, - language=None, - no_create=False, - **fields, + doc_id: str, + nosave: bool = False, + score: float = 1.0, + payload: bool = None, + replace: bool = False, + partial: bool = False, + language: Union[str, None] = None, + no_create: str = False, + **fields: List[str], ): """ Add a single document to the index. @@ -554,7 +554,7 @@ def aggregate( AGGREGATE_CMD, raw, query=query, has_cursor=has_cursor ) - def _get_aggregate_result(self, raw, query, has_cursor): + def _get_aggregate_result(self, raw: List, query: str, has_cursor: bool): if has_cursor: if isinstance(query, Cursor): query.cid = raw[1] @@ -642,7 +642,7 @@ def spellcheck(self, query, distance=None, include=None, exclude=None): return self._parse_results(SPELLCHECK_CMD, res) - def dict_add(self, name, *terms): + def dict_add(self, name: str, *terms: List[str]): """Adds terms to a dictionary. ### Parameters @@ -656,7 +656,7 @@ def dict_add(self, name, *terms): cmd.extend(terms) return self.execute_command(*cmd) - def dict_del(self, name, *terms): + def dict_del(self, name: str, *terms: List[str]): """Deletes terms from a dictionary. ### Parameters @@ -670,7 +670,7 @@ def dict_del(self, name, *terms): cmd.extend(terms) return self.execute_command(*cmd) - def dict_dump(self, name): + def dict_dump(self, name: str): """Dumps all terms in the given dictionary. ### Parameters @@ -682,7 +682,7 @@ def dict_dump(self, name): cmd = [DICT_DUMP_CMD, name] return self.execute_command(*cmd) - def config_set(self, option, value): + def config_set(self, option: str, value: str) -> bool: """Set runtime configuration option. ### Parameters @@ -696,7 +696,7 @@ def config_set(self, option, value): raw = self.execute_command(*cmd) return raw == "OK" - def config_get(self, option): + def config_get(self, option: str) -> str: """Get runtime configuration option value. ### Parameters @@ -709,7 +709,7 @@ def config_get(self, option): res = self.execute_command(*cmd) return self._parse_results(CONFIG_CMD, res) - def tagvals(self, tagfield): + def tagvals(self, tagfield: str): """ Return a list of all possible tag values @@ -722,7 +722,7 @@ def tagvals(self, tagfield): return self.execute_command(TAGVALS_CMD, self.index_name, tagfield) - def aliasadd(self, alias): + def aliasadd(self, alias: str): """ Alias a search index - will fail if alias already exists @@ -735,7 +735,7 @@ def aliasadd(self, alias): return self.execute_command(ALIAS_ADD_CMD, alias, self.index_name) - def aliasupdate(self, alias): + def aliasupdate(self, alias: str): """ Updates an alias - will fail if alias does not already exist @@ -748,7 +748,7 @@ def aliasupdate(self, alias): return self.execute_command(ALIAS_UPDATE_CMD, alias, self.index_name) - def aliasdel(self, alias): + def aliasdel(self, alias: str): """ Removes an alias to a search index @@ -783,7 +783,7 @@ def sugadd(self, key, *suggestions, **kwargs): return pipe.execute()[-1] - def suglen(self, key): + def suglen(self, key: str) -> int: """ Return the number of entries in the AutoCompleter index. @@ -791,7 +791,7 @@ def suglen(self, key): """ # noqa return self.execute_command(SUGLEN_COMMAND, key) - def sugdel(self, key, string): + def sugdel(self, key: str, string: str) -> int: """ Delete a string from the AutoCompleter index. Returns 1 if the string was found and deleted, 0 otherwise. @@ -801,8 +801,14 @@ def sugdel(self, key, string): return self.execute_command(SUGDEL_COMMAND, key, string) def sugget( - self, key, prefix, fuzzy=False, num=10, with_scores=False, with_payloads=False - ): + self, + key: str, + prefix: str, + fuzzy: bool = False, + num: int = 10, + with_scores: bool = False, + with_payloads: bool = False, + ) -> List[SuggestionParser]: """ Get a list of suggestions from the AutoCompleter, for a given prefix. @@ -850,7 +856,7 @@ def sugget( parser = SuggestionParser(with_scores, with_payloads, res) return [s for s in parser] - def synupdate(self, groupid, skipinitial=False, *terms): + def synupdate(self, groupid: str, skipinitial: bool = False, *terms: List[str]): """ Updates a synonym group. The command is used to create or update a synonym group with @@ -986,7 +992,7 @@ async def spellcheck(self, query, distance=None, include=None, exclude=None): return self._parse_results(SPELLCHECK_CMD, res) - async def config_set(self, option, value): + async def config_set(self, option: str, value: str) -> bool: """Set runtime configuration option. ### Parameters @@ -1000,7 +1006,7 @@ async def config_set(self, option, value): raw = await self.execute_command(*cmd) return raw == "OK" - async def config_get(self, option): + async def config_get(self, option: str) -> str: """Get runtime configuration option value. ### Parameters @@ -1053,8 +1059,14 @@ async def sugadd(self, key, *suggestions, **kwargs): return (await pipe.execute())[-1] async def sugget( - self, key, prefix, fuzzy=False, num=10, with_scores=False, with_payloads=False - ): + self, + key: str, + prefix: str, + fuzzy: bool = False, + num: int = 10, + with_scores: bool = False, + with_payloads: bool = False, + ) -> List[SuggestionParser]: """ Get a list of suggestions from the AutoCompleter, for a given prefix. diff --git a/redis/commands/search/field.py b/redis/commands/search/field.py index 6f31ce1fc2..76eb58c2d7 100644 --- a/redis/commands/search/field.py +++ b/redis/commands/search/field.py @@ -4,7 +4,6 @@ class Field: - NUMERIC = "NUMERIC" TEXT = "TEXT" WEIGHT = "WEIGHT" diff --git a/redis/commands/search/query.py b/redis/commands/search/query.py index 362dd6c72a..ba0481dd64 100644 --- a/redis/commands/search/query.py +++ b/redis/commands/search/query.py @@ -1,3 +1,6 @@ +from typing import List, Union + + class Query: """ Query is used to build complex queries that have more parameters than just @@ -8,52 +11,52 @@ class Query: i.e. `Query("foo").verbatim().filter(...)` etc. """ - def __init__(self, query_string): + def __init__(self, query_string: str) -> None: """ Create a new query object. The query string is set in the constructor, and other options have setter functions. """ - self._query_string = query_string - self._offset = 0 - self._num = 10 - self._no_content = False - self._no_stopwords = False - self._fields = None - self._verbatim = False - self._with_payloads = False - self._with_scores = False - self._scorer = False - self._filters = list() - self._ids = None - self._slop = -1 - self._timeout = None - self._in_order = False - self._sortby = None - self._return_fields = [] - self._summarize_fields = [] - self._highlight_fields = [] - self._language = None - self._expander = None - self._dialect = None - - def query_string(self): + self._query_string: str = query_string + self._offset: int = 0 + self._num: int = 10 + self._no_content: bool = False + self._no_stopwords: bool = False + self._fields: Union[List[str], None] = None + self._verbatim: bool = False + self._with_payloads: bool = False + self._with_scores: bool = False + self._scorer: bool = False + self._filters: List = list() + self._ids: Union[List[str], None] = None + self._slop: int = -1 + self._timeout: Union[float, None] = None + self._in_order: bool = False + self._sortby: Union[SortbyField, None] = None + self._return_fields: List = [] + self._summarize_fields: List = [] + self._highlight_fields: List = [] + self._language: Union[str, None] = None + self._expander: Union[str, None] = None + self._dialect: Union[int, None] = None + + def query_string(self) -> str: """Return the query string of this query only.""" return self._query_string - def limit_ids(self, *ids): + def limit_ids(self, *ids) -> "Query": """Limit the results to a specific set of pre-known document ids of any length.""" self._ids = ids return self - def return_fields(self, *fields): + def return_fields(self, *fields) -> "Query": """Add fields to return fields.""" self._return_fields += fields return self - def return_field(self, field, as_field=None): + def return_field(self, field: str, as_field: Union[str, None] = None) -> "Query": """Add field to return fields (Optional: add 'AS' name to the field).""" self._return_fields.append(field) @@ -61,12 +64,18 @@ def return_field(self, field, as_field=None): self._return_fields += ("AS", as_field) return self - def _mk_field_list(self, fields): + def _mk_field_list(self, fields: List[str]) -> List: if not fields: return [] return [fields] if isinstance(fields, str) else list(fields) - def summarize(self, fields=None, context_len=None, num_frags=None, sep=None): + def summarize( + self, + fields: Union[None, List] = None, + context_len: Union[None, int] = None, + num_frags: Union[None, int] = None, + sep: [Union, str] = None, + ) -> "Query": """ Return an abridged format of the field, containing only the segments of the field which contain the matching term(s). @@ -98,7 +107,9 @@ def summarize(self, fields=None, context_len=None, num_frags=None, sep=None): self._summarize_fields = args return self - def highlight(self, fields=None, tags=None): + def highlight( + self, fields: Union[List[str], None] = None, tags: List[str, str] = None + ) -> None: """ Apply specified markup to matched term(s) within the returned field(s). @@ -116,7 +127,7 @@ def highlight(self, fields=None, tags=None): self._highlight_fields = args return self - def language(self, language): + def language(self, language: str) -> "Query": """ Analyze the query as being in the specified language. @@ -125,19 +136,19 @@ def language(self, language): self._language = language return self - def slop(self, slop): + def slop(self, slop: int) -> "Query": """Allow a maximum of N intervening non matched terms between phrase terms (0 means exact phrase). """ self._slop = slop return self - def timeout(self, timeout): + def timeout(self, timeout: float) -> "Query": """overrides the timeout parameter of the module""" self._timeout = timeout return self - def in_order(self): + def in_order(self) -> "Query": """ Match only documents where the query terms appear in the same order in the document. @@ -146,7 +157,7 @@ def in_order(self): self._in_order = True return self - def scorer(self, scorer): + def scorer(self, scorer: str) -> "Query": """ Use a different scoring function to evaluate document relevance. Default is `TFIDF`. @@ -157,7 +168,7 @@ def scorer(self, scorer): self._scorer = scorer return self - def get_args(self): + def get_args(self) -> List[str]: """Format the redis arguments for this query and return them.""" args = [self._query_string] args += self._get_args_tags() @@ -165,7 +176,7 @@ def get_args(self): args += ["LIMIT", self._offset, self._num] return args - def _get_args_tags(self): + def _get_args_tags(self) -> List[str]: args = [] if self._no_content: args.append("NOCONTENT") @@ -216,7 +227,7 @@ def _get_args_tags(self): return args - def paging(self, offset, num): + def paging(self, offset: int, num: int) -> "Query": """ Set the paging for the query (defaults to 0..10). @@ -227,19 +238,19 @@ def paging(self, offset, num): self._num = num return self - def verbatim(self): + def verbatim(self) -> "Query": """Set the query to be verbatim, i.e. use no query expansion or stemming. """ self._verbatim = True return self - def no_content(self): + def no_content(self) -> "Query": """Set the query to only return ids and not the document content.""" self._no_content = True return self - def no_stopwords(self): + def no_stopwords(self) -> "Query": """ Prevent the query from being filtered for stopwords. Only useful in very big queries that you are certain contain @@ -248,17 +259,17 @@ def no_stopwords(self): self._no_stopwords = True return self - def with_payloads(self): + def with_payloads(self) -> "Query": """Ask the engine to return document payloads.""" self._with_payloads = True return self - def with_scores(self): + def with_scores(self) -> "Query": """Ask the engine to return document search scores.""" self._with_scores = True return self - def limit_fields(self, *fields): + def limit_fields(self, *fields: List[str]) -> "Query": """ Limit the search to specific TEXT fields only. @@ -268,7 +279,7 @@ def limit_fields(self, *fields): self._fields = fields return self - def add_filter(self, flt): + def add_filter(self, flt: "Filter") -> "Query": """ Add a numeric or geo filter to the query. **Currently only one of each filter is supported by the engine** @@ -280,7 +291,7 @@ def add_filter(self, flt): self._filters.append(flt) return self - def sort_by(self, field, asc=True): + def sort_by(self, field: str, asc: bool = True) -> "Query": """ Add a sortby field to the query. @@ -290,7 +301,7 @@ def sort_by(self, field, asc=True): self._sortby = SortbyField(field, asc) return self - def expander(self, expander): + def expander(self, expander: str) -> "Query": """ Add a expander field to the query. @@ -310,7 +321,7 @@ def dialect(self, dialect: int) -> "Query": class Filter: - def __init__(self, keyword, field, *args): + def __init__(self, keyword: str, field: str, *args: List[str]) -> None: self.args = [keyword, field] + list(args) @@ -318,7 +329,14 @@ class NumericFilter(Filter): INF = "+inf" NEG_INF = "-inf" - def __init__(self, field, minval, maxval, minExclusive=False, maxExclusive=False): + def __init__( + self, + field: str, + minval: Union[int, str], + maxval: Union[int, str], + minExclusive: bool = False, + maxExclusive: bool = False, + ) -> None: args = [ minval if not minExclusive else f"({minval}", maxval if not maxExclusive else f"({maxval}", @@ -333,10 +351,12 @@ class GeoFilter(Filter): FEET = "ft" MILES = "mi" - def __init__(self, field, lon, lat, radius, unit=KILOMETERS): + def __init__( + self, field: str, lon: float, lat: float, radius: float, unit: str = KILOMETERS + ) -> None: Filter.__init__(self, "GEOFILTER", field, lon, lat, radius, unit) class SortbyField: - def __init__(self, field, asc=True): + def __init__(self, field: str, asc=True) -> None: self.args = [field, "ASC" if asc else "DESC"] diff --git a/redis/commands/search/reducers.py b/redis/commands/search/reducers.py index 41ed11a238..8b60f23283 100644 --- a/redis/commands/search/reducers.py +++ b/redis/commands/search/reducers.py @@ -1,8 +1,12 @@ -from .aggregation import Reducer, SortDirection +from typing import Union + +from .aggregation import Asc, Desc, Reducer, SortDirection class FieldOnlyReducer(Reducer): - def __init__(self, field): + """See https://redis.io/docs/interact/search-and-query/search/aggregations/""" + + def __init__(self, field: str) -> None: super().__init__(field) self._field = field @@ -14,7 +18,7 @@ class count(Reducer): NAME = "COUNT" - def __init__(self): + def __init__(self) -> None: super().__init__() @@ -25,7 +29,7 @@ class sum(FieldOnlyReducer): NAME = "SUM" - def __init__(self, field): + def __init__(self, field: str) -> None: super().__init__(field) @@ -36,7 +40,7 @@ class min(FieldOnlyReducer): NAME = "MIN" - def __init__(self, field): + def __init__(self, field: str) -> None: super().__init__(field) @@ -47,7 +51,7 @@ class max(FieldOnlyReducer): NAME = "MAX" - def __init__(self, field): + def __init__(self, field: str) -> None: super().__init__(field) @@ -58,7 +62,7 @@ class avg(FieldOnlyReducer): NAME = "AVG" - def __init__(self, field): + def __init__(self, field: str) -> None: super().__init__(field) @@ -69,7 +73,7 @@ class tolist(FieldOnlyReducer): NAME = "TOLIST" - def __init__(self, field): + def __init__(self, field: str) -> None: super().__init__(field) @@ -81,7 +85,7 @@ class count_distinct(FieldOnlyReducer): NAME = "COUNT_DISTINCT" - def __init__(self, field): + def __init__(self, field: str) -> None: super().__init__(field) @@ -103,7 +107,7 @@ class quantile(Reducer): NAME = "QUANTILE" - def __init__(self, field, pct): + def __init__(self, field: str, pct: float) -> None: super().__init__(field, str(pct)) self._field = field @@ -115,7 +119,7 @@ class stddev(FieldOnlyReducer): NAME = "STDDEV" - def __init__(self, field): + def __init__(self, field: str) -> None: super().__init__(field) @@ -126,7 +130,7 @@ class first_value(Reducer): NAME = "FIRST_VALUE" - def __init__(self, field, *byfields): + def __init__(self, field: str, *byfields: Union[Asc, Desc]) -> None: """ Selects the first value of the given field within the group. @@ -166,7 +170,7 @@ class random_sample(Reducer): NAME = "RANDOM_SAMPLE" - def __init__(self, field, size): + def __init__(self, field: str, size: int) -> None: """ ### Parameter diff --git a/redis/commands/search/result.py b/redis/commands/search/result.py index 451bf89bb7..5b19e6faa4 100644 --- a/redis/commands/search/result.py +++ b/redis/commands/search/result.py @@ -69,5 +69,5 @@ def __init__( ) self.docs.append(doc) - def __repr__(self): + def __repr__(self) -> str: return f"Result{{{self.total} total, docs: {self.docs}}}" diff --git a/redis/commands/search/suggestion.py b/redis/commands/search/suggestion.py index 5d1eba64b8..02e90f384b 100644 --- a/redis/commands/search/suggestion.py +++ b/redis/commands/search/suggestion.py @@ -1,3 +1,5 @@ +from typing import Union + from ._util import to_string @@ -7,12 +9,14 @@ class Suggestion: autocomplete server """ - def __init__(self, string, score=1.0, payload=None): + def __init__( + self, string: str, score=1.0, payload: Union(str, None) = None + ) -> None: self.string = to_string(string) self.payload = to_string(payload) self.score = score - def __repr__(self): + def __repr__(self) -> str: return self.string @@ -23,7 +27,7 @@ class SuggestionParser: the return value depending on what objects were requested """ - def __init__(self, with_scores, with_payloads, ret): + def __init__(self, with_scores: bool, with_payloads, ret) -> None: self.with_scores = with_scores self.with_payloads = with_payloads diff --git a/redis/connection.py b/redis/connection.py index 45ecd2a370..8b64152cb5 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -9,7 +9,7 @@ from itertools import chain from queue import Empty, Full, LifoQueue from time import time -from typing import Optional, Type, Union +from typing import Any, Callable, List, Optional, Type, Union from urllib.parse import parse_qs, unquote, urlparse from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser @@ -55,7 +55,7 @@ class HiredisRespSerializer: - def pack(self, *args): + def pack(self, *args: List): """Pack a series of arguments into the Redis protocol""" output = [] @@ -128,27 +128,27 @@ class AbstractConnection: def __init__( self, - db=0, - password=None, - socket_timeout=None, - socket_connect_timeout=None, - retry_on_timeout=False, + db: int = 0, + password: Union[str, None] = None, + socket_timeout: Union[float, None] = None, + socket_connect_timeout: Union[float, None] = None, + retry_on_timeout: bool = False, retry_on_error=SENTINEL, - encoding="utf-8", - encoding_errors="strict", - decode_responses=False, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, parser_class=DefaultParser, - socket_read_size=65536, - health_check_interval=0, - client_name=None, - lib_name="redis-py", - lib_version=get_lib_version(), - username=None, - retry=None, - redis_connect_func=None, + socket_read_size: int = 65536, + health_check_interval: int = 0, + client_name: Union[str, None] = None, + lib_name: Union[str, None] = "redis-py", + lib_version: float = get_lib_version(), + username: Union[str, None] = None, + retry: Union[Any, None] = None, + redis_connect_func: Union[None, Callable["..."]] = None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, - command_packer=None, + command_packer: (Any | HiredisRespSerializer | PythonRespSerializer) = None, ): """ Initialize a new Connection. @@ -970,7 +970,10 @@ class initializer. In the case of conflicting arguments, querystring return cls(**kwargs) def __init__( - self, connection_class=Connection, max_connections=None, **connection_kwargs + self, + connection_class=Connection, + max_connections: Union[int, None] = None, + **connection_kwargs, ): max_connections = max_connections or 2**31 if not isinstance(max_connections, int) or max_connections < 0: @@ -991,13 +994,13 @@ def __init__( self._fork_lock = threading.Lock() self.reset() - def __repr__(self): + def __repr__(self) -> (str, str): return ( f"{type(self).__name__}" f"<{repr(self.connection_class(**self.connection_kwargs))}>" ) - def reset(self): + def reset(self) -> None: self._lock = threading.Lock() self._created_connections = 0 self._available_connections = [] @@ -1014,7 +1017,7 @@ def reset(self): # reset() and they will immediately release _fork_lock and continue on. self.pid = os.getpid() - def _checkpid(self): + def _checkpid(self) -> None: # _checkpid() attempts to keep ConnectionPool fork-safe on modern # systems. this is called by all ConnectionPool methods that # manipulate the pool's state such as get_connection() and release(). @@ -1061,7 +1064,7 @@ def _checkpid(self): finally: self._fork_lock.release() - def get_connection(self, command_name, *keys, **options): + def get_connection(self, command_name: str, *keys, **options) -> type[Connection]: "Get a connection from the pool" self._checkpid() with self._lock: @@ -1094,7 +1097,7 @@ def get_connection(self, command_name, *keys, **options): return connection - def get_encoder(self): + def get_encoder(self) -> Encoder: "Return an encoder based on encoding settings" kwargs = self.connection_kwargs return Encoder( @@ -1103,14 +1106,14 @@ def get_encoder(self): decode_responses=kwargs.get("decode_responses", False), ) - def make_connection(self): + def make_connection(self) -> type[Connection]: "Create a new connection" if self._created_connections >= self.max_connections: raise ConnectionError("Too many connections") self._created_connections += 1 return self.connection_class(**self.connection_kwargs) - def release(self, connection): + def release(self, connection: type[Connection]) -> None: "Releases the connection back to the pool" self._checkpid() with self._lock: @@ -1131,10 +1134,10 @@ def release(self, connection): connection.disconnect() return - def owns_connection(self, connection): + def owns_connection(self, connection: type[Connection]) -> int: return connection.pid == self.pid - def disconnect(self, inuse_connections=True): + def disconnect(self, inuse_connections: bool = True) -> None: """ Disconnects connections in the pool @@ -1208,7 +1211,6 @@ def __init__( queue_class=LifoQueue, **connection_kwargs, ): - self.queue_class = queue_class self.timeout = timeout super().__init__( diff --git a/tests/test_asyncio/test_cwe_404.py b/tests/test_asyncio/test_cwe_404.py index 76ec2bbd26..17ed6822ac 100644 --- a/tests/test_asyncio/test_cwe_404.py +++ b/tests/test_asyncio/test_cwe_404.py @@ -99,17 +99,14 @@ async def pipe( @pytest.mark.onlynoncluster @pytest.mark.parametrize("delay", argvalues=[0.05, 0.5, 1, 2]) async def test_standalone(delay, master_host): - # create a tcp socket proxy that relays data to Redis and back, # inserting 0.1 seconds of delay async with DelayProxy(addr=("127.0.0.1", 5380), redis_addr=master_host) as dp: - for b in [True, False]: # note that we connect to proxy, rather than to Redis directly async with Redis( host="127.0.0.1", port=5380, single_connection_client=b ) as r: - await r.set("foo", "foo") await r.set("bar", "bar") @@ -189,7 +186,6 @@ async def op(pipe): @pytest.mark.onlycluster async def test_cluster(master_host): - delay = 0.1 cluster_port = 16379 remap_base = 7372 diff --git a/tests/test_asyncio/test_json.py b/tests/test_asyncio/test_json.py index ed651cd903..a35bd4795f 100644 --- a/tests/test_asyncio/test_json.py +++ b/tests/test_asyncio/test_json.py @@ -490,7 +490,6 @@ async def test_json_mget_dollar(decoded_r: redis.Redis): @pytest.mark.redismod async def test_numby_commands_dollar(decoded_r: redis.Redis): - # Test NUMINCRBY await decoded_r.json().set( "doc1", "$", {"a": "b", "b": [{"a": 2}, {"a": 5.0}, {"a": "c"}]} @@ -546,7 +545,6 @@ async def test_numby_commands_dollar(decoded_r: redis.Redis): @pytest.mark.redismod async def test_strappend_dollar(decoded_r: redis.Redis): - await decoded_r.json().set( "doc1", "$", {"a": "foo", "nested1": {"a": "hello"}, "nested2": {"a": 31}} ) @@ -578,7 +576,6 @@ async def test_strappend_dollar(decoded_r: redis.Redis): @pytest.mark.redismod async def test_strlen_dollar(decoded_r: redis.Redis): - # Test multi await decoded_r.json().set( "doc1", "$", {"a": "foo", "nested1": {"a": "hello"}, "nested2": {"a": 31}} @@ -713,7 +710,6 @@ async def test_arrinsert_dollar(decoded_r: redis.Redis): @pytest.mark.redismod async def test_arrlen_dollar(decoded_r: redis.Redis): - await decoded_r.json().set( "doc1", "$", @@ -802,7 +798,6 @@ async def test_arrpop_dollar(decoded_r: redis.Redis): @pytest.mark.redismod async def test_arrtrim_dollar(decoded_r: redis.Redis): - await decoded_r.json().set( "doc1", "$", @@ -960,7 +955,6 @@ async def test_type_dollar(decoded_r: redis.Redis): @pytest.mark.redismod async def test_clear_dollar(decoded_r: redis.Redis): - await decoded_r.json().set( "doc1", "$", diff --git a/tests/test_asyncio/test_lock.py b/tests/test_asyncio/test_lock.py index 75484a2791..c052eae2a0 100644 --- a/tests/test_asyncio/test_lock.py +++ b/tests/test_asyncio/test_lock.py @@ -234,7 +234,6 @@ class TestLockClassSelection: def test_lock_class_argument(self, r): class MyLock: def __init__(self, *args, **kwargs): - pass lock = r.lock("foo", lock_class=MyLock) diff --git a/tests/test_asyncio/test_pipeline.py b/tests/test_asyncio/test_pipeline.py index 0fa1204750..3d271bf1d0 100644 --- a/tests/test_asyncio/test_pipeline.py +++ b/tests/test_asyncio/test_pipeline.py @@ -396,7 +396,6 @@ async def test_pipeline_get(self, r): @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.0.0") async def test_pipeline_discard(self, r): - # empty pipeline should raise an error async with r.pipeline() as pipe: pipe.set("key", "someval") diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index 8fef34d83d..19d4b1c650 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -121,7 +121,6 @@ async def test_pattern_subscribe_unsubscribe(self, pubsub): async def _test_resubscribe_on_reconnection( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): - for key in keys: assert await sub_func(key) is None @@ -163,7 +162,6 @@ async def test_resubscribe_to_patterns_on_reconnection(self, pubsub): async def _test_subscribed_property( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): - assert p.subscribed is False await sub_func(keys[0]) # we're now subscribed even though we haven't processed the diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index e46de39c70..efc5bf549c 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -77,7 +77,6 @@ async def createIndex(decoded_r, num_docs=100, definition=None): r = csv.reader(bzfp, delimiter=";") for n, line in enumerate(r): - play, chapter, _, text = line[1], line[2], line[4], line[5] key = f"{play}:{chapter}".lower() @@ -163,10 +162,8 @@ async def test_client(decoded_r: redis.Redis): ) ).total both_total = ( - await ( - decoded_r.ft().search( - Query("henry").no_content().limit_fields("play", "txt") - ) + await decoded_r.ft().search( + Query("henry").no_content().limit_fields("play", "txt") ) ).total assert 129 == txt_total @@ -370,18 +367,14 @@ async def test_stopwords(decoded_r: redis.Redis): @pytest.mark.redismod async def test_filters(decoded_r: redis.Redis): - await ( - decoded_r.ft().create_index( - (TextField("txt"), NumericField("num"), GeoField("loc")) - ) + await decoded_r.ft().create_index( + (TextField("txt"), NumericField("num"), GeoField("loc")) ) - await ( - decoded_r.hset( - "doc1", mapping={"txt": "foo bar", "num": 3.141, "loc": "-0.441,51.458"} - ) + await decoded_r.hset( + "doc1", mapping={"txt": "foo bar", "num": 3.141, "loc": "-0.441,51.458"} ) - await ( - decoded_r.hset("doc2", mapping={"txt": "foo baz", "num": 2, "loc": "-0.1,51.2"}) + await decoded_r.hset( + "doc2", mapping={"txt": "foo baz", "num": 2, "loc": "-0.1,51.2"} ) await waitForIndex(decoded_r, "idx") @@ -432,10 +425,8 @@ async def test_filters(decoded_r: redis.Redis): @pytest.mark.redismod async def test_sort_by(decoded_r: redis.Redis): - await ( - decoded_r.ft().create_index( - (TextField("txt"), NumericField("num", sortable=True)) - ) + await decoded_r.ft().create_index( + (TextField("txt"), NumericField("num", sortable=True)) ) await decoded_r.hset("doc1", mapping={"txt": "foo bar", "num": 1}) await decoded_r.hset("doc2", mapping={"txt": "foo baz", "num": 2}) @@ -488,8 +479,8 @@ async def test_drop_index(decoded_r: redis.Redis): @pytest.mark.redismod async def test_example(decoded_r: redis.Redis): # Creating the index definition and schema - await ( - decoded_r.ft().create_index((TextField("title", weight=5.0), TextField("body"))) + await decoded_r.ft().create_index( + (TextField("title", weight=5.0), TextField("body")) ) # Indexing a document @@ -550,8 +541,8 @@ async def test_auto_complete(decoded_r: redis.Redis): await decoded_r.ft().sugadd("ac", Suggestion("pay2", payload="pl2")) await decoded_r.ft().sugadd("ac", Suggestion("pay3", payload="pl3")) - sugs = await ( - decoded_r.ft().sugget("ac", "pay", with_payloads=True, with_scores=True) + sugs = await decoded_r.ft().sugget( + "ac", "pay", with_payloads=True, with_scores=True ) assert 3 == len(sugs) for sug in sugs: @@ -639,8 +630,8 @@ async def test_no_index(decoded_r: redis.Redis): @pytest.mark.redismod async def test_explain(decoded_r: redis.Redis): - await ( - decoded_r.ft().create_index((TextField("f1"), TextField("f2"), TextField("f3"))) + await decoded_r.ft().create_index( + (TextField("f1"), TextField("f2"), TextField("f3")) ) res = await decoded_r.ft().explain("@f3:f3_val @f2:f2_val @f1:f1_val") assert res @@ -903,10 +894,8 @@ async def test_alter_schema_add(decoded_r: redis.Redis): async def test_spell_check(decoded_r: redis.Redis): await decoded_r.ft().create_index((TextField("f1"), TextField("f2"))) - await ( - decoded_r.hset( - "doc1", mapping={"f1": "some valid content", "f2": "this is sample text"} - ) + await decoded_r.hset( + "doc1", mapping={"f1": "some valid content", "f2": "this is sample text"} ) await decoded_r.hset("doc2", mapping={"f1": "very important", "f2": "lorem ipsum"}) await waitForIndex(decoded_r, "idx") @@ -1042,8 +1031,8 @@ async def test_scorer(decoded_r: redis.Redis): assert 1.0 == res.docs[0].score res = await decoded_r.ft().search(Query("quick").scorer("TFIDF").with_scores()) assert 1.0 == res.docs[0].score - res = await ( - decoded_r.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores()) + res = await decoded_r.ft().search( + Query("quick").scorer("TFIDF.DOCNORM").with_scores() ) assert 0.14285714285714285 == res.docs[0].score res = await decoded_r.ft().search(Query("quick").scorer("BM25").with_scores()) diff --git a/tests/test_asyncio/test_sentinel.py b/tests/test_asyncio/test_sentinel.py index a2d52f17b7..25bd7730da 100644 --- a/tests/test_asyncio/test_sentinel.py +++ b/tests/test_asyncio/test_sentinel.py @@ -72,7 +72,6 @@ def client(self, host, port, **kwargs): @pytest_asyncio.fixture() async def cluster(master_ip): - cluster = SentinelTestCluster(ip=master_ip) saved_Redis = redis.asyncio.sentinel.Redis redis.asyncio.sentinel.Redis = cluster.client diff --git a/tests/test_asyncio/test_timeseries.py b/tests/test_asyncio/test_timeseries.py index 48ffdfd889..91c15c3db2 100644 --- a/tests/test_asyncio/test_timeseries.py +++ b/tests/test_asyncio/test_timeseries.py @@ -108,7 +108,6 @@ async def test_add(decoded_r: redis.Redis): @pytest.mark.redismod @skip_ifmodversion_lt("1.4.0", "timeseries") async def test_add_duplicate_policy(r: redis.Redis): - # Test for duplicate policy BLOCK assert 1 == await r.ts().add("time-serie-add-ooo-block", 1, 5.0) with pytest.raises(Exception): diff --git a/tests/test_commands.py b/tests/test_commands.py index b538dc3038..6660c2c6b0 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -516,7 +516,6 @@ def test_client_trackinginfo(self, r): @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise() def test_client_tracking(self, r, r2): - # simple case assert r.client_tracking_on() assert r.client_tracking_off() @@ -5011,7 +5010,6 @@ def test_module_loadex(self, r: redis.Redis): @skip_if_server_version_lt("2.6.0") def test_restore(self, r): - # standard restore key = "foo" r.set(key, "bar") diff --git a/tests/test_graph_utils/test_edge.py b/tests/test_graph_utils/test_edge.py index 581ebfab5d..d2a1e3f39e 100644 --- a/tests/test_graph_utils/test_edge.py +++ b/tests/test_graph_utils/test_edge.py @@ -4,7 +4,6 @@ @pytest.mark.redismod def test_init(): - with pytest.raises(AssertionError): edge.Edge(None, None, None) edge.Edge(node.Node(), None, None) diff --git a/tests/test_json.py b/tests/test_json.py index be347f6677..73d72b8cc9 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -464,7 +464,6 @@ def test_json_mget_dollar(client): def test_numby_commands_dollar(client): - # Test NUMINCRBY client.json().set("doc1", "$", {"a": "b", "b": [{"a": 2}, {"a": 5.0}, {"a": "c"}]}) # Test multi @@ -508,7 +507,6 @@ def test_numby_commands_dollar(client): def test_strappend_dollar(client): - client.json().set( "doc1", "$", {"a": "foo", "nested1": {"a": "hello"}, "nested2": {"a": 31}} ) @@ -539,7 +537,6 @@ def test_strappend_dollar(client): def test_strlen_dollar(client): - # Test multi client.json().set( "doc1", "$", {"a": "foo", "nested1": {"a": "hello"}, "nested2": {"a": 31}} @@ -672,7 +669,6 @@ def test_arrinsert_dollar(client): def test_arrlen_dollar(client): - client.json().set( "doc1", "$", @@ -762,7 +758,6 @@ def test_arrpop_dollar(client): def test_arrtrim_dollar(client): - client.json().set( "doc1", "$", @@ -1015,7 +1010,6 @@ def test_toggle_dollar(client): def test_resp_dollar(client): - data = { "L1": { "a": { @@ -1244,7 +1238,6 @@ def test_resp_dollar(client): def test_arrindex_dollar(client): - client.json().set( "store", "$", diff --git a/tests/test_lock.py b/tests/test_lock.py index b34f7f0159..72af87fa81 100644 --- a/tests/test_lock.py +++ b/tests/test_lock.py @@ -247,7 +247,6 @@ class TestLockClassSelection: def test_lock_class_argument(self, r): class MyLock: def __init__(self, *args, **kwargs): - pass lock = r.lock("foo", lock_class=MyLock) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index e64a763bae..7f10fcad4f 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -390,7 +390,6 @@ def test_pipeline_with_bitfield(self, r): @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.0.0") def test_pipeline_discard(self, r): - # empty pipeline should raise an error with r.pipeline() as pipe: pipe.set("key", "someval") diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index ba097e3194..fb46772af3 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -152,7 +152,6 @@ def test_shard_channel_subscribe_unsubscribe_cluster(self, r): def _test_resubscribe_on_reconnection( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): - for key in keys: assert sub_func(key) is None @@ -201,7 +200,6 @@ def test_resubscribe_to_shard_channels_on_reconnection(self, r): def _test_subscribed_property( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): - assert p.subscribed is False sub_func(keys[0]) # we're now subscribed even though we haven't processed the diff --git a/tests/test_search.py b/tests/test_search.py index 7612332470..9bbfc3c696 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -86,7 +86,6 @@ def createIndex(client, num_docs=100, definition=None): r = csv.reader(bzfp, delimiter=";") for n, line in enumerate(r): - play, chapter, _, text = line[1], line[2], line[4], line[5] key = f"{play}:{chapter}".lower() @@ -820,7 +819,6 @@ def test_spell_check(client): waitForIndex(client, getattr(client.ft(), "index_name", "idx")) if is_resp2_connection(client): - # test spellcheck res = client.ft().spellcheck("impornant") assert "important" == res["impornant"][0]["suggestion"] @@ -2100,7 +2098,6 @@ def test_numeric_params(client): @pytest.mark.redismod @skip_ifmodversion_lt("2.4.3", "search") def test_geo_params(client): - client.ft().create_index((GeoField("g"))) client.hset("doc1", mapping={"g": "29.69465, 34.95126"}) client.hset("doc2", mapping={"g": "29.69350, 34.94737"}) diff --git a/tests/test_timeseries.py b/tests/test_timeseries.py index 4ab86cd56e..6b59967f3c 100644 --- a/tests/test_timeseries.py +++ b/tests/test_timeseries.py @@ -104,7 +104,6 @@ def test_add(client): @skip_ifmodversion_lt("1.4.0", "timeseries") def test_add_duplicate_policy(client): - # Test for duplicate policy BLOCK assert 1 == client.ts().add("time-serie-add-ooo-block", 1, 5.0) with pytest.raises(Exception): From 14ec59ef2eceede65def68ef5e422e3b153a2660 Mon Sep 17 00:00:00 2001 From: "Chayim I. Kirshen" Date: Wed, 20 Sep 2023 15:58:16 +0300 Subject: [PATCH 02/19] fixed callable[T] --- redis/client.py | 10 +++++++--- redis/commands/search/aggregation.py | 2 +- redis/commands/search/query.py | 2 +- redis/commands/search/suggestion.py | 2 +- redis/connection.py | 4 ++-- 5 files changed, 12 insertions(+), 8 deletions(-) diff --git a/redis/client.py b/redis/client.py index 6ef71f083c..0d0add744a 100755 --- a/redis/client.py +++ b/redis/client.py @@ -367,7 +367,9 @@ def pipeline(self, transaction=True, shard_hint=None) -> "Pipeline": self.connection_pool, self.response_callbacks, transaction, shard_hint ) - def transaction(self, func: Callable["..."], *watches, **kwargs) -> None: + def transaction( + self, func: Callable[["Pipeline"], None], *watches, **kwargs + ) -> None: """ Convenience method for executing the callable `func` as a transaction while watching all keys specified in `watches`. The 'func' callable @@ -651,7 +653,7 @@ def __init__( shard_hint=None, ignore_subscribe_messages: bool = False, encoder: Union[None, "Encoder"] = None, - push_handler_func: Union[None, Callable["..."]] = None, + push_handler_func: Union[None, Callable[[str], None]] = None, ): self.connection_pool = connection_pool self.shard_hint = shard_hint @@ -1124,7 +1126,9 @@ def __init__( pubsub, sleep_time: float, daemon: bool = False, - exception_handler: Union[Callable["..."], None] = None, + exception_handler: Union[ + Callable[[Exception, "PubSub", "PubSubWorkerThread"], None], None + ] = None, ): super().__init__() self.daemon = daemon diff --git a/redis/commands/search/aggregation.py b/redis/commands/search/aggregation.py index aa42148470..50d18f476a 100644 --- a/redis/commands/search/aggregation.py +++ b/redis/commands/search/aggregation.py @@ -266,7 +266,7 @@ def sort_by(self, *fields: List[str], **kwargs) -> "AggregateRequest": self._aggregateplan.extend(ret) return self - def filter(self, expressions: Union(str, List[str])) -> "AggregateRequest": + def filter(self, expressions: Union[str, List[str]]) -> "AggregateRequest": """ Specify filter for post-query results using predicates relating to values in the result set. diff --git a/redis/commands/search/query.py b/redis/commands/search/query.py index ba0481dd64..973ccb6eb4 100644 --- a/redis/commands/search/query.py +++ b/redis/commands/search/query.py @@ -108,7 +108,7 @@ def summarize( return self def highlight( - self, fields: Union[List[str], None] = None, tags: List[str, str] = None + self, fields: Union[List[str], None] = None, tags: List[str] = None ) -> None: """ Apply specified markup to matched term(s) within the returned field(s). diff --git a/redis/commands/search/suggestion.py b/redis/commands/search/suggestion.py index 02e90f384b..ed3f533fab 100644 --- a/redis/commands/search/suggestion.py +++ b/redis/commands/search/suggestion.py @@ -10,7 +10,7 @@ class Suggestion: """ def __init__( - self, string: str, score=1.0, payload: Union(str, None) = None + self, string: str, score=1.0, payload: Union[str, None] = None ) -> None: self.string = to_string(string) self.payload = to_string(payload) diff --git a/redis/connection.py b/redis/connection.py index 8b64152cb5..44c92ae8f8 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -145,10 +145,10 @@ def __init__( lib_version: float = get_lib_version(), username: Union[str, None] = None, retry: Union[Any, None] = None, - redis_connect_func: Union[None, Callable["..."]] = None, + redis_connect_func: Union[None, Callable[[], None]] = None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, - command_packer: (Any | HiredisRespSerializer | PythonRespSerializer) = None, + command_packer: Union[Callable[[], None], None] = None, ): """ Initialize a new Connection. From 809687e4b3fde3691e23409e4e5df5f3102e04da Mon Sep 17 00:00:00 2001 From: "Chayim I. Kirshen" Date: Wed, 20 Sep 2023 16:59:03 +0300 Subject: [PATCH 03/19] con --- redis/connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/connection.py b/redis/connection.py index 44c92ae8f8..a57a982c1d 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1064,7 +1064,7 @@ def _checkpid(self) -> None: finally: self._fork_lock.release() - def get_connection(self, command_name: str, *keys, **options) -> type[Connection]: + def get_connection(self, command_name: str, *keys, **options) -> "Connection": "Get a connection from the pool" self._checkpid() with self._lock: From 12673c6ac97dd6de9facce6e1b6214b93a98c236 Mon Sep 17 00:00:00 2001 From: "Chayim I. Kirshen" Date: Wed, 20 Sep 2023 17:03:53 +0300 Subject: [PATCH 04/19] more connectios --- redis/connection.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index a57a982c1d..32fad09c40 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1106,14 +1106,14 @@ def get_encoder(self) -> Encoder: decode_responses=kwargs.get("decode_responses", False), ) - def make_connection(self) -> type[Connection]: + def make_connection(self) -> "Connection": "Create a new connection" if self._created_connections >= self.max_connections: raise ConnectionError("Too many connections") self._created_connections += 1 return self.connection_class(**self.connection_kwargs) - def release(self, connection: type[Connection]) -> None: + def release(self, connection: "Connection") -> None: "Releases the connection back to the pool" self._checkpid() with self._lock: @@ -1134,7 +1134,7 @@ def release(self, connection: type[Connection]) -> None: connection.disconnect() return - def owns_connection(self, connection: type[Connection]) -> int: + def owns_connection(self, connection: "Connection") -> int: return connection.pid == self.pid def disconnect(self, inuse_connections: bool = True) -> None: From fc851c0ad2db1a5394e2cad6e8e47b3ebbf330a9 Mon Sep 17 00:00:00 2001 From: "Chayim I. Kirshen" Date: Thu, 21 Sep 2023 10:02:06 +0300 Subject: [PATCH 05/19] restoring dev reqs --- dev_requirements.txt | 2 -- 1 file changed, 2 deletions(-) diff --git a/dev_requirements.txt b/dev_requirements.txt index 6bd418bf46..3715599af0 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -15,5 +15,3 @@ ujson>=4.2.0 wheel>=0.30.0 urllib3<2 uvloop -types-requests -types-pyOpenSSL \ No newline at end of file From f72b412fab601198ce4897b17db8392064856c83 Mon Sep 17 00:00:00 2001 From: Chayim Date: Thu, 21 Sep 2023 11:43:26 +0300 Subject: [PATCH 06/19] Update redis/commands/search/suggestion.py Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> --- redis/commands/search/suggestion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/commands/search/suggestion.py b/redis/commands/search/suggestion.py index ed3f533fab..d71921f38b 100644 --- a/redis/commands/search/suggestion.py +++ b/redis/commands/search/suggestion.py @@ -10,7 +10,7 @@ class Suggestion: """ def __init__( - self, string: str, score=1.0, payload: Union[str, None] = None + self, string: str, score: float = 1.0, payload: Optional[str] = None ) -> None: self.string = to_string(string) self.payload = to_string(payload) From f3b42a815e81b18d141dc52255af664e085411b4 Mon Sep 17 00:00:00 2001 From: Chayim Date: Thu, 21 Sep 2023 11:43:34 +0300 Subject: [PATCH 07/19] Update redis/commands/core.py Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> --- redis/commands/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/commands/core.py b/redis/commands/core.py index 49802576ec..e73553e47e 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -403,7 +403,7 @@ class ManagementCommands(CommandsProtocol): Redis management commands """ - def auth(self, password: str, username=None, **kwargs): + def auth(self, password: str, username: Optional[str] = None, **kwargs): """ Authenticates the user. If you do not pass username, Redis will try to authenticate for the "default" user. If you do pass username, it will From 50fd175591fca1eccb95170b333635284f7322a0 Mon Sep 17 00:00:00 2001 From: Chayim Date: Thu, 21 Sep 2023 11:43:40 +0300 Subject: [PATCH 08/19] Update redis/commands/search/suggestion.py Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> --- redis/commands/search/suggestion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/commands/search/suggestion.py b/redis/commands/search/suggestion.py index d71921f38b..c59eae1944 100644 --- a/redis/commands/search/suggestion.py +++ b/redis/commands/search/suggestion.py @@ -27,7 +27,7 @@ class SuggestionParser: the return value depending on what objects were requested """ - def __init__(self, with_scores: bool, with_payloads, ret) -> None: + def __init__(self, with_scores: bool, with_payloads: bool, ret) -> None: self.with_scores = with_scores self.with_payloads = with_payloads From aa883ffe9bcd3d43c188d4c395f42664807b2e0c Mon Sep 17 00:00:00 2001 From: Chayim Date: Thu, 21 Sep 2023 11:43:51 +0300 Subject: [PATCH 09/19] Update redis/commands/search/commands.py Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> --- redis/commands/search/commands.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index c00abdde98..b2ca3e8d4d 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -554,7 +554,7 @@ def aggregate( AGGREGATE_CMD, raw, query=query, has_cursor=has_cursor ) - def _get_aggregate_result(self, raw: List, query: str, has_cursor: bool): + def _get_aggregate_result(self, raw: List, query: Union[str, Query, AggregateRequest], has_cursor: bool): if has_cursor: if isinstance(query, Cursor): query.cid = raw[1] From c2084f7048ed1eefa97f064fe091459dc91f700e Mon Sep 17 00:00:00 2001 From: Chayim Date: Thu, 21 Sep 2023 11:43:57 +0300 Subject: [PATCH 10/19] Update redis/client.py Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> --- redis/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/client.py b/redis/client.py index 0d0add744a..0b7cfad43c 100755 --- a/redis/client.py +++ b/redis/client.py @@ -50,7 +50,7 @@ class CaseInsensitiveDict(dict): "Case insensitive dict implementation. Assumes string keys only." - def __init__(self, data) -> None: + def __init__(self, data: Dict[str, str]) -> None: for k, v in data.items(): self[k.upper()] = v From 6257558150c4d891d60a4c35dbef2dde921fbb4c Mon Sep 17 00:00:00 2001 From: Chayim Date: Thu, 21 Sep 2023 11:44:03 +0300 Subject: [PATCH 11/19] Update redis/commands/search/suggestion.py Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> --- redis/commands/search/suggestion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/commands/search/suggestion.py b/redis/commands/search/suggestion.py index c59eae1944..499c8d917e 100644 --- a/redis/commands/search/suggestion.py +++ b/redis/commands/search/suggestion.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Optional from ._util import to_string From 083ee8298ccbee99ec2d6c4487825464692eaced Mon Sep 17 00:00:00 2001 From: Chayim Date: Thu, 21 Sep 2023 11:44:12 +0300 Subject: [PATCH 12/19] Update redis/connection.py Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> --- redis/connection.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index 32fad09c40..ee17d52680 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -129,9 +129,9 @@ class AbstractConnection: def __init__( self, db: int = 0, - password: Union[str, None] = None, - socket_timeout: Union[float, None] = None, - socket_connect_timeout: Union[float, None] = None, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, retry_on_timeout: bool = False, retry_on_error=SENTINEL, encoding: str = "utf-8", From 58087401106e24f5d48bd10260d930172d966007 Mon Sep 17 00:00:00 2001 From: Chayim Date: Thu, 21 Sep 2023 11:44:20 +0300 Subject: [PATCH 13/19] Update redis/connection.py Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> --- redis/connection.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index ee17d52680..a061b9f2ce 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -140,12 +140,12 @@ def __init__( parser_class=DefaultParser, socket_read_size: int = 65536, health_check_interval: int = 0, - client_name: Union[str, None] = None, - lib_name: Union[str, None] = "redis-py", - lib_version: float = get_lib_version(), - username: Union[str, None] = None, + client_name: Optional[str] = None, + lib_name: Optional[str] = "redis-py", + lib_version: Optional[str] = get_lib_version(), + username: Optional[str] = None, retry: Union[Any, None] = None, - redis_connect_func: Union[None, Callable[[], None]] = None, + redis_connect_func: Optional[Callable[[], None]] = None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, command_packer: Union[Callable[[], None], None] = None, From 0f40cdfa93911cf7430544c88c61ccd0390151a5 Mon Sep 17 00:00:00 2001 From: Chayim Date: Thu, 21 Sep 2023 11:44:27 +0300 Subject: [PATCH 14/19] Update redis/connection.py Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> --- redis/connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/connection.py b/redis/connection.py index a061b9f2ce..61e47290a5 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -972,7 +972,7 @@ class initializer. In the case of conflicting arguments, querystring def __init__( self, connection_class=Connection, - max_connections: Union[int, None] = None, + max_connections: Optional[int] = None, **connection_kwargs, ): max_connections = max_connections or 2**31 From b3b47cfb28204912f0f3d27a000a7db479952640 Mon Sep 17 00:00:00 2001 From: Chayim Date: Thu, 21 Sep 2023 11:44:33 +0300 Subject: [PATCH 15/19] Update redis/connection.py Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> --- redis/connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/connection.py b/redis/connection.py index 61e47290a5..ae52c27d1e 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -148,7 +148,7 @@ def __init__( redis_connect_func: Optional[Callable[[], None]] = None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, - command_packer: Union[Callable[[], None], None] = None, + command_packer: Optional[Callable[[], None]] = None, ): """ Initialize a new Connection. From 5bcae5aa376399d8f51bac8c7042ed682f2e97a7 Mon Sep 17 00:00:00 2001 From: Chayim Date: Thu, 21 Sep 2023 12:35:00 +0300 Subject: [PATCH 16/19] Update redis/client.py Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> --- redis/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/client.py b/redis/client.py index 0d9eea1ce6..34bb4ff9c2 100755 --- a/redis/client.py +++ b/redis/client.py @@ -94,7 +94,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): """ @classmethod - def from_url(cls, url, **kwargs) -> None: + def from_url(cls, url: str, **kwargs) -> None: """ Return a Redis client object configured from the given URL From b3efe2ca504fd284605f5cbadc8a0df387c80334 Mon Sep 17 00:00:00 2001 From: Chayim Date: Thu, 21 Sep 2023 12:35:16 +0300 Subject: [PATCH 17/19] Update redis/client.py Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> --- redis/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/client.py b/redis/client.py index 34bb4ff9c2..a938d07098 100755 --- a/redis/client.py +++ b/redis/client.py @@ -652,7 +652,7 @@ def __init__( connection_pool, shard_hint=None, ignore_subscribe_messages: bool = False, - encoder: Union[None, "Encoder"] = None, + encoder: Optional["Encoder"] = None, push_handler_func: Union[None, Callable[[str], None]] = None, ): self.connection_pool = connection_pool From 745f8d0ba9cf4b9b7b4f723f29fe418768059f71 Mon Sep 17 00:00:00 2001 From: Chayim Date: Thu, 21 Sep 2023 12:36:35 +0300 Subject: [PATCH 18/19] Apply suggestions from code review Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> --- redis/client.py | 10 ++++++---- redis/commands/search/commands.py | 2 +- redis/commands/search/query.py | 30 +++++++++++++++--------------- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/redis/client.py b/redis/client.py index a938d07098..be872d9600 100755 --- a/redis/client.py +++ b/redis/client.py @@ -328,7 +328,7 @@ def set_retry(self, retry: "Retry") -> None: self.get_connection_kwargs().update({"retry": retry}) self.connection_pool.set_retry(retry) - def set_response_callback(self, command, callback) -> None: + def set_response_callback(self, command: str, callback: Callable) -> None: """Set a custom Response Callback""" self.response_callbacks[command] = callback @@ -394,10 +394,10 @@ def transaction( def lock( self, name: str, - timeout: Union[None, float] = None, + timeout: Optional[float] = None, sleep: float = 0.1, blocking: bool = True, - blocking_timeout: Union[None, float] = None, + blocking_timeout: Optional[float] = None, lock_class: Union[None, Any] = None, thread_local: bool = True, ): @@ -1099,7 +1099,9 @@ def handle_message(self, response, ignore_subscribe_messages=False): return message def run_in_thread( - self, sleep_time=0, daemon=False, exception_handler=None + self, sleep_time: int = 0, + daemon: bool = False, + exception_handler: Optional[Callable] = None, ) -> "PubSubWorkerThread": for channel, handler in self.channels.items(): if handler is None: diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index b2ca3e8d4d..a5f66d4164 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -328,7 +328,7 @@ def add_document( payload: bool = None, replace: bool = False, partial: bool = False, - language: Union[str, None] = None, + language: Optional[str] = None, no_create: str = False, **fields: List[str], ): diff --git a/redis/commands/search/query.py b/redis/commands/search/query.py index 973ccb6eb4..3ae3208587 100644 --- a/redis/commands/search/query.py +++ b/redis/commands/search/query.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import List, Optional, Union class Query: @@ -23,23 +23,23 @@ def __init__(self, query_string: str) -> None: self._num: int = 10 self._no_content: bool = False self._no_stopwords: bool = False - self._fields: Union[List[str], None] = None + self._fields: Optional[List[str]] = None self._verbatim: bool = False self._with_payloads: bool = False self._with_scores: bool = False - self._scorer: bool = False + self._scorer: Optional[str] = None self._filters: List = list() - self._ids: Union[List[str], None] = None + self._ids: Optional[List[str]] = None self._slop: int = -1 - self._timeout: Union[float, None] = None + self._timeout: Optional[float] = None self._in_order: bool = False - self._sortby: Union[SortbyField, None] = None + self._sortby: Optional[SortbyField] = None self._return_fields: List = [] self._summarize_fields: List = [] self._highlight_fields: List = [] - self._language: Union[str, None] = None - self._expander: Union[str, None] = None - self._dialect: Union[int, None] = None + self._language: Optional[str] = None + self._expander: Optional[str] = None + self._dialect: Optional[int] = None def query_string(self) -> str: """Return the query string of this query only.""" @@ -56,7 +56,7 @@ def return_fields(self, *fields) -> "Query": self._return_fields += fields return self - def return_field(self, field: str, as_field: Union[str, None] = None) -> "Query": + def return_field(self, field: str, as_field: Optional[str] = None) -> "Query": """Add field to return fields (Optional: add 'AS' name to the field).""" self._return_fields.append(field) @@ -71,10 +71,10 @@ def _mk_field_list(self, fields: List[str]) -> List: def summarize( self, - fields: Union[None, List] = None, - context_len: Union[None, int] = None, - num_frags: Union[None, int] = None, - sep: [Union, str] = None, + fields: Optional[List] = None, + context_len: Optional[int] = None, + num_frags: Optional[int] = None, + sep: Optional[str] = None, ) -> "Query": """ Return an abridged format of the field, containing only the segments of @@ -108,7 +108,7 @@ def summarize( return self def highlight( - self, fields: Union[List[str], None] = None, tags: List[str] = None + self, fields: Optional[List[str]] = None, tags: Optional [List[str]] = None ) -> None: """ Apply specified markup to matched term(s) within the returned field(s). From c080909b5d027d6e527d2784611ccf206963c770 Mon Sep 17 00:00:00 2001 From: "Chayim I. Kirshen" Date: Thu, 21 Sep 2023 12:39:40 +0300 Subject: [PATCH 19/19] linters --- redis/client.py | 3 ++- redis/commands/search/commands.py | 4 +++- redis/commands/search/query.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/redis/client.py b/redis/client.py index be872d9600..4923143543 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1099,7 +1099,8 @@ def handle_message(self, response, ignore_subscribe_messages=False): return message def run_in_thread( - self, sleep_time: int = 0, + self, + sleep_time: int = 0, daemon: bool = False, exception_handler: Optional[Callable] = None, ) -> "PubSubWorkerThread": diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index a5f66d4164..2df2b5a754 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -554,7 +554,9 @@ def aggregate( AGGREGATE_CMD, raw, query=query, has_cursor=has_cursor ) - def _get_aggregate_result(self, raw: List, query: Union[str, Query, AggregateRequest], has_cursor: bool): + def _get_aggregate_result( + self, raw: List, query: Union[str, Query, AggregateRequest], has_cursor: bool + ): if has_cursor: if isinstance(query, Cursor): query.cid = raw[1] diff --git a/redis/commands/search/query.py b/redis/commands/search/query.py index 3ae3208587..113ddf9da8 100644 --- a/redis/commands/search/query.py +++ b/redis/commands/search/query.py @@ -108,7 +108,7 @@ def summarize( return self def highlight( - self, fields: Optional[List[str]] = None, tags: Optional [List[str]] = None + self, fields: Optional[List[str]] = None, tags: Optional[List[str]] = None ) -> None: """ Apply specified markup to matched term(s) within the returned field(s).