diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index 1d8a33a22a..bfd1841669 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -36,7 +36,7 @@ jobs: strategy: max-parallel: 15 matrix: - python-version: ['3.6', '3.7', '3.8', '3.9', '3.10', 'pypy-3.7'] + python-version: ['3.6','3.7', '3.8', '3.9', '3.10', 'pypy-3.7'] test-type: ['standalone', 'cluster'] connection-type: ['hiredis', 'plain'] env: @@ -50,6 +50,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: run tests run: | + pip install -U setuptools wheel pip install -r dev_requirements.txt tox -e ${{matrix.test-type}}-${{matrix.connection-type}} - name: Upload codecov coverage @@ -79,7 +80,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.6', '3.7', '3.8', '3.9', '3.10', 'pypy-3.7'] + python-version: ['3.7', '3.8', '3.9', '3.10', 'pypy-3.7'] steps: - uses: actions/checkout@v2 - name: install python ${{ matrix.python-version }} diff --git a/.gitignore b/.gitignore index 96fbdd5646..b392a2d748 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,7 @@ vagrant/.vagrant env venv coverage.xml -.venv +.venv* *.xml .coverage* docker/stunnel/keys diff --git a/.mypy.ini b/.mypy.ini new file mode 100644 index 0000000000..942574e0f3 --- /dev/null +++ b/.mypy.ini @@ -0,0 +1,24 @@ +[mypy] +#, docs/examples, tests +files = redis +check_untyped_defs = True +follow_imports_for_stubs asyncio.= True +#disallow_any_decorated = True +disallow_subclassing_any = True +#disallow_untyped_calls = True +disallow_untyped_decorators = True +#disallow_untyped_defs = True +implicit_reexport = False +no_implicit_optional = True +show_error_codes = True +strict_equality = True +warn_incomplete_stub = True +warn_redundant_casts = True +warn_unreachable = True +warn_unused_ignores = True +disallow_any_unimported = True +#warn_return_any = True + +[mypy-redis.asyncio.lock] +# TODO: Remove once locks has been rewritten +ignore_errors = True diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ebb66bbf3e..827a25f9c4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -112,12 +112,12 @@ You can see the logging output of a containers like this: `$ docker logs -f ` The command make test runs all tests in all tested Python -environments. To run the tests in a single environment, like Python 3.6, +environments. To run the tests in a single environment, like Python 3.9, use a command like this: -`$ docker-compose run test tox -e py36 -- --redis-url=redis://master:6379/9` +`$ docker-compose run test tox -e py39 -- --redis-url=redis://master:6379/9` -Here, the flag `-e py36` runs tests against the Python 3.6 tox +Here, the flag `-e py39` runs tests against the Python 3.9 tox environment. And note from the example that whenever you run tests like this, instead of using make test, you need to pass `-- --redis-url=redis://master:6379/9`. This points the tests at the diff --git a/README.md b/README.md index f35c7bb13a..c6ee3c10eb 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,11 @@ The Python interface to the Redis key-value store. --------------------------------------------- +## Python Notice + +redis-py 4.2.x will be the last generation of redis-py to support python 3.6 as it has been [End of Life'd](https://www.python.org/dev/peps/pep-0494/#schedule-last-security-only-release). Async support was introduced in redis-py 4.2.x thanks to [aioredis](https://github.com/aio-libs/aioredis-py), which necessitates this change. We will continue to maintain 3.6 support as long as possible - but the plan is for redis-py version 5+ to offically remove 3.6. + +--------------------------- ## Installation @@ -51,7 +56,7 @@ contributing](https://github.com/redis/redis-py/blob/master/CONTRIBUTING.md). ## Getting Started -redis-py supports Python 3.6+. +redis-py supports Python 3.7+. ``` pycon >>> import redis diff --git a/dev_requirements.txt b/dev_requirements.txt index 2a4f37762f..0c4bee9f35 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -2,8 +2,10 @@ black==21.11b1 flake8==4.0.1 flynt~=0.69.0 isort==5.10.1 +mock==4.0.3 pytest==6.2.5 pytest-timeout==2.0.1 +pytest-asyncio>=0.16.0 tox==3.24.4 tox-docker==3.1.0 invoke==1.6.0 @@ -11,3 +13,4 @@ pytest-cov>=3.0.0 vulture>=2.3.0 ujson>=4.2.0 wheel>=0.30.0 +uvloop diff --git a/docs/connections.rst b/docs/connections.rst index ba39f3341f..9804a15bf1 100644 --- a/docs/connections.rst +++ b/docs/connections.rst @@ -42,4 +42,14 @@ Connection Pools .. autoclass:: redis.connection.ConnectionPool :members: -More connection examples can be found `here `_. \ No newline at end of file +More connection examples can be found `here `_. + +Async Client +************ + +This client is used for communicating with Redis, asynchronously. + +.. autoclass:: redis.asyncio.connection.Connection + :members: + +More connection examples can be found `here `_ \ No newline at end of file diff --git a/docs/examples.rst b/docs/examples.rst index fb57499e65..6d659a0b75 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -7,5 +7,6 @@ Examples examples/connection_examples examples/ssl_connection_examples + examples/asyncio_examples examples/search_json_examples examples/set_and_get_examples diff --git a/docs/examples/asyncio_examples.ipynb b/docs/examples/asyncio_examples.ipynb new file mode 100644 index 0000000000..66d435835b --- /dev/null +++ b/docs/examples/asyncio_examples.ipynb @@ -0,0 +1,301 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "collapsed": true, + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "# Asyncio Examples\n", + "\n", + "All commands are coroutine functions.\n", + "\n", + "## Connecting and Disconnecting\n", + "\n", + "Utilizing asyncio Redis requires an explicit disconnect of the connection since there is no asyncio deconstructor magic method. By default, a connection pool is created on `redis.Redis()` and attached to this `Redis` instance. The connection pool closes automatically on the call to `Redis.close` which disconnects all connections." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ping successful: True\n" + ] + } + ], + "source": [ + "import redis.asyncio as redis\n", + "\n", + "connection = redis.Redis()\n", + "print(f\"Ping successful: {await connection.ping()}\")\n", + "await connection.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "If you supply a custom `ConnectionPool` that is supplied to several `Redis` instances, you may want to disconnect the connection pool explicitly. Disconnecting the connection pool simply disconnects all connections hosted in the pool." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "import redis.asyncio as redis\n", + "\n", + "connection = redis.Redis(auto_close_connection_pool=False)\n", + "await connection.close()\n", + "# Or: await connection.close(close_connection_pool=False)\n", + "await connection.connection_pool.disconnect()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "## Transactions (Multi/Exec)\n", + "\n", + "The aioredis.Redis.pipeline will return a aioredis.Pipeline object, which will buffer all commands in-memory and compile them into batches using the Redis Bulk String protocol. Additionally, each command will return the Pipeline instance, allowing you to chain your commands, i.e., p.set('foo', 1).set('bar', 2).mget('foo', 'bar').\n", + "\n", + "The commands will not be reflected in Redis until execute() is called & awaited.\n", + "\n", + "Usually, when performing a bulk operation, taking advantage of a “transaction” (e.g., Multi/Exec) is to be desired, as it will also add a layer of atomicity to your bulk operation." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "import redis.asyncio as redis\n", + "\n", + "r = await redis.from_url(\"redis://localhost\")\n", + "async with r.pipeline(transaction=True) as pipe:\n", + " ok1, ok2 = await (pipe.set(\"key1\", \"value1\").set(\"key2\", \"value2\").execute())\n", + "assert ok1\n", + "assert ok2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pub/Sub Mode\n", + "\n", + "Subscribing to specific channels:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(Reader) Message Received: {'type': 'message', 'pattern': None, 'channel': b'channel:1', 'data': b'Hello'}\n", + "(Reader) Message Received: {'type': 'message', 'pattern': None, 'channel': b'channel:2', 'data': b'World'}\n", + "(Reader) Message Received: {'type': 'message', 'pattern': None, 'channel': b'channel:1', 'data': b'STOP'}\n", + "(Reader) STOP\n" + ] + } + ], + "source": [ + "import asyncio\n", + "\n", + "import async_timeout\n", + "\n", + "import redis.asyncio as redis\n", + "\n", + "STOPWORD = \"STOP\"\n", + "\n", + "\n", + "async def reader(channel: redis.client.PubSub):\n", + " while True:\n", + " try:\n", + " async with async_timeout.timeout(1):\n", + " message = await channel.get_message(ignore_subscribe_messages=True)\n", + " if message is not None:\n", + " print(f\"(Reader) Message Received: {message}\")\n", + " if message[\"data\"].decode() == STOPWORD:\n", + " print(\"(Reader) STOP\")\n", + " break\n", + " await asyncio.sleep(0.01)\n", + " except asyncio.TimeoutError:\n", + " pass\n", + "\n", + "r = redis.from_url(\"redis://localhost\")\n", + "pubsub = r.pubsub()\n", + "await pubsub.subscribe(\"channel:1\", \"channel:2\")\n", + "\n", + "future = asyncio.create_task(reader(pubsub))\n", + "\n", + "await r.publish(\"channel:1\", \"Hello\")\n", + "await r.publish(\"channel:2\", \"World\")\n", + "await r.publish(\"channel:1\", STOPWORD)\n", + "\n", + "await future" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Subscribing to channels matching a glob-style pattern:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(Reader) Message Received: {'type': 'pmessage', 'pattern': b'channel:*', 'channel': b'channel:1', 'data': b'Hello'}\n", + "(Reader) Message Received: {'type': 'pmessage', 'pattern': b'channel:*', 'channel': b'channel:2', 'data': b'World'}\n", + "(Reader) Message Received: {'type': 'pmessage', 'pattern': b'channel:*', 'channel': b'channel:1', 'data': b'STOP'}\n", + "(Reader) STOP\n" + ] + } + ], + "source": [ + "import asyncio\n", + "\n", + "import async_timeout\n", + "\n", + "import redis.asyncio as redis\n", + "\n", + "STOPWORD = \"STOP\"\n", + "\n", + "\n", + "async def reader(channel: redis.client.PubSub):\n", + " while True:\n", + " try:\n", + " async with async_timeout.timeout(1):\n", + " message = await channel.get_message(ignore_subscribe_messages=True)\n", + " if message is not None:\n", + " print(f\"(Reader) Message Received: {message}\")\n", + " if message[\"data\"].decode() == STOPWORD:\n", + " print(\"(Reader) STOP\")\n", + " break\n", + " await asyncio.sleep(0.01)\n", + " except asyncio.TimeoutError:\n", + " pass\n", + "\n", + "\n", + "r = await redis.from_url(\"redis://localhost\")\n", + "pubsub = r.pubsub()\n", + "await pubsub.psubscribe(\"channel:*\")\n", + "\n", + "future = asyncio.create_task(reader(pubsub))\n", + "\n", + "await r.publish(\"channel:1\", \"Hello\")\n", + "await r.publish(\"channel:2\", \"World\")\n", + "await r.publish(\"channel:1\", STOPWORD)\n", + "\n", + "await future" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sentinel Client\n", + "\n", + "The Sentinel client requires a list of Redis Sentinel addresses to connect to and start discovering services.\n", + "\n", + "Calling aioredis.sentinel.Sentinel.master_for or aioredis.sentinel.Sentinel.slave_for methods will return Redis clients connected to specified services monitored by Sentinel.\n", + "\n", + "Sentinel client will detect failover and reconnect Redis clients automatically." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "import asyncio\n", + "\n", + "from redis.asyncio.sentinel import Sentinel\n", + "\n", + "\n", + "sentinel = Sentinel([(\"localhost\", 26379), (\"sentinel2\", 26379)])\n", + "r = sentinel.master_for(\"mymaster\")\n", + "\n", + "ok = await r.set(\"key\", \"value\")\n", + "assert ok\n", + "val = await r.get(\"key\")\n", + "assert val == b\"value\"" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 2e42acb81b..630bad4f55 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -9,7 +9,7 @@ Welcome to redis-py's documentation! Getting Started **************** -`redis-py `_ requires a running Redis server, and Python 3.6+. See the `Redis +`redis-py `_ requires a running Redis server, and Python 3.7+. See the `Redis quickstart `_ for Redis installation instructions. redis-py can be installed using pip via ``pip install redis``. diff --git a/redis/__init__.py b/redis/__init__.py index 35044be29d..b7560a6715 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -1,10 +1,5 @@ import sys -if sys.version_info >= (3, 8): - from importlib import metadata -else: - import importlib_metadata as metadata - from redis.client import Redis, StrictRedis from redis.cluster import RedisCluster from redis.connection import ( @@ -37,6 +32,11 @@ ) from redis.utils import from_url +if sys.version_info >= (3, 8): + from importlib import metadata +else: + import importlib_metadata as metadata + def int_or_str(value): try: diff --git a/redis/asyncio/__init__.py b/redis/asyncio/__init__.py new file mode 100644 index 0000000000..c655c7da4b --- /dev/null +++ b/redis/asyncio/__init__.py @@ -0,0 +1,58 @@ +from redis.asyncio.client import Redis, StrictRedis +from redis.asyncio.connection import ( + BlockingConnectionPool, + Connection, + ConnectionPool, + SSLConnection, + UnixDomainSocketConnection, +) +from redis.asyncio.sentinel import ( + Sentinel, + SentinelConnectionPool, + SentinelManagedConnection, + SentinelManagedSSLConnection, +) +from redis.asyncio.utils import from_url +from redis.exceptions import ( + AuthenticationError, + AuthenticationWrongNumberOfArgsError, + BusyLoadingError, + ChildDeadlockedError, + ConnectionError, + DataError, + InvalidResponse, + PubSubError, + ReadOnlyError, + RedisError, + ResponseError, + TimeoutError, + WatchError, +) + +__all__ = [ + "AuthenticationError", + "AuthenticationWrongNumberOfArgsError", + "BlockingConnectionPool", + "BusyLoadingError", + "ChildDeadlockedError", + "Connection", + "ConnectionError", + "ConnectionPool", + "DataError", + "from_url", + "InvalidResponse", + "PubSubError", + "ReadOnlyError", + "Redis", + "RedisError", + "ResponseError", + "Sentinel", + "SentinelConnectionPool", + "SentinelManagedConnection", + "SentinelManagedSSLConnection", + "SSLConnection", + "StrictRedis", + "TimeoutError", + "UnixDomainSocketConnection", + "WatchError", +] diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py new file mode 100644 index 0000000000..619592ef76 --- /dev/null +++ b/redis/asyncio/client.py @@ -0,0 +1,1344 @@ +import asyncio +import copy +import inspect +import re +import warnings +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Awaitable, + Callable, + Dict, + Iterable, + List, + Mapping, + MutableMapping, + NoReturn, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) + +from redis.asyncio.connection import ( + Connection, + ConnectionPool, + SSLConnection, + UnixDomainSocketConnection, +) +from redis.client import ( + EMPTY_RESPONSE, + NEVER_DECODE, + AbstractRedis, + CaseInsensitiveDict, + bool_ok, +) +from redis.commands import ( + AsyncCoreCommands, + AsyncSentinelCommands, + RedisModuleCommands, + list_or_args, +) +from redis.compat import Protocol, TypedDict +from redis.exceptions import ( + ConnectionError, + ExecAbortError, + PubSubError, + RedisError, + ResponseError, + TimeoutError, + WatchError, +) +from redis.lock import Lock +from redis.retry import Retry +from redis.typing import ChannelT, EncodableT, KeyT +from redis.utils import safe_str, str_if_bytes + +PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]] +_KeyT = TypeVar("_KeyT", bound=KeyT) +_ArgT = TypeVar("_ArgT", KeyT, EncodableT) +_RedisT = TypeVar("_RedisT", bound="Redis") +_NormalizeKeysT = TypeVar("_NormalizeKeysT", bound=Mapping[ChannelT, object]) +if TYPE_CHECKING: + from redis.commands.core import Script + + +class ResponseCallbackProtocol(Protocol): + def __call__(self, response: Any, **kwargs): + ... + + +class AsyncResponseCallbackProtocol(Protocol): + async def __call__(self, response: Any, **kwargs): + ... + + +ResponseCallbackT = Union[ResponseCallbackProtocol, AsyncResponseCallbackProtocol] + + +class Redis( + AbstractRedis, RedisModuleCommands, AsyncCoreCommands, AsyncSentinelCommands +): + """ + Implementation of the Redis protocol. + + This abstract class provides a Python interface to all Redis commands + and an implementation of the Redis protocol. + + Pipelines derive from this, implementing how + the commands are sent and received to the Redis server. Based on + configuration, an instance will either use a ConnectionPool, or + Connection object to talk to redis. + """ + + response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT] + + @classmethod + def from_url(cls, url: str, **kwargs): + """ + Return a Redis client object configured from the given URL + + For example:: + + redis://[[username]:[password]]@localhost:6379/0 + rediss://[[username]:[password]]@localhost:6379/0 + unix://[[username]:[password]]@/path/to/socket.sock?db=0 + + Three URL schemes are supported: + + - `redis://` creates a TCP socket connection. See more at: + + - `rediss://` creates a SSL wrapped TCP socket connection. See more at: + + - ``unix://``: creates a Unix Domain Socket connection. + + The username, password, hostname, path and all querystring values + are passed through urllib.parse.unquote in order to replace any + percent-encoded values with their corresponding characters. + + There are several ways to specify a database number. The first value + found will be used: + 1. A ``db`` querystring option, e.g. redis://localhost?db=0 + 2. If using the redis:// or rediss:// schemes, the path argument + of the url, e.g. redis://localhost/0 + 3. A ``db`` keyword argument to this function. + + If none of these options are specified, the default db=0 is used. + + All querystring options are cast to their appropriate Python types. + Boolean arguments can be specified with string values "True"/"False" + or "Yes"/"No". Values that cannot be properly cast cause a + ``ValueError`` to be raised. Once parsed, the querystring arguments + and keyword arguments are passed to the ``ConnectionPool``'s + class initializer. In the case of conflicting arguments, querystring + arguments always win. + + """ + connection_pool = ConnectionPool.from_url(url, **kwargs) + return cls(connection_pool=connection_pool) + + def __init__( + self, + *, + host: str = "localhost", + port: int = 6379, + db: Union[str, int] = 0, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + socket_keepalive: Optional[bool] = None, + socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, + connection_pool: Optional[ConnectionPool] = None, + unix_socket_path: Optional[str] = None, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + retry_on_timeout: bool = False, + ssl: bool = False, + ssl_keyfile: Optional[str] = None, + ssl_certfile: Optional[str] = None, + ssl_cert_reqs: str = "required", + ssl_ca_certs: Optional[str] = None, + ssl_check_hostname: bool = False, + max_connections: Optional[int] = None, + single_connection_client: bool = False, + health_check_interval: int = 0, + client_name: Optional[str] = None, + username: Optional[str] = None, + retry: Optional[Retry] = None, + auto_close_connection_pool: bool = True, + ): + """ + Initialize a new Redis client. + To specify a retry policy, first set `retry_on_timeout` to `True` + then set `retry` to a valid `Retry` object + """ + kwargs: Dict[str, Any] + # auto_close_connection_pool only has an effect if connection_pool is + # None. This is a similar feature to the missing __del__ to resolve #1103, + # but it accounts for whether a user wants to manually close the connection + # pool, as a similar feature to ConnectionPool's __del__. + self.auto_close_connection_pool = ( + auto_close_connection_pool if connection_pool is None else False + ) + if not connection_pool: + kwargs = { + "db": db, + "username": username, + "password": password, + "socket_timeout": socket_timeout, + "encoding": encoding, + "encoding_errors": encoding_errors, + "decode_responses": decode_responses, + "retry_on_timeout": retry_on_timeout, + "retry": copy.deepcopy(retry), + "max_connections": max_connections, + "health_check_interval": health_check_interval, + "client_name": client_name, + } + # based on input, setup appropriate connection args + if unix_socket_path is not None: + kwargs.update( + { + "path": unix_socket_path, + "connection_class": UnixDomainSocketConnection, + } + ) + else: + # TCP specific options + kwargs.update( + { + "host": host, + "port": port, + "socket_connect_timeout": socket_connect_timeout, + "socket_keepalive": socket_keepalive, + "socket_keepalive_options": socket_keepalive_options, + } + ) + + if ssl: + kwargs.update( + { + "connection_class": SSLConnection, + "ssl_keyfile": ssl_keyfile, + "ssl_certfile": ssl_certfile, + "ssl_cert_reqs": ssl_cert_reqs, + "ssl_ca_certs": ssl_ca_certs, + "ssl_check_hostname": ssl_check_hostname, + } + ) + connection_pool = ConnectionPool(**kwargs) + self.connection_pool = connection_pool + self.single_connection_client = single_connection_client + self.connection: Optional[Connection] = None + + self.response_callbacks = CaseInsensitiveDict(self.__class__.RESPONSE_CALLBACKS) + + def __repr__(self): + return f"{self.__class__.__name__}<{self.connection_pool!r}>" + + def __await__(self): + return self.initialize().__await__() + + async def initialize(self: _RedisT) -> _RedisT: + if self.single_connection_client and self.connection is None: + self.connection = await self.connection_pool.get_connection("_") + return self + + def set_response_callback(self, command: str, callback: ResponseCallbackT): + """Set a custom Response Callback""" + self.response_callbacks[command] = callback + + def load_external_module( + self, + funcname, + func, + ): + """ + This function can be used to add externally defined redis modules, + and their namespaces to the redis client. + + funcname - A string containing the name of the function to create + func - The function, being added to this class. + + ex: Assume that one has a custom redis module named foomod that + creates command named 'foo.dothing' and 'foo.anotherthing' in redis. + To load function functions into this namespace: + + from redis import Redis + from foomodule import F + r = Redis() + r.load_external_module("foo", F) + r.foo().dothing('your', 'arguments') + + For a concrete example see the reimport of the redisjson module in + tests/test_connection.py::test_loading_external_modules + """ + setattr(self, funcname, func) + + def pipeline( + self, transaction: bool = True, shard_hint: Optional[str] = None + ) -> "Pipeline": + """ + Return a new pipeline object that can queue multiple commands for + later execution. ``transaction`` indicates whether all commands + should be executed atomically. Apart from making a group of operations + atomic, pipelines are useful for reducing the back-and-forth overhead + between the client and server. + """ + return Pipeline( + self.connection_pool, self.response_callbacks, transaction, shard_hint + ) + + async def transaction( + self, + func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]], + *watches: KeyT, + shard_hint: Optional[str] = None, + value_from_callable: bool = False, + watch_delay: Optional[float] = None, + ): + """ + Convenience method for executing the callable `func` as a transaction + while watching all keys specified in `watches`. The 'func' callable + should expect a single argument which is a Pipeline object. + """ + pipe: Pipeline + async with self.pipeline(True, shard_hint) as pipe: + while True: + try: + if watches: + await pipe.watch(*watches) + func_value = func(pipe) + if inspect.isawaitable(func_value): + func_value = await func_value + exec_value = await pipe.execute() + return func_value if value_from_callable else exec_value + except WatchError: + if watch_delay is not None and watch_delay > 0: + await asyncio.sleep(watch_delay) + continue + + def lock( + self, + name: KeyT, + timeout: Optional[float] = None, + sleep: float = 0.1, + blocking_timeout: Optional[float] = None, + lock_class: Optional[Type[Lock]] = None, + thread_local: bool = True, + ) -> Lock: + """ + Return a new Lock object using key ``name`` that mimics + the behavior of threading.Lock. + + If specified, ``timeout`` indicates a maximum life for the lock. + By default, it will remain locked until release() is called. + + ``sleep`` indicates the amount of time to sleep per loop iteration + when the lock is in blocking mode and another client is currently + holding the lock. + + ``blocking_timeout`` indicates the maximum amount of time in seconds to + spend trying to acquire the lock. A value of ``None`` indicates + continue trying forever. ``blocking_timeout`` can be specified as a + float or integer, both representing the number of seconds to wait. + + ``lock_class`` forces the specified lock implementation. + + ``thread_local`` indicates whether the lock token is placed in + thread-local storage. By default, the token is placed in thread local + storage so that a thread only sees its token, not a token set by + another thread. Consider the following timeline: + + time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. + thread-1 sets the token to "abc" + time: 1, thread-2 blocks trying to acquire `my-lock` using the + Lock instance. + time: 5, thread-1 has not yet completed. redis expires the lock + key. + time: 5, thread-2 acquired `my-lock` now that it's available. + thread-2 sets the token to "xyz" + time: 6, thread-1 finishes its work and calls release(). if the + token is *not* stored in thread local storage, then + thread-1 would see the token value as "xyz" and would be + able to successfully release the thread-2's lock. + + In some use cases it's necessary to disable thread local storage. For + example, if you have code where one thread acquires a lock and passes + that lock instance to a worker thread to release later. If thread + local storage isn't disabled in this case, the worker thread won't see + the token set by the thread that acquired the lock. Our assumption + is that these cases aren't common and as such default to using + thread local storage.""" + if lock_class is None: + lock_class = Lock + return lock_class( + self, + name, + timeout=timeout, + sleep=sleep, + blocking_timeout=blocking_timeout, + thread_local=thread_local, + ) + + def pubsub(self, **kwargs) -> "PubSub": + """ + Return a Publish/Subscribe object. With this object, you can + subscribe to channels and listen for messages that get published to + them. + """ + return PubSub(self.connection_pool, **kwargs) + + def monitor(self) -> "Monitor": + return Monitor(self.connection_pool) + + def client(self) -> "Redis": + return self.__class__( + connection_pool=self.connection_pool, single_connection_client=True + ) + + async def __aenter__(self: _RedisT) -> _RedisT: + return await self.initialize() + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() + + _DEL_MESSAGE = "Unclosed Redis client" + + def __del__(self, _warnings: Any = warnings) -> None: + if self.connection is not None: + _warnings.warn( + f"Unclosed client session {self!r}", + ResourceWarning, + source=self, + ) + context = {"client": self, "message": self._DEL_MESSAGE} + asyncio.get_event_loop().call_exception_handler(context) + + async def close(self, close_connection_pool: Optional[bool] = None) -> None: + """ + Closes Redis client connection + + :param close_connection_pool: decides whether to close the connection pool used + by this Redis client, overriding Redis.auto_close_connection_pool. By default, + let Redis.auto_close_connection_pool decide whether to close the connection + pool. + """ + conn = self.connection + if conn: + self.connection = None + await self.connection_pool.release(conn) + if close_connection_pool or ( + close_connection_pool is None and self.auto_close_connection_pool + ): + await self.connection_pool.disconnect() + + async def _send_command_parse_response(self, conn, command_name, *args, **options): + """ + Send a command and parse the response + """ + await conn.send_command(*args) + return await self.parse_response(conn, command_name, **options) + + async def _disconnect_raise(self, conn: Connection, error: Exception): + """ + Close the connection and raise an exception + if retry_on_timeout is not set or the error + is not a TimeoutError + """ + await conn.disconnect() + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + raise error + + # COMMAND EXECUTION AND PROTOCOL PARSING + async def execute_command(self, *args, **options): + """Execute a command and return a parsed response""" + await self.initialize() + pool = self.connection_pool + command_name = args[0] + conn = self.connection or await pool.get_connection(command_name, **options) + + try: + return await conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_raise(conn, error), + ) + finally: + if not self.connection: + await pool.release(conn) + + async def parse_response( + self, connection: Connection, command_name: Union[str, bytes], **options + ): + """Parses a response from the Redis server""" + try: + if NEVER_DECODE in options: + response = await connection.read_response(disable_encoding=True) + else: + response = await connection.read_response() + except ResponseError: + if EMPTY_RESPONSE in options: + return options[EMPTY_RESPONSE] + raise + if command_name in self.response_callbacks: + # Mypy bug: https://github.com/python/mypy/issues/10977 + command_name = cast(str, command_name) + retval = self.response_callbacks[command_name](response, **options) + return await retval if inspect.isawaitable(retval) else retval + return response + + +StrictRedis = Redis + + +class MonitorCommandInfo(TypedDict): + time: float + db: int + client_address: str + client_port: str + client_type: str + command: str + + +class Monitor: + """ + Monitor is useful for handling the MONITOR command to the redis server. + next_command() method returns one command from monitor + listen() method yields commands from monitor. + """ + + monitor_re = re.compile(r"\[(\d+) (.*)\] (.*)") + command_re = re.compile(r'"(.*?)(? MonitorCommandInfo: + """Parse the response from a monitor command""" + await self.connect() + response = await self.connection.read_response() + if isinstance(response, bytes): + response = self.connection.encoder.decode(response, force=True) + command_time, command_data = response.split(" ", 1) + m = self.monitor_re.match(command_data) + db_id, client_info, command = m.groups() + command = " ".join(self.command_re.findall(command)) + # Redis escapes double quotes because each piece of the command + # string is surrounded by double quotes. We don't have that + # requirement so remove the escaping and leave the quote. + command = command.replace('\\"', '"') + + if client_info == "lua": + client_address = "lua" + client_port = "" + client_type = "lua" + elif client_info.startswith("unix"): + client_address = "unix" + client_port = client_info[5:] + client_type = "unix" + else: + # use rsplit as ipv6 addresses contain colons + client_address, client_port = client_info.rsplit(":", 1) + client_type = "tcp" + return { + "time": float(command_time), + "db": int(db_id), + "client_address": client_address, + "client_port": client_port, + "client_type": client_type, + "command": command, + } + + async def listen(self) -> AsyncIterator[MonitorCommandInfo]: + """Listen for commands coming to the server.""" + while True: + yield await self.next_command() + + +class PubSub: + """ + PubSub provides publish, subscribe and listen support to Redis channels. + + After subscribing to one or more channels, the listen() method will block + until a message arrives on one of the subscribed channels. That message + will be returned and it's safe to start listening again. + """ + + PUBLISH_MESSAGE_TYPES = ("message", "pmessage") + UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe") + HEALTH_CHECK_MESSAGE = "redis-py-health-check" + + def __init__( + self, + connection_pool: ConnectionPool, + shard_hint: Optional[str] = None, + ignore_subscribe_messages: bool = False, + encoder=None, + ): + self.connection_pool = connection_pool + self.shard_hint = shard_hint + self.ignore_subscribe_messages = ignore_subscribe_messages + self.connection = None + # we need to know the encoding options for this connection in order + # to lookup channel and pattern names for callback handlers. + self.encoder = encoder + if self.encoder is None: + self.encoder = self.connection_pool.get_encoder() + if self.encoder.decode_responses: + self.health_check_response: Iterable[Union[str, bytes]] = [ + "pong", + self.HEALTH_CHECK_MESSAGE, + ] + else: + self.health_check_response = [ + b"pong", + self.encoder.encode(self.HEALTH_CHECK_MESSAGE), + ] + self.channels = {} + self.pending_unsubscribe_channels = set() + self.patterns = {} + self.pending_unsubscribe_patterns = set() + self._lock = asyncio.Lock() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.reset() + + def __del__(self): + if self.connection: + self.connection.clear_connect_callbacks() + + async def reset(self): + async with self._lock: + if self.connection: + await self.connection.disconnect() + self.connection.clear_connect_callbacks() + await self.connection_pool.release(self.connection) + self.connection = None + self.channels = {} + self.pending_unsubscribe_channels = set() + self.patterns = {} + self.pending_unsubscribe_patterns = set() + + def close(self) -> Awaitable[NoReturn]: + return self.reset() + + async def on_connect(self, connection: Connection): + """Re-subscribe to any channels and patterns previously subscribed to""" + # NOTE: for python3, we can't pass bytestrings as keyword arguments + # so we need to decode channel/pattern names back to unicode strings + # before passing them to [p]subscribe. + self.pending_unsubscribe_channels.clear() + self.pending_unsubscribe_patterns.clear() + if self.channels: + channels = {} + for k, v in self.channels.items(): + channels[self.encoder.decode(k, force=True)] = v + await self.subscribe(**channels) + if self.patterns: + patterns = {} + for k, v in self.patterns.items(): + patterns[self.encoder.decode(k, force=True)] = v + await self.psubscribe(**patterns) + + @property + def subscribed(self): + """Indicates if there are subscriptions to any channels or patterns""" + return bool(self.channels or self.patterns) + + async def execute_command(self, *args: EncodableT): + """Execute a publish/subscribe command""" + + # NOTE: don't parse the response in this function -- it could pull a + # legitimate message off the stack if the connection is already + # subscribed to one or more channels + + if self.connection is None: + self.connection = await self.connection_pool.get_connection( + "pubsub", self.shard_hint + ) + # register a callback that re-subscribes to any channels we + # were listening to when we were disconnected + self.connection.register_connect_callback(self.on_connect) + connection = self.connection + kwargs = {"check_health": not self.subscribed} + await self._execute(connection, connection.send_command, *args, **kwargs) + + async def _disconnect_raise_connect(self, conn, error): + """ + Close the connection and raise an exception + if retry_on_timeout is not set or the error + is not a TimeoutError. Otherwise, try to reconnect + """ + await conn.disconnect() + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + raise error + await conn.connect() + + async def _execute(self, conn, command, *args, **kwargs): + """ + Connect manually upon disconnection. If the Redis server is down, + this will fail and raise a ConnectionError as desired. + After reconnection, the ``on_connect`` callback should have been + called by the # connection to resubscribe us to any channels and + patterns we were previously listening to + """ + return await conn.retry.call_with_retry( + lambda: command(*args, **kwargs), + lambda error: self._disconnect_raise_connect(conn, error), + ) + + async def parse_response(self, block: bool = True, timeout: float = 0): + """Parse the response from a publish/subscribe command""" + conn = self.connection + if conn is None: + raise RuntimeError( + "pubsub connection not set: " + "did you forget to call subscribe() or psubscribe()?" + ) + + await self.check_health() + + if not block and not await self._execute(conn, conn.can_read, timeout=timeout): + return None + response = await self._execute(conn, conn.read_response) + + if conn.health_check_interval and response == self.health_check_response: + # ignore the health check message as user might not expect it + return None + return response + + async def check_health(self): + conn = self.connection + if conn is None: + raise RuntimeError( + "pubsub connection not set: " + "did you forget to call subscribe() or psubscribe()?" + ) + + if ( + conn.health_check_interval + and asyncio.get_event_loop().time() > conn.next_health_check + ): + await conn.send_command( + "PING", self.HEALTH_CHECK_MESSAGE, check_health=False + ) + + def _normalize_keys(self, data: _NormalizeKeysT) -> _NormalizeKeysT: + """ + normalize channel/pattern names to be either bytes or strings + based on whether responses are automatically decoded. this saves us + from coercing the value for each message coming in. + """ + encode = self.encoder.encode + decode = self.encoder.decode + return {decode(encode(k)): v for k, v in data.items()} # type: ignore[return-value] # noqa: E501 + + async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler): + """ + Subscribe to channel patterns. Patterns supplied as keyword arguments + expect a pattern name as the key and a callable as the value. A + pattern's callable will be invoked automatically when a message is + received on that pattern rather than producing a message via + ``listen()``. + """ + parsed_args = list_or_args((args[0],), args[1:]) if args else args + new_patterns: Dict[ChannelT, PubSubHandler] = dict.fromkeys(parsed_args) + # Mypy bug: https://github.com/python/mypy/issues/10970 + new_patterns.update(kwargs) # type: ignore[arg-type] + ret_val = await self.execute_command("PSUBSCRIBE", *new_patterns.keys()) + # update the patterns dict AFTER we send the command. we don't want to + # subscribe twice to these patterns, once for the command and again + # for the reconnection. + new_patterns = self._normalize_keys(new_patterns) + self.patterns.update(new_patterns) + self.pending_unsubscribe_patterns.difference_update(new_patterns) + return ret_val + + def punsubscribe(self, *args: ChannelT) -> Awaitable: + """ + Unsubscribe from the supplied patterns. If empty, unsubscribe from + all patterns. + """ + patterns: Iterable[ChannelT] + if args: + parsed_args = list_or_args((args[0],), args[1:]) + patterns = self._normalize_keys(dict.fromkeys(parsed_args)).keys() + else: + parsed_args = [] + patterns = self.patterns + self.pending_unsubscribe_patterns.update(patterns) + return self.execute_command("PUNSUBSCRIBE", *parsed_args) + + async def subscribe(self, *args: ChannelT, **kwargs: Callable): + """ + Subscribe to channels. Channels supplied as keyword arguments expect + a channel name as the key and a callable as the value. A channel's + callable will be invoked automatically when a message is received on + that channel rather than producing a message via ``listen()`` or + ``get_message()``. + """ + parsed_args = list_or_args((args[0],), args[1:]) if args else () + new_channels = dict.fromkeys(parsed_args) + # Mypy bug: https://github.com/python/mypy/issues/10970 + new_channels.update(kwargs) # type: ignore[arg-type] + ret_val = await self.execute_command("SUBSCRIBE", *new_channels.keys()) + # update the channels dict AFTER we send the command. we don't want to + # subscribe twice to these channels, once for the command and again + # for the reconnection. + new_channels = self._normalize_keys(new_channels) + self.channels.update(new_channels) + self.pending_unsubscribe_channels.difference_update(new_channels) + return ret_val + + def unsubscribe(self, *args) -> Awaitable: + """ + Unsubscribe from the supplied channels. If empty, unsubscribe from + all channels + """ + if args: + parsed_args = list_or_args(args[0], args[1:]) + channels = self._normalize_keys(dict.fromkeys(parsed_args)) + else: + parsed_args = [] + channels = self.channels + self.pending_unsubscribe_channels.update(channels) + return self.execute_command("UNSUBSCRIBE", *parsed_args) + + async def listen(self) -> AsyncIterator: + """Listen for messages on channels this client has been subscribed to""" + while self.subscribed: + response = await self.handle_message(await self.parse_response(block=True)) + if response is not None: + yield response + + async def get_message( + self, ignore_subscribe_messages: bool = False, timeout: float = 0.0 + ): + """ + Get the next message if one is available, otherwise None. + + If timeout is specified, the system will wait for `timeout` seconds + before returning. Timeout should be specified as a floating point + number. + """ + response = await self.parse_response(block=False, timeout=timeout) + if response: + return await self.handle_message(response, ignore_subscribe_messages) + return None + + def ping(self, message=None) -> Awaitable: + """ + Ping the Redis server + """ + message = "" if message is None else message + return self.execute_command("PING", message) + + async def handle_message(self, response, ignore_subscribe_messages=False): + """ + Parses a pub/sub message. If the channel or pattern was subscribed to + with a message handler, the handler is invoked instead of a parsed + message being returned. + """ + message_type = str_if_bytes(response[0]) + if message_type == "pmessage": + message = { + "type": message_type, + "pattern": response[1], + "channel": response[2], + "data": response[3], + } + elif message_type == "pong": + message = { + "type": message_type, + "pattern": None, + "channel": None, + "data": response[1], + } + else: + message = { + "type": message_type, + "pattern": None, + "channel": response[1], + "data": response[2], + } + + # if this is an unsubscribe message, remove it from memory + if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES: + if message_type == "punsubscribe": + pattern = response[1] + if pattern in self.pending_unsubscribe_patterns: + self.pending_unsubscribe_patterns.remove(pattern) + self.patterns.pop(pattern, None) + else: + channel = response[1] + if channel in self.pending_unsubscribe_channels: + self.pending_unsubscribe_channels.remove(channel) + self.channels.pop(channel, None) + + if message_type in self.PUBLISH_MESSAGE_TYPES: + # if there's a message handler, invoke it + if message_type == "pmessage": + handler = self.patterns.get(message["pattern"], None) + else: + handler = self.channels.get(message["channel"], None) + if handler: + if inspect.iscoroutinefunction(handler): + await handler(message) + else: + handler(message) + return None + elif message_type != "pong": + # this is a subscribe/unsubscribe message. ignore if we don't + # want them + if ignore_subscribe_messages or self.ignore_subscribe_messages: + return None + + return message + + async def run( + self, + *, + exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None, + poll_timeout: float = 1.0, + ) -> None: + """Process pub/sub messages using registered callbacks. + + This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in + redis-py, but it is a coroutine. To launch it as a separate task, use + ``asyncio.create_task``: + + >>> task = asyncio.create_task(pubsub.run()) + + To shut it down, use asyncio cancellation: + + >>> task.cancel() + >>> await task + """ + for channel, handler in self.channels.items(): + if handler is None: + raise PubSubError(f"Channel: '{channel}' has no handler registered") + for pattern, handler in self.patterns.items(): + if handler is None: + raise PubSubError(f"Pattern: '{pattern}' has no handler registered") + + while True: + try: + await self.get_message( + ignore_subscribe_messages=True, timeout=poll_timeout + ) + except asyncio.CancelledError: + raise + except BaseException as e: + if exception_handler is None: + raise + res = exception_handler(e, self) + if inspect.isawaitable(res): + await res + # Ensure that other tasks on the event loop get a chance to run + # if we didn't have to block for I/O anywhere. + await asyncio.sleep(0) + + +class PubsubWorkerExceptionHandler(Protocol): + def __call__(self, e: BaseException, pubsub: PubSub): + ... + + +class AsyncPubsubWorkerExceptionHandler(Protocol): + async def __call__(self, e: BaseException, pubsub: PubSub): + ... + + +PSWorkerThreadExcHandlerT = Union[ + PubsubWorkerExceptionHandler, AsyncPubsubWorkerExceptionHandler +] + + +CommandT = Tuple[Tuple[Union[str, bytes], ...], Mapping[str, Any]] +CommandStackT = List[CommandT] + + +class Pipeline(Redis): # lgtm [py/init-calls-subclass] + """ + Pipelines provide a way to transmit multiple commands to the Redis server + in one transmission. This is convenient for batch processing, such as + saving all the values in a list to Redis. + + All commands executed within a pipeline are wrapped with MULTI and EXEC + calls. This guarantees all commands executed in the pipeline will be + executed atomically. + + Any command raising an exception does *not* halt the execution of + subsequent commands in the pipeline. Instead, the exception is caught + and its instance is placed into the response list returned by execute(). + Code iterating over the response list should be able to deal with an + instance of an exception as a potential value. In general, these will be + ResponseError exceptions, such as those raised when issuing a command + on a key of a different datatype. + """ + + UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} + + def __init__( + self, + connection_pool: ConnectionPool, + response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT], + transaction: bool, + shard_hint: Optional[str], + ): + self.connection_pool = connection_pool + self.connection = None + self.response_callbacks = response_callbacks + self.is_transaction = transaction + self.shard_hint = shard_hint + self.watching = False + self.command_stack: CommandStackT = [] + self.scripts: Set["Script"] = set() + self.explicit_transaction = False + + async def __aenter__(self: _RedisT) -> _RedisT: + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.reset() + + def __await__(self): + return self._async_self().__await__() + + _DEL_MESSAGE = "Unclosed Pipeline client" + + def __len__(self): + return len(self.command_stack) + + def __bool__(self): + """Pipeline instances should always evaluate to True""" + return True + + async def _async_self(self): + return self + + async def reset(self): + self.command_stack = [] + self.scripts = set() + # make sure to reset the connection state in the event that we were + # watching something + if self.watching and self.connection: + try: + # call this manually since our unwatch or + # immediate_execute_command methods can call reset() + await self.connection.send_command("UNWATCH") + await self.connection.read_response() + except ConnectionError: + # disconnect will also remove any previous WATCHes + if self.connection: + await self.connection.disconnect() + # clean up the other instance attributes + self.watching = False + self.explicit_transaction = False + # we can safely return the connection to the pool here since we're + # sure we're no longer WATCHing anything + if self.connection: + await self.connection_pool.release(self.connection) + self.connection = None + + def multi(self): + """ + Start a transactional block of the pipeline after WATCH commands + are issued. End the transactional block with `execute`. + """ + if self.explicit_transaction: + raise RedisError("Cannot issue nested calls to MULTI") + if self.command_stack: + raise RedisError( + "Commands without an initial WATCH have already " "been issued" + ) + self.explicit_transaction = True + + def execute_command( + self, *args, **kwargs + ) -> Union["Pipeline", Awaitable["Pipeline"]]: + if (self.watching or args[0] == "WATCH") and not self.explicit_transaction: + return self.immediate_execute_command(*args, **kwargs) + return self.pipeline_execute_command(*args, **kwargs) + + async def _disconnect_reset_raise(self, conn, error): + """ + Close the connection, reset watching state and + raise an exception if we were watching, + retry_on_timeout is not set, + or the error is not a TimeoutError + """ + await conn.disconnect() + # if we were already watching a variable, the watch is no longer + # valid since this connection has died. raise a WatchError, which + # indicates the user should retry this transaction. + if self.watching: + await self.reset() + raise WatchError( + "A ConnectionError occurred on while " "watching one or more keys" + ) + # if retry_on_timeout is not set, or the error is not + # a TimeoutError, raise it + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + await self.reset() + raise + + async def immediate_execute_command(self, *args, **options): + """ + Execute a command immediately, but don't auto-retry on a + ConnectionError if we're already WATCHing a variable. Used when + issuing WATCH or subsequent commands retrieving their values but before + MULTI is called. + """ + command_name = args[0] + conn = self.connection + # if this is the first call, we need a connection + if not conn: + conn = await self.connection_pool.get_connection( + command_name, self.shard_hint + ) + self.connection = conn + + return await conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_reset_raise(conn, error), + ) + + def pipeline_execute_command(self, *args, **options): + """ + Stage a command to be executed when execute() is next called + + Returns the current Pipeline object back so commands can be + chained together, such as: + + pipe = pipe.set('foo', 'bar').incr('baz').decr('bang') + + At some other point, you can then run: pipe.execute(), + which will execute all commands queued in the pipe. + """ + self.command_stack.append((args, options)) + return self + + async def _execute_transaction( # noqa: C901 + self, connection: Connection, commands: CommandStackT, raise_on_error + ): + pre: CommandT = (("MULTI",), {}) + post: CommandT = (("EXEC",), {}) + cmds = (pre, *commands, post) + all_cmds = connection.pack_commands( + args for args, options in cmds if EMPTY_RESPONSE not in options + ) + await connection.send_packed_command(all_cmds) + errors = [] + + # parse off the response for MULTI + # NOTE: we need to handle ResponseErrors here and continue + # so that we read all the additional command messages from + # the socket + try: + await self.parse_response(connection, "_") + except ResponseError as err: + errors.append((0, err)) + + # and all the other commands + for i, command in enumerate(commands): + if EMPTY_RESPONSE in command[1]: + errors.append((i, command[1][EMPTY_RESPONSE])) + else: + try: + await self.parse_response(connection, "_") + except ResponseError as err: + self.annotate_exception(err, i + 1, command[0]) + errors.append((i, err)) + + # parse the EXEC. + try: + response = await self.parse_response(connection, "_") + except ExecAbortError as err: + if errors: + raise errors[0][1] from err + raise + + # EXEC clears any watched keys + self.watching = False + + if response is None: + raise WatchError("Watched variable changed.") from None + + # put any parse errors into the response + for i, e in errors: + response.insert(i, e) + + if len(response) != len(commands): + if self.connection: + await self.connection.disconnect() + raise ResponseError( + "Wrong number of response items from pipeline execution" + ) from None + + # find any errors in the response and raise if necessary + if raise_on_error: + self.raise_first_error(commands, response) + + # We have to run response callbacks manually + data = [] + for r, cmd in zip(response, commands): + if not isinstance(r, Exception): + args, options = cmd + command_name = args[0] + if command_name in self.response_callbacks: + r = self.response_callbacks[command_name](r, **options) + if inspect.isawaitable(r): + r = await r + data.append(r) + return data + + async def _execute_pipeline( + self, connection: Connection, commands: CommandStackT, raise_on_error: bool + ): + # build up all commands into a single request to increase network perf + all_cmds = connection.pack_commands([args for args, _ in commands]) + await connection.send_packed_command(all_cmds) + + response = [] + for args, options in commands: + try: + response.append( + await self.parse_response(connection, args[0], **options) + ) + except ResponseError as e: + response.append(e) + + if raise_on_error: + self.raise_first_error(commands, response) + return response + + def raise_first_error(self, commands: CommandStackT, response: Iterable[Any]): + for i, r in enumerate(response): + if isinstance(r, ResponseError): + self.annotate_exception(r, i + 1, commands[i][0]) + raise r + + def annotate_exception( + self, exception: Exception, number: int, command: Iterable[object] + ) -> None: + cmd = " ".join(map(safe_str, command)) + msg = f"Command # {number} ({cmd}) of pipeline caused error: {exception.args}" + exception.args = (msg,) + exception.args[1:] + + async def parse_response( + self, connection: Connection, command_name: Union[str, bytes], **options + ): + result = await super().parse_response(connection, command_name, **options) + if command_name in self.UNWATCH_COMMANDS: + self.watching = False + elif command_name == "WATCH": + self.watching = True + return result + + async def load_scripts(self): + # make sure all scripts that are about to be run on this pipeline exist + scripts = list(self.scripts) + immediate = self.immediate_execute_command + shas = [s.sha for s in scripts] + # we can't use the normal script_* methods because they would just + # get buffered in the pipeline. + exists = await immediate("SCRIPT EXISTS", *shas) + if not all(exists): + for s, exist in zip(scripts, exists): + if not exist: + s.sha = await immediate("SCRIPT LOAD", s.script) + + async def _disconnect_raise_reset(self, conn: Connection, error: Exception): + """ + Close the connection, raise an exception if we were watching, + and raise an exception if retry_on_timeout is not set, + or the error is not a TimeoutError + """ + await conn.disconnect() + # if we were watching a variable, the watch is no longer valid + # since this connection has died. raise a WatchError, which + # indicates the user should retry this transaction. + if self.watching: + raise WatchError( + "A ConnectionError occurred on while " "watching one or more keys" + ) + # if retry_on_timeout is not set, or the error is not + # a TimeoutError, raise it + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + await self.reset() + raise + + async def execute(self, raise_on_error: bool = True): + """Execute all the commands in the current pipeline""" + stack = self.command_stack + if not stack and not self.watching: + return [] + if self.scripts: + await self.load_scripts() + if self.is_transaction or self.explicit_transaction: + execute = self._execute_transaction + else: + execute = self._execute_pipeline + + conn = self.connection + if not conn: + conn = await self.connection_pool.get_connection("MULTI", self.shard_hint) + # assign to self.connection so reset() releases the connection + # back to the pool after we're done + self.connection = conn + conn = cast(Connection, conn) + + try: + return await conn.retry.call_with_retry( + lambda: execute(conn, stack, raise_on_error), + lambda error: self._disconnect_raise_reset(conn, error), + ) + finally: + await self.reset() + + async def discard(self): + """Flushes all previously queued commands + See: https://redis.io/commands/DISCARD + """ + await self.execute_command("DISCARD") + + async def watch(self, *names: KeyT): + """Watches the values at keys ``names``""" + if self.explicit_transaction: + raise RedisError("Cannot issue a WATCH after a MULTI") + return await self.execute_command("WATCH", *names) + + async def unwatch(self): + """Unwatches all previously specified keys""" + return self.watching and await self.execute_command("UNWATCH") or True diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py new file mode 100644 index 0000000000..ae54d8147a --- /dev/null +++ b/redis/asyncio/connection.py @@ -0,0 +1,1707 @@ +import asyncio +import copy +import enum +import errno +import inspect +import io +import os +import socket +import ssl +import sys +import threading +import weakref +from itertools import chain +from types import MappingProxyType +from typing import ( + Any, + Callable, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) +from urllib.parse import ParseResult, parse_qs, unquote, urlparse + +import async_timeout + +from redis.asyncio.retry import Retry +from redis.backoff import NoBackoff +from redis.compat import Protocol, TypedDict +from redis.exceptions import ( + AuthenticationError, + AuthenticationWrongNumberOfArgsError, + BusyLoadingError, + ChildDeadlockedError, + ConnectionError, + DataError, + ExecAbortError, + InvalidResponse, + ModuleError, + NoPermissionError, + NoScriptError, + ReadOnlyError, + RedisError, + ResponseError, + TimeoutError, +) +from redis.typing import EncodableT, EncodedT +from redis.utils import HIREDIS_AVAILABLE, str_if_bytes + +hiredis = None +if HIREDIS_AVAILABLE: + import hiredis + +NONBLOCKING_EXCEPTION_ERROR_NUMBERS = { + BlockingIOError: errno.EWOULDBLOCK, + ssl.SSLWantReadError: 2, + ssl.SSLWantWriteError: 2, + ssl.SSLError: 2, +} + +NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys()) + + +SYM_STAR = b"*" +SYM_DOLLAR = b"$" +SYM_CRLF = b"\r\n" +SYM_LF = b"\n" +SYM_EMPTY = b"" + +SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." + + +class _Sentinel(enum.Enum): + sentinel = object() + + +SENTINEL = _Sentinel.sentinel +MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs." +NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" +MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible." +MODULE_EXPORTS_DATA_TYPES_ERROR = ( + "Error unloading module: the module " + "exports one or more module-side data " + "types, can't unload" +) + + +class _HiredisReaderArgs(TypedDict, total=False): + protocolError: Callable[[str], Exception] + replyError: Callable[[str], Exception] + encoding: Optional[str] + errors: Optional[str] + + +class Encoder: + """Encode strings to bytes-like and decode bytes-like to strings""" + + __slots__ = "encoding", "encoding_errors", "decode_responses" + + def __init__(self, encoding: str, encoding_errors: str, decode_responses: bool): + self.encoding = encoding + self.encoding_errors = encoding_errors + self.decode_responses = decode_responses + + def encode(self, value: EncodableT) -> EncodedT: + """Return a bytestring or bytes-like representation of the value""" + if isinstance(value, (bytes, memoryview)): + return value + if isinstance(value, bool): + # special case bool since it is a subclass of int + raise DataError( + "Invalid input of type: 'bool'. " + "Convert to a bytes, string, int or float first." + ) + if isinstance(value, (int, float)): + return repr(value).encode() + if not isinstance(value, str): + # a value we don't know how to deal with. throw an error + typename = value.__class__.__name__ # type: ignore[unreachable] + raise DataError( + f"Invalid input of type: {typename!r}. " + "Convert to a bytes, string, int or float first." + ) + return value.encode(self.encoding, self.encoding_errors) + + def decode(self, value: EncodableT, force=False) -> EncodableT: + """Return a unicode string from the bytes-like representation""" + if self.decode_responses or force: + if isinstance(value, memoryview): + return value.tobytes().decode(self.encoding, self.encoding_errors) + if isinstance(value, bytes): + return value.decode(self.encoding, self.encoding_errors) + return value + + +ExceptionMappingT = Mapping[str, Union[Type[Exception], Mapping[str, Type[Exception]]]] + + +class BaseParser: + """Plain Python parsing class""" + + __slots__ = "_stream", "_buffer", "_read_size" + + EXCEPTION_CLASSES: ExceptionMappingT = { + "ERR": { + "max number of clients reached": ConnectionError, + "Client sent AUTH, but no password is set": AuthenticationError, + "invalid password": AuthenticationError, + # some Redis server versions report invalid command syntax + # in lowercase + "wrong number of arguments for 'auth' command": AuthenticationWrongNumberOfArgsError, # noqa: E501 + # some Redis server versions report invalid command syntax + # in uppercase + "wrong number of arguments for 'AUTH' command": AuthenticationWrongNumberOfArgsError, # noqa: E501 + MODULE_LOAD_ERROR: ModuleError, + MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError, + NO_SUCH_MODULE_ERROR: ModuleError, + MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError, + }, + "EXECABORT": ExecAbortError, + "LOADING": BusyLoadingError, + "NOSCRIPT": NoScriptError, + "READONLY": ReadOnlyError, + "NOAUTH": AuthenticationError, + "NOPERM": NoPermissionError, + } + + def __init__(self, socket_read_size: int): + self._stream: Optional[asyncio.StreamReader] = None + self._buffer: Optional[SocketBuffer] = None + self._read_size = socket_read_size + + def __del__(self): + try: + self.on_disconnect() + except Exception: + pass + + def parse_error(self, response: str) -> ResponseError: + """Parse an error response""" + error_code = response.split(" ")[0] + if error_code in self.EXCEPTION_CLASSES: + response = response[len(error_code) + 1 :] + exception_class = self.EXCEPTION_CLASSES[error_code] + if isinstance(exception_class, dict): + exception_class = exception_class.get(response, ResponseError) + return exception_class(response) + return ResponseError(response) + + def on_disconnect(self): + raise NotImplementedError() + + def on_connect(self, connection: "Connection"): + raise NotImplementedError() + + async def can_read(self, timeout: float) -> bool: + raise NotImplementedError() + + async def read_response( + self, disable_decoding: bool = False + ) -> Union[EncodableT, ResponseError, None, List[EncodableT]]: + raise NotImplementedError() + + +class SocketBuffer: + """Async-friendly re-impl of redis-py's SocketBuffer. + + TODO: We're currently passing through two buffers, + the asyncio.StreamReader and this. I imagine we can reduce the layers here + while maintaining compliance with prior art. + """ + + def __init__( + self, + stream_reader: asyncio.StreamReader, + socket_read_size: int, + socket_timeout: Optional[float], + ): + self._stream: Optional[asyncio.StreamReader] = stream_reader + self.socket_read_size = socket_read_size + self.socket_timeout = socket_timeout + self._buffer: Optional[io.BytesIO] = io.BytesIO() + # number of bytes written to the buffer from the socket + self.bytes_written = 0 + # number of bytes read from the buffer + self.bytes_read = 0 + + @property + def length(self): + return self.bytes_written - self.bytes_read + + async def _read_from_socket( + self, + length: Optional[int] = None, + timeout: Union[float, None, _Sentinel] = SENTINEL, + raise_on_timeout: bool = True, + ) -> bool: + buf = self._buffer + if buf is None or self._stream is None: + raise RedisError("Buffer is closed.") + buf.seek(self.bytes_written) + marker = 0 + timeout = timeout if timeout is not SENTINEL else self.socket_timeout + + try: + while True: + async with async_timeout.timeout(timeout): + data = await self._stream.read(self.socket_read_size) + # an empty string indicates the server shutdown the socket + if isinstance(data, bytes) and len(data) == 0: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + buf.write(data) + data_length = len(data) + self.bytes_written += data_length + marker += data_length + + if length is not None and length > marker: + continue + return True + except (socket.timeout, asyncio.TimeoutError): + if raise_on_timeout: + raise TimeoutError("Timeout reading from socket") + return False + except NONBLOCKING_EXCEPTIONS as ex: + # if we're in nonblocking mode and the recv raises a + # blocking error, simply return False indicating that + # there's no data to be read. otherwise raise the + # original exception. + allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) + if not raise_on_timeout and ex.errno == allowed: + return False + raise ConnectionError(f"Error while reading from socket: {ex.args}") + + async def can_read(self, timeout: float) -> bool: + return bool(self.length) or await self._read_from_socket( + timeout=timeout, raise_on_timeout=False + ) + + async def read(self, length: int) -> bytes: + length = length + 2 # make sure to read the \r\n terminator + # make sure we've read enough data from the socket + if length > self.length: + await self._read_from_socket(length - self.length) + + if self._buffer is None: + raise RedisError("Buffer is closed.") + + self._buffer.seek(self.bytes_read) + data = self._buffer.read(length) + self.bytes_read += len(data) + + # purge the buffer when we've consumed it all so it doesn't + # grow forever + if self.bytes_read == self.bytes_written: + self.purge() + + return data[:-2] + + async def readline(self) -> bytes: + buf = self._buffer + if buf is None: + raise RedisError("Buffer is closed.") + + buf.seek(self.bytes_read) + data = buf.readline() + while not data.endswith(SYM_CRLF): + # there's more data in the socket that we need + await self._read_from_socket() + buf.seek(self.bytes_read) + data = buf.readline() + + self.bytes_read += len(data) + + # purge the buffer when we've consumed it all so it doesn't + # grow forever + if self.bytes_read == self.bytes_written: + self.purge() + + return data[:-2] + + def purge(self): + if self._buffer is None: + raise RedisError("Buffer is closed.") + + self._buffer.seek(0) + self._buffer.truncate() + self.bytes_written = 0 + self.bytes_read = 0 + + def close(self): + try: + self.purge() + self._buffer.close() # type: ignore[union-attr] + except Exception: + # issue #633 suggests the purge/close somehow raised a + # BadFileDescriptor error. Perhaps the client ran out of + # memory or something else? It's probably OK to ignore + # any error being raised from purge/close since we're + # removing the reference to the instance below. + pass + self._buffer = None + self._stream = None + + +class PythonParser(BaseParser): + """Plain Python parsing class""" + + __slots__ = BaseParser.__slots__ + ("encoder",) + + def __init__(self, socket_read_size: int): + super().__init__(socket_read_size) + self.encoder: Optional[Encoder] = None + + def on_connect(self, connection: "Connection"): + """Called when the stream connects""" + self._stream = connection._reader + if self._stream is None: + raise RedisError("Buffer is closed.") + + self._buffer = SocketBuffer( + self._stream, self._read_size, connection.socket_timeout + ) + self.encoder = connection.encoder + + def on_disconnect(self): + """Called when the stream disconnects""" + if self._stream is not None: + self._stream = None + if self._buffer is not None: + self._buffer.close() + self._buffer = None + self.encoder = None + + async def can_read(self, timeout: float): + return self._buffer and bool(await self._buffer.can_read(timeout)) + + async def read_response( + self, disable_decoding: bool = False + ) -> Union[EncodableT, ResponseError, None]: + if not self._buffer or not self.encoder: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + raw = await self._buffer.readline() + if not raw: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + response: Any + byte, response = raw[:1], raw[1:] + + if byte not in (b"-", b"+", b":", b"$", b"*"): + raise InvalidResponse(f"Protocol Error: {raw!r}") + + # server returned an error + if byte == b"-": + response = response.decode("utf-8", errors="replace") + error = self.parse_error(response) + # if the error is a ConnectionError, raise immediately so the user + # is notified + if isinstance(error, ConnectionError): + raise error + # otherwise, we're dealing with a ResponseError that might belong + # inside a pipeline response. the connection's read_response() + # and/or the pipeline's execute() will raise this error if + # necessary, so just return the exception instance here. + return error + # single value + elif byte == b"+": + pass + # int value + elif byte == b":": + response = int(response) + # bulk response + elif byte == b"$": + length = int(response) + if length == -1: + return None + response = await self._buffer.read(length) + # multi-bulk response + elif byte == b"*": + length = int(response) + if length == -1: + return None + response = [ + (await self.read_response(disable_decoding)) for _ in range(length) + ] + if isinstance(response, bytes) and disable_decoding is False: + response = self.encoder.decode(response) + return response + + +class HiredisParser(BaseParser): + """Parser class for connections using Hiredis""" + + __slots__ = BaseParser.__slots__ + ("_next_response", "_reader", "_socket_timeout") + + _next_response: bool + + def __init__(self, socket_read_size: int): + if not HIREDIS_AVAILABLE: + raise RedisError("Hiredis is not available.") + super().__init__(socket_read_size=socket_read_size) + self._reader: Optional[hiredis.Reader] = None + self._socket_timeout: Optional[float] = None + + def on_connect(self, connection: "Connection"): + self._stream = connection._reader + kwargs: _HiredisReaderArgs = { + "protocolError": InvalidResponse, + "replyError": self.parse_error, + } + if connection.encoder.decode_responses: + kwargs["encoding"] = connection.encoder.encoding + kwargs["errors"] = connection.encoder.encoding_errors + + self._reader = hiredis.Reader(**kwargs) + self._next_response = False + self._socket_timeout = connection.socket_timeout + + def on_disconnect(self): + self._stream = None + self._reader = None + self._next_response = False + + async def can_read(self, timeout: float): + if not self._reader: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + + if self._next_response is False: + self._next_response = self._reader.gets() + if self._next_response is False: + return await self.read_from_socket(timeout=timeout, raise_on_timeout=False) + return True + + async def read_from_socket( + self, + timeout: Union[float, None, _Sentinel] = SENTINEL, + raise_on_timeout: bool = True, + ): + if self._stream is None or self._reader is None: + raise RedisError("Parser already closed.") + + timeout = self._socket_timeout if timeout is SENTINEL else timeout + try: + async with async_timeout.timeout(timeout): + buffer = await self._stream.read(self._read_size) + if not isinstance(buffer, bytes) or len(buffer) == 0: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None + self._reader.feed(buffer) + # data was read from the socket and added to the buffer. + # return True to indicate that data was read. + return True + except asyncio.CancelledError: + raise + except (socket.timeout, asyncio.TimeoutError): + if raise_on_timeout: + raise TimeoutError("Timeout reading from socket") from None + return False + except NONBLOCKING_EXCEPTIONS as ex: + # if we're in nonblocking mode and the recv raises a + # blocking error, simply return False indicating that + # there's no data to be read. otherwise raise the + # original exception. + allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) + if not raise_on_timeout and ex.errno == allowed: + return False + raise ConnectionError(f"Error while reading from socket: {ex.args}") + + async def read_response( + self, disable_decoding: bool = False + ) -> Union[EncodableT, List[EncodableT]]: + if not self._stream or not self._reader: + self.on_disconnect() + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None + + response: Union[ + EncodableT, ConnectionError, List[Union[EncodableT, ConnectionError]] + ] + # _next_response might be cached from a can_read() call + if self._next_response is not False: + response = self._next_response + self._next_response = False + return response + + response = self._reader.gets() + while response is False: + await self.read_from_socket() + response = self._reader.gets() + + # if the response is a ConnectionError or the response is a list and + # the first item is a ConnectionError, raise it as something bad + # happened + if isinstance(response, ConnectionError): + raise response + elif ( + isinstance(response, list) + and response + and isinstance(response[0], ConnectionError) + ): + raise response[0] + # cast as there won't be a ConnectionError here. + return cast(Union[EncodableT, List[EncodableT]], response) + + +DefaultParser: Type[Union[PythonParser, HiredisParser]] +if HIREDIS_AVAILABLE: + DefaultParser = HiredisParser +else: + DefaultParser = PythonParser + + +class ConnectCallbackProtocol(Protocol): + def __call__(self, connection: "Connection"): + ... + + +class AsyncConnectCallbackProtocol(Protocol): + async def __call__(self, connection: "Connection"): + ... + + +ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol] + + +class Connection: + """Manages TCP communication to and from a Redis server""" + + __slots__ = ( + "pid", + "host", + "port", + "db", + "username", + "client_name", + "password", + "socket_timeout", + "socket_connect_timeout", + "socket_keepalive", + "socket_keepalive_options", + "socket_type", + "redis_connect_func", + "retry_on_timeout", + "health_check_interval", + "next_health_check", + "last_active_at", + "encoder", + "ssl_context", + "_reader", + "_writer", + "_parser", + "_connect_callbacks", + "_buffer_cutoff", + "_lock", + "_socket_read_size", + "__dict__", + ) + + def __init__( + self, + *, + host: str = "localhost", + port: Union[str, int] = 6379, + db: Union[str, int] = 0, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + socket_keepalive: bool = False, + socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, + socket_type: int = 0, + retry_on_timeout: bool = False, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + parser_class: Type[BaseParser] = DefaultParser, + socket_read_size: int = 65536, + health_check_interval: float = 0, + client_name: Optional[str] = None, + username: Optional[str] = None, + retry: Optional[Retry] = None, + redis_connect_func: Optional[ConnectCallbackT] = None, + encoder_class: Type[Encoder] = Encoder, + ): + self.pid = os.getpid() + self.host = host + self.port = int(port) + self.db = db + self.username = username + self.client_name = client_name + self.password = password + self.socket_timeout = socket_timeout + self.socket_connect_timeout = socket_connect_timeout or socket_timeout or None + self.socket_keepalive = socket_keepalive + self.socket_keepalive_options = socket_keepalive_options or {} + self.socket_type = socket_type + self.retry_on_timeout = retry_on_timeout + if retry_on_timeout: + if retry is None: + self.retry = Retry(NoBackoff(), 1) + else: + # deep-copy the Retry object as it is mutable + self.retry = copy.deepcopy(retry) + else: + self.retry = Retry(NoBackoff(), 0) + self.health_check_interval = health_check_interval + self.next_health_check: float = -1 + self.ssl_context: Optional[RedisSSLContext] = None + self.encoder = encoder_class(encoding, encoding_errors, decode_responses) + self.redis_connect_func = redis_connect_func + self._reader: Optional[asyncio.StreamReader] = None + self._writer: Optional[asyncio.StreamWriter] = None + self._socket_read_size = socket_read_size + self.set_parser(parser_class) + self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = [] + self._buffer_cutoff = 6000 + self._lock = asyncio.Lock() + + def __repr__(self): + repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces())) + return f"{self.__class__.__name__}<{repr_args}>" + + def repr_pieces(self): + pieces = [("host", self.host), ("port", self.port), ("db", self.db)] + if self.client_name: + pieces.append(("client_name", self.client_name)) + return pieces + + def __del__(self): + try: + if self.is_connected: + loop = asyncio.get_event_loop() + coro = self.disconnect() + if loop.is_running(): + loop.create_task(coro) + else: + loop.run_until_complete(coro) + except Exception: + pass + + @property + def is_connected(self): + return bool(self._reader and self._writer) + + def register_connect_callback(self, callback): + self._connect_callbacks.append(weakref.WeakMethod(callback)) + + def clear_connect_callbacks(self): + self._connect_callbacks = [] + + def set_parser(self, parser_class): + """ + Creates a new instance of parser_class with socket size: + _socket_read_size and assigns it to the parser for the connection + :param parser_class: The required parser class + """ + self._parser = parser_class(socket_read_size=self._socket_read_size) + + async def connect(self): + """Connects to the Redis server if not already connected""" + if self.is_connected: + return + try: + await self._connect() + except asyncio.CancelledError: + raise + except (socket.timeout, asyncio.TimeoutError): + raise TimeoutError("Timeout connecting to server") + except OSError as e: + raise ConnectionError(self._error_message(e)) + except Exception as exc: + raise ConnectionError(exc) from exc + + try: + if self.redis_connect_func is None: + # Use the default on_connect function + await self.on_connect() + else: + # Use the passed function redis_connect_func + await self.redis_connect_func(self) if asyncio.iscoroutinefunction( + self.redis_connect_func + ) else self.redis_connect_func(self) + except RedisError: + # clean up after any error in on_connect + await self.disconnect() + raise + + # run any user callbacks. right now the only internal callback + # is for pubsub channel/pattern resubscription + for ref in self._connect_callbacks: + callback = ref() + task = callback(self) + if task and inspect.isawaitable(task): + await task + + async def _connect(self): + """Create a TCP socket connection""" + async with async_timeout.timeout(self.socket_connect_timeout): + reader, writer = await asyncio.open_connection( + host=self.host, + port=self.port, + ssl=self.ssl_context.get() if self.ssl_context else None, + ) + self._reader = reader + self._writer = writer + sock = writer.transport.get_extra_info("socket") + if sock is not None: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + try: + # TCP_KEEPALIVE + if self.socket_keepalive: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + for k, v in self.socket_keepalive_options.items(): + sock.setsockopt(socket.SOL_TCP, k, v) + + except (OSError, TypeError): + # `socket_keepalive_options` might contain invalid options + # causing an error. Do not leave the connection open. + writer.close() + raise + + def _error_message(self, exception): + # args for socket.error can either be (errno, "message") + # or just "message" + if len(exception.args) == 1: + return f"Error connecting to {self.host}:{self.port}. {exception.args[0]}." + else: + return ( + f"Error {exception.args[0]} connecting to {self.host}:{self.port}. " + f"{exception.args[0]}." + ) + + async def on_connect(self): + """Initialize the connection, authenticate and select a database""" + self._parser.on_connect(self) + + # if username and/or password are set, authenticate + if self.username or self.password: + auth_args: Union[Tuple[str], Tuple[str, str]] + if self.username: + auth_args = (self.username, self.password or "") + else: + # Mypy bug: https://github.com/python/mypy/issues/10944 + auth_args = (self.password or "",) + # avoid checking health here -- PING will fail if we try + # to check the health prior to the AUTH + await self.send_command("AUTH", *auth_args, check_health=False) + + try: + auth_response = await self.read_response() + except AuthenticationWrongNumberOfArgsError: + # a username and password were specified but the Redis + # server seems to be < 6.0.0 which expects a single password + # arg. retry auth with just the password. + # https://github.com/andymccurdy/redis-py/issues/1274 + await self.send_command("AUTH", self.password, check_health=False) + auth_response = await self.read_response() + + if str_if_bytes(auth_response) != "OK": + raise AuthenticationError("Invalid Username or Password") + + # if a client_name is given, set it + if self.client_name: + await self.send_command("CLIENT", "SETNAME", self.client_name) + if str_if_bytes(await self.read_response()) != "OK": + raise ConnectionError("Error setting client name") + + # if a database is specified, switch to it + if self.db: + await self.send_command("SELECT", self.db) + if str_if_bytes(await self.read_response()) != "OK": + raise ConnectionError("Invalid Database") + + async def disconnect(self): + """Disconnects from the Redis server""" + try: + async with async_timeout.timeout(self.socket_connect_timeout): + self._parser.on_disconnect() + if not self.is_connected: + return + try: + if os.getpid() == self.pid: + self._writer.close() # type: ignore[union-attr] + # py3.6 doesn't have this method + if hasattr(self._writer, "wait_closed"): + await self._writer.wait_closed() # type: ignore[union-attr] + except OSError: + pass + self._reader = None + self._writer = None + except asyncio.TimeoutError: + raise TimeoutError( + f"Timed out closing connection after {self.socket_connect_timeout}" + ) from None + + async def _send_ping(self): + """Send PING, expect PONG in return""" + await self.send_command("PING", check_health=False) + if str_if_bytes(await self.read_response()) != "PONG": + raise ConnectionError("Bad response from PING health check") + + async def _ping_failed(self, error): + """Function to call when PING fails""" + await self.disconnect() + + async def check_health(self): + """Check the health of the connection with a PING/PONG""" + if sys.version_info[0:2] == (3, 6): + func = asyncio.get_event_loop + else: + func = asyncio.get_running_loop + + if self.health_check_interval and func().time() > self.next_health_check: + await self.retry.call_with_retry(self._send_ping, self._ping_failed) + + async def _send_packed_command(self, command: Iterable[bytes]) -> None: + if self._writer is None: + raise RedisError("Connection already closed.") + + self._writer.writelines(command) + await self._writer.drain() + + async def send_packed_command( + self, + command: Union[bytes, str, Iterable[bytes]], + check_health: bool = True, + ): + """Send an already packed command to the Redis server""" + if not self._writer: + await self.connect() + # guard against health check recursion + if check_health: + await self.check_health() + try: + if isinstance(command, str): + command = command.encode() + if isinstance(command, bytes): + command = [command] + await asyncio.wait_for( + self._send_packed_command(command), + self.socket_timeout, + ) + except asyncio.TimeoutError: + await self.disconnect() + raise TimeoutError("Timeout writing to socket") from None + except OSError as e: + await self.disconnect() + if len(e.args) == 1: + err_no, errmsg = "UNKNOWN", e.args[0] + else: + err_no = e.args[0] + errmsg = e.args[1] + raise ConnectionError( + f"Error {err_no} while writing to socket. {errmsg}." + ) from e + except BaseException: + await self.disconnect() + raise + + async def send_command(self, *args, **kwargs): + """Pack and send a command to the Redis server""" + if not self.is_connected: + await self.connect() + await self.send_packed_command( + self.pack_command(*args), check_health=kwargs.get("check_health", True) + ) + + async def can_read(self, timeout: float = 0): + """Poll the socket to see if there's data that can be read.""" + if not self.is_connected: + await self.connect() + return await self._parser.can_read(timeout) + + async def read_response(self, disable_decoding: bool = False): + """Read the response from a previously sent command""" + try: + async with self._lock: + async with async_timeout.timeout(self.socket_timeout): + response = await self._parser.read_response( + disable_decoding=disable_decoding + ) + except asyncio.TimeoutError: + await self.disconnect() + raise TimeoutError(f"Timeout reading from {self.host}:{self.port}") + except OSError as e: + await self.disconnect() + raise ConnectionError( + f"Error while reading from {self.host}:{self.port} : {e.args}" + ) + except BaseException: + await self.disconnect() + raise + + if self.health_check_interval: + if sys.version_info[0:2] == (3, 6): + func = asyncio.get_event_loop + else: + func = asyncio.get_running_loop + self.next_health_check = func().time() + self.health_check_interval + + if isinstance(response, ResponseError): + raise response from None + return response + + def pack_command(self, *args: EncodableT) -> List[bytes]: + """Pack a series of arguments into the Redis protocol""" + output = [] + # the client might have included 1 or more literal arguments in + # the command name, e.g., 'CONFIG GET'. The Redis server expects these + # arguments to be sent separately, so split the first argument + # manually. These arguments should be bytestrings so that they are + # not encoded. + assert not isinstance(args[0], float) + if isinstance(args[0], str): + args = tuple(args[0].encode().split()) + args[1:] + elif b" " in args[0]: + args = tuple(args[0].split()) + args[1:] + + buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) + + buffer_cutoff = self._buffer_cutoff + for arg in map(self.encoder.encode, args): + # to avoid large string mallocs, chunk the command into the + # output list if we're sending large values or memoryviews + arg_length = len(arg) + if ( + len(buff) > buffer_cutoff + or arg_length > buffer_cutoff + or isinstance(arg, memoryview) + ): + buff = SYM_EMPTY.join( + (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) + ) + output.append(buff) + output.append(arg) + buff = SYM_CRLF + else: + buff = SYM_EMPTY.join( + ( + buff, + SYM_DOLLAR, + str(arg_length).encode(), + SYM_CRLF, + arg, + SYM_CRLF, + ) + ) + output.append(buff) + return output + + def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]: + """Pack multiple commands into the Redis protocol""" + output: List[bytes] = [] + pieces: List[bytes] = [] + buffer_length = 0 + buffer_cutoff = self._buffer_cutoff + + for cmd in commands: + for chunk in self.pack_command(*cmd): + chunklen = len(chunk) + if ( + buffer_length > buffer_cutoff + or chunklen > buffer_cutoff + or isinstance(chunk, memoryview) + ): + output.append(SYM_EMPTY.join(pieces)) + buffer_length = 0 + pieces = [] + + if chunklen > buffer_cutoff or isinstance(chunk, memoryview): + output.append(chunk) + else: + pieces.append(chunk) + buffer_length += chunklen + + if pieces: + output.append(SYM_EMPTY.join(pieces)) + return output + + +class SSLConnection(Connection): + def __init__( + self, + ssl_keyfile: Optional[str] = None, + ssl_certfile: Optional[str] = None, + ssl_cert_reqs: str = "required", + ssl_ca_certs: Optional[str] = None, + ssl_check_hostname: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self.ssl_context: RedisSSLContext = RedisSSLContext( + keyfile=ssl_keyfile, + certfile=ssl_certfile, + cert_reqs=ssl_cert_reqs, + ca_certs=ssl_ca_certs, + check_hostname=ssl_check_hostname, + ) + + @property + def keyfile(self): + return self.ssl_context.keyfile + + @property + def certfile(self): + return self.ssl_context.certfile + + @property + def cert_reqs(self): + return self.ssl_context.cert_reqs + + @property + def ca_certs(self): + return self.ssl_context.ca_certs + + @property + def check_hostname(self): + return self.ssl_context.check_hostname + + +class RedisSSLContext: + __slots__ = ( + "keyfile", + "certfile", + "cert_reqs", + "ca_certs", + "context", + "check_hostname", + ) + + def __init__( + self, + keyfile: Optional[str] = None, + certfile: Optional[str] = None, + cert_reqs: Optional[str] = None, + ca_certs: Optional[str] = None, + check_hostname: bool = False, + ): + self.keyfile = keyfile + self.certfile = certfile + if cert_reqs is None: + self.cert_reqs = ssl.CERT_NONE + elif isinstance(cert_reqs, str): + CERT_REQS = { + "none": ssl.CERT_NONE, + "optional": ssl.CERT_OPTIONAL, + "required": ssl.CERT_REQUIRED, + } + if cert_reqs not in CERT_REQS: + raise RedisError( + f"Invalid SSL Certificate Requirements Flag: {cert_reqs}" + ) + self.cert_reqs = CERT_REQS[cert_reqs] + self.ca_certs = ca_certs + self.check_hostname = check_hostname + self.context: Optional[ssl.SSLContext] = None + + def get(self) -> ssl.SSLContext: + if not self.context: + context = ssl.create_default_context() + context.check_hostname = self.check_hostname + context.verify_mode = self.cert_reqs + if self.certfile and self.keyfile: + context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile) + if self.ca_certs: + context.load_verify_locations(self.ca_certs) + self.context = context + return self.context + + +class UnixDomainSocketConnection(Connection): # lgtm [py/missing-call-to-init] + def __init__( + self, + *, + path: str = "", + db: Union[str, int] = 0, + username: Optional[str] = None, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + retry_on_timeout: bool = False, + parser_class: Type[BaseParser] = DefaultParser, + socket_read_size: int = 65536, + health_check_interval: float = 0.0, + client_name: str = None, + retry: Optional[Retry] = None, + ): + """ + Initialize a new UnixDomainSocketConnection. + To specify a retry policy, first set `retry_on_timeout` to `True` + then set `retry` to a valid `Retry` object + """ + self.pid = os.getpid() + self.path = path + self.db = db + self.username = username + self.client_name = client_name + self.password = password + self.socket_timeout = socket_timeout + self.socket_connect_timeout = socket_connect_timeout or socket_timeout or None + self.retry_on_timeout = retry_on_timeout + if retry_on_timeout: + if retry is None: + self.retry = Retry(NoBackoff(), 1) + else: + # deep-copy the Retry object as it is mutable + self.retry = copy.deepcopy(retry) + else: + self.retry = Retry(NoBackoff(), 0) + self.health_check_interval = health_check_interval + self.next_health_check = -1 + self.encoder = Encoder(encoding, encoding_errors, decode_responses) + self._sock = None + self._reader = None + self._writer = None + self._socket_read_size = socket_read_size + self.set_parser(parser_class) + self._connect_callbacks = [] + self._buffer_cutoff = 6000 + self._lock = asyncio.Lock() + + def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]: + pieces = [ + ("path", self.path), + ("db", self.db), + ] + if self.client_name: + pieces.append(("client_name", self.client_name)) + return pieces + + async def _connect(self): + async with async_timeout.timeout(self.socket_connect_timeout): + reader, writer = await asyncio.open_unix_connection(path=self.path) + self._reader = reader + self._writer = writer + await self.on_connect() + + def _error_message(self, exception): + # args for socket.error can either be (errno, "message") + # or just "message" + if len(exception.args) == 1: + return f"Error connecting to unix socket: {self.path}. {exception.args[0]}." + else: + return ( + f"Error {exception.args[0]} connecting to unix socket: " + f"{self.path}. {exception.args[1]}." + ) + + +FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") + + +def to_bool(value) -> Optional[bool]: + if value is None or value == "": + return None + if isinstance(value, str) and value.upper() in FALSE_STRINGS: + return False + return bool(value) + + +URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType( + { + "db": int, + "socket_timeout": float, + "socket_connect_timeout": float, + "socket_keepalive": to_bool, + "retry_on_timeout": to_bool, + "max_connections": int, + "health_check_interval": int, + "ssl_check_hostname": to_bool, + } +) + + +class ConnectKwargs(TypedDict, total=False): + username: str + password: str + connection_class: Type[Connection] + host: str + port: int + db: int + path: str + + +def parse_url(url: str) -> ConnectKwargs: + parsed: ParseResult = urlparse(url) + kwargs: ConnectKwargs = {} + + for name, value_list in parse_qs(parsed.query).items(): + if value_list and len(value_list) > 0: + value = unquote(value_list[0]) + parser = URL_QUERY_ARGUMENT_PARSERS.get(name) + if parser: + try: + # We can't type this. + kwargs[name] = parser(value) # type: ignore[misc] + except (TypeError, ValueError): + raise ValueError(f"Invalid value for `{name}` in connection URL.") + else: + kwargs[name] = value # type: ignore[misc] + + if parsed.username: + kwargs["username"] = unquote(parsed.username) + if parsed.password: + kwargs["password"] = unquote(parsed.password) + + # We only support redis://, rediss:// and unix:// schemes. + if parsed.scheme == "unix": + if parsed.path: + kwargs["path"] = unquote(parsed.path) + kwargs["connection_class"] = UnixDomainSocketConnection + + elif parsed.scheme in ("redis", "rediss"): + if parsed.hostname: + kwargs["host"] = unquote(parsed.hostname) + if parsed.port: + kwargs["port"] = int(parsed.port) + + # If there's a path argument, use it as the db argument if a + # querystring value wasn't specified + if parsed.path and "db" not in kwargs: + try: + kwargs["db"] = int(unquote(parsed.path).replace("/", "")) + except (AttributeError, ValueError): + pass + + if parsed.scheme == "rediss": + kwargs["connection_class"] = SSLConnection + else: + valid_schemes = "redis://, rediss://, unix://" + raise ValueError( + f"Redis URL must specify one of the following schemes ({valid_schemes})" + ) + + return kwargs + + +_CP = TypeVar("_CP", bound="ConnectionPool") + + +class ConnectionPool: + """ + Create a connection pool. ``If max_connections`` is set, then this + object raises :py:class:`~redis.ConnectionError` when the pool's + limit is reached. + + By default, TCP connections are created unless ``connection_class`` + is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for + unix sockets. + + Any additional keyword arguments are passed to the constructor of + ``connection_class``. + """ + + @classmethod + def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: + """ + Return a connection pool configured from the given URL. + + For example:: + + redis://[[username]:[password]]@localhost:6379/0 + rediss://[[username]:[password]]@localhost:6379/0 + unix://[[username]:[password]]@/path/to/socket.sock?db=0 + + Three URL schemes are supported: + + - `redis://` creates a TCP socket connection. See more at: + + - `rediss://` creates a SSL wrapped TCP socket connection. See more at: + + - ``unix://``: creates a Unix Domain Socket connection. + + The username, password, hostname, path and all querystring values + are passed through urllib.parse.unquote in order to replace any + percent-encoded values with their corresponding characters. + + There are several ways to specify a database number. The first value + found will be used: + 1. A ``db`` querystring option, e.g. redis://localhost?db=0 + 2. If using the redis:// or rediss:// schemes, the path argument + of the url, e.g. redis://localhost/0 + 3. A ``db`` keyword argument to this function. + + If none of these options are specified, the default db=0 is used. + + All querystring options are cast to their appropriate Python types. + Boolean arguments can be specified with string values "True"/"False" + or "Yes"/"No". Values that cannot be properly cast cause a + ``ValueError`` to be raised. Once parsed, the querystring arguments + and keyword arguments are passed to the ``ConnectionPool``'s + class initializer. In the case of conflicting arguments, querystring + arguments always win. + """ + url_options = parse_url(url) + kwargs.update(url_options) + return cls(**kwargs) + + def __init__( + self, + connection_class: Type[Connection] = Connection, + max_connections: Optional[int] = None, + **connection_kwargs, + ): + max_connections = max_connections or 2 ** 31 + if not isinstance(max_connections, int) or max_connections < 0: + raise ValueError('"max_connections" must be a positive integer') + + self.connection_class = connection_class + self.connection_kwargs = connection_kwargs + self.max_connections = max_connections + + # a lock to protect the critical section in _checkpid(). + # this lock is acquired when the process id changes, such as + # after a fork. during this time, multiple threads in the child + # process could attempt to acquire this lock. the first thread + # to acquire the lock will reset the data structures and lock + # object of this pool. subsequent threads acquiring this lock + # will notice the first thread already did the work and simply + # release the lock. + self._fork_lock = threading.Lock() + self._lock = asyncio.Lock() + self._created_connections: int + self._available_connections: List[Connection] + self._in_use_connections: Set[Connection] + self.reset() # lgtm [py/init-calls-subclass] + self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder) + + def __repr__(self): + return ( + f"{self.__class__.__name__}" + f"<{self.connection_class(**self.connection_kwargs)!r}>" + ) + + def reset(self): + self._lock = asyncio.Lock() + self._created_connections = 0 + self._available_connections = [] + self._in_use_connections = set() + + # this must be the last operation in this method. while reset() is + # called when holding _fork_lock, other threads in this process + # can call _checkpid() which compares self.pid and os.getpid() without + # holding any lock (for performance reasons). keeping this assignment + # as the last operation ensures that those other threads will also + # notice a pid difference and block waiting for the first thread to + # release _fork_lock. when each of these threads eventually acquire + # _fork_lock, they will notice that another thread already called + # reset() and they will immediately release _fork_lock and continue on. + self.pid = os.getpid() + + def _checkpid(self): + # _checkpid() attempts to keep ConnectionPool fork-safe on modern + # systems. this is called by all ConnectionPool methods that + # manipulate the pool's state such as get_connection() and release(). + # + # _checkpid() determines whether the process has forked by comparing + # the current process id to the process id saved on the ConnectionPool + # instance. if these values are the same, _checkpid() simply returns. + # + # when the process ids differ, _checkpid() assumes that the process + # has forked and that we're now running in the child process. the child + # process cannot use the parent's file descriptors (e.g., sockets). + # therefore, when _checkpid() sees the process id change, it calls + # reset() in order to reinitialize the child's ConnectionPool. this + # will cause the child to make all new connection objects. + # + # _checkpid() is protected by self._fork_lock to ensure that multiple + # threads in the child process do not call reset() multiple times. + # + # there is an extremely small chance this could fail in the following + # scenario: + # 1. process A calls _checkpid() for the first time and acquires + # self._fork_lock. + # 2. while holding self._fork_lock, process A forks (the fork() + # could happen in a different thread owned by process A) + # 3. process B (the forked child process) inherits the + # ConnectionPool's state from the parent. that state includes + # a locked _fork_lock. process B will not be notified when + # process A releases the _fork_lock and will thus never be + # able to acquire the _fork_lock. + # + # to mitigate this possible deadlock, _checkpid() will only wait 5 + # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in + # that time it is assumed that the child is deadlocked and a + # redis.ChildDeadlockedError error is raised. + if self.pid != os.getpid(): + acquired = self._fork_lock.acquire(timeout=5) + if not acquired: + raise ChildDeadlockedError + # reset() the instance for the new process if another thread + # hasn't already done so + try: + if self.pid != os.getpid(): + self.reset() + finally: + self._fork_lock.release() + + async def get_connection(self, command_name, *keys, **options): + """Get a connection from the pool""" + self._checkpid() + async with self._lock: + try: + connection = self._available_connections.pop() + except IndexError: + connection = self.make_connection() + self._in_use_connections.add(connection) + + try: + # ensure this connection is connected to Redis + await connection.connect() + # connections that the pool provides should be ready to send + # a command. if not, the connection was either returned to the + # pool before all data has been read or the socket has been + # closed. either way, reconnect and verify everything is good. + try: + if await connection.can_read(): + raise ConnectionError("Connection has data") from None + except ConnectionError: + await connection.disconnect() + await connection.connect() + if await connection.can_read(): + raise ConnectionError("Connection not ready") from None + except BaseException: + # release the connection back to the pool so that we don't + # leak it + await self.release(connection) + raise + + return connection + + def get_encoder(self): + """Return an encoder based on encoding settings""" + kwargs = self.connection_kwargs + return self.encoder_class( + encoding=kwargs.get("encoding", "utf-8"), + encoding_errors=kwargs.get("encoding_errors", "strict"), + decode_responses=kwargs.get("decode_responses", False), + ) + + def make_connection(self): + """Create a new connection""" + if self._created_connections >= self.max_connections: + raise ConnectionError("Too many connections") + self._created_connections += 1 + return self.connection_class(**self.connection_kwargs) + + async def release(self, connection: Connection): + """Releases the connection back to the pool""" + self._checkpid() + async with self._lock: + try: + self._in_use_connections.remove(connection) + except KeyError: + # Gracefully fail when a connection is returned to this pool + # that the pool doesn't actually own + pass + + if self.owns_connection(connection): + self._available_connections.append(connection) + else: + # pool doesn't own this connection. do not add it back + # to the pool and decrement the count so that another + # connection can take its place if needed + self._created_connections -= 1 + await connection.disconnect() + return + + def owns_connection(self, connection: Connection): + return connection.pid == self.pid + + async def disconnect(self, inuse_connections: bool = True): + """ + Disconnects connections in the pool + + If ``inuse_connections`` is True, disconnect connections that are + current in use, potentially by other tasks. Otherwise only disconnect + connections that are idle in the pool. + """ + self._checkpid() + async with self._lock: + if inuse_connections: + connections: Iterable[Connection] = chain( + self._available_connections, self._in_use_connections + ) + else: + connections = self._available_connections + resp = await asyncio.gather( + *(connection.disconnect() for connection in connections), + return_exceptions=True, + ) + exc = next((r for r in resp if isinstance(r, BaseException)), None) + if exc: + raise exc + + +class BlockingConnectionPool(ConnectionPool): + """ + Thread-safe blocking connection pool:: + + >>> from redis.client import Redis + >>> client = Redis(connection_pool=BlockingConnectionPool()) + + It performs the same function as the default + :py:class:`~redis.ConnectionPool` implementation, in that, + it maintains a pool of reusable connections that can be shared by + multiple redis clients (safely across threads if required). + + The difference is that, in the event that a client tries to get a + connection from the pool when all of connections are in use, rather than + raising a :py:class:`~redis.ConnectionError` (as the default + :py:class:`~redis.ConnectionPool` implementation does), it + makes the client wait ("blocks") for a specified number of seconds until + a connection becomes available. + + Use ``max_connections`` to increase / decrease the pool size:: + + >>> pool = BlockingConnectionPool(max_connections=10) + + Use ``timeout`` to tell it either how many seconds to wait for a connection + to become available, or to block forever: + + >>> # Block forever. + >>> pool = BlockingConnectionPool(timeout=None) + + >>> # Raise a ``ConnectionError`` after five seconds if a connection is + >>> # not available. + >>> pool = BlockingConnectionPool(timeout=5) + """ + + def __init__( + self, + max_connections: int = 50, + timeout: Optional[int] = 20, + connection_class: Type[Connection] = Connection, + queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, + **connection_kwargs, + ): + + self.queue_class = queue_class + self.timeout = timeout + self._connections: List[Connection] + super().__init__( + connection_class=connection_class, + max_connections=max_connections, + **connection_kwargs, + ) + + def reset(self): + # Create and fill up a thread safe queue with ``None`` values. + self.pool = self.queue_class(self.max_connections) + while True: + try: + self.pool.put_nowait(None) + except asyncio.QueueFull: + break + + # Keep a list of actual connection instances so that we can + # disconnect them later. + self._connections = [] + + # this must be the last operation in this method. while reset() is + # called when holding _fork_lock, other threads in this process + # can call _checkpid() which compares self.pid and os.getpid() without + # holding any lock (for performance reasons). keeping this assignment + # as the last operation ensures that those other threads will also + # notice a pid difference and block waiting for the first thread to + # release _fork_lock. when each of these threads eventually acquire + # _fork_lock, they will notice that another thread already called + # reset() and they will immediately release _fork_lock and continue on. + self.pid = os.getpid() + + def make_connection(self): + """Make a fresh connection.""" + connection = self.connection_class(**self.connection_kwargs) + self._connections.append(connection) + return connection + + async def get_connection(self, command_name, *keys, **options): + """ + Get a connection, blocking for ``self.timeout`` until a connection + is available from the pool. + + If the connection returned is ``None`` then creates a new connection. + Because we use a last-in first-out queue, the existing connections + (having been returned to the pool after the initial ``None`` values + were added) will be returned before ``None`` values. This means we only + create new connections when we need to, i.e.: the actual number of + connections will only increase in response to demand. + """ + # Make sure we haven't changed process. + self._checkpid() + + # Try and get a connection from the pool. If one isn't available within + # self.timeout then raise a ``ConnectionError``. + connection = None + try: + async with async_timeout.timeout(self.timeout): + connection = await self.pool.get() + except (asyncio.QueueEmpty, asyncio.TimeoutError): + # Note that this is not caught by the redis client and will be + # raised unless handled by application code. If you want never to + raise ConnectionError("No connection available.") + + # If the ``connection`` is actually ``None`` then that's a cue to make + # a new connection to add to the pool. + if connection is None: + connection = self.make_connection() + + try: + # ensure this connection is connected to Redis + await connection.connect() + # connections that the pool provides should be ready to send + # a command. if not, the connection was either returned to the + # pool before all data has been read or the socket has been + # closed. either way, reconnect and verify everything is good. + try: + if await connection.can_read(): + raise ConnectionError("Connection has data") from None + except ConnectionError: + await connection.disconnect() + await connection.connect() + if await connection.can_read(): + raise ConnectionError("Connection not ready") from None + except BaseException: + # release the connection back to the pool so that we don't leak it + await self.release(connection) + raise + + return connection + + async def release(self, connection: Connection): + """Releases the connection back to the pool.""" + # Make sure we haven't changed process. + self._checkpid() + if not self.owns_connection(connection): + # pool doesn't own this connection. do not add it back + # to the pool. instead add a None value which is a placeholder + # that will cause the pool to recreate the connection if + # its needed. + await connection.disconnect() + self.pool.put_nowait(None) + return + + # Put the connection back into the pool. + try: + self.pool.put_nowait(connection) + except asyncio.QueueFull: + # perhaps the pool has been reset() after a fork? regardless, + # we don't want this connection + pass + + async def disconnect(self, inuse_connections: bool = True): + """Disconnects all connections in the pool.""" + self._checkpid() + async with self._lock: + resp = await asyncio.gather( + *(connection.disconnect() for connection in self._connections), + return_exceptions=True, + ) + exc = next((r for r in resp if isinstance(r, BaseException)), None) + if exc: + raise exc diff --git a/redis/asyncio/lock.py b/redis/asyncio/lock.py new file mode 100644 index 0000000000..d4861329a0 --- /dev/null +++ b/redis/asyncio/lock.py @@ -0,0 +1,311 @@ +import asyncio +import sys +import threading +import uuid +from types import SimpleNamespace +from typing import TYPE_CHECKING, Awaitable, NoReturn, Optional, Union + +from redis.exceptions import LockError, LockNotOwnedError + +if TYPE_CHECKING: + from redis.asyncio import Redis + + +class Lock: + """ + A shared, distributed Lock. Using Redis for locking allows the Lock + to be shared across processes and/or machines. + + It's left to the user to resolve deadlock issues and make sure + multiple clients play nicely together. + """ + + lua_release = None + lua_extend = None + lua_reacquire = None + + # KEYS[1] - lock name + # ARGV[1] - token + # return 1 if the lock was released, otherwise 0 + LUA_RELEASE_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + redis.call('del', KEYS[1]) + return 1 + """ + + # KEYS[1] - lock name + # ARGV[1] - token + # ARGV[2] - additional milliseconds + # ARGV[3] - "0" if the additional time should be added to the lock's + # existing ttl or "1" if the existing ttl should be replaced + # return 1 if the locks time was extended, otherwise 0 + LUA_EXTEND_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + local expiration = redis.call('pttl', KEYS[1]) + if not expiration then + expiration = 0 + end + if expiration < 0 then + return 0 + end + + local newttl = ARGV[2] + if ARGV[3] == "0" then + newttl = ARGV[2] + expiration + end + redis.call('pexpire', KEYS[1], newttl) + return 1 + """ + + # KEYS[1] - lock name + # ARGV[1] - token + # ARGV[2] - milliseconds + # return 1 if the locks time was reacquired, otherwise 0 + LUA_REACQUIRE_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + redis.call('pexpire', KEYS[1], ARGV[2]) + return 1 + """ + + def __init__( + self, + redis: "Redis", + name: Union[str, bytes, memoryview], + timeout: Optional[float] = None, + sleep: float = 0.1, + blocking: bool = True, + blocking_timeout: Optional[float] = None, + thread_local: bool = True, + ): + """ + Create a new Lock instance named ``name`` using the Redis client + supplied by ``redis``. + + ``timeout`` indicates a maximum life for the lock in seconds. + By default, it will remain locked until release() is called. + ``timeout`` can be specified as a float or integer, both representing + the number of seconds to wait. + + ``sleep`` indicates the amount of time to sleep in seconds per loop + iteration when the lock is in blocking mode and another client is + currently holding the lock. + + ``blocking`` indicates whether calling ``acquire`` should block until + the lock has been acquired or to fail immediately, causing ``acquire`` + to return False and the lock not being acquired. Defaults to True. + Note this value can be overridden by passing a ``blocking`` + argument to ``acquire``. + + ``blocking_timeout`` indicates the maximum amount of time in seconds to + spend trying to acquire the lock. A value of ``None`` indicates + continue trying forever. ``blocking_timeout`` can be specified as a + float or integer, both representing the number of seconds to wait. + + ``thread_local`` indicates whether the lock token is placed in + thread-local storage. By default, the token is placed in thread local + storage so that a thread only sees its token, not a token set by + another thread. Consider the following timeline: + + time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. + thread-1 sets the token to "abc" + time: 1, thread-2 blocks trying to acquire `my-lock` using the + Lock instance. + time: 5, thread-1 has not yet completed. redis expires the lock + key. + time: 5, thread-2 acquired `my-lock` now that it's available. + thread-2 sets the token to "xyz" + time: 6, thread-1 finishes its work and calls release(). if the + token is *not* stored in thread local storage, then + thread-1 would see the token value as "xyz" and would be + able to successfully release the thread-2's lock. + + In some use cases it's necessary to disable thread local storage. For + example, if you have code where one thread acquires a lock and passes + that lock instance to a worker thread to release later. If thread + local storage isn't disabled in this case, the worker thread won't see + the token set by the thread that acquired the lock. Our assumption + is that these cases aren't common and as such default to using + thread local storage. + """ + self.redis = redis + self.name = name + self.timeout = timeout + self.sleep = sleep + self.blocking = blocking + self.blocking_timeout = blocking_timeout + self.thread_local = bool(thread_local) + self.local = threading.local() if self.thread_local else SimpleNamespace() + self.local.token = None + self.register_scripts() + + def register_scripts(self): + cls = self.__class__ + client = self.redis + if cls.lua_release is None: + cls.lua_release = client.register_script(cls.LUA_RELEASE_SCRIPT) + if cls.lua_extend is None: + cls.lua_extend = client.register_script(cls.LUA_EXTEND_SCRIPT) + if cls.lua_reacquire is None: + cls.lua_reacquire = client.register_script(cls.LUA_REACQUIRE_SCRIPT) + + async def __aenter__(self): + if await self.acquire(): + return self + raise LockError("Unable to acquire lock within the time specified") + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.release() + + async def acquire( + self, + blocking: Optional[bool] = None, + blocking_timeout: Optional[float] = None, + token: Optional[Union[str, bytes]] = None, + ): + """ + Use Redis to hold a shared, distributed lock named ``name``. + Returns True once the lock is acquired. + + If ``blocking`` is False, always return immediately. If the lock + was acquired, return True, otherwise return False. + + ``blocking_timeout`` specifies the maximum number of seconds to + wait trying to acquire the lock. + + ``token`` specifies the token value to be used. If provided, token + must be a bytes object or a string that can be encoded to a bytes + object with the default encoding. If a token isn't specified, a UUID + will be generated. + """ + if sys.version_info[0:2] != (3, 6): + loop = asyncio.get_running_loop() + else: + loop = asyncio.get_event_loop() + + sleep = self.sleep + if token is None: + token = uuid.uuid1().hex.encode() + else: + encoder = self.redis.connection_pool.get_encoder() + token = encoder.encode(token) + if blocking is None: + blocking = self.blocking + if blocking_timeout is None: + blocking_timeout = self.blocking_timeout + stop_trying_at = None + if blocking_timeout is not None: + stop_trying_at = loop.time() + blocking_timeout + while True: + if await self.do_acquire(token): + self.local.token = token + return True + if not blocking: + return False + next_try_at = loop.time() + sleep + if stop_trying_at is not None and next_try_at > stop_trying_at: + return False + await asyncio.sleep(sleep) + + async def do_acquire(self, token: Union[str, bytes]) -> bool: + if self.timeout: + # convert to milliseconds + timeout = int(self.timeout * 1000) + else: + timeout = None + if await self.redis.set(self.name, token, nx=True, px=timeout): + return True + return False + + async def locked(self) -> bool: + """ + Returns True if this key is locked by any process, otherwise False. + """ + return await self.redis.get(self.name) is not None + + async def owned(self) -> bool: + """ + Returns True if this key is locked by this lock, otherwise False. + """ + stored_token = await self.redis.get(self.name) + # need to always compare bytes to bytes + # TODO: this can be simplified when the context manager is finished + if stored_token and not isinstance(stored_token, bytes): + encoder = self.redis.connection_pool.get_encoder() + stored_token = encoder.encode(stored_token) + return self.local.token is not None and stored_token == self.local.token + + def release(self) -> Awaitable[NoReturn]: + """Releases the already acquired lock""" + expected_token = self.local.token + if expected_token is None: + raise LockError("Cannot release an unlocked lock") + self.local.token = None + return self.do_release(expected_token) + + async def do_release(self, expected_token: bytes): + if not bool( + await self.lua_release( + keys=[self.name], args=[expected_token], client=self.redis + ) + ): + raise LockNotOwnedError("Cannot release a lock" " that's no longer owned") + + def extend( + self, additional_time: float, replace_ttl: bool = False + ) -> Awaitable[bool]: + """ + Adds more time to an already acquired lock. + + ``additional_time`` can be specified as an integer or a float, both + representing the number of seconds to add. + + ``replace_ttl`` if False (the default), add `additional_time` to + the lock's existing ttl. If True, replace the lock's ttl with + `additional_time`. + """ + if self.local.token is None: + raise LockError("Cannot extend an unlocked lock") + if self.timeout is None: + raise LockError("Cannot extend a lock with no timeout") + return self.do_extend(additional_time, replace_ttl) + + async def do_extend(self, additional_time, replace_ttl) -> bool: + additional_time = int(additional_time * 1000) + if not bool( + await self.lua_extend( + keys=[self.name], + args=[self.local.token, additional_time, replace_ttl and "1" or "0"], + client=self.redis, + ) + ): + raise LockNotOwnedError("Cannot extend a lock that's" " no longer owned") + return True + + def reacquire(self) -> Awaitable[bool]: + """ + Resets a TTL of an already acquired lock back to a timeout value. + """ + if self.local.token is None: + raise LockError("Cannot reacquire an unlocked lock") + if self.timeout is None: + raise LockError("Cannot reacquire a lock with no timeout") + return self.do_reacquire() + + async def do_reacquire(self) -> bool: + timeout = int(self.timeout * 1000) + if not bool( + await self.lua_reacquire( + keys=[self.name], args=[self.local.token, timeout], client=self.redis + ) + ): + raise LockNotOwnedError("Cannot reacquire a lock that's" " no longer owned") + return True diff --git a/redis/asyncio/retry.py b/redis/asyncio/retry.py new file mode 100644 index 0000000000..9b5349402c --- /dev/null +++ b/redis/asyncio/retry.py @@ -0,0 +1,58 @@ +from asyncio import sleep +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Tuple, TypeVar + +from redis.exceptions import ConnectionError, RedisError, TimeoutError + +if TYPE_CHECKING: + from redis.backoff import AbstractBackoff + + +T = TypeVar("T") + + +class Retry: + """Retry a specific number of times after a failure""" + + __slots__ = "_backoff", "_retries", "_supported_errors" + + def __init__( + self, + backoff: "AbstractBackoff", + retries: int, + supported_errors: Tuple[RedisError, ...] = ( + ConnectionError, + TimeoutError, + ), + ): + """ + Initialize a `Retry` object with a `Backoff` object + that retries a maximum of `retries` times. + You can specify the types of supported errors which trigger + a retry with the `supported_errors` parameter. + """ + self._backoff = backoff + self._retries = retries + self._supported_errors = supported_errors + + async def call_with_retry( + self, do: Callable[[], Awaitable[T]], fail: Callable[[RedisError], Any] + ) -> T: + """ + Execute an operation that might fail and returns its result, or + raise the exception that was thrown depending on the `Backoff` object. + `do`: the operation to call. Expects no argument. + `fail`: the failure handler, expects the last error that was thrown + """ + self._backoff.reset() + failures = 0 + while True: + try: + return await do() + except self._supported_errors as error: + failures += 1 + await fail(error) + if failures > self._retries: + raise error + backoff = self._backoff.compute(failures) + if backoff > 0: + await sleep(backoff) diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py new file mode 100644 index 0000000000..5aefd09ebd --- /dev/null +++ b/redis/asyncio/sentinel.py @@ -0,0 +1,353 @@ +import asyncio +import random +import weakref +from typing import AsyncIterator, Iterable, Mapping, Sequence, Tuple, Type + +from redis.asyncio.client import Redis +from redis.asyncio.connection import ( + Connection, + ConnectionPool, + EncodableT, + SSLConnection, +) +from redis.commands import AsyncSentinelCommands +from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError +from redis.utils import str_if_bytes + + +class MasterNotFoundError(ConnectionError): + pass + + +class SlaveNotFoundError(ConnectionError): + pass + + +class SentinelManagedConnection(Connection): + def __init__(self, **kwargs): + self.connection_pool = kwargs.pop("connection_pool") + super().__init__(**kwargs) + + def __repr__(self): + pool = self.connection_pool + s = f"{self.__class__.__name__}" + + async def connect_to(self, address): + self.host, self.port = address + await super().connect() + if self.connection_pool.check_connection: + await self.send_command("PING") + if str_if_bytes(await self.read_response()) != "PONG": + raise ConnectionError("PING failed") + + async def connect(self): + if self._reader: + return # already connected + if self.connection_pool.is_master: + await self.connect_to(await self.connection_pool.get_master_address()) + else: + async for slave in self.connection_pool.rotate_slaves(): + try: + return await self.connect_to(slave) + except ConnectionError: + continue + raise SlaveNotFoundError # Never be here + + async def read_response(self, disable_decoding: bool = False): + try: + return await super().read_response(disable_decoding=disable_decoding) + except ReadOnlyError: + if self.connection_pool.is_master: + # When talking to a master, a ReadOnlyError when likely + # indicates that the previous master that we're still connected + # to has been demoted to a slave and there's a new master. + # calling disconnect will force the connection to re-query + # sentinel during the next connect() attempt. + await self.disconnect() + raise ConnectionError("The previous master is now a slave") + raise + + +class SentinelManagedSSLConnection(SentinelManagedConnection, SSLConnection): + pass + + +class SentinelConnectionPool(ConnectionPool): + """ + Sentinel backed connection pool. + + If ``check_connection`` flag is set to True, SentinelManagedConnection + sends a PING command right after establishing the connection. + """ + + def __init__(self, service_name, sentinel_manager, **kwargs): + kwargs["connection_class"] = kwargs.get( + "connection_class", + SentinelManagedSSLConnection + if kwargs.pop("ssl", False) + else SentinelManagedConnection, + ) + self.is_master = kwargs.pop("is_master", True) + self.check_connection = kwargs.pop("check_connection", False) + super().__init__(**kwargs) + self.connection_kwargs["connection_pool"] = weakref.proxy(self) + self.service_name = service_name + self.sentinel_manager = sentinel_manager + self.master_address = None + self.slave_rr_counter = None + + def __repr__(self): + return ( + f"{self.__class__.__name__}" + f"" + ) + + def reset(self): + super().reset() + self.master_address = None + self.slave_rr_counter = None + + def owns_connection(self, connection: Connection): + check = not self.is_master or ( + self.is_master and self.master_address == (connection.host, connection.port) + ) + return check and super().owns_connection(connection) + + async def get_master_address(self): + master_address = await self.sentinel_manager.discover_master(self.service_name) + if self.is_master: + if self.master_address != master_address: + self.master_address = master_address + # disconnect any idle connections so that they reconnect + # to the new master the next time that they are used. + await self.disconnect(inuse_connections=False) + return master_address + + async def rotate_slaves(self) -> AsyncIterator: + """Round-robin slave balancer""" + slaves = await self.sentinel_manager.discover_slaves(self.service_name) + if slaves: + if self.slave_rr_counter is None: + self.slave_rr_counter = random.randint(0, len(slaves) - 1) + for _ in range(len(slaves)): + self.slave_rr_counter = (self.slave_rr_counter + 1) % len(slaves) + slave = slaves[self.slave_rr_counter] + yield slave + # Fallback to the master connection + try: + yield await self.get_master_address() + except MasterNotFoundError: + pass + raise SlaveNotFoundError(f"No slave found for {self.service_name!r}") + + +class Sentinel(AsyncSentinelCommands): + """ + Redis Sentinel cluster client + + >>> from redis.sentinel import Sentinel + >>> sentinel = Sentinel([('localhost', 26379)], socket_timeout=0.1) + >>> master = sentinel.master_for('mymaster', socket_timeout=0.1) + >>> await master.set('foo', 'bar') + >>> slave = sentinel.slave_for('mymaster', socket_timeout=0.1) + >>> await slave.get('foo') + b'bar' + + ``sentinels`` is a list of sentinel nodes. Each node is represented by + a pair (hostname, port). + + ``min_other_sentinels`` defined a minimum number of peers for a sentinel. + When querying a sentinel, if it doesn't meet this threshold, responses + from that sentinel won't be considered valid. + + ``sentinel_kwargs`` is a dictionary of connection arguments used when + connecting to sentinel instances. Any argument that can be passed to + a normal Redis connection can be specified here. If ``sentinel_kwargs`` is + not specified, any socket_timeout and socket_keepalive options specified + in ``connection_kwargs`` will be used. + + ``connection_kwargs`` are keyword arguments that will be used when + establishing a connection to a Redis server. + """ + + def __init__( + self, + sentinels, + min_other_sentinels=0, + sentinel_kwargs=None, + **connection_kwargs, + ): + # if sentinel_kwargs isn't defined, use the socket_* options from + # connection_kwargs + if sentinel_kwargs is None: + sentinel_kwargs = { + k: v for k, v in connection_kwargs.items() if k.startswith("socket_") + } + self.sentinel_kwargs = sentinel_kwargs + + self.sentinels = [ + Redis(host=hostname, port=port, **self.sentinel_kwargs) + for hostname, port in sentinels + ] + self.min_other_sentinels = min_other_sentinels + self.connection_kwargs = connection_kwargs + + async def execute_command(self, *args, **kwargs): + """ + Execute Sentinel command in sentinel nodes. + once - If set to True, then execute the resulting command on a single + node at random, rather than across the entire sentinel cluster. + """ + once = bool(kwargs.get("once", False)) + if "once" in kwargs.keys(): + kwargs.pop("once") + + if once: + tasks = [ + asyncio.Task(sentinel.execute_command(*args, **kwargs)) + for sentinel in self.sentinels + ] + await asyncio.gather(*tasks) + else: + await random.choice(self.sentinels).execute_command(*args, **kwargs) + return True + + def __repr__(self): + sentinel_addresses = [] + for sentinel in self.sentinels: + sentinel_addresses.append( + f"{sentinel.connection_pool.connection_kwargs['host']}:" + f"{sentinel.connection_pool.connection_kwargs['port']}" + ) + return f"{self.__class__.__name__}" + + def check_master_state(self, state: dict, service_name: str) -> bool: + if not state["is_master"] or state["is_sdown"] or state["is_odown"]: + return False + # Check if our sentinel doesn't see other nodes + if state["num-other-sentinels"] < self.min_other_sentinels: + return False + return True + + async def discover_master(self, service_name: str): + """ + Asks sentinel servers for the Redis master's address corresponding + to the service labeled ``service_name``. + + Returns a pair (address, port) or raises MasterNotFoundError if no + master is found. + """ + for sentinel_no, sentinel in enumerate(self.sentinels): + try: + masters = await sentinel.sentinel_masters() + except (ConnectionError, TimeoutError): + continue + state = masters.get(service_name) + if state and self.check_master_state(state, service_name): + # Put this sentinel at the top of the list + self.sentinels[0], self.sentinels[sentinel_no] = ( + sentinel, + self.sentinels[0], + ) + return state["ip"], state["port"] + raise MasterNotFoundError(f"No master found for {service_name!r}") + + def filter_slaves( + self, slaves: Iterable[Mapping] + ) -> Sequence[Tuple[EncodableT, EncodableT]]: + """Remove slaves that are in an ODOWN or SDOWN state""" + slaves_alive = [] + for slave in slaves: + if slave["is_odown"] or slave["is_sdown"]: + continue + slaves_alive.append((slave["ip"], slave["port"])) + return slaves_alive + + async def discover_slaves( + self, service_name: str + ) -> Sequence[Tuple[EncodableT, EncodableT]]: + """Returns a list of alive slaves for service ``service_name``""" + for sentinel in self.sentinels: + try: + slaves = await sentinel.sentinel_slaves(service_name) + except (ConnectionError, ResponseError, TimeoutError): + continue + slaves = self.filter_slaves(slaves) + if slaves: + return slaves + return [] + + def master_for( + self, + service_name: str, + redis_class: Type[Redis] = Redis, + connection_pool_class: Type[SentinelConnectionPool] = SentinelConnectionPool, + **kwargs, + ): + """ + Returns a redis client instance for the ``service_name`` master. + + A :py:class:`~redis.sentinel.SentinelConnectionPool` class is + used to retrieve the master's address before establishing a new + connection. + + NOTE: If the master's address has changed, any cached connections to + the old master are closed. + + By default clients will be a :py:class:`~redis.Redis` instance. + Specify a different class to the ``redis_class`` argument if you + desire something different. + + The ``connection_pool_class`` specifies the connection pool to + use. The :py:class:`~redis.sentinel.SentinelConnectionPool` + will be used by default. + + All other keyword arguments are merged with any connection_kwargs + passed to this class and passed to the connection pool as keyword + arguments to be used to initialize Redis connections. + """ + kwargs["is_master"] = True + connection_kwargs = dict(self.connection_kwargs) + connection_kwargs.update(kwargs) + return redis_class( + connection_pool=connection_pool_class( + service_name, self, **connection_kwargs + ) + ) + + def slave_for( + self, + service_name: str, + redis_class: Type[Redis] = Redis, + connection_pool_class: Type[SentinelConnectionPool] = SentinelConnectionPool, + **kwargs, + ): + """ + Returns redis client instance for the ``service_name`` slave(s). + + A SentinelConnectionPool class is used to retrieve the slave's + address before establishing a new connection. + + By default clients will be a :py:class:`~redis.Redis` instance. + Specify a different class to the ``redis_class`` argument if you + desire something different. + + The ``connection_pool_class`` specifies the connection pool to use. + The SentinelConnectionPool will be used by default. + + All other keyword arguments are merged with any connection_kwargs + passed to this class and passed to the connection pool as keyword + arguments to be used to initialize Redis connections. + """ + kwargs["is_master"] = False + connection_kwargs = dict(self.connection_kwargs) + connection_kwargs.update(kwargs) + return redis_class( + connection_pool=connection_pool_class( + service_name, self, **connection_kwargs + ) + ) diff --git a/redis/asyncio/utils.py b/redis/asyncio/utils.py new file mode 100644 index 0000000000..5a55b36a33 --- /dev/null +++ b/redis/asyncio/utils.py @@ -0,0 +1,28 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from redis.asyncio.client import Pipeline, Redis + + +def from_url(url, **kwargs): + """ + Returns an active Redis client generated from the given database URL. + + Will attempt to extract the database id from the path url fragment, if + none is provided. + """ + from redis.asyncio.client import Redis + + return Redis.from_url(url, **kwargs) + + +class pipeline: + def __init__(self, redis_obj: "Redis"): + self.p: "Pipeline" = redis_obj.pipeline() + + async def __aenter__(self) -> "Pipeline": + return self.p + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.p.execute() + del self.p diff --git a/redis/client.py b/redis/client.py index 1317cb6ffc..22c5dc1e57 100755 --- a/redis/client.py +++ b/redis/client.py @@ -642,19 +642,7 @@ def parse_set_result(response, **options): return response and str_if_bytes(response) == "OK" -class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): - """ - Implementation of the Redis protocol. - - This abstract class provides a Python interface to all Redis commands - and an implementation of the Redis protocol. - - Pipelines derive from this, implementing how - the commands are sent and received to the Redis server. Based on - configuration, an instance will either use a ConnectionPool, or - Connection object to talk to redis. - """ - +class AbstractRedis: RESPONSE_CALLBACKS = { **string_keys_to_dict( "AUTH COPY EXPIRE EXPIREAT PEXPIRE PEXPIREAT " @@ -807,6 +795,20 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): "ZMSCORE": parse_zmscore, } + +class Redis(AbstractRedis, RedisModuleCommands, CoreCommands, SentinelCommands): + """ + Implementation of the Redis protocol. + + This abstract class provides a Python interface to all Redis commands + and an implementation of the Redis protocol. + + Pipelines derive from this, implementing how + the commands are sent and received to the Redis server. Based on + configuration, an instance will either use a ConnectionPool, or + Connection object to talk to redis. + """ + @classmethod def from_url(cls, url, **kwargs): """ diff --git a/redis/commands/__init__.py b/redis/commands/__init__.py index 07fa7f1431..b9dd0b7210 100644 --- a/redis/commands/__init__.py +++ b/redis/commands/__init__.py @@ -1,15 +1,17 @@ from .cluster import RedisClusterCommands -from .core import CoreCommands +from .core import AsyncCoreCommands, CoreCommands from .helpers import list_or_args from .parser import CommandsParser from .redismodules import RedisModuleCommands -from .sentinel import SentinelCommands +from .sentinel import AsyncSentinelCommands, SentinelCommands __all__ = [ "RedisClusterCommands", "CommandsParser", + "AsyncCoreCommands", "CoreCommands", "list_or_args", "RedisModuleCommands", + "AsyncSentinelCommands", "SentinelCommands", ] diff --git a/redis/commands/core.py b/redis/commands/core.py index 1e221c9402..80bc55ff22 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -1,21 +1,64 @@ +# from __future__ import annotations + import datetime import hashlib import time import warnings -from typing import List, Optional, Union - +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Awaitable, + Callable, + Dict, + Iterable, + Iterator, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) + +from redis.compat import Literal from redis.exceptions import ConnectionError, DataError, NoScriptError, RedisError +from redis.typing import ( + AbsExpiryT, + AnyKeyT, + BitfieldOffsetT, + ChannelT, + CommandsProtocol, + ConsumerT, + EncodableT, + ExpiryT, + FieldT, + GroupT, + KeysT, + KeyT, + PatternT, + ScriptTextT, + StreamIdT, + TimeoutSecT, + ZScoreBoundT, +) from .helpers import list_or_args +if TYPE_CHECKING: + from redis.asyncio.client import Redis as AsyncRedis + from redis.client import Redis + +ResponseT = Union[Awaitable, Any] + -class ACLCommands: +class ACLCommands(CommandsProtocol): """ Redis Access Control List (ACL) commands. see: https://redis.io/topics/acl """ - def acl_cat(self, category=None, **kwargs): + def acl_cat(self, category: Union[str, None] = None, **kwargs) -> ResponseT: """ Returns a list of categories or commands within a category. @@ -25,7 +68,7 @@ def acl_cat(self, category=None, **kwargs): For more information check https://redis.io/commands/acl-cat """ - pieces = [category] if category else [] + pieces: list[EncodableT] = [category] if category else [] return self.execute_command("ACL CAT", *pieces, **kwargs) def acl_dryrun(self, username, *args, **kwargs): @@ -36,7 +79,7 @@ def acl_dryrun(self, username, *args, **kwargs): """ return self.execute_command("ACL DRYRUN", username, *args, **kwargs) - def acl_deluser(self, *username, **kwargs): + def acl_deluser(self, *username: str, **kwargs) -> ResponseT: """ Delete the ACL for the specified ``username``s @@ -44,7 +87,7 @@ def acl_deluser(self, *username, **kwargs): """ return self.execute_command("ACL DELUSER", *username, **kwargs) - def acl_genpass(self, bits=None, **kwargs): + def acl_genpass(self, bits: Union[int, None] = None, **kwargs) -> ResponseT: """Generate a random password value. If ``bits`` is supplied then use this number of bits, rounded to the next multiple of 4. @@ -62,7 +105,7 @@ def acl_genpass(self, bits=None, **kwargs): ) return self.execute_command("ACL GENPASS", *pieces, **kwargs) - def acl_getuser(self, username, **kwargs): + def acl_getuser(self, username: str, **kwargs) -> ResponseT: """ Get the ACL details for the specified ``username``. @@ -72,7 +115,7 @@ def acl_getuser(self, username, **kwargs): """ return self.execute_command("ACL GETUSER", username, **kwargs) - def acl_help(self, **kwargs): + def acl_help(self, **kwargs) -> ResponseT: """The ACL HELP command returns helpful text describing the different subcommands. @@ -80,7 +123,7 @@ def acl_help(self, **kwargs): """ return self.execute_command("ACL HELP", **kwargs) - def acl_list(self, **kwargs): + def acl_list(self, **kwargs) -> ResponseT: """ Return a list of all ACLs on the server @@ -88,7 +131,7 @@ def acl_list(self, **kwargs): """ return self.execute_command("ACL LIST", **kwargs) - def acl_log(self, count=None, **kwargs): + def acl_log(self, count: Union[int, None] = None, **kwargs) -> ResponseT: """ Get ACL logs as a list. :param int count: Get logs[0:count]. @@ -104,7 +147,7 @@ def acl_log(self, count=None, **kwargs): return self.execute_command("ACL LOG", *args, **kwargs) - def acl_log_reset(self, **kwargs): + def acl_log_reset(self, **kwargs) -> ResponseT: """ Reset ACL logs. :rtype: Boolean. @@ -114,7 +157,7 @@ def acl_log_reset(self, **kwargs): args = [b"RESET"] return self.execute_command("ACL LOG", *args, **kwargs) - def acl_load(self, **kwargs): + def acl_load(self, **kwargs) -> ResponseT: """ Load ACL rules from the configured ``aclfile``. @@ -125,7 +168,7 @@ def acl_load(self, **kwargs): """ return self.execute_command("ACL LOAD", **kwargs) - def acl_save(self, **kwargs): + def acl_save(self, **kwargs) -> ResponseT: """ Save ACL rules to the configured ``aclfile``. @@ -138,19 +181,19 @@ def acl_save(self, **kwargs): def acl_setuser( self, - username, - enabled=False, - nopass=False, - passwords=None, - hashed_passwords=None, - categories=None, - commands=None, - keys=None, - reset=False, - reset_keys=False, - reset_passwords=False, + username: str, + enabled: bool = False, + nopass: bool = False, + passwords: Union[str, Iterable[str], None] = None, + hashed_passwords: Union[str, Iterable[str], None] = None, + categories: Union[Iterable[str], None] = None, + commands: Union[Iterable[str], None] = None, + keys: Union[Iterable[KeyT], None] = None, + reset: bool = False, + reset_keys: bool = False, + reset_passwords: bool = False, **kwargs, - ): + ) -> ResponseT: """ Create or update an ACL user. @@ -213,7 +256,7 @@ def acl_setuser( For more information check https://redis.io/commands/acl-setuser """ encoder = self.get_encoder() - pieces = [username] + pieces: list[str | bytes] = [username] if reset: pieces.append(b"reset") @@ -303,14 +346,14 @@ def acl_setuser( return self.execute_command("ACL SETUSER", *pieces, **kwargs) - def acl_users(self, **kwargs): + def acl_users(self, **kwargs) -> ResponseT: """Returns a list of all registered users on the server. For more information check https://redis.io/commands/acl-users """ return self.execute_command("ACL USERS", **kwargs) - def acl_whoami(self, **kwargs): + def acl_whoami(self, **kwargs) -> ResponseT: """Get the username for the current connection For more information check https://redis.io/commands/acl-whoami @@ -318,7 +361,10 @@ def acl_whoami(self, **kwargs): return self.execute_command("ACL WHOAMI", **kwargs) -class ManagementCommands: +AsyncACLCommands = ACLCommands + + +class ManagementCommands(CommandsProtocol): """ Redis management commands """ @@ -339,7 +385,7 @@ def bgrewriteaof(self, **kwargs): """ return self.execute_command("BGREWRITEAOF", **kwargs) - def bgsave(self, schedule=True, **kwargs): + def bgsave(self, schedule: bool = True, **kwargs) -> ResponseT: """ Tell the Redis server to save its data to disk. Unlike save(), this method is asynchronous and returns immediately. @@ -351,7 +397,7 @@ def bgsave(self, schedule=True, **kwargs): pieces.append("SCHEDULE") return self.execute_command("BGSAVE", *pieces, **kwargs) - def role(self): + def role(self) -> ResponseT: """ Provide information on the role of a Redis instance in the context of replication, by returning if the instance @@ -361,7 +407,7 @@ def role(self): """ return self.execute_command("ROLE") - def client_kill(self, address, **kwargs): + def client_kill(self, address: str, **kwargs) -> ResponseT: """Disconnects the client at ``address`` (ip:port) For more information check https://redis.io/commands/client-kill @@ -370,18 +416,18 @@ def client_kill(self, address, **kwargs): def client_kill_filter( self, - _id=None, - _type=None, - addr=None, - skipme=None, - laddr=None, - user=None, + _id: Union[str, None] = None, + _type: Union[str, None] = None, + addr: Union[str, None] = None, + skipme: Union[bool, None] = None, + laddr: Union[bool, None] = None, + user: str = None, **kwargs, - ): + ) -> ResponseT: """ Disconnects client(s) using a variety of filter options - :param id: Kills a client by its unique ID field - :param type: Kills a client by type where type is one of 'normal', + :param _id: Kills a client by its unique ID field + :param _type: Kills a client by type where type is one of 'normal', 'master', 'slave' or 'pubsub' :param addr: Kills a client by its 'address:port' :param skipme: If True, then the client calling the command @@ -418,7 +464,7 @@ def client_kill_filter( ) return self.execute_command("CLIENT KILL", *args, **kwargs) - def client_info(self, **kwargs): + def client_info(self, **kwargs) -> ResponseT: """ Returns information and statistics about the current client connection. @@ -427,7 +473,12 @@ def client_info(self, **kwargs): """ return self.execute_command("CLIENT INFO", **kwargs) - def client_list(self, _type=None, client_id=[], **kwargs): + def client_list( + self, + _type: Union[str, None] = None, + client_id: List[EncodableT] = [], + **kwargs, + ) -> ResponseT: """ Returns a list of currently connected clients. If type of client specified, only that type will be returned. @@ -446,12 +497,12 @@ def client_list(self, _type=None, client_id=[], **kwargs): args.append(_type) if not isinstance(client_id, list): raise DataError("client_id must be a list") - if client_id != []: + if client_id: args.append(b"ID") args.append(" ".join(client_id)) return self.execute_command("CLIENT LIST", *args, **kwargs) - def client_getname(self, **kwargs): + def client_getname(self, **kwargs) -> ResponseT: """ Returns the current connection name @@ -459,7 +510,7 @@ def client_getname(self, **kwargs): """ return self.execute_command("CLIENT GETNAME", **kwargs) - def client_getredir(self, **kwargs): + def client_getredir(self, **kwargs) -> ResponseT: """ Returns the ID (an integer) of the client to whom we are redirecting tracking notifications. @@ -468,7 +519,11 @@ def client_getredir(self, **kwargs): """ return self.execute_command("CLIENT GETREDIR", **kwargs) - def client_reply(self, reply, **kwargs): + def client_reply( + self, + reply: Union[Literal["ON"], Literal["OFF"], Literal["SKIP"]], + **kwargs, + ) -> ResponseT: """ Enable and disable redis server replies. ``reply`` Must be ON OFF or SKIP, @@ -489,7 +544,7 @@ def client_reply(self, reply, **kwargs): raise DataError(f"CLIENT REPLY must be one of {replies!r}") return self.execute_command("CLIENT REPLY", reply, **kwargs) - def client_id(self, **kwargs): + def client_id(self, **kwargs) -> ResponseT: """ Returns the current connection id @@ -499,13 +554,13 @@ def client_id(self, **kwargs): def client_tracking_on( self, - clientid=None, - prefix=[], - bcast=False, - optin=False, - optout=False, - noloop=False, - ): + clientid: Union[int, None] = None, + prefix: Sequence[KeyT] = [], + bcast: bool = False, + optin: bool = False, + optout: bool = False, + noloop: bool = False, + ) -> ResponseT: """ Turn on the tracking mode. For more information about the options look at client_tracking func. @@ -518,13 +573,13 @@ def client_tracking_on( def client_tracking_off( self, - clientid=None, - prefix=[], - bcast=False, - optin=False, - optout=False, - noloop=False, - ): + clientid: Union[int, None] = None, + prefix: Sequence[KeyT] = [], + bcast: bool = False, + optin: bool = False, + optout: bool = False, + noloop: bool = False, + ) -> ResponseT: """ Turn off the tracking mode. For more information about the options look at client_tracking func. @@ -537,15 +592,15 @@ def client_tracking_off( def client_tracking( self, - on=True, - clientid=None, - prefix=[], - bcast=False, - optin=False, - optout=False, - noloop=False, + on: bool = True, + clientid: Union[int, None] = None, + prefix: Sequence[KeyT] = [], + bcast: bool = False, + optin: bool = False, + optout: bool = False, + noloop: bool = False, **kwargs, - ): + ) -> ResponseT: """ Enables the tracking feature of the Redis server, that is used for server assisted client side caching. @@ -595,7 +650,7 @@ def client_tracking( return self.execute_command("CLIENT TRACKING", *pieces) - def client_trackinginfo(self, **kwargs): + def client_trackinginfo(self, **kwargs) -> ResponseT: """ Returns the information about the current client connection's use of the server assisted client side cache. @@ -604,7 +659,7 @@ def client_trackinginfo(self, **kwargs): """ return self.execute_command("CLIENT TRACKINGINFO", **kwargs) - def client_setname(self, name, **kwargs): + def client_setname(self, name: str, **kwargs) -> ResponseT: """ Sets the current connection name @@ -612,7 +667,12 @@ def client_setname(self, name, **kwargs): """ return self.execute_command("CLIENT SETNAME", name, **kwargs) - def client_unblock(self, client_id, error=False, **kwargs): + def client_unblock( + self, + client_id: int, + error: bool = False, + **kwargs, + ) -> ResponseT: """ Unblocks a connection by its client id. If ``error`` is True, unblocks the client with a special error message. @@ -626,7 +686,7 @@ def client_unblock(self, client_id, error=False, **kwargs): args.append(b"ERROR") return self.execute_command(*args, **kwargs) - def client_pause(self, timeout, all=True, **kwargs): + def client_pause(self, timeout: int, all: bool = True, **kwargs) -> ResponseT: """ Suspend all the Redis clients for the specified amount of time :param timeout: milliseconds to pause clients @@ -649,7 +709,7 @@ def client_pause(self, timeout, all=True, **kwargs): args.append("WRITE") return self.execute_command(*args, **kwargs) - def client_unpause(self, **kwargs): + def client_unpause(self, **kwargs) -> ResponseT: """ Unpause all redis clients @@ -673,15 +733,15 @@ def command(self, **kwargs): """ return self.execute_command("COMMAND", **kwargs) - def command_info(self, **kwargs): + def command_info(self, **kwargs) -> None: raise NotImplementedError( "COMMAND INFO is intentionally not implemented in the client." ) - def command_count(self, **kwargs): + def command_count(self, **kwargs) -> ResponseT: return self.execute_command("COMMAND COUNT", **kwargs) - def config_get(self, pattern="*", **kwargs): + def config_get(self, pattern: PatternT = "*", **kwargs) -> ResponseT: """ Return a dictionary of configuration based on the ``pattern`` @@ -689,14 +749,14 @@ def config_get(self, pattern="*", **kwargs): """ return self.execute_command("CONFIG GET", pattern, **kwargs) - def config_set(self, name, value, **kwargs): + def config_set(self, name: KeyT, value: EncodableT, **kwargs) -> ResponseT: """Set config item ``name`` with ``value`` For more information check https://redis.io/commands/config-set """ return self.execute_command("CONFIG SET", name, value, **kwargs) - def config_resetstat(self, **kwargs): + def config_resetstat(self, **kwargs) -> ResponseT: """ Reset runtime statistics @@ -704,7 +764,7 @@ def config_resetstat(self, **kwargs): """ return self.execute_command("CONFIG RESETSTAT", **kwargs) - def config_rewrite(self, **kwargs): + def config_rewrite(self, **kwargs) -> ResponseT: """ Rewrite config file with the minimal change to reflect running config. @@ -712,7 +772,7 @@ def config_rewrite(self, **kwargs): """ return self.execute_command("CONFIG REWRITE", **kwargs) - def dbsize(self, **kwargs): + def dbsize(self, **kwargs) -> ResponseT: """ Returns the number of keys in the current database @@ -720,7 +780,7 @@ def dbsize(self, **kwargs): """ return self.execute_command("DBSIZE", **kwargs) - def debug_object(self, key, **kwargs): + def debug_object(self, key: KeyT, **kwargs) -> ResponseT: """ Returns version specific meta information about a given key @@ -728,7 +788,7 @@ def debug_object(self, key, **kwargs): """ return self.execute_command("DEBUG OBJECT", key, **kwargs) - def debug_segfault(self, **kwargs): + def debug_segfault(self, **kwargs) -> None: raise NotImplementedError( """ DEBUG SEGFAULT is intentionally not implemented in the client. @@ -737,7 +797,7 @@ def debug_segfault(self, **kwargs): """ ) - def echo(self, value, **kwargs): + def echo(self, value: EncodableT, **kwargs) -> ResponseT: """ Echo the string back from the server @@ -745,7 +805,7 @@ def echo(self, value, **kwargs): """ return self.execute_command("ECHO", value, **kwargs) - def flushall(self, asynchronous=False, **kwargs): + def flushall(self, asynchronous: bool = False, **kwargs) -> ResponseT: """ Delete all keys in all databases on the current host. @@ -759,7 +819,7 @@ def flushall(self, asynchronous=False, **kwargs): args.append(b"ASYNC") return self.execute_command("FLUSHALL", *args, **kwargs) - def flushdb(self, asynchronous=False, **kwargs): + def flushdb(self, asynchronous: bool = False, **kwargs) -> ResponseT: """ Delete all keys in the current database. @@ -773,7 +833,7 @@ def flushdb(self, asynchronous=False, **kwargs): args.append(b"ASYNC") return self.execute_command("FLUSHDB", *args, **kwargs) - def sync(self): + def sync(self) -> ResponseT: """ Initiates a replication stream from the master. @@ -785,7 +845,7 @@ def sync(self): options[NEVER_DECODE] = [] return self.execute_command("SYNC", **options) - def psync(self, replicationid, offset): + def psync(self, replicationid: str, offset: int): """ Initiates a replication stream from the master. Newer version for `sync`. @@ -798,7 +858,7 @@ def psync(self, replicationid, offset): options[NEVER_DECODE] = [] return self.execute_command("PSYNC", replicationid, offset, **options) - def swapdb(self, first, second, **kwargs): + def swapdb(self, first: int, second: int, **kwargs) -> ResponseT: """ Swap two databases @@ -806,14 +866,14 @@ def swapdb(self, first, second, **kwargs): """ return self.execute_command("SWAPDB", first, second, **kwargs) - def select(self, index, **kwargs): + def select(self, index: int, **kwargs) -> ResponseT: """Select the Redis logical database at index. See: https://redis.io/commands/select """ return self.execute_command("SELECT", index, **kwargs) - def info(self, section=None, **kwargs): + def info(self, section: Union[str, None] = None, **kwargs) -> ResponseT: """ Returns a dictionary containing information about the Redis server @@ -830,7 +890,7 @@ def info(self, section=None, **kwargs): else: return self.execute_command("INFO", section, **kwargs) - def lastsave(self, **kwargs): + def lastsave(self, **kwargs) -> ResponseT: """ Return a Python datetime object representing the last time the Redis database was saved to disk @@ -839,7 +899,7 @@ def lastsave(self, **kwargs): """ return self.execute_command("LASTSAVE", **kwargs) - def lolwut(self, *version_numbers, **kwargs): + def lolwut(self, *version_numbers: Union[str, float], **kwargs) -> ResponseT: """ Get the Redis version and a piece of generative computer art @@ -850,7 +910,7 @@ def lolwut(self, *version_numbers, **kwargs): else: return self.execute_command("LOLWUT", **kwargs) - def reset(self): + def reset(self) -> ResponseT: """Perform a full reset on the connection's server side contenxt. See: https://redis.io/commands/reset @@ -859,16 +919,16 @@ def reset(self): def migrate( self, - host, - port, - keys, - destination_db, - timeout, - copy=False, - replace=False, - auth=None, + host: str, + port: int, + keys: KeysT, + destination_db: int, + timeout: int, + copy: bool = False, + replace: bool = False, + auth: Union[str, None] = None, **kwargs, - ): + ) -> ResponseT: """ Migrate 1 or more keys from the current Redis server to a different server specified by the ``host``, ``port`` and ``destination_db``. @@ -905,7 +965,7 @@ def migrate( "MIGRATE", host, port, "", destination_db, timeout, *pieces, **kwargs ) - def object(self, infotype, key, **kwargs): + def object(self, infotype: str, key: KeyT, **kwargs) -> ResponseT: """ Return the encoding, idletime, or refcount about the key """ @@ -913,7 +973,7 @@ def object(self, infotype, key, **kwargs): "OBJECT", infotype, key, infotype=infotype, **kwargs ) - def memory_doctor(self, **kwargs): + def memory_doctor(self, **kwargs) -> None: raise NotImplementedError( """ MEMORY DOCTOR is intentionally not implemented in the client. @@ -922,7 +982,7 @@ def memory_doctor(self, **kwargs): """ ) - def memory_help(self, **kwargs): + def memory_help(self, **kwargs) -> None: raise NotImplementedError( """ MEMORY HELP is intentionally not implemented in the client. @@ -931,7 +991,7 @@ def memory_help(self, **kwargs): """ ) - def memory_stats(self, **kwargs): + def memory_stats(self, **kwargs) -> ResponseT: """ Return a dictionary of memory stats @@ -939,7 +999,7 @@ def memory_stats(self, **kwargs): """ return self.execute_command("MEMORY STATS", **kwargs) - def memory_malloc_stats(self, **kwargs): + def memory_malloc_stats(self, **kwargs) -> ResponseT: """ Return an internal statistics report from the memory allocator. @@ -947,7 +1007,9 @@ def memory_malloc_stats(self, **kwargs): """ return self.execute_command("MEMORY MALLOC-STATS", **kwargs) - def memory_usage(self, key, samples=None, **kwargs): + def memory_usage( + self, key: KeyT, samples: Union[int, None] = None, **kwargs + ) -> ResponseT: """ Return the total memory usage for key, its value and associated administrative overheads. @@ -963,7 +1025,7 @@ def memory_usage(self, key, samples=None, **kwargs): args.extend([b"SAMPLES", samples]) return self.execute_command("MEMORY USAGE", key, *args, **kwargs) - def memory_purge(self, **kwargs): + def memory_purge(self, **kwargs) -> ResponseT: """ Attempts to purge dirty pages for reclamation by allocator @@ -971,7 +1033,7 @@ def memory_purge(self, **kwargs): """ return self.execute_command("MEMORY PURGE", **kwargs) - def ping(self, **kwargs): + def ping(self, **kwargs) -> ResponseT: """ Ping the Redis server @@ -979,7 +1041,7 @@ def ping(self, **kwargs): """ return self.execute_command("PING", **kwargs) - def quit(self, **kwargs): + def quit(self, **kwargs) -> ResponseT: """ Ask the server to close the connection. @@ -987,7 +1049,7 @@ def quit(self, **kwargs): """ return self.execute_command("QUIT", **kwargs) - def replicaof(self, *args, **kwargs): + def replicaof(self, *args, **kwargs) -> ResponseT: """ Update the replication settings of a redis replica, on the fly. Examples of valid arguments include: @@ -998,7 +1060,7 @@ def replicaof(self, *args, **kwargs): """ return self.execute_command("REPLICAOF", *args, **kwargs) - def save(self, **kwargs): + def save(self, **kwargs) -> ResponseT: """ Tell the Redis server to save its data to disk, blocking until the save is complete @@ -1007,7 +1069,7 @@ def save(self, **kwargs): """ return self.execute_command("SAVE", **kwargs) - def shutdown(self, save=False, nosave=False, **kwargs): + def shutdown(self, save: bool = False, nosave: bool = False, **kwargs) -> None: """Shutdown the Redis server. If Redis has persistence configured, data will be flushed before shutdown. If the "save" option is set, a data flush will be attempted even if there is no persistence @@ -1030,7 +1092,9 @@ def shutdown(self, save=False, nosave=False, **kwargs): return raise RedisError("SHUTDOWN seems to have failed.") - def slaveof(self, host=None, port=None, **kwargs): + def slaveof( + self, host: Union[str, None] = None, port: Union[int, None] = None, **kwargs + ) -> ResponseT: """ Set the server to be a replicated slave of the instance identified by the ``host`` and ``port``. If called without arguments, the @@ -1042,7 +1106,7 @@ def slaveof(self, host=None, port=None, **kwargs): return self.execute_command("SLAVEOF", b"NO", b"ONE", **kwargs) return self.execute_command("SLAVEOF", host, port, **kwargs) - def slowlog_get(self, num=None, **kwargs): + def slowlog_get(self, num: Union[int, None] = None, **kwargs) -> ResponseT: """ Get the entries from the slowlog. If ``num`` is specified, get the most recent ``num`` items. @@ -1059,7 +1123,7 @@ def slowlog_get(self, num=None, **kwargs): kwargs[NEVER_DECODE] = [] return self.execute_command(*args, **kwargs) - def slowlog_len(self, **kwargs): + def slowlog_len(self, **kwargs) -> ResponseT: """ Get the number of items in the slowlog @@ -1067,7 +1131,7 @@ def slowlog_len(self, **kwargs): """ return self.execute_command("SLOWLOG LEN", **kwargs) - def slowlog_reset(self, **kwargs): + def slowlog_reset(self, **kwargs) -> ResponseT: """ Remove all items in the slowlog @@ -1075,7 +1139,7 @@ def slowlog_reset(self, **kwargs): """ return self.execute_command("SLOWLOG RESET", **kwargs) - def time(self, **kwargs): + def time(self, **kwargs) -> ResponseT: """ Returns the server time as a 2-item tuple of ints: (seconds since epoch, microseconds into this second). @@ -1084,7 +1148,7 @@ def time(self, **kwargs): """ return self.execute_command("TIME", **kwargs) - def wait(self, num_replicas, timeout, **kwargs): + def wait(self, num_replicas: int, timeout: int, **kwargs) -> ResponseT: """ Redis synchronous replication That returns the number of replicas that processed the query when @@ -1114,12 +1178,166 @@ def failover(self): ) -class BasicKeyCommands: +AsyncManagementCommands = ManagementCommands + + +class AsyncManagementCommands(ManagementCommands): + async def command_info(self, **kwargs) -> None: + return super().command_info(**kwargs) + + async def debug_segfault(self, **kwargs) -> None: + return super().debug_segfault(**kwargs) + + async def memory_doctor(self, **kwargs) -> None: + return super().memory_doctor(**kwargs) + + async def memory_help(self, **kwargs) -> None: + return super().memory_help(**kwargs) + + async def shutdown( + self, save: bool = False, nosave: bool = False, **kwargs + ) -> None: + """Shutdown the Redis server. If Redis has persistence configured, + data will be flushed before shutdown. If the "save" option is set, + a data flush will be attempted even if there is no persistence + configured. If the "nosave" option is set, no data flush will be + attempted. The "save" and "nosave" options cannot both be set. + + For more information check https://redis.io/commands/shutdown + """ + if save and nosave: + raise DataError("SHUTDOWN save and nosave cannot both be set") + args = ["SHUTDOWN"] + if save: + args.append("SAVE") + if nosave: + args.append("NOSAVE") + try: + await self.execute_command(*args, **kwargs) + except ConnectionError: + # a ConnectionError here is expected + return + raise RedisError("SHUTDOWN seems to have failed.") + + +class BitFieldOperation: + """ + Command builder for BITFIELD commands. + """ + + def __init__( + self, + client: Union["Redis", "AsyncRedis"], + key: str, + default_overflow: Union[str, None] = None, + ): + self.client = client + self.key = key + self._default_overflow = default_overflow + # for typing purposes, run the following in constructor and in reset() + self.operations: list[tuple[EncodableT, ...]] = [] + self._last_overflow = "WRAP" + self.reset() + + def reset(self): + """ + Reset the state of the instance to when it was constructed + """ + self.operations = [] + self._last_overflow = "WRAP" + self.overflow(self._default_overflow or self._last_overflow) + + def overflow(self, overflow: str): + """ + Update the overflow algorithm of successive INCRBY operations + :param overflow: Overflow algorithm, one of WRAP, SAT, FAIL. See the + Redis docs for descriptions of these algorithmsself. + :returns: a :py:class:`BitFieldOperation` instance. + """ + overflow = overflow.upper() + if overflow != self._last_overflow: + self._last_overflow = overflow + self.operations.append(("OVERFLOW", overflow)) + return self + + def incrby( + self, + fmt: str, + offset: BitfieldOffsetT, + increment: int, + overflow: Union[str, None] = None, + ): + """ + Increment a bitfield by a given amount. + :param fmt: format-string for the bitfield being updated, e.g. 'u8' + for an unsigned 8-bit integer. + :param offset: offset (in number of bits). If prefixed with a + '#', this is an offset multiplier, e.g. given the arguments + fmt='u8', offset='#2', the offset will be 16. + :param int increment: value to increment the bitfield by. + :param str overflow: overflow algorithm. Defaults to WRAP, but other + acceptable values are SAT and FAIL. See the Redis docs for + descriptions of these algorithms. + :returns: a :py:class:`BitFieldOperation` instance. + """ + if overflow is not None: + self.overflow(overflow) + + self.operations.append(("INCRBY", fmt, offset, increment)) + return self + + def get(self, fmt: str, offset: BitfieldOffsetT): + """ + Get the value of a given bitfield. + :param fmt: format-string for the bitfield being read, e.g. 'u8' for + an unsigned 8-bit integer. + :param offset: offset (in number of bits). If prefixed with a + '#', this is an offset multiplier, e.g. given the arguments + fmt='u8', offset='#2', the offset will be 16. + :returns: a :py:class:`BitFieldOperation` instance. + """ + self.operations.append(("GET", fmt, offset)) + return self + + def set(self, fmt: str, offset: BitfieldOffsetT, value: int): + """ + Set the value of a given bitfield. + :param fmt: format-string for the bitfield being read, e.g. 'u8' for + an unsigned 8-bit integer. + :param offset: offset (in number of bits). If prefixed with a + '#', this is an offset multiplier, e.g. given the arguments + fmt='u8', offset='#2', the offset will be 16. + :param int value: value to set at the given position. + :returns: a :py:class:`BitFieldOperation` instance. + """ + self.operations.append(("SET", fmt, offset, value)) + return self + + @property + def command(self): + cmd = ["BITFIELD", self.key] + for ops in self.operations: + cmd.extend(ops) + return cmd + + def execute(self) -> ResponseT: + """ + Execute the operation(s) in a single BITFIELD command. The return value + is a list of values corresponding to each operation. If the client + used to create this instance was a pipeline, the list of values + will be present within the pipeline's execute. + """ + command = self.command + self.reset() + return self.client.execute_command(*command) + + +class BasicKeyCommands(CommandsProtocol): """ Redis basic key-based commands """ - def append(self, key, value): + def append(self, key: KeyT, value: EncodableT) -> ResponseT: """ Appends the string ``value`` to the value at ``key``. If ``key`` doesn't already exist, create it with a value of ``value``. @@ -1129,7 +1347,12 @@ def append(self, key, value): """ return self.execute_command("APPEND", key, value) - def bitcount(self, key, start=None, end=None): + def bitcount( + self, + key: KeyT, + start: Union[int, None] = None, + end: Union[int, None] = None, + ) -> ResponseT: """ Returns the count of set bits in the value of ``key``. Optional ``start`` and ``end`` parameters indicate which bytes to consider @@ -1144,7 +1367,11 @@ def bitcount(self, key, start=None, end=None): raise DataError("Both start and end must be specified") return self.execute_command("BITCOUNT", *params) - def bitfield(self, key, default_overflow=None): + def bitfield( + self: Union["Redis", "AsyncRedis"], + key: KeyT, + default_overflow: Union[str, None] = None, + ) -> BitFieldOperation: """ Return a BitFieldOperation instance to conveniently construct one or more bitfield operations on ``key``. @@ -1153,7 +1380,12 @@ def bitfield(self, key, default_overflow=None): """ return BitFieldOperation(self, key, default_overflow=default_overflow) - def bitop(self, operation, dest, *keys): + def bitop( + self, + operation: str, + dest: KeyT, + *keys: KeyT, + ) -> ResponseT: """ Perform a bitwise operation using ``operation`` between ``keys`` and store the result in ``dest``. @@ -1162,7 +1394,13 @@ def bitop(self, operation, dest, *keys): """ return self.execute_command("BITOP", operation, dest, *keys) - def bitpos(self, key, bit, start=None, end=None): + def bitpos( + self, + key: KeyT, + bit: int, + start: Union[int, None] = None, + end: Union[int, None] = None, + ) -> ResponseT: """ Return the position of the first bit set to 1 or 0 in a string. ``start`` and ``end`` defines search range. The range is interpreted @@ -1183,7 +1421,13 @@ def bitpos(self, key, bit, start=None, end=None): raise DataError("start argument is not set, " "when end is specified") return self.execute_command("BITPOS", *params) - def copy(self, source, destination, destination_db=None, replace=False): + def copy( + self, + source: str, + destination: str, + destination_db: Union[str, None] = None, + replace: bool = False, + ) -> ResponseT: """ Copy the value stored in the ``source`` key to the ``destination`` key. @@ -1203,7 +1447,7 @@ def copy(self, source, destination, destination_db=None, replace=False): params.append("REPLACE") return self.execute_command("COPY", *params) - def decrby(self, name, amount=1): + def decrby(self, name: KeyT, amount: int = 1) -> ResponseT: """ Decrements the value of ``key`` by ``amount``. If no key exists, the value will be initialized as 0 - ``amount`` @@ -1214,16 +1458,16 @@ def decrby(self, name, amount=1): decr = decrby - def delete(self, *names): + def delete(self, *names: KeyT) -> ResponseT: """ Delete one or more keys specified by ``names`` """ return self.execute_command("DEL", *names) - def __delitem__(self, name): + def __delitem__(self, name: KeyT): self.delete(name) - def dump(self, name): + def dump(self, name: KeyT) -> ResponseT: """ Return a serialized version of the value stored at the specified key. If key does not exist a nil bulk reply is returned. @@ -1236,7 +1480,7 @@ def dump(self, name): options[NEVER_DECODE] = [] return self.execute_command("DUMP", name, **options) - def exists(self, *names): + def exists(self, *names: KeyT) -> ResponseT: """ Returns the number of ``names`` that exist @@ -1246,7 +1490,7 @@ def exists(self, *names): __contains__ = exists - def expire(self, name, time): + def expire(self, name: KeyT, time: ExpiryT) -> ResponseT: """ Set an expire flag on key ``name`` for ``time`` seconds. ``time`` can be represented by an integer or a Python timedelta object. @@ -1257,7 +1501,7 @@ def expire(self, name, time): time = int(time.total_seconds()) return self.execute_command("EXPIRE", name, time) - def expireat(self, name, when): + def expireat(self, name: KeyT, when: AbsExpiryT) -> ResponseT: """ Set an expire flag on key ``name``. ``when`` can be represented as an integer indicating unix time or a Python datetime object. @@ -1268,7 +1512,7 @@ def expireat(self, name, when): when = int(time.mktime(when.timetuple())) return self.execute_command("EXPIREAT", name, when) - def get(self, name): + def get(self, name: KeyT) -> ResponseT: """ Return the value at key ``name``, or None if the key doesn't exist @@ -1276,7 +1520,7 @@ def get(self, name): """ return self.execute_command("GET", name) - def getdel(self, name): + def getdel(self, name: KeyT) -> ResponseT: """ Get the value at key ``name`` and delete the key. This command is similar to GET, except for the fact that it also deletes @@ -1287,7 +1531,15 @@ def getdel(self, name): """ return self.execute_command("GETDEL", name) - def getex(self, name, ex=None, px=None, exat=None, pxat=None, persist=False): + def getex( + self, + name: KeyT, + ex: Union[ExpiryT, None] = None, + px: Union[ExpiryT, None] = None, + exat: Union[AbsExpiryT, None] = None, + pxat: Union[AbsExpiryT, None] = None, + persist: bool = False, + ) -> ResponseT: """ Get the value of key and optionally set its expiration. GETEX is similar to GET, but is a write command with @@ -1316,7 +1568,7 @@ def getex(self, name, ex=None, px=None, exat=None, pxat=None, persist=False): "and ``persist`` are mutually exclusive." ) - pieces = [] + pieces: list[EncodableT] = [] # similar to set command if ex is not None: pieces.append("EX") @@ -1346,7 +1598,7 @@ def getex(self, name, ex=None, px=None, exat=None, pxat=None, persist=False): return self.execute_command("GETEX", name, *pieces) - def __getitem__(self, name): + def __getitem__(self, name: KeyT): """ Return the value at key ``name``, raises a KeyError if the key doesn't exist. @@ -1356,7 +1608,7 @@ def __getitem__(self, name): return value raise KeyError(name) - def getbit(self, name, offset): + def getbit(self, name: KeyT, offset: int) -> ResponseT: """ Returns a boolean indicating the value of ``offset`` in ``name`` @@ -1364,7 +1616,7 @@ def getbit(self, name, offset): """ return self.execute_command("GETBIT", name, offset) - def getrange(self, key, start, end): + def getrange(self, key: KeyT, start: int, end: int) -> ResponseT: """ Returns the substring of the string value stored at ``key``, determined by the offsets ``start`` and ``end`` (both are inclusive) @@ -1373,7 +1625,7 @@ def getrange(self, key, start, end): """ return self.execute_command("GETRANGE", key, start, end) - def getset(self, name, value): + def getset(self, name: KeyT, value: EncodableT) -> ResponseT: """ Sets the value at key ``name`` to ``value`` and returns the old value at key ``name`` atomically. @@ -1385,7 +1637,7 @@ def getset(self, name, value): """ return self.execute_command("GETSET", name, value) - def incrby(self, name, amount=1): + def incrby(self, name: KeyT, amount: int = 1) -> ResponseT: """ Increments the value of ``key`` by ``amount``. If no key exists, the value will be initialized as ``amount`` @@ -1396,7 +1648,7 @@ def incrby(self, name, amount=1): incr = incrby - def incrbyfloat(self, name, amount=1.0): + def incrbyfloat(self, name: KeyT, amount: float = 1.0) -> ResponseT: """ Increments the value at key ``name`` by floating ``amount``. If no key exists, the value will be initialized as ``amount`` @@ -1405,7 +1657,7 @@ def incrbyfloat(self, name, amount=1.0): """ return self.execute_command("INCRBYFLOAT", name, amount) - def keys(self, pattern="*", **kwargs): + def keys(self, pattern: PatternT = "*", **kwargs) -> ResponseT: """ Returns a list of keys matching ``pattern`` @@ -1413,7 +1665,13 @@ def keys(self, pattern="*", **kwargs): """ return self.execute_command("KEYS", pattern, **kwargs) - def lmove(self, first_list, second_list, src="LEFT", dest="RIGHT"): + def lmove( + self, + first_list: str, + second_list: str, + src: str = "LEFT", + dest: str = "RIGHT", + ) -> ResponseT: """ Atomically returns and removes the first/last element of a list, pushing it as the first/last element on the destination list. @@ -1424,7 +1682,14 @@ def lmove(self, first_list, second_list, src="LEFT", dest="RIGHT"): params = [first_list, second_list, src, dest] return self.execute_command("LMOVE", *params) - def blmove(self, first_list, second_list, timeout, src="LEFT", dest="RIGHT"): + def blmove( + self, + first_list: str, + second_list: str, + timeout: int, + src: str = "LEFT", + dest: str = "RIGHT", + ) -> ResponseT: """ Blocking version of lmove. @@ -1433,7 +1698,7 @@ def blmove(self, first_list, second_list, timeout, src="LEFT", dest="RIGHT"): params = [first_list, second_list, src, dest, timeout] return self.execute_command("BLMOVE", *params) - def mget(self, keys, *args): + def mget(self, keys: KeysT, *args: EncodableT) -> ResponseT: """ Returns a list of values ordered identically to ``keys`` @@ -1447,7 +1712,7 @@ def mget(self, keys, *args): options[EMPTY_RESPONSE] = [] return self.execute_command("MGET", *args, **options) - def mset(self, mapping): + def mset(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT: """ Sets key/values based on a mapping. Mapping is a dictionary of key/value pairs. Both keys and values should be strings or types that @@ -1460,7 +1725,7 @@ def mset(self, mapping): items.extend(pair) return self.execute_command("MSET", *items) - def msetnx(self, mapping): + def msetnx(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT: """ Sets key/values based on a mapping if none of the keys are already set. Mapping is a dictionary of key/value pairs. Both keys and values @@ -1474,7 +1739,7 @@ def msetnx(self, mapping): items.extend(pair) return self.execute_command("MSETNX", *items) - def move(self, name, db): + def move(self, name: KeyT, db: int) -> ResponseT: """ Moves the key ``name`` to a different Redis database ``db`` @@ -1482,7 +1747,7 @@ def move(self, name, db): """ return self.execute_command("MOVE", name, db) - def persist(self, name): + def persist(self, name: KeyT) -> ResponseT: """ Removes an expiration on ``name`` @@ -1490,7 +1755,7 @@ def persist(self, name): """ return self.execute_command("PERSIST", name) - def pexpire(self, name, time): + def pexpire(self, name: KeyT, time: ExpiryT) -> ResponseT: """ Set an expire flag on key ``name`` for ``time`` milliseconds. ``time`` can be represented by an integer or a Python timedelta @@ -1502,7 +1767,7 @@ def pexpire(self, name, time): time = int(time.total_seconds() * 1000) return self.execute_command("PEXPIRE", name, time) - def pexpireat(self, name, when): + def pexpireat(self, name: KeyT, when: AbsExpiryT) -> ResponseT: """ Set an expire flag on key ``name``. ``when`` can be represented as an integer representing unix time in milliseconds (unix time * 1000) @@ -1515,7 +1780,12 @@ def pexpireat(self, name, when): when = int(time.mktime(when.timetuple())) * 1000 + ms return self.execute_command("PEXPIREAT", name, when) - def psetex(self, name, time_ms, value): + def psetex( + self, + name: KeyT, + time_ms: ExpiryT, + value: EncodableT, + ): """ Set the value of key ``name`` to ``value`` that expires in ``time_ms`` milliseconds. ``time_ms`` can be represented by an integer or a Python @@ -1527,7 +1797,7 @@ def psetex(self, name, time_ms, value): time_ms = int(time_ms.total_seconds() * 1000) return self.execute_command("PSETEX", name, time_ms, value) - def pttl(self, name): + def pttl(self, name: KeyT) -> ResponseT: """ Returns the number of milliseconds until the key ``name`` will expire @@ -1535,7 +1805,12 @@ def pttl(self, name): """ return self.execute_command("PTTL", name) - def hrandfield(self, key, count=None, withvalues=False): + def hrandfield( + self, + key: str, + count: int = None, + withvalues: bool = False, + ) -> ResponseT: """ Return a random field from the hash value stored at key. @@ -1557,7 +1832,7 @@ def hrandfield(self, key, count=None, withvalues=False): return self.execute_command("HRANDFIELD", key, *params) - def randomkey(self, **kwargs): + def randomkey(self, **kwargs) -> ResponseT: """ Returns the name of a random key @@ -1565,7 +1840,7 @@ def randomkey(self, **kwargs): """ return self.execute_command("RANDOMKEY", **kwargs) - def rename(self, src, dst): + def rename(self, src: KeyT, dst: KeyT) -> ResponseT: """ Rename key ``src`` to ``dst`` @@ -1573,7 +1848,7 @@ def rename(self, src, dst): """ return self.execute_command("RENAME", src, dst) - def renamenx(self, src, dst): + def renamenx(self, src: KeyT, dst: KeyT): """ Rename key ``src`` to ``dst`` if ``dst`` doesn't already exist @@ -1583,14 +1858,14 @@ def renamenx(self, src, dst): def restore( self, - name, - ttl, - value, - replace=False, - absttl=False, - idletime=None, - frequency=None, - ): + name: KeyT, + ttl: float, + value: EncodableT, + replace: bool = False, + absttl: bool = False, + idletime: Union[int, None] = None, + frequency: Union[int, None] = None, + ) -> ResponseT: """ Create a key using the provided serialized value, previously obtained using DUMP. @@ -1633,17 +1908,17 @@ def restore( def set( self, - name, - value, - ex=None, - px=None, - nx=False, - xx=False, - keepttl=False, - get=False, - exat=None, - pxat=None, - ): + name: KeyT, + value: EncodableT, + ex: Union[ExpiryT, None] = None, + px: Union[ExpiryT, None] = None, + nx: bool = False, + xx: bool = False, + keepttl: bool = False, + get: bool = False, + exat: Union[AbsExpiryT, None] = None, + pxat: Union[AbsExpiryT, None] = None, + ) -> ResponseT: """ Set the value at key ``name`` to ``value`` @@ -1672,7 +1947,7 @@ def set( For more information check https://redis.io/commands/set """ - pieces = [name, value] + pieces: list[EncodableT] = [name, value] options = {} if ex is not None: pieces.append("EX") @@ -1716,10 +1991,10 @@ def set( return self.execute_command("SET", *pieces, **options) - def __setitem__(self, name, value): + def __setitem__(self, name: KeyT, value: EncodableT): self.set(name, value) - def setbit(self, name, offset, value): + def setbit(self, name: KeyT, offset: int, value: int) -> ResponseT: """ Flag the ``offset`` in ``name`` as ``value``. Returns a boolean indicating the previous value of ``offset``. @@ -1729,7 +2004,7 @@ def setbit(self, name, offset, value): value = value and 1 or 0 return self.execute_command("SETBIT", name, offset, value) - def setex(self, name, time, value): + def setex(self, name: KeyT, time: ExpiryT, value: EncodableT) -> ResponseT: """ Set the value of key ``name`` to ``value`` that expires in ``time`` seconds. ``time`` can be represented by an integer or a Python @@ -1741,7 +2016,7 @@ def setex(self, name, time, value): time = int(time.total_seconds()) return self.execute_command("SETEX", name, time, value) - def setnx(self, name, value): + def setnx(self, name: KeyT, value: EncodableT) -> ResponseT: """ Set the value of key ``name`` to ``value`` if key doesn't exist @@ -1749,7 +2024,12 @@ def setnx(self, name, value): """ return self.execute_command("SETNX", name, value) - def setrange(self, name, offset, value): + def setrange( + self, + name: KeyT, + offset: int, + value: EncodableT, + ) -> ResponseT: """ Overwrite bytes in the value of ``name`` starting at ``offset`` with ``value``. If ``offset`` plus the length of ``value`` exceeds the @@ -1766,16 +2046,16 @@ def setrange(self, name, offset, value): def stralgo( self, - algo, - value1, - value2, - specific_argument="strings", - len=False, - idx=False, - minmatchlen=None, - withmatchlen=False, + algo: Literal["LCS"], + value1: KeyT, + value2: KeyT, + specific_argument: Union[Literal["strings"], Literal["keys"]] = "strings", + len: bool = False, + idx: bool = False, + minmatchlen: Union[int, None] = None, + withmatchlen: bool = False, **kwargs, - ): + ) -> ResponseT: """ Implements complex algorithms that operate on strings. Right now the only algorithm implemented is the LCS algorithm @@ -1805,7 +2085,7 @@ def stralgo( if len and idx: raise DataError("len and idx cannot be provided together.") - pieces = [algo, specific_argument.upper(), value1, value2] + pieces: list[EncodableT] = [algo, specific_argument.upper(), value1, value2] if len: pieces.append(b"LEN") if idx: @@ -1828,7 +2108,7 @@ def stralgo( **kwargs, ) - def strlen(self, name): + def strlen(self, name: KeyT) -> ResponseT: """ Return the number of bytes stored in the value of ``name`` @@ -1836,14 +2116,14 @@ def strlen(self, name): """ return self.execute_command("STRLEN", name) - def substr(self, name, start, end=-1): + def substr(self, name: KeyT, start: int, end: int = -1) -> ResponseT: """ Return a substring of the string at key ``name``. ``start`` and ``end`` are 0-based integers specifying the portion of the string to return. """ return self.execute_command("SUBSTR", name, start, end) - def touch(self, *args): + def touch(self, *args: KeyT) -> ResponseT: """ Alters the last access time of a key(s) ``*args``. A key is ignored if it does not exist. @@ -1852,7 +2132,7 @@ def touch(self, *args): """ return self.execute_command("TOUCH", *args) - def ttl(self, name): + def ttl(self, name: KeyT) -> ResponseT: """ Returns the number of seconds until the key ``name`` will expire @@ -1860,7 +2140,7 @@ def ttl(self, name): """ return self.execute_command("TTL", name) - def type(self, name): + def type(self, name: KeyT) -> ResponseT: """ Returns the type of key ``name`` @@ -1868,7 +2148,7 @@ def type(self, name): """ return self.execute_command("TYPE", name) - def watch(self, *names): + def watch(self, *names: KeyT) -> None: """ Watches the values at keys ``names``, or None if the key doesn't exist @@ -1876,7 +2156,7 @@ def watch(self, *names): """ warnings.warn(DeprecationWarning("Call WATCH from a Pipeline object")) - def unwatch(self): + def unwatch(self) -> None: """ Unwatches the value at key ``name``, or None of the key doesn't exist @@ -1884,7 +2164,7 @@ def unwatch(self): """ warnings.warn(DeprecationWarning("Call UNWATCH from a Pipeline object")) - def unlink(self, *names): + def unlink(self, *names: KeyT) -> ResponseT: """ Unlink one or more keys specified by ``names`` @@ -1922,7 +2202,27 @@ def lcs( return self.execute_command("LCS", *pieces) -class ListCommands: +class AsyncBasicKeyCommands(BasicKeyCommands): + def __delitem__(self, name: KeyT): + raise TypeError("Async Redis client does not support class deletion") + + def __contains__(self, name: KeyT): + raise TypeError("Async Redis client does not support class inclusion") + + def __getitem__(self, name: KeyT): + raise TypeError("Async Redis client does not support class retrieval") + + def __setitem__(self, name: KeyT, value: EncodableT): + raise TypeError("Async Redis client does not support class assignment") + + async def watch(self, *names: KeyT) -> None: + return super().watch(*names) + + async def unwatch(self) -> None: + return super().unwatch() + + +class ListCommands(CommandsProtocol): """ Redis commands for List data type. see: https://redis.io/topics/data-types#lists @@ -2204,7 +2504,7 @@ def lpos( For more information check https://redis.io/commands/lpos """ - pieces = [name, value] + pieces: list[EncodableT] = [name, value] if rank is not None: pieces.extend(["RANK", rank]) @@ -2256,7 +2556,7 @@ def sort( if (start is not None and num is None) or (num is not None and start is None): raise DataError("``start`` and ``num`` must both be specified") - pieces = [name] + pieces: list[EncodableT] = [name] if by is not None: pieces.extend([b"BY", by]) if start is not None and num is not None: @@ -2289,13 +2589,23 @@ def sort( return self.execute_command("SORT", *pieces, **options) -class ScanCommands: +AsyncListCommands = ListCommands + + +class ScanCommands(CommandsProtocol): """ Redis SCAN commands. see: https://redis.io/commands/scan """ - def scan(self, cursor=0, match=None, count=None, _type=None, **kwargs): + def scan( + self, + cursor: int = 0, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + _type: Union[str, None] = None, + **kwargs, + ) -> ResponseT: """ Incrementally return lists of key names. Also return a cursor indicating the scan position. @@ -2312,7 +2622,7 @@ def scan(self, cursor=0, match=None, count=None, _type=None, **kwargs): For more information check https://redis.io/commands/scan """ - pieces = [cursor] + pieces: list[EncodableT] = [cursor] if match is not None: pieces.extend([b"MATCH", match]) if count is not None: @@ -2321,7 +2631,13 @@ def scan(self, cursor=0, match=None, count=None, _type=None, **kwargs): pieces.extend([b"TYPE", _type]) return self.execute_command("SCAN", *pieces, **kwargs) - def scan_iter(self, match=None, count=None, _type=None, **kwargs): + def scan_iter( + self, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + _type: Union[str, None] = None, + **kwargs, + ) -> Iterator: """ Make an iterator using the SCAN command so that the client doesn't need to remember the cursor position. @@ -2343,7 +2659,13 @@ def scan_iter(self, match=None, count=None, _type=None, **kwargs): ) yield from data - def sscan(self, name, cursor=0, match=None, count=None): + def sscan( + self, + name: KeyT, + cursor: int = 0, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + ) -> ResponseT: """ Incrementally return lists of elements in a set. Also return a cursor indicating the scan position. @@ -2354,14 +2676,19 @@ def sscan(self, name, cursor=0, match=None, count=None): For more information check https://redis.io/commands/sscan """ - pieces = [name, cursor] + pieces: list[EncodableT] = [name, cursor] if match is not None: pieces.extend([b"MATCH", match]) if count is not None: pieces.extend([b"COUNT", count]) return self.execute_command("SSCAN", *pieces) - def sscan_iter(self, name, match=None, count=None): + def sscan_iter( + self, + name: KeyT, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + ) -> Iterator: """ Make an iterator using the SSCAN command so that the client doesn't need to remember the cursor position. @@ -2375,7 +2702,13 @@ def sscan_iter(self, name, match=None, count=None): cursor, data = self.sscan(name, cursor=cursor, match=match, count=count) yield from data - def hscan(self, name, cursor=0, match=None, count=None): + def hscan( + self, + name: KeyT, + cursor: int = 0, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + ) -> ResponseT: """ Incrementally return key/value slices in a hash. Also return a cursor indicating the scan position. @@ -2386,14 +2719,19 @@ def hscan(self, name, cursor=0, match=None, count=None): For more information check https://redis.io/commands/hscan """ - pieces = [name, cursor] + pieces: list[EncodableT] = [name, cursor] if match is not None: pieces.extend([b"MATCH", match]) if count is not None: pieces.extend([b"COUNT", count]) return self.execute_command("HSCAN", *pieces) - def hscan_iter(self, name, match=None, count=None): + def hscan_iter( + self, + name: str, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + ) -> Iterator: """ Make an iterator using the HSCAN command so that the client doesn't need to remember the cursor position. @@ -2407,7 +2745,14 @@ def hscan_iter(self, name, match=None, count=None): cursor, data = self.hscan(name, cursor=cursor, match=match, count=count) yield from data.items() - def zscan(self, name, cursor=0, match=None, count=None, score_cast_func=float): + def zscan( + self, + name: KeyT, + cursor: int = 0, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + score_cast_func: Union[type, Callable] = float, + ) -> ResponseT: """ Incrementally return lists of elements in a sorted set. Also return a cursor indicating the scan position. @@ -2428,7 +2773,13 @@ def zscan(self, name, cursor=0, match=None, count=None, score_cast_func=float): options = {"score_cast_func": score_cast_func} return self.execute_command("ZSCAN", *pieces, **options) - def zscan_iter(self, name, match=None, count=None, score_cast_func=float): + def zscan_iter( + self, + name: KeyT, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + score_cast_func: Union[type, Callable] = float, + ) -> Iterator: """ Make an iterator using the ZSCAN command so that the client doesn't need to remember the cursor position. @@ -2451,7 +2802,111 @@ def zscan_iter(self, name, match=None, count=None, score_cast_func=float): yield from data -class SetCommands: +class AsyncScanCommands(ScanCommands): + async def scan_iter( + self, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + _type: Union[str, None] = None, + **kwargs, + ) -> AsyncIterator: + """ + Make an iterator using the SCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` provides a hint to Redis about the number of keys to + return per batch. + + ``_type`` filters the returned values by a particular Redis type. + Stock Redis instances allow for the following types: + HASH, LIST, SET, STREAM, STRING, ZSET + Additionally, Redis modules can expose other types as well. + """ + cursor = "0" + while cursor != 0: + cursor, data = await self.scan( + cursor=cursor, match=match, count=count, _type=_type, **kwargs + ) + for d in data: + yield d + + async def sscan_iter( + self, + name: KeyT, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + ) -> AsyncIterator: + """ + Make an iterator using the SSCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + """ + cursor = "0" + while cursor != 0: + cursor, data = await self.sscan( + name, cursor=cursor, match=match, count=count + ) + for d in data: + yield d + + async def hscan_iter( + self, + name: str, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + ) -> AsyncIterator: + """ + Make an iterator using the HSCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + """ + cursor = "0" + while cursor != 0: + cursor, data = await self.hscan( + name, cursor=cursor, match=match, count=count + ) + for it in data.items(): + yield it + + async def zscan_iter( + self, + name: KeyT, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + score_cast_func: Union[type, Callable] = float, + ) -> AsyncIterator: + """ + Make an iterator using the ZSCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + + ``score_cast_func`` a callable used to cast the score return value + """ + cursor = "0" + while cursor != 0: + cursor, data = await self.zscan( + name, + cursor=cursor, + match=match, + count=count, + score_cast_func=score_cast_func, + ) + for d in data: + yield d + + +class SetCommands(CommandsProtocol): """ Redis commands for Set data type. see: https://redis.io/topics/data-types#sets @@ -2612,13 +3067,21 @@ def sunionstore(self, dest: str, keys: List, *args: List) -> int: return self.execute_command("SUNIONSTORE", dest, *args) -class StreamCommands: +AsyncSetCommands = SetCommands + + +class StreamCommands(CommandsProtocol): """ Redis commands for Stream data type. see: https://redis.io/topics/streams-intro """ - def xack(self, name, groupname, *ids): + def xack( + self, + name: KeyT, + groupname: GroupT, + *ids: StreamIdT, + ) -> ResponseT: """ Acknowledges the successful processing of one or more messages. name: name of the stream. @@ -2631,15 +3094,15 @@ def xack(self, name, groupname, *ids): def xadd( self, - name, - fields, - id="*", - maxlen=None, - approximate=True, - nomkstream=False, - minid=None, - limit=None, - ): + name: KeyT, + fields: Dict[FieldT, EncodableT], + id: StreamIdT = "*", + maxlen: Union[int, None] = None, + approximate: bool = True, + nomkstream: bool = False, + minid: Union[StreamIdT, None] = None, + limit: Union[int, None] = None, + ) -> ResponseT: """ Add to a stream. name: name of the stream @@ -2655,7 +3118,7 @@ def xadd( For more information check https://redis.io/commands/xadd """ - pieces = [] + pieces: list[EncodableT] = [] if maxlen is not None and minid is not None: raise DataError( "Only one of ```maxlen``` or ```minid``` " "may be specified" @@ -2686,14 +3149,14 @@ def xadd( def xautoclaim( self, - name, - groupname, - consumername, - min_idle_time, - start_id=0, - count=None, - justid=False, - ): + name: KeyT, + groupname: GroupT, + consumername: ConsumerT, + min_idle_time: int, + start_id: int = 0, + count: Union[int, None] = None, + justid: bool = False, + ) -> ResponseT: """ Transfers ownership of pending stream entries that match the specified criteria. Conceptually, equivalent to calling XPENDING and then XCLAIM, @@ -2737,17 +3200,17 @@ def xautoclaim( def xclaim( self, - name, - groupname, - consumername, - min_idle_time, - message_ids, - idle=None, - time=None, - retrycount=None, - force=False, - justid=False, - ): + name: KeyT, + groupname: GroupT, + consumername: ConsumerT, + min_idle_time: int, + message_ids: [List[StreamIdT], Tuple[StreamIdT]], + idle: Union[int, None] = None, + time: Union[int, None] = None, + retrycount: Union[int, None] = None, + force: bool = False, + justid: bool = False, + ) -> ResponseT: """ Changes the ownership of a pending message. name: name of the stream. @@ -2781,7 +3244,7 @@ def xclaim( ) kwargs = {} - pieces = [name, groupname, consumername, str(min_idle_time)] + pieces: list[EncodableT] = [name, groupname, consumername, str(min_idle_time)] pieces.extend(list(message_ids)) if idle is not None: @@ -2808,7 +3271,7 @@ def xclaim( kwargs["parse_justid"] = True return self.execute_command("XCLAIM", *pieces, **kwargs) - def xdel(self, name, *ids): + def xdel(self, name: KeyT, *ids: StreamIdT) -> ResponseT: """ Deletes one or more messages from a stream. name: name of the stream. @@ -2818,7 +3281,13 @@ def xdel(self, name, *ids): """ return self.execute_command("XDEL", name, *ids) - def xgroup_create(self, name, groupname, id="$", mkstream=False): + def xgroup_create( + self, + name: KeyT, + groupname: GroupT, + id: StreamIdT = "$", + mkstream: bool = False, + ) -> ResponseT: """ Create a new consumer group associated with a stream. name: name of the stream. @@ -2827,12 +3296,17 @@ def xgroup_create(self, name, groupname, id="$", mkstream=False): For more information check https://redis.io/commands/xgroup-create """ - pieces = ["XGROUP CREATE", name, groupname, id] + pieces: list[EncodableT] = ["XGROUP CREATE", name, groupname, id] if mkstream: pieces.append(b"MKSTREAM") return self.execute_command(*pieces) - def xgroup_delconsumer(self, name, groupname, consumername): + def xgroup_delconsumer( + self, + name: KeyT, + groupname: GroupT, + consumername: ConsumerT, + ) -> ResponseT: """ Remove a specific consumer from a consumer group. Returns the number of pending messages that the consumer had before it @@ -2845,7 +3319,7 @@ def xgroup_delconsumer(self, name, groupname, consumername): """ return self.execute_command("XGROUP DELCONSUMER", name, groupname, consumername) - def xgroup_destroy(self, name, groupname): + def xgroup_destroy(self, name: KeyT, groupname: GroupT) -> ResponseT: """ Destroy a consumer group. name: name of the stream. @@ -2855,7 +3329,12 @@ def xgroup_destroy(self, name, groupname): """ return self.execute_command("XGROUP DESTROY", name, groupname) - def xgroup_createconsumer(self, name, groupname, consumername): + def xgroup_createconsumer( + self, + name: KeyT, + groupname: GroupT, + consumername: ConsumerT, + ) -> ResponseT: """ Consumers in a consumer group are auto-created every time a new consumer name is mentioned by some command. @@ -2870,7 +3349,12 @@ def xgroup_createconsumer(self, name, groupname, consumername): "XGROUP CREATECONSUMER", name, groupname, consumername ) - def xgroup_setid(self, name, groupname, id): + def xgroup_setid( + self, + name: KeyT, + groupname: GroupT, + id: StreamIdT, + ) -> ResponseT: """ Set the consumer group last delivered ID to something else. name: name of the stream. @@ -2881,7 +3365,7 @@ def xgroup_setid(self, name, groupname, id): """ return self.execute_command("XGROUP SETID", name, groupname, id) - def xinfo_consumers(self, name, groupname): + def xinfo_consumers(self, name: KeyT, groupname: GroupT) -> ResponseT: """ Returns general information about the consumers in the group. name: name of the stream. @@ -2891,7 +3375,7 @@ def xinfo_consumers(self, name, groupname): """ return self.execute_command("XINFO CONSUMERS", name, groupname) - def xinfo_groups(self, name): + def xinfo_groups(self, name: KeyT) -> ResponseT: """ Returns general information about the consumer groups of the stream. name: name of the stream. @@ -2900,7 +3384,7 @@ def xinfo_groups(self, name): """ return self.execute_command("XINFO GROUPS", name) - def xinfo_stream(self, name, full=False): + def xinfo_stream(self, name: KeyT, full: bool = False) -> ResponseT: """ Returns general information about the stream. name: name of the stream. @@ -2915,7 +3399,7 @@ def xinfo_stream(self, name, full=False): options = {"full": full} return self.execute_command("XINFO STREAM", *pieces, **options) - def xlen(self, name): + def xlen(self, name: KeyT) -> ResponseT: """ Returns the number of elements in a given stream. @@ -2923,7 +3407,7 @@ def xlen(self, name): """ return self.execute_command("XLEN", name) - def xpending(self, name, groupname): + def xpending(self, name: KeyT, groupname: GroupT) -> ResponseT: """ Returns information about pending messages of a group. name: name of the stream. @@ -2935,14 +3419,14 @@ def xpending(self, name, groupname): def xpending_range( self, - name, - groupname, - idle=None, - min=None, - max=None, - count=None, - consumername=None, - ): + name: KeyT, + groupname: GroupT, + min: StreamIdT, + max: StreamIdT, + count: int, + consumername: Union[ConsumerT, None] = None, + idle: Union[int, None] = None, + ) -> ResponseT: """ Returns information about pending messages, in a range. @@ -2990,7 +3474,13 @@ def xpending_range( return self.execute_command("XPENDING", *pieces, parse_detail=True) - def xrange(self, name, min="-", max="+", count=None): + def xrange( + self, + name: KeyT, + min: StreamIdT = "-", + max: StreamIdT = "+", + count: Union[int, None] = None, + ) -> ResponseT: """ Read stream values within an interval. name: name of the stream. @@ -3012,7 +3502,12 @@ def xrange(self, name, min="-", max="+", count=None): return self.execute_command("XRANGE", name, *pieces) - def xread(self, streams, count=None, block=None): + def xread( + self, + streams: Dict[KeyT, StreamIdT], + count: Union[int, None] = None, + block: Union[int, None] = None, + ) -> ResponseT: """ Block and monitor multiple streams for new data. streams: a dict of stream names to stream IDs, where @@ -3043,8 +3538,14 @@ def xread(self, streams, count=None, block=None): return self.execute_command("XREAD", *pieces) def xreadgroup( - self, groupname, consumername, streams, count=None, block=None, noack=False - ): + self, + groupname: str, + consumername: str, + streams: Dict[KeyT, StreamIdT], + count: Union[int, None] = None, + block: Union[int, None] = None, + noack: bool = False, + ) -> ResponseT: """ Read from a stream via a consumer group. groupname: name of the consumer group. @@ -3058,7 +3559,7 @@ def xreadgroup( For more information check https://redis.io/commands/xreadgroup """ - pieces = [b"GROUP", groupname, consumername] + pieces: list[EncodableT] = [b"GROUP", groupname, consumername] if count is not None: if not isinstance(count, int) or count < 1: raise DataError("XREADGROUP count must be a positive integer") @@ -3078,7 +3579,13 @@ def xreadgroup( pieces.extend(streams.values()) return self.execute_command("XREADGROUP", *pieces) - def xrevrange(self, name, max="+", min="-", count=None): + def xrevrange( + self, + name: KeyT, + max: StreamIdT = "+", + min: StreamIdT = "-", + count: Union[int, None] = None, + ) -> ResponseT: """ Read stream values within an interval, in reverse order. name: name of the stream @@ -3091,7 +3598,7 @@ def xrevrange(self, name, max="+", min="-", count=None): For more information check https://redis.io/commands/xrevrange """ - pieces = [max, min] + pieces: list[EncodableT] = [max, min] if count is not None: if not isinstance(count, int) or count < 1: raise DataError("XREVRANGE count must be a positive integer") @@ -3100,7 +3607,14 @@ def xrevrange(self, name, max="+", min="-", count=None): return self.execute_command("XREVRANGE", name, *pieces) - def xtrim(self, name, maxlen=None, approximate=True, minid=None, limit=None): + def xtrim( + self, + name: KeyT, + maxlen: int, + approximate: bool = True, + minid: Union[StreamIdT, None] = None, + limit: Union[int, None] = None, + ) -> ResponseT: """ Trims old messages from a stream. name: name of the stream. @@ -3113,7 +3627,7 @@ def xtrim(self, name, maxlen=None, approximate=True, minid=None, limit=None): For more information check https://redis.io/commands/xtrim """ - pieces = [] + pieces: list[EncodableT] = [] if maxlen is not None and minid is not None: raise DataError("Only one of ``maxlen`` or ``minid`` " "may be specified") @@ -3134,15 +3648,26 @@ def xtrim(self, name, maxlen=None, approximate=True, minid=None, limit=None): return self.execute_command("XTRIM", name, *pieces) -class SortedSetCommands: +AsyncStreamCommands = StreamCommands + + +class SortedSetCommands(CommandsProtocol): """ Redis commands for Sorted Sets data type. see: https://redis.io/topics/data-types-intro#redis-sorted-sets """ def zadd( - self, name, mapping, nx=False, xx=False, ch=False, incr=False, gt=None, lt=None - ): + self, + name: KeyT, + mapping: Mapping[AnyKeyT, EncodableT], + nx: bool = False, + xx: bool = False, + ch: bool = False, + incr: bool = False, + gt: bool = None, + lt: bool = None, + ) -> ResponseT: """ Set any number of element-name, score pairs to the key ``name``. Pairs are specified as a dict of element-names keys to score values. @@ -3188,7 +3713,7 @@ def zadd( if nx is True and (gt is not None or lt is not None): raise DataError("Only one of 'nx', 'lt', or 'gr' may be defined.") - pieces = [] + pieces: list[EncodableT] = [] options = {} if nx: pieces.append(b"NX") @@ -3208,7 +3733,7 @@ def zadd( pieces.append(pair[0]) return self.execute_command("ZADD", name, *pieces, **options) - def zcard(self, name): + def zcard(self, name: KeyT): """ Return the number of elements in the sorted set ``name`` @@ -3216,7 +3741,7 @@ def zcard(self, name): """ return self.execute_command("ZCARD", name) - def zcount(self, name, min, max): + def zcount(self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT) -> ResponseT: """ Returns the number of elements in the sorted set at key ``name`` with a score between ``min`` and ``max``. @@ -3225,7 +3750,7 @@ def zcount(self, name, min, max): """ return self.execute_command("ZCOUNT", name, min, max) - def zdiff(self, keys, withscores=False): + def zdiff(self, keys: KeysT, withscores: bool = False) -> ResponseT: """ Returns the difference between the first and all successive input sorted sets provided in ``keys``. @@ -3237,7 +3762,7 @@ def zdiff(self, keys, withscores=False): pieces.append("WITHSCORES") return self.execute_command("ZDIFF", *pieces) - def zdiffstore(self, dest, keys): + def zdiffstore(self, dest: KeyT, keys: KeysT) -> ResponseT: """ Computes the difference between the first and all successive input sorted sets provided in ``keys`` and stores the result in ``dest``. @@ -3247,7 +3772,12 @@ def zdiffstore(self, dest, keys): pieces = [len(keys), *keys] return self.execute_command("ZDIFFSTORE", dest, *pieces) - def zincrby(self, name, amount, value): + def zincrby( + self, + name: KeyT, + amount: float, + value: EncodableT, + ) -> ResponseT: """ Increment the score of ``value`` in sorted set ``name`` by ``amount`` @@ -3255,7 +3785,12 @@ def zincrby(self, name, amount, value): """ return self.execute_command("ZINCRBY", name, amount, value) - def zinter(self, keys, aggregate=None, withscores=False): + def zinter( + self, + keys: KeysT, + aggregate: Union[str, None] = None, + withscores: bool = False, + ) -> ResponseT: """ Return the intersect of multiple sorted sets specified by ``keys``. With the ``aggregate`` option, it is possible to specify how the @@ -3269,7 +3804,12 @@ def zinter(self, keys, aggregate=None, withscores=False): """ return self._zaggregate("ZINTER", None, keys, aggregate, withscores=withscores) - def zinterstore(self, dest, keys, aggregate=None): + def zinterstore( + self, + dest: KeyT, + keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], + aggregate: Union[str, None] = None, + ) -> ResponseT: """ Intersect multiple sorted sets specified by ``keys`` into a new sorted set, ``dest``. Scores in the destination will be aggregated @@ -3305,7 +3845,11 @@ def zlexcount(self, name, min, max): """ return self.execute_command("ZLEXCOUNT", name, min, max) - def zpopmax(self, name, count=None): + def zpopmax( + self, + name: KeyT, + count: Union[int, None] = None, + ) -> ResponseT: """ Remove and return up to ``count`` members with the highest scores from the sorted set ``name``. @@ -3316,7 +3860,11 @@ def zpopmax(self, name, count=None): options = {"withscores": True} return self.execute_command("ZPOPMAX", name, *args, **options) - def zpopmin(self, name, count=None): + def zpopmin( + self, + name: KeyT, + count: Union[int, None] = None, + ) -> ResponseT: """ Remove and return up to ``count`` members with the lowest scores from the sorted set ``name``. @@ -3327,7 +3875,12 @@ def zpopmin(self, name, count=None): options = {"withscores": True} return self.execute_command("ZPOPMIN", name, *args, **options) - def zrandmember(self, key, count=None, withscores=False): + def zrandmember( + self, + key: KeyT, + count: int = None, + withscores: bool = False, + ) -> ResponseT: """ Return a random element from the sorted set value stored at key. @@ -3351,7 +3904,7 @@ def zrandmember(self, key, count=None, withscores=False): return self.execute_command("ZRANDMEMBER", key, *params) - def bzpopmax(self, keys, timeout=0): + def bzpopmax(self, keys: KeysT, timeout: TimeoutSecT = 0) -> ResponseT: """ ZPOPMAX a value off of the first non-empty sorted set named in the ``keys`` list. @@ -3370,7 +3923,7 @@ def bzpopmax(self, keys, timeout=0): keys.append(timeout) return self.execute_command("BZPOPMAX", *keys) - def bzpopmin(self, keys, timeout=0): + def bzpopmin(self, keys: KeysT, timeout: TimeoutSecT = 0) -> ResponseT: """ ZPOPMIN a value off of the first non-empty sorted set named in the ``keys`` list. @@ -3385,7 +3938,7 @@ def bzpopmin(self, keys, timeout=0): """ if timeout is None: timeout = 0 - keys = list_or_args(keys, None) + keys: list[EncodableT] = list_or_args(keys, None) keys.append(timeout) return self.execute_command("BZPOPMIN", *keys) @@ -3449,18 +4002,18 @@ def bzmpop( def _zrange( self, command, - dest, - name, - start, - end, - desc=False, - byscore=False, - bylex=False, - withscores=False, - score_cast_func=float, - offset=None, - num=None, - ): + dest: Union[KeyT, None], + name: KeyT, + start: int, + end: int, + desc: bool = False, + byscore: bool = False, + bylex: bool = False, + withscores: bool = False, + score_cast_func: Union[type, Callable, None] = float, + offset: Union[int, None] = None, + num: Union[int, None] = None, + ) -> ResponseT: if byscore and bylex: raise DataError( "``byscore`` and ``bylex`` can not be " "specified together." @@ -3490,17 +4043,17 @@ def _zrange( def zrange( self, - name, - start, - end, - desc=False, - withscores=False, - score_cast_func=float, - byscore=False, - bylex=False, - offset=None, - num=None, - ): + name: KeyT, + start: int, + end: int, + desc: bool = False, + withscores: bool = False, + score_cast_func: Union[type, Callable] = float, + byscore: bool = False, + bylex: bool = False, + offset: int = None, + num: int = None, + ) -> ResponseT: """ Return a range of values from sorted set ``name`` between ``start`` and ``end`` sorted in ascending order. @@ -3549,7 +4102,14 @@ def zrange( num, ) - def zrevrange(self, name, start, end, withscores=False, score_cast_func=float): + def zrevrange( + self, + name: KeyT, + start: int, + end: int, + withscores: bool = False, + score_cast_func: Union[type, Callable] = float, + ) -> ResponseT: """ Return a range of values from sorted set ``name`` between ``start`` and ``end`` sorted in descending order. @@ -3571,16 +4131,16 @@ def zrevrange(self, name, start, end, withscores=False, score_cast_func=float): def zrangestore( self, - dest, - name, - start, - end, - byscore=False, - bylex=False, - desc=False, - offset=None, - num=None, - ): + dest: KeyT, + name: KeyT, + start: int, + end: int, + byscore: bool = False, + bylex: bool = False, + desc: bool = False, + offset: Union[int, None] = None, + num: Union[int, None] = None, + ) -> ResponseT: """ Stores in ``dest`` the result of a range of values from sorted set ``name`` between ``start`` and ``end`` sorted in ascending order. @@ -3619,7 +4179,14 @@ def zrangestore( num, ) - def zrangebylex(self, name, min, max, start=None, num=None): + def zrangebylex( + self, + name: KeyT, + min: EncodableT, + max: EncodableT, + start: Union[int, None] = None, + num: Union[int, None] = None, + ) -> ResponseT: """ Return the lexicographical range of values from sorted set ``name`` between ``min`` and ``max``. @@ -3636,7 +4203,14 @@ def zrangebylex(self, name, min, max, start=None, num=None): pieces.extend([b"LIMIT", start, num]) return self.execute_command(*pieces) - def zrevrangebylex(self, name, max, min, start=None, num=None): + def zrevrangebylex( + self, + name: KeyT, + max: EncodableT, + min: EncodableT, + start: Union[int, None] = None, + num: Union[int, None] = None, + ) -> ResponseT: """ Return the reversed lexicographical range of values from sorted set ``name`` between ``max`` and ``min``. @@ -3655,14 +4229,14 @@ def zrevrangebylex(self, name, max, min, start=None, num=None): def zrangebyscore( self, - name, - min, - max, - start=None, - num=None, - withscores=False, - score_cast_func=float, - ): + name: KeyT, + min: ZScoreBoundT, + max: ZScoreBoundT, + start: Union[int, None] = None, + num: Union[int, None] = None, + withscores: bool = False, + score_cast_func: Union[type, Callable] = float, + ) -> ResponseT: """ Return a range of values from the sorted set ``name`` with scores between ``min`` and ``max``. @@ -3689,13 +4263,13 @@ def zrangebyscore( def zrevrangebyscore( self, - name, - max, - min, - start=None, - num=None, - withscores=False, - score_cast_func=float, + name: KeyT, + max: ZScoreBoundT, + min: ZScoreBoundT, + start: Union[int, None] = None, + num: Union[int, None] = None, + withscores: bool = False, + score_cast_func: Union[type, Callable] = float, ): """ Return a range of values from the sorted set ``name`` with scores @@ -3721,7 +4295,7 @@ def zrevrangebyscore( options = {"withscores": withscores, "score_cast_func": score_cast_func} return self.execute_command(*pieces, **options) - def zrank(self, name, value): + def zrank(self, name: KeyT, value: EncodableT) -> ResponseT: """ Returns a 0-based value indicating the rank of ``value`` in sorted set ``name`` @@ -3730,7 +4304,7 @@ def zrank(self, name, value): """ return self.execute_command("ZRANK", name, value) - def zrem(self, name, *values): + def zrem(self, name: KeyT, *values: EncodableT) -> ResponseT: """ Remove member ``values`` from sorted set ``name`` @@ -3738,7 +4312,7 @@ def zrem(self, name, *values): """ return self.execute_command("ZREM", name, *values) - def zremrangebylex(self, name, min, max): + def zremrangebylex(self, name: KeyT, min: EncodableT, max: EncodableT) -> ResponseT: """ Remove all elements in the sorted set ``name`` between the lexicographical range specified by ``min`` and ``max``. @@ -3749,7 +4323,7 @@ def zremrangebylex(self, name, min, max): """ return self.execute_command("ZREMRANGEBYLEX", name, min, max) - def zremrangebyrank(self, name, min, max): + def zremrangebyrank(self, name: KeyT, min: int, max: int) -> ResponseT: """ Remove all elements in the sorted set ``name`` with ranks between ``min`` and ``max``. Values are 0-based, ordered from smallest score @@ -3760,7 +4334,9 @@ def zremrangebyrank(self, name, min, max): """ return self.execute_command("ZREMRANGEBYRANK", name, min, max) - def zremrangebyscore(self, name, min, max): + def zremrangebyscore( + self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT + ) -> ResponseT: """ Remove all elements in the sorted set ``name`` with scores between ``min`` and ``max``. Returns the number of elements removed. @@ -3769,7 +4345,7 @@ def zremrangebyscore(self, name, min, max): """ return self.execute_command("ZREMRANGEBYSCORE", name, min, max) - def zrevrank(self, name, value): + def zrevrank(self, name: KeyT, value: EncodableT) -> ResponseT: """ Returns a 0-based value indicating the descending rank of ``value`` in sorted set ``name`` @@ -3778,7 +4354,7 @@ def zrevrank(self, name, value): """ return self.execute_command("ZREVRANK", name, value) - def zscore(self, name, value): + def zscore(self, name: KeyT, value: EncodableT) -> ResponseT: """ Return the score of element ``value`` in sorted set ``name`` @@ -3786,7 +4362,12 @@ def zscore(self, name, value): """ return self.execute_command("ZSCORE", name, value) - def zunion(self, keys, aggregate=None, withscores=False): + def zunion( + self, + keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], + aggregate: Union[str, None] = None, + withscores: bool = False, + ) -> ResponseT: """ Return the union of multiple sorted sets specified by ``keys``. ``keys`` can be provided as dictionary of keys and their weights. @@ -3797,7 +4378,12 @@ def zunion(self, keys, aggregate=None, withscores=False): """ return self._zaggregate("ZUNION", None, keys, aggregate, withscores=withscores) - def zunionstore(self, dest, keys, aggregate=None): + def zunionstore( + self, + dest: KeyT, + keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], + aggregate: Union[str, None] = None, + ) -> ResponseT: """ Union multiple sorted sets specified by ``keys`` into a new sorted set, ``dest``. Scores in the destination will be @@ -3807,7 +4393,11 @@ def zunionstore(self, dest, keys, aggregate=None): """ return self._zaggregate("ZUNIONSTORE", dest, keys, aggregate) - def zmscore(self, key, members): + def zmscore( + self, + key: KeyT, + members: List[str], + ) -> ResponseT: """ Returns the scores associated with the specified members in the sorted set stored at key. @@ -3823,8 +4413,15 @@ def zmscore(self, key, members): pieces = [key] + members return self.execute_command("ZMSCORE", *pieces) - def _zaggregate(self, command, dest, keys, aggregate=None, **options): - pieces = [command] + def _zaggregate( + self, + command: str, + dest: Union[KeyT, None], + keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], + aggregate: Union[str, None] = None, + **options, + ) -> ResponseT: + pieces: list[EncodableT] = [command] if dest is not None: pieces.append(dest) pieces.append(len(keys)) @@ -3847,13 +4444,16 @@ def _zaggregate(self, command, dest, keys, aggregate=None, **options): return self.execute_command(*pieces, **options) -class HyperlogCommands: +AsyncSortedSetCommands = SortedSetCommands + + +class HyperlogCommands(CommandsProtocol): """ Redis commands of HyperLogLogs data type. see: https://redis.io/topics/data-types-intro#hyperloglogs """ - def pfadd(self, name, *values): + def pfadd(self, name: KeyT, *values: EncodableT) -> ResponseT: """ Adds the specified elements to the specified HyperLogLog. @@ -3861,7 +4461,7 @@ def pfadd(self, name, *values): """ return self.execute_command("PFADD", name, *values) - def pfcount(self, *sources): + def pfcount(self, *sources: KeyT) -> ResponseT: """ Return the approximated cardinality of the set observed by the HyperLogLog at key(s). @@ -3870,7 +4470,7 @@ def pfcount(self, *sources): """ return self.execute_command("PFCOUNT", *sources) - def pfmerge(self, dest, *sources): + def pfmerge(self, dest: KeyT, *sources: KeyT) -> ResponseT: """ Merge N different HyperLogLogs into a single one. @@ -3879,7 +4479,10 @@ def pfmerge(self, dest, *sources): return self.execute_command("PFMERGE", dest, *sources) -class HashCommands: +AsyncHyperlogCommands = HyperlogCommands + + +class HashCommands(CommandsProtocol): """ Redis commands for Hash data type. see: https://redis.io/topics/data-types-intro#redis-hashes @@ -4031,13 +4634,106 @@ def hstrlen(self, name: str, key: str) -> int: return self.execute_command("HSTRLEN", name, key) -class PubSubCommands: +AsyncHashCommands = HashCommands + + +class Script: + """ + An executable Lua script object returned by ``register_script`` + """ + + def __init__(self, registered_client: "Redis", script: ScriptTextT): + self.registered_client = registered_client + self.script = script + # Precalculate and store the SHA1 hex digest of the script. + + if isinstance(script, str): + # We need the encoding from the client in order to generate an + # accurate byte representation of the script + encoder = registered_client.connection_pool.get_encoder() + script = encoder.encode(script) + self.sha = hashlib.sha1(script).hexdigest() + + def __call__( + self, + keys: Union[Sequence[KeyT], None] = None, + args: Union[Iterable[EncodableT], None] = None, + client: Union["Redis", None] = None, + ): + """Execute the script, passing any required ``args``""" + keys = keys or [] + args = args or [] + if client is None: + client = self.registered_client + args = tuple(keys) + tuple(args) + # make sure the Redis server knows about the script + from redis.client import Pipeline + + if isinstance(client, Pipeline): + # Make sure the pipeline can register the script before executing. + client.scripts.add(self) + try: + return client.evalsha(self.sha, len(keys), *args) + except NoScriptError: + # Maybe the client is pointed to a different server than the client + # that created this instance? + # Overwrite the sha just in case there was a discrepancy. + self.sha = client.script_load(self.script) + return client.evalsha(self.sha, len(keys), *args) + + +class AsyncScript: + """ + An executable Lua script object returned by ``register_script`` + """ + + def __init__(self, registered_client: "AsyncRedis", script: ScriptTextT): + self.registered_client = registered_client + self.script = script + # Precalculate and store the SHA1 hex digest of the script. + + if isinstance(script, str): + # We need the encoding from the client in order to generate an + # accurate byte representation of the script + encoder = registered_client.connection_pool.get_encoder() + script = encoder.encode(script) + self.sha = hashlib.sha1(script).hexdigest() + + async def __call__( + self, + keys: Union[Sequence[KeyT], None] = None, + args: Union[Iterable[EncodableT], None] = None, + client: Union["AsyncRedis", None] = None, + ): + """Execute the script, passing any required ``args``""" + keys = keys or [] + args = args or [] + if client is None: + client = self.registered_client + args = tuple(keys) + tuple(args) + # make sure the Redis server knows about the script + from redis.asyncio.client import Pipeline + + if isinstance(client, Pipeline): + # Make sure the pipeline can register the script before executing. + client.scripts.add(self) + try: + return await client.evalsha(self.sha, len(keys), *args) + except NoScriptError: + # Maybe the client is pointed to a different server than the client + # that created this instance? + # Overwrite the sha just in case there was a discrepancy. + self.sha = await client.script_load(self.script) + return await client.evalsha(self.sha, len(keys), *args) + + +class PubSubCommands(CommandsProtocol): """ Redis PubSub commands. see https://redis.io/topics/pubsub """ - def publish(self, channel, message, **kwargs): + def publish(self, channel: ChannelT, message: EncodableT, **kwargs) -> ResponseT: """ Publish ``message`` on ``channel``. Returns the number of subscribers the message was delivered to. @@ -4046,7 +4742,7 @@ def publish(self, channel, message, **kwargs): """ return self.execute_command("PUBLISH", channel, message, **kwargs) - def pubsub_channels(self, pattern="*", **kwargs): + def pubsub_channels(self, pattern: PatternT = "*", **kwargs) -> ResponseT: """ Return a list of channels that have at least one subscriber @@ -4054,7 +4750,7 @@ def pubsub_channels(self, pattern="*", **kwargs): """ return self.execute_command("PUBSUB CHANNELS", pattern, **kwargs) - def pubsub_numpat(self, **kwargs): + def pubsub_numpat(self, **kwargs) -> ResponseT: """ Returns the number of subscriptions to patterns @@ -4062,7 +4758,7 @@ def pubsub_numpat(self, **kwargs): """ return self.execute_command("PUBSUB NUMPAT", **kwargs) - def pubsub_numsub(self, *args, **kwargs): + def pubsub_numsub(self, *args: ChannelT, **kwargs) -> ResponseT: """ Return a list of (channel, number of subscribers) tuples for each channel given in ``*args`` @@ -4072,7 +4768,10 @@ def pubsub_numsub(self, *args, **kwargs): return self.execute_command("PUBSUB NUMSUB", *args, **kwargs) -class ScriptCommands: +AsyncPubSubCommands = PubSubCommands + + +class ScriptCommands(CommandsProtocol): """ Redis Lua script commands. see: https://redis.com/ebook/part-3-next-steps/chapter-11-scripting-redis-with-lua/ @@ -4140,7 +4839,7 @@ def evalsha_ro(self, sha: str, numkeys: int, *keys_and_args: list) -> str: """ return self._evalsha("EVALSHA_RO", sha, numkeys, *keys_and_args) - def script_exists(self, *args): + def script_exists(self, *args: str): """ Check if a script exists in the script cache by specifying the SHAs of each script as ``args``. Returns a list of boolean values indicating if @@ -4150,12 +4849,14 @@ def script_exists(self, *args): """ return self.execute_command("SCRIPT EXISTS", *args) - def script_debug(self, *args): + def script_debug(self, *args) -> None: raise NotImplementedError( "SCRIPT DEBUG is intentionally not implemented in the client." ) - def script_flush(self, sync_type=None): + def script_flush( + self, sync_type: Union[Literal["SYNC"], Literal["ASYNC"]] = None + ) -> ResponseT: """Flush all scripts from the script cache. ``sync_type`` is by default SYNC (synchronous) but it can also be ASYNC. @@ -4175,7 +4876,7 @@ def script_flush(self, sync_type=None): pieces = [sync_type] return self.execute_command("SCRIPT FLUSH", *pieces) - def script_kill(self): + def script_kill(self) -> ResponseT: """ Kill the currently executing Lua script @@ -4183,7 +4884,7 @@ def script_kill(self): """ return self.execute_command("SCRIPT KILL") - def script_load(self, script): + def script_load(self, script: ScriptTextT) -> ResponseT: """ Load a Lua ``script`` into the script cache. Returns the SHA. @@ -4191,7 +4892,7 @@ def script_load(self, script): """ return self.execute_command("SCRIPT LOAD", script) - def register_script(self, script): + def register_script(self: "Redis", script: ScriptTextT) -> Script: """ Register a Lua ``script`` specifying the ``keys`` it will touch. Returns a Script object that is callable and hides the complexity of @@ -4201,13 +4902,34 @@ def register_script(self, script): return Script(self, script) -class GeoCommands: +class AsyncScriptCommands(ScriptCommands): + async def script_debug(self, *args) -> None: + return super().script_debug() + + def register_script(self: "AsyncRedis", script: ScriptTextT) -> AsyncScript: + """ + Register a Lua ``script`` specifying the ``keys`` it will touch. + Returns a Script object that is callable and hides the complexity of + deal with scripts, keys, and shas. This is the preferred way to work + with Lua scripts. + """ + return AsyncScript(self, script) + + +class GeoCommands(CommandsProtocol): """ Redis Geospatial commands. see: https://redis.com/redis-best-practices/indexing-patterns/geospatial/ """ - def geoadd(self, name, values, nx=False, xx=False, ch=False): + def geoadd( + self, + name: KeyT, + values: Sequence[EncodableT], + nx: bool = False, + xx: bool = False, + ch: bool = False, + ) -> ResponseT: """ Add the specified geospatial items to the specified key identified by the ``name`` argument. The Geospatial items are given as ordered @@ -4242,7 +4964,13 @@ def geoadd(self, name, values, nx=False, xx=False, ch=False): pieces.extend(values) return self.execute_command("GEOADD", *pieces) - def geodist(self, name, place1, place2, unit=None): + def geodist( + self, + name: KeyT, + place1: FieldT, + place2: FieldT, + unit: Union[str, None] = None, + ) -> ResponseT: """ Return the distance between ``place1`` and ``place2`` members of the ``name`` key. @@ -4251,14 +4979,14 @@ def geodist(self, name, place1, place2, unit=None): For more information check https://redis.io/commands/geodist """ - pieces = [name, place1, place2] + pieces: list[EncodableT] = [name, place1, place2] if unit and unit not in ("m", "km", "mi", "ft"): raise DataError("GEODIST invalid unit") elif unit: pieces.append(unit) return self.execute_command("GEODIST", *pieces) - def geohash(self, name, *values): + def geohash(self, name: KeyT, *values: FieldT) -> ResponseT: """ Return the geo hash string for each item of ``values`` members of the specified key identified by the ``name`` argument. @@ -4267,7 +4995,7 @@ def geohash(self, name, *values): """ return self.execute_command("GEOHASH", name, *values) - def geopos(self, name, *values): + def geopos(self, name: KeyT, *values: FieldT) -> ResponseT: """ Return the positions of each item of ``values`` as members of the specified key identified by the ``name`` argument. Each position @@ -4279,20 +5007,20 @@ def geopos(self, name, *values): def georadius( self, - name, - longitude, - latitude, - radius, - unit=None, - withdist=False, - withcoord=False, - withhash=False, - count=None, - sort=None, - store=None, - store_dist=None, - any=False, - ): + name: KeyT, + longitude: float, + latitude: float, + radius: float, + unit: Union[str, None] = None, + withdist: bool = False, + withcoord: bool = False, + withhash: bool = False, + count: Union[int, None] = None, + sort: Union[str, None] = None, + store: Union[KeyT, None] = None, + store_dist: Union[KeyT, None] = None, + any: bool = False, + ) -> ResponseT: """ Return the members of the specified key identified by the ``name`` argument which are within the borders of the area specified @@ -4342,19 +5070,19 @@ def georadius( def georadiusbymember( self, - name, - member, - radius, - unit=None, - withdist=False, - withcoord=False, - withhash=False, - count=None, - sort=None, - store=None, - store_dist=None, - any=False, - ): + name: KeyT, + member: FieldT, + radius: float, + unit: Union[str, None] = None, + withdist: bool = False, + withcoord: bool = False, + withhash: bool = False, + count: Union[int, None] = None, + sort: Union[str, None] = None, + store: Union[KeyT, None] = None, + store_dist: Union[KeyT, None] = None, + any: bool = False, + ) -> ResponseT: """ This command is exactly like ``georadius`` with the sole difference that instead of taking, as the center of the area to query, a longitude @@ -4379,7 +5107,12 @@ def georadiusbymember( any=any, ) - def _georadiusgeneric(self, command, *args, **kwargs): + def _georadiusgeneric( + self, + command: str, + *args: EncodableT, + **kwargs: Union[EncodableT, None], + ) -> ResponseT: pieces = list(args) if kwargs["unit"] and kwargs["unit"] not in ("m", "km", "mi", "ft"): raise DataError("GEORADIUS invalid unit") @@ -4427,21 +5160,21 @@ def _georadiusgeneric(self, command, *args, **kwargs): def geosearch( self, - name, - member=None, - longitude=None, - latitude=None, - unit="m", - radius=None, - width=None, - height=None, - sort=None, - count=None, - any=False, - withcoord=False, - withdist=False, - withhash=False, - ): + name: KeyT, + member: Union[FieldT, None] = None, + longitude: Union[float, None] = None, + latitude: Union[float, None] = None, + unit: str = "m", + radius: Union[float, None] = None, + width: Union[float, None] = None, + height: Union[float, None] = None, + sort: Union[str, None] = None, + count: Union[int, None] = None, + any: bool = False, + withcoord: bool = False, + withdist: bool = False, + withhash: bool = False, + ) -> ResponseT: """ Return the members of specified key identified by the ``name`` argument, which are within the borders of the @@ -4499,20 +5232,20 @@ def geosearch( def geosearchstore( self, - dest, - name, - member=None, - longitude=None, - latitude=None, - unit="m", - radius=None, - width=None, - height=None, - sort=None, - count=None, - any=False, - storedist=False, - ): + dest: KeyT, + name: KeyT, + member: Union[FieldT, None] = None, + longitude: Union[float, None] = None, + latitude: Union[float, None] = None, + unit: str = "m", + radius: Union[float, None] = None, + width: Union[float, None] = None, + height: Union[float, None] = None, + sort: Union[str, None] = None, + count: Union[int, None] = None, + any: bool = False, + storedist: bool = False, + ) -> ResponseT: """ This command is like GEOSEARCH, but stores the result in ``dest``. By default, it stores the results in the destination @@ -4544,7 +5277,12 @@ def geosearchstore( store_dist=storedist, ) - def _geosearchgeneric(self, command, *args, **kwargs): + def _geosearchgeneric( + self, + command: str, + *args: EncodableT, + **kwargs: Union[EncodableT, None], + ) -> ResponseT: pieces = list(args) # FROMMEMBER or FROMLONLAT @@ -4609,13 +5347,16 @@ def _geosearchgeneric(self, command, *args, **kwargs): return self.execute_command(command, *pieces, **kwargs) -class ModuleCommands: +AsyncGeoCommands = GeoCommands + + +class ModuleCommands(CommandsProtocol): """ Redis Module commands. see: https://redis.io/topics/modules-intro """ - def module_load(self, path, *args): + def module_load(self, path, *args) -> ResponseT: """ Loads the module from ``path``. Passes all ``*args`` to the module, during loading. @@ -4625,7 +5366,7 @@ def module_load(self, path, *args): """ return self.execute_command("MODULE LOAD", path, *args) - def module_unload(self, name): + def module_unload(self, name) -> ResponseT: """ Unloads the module ``name``. Raises ``ModuleError`` if ``name`` is not in loaded modules. @@ -4634,7 +5375,7 @@ def module_unload(self, name): """ return self.execute_command("MODULE UNLOAD", name) - def module_list(self): + def module_list(self) -> ResponseT: """ Returns a list of dictionaries containing the name and version of all loaded modules. @@ -4643,166 +5384,35 @@ def module_list(self): """ return self.execute_command("MODULE LIST") - def command_info(self): + def command_info(self) -> None: raise NotImplementedError( "COMMAND INFO is intentionally not implemented in the client." ) - def command_count(self): + def command_count(self) -> ResponseT: return self.execute_command("COMMAND COUNT") - def command_getkeys(self, *args): + def command_getkeys(self, *args) -> ResponseT: return self.execute_command("COMMAND GETKEYS", *args) - def command(self): + def command(self) -> ResponseT: return self.execute_command("COMMAND") -class Script: - """ - An executable Lua script object returned by ``register_script`` - """ - - def __init__(self, registered_client, script): - self.registered_client = registered_client - self.script = script - # Precalculate and store the SHA1 hex digest of the script. - - if isinstance(script, str): - # We need the encoding from the client in order to generate an - # accurate byte representation of the script - encoder = registered_client.connection_pool.get_encoder() - script = encoder.encode(script) - self.sha = hashlib.sha1(script).hexdigest() - - def __call__(self, keys=[], args=[], client=None): - "Execute the script, passing any required ``args``" - if client is None: - client = self.registered_client - args = tuple(keys) + tuple(args) - # make sure the Redis server knows about the script - from redis.client import Pipeline - - if isinstance(client, Pipeline): - # Make sure the pipeline can register the script before executing. - client.scripts.add(self) - try: - return client.evalsha(self.sha, len(keys), *args) - except NoScriptError: - # Maybe the client is pointed to a different server than the client - # that created this instance? - # Overwrite the sha just in case there was a discrepancy. - self.sha = client.script_load(self.script) - return client.evalsha(self.sha, len(keys), *args) - - -class BitFieldOperation: - """ - Command builder for BITFIELD commands. - """ - - def __init__(self, client, key, default_overflow=None): - self.client = client - self.key = key - self._default_overflow = default_overflow - self.reset() - - def reset(self): - """ - Reset the state of the instance to when it was constructed - """ - self.operations = [] - self._last_overflow = "WRAP" - self.overflow(self._default_overflow or self._last_overflow) - - def overflow(self, overflow): - """ - Update the overflow algorithm of successive INCRBY operations - :param overflow: Overflow algorithm, one of WRAP, SAT, FAIL. See the - Redis docs for descriptions of these algorithmsself. - :returns: a :py:class:`BitFieldOperation` instance. - """ - overflow = overflow.upper() - if overflow != self._last_overflow: - self._last_overflow = overflow - self.operations.append(("OVERFLOW", overflow)) - return self - - def incrby(self, fmt, offset, increment, overflow=None): - """ - Increment a bitfield by a given amount. - :param fmt: format-string for the bitfield being updated, e.g. 'u8' - for an unsigned 8-bit integer. - :param offset: offset (in number of bits). If prefixed with a - '#', this is an offset multiplier, e.g. given the arguments - fmt='u8', offset='#2', the offset will be 16. - :param int increment: value to increment the bitfield by. - :param str overflow: overflow algorithm. Defaults to WRAP, but other - acceptable values are SAT and FAIL. See the Redis docs for - descriptions of these algorithms. - :returns: a :py:class:`BitFieldOperation` instance. - """ - if overflow is not None: - self.overflow(overflow) +class AsyncModuleCommands(ModuleCommands): + async def command_info(self) -> None: + return super().command_info() - self.operations.append(("INCRBY", fmt, offset, increment)) - return self - def get(self, fmt, offset): - """ - Get the value of a given bitfield. - :param fmt: format-string for the bitfield being read, e.g. 'u8' for - an unsigned 8-bit integer. - :param offset: offset (in number of bits). If prefixed with a - '#', this is an offset multiplier, e.g. given the arguments - fmt='u8', offset='#2', the offset will be 16. - :returns: a :py:class:`BitFieldOperation` instance. - """ - self.operations.append(("GET", fmt, offset)) - return self - - def set(self, fmt, offset, value): - """ - Set the value of a given bitfield. - :param fmt: format-string for the bitfield being read, e.g. 'u8' for - an unsigned 8-bit integer. - :param offset: offset (in number of bits). If prefixed with a - '#', this is an offset multiplier, e.g. given the arguments - fmt='u8', offset='#2', the offset will be 16. - :param int value: value to set at the given position. - :returns: a :py:class:`BitFieldOperation` instance. - """ - self.operations.append(("SET", fmt, offset, value)) - return self - - @property - def command(self): - cmd = ["BITFIELD", self.key] - for ops in self.operations: - cmd.extend(ops) - return cmd - - def execute(self): - """ - Execute the operation(s) in a single BITFIELD command. The return value - is a list of values corresponding to each operation. If the client - used to create this instance was a pipeline, the list of values - will be present within the pipeline's execute. - """ - command = self.command - self.reset() - return self.client.execute_command(*command) - - -class ClusterCommands: +class ClusterCommands(CommandsProtocol): """ Class for Redis Cluster commands """ - def cluster(self, cluster_arg, *args, **kwargs): + def cluster(self, cluster_arg, *args, **kwargs) -> ResponseT: return self.execute_command(f"CLUSTER {cluster_arg.upper()}", *args, **kwargs) - def readwrite(self, **kwargs): + def readwrite(self, **kwargs) -> ResponseT: """ Disables read queries for a connection to a Redis Cluster slave node. @@ -4810,7 +5420,7 @@ def readwrite(self, **kwargs): """ return self.execute_command("READWRITE", **kwargs) - def readonly(self, **kwargs): + def readonly(self, **kwargs) -> ResponseT: """ Enables read queries for a connection to a Redis Cluster replica node. @@ -4819,6 +5429,9 @@ def readonly(self, **kwargs): return self.execute_command("READONLY", **kwargs) +AsyncClusterCommands = ClusterCommands + + class DataAccessCommands( BasicKeyCommands, HyperlogCommands, @@ -4832,7 +5445,24 @@ class DataAccessCommands( ): """ A class containing all of the implemented data access redis commands. - This class is to be used as a mixin. + This class is to be used as a mixin for synchronous Redis clients. + """ + + +class AsyncDataAccessCommands( + AsyncBasicKeyCommands, + AsyncHyperlogCommands, + AsyncHashCommands, + AsyncGeoCommands, + AsyncListCommands, + AsyncScanCommands, + AsyncSetCommands, + AsyncStreamCommands, + AsyncSortedSetCommands, +): + """ + A class containing all of the implemented data access redis commands. + This class is to be used as a mixin for asynchronous Redis clients. """ @@ -4847,5 +5477,20 @@ class CoreCommands( ): """ A class containing all of the implemented redis commands. This class is - to be used as a mixin. + to be used as a mixin for synchronous Redis clients. + """ + + +class AsyncCoreCommands( + AsyncACLCommands, + AsyncClusterCommands, + AsyncDataAccessCommands, + AsyncManagementCommands, + AsyncModuleCommands, + AsyncPubSubCommands, + AsyncScriptCommands, +): + """ + A class containing all of the implemented redis commands. This class is + to be used as a mixin for asynchronous Redis clients. """ diff --git a/redis/commands/sentinel.py b/redis/commands/sentinel.py index a9b06c2f6e..e054ec6a38 100644 --- a/redis/commands/sentinel.py +++ b/redis/commands/sentinel.py @@ -8,39 +8,39 @@ class SentinelCommands: """ def sentinel(self, *args): - "Redis Sentinel's SENTINEL command." + """Redis Sentinel's SENTINEL command.""" warnings.warn(DeprecationWarning("Use the individual sentinel_* methods")) def sentinel_get_master_addr_by_name(self, service_name): - "Returns a (host, port) pair for the given ``service_name``" + """Returns a (host, port) pair for the given ``service_name``""" return self.execute_command("SENTINEL GET-MASTER-ADDR-BY-NAME", service_name) def sentinel_master(self, service_name): - "Returns a dictionary containing the specified masters state." + """Returns a dictionary containing the specified masters state.""" return self.execute_command("SENTINEL MASTER", service_name) def sentinel_masters(self): - "Returns a list of dictionaries containing each master's state." + """Returns a list of dictionaries containing each master's state.""" return self.execute_command("SENTINEL MASTERS") def sentinel_monitor(self, name, ip, port, quorum): - "Add a new master to Sentinel to be monitored" + """Add a new master to Sentinel to be monitored""" return self.execute_command("SENTINEL MONITOR", name, ip, port, quorum) def sentinel_remove(self, name): - "Remove a master from Sentinel's monitoring" + """Remove a master from Sentinel's monitoring""" return self.execute_command("SENTINEL REMOVE", name) def sentinel_sentinels(self, service_name): - "Returns a list of sentinels for ``service_name``" + """Returns a list of sentinels for ``service_name``""" return self.execute_command("SENTINEL SENTINELS", service_name) def sentinel_set(self, name, option, value): - "Set Sentinel monitoring parameters for a given master" + """Set Sentinel monitoring parameters for a given master""" return self.execute_command("SENTINEL SET", name, option, value) def sentinel_slaves(self, service_name): - "Returns a list of slaves for ``service_name``" + """Returns a list of slaves for ``service_name``""" return self.execute_command("SENTINEL SLAVES", service_name) def sentinel_reset(self, pattern): @@ -91,3 +91,9 @@ def sentinel_flushconfig(self): completely missing. """ return self.execute_command("SENTINEL FLUSHCONFIG") + + +class AsyncSentinelCommands(SentinelCommands): + async def sentinel(self, *args) -> None: + """Redis Sentinel's SENTINEL command.""" + super().sentinel(*args) diff --git a/redis/compat.py b/redis/compat.py new file mode 100644 index 0000000000..738687f645 --- /dev/null +++ b/redis/compat.py @@ -0,0 +1,9 @@ +# flake8: noqa +try: + from typing import Literal, Protocol, TypedDict # lgtm [py/unused-import] +except ImportError: + from typing_extensions import ( # lgtm [py/unused-import] + Literal, + Protocol, + TypedDict, + ) diff --git a/redis/typing.py b/redis/typing.py new file mode 100644 index 0000000000..d96e4e3a5d --- /dev/null +++ b/redis/typing.py @@ -0,0 +1,45 @@ +# from __future__ import annotations + +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Iterable, TypeVar, Union + +from redis.compat import Protocol + +if TYPE_CHECKING: + from redis.asyncio.connection import ConnectionPool as AsyncConnectionPool + from redis.connection import ConnectionPool + + +EncodedT = Union[bytes, memoryview] +DecodedT = Union[str, int, float] +EncodableT = Union[EncodedT, DecodedT] +AbsExpiryT = Union[int, datetime] +ExpiryT = Union[float, timedelta] +ZScoreBoundT = Union[float, str] # str allows for the [ or ( prefix +BitfieldOffsetT = Union[int, str] # str allows for #x syntax +_StringLikeT = Union[bytes, str, memoryview] +KeyT = _StringLikeT # Main redis key space +PatternT = _StringLikeT # Patterns matched against keys, fields etc +FieldT = EncodableT # Fields within hash tables, streams and geo commands +KeysT = Union[KeyT, Iterable[KeyT]] +ChannelT = _StringLikeT +GroupT = _StringLikeT # Consumer group +ConsumerT = _StringLikeT # Consumer name +StreamIdT = Union[int, _StringLikeT] +ScriptTextT = _StringLikeT +TimeoutSecT = Union[int, float, _StringLikeT] +# Mapping is not covariant in the key type, which prevents +# Mapping[_StringLikeT, X from accepting arguments of type Dict[str, X]. Using +# a TypeVar instead of a Union allows mappings with any of the permitted types +# to be passed. Care is needed if there is more than one such mapping in a +# type signature because they will all be required to be the same key type. +AnyKeyT = TypeVar("AnyKeyT", bytes, str, memoryview) +AnyFieldT = TypeVar("AnyFieldT", bytes, str, memoryview) +AnyChannelT = TypeVar("AnyChannelT", bytes, str, memoryview) + + +class CommandsProtocol(Protocol): + connection_pool: Union["AsyncConnectionPool", "ConnectionPool"] + + def execute_command(self, *args, **options): + ... diff --git a/requirements.txt b/requirements.txt index b05ff454bf..7f0ebf0c69 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,4 @@ +async-timeout>=4.0.2 deprecated>=1.2.3 packaging>=20.4 +typing-extensions diff --git a/setup.py b/setup.py index a44fe620d7..a6d5cf0679 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,7 @@ packages=find_packages( include=[ "redis", + "redis.asyncio", "redis.commands", "redis.commands.bf", "redis.commands.json", @@ -34,6 +35,8 @@ "deprecated>=1.2.3", "packaging>=20.4", 'importlib-metadata >= 1.0; python_version < "3.8"', + "typing-extensions", + "async-timeout>=4.0.2", ], classifiers=[ "Development Status :: 5 - Production/Stable", diff --git a/tests/conftest.py b/tests/conftest.py index d9de876933..2534ca0c95 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,7 @@ +import argparse import random import time +from typing import Callable, TypeVar from unittest.mock import Mock from urllib.parse import urlparse @@ -21,6 +23,54 @@ default_redis_ssl_url = "rediss://localhost:6666" default_cluster_nodes = 6 +_DecoratedTest = TypeVar("_DecoratedTest", bound="Callable") +_TestDecorator = Callable[[_DecoratedTest], _DecoratedTest] + + +# Taken from python3.9 +class BooleanOptionalAction(argparse.Action): + def __init__( + self, + option_strings, + dest, + default=None, + type=None, + choices=None, + required=False, + help=None, + metavar=None, + ): + + _option_strings = [] + for option_string in option_strings: + _option_strings.append(option_string) + + if option_string.startswith("--"): + option_string = "--no-" + option_string[2:] + _option_strings.append(option_string) + + if help is not None and default is not None: + help += f" (default: {default})" + + super().__init__( + option_strings=_option_strings, + dest=dest, + nargs=0, + default=default, + type=type, + choices=choices, + required=required, + help=help, + metavar=metavar, + ) + + def __call__(self, parser, namespace, values, option_string=None): + if option_string in self.option_strings: + setattr(namespace, self.dest, not option_string.startswith("--no-")) + + def format_usage(self): + return " | ".join(self.option_strings) + def pytest_addoption(parser): parser.addoption( @@ -62,6 +112,9 @@ def pytest_addoption(parser): help="Redis unstable (latest version) connection string " "defaults to %(default)s`", ) + parser.addoption( + "--uvloop", action=BooleanOptionalAction, help="Run tests with uvloop" + ) def _get_info(redis_url): @@ -101,6 +154,18 @@ def pytest_sessionstart(session): cluster_nodes = session.config.getoption("--redis-cluster-nodes") wait_for_cluster_creation(redis_url, cluster_nodes) + use_uvloop = session.config.getoption("--uvloop") + + if use_uvloop: + try: + import uvloop + + uvloop.install() + except ImportError as e: + raise RuntimeError( + "Can not import uvloop, make sure it is installed" + ) from e + def wait_for_cluster_creation(redis_url, cluster_nodes, timeout=60): """ @@ -133,19 +198,19 @@ def wait_for_cluster_creation(redis_url, cluster_nodes, timeout=60): ) -def skip_if_server_version_lt(min_version): +def skip_if_server_version_lt(min_version: str) -> _TestDecorator: redis_version = REDIS_INFO["version"] check = Version(redis_version) < Version(min_version) return pytest.mark.skipif(check, reason=f"Redis version required >= {min_version}") -def skip_if_server_version_gte(min_version): +def skip_if_server_version_gte(min_version: str) -> _TestDecorator: redis_version = REDIS_INFO["version"] check = Version(redis_version) >= Version(min_version) return pytest.mark.skipif(check, reason=f"Redis version required < {min_version}") -def skip_unless_arch_bits(arch_bits): +def skip_unless_arch_bits(arch_bits: int) -> _TestDecorator: return pytest.mark.skipif( REDIS_INFO["arch_bits"] != arch_bits, reason=f"server is not {arch_bits}-bit" ) @@ -169,17 +234,17 @@ def skip_ifmodversion_lt(min_version: str, module_name: str): raise AttributeError(f"No redis module named {module_name}") -def skip_if_redis_enterprise(): +def skip_if_redis_enterprise() -> _TestDecorator: check = REDIS_INFO["enterprise"] is True return pytest.mark.skipif(check, reason="Redis enterprise") -def skip_ifnot_redis_enterprise(): +def skip_ifnot_redis_enterprise() -> _TestDecorator: check = REDIS_INFO["enterprise"] is False return pytest.mark.skipif(check, reason="Not running in redis enterprise") -def skip_if_nocryptography(): +def skip_if_nocryptography() -> _TestDecorator: try: import cryptography # noqa @@ -188,7 +253,7 @@ def skip_if_nocryptography(): return pytest.mark.skipif(True, reason="No cryptography dependency") -def skip_if_cryptography(): +def skip_if_cryptography() -> _TestDecorator: try: import cryptography # noqa diff --git a/tests/test_asyncio/__init__.py b/tests/test_asyncio/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_asyncio/compat.py b/tests/test_asyncio/compat.py new file mode 100644 index 0000000000..ced4974196 --- /dev/null +++ b/tests/test_asyncio/compat.py @@ -0,0 +1,6 @@ +from unittest import mock + +try: + mock.AsyncMock +except AttributeError: + import mock diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py new file mode 100644 index 0000000000..0e9c73ec61 --- /dev/null +++ b/tests/test_asyncio/conftest.py @@ -0,0 +1,205 @@ +import asyncio +import random +import sys +from typing import Union +from urllib.parse import urlparse + +if sys.version_info[0:2] == (3, 6): + import pytest as pytest_asyncio +else: + import pytest_asyncio + +import pytest +from packaging.version import Version + +import redis.asyncio as redis +from redis.asyncio.client import Monitor +from redis.asyncio.connection import ( + HIREDIS_AVAILABLE, + HiredisParser, + PythonParser, + parse_url, +) +from tests.conftest import REDIS_INFO + +from .compat import mock + + +async def _get_info(redis_url): + client = redis.Redis.from_url(redis_url) + info = await client.info() + await client.connection_pool.disconnect() + return info + + +@pytest_asyncio.fixture( + params=[ + (True, PythonParser), + (False, PythonParser), + pytest.param( + (True, HiredisParser), + marks=pytest.mark.skipif( + not HIREDIS_AVAILABLE, reason="hiredis is not installed" + ), + ), + pytest.param( + (False, HiredisParser), + marks=pytest.mark.skipif( + not HIREDIS_AVAILABLE, reason="hiredis is not installed" + ), + ), + ], + ids=[ + "single-python-parser", + "pool-python-parser", + "single-hiredis", + "pool-hiredis", + ], +) +def create_redis(request, event_loop: asyncio.BaseEventLoop): + """Wrapper around redis.create_redis.""" + single_connection, parser_cls = request.param + + async def f(url: str = request.config.getoption("--redis-url"), **kwargs): + single = kwargs.pop("single_connection_client", False) or single_connection + parser_class = kwargs.pop("parser_class", None) or parser_cls + url_options = parse_url(url) + url_options.update(kwargs) + pool = redis.ConnectionPool(parser_class=parser_class, **url_options) + client: redis.Redis = redis.Redis(connection_pool=pool) + if single: + client = client.client() + await client.initialize() + + def teardown(): + async def ateardown(): + if "username" in kwargs: + return + try: + await client.flushdb() + except redis.ConnectionError: + # handle cases where a test disconnected a client + # just manually retry the flushdb + await client.flushdb() + await client.close() + await client.connection_pool.disconnect() + + if event_loop.is_running(): + event_loop.create_task(ateardown()) + else: + event_loop.run_until_complete(ateardown()) + + request.addfinalizer(teardown) + + return client + + return f + + +@pytest_asyncio.fixture() +async def r(create_redis): + yield await create_redis() + + +@pytest_asyncio.fixture() +async def r2(create_redis): + """A second client for tests that need multiple""" + yield await create_redis() + + +def _gen_cluster_mock_resp(r, response): + connection = mock.AsyncMock() + connection.read_response.return_value = response + r.connection = connection + return r + + +@pytest_asyncio.fixture() +async def mock_cluster_resp_ok(create_redis, **kwargs): + r = await create_redis(**kwargs) + return _gen_cluster_mock_resp(r, "OK") + + +@pytest_asyncio.fixture() +async def mock_cluster_resp_int(create_redis, **kwargs): + r = await create_redis(**kwargs) + return _gen_cluster_mock_resp(r, "2") + + +@pytest_asyncio.fixture() +async def mock_cluster_resp_info(create_redis, **kwargs): + r = await create_redis(**kwargs) + response = ( + "cluster_state:ok\r\ncluster_slots_assigned:16384\r\n" + "cluster_slots_ok:16384\r\ncluster_slots_pfail:0\r\n" + "cluster_slots_fail:0\r\ncluster_known_nodes:7\r\n" + "cluster_size:3\r\ncluster_current_epoch:7\r\n" + "cluster_my_epoch:2\r\ncluster_stats_messages_sent:170262\r\n" + "cluster_stats_messages_received:105653\r\n" + ) + return _gen_cluster_mock_resp(r, response) + + +@pytest_asyncio.fixture() +async def mock_cluster_resp_nodes(create_redis, **kwargs): + r = await create_redis(**kwargs) + response = ( + "c8253bae761cb1ecb2b61857d85dfe455a0fec8b 172.17.0.7:7006 " + "slave aa90da731f673a99617dfe930306549a09f83a6b 0 " + "1447836263059 5 connected\n" + "9bd595fe4821a0e8d6b99d70faa660638a7612b3 172.17.0.7:7008 " + "master - 0 1447836264065 0 connected\n" + "aa90da731f673a99617dfe930306549a09f83a6b 172.17.0.7:7003 " + "myself,master - 0 0 2 connected 5461-10922\n" + "1df047e5a594f945d82fc140be97a1452bcbf93e 172.17.0.7:7007 " + "slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 " + "1447836262556 3 connected\n" + "4ad9a12e63e8f0207025eeba2354bcf4c85e5b22 172.17.0.7:7005 " + "master - 0 1447836262555 7 connected 0-5460\n" + "19efe5a631f3296fdf21a5441680f893e8cc96ec 172.17.0.7:7004 " + "master - 0 1447836263562 3 connected 10923-16383\n" + "fbb23ed8cfa23f17eaf27ff7d0c410492a1093d6 172.17.0.7:7002 " + "master,fail - 1447829446956 1447829444948 1 disconnected\n" + ) + return _gen_cluster_mock_resp(r, response) + + +@pytest_asyncio.fixture() +async def mock_cluster_resp_slaves(create_redis, **kwargs): + r = await create_redis(**kwargs) + response = ( + "['1df047e5a594f945d82fc140be97a1452bcbf93e 172.17.0.7:7007 " + "slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 " + "1447836789290 3 connected']" + ) + return _gen_cluster_mock_resp(r, response) + + +@pytest_asyncio.fixture(scope="session") +def master_host(request): + url = request.config.getoption("--redis-url") + parts = urlparse(url) + yield parts.hostname + + +async def wait_for_command( + client: redis.Redis, monitor: Monitor, command: str, key: Union[str, None] = None +): + # issue a command with a key name that's local to this process. + # if we find a command with our key before the command we're waiting + # for, something went wrong + if key is None: + # generate key + redis_version = REDIS_INFO["version"] + if Version(redis_version) >= Version("5.0.0"): + id_str = str(client.client_id()) + else: + id_str = f"{random.randrange(2 ** 32):08x}" + key = f"__REDIS-PY-{id_str}__" + await client.get(key) + while True: + monitor_response = await monitor.next_command() + if command in monitor_response["command"]: + return monitor_response + if key in monitor_response["command"]: + return None diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py new file mode 100644 index 0000000000..12855c9326 --- /dev/null +++ b/tests/test_asyncio/test_commands.py @@ -0,0 +1,3 @@ +""" +Tests async overrides of commands from their mixins +""" diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py new file mode 100644 index 0000000000..46abec01d6 --- /dev/null +++ b/tests/test_asyncio/test_connection.py @@ -0,0 +1,64 @@ +import asyncio +import types + +import pytest + +from redis.asyncio.connection import PythonParser, UnixDomainSocketConnection +from redis.exceptions import InvalidResponse +from redis.utils import HIREDIS_AVAILABLE +from tests.conftest import skip_if_server_version_lt + +from .compat import mock + +pytestmark = pytest.mark.asyncio + + +@pytest.mark.onlynoncluster +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +async def test_invalid_response(create_redis): + r = await create_redis(single_connection_client=True) + + raw = b"x" + readline_mock = mock.AsyncMock(return_value=raw) + + parser: "PythonParser" = r.connection._parser + with mock.patch.object(parser._buffer, "readline", readline_mock): + with pytest.raises(InvalidResponse) as cm: + await parser.read_response() + assert str(cm.value) == f"Protocol Error: {raw!r}" + + +@skip_if_server_version_lt("4.0.0") +@pytest.mark.redismod +@pytest.mark.onlynoncluster +async def test_loading_external_modules(modclient): + def inner(): + pass + + modclient.load_external_module("myfuncname", inner) + assert getattr(modclient, "myfuncname") == inner + assert isinstance(getattr(modclient, "myfuncname"), types.FunctionType) + + # and call it + from redis.commands import RedisModuleCommands + + j = RedisModuleCommands.json + modclient.load_external_module("sometestfuncname", j) + + # d = {'hello': 'world!'} + # mod = j(modclient) + # mod.set("fookey", ".", d) + # assert mod.get('fookey') == d + + +@pytest.mark.onlynoncluster +async def test_socket_param_regression(r): + """A regression test for issue #1060""" + conn = UnixDomainSocketConnection() + _ = await conn.disconnect() is True + + +@pytest.mark.onlynoncluster +async def test_can_run_concurrent_commands(r): + assert await r.ping() is True + assert all(await asyncio.gather(*(r.ping() for _ in range(10)))) diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py new file mode 100644 index 0000000000..f9dfefd5cc --- /dev/null +++ b/tests/test_asyncio/test_connection_pool.py @@ -0,0 +1,884 @@ +import asyncio +import os +import re +import sys + +import pytest + +if sys.version_info[0:2] == (3, 6): + import pytest as pytest_asyncio +else: + import pytest_asyncio + +import redis.asyncio as redis +from redis.asyncio.connection import Connection, to_bool +from tests.conftest import skip_if_redis_enterprise, skip_if_server_version_lt + +from .compat import mock +from .test_pubsub import wait_for_message + +pytestmark = pytest.mark.asyncio + + +class TestRedisAutoReleaseConnectionPool: + @pytest_asyncio.fixture + async def r(self, create_redis) -> redis.Redis: + """This is necessary since r and r2 create ConnectionPools behind the scenes""" + r = await create_redis() + r.auto_close_connection_pool = True + yield r + + @staticmethod + def get_total_connected_connections(pool): + return len(pool._available_connections) + len(pool._in_use_connections) + + @staticmethod + async def create_two_conn(r: redis.Redis): + if not r.single_connection_client: # Single already initialized connection + r.connection = await r.connection_pool.get_connection("_") + return await r.connection_pool.get_connection("_") + + @staticmethod + def has_no_connected_connections(pool: redis.ConnectionPool): + return not any( + x.is_connected + for x in pool._available_connections + list(pool._in_use_connections) + ) + + async def test_auto_disconnect_redis_created_pool(self, r: redis.Redis): + new_conn = await self.create_two_conn(r) + assert new_conn != r.connection + assert self.get_total_connected_connections(r.connection_pool) == 2 + await r.close() + assert self.has_no_connected_connections(r.connection_pool) + + async def test_do_not_auto_disconnect_redis_created_pool(self, r2: redis.Redis): + assert r2.auto_close_connection_pool is False, ( + "The connection pool should not be disconnected as a manually created " + "connection pool was passed in in conftest.py" + ) + new_conn = await self.create_two_conn(r2) + assert self.get_total_connected_connections(r2.connection_pool) == 2 + await r2.close() + assert r2.connection_pool._in_use_connections == {new_conn} + assert new_conn.is_connected + assert len(r2.connection_pool._available_connections) == 1 + assert r2.connection_pool._available_connections[0].is_connected + + async def test_auto_release_override_true_manual_created_pool(self, r: redis.Redis): + assert r.auto_close_connection_pool is True, "This is from the class fixture" + await self.create_two_conn(r) + await r.close() + assert self.get_total_connected_connections(r.connection_pool) == 2, ( + "The connection pool should not be disconnected as a manually created " + "connection pool was passed in in conftest.py" + ) + assert self.has_no_connected_connections(r.connection_pool) + + @pytest.mark.parametrize("auto_close_conn_pool", [True, False]) + async def test_close_override(self, r: redis.Redis, auto_close_conn_pool): + r.auto_close_connection_pool = auto_close_conn_pool + await self.create_two_conn(r) + await r.close(close_connection_pool=True) + assert self.has_no_connected_connections(r.connection_pool) + + @pytest.mark.parametrize("auto_close_conn_pool", [True, False]) + async def test_negate_auto_close_client_pool( + self, r: redis.Redis, auto_close_conn_pool + ): + r.auto_close_connection_pool = auto_close_conn_pool + new_conn = await self.create_two_conn(r) + await r.close(close_connection_pool=False) + assert not self.has_no_connected_connections(r.connection_pool) + assert r.connection_pool._in_use_connections == {new_conn} + assert r.connection_pool._available_connections[0].is_connected + assert self.get_total_connected_connections(r.connection_pool) == 2 + + +class DummyConnection(Connection): + description_format = "DummyConnection<>" + + def __init__(self, **kwargs): + self.kwargs = kwargs + self.pid = os.getpid() + + async def connect(self): + pass + + async def disconnect(self): + pass + + async def can_read(self, timeout: float = 0): + return False + + +@pytest.mark.onlynoncluster +class TestConnectionPool: + def get_pool( + self, + connection_kwargs=None, + max_connections=None, + connection_class=redis.Connection, + ): + connection_kwargs = connection_kwargs or {} + pool = redis.ConnectionPool( + connection_class=connection_class, + max_connections=max_connections, + **connection_kwargs, + ) + return pool + + async def test_connection_creation(self): + connection_kwargs = {"foo": "bar", "biz": "baz"} + pool = self.get_pool( + connection_kwargs=connection_kwargs, connection_class=DummyConnection + ) + connection = await pool.get_connection("_") + assert isinstance(connection, DummyConnection) + assert connection.kwargs == connection_kwargs + + async def test_multiple_connections(self, master_host): + connection_kwargs = {"host": master_host} + pool = self.get_pool(connection_kwargs=connection_kwargs) + c1 = await pool.get_connection("_") + c2 = await pool.get_connection("_") + assert c1 != c2 + + async def test_max_connections(self, master_host): + connection_kwargs = {"host": master_host} + pool = self.get_pool(max_connections=2, connection_kwargs=connection_kwargs) + await pool.get_connection("_") + await pool.get_connection("_") + with pytest.raises(redis.ConnectionError): + await pool.get_connection("_") + + async def test_reuse_previously_released_connection(self, master_host): + connection_kwargs = {"host": master_host} + pool = self.get_pool(connection_kwargs=connection_kwargs) + c1 = await pool.get_connection("_") + await pool.release(c1) + c2 = await pool.get_connection("_") + assert c1 == c2 + + def test_repr_contains_db_info_tcp(self): + connection_kwargs = { + "host": "localhost", + "port": 6379, + "db": 1, + "client_name": "test-client", + } + pool = self.get_pool( + connection_kwargs=connection_kwargs, connection_class=redis.Connection + ) + expected = ( + "ConnectionPool>" + ) + assert repr(pool) == expected + + def test_repr_contains_db_info_unix(self): + connection_kwargs = {"path": "/abc", "db": 1, "client_name": "test-client"} + pool = self.get_pool( + connection_kwargs=connection_kwargs, + connection_class=redis.UnixDomainSocketConnection, + ) + expected = ( + "ConnectionPool>" + ) + assert repr(pool) == expected + + +@pytest.mark.onlynoncluster +class TestBlockingConnectionPool: + def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20): + connection_kwargs = connection_kwargs or {} + pool = redis.BlockingConnectionPool( + connection_class=DummyConnection, + max_connections=max_connections, + timeout=timeout, + **connection_kwargs, + ) + return pool + + async def test_connection_creation(self, master_host): + connection_kwargs = { + "foo": "bar", + "biz": "baz", + "host": master_host[0], + "port": master_host[1], + } + pool = self.get_pool(connection_kwargs=connection_kwargs) + connection = await pool.get_connection("_") + assert isinstance(connection, DummyConnection) + assert connection.kwargs == connection_kwargs + + async def test_disconnect(self, master_host): + """A regression test for #1047""" + connection_kwargs = { + "foo": "bar", + "biz": "baz", + "host": master_host[0], + "port": master_host[1], + } + pool = self.get_pool(connection_kwargs=connection_kwargs) + await pool.get_connection("_") + await pool.disconnect() + + async def test_multiple_connections(self, master_host): + connection_kwargs = {"host": master_host[0], "port": master_host[1]} + pool = self.get_pool(connection_kwargs=connection_kwargs) + c1 = await pool.get_connection("_") + c2 = await pool.get_connection("_") + assert c1 != c2 + + async def test_connection_pool_blocks_until_timeout(self, master_host): + """When out of connections, block for timeout seconds, then raise""" + connection_kwargs = {"host": master_host} + pool = self.get_pool( + max_connections=1, timeout=0.1, connection_kwargs=connection_kwargs + ) + await pool.get_connection("_") + + start = asyncio.get_event_loop().time() + with pytest.raises(redis.ConnectionError): + await pool.get_connection("_") + # we should have waited at least 0.1 seconds + assert asyncio.get_event_loop().time() - start >= 0.1 + + async def test_connection_pool_blocks_until_conn_available(self, master_host): + """ + When out of connections, block until another connection is released + to the pool + """ + connection_kwargs = {"host": master_host[0], "port": master_host[1]} + pool = self.get_pool( + max_connections=1, timeout=2, connection_kwargs=connection_kwargs + ) + c1 = await pool.get_connection("_") + + async def target(): + await asyncio.sleep(0.1) + await pool.release(c1) + + start = asyncio.get_event_loop().time() + await asyncio.gather(target(), pool.get_connection("_")) + assert asyncio.get_event_loop().time() - start >= 0.1 + + async def test_reuse_previously_released_connection(self, master_host): + connection_kwargs = {"host": master_host} + pool = self.get_pool(connection_kwargs=connection_kwargs) + c1 = await pool.get_connection("_") + await pool.release(c1) + c2 = await pool.get_connection("_") + assert c1 == c2 + + def test_repr_contains_db_info_tcp(self): + pool = redis.ConnectionPool( + host="localhost", port=6379, client_name="test-client" + ) + expected = ( + "ConnectionPool>" + ) + assert repr(pool) == expected + + def test_repr_contains_db_info_unix(self): + pool = redis.ConnectionPool( + connection_class=redis.UnixDomainSocketConnection, + path="abc", + client_name="test-client", + ) + expected = ( + "ConnectionPool>" + ) + assert repr(pool) == expected + + +@pytest.mark.onlynoncluster +class TestConnectionPoolURLParsing: + def test_hostname(self): + pool = redis.ConnectionPool.from_url("redis://my.host") + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "my.host", + } + + def test_quoted_hostname(self): + pool = redis.ConnectionPool.from_url("redis://my %2F host %2B%3D+") + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "my / host +=+", + } + + def test_port(self): + pool = redis.ConnectionPool.from_url("redis://localhost:6380") + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "port": 6380, + } + + @skip_if_server_version_lt("6.0.0") + def test_username(self): + pool = redis.ConnectionPool.from_url("redis://myuser:@localhost") + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "username": "myuser", + } + + @skip_if_server_version_lt("6.0.0") + def test_quoted_username(self): + pool = redis.ConnectionPool.from_url( + "redis://%2Fmyuser%2F%2B name%3D%24+:@localhost" + ) + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "username": "/myuser/+ name=$+", + } + + def test_password(self): + pool = redis.ConnectionPool.from_url("redis://:mypassword@localhost") + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "password": "mypassword", + } + + def test_quoted_password(self): + pool = redis.ConnectionPool.from_url( + "redis://:%2Fmypass%2F%2B word%3D%24+@localhost" + ) + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "password": "/mypass/+ word=$+", + } + + @skip_if_server_version_lt("6.0.0") + def test_username_and_password(self): + pool = redis.ConnectionPool.from_url("redis://myuser:mypass@localhost") + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "username": "myuser", + "password": "mypass", + } + + def test_db_as_argument(self): + pool = redis.ConnectionPool.from_url("redis://localhost", db=1) + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "db": 1, + } + + def test_db_in_path(self): + pool = redis.ConnectionPool.from_url("redis://localhost/2", db=1) + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "db": 2, + } + + def test_db_in_querystring(self): + pool = redis.ConnectionPool.from_url("redis://localhost/2?db=3", db=1) + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "db": 3, + } + + def test_extra_typed_querystring_options(self): + pool = redis.ConnectionPool.from_url( + "redis://localhost/2?socket_timeout=20&socket_connect_timeout=10" + "&socket_keepalive=&retry_on_timeout=Yes&max_connections=10" + ) + + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "db": 2, + "socket_timeout": 20.0, + "socket_connect_timeout": 10.0, + "retry_on_timeout": True, + } + assert pool.max_connections == 10 + + def test_boolean_parsing(self): + for expected, value in ( + (None, None), + (None, ""), + (False, 0), + (False, "0"), + (False, "f"), + (False, "F"), + (False, "False"), + (False, "n"), + (False, "N"), + (False, "No"), + (True, 1), + (True, "1"), + (True, "y"), + (True, "Y"), + (True, "Yes"), + ): + assert expected is to_bool(value) + + def test_client_name_in_querystring(self): + pool = redis.ConnectionPool.from_url("redis://location?client_name=test-client") + assert pool.connection_kwargs["client_name"] == "test-client" + + def test_invalid_extra_typed_querystring_options(self): + with pytest.raises(ValueError): + redis.ConnectionPool.from_url( + "redis://localhost/2?socket_timeout=_&" "socket_connect_timeout=abc" + ) + + def test_extra_querystring_options(self): + pool = redis.ConnectionPool.from_url("redis://localhost?a=1&b=2") + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == {"host": "localhost", "a": "1", "b": "2"} + + def test_calling_from_subclass_returns_correct_instance(self): + pool = redis.BlockingConnectionPool.from_url("redis://localhost") + assert isinstance(pool, redis.BlockingConnectionPool) + + def test_client_creates_connection_pool(self): + r = redis.Redis.from_url("redis://myhost") + assert r.connection_pool.connection_class == redis.Connection + assert r.connection_pool.connection_kwargs == { + "host": "myhost", + } + + def test_invalid_scheme_raises_error(self): + with pytest.raises(ValueError) as cm: + redis.ConnectionPool.from_url("localhost") + assert str(cm.value) == ( + "Redis URL must specify one of the following schemes " + "(redis://, rediss://, unix://)" + ) + + +@pytest.mark.onlynoncluster +class TestConnectionPoolUnixSocketURLParsing: + def test_defaults(self): + pool = redis.ConnectionPool.from_url("unix:///socket") + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + } + + @skip_if_server_version_lt("6.0.0") + def test_username(self): + pool = redis.ConnectionPool.from_url("unix://myuser:@/socket") + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + "username": "myuser", + } + + @skip_if_server_version_lt("6.0.0") + def test_quoted_username(self): + pool = redis.ConnectionPool.from_url( + "unix://%2Fmyuser%2F%2B name%3D%24+:@/socket" + ) + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + "username": "/myuser/+ name=$+", + } + + def test_password(self): + pool = redis.ConnectionPool.from_url("unix://:mypassword@/socket") + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + "password": "mypassword", + } + + def test_quoted_password(self): + pool = redis.ConnectionPool.from_url( + "unix://:%2Fmypass%2F%2B word%3D%24+@/socket" + ) + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + "password": "/mypass/+ word=$+", + } + + def test_quoted_path(self): + pool = redis.ConnectionPool.from_url( + "unix://:mypassword@/my%2Fpath%2Fto%2F..%2F+_%2B%3D%24ocket" + ) + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/my/path/to/../+_+=$ocket", + "password": "mypassword", + } + + def test_db_as_argument(self): + pool = redis.ConnectionPool.from_url("unix:///socket", db=1) + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + "db": 1, + } + + def test_db_in_querystring(self): + pool = redis.ConnectionPool.from_url("unix:///socket?db=2", db=1) + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + "db": 2, + } + + def test_client_name_in_querystring(self): + pool = redis.ConnectionPool.from_url("redis://location?client_name=test-client") + assert pool.connection_kwargs["client_name"] == "test-client" + + def test_extra_querystring_options(self): + pool = redis.ConnectionPool.from_url("unix:///socket?a=1&b=2") + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == {"path": "/socket", "a": "1", "b": "2"} + + +@pytest.mark.onlynoncluster +class TestSSLConnectionURLParsing: + def test_host(self): + pool = redis.ConnectionPool.from_url("rediss://my.host") + assert pool.connection_class == redis.SSLConnection + assert pool.connection_kwargs == { + "host": "my.host", + } + + def test_cert_reqs_options(self): + import ssl + + class DummyConnectionPool(redis.ConnectionPool): + def get_connection(self, *args, **kwargs): + return self.make_connection() + + pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=none") + assert pool.get_connection("_").cert_reqs == ssl.CERT_NONE + + pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=optional") + assert pool.get_connection("_").cert_reqs == ssl.CERT_OPTIONAL + + pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=required") + assert pool.get_connection("_").cert_reqs == ssl.CERT_REQUIRED + + pool = DummyConnectionPool.from_url("rediss://?ssl_check_hostname=False") + assert pool.get_connection("_").check_hostname is False + + pool = DummyConnectionPool.from_url("rediss://?ssl_check_hostname=True") + assert pool.get_connection("_").check_hostname is True + + +@pytest.mark.onlynoncluster +class TestConnection: + async def test_on_connect_error(self): + """ + An error in Connection.on_connect should disconnect from the server + see for details: https://github.com/andymccurdy/redis-py/issues/368 + """ + # this assumes the Redis server being tested against doesn't have + # 9999 databases ;) + bad_connection = redis.Redis(db=9999) + # an error should be raised on connect + with pytest.raises(redis.RedisError): + await bad_connection.info() + pool = bad_connection.connection_pool + assert len(pool._available_connections) == 1 + assert not pool._available_connections[0]._reader + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("2.8.8") + @skip_if_redis_enterprise() + async def test_busy_loading_disconnects_socket(self, r): + """ + If Redis raises a LOADING error, the connection should be + disconnected and a BusyLoadingError raised + """ + with pytest.raises(redis.BusyLoadingError): + await r.execute_command("DEBUG", "ERROR", "LOADING fake message") + if r.connection: + assert not r.connection._reader + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("2.8.8") + @skip_if_redis_enterprise() + async def test_busy_loading_from_pipeline_immediate_command(self, r): + """ + BusyLoadingErrors should raise from Pipelines that execute a + command immediately, like WATCH does. + """ + pipe = r.pipeline() + with pytest.raises(redis.BusyLoadingError): + await pipe.immediate_execute_command( + "DEBUG", "ERROR", "LOADING fake message" + ) + pool = r.connection_pool + assert not pipe.connection + assert len(pool._available_connections) == 1 + assert not pool._available_connections[0]._reader + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("2.8.8") + @skip_if_redis_enterprise() + async def test_busy_loading_from_pipeline(self, r): + """ + BusyLoadingErrors should be raised from a pipeline execution + regardless of the raise_on_error flag. + """ + pipe = r.pipeline() + pipe.execute_command("DEBUG", "ERROR", "LOADING fake message") + with pytest.raises(redis.BusyLoadingError): + await pipe.execute() + pool = r.connection_pool + assert not pipe.connection + assert len(pool._available_connections) == 1 + assert not pool._available_connections[0]._reader + + @skip_if_server_version_lt("2.8.8") + @skip_if_redis_enterprise() + async def test_read_only_error(self, r): + """READONLY errors get turned in ReadOnlyError exceptions""" + with pytest.raises(redis.ReadOnlyError): + await r.execute_command("DEBUG", "ERROR", "READONLY blah blah") + + def test_connect_from_url_tcp(self): + connection = redis.Redis.from_url("redis://localhost") + pool = connection.connection_pool + + assert re.match("(.*)<(.*)<(.*)>>", repr(pool)).groups() == ( + "ConnectionPool", + "Connection", + "host=localhost,port=6379,db=0", + ) + + def test_connect_from_url_unix(self): + connection = redis.Redis.from_url("unix:///path/to/socket") + pool = connection.connection_pool + + assert re.match("(.*)<(.*)<(.*)>>", repr(pool)).groups() == ( + "ConnectionPool", + "UnixDomainSocketConnection", + "path=/path/to/socket,db=0", + ) + + @skip_if_redis_enterprise() + async def test_connect_no_auth_supplied_when_required(self, r): + """ + AuthenticationError should be raised when the server requires a + password but one isn't supplied. + """ + with pytest.raises(redis.AuthenticationError): + await r.execute_command( + "DEBUG", "ERROR", "ERR Client sent AUTH, but no password is set" + ) + + @skip_if_redis_enterprise() + async def test_connect_invalid_password_supplied(self, r): + """AuthenticationError should be raised when sending the wrong password""" + with pytest.raises(redis.AuthenticationError): + await r.execute_command("DEBUG", "ERROR", "ERR invalid password") + + +@pytest.mark.onlynoncluster +class TestMultiConnectionClient: + @pytest_asyncio.fixture() + async def r(self, create_redis, server): + redis = await create_redis(single_connection_client=False) + yield redis + await redis.flushall() + + +@pytest.mark.onlynoncluster +class TestHealthCheck: + interval = 60 + + @pytest_asyncio.fixture() + async def r(self, create_redis): + redis = await create_redis(health_check_interval=self.interval) + yield redis + await redis.flushall() + + def assert_interval_advanced(self, connection): + diff = connection.next_health_check - asyncio.get_event_loop().time() + assert self.interval > diff > (self.interval - 1) + + async def test_health_check_runs(self, r): + if r.connection: + r.connection.next_health_check = asyncio.get_event_loop().time() - 1 + await r.connection.check_health() + self.assert_interval_advanced(r.connection) + + async def test_arbitrary_command_invokes_health_check(self, r): + # invoke a command to make sure the connection is entirely setup + if r.connection: + await r.get("foo") + r.connection.next_health_check = asyncio.get_event_loop().time() + with mock.patch.object( + r.connection, "send_command", wraps=r.connection.send_command + ) as m: + await r.get("foo") + m.assert_called_with("PING", check_health=False) + + self.assert_interval_advanced(r.connection) + + async def test_arbitrary_command_advances_next_health_check(self, r): + if r.connection: + await r.get("foo") + next_health_check = r.connection.next_health_check + await r.get("foo") + assert next_health_check < r.connection.next_health_check + + async def test_health_check_not_invoked_within_interval(self, r): + if r.connection: + await r.get("foo") + with mock.patch.object( + r.connection, "send_command", wraps=r.connection.send_command + ) as m: + await r.get("foo") + ping_call_spec = (("PING",), {"check_health": False}) + assert ping_call_spec not in m.call_args_list + + async def test_health_check_in_pipeline(self, r): + async with r.pipeline(transaction=False) as pipe: + pipe.connection = await pipe.connection_pool.get_connection("_") + pipe.connection.next_health_check = 0 + with mock.patch.object( + pipe.connection, "send_command", wraps=pipe.connection.send_command + ) as m: + responses = await pipe.set("foo", "bar").get("foo").execute() + m.assert_any_call("PING", check_health=False) + assert responses == [True, b"bar"] + + async def test_health_check_in_transaction(self, r): + async with r.pipeline(transaction=True) as pipe: + pipe.connection = await pipe.connection_pool.get_connection("_") + pipe.connection.next_health_check = 0 + with mock.patch.object( + pipe.connection, "send_command", wraps=pipe.connection.send_command + ) as m: + responses = await pipe.set("foo", "bar").get("foo").execute() + m.assert_any_call("PING", check_health=False) + assert responses == [True, b"bar"] + + async def test_health_check_in_watched_pipeline(self, r): + await r.set("foo", "bar") + async with r.pipeline(transaction=False) as pipe: + pipe.connection = await pipe.connection_pool.get_connection("_") + pipe.connection.next_health_check = 0 + with mock.patch.object( + pipe.connection, "send_command", wraps=pipe.connection.send_command + ) as m: + await pipe.watch("foo") + # the health check should be called when watching + m.assert_called_with("PING", check_health=False) + self.assert_interval_advanced(pipe.connection) + assert await pipe.get("foo") == b"bar" + + # reset the mock to clear the call list and schedule another + # health check + m.reset_mock() + pipe.connection.next_health_check = 0 + + pipe.multi() + responses = await pipe.set("foo", "not-bar").get("foo").execute() + assert responses == [True, b"not-bar"] + m.assert_any_call("PING", check_health=False) + + async def test_health_check_in_pubsub_before_subscribe(self, r): + """A health check happens before the first [p]subscribe""" + p = r.pubsub() + p.connection = await p.connection_pool.get_connection("_") + p.connection.next_health_check = 0 + with mock.patch.object( + p.connection, "send_command", wraps=p.connection.send_command + ) as m: + assert not p.subscribed + await p.subscribe("foo") + # the connection is not yet in pubsub mode, so the normal + # ping/pong within connection.send_command should check + # the health of the connection + m.assert_any_call("PING", check_health=False) + self.assert_interval_advanced(p.connection) + + subscribe_message = await wait_for_message(p) + assert subscribe_message["type"] == "subscribe" + + async def test_health_check_in_pubsub_after_subscribed(self, r): + """ + Pubsub can handle a new subscribe when it's time to check the + connection health + """ + p = r.pubsub() + p.connection = await p.connection_pool.get_connection("_") + p.connection.next_health_check = 0 + with mock.patch.object( + p.connection, "send_command", wraps=p.connection.send_command + ) as m: + await p.subscribe("foo") + subscribe_message = await wait_for_message(p) + assert subscribe_message["type"] == "subscribe" + self.assert_interval_advanced(p.connection) + # because we weren't subscribed when sending the subscribe + # message to 'foo', the connection's standard check_health ran + # prior to subscribing. + m.assert_any_call("PING", check_health=False) + + p.connection.next_health_check = 0 + m.reset_mock() + + await p.subscribe("bar") + # the second subscribe issues exactly only command (the subscribe) + # and the health check is not invoked + m.assert_called_once_with("SUBSCRIBE", "bar", check_health=False) + + # since no message has been read since the health check was + # reset, it should still be 0 + assert p.connection.next_health_check == 0 + + subscribe_message = await wait_for_message(p) + assert subscribe_message["type"] == "subscribe" + assert await wait_for_message(p) is None + # now that the connection is subscribed, the pubsub health + # check should have taken over and include the HEALTH_CHECK_MESSAGE + m.assert_any_call("PING", p.HEALTH_CHECK_MESSAGE, check_health=False) + self.assert_interval_advanced(p.connection) + + async def test_health_check_in_pubsub_poll(self, r): + """ + Polling a pubsub connection that's subscribed will regularly + check the connection's health. + """ + p = r.pubsub() + p.connection = await p.connection_pool.get_connection("_") + with mock.patch.object( + p.connection, "send_command", wraps=p.connection.send_command + ) as m: + await p.subscribe("foo") + subscribe_message = await wait_for_message(p) + assert subscribe_message["type"] == "subscribe" + self.assert_interval_advanced(p.connection) + + # polling the connection before the health check interval + # doesn't result in another health check + m.reset_mock() + next_health_check = p.connection.next_health_check + assert await wait_for_message(p) is None + assert p.connection.next_health_check == next_health_check + m.assert_not_called() + + # reset the health check and poll again + # we should not receive a pong message, but the next_health_check + # should be advanced + p.connection.next_health_check = 0 + assert await wait_for_message(p) is None + m.assert_called_with("PING", p.HEALTH_CHECK_MESSAGE, check_health=False) + self.assert_interval_advanced(p.connection) diff --git a/tests/test_asyncio/test_encoding.py b/tests/test_asyncio/test_encoding.py new file mode 100644 index 0000000000..da2983738e --- /dev/null +++ b/tests/test_asyncio/test_encoding.py @@ -0,0 +1,126 @@ +import sys + +import pytest + +if sys.version_info[0:2] == (3, 6): + import pytest as pytest_asyncio +else: + import pytest_asyncio + +import redis.asyncio as redis +from redis.exceptions import DataError + +pytestmark = pytest.mark.asyncio + + +@pytest.mark.onlynoncluster +class TestEncoding: + @pytest_asyncio.fixture() + async def r(self, create_redis): + redis = await create_redis(decode_responses=True) + yield redis + await redis.flushall() + + @pytest_asyncio.fixture() + async def r_no_decode(self, create_redis): + redis = await create_redis(decode_responses=False) + yield redis + await redis.flushall() + + async def test_simple_encoding(self, r_no_decode: redis.Redis): + unicode_string = chr(3456) + "abcd" + chr(3421) + await r_no_decode.set("unicode-string", unicode_string.encode("utf-8")) + cached_val = await r_no_decode.get("unicode-string") + assert isinstance(cached_val, bytes) + assert unicode_string == cached_val.decode("utf-8") + + async def test_simple_encoding_and_decoding(self, r: redis.Redis): + unicode_string = chr(3456) + "abcd" + chr(3421) + await r.set("unicode-string", unicode_string) + cached_val = await r.get("unicode-string") + assert isinstance(cached_val, str) + assert unicode_string == cached_val + + async def test_memoryview_encoding(self, r_no_decode: redis.Redis): + unicode_string = chr(3456) + "abcd" + chr(3421) + unicode_string_view = memoryview(unicode_string.encode("utf-8")) + await r_no_decode.set("unicode-string-memoryview", unicode_string_view) + cached_val = await r_no_decode.get("unicode-string-memoryview") + # The cached value won't be a memoryview because it's a copy from Redis + assert isinstance(cached_val, bytes) + assert unicode_string == cached_val.decode("utf-8") + + async def test_memoryview_encoding_and_decoding(self, r: redis.Redis): + unicode_string = chr(3456) + "abcd" + chr(3421) + unicode_string_view = memoryview(unicode_string.encode("utf-8")) + await r.set("unicode-string-memoryview", unicode_string_view) + cached_val = await r.get("unicode-string-memoryview") + assert isinstance(cached_val, str) + assert unicode_string == cached_val + + async def test_list_encoding(self, r: redis.Redis): + unicode_string = chr(3456) + "abcd" + chr(3421) + result = [unicode_string, unicode_string, unicode_string] + await r.rpush("a", *result) + assert await r.lrange("a", 0, -1) == result + + +@pytest.mark.onlynoncluster +class TestEncodingErrors: + async def test_ignore(self, create_redis): + r = await create_redis( + decode_responses=True, + encoding_errors="ignore", + ) + await r.set("a", b"foo\xff") + assert await r.get("a") == "foo" + + async def test_replace(self, create_redis): + r = await create_redis( + decode_responses=True, + encoding_errors="replace", + ) + await r.set("a", b"foo\xff") + assert await r.get("a") == "foo\ufffd" + + +@pytest.mark.onlynoncluster +class TestMemoryviewsAreNotPacked: + async def test_memoryviews_are_not_packed(self, r): + arg = memoryview(b"some_arg") + arg_list = ["SOME_COMMAND", arg] + c = r.connection or await r.connection_pool.get_connection("_") + cmd = c.pack_command(*arg_list) + assert cmd[1] is arg + cmds = c.pack_commands([arg_list, arg_list]) + assert cmds[1] is arg + assert cmds[3] is arg + + +class TestCommandsAreNotEncoded: + @pytest_asyncio.fixture() + async def r(self, create_redis): + redis = await create_redis(encoding="utf-16") + yield redis + await redis.flushall() + + async def test_basic_command(self, r: redis.Redis): + await r.set("hello", "world") + + +class TestInvalidUserInput: + async def test_boolean_fails(self, r: redis.Redis): + with pytest.raises(DataError): + await r.set("a", True) # type: ignore + + async def test_none_fails(self, r: redis.Redis): + with pytest.raises(DataError): + await r.set("a", None) # type: ignore + + async def test_user_type_fails(self, r: redis.Redis): + class Foo: + def __str__(self): + return "Foo" + + with pytest.raises(DataError): + await r.set("a", Foo()) # type: ignore diff --git a/tests/test_asyncio/test_lock.py b/tests/test_asyncio/test_lock.py new file mode 100644 index 0000000000..c496718a67 --- /dev/null +++ b/tests/test_asyncio/test_lock.py @@ -0,0 +1,242 @@ +import asyncio +import sys + +import pytest + +if sys.version_info[0:2] == (3, 6): + import pytest as pytest_asyncio +else: + import pytest_asyncio + +from redis.asyncio.lock import Lock +from redis.exceptions import LockError, LockNotOwnedError + +pytestmark = pytest.mark.asyncio + + +@pytest.mark.onlynoncluster +class TestLock: + @pytest_asyncio.fixture() + async def r_decoded(self, create_redis): + redis = await create_redis(decode_responses=True) + yield redis + await redis.flushall() + + def get_lock(self, redis, *args, **kwargs): + kwargs["lock_class"] = Lock + return redis.lock(*args, **kwargs) + + async def test_lock(self, r): + lock = self.get_lock(r, "foo") + assert await lock.acquire(blocking=False) + assert await r.get("foo") == lock.local.token + assert await r.ttl("foo") == -1 + await lock.release() + assert await r.get("foo") is None + + async def test_lock_token(self, r): + lock = self.get_lock(r, "foo") + await self._test_lock_token(r, lock) + + async def test_lock_token_thread_local_false(self, r): + lock = self.get_lock(r, "foo", thread_local=False) + await self._test_lock_token(r, lock) + + async def _test_lock_token(self, r, lock): + assert await lock.acquire(blocking=False, token="test") + assert await r.get("foo") == b"test" + assert lock.local.token == b"test" + assert await r.ttl("foo") == -1 + await lock.release() + assert await r.get("foo") is None + assert lock.local.token is None + + async def test_locked(self, r): + lock = self.get_lock(r, "foo") + assert await lock.locked() is False + await lock.acquire(blocking=False) + assert await lock.locked() is True + await lock.release() + assert await lock.locked() is False + + async def _test_owned(self, client): + lock = self.get_lock(client, "foo") + assert await lock.owned() is False + await lock.acquire(blocking=False) + assert await lock.owned() is True + await lock.release() + assert await lock.owned() is False + + lock2 = self.get_lock(client, "foo") + assert await lock.owned() is False + assert await lock2.owned() is False + await lock2.acquire(blocking=False) + assert await lock.owned() is False + assert await lock2.owned() is True + await lock2.release() + assert await lock.owned() is False + assert await lock2.owned() is False + + async def test_owned(self, r): + await self._test_owned(r) + + async def test_owned_with_decoded_responses(self, r_decoded): + await self._test_owned(r_decoded) + + async def test_competing_locks(self, r): + lock1 = self.get_lock(r, "foo") + lock2 = self.get_lock(r, "foo") + assert await lock1.acquire(blocking=False) + assert not await lock2.acquire(blocking=False) + await lock1.release() + assert await lock2.acquire(blocking=False) + assert not await lock1.acquire(blocking=False) + await lock2.release() + + async def test_timeout(self, r): + lock = self.get_lock(r, "foo", timeout=10) + assert await lock.acquire(blocking=False) + assert 8 < (await r.ttl("foo")) <= 10 + await lock.release() + + async def test_float_timeout(self, r): + lock = self.get_lock(r, "foo", timeout=9.5) + assert await lock.acquire(blocking=False) + assert 8 < (await r.pttl("foo")) <= 9500 + await lock.release() + + async def test_blocking_timeout(self, r, event_loop): + lock1 = self.get_lock(r, "foo") + assert await lock1.acquire(blocking=False) + bt = 0.2 + sleep = 0.05 + lock2 = self.get_lock(r, "foo", sleep=sleep, blocking_timeout=bt) + start = event_loop.time() + assert not await lock2.acquire() + # The elapsed duration should be less than the total blocking_timeout + assert bt > (event_loop.time() - start) > bt - sleep + await lock1.release() + + async def test_context_manager(self, r): + # blocking_timeout prevents a deadlock if the lock can't be acquired + # for some reason + async with self.get_lock(r, "foo", blocking_timeout=0.2) as lock: + assert await r.get("foo") == lock.local.token + assert await r.get("foo") is None + + async def test_context_manager_raises_when_locked_not_acquired(self, r): + await r.set("foo", "bar") + with pytest.raises(LockError): + async with self.get_lock(r, "foo", blocking_timeout=0.1): + pass + + async def test_high_sleep_small_blocking_timeout(self, r): + lock1 = self.get_lock(r, "foo") + assert await lock1.acquire(blocking=False) + sleep = 60 + bt = 1 + lock2 = self.get_lock(r, "foo", sleep=sleep, blocking_timeout=bt) + start = asyncio.get_event_loop().time() + assert not await lock2.acquire() + # the elapsed timed is less than the blocking_timeout as the lock is + # unattainable given the sleep/blocking_timeout configuration + assert bt > (asyncio.get_event_loop().time() - start) + await lock1.release() + + async def test_releasing_unlocked_lock_raises_error(self, r): + lock = self.get_lock(r, "foo") + with pytest.raises(LockError): + await lock.release() + + async def test_releasing_lock_no_longer_owned_raises_error(self, r): + lock = self.get_lock(r, "foo") + await lock.acquire(blocking=False) + # manually change the token + await r.set("foo", "a") + with pytest.raises(LockNotOwnedError): + await lock.release() + # even though we errored, the token is still cleared + assert lock.local.token is None + + async def test_extend_lock(self, r): + lock = self.get_lock(r, "foo", timeout=10) + assert await lock.acquire(blocking=False) + assert 8000 < (await r.pttl("foo")) <= 10000 + assert await lock.extend(10) + assert 16000 < (await r.pttl("foo")) <= 20000 + await lock.release() + + async def test_extend_lock_replace_ttl(self, r): + lock = self.get_lock(r, "foo", timeout=10) + assert await lock.acquire(blocking=False) + assert 8000 < (await r.pttl("foo")) <= 10000 + assert await lock.extend(10, replace_ttl=True) + assert 8000 < (await r.pttl("foo")) <= 10000 + await lock.release() + + async def test_extend_lock_float(self, r): + lock = self.get_lock(r, "foo", timeout=10.0) + assert await lock.acquire(blocking=False) + assert 8000 < (await r.pttl("foo")) <= 10000 + assert await lock.extend(10.0) + assert 16000 < (await r.pttl("foo")) <= 20000 + await lock.release() + + async def test_extending_unlocked_lock_raises_error(self, r): + lock = self.get_lock(r, "foo", timeout=10) + with pytest.raises(LockError): + await lock.extend(10) + + async def test_extending_lock_with_no_timeout_raises_error(self, r): + lock = self.get_lock(r, "foo") + assert await lock.acquire(blocking=False) + with pytest.raises(LockError): + await lock.extend(10) + await lock.release() + + async def test_extending_lock_no_longer_owned_raises_error(self, r): + lock = self.get_lock(r, "foo", timeout=10) + assert await lock.acquire(blocking=False) + await r.set("foo", "a") + with pytest.raises(LockNotOwnedError): + await lock.extend(10) + + async def test_reacquire_lock(self, r): + lock = self.get_lock(r, "foo", timeout=10) + assert await lock.acquire(blocking=False) + assert await r.pexpire("foo", 5000) + assert await r.pttl("foo") <= 5000 + assert await lock.reacquire() + assert 8000 < (await r.pttl("foo")) <= 10000 + await lock.release() + + async def test_reacquiring_unlocked_lock_raises_error(self, r): + lock = self.get_lock(r, "foo", timeout=10) + with pytest.raises(LockError): + await lock.reacquire() + + async def test_reacquiring_lock_with_no_timeout_raises_error(self, r): + lock = self.get_lock(r, "foo") + assert await lock.acquire(blocking=False) + with pytest.raises(LockError): + await lock.reacquire() + await lock.release() + + async def test_reacquiring_lock_no_longer_owned_raises_error(self, r): + lock = self.get_lock(r, "foo", timeout=10) + assert await lock.acquire(blocking=False) + await r.set("foo", "a") + with pytest.raises(LockNotOwnedError): + await lock.reacquire() + + +@pytest.mark.onlynoncluster +class TestLockClassSelection: + def test_lock_class_argument(self, r): + class MyLock: + def __init__(self, *args, **kwargs): + + pass + + lock = r.lock("foo", lock_class=MyLock) + assert type(lock) == MyLock diff --git a/tests/test_asyncio/test_monitor.py b/tests/test_asyncio/test_monitor.py new file mode 100644 index 0000000000..783ba262b0 --- /dev/null +++ b/tests/test_asyncio/test_monitor.py @@ -0,0 +1,67 @@ +import pytest + +from tests.conftest import skip_if_redis_enterprise, skip_ifnot_redis_enterprise + +from .conftest import wait_for_command + +pytestmark = pytest.mark.asyncio + + +@pytest.mark.onlynoncluster +class TestMonitor: + async def test_wait_command_not_found(self, r): + """Make sure the wait_for_command func works when command is not found""" + async with r.monitor() as m: + response = await wait_for_command(r, m, "nothing") + assert response is None + + async def test_response_values(self, r): + db = r.connection_pool.connection_kwargs.get("db", 0) + async with r.monitor() as m: + await r.ping() + response = await wait_for_command(r, m, "PING") + assert isinstance(response["time"], float) + assert response["db"] == db + assert response["client_type"] in ("tcp", "unix") + assert isinstance(response["client_address"], str) + assert isinstance(response["client_port"], str) + assert response["command"] == "PING" + + async def test_command_with_quoted_key(self, r): + async with r.monitor() as m: + await r.get('foo"bar') + response = await wait_for_command(r, m, 'GET foo"bar') + assert response["command"] == 'GET foo"bar' + + async def test_command_with_binary_data(self, r): + async with r.monitor() as m: + byte_string = b"foo\x92" + await r.get(byte_string) + response = await wait_for_command(r, m, "GET foo\\x92") + assert response["command"] == "GET foo\\x92" + + async def test_command_with_escaped_data(self, r): + async with r.monitor() as m: + byte_string = b"foo\\x92" + await r.get(byte_string) + response = await wait_for_command(r, m, "GET foo\\\\x92") + assert response["command"] == "GET foo\\\\x92" + + @skip_if_redis_enterprise() + async def test_lua_script(self, r): + async with r.monitor() as m: + script = 'return redis.call("GET", "foo")' + assert await r.eval(script, 0) is None + response = await wait_for_command(r, m, "GET foo") + assert response["command"] == "GET foo" + assert response["client_type"] == "lua" + assert response["client_address"] == "lua" + assert response["client_port"] == "" + + @skip_ifnot_redis_enterprise() + async def test_lua_script_in_enterprise(self, r): + async with r.monitor() as m: + script = 'return redis.call("GET", "foo")' + assert await r.eval(script, 0) is None + response = await wait_for_command(r, m, "GET foo") + assert response is None diff --git a/tests/test_asyncio/test_pipeline.py b/tests/test_asyncio/test_pipeline.py new file mode 100644 index 0000000000..5bb1a8a4e0 --- /dev/null +++ b/tests/test_asyncio/test_pipeline.py @@ -0,0 +1,409 @@ +import pytest + +import redis +from tests.conftest import skip_if_server_version_lt + +from .conftest import wait_for_command + +pytestmark = pytest.mark.asyncio + + +@pytest.mark.onlynoncluster +class TestPipeline: + async def test_pipeline_is_true(self, r): + """Ensure pipeline instances are not false-y""" + async with r.pipeline() as pipe: + assert pipe + + async def test_pipeline(self, r): + async with r.pipeline() as pipe: + ( + pipe.set("a", "a1") + .get("a") + .zadd("z", {"z1": 1}) + .zadd("z", {"z2": 4}) + .zincrby("z", 1, "z1") + .zrange("z", 0, 5, withscores=True) + ) + assert await pipe.execute() == [ + True, + b"a1", + True, + True, + 2.0, + [(b"z1", 2.0), (b"z2", 4)], + ] + + async def test_pipeline_memoryview(self, r): + async with r.pipeline() as pipe: + (pipe.set("a", memoryview(b"a1")).get("a")) + assert await pipe.execute() == [ + True, + b"a1", + ] + + async def test_pipeline_length(self, r): + async with r.pipeline() as pipe: + # Initially empty. + assert len(pipe) == 0 + + # Fill 'er up! + pipe.set("a", "a1").set("b", "b1").set("c", "c1") + assert len(pipe) == 3 + + # Execute calls reset(), so empty once again. + await pipe.execute() + assert len(pipe) == 0 + + @pytest.mark.onlynoncluster + async def test_pipeline_no_transaction(self, r): + async with r.pipeline(transaction=False) as pipe: + pipe.set("a", "a1").set("b", "b1").set("c", "c1") + assert await pipe.execute() == [True, True, True] + assert await r.get("a") == b"a1" + assert await r.get("b") == b"b1" + assert await r.get("c") == b"c1" + + async def test_pipeline_no_transaction_watch(self, r): + await r.set("a", 0) + + async with r.pipeline(transaction=False) as pipe: + await pipe.watch("a") + a = await pipe.get("a") + + pipe.multi() + pipe.set("a", int(a) + 1) + assert await pipe.execute() == [True] + + async def test_pipeline_no_transaction_watch_failure(self, r): + await r.set("a", 0) + + async with r.pipeline(transaction=False) as pipe: + await pipe.watch("a") + a = await pipe.get("a") + + await r.set("a", "bad") + + pipe.multi() + pipe.set("a", int(a) + 1) + + with pytest.raises(redis.WatchError): + await pipe.execute() + + assert await r.get("a") == b"bad" + + async def test_exec_error_in_response(self, r): + """ + an invalid pipeline command at exec time adds the exception instance + to the list of returned values + """ + await r.set("c", "a") + async with r.pipeline() as pipe: + pipe.set("a", 1).set("b", 2).lpush("c", 3).set("d", 4) + result = await pipe.execute(raise_on_error=False) + + assert result[0] + assert await r.get("a") == b"1" + assert result[1] + assert await r.get("b") == b"2" + + # we can't lpush to a key that's a string value, so this should + # be a ResponseError exception + assert isinstance(result[2], redis.ResponseError) + assert await r.get("c") == b"a" + + # since this isn't a transaction, the other commands after the + # error are still executed + assert result[3] + assert await r.get("d") == b"4" + + # make sure the pipe was restored to a working state + assert await pipe.set("z", "zzz").execute() == [True] + assert await r.get("z") == b"zzz" + + async def test_exec_error_raised(self, r): + await r.set("c", "a") + async with r.pipeline() as pipe: + pipe.set("a", 1).set("b", 2).lpush("c", 3).set("d", 4) + with pytest.raises(redis.ResponseError) as ex: + await pipe.execute() + assert str(ex.value).startswith( + "Command # 3 (LPUSH c 3) of " "pipeline caused error: " + ) + + # make sure the pipe was restored to a working state + assert await pipe.set("z", "zzz").execute() == [True] + assert await r.get("z") == b"zzz" + + @pytest.mark.onlynoncluster + async def test_transaction_with_empty_error_command(self, r): + """ + Commands with custom EMPTY_ERROR functionality return their default + values in the pipeline no matter the raise_on_error preference + """ + for error_switch in (True, False): + async with r.pipeline() as pipe: + pipe.set("a", 1).mget([]).set("c", 3) + result = await pipe.execute(raise_on_error=error_switch) + + assert result[0] + assert result[1] == [] + assert result[2] + + @pytest.mark.onlynoncluster + async def test_pipeline_with_empty_error_command(self, r): + """ + Commands with custom EMPTY_ERROR functionality return their default + values in the pipeline no matter the raise_on_error preference + """ + for error_switch in (True, False): + async with r.pipeline(transaction=False) as pipe: + pipe.set("a", 1).mget([]).set("c", 3) + result = await pipe.execute(raise_on_error=error_switch) + + assert result[0] + assert result[1] == [] + assert result[2] + + async def test_parse_error_raised(self, r): + async with r.pipeline() as pipe: + # the zrem is invalid because we don't pass any keys to it + pipe.set("a", 1).zrem("b").set("b", 2) + with pytest.raises(redis.ResponseError) as ex: + await pipe.execute() + + assert str(ex.value).startswith( + "Command # 2 (ZREM b) of " "pipeline caused error: " + ) + + # make sure the pipe was restored to a working state + assert await pipe.set("z", "zzz").execute() == [True] + assert await r.get("z") == b"zzz" + + @pytest.mark.onlynoncluster + async def test_parse_error_raised_transaction(self, r): + async with r.pipeline() as pipe: + pipe.multi() + # the zrem is invalid because we don't pass any keys to it + pipe.set("a", 1).zrem("b").set("b", 2) + with pytest.raises(redis.ResponseError) as ex: + await pipe.execute() + + assert str(ex.value).startswith( + "Command # 2 (ZREM b) of " "pipeline caused error: " + ) + + # make sure the pipe was restored to a working state + assert await pipe.set("z", "zzz").execute() == [True] + assert await r.get("z") == b"zzz" + + @pytest.mark.onlynoncluster + async def test_watch_succeed(self, r): + await r.set("a", 1) + await r.set("b", 2) + + async with r.pipeline() as pipe: + await pipe.watch("a", "b") + assert pipe.watching + a_value = await pipe.get("a") + b_value = await pipe.get("b") + assert a_value == b"1" + assert b_value == b"2" + pipe.multi() + + pipe.set("c", 3) + assert await pipe.execute() == [True] + assert not pipe.watching + + @pytest.mark.onlynoncluster + async def test_watch_failure(self, r): + await r.set("a", 1) + await r.set("b", 2) + + async with r.pipeline() as pipe: + await pipe.watch("a", "b") + await r.set("b", 3) + pipe.multi() + pipe.get("a") + with pytest.raises(redis.WatchError): + await pipe.execute() + + assert not pipe.watching + + @pytest.mark.onlynoncluster + async def test_watch_failure_in_empty_transaction(self, r): + await r.set("a", 1) + await r.set("b", 2) + + async with r.pipeline() as pipe: + await pipe.watch("a", "b") + await r.set("b", 3) + pipe.multi() + with pytest.raises(redis.WatchError): + await pipe.execute() + + assert not pipe.watching + + @pytest.mark.onlynoncluster + async def test_unwatch(self, r): + await r.set("a", 1) + await r.set("b", 2) + + async with r.pipeline() as pipe: + await pipe.watch("a", "b") + await r.set("b", 3) + await pipe.unwatch() + assert not pipe.watching + pipe.get("a") + assert await pipe.execute() == [b"1"] + + @pytest.mark.onlynoncluster + async def test_watch_exec_no_unwatch(self, r): + await r.set("a", 1) + await r.set("b", 2) + + async with r.monitor() as m: + async with r.pipeline() as pipe: + await pipe.watch("a", "b") + assert pipe.watching + a_value = await pipe.get("a") + b_value = await pipe.get("b") + assert a_value == b"1" + assert b_value == b"2" + pipe.multi() + pipe.set("c", 3) + assert await pipe.execute() == [True] + assert not pipe.watching + + unwatch_command = await wait_for_command(r, m, "UNWATCH") + assert unwatch_command is None, "should not send UNWATCH" + + @pytest.mark.onlynoncluster + async def test_watch_reset_unwatch(self, r): + await r.set("a", 1) + + async with r.monitor() as m: + async with r.pipeline() as pipe: + await pipe.watch("a") + assert pipe.watching + await pipe.reset() + assert not pipe.watching + + unwatch_command = await wait_for_command(r, m, "UNWATCH") + assert unwatch_command is not None + assert unwatch_command["command"] == "UNWATCH" + + @pytest.mark.onlynoncluster + async def test_transaction_callable(self, r): + await r.set("a", 1) + await r.set("b", 2) + has_run = [] + + async def my_transaction(pipe): + a_value = await pipe.get("a") + assert a_value in (b"1", b"2") + b_value = await pipe.get("b") + assert b_value == b"2" + + # silly run-once code... incr's "a" so WatchError should be raised + # forcing this all to run again. this should incr "a" once to "2" + if not has_run: + await r.incr("a") + has_run.append("it has") + + pipe.multi() + pipe.set("c", int(a_value) + int(b_value)) + + result = await r.transaction(my_transaction, "a", "b") + assert result == [True] + assert await r.get("c") == b"4" + + @pytest.mark.onlynoncluster + async def test_transaction_callable_returns_value_from_callable(self, r): + async def callback(pipe): + # No need to do anything here since we only want the return value + return "a" + + res = await r.transaction(callback, "my-key", value_from_callable=True) + assert res == "a" + + async def test_exec_error_in_no_transaction_pipeline(self, r): + await r.set("a", 1) + async with r.pipeline(transaction=False) as pipe: + pipe.llen("a") + pipe.expire("a", 100) + + with pytest.raises(redis.ResponseError) as ex: + await pipe.execute() + + assert str(ex.value).startswith( + "Command # 1 (LLEN a) of " "pipeline caused error: " + ) + + assert await r.get("a") == b"1" + + async def test_exec_error_in_no_transaction_pipeline_unicode_command(self, r): + key = chr(3456) + "abcd" + chr(3421) + await r.set(key, 1) + async with r.pipeline(transaction=False) as pipe: + pipe.llen(key) + pipe.expire(key, 100) + + with pytest.raises(redis.ResponseError) as ex: + await pipe.execute() + + expected = f"Command # 1 (LLEN {key}) of pipeline caused error: " + assert str(ex.value).startswith(expected) + + assert await r.get(key) == b"1" + + async def test_pipeline_with_bitfield(self, r): + async with r.pipeline() as pipe: + pipe.set("a", "1") + bf = pipe.bitfield("b") + pipe2 = ( + bf.set("u8", 8, 255) + .get("u8", 0) + .get("u4", 8) # 1111 + .get("u4", 12) # 1111 + .get("u4", 13) # 1110 + .execute() + ) + pipe.get("a") + response = await pipe.execute() + + assert pipe == pipe2 + assert response == [True, [0, 0, 15, 15, 14], b"1"] + + async def test_pipeline_get(self, r): + await r.set("a", "a1") + async with r.pipeline() as pipe: + await pipe.get("a") + assert await pipe.execute() == [b"a1"] + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("2.0.0") + async def test_pipeline_discard(self, r): + + # empty pipeline should raise an error + async with r.pipeline() as pipe: + pipe.set("key", "someval") + await pipe.discard() + with pytest.raises(redis.exceptions.ResponseError): + await pipe.execute() + + # setting a pipeline and discarding should do the same + async with r.pipeline() as pipe: + pipe.set("key", "someval") + pipe.set("someotherkey", "val") + response = await pipe.execute() + pipe.set("key", "another value!") + await pipe.discard() + pipe.set("key", "another vae!") + with pytest.raises(redis.exceptions.ResponseError): + await pipe.execute() + + pipe.set("foo", "bar") + response = await pipe.execute() + assert response[0] + assert await r.get("foo") == b"bar" diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py new file mode 100644 index 0000000000..9efcd3cf8a --- /dev/null +++ b/tests/test_asyncio/test_pubsub.py @@ -0,0 +1,660 @@ +import asyncio +import sys +from typing import Optional + +import pytest + +if sys.version_info[0:2] == (3, 6): + import pytest as pytest_asyncio +else: + import pytest_asyncio + +import redis.asyncio as redis +from redis.exceptions import ConnectionError +from redis.typing import EncodableT +from tests.conftest import skip_if_server_version_lt + +from .compat import mock + +pytestmark = pytest.mark.asyncio(forbid_global_loop=True) + + +async def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False): + now = asyncio.get_event_loop().time() + timeout = now + timeout + while now < timeout: + message = await pubsub.get_message( + ignore_subscribe_messages=ignore_subscribe_messages + ) + if message is not None: + return message + await asyncio.sleep(0.01) + now = asyncio.get_event_loop().time() + return None + + +def make_message( + type, channel: Optional[str], data: EncodableT, pattern: Optional[str] = None +): + return { + "type": type, + "pattern": pattern and pattern.encode("utf-8") or None, + "channel": channel and channel.encode("utf-8") or None, + "data": data.encode("utf-8") if isinstance(data, str) else data, + } + + +def make_subscribe_test_data(pubsub, type): + if type == "channel": + return { + "p": pubsub, + "sub_type": "subscribe", + "unsub_type": "unsubscribe", + "sub_func": pubsub.subscribe, + "unsub_func": pubsub.unsubscribe, + "keys": ["foo", "bar", "uni" + chr(4456) + "code"], + } + elif type == "pattern": + return { + "p": pubsub, + "sub_type": "psubscribe", + "unsub_type": "punsubscribe", + "sub_func": pubsub.psubscribe, + "unsub_func": pubsub.punsubscribe, + "keys": ["f*", "b*", "uni" + chr(4456) + "*"], + } + assert False, f"invalid subscribe type: {type}" + + +@pytest.mark.onlynoncluster +class TestPubSubSubscribeUnsubscribe: + async def _test_subscribe_unsubscribe( + self, p, sub_type, unsub_type, sub_func, unsub_func, keys + ): + for key in keys: + assert await sub_func(key) is None + + # should be a message for each channel/pattern we just subscribed to + for i, key in enumerate(keys): + assert await wait_for_message(p) == make_message(sub_type, key, i + 1) + + for key in keys: + assert await unsub_func(key) is None + + # should be a message for each channel/pattern we just unsubscribed + # from + for i, key in enumerate(keys): + i = len(keys) - 1 - i + assert await wait_for_message(p) == make_message(unsub_type, key, i) + + async def test_channel_subscribe_unsubscribe(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "channel") + await self._test_subscribe_unsubscribe(**kwargs) + + async def test_pattern_subscribe_unsubscribe(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "pattern") + await self._test_subscribe_unsubscribe(**kwargs) + + @pytest.mark.onlynoncluster + async def _test_resubscribe_on_reconnection( + self, p, sub_type, unsub_type, sub_func, unsub_func, keys + ): + + for key in keys: + assert await sub_func(key) is None + + # should be a message for each channel/pattern we just subscribed to + for i, key in enumerate(keys): + assert await wait_for_message(p) == make_message(sub_type, key, i + 1) + + # manually disconnect + await p.connection.disconnect() + + # calling get_message again reconnects and resubscribes + # note, we may not re-subscribe to channels in exactly the same order + # so we have to do some extra checks to make sure we got them all + messages = [] + for i in range(len(keys)): + messages.append(await wait_for_message(p)) + + unique_channels = set() + assert len(messages) == len(keys) + for i, message in enumerate(messages): + assert message["type"] == sub_type + assert message["data"] == i + 1 + assert isinstance(message["channel"], bytes) + channel = message["channel"].decode("utf-8") + unique_channels.add(channel) + + assert len(unique_channels) == len(keys) + for channel in unique_channels: + assert channel in keys + + async def test_resubscribe_to_channels_on_reconnection(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "channel") + await self._test_resubscribe_on_reconnection(**kwargs) + + async def test_resubscribe_to_patterns_on_reconnection(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "pattern") + await self._test_resubscribe_on_reconnection(**kwargs) + + async def _test_subscribed_property( + self, p, sub_type, unsub_type, sub_func, unsub_func, keys + ): + + assert p.subscribed is False + await sub_func(keys[0]) + # we're now subscribed even though we haven't processed the + # reply from the server just yet + assert p.subscribed is True + assert await wait_for_message(p) == make_message(sub_type, keys[0], 1) + # we're still subscribed + assert p.subscribed is True + + # unsubscribe from all channels + await unsub_func() + # we're still technically subscribed until we process the + # response messages from the server + assert p.subscribed is True + assert await wait_for_message(p) == make_message(unsub_type, keys[0], 0) + # now we're no longer subscribed as no more messages can be delivered + # to any channels we were listening to + assert p.subscribed is False + + # subscribing again flips the flag back + await sub_func(keys[0]) + assert p.subscribed is True + assert await wait_for_message(p) == make_message(sub_type, keys[0], 1) + + # unsubscribe again + await unsub_func() + assert p.subscribed is True + # subscribe to another channel before reading the unsubscribe response + await sub_func(keys[1]) + assert p.subscribed is True + # read the unsubscribe for key1 + assert await wait_for_message(p) == make_message(unsub_type, keys[0], 0) + # we're still subscribed to key2, so subscribed should still be True + assert p.subscribed is True + # read the key2 subscribe message + assert await wait_for_message(p) == make_message(sub_type, keys[1], 1) + await unsub_func() + # haven't read the message yet, so we're still subscribed + assert p.subscribed is True + assert await wait_for_message(p) == make_message(unsub_type, keys[1], 0) + # now we're finally unsubscribed + assert p.subscribed is False + + async def test_subscribe_property_with_channels(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "channel") + await self._test_subscribed_property(**kwargs) + + @pytest.mark.onlynoncluster + async def test_subscribe_property_with_patterns(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "pattern") + await self._test_subscribed_property(**kwargs) + + async def test_ignore_all_subscribe_messages(self, r: redis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + + checks = ( + (p.subscribe, "foo"), + (p.unsubscribe, "foo"), + (p.psubscribe, "f*"), + (p.punsubscribe, "f*"), + ) + + assert p.subscribed is False + for func, channel in checks: + assert await func(channel) is None + assert p.subscribed is True + assert await wait_for_message(p) is None + assert p.subscribed is False + + async def test_ignore_individual_subscribe_messages(self, r: redis.Redis): + p = r.pubsub() + + checks = ( + (p.subscribe, "foo"), + (p.unsubscribe, "foo"), + (p.psubscribe, "f*"), + (p.punsubscribe, "f*"), + ) + + assert p.subscribed is False + for func, channel in checks: + assert await func(channel) is None + assert p.subscribed is True + message = await wait_for_message(p, ignore_subscribe_messages=True) + assert message is None + assert p.subscribed is False + + async def test_sub_unsub_resub_channels(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "channel") + await self._test_sub_unsub_resub(**kwargs) + + @pytest.mark.onlynoncluster + async def test_sub_unsub_resub_patterns(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "pattern") + await self._test_sub_unsub_resub(**kwargs) + + async def _test_sub_unsub_resub( + self, p, sub_type, unsub_type, sub_func, unsub_func, keys + ): + # https://github.com/andymccurdy/redis-py/issues/764 + key = keys[0] + await sub_func(key) + await unsub_func(key) + await sub_func(key) + assert p.subscribed is True + assert await wait_for_message(p) == make_message(sub_type, key, 1) + assert await wait_for_message(p) == make_message(unsub_type, key, 0) + assert await wait_for_message(p) == make_message(sub_type, key, 1) + assert p.subscribed is True + + async def test_sub_unsub_all_resub_channels(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "channel") + await self._test_sub_unsub_all_resub(**kwargs) + + async def test_sub_unsub_all_resub_patterns(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "pattern") + await self._test_sub_unsub_all_resub(**kwargs) + + async def _test_sub_unsub_all_resub( + self, p, sub_type, unsub_type, sub_func, unsub_func, keys + ): + # https://github.com/andymccurdy/redis-py/issues/764 + key = keys[0] + await sub_func(key) + await unsub_func() + await sub_func(key) + assert p.subscribed is True + assert await wait_for_message(p) == make_message(sub_type, key, 1) + assert await wait_for_message(p) == make_message(unsub_type, key, 0) + assert await wait_for_message(p) == make_message(sub_type, key, 1) + assert p.subscribed is True + + +@pytest.mark.onlynoncluster +class TestPubSubMessages: + def setup_method(self, method): + self.message = None + + def message_handler(self, message): + self.message = message + + async def async_message_handler(self, message): + self.async_message = message + + async def test_published_message_to_channel(self, r: redis.Redis): + p = r.pubsub() + await p.subscribe("foo") + assert await wait_for_message(p) == make_message("subscribe", "foo", 1) + assert await r.publish("foo", "test message") == 1 + + message = await wait_for_message(p) + assert isinstance(message, dict) + assert message == make_message("message", "foo", "test message") + + async def test_published_message_to_pattern(self, r: redis.Redis): + p = r.pubsub() + await p.subscribe("foo") + await p.psubscribe("f*") + assert await wait_for_message(p) == make_message("subscribe", "foo", 1) + assert await wait_for_message(p) == make_message("psubscribe", "f*", 2) + # 1 to pattern, 1 to channel + assert await r.publish("foo", "test message") == 2 + + message1 = await wait_for_message(p) + message2 = await wait_for_message(p) + assert isinstance(message1, dict) + assert isinstance(message2, dict) + + expected = [ + make_message("message", "foo", "test message"), + make_message("pmessage", "foo", "test message", pattern="f*"), + ] + + assert message1 in expected + assert message2 in expected + assert message1 != message2 + + async def test_channel_message_handler(self, r: redis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + await p.subscribe(foo=self.message_handler) + assert await wait_for_message(p) is None + assert await r.publish("foo", "test message") == 1 + assert await wait_for_message(p) is None + assert self.message == make_message("message", "foo", "test message") + + async def test_channel_async_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + await p.subscribe(foo=self.async_message_handler) + assert await wait_for_message(p) is None + assert await r.publish("foo", "test message") == 1 + assert await wait_for_message(p) is None + assert self.async_message == make_message("message", "foo", "test message") + + async def test_channel_sync_async_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + await p.subscribe(foo=self.message_handler) + await p.subscribe(bar=self.async_message_handler) + assert await wait_for_message(p) is None + assert await r.publish("foo", "test message") == 1 + assert await r.publish("bar", "test message 2") == 1 + assert await wait_for_message(p) is None + assert self.message == make_message("message", "foo", "test message") + assert self.async_message == make_message("message", "bar", "test message 2") + + @pytest.mark.onlynoncluster + async def test_pattern_message_handler(self, r: redis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + await p.psubscribe(**{"f*": self.message_handler}) + assert await wait_for_message(p) is None + assert await r.publish("foo", "test message") == 1 + assert await wait_for_message(p) is None + assert self.message == make_message( + "pmessage", "foo", "test message", pattern="f*" + ) + + async def test_unicode_channel_message_handler(self, r: redis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + channel = "uni" + chr(4456) + "code" + channels = {channel: self.message_handler} + await p.subscribe(**channels) + assert await wait_for_message(p) is None + assert await r.publish(channel, "test message") == 1 + assert await wait_for_message(p) is None + assert self.message == make_message("message", channel, "test message") + + @pytest.mark.onlynoncluster + # see: https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html + # #known-limitations-with-pubsub + async def test_unicode_pattern_message_handler(self, r: redis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + pattern = "uni" + chr(4456) + "*" + channel = "uni" + chr(4456) + "code" + await p.psubscribe(**{pattern: self.message_handler}) + assert await wait_for_message(p) is None + assert await r.publish(channel, "test message") == 1 + assert await wait_for_message(p) is None + assert self.message == make_message( + "pmessage", channel, "test message", pattern=pattern + ) + + async def test_get_message_without_subscribe(self, r: redis.Redis): + p = r.pubsub() + with pytest.raises(RuntimeError) as info: + await p.get_message() + expect = ( + "connection not set: " "did you forget to call subscribe() or psubscribe()?" + ) + assert expect in info.exconly() + + +@pytest.mark.onlynoncluster +class TestPubSubAutoDecoding: + """These tests only validate that we get unicode values back""" + + channel = "uni" + chr(4456) + "code" + pattern = "uni" + chr(4456) + "*" + data = "abc" + chr(4458) + "123" + + def make_message(self, type, channel, data, pattern=None): + return {"type": type, "channel": channel, "pattern": pattern, "data": data} + + def setup_method(self, method): + self.message = None + + def message_handler(self, message): + self.message = message + + @pytest_asyncio.fixture() + async def r(self, create_redis): + return await create_redis( + decode_responses=True, + ) + + async def test_channel_subscribe_unsubscribe(self, r: redis.Redis): + p = r.pubsub() + await p.subscribe(self.channel) + assert await wait_for_message(p) == self.make_message( + "subscribe", self.channel, 1 + ) + + await p.unsubscribe(self.channel) + assert await wait_for_message(p) == self.make_message( + "unsubscribe", self.channel, 0 + ) + + async def test_pattern_subscribe_unsubscribe(self, r: redis.Redis): + p = r.pubsub() + await p.psubscribe(self.pattern) + assert await wait_for_message(p) == self.make_message( + "psubscribe", self.pattern, 1 + ) + + await p.punsubscribe(self.pattern) + assert await wait_for_message(p) == self.make_message( + "punsubscribe", self.pattern, 0 + ) + + async def test_channel_publish(self, r: redis.Redis): + p = r.pubsub() + await p.subscribe(self.channel) + assert await wait_for_message(p) == self.make_message( + "subscribe", self.channel, 1 + ) + await r.publish(self.channel, self.data) + assert await wait_for_message(p) == self.make_message( + "message", self.channel, self.data + ) + + @pytest.mark.onlynoncluster + async def test_pattern_publish(self, r: redis.Redis): + p = r.pubsub() + await p.psubscribe(self.pattern) + assert await wait_for_message(p) == self.make_message( + "psubscribe", self.pattern, 1 + ) + await r.publish(self.channel, self.data) + assert await wait_for_message(p) == self.make_message( + "pmessage", self.channel, self.data, pattern=self.pattern + ) + + async def test_channel_message_handler(self, r: redis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + await p.subscribe(**{self.channel: self.message_handler}) + assert await wait_for_message(p) is None + await r.publish(self.channel, self.data) + assert await wait_for_message(p) is None + assert self.message == self.make_message("message", self.channel, self.data) + + # test that we reconnected to the correct channel + self.message = None + await p.connection.disconnect() + assert await wait_for_message(p) is None # should reconnect + new_data = self.data + "new data" + await r.publish(self.channel, new_data) + assert await wait_for_message(p) is None + assert self.message == self.make_message("message", self.channel, new_data) + + async def test_pattern_message_handler(self, r: redis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + await p.psubscribe(**{self.pattern: self.message_handler}) + assert await wait_for_message(p) is None + await r.publish(self.channel, self.data) + assert await wait_for_message(p) is None + assert self.message == self.make_message( + "pmessage", self.channel, self.data, pattern=self.pattern + ) + + # test that we reconnected to the correct pattern + self.message = None + await p.connection.disconnect() + assert await wait_for_message(p) is None # should reconnect + new_data = self.data + "new data" + await r.publish(self.channel, new_data) + assert await wait_for_message(p) is None + assert self.message == self.make_message( + "pmessage", self.channel, new_data, pattern=self.pattern + ) + + async def test_context_manager(self, r: redis.Redis): + async with r.pubsub() as pubsub: + await pubsub.subscribe("foo") + assert pubsub.connection is not None + + assert pubsub.connection is None + assert pubsub.channels == {} + assert pubsub.patterns == {} + + +@pytest.mark.onlynoncluster +class TestPubSubRedisDown: + async def test_channel_subscribe(self, r: redis.Redis): + r = redis.Redis(host="localhost", port=6390) + p = r.pubsub() + with pytest.raises(ConnectionError): + await p.subscribe("foo") + + +@pytest.mark.onlynoncluster +class TestPubSubSubcommands: + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("2.8.0") + async def test_pubsub_channels(self, r: redis.Redis): + p = r.pubsub() + await p.subscribe("foo", "bar", "baz", "quux") + for i in range(4): + assert (await wait_for_message(p))["type"] == "subscribe" + expected = [b"bar", b"baz", b"foo", b"quux"] + assert all([channel in await r.pubsub_channels() for channel in expected]) + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("2.8.0") + async def test_pubsub_numsub(self, r: redis.Redis): + p1 = r.pubsub() + await p1.subscribe("foo", "bar", "baz") + for i in range(3): + assert (await wait_for_message(p1))["type"] == "subscribe" + p2 = r.pubsub() + await p2.subscribe("bar", "baz") + for i in range(2): + assert (await wait_for_message(p2))["type"] == "subscribe" + p3 = r.pubsub() + await p3.subscribe("baz") + assert (await wait_for_message(p3))["type"] == "subscribe" + + channels = [(b"foo", 1), (b"bar", 2), (b"baz", 3)] + assert await r.pubsub_numsub("foo", "bar", "baz") == channels + + @skip_if_server_version_lt("2.8.0") + async def test_pubsub_numpat(self, r: redis.Redis): + p = r.pubsub() + await p.psubscribe("*oo", "*ar", "b*z") + for i in range(3): + assert (await wait_for_message(p))["type"] == "psubscribe" + assert await r.pubsub_numpat() == 3 + + +@pytest.mark.onlynoncluster +class TestPubSubPings: + @skip_if_server_version_lt("3.0.0") + async def test_send_pubsub_ping(self, r: redis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + await p.subscribe("foo") + await p.ping() + assert await wait_for_message(p) == make_message( + type="pong", channel=None, data="", pattern=None + ) + + @skip_if_server_version_lt("3.0.0") + async def test_send_pubsub_ping_message(self, r: redis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + await p.subscribe("foo") + await p.ping(message="hello world") + assert await wait_for_message(p) == make_message( + type="pong", channel=None, data="hello world", pattern=None + ) + + +@pytest.mark.onlynoncluster +class TestPubSubConnectionKilled: + @skip_if_server_version_lt("3.0.0") + async def test_connection_error_raised_when_connection_dies(self, r: redis.Redis): + p = r.pubsub() + await p.subscribe("foo") + assert await wait_for_message(p) == make_message("subscribe", "foo", 1) + for client in await r.client_list(): + if client["cmd"] == "subscribe": + await r.client_kill_filter(_id=client["id"]) + with pytest.raises(ConnectionError): + await wait_for_message(p) + + +@pytest.mark.onlynoncluster +class TestPubSubTimeouts: + async def test_get_message_with_timeout_returns_none(self, r: redis.Redis): + p = r.pubsub() + await p.subscribe("foo") + assert await wait_for_message(p) == make_message("subscribe", "foo", 1) + assert await p.get_message(timeout=0.01) is None + + +@pytest.mark.onlynoncluster +class TestPubSubRun: + async def _subscribe(self, p, *args, **kwargs): + await p.subscribe(*args, **kwargs) + # Wait for the server to act on the subscription, to be sure that + # a subsequent publish on another connection will reach the pubsub. + while True: + message = await p.get_message(timeout=1) + if ( + message is not None + and message["type"] == "subscribe" + and message["channel"] == b"foo" + ): + return + + async def test_callbacks(self, r: redis.Redis): + def callback(message): + messages.put_nowait(message) + + messages = asyncio.Queue() + p = r.pubsub() + await self._subscribe(p, foo=callback) + task = asyncio.get_event_loop().create_task(p.run()) + await r.publish("foo", "bar") + message = await messages.get() + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + assert message == { + "channel": b"foo", + "data": b"bar", + "pattern": None, + "type": "message", + } + + async def test_exception_handler(self, r: redis.Redis): + def exception_handler_callback(e, pubsub) -> None: + assert pubsub == p + exceptions.put_nowait(e) + + exceptions = asyncio.Queue() + p = r.pubsub() + await self._subscribe(p, foo=lambda x: None) + with mock.patch.object(p, "get_message", side_effect=Exception("error")): + task = asyncio.get_event_loop().create_task( + p.run(exception_handler=exception_handler_callback) + ) + e = await exceptions.get() + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + assert str(e) == "error" diff --git a/tests/test_asyncio/test_retry.py b/tests/test_asyncio/test_retry.py new file mode 100644 index 0000000000..e83e001847 --- /dev/null +++ b/tests/test_asyncio/test_retry.py @@ -0,0 +1,70 @@ +import pytest + +from redis.asyncio.connection import Connection, UnixDomainSocketConnection +from redis.asyncio.retry import Retry +from redis.backoff import AbstractBackoff, NoBackoff +from redis.exceptions import ConnectionError + + +class BackoffMock(AbstractBackoff): + def __init__(self): + self.reset_calls = 0 + self.calls = 0 + + def reset(self): + self.reset_calls += 1 + + def compute(self, failures): + self.calls += 1 + return 0 + + +@pytest.mark.onlynoncluster +class TestConnectionConstructorWithRetry: + "Test that the Connection constructors properly handles Retry objects" + + @pytest.mark.parametrize("retry_on_timeout", [False, True]) + @pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection]) + def test_retry_on_timeout_boolean(self, Class, retry_on_timeout): + c = Class(retry_on_timeout=retry_on_timeout) + assert c.retry_on_timeout == retry_on_timeout + assert isinstance(c.retry, Retry) + assert c.retry._retries == (1 if retry_on_timeout else 0) + + @pytest.mark.parametrize("retries", range(10)) + @pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection]) + def test_retry_on_timeout_retry(self, Class, retries: int): + retry_on_timeout = retries > 0 + c = Class(retry_on_timeout=retry_on_timeout, retry=Retry(NoBackoff(), retries)) + assert c.retry_on_timeout == retry_on_timeout + assert isinstance(c.retry, Retry) + assert c.retry._retries == retries + + +@pytest.mark.onlynoncluster +class TestRetry: + "Test that Retry calls backoff and retries the expected number of times" + + def setup_method(self, test_method): + self.actual_attempts = 0 + self.actual_failures = 0 + + async def _do(self): + self.actual_attempts += 1 + raise ConnectionError() + + async def _fail(self, error): + self.actual_failures += 1 + + @pytest.mark.parametrize("retries", range(10)) + @pytest.mark.asyncio + async def test_retry(self, retries: int): + backoff = BackoffMock() + retry = Retry(backoff, retries) + with pytest.raises(ConnectionError): + await retry.call_with_retry(self._do, self._fail) + + assert self.actual_attempts == 1 + retries + assert self.actual_failures == 1 + retries + assert backoff.reset_calls == 1 + assert backoff.calls == retries diff --git a/tests/test_asyncio/test_scripting.py b/tests/test_asyncio/test_scripting.py new file mode 100644 index 0000000000..764525fb4a --- /dev/null +++ b/tests/test_asyncio/test_scripting.py @@ -0,0 +1,159 @@ +import sys + +import pytest + +if sys.version_info[0:2] == (3, 6): + import pytest as pytest_asyncio +else: + import pytest_asyncio + +from redis import exceptions +from tests.conftest import skip_if_server_version_lt + +multiply_script = """ +local value = redis.call('GET', KEYS[1]) +value = tonumber(value) +return value * ARGV[1]""" + +msgpack_hello_script = """ +local message = cmsgpack.unpack(ARGV[1]) +local name = message['name'] +return "hello " .. name +""" +msgpack_hello_script_broken = """ +local message = cmsgpack.unpack(ARGV[1]) +local names = message['name'] +return "hello " .. name +""" + + +@pytest.mark.onlynoncluster +class TestScripting: + @pytest_asyncio.fixture + async def r(self, create_redis): + redis = await create_redis() + yield redis + await redis.script_flush() + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_eval(self, r): + await r.flushdb() + await r.set("a", 2) + # 2 * 3 == 6 + assert await r.eval(multiply_script, 1, "a", 3) == 6 + + @pytest.mark.asyncio(forbid_global_loop=True) + @skip_if_server_version_lt("6.2.0") + async def test_script_flush(self, r): + await r.set("a", 2) + await r.script_load(multiply_script) + await r.script_flush("ASYNC") + + await r.set("a", 2) + await r.script_load(multiply_script) + await r.script_flush("SYNC") + + await r.set("a", 2) + await r.script_load(multiply_script) + await r.script_flush() + + with pytest.raises(exceptions.DataError): + await r.set("a", 2) + await r.script_load(multiply_script) + await r.script_flush("NOTREAL") + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_evalsha(self, r): + await r.set("a", 2) + sha = await r.script_load(multiply_script) + # 2 * 3 == 6 + assert await r.evalsha(sha, 1, "a", 3) == 6 + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_evalsha_script_not_loaded(self, r): + await r.set("a", 2) + sha = await r.script_load(multiply_script) + # remove the script from Redis's cache + await r.script_flush() + with pytest.raises(exceptions.NoScriptError): + await r.evalsha(sha, 1, "a", 3) + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_script_loading(self, r): + # get the sha, then clear the cache + sha = await r.script_load(multiply_script) + await r.script_flush() + assert await r.script_exists(sha) == [False] + await r.script_load(multiply_script) + assert await r.script_exists(sha) == [True] + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_script_object(self, r): + await r.script_flush() + await r.set("a", 2) + multiply = r.register_script(multiply_script) + precalculated_sha = multiply.sha + assert precalculated_sha + assert await r.script_exists(multiply.sha) == [False] + # Test second evalsha block (after NoScriptError) + assert await multiply(keys=["a"], args=[3]) == 6 + # At this point, the script should be loaded + assert await r.script_exists(multiply.sha) == [True] + # Test that the precalculated sha matches the one from redis + assert multiply.sha == precalculated_sha + # Test first evalsha block + assert await multiply(keys=["a"], args=[3]) == 6 + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_script_object_in_pipeline(self, r): + await r.script_flush() + multiply = r.register_script(multiply_script) + precalculated_sha = multiply.sha + assert precalculated_sha + pipe = r.pipeline() + pipe.set("a", 2) + pipe.get("a") + await multiply(keys=["a"], args=[3], client=pipe) + assert await r.script_exists(multiply.sha) == [False] + # [SET worked, GET 'a', result of multiple script] + assert await pipe.execute() == [True, b"2", 6] + # The script should have been loaded by pipe.execute() + assert await r.script_exists(multiply.sha) == [True] + # The precalculated sha should have been the correct one + assert multiply.sha == precalculated_sha + + # purge the script from redis's cache and re-run the pipeline + # the multiply script should be reloaded by pipe.execute() + await r.script_flush() + pipe = r.pipeline() + pipe.set("a", 2) + pipe.get("a") + await multiply(keys=["a"], args=[3], client=pipe) + assert await r.script_exists(multiply.sha) == [False] + # [SET worked, GET 'a', result of multiple script] + assert await pipe.execute() == [True, b"2", 6] + assert await r.script_exists(multiply.sha) == [True] + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_eval_msgpack_pipeline_error_in_lua(self, r): + msgpack_hello = r.register_script(msgpack_hello_script) + assert msgpack_hello.sha + + pipe = r.pipeline() + + # avoiding a dependency to msgpack, this is the output of + # msgpack.dumps({"name": "joe"}) + msgpack_message_1 = b"\x81\xa4name\xa3Joe" + + await msgpack_hello(args=[msgpack_message_1], client=pipe) + + assert await r.script_exists(msgpack_hello.sha) == [False] + assert (await pipe.execute())[0] == b"hello Joe" + assert await r.script_exists(msgpack_hello.sha) == [True] + + msgpack_hello_broken = r.register_script(msgpack_hello_script_broken) + + await msgpack_hello_broken(args=[msgpack_message_1], client=pipe) + with pytest.raises(exceptions.ResponseError) as excinfo: + await pipe.execute() + assert excinfo.type == exceptions.ResponseError diff --git a/tests/test_asyncio/test_sentinel.py b/tests/test_asyncio/test_sentinel.py new file mode 100644 index 0000000000..cd6810c1b5 --- /dev/null +++ b/tests/test_asyncio/test_sentinel.py @@ -0,0 +1,249 @@ +import socket +import sys + +import pytest + +if sys.version_info[0:2] == (3, 6): + import pytest as pytest_asyncio +else: + import pytest_asyncio + +import redis.asyncio.sentinel +from redis import exceptions +from redis.asyncio.sentinel import ( + MasterNotFoundError, + Sentinel, + SentinelConnectionPool, + SlaveNotFoundError, +) + +pytestmark = pytest.mark.asyncio + + +@pytest_asyncio.fixture(scope="module") +def master_ip(master_host): + yield socket.gethostbyname(master_host) + + +class SentinelTestClient: + def __init__(self, cluster, id): + self.cluster = cluster + self.id = id + + async def sentinel_masters(self): + self.cluster.connection_error_if_down(self) + self.cluster.timeout_if_down(self) + return {self.cluster.service_name: self.cluster.master} + + async def sentinel_slaves(self, master_name): + self.cluster.connection_error_if_down(self) + self.cluster.timeout_if_down(self) + if master_name != self.cluster.service_name: + return [] + return self.cluster.slaves + + async def execute_command(self, *args, **kwargs): + # wrapper purely to validate the calls don't explode + from redis.asyncio.client import bool_ok + + return bool_ok + + +class SentinelTestCluster: + def __init__(self, service_name="mymaster", ip="127.0.0.1", port=6379): + self.clients = {} + self.master = { + "ip": ip, + "port": port, + "is_master": True, + "is_sdown": False, + "is_odown": False, + "num-other-sentinels": 0, + } + self.service_name = service_name + self.slaves = [] + self.nodes_down = set() + self.nodes_timeout = set() + + def connection_error_if_down(self, node): + if node.id in self.nodes_down: + raise exceptions.ConnectionError + + def timeout_if_down(self, node): + if node.id in self.nodes_timeout: + raise exceptions.TimeoutError + + def client(self, host, port, **kwargs): + return SentinelTestClient(self, (host, port)) + + +@pytest_asyncio.fixture() +async def cluster(master_ip): + + cluster = SentinelTestCluster(ip=master_ip) + saved_Redis = redis.asyncio.sentinel.Redis + redis.asyncio.sentinel.Redis = cluster.client + yield cluster + redis.asyncio.sentinel.Redis = saved_Redis + + +@pytest_asyncio.fixture() +def sentinel(request, cluster): + return Sentinel([("foo", 26379), ("bar", 26379)]) + + +@pytest.mark.onlynoncluster +async def test_discover_master(sentinel, master_ip): + address = await sentinel.discover_master("mymaster") + assert address == (master_ip, 6379) + + +@pytest.mark.onlynoncluster +async def test_discover_master_error(sentinel): + with pytest.raises(MasterNotFoundError): + await sentinel.discover_master("xxx") + + +@pytest.mark.onlynoncluster +async def test_discover_master_sentinel_down(cluster, sentinel, master_ip): + # Put first sentinel 'foo' down + cluster.nodes_down.add(("foo", 26379)) + address = await sentinel.discover_master("mymaster") + assert address == (master_ip, 6379) + # 'bar' is now first sentinel + assert sentinel.sentinels[0].id == ("bar", 26379) + + +@pytest.mark.onlynoncluster +async def test_discover_master_sentinel_timeout(cluster, sentinel, master_ip): + # Put first sentinel 'foo' down + cluster.nodes_timeout.add(("foo", 26379)) + address = await sentinel.discover_master("mymaster") + assert address == (master_ip, 6379) + # 'bar' is now first sentinel + assert sentinel.sentinels[0].id == ("bar", 26379) + + +@pytest.mark.onlynoncluster +async def test_master_min_other_sentinels(cluster, master_ip): + sentinel = Sentinel([("foo", 26379)], min_other_sentinels=1) + # min_other_sentinels + with pytest.raises(MasterNotFoundError): + await sentinel.discover_master("mymaster") + cluster.master["num-other-sentinels"] = 2 + address = await sentinel.discover_master("mymaster") + assert address == (master_ip, 6379) + + +@pytest.mark.onlynoncluster +async def test_master_odown(cluster, sentinel): + cluster.master["is_odown"] = True + with pytest.raises(MasterNotFoundError): + await sentinel.discover_master("mymaster") + + +@pytest.mark.onlynoncluster +async def test_master_sdown(cluster, sentinel): + cluster.master["is_sdown"] = True + with pytest.raises(MasterNotFoundError): + await sentinel.discover_master("mymaster") + + +@pytest.mark.onlynoncluster +async def test_discover_slaves(cluster, sentinel): + assert await sentinel.discover_slaves("mymaster") == [] + + cluster.slaves = [ + {"ip": "slave0", "port": 1234, "is_odown": False, "is_sdown": False}, + {"ip": "slave1", "port": 1234, "is_odown": False, "is_sdown": False}, + ] + assert await sentinel.discover_slaves("mymaster") == [ + ("slave0", 1234), + ("slave1", 1234), + ] + + # slave0 -> ODOWN + cluster.slaves[0]["is_odown"] = True + assert await sentinel.discover_slaves("mymaster") == [("slave1", 1234)] + + # slave1 -> SDOWN + cluster.slaves[1]["is_sdown"] = True + assert await sentinel.discover_slaves("mymaster") == [] + + cluster.slaves[0]["is_odown"] = False + cluster.slaves[1]["is_sdown"] = False + + # node0 -> DOWN + cluster.nodes_down.add(("foo", 26379)) + assert await sentinel.discover_slaves("mymaster") == [ + ("slave0", 1234), + ("slave1", 1234), + ] + cluster.nodes_down.clear() + + # node0 -> TIMEOUT + cluster.nodes_timeout.add(("foo", 26379)) + assert await sentinel.discover_slaves("mymaster") == [ + ("slave0", 1234), + ("slave1", 1234), + ] + + +@pytest.mark.onlynoncluster +async def test_master_for(cluster, sentinel, master_ip): + master = sentinel.master_for("mymaster", db=9) + assert await master.ping() + assert master.connection_pool.master_address == (master_ip, 6379) + + # Use internal connection check + master = sentinel.master_for("mymaster", db=9, check_connection=True) + assert await master.ping() + + +@pytest.mark.onlynoncluster +async def test_slave_for(cluster, sentinel): + cluster.slaves = [ + {"ip": "127.0.0.1", "port": 6379, "is_odown": False, "is_sdown": False}, + ] + slave = sentinel.slave_for("mymaster", db=9) + assert await slave.ping() + + +@pytest.mark.onlynoncluster +async def test_slave_for_slave_not_found_error(cluster, sentinel): + cluster.master["is_odown"] = True + slave = sentinel.slave_for("mymaster", db=9) + with pytest.raises(SlaveNotFoundError): + await slave.ping() + + +@pytest.mark.onlynoncluster +async def test_slave_round_robin(cluster, sentinel, master_ip): + cluster.slaves = [ + {"ip": "slave0", "port": 6379, "is_odown": False, "is_sdown": False}, + {"ip": "slave1", "port": 6379, "is_odown": False, "is_sdown": False}, + ] + pool = SentinelConnectionPool("mymaster", sentinel) + rotator = pool.rotate_slaves() + assert await rotator.__anext__() in (("slave0", 6379), ("slave1", 6379)) + assert await rotator.__anext__() in (("slave0", 6379), ("slave1", 6379)) + # Fallback to master + assert await rotator.__anext__() == (master_ip, 6379) + with pytest.raises(SlaveNotFoundError): + await rotator.__anext__() + + +@pytest.mark.onlynoncluster +async def test_ckquorum(cluster, sentinel): + assert await sentinel.sentinel_ckquorum("mymaster") + + +@pytest.mark.onlynoncluster +async def test_flushconfig(cluster, sentinel): + assert await sentinel.sentinel_flushconfig() + + +@pytest.mark.onlynoncluster +async def test_reset(cluster, sentinel): + cluster.master["is_odown"] = True + assert await sentinel.sentinel_reset("mymaster") diff --git a/tox.ini b/tox.ini index 2639cb7a2e..3ef01ddead 100644 --- a/tox.ini +++ b/tox.ini @@ -10,7 +10,7 @@ markers = [tox] minversion = 3.2.0 requires = tox-docker -envlist = {standalone,cluster}-{plain,hiredis,ocsp}-{py36,py37,py38,py39,pypy3},linters,docs +envlist = {standalone,cluster}-{plain,hiredis,ocsp}-{uvloop,asyncio}-{py36,py37,py38,py39,pypy3},linters,docs [docker:master] name = master @@ -270,7 +270,9 @@ setenv = CLUSTER_URL = "redis://localhost:16379/0" commands = standalone: pytest --cov=./ --cov-report=xml:coverage_redis.xml -W always -m 'not onlycluster' {posargs} + standalone-uvloop: pytest --cov=./ --cov-report=xml:coverage_redis.xml -W always -m 'not onlycluster' --uvloop {posargs} cluster: pytest --cov=./ --cov-report=xml:coverage_cluster.xml -W always -m 'not onlynoncluster and not redismod' --redis-url={env:CLUSTER_URL:} {posargs} + cluster-uvloop: pytest --cov=./ --cov-report=xml:coverage_redis.xml -W always -m 'not onlycluster' --uvloop {posargs} [testenv:redis5] deps = @@ -350,6 +352,7 @@ exclude = dist, docker, venv*, + .venv*, whitelist.py ignore = F405 diff --git a/whitelist.py b/whitelist.py index 891ccd6022..27210284c7 100644 --- a/whitelist.py +++ b/whitelist.py @@ -10,3 +10,8 @@ exc_type # unused variable (/data/repos/redis/redis-py/redis/lock.py:156) exc_value # unused variable (/data/repos/redis/redis-py/redis/lock.py:156) traceback # unused variable (/data/repos/redis/redis-py/redis/lock.py:156) +exc_type # unused variable (/data/repos/redis/redis-py/redis/asyncio/utils.py:26) +exc_value # unused variable (/data/repos/redis/redis-py/redis/asyncio/utils.py:26) +traceback # unused variable (/data/repos/redis/redis-py/redis/asyncio/utils.py:26) +AsyncConnectionPool # unused import (//data/repos/redis/redis-py/redis/typing.py:9) +AsyncRedis # unused import (//data/repos/redis/redis-py/redis/commands/core.py:49)