diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index 3f8205d7..4a30faf6 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -318,7 +318,11 @@ def do_local_jwk(self, filename): Load a JWKS from a local file :param filename: Name of the file from which the JWKS should be loaded + :return: True if load was successful or False if file hasn't been modified """ + if not self._local_update_required(): + return False + LOGGER.info("Reading local JWKS from %s", filename) with open(filename) as input_file: _info = json.load(input_file) @@ -328,6 +332,7 @@ def do_local_jwk(self, filename): self.do_keys([_info]) self.last_local = time.time() self.time_out = self.last_local + self.cache_time + return True def do_local_der(self, filename, keytype, keyusage=None, kid=""): """ @@ -336,7 +341,11 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=""): :param filename: Name of the file :param keytype: Presently 'rsa' and 'ec' supported :param keyusage: encryption ('enc') or signing ('sig') or both + :return: True if load was successful or False if file hasn't been modified """ + if not self._local_update_required(): + return False + LOGGER.info("Reading local DER from %s", filename) key_args = {} _kty = keytype.lower() @@ -359,12 +368,13 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=""): self.do_keys([key_args]) self.last_local = time.time() self.time_out = self.last_local + self.cache_time + return True def do_remote(self): """ Load a JWKS from a webpage. - :return: True or False if load was successful + :return: True if load was successful or False if remote hasn't been modified """ # if self.verify_ssl is not None: # self.httpc_params["verify"] = self.verify_ssl @@ -390,7 +400,10 @@ def do_remote(self): LOGGER.error(err) raise UpdateFailed(REMOTE_FAILED.format(self.source, str(err))) - if _http_resp.status_code == 200: # New content + load_successful = _http_resp.status_code == 200 + not_modified = _http_resp.status_code == 304 + + if load_successful: self.time_out = time.time() + self.cache_time self.imp_jwks = self._parse_remote_response(_http_resp) @@ -408,11 +421,9 @@ def do_remote(self): if hasattr(_http_resp, "headers"): headers = getattr(_http_resp, "headers") self.last_remote = headers.get("last-modified") or headers.get("date") - - elif _http_resp.status_code == 304: # Not modified + elif not_modified: LOGGER.debug("%s not modified since %s", self.source, self.last_remote) self.time_out = time.time() + self.cache_time - else: LOGGER.warning( "HTTP status %d reading remote JWKS from %s", @@ -424,7 +435,7 @@ def do_remote(self): self.last_updated = time.time() self.ignore_errors_until = None - return True + return load_successful def _parse_remote_response(self, response): """ @@ -449,14 +460,10 @@ def _parse_remote_response(self, response): return None def _uptodate(self): - res = False if self.remote or self.local: if time.time() > self.time_out: - if self.local and not self._local_update_required(): - res = True - elif self.update(): - res = True - return res + return self.update() + return False def update(self): """ @@ -464,8 +471,9 @@ def update(self): This is a forced update, will happen even if cache time has not elapsed. Replaced keys will be marked as inactive and not removed. + + :return: True if update was ok or False if we encountered an error during update. """ - res = True # An update was successful if self.source: _old_keys = self._keys # just in case @@ -475,24 +483,27 @@ def update(self): try: if self.local: if self.fileformat in ["jwks", "jwk"]: - self.do_local_jwk(self.source) + updated = self.do_local_jwk(self.source) elif self.fileformat == "der": - self.do_local_der(self.source, self.keytype, self.keyusage) + updated = self.do_local_der(self.source, self.keytype, self.keyusage) elif self.remote: - res = self.do_remote() + updated = self.do_remote() except Exception as err: LOGGER.error("Key bundle update failed: %s", err) self._keys = _old_keys # restore return False - now = time.time() - for _key in _old_keys: - if _key not in self._keys: - if not _key.inactive_since: # If already marked don't mess - _key.inactive_since = now - self._keys.append(_key) + if updated: + now = time.time() + for _key in _old_keys: + if _key not in self._keys: + if not _key.inactive_since: # If already marked don't mess + _key.inactive_since = now + self._keys.append(_key) + else: + self._keys = _old_keys - return res + return True def get(self, typ="", only_active=True): """ diff --git a/tests/test_03_key_bundle.py b/tests/test_03_key_bundle.py index 35fbdff2..7d120269 100755 --- a/tests/test_03_key_bundle.py +++ b/tests/test_03_key_bundle.py @@ -567,6 +567,7 @@ def test_update_2(): ec_key = new_ec_key(crv="P-256", key_ops=["sign"]) _jwks = {"keys": [rsa_key.serialize(), ec_key.serialize()]} + time.sleep(0.5) with open(fname, "w") as fp: fp.write(json.dumps(_jwks)) @@ -1009,7 +1010,7 @@ def test_remote_not_modified(): with responses.RequestsMock() as rsps: rsps.add(method="GET", url=source, status=304, headers=headers) - assert kb.do_remote() + assert not kb.do_remote() assert kb.last_remote == headers.get("Last-Modified") timeout2 = kb.time_out @@ -1019,6 +1020,7 @@ def test_remote_not_modified(): kb2 = KeyBundle().load(exp) assert kb2.source == source assert len(kb2.keys()) == 3 + assert len(kb2.active_keys()) == 3 assert len(kb2.get("rsa")) == 1 assert len(kb2.get("oct")) == 1 assert len(kb2.get("ec")) == 1 diff --git a/tests/test_04_key_jar.py b/tests/test_04_key_jar.py index b31e5ba8..53cb5ef5 100755 --- a/tests/test_04_key_jar.py +++ b/tests/test_04_key_jar.py @@ -746,6 +746,12 @@ def test_aud(self): keys = self.bob_keyjar.get_jwt_verify_keys(_jwt.jwt, no_kid_issuer=no_kid_issuer) assert len(keys) == 1 + def test_inactive_verify_key(self): + _jwt = factory(self.sjwt_b) + self.alice_keyjar.return_issuer("Bob")[0].mark_all_as_inactive() + keys = self.alice_keyjar.get_jwt_verify_keys(_jwt.jwt) + assert len(keys) == 0 + def test_copy(): kj = KeyJar() diff --git a/tox.ini b/tox.ini index def7f92e..272756a9 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,7 @@ envlist = py{36,37,38},quality [testenv] passenv = CI TRAVIS TRAVIS_* commands = - py.test --cov=cryptojwt --isort --black {posargs} + pytest -vvv -ra --cov=cryptojwt --isort --black {posargs} codecov extras = testing deps =