Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions core/state/statedb.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ type StateDB struct {

// Transient storage
transientStorage transientStorage
// Overrides to apply after Prepare()
pendingTransientOverrides map[common.Address]map[common.Hash]common.Hash

// Journal of state modifications. This is the backbone of
// Snapshot and RevertToSnapshot.
Expand Down Expand Up @@ -1430,6 +1432,22 @@ func (s *StateDB) Prepare(rules params.Rules, sender, coinbase common.Address, d
}
// Reset transient storage at the beginning of transaction execution
s.transientStorage = newTransientStorage()

// Apply any pending transient storage overrides after reset
if s.pendingTransientOverrides != nil {
for addr, storage := range s.pendingTransientOverrides {
for key, value := range storage {
s.transientStorage.Set(addr, key, value)
}
}
s.pendingTransientOverrides = nil // Clear after applying
}
}

// SetPendingTransientOverrides stores transient storage overrides to be applied
// after the next Prepare() call. This ensures overrides are applied to a clean state.
func (s *StateDB) SetPendingTransientOverrides(overrides map[common.Address]map[common.Hash]common.Hash) {
s.pendingTransientOverrides = overrides
}

// AddAddressToAccessList adds the given address to the access list
Expand Down
21 changes: 13 additions & 8 deletions ethclient/gethclient/gethclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,21 +301,26 @@ type OverrideAccount struct {

// StateDiff allows overriding individual storage slots.
StateDiff map[common.Hash]common.Hash

// TransientStorage allows overriding transient storage slots.
TransientStorage map[common.Hash]common.Hash
}

