Skip to content

Atomic keys update #115

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 39 additions & 41 deletions src/cryptojwt/key_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,15 +252,16 @@ def __init__(
self.source = None
if isinstance(keys, dict):
if "keys" in keys:
self._add_jwk_dicts(keys["keys"])
initial_keys = keys["keys"]
else:
self._add_jwk_dicts([keys])
initial_keys = [keys]
else:
self._add_jwk_dicts(keys)
initial_keys = keys
self._keys = self.jwk_dicts_as_keys(initial_keys)
else:
self._set_source(source, fileformat)
if self.local:
self._do_local(kid)
self._keys = self._do_local(kid)

def _set_source(self, source, fileformat):
if source.startswith("file://"):
Expand All @@ -283,9 +284,10 @@ def _set_source(self, source, fileformat):

def _do_local(self, kid):
if self.fileformat in ["jwks", "jwk"]:
self._do_local_jwk(self.source)
updated, keys = self._do_local_jwk(self.source)
elif self.fileformat == "der":
self._do_local_der(self.source, self.keytype, self.keyusage, kid)
updated, keys = self._do_local_der(self.source, self.keytype, self.keyusage, kid)
return keys

def _local_update_required(self) -> bool:
stat = os.stat(self.source)
Expand All @@ -309,13 +311,8 @@ def add_jwk_dicts(self, keys):
:param keys: List of JWK dictionaries
:return:
"""
self._add_jwk_dicts(keys)

def _add_jwk_dicts(self, keys):
_new_keys = self.jwk_dicts_as_keys(keys)
if _new_keys:
self._keys.extend(_new_keys)
self.last_updated = time.time()
self._keys.extend(self.jwk_dicts_as_keys(keys))
self.last_updated = time.time()

def jwk_dicts_as_keys(self, keys):
"""
Expand Down Expand Up @@ -384,18 +381,19 @@ def _do_local_jwk(self, filename):
:return: True if load was successful or False if file hasn't been modified
"""
if not self._local_update_required():
return False
return False, None

LOGGER.info("Reading local JWKS from %s", filename)
with open(filename) as input_file:
_info = json.load(input_file)
if "keys" in _info:
self._add_jwk_dicts(_info["keys"])
new_keys = self.jwk_dicts_as_keys(_info["keys"])
else:
self._add_jwk_dicts([_info])
new_keys = self.jwk_dicts_as_keys([_info])

self.last_local = time.time()
self.time_out = self.last_local + self.cache_time
return True
return True, new_keys

def _do_local_der(self, filename, keytype, keyusage=None, kid=""):
"""
Expand All @@ -407,7 +405,7 @@ def _do_local_der(self, filename, keytype, keyusage=None, kid=""):
:return: True if load was successful or False if file hasn't been modified
"""
if not self._local_update_required():
return False
return False, None

LOGGER.info("Reading local DER from %s", filename)
key_args = {}
Expand All @@ -428,12 +426,12 @@ def _do_local_der(self, filename, keytype, keyusage=None, kid=""):
if kid:
key_args["kid"] = kid

self._add_jwk_dicts([key_args])
new_keys = self.jwk_dicts_as_keys([key_args])
self.last_local = time.time()
self.time_out = self.last_local + self.cache_time
return True
return True, new_keys

def _do_remote(self):
def _do_remote(self, set_keys=True):
"""
Load a JWKS from a webpage.

Expand All @@ -448,7 +446,7 @@ def _do_remote(self):
self.source,
datetime.fromtimestamp(self.ignore_errors_until),
)
return False
return False, None

LOGGER.info("Reading remote JWKS from %s", self.source)
try:
Expand Down Expand Up @@ -497,11 +495,12 @@ def _do_remote(self):
self.ignore_errors_until = time.time() + self.ignore_errors_period
raise UpdateFailed(REMOTE_FAILED.format(self.source, _http_resp.status_code))

if new_keys is not None:
if set_keys and new_keys:
self._keys = new_keys

self.last_updated = time.time()
self.ignore_errors_until = None
return load_successful
return load_successful, new_keys

def _parse_remote_response(self, response):
"""
Expand Down Expand Up @@ -542,34 +541,31 @@ def update(self):
:return: True if update was ok or False if we encountered an error during update.
"""
if self.source:
_old_keys = self._keys # just in case

# reread everything
self._keys = []
new_keys = []
updated = None

try:
if self.local:
if self.fileformat in ["jwks", "jwk"]:
updated = self._do_local_jwk(self.source)
updated, k = self._do_local_jwk(self.source)
elif self.fileformat == "der":
updated = self._do_local_der(self.source, self.keytype, self.keyusage)
updated, k = self._do_local_der(self.source, self.keytype, self.keyusage)
elif self.remote:
updated = self._do_remote()
updated, k = self._do_remote(set_keys=False)
if k:
new_keys.extend(k)
except Exception as err:
LOGGER.error("Key bundle update failed: %s", err)
self._keys = _old_keys # restore
return False

