diff --git a/CHANGELOG.md b/CHANGELOG.md index b45a8c5..2372f6f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,12 @@ - ... +## `0.2.0` - 1/04/2023 + +#### Added + +- PR-#50: Split up DIDTokenError into DIDTokenExpired, DIDTokenMalformed, and DIDTokenInvalid. + ## `0.1.0` - 11/30/2022 #### Added diff --git a/magic_admin/error.py b/magic_admin/error.py index 8581c55..e58b041 100644 --- a/magic_admin/error.py +++ b/magic_admin/error.py @@ -17,7 +17,15 @@ def to_dict(self): return {'message': str(self)} -class DIDTokenError(MagicError): +class DIDTokenInvalid(MagicError): + pass + + +class DIDTokenMalformed(MagicError): + pass + + +class DIDTokenExpired(MagicError): pass diff --git a/magic_admin/resources/token.py b/magic_admin/resources/token.py index ab17023..c218158 100644 --- a/magic_admin/resources/token.py +++ b/magic_admin/resources/token.py @@ -4,7 +4,9 @@ from eth_account.messages import defunct_hash_message from web3.auto import w3 -from magic_admin.error import DIDTokenError +from magic_admin.error import DIDTokenExpired +from magic_admin.error import DIDTokenInvalid +from magic_admin.error import DIDTokenMalformed from magic_admin.resources.base import ResourceComponent from magic_admin.utils.did_token import parse_public_address_from_issuer from magic_admin.utils.time import apply_did_token_nbf_grace_period @@ -42,7 +44,7 @@ def _check_required_fields(cls, claim): missing_fields.append(field) if missing_fields: - raise DIDTokenError( + raise DIDTokenMalformed( message='DID token is missing required field(s): {}'.format( sorted(missing_fields), ), @@ -55,7 +57,7 @@ def decode(cls, did_token): did_token (base64.str): Base64 encoded string. Raises: - DIDTokenError: If token format is invalid. + DIDTokenMalformed: If token format is invalid. Returns: proof (str): A signed message. @@ -66,7 +68,7 @@ def decode(cls, did_token): base64.urlsafe_b64decode(did_token).decode('utf-8'), ) except Exception as e: - raise DIDTokenError( + raise DIDTokenMalformed( message='DID token is malformed. It has to be a based64 encoded ' 'JSON serialized string. {err} ({msg}).'.format( err=e.__class__.__name__, @@ -75,7 +77,7 @@ def decode(cls, did_token): ) if len(decoded_did_token) != EXPECTED_DID_TOKEN_CONTENT_LENGTH: - raise DIDTokenError( + raise DIDTokenMalformed( message='DID token is malformed. It has to have two parts ' '[proof, claim].', ) @@ -85,7 +87,7 @@ def decode(cls, did_token): try: claim = json.loads(decoded_did_token[1]) except Exception as e: - raise DIDTokenError( + raise DIDTokenMalformed( message='DID token is malformed. Given claim should be a JSON ' 'serialized string. {err} ({msg}).'.format( err=e.__class__.__name__, @@ -130,12 +132,20 @@ def validate(cls, did_token): did_token (base64.str): Base64 encoded string. Raises: - DIDTokenError: If DID token fails the validation. + DIDTokenInvalid: If DID token fails the validation. + DIDTokenExpired: If DID token has expired. Returns: None. """ proof, claim = cls.decode(did_token) + + if claim['ext'] is None: + raise DIDTokenInvalid( + message='Please check the "ext" field and regenerate a new token ' + 'with a suitable value.', + ) + recovered_address = w3.eth.account.recoverHash( defunct_hash_message( text=json.dumps(claim, separators=(',', ':')), @@ -144,7 +154,7 @@ def validate(cls, did_token): ) if recovered_address != cls.get_public_address(did_token): - raise DIDTokenError( + raise DIDTokenInvalid( message='Signature mismatch between "proof" and "claim". Please ' 'generate a new token with an intended issuer.', ) @@ -152,12 +162,12 @@ def validate(cls, did_token): current_time_in_s = epoch_time_now() if current_time_in_s > claim['ext']: - raise DIDTokenError( + raise DIDTokenExpired( message='Given DID token has expired. Please generate a new one.', ) if current_time_in_s < apply_did_token_nbf_grace_period(claim['nbf']): - raise DIDTokenError( + raise DIDTokenInvalid( message='Given DID token cannot be used at this time. Please ' 'check the "nbf" field and regenerate a new token with a suitable ' 'value.', diff --git a/magic_admin/utils/did_token.py b/magic_admin/utils/did_token.py index 2202f57..bd8ede6 100644 --- a/magic_admin/utils/did_token.py +++ b/magic_admin/utils/did_token.py @@ -1,4 +1,4 @@ -from magic_admin.error import DIDTokenError +from magic_admin.error import DIDTokenMalformed def parse_public_address_from_issuer(issuer): @@ -14,7 +14,7 @@ def parse_public_address_from_issuer(issuer): try: return issuer.split(':')[2] except IndexError: - raise DIDTokenError( + raise DIDTokenMalformed( 'Given issuer ({}) is malformed. Please make sure it follows the ' '`did:method-name:method-specific-id` format.'.format(issuer), ) diff --git a/tests/unit/error_test.py b/tests/unit/error_test.py index 07aa575..6e4f82f 100644 --- a/tests/unit/error_test.py +++ b/tests/unit/error_test.py @@ -2,7 +2,7 @@ from magic_admin.error import APIError from magic_admin.error import AuthenticationError from magic_admin.error import BadRequestError -from magic_admin.error import DIDTokenError +from magic_admin.error import DIDTokenInvalid from magic_admin.error import ForbiddenError from magic_admin.error import MagicError from magic_admin.error import RateLimitingError @@ -34,9 +34,9 @@ class TestMagicError(MagicErrorBase): error_class = MagicError -class TestDIDTokenError(MagicErrorBase): +class TestDIDTokenInvalid(MagicErrorBase): - error_class = DIDTokenError + error_class = DIDTokenInvalid class TestAPIConnectionError(MagicErrorBase): diff --git a/tests/unit/resources/token_test.py b/tests/unit/resources/token_test.py index f8099bc..f416ffb 100644 --- a/tests/unit/resources/token_test.py +++ b/tests/unit/resources/token_test.py @@ -4,7 +4,9 @@ import pytest -from magic_admin.error import DIDTokenError +from magic_admin.error import DIDTokenExpired +from magic_admin.error import DIDTokenInvalid +from magic_admin.error import DIDTokenMalformed from magic_admin.resources.token import Token @@ -24,7 +26,7 @@ def test_required_fields(self): ) == frozenset() def test_check_required_fields_raises_error(self): - with pytest.raises(DIDTokenError) as e: + with pytest.raises(DIDTokenMalformed) as e: Token._check_required_fields( self._generate_claim(fields=['nbf', 'sub', 'aud', 'tid', 'iat']), ) @@ -80,7 +82,7 @@ def setup_mocks(self): def test_decode_raises_error_if_did_token_is_malformed(self, setup_mocks): setup_mocks.urlsafe_b64decode.side_effect = Exception() - with pytest.raises(DIDTokenError) as e: + with pytest.raises(DIDTokenMalformed) as e: Token.decode(self.did_token) setup_mocks.urlsafe_b64decode.assert_called_once_with(self.did_token) @@ -90,7 +92,7 @@ def test_decode_raises_error_if_did_token_is_malformed(self, setup_mocks): def test_decode_raises_error_if_did_token_has_missing_parts(self, setup_mocks): setup_mocks.json_loads.return_value = ('miss one part') - with pytest.raises(DIDTokenError) as e: + with pytest.raises(DIDTokenMalformed) as e: Token.decode(self.did_token) setup_mocks.urlsafe_b64decode.assert_called_once_with(self.did_token) @@ -101,7 +103,7 @@ def test_decode_raises_error_if_did_token_has_missing_parts(self, setup_mocks): '[proof, claim].' def test_decode_raises_error_if_claim_is_not_json_serializable(self, setup_mocks): - with pytest.raises(DIDTokenError) as e: + with pytest.raises(DIDTokenMalformed) as e: setup_mocks.json_loads.side_effect = [ ('proof_in_str', 'claim_in_str'), # Succeeds the first time. Exception(), # Fails the second time. @@ -228,7 +230,7 @@ def _assert_validate_funcs_called( def test_validate_raises_error_if_signature_mismatch(self, setup_mocks): setup_mocks.get_public_address.return_value = 'random_public_address' - with pytest.raises(DIDTokenError) as e: + with pytest.raises(DIDTokenInvalid) as e: Token.validate(self.did_token) self._assert_validate_funcs_called(setup_mocks) @@ -239,7 +241,7 @@ def test_validate_raises_error_if_did_token_expires(self, setup_mocks): setup_mocks.epoch_time_now.return_value = \ setup_mocks.claim['ext'] + 1 - with pytest.raises(DIDTokenError) as e: + with pytest.raises(DIDTokenExpired) as e: Token.validate(self.did_token) self._assert_validate_funcs_called( @@ -249,11 +251,20 @@ def test_validate_raises_error_if_did_token_expires(self, setup_mocks): assert str(e.value) == 'Given DID token has expired. Please generate a ' \ 'new one.' + def test_validate_raises_error_if_did_token_has_no_expiration(self, setup_mocks): + setup_mocks.claim['ext'] = None + + with pytest.raises(DIDTokenInvalid) as e: + Token.validate(self.did_token) + + assert str(e.value) == 'Please check the "ext" field and regenerate a new' \ + ' token with a suitable value.' + def test_validate_raises_error_if_did_token_used_before_nbf(self, setup_mocks): setup_mocks.epoch_time_now.return_value = \ setup_mocks.claim['nbf'] - 1 - with pytest.raises(DIDTokenError) as e: + with pytest.raises(DIDTokenInvalid) as e: Token.validate(self.did_token) self._assert_validate_funcs_called( diff --git a/tests/unit/utils/did_token_test.py b/tests/unit/utils/did_token_test.py index 1eaeb26..2e02755 100644 --- a/tests/unit/utils/did_token_test.py +++ b/tests/unit/utils/did_token_test.py @@ -1,6 +1,6 @@ import pytest -from magic_admin.error import DIDTokenError +from magic_admin.error import DIDTokenMalformed from magic_admin.utils.did_token import construct_issuer_with_public_address from magic_admin.utils.did_token import parse_public_address_from_issuer from testing.data.did_token import issuer @@ -15,7 +15,7 @@ def test_parse_public_address_from_issuer(self): assert parse_public_address_from_issuer(issuer) == public_address def test_parse_public_address_from_issuer_raises_error(self): - with pytest.raises(DIDTokenError) as e: + with pytest.raises(DIDTokenMalformed) as e: parse_public_address_from_issuer(self.malformed_issuer) assert str(e.value) == \