From cfde0f2c9666113e35dad2e34c0ac0b1d361a342 Mon Sep 17 00:00:00 2001 From: Anas Date: Sun, 24 Apr 2022 23:44:17 +0300 Subject: [PATCH 1/3] Made sync lock consistent and added types to it --- redis/client.py | 9 ++++++ redis/lock.py | 72 +++++++++++++++++++++++++++++----------------- redis/typing.py | 2 +- tests/test_lock.py | 20 +++++++++++++ 4 files changed, 76 insertions(+), 27 deletions(-) diff --git a/redis/client.py b/redis/client.py index d8d7a75ce0..a9201cf559 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1089,6 +1089,7 @@ def lock( name, timeout=None, sleep=0.1, + blocking=True, blocking_timeout=None, lock_class=None, thread_local=True, @@ -1104,6 +1105,13 @@ def lock( when the lock is in blocking mode and another client is currently holding the lock. + ``blocking`` indicates whether calling ``acquire`` should block until + the lock has been acquired or to fail immediately, causing ``acquire`` + to return False and the lock not being acquired. Defaults to True. + Note this value can be overridden by passing a ``blocking`` + argument to ``acquire``. + + ``blocking_timeout`` indicates the maximum amount of time in seconds to spend trying to acquire the lock. A value of ``None`` indicates continue trying forever. ``blocking_timeout`` can be specified as a @@ -1146,6 +1154,7 @@ def lock( name, timeout=timeout, sleep=sleep, + blocking=blocking, blocking_timeout=blocking_timeout, thread_local=thread_local, ) diff --git a/redis/lock.py b/redis/lock.py index 74e769bfea..2a79941cd2 100644 --- a/redis/lock.py +++ b/redis/lock.py @@ -1,9 +1,13 @@ +from __future__ import annotations + import threading import time as mod_time import uuid -from types import SimpleNamespace +from types import SimpleNamespace, TracebackType +from typing import Optional, Type from redis.exceptions import LockError, LockNotOwnedError +from redis.typing import Number class Lock: @@ -73,13 +77,14 @@ class Lock: def __init__( self, - redis, - name, - timeout=None, - sleep=0.1, - blocking=True, - blocking_timeout=None, - thread_local=True, + redis: "Redis", + name: str, + *, + timeout: Optional[Number] = None, + sleep: Number = 0.1, + blocking: bool = True, + blocking_timeout: Optional[Number] = None, + thread_local: bool = True, ): """ Create a new Lock instance named ``name`` using the Redis client @@ -142,7 +147,7 @@ def __init__( self.local.token = None self.register_scripts() - def register_scripts(self): + def register_scripts(self) -> None: cls = self.__class__ client = self.redis if cls.lua_release is None: @@ -152,15 +157,27 @@ def register_scripts(self): if cls.lua_reacquire is None: cls.lua_reacquire = client.register_script(cls.LUA_REACQUIRE_SCRIPT) - def __enter__(self): + def __enter__(self) -> "Lock": if self.acquire(): return self raise LockError("Unable to acquire lock within the time specified") - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType] + ) -> None: self.release() - def acquire(self, blocking=None, blocking_timeout=None, token=None): + def acquire( + self, + *, + sleep: Optional[Number] = None, + blocking: Optional[bool] = None, + blocking_timeout: Optional[Number] = None, + token: Optional[str] = None + ): """ Use Redis to hold a shared, distributed lock named ``name``. Returns True once the lock is acquired. @@ -176,7 +193,8 @@ def acquire(self, blocking=None, blocking_timeout=None, token=None): object with the default encoding. If a token isn't specified, a UUID will be generated. """ - sleep = self.sleep + if sleep is None: + sleep = self.sleep if token is None: token = uuid.uuid1().hex.encode() else: @@ -200,7 +218,7 @@ def acquire(self, blocking=None, blocking_timeout=None, token=None): return False mod_time.sleep(sleep) - def do_acquire(self, token): + def do_acquire(self, token: str) -> bool: if self.timeout: # convert to milliseconds timeout = int(self.timeout * 1000) @@ -210,13 +228,13 @@ def do_acquire(self, token): return True return False - def locked(self): + def locked(self) -> bool: """ Returns True if this key is locked by any process, otherwise False. """ return self.redis.get(self.name) is not None - def owned(self): + def owned(self) -> bool: """ Returns True if this key is locked by this lock, otherwise False. """ @@ -228,21 +246,23 @@ def owned(self): stored_token = encoder.encode(stored_token) return self.local.token is not None and stored_token == self.local.token - def release(self): - "Releases the already acquired lock" + def release(self) -> None: + """ + Releases the already acquired lock + """ expected_token = self.local.token if expected_token is None: raise LockError("Cannot release an unlocked lock") self.local.token = None self.do_release(expected_token) - def do_release(self, expected_token): + def do_release(self, expected_token: str) -> None: if not bool( self.lua_release(keys=[self.name], args=[expected_token], client=self.redis) ): raise LockNotOwnedError("Cannot release a lock" " that's no longer owned") - def extend(self, additional_time, replace_ttl=False): + def extend(self, additional_time: int, replace_ttl: bool = False) -> bool: """ Adds more time to an already acquired lock. @@ -259,19 +279,19 @@ def extend(self, additional_time, replace_ttl=False): raise LockError("Cannot extend a lock with no timeout") return self.do_extend(additional_time, replace_ttl) - def do_extend(self, additional_time, replace_ttl): + def do_extend(self, additional_time: int, replace_ttl: bool) -> bool: additional_time = int(additional_time * 1000) if not bool( self.lua_extend( keys=[self.name], - args=[self.local.token, additional_time, replace_ttl and "1" or "0"], + args=[self.local.token, additional_time, "1" if replace_ttl else "0"], client=self.redis, ) ): - raise LockNotOwnedError("Cannot extend a lock that's" " no longer owned") + raise LockNotOwnedError("Cannot extend a lock that's no longer owned") return True - def reacquire(self): + def reacquire(self) -> bool: """ Resets a TTL of an already acquired lock back to a timeout value. """ @@ -281,12 +301,12 @@ def reacquire(self): raise LockError("Cannot reacquire a lock with no timeout") return self.do_reacquire() - def do_reacquire(self): + def do_reacquire(self) -> bool: timeout = int(self.timeout * 1000) if not bool( self.lua_reacquire( keys=[self.name], args=[self.local.token, timeout], client=self.redis ) ): - raise LockNotOwnedError("Cannot reacquire a lock that's" " no longer owned") + raise LockNotOwnedError("Cannot reacquire a lock that's no longer owned") return True diff --git a/redis/typing.py b/redis/typing.py index 73ae411f4d..ee3abfa84a 100644 --- a/redis/typing.py +++ b/redis/typing.py @@ -10,6 +10,7 @@ from redis.connection import ConnectionPool +Number = Union[int, float] EncodedT = Union[bytes, memoryview] DecodedT = Union[str, int, float] EncodableT = Union[EncodedT, DecodedT] @@ -37,7 +38,6 @@ AnyFieldT = TypeVar("AnyFieldT", bytes, str, memoryview) AnyChannelT = TypeVar("AnyChannelT", bytes, str, memoryview) - class CommandsProtocol(Protocol): connection_pool: Union["AsyncConnectionPool", "ConnectionPool"] diff --git a/tests/test_lock.py b/tests/test_lock.py index 0a63f1e06c..176d6b569a 100644 --- a/tests/test_lock.py +++ b/tests/test_lock.py @@ -116,6 +116,16 @@ def test_context_manager(self, r): assert r.get("foo") == lock.local.token assert r.get("foo") is None + def test_context_manager_blocking_timeout(self, r): + with self.get_lock(r, "foo", blocking=False) as lock1: + bt = 0.4 + sleep = 0.05 + lock2 = self.get_lock(r, "foo", sleep=sleep, blocking_timeout=bt) + start = time.monotonic() + assert not lock2.acquire() + # The elapsed duration should be less than the total blocking_timeout + assert bt > (time.monotonic() - start) > bt - sleep + def test_context_manager_raises_when_locked_not_acquired(self, r): r.set("foo", "bar") with pytest.raises(LockError): @@ -221,6 +231,16 @@ def test_reacquiring_lock_no_longer_owned_raises_error(self, r): with pytest.raises(LockNotOwnedError): lock.reacquire() + def test_context_manager_reacquiring_lock_with_no_timeout_raises_error(self, r): + with self.get_lock(r, "foo", timeout=None, blocking=False) as lock: + with pytest.raises(LockError): + lock.reacquire() + + def test_context_manager_reacquiring_lock_no_longer_owned_raises_error(self, r): + with pytest.raises(LockNotOwnedError): + with self.get_lock(r, "foo", timeout=10, blocking=False): + r.set("foo", "a") + class TestLockClassSelection: def test_lock_class_argument(self, r): From f347e06783b5f4af250324fd0252010df7bd6b9b Mon Sep 17 00:00:00 2001 From: Anas Date: Sun, 24 Apr 2022 23:59:36 +0300 Subject: [PATCH 2/3] Made linters happy --- redis/lock.py | 8 +++----- redis/typing.py | 1 + tests/test_lock.py | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/redis/lock.py b/redis/lock.py index 2a79941cd2..c509f7d9db 100644 --- a/redis/lock.py +++ b/redis/lock.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import threading import time as mod_time import uuid @@ -77,7 +75,7 @@ class Lock: def __init__( self, - redis: "Redis", + redis, name: str, *, timeout: Optional[Number] = None, @@ -166,7 +164,7 @@ def __exit__( self, exc_type: Optional[Type[BaseException]], exc_value: Optional[BaseException], - traceback: Optional[TracebackType] + traceback: Optional[TracebackType], ) -> None: self.release() @@ -176,7 +174,7 @@ def acquire( sleep: Optional[Number] = None, blocking: Optional[bool] = None, blocking_timeout: Optional[Number] = None, - token: Optional[str] = None + token: Optional[str] = None, ): """ Use Redis to hold a shared, distributed lock named ``name``. diff --git a/redis/typing.py b/redis/typing.py index ee3abfa84a..59b255071c 100644 --- a/redis/typing.py +++ b/redis/typing.py @@ -38,6 +38,7 @@ AnyFieldT = TypeVar("AnyFieldT", bytes, str, memoryview) AnyChannelT = TypeVar("AnyChannelT", bytes, str, memoryview) + class CommandsProtocol(Protocol): connection_pool: Union["AsyncConnectionPool", "ConnectionPool"] diff --git a/tests/test_lock.py b/tests/test_lock.py index 176d6b569a..10ad7e1539 100644 --- a/tests/test_lock.py +++ b/tests/test_lock.py @@ -117,7 +117,7 @@ def test_context_manager(self, r): assert r.get("foo") is None def test_context_manager_blocking_timeout(self, r): - with self.get_lock(r, "foo", blocking=False) as lock1: + with self.get_lock(r, "foo", blocking=False): bt = 0.4 sleep = 0.05 lock2 = self.get_lock(r, "foo", sleep=sleep, blocking_timeout=bt) From 80c47b6cc92aa5f40f75f13b1ec26eb90e2d76a2 Mon Sep 17 00:00:00 2001 From: Anas Date: Mon, 25 Apr 2022 00:11:02 +0300 Subject: [PATCH 3/3] Fixed cluster client lock signature --- redis/client.py | 1 - redis/cluster.py | 8 ++++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/redis/client.py b/redis/client.py index a9201cf559..baf15dddb6 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1111,7 +1111,6 @@ def lock( Note this value can be overridden by passing a ``blocking`` argument to ``acquire``. - ``blocking_timeout`` indicates the maximum amount of time in seconds to spend trying to acquire the lock. A value of ``None`` indicates continue trying forever. ``blocking_timeout`` can be specified as a diff --git a/redis/cluster.py b/redis/cluster.py index 221df856c1..e6938dc62d 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -801,6 +801,7 @@ def lock( name, timeout=None, sleep=0.1, + blocking=True, blocking_timeout=None, lock_class=None, thread_local=True, @@ -816,6 +817,12 @@ def lock( when the lock is in blocking mode and another client is currently holding the lock. + ``blocking`` indicates whether calling ``acquire`` should block until + the lock has been acquired or to fail immediately, causing ``acquire`` + to return False and the lock not being acquired. Defaults to True. + Note this value can be overridden by passing a ``blocking`` + argument to ``acquire``. + ``blocking_timeout`` indicates the maximum amount of time in seconds to spend trying to acquire the lock. A value of ``None`` indicates continue trying forever. ``blocking_timeout`` can be specified as a @@ -858,6 +865,7 @@ def lock( name, timeout=timeout, sleep=sleep, + blocking=blocking, blocking_timeout=blocking_timeout, thread_local=thread_local, )