From 008da5945ce1c542f64628c984bc1397395e675f Mon Sep 17 00:00:00 2001 From: Jakob Schlyter Date: Fri, 4 Mar 2022 10:44:37 +0100 Subject: [PATCH 1/5] use locked version, set last_updated --- src/cryptojwt/key_bundle.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index 2387f69b..dfd48524 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -252,11 +252,11 @@ def __init__( self.source = None if isinstance(keys, dict): if "keys" in keys: - self._add_jwk_dicts(keys["keys"]) + self.add_jwk_dicts(keys["keys"]) else: - self._add_jwk_dicts([keys]) + self.add_jwk_dicts([keys]) else: - self._add_jwk_dicts(keys) + self.add_jwk_dicts(keys) else: self._set_source(source, fileformat) if self.local: @@ -310,12 +310,12 @@ def add_jwk_dicts(self, keys): :return: """ self._add_jwk_dicts(keys) + self.last_updated = time.time() def _add_jwk_dicts(self, keys): _new_keys = self.jwk_dicts_as_keys(keys) if _new_keys: self._keys.extend(_new_keys) - self.last_updated = time.time() def jwk_dicts_as_keys(self, keys): """ From 99a9ce99fbf51cd64b4a945cd02609a24d5723e9 Mon Sep 17 00:00:00 2001 From: Jakob Schlyter Date: Fri, 4 Mar 2022 11:31:07 +0100 Subject: [PATCH 2/5] more --- src/cryptojwt/key_bundle.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index dfd48524..c9ee1347 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -252,15 +252,16 @@ def __init__( self.source = None if isinstance(keys, dict): if "keys" in keys: - self.add_jwk_dicts(keys["keys"]) + initial_keys = keys["keys"] else: - self.add_jwk_dicts([keys]) + initial_keys = [keys] else: - self.add_jwk_dicts(keys) + initial_keys = keys + self._keys = self.jwk_dicts_as_keys(initial_keys) else: self._set_source(source, fileformat) if self.local: - self._do_local(kid) + self._keys = self._do_local(kid) def _set_source(self, source, fileformat): if source.startswith("file://"): @@ -286,6 +287,7 @@ def _do_local(self, kid): self._do_local_jwk(self.source) elif self.fileformat == "der": self._do_local_der(self.source, self.keytype, self.keyusage, kid) + return self._keys def _local_update_required(self) -> bool: stat = os.stat(self.source) From 87a99087514e4dcc96b90437789e334f13e04e7b Mon Sep 17 00:00:00 2001 From: Jakob Schlyter Date: Fri, 4 Mar 2022 11:36:20 +0100 Subject: [PATCH 3/5] more --- src/cryptojwt/key_bundle.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index c9ee1347..cec07f86 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -284,10 +284,10 @@ def _set_source(self, source, fileformat): def _do_local(self, kid): if self.fileformat in ["jwks", "jwk"]: - self._do_local_jwk(self.source) + updated, res = self._do_local_jwk(self.source) elif self.fileformat == "der": - self._do_local_der(self.source, self.keytype, self.keyusage, kid) - return self._keys + updated, res = self._do_local_der(self.source, self.keytype, self.keyusage, kid) + return res def _local_update_required(self) -> bool: stat = os.stat(self.source) @@ -386,18 +386,19 @@ def _do_local_jwk(self, filename): :return: True if load was successful or False if file hasn't been modified """ if not self._local_update_required(): - return False + return False, None LOGGER.info("Reading local JWKS from %s", filename) with open(filename) as input_file: _info = json.load(input_file) if "keys" in _info: - self._add_jwk_dicts(_info["keys"]) + res = self.jwk_dicts_as_keys(_info["keys"]) else: - self._add_jwk_dicts([_info]) + res = self.jwk_dicts_as_keys([_info]) + self.last_local = time.time() self.time_out = self.last_local + self.cache_time - return True + return True, res def _do_local_der(self, filename, keytype, keyusage=None, kid=""): """ @@ -409,7 +410,7 @@ def _do_local_der(self, filename, keytype, keyusage=None, kid=""): :return: True if load was successful or False if file hasn't been modified """ if not self._local_update_required(): - return False + return False, None LOGGER.info("Reading local DER from %s", filename) key_args = {} @@ -430,10 +431,10 @@ def _do_local_der(self, filename, keytype, keyusage=None, kid=""): if kid: key_args["kid"] = kid - self._add_jwk_dicts([key_args]) + res = self.jwk_dicts_as_keys([key_args]) self.last_local = time.time() self.time_out = self.last_local + self.cache_time - return True + return True, res def _do_remote(self): """ @@ -553,9 +554,13 @@ def update(self): try: if self.local: if self.fileformat in ["jwks", "jwk"]: - updated = self._do_local_jwk(self.source) + updated, k = self._do_local_jwk(self.source) + if k: + self._keys.extend(k) elif self.fileformat == "der": - updated = self._do_local_der(self.source, self.keytype, self.keyusage) + updated, k = self._do_local_der(self.source, self.keytype, self.keyusage) + if k: + self._keys.extend(k) elif self.remote: updated = self._do_remote() except Exception as err: From 8d85731077e1f0a9ffdad1e3f4c20f9a8c75934f Mon Sep 17 00:00:00 2001 From: Jakob Schlyter Date: Fri, 4 Mar 2022 13:10:18 +0100 Subject: [PATCH 4/5] ensure we only update keys once --- src/cryptojwt/key_bundle.py | 66 ++++++++++++++++--------------------- tests/test_03_key_bundle.py | 27 ++++++++------- 2 files changed, 44 insertions(+), 49 deletions(-) diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index cec07f86..4af7bfca 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -263,6 +263,7 @@ def __init__( if self.local: self._keys = self._do_local(kid) + def _set_source(self, source, fileformat): if source.startswith("file://"): self.source = source[7:] @@ -284,10 +285,10 @@ def _set_source(self, source, fileformat): def _do_local(self, kid): if self.fileformat in ["jwks", "jwk"]: - updated, res = self._do_local_jwk(self.source) + updated, keys = self._do_local_jwk(self.source) elif self.fileformat == "der": - updated, res = self._do_local_der(self.source, self.keytype, self.keyusage, kid) - return res + updated, keys = self._do_local_der(self.source, self.keytype, self.keyusage, kid) + return keys def _local_update_required(self) -> bool: stat = os.stat(self.source) @@ -311,14 +312,9 @@ def add_jwk_dicts(self, keys): :param keys: List of JWK dictionaries :return: """ - self._add_jwk_dicts(keys) + self._keys.extend(self.jwk_dicts_as_keys(keys)) self.last_updated = time.time() - def _add_jwk_dicts(self, keys): - _new_keys = self.jwk_dicts_as_keys(keys) - if _new_keys: - self._keys.extend(_new_keys) - def jwk_dicts_as_keys(self, keys): """ Return JWK dictionaries as list of JWK objects @@ -392,13 +388,13 @@ def _do_local_jwk(self, filename): with open(filename) as input_file: _info = json.load(input_file) if "keys" in _info: - res = self.jwk_dicts_as_keys(_info["keys"]) + new_keys = self.jwk_dicts_as_keys(_info["keys"]) else: - res = self.jwk_dicts_as_keys([_info]) + new_keys = self.jwk_dicts_as_keys([_info]) self.last_local = time.time() self.time_out = self.last_local + self.cache_time - return True, res + return True, new_keys def _do_local_der(self, filename, keytype, keyusage=None, kid=""): """ @@ -431,12 +427,12 @@ def _do_local_der(self, filename, keytype, keyusage=None, kid=""): if kid: key_args["kid"] = kid - res = self.jwk_dicts_as_keys([key_args]) + new_keys = self.jwk_dicts_as_keys([key_args]) self.last_local = time.time() self.time_out = self.last_local + self.cache_time - return True, res + return True, new_keys - def _do_remote(self): + def _do_remote(self, set_keys=True): """ Load a JWKS from a webpage. @@ -451,7 +447,7 @@ def _do_remote(self): self.source, datetime.fromtimestamp(self.ignore_errors_until), ) - return False + return False, None LOGGER.info("Reading remote JWKS from %s", self.source) try: @@ -500,11 +496,12 @@ def _do_remote(self): self.ignore_errors_until = time.time() + self.ignore_errors_period raise UpdateFailed(REMOTE_FAILED.format(self.source, _http_resp.status_code)) - if new_keys is not None: + if set_keys and new_keys: self._keys = new_keys + self.last_updated = time.time() self.ignore_errors_until = None - return load_successful + return load_successful, new_keys def _parse_remote_response(self, response): """ @@ -545,38 +542,31 @@ def update(self): :return: True if update was ok or False if we encountered an error during update. """ if self.source: - _old_keys = self._keys # just in case - - # reread everything - self._keys = [] + new_keys = [] updated = None try: if self.local: if self.fileformat in ["jwks", "jwk"]: updated, k = self._do_local_jwk(self.source) - if k: - self._keys.extend(k) elif self.fileformat == "der": updated, k = self._do_local_der(self.source, self.keytype, self.keyusage) - if k: - self._keys.extend(k) elif self.remote: - updated = self._do_remote() + updated, k = self._do_remote(set_keys=False) + if k: + new_keys.extend(k) except Exception as err: LOGGER.error("Key bundle update failed: %s", err) - self._keys = _old_keys # restore return False if updated: now = time.time() - for _key in _old_keys: - if _key not in self._keys: + for _key in self._keys: + if _key not in new_keys: if not _key.inactive_since: # If already marked don't mess _key.inactive_since = now - self._keys.append(_key) - else: - self._keys = _old_keys + new_keys.append(_key) + self._keys = new_keys return True @@ -592,9 +582,9 @@ def get(self, typ="", only_active=True): if typ: _typs = [typ.lower(), typ.upper()] - _keys = [k for k in self._keys[:] if k.kty in _typs] + _keys = [k for k in self._keys if k.kty in _typs] else: - _keys = self._keys[:] + _keys = self._keys if only_active: return [k for k in _keys if not k.inactive_since] @@ -609,7 +599,7 @@ def keys(self, update: bool = True): """ if update: self._uptodate() - return self._keys[:] + return self._keys def active_keys(self): """Return the set of active keys.""" @@ -836,9 +826,11 @@ def load(self, spec): :param spec: Dictionary with attributes and value to populate the instance with :return: The instance itself """ + _keys = spec.get("keys", []) if _keys: - self._add_jwk_dicts(_keys) + self._keys.extend(self.jwk_dicts_as_keys(_keys)) + self.last_updated = time.time() for attr, default in self.params.items(): val = spec.get(attr) diff --git a/tests/test_03_key_bundle.py b/tests/test_03_key_bundle.py index 95f83c04..048ca958 100755 --- a/tests/test_03_key_bundle.py +++ b/tests/test_03_key_bundle.py @@ -480,7 +480,8 @@ def test_httpc_params_1(): rsps.add(method=responses.GET, url=source, json=JWKS_DICT, status=200) httpc_params = {"timeout": (2, 2)} # connect, read timeouts in seconds kb = KeyBundle(source=source, httpc=requests.request, httpc_params=httpc_params) - assert kb._do_remote() + updated, _ = kb._do_remote() + assert updated == True @pytest.mark.network @@ -920,7 +921,7 @@ def test_export_inactive(): def test_remote(): - source = "https://example.com/keys.json" + source = "https://example.com/test_remote/keys.json" # Mock response with responses.RequestsMock() as rsps: rsps.add(method="GET", url=source, json=JWKS_DICT, status=200) @@ -941,7 +942,7 @@ def test_remote(): def test_remote_not_modified(): - source = "https://example.com/keys.json" + source = "https://example.com/test_remote_not_modified/keys.json" headers = { "Date": "Fri, 15 Mar 2019 10:14:25 GMT", "Last-Modified": "Fri, 1 Jan 1970 00:00:00 GMT", @@ -954,13 +955,15 @@ def test_remote_not_modified(): with responses.RequestsMock() as rsps: rsps.add(method="GET", url=source, json=JWKS_DICT, status=200, headers=headers) - assert kb._do_remote() + updated, _ = kb._do_remote() + assert updated == True assert kb.last_remote == headers.get("Last-Modified") timeout1 = kb.time_out with responses.RequestsMock() as rsps: rsps.add(method="GET", url=source, status=304, headers=headers) - assert not kb._do_remote() + updated, _ = kb._do_remote() + assert not updated assert kb.last_remote == headers.get("Last-Modified") timeout2 = kb.time_out @@ -980,8 +983,8 @@ def test_remote_not_modified(): def test_ignore_errors_period(): - source_good = "https://example.com/keys.json" - source_bad = "https://example.com/keys-bad.json" + source_good = "https://example.com/test_ignore_errors_period/keys.json" + source_bad = "https://example.com/test_ignore_errors_period/keys-bad.json" ignore_errors_period = 1 # Mock response with responses.RequestsMock() as rsps: @@ -994,19 +997,19 @@ def test_ignore_errors_period(): httpc_params=httpc_params, ignore_errors_period=ignore_errors_period, ) - res = kb._do_remote() + res, _ = kb._do_remote() assert res == True assert kb.ignore_errors_until is None # refetch, but fail by using a bad source kb.source = source_bad try: - res = kb._do_remote() + res, _ = kb._do_remote() except UpdateFailed: pass # retry should fail silently as we're in holddown - res = kb._do_remote() + res, _ = kb._do_remote() assert kb.ignore_errors_until is not None assert res == False @@ -1015,7 +1018,7 @@ def test_ignore_errors_period(): # try again kb.source = source_good - res = kb._do_remote() + res, _ = kb._do_remote() assert res == True @@ -1031,7 +1034,7 @@ def test_ignore_invalid_keys(): def test_exclude_attributes(): - source = "https://example.com/keys.json" + source = "https://example.com/test_exclude_attributes/keys.json" # Mock response with responses.RequestsMock() as rsps: rsps.add(method="GET", url=source, json=JWKS_DICT, status=200) From baacdd13c18cb84c74774f1726b97abf8a52b3a6 Mon Sep 17 00:00:00 2001 From: Jakob Schlyter Date: Fri, 4 Mar 2022 13:30:45 +0100 Subject: [PATCH 5/5] reformat --- src/cryptojwt/key_bundle.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index 4af7bfca..475b1a6f 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -263,7 +263,6 @@ def __init__( if self.local: self._keys = self._do_local(kid) - def _set_source(self, source, fileformat): if source.startswith("file://"): self.source = source[7:] @@ -552,7 +551,7 @@ def update(self): elif self.fileformat == "der": updated, k = self._do_local_der(self.source, self.keytype, self.keyusage) elif self.remote: - updated, k = self._do_remote(set_keys=False) + updated, k = self._do_remote(set_keys=False) if k: new_keys.extend(k) except Exception as err: