Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 48 additions & 12 deletions src/cryptojwt/key_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True,

self._keys = []
self.remote = False
self.local = False
self.cache_time = cache_time
self.time_out = 0
self.etag = ""
Expand All @@ -189,6 +190,8 @@ def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True,
self.keyusage = keyusage
self.imp_jwks = None
self.last_updated = 0
self.last_remote = None # HTTP Date of last remote update
self.last_local = None # UNIX timestamp of last local update

if httpc:
self.httpc = httpc
Expand All @@ -208,13 +211,13 @@ def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True,
self.do_keys(keys)
else:
self._set_source(source, fileformat)

if not self.remote and self.source: # local file
if self.local:
self._do_local(kid)

def _set_source(self, source, fileformat):
if source.startswith("file://"):
self.source = source[7:]
self.local = True
elif source.startswith("http://") or source.startswith("https://"):
self.source = source
self.remote = True
Expand All @@ -224,6 +227,7 @@ def _set_source(self, source, fileformat):
if fileformat.lower() in ['rsa', 'der', 'jwks']:
if os.path.isfile(source):
self.source = source
self.local = True
else:
raise ImportError('No such file')
else:
Expand All @@ -235,6 +239,16 @@ def _do_local(self, kid):
elif self.fileformat == "der":
self.do_local_der(self.source, self.keytype, self.keyusage, kid)

def _local_update_required(self) -> bool:
stat = os.stat(self.source)
if self.last_local and stat.st_mtime < self.last_local:
LOGGER.debug("%s not modfied", self.source)
return False
else:
LOGGER.debug("%s modfied", self.source)
self.last_local = stat.st_mtime
return True

def do_keys(self, keys):
"""
Go from JWK description to binary keys
Expand Down Expand Up @@ -290,12 +304,15 @@ def do_local_jwk(self, filename):

:param filename: Name of the file from which the JWKS should be loaded
"""
LOGGER.debug("Reading JWKS from %s", filename)
with open(filename) as input_file:
_info = json.load(input_file)
if 'keys' in _info:
self.do_keys(_info["keys"])
else:
self.do_keys([_info])
self.last_local = time.time()
self.time_out = self.last_local + self.cache_time

def do_local_der(self, filename, keytype, keyusage=None, kid=''):
"""
Expand All @@ -305,6 +322,7 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=''):
:param keytype: Presently 'rsa' and 'ec' supported
:param keyusage: encryption ('enc') or signing ('sig') or both
"""
LOGGER.debug("Reading DER from %s", filename)
key_args = {}
_kty = keytype.lower()
if _kty in ['rsa', 'ec']:
Expand All @@ -324,6 +342,8 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=''):
key_args['kid'] = kid

self.do_keys([key_args])
self.last_local = time.time()
self.time_out = self.last_local + self.cache_time

def do_remote(self):
"""
Expand All @@ -336,6 +356,10 @@ def do_remote(self):

try:
LOGGER.debug('KeyBundle fetch keys from: %s', self.source)
if self.last_remote is not None:
if "headers" not in self.httpc_params:
self.httpc_params["headers"] = {}
self.httpc_params["headers"]["If-Modified-Since"] = self.last_remote
_http_resp = self.httpc('GET', self.source, **self.httpc_params)
except Exception as err:
LOGGER.error(err)
Expand All @@ -357,6 +381,14 @@ def do_remote(self):
LOGGER.error("No 'keys' keyword in JWKS")
raise UpdateFailed(MALFORMED.format(self.source))

if hasattr(_http_resp, "headers"):
headers = getattr(_http_resp, "headers")
self.last_remote = headers.get("last-modified") or headers.get("date")

elif _http_resp.status_code == 304: # Not modified
LOGGER.debug("%s not modified since %s", self.source, self.last_remote)
pass

else:
raise UpdateFailed(
REMOTE_FAILED.format(self.source, _http_resp.status_code))
Expand Down Expand Up @@ -387,14 +419,12 @@ def _parse_remote_response(self, response):

def _uptodate(self):
res = False
if not self._keys:
if self.remote: # verify that it's not to old
if time.time() > self.time_out:
if self.update():
res = True
elif self.remote:
if self.update():
res = True
if self.remote or self.local:
if time.time() > self.time_out:
if self.local and not self._local_update_required():
res = True
elif self.update():
res = True
return res

def update(self):
Expand All @@ -412,13 +442,13 @@ def update(self):
self._keys = []

try:
if self.remote is False:
if self.local:
if self.fileformat in ["jwks", "jwk"]:
self.do_local_jwk(self.source)
elif self.fileformat == "der":
self.do_local_der(self.source, self.keytype,
self.keyusage)
else:
elif self.remote:
res = self.do_remote()
except Exception as err:
LOGGER.error('Key bundle update failed: %s', err)
Expand Down Expand Up @@ -661,8 +691,11 @@ def dump(self):
"keys": _keys,
"fileformat": self.fileformat,
"last_updated": self.last_updated,
"last_remote": self.last_remote,
"last_local": self.last_local,
"httpc_params": self.httpc_params,
"remote": self.remote,
"local": self.local,
"imp_jwks": self.imp_jwks,
"time_out": self.time_out,
"cache_time": self.cache_time
Expand All @@ -680,7 +713,10 @@ def load(self, spec):
self.source = spec.get("source", None)
self.fileformat = spec.get("fileformat", "jwks")
self.last_updated = spec.get("last_updated", 0)
self.last_remote = spec.get("last_remote", None)
self.last_local = spec.get("last_local", None)
self.remote = spec.get("remote", False)
self.local = spec.get("local", False)
self.imp_jwks = spec.get('imp_jwks', None)
self.time_out = spec.get('time_out', 0)
self.cache_time = spec.get('cache_time', 0)
Expand Down
3 changes: 3 additions & 0 deletions tests/test_03_key_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,10 @@ def test_export_inactive():
'imp_jwks',
'keys',
'last_updated',
'last_remote',
'last_local',
'remote',
'local',
'time_out'}

kb2 = KeyBundle().load(res)
Expand Down