diff --git a/redis/commands/core.py b/redis/commands/core.py index 2937780577..a387c51860 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -1492,16 +1492,42 @@ def exists(self, *names: KeyT) -> ResponseT: __contains__ = exists - def expire(self, name: KeyT, time: ExpiryT) -> ResponseT: + def expire( + self, + name: KeyT, + time: ExpiryT, + nx: bool = False, + xx: bool = False, + gt: bool = False, + lt: bool = False, + ) -> ResponseT: """ - Set an expire flag on key ``name`` for ``time`` seconds. ``time`` - can be represented by an integer or a Python timedelta object. + Set an expire flag on key ``name`` for ``time`` seconds with given + ``option``. ``time`` can be represented by an integer or a Python timedelta + object. + + Valid options are: + NX -> Set expiry only when the key has no expiry + XX -> Set expiry only when the key has an existing expiry + GT -> Set expiry only when the new expiry is greater than current one + LT -> Set expiry only when the new expiry is less than current one For more information check https://redis.io/commands/expire """ if isinstance(time, datetime.timedelta): time = int(time.total_seconds()) - return self.execute_command("EXPIRE", name, time) + + exp_option = list() + if nx: + exp_option.append("NX") + if xx: + exp_option.append("XX") + if gt: + exp_option.append("GT") + if lt: + exp_option.append("LT") + + return self.execute_command("EXPIRE", name, time, *exp_option) def expireat(self, name: KeyT, when: AbsExpiryT) -> ResponseT: """ diff --git a/tests/test_commands.py b/tests/test_commands.py index dc10cde165..2e1291415a 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -1058,6 +1058,31 @@ def test_expire(self, r): assert r.persist("a") assert r.ttl("a") == -1 + @skip_if_server_version_lt("7.0.0") + def test_expire_option_nx(self, r): + r.set("key", "val") + assert r.expire("key", 100, nx=True) == 1 + assert r.expire("key", 500, nx=True) == 0 + + @skip_if_server_version_lt("7.0.0") + def test_expire_option_xx(self, r): + r.set("key", "val") + assert r.expire("key", 100, xx=True) == 0 + assert r.expire("key", 100) + assert r.expire("key", 500, nx=True) == 1 + + @skip_if_server_version_lt("7.0.0") + def test_expire_option_gt(self, r): + r.set("key", "val", 100) + assert r.expire("key", 50, gt=True) == 0 + assert r.expire("key", 500, gt=True) == 1 + + @skip_if_server_version_lt("7.0.0") + def test_expire_option_lt(self, r): + r.set("key", "val", 100) + assert r.expire("key", 50, lt=True) == 1 + assert r.expire("key", 150, lt=True) == 0 + def test_expireat_datetime(self, r): expire_at = redis_server_time(r) + datetime.timedelta(minutes=1) r["a"] = "foo"