diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index 2387f69b..475b1a6f 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -252,15 +252,16 @@ def __init__( self.source = None if isinstance(keys, dict): if "keys" in keys: - self._add_jwk_dicts(keys["keys"]) + initial_keys = keys["keys"] else: - self._add_jwk_dicts([keys]) + initial_keys = [keys] else: - self._add_jwk_dicts(keys) + initial_keys = keys + self._keys = self.jwk_dicts_as_keys(initial_keys) else: self._set_source(source, fileformat) if self.local: - self._do_local(kid) + self._keys = self._do_local(kid) def _set_source(self, source, fileformat): if source.startswith("file://"): @@ -283,9 +284,10 @@ def _set_source(self, source, fileformat): def _do_local(self, kid): if self.fileformat in ["jwks", "jwk"]: - self._do_local_jwk(self.source) + updated, keys = self._do_local_jwk(self.source) elif self.fileformat == "der": - self._do_local_der(self.source, self.keytype, self.keyusage, kid) + updated, keys = self._do_local_der(self.source, self.keytype, self.keyusage, kid) + return keys def _local_update_required(self) -> bool: stat = os.stat(self.source) @@ -309,13 +311,8 @@ def add_jwk_dicts(self, keys): :param keys: List of JWK dictionaries :return: """ - self._add_jwk_dicts(keys) - - def _add_jwk_dicts(self, keys): - _new_keys = self.jwk_dicts_as_keys(keys) - if _new_keys: - self._keys.extend(_new_keys) - self.last_updated = time.time() + self._keys.extend(self.jwk_dicts_as_keys(keys)) + self.last_updated = time.time() def jwk_dicts_as_keys(self, keys): """ @@ -384,18 +381,19 @@ def _do_local_jwk(self, filename): :return: True if load was successful or False if file hasn't been modified """ if not self._local_update_required(): - return False + return False, None LOGGER.info("Reading local JWKS from %s", filename) with open(filename) as input_file: _info = json.load(input_file) if "keys" in _info: - self._add_jwk_dicts(_info["keys"]) + new_keys = self.jwk_dicts_as_keys(_info["keys"]) else: - self._add_jwk_dicts([_info]) + new_keys = self.jwk_dicts_as_keys([_info]) + self.last_local = time.time() self.time_out = self.last_local + self.cache_time - return True + return True, new_keys def _do_local_der(self, filename, keytype, keyusage=None, kid=""): """ @@ -407,7 +405,7 @@ def _do_local_der(self, filename, keytype, keyusage=None, kid=""): :return: True if load was successful or False if file hasn't been modified """ if not self._local_update_required(): - return False + return False, None LOGGER.info("Reading local DER from %s", filename) key_args = {} @@ -428,12 +426,12 @@ def _do_local_der(self, filename, keytype, keyusage=None, kid=""): if kid: key_args["kid"] = kid - self._add_jwk_dicts([key_args]) + new_keys = self.jwk_dicts_as_keys([key_args]) self.last_local = time.time() self.time_out = self.last_local + self.cache_time - return True + return True, new_keys - def _do_remote(self): + def _do_remote(self, set_keys=True): """ Load a JWKS from a webpage. @@ -448,7 +446,7 @@ def _do_remote(self): self.source, datetime.fromtimestamp(self.ignore_errors_until), ) - return False + return False, None LOGGER.info("Reading remote JWKS from %s", self.source) try: @@ -497,11 +495,12 @@ def _do_remote(self): self.ignore_errors_until = time.time() + self.ignore_errors_period raise UpdateFailed(REMOTE_FAILED.format(self.source, _http_resp.status_code)) - if new_keys is not None: + if set_keys and new_keys: self._keys = new_keys + self.last_updated = time.time() self.ignore_errors_until = None - return load_successful + return load_successful, new_keys def _parse_remote_response(self, response): """ @@ -542,34 +541,31 @@ def update(self): :return: True if update was ok or False if we encountered an error during update. """ if self.source: - _old_keys = self._keys # just in case - - # reread everything - self._keys = [] + new_keys = [] updated = None try: if self.local: if self.fileformat in ["jwks", "jwk"]: - updated = self._do_local_jwk(self.source) + updated, k = self._do_local_jwk(self.source) elif self.fileformat == "der": - updated = self._do_local_der(self.source, self.keytype, self.keyusage) + updated, k = self._do_local_der(self.source, self.keytype, self.keyusage) elif self.remote: - updated = self._do_remote() + updated, k = self._do_remote(set_keys=False) + if k: + new_keys.extend(k) except Exception as err: LOGGER.error("Key bundle update failed: %s", err) - self._keys = _old_keys # restore return False if updated: now = time.time() - for _key in _old_keys: - if _key not in self._keys: + for _key in self._keys: + if _key not in new_keys: if not _key.inactive_since: # If already marked don't mess _key.inactive_since = now - self._keys.append(_key) - else: - self._keys = _old_keys + new_keys.append(_key) + self._keys = new_keys return True @@ -585,9 +581,9 @@ def get(self, typ="", only_active=True): if typ: _typs = [typ.lower(), typ.upper()] - _keys = [k for k in self._keys[:] if k.kty in _typs] + _keys = [k for k in self._keys if k.kty in _typs] else: - _keys = self._keys[:] + _keys = self._keys if only_active: return [k for k in _keys if not k.inactive_since] @@ -602,7 +598,7 @@ def keys(self, update: bool = True): """ if update: self._uptodate() - return self._keys[:] + return self._keys def active_keys(self): """Return the set of active keys.""" @@ -829,9 +825,11 @@ def load(self, spec): :param spec: Dictionary with attributes and value to populate the instance with :return: The instance itself """ + _keys = spec.get("keys", []) if _keys: - self._add_jwk_dicts(_keys) + self._keys.extend(self.jwk_dicts_as_keys(_keys)) + self.last_updated = time.time() for attr, default in self.params.items(): val = spec.get(attr) diff --git a/tests/test_03_key_bundle.py b/tests/test_03_key_bundle.py index 95f83c04..048ca958 100755 --- a/tests/test_03_key_bundle.py +++ b/tests/test_03_key_bundle.py @@ -480,7 +480,8 @@ def test_httpc_params_1(): rsps.add(method=responses.GET, url=source, json=JWKS_DICT, status=200) httpc_params = {"timeout": (2, 2)} # connect, read timeouts in seconds kb = KeyBundle(source=source, httpc=requests.request, httpc_params=httpc_params) - assert kb._do_remote() + updated, _ = kb._do_remote() + assert updated == True @pytest.mark.network @@ -920,7 +921,7 @@ def test_export_inactive(): def test_remote(): - source = "https://example.com/keys.json" + source = "https://example.com/test_remote/keys.json" # Mock response with responses.RequestsMock() as rsps: rsps.add(method="GET", url=source, json=JWKS_DICT, status=200) @@ -941,7 +942,7 @@ def test_remote(): def test_remote_not_modified(): - source = "https://example.com/keys.json" + source = "https://example.com/test_remote_not_modified/keys.json" headers = { "Date": "Fri, 15 Mar 2019 10:14:25 GMT", "Last-Modified": "Fri, 1 Jan 1970 00:00:00 GMT", @@ -954,13 +955,15 @@ def test_remote_not_modified(): with responses.RequestsMock() as rsps: rsps.add(method="GET", url=source, json=JWKS_DICT, status=200, headers=headers) - assert kb._do_remote() + updated, _ = kb._do_remote() + assert updated == True assert kb.last_remote == headers.get("Last-Modified") timeout1 = kb.time_out with responses.RequestsMock() as rsps: rsps.add(method="GET", url=source, status=304, headers=headers) - assert not kb._do_remote() + updated, _ = kb._do_remote() + assert not updated assert kb.last_remote == headers.get("Last-Modified") timeout2 = kb.time_out @@ -980,8 +983,8 @@ def test_remote_not_modified(): def test_ignore_errors_period(): - source_good = "https://example.com/keys.json" - source_bad = "https://example.com/keys-bad.json" + source_good = "https://example.com/test_ignore_errors_period/keys.json" + source_bad = "https://example.com/test_ignore_errors_period/keys-bad.json" ignore_errors_period = 1 # Mock response with responses.RequestsMock() as rsps: @@ -994,19 +997,19 @@ def test_ignore_errors_period(): httpc_params=httpc_params, ignore_errors_period=ignore_errors_period, ) - res = kb._do_remote() + res, _ = kb._do_remote() assert res == True assert kb.ignore_errors_until is None # refetch, but fail by using a bad source kb.source = source_bad try: - res = kb._do_remote() + res, _ = kb._do_remote() except UpdateFailed: pass # retry should fail silently as we're in holddown - res = kb._do_remote() + res, _ = kb._do_remote() assert kb.ignore_errors_until is not None assert res == False @@ -1015,7 +1018,7 @@ def test_ignore_errors_period(): # try again kb.source = source_good - res = kb._do_remote() + res, _ = kb._do_remote() assert res == True @@ -1031,7 +1034,7 @@ def test_ignore_invalid_keys(): def test_exclude_attributes(): - source = "https://example.com/keys.json" + source = "https://example.com/test_exclude_attributes/keys.json" # Mock response with responses.RequestsMock() as rsps: rsps.add(method="GET", url=source, json=JWKS_DICT, status=200)