diff --git a/eth/beacon/db/chain.py b/eth/beacon/db/chain.py index 9bf76581d5..eccc4c5e6a 100644 --- a/eth/beacon/db/chain.py +++ b/eth/beacon/db/chain.py @@ -243,7 +243,7 @@ def _get_canonical_head(cls, db: BaseDB) -> BaseBeaconBlock: canonical_head_hash = db[SchemaV1.make_canonical_head_hash_lookup_key()] except KeyError: raise CanonicalHeadNotFound("No canonical head set for this chain") - return cls._get_block_by_hash(db, canonical_head_hash) + return cls._get_block_by_hash(db, Hash32(canonical_head_hash)) def get_block_by_hash(self, block_hash: Hash32) -> BaseBeaconBlock: return self._get_block_by_hash(self.db, block_hash) diff --git a/eth/beacon/state_machines/base.py b/eth/beacon/state_machines/base.py index 460a56fa4a..c859b86084 100644 --- a/eth/beacon/state_machines/base.py +++ b/eth/beacon/state_machines/base.py @@ -95,7 +95,7 @@ def __init__(self, chaindb: BaseBeaconChainDB, block: BaseBeaconBlock=None) -> N # Logging # @property - def logger(self): + def logger(self) -> logging.Logger: return logging.getLogger('eth.beacon.state_machines.base.BeaconStateMachine.{0}'.format( self.__class__.__name__ )) diff --git a/eth/beacon/state_machines/forks/serenity/crystallized_states.py b/eth/beacon/state_machines/forks/serenity/crystallized_states.py index 5191bc0a84..92e32e02fe 100644 --- a/eth/beacon/state_machines/forks/serenity/crystallized_states.py +++ b/eth/beacon/state_machines/forks/serenity/crystallized_states.py @@ -3,7 +3,9 @@ class SerenityCrystallizedState(CrystallizedState): @classmethod - def from_crystallized_state(cls, crystallized_state): + def from_crystallized_state(cls, + crystallized_state: CrystallizedState + ) -> "SerenityCrystallizedState": return cls( validators=crystallized_state.validators, last_state_recalc=crystallized_state.last_state_recalc, diff --git a/eth/beacon/types/crystallized_states.py b/eth/beacon/types/crystallized_states.py index 97ac9f944c..8d0fa7c9a8 100644 --- a/eth/beacon/types/crystallized_states.py +++ b/eth/beacon/types/crystallized_states.py @@ -117,5 +117,5 @@ def num_validators(self) -> int: return len(self.validators) @property - def num_crosslink_records(self): + def num_crosslink_records(self) -> int: return len(self.crosslink_records) diff --git a/eth/chains/base.py b/eth/chains/base.py index 1fe027bcbb..ceaa9c6b28 100644 --- a/eth/chains/base.py +++ b/eth/chains/base.py @@ -12,6 +12,7 @@ cast, Dict, Generator, + Iterable, Iterator, List, Optional, @@ -19,6 +20,8 @@ Type, TYPE_CHECKING, Union, + TypeVar, + Generic, ) import logging @@ -32,6 +35,12 @@ encode_hex, ) +from eth.constants import ( + BLANK_ROOT_HASH, + EMPTY_UNCLE_HASH, + MAX_UNCLE_DEPTH, +) + from eth.db.backends.base import BaseAtomicDB from eth.db.chain import ( BaseChainDB, @@ -40,11 +49,7 @@ from eth.db.header import ( HeaderDB, ) -from eth.constants import ( - BLANK_ROOT_HASH, - EMPTY_UNCLE_HASH, - MAX_UNCLE_DEPTH, -) + from eth.estimators import ( get_gas_estimator, ) @@ -53,15 +58,7 @@ TransactionNotFound, VMNotFound, ) -from eth.utils.spoof import ( - SpoofTransaction, -) -from eth.validation import ( - validate_block_number, - validate_uint256, - validate_word, - validate_vm_configuration, -) + from eth.rlp.blocks import ( BaseBlock, ) @@ -76,6 +73,13 @@ BaseTransaction, BaseUnsignedTransaction, ) + +from eth.typing import ( # noqa: F401 + AccountState, + BaseOrSpoofTransaction, + StaticMethod, +) + from eth.utils.db import ( apply_state_dict, ) @@ -88,9 +92,15 @@ from eth.utils.rlp import ( validate_imported_block_unchanged, ) -from eth.typing import ( - AccountState, + +from eth.validation import ( + validate_block_number, + validate_uint256, + validate_word, + validate_vm_configuration, ) +from eth.vm.computation import BaseComputation +from eth.vm.state import BaseState # noqa: F401 from eth._warnings import catch_and_ignore_import_warning with catch_and_ignore_import_warning(): @@ -107,7 +117,9 @@ ) if TYPE_CHECKING: - from eth.vm.base import BaseVM # noqa: F401 + from eth.vm.base import ( # noqa: F401 + BaseVM, + ) class BaseChain(Configurable, ABC): @@ -196,11 +208,11 @@ def get_block_header_by_hash(self, block_hash: Hash32) -> BlockHeader: raise NotImplementedError("Chain classes must implement this method") @abstractmethod - def get_canonical_head(self): + def get_canonical_head(self) -> BlockHeader: raise NotImplementedError("Chain classes must implement this method") @abstractmethod - def get_score(self, block_hash): + def get_score(self, block_hash: Hash32) -> int: raise NotImplementedError("Chain classes must implement this method") # @@ -227,11 +239,15 @@ def get_canonical_block_by_number(self, block_number: BlockNumber) -> BaseBlock: raise NotImplementedError("Chain classes must implement this method") @abstractmethod - def get_canonical_block_hash(self, block_number): + def get_canonical_block_hash(self, block_number: BlockNumber) -> Hash32: raise NotImplementedError("Chain classes must implement this method") @abstractmethod - def build_block_with_transactions(self, transactions, parent_header): + def build_block_with_transactions( + self, + transactions: Tuple[BaseTransaction, ...], + parent_header: BlockHeader=None + ) -> Tuple[BaseBlock, Tuple[Receipt, ...], Tuple[BaseComputation, ...]]: raise NotImplementedError("Chain classes must implement this method") # @@ -262,14 +278,14 @@ def get_canonical_transaction(self, transaction_hash: Hash32) -> BaseTransaction @abstractmethod def get_transaction_result( self, - transaction: Union[BaseTransaction, SpoofTransaction], + transaction: BaseOrSpoofTransaction, at_header: BlockHeader) -> bytes: raise NotImplementedError("Chain classes must implement this method") @abstractmethod def estimate_gas( self, - transaction: Union[BaseTransaction, SpoofTransaction], + transaction: BaseOrSpoofTransaction, at_header: BlockHeader=None) -> int: raise NotImplementedError("Chain classes must implement this method") @@ -320,7 +336,7 @@ class Chain(BaseChain): current block number. """ logger = logging.getLogger("eth.chain.chain.Chain") - gas_estimator = None # type: Callable + gas_estimator = None # type: StaticMethod[Callable[[BaseState, BaseOrSpoofTransaction], int]] chaindb_class = ChainDB # type: Type[BaseChainDB] @@ -335,7 +351,7 @@ def __init__(self, base_db: BaseAtomicDB) -> None: self.chaindb = self.get_chaindb_class()(base_db) self.headerdb = HeaderDB(base_db) if self.gas_estimator is None: - self.gas_estimator = get_gas_estimator() # type: ignore + self.gas_estimator = get_gas_estimator() # # Helpers @@ -414,7 +430,9 @@ def get_vm(self, at_header: BlockHeader=None) -> 'BaseVM': # # Header API # - def create_header_from_parent(self, parent_header, **header_params): + def create_header_from_parent(self, + parent_header: BlockHeader, + **header_params: HeaderParams) -> BlockHeader: """ Passthrough helper to the VM class of the block descending from the given header. @@ -432,7 +450,7 @@ def get_block_header_by_hash(self, block_hash: Hash32) -> BlockHeader: validate_word(block_hash, title="Block Hash") return self.chaindb.get_block_header_by_hash(block_hash) - def get_canonical_head(self): + def get_canonical_head(self) -> BlockHeader: """ Returns the block header at the canonical chain head. @@ -440,7 +458,7 @@ def get_canonical_head(self): """ return self.chaindb.get_canonical_head() - def get_score(self, block_hash): + def get_score(self, block_hash: Hash32) -> int: """ Returns the difficulty score of the block with the given hash. @@ -498,7 +516,7 @@ def get_block_by_hash(self, block_hash: Hash32) -> BaseBlock: block_header = self.get_block_header_by_hash(block_hash) return self.get_block_by_header(block_header) - def get_block_by_header(self, block_header): + def get_block_by_header(self, block_header: BlockHeader) -> BaseBlock: """ Returns the requested block as specified by the block header. """ @@ -524,7 +542,11 @@ def get_canonical_block_hash(self, block_number: BlockNumber) -> Hash32: """ return self.chaindb.get_canonical_block_hash(block_number) - def build_block_with_transactions(self, transactions, parent_header=None): + def build_block_with_transactions( + self, + transactions: Tuple[BaseTransaction, ...], + parent_header: BlockHeader=None + ) -> Tuple[BaseBlock, Tuple[Receipt, ...], Tuple[BaseComputation, ...]]: """ Generate a block with the provided transactions. This does *not* import that block into your chain. If you want this new block in your chain, @@ -554,12 +576,12 @@ def get_canonical_transaction(self, transaction_hash: Hash32) -> BaseTransaction found in the main chain. """ (block_num, index) = self.chaindb.get_transaction_index(transaction_hash) - VM = self.get_vm_class_for_block_number(block_num) + VM_class = self.get_vm_class_for_block_number(block_num) transaction = self.chaindb.get_transaction_by_index( block_num, index, - VM.get_transaction_class(), + VM_class.get_transaction_class(), ) if transaction.hash == transaction_hash: @@ -603,22 +625,21 @@ def create_unsigned_transaction(self, # def get_transaction_result( self, - transaction: Union[BaseTransaction, SpoofTransaction], + transaction: BaseOrSpoofTransaction, at_header: BlockHeader) -> bytes: """ Return the result of running the given transaction. This is referred to as a `call()` in web3. """ with self.get_vm(at_header).state_in_temp_block() as state: - # Ignore is to not bleed the SpoofTransaction deeper into the code base - computation = state.costless_execute_transaction(transaction) # type: ignore + computation = state.costless_execute_transaction(transaction) computation.raise_if_error() return computation.output def estimate_gas( self, - transaction: Union[BaseTransaction, SpoofTransaction], + transaction: BaseOrSpoofTransaction, at_header: BlockHeader=None) -> int: """ Returns an estimation of the amount of gas the given transaction will @@ -689,8 +710,8 @@ def import_block(self, # Validation API # def validate_receipt(self, receipt: Receipt, at_header: BlockHeader) -> None: - VM = self.get_vm_class(at_header) - VM.validate_receipt(receipt) + VM_class = self.get_vm_class(at_header) + VM_class.validate_receipt(receipt) def validate_block(self, block: BaseBlock) -> None: """ @@ -704,9 +725,9 @@ def validate_block(self, block: BaseBlock) -> None: """ if block.is_genesis: raise ValidationError("Cannot validate genesis block this way") - VM = self.get_vm_class_for_block_number(BlockNumber(block.number)) + VM_class = self.get_vm_class_for_block_number(BlockNumber(block.number)) parent_block = self.get_block_by_hash(block.header.parent_hash) - VM.validate_header(block.header, parent_block.header, check_seal=True) + VM_class.validate_header(block.header, parent_block.header, check_seal=True) self.validate_uncles(block) self.validate_gaslimit(block.header) @@ -714,8 +735,8 @@ def validate_seal(self, header: BlockHeader) -> None: """ Validate the seal on the given header. """ - VM = self.get_vm_class_for_block_number(BlockNumber(header.block_number)) - VM.validate_seal(header) + VM_class = self.get_vm_class_for_block_number(BlockNumber(header.block_number)) + VM_class.validate_seal(header) def validate_gaslimit(self, header: BlockHeader) -> None: """ @@ -830,7 +851,7 @@ def validate_chain( @to_set -def _extract_uncle_hashes(blocks): +def _extract_uncle_hashes(blocks: Iterable[BaseBlock]) -> Iterable[Hash32]: for block in blocks: for uncle in block.uncles: yield uncle.hash @@ -843,7 +864,9 @@ def __init__(self, base_db: BaseAtomicDB, header: BlockHeader=None) -> None: super().__init__(base_db) self.header = self.ensure_header(header) - def apply_transaction(self, transaction): + def apply_transaction(self, + transaction: BaseTransaction + ) -> Tuple[BaseBlock, Receipt, BaseComputation]: """ Applies the transaction to the current tip block. diff --git a/eth/chains/header.py b/eth/chains/header.py index 270443d263..ad00fa7cd3 100644 --- a/eth/chains/header.py +++ b/eth/chains/header.py @@ -1,12 +1,21 @@ from abc import ABC, abstractmethod -from typing import Dict, Any, Tuple, Type # noqa: F401 +from typing import ( # noqa: F401 + Any, + cast, + Dict, + Tuple, + Type, +) from eth_typing import ( BlockNumber, Hash32, ) -from eth.db.backends.base import BaseDB +from eth.db.backends.base import ( + BaseAtomicDB, + BaseDB, +) from eth.db.header import ( # noqa: F401 BaseHeaderDB, HeaderDB, @@ -47,7 +56,7 @@ def from_genesis_header(cls, # @classmethod @abstractmethod - def get_headerdb_class(cls): + def get_headerdb_class(cls) -> Type[BaseHeaderDB]: raise NotImplementedError("Chain classes must implement this method") # @@ -73,7 +82,9 @@ def header_exists(self, block_hash: Hash32) -> bool: raise NotImplementedError("Chain classes must implement this method") @abstractmethod - def import_header(self, header: BlockHeader) -> Tuple[BlockHeader, ...]: + def import_header(self, + header: BlockHeader + ) -> Tuple[Tuple[BlockHeader, ...], Tuple[BlockHeader, ...]]: raise NotImplementedError("Chain classes must implement this method") @@ -82,7 +93,7 @@ class HeaderChain(BaseHeaderChain): def __init__(self, base_db: BaseDB, header: BlockHeader=None) -> None: self.base_db = base_db - self.headerdb = self.get_headerdb_class()(base_db) + self.headerdb = self.get_headerdb_class()(cast(BaseAtomicDB, base_db)) if header is None: self.header = self.get_canonical_head() @@ -99,7 +110,7 @@ def from_genesis_header(cls, """ Initializes the chain from the genesis header. """ - headerdb = cls.get_headerdb_class()(base_db) + headerdb = cls.get_headerdb_class()(cast(BaseAtomicDB, base_db)) headerdb.persist_header(genesis_header) return cls(base_db, genesis_header) @@ -107,7 +118,7 @@ def from_genesis_header(cls, # Helpers # @classmethod - def get_headerdb_class(cls): + def get_headerdb_class(cls) -> Type[BaseHeaderDB]: """ Returns the class which should be used for the `headerdb` """ @@ -151,7 +162,9 @@ def header_exists(self, block_hash: Hash32) -> bool: """ return self.headerdb.header_exists(block_hash) - def import_header(self, header: BlockHeader) -> Tuple[BlockHeader, ...]: + def import_header(self, + header: BlockHeader + ) -> Tuple[Tuple[BlockHeader, ...], Tuple[BlockHeader, ...]]: """ Direct passthrough to `headerdb` diff --git a/eth/chains/mainnet/__init__.py b/eth/chains/mainnet/__init__.py index e37780e17b..99bce66a28 100644 --- a/eth/chains/mainnet/__init__.py +++ b/eth/chains/mainnet/__init__.py @@ -1,4 +1,8 @@ -from typing import Tuple, Type # noqa: F401 +from typing import ( # noqa: F401 + Tuple, + Type, + TypeVar, +) from eth_utils import ( decode_hex, @@ -31,13 +35,16 @@ ) -class MainnetDAOValidatorVM: +class MainnetDAOValidatorVM(HomesteadVM): """Only on mainnet, TheDAO fork is accompanied by special extra data. Validate those headers""" @classmethod - def validate_header(cls, header, previous_header, check_seal=True): - # ignore mypy warnings, because super's validate_header is defined by mixing w/ other class - super().validate_header(header, previous_header, check_seal) # type: ignore + def validate_header(cls, + header: BlockHeader, + previous_header: BlockHeader, + check_seal: bool=True) -> None: + + super().validate_header(header, previous_header, check_seal) # The special extra_data is set on the ten headers starting at the fork dao_fork_at = cls.get_dao_fork_block_number() @@ -61,7 +68,7 @@ def validate_header(cls, header, previous_header, check_seal=True): ) -class MainnetHomesteadVM(MainnetDAOValidatorVM, HomesteadVM): +class MainnetHomesteadVM(MainnetDAOValidatorVM): _dao_fork_block_number = DAO_FORK_MAINNET_BLOCK diff --git a/eth/chains/tester/__init__.py b/eth/chains/tester/__init__.py index f8339ead73..7ab9a080c2 100644 --- a/eth/chains/tester/__init__.py +++ b/eth/chains/tester/__init__.py @@ -21,6 +21,9 @@ from eth.chains.base import Chain from eth.chains.mainnet import MainnetChain +from eth.rlp.blocks import ( + BaseBlock, +) from eth.rlp.headers import ( BlockHeader ) @@ -144,7 +147,7 @@ class MainnetTesterChain(BaseMainnetTesterChain): configuration of fork rules. """ @classmethod - def validate_seal(cls, block): + def validate_seal(cls, block: BaseBlock) -> None: """ We don't validate the proof of work seal on the tester chain. """ diff --git a/eth/db/account.py b/eth/db/account.py index 7cf4957d29..c3756a7509 100644 --- a/eth/db/account.py +++ b/eth/db/account.py @@ -28,6 +28,9 @@ BLANK_ROOT_HASH, EMPTY_SHA3, ) +from eth.db.backends.base import ( + BaseDB, +) from eth.db.batch import ( BatchDB, ) @@ -71,7 +74,7 @@ def __init__(self) -> None: @property @abstractmethod - def state_root(self): + def state_root(self) -> Hash32: raise NotImplementedError("Must be implemented by subclasses") @abstractmethod @@ -82,11 +85,11 @@ def has_root(self, state_root: bytes) -> bool: # Storage # @abstractmethod - def get_storage(self, address, slot): + def get_storage(self, address: Address, slot: int) -> int: raise NotImplementedError("Must be implemented by subclasses") @abstractmethod - def set_storage(self, address, slot, value): + def set_storage(self, address: Address, slot: int, value: int) -> None: raise NotImplementedError("Must be implemented by subclasses") # @@ -104,40 +107,40 @@ def set_nonce(self, address: Address, nonce: int) -> None: # Balance # @abstractmethod - def get_balance(self, address): + def get_balance(self, address: Address) -> int: raise NotImplementedError("Must be implemented by subclasses") @abstractmethod - def set_balance(self, address, balance): + def set_balance(self, address: Address, balance: int) -> None: raise NotImplementedError("Must be implemented by subclasses") - def delta_balance(self, address, delta): + def delta_balance(self, address: Address, delta: int) -> None: self.set_balance(address, self.get_balance(address) + delta) # # Code # @abstractmethod - def set_code(self, address, code): + def set_code(self, address: Address, code: bytes) -> None: raise NotImplementedError("Must be implemented by subclasses") @abstractmethod - def get_code(self, address): + def get_code(self, address: Address) -> bytes: raise NotImplementedError("Must be implemented by subclasses") @abstractmethod - def get_code_hash(self, address): + def get_code_hash(self, address: Address) -> Hash32: raise NotImplementedError("Must be implemented by subclasses") @abstractmethod - def delete_code(self, address): + def delete_code(self, address: Address) -> None: raise NotImplementedError("Must be implemented by subclasses") # # Account Methods # @abstractmethod - def account_is_empty(self, address): + def account_is_empty(self, address: Address) -> bool: raise NotImplementedError("Must be implemented by subclass") # @@ -177,7 +180,7 @@ class AccountDB(BaseAccountDB): logger = cast(TraceLogger, logging.getLogger('eth.db.account.AccountDB')) - def __init__(self, db, state_root=BLANK_ROOT_HASH): + def __init__(self, db: BaseDB, state_root: Hash32=BLANK_ROOT_HASH) -> None: r""" Internal implementation details (subject to rapid change): Database entries go through several pipes, like so... @@ -225,11 +228,11 @@ def __init__(self, db, state_root=BLANK_ROOT_HASH): self._journaltrie = JournalDB(self._trie_cache) @property - def state_root(self): + def state_root(self) -> Hash32: return self._trie.root_hash @state_root.setter - def state_root(self, value): + def state_root(self, value: Hash32) -> None: self._trie_cache.reset_cache() self._trie.root_hash = value @@ -239,7 +242,7 @@ def has_root(self, state_root: bytes) -> bool: # # Storage # - def get_storage(self, address, slot, from_journal=True): + def get_storage(self, address: Address, slot: int, from_journal: bool=True) -> int: validate_canonical_address(address, title="Storage Address") validate_uint256(slot, title="Storage Slot") @@ -254,7 +257,7 @@ def get_storage(self, address, slot, from_journal=True): else: return 0 - def set_storage(self, address, slot, value): + def set_storage(self, address: Address, slot: int, value: int) -> None: validate_uint256(value, title="Storage Value") validate_uint256(slot, title="Storage Slot") validate_canonical_address(address, title="Storage Address") @@ -272,7 +275,7 @@ def set_storage(self, address, slot, value): self._set_account(address, account.copy(storage_root=storage.root_hash)) - def delete_storage(self, address): + def delete_storage(self, address: Address) -> None: validate_canonical_address(address, title="Storage Address") account = self._get_account(address) @@ -281,13 +284,13 @@ def delete_storage(self, address): # # Balance # - def get_balance(self, address): + def get_balance(self, address: Address) -> int: validate_canonical_address(address, title="Storage Address") account = self._get_account(address) return account.balance - def set_balance(self, address, balance): + def set_balance(self, address: Address, balance: int) -> None: validate_canonical_address(address, title="Storage Address") validate_uint256(balance, title="Account Balance") @@ -297,27 +300,27 @@ def set_balance(self, address, balance): # # Nonce # - def get_nonce(self, address): + def get_nonce(self, address: Address) -> int: validate_canonical_address(address, title="Storage Address") account = self._get_account(address) return account.nonce - def set_nonce(self, address, nonce): + def set_nonce(self, address: Address, nonce: int) -> None: validate_canonical_address(address, title="Storage Address") validate_uint256(nonce, title="Nonce") account = self._get_account(address) self._set_account(address, account.copy(nonce=nonce)) - def increment_nonce(self, address): + def increment_nonce(self, address: Address) -> None: current_nonce = self.get_nonce(address) self.set_nonce(address, current_nonce + 1) # # Code # - def get_code(self, address): + def get_code(self, address: Address) -> bytes: validate_canonical_address(address, title="Storage Address") try: @@ -325,7 +328,7 @@ def get_code(self, address): except KeyError: return b"" - def set_code(self, address, code): + def set_code(self, address: Address, code: bytes) -> None: validate_canonical_address(address, title="Storage Address") validate_is_bytes(code, title="Code") @@ -335,13 +338,13 @@ def set_code(self, address, code): self._journaldb[code_hash] = code self._set_account(address, account.copy(code_hash=code_hash)) - def get_code_hash(self, address): + def get_code_hash(self, address: Address) -> Hash32: validate_canonical_address(address, title="Storage Address") account = self._get_account(address) return account.code_hash - def delete_code(self, address): + def delete_code(self, address: Address) -> None: validate_canonical_address(address, title="Storage Address") account = self._get_account(address) @@ -350,32 +353,32 @@ def delete_code(self, address): # # Account Methods # - def account_has_code_or_nonce(self, address): + def account_has_code_or_nonce(self, address: Address) -> bool: return self.get_nonce(address) != 0 or self.get_code_hash(address) != EMPTY_SHA3 - def delete_account(self, address): + def delete_account(self, address: Address) -> None: validate_canonical_address(address, title="Storage Address") del self._journaltrie[address] - def account_exists(self, address): + def account_exists(self, address: Address) -> bool: validate_canonical_address(address, title="Storage Address") return self._journaltrie.get(address, b'') != b'' - def touch_account(self, address): + def touch_account(self, address: Address) -> None: validate_canonical_address(address, title="Storage Address") account = self._get_account(address) self._set_account(address, account) - def account_is_empty(self, address): + def account_is_empty(self, address: Address) -> bool: return not self.account_has_code_or_nonce(address) and self.get_balance(address) == 0 # # Internal # - def _get_account(self, address, from_journal=True): + def _get_account(self, address: Address, from_journal: bool=True) -> Account: rlp_account = (self._journaltrie if from_journal else self._trie_cache).get(address, b'') if rlp_account: account = rlp.decode(rlp_account, sedes=Account) @@ -383,7 +386,7 @@ def _get_account(self, address, from_journal=True): account = Account() return account - def _set_account(self, address, account): + def _set_account(self, address: Address, account: Account) -> None: rlp_account = rlp.encode(account, sedes=Account) self._journaltrie[address] = rlp_account @@ -424,7 +427,7 @@ def _log_pending_accounts(self) -> None: continue else: accounts_displayed.add(address) - account = self._get_account(address) + account = self._get_account(Address(address)) self.logger.trace( "Account %s: balance %d, nonce %d, storage root %s, code hash %s", encode_hex(address), diff --git a/eth/db/atomic.py b/eth/db/atomic.py index 2d9fafacd7..48899b3ac4 100644 --- a/eth/db/atomic.py +++ b/eth/db/atomic.py @@ -1,6 +1,9 @@ from contextlib import contextmanager import logging -from typing import Generator +from typing import ( + Generator, + Iterator, +) from eth_utils import ( ValidationError, @@ -110,14 +113,14 @@ def _exists(self, key: bytes) -> bool: @classmethod @contextmanager - def _commit_unless_raises(cls, write_target_db): + def _commit_unless_raises(cls, write_target_db: BaseDB) -> Iterator['AtomicDBWriteBatch']: """ Commit all writes inside the context, unless an exception was raised. Although this is technically an external API, it (and this whole class) is only intended to be used by AtomicDB. """ - readable_write_batch = cls(write_target_db) + readable_write_batch = cls(write_target_db) # type: AtomicDBWriteBatch try: yield readable_write_batch except Exception: diff --git a/eth/db/backends/base.py b/eth/db/backends/base.py index 88423fc0f4..49ad884f3c 100644 --- a/eth/db/backends/base.py +++ b/eth/db/backends/base.py @@ -6,8 +6,18 @@ MutableMapping, ) +from typing import ( + Any, + TYPE_CHECKING +) + +if TYPE_CHECKING: + MM = MutableMapping[bytes, bytes] +else: + MM = MutableMapping + -class BaseDB(MutableMapping, ABC): +class BaseDB(MM, ABC): """ This is an abstract key/value lookup with all :class:`bytes` values, with some convenience methods for databases. As much as possible, @@ -35,9 +45,10 @@ def set(self, key: bytes, value: bytes) -> None: def exists(self, key: bytes) -> bool: return self.__contains__(key) - def __contains__(self, key): + def __contains__(self, key: bytes) -> bool: # type: ignore # Breaks LSP if hasattr(self, '_exists'): - return self._exists(key) + # Classes which inherit this class would have `_exists` attr + return self._exists(key) # type: ignore else: return super().__contains__(key) @@ -47,10 +58,10 @@ def delete(self, key: bytes) -> None: except KeyError: return None - def __iter__(self): - raise NotImplementedError("By default, DB classes cannot by iterated.") + def __iter__(self) -> None: + raise NotImplementedError("By default, DB classes cannot be iterated.") - def __len__(self): + def __len__(self) -> int: raise NotImplementedError("By default, DB classes cannot return the total number of keys.") @@ -80,5 +91,5 @@ class BaseAtomicDB(BaseDB): # or neither will """ @abstractmethod - def atomic_batch(self): + def atomic_batch(self) -> Any: raise NotImplementedError diff --git a/eth/db/batch.py b/eth/db/batch.py index 23c4da4de2..cfd5407f02 100644 --- a/eth/db/batch.py +++ b/eth/db/batch.py @@ -37,7 +37,7 @@ def __exit__(self, exc_type: None, exc_value: None, traceback: None) -> None: self.clear() self.logger.exception("Unexpected error occurred during batch update") - def clear(self): + def clear(self) -> None: self._track_diff = DBDiffTracker() def commit(self, apply_deletes: bool = True) -> None: diff --git a/eth/db/cache.py b/eth/db/cache.py index db9e5aaf24..d43f62f239 100644 --- a/eth/db/cache.py +++ b/eth/db/cache.py @@ -8,24 +8,24 @@ class CacheDB(BaseDB): Set and get decoded RLP objects, where the underlying db stores encoded objects. """ - def __init__(self, db, cache_size=2048): + def __init__(self, db: BaseDB, cache_size: int=2048) -> None: self._db = db self._cache_size = cache_size self.reset_cache() - def reset_cache(self): + def reset_cache(self) -> None: self._cached_values = LRU(self._cache_size) - def __getitem__(self, key): + def __getitem__(self, key: bytes) -> bytes: if key not in self._cached_values: self._cached_values[key] = self._db[key] return self._cached_values[key] - def __setitem__(self, key, value): + def __setitem__(self, key: bytes, value: bytes) -> None: self._cached_values[key] = value self._db[key] = value - def __delitem__(self, key): + def __delitem__(self, key: bytes) -> None: if key in self._cached_values: del self._cached_values[key] del self._db[key] diff --git a/eth/db/diff.py b/eth/db/diff.py index ea58f75422..89d238bc60 100644 --- a/eth/db/diff.py +++ b/eth/db/diff.py @@ -2,18 +2,30 @@ Mapping, MutableMapping, ) -from typing import ( # noqa: F401 +from typing import ( Dict, Iterable, Union, + TYPE_CHECKING, ) +from eth.db.backends.base import BaseDB + +if TYPE_CHECKING: + ABC_Mutable_Mapping = MutableMapping[bytes, Union[bytes, 'MissingReason']] + ABC_Mapping = Mapping[bytes, Union[bytes, 'MissingReason']] +else: + ABC_Mutable_Mapping = MutableMapping + ABC_Mapping = Mapping + class MissingReason: - def __init__(self, reason): + def __init__(self, reason: str) -> None: self.reason = reason - def __str__(self, reason): + def __str__(self, reason: str) -> str: # type: ignore + # Ignoring mypy type here because the function signature + # has been overwritten from the traditional `def __str__(self): ...` return "Key is missing because it was {}".format(self.reason) @@ -34,11 +46,11 @@ def __init__(self, missing_key: bytes, reason: MissingReason) -> None: super().__init__(missing_key, reason) @property - def is_deleted(self): + def is_deleted(self) -> bool: return self.reason == DELETED -class DBDiffTracker(MutableMapping): +class DBDiffTracker(ABC_Mutable_Mapping): """ Records changes to a :class:`~eth.db.BaseDB` @@ -53,41 +65,41 @@ class DBDiffTracker(MutableMapping): When it's time to take the tracked changes and write them to your database, get the :class:`DBDiff` with :meth:`DBDiffTracker.diff` and use the attached methods. """ - def __init__(self): - self._changes = {} # type: Dict[bytes, Union[bytes, DiffMissingError]] + def __init__(self) -> None: + self._changes = {} # type: Dict[bytes, Union[bytes, MissingReason]] - def __contains__(self, key): + def __contains__(self, key: bytes) -> bool: # type: ignore # Breaks LSP result = self._changes.get(key, NEVER_INSERTED) return result not in (DELETED, NEVER_INSERTED) - def __getitem__(self, key): + def __getitem__(self, key: bytes) -> bytes: result = self._changes.get(key, NEVER_INSERTED) if result in (DELETED, NEVER_INSERTED): - raise DiffMissingError(key, result) + raise DiffMissingError(key, result) # type: ignore # ignore over cast for perf reasons else: - return result + return result # type: ignore # ignore over cast for perf reasons - def __setitem__(self, key, value): + def __setitem__(self, key: bytes, value: Union[bytes, MissingReason]) -> None: self._changes[key] = value - def __delitem__(self, key): + def __delitem__(self, key: bytes) -> None: # The diff does not have access to any underlying db, # so it cannot check if the key exists before deleting. self._changes[key] = DELETED - def __iter__(self): + def __iter__(self) -> None: raise NotImplementedError( "Cannot iterate through changes, use diff().apply_to(db) to update a database" ) - def __len__(self): + def __len__(self) -> int: return len(self._changes) - def diff(self): + def diff(self) -> 'DBDiff': return DBDiff(dict(self._changes)) -class DBDiff(Mapping): +class DBDiff(ABC_Mapping): """ DBDiff is a read-only view of the updates/inserts and deletes generated when tracking changes with :class:`DBDiffTracker`. @@ -95,30 +107,32 @@ class DBDiff(Mapping): The primary usage is to apply these changes to your underlying database with :meth:`apply_to`. """ - _changes = None # type: Dict[bytes, Union[bytes, DiffMissingError]] + _changes = None # type: Dict[bytes, Union[bytes, MissingReason]] - def __init__(self, changes: Dict[bytes, Union[bytes, DiffMissingError]] = None) -> None: + def __init__(self, changes: Dict[bytes, Union[bytes, MissingReason]] = None) -> None: if changes is None: self._changes = {} else: self._changes = changes - def __getitem__(self, key): + def __getitem__(self, key: bytes) -> bytes: result = self._changes.get(key, NEVER_INSERTED) if result in (DELETED, NEVER_INSERTED): - raise DiffMissingError(key, result) + raise DiffMissingError(key, result) # type: ignore # ignore over cast for perf reasons else: - return result + return result # type: ignore # ignore over cast for perf reasons - def __iter__(self): + def __iter__(self) -> None: raise NotImplementedError( "Cannot iterate through changes, use apply_to(db) to update a database" ) - def __len__(self): + def __len__(self) -> int: return len(self._changes) - def apply_to(self, db: MutableMapping, apply_deletes: bool = True) -> None: + def apply_to(self, + db: Union[BaseDB, ABC_Mutable_Mapping], + apply_deletes: bool = True) -> None: """ Apply the changes in this diff to the given database. You may choose to opt out of deleting any underlying keys. @@ -136,7 +150,7 @@ def apply_to(self, db: MutableMapping, apply_deletes: bool = True) -> None: else: pass else: - db[key] = value + db[key] = value # type: ignore # ignore over cast for perf reasons @classmethod def join(cls, diffs: Iterable['DBDiff']) -> 'DBDiff': diff --git a/eth/db/header.py b/eth/db/header.py index 04ff738024..c9751b3633 100644 --- a/eth/db/header.py +++ b/eth/db/header.py @@ -146,7 +146,7 @@ def _get_canonical_head(cls, db: BaseDB) -> BlockHeader: canonical_head_hash = db[SchemaV1.make_canonical_head_hash_lookup_key()] except KeyError: raise CanonicalHeadNotFound("No canonical head set for this chain") - return cls._get_block_header_by_hash(db, canonical_head_hash) + return cls._get_block_header_by_hash(db, Hash32(canonical_head_hash)) # # Header API diff --git a/eth/db/journal.py b/eth/db/journal.py index e47c575696..bc683b7006 100644 --- a/eth/db/journal.py +++ b/eth/db/journal.py @@ -121,7 +121,7 @@ def commit_changeset(self, changeset_id: uuid.UUID) -> Dict[bytes, bytes]: # # Database API # - def __getitem__(self, key: bytes) -> Union[bytes, DeletedEntry]: + def __getitem__(self, key: bytes) -> Union[bytes, DeletedEntry]: # type: ignore # Breaks LSP """ For key lookups we need to iterate through the changesets in reverse order, returning from the first one in which the key is present. diff --git a/eth/db/keymap.py b/eth/db/keymap.py index 22eeb68dcf..82ea25ccee 100644 --- a/eth/db/keymap.py +++ b/eth/db/keymap.py @@ -2,6 +2,10 @@ abstractmethod, ) +from typing import ( + Any, +) + from eth.db.backends.base import BaseDB @@ -10,7 +14,7 @@ class KeyMapDB(BaseDB): Modify keys when accessing the database, according to the abstract keymap function set in the subclass. """ - def __init__(self, db): + def __init__(self, db: BaseDB) -> None: self._db = db @staticmethod @@ -18,26 +22,26 @@ def __init__(self, db): def keymap(key: bytes) -> bytes: raise NotImplementedError - def __getitem__(self, key): + def __getitem__(self, key: bytes) -> bytes: mapped_key = self.keymap(key) return self._db[mapped_key] - def __setitem__(self, key, val): + def __setitem__(self, key: bytes, val: bytes) -> None: mapped_key = self.keymap(key) self._db[mapped_key] = val - def __delitem__(self, key): + def __delitem__(self, key: bytes) -> None: mapped_key = self.keymap(key) del self._db[mapped_key] - def __contains__(self, key): + def __contains__(self, key: bytes) -> bool: # type: ignore # Breaks LSP mapped_key = self.keymap(key) return mapped_key in self._db - def __getattr__(self, attr): + def __getattr__(self, attr: Any) -> Any: return getattr(self._db, attr) - def __setattr__(self, attr, val): + def __setattr__(self, attr: Any, val: Any) -> None: if attr in ('_db', 'keymap'): super().__setattr__(attr, val) else: diff --git a/eth/estimators/__init__.py b/eth/estimators/__init__.py index f879ab0107..1d03990dc2 100644 --- a/eth/estimators/__init__.py +++ b/eth/estimators/__init__.py @@ -1,19 +1,26 @@ import os from typing import ( Callable, - cast + cast, ) - -from eth_utils import import_string - -from eth.rlp.transactions import BaseTransaction -from eth.vm.state import BaseState +from eth.typing import ( + BaseOrSpoofTransaction, +) +from eth.utils.module_loading import ( + import_string, +) +from eth.vm.state import ( + BaseState, +) -def get_gas_estimator() -> Callable[[BaseState, BaseTransaction], int]: +def get_gas_estimator() -> Callable[[BaseState, BaseOrSpoofTransaction], int]: import_path = os.environ.get( 'GAS_ESTIMATOR_BACKEND_FUNC', 'eth.estimators.gas.binary_gas_search_intrinsic_tolerance', ) - return cast(Callable[[BaseState, BaseTransaction], int], import_string(import_path)) + return cast( + Callable[[BaseState, BaseOrSpoofTransaction], int], + import_string(import_path) + ) diff --git a/eth/estimators/gas.py b/eth/estimators/gas.py index 43eb3f7088..ef5a250ab9 100644 --- a/eth/estimators/gas.py +++ b/eth/estimators/gas.py @@ -1,14 +1,29 @@ +from typing import ( + Optional, +) + from cytoolz import ( curry, ) +from eth.exceptions import ( + VMError, +) + +from eth.rlp.transactions import ( + BaseTransaction, +) from eth.utils.spoof import ( SpoofTransaction, ) +from eth.vm.state import ( + BaseState, +) + -def _get_computation_error(state, transaction): +def _get_computation_error(state: BaseState, transaction: SpoofTransaction) -> Optional[VMError]: snapshot = state.snapshot() @@ -24,7 +39,7 @@ def _get_computation_error(state, transaction): @curry -def binary_gas_search(state, transaction, tolerance=1): +def binary_gas_search(state: BaseState, transaction: BaseTransaction, tolerance: int=1) -> int: """ Run the transaction with various gas limits, progressively approaching the minimum needed to succeed without an OutOfGas exception. diff --git a/eth/precompiles/ecadd.py b/eth/precompiles/ecadd.py index 5debf2b2b7..27938b0e24 100644 --- a/eth/precompiles/ecadd.py +++ b/eth/precompiles/ecadd.py @@ -1,3 +1,5 @@ +from typing import Tuple + from py_ecc import ( optimized_bn128 as bn128, ) @@ -21,8 +23,12 @@ pad32r, ) +from eth.vm.computation import ( + BaseComputation, +) + -def ecadd(computation): +def ecadd(computation: BaseComputation) -> BaseComputation: computation.consume_gas(constants.GAS_ECADD, reason='ECADD Precompile') try: @@ -39,7 +45,7 @@ def ecadd(computation): return computation -def _ecadd(data): +def _ecadd(data: bytes) -> Tuple[bn128.FQ, bn128.FQ]: x1_bytes = pad32r(data[:32]) y1_bytes = pad32r(data[32:64]) x2_bytes = pad32r(data[64:96]) diff --git a/eth/precompiles/ecmul.py b/eth/precompiles/ecmul.py index 30951e6d26..0df5238c5c 100644 --- a/eth/precompiles/ecmul.py +++ b/eth/precompiles/ecmul.py @@ -1,3 +1,5 @@ +from typing import Tuple + from py_ecc import ( optimized_bn128 as bn128, ) @@ -21,8 +23,12 @@ pad32r, ) +from eth.vm.computation import ( + BaseComputation, +) + -def ecmul(computation): +def ecmul(computation: BaseComputation) -> BaseComputation: computation.consume_gas(constants.GAS_ECMUL, reason='ECMUL Precompile') try: @@ -39,7 +45,7 @@ def ecmul(computation): return computation -def _ecmull(data): +def _ecmull(data: bytes) -> Tuple[bn128.FQ, bn128.FQ]: x_bytes = pad32r(data[:32]) y_bytes = pad32r(data[32:64]) m_bytes = pad32r(data[64:96]) diff --git a/eth/precompiles/ecpairing.py b/eth/precompiles/ecpairing.py index c9ce0b2df3..d7628570fc 100644 --- a/eth/precompiles/ecpairing.py +++ b/eth/precompiles/ecpairing.py @@ -1,3 +1,5 @@ +from typing import Tuple + from cytoolz import ( curry, pipe, @@ -16,6 +18,7 @@ from eth.exceptions import ( VMError, ) + from eth.utils.bn128 import ( validate_point, ) @@ -23,12 +26,16 @@ pad32, ) +from eth.vm.computation import ( + BaseComputation, +) + ZERO = (bn128.FQ2.one(), bn128.FQ2.one(), bn128.FQ2.zero()) EXPONENT = bn128.FQ12.one() -def ecpairing(computation): +def ecpairing(computation: BaseComputation) -> BaseComputation: if len(computation.msg.data) % 192: # data length must be an exact multiple of 192 raise VMError("Invalid ECPAIRING parameters") @@ -52,7 +59,7 @@ def ecpairing(computation): return computation -def _ecpairing(data): +def _ecpairing(data: bytes) -> bool: exponent = bn128.FQ12.one() processing_pipeline = ( @@ -67,7 +74,7 @@ def _ecpairing(data): @curry -def _process_point(data_buffer, exponent): +def _process_point(data_buffer: bytes, exponent: int) -> int: x1, y1, x2_i, x2_r, y2_i, y2_r = _extract_point(data_buffer) p1 = validate_point(x1, y1) @@ -91,7 +98,7 @@ def _process_point(data_buffer, exponent): return exponent * bn128.pairing(p2, p1, final_exponentiate=False) -def _extract_point(data_slice): +def _extract_point(data_slice: bytes) -> Tuple[int, int, int, int, int, int]: x1_bytes = data_slice[:32] y1_bytes = data_slice[32:64] x2_i_bytes = data_slice[64:96] diff --git a/eth/precompiles/ecrecover.py b/eth/precompiles/ecrecover.py index e860c201dc..d04ec02b24 100644 --- a/eth/precompiles/ecrecover.py +++ b/eth/precompiles/ecrecover.py @@ -10,19 +10,22 @@ from eth import constants +from eth.utils.padding import ( + pad32, + pad32r, +) + from eth.validation import ( validate_lt_secpk1n, validate_gte, validate_lte, ) - -from eth.utils.padding import ( - pad32, - pad32r, +from eth.vm.computation import ( + BaseComputation, ) -def ecrecover(computation): +def ecrecover(computation: BaseComputation) -> BaseComputation: computation.consume_gas(constants.GAS_ECRECOVER, reason="ECRecover Precompile") raw_message_hash = computation.msg.data[:32] message_hash = pad32r(raw_message_hash) diff --git a/eth/precompiles/identity.py b/eth/precompiles/identity.py index 111cf44663..a2661559c0 100644 --- a/eth/precompiles/identity.py +++ b/eth/precompiles/identity.py @@ -3,8 +3,12 @@ ceil32, ) +from eth.vm.computation import ( + BaseComputation, +) + -def identity(computation): +def identity(computation: BaseComputation) -> BaseComputation: word_count = ceil32(len(computation.msg.data)) // 32 gas_fee = constants.GAS_IDENTITY + word_count * constants.GAS_IDENTITYWORD diff --git a/eth/precompiles/modexp.py b/eth/precompiles/modexp.py index 8eafbb0b79..d6d0772653 100644 --- a/eth/precompiles/modexp.py +++ b/eth/precompiles/modexp.py @@ -1,3 +1,7 @@ +from typing import ( + Tuple, +) + from eth_utils import ( big_endian_to_int, int_to_big_endian, @@ -14,8 +18,13 @@ zpad_left, ) +from eth.vm.computation import ( + BaseComputation, +) + -def _compute_adjusted_exponent_length(exponent_length, first_32_exponent_bytes): +def _compute_adjusted_exponent_length(exponent_length: int, + first_32_exponent_bytes: bytes) -> int: exponent = big_endian_to_int(first_32_exponent_bytes) if exponent_length <= 32 and exponent == 0: @@ -30,7 +39,7 @@ def _compute_adjusted_exponent_length(exponent_length, first_32_exponent_bytes): ) -def _compute_complexity(length): +def _compute_complexity(length: int) -> int: if length <= 64: return length ** 2 elif length <= 1024: @@ -41,7 +50,7 @@ def _compute_complexity(length): return length ** 2 // 16 + 480 * length - 199680 -def _extract_lengths(data): +def _extract_lengths(data: bytes) -> Tuple[int, int, int]: # extract argument lengths base_length_bytes = pad32r(data[:32]) base_length = big_endian_to_int(base_length_bytes) @@ -55,7 +64,7 @@ def _extract_lengths(data): return base_length, exponent_length, modulus_length -def _compute_modexp_gas_fee(data): +def _compute_modexp_gas_fee(data: bytes) -> int: base_length, exponent_length, modulus_length = _extract_lengths(data) first_32_exponent_bytes = zpad_right( @@ -76,7 +85,7 @@ def _compute_modexp_gas_fee(data): return gas_fee -def _modexp(data): +def _modexp(data: bytes) -> int: base_length, exponent_length, modulus_length = _extract_lengths(data) if base_length == 0: @@ -112,7 +121,7 @@ def _modexp(data): return result -def modexp(computation): +def modexp(computation: BaseComputation) -> BaseComputation: """ https://github.com/ethereum/EIPs/pull/198 """ diff --git a/eth/precompiles/ripemd160.py b/eth/precompiles/ripemd160.py index f906a7c4cf..3b692507b4 100644 --- a/eth/precompiles/ripemd160.py +++ b/eth/precompiles/ripemd160.py @@ -8,9 +8,12 @@ from eth.utils.padding import ( pad32, ) +from eth.vm.computation import ( + BaseComputation, +) -def ripemd160(computation): +def ripemd160(computation: BaseComputation) -> BaseComputation: word_count = ceil32(len(computation.msg.data)) // 32 gas_fee = constants.GAS_RIPEMD160 + word_count * constants.GAS_RIPEMD160WORD diff --git a/eth/precompiles/sha256.py b/eth/precompiles/sha256.py index a835e5a673..933f9f6ca9 100644 --- a/eth/precompiles/sha256.py +++ b/eth/precompiles/sha256.py @@ -6,8 +6,12 @@ ceil32, ) +from eth.vm.computation import ( + BaseComputation, +) + -def sha256(computation): +def sha256(computation: BaseComputation) -> BaseComputation: word_count = ceil32(len(computation.msg.data)) // 32 gas_fee = constants.GAS_SHA256 + word_count * constants.GAS_SHA256WORD diff --git a/eth/rlp/headers.py b/eth/rlp/headers.py index 6382582c47..2bc39e7500 100644 --- a/eth/rlp/headers.py +++ b/eth/rlp/headers.py @@ -108,22 +108,22 @@ def __init__(self, nonce: bytes=GENESIS_NONCE) -> None: ... - def __init__(self, # noqa: F811 - difficulty, - block_number, - gas_limit, - timestamp=None, - coinbase=ZERO_ADDRESS, - parent_hash=ZERO_HASH32, - uncles_hash=EMPTY_UNCLE_HASH, - state_root=BLANK_ROOT_HASH, - transaction_root=BLANK_ROOT_HASH, - receipt_root=BLANK_ROOT_HASH, - bloom=0, - gas_used=0, - extra_data=b'', - mix_hash=ZERO_HASH32, - nonce=GENESIS_NONCE): + def __init__(self, # type: ignore # noqa: F811 + difficulty: int, + block_number: int, + gas_limit: int, + timestamp: int=None, + coinbase: Address=ZERO_ADDRESS, + parent_hash: Hash32=ZERO_HASH32, + uncles_hash: Hash32=EMPTY_UNCLE_HASH, + state_root: Hash32=BLANK_ROOT_HASH, + transaction_root: Hash32=BLANK_ROOT_HASH, + receipt_root: Hash32=BLANK_ROOT_HASH, + bloom: int=0, + gas_used: int=0, + extra_data: bytes=b'', + mix_hash: Hash32=ZERO_HASH32, + nonce: bytes=GENESIS_NONCE) -> None: if timestamp is None: timestamp = int(time.time()) super().__init__( @@ -163,7 +163,7 @@ def mining_hash(self) -> Hash32: return keccak(rlp.encode(self[:-2], MiningHeader)) @property - def hex_hash(self): + def hex_hash(self) -> str: return encode_hex(self.hash) @classmethod diff --git a/eth/rlp/logs.py b/eth/rlp/logs.py index adac2e3733..99766ea46e 100644 --- a/eth/rlp/logs.py +++ b/eth/rlp/logs.py @@ -4,7 +4,10 @@ binary, ) -from typing import List +from typing import ( + List, + Tuple, +) from .sedes import ( address, @@ -23,7 +26,7 @@ def __init__(self, address: bytes, topics: List[int], data: bytes) -> None: super().__init__(address, topics, data) @property - def bloomables(self): + def bloomables(self) -> Tuple[bytes, ...]: return ( self.address, ) + tuple( diff --git a/eth/typing.py b/eth/typing.py index 4ec64f866f..eef187652b 100644 --- a/eth/typing.py +++ b/eth/typing.py @@ -2,11 +2,14 @@ Any, Callable, Dict, + Generic, Iterable, List, NewType, Tuple, Union, + TypeVar, + TYPE_CHECKING, ) from eth_typing import ( @@ -17,6 +20,14 @@ TypedDict, ) +if TYPE_CHECKING: + from eth.rlp.transactions import ( # noqa: F401 + BaseTransaction + ) + from eth.utils.spoof import ( # noqa: F401 + SpoofTransaction + ) + # TODO: Move into eth_typing @@ -30,6 +41,8 @@ AccountDiff = Iterable[Tuple[Address, str, Union[int, bytes], Union[int, bytes]]] +BaseOrSpoofTransaction = Union['BaseTransaction', 'SpoofTransaction'] + GeneralState = Union[ AccountState, List[Tuple[Address, Dict[str, Union[int, bytes, Dict[int, int]]]]] @@ -52,3 +65,18 @@ VRS = NewType("VRS", Tuple[int, int, int]) IntConvertible = Union[int, bytes, HexStr, str] + + +TFunc = TypeVar('TFunc') + + +class StaticMethod(Generic[TFunc]): + """ + A property class purely to convince mypy to let us assign a function to an + instance variable. See more at: https://github.com/python/mypy/issues/708#issuecomment-405812141 + """ + def __get__(self, oself: Any, owner: Any) -> TFunc: + return self._func + + def __set__(self, oself: Any, value: TFunc) -> None: + self._func = value diff --git a/eth/vm/base.py b/eth/vm/base.py index 71cfcf7774..41d582a017 100644 --- a/eth/vm/base.py +++ b/eth/vm/base.py @@ -127,6 +127,14 @@ def execute_bytecode(self, code_address: Address=None) -> BaseComputation: raise NotImplementedError("VM classes must implement this method") + @abstractmethod + def apply_all_transactions( + self, + transactions: Tuple[BaseTransaction, ...], + base_header: BlockHeader + ) -> Tuple[BlockHeader, Tuple[Receipt, ...], Tuple[BaseComputation, ...]]: + raise NotImplementedError("VM classes must implement this method") + @abstractmethod def make_receipt(self, base_header: BlockHeader, @@ -452,9 +460,11 @@ def execute_bytecode(self, transaction_context, ) - def apply_all_transactions(self, - transactions: Tuple[BaseTransaction, ...], - base_header: BlockHeader) -> Tuple[BlockHeader, Tuple[Receipt, ...], Tuple[BaseComputation, ...]]: # noqa: E501 + def apply_all_transactions( + self, + transactions: Tuple[BaseTransaction, ...], + base_header: BlockHeader + ) -> Tuple[BlockHeader, Tuple[Receipt, ...], Tuple[BaseComputation, ...]]: """ Determine the results of applying all transactions to the base header. This does *not* update the current block or header of the VM. diff --git a/eth/vm/computation.py b/eth/vm/computation.py index 7c9fc756df..6a9c7c37a4 100644 --- a/eth/vm/computation.py +++ b/eth/vm/computation.py @@ -115,7 +115,7 @@ class BaseComputation(Configurable, ABC): # VM configuration opcodes = None # type: Dict[int, Any] - _precompiles = None # type: Dict[Address, Callable[['BaseComputation'], Any]] + _precompiles = None # type: Dict[Address, Callable[['BaseComputation'], 'BaseComputation']] logger = cast(TraceLogger, logging.getLogger('eth.vm.computation.Computation')) diff --git a/eth/vm/forks/frontier/computation.py b/eth/vm/forks/frontier/computation.py index bc2fa9014e..105d5c03ac 100644 --- a/eth/vm/forks/frontier/computation.py +++ b/eth/vm/forks/frontier/computation.py @@ -42,7 +42,7 @@ class FrontierComputation(BaseComputation): """ # Override opcodes = FRONTIER_OPCODES - _precompiles = FRONTIER_PRECOMPILES + _precompiles = FRONTIER_PRECOMPILES # type: ignore # https://github.com/python/mypy/issues/708 # noqa: E501 def apply_message(self) -> BaseComputation: snapshot = self.state.snapshot() diff --git a/eth/vm/forks/frontier/state.py b/eth/vm/forks/frontier/state.py index be29ceb322..27ffcad1bc 100644 --- a/eth/vm/forks/frontier/state.py +++ b/eth/vm/forks/frontier/state.py @@ -1,5 +1,5 @@ from __future__ import absolute_import -from typing import Type # noqa: F401 +from typing import Type, Union # noqa: F401 from eth_hash.auto import keccak from eth_utils import ( @@ -14,10 +14,9 @@ ContractCreationCollision, ) -from eth.rlp.transactions import ( - BaseTransaction, +from eth.typing import ( + BaseOrSpoofTransaction, ) - from eth.utils.address import ( generate_contract_address, ) @@ -45,7 +44,7 @@ class FrontierTransactionExecutor(BaseTransactionExecutor): - def validate_transaction(self, transaction: BaseTransaction) -> BaseTransaction: + def validate_transaction(self, transaction: BaseOrSpoofTransaction) -> BaseOrSpoofTransaction: # Validate the transaction transaction.validate() @@ -53,7 +52,7 @@ def validate_transaction(self, transaction: BaseTransaction) -> BaseTransaction: return transaction - def build_evm_message(self, transaction: BaseTransaction) -> Message: + def build_evm_message(self, transaction: BaseOrSpoofTransaction) -> Message: gas_fee = transaction.gas * transaction.gas_price @@ -105,7 +104,9 @@ def build_evm_message(self, transaction: BaseTransaction) -> Message: ) return message - def build_computation(self, message: Message, transaction: BaseTransaction) -> BaseComputation: + def build_computation(self, + message: Message, + transaction: BaseOrSpoofTransaction) -> BaseComputation: """Apply the message to the VM.""" transaction_context = self.vm_state.get_transaction_context(transaction) if message.is_create: @@ -139,7 +140,7 @@ def build_computation(self, message: Message, transaction: BaseTransaction) -> B return computation def finalize_computation(self, - transaction: BaseTransaction, + transaction: BaseOrSpoofTransaction, computation: BaseComputation) -> BaseComputation: # Self Destruct Refunds num_deletions = len(computation.get_accounts_for_deletion()) @@ -192,9 +193,9 @@ class FrontierState(BaseState): account_db_class = AccountDB # Type[BaseAccountDB] transaction_executor = FrontierTransactionExecutor # Type[BaseTransactionExecutor] - def validate_transaction(self, transaction: BaseTransaction) -> None: + def validate_transaction(self, transaction: BaseOrSpoofTransaction) -> None: validate_frontier_transaction(self.account_db, transaction) - def execute_transaction(self, transaction: BaseTransaction) -> BaseTransactionExecutor: + def execute_transaction(self, transaction: BaseOrSpoofTransaction) -> BaseTransactionExecutor: executor = self.get_transaction_executor() return executor(transaction) diff --git a/eth/vm/forks/frontier/validation.py b/eth/vm/forks/frontier/validation.py index ea85bfad39..898f25b894 100644 --- a/eth/vm/forks/frontier/validation.py +++ b/eth/vm/forks/frontier/validation.py @@ -5,12 +5,18 @@ from eth.db.account import BaseAccountDB from eth.rlp.headers import BlockHeader + from eth.rlp.transactions import BaseTransaction +from eth.typing import ( + BaseOrSpoofTransaction +) + from eth.vm.base import BaseVM -def validate_frontier_transaction(account_db: BaseAccountDB, transaction: BaseTransaction) -> None: +def validate_frontier_transaction(account_db: BaseAccountDB, + transaction: BaseOrSpoofTransaction) -> None: gas_cost = transaction.gas * transaction.gas_price sender_balance = account_db.get_balance(transaction.sender) diff --git a/eth/vm/forks/homestead/state.py b/eth/vm/forks/homestead/state.py index 6dccd70a4e..baf39dbb3e 100644 --- a/eth/vm/forks/homestead/state.py +++ b/eth/vm/forks/homestead/state.py @@ -1,4 +1,6 @@ -from eth.rlp.transactions import BaseTransaction +from eth.typing import ( + BaseOrSpoofTransaction, +) from eth.vm.forks.frontier.state import ( FrontierState, @@ -12,7 +14,7 @@ class HomesteadState(FrontierState): computation_class = HomesteadComputation - def validate_transaction(self, transaction: BaseTransaction) -> None: + def validate_transaction(self, transaction: BaseOrSpoofTransaction) -> None: validate_homestead_transaction(self.account_db, transaction) diff --git a/eth/vm/forks/homestead/validation.py b/eth/vm/forks/homestead/validation.py index add2be3c68..3938d9427d 100644 --- a/eth/vm/forks/homestead/validation.py +++ b/eth/vm/forks/homestead/validation.py @@ -8,14 +8,15 @@ from eth.db.account import BaseAccountDB -from eth.rlp.transactions import BaseTransaction +from eth.typing import BaseOrSpoofTransaction from eth.vm.forks.frontier.validation import ( validate_frontier_transaction, ) -def validate_homestead_transaction(account_db: BaseAccountDB, transaction: BaseTransaction) -> None: +def validate_homestead_transaction(account_db: BaseAccountDB, + transaction: BaseOrSpoofTransaction) -> None: if transaction.s > SECPK1_N // 2 or transaction.s == 0: raise ValidationError("Invalid signature S value") diff --git a/eth/vm/forks/spurious_dragon/state.py b/eth/vm/forks/spurious_dragon/state.py index 8bc01fd903..36909e0d1c 100644 --- a/eth/vm/forks/spurious_dragon/state.py +++ b/eth/vm/forks/spurious_dragon/state.py @@ -1,9 +1,11 @@ -from eth.rlp.transactions import BaseTransaction - from eth_utils import ( encode_hex, ) +from eth.typing import ( + BaseOrSpoofTransaction, +) + from eth.vm.computation import BaseComputation from eth.vm.forks.homestead.state import ( @@ -17,7 +19,7 @@ class SpuriousDragonTransactionExecutor(HomesteadTransactionExecutor): def finalize_computation(self, - transaction: BaseTransaction, + transaction: BaseOrSpoofTransaction, computation: BaseComputation) -> BaseComputation: computation = super().finalize_computation(transaction, computation) diff --git a/eth/vm/state.py b/eth/vm/state.py index 056177ec94..a927ae20e0 100644 --- a/eth/vm/state.py +++ b/eth/vm/state.py @@ -11,6 +11,7 @@ Tuple, Type, TYPE_CHECKING, + Union, ) from uuid import UUID @@ -34,6 +35,9 @@ from eth.tools.logging import ( TraceLogger, ) +from eth.typing import ( + BaseOrSpoofTransaction, +) from eth.utils.datatypes import ( Configurable, ) @@ -49,6 +53,7 @@ from eth.rlp.transactions import ( # noqa: F401 BaseTransaction, ) + from eth.vm.transaction_context import ( # noqa: F401 BaseTransactionContext, ) @@ -250,7 +255,8 @@ def apply_transaction(self, transaction: 'BaseTransaction') -> Tuple[bytes, 'Bas def get_transaction_executor(self) -> 'BaseTransactionExecutor': return self.transaction_executor(self) - def costless_execute_transaction(self, transaction: 'BaseTransaction') -> 'BaseComputation': + def costless_execute_transaction(self, + transaction: BaseOrSpoofTransaction) -> 'BaseComputation': with self.override_transaction_context(gas_price=transaction.gas_price): free_transaction = transaction.copy(gas_price=0) return self.execute_transaction(free_transaction) @@ -270,15 +276,16 @@ def get_custom_transaction_context(transaction: 'BaseTransaction') -> 'BaseTrans self.get_transaction_context = original_context # type: ignore # Remove ignore if https://github.com/python/mypy/issues/708 is fixed. # noqa: E501 @abstractmethod - def execute_transaction(self, transaction: 'BaseTransaction') -> 'BaseComputation': + def execute_transaction(self, transaction: BaseOrSpoofTransaction) -> 'BaseComputation': raise NotImplementedError() @abstractmethod - def validate_transaction(self, transaction: 'BaseTransaction') -> None: + def validate_transaction(self, transaction: BaseOrSpoofTransaction) -> None: raise NotImplementedError @classmethod - def get_transaction_context(cls, transaction: 'BaseTransaction') -> 'BaseTransactionContext': + def get_transaction_context(cls, + transaction: BaseOrSpoofTransaction) -> 'BaseTransactionContext': return cls.get_transaction_context_class()( gas_price=transaction.gas_price, origin=transaction.sender, @@ -289,7 +296,7 @@ class BaseTransactionExecutor(ABC): def __init__(self, vm_state: BaseState) -> None: self.vm_state = vm_state - def __call__(self, transaction: 'BaseTransaction') -> 'BaseComputation': + def __call__(self, transaction: BaseOrSpoofTransaction) -> 'BaseComputation': valid_transaction = self.validate_transaction(transaction) message = self.build_evm_message(valid_transaction) computation = self.build_computation(message, valid_transaction) @@ -297,21 +304,21 @@ def __call__(self, transaction: 'BaseTransaction') -> 'BaseComputation': return finalized_computation @abstractmethod - def validate_transaction(self, transaction: 'BaseTransaction') -> 'BaseTransaction': + def validate_transaction(self, transaction: BaseOrSpoofTransaction) -> BaseOrSpoofTransaction: raise NotImplementedError @abstractmethod - def build_evm_message(self, transaction: 'BaseTransaction') -> Message: + def build_evm_message(self, transaction: BaseOrSpoofTransaction) -> Message: raise NotImplementedError() @abstractmethod def build_computation(self, message: Message, - transaction: 'BaseTransaction') -> 'BaseComputation': + transaction: BaseOrSpoofTransaction) -> 'BaseComputation': raise NotImplementedError() @abstractmethod def finalize_computation(self, - transaction: 'BaseTransaction', + transaction: BaseOrSpoofTransaction, computation: 'BaseComputation') -> 'BaseComputation': raise NotImplementedError() diff --git a/p2p/discovery.py b/p2p/discovery.py index a750d48152..f90682963a 100644 --- a/p2p/discovery.py +++ b/p2p/discovery.py @@ -1084,7 +1084,7 @@ def __repr__(self) -> str: @to_list def _extract_nodes_from_payload( sender: kademlia.Address, - payload: List[Tuple[str, str, str, str]], + payload: List[Tuple[str, str, str, bytes]], logger: TraceLogger) -> Iterator[kademlia.Node]: for item in payload: ip, udp_port, tcp_port, node_id = item diff --git a/tox.ini b/tox.ini index b29ded06d3..cdac687357 100644 --- a/tox.ini +++ b/tox.ini @@ -104,10 +104,7 @@ commands= {[common-lint]commands} flake8 {toxinidir}/tests --exclude="trinity,p2p" # TODO: Drop --ignore-missing-imports once we have type annotations for eth_utils, coincurve and cytoolz - mypy --follow-imports=silent --warn-unused-ignores --ignore-missing-imports --no-strict-optional --check-untyped-defs --disallow-incomplete-defs -p eth - mypy --follow-imports=silent --warn-unused-ignores --ignore-missing-imports --no-strict-optional --check-untyped-defs --disallow-incomplete-defs --disallow-untyped-defs --disallow-any-generics -p eth.utils - mypy --follow-imports=silent --warn-unused-ignores --ignore-missing-imports --no-strict-optional --check-untyped-defs --disallow-incomplete-defs --disallow-untyped-defs --disallow-any-generics -p eth.tools - mypy --follow-imports=silent --warn-unused-ignores --ignore-missing-imports --no-strict-optional --check-untyped-defs --disallow-incomplete-defs --disallow-untyped-defs --disallow-any-generics -p eth.vm + mypy --follow-imports=silent --warn-unused-ignores --ignore-missing-imports --no-strict-optional --check-untyped-defs --disallow-incomplete-defs --disallow-untyped-defs --disallow-any-generics -p eth [testenv:py36-lint] diff --git a/trinity/chains/header.py b/trinity/chains/header.py index 785913591d..3571410f85 100644 --- a/trinity/chains/header.py +++ b/trinity/chains/header.py @@ -7,6 +7,7 @@ from typing import Tuple, Type from eth.db.backends.base import BaseDB +from eth.db.header import BaseHeaderDB from eth.chains.header import ( BaseHeaderChain, HeaderChain, @@ -51,7 +52,7 @@ def from_genesis_header(cls, raise NotImplementedError("Chain classes must implement this method") @classmethod - def get_headerdb_class(cls) -> BaseDB: + def get_headerdb_class(cls) -> Type[BaseHeaderDB]: raise NotImplementedError("Chain classes must implement this method") coro_get_block_header_by_hash = async_method('get_block_header_by_hash') diff --git a/trinity/chains/light.py b/trinity/chains/light.py index ce06ba22d2..54fa69945b 100644 --- a/trinity/chains/light.py +++ b/trinity/chains/light.py @@ -47,8 +47,8 @@ BaseTransaction, BaseUnsignedTransaction, ) -from eth.utils.spoof import ( - SpoofTransaction, +from eth.typing import ( + BaseOrSpoofTransaction, ) from eth.vm.computation import ( BaseComputation @@ -175,8 +175,9 @@ async def coro_get_canonical_block_by_number(self, block_number: BlockNumber) -> def get_canonical_block_hash(self, block_number: BlockNumber) -> Hash32: return self._headerdb.get_canonical_block_hash(block_number) - def build_block_with_transactions( - self, transactions: Tuple[BaseTransaction, ...], parent_header: BlockHeader) -> None: + def build_block_with_transactions(self, + transactions: Tuple[BaseTransaction, ...], + parent_header: BlockHeader=None) -> Tuple[BaseBlock, Tuple[Receipt, ...], Tuple[BaseComputation, ...]]: # noqa: E501 raise NotImplementedError("Chain classes must implement " + inspect.stack()[0][3]) # @@ -208,13 +209,13 @@ def apply_transaction( def get_transaction_result( self, - transaction: Union[BaseTransaction, SpoofTransaction], + transaction: BaseOrSpoofTransaction, at_header: BlockHeader) -> bytes: raise NotImplementedError("Chain classes must implement " + inspect.stack()[0][3]) def estimate_gas( self, - transaction: Union[BaseTransaction, SpoofTransaction], + transaction: BaseOrSpoofTransaction, at_header: BlockHeader=None) -> int: raise NotImplementedError("Chain classes must implement " + inspect.stack()[0][3]) diff --git a/trinity/utils/datastructures.py b/trinity/utils/datastructures.py index d1eba7a873..922101c00b 100644 --- a/trinity/utils/datastructures.py +++ b/trinity/utils/datastructures.py @@ -44,29 +44,20 @@ pipe, ) +from eth.typing import ( + StaticMethod, +) + from trinity.utils.queues import ( queue_get_batch, queue_get_nowait, ) -TFunc = TypeVar('TFunc') TPrerequisite = TypeVar('TPrerequisite', bound=Enum) TTask = TypeVar('TTask') TTaskID = TypeVar('TTaskID') -class StaticMethod(Generic[TFunc]): - """ - A property class purely to convince mypy to let us assign a function to an - instance variable. See more at: https://github.com/python/mypy/issues/708#issuecomment-405812141 - """ - def __get__(self, oself: Any, owner: Any) -> TFunc: - return self._func - - def __set__(self, oself: Any, value: TFunc) -> None: - self._func = value - - @total_ordering class SortableTask(Generic[TTask]): _order_fn: StaticMethod[Callable[[TTask], Any]] = None