if updated:
now = time.time()
for _key in _old_keys:
if _key not in self._keys:
for _key in self._keys:
if _key not in new_keys:
if not _key.inactive_since: # If already marked don't mess
_key.inactive_since = now
self._keys.append(_key)
else:
self._keys = _old_keys
new_keys.append(_key)
self._keys = new_keys

return True

Expand All @@ -585,9 +581,9 @@ def get(self, typ="", only_active=True):

if typ:
_typs = [typ.lower(), typ.upper()]
_keys = [k for k in self._keys[:] if k.kty in _typs]
_keys = [k for k in self._keys if k.kty in _typs]
else:
_keys = self._keys[:]
_keys = self._keys

if only_active:
return [k for k in _keys if not k.inactive_since]
Expand All @@ -602,7 +598,7 @@ def keys(self, update: bool = True):
"""
if update:
self._uptodate()
return self._keys[:]
return self._keys

def active_keys(self):
"""Return the set of active keys."""
Expand Down Expand Up @@ -829,9 +825,11 @@ def load(self, spec):
:param spec: Dictionary with attributes and value to populate the instance with
:return: The instance itself
"""

_keys = spec.get("keys", [])
if _keys:
self._add_jwk_dicts(_keys)
self._keys.extend(self.jwk_dicts_as_keys(_keys))
self.last_updated = time.time()

for attr, default in self.params.items():
val = spec.get(attr)
Expand Down
27 changes: 15 additions & 12 deletions tests/test_03_key_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,8 @@ def test_httpc_params_1():
rsps.add(method=responses.GET, url=source, json=JWKS_DICT, status=200)
httpc_params = {"timeout": (2, 2)} # connect, read timeouts in seconds
kb = KeyBundle(source=source, httpc=requests.request, httpc_params=httpc_params)
assert kb._do_remote()
updated, _ = kb._do_remote()
assert updated == True


@pytest.mark.network
Expand Down Expand Up @@ -920,7 +921,7 @@ def test_export_inactive():


def test_remote():
source = "https://example.com/keys.json"
source = "https://example.com/test_remote/keys.json"
# Mock response
with responses.RequestsMock() as rsps:
rsps.add(method="GET", url=source, json=JWKS_DICT, status=200)
Expand All @@ -941,7 +942,7 @@ def test_remote():


def test_remote_not_modified():
source = "https://example.com/keys.json"
source = "https://example.com/test_remote_not_modified/keys.json"
headers = {
"Date": "Fri, 15 Mar 2019 10:14:25 GMT",
"Last-Modified": "Fri, 1 Jan 1970 00:00:00 GMT",
Expand All @@ -954,13 +955,15 @@ def test_remote_not_modified():

with responses.RequestsMock() as rsps:
rsps.add(method="GET", url=source, json=JWKS_DICT, status=200, headers=headers)
assert kb._do_remote()
updated, _ = kb._do_remote()
assert updated == True
assert kb.last_remote == headers.get("Last-Modified")
timeout1 = kb.time_out

with responses.RequestsMock() as rsps:
rsps.add(method="GET", url=source, status=304, headers=headers)
assert not kb._do_remote()
updated, _ = kb._do_remote()
assert not updated
assert kb.last_remote == headers.get("Last-Modified")
timeout2 = kb.time_out

Expand All @@ -980,8 +983,8 @@ def test_remote_not_modified():


def test_ignore_errors_period():
source_good = "https://example.com/keys.json"
source_bad = "https://example.com/keys-bad.json"
source_good = "https://example.com/test_ignore_errors_period/keys.json"
source_bad = "https://example.com/test_ignore_errors_period/keys-bad.json"
ignore_errors_period = 1
# Mock response
with responses.RequestsMock() as rsps:
Expand All @@ -994,19 +997,19 @@ def test_ignore_errors_period():
httpc_params=httpc_params,
ignore_errors_period=ignore_errors_period,
)
res = kb._do_remote()
res, _ = kb._do_remote()
assert res == True
assert kb.ignore_errors_until is None

# refetch, but fail by using a bad source
kb.source = source_bad
try:
res = kb._do_remote()
res, _ = kb._do_remote()
except UpdateFailed:
pass

# retry should fail silently as we're in holddown
res = kb._do_remote()
res, _ = kb._do_remote()
assert kb.ignore_errors_until is not None
assert res == False

Expand All @@ -1015,7 +1018,7 @@ def test_ignore_errors_period():

# try again
kb.source = source_good
res = kb._do_remote()
res, _ = kb._do_remote()
assert res == True


Expand All @@ -1031,7 +1034,7 @@ def test_ignore_invalid_keys():


def test_exclude_attributes():
source = "https://example.com/keys.json"
source = "https://example.com/test_exclude_attributes/keys.json"
# Mock response
with responses.RequestsMock() as rsps:
rsps.add(method="GET", url=source, json=JWKS_DICT, status=200)
Expand Down