Skip to content

Add util.assert_never for static checks of must-be-unreachable code #51

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 14 additions & 23 deletions bitcointx/core/psbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from ..wallet import CCoinExtPubKey

from ..util import (
ensure_isinstance, no_bool_use_as_property,
ensure_isinstance, no_bool_use_as_property, assert_never,
ClassMappingDispatcher, activate_class_dispatcher
)

Expand Down Expand Up @@ -135,13 +135,8 @@ class PSBT_OutKeyType(Enum):
('key_type', int), ('key_data', bytes), ('value', bytes)
])

T_KeyTypeEnum_Type = Union[
Type[PSBT_GlobalKeyType],
Type[PSBT_OutKeyType],
Type[PSBT_InKeyType],
]

T_KeyTypeEnum = Union[PSBT_GlobalKeyType, PSBT_OutKeyType, PSBT_InKeyType]
T_KeyTypeEnum = TypeVar(
'T_KeyTypeEnum', PSBT_GlobalKeyType, PSBT_OutKeyType, PSBT_InKeyType)


def proprietary_field_repr(
Expand Down Expand Up @@ -296,7 +291,7 @@ def merge_unknown_fields(
def read_psbt_keymap(
f: ByteStream_Type,
keys_seen: Set[bytes],
keys_enum_class: T_KeyTypeEnum_Type,
keys_enum_class: Type[T_KeyTypeEnum],
proprietary_fields: Dict[bytes, List[PSBT_ProprietaryTypeData]],
unknown_fields: List[PSBT_UnknownTypeData]
) -> Generator[Tuple[T_KeyTypeEnum, bytes, bytes], None, None]:
Expand Down Expand Up @@ -1295,10 +1290,7 @@ def check_witness_and_nonwitness_utxo_in_sync(
ensure_empty_key_data(key_type, key_data, descr(''))
proof_of_reserves_commitment = value
else:
raise AssertionError(
f'If key type {key_type} is recognized, '
f'it must be handled, and this statement '
f'should not be reached.')
assert_never(key_type)

# non_witness_utxo is preferred over witness_utxo for `utxo` kwarg
# because non_witness_utxo is a full transaction,
Expand Down Expand Up @@ -1646,13 +1638,13 @@ def descr(msg: str) -> str:
read_psbt_keymap(f, keys_seen, PSBT_OutKeyType,
proprietary_fields, unknown_fields):

if key_type == PSBT_OutKeyType.REDEEM_SCRIPT:
if key_type is PSBT_OutKeyType.REDEEM_SCRIPT:
ensure_empty_key_data(key_type, key_data, descr(''))
redeem_script = CScript(value)
elif key_type == PSBT_OutKeyType.WITNESS_SCRIPT:
elif key_type is PSBT_OutKeyType.WITNESS_SCRIPT:
ensure_empty_key_data(key_type, key_data, descr(''))
witness_script = CScript(value)
elif key_type == PSBT_OutKeyType.BIP32_DERIVATION:
elif key_type is PSBT_OutKeyType.BIP32_DERIVATION:
pub = CPubKey(key_data)
if not pub.is_fullyvalid():
raise SerializationError(
Expand All @@ -1662,6 +1654,8 @@ def descr(msg: str) -> str:
("duplicate keys should have been catched "
"inside read_psbt_keymap()")
derivation_map[pub] = PSBT_KeyDerivationInfo.deserialize(value)
else:
assert_never(key_type)

return cls(redeem_script=redeem_script, witness_script=witness_script,
derivation_map=derivation_map,
Expand Down Expand Up @@ -2107,10 +2101,10 @@ def stream_deserialize(cls: Type[T_PartiallySignedTransaction],
read_psbt_keymap(f, keys_seen, PSBT_GlobalKeyType,
proprietary_fields, unknown_fields):

if key_type == PSBT_GlobalKeyType.UNSIGNED_TX:
if key_type is PSBT_GlobalKeyType.UNSIGNED_TX:
ensure_empty_key_data(key_type, key_data)
unsigned_tx = CTransaction.deserialize(value)
elif key_type == PSBT_GlobalKeyType.XPUB:
elif key_type is PSBT_GlobalKeyType.XPUB:
if key_data[:4] != CCoinExtPubKey.base58_prefix:
raise ValueError(
f'One of global xpubs has unknown prefix: expected '
Expand All @@ -2121,17 +2115,14 @@ def stream_deserialize(cls: Type[T_PartiallySignedTransaction],
("duplicate keys should have been catched "
"inside read_psbt_keymap()")
xpubs[xpub] = PSBT_KeyDerivationInfo.deserialize(value)
elif key_type == PSBT_GlobalKeyType.VERSION:
elif key_type is PSBT_GlobalKeyType.VERSION:
ensure_empty_key_data(key_type, key_data)
if len(value) != 4:
raise SerializationError(
f'Incorrect data length for {key_type.name}')
version = struct.unpack(b'<I', value)[0]
else:
raise AssertionError(
f'If key type {key_type} is present in PSBT_GLOBAL_KEYS, '
f'it must be handled, and this statement '
f'should not be reached.')
assert_never(key_type)

if unsigned_tx is None:
raise ValueError(
Expand Down
10 changes: 5 additions & 5 deletions bitcointx/tests/test_psbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,20 +679,20 @@ def get_input_key_types(psbt_bytes: bytes) -> Set[PSBT_InKeyType]:
assert magic == CoreCoinParams.PSBT_MAGIC_HEADER_BYTES
unsigned_tx = None
keys_seen: Set[bytes] = set()
for key_type, key_data, value in \
for key_type_g, key_data, value in \
read_psbt_keymap(f, keys_seen, PSBT_GlobalKeyType,
OrderedDict(), list()):
if key_type == PSBT_GlobalKeyType.UNSIGNED_TX:
if key_type_g == PSBT_GlobalKeyType.UNSIGNED_TX:
unsigned_tx = CTransaction.deserialize(value)
assert unsigned_tx
keys_seen = set()
key_types_seen: Set[PSBT_InKeyType] = set()
assert len(unsigned_tx.vin) == 1
for key_type, key_data, value in \
for key_type_in, key_data, value in \
read_psbt_keymap(f, keys_seen, PSBT_InKeyType,
OrderedDict(), list()):
assert isinstance(key_type, PSBT_InKeyType)
key_types_seen.add(key_type)
assert isinstance(key_type_in, PSBT_InKeyType)
key_types_seen.add(key_type_in)

return key_types_seen

Expand Down
26 changes: 25 additions & 1 deletion bitcointx/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
has_contextvars = False

import functools
from enum import Enum
from types import FunctionType
from abc import ABCMeta, ABC, abstractmethod
from typing import (
Type, Set, Tuple, List, Dict, Union, Any, Callable, Iterable, Optional,
TypeVar, Generic, cast
TypeVar, Generic, cast, NoReturn
)

_secp256k1_library_path: Optional[str] = None
Expand Down Expand Up @@ -471,6 +472,28 @@ def ensure_isinstance(var: object,
raise TypeError(msg)


def assert_never(x: NoReturn) -> NoReturn:
"""For use with static checking. The checker such as mypy will raise
error if the statement `assert_never(...)` is reached. At runtime,
an `AssertionError` will be raised.
Useful to ensure that all variants of Enum is handled.
Might become useful in other ways, and because of this, the message
for `AssertionError` at runtime can differ on actual type of the argument.
For full control of the message, just pass a string as the argument.
"""

if isinstance(x, Enum):
msg = f'Enum {x} is not handled'
elif isinstance(x, str):
msg = x
elif isinstance(x, type):
msg = f'{x.__name__} is not handled'
else:
msg = f'{x.__class__.__name__} is not handled'

raise AssertionError(msg)


class ReadOnlyFieldGuard(ABC):
"""A unique class that is used as a guard type for ReadOnlyField.
It cannot be instantiated at runtime, and the static check will also
Expand Down Expand Up @@ -604,6 +627,7 @@ def set_dispatcher_class(self, identity: str,
'ClassMappingDispatcher',
'classgetter',
'ensure_isinstance',
'assert_never',
'ReadOnlyField',
'WriteableField',
'ContextVarsCompat',
Expand Down