diff --git a/core/state/statedb.go b/core/state/statedb.go index b770698255e..8759e8faa5d 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -130,6 +130,8 @@ type StateDB struct { // Transient storage transientStorage transientStorage + // Overrides to apply after Prepare() + pendingTransientOverrides map[common.Address]map[common.Hash]common.Hash // Journal of state modifications. This is the backbone of // Snapshot and RevertToSnapshot. @@ -1430,6 +1432,22 @@ func (s *StateDB) Prepare(rules params.Rules, sender, coinbase common.Address, d } // Reset transient storage at the beginning of transaction execution s.transientStorage = newTransientStorage() + + // Apply any pending transient storage overrides after reset + if s.pendingTransientOverrides != nil { + for addr, storage := range s.pendingTransientOverrides { + for key, value := range storage { + s.transientStorage.Set(addr, key, value) + } + } + s.pendingTransientOverrides = nil // Clear after applying + } +} + +// SetPendingTransientOverrides stores transient storage overrides to be applied +// after the next Prepare() call. This ensures overrides are applied to a clean state. +func (s *StateDB) SetPendingTransientOverrides(overrides map[common.Address]map[common.Hash]common.Hash) { + s.pendingTransientOverrides = overrides } // AddAddressToAccessList adds the given address to the access list diff --git a/ethclient/gethclient/gethclient.go b/ethclient/gethclient/gethclient.go index 54997cbf51a..9dcf3eb3b19 100644 --- a/ethclient/gethclient/gethclient.go +++ b/ethclient/gethclient/gethclient.go @@ -167,6 +167,31 @@ func (ec *Client) CallContractWithBlockOverrides(ctx context.Context, msg ethere return hex, err } +// CallContractWithTransientOverrides executes a message call transaction, which is directly executed +// in the VM of the node, but never mined into the blockchain. +// +// blockNumber selects the block height at which the call runs. It can be nil, in which +// case the code is taken from the latest known block. Note that state from very old +// blocks might not be available. +// +// overrides specifies a map of contract states that should be overwritten before executing +// the message call. +// +// blockOverrides specifies block fields exposed to the EVM that can be overridden for the call. +// +// transientOverrides specifies transient storage slots that should be overwritten before executing +// the message call. Transient storage is reset at the beginning of each transaction. +// +// Please use ethclient.CallContract instead if you don't need the override functionality. +func (ec *Client) CallContractWithTransientOverrides(ctx context.Context, msg ethereum.CallMsg, blockNumber *big.Int, overrides *map[common.Address]OverrideAccount, blockOverrides *BlockOverrides, transientOverrides *TransientOverrides) ([]byte, error) { + var hex hexutil.Bytes + err := ec.c.CallContext( + ctx, &hex, "eth_call", toCallArg(msg), + toBlockNumArg(blockNumber), overrides, blockOverrides, transientOverrides, + ) + return hex, err +} + // GCStats retrieves the current garbage collection stats from a geth node. func (ec *Client) GCStats(ctx context.Context) (*debug.GCStats, error) { var result debug.GCStats @@ -348,6 +373,10 @@ type BlockOverrides struct { BaseFee *big.Int } +// TransientOverrides specifies transient storage slots to override for eth_call. +// The map key is the contract address, and the value is another map of slot to value. +type TransientOverrides map[common.Address]map[common.Hash]common.Hash + func (o BlockOverrides) MarshalJSON() ([]byte, error) { type override struct { Number *hexutil.Big `json:"number,omitempty"` diff --git a/graphql/graphql.go b/graphql/graphql.go index 0b2a77a3c4c..5fabb0a7957 100644 --- a/graphql/graphql.go +++ b/graphql/graphql.go @@ -1189,7 +1189,7 @@ func (c *CallResult) Status() hexutil.Uint64 { func (b *Block) Call(ctx context.Context, args struct { Data ethapi.TransactionArgs }) (*CallResult, error) { - result, err := ethapi.DoCall(ctx, b.r.backend, args.Data, *b.numberOrHash, nil, nil, b.r.backend.RPCEVMTimeout(), b.r.backend.RPCGasCap()) + result, err := ethapi.DoCall(ctx, b.r.backend, args.Data, *b.numberOrHash, nil, nil, nil, b.r.backend.RPCEVMTimeout(), b.r.backend.RPCGasCap()) if err != nil { return nil, err } @@ -1252,7 +1252,7 @@ func (p *Pending) Call(ctx context.Context, args struct { Data ethapi.TransactionArgs }) (*CallResult, error) { pendingBlockNr := rpc.BlockNumberOrHashWithNumber(rpc.PendingBlockNumber) - result, err := ethapi.DoCall(ctx, p.r.backend, args.Data, pendingBlockNr, nil, nil, p.r.backend.RPCEVMTimeout(), p.r.backend.RPCGasCap()) + result, err := ethapi.DoCall(ctx, p.r.backend, args.Data, pendingBlockNr, nil, nil, nil, p.r.backend.RPCEVMTimeout(), p.r.backend.RPCGasCap()) if err != nil { return nil, err } diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index ebb8ece7301..b3e9ee33947 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -669,7 +669,7 @@ func (context *ChainContext) Config() *params.ChainConfig { return context.b.ChainConfig() } -func doCall(ctx context.Context, b Backend, args TransactionArgs, state *state.StateDB, header *types.Header, overrides *override.StateOverride, blockOverrides *override.BlockOverrides, timeout time.Duration, globalGasCap uint64) (*core.ExecutionResult, error) { +func doCall(ctx context.Context, b Backend, args TransactionArgs, state *state.StateDB, header *types.Header, overrides *override.StateOverride, blockOverrides *override.BlockOverrides, transientOverrides *override.TransientStorageOverride, timeout time.Duration, globalGasCap uint64) (*core.ExecutionResult, error) { blockCtx := core.NewEVMBlockContext(header, NewChainContext(ctx, b), nil) if blockOverrides != nil { if err := blockOverrides.Apply(&blockCtx); err != nil { @@ -682,6 +682,9 @@ func doCall(ctx context.Context, b Backend, args TransactionArgs, state *state.S return nil, err } + // Apply transient storage overrides. These will be applied after Prepare() is called. + transientOverrides.Apply(state) + // Setup context so it may be cancelled the call has completed // or, in case of unmetered gas, setup a context with a timeout. var cancel context.CancelFunc @@ -750,14 +753,14 @@ func applyMessageWithEVM(ctx context.Context, evm *vm.EVM, msg *core.Message, ti return result, nil } -func DoCall(ctx context.Context, b Backend, args TransactionArgs, blockNrOrHash rpc.BlockNumberOrHash, overrides *override.StateOverride, blockOverrides *override.BlockOverrides, timeout time.Duration, globalGasCap uint64) (*core.ExecutionResult, error) { +func DoCall(ctx context.Context, b Backend, args TransactionArgs, blockNrOrHash rpc.BlockNumberOrHash, overrides *override.StateOverride, blockOverrides *override.BlockOverrides, transientOverrides *override.TransientStorageOverride, timeout time.Duration, globalGasCap uint64) (*core.ExecutionResult, error) { defer func(start time.Time) { log.Debug("Executing EVM call finished", "runtime", time.Since(start)) }(time.Now()) state, header, err := b.StateAndHeaderByNumberOrHash(ctx, blockNrOrHash) if state == nil || err != nil { return nil, err } - return doCall(ctx, b, args, state, header, overrides, blockOverrides, timeout, globalGasCap) + return doCall(ctx, b, args, state, header, overrides, blockOverrides, transientOverrides, timeout, globalGasCap) } // Call executes the given transaction on the state for the given block number. @@ -766,12 +769,12 @@ func DoCall(ctx context.Context, b Backend, args TransactionArgs, blockNrOrHash // // Note, this function doesn't make and changes in the state/blockchain and is // useful to execute and retrieve values. -func (api *BlockChainAPI) Call(ctx context.Context, args TransactionArgs, blockNrOrHash *rpc.BlockNumberOrHash, overrides *override.StateOverride, blockOverrides *override.BlockOverrides) (hexutil.Bytes, error) { +func (api *BlockChainAPI) Call(ctx context.Context, args TransactionArgs, blockNrOrHash *rpc.BlockNumberOrHash, overrides *override.StateOverride, blockOverrides *override.BlockOverrides, transientOverrides *override.TransientStorageOverride) (hexutil.Bytes, error) { if blockNrOrHash == nil { latest := rpc.BlockNumberOrHashWithNumber(rpc.LatestBlockNumber) blockNrOrHash = &latest } - result, err := DoCall(ctx, api.b, args, *blockNrOrHash, overrides, blockOverrides, api.b.RPCEVMTimeout(), api.b.RPCGasCap()) + result, err := DoCall(ctx, api.b, args, *blockNrOrHash, overrides, blockOverrides, transientOverrides, api.b.RPCEVMTimeout(), api.b.RPCGasCap()) if err != nil { return nil, err } diff --git a/internal/ethapi/api_test.go b/internal/ethapi/api_test.go index c0a8fe9a583..96dec99565f 100644 --- a/internal/ethapi/api_test.go +++ b/internal/ethapi/api_test.go @@ -967,13 +967,14 @@ func TestCall(t *testing.T) { })) randomAccounts := newAccounts(3) var testSuite = []struct { - name string - blockNumber rpc.BlockNumber - overrides override.StateOverride - call TransactionArgs - blockOverrides override.BlockOverrides - expectErr error - want string + name string + blockNumber rpc.BlockNumber + overrides override.StateOverride + call TransactionArgs + blockOverrides override.BlockOverrides + transientOverrides override.TransientStorageOverride + expectErr error + want string }{ // transfer on genesis { @@ -1226,9 +1227,30 @@ func TestCall(t *testing.T) { }, expectErr: errors.New(`block override "withdrawals" is not supported for this RPC method`), }, + // Test transient storage override + { + name: "transient storage override takes effect", + blockNumber: rpc.LatestBlockNumber, + call: TransactionArgs{ + From: &accounts[1].addr, + To: &randomAccounts[2].addr, + }, + overrides: override.StateOverride{ + randomAccounts[2].addr: override.OverrideAccount{ + // PUSH1 0x00 TLOAD PUSH1 0x00 MSTORE PUSH1 0x20 PUSH1 0x00 RETURN + Code: hex2Bytes("0x60005c60005260206000f3"), + }, + }, + transientOverrides: override.TransientStorageOverride{ + randomAccounts[2].addr: map[common.Hash]common.Hash{ + common.Hash{}: common.HexToHash("0xabcd"), + }, + }, + want: "0x000000000000000000000000000000000000000000000000000000000000abcd", + }, } for _, tc := range testSuite { - result, err := api.Call(context.Background(), tc.call, &rpc.BlockNumberOrHash{BlockNumber: &tc.blockNumber}, &tc.overrides, &tc.blockOverrides) + result, err := api.Call(context.Background(), tc.call, &rpc.BlockNumberOrHash{BlockNumber: &tc.blockNumber}, &tc.overrides, &tc.blockOverrides, &tc.transientOverrides) if tc.expectErr != nil { if err == nil { t.Errorf("test %s: want error %v, have nothing", tc.name, tc.expectErr) diff --git a/internal/ethapi/override/override.go b/internal/ethapi/override/override.go index 9d57a78651d..853eecdd93f 100644 --- a/internal/ethapi/override/override.go +++ b/internal/ethapi/override/override.go @@ -33,7 +33,7 @@ import ( // OverrideAccount indicates the overriding fields of account during the execution // of a message call. // Note, state and stateDiff can't be specified at the same time. If state is -// set, message execution will only use the data in the given state. Otherwise +// set, message execution will only use the data in the given state. Otherwise, // if stateDiff is set, all diff will be applied first and then execute the call // message. type OverrideAccount struct { @@ -48,6 +48,9 @@ type OverrideAccount struct { // StateOverride is the collection of overridden accounts. type StateOverride map[common.Address]OverrideAccount +// TransientStorageOverride is the collection of transient storage overrides. +type TransientStorageOverride map[common.Address]map[common.Hash]common.Hash + func (diff *StateOverride) has(address common.Address) bool { _, ok := (*diff)[address] return ok @@ -119,6 +122,14 @@ func (diff *StateOverride) Apply(statedb *state.StateDB, precompiles vm.Precompi return nil } +// Apply stores transient storage overrides to be applied after Prepare(). +func (diff *TransientStorageOverride) Apply(statedb *state.StateDB) { + if diff == nil || len(*diff) == 0 { + return + } + statedb.SetPendingTransientOverrides(*diff) +} + // BlockOverrides is a set of header fields to override. type BlockOverrides struct { Number *hexutil.Big diff --git a/internal/ethapi/override/override_test.go b/internal/ethapi/override/override_test.go index 6feafaac756..9d1ed898dcd 100644 --- a/internal/ethapi/override/override_test.go +++ b/internal/ethapi/override/override_test.go @@ -26,6 +26,7 @@ import ( "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/triedb" ) @@ -128,3 +129,56 @@ func hex2Bytes(str string) *hexutil.Bytes { rpcBytes := hexutil.Bytes(common.FromHex(str)) return &rpcBytes } + +func TestStateOverrideTransientStorage(t *testing.T) { + db := state.NewDatabase(triedb.NewDatabase(rawdb.NewMemoryDatabase(), nil), nil) + statedb, err := state.New(types.EmptyRootHash, db) + if err != nil { + t.Fatalf("failed to create statedb: %v", err) + } + + addr := common.BytesToAddress([]byte{0x1}) + key1 := common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000001") + key2 := common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000002") + value1 := common.HexToHash("0x1111111111111111111111111111111111111111111111111111111111111111") + value2 := common.HexToHash("0x2222222222222222222222222222222222222222222222222222222222222222") + + // Verify initial state is empty + if got := statedb.GetTransientState(addr, key1); got != (common.Hash{}) { + t.Fatalf("expected initial transient state to be empty, got %s", got.Hex()) + } + if got := statedb.GetTransientState(addr, key2); got != (common.Hash{}) { + t.Fatalf("expected initial transient state to be empty, got %s", got.Hex()) + } + + // Apply transient storage override + transientOverride := TransientStorageOverride{ + addr: map[common.Hash]common.Hash{ + key1: value1, + key2: value2, + }, + } + + transientOverride.Apply(statedb) + + statedb.Prepare(params.Rules{}, common.Address{}, common.Address{}, nil, nil, nil) + + // Verify transient storage was set + if got := statedb.GetTransientState(addr, key1); got != value1 { + t.Errorf("expected transient state for key1 to be %s, got %s", value1.Hex(), got.Hex()) + } + if got := statedb.GetTransientState(addr, key2); got != value2 { + t.Errorf("expected transient state for key2 to be %s, got %s", value2.Hex(), got.Hex()) + } + + // Verify other addresses/keys remain empty + otherAddr := common.BytesToAddress([]byte{0x2}) + if got := statedb.GetTransientState(otherAddr, key1); got != (common.Hash{}) { + t.Errorf("expected transient state for different address to be empty, got %s", got.Hex()) + } + + otherKey := common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000003") + if got := statedb.GetTransientState(addr, otherKey); got != (common.Hash{}) { + t.Errorf("expected transient state for different key to be empty, got %s", got.Hex()) + } +}