diff --git a/src/cryptojwt/jws/utils.py b/src/cryptojwt/jws/utils.py index 8ee30953..709c7853 100644 --- a/src/cryptojwt/jws/utils.py +++ b/src/cryptojwt/jws/utils.py @@ -1,4 +1,5 @@ # import struct + from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import padding diff --git a/src/cryptojwt/utils.py b/src/cryptojwt/utils.py index 8471840f..5c13d91a 100644 --- a/src/cryptojwt/utils.py +++ b/src/cryptojwt/utils.py @@ -13,6 +13,7 @@ DEFAULT_HTTPC_TIMEOUT = 10 + # --------------------------------------------------------------------------- # Helper functions @@ -193,7 +194,7 @@ def split_token(token): def deser(val): """ - Deserialize from a string representation of an long integer + Deserialize from a string representation of a long integer to the python representation of a long integer. :param val: The string representation of the long integer. @@ -212,12 +213,12 @@ def modsplit(name): if ":" in name: _part = name.split(":") if len(_part) != 2: - raise ValueError(f"Syntax error: {s}") + raise ValueError(f"Syntax error: {name}") return _part[0], _part[1] _part = name.split(".") if len(_part) < 2: - raise ValueError(f"Syntax error: {s}") + raise ValueError(f"Syntax error: {name}") return ".".join(_part[:-1]), _part[-1] @@ -273,3 +274,94 @@ def check_content_type(content_type, mime_type): msg["content-type"] = content_type mt = msg.get_content_type() return mime_type == mt + + +def is_compact_jws(token): + token = as_bytes(token) + + try: + part = split_token(token) + except BadSyntax: + return False + + # Should be three parts + if len(part) != 3: + return False + + # All base64 encoded + try: + part = [b64d(p) for p in part] + except Exception: + return False + + # header should be a JSON object, 'alg' most be one parameter + try: + _header = json.loads(part[0]) + except Exception: + return False + + if "alg" not in _header: + return False + + return True + + +def is_jwe(token): + token = as_bytes(token) + + try: + part = split_token(token) + except BadSyntax: + return False + + # Should be five parts + if len(part) != 5: + return False + + # All base64 encoded + try: + part = [b64d(p) for p in part] + except Exception: + return False + + # header should be a JSON object, 'alg' most be one parameter + try: + _header = json.loads(part[0]) + except Exception: + return False + + if "alg" not in _header or "enc" not in _header: + return False + + return True + + +def is_json_jws(token): + if isinstance(token, str): + try: + token = json.loads(token) + except Exception: + return False + + for arg in ["payload", "signatures"]: + if arg not in token: + return False + + if not isinstance(token["signatures"], list): + return False + + for sign in token["signatures"]: + if not isinstance(sign, dict): + return False + if "signature" not in sign: + return False + + return True + + +def is_jws(token): + if is_json_jws(token): + return "json" + elif is_compact_jws(token): + return "compact" + return False diff --git a/tests/test_06_jws.py b/tests/test_06_jws.py index 6045c4cb..c62d0e42 100644 --- a/tests/test_06_jws.py +++ b/tests/test_06_jws.py @@ -7,6 +7,7 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import ec +from cryptojwt import as_unicode from cryptojwt.exception import BadSignature from cryptojwt.exception import UnknownAlgorithm from cryptojwt.exception import WrongNumberOfParts @@ -25,10 +26,13 @@ from cryptojwt.jws.utils import left_hash from cryptojwt.jws.utils import parse_rsa_algorithm from cryptojwt.key_bundle import KeyBundle +from cryptojwt.utils import as_bytes from cryptojwt.utils import b64d from cryptojwt.utils import b64d_enc_dec from cryptojwt.utils import b64e from cryptojwt.utils import intarr2bin +from cryptojwt.utils import is_compact_jws +from cryptojwt.utils import is_json_jws BASEDIR = os.path.abspath(os.path.dirname(__file__)) @@ -297,7 +301,6 @@ def full_path(local_file): ] } - SIGJWKS = KeyBundle(JWKS_b) @@ -1020,3 +1023,50 @@ def test_verify_json_missing_key(): # With both assert JWS().verify_json(_jwt, keys=[vkeys[0], sym_key]) + + +def test_is_compact_jws(): + _header = {"foo": "bar", "alg": "HS384"} + _payload = "hello world" + _sym_key = SYMKey(key=b"My hollow echo chamber", alg="HS384") + + _jwt = JWS(msg=_payload, alg="HS384").sign_compact(keys=[_sym_key]) + + assert is_compact_jws(_jwt) + + # Faulty examples + + # to few parts + assert is_compact_jws("abc.def") is False + + # right number of parts but not base64 + + assert is_compact_jws("abc.def.ghi") is False + + # not base64 illegal characters + assert is_compact_jws("abc.::::.ghi") is False + + # Faulty header + _faulty_header = {"foo": "bar"} # alg is a MUST + _jwt = ".".join([as_unicode(b64e(as_bytes(json.dumps(_faulty_header)))), "def", "ghi"]) + assert is_compact_jws(_jwt) is False + + +def test_is_json_jws(): + ec_key = ECKey().load_key(P256()) + sym_key = SYMKey(key=b"My hollow echo chamber", alg="HS384") + + protected_headers_1 = {"foo": "bar", "alg": "ES256"} + unprotected_headers_1 = {"abc": "xyz"} + protected_headers_2 = {"foo": "bar", "alg": "HS384"} + unprotected_headers_2 = {"abc": "zeb"} + payload = "hello world" + _jwt = JWS(msg=payload).sign_json( + headers=[ + (protected_headers_1, unprotected_headers_1), + (protected_headers_2, unprotected_headers_2), + ], + keys=[ec_key, sym_key], + ) + + assert is_json_jws(_jwt) diff --git a/tests/test_07_jwe.py b/tests/test_07_jwe.py index 82a31607..eade8c9a 100644 --- a/tests/test_07_jwe.py +++ b/tests/test_07_jwe.py @@ -37,6 +37,8 @@ __author__ = "rohe0002" +from cryptojwt.utils import is_jwe + def rndstr(size=16): """ @@ -717,3 +719,10 @@ def test_fernet_blake2s(): decrypter = encrypter resp = decrypter.decrypt(_token) assert resp == plain + + +def test_is_jwe(): + encryption_key = SYMKey(use="enc", key="DukeofHazardpass", kid="some-key-id") + jwe = JWE(plain, alg="A128KW", enc="A128CBC-HS256") + _jwe = jwe.encrypt(keys=[encryption_key], kid="some-key-id") + assert is_jwe(_jwe)