From ca5f154ef6db349eaa05c9cf60b0e8ed064370b3 Mon Sep 17 00:00:00 2001 From: Jakob Schlyter Date: Wed, 7 Apr 2021 10:24:42 +0200 Subject: [PATCH 1/9] add thread lock while updating the key bundle (or the bundle will be corrupt) --- src/cryptojwt/key_bundle.py | 54 +++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index 34b44205..425d5e0a 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -8,6 +8,7 @@ from functools import cmp_to_key from typing import List from typing import Optional +import threading import requests @@ -507,34 +508,35 @@ def update(self): :return: True if update was ok or False if we encountered an error during update. """ if self.source: - _old_keys = self._keys # just in case + with threading.Lock(): + _old_keys = self._keys # just in case - # reread everything - self._keys = [] - updated = None + # reread everything + self._keys = [] + updated = None - try: - if self.local: - if self.fileformat in ["jwks", "jwk"]: - updated = self.do_local_jwk(self.source) - elif self.fileformat == "der": - updated = self.do_local_der(self.source, self.keytype, self.keyusage) - elif self.remote: - updated = self.do_remote() - except Exception as err: - LOGGER.error("Key bundle update failed: %s", err) - self._keys = _old_keys # restore - return False - - if updated: - now = time.time() - for _key in _old_keys: - if _key not in self._keys: - if not _key.inactive_since: # If already marked don't mess - _key.inactive_since = now - self._keys.append(_key) - else: - self._keys = _old_keys + try: + if self.local: + if self.fileformat in ["jwks", "jwk"]: + updated = self.do_local_jwk(self.source) + elif self.fileformat == "der": + updated = self.do_local_der(self.source, self.keytype, self.keyusage) + elif self.remote: + updated = self.do_remote() + except Exception as err: + LOGGER.error("Key bundle update failed: %s", err) + self._keys = _old_keys # restore + return False + + if updated: + now = time.time() + for _key in _old_keys: + if _key not in self._keys: + if not _key.inactive_since: # If already marked don't mess + _key.inactive_since = now + self._keys.append(_key) + else: + self._keys = _old_keys return True From ed139d1d1705da412601a09b899abbaac8e971de Mon Sep 17 00:00:00 2001 From: Jakob Schlyter Date: Wed, 7 Apr 2021 10:27:30 +0200 Subject: [PATCH 2/9] isort --- src/cryptojwt/key_bundle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index 425d5e0a..2527e162 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -3,12 +3,12 @@ import json import logging import os +import threading import time from datetime import datetime from functools import cmp_to_key from typing import List from typing import Optional -import threading import requests From 27aba0370bb7744199d173595266a9be2e3572c7 Mon Sep 17 00:00:00 2001 From: Jakob Schlyter Date: Wed, 7 Apr 2021 15:57:28 +0200 Subject: [PATCH 3/9] bump --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 01b995c3..1dc63136 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ exclude_lines = [ [tool.poetry] name = "cryptojwt" -version = "1.5.0" +version = "1.5.1" description = "Python implementation of JWT, JWE, JWS and JWK" authors = ["Roland Hedberg "] license = "Apache-2.0" From c9e59bc3d606d4ef421417e3322b0d9801657c90 Mon Sep 17 00:00:00 2001 From: Jakob Schlyter Date: Wed, 7 Apr 2021 16:08:54 +0200 Subject: [PATCH 4/9] use global lock for update, ok from @janste63 --- src/cryptojwt/key_bundle.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index 2527e162..4750a078 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -47,6 +47,8 @@ MAP = {"dec": "enc", "enc": "enc", "ver": "sig", "sig": "sig"} +update_lock = threading.Lock() + def harmonize_usage(use): """ @@ -508,7 +510,7 @@ def update(self): :return: True if update was ok or False if we encountered an error during update. """ if self.source: - with threading.Lock(): + with update_lock: _old_keys = self._keys # just in case # reread everything From 9656a5586b3d98dc17834a08f073c8386b40c968 Mon Sep 17 00:00:00 2001 From: Jakob Schlyter Date: Wed, 7 Apr 2021 18:08:44 +0200 Subject: [PATCH 5/9] first cut at proper read/write locking for KeyBundle --- poetry.lock | 19 +++++- pyproject.toml | 1 + src/cryptojwt/key_bundle.py | 123 +++++++++++++++++++++--------------- 3 files changed, 90 insertions(+), 53 deletions(-) diff --git a/poetry.lock b/poetry.lock index d1495f92..2afcfa31 100644 --- a/poetry.lock +++ b/poetry.lock @@ -403,6 +403,17 @@ category = "dev" optional = false python-versions = "*" +[[package]] +name = "readerwriterlock" +version = "1.0.8" +description = "A python implementation of the three Reader-Writer problems." +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +typing-extensions = "*" + [[package]] name = "regex" version = "2020.11.13" @@ -607,7 +618,7 @@ python-versions = "*" name = "typing-extensions" version = "3.7.4.3" description = "Backported and Experimental Type Hints for Python 3.5+" -category = "dev" +category = "main" optional = false python-versions = "*" @@ -639,7 +650,7 @@ testing = ["pytest (>=4.6)", "pytest-checkdocs (>=1.2.3)", "pytest-flake8", "pyt [metadata] lock-version = "1.1" python-versions = "^3.6" -content-hash = "a39019ac3af30b197f08099583fc06c61a9077638c3e351e34afd9ed019ef352" +content-hash = "27c1fe3ba4fef2096f46f02e72a8b80b0d972f78ac6739d29af5d03609c49c52" [metadata.files] alabaster = [ @@ -909,6 +920,10 @@ pytz = [ {file = "pytz-2021.1-py2.py3-none-any.whl", hash = "sha256:eb10ce3e7736052ed3623d49975ce333bcd712c7bb19a58b9e2089d4057d0798"}, {file = "pytz-2021.1.tar.gz", hash = "sha256:83a4a90894bf38e243cf052c8b58f381bfe9a7a483f6a9cab140bc7f702ac4da"}, ] +readerwriterlock = [ + {file = "readerwriterlock-1.0.8-py3-none-any.whl", hash = "sha256:35c6277a4fdc23b7449025f708578a5069698bc7eea63ce7c6c50de175c71435"}, + {file = "readerwriterlock-1.0.8.tar.gz", hash = "sha256:b806126d0c5ca90e84eb6f73a2bf9761df6c0b3184b9063bc229078cdf1464a7"}, +] regex = [ {file = "regex-2020.11.13-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:8b882a78c320478b12ff024e81dc7d43c1462aa4a3341c754ee65d857a521f85"}, {file = "regex-2020.11.13-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:a63f1a07932c9686d2d416fb295ec2c01ab246e89b4d58e5fa468089cab44b70"}, diff --git a/pyproject.toml b/pyproject.toml index 1dc63136..b8e8c9dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ jwtpeek = "cryptojwt.tools.jwtpeek:main" python = "^3.6" cryptography = "^3.4.6" requests = "^2.25.1" +readerwriterlock = "^1.0.8" [tool.poetry.dev-dependencies] alabaster = "^0.7.12" diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index 4750a078..88e1b53a 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -3,7 +3,6 @@ import json import logging import os -import threading import time from datetime import datetime from functools import cmp_to_key @@ -11,6 +10,7 @@ from typing import Optional import requests +from readerwriterlock import rwlock from cryptojwt.jwk.ec import NIST2SEC from cryptojwt.jwk.hmac import new_sym_key @@ -47,8 +47,6 @@ MAP = {"dec": "enc", "enc": "enc", "ver": "sig", "sig": "sig"} -update_lock = threading.Lock() - def harmonize_usage(use): """ @@ -153,6 +151,14 @@ def ec_init(spec): return _kb +def keys_writer(func): + def wrapper(self, *args, **kwargs): + with self._lock_writer: + return func(self, *args, **kwargs) + + return wrapper + + class KeyBundle: """The Key Bundle""" @@ -230,6 +236,10 @@ def __init__( self.source = None self.time_out = 0 + self._lock = rwlock.RWLockFairD() + self._lock_reader = self._lock.gen_rlock() + self._lock_writer = self._lock.gen_wlock() + if httpc: self.httpc = httpc else: @@ -500,6 +510,7 @@ def _uptodate(self): return self.update() return False + @keys_writer def update(self): """ Reload the keys if necessary. @@ -510,35 +521,34 @@ def update(self): :return: True if update was ok or False if we encountered an error during update. """ if self.source: - with update_lock: - _old_keys = self._keys # just in case + _old_keys = self._keys # just in case - # reread everything - self._keys = [] - updated = None + # reread everything + self._keys = [] + updated = None - try: - if self.local: - if self.fileformat in ["jwks", "jwk"]: - updated = self.do_local_jwk(self.source) - elif self.fileformat == "der": - updated = self.do_local_der(self.source, self.keytype, self.keyusage) - elif self.remote: - updated = self.do_remote() - except Exception as err: - LOGGER.error("Key bundle update failed: %s", err) - self._keys = _old_keys # restore - return False - - if updated: - now = time.time() - for _key in _old_keys: - if _key not in self._keys: - if not _key.inactive_since: # If already marked don't mess - _key.inactive_since = now - self._keys.append(_key) - else: - self._keys = _old_keys + try: + if self.local: + if self.fileformat in ["jwks", "jwk"]: + updated = self.do_local_jwk(self.source) + elif self.fileformat == "der": + updated = self.do_local_der(self.source, self.keytype, self.keyusage) + elif self.remote: + updated = self.do_remote() + except Exception as err: + LOGGER.error("Key bundle update failed: %s", err) + self._keys = _old_keys # restore + return False + + if updated: + now = time.time() + for _key in _old_keys: + if _key not in self._keys: + if not _key.inactive_since: # If already marked don't mess + _key.inactive_since = now + self._keys.append(_key) + else: + self._keys = _old_keys return True @@ -551,32 +561,34 @@ def get(self, typ="", only_active=True): otherwise the appropriate keys in a list """ self._uptodate() - _typs = [typ.lower(), typ.upper()] - if typ: - _keys = [k for k in self._keys if k.kty in _typs] - else: - _keys = self._keys + with self._lock_reader: + if typ: + _typs = [typ.lower(), typ.upper()] + _keys = [k for k in self._keys if k.kty in _typs] + else: + _keys = self._keys if only_active: return [k for k in _keys if not k.inactive_since] return _keys - def keys(self): + def keys(self, update: bool = True): """ Return all keys after having updated them :return: List of all keys """ - self._uptodate() - - return self._keys + if update: + self._uptodate() + with self._lock_reader: + return self._keys def active_keys(self): """Return the set of active keys.""" _res = [] - for k in self._keys: + for k in self.keys(): try: ias = k.inactive_since except ValueError: @@ -586,6 +598,7 @@ def active_keys(self): _res.append(k) return _res + @keys_writer def remove_keys_by_type(self, typ): """ Remove keys that are of a specific type. @@ -605,9 +618,8 @@ def jwks(self, private=False): :param private: Whether private key information should be included. :return: A JWKS JSON representation of the keys in this bundle """ - self._uptodate() keys = list() - for k in self._keys: + for k in self.keys(): if private: key = k.serialize(private) else: @@ -617,6 +629,7 @@ def jwks(self, private=False): keys.append(key) return json.dumps({"keys": keys}) + @keys_writer def append(self, key): """ Add a key to list of keys in this bundle @@ -625,10 +638,12 @@ def append(self, key): """ self._keys.append(key) + @keys_writer def extend(self, keys): """Add a key to the list of keys.""" self._keys.extend(keys) + @keys_writer def remove(self, key): """ Remove a specific key from this bundle @@ -648,6 +663,7 @@ def __len__(self): """ return len(self._keys) + @keys_writer def set(self, keys): """Set the keys to the set provided.""" self._keys = keys @@ -659,13 +675,15 @@ def get_key_with_kid(self, kid): :param kid: The Key ID :return: The key or None """ + self._uptodate() + with self._lock_reader: + return self._get_key_with_kid(kid) + + def _get_key_with_kid(self, kid): for key in self._keys: if key.kid == kid: return key - # Try updating since there might have been an update to the key file - self.update() - for key in self._keys: if key.kid == kid: return key @@ -680,16 +698,16 @@ def kids(self): The reason might be that there are some keys with no key ID. :return: A list of all the key IDs that exists in this bundle """ - self._uptodate() - return [key.kid for key in self._keys if key.kid != ""] + return [key.kid for key in self.keys() if key.kid != ""] + @keys_writer def mark_as_inactive(self, kid): """ Mark a specific key as inactive based on the keys KeyID. :param kid: The Key Identifier """ - k = self.get_key_with_kid(kid) + k = self._get_key_with_kid(kid) if k: self._keys.remove(k) k.inactive_since = time.time() @@ -698,17 +716,19 @@ def mark_as_inactive(self, kid): else: return False + @keys_writer def mark_all_as_inactive(self): """ Mark a specific key as inactive based on the keys KeyID. """ - _keys = self.keys() + _keys = self._keys _updated = [] for k in _keys: k.inactive_since = time.time() _updated.append(k) self._keys = _updated + @keys_writer def remove_outdated(self, after, when=0): """ Remove keys that should not be available any more. @@ -775,7 +795,7 @@ def difference(self, bundle): if not isinstance(bundle, KeyBundle): return ValueError("Not a KeyBundle instance") - return [k for k in self._keys if k not in bundle] + return [k for k in self.keys() if k not in bundle] def dump(self, exclude_attributes: Optional[List[str]] = None): if exclude_attributes is None: @@ -785,7 +805,7 @@ def dump(self, exclude_attributes: Optional[List[str]] = None): if "keys" not in exclude_attributes: _keys = [] - for _k in self._keys: + for _k in self.keys(update=False): _ser = _k.to_dict() if _k.inactive_since: _ser["inactive_since"] = _k.inactive_since @@ -819,6 +839,7 @@ def load(self, spec): return self + @keys_writer def flush(self): self._keys = [] self.cache_time = (300,) From caead3fcff17264b6fbe8fd9d282788c6d42ed09 Mon Sep 17 00:00:00 2001 From: Jakob Schlyter Date: Mon, 12 Apr 2021 16:34:07 +0200 Subject: [PATCH 6/9] more lock fixes --- src/cryptojwt/key_bundle.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index 88e1b53a..588cf48a 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -151,6 +151,14 @@ def ec_init(spec): return _kb +def keys_reader(func): + def wrapper(self, *args, **kwargs): + with self._lock_reader: + return func(self, *args, **kwargs) + + return wrapper + + def keys_writer(func): def wrapper(self, *args, **kwargs): with self._lock_writer: @@ -655,6 +663,7 @@ def remove(self, key): except ValueError: pass + @keys_reader def __len__(self): """ The number of keys. @@ -760,8 +769,9 @@ def remove_outdated(self, after, when=0): return changed def __contains__(self, key): - return key in self._keys + return key in self.keys() + @keys_reader def copy(self): """ Make deep copy of this KeyBundle @@ -782,6 +792,7 @@ def copy(self): return _bundle + @keys_reader def __iter__(self): return self._keys.__iter__() From 670d0bc087c9adfa4ad945f68f80921202d5e138 Mon Sep 17 00:00:00 2001 From: Jakob Schlyter Date: Mon, 12 Apr 2021 16:38:46 +0200 Subject: [PATCH 7/9] use copies --- src/cryptojwt/key_bundle.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index 588cf48a..5c84f2a8 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -575,7 +575,7 @@ def get(self, typ="", only_active=True): _typs = [typ.lower(), typ.upper()] _keys = [k for k in self._keys if k.kty in _typs] else: - _keys = self._keys + _keys = copy.copy(self._keys) if only_active: return [k for k in _keys if not k.inactive_since] @@ -591,7 +591,7 @@ def keys(self, update: bool = True): if update: self._uptodate() with self._lock_reader: - return self._keys + return copy.copy(self._keys) def active_keys(self): """Return the set of active keys.""" @@ -792,9 +792,8 @@ def copy(self): return _bundle - @keys_reader def __iter__(self): - return self._keys.__iter__() + return self.keys().__iter__() def difference(self, bundle): """ From 2f2e8cb3bac07ff8edd3251e31a8284ed0da4eab Mon Sep 17 00:00:00 2001 From: Jakob Schlyter Date: Mon, 12 Apr 2021 16:53:30 +0200 Subject: [PATCH 8/9] - mark do_local_der and do_local_jwk as private - create public version of do_keys (with write lock) --- src/cryptojwt/key_bundle.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index 5c84f2a8..fdb19b4b 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -259,11 +259,11 @@ def __init__( self.source = None if isinstance(keys, dict): if "keys" in keys: - self.do_keys(keys["keys"]) + self._do_keys(keys["keys"]) else: - self.do_keys([keys]) + self._do_keys([keys]) else: - self.do_keys(keys) + self._do_keys(keys) else: self._set_source(source, fileformat) if self.local: @@ -290,9 +290,9 @@ def _set_source(self, source, fileformat): def _do_local(self, kid): if self.fileformat in ["jwks", "jwk"]: - self.do_local_jwk(self.source) + self._do_local_jwk(self.source) elif self.fileformat == "der": - self.do_local_der(self.source, self.keytype, self.keyusage, kid) + self._do_local_der(self.source, self.keytype, self.keyusage, kid) def _local_update_required(self) -> bool: stat = os.stat(self.source) @@ -304,7 +304,11 @@ def _local_update_required(self) -> bool: self.last_local = stat.st_mtime return True + @keys_writer def do_keys(self, keys): + return self._do_keys(keys) + + def _do_keys(self, keys): """ Go from JWK description to binary keys @@ -366,7 +370,7 @@ def do_keys(self, keys): self.last_updated = time.time() - def do_local_jwk(self, filename): + def _do_local_jwk(self, filename): """ Load a JWKS from a local file @@ -380,14 +384,14 @@ def do_local_jwk(self, filename): with open(filename) as input_file: _info = json.load(input_file) if "keys" in _info: - self.do_keys(_info["keys"]) + self._do_keys(_info["keys"]) else: - self.do_keys([_info]) + self._do_keys([_info]) self.last_local = time.time() self.time_out = self.last_local + self.cache_time return True - def do_local_der(self, filename, keytype, keyusage=None, kid=""): + def _do_local_der(self, filename, keytype, keyusage=None, kid=""): """ Load a DER encoded file amd create a key from it. @@ -418,7 +422,7 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=""): if kid: key_args["kid"] = kid - self.do_keys([key_args]) + self._do_keys([key_args]) self.last_local = time.time() self.time_out = self.last_local + self.cache_time return True @@ -465,7 +469,7 @@ def do_remote(self): LOGGER.debug("Loaded JWKS: %s from %s", _http_resp.text, self.source) try: - self.do_keys(self.imp_jwks["keys"]) + self._do_keys(self.imp_jwks["keys"]) except KeyError: LOGGER.error("No 'keys' keyword in JWKS") self.ignore_errors_until = time.time() + self.ignore_errors_period @@ -538,9 +542,9 @@ def update(self): try: if self.local: if self.fileformat in ["jwks", "jwk"]: - updated = self.do_local_jwk(self.source) + updated = self._do_local_jwk(self.source) elif self.fileformat == "der": - updated = self.do_local_der(self.source, self.keytype, self.keyusage) + updated = self._do_local_der(self.source, self.keytype, self.keyusage) elif self.remote: updated = self.do_remote() except Exception as err: @@ -840,7 +844,7 @@ def load(self, spec): """ _keys = spec.get("keys", []) if _keys: - self.do_keys(_keys) + self._do_keys(_keys) for attr, default in self.params.items(): val = spec.get(attr) From cf76011439f213c365dd18403d5dc443d9260483 Mon Sep 17 00:00:00 2001 From: Jakob Schlyter Date: Mon, 12 Apr 2021 16:54:44 +0200 Subject: [PATCH 9/9] address more comments from @janste63 --- src/cryptojwt/key_bundle.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index fdb19b4b..4a327471 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -834,6 +834,7 @@ def dump(self, exclude_attributes: Optional[List[str]] = None): return res + @keys_writer def load(self, spec): """ Sets attributes according to a specification.