Skip to content

Inactive key fixes #66

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 34 additions & 23 deletions src/cryptojwt/key_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,11 @@ def do_local_jwk(self, filename):
Load a JWKS from a local file

:param filename: Name of the file from which the JWKS should be loaded
:return: True if load was successful or False if file hasn't been modified
"""
if not self._local_update_required():
return False

LOGGER.info("Reading local JWKS from %s", filename)
with open(filename) as input_file:
_info = json.load(input_file)
Expand All @@ -328,6 +332,7 @@ def do_local_jwk(self, filename):
self.do_keys([_info])
self.last_local = time.time()
self.time_out = self.last_local + self.cache_time
return True

def do_local_der(self, filename, keytype, keyusage=None, kid=""):
"""
Expand All @@ -336,7 +341,11 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=""):
:param filename: Name of the file
:param keytype: Presently 'rsa' and 'ec' supported
:param keyusage: encryption ('enc') or signing ('sig') or both
:return: True if load was successful or False if file hasn't been modified
"""
if not self._local_update_required():
return False

LOGGER.info("Reading local DER from %s", filename)
key_args = {}
_kty = keytype.lower()
Expand All @@ -359,12 +368,13 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=""):
self.do_keys([key_args])
self.last_local = time.time()
self.time_out = self.last_local + self.cache_time
return True

def do_remote(self):
"""
Load a JWKS from a webpage.

:return: True or False if load was successful
:return: True if load was successful or False if remote hasn't been modified
"""
# if self.verify_ssl is not None:
# self.httpc_params["verify"] = self.verify_ssl
Expand All @@ -390,7 +400,10 @@ def do_remote(self):
LOGGER.error(err)
raise UpdateFailed(REMOTE_FAILED.format(self.source, str(err)))

if _http_resp.status_code == 200: # New content
load_successful = _http_resp.status_code == 200
not_modified = _http_resp.status_code == 304

if load_successful:
self.time_out = time.time() + self.cache_time

self.imp_jwks = self._parse_remote_response(_http_resp)
Expand All @@ -408,11 +421,9 @@ def do_remote(self):
if hasattr(_http_resp, "headers"):
headers = getattr(_http_resp, "headers")
self.last_remote = headers.get("last-modified") or headers.get("date")

elif _http_resp.status_code == 304: # Not modified
elif not_modified:
LOGGER.debug("%s not modified since %s", self.source, self.last_remote)
self.time_out = time.time() + self.cache_time

else:
LOGGER.warning(
"HTTP status %d reading remote JWKS from %s",
Expand All @@ -424,7 +435,7 @@ def do_remote(self):

self.last_updated = time.time()
self.ignore_errors_until = None
return True
return load_successful

def _parse_remote_response(self, response):
"""
Expand All @@ -449,23 +460,20 @@ def _parse_remote_response(self, response):
return None

def _uptodate(self):
res = False
if self.remote or self.local:
if time.time() > self.time_out:
if self.local and not self._local_update_required():
res = True
elif self.update():
res = True
return res
return self.update()
return False

def update(self):
"""
Reload the keys if necessary.

This is a forced update, will happen even if cache time has not elapsed.
Replaced keys will be marked as inactive and not removed.

:return: True if update was ok or False if we encountered an error during update.
"""
res = True # An update was successful
if self.source:
_old_keys = self._keys # just in case

Expand All @@ -475,24 +483,27 @@ def update(self):
try:
if self.local:
if self.fileformat in ["jwks", "jwk"]:
self.do_local_jwk(self.source)
updated = self.do_local_jwk(self.source)
elif self.fileformat == "der":
self.do_local_der(self.source, self.keytype, self.keyusage)
updated = self.do_local_der(self.source, self.keytype, self.keyusage)
elif self.remote:
res = self.do_remote()
updated = self.do_remote()
except Exception as err:
LOGGER.error("Key bundle update failed: %s", err)
self._keys = _old_keys # restore
return False

now = time.time()
for _key in _old_keys:
if _key not in self._keys:
if not _key.inactive_since: # If already marked don't mess
_key.inactive_since = now
self._keys.append(_key)
if updated:
now = time.time()
for _key in _old_keys:
if _key not in self._keys:
if not _key.inactive_since: # If already marked don't mess
_key.inactive_since = now
self._keys.append(_key)
else:
self._keys = _old_keys

return res
return True

def get(self, typ="", only_active=True):
"""
Expand Down
4 changes: 3 additions & 1 deletion tests/test_03_key_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@ def test_update_2():
ec_key = new_ec_key(crv="P-256", key_ops=["sign"])
_jwks = {"keys": [rsa_key.serialize(), ec_key.serialize()]}

time.sleep(0.5)
with open(fname, "w") as fp:
fp.write(json.dumps(_jwks))

Expand Down Expand Up @@ -1009,7 +1010,7 @@ def test_remote_not_modified():

with responses.RequestsMock() as rsps:
rsps.add(method="GET", url=source, status=304, headers=headers)
assert kb.do_remote()
assert not kb.do_remote()
assert kb.last_remote == headers.get("Last-Modified")
timeout2 = kb.time_out

Expand All @@ -1019,6 +1020,7 @@ def test_remote_not_modified():
kb2 = KeyBundle().load(exp)
assert kb2.source == source
assert len(kb2.keys()) == 3
assert len(kb2.active_keys()) == 3
assert len(kb2.get("rsa")) == 1
assert len(kb2.get("oct")) == 1
assert len(kb2.get("ec")) == 1
Expand Down
6 changes: 6 additions & 0 deletions tests/test_04_key_jar.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,12 @@ def test_aud(self):
keys = self.bob_keyjar.get_jwt_verify_keys(_jwt.jwt, no_kid_issuer=no_kid_issuer)
assert len(keys) == 1

def test_inactive_verify_key(self):
_jwt = factory(self.sjwt_b)
self.alice_keyjar.return_issuer("Bob")[0].mark_all_as_inactive()
keys = self.alice_keyjar.get_jwt_verify_keys(_jwt.jwt)
assert len(keys) == 0


def test_copy():
kj = KeyJar()
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ envlist = py{36,37,38},quality
[testenv]
passenv = CI TRAVIS TRAVIS_*
commands =
py.test --cov=cryptojwt --isort --black {posargs}
pytest -vvv -ra --cov=cryptojwt --isort --black {posargs}
codecov
extras = testing
deps =
Expand Down