func (a OverrideAccount) MarshalJSON() ([]byte, error) {
type acc struct {
Nonce hexutil.Uint64 `json:"nonce,omitempty"`
Code string `json:"code,omitempty"`
Balance *hexutil.Big `json:"balance,omitempty"`
State interface{} `json:"state,omitempty"`
StateDiff map[common.Hash]common.Hash `json:"stateDiff,omitempty"`
Nonce hexutil.Uint64 `json:"nonce,omitempty"`
Code string `json:"code,omitempty"`
Balance *hexutil.Big `json:"balance,omitempty"`
State interface{} `json:"state,omitempty"`
StateDiff map[common.Hash]common.Hash `json:"stateDiff,omitempty"`
TransientStorage map[common.Hash]common.Hash `json:"transientStorage,omitempty"`
}

output := acc{
Nonce: hexutil.Uint64(a.Nonce),
Balance: (*hexutil.Big)(a.Balance),
StateDiff: a.StateDiff,
Nonce: hexutil.Uint64(a.Nonce),
Balance: (*hexutil.Big)(a.Balance),
StateDiff: a.StateDiff,
TransientStorage: a.TransientStorage,
}
if a.Code != nil {
output.Code = hexutil.Encode(a.Code)
Expand Down
19 changes: 19 additions & 0 deletions internal/ethapi/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,25 @@ func TestCall(t *testing.T) {
},
expectErr: errors.New(`block override "withdrawals" is not supported for this RPC method`),
},
// Test transient storage override
{
name: "transient storage override takes effect",
blockNumber: rpc.LatestBlockNumber,
call: TransactionArgs{
From: &accounts[1].addr,
To: &randomAccounts[2].addr,
},
overrides: override.StateOverride{
randomAccounts[2].addr: override.OverrideAccount{
// PUSH1 0x00 TLOAD PUSH1 0x00 MSTORE PUSH1 0x20 PUSH1 0x00 RETURN
Code: hex2Bytes("0x60005c60005260206000f3"),
TransientStorage: map[common.Hash]common.Hash{
common.Hash{}: common.HexToHash("0xabcd"),
},
},
},
want: "0x000000000000000000000000000000000000000000000000000000000000abcd",
},
}
for _, tc := range testSuite {
result, err := api.Call(context.Background(), tc.call, &rpc.BlockNumberOrHash{BlockNumber: &tc.blockNumber}, &tc.overrides, &tc.blockOverrides)
Expand Down
17 changes: 16 additions & 1 deletion internal/ethapi/override/override.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import (
// OverrideAccount indicates the overriding fields of account during the execution
// of a message call.
// Note, state and stateDiff can't be specified at the same time. If state is
// set, message execution will only use the data in the given state. Otherwise
// set, message execution will only use the data in the given state. Otherwise,
// if stateDiff is set, all diff will be applied first and then execute the call
// message.
type OverrideAccount struct {
Expand All @@ -42,6 +42,7 @@ type OverrideAccount struct {
Balance *hexutil.Big `json:"balance"`
State map[common.Hash]common.Hash `json:"state"`
StateDiff map[common.Hash]common.Hash `json:"stateDiff"`
TransientStorage map[common.Hash]common.Hash `json:"transientStorage"`
MovePrecompileTo *common.Address `json:"movePrecompileToAddress"`
}

Expand All @@ -58,6 +59,20 @@ func (diff *StateOverride) Apply(statedb *state.StateDB, precompiles vm.Precompi
if diff == nil {
return nil
}

// Get transient storage overrides to apply after Prepare()
transientOverrides := make(map[common.Address]map[common.Hash]common.Hash)
for addr, account := range *diff {
if account.TransientStorage != nil && len(account.TransientStorage) > 0 {
transientOverrides[addr] = account.TransientStorage
}
}

// Store transient storage overrides
if len(transientOverrides) > 0 {
statedb.SetPendingTransientOverrides(transientOverrides)
}

// Tracks destinations of precompiles that were moved.
dirtyAddrs := make(map[common.Address]struct{})
for addr, account := range *diff {
Expand Down
58 changes: 58 additions & 0 deletions internal/ethapi/override/override_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/triedb"
)

Expand Down Expand Up @@ -128,3 +129,60 @@ func hex2Bytes(str string) *hexutil.Bytes {
rpcBytes := hexutil.Bytes(common.FromHex(str))
return &rpcBytes
}

func TestStateOverrideTransientStorage(t *testing.T) {
db := state.NewDatabase(triedb.NewDatabase(rawdb.NewMemoryDatabase(), nil), nil)
statedb, err := state.New(types.EmptyRootHash, db)
if err != nil {
t.Fatalf("failed to create statedb: %v", err)
}

addr := common.BytesToAddress([]byte{0x1})
key1 := common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000001")
key2 := common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000002")
value1 := common.HexToHash("0x1111111111111111111111111111111111111111111111111111111111111111")
value2 := common.HexToHash("0x2222222222222222222222222222222222222222222222222222222222222222")

// Verify initial state is empty
if got := statedb.GetTransientState(addr, key1); got != (common.Hash{}) {
t.Fatalf("expected initial transient state to be empty, got %s", got.Hex())
}
if got := statedb.GetTransientState(addr, key2); got != (common.Hash{}) {
t.Fatalf("expected initial transient state to be empty, got %s", got.Hex())
}

// Apply override with transient storage
override := StateOverride{
addr: OverrideAccount{
TransientStorage: map[common.Hash]common.Hash{
key1: value1,
key2: value2,
},
},
}

if err := override.Apply(statedb, nil); err != nil {
t.Fatalf("failed to apply override: %v", err)
}

statedb.Prepare(params.Rules{}, common.Address{}, common.Address{}, nil, nil, nil)

// Verify transient storage was set
if got := statedb.GetTransientState(addr, key1); got != value1 {
t.Errorf("expected transient state for key1 to be %s, got %s", value1.Hex(), got.Hex())
}
if got := statedb.GetTransientState(addr, key2); got != value2 {
t.Errorf("expected transient state for key2 to be %s, got %s", value2.Hex(), got.Hex())
}

// Verify other addresses/keys remain empty
otherAddr := common.BytesToAddress([]byte{0x2})
if got := statedb.GetTransientState(otherAddr, key1); got != (common.Hash{}) {
t.Errorf("expected transient state for different address to be empty, got %s", got.Hex())
}

otherKey := common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000003")
if got := statedb.GetTransientState(addr, otherKey); got != (common.Hash{}) {
t.Errorf("expected transient state for different key to be empty, got %s", got.Hex())
}
}