Skip to content

Commit

Permalink
re-execute call sequences, lookup exec. trace by tx hash
Browse files Browse the repository at this point in the history
  • Loading branch information
0xalpharush committed Jul 9, 2024
1 parent bfe0ba8 commit cc8c4af
Show file tree
Hide file tree
Showing 12 changed files with 191 additions and 88 deletions.
97 changes: 74 additions & 23 deletions chain/test_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"github.com/crytic/medusa/chain/config"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/tracing"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/eth/tracers"
"github.com/ethereum/go-ethereum/triedb"
"github.com/ethereum/go-ethereum/triedb/hashdb"
"github.com/holiman/uint256"
Expand Down Expand Up @@ -84,7 +86,7 @@ type TestChain struct {
// NewTestChain creates a simulated Ethereum backend used for testing, or returns an error if one occurred.
// This creates a test chain with a test chain configuration and the provided genesis allocation and config.
// If a nil config is provided, a default one is used.
func NewTestChain(genesisAlloc core.GenesisAlloc, testChainConfig *config.TestChainConfig) (*TestChain, error) {
func NewTestChain(genesisAlloc types.GenesisAlloc, testChainConfig *config.TestChainConfig) (*TestChain, error) {
// Copy our chain config, so it is not shared across chains.
chainConfig, err := utils.CopyChainConfig(params.TestChainConfig)
if err != nil {
Expand Down Expand Up @@ -143,7 +145,7 @@ func NewTestChain(genesisAlloc core.GenesisAlloc, testChainConfig *config.TestCh
return nil, err
}
for _, cheatContract := range cheatContracts {
genesisDefinition.Alloc[cheatContract.address] = core.GenesisAccount{
genesisDefinition.Alloc[cheatContract.address] = types.Account{
Balance: big.NewInt(0),
Code: []byte{0xFF},
}
Expand Down Expand Up @@ -251,7 +253,7 @@ func (t *TestChain) Clone(onCreateFunc func(chain *TestChain) error) (*TestChain
// Now add each transaction/message to it.
messages := t.blocks[i].Messages
for j := 0; j < len(messages); j++ {
err = targetChain.PendingBlockAddTx(messages[j])
err = targetChain.PendingBlockAddTx(messages[j], nil)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -561,7 +563,7 @@ func (t *TestChain) CallContract(msg *core.Message, state *state.StateDB, additi
}

// Obtain our state snapshot to revert any changes after our call
// snapshot := state.Snapshot()
snapshot := state.Snapshot()

// Set infinite balance to the fake caller account
state.AddBalance(msg.From, uint256.MustFromBig(math.MaxBig256), tracing.BalanceChangeUnspecified)
Expand All @@ -585,19 +587,57 @@ func (t *TestChain) CallContract(msg *core.Message, state *state.StateDB, additi
})
t.evm = evm

tx := utils.MessageToTransaction(msg)
if evm.Config.Tracer != nil && evm.Config.Tracer.OnTxStart != nil {
evm.Config.Tracer.OnTxStart(evm.GetVMContext(), utils.MessageToTransaction(msg), msg.From)
evm.Config.Tracer.OnTxStart(evm.GetVMContext(), tx, msg.From)
}
// Fund the gas pool, so it can execute endlessly (no block gas limit).
gasPool := new(core.GasPool).AddGas(math.MaxUint64)

// Perform our state transition to obtain the result.
res, err := core.NewStateTransition(evm, msg, gasPool).TransitionDb()
msgResult, err := core.ApplyMessage(evm, msg, gasPool)

// Revert to our state snapshot to undo any changes.
// state.RevertToSnapshot(snapshot)
if err != nil {
state.RevertToSnapshot(snapshot)
}

// Receipt:
var root []byte
if t.chainConfig.IsByzantium(blockContext.BlockNumber) {
t.state.Finalise(true)
} else {
root = state.IntermediateRoot(t.chainConfig.IsEIP158(blockContext.BlockNumber)).Bytes()
}

return res, err
// Create a new receipt for the transaction, storing the intermediate root and
// gas used by the tx.
receipt := &types.Receipt{Type: tx.Type(), PostState: root, CumulativeGasUsed: msgResult.UsedGas}
if msgResult.Failed() {
receipt.Status = types.ReceiptStatusFailed
} else {
receipt.Status = types.ReceiptStatusSuccessful
}
receipt.TxHash = tx.Hash()
receipt.GasUsed = msgResult.UsedGas

// If the transaction created a contract, store the creation address in the receipt.
if msg.To == nil {
receipt.ContractAddress = crypto.CreateAddress(evm.TxContext.Origin, tx.Nonce())
}

// Set the receipt logs and create the bloom filter.
receipt.Logs = t.state.GetLogs(tx.Hash(), blockContext.BlockNumber.Uint64(), blockContext.GetHash(blockContext.BlockNumber.Uint64()))
receipt.Bloom = types.CreateBloom(types.Receipts{receipt})
receipt.TransactionIndex = uint(0)

if evm.Config.Tracer != nil {
if evm.Config.Tracer.OnTxEnd != nil {
evm.Config.Tracer.OnTxEnd(receipt, nil)
}
}

return msgResult, err
}

// PendingBlock describes the current pending block which is being constructed and awaiting commitment to the chain.
Expand Down Expand Up @@ -701,16 +741,19 @@ func (t *TestChain) PendingBlockCreateWithParameters(blockNumber uint64, blockTi

// PendingBlockAddTx takes a message (internal txs) and adds it to the current pending block, updating the header
// with relevant execution information. If a pending block was not created, an error is returned.
// Returns the constructed block, or an error if one occurred.
func (t *TestChain) PendingBlockAddTx(message *core.Message) error {
// Returns an error if one occurred.
func (t *TestChain) PendingBlockAddTx(message *core.Message, getTracerFn func(txIndex int, txHash common.Hash) *tracers.Tracer) error {
if getTracerFn == nil {
getTracerFn = func(txIndex int, txHash common.Hash) *tracers.Tracer {
return t.transactionTracerRouter.NativeTracer.Tracer
}
}

// If we don't have a pending block, return an error
if t.pendingBlock == nil {
return errors.New("could not add tx to the chain's pending block because no pending block was created")
}

// Obtain our state root hash prior to execution.
// previousStateRoot := t.pendingBlock.Header.Root

// Create a gas pool indicating how much gas can be spent executing the transaction.
gasPool := new(core.GasPool).AddGas(t.pendingBlock.Header.GasLimit - t.pendingBlock.Header.GasUsed)

Expand All @@ -721,17 +764,25 @@ func (t *TestChain) PendingBlockAddTx(message *core.Message) error {
// TODO reuse
blockContext := newTestChainBlockContext(t, t.pendingBlock.Header)

// Create our EVM instance.
evm := vm.NewEVM(blockContext, core.NewEVMTxContext(message), t.state, t.chainConfig, vm.Config{
vmConfig := vm.Config{
//Debug: true,
Tracer: t.transactionTracerRouter.NativeTracer.Hooks,
NoBaseFee: true,
ConfigExtensions: t.vmConfigExtensions,
})
t.evm = evm
}

tracer := getTracerFn(len(t.pendingBlock.Messages), tx.Hash())
if tracer != nil {
vmConfig.Tracer = tracer.Hooks
}

t.state.SetTxContext(tx.Hash(), len(t.pendingBlock.Messages))

// Create our EVM instance.
evm := vm.NewEVM(blockContext, core.NewEVMTxContext(message), t.state, t.chainConfig, vmConfig)

// Set our EVM instance for the test chain in order for cheatcodes to access EVM interpreter's block context.
t.evm = evm

if evm.Config.Tracer != nil && evm.Config.Tracer.OnTxStart != nil {
evm.Config.Tracer.OnTxStart(evm.GetVMContext(), tx, message.From)
}
Expand Down Expand Up @@ -934,11 +985,11 @@ func (t *TestChain) emitContractChangeEvents(reverting bool, messageResults ...*
Contract: deploymentChange.Contract,
})
} else if deploymentChange.Destroyed {

Check failure on line 987 in chain/test_chain.go

View workflow job for this annotation

GitHub Actions / lint

SA9003: empty branch (staticcheck)
err = t.Events.ContractDeploymentAddedEventEmitter.Publish(ContractDeploymentsAddedEvent{
Chain: t,
Contract: deploymentChange.Contract,
DynamicDeployment: deploymentChange.DynamicCreation,
})
// err = t.Events.ContractDeploymentAddedEventEmitter.Publish(ContractDeploymentsAddedEvent{
// Chain: t,
// Contract: deploymentChange.Contract,
// DynamicDeployment: deploymentChange.DynamicCreation,
// })
}
if err != nil {
return err
Expand Down
14 changes: 7 additions & 7 deletions chain/test_chain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ func createChain(t *testing.T) (*TestChain, []common.Address) {
assert.NoError(t, err)

// NOTE: Sharing GenesisAlloc between nodes will result in some accounts not being funded for some reason.
genesisAlloc := make(core.GenesisAlloc)
genesisAlloc := make(types.GenesisAlloc)

// Fund all of our sender addresses in the genesis block
initBalance := new(big.Int).Div(abi.MaxInt256, big.NewInt(2))
for _, sender := range senders {
genesisAlloc[sender] = core.GenesisAccount{
genesisAlloc[sender] = types.Account{
Balance: initBalance,
}
}
Expand Down Expand Up @@ -260,7 +260,7 @@ func TestChainDynamicDeployments(t *testing.T) {
assert.NoError(t, err)

// Add our transaction to the block
err = chain.PendingBlockAddTx(&msg)
err = chain.PendingBlockAddTx(&msg, nil)
assert.NoError(t, err)

// Commit the pending block to the chain, so it becomes the new head.
Expand Down Expand Up @@ -385,7 +385,7 @@ func TestChainDeploymentWithArgs(t *testing.T) {
assert.NoError(t, err)

// Add our transaction to the block
err = chain.PendingBlockAddTx(&msg)
err = chain.PendingBlockAddTx(&msg, nil)
assert.NoError(t, err)

// Commit the pending block to the chain, so it becomes the new head.
Expand Down Expand Up @@ -494,7 +494,7 @@ func TestChainCloning(t *testing.T) {
assert.NoError(t, err)

// Add our transaction to the block
err = chain.PendingBlockAddTx(&msg)
err = chain.PendingBlockAddTx(&msg, nil)
assert.NoError(t, err)

// Commit the pending block to the chain, so it becomes the new head.
Expand Down Expand Up @@ -588,7 +588,7 @@ func TestChainCallSequenceReplayMatchSimple(t *testing.T) {
assert.NoError(t, err)

// Add our transaction to the block
err = chain.PendingBlockAddTx(&msg)
err = chain.PendingBlockAddTx(&msg, nil)
assert.NoError(t, err)

// Commit the pending block to the chain, so it becomes the new head.
Expand Down Expand Up @@ -627,7 +627,7 @@ func TestChainCallSequenceReplayMatchSimple(t *testing.T) {
_, err := recreatedChain.PendingBlockCreate()
assert.NoError(t, err)
for _, message := range chain.blocks[i].Messages {
err = recreatedChain.PendingBlockAddTx(message)
err = recreatedChain.PendingBlockAddTx(message, nil)
assert.NoError(t, err)
}
err = recreatedChain.PendingBlockCommit()
Expand Down
3 changes: 1 addition & 2 deletions compilation/types/compiled_contract.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ func (c *CompiledContract) IsMatch(initBytecode []byte, runtimeBytecode []byte)
deploymentBytecodeHash := deploymentMetadata.ExtractBytecodeHash()
definitionBytecodeHash := definitionMetadata.ExtractBytecodeHash()
if deploymentBytecodeHash != nil && definitionBytecodeHash != nil {
x := bytes.Equal(deploymentBytecodeHash, definitionBytecodeHash)
return x
return bytes.Equal(deploymentBytecodeHash, definitionBytecodeHash)
}
}
}
Expand Down
22 changes: 19 additions & 3 deletions fuzzing/calls/call_sequence_execution.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package calls

import (
"fmt"

"github.com/crytic/medusa/chain"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/eth/tracers"
)

// ExecuteCallSequenceFetchElementFunc describes a function that is called to obtain the next call sequence element to
Expand All @@ -22,7 +25,7 @@ type ExecuteCallSequenceExecutionCheckFunc func(currentExecutedSequence CallSequ
// A "post element executed check" function is provided to check whether execution should stop after each element is
// executed.
// Returns the call sequence which was executed and an error if one occurs.
func ExecuteCallSequenceIteratively(chain *chain.TestChain, fetchElementFunc ExecuteCallSequenceFetchElementFunc, executionCheckFunc ExecuteCallSequenceExecutionCheckFunc) (CallSequence, error) {
func ExecuteCallSequenceIteratively(chain *chain.TestChain, fetchElementFunc ExecuteCallSequenceFetchElementFunc, executionCheckFunc ExecuteCallSequenceExecutionCheckFunc, getTracerFn func(txIndex int, txHash common.Hash) *tracers.Tracer) (CallSequence, error) {
// If there is no fetch element function provided, throw an error
if fetchElementFunc == nil {
return nil, fmt.Errorf("could not execute call sequence on chain as the 'fetch element function' provided was nil")
Expand Down Expand Up @@ -84,7 +87,8 @@ func ExecuteCallSequenceIteratively(chain *chain.TestChain, fetchElementFunc Exe
}

// Try to add our transaction to this block.
err = chain.PendingBlockAddTx(callSequenceElement.Call.ToCoreMessage())
err = chain.PendingBlockAddTx(callSequenceElement.Call.ToCoreMessage(), getTracerFn)

if err != nil {
// If we encountered a block gas limit error, this tx is too expensive to fit in this block.
// If there are other transactions in the block, this makes sense. The block is "full".
Expand Down Expand Up @@ -161,6 +165,18 @@ func ExecuteCallSequence(chain *chain.TestChain, callSequence CallSequence) (Cal
return nil, nil
}

return ExecuteCallSequenceIteratively(chain, fetchElementFunc, nil, nil)
}

func ExecuteCallSequenceWithTracer(chain *chain.TestChain, callSequence CallSequence, getTracerFn func(txIndex int, txHash common.Hash) *tracers.Tracer) (CallSequence, error) {
// Execute our sequence with a simple fetch operation provided to obtain each element.
fetchElementFunc := func(currentIndex int) (*CallSequenceElement, error) {
if currentIndex < len(callSequence) {
return callSequence[currentIndex], nil
}
return nil, nil
}

// Execute our provided call sequence iteratively.
return ExecuteCallSequenceIteratively(chain, fetchElementFunc, nil)
return ExecuteCallSequenceIteratively(chain, fetchElementFunc, nil, getTracerFn)
}
5 changes: 2 additions & 3 deletions fuzzing/corpus/corpus.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ func (c *Corpus) initializeSequences(sequenceFiles *corpusDirectory[calls.CallSe
}

// Execute each call sequence, populating runtime data and collecting coverage data along the way.
_, err = calls.ExecuteCallSequenceIteratively(testChain, fetchElementFunc, executionCheckFunc)
_, err = calls.ExecuteCallSequenceIteratively(testChain, fetchElementFunc, executionCheckFunc, nil)

// If we failed to replay a sequence and measure coverage due to an unexpected error, report it.
if err != nil {
Expand All @@ -228,8 +228,7 @@ func (c *Corpus) initializeSequences(sequenceFiles *corpusDirectory[calls.CallSe
}

// Revert chain state to our starting point to test the next sequence.
err = testChain.RevertToBlockNumber(baseBlockNumber)
if err != nil {
if err := testChain.RevertToBlockNumber(baseBlockNumber); err != nil {
return fmt.Errorf("failed to reset the chain while seeding coverage: %v\n", err)
}
}
Expand Down
25 changes: 20 additions & 5 deletions fuzzing/executiontracer/execution_tracer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/crytic/medusa/chain"
"github.com/crytic/medusa/fuzzing/contracts"
"github.com/crytic/medusa/utils"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/state"
Expand All @@ -23,15 +24,16 @@ import (
func CallWithExecutionTrace(testChain *chain.TestChain, contractDefinitions contracts.Contracts, msg *core.Message, state *state.StateDB) (*core.ExecutionResult, *ExecutionTrace, error) {
// Create an execution tracer
executionTracer := NewExecutionTracer(contractDefinitions, testChain.CheatCodeContracts())

defer executionTracer.Close()
// Call the contract on our chain with the provided state.
executionResult, err := testChain.CallContract(msg, state, executionTracer.NativeTracer)
if err != nil {
return nil, nil, err
}

// Obtain our trace
trace := executionTracer.Trace()
hash := utils.MessageToTransaction(msg).Hash()
trace := executionTracer.GetTrace(hash)

// Return the trace
return executionResult, trace, nil
Expand All @@ -49,6 +51,8 @@ type ExecutionTracer struct {
// trace represents the current execution trace captured by this tracer.
trace *ExecutionTrace

traceMap map[common.Hash]*ExecutionTrace

// currentCallFrame references the current call frame being traced.
currentCallFrame *CallFrame

Expand All @@ -72,11 +76,13 @@ func NewExecutionTracer(contractDefinitions contracts.Contracts, cheatCodeContra
tracer := &ExecutionTracer{
contractDefinitions: contractDefinitions,
cheatCodeContracts: cheatCodeContracts,
traceMap: make(map[common.Hash]*ExecutionTrace),
}
nativeTracer := &tracers.Tracer{
Hooks: &tracing.Hooks{
OnTxStart: tracer.OnTxStart,
OnEnter: tracer.OnEnter,
OnTxEnd: tracer.OnTxEnd,
OnExit: tracer.OnExit,
OnOpcode: tracer.OnOpcode,
},
Expand All @@ -85,10 +91,19 @@ func NewExecutionTracer(contractDefinitions contracts.Contracts, cheatCodeContra

return tracer
}
func (t *ExecutionTracer) Close() {
t.traceMap = nil
}

// Trace returns the currently recording or last recorded execution trace by the tracer.
func (t *ExecutionTracer) Trace() *ExecutionTrace {
return t.trace
// GetTrace returns the currently recording or last recorded execution trace by the tracer.
func (t *ExecutionTracer) GetTrace(txHash common.Hash) *ExecutionTrace {
if trace, ok := t.traceMap[txHash]; ok {
return trace
}
return nil
}
func (t *ExecutionTracer) OnTxEnd(receipt *coretypes.Receipt, err error) {
t.traceMap[receipt.TxHash] = t.trace
}

// CaptureTxStart is called upon the start of transaction execution, as defined by tracers.Tracer.
Expand Down
Loading

0 comments on commit cc8c4af

Please sign in to comment.