diff --git a/crates/abi/abi/HEVM.sol b/crates/abi/abi/HEVM.sol index 47c66ae51774..76c08ddac069 100644 --- a/crates/abi/abi/HEVM.sol +++ b/crates/abi/abi/HEVM.sol @@ -5,6 +5,9 @@ struct DirEntry { string errorMessage; string path; uint64 depth; bool isDir; bo struct FsMetadata { bool isDir; bool isSymlink; uint256 length; bool readOnly; uint256 modified; uint256 accessed; uint256 created; } struct Wallet { address addr; uint256 publicKeyX; uint256 publicKeyY; uint256 privateKey; } struct FfiResult { int32 exitCode; bytes stdout; bytes stderr; } +struct ChainInfo { uint256 forkId; uint256 chainId; } +struct AccountAccess { ChainInfo chainInfo; uint256 kind; address account; address accessor; bool initialized; uint256 oldBalance; uint256 newBalance; bytes deployedCode; uint256 value; bytes data; bool reverted; StorageAccess[] storageAccesses; } +struct StorageAccess { address account; bytes32 slot; bool isWrite; bytes32 previousValue; bytes32 newValue; bool reverted; } allowCheatcodes(address) @@ -84,6 +87,9 @@ record() accesses(address)(bytes32[], bytes32[]) skip(bool) +startStateDiffRecording() +stopAndReturnStateDiff()(AccountAccess[]) + recordLogs() getRecordedLogs()(Log[]) diff --git a/crates/abi/src/bindings/hevm.rs b/crates/abi/src/bindings/hevm.rs index 8ed992f4d1a5..8591d4911b80 100644 --- a/crates/abi/src/bindings/hevm.rs +++ b/crates/abi/src/bindings/hevm.rs @@ -4891,6 +4891,77 @@ pub mod hevm { }, ], ), + ( + ::std::borrow::ToOwned::to_owned("startStateDiffRecording"), + ::std::vec![ + ::ethers_core::abi::ethabi::Function { + name: ::std::borrow::ToOwned::to_owned( + "startStateDiffRecording", + ), + inputs: ::std::vec![], + outputs: ::std::vec![], + constant: ::core::option::Option::None, + state_mutability: ::ethers_core::abi::ethabi::StateMutability::NonPayable, + }, + ], + ), + ( + ::std::borrow::ToOwned::to_owned("stopAndReturnStateDiff"), + ::std::vec![ + ::ethers_core::abi::ethabi::Function { + name: ::std::borrow::ToOwned::to_owned( + "stopAndReturnStateDiff", + ), + inputs: ::std::vec![], + outputs: ::std::vec![ + ::ethers_core::abi::ethabi::Param { + name: ::std::string::String::new(), + kind: ::ethers_core::abi::ethabi::ParamType::Array( + ::std::boxed::Box::new( + ::ethers_core::abi::ethabi::ParamType::Tuple( + ::std::vec![ + ::ethers_core::abi::ethabi::ParamType::Tuple( + ::std::vec![ + ::ethers_core::abi::ethabi::ParamType::Uint(256usize), + ::ethers_core::abi::ethabi::ParamType::Uint(256usize), + ], + ), + ::ethers_core::abi::ethabi::ParamType::Uint(256usize), + ::ethers_core::abi::ethabi::ParamType::Address, + ::ethers_core::abi::ethabi::ParamType::Address, + ::ethers_core::abi::ethabi::ParamType::Bool, + ::ethers_core::abi::ethabi::ParamType::Uint(256usize), + ::ethers_core::abi::ethabi::ParamType::Uint(256usize), + ::ethers_core::abi::ethabi::ParamType::Bytes, + ::ethers_core::abi::ethabi::ParamType::Uint(256usize), + ::ethers_core::abi::ethabi::ParamType::Bytes, + ::ethers_core::abi::ethabi::ParamType::Bool, + ::ethers_core::abi::ethabi::ParamType::Array( + ::std::boxed::Box::new( + ::ethers_core::abi::ethabi::ParamType::Tuple( + ::std::vec![ + ::ethers_core::abi::ethabi::ParamType::Address, + ::ethers_core::abi::ethabi::ParamType::FixedBytes(32usize), + ::ethers_core::abi::ethabi::ParamType::Bool, + ::ethers_core::abi::ethabi::ParamType::FixedBytes(32usize), + ::ethers_core::abi::ethabi::ParamType::FixedBytes(32usize), + ::ethers_core::abi::ethabi::ParamType::Bool, + ], + ), + ), + ), + ], + ), + ), + ), + internal_type: ::core::option::Option::None, + }, + ], + constant: ::core::option::Option::None, + state_mutability: ::ethers_core::abi::ethabi::StateMutability::NonPayable, + }, + ], + ), ( ::std::borrow::ToOwned::to_owned("stopBroadcast"), ::std::vec![ @@ -7349,6 +7420,49 @@ pub mod hevm { .method_hash([69, 181, 96, 120], (p0, p1)) .expect("method not found (this should never happen)") } + ///Calls the contract's `startStateDiffRecording` (0xcf22e3c9) function + pub fn start_state_diff_recording( + &self, + ) -> ::ethers_contract::builders::ContractCall { + self.0 + .method_hash([207, 34, 227, 201], ()) + .expect("method not found (this should never happen)") + } + ///Calls the contract's `stopAndReturnStateDiff` (0xaa5cf90e) function + pub fn stop_and_return_state_diff( + &self, + ) -> ::ethers_contract::builders::ContractCall< + M, + ::std::vec::Vec< + ( + (::ethers_core::types::U256, ::ethers_core::types::U256), + ::ethers_core::types::U256, + ::ethers_core::types::Address, + ::ethers_core::types::Address, + bool, + ::ethers_core::types::U256, + ::ethers_core::types::U256, + ::ethers_core::types::Bytes, + ::ethers_core::types::U256, + ::ethers_core::types::Bytes, + bool, + ::std::vec::Vec< + ( + ::ethers_core::types::Address, + [u8; 32], + bool, + [u8; 32], + [u8; 32], + bool, + ), + >, + ), + >, + > { + self.0 + .method_hash([170, 92, 249, 14], ()) + .expect("method not found (this should never happen)") + } ///Calls the contract's `stopBroadcast` (0x76eadd36) function pub fn stop_broadcast( &self, @@ -10321,6 +10435,32 @@ pub mod hevm { pub ::ethers_core::types::Address, pub ::ethers_core::types::Address, ); + ///Container type for all input parameters for the `startStateDiffRecording` function with signature `startStateDiffRecording()` and selector `0xcf22e3c9` + #[derive( + Clone, + ::ethers_contract::EthCall, + ::ethers_contract::EthDisplay, + Default, + Debug, + PartialEq, + Eq, + Hash + )] + #[ethcall(name = "startStateDiffRecording", abi = "startStateDiffRecording()")] + pub struct StartStateDiffRecordingCall; + ///Container type for all input parameters for the `stopAndReturnStateDiff` function with signature `stopAndReturnStateDiff()` and selector `0xaa5cf90e` + #[derive( + Clone, + ::ethers_contract::EthCall, + ::ethers_contract::EthDisplay, + Default, + Debug, + PartialEq, + Eq, + Hash + )] + #[ethcall(name = "stopAndReturnStateDiff", abi = "stopAndReturnStateDiff()")] + pub struct StopAndReturnStateDiffCall; ///Container type for all input parameters for the `stopBroadcast` function with signature `stopBroadcast()` and selector `0x76eadd36` #[derive( Clone, @@ -10796,6 +10936,8 @@ pub mod hevm { StartMappingRecording(StartMappingRecordingCall), StartPrank0(StartPrank0Call), StartPrank1(StartPrank1Call), + StartStateDiffRecording(StartStateDiffRecordingCall), + StopAndReturnStateDiff(StopAndReturnStateDiffCall), StopBroadcast(StopBroadcastCall), StopMappingRecording(StopMappingRecordingCall), StopPrank(StopPrankCall), @@ -11783,6 +11925,16 @@ pub mod hevm { ) { return Ok(Self::StartPrank1(decoded)); } + if let Ok(decoded) = ::decode( + data, + ) { + return Ok(Self::StartStateDiffRecording(decoded)); + } + if let Ok(decoded) = ::decode( + data, + ) { + return Ok(Self::StopAndReturnStateDiff(decoded)); + } if let Ok(decoded) = ::decode( data, ) { @@ -12348,6 +12500,12 @@ pub mod hevm { Self::StartPrank1(element) => { ::ethers_core::abi::AbiEncode::encode(element) } + Self::StartStateDiffRecording(element) => { + ::ethers_core::abi::AbiEncode::encode(element) + } + Self::StopAndReturnStateDiff(element) => { + ::ethers_core::abi::AbiEncode::encode(element) + } Self::StopBroadcast(element) => { ::ethers_core::abi::AbiEncode::encode(element) } @@ -12619,6 +12777,12 @@ pub mod hevm { } Self::StartPrank0(element) => ::core::fmt::Display::fmt(element, f), Self::StartPrank1(element) => ::core::fmt::Display::fmt(element, f), + Self::StartStateDiffRecording(element) => { + ::core::fmt::Display::fmt(element, f) + } + Self::StopAndReturnStateDiff(element) => { + ::core::fmt::Display::fmt(element, f) + } Self::StopBroadcast(element) => ::core::fmt::Display::fmt(element, f), Self::StopMappingRecording(element) => { ::core::fmt::Display::fmt(element, f) @@ -13605,6 +13769,16 @@ pub mod hevm { Self::StartPrank1(value) } } + impl ::core::convert::From for HEVMCalls { + fn from(value: StartStateDiffRecordingCall) -> Self { + Self::StartStateDiffRecording(value) + } + } + impl ::core::convert::From for HEVMCalls { + fn from(value: StopAndReturnStateDiffCall) -> Self { + Self::StopAndReturnStateDiff(value) + } + } impl ::core::convert::From for HEVMCalls { fn from(value: StopBroadcastCall) -> Self { Self::StopBroadcast(value) @@ -15126,6 +15300,44 @@ pub mod hevm { Hash )] pub struct SnapshotReturn(pub ::ethers_core::types::U256); + ///Container type for all return fields from the `stopAndReturnStateDiff` function with signature `stopAndReturnStateDiff()` and selector `0xaa5cf90e` + #[derive( + Clone, + ::ethers_contract::EthAbiType, + ::ethers_contract::EthAbiCodec, + Default, + Debug, + PartialEq, + Eq, + Hash + )] + pub struct StopAndReturnStateDiffReturn( + pub ::std::vec::Vec< + ( + (::ethers_core::types::U256, ::ethers_core::types::U256), + ::ethers_core::types::U256, + ::ethers_core::types::Address, + ::ethers_core::types::Address, + bool, + ::ethers_core::types::U256, + ::ethers_core::types::U256, + ::ethers_core::types::Bytes, + ::ethers_core::types::U256, + ::ethers_core::types::Bytes, + bool, + ::std::vec::Vec< + ( + ::ethers_core::types::Address, + [u8; 32], + bool, + [u8; 32], + [u8; 32], + bool, + ), + >, + ), + >, + ); ///Container type for all return fields from the `tryFfi` function with signature `tryFfi(string[])` and selector `0xf45c1ce7` #[derive( Clone, @@ -15152,6 +15364,46 @@ pub mod hevm { Hash )] pub struct UnixTimeReturn(pub ::ethers_core::types::U256); + ///`AccountAccess((uint256,uint256),uint256,address,address,bool,uint256,uint256,bytes,uint256,bytes,bool,(address,bytes32,bool,bytes32,bytes32,bool)[])` + #[derive( + Clone, + ::ethers_contract::EthAbiType, + ::ethers_contract::EthAbiCodec, + Default, + Debug, + PartialEq, + Eq, + Hash + )] + pub struct AccountAccess { + pub chain_info: ChainInfo, + pub kind: ::ethers_core::types::U256, + pub account: ::ethers_core::types::Address, + pub accessor: ::ethers_core::types::Address, + pub initialized: bool, + pub old_balance: ::ethers_core::types::U256, + pub new_balance: ::ethers_core::types::U256, + pub deployed_code: ::ethers_core::types::Bytes, + pub value: ::ethers_core::types::U256, + pub data: ::ethers_core::types::Bytes, + pub reverted: bool, + pub storage_accesses: ::std::vec::Vec, + } + ///`ChainInfo(uint256,uint256)` + #[derive( + Clone, + ::ethers_contract::EthAbiType, + ::ethers_contract::EthAbiCodec, + Default, + Debug, + PartialEq, + Eq, + Hash + )] + pub struct ChainInfo { + pub fork_id: ::ethers_core::types::U256, + pub chain_id: ::ethers_core::types::U256, + } ///`DirEntry(string,string,uint64,bool,bool)` #[derive( Clone, @@ -15258,6 +15510,25 @@ pub mod hevm { pub name: ::std::string::String, pub url: ::std::string::String, } + ///`StorageAccess(address,bytes32,bool,bytes32,bytes32,bool)` + #[derive( + Clone, + ::ethers_contract::EthAbiType, + ::ethers_contract::EthAbiCodec, + Default, + Debug, + PartialEq, + Eq, + Hash + )] + pub struct StorageAccess { + pub account: ::ethers_core::types::Address, + pub slot: [u8; 32], + pub is_write: bool, + pub previous_value: [u8; 32], + pub new_value: [u8; 32], + pub reverted: bool, + } ///`Wallet(address,uint256,uint256,uint256)` #[derive( Clone, diff --git a/crates/cheatcodes/assets/cheatcodes.json b/crates/cheatcodes/assets/cheatcodes.json index a35fd71c0982..c835e4633997 100644 --- a/crates/cheatcodes/assets/cheatcodes.json +++ b/crates/cheatcodes/assets/cheatcodes.json @@ -33,6 +33,40 @@ "description": "A recurrent prank triggered by a `vm.startPrank()` call is currently active." } ] + }, + { + "name": "AccountAccessKind", + "description": "The kind of account access that occurred.", + "variants": [ + { + "name": "Call", + "description": "The account was called." + }, + { + "name": "DelegateCall", + "description": "The account was called via delegatecall." + }, + { + "name": "CallCode", + "description": "The account was called via callcode." + }, + { + "name": "StaticCall", + "description": "The account was called via staticcall." + }, + { + "name": "Create", + "description": "The account was created." + }, + { + "name": "SelfDestruct", + "description": "The account was selfdestructed." + }, + { + "name": "Resume", + "description": "Synthetic access indicating the current context has resumed after a previous sub-context (AccountAccess)." + } + ] } ], "structs": [ @@ -242,6 +276,124 @@ "description": "The `stderr` data." } ] + }, + { + "name": "ChainInfo", + "description": "Information on the chain and fork.", + "fields": [ + { + "name": "forkId", + "ty": "uint256", + "description": "The fork identifier. Set to zero if no fork is active." + }, + { + "name": "chainId", + "ty": "uint256", + "description": "The chain ID of the current fork." + } + ] + }, + { + "name": "AccountAccess", + "description": "The result of a `stopAndReturnStateDiff` call.", + "fields": [ + { + "name": "chainInfo", + "ty": "ChainInfo", + "description": "The chain and fork the access occurred." + }, + { + "name": "kind", + "ty": "AccountAccessKind", + "description": "The kind of account access that determines what the account is.\n If kind is Call, DelegateCall, StaticCall or CallCode, then the account is the callee.\n If kind is Create, then the account is the newly created account.\n If kind is SelfDestruct, then the account is the selfdestruct recipient.\n If kind is a Resume, then account represents a account context that has resumed." + }, + { + "name": "account", + "ty": "address", + "description": "The account that was accessed.\n It's either the account created, callee or a selfdestruct recipient for CREATE, CALL or SELFDESTRUCT." + }, + { + "name": "accessor", + "ty": "address", + "description": "What accessed the account." + }, + { + "name": "initialized", + "ty": "bool", + "description": "If the account was initialized or empty prior to the access.\n An account is considered initialized if it has code, a\n non-zero nonce, or a non-zero balance." + }, + { + "name": "oldBalance", + "ty": "uint256", + "description": "The previous balance of the accessed account." + }, + { + "name": "newBalance", + "ty": "uint256", + "description": "The potential new balance of the accessed account.\n That is, all balance changes are recorded here, even if reverts occurred." + }, + { + "name": "deployedCode", + "ty": "bytes", + "description": "Code of the account deployed by CREATE." + }, + { + "name": "value", + "ty": "uint256", + "description": "Value passed along with the account access" + }, + { + "name": "data", + "ty": "bytes", + "description": "Input data provided to the CREATE or CALL" + }, + { + "name": "reverted", + "ty": "bool", + "description": "If this access reverted in either the current or parent context." + }, + { + "name": "storageAccesses", + "ty": "StorageAccess[]", + "description": "An ordered list of storage accesses made during an account access operation." + } + ] + }, + { + "name": "StorageAccess", + "description": "The storage accessed during an `AccountAccess`.", + "fields": [ + { + "name": "account", + "ty": "address", + "description": "The account whose storage was accessed." + }, + { + "name": "slot", + "ty": "bytes32", + "description": "The slot that was accessed." + }, + { + "name": "isWrite", + "ty": "bool", + "description": "If the access was a write." + }, + { + "name": "previousValue", + "ty": "bytes32", + "description": "The previous value of the slot." + }, + { + "name": "newValue", + "ty": "bytes32", + "description": "The new value of the slot." + }, + { + "name": "reverted", + "ty": "bool", + "description": "If the access was reverted." + } + ] } ], "cheatcodes": [ @@ -4145,6 +4297,46 @@ "status": "stable", "safety": "unsafe" }, + { + "func": { + "id": "startStateDiffRecording", + "description": "Record all account accesses as part of CREATE, CALL or SELFDESTRUCT opcodes in order,\nalong with the context of the calls", + "declaration": "function startStateDiffRecording() external;", + "visibility": "external", + "mutability": "", + "signature": "startStateDiffRecording()", + "selector": "0xcf22e3c9", + "selectorBytes": [ + 207, + 34, + 227, + 201 + ] + }, + "group": "evm", + "status": "stable", + "safety": "safe" + }, + { + "func": { + "id": "stopAndReturnStateDiff", + "description": "Returns an ordered array of all account accesses from a `vm.startStateDiffRecording` session.", + "declaration": "function stopAndReturnStateDiff() external returns (AccountAccess[] memory accesses);", + "visibility": "external", + "mutability": "", + "signature": "stopAndReturnStateDiff()", + "selector": "0xaa5cf90e", + "selectorBytes": [ + 170, + 92, + 249, + 14 + ] + }, + "group": "evm", + "status": "stable", + "safety": "safe" + }, { "func": { "id": "stopBroadcast", diff --git a/crates/cheatcodes/spec/src/lib.rs b/crates/cheatcodes/spec/src/lib.rs index 3877806ee6fc..9e0cf12dde02 100644 --- a/crates/cheatcodes/spec/src/lib.rs +++ b/crates/cheatcodes/spec/src/lib.rs @@ -80,8 +80,14 @@ impl Cheatcodes<'static> { Vm::FsMetadata::STRUCT.clone(), Vm::Wallet::STRUCT.clone(), Vm::FfiResult::STRUCT.clone(), + Vm::ChainInfo::STRUCT.clone(), + Vm::AccountAccess::STRUCT.clone(), + Vm::StorageAccess::STRUCT.clone(), + ]), + enums: Cow::Owned(vec![ + Vm::CallerMode::ENUM.clone(), + Vm::AccountAccessKind::ENUM.clone(), ]), - enums: Cow::Owned(vec![Vm::CallerMode::ENUM.clone()]), errors: Vm::VM_ERRORS.iter().map(|&x| x.clone()).collect(), events: Cow::Borrowed(&[]), // events: Vm::VM_EVENTS.iter().map(|&x| x.clone()).collect(), diff --git a/crates/cheatcodes/spec/src/vm.rs b/crates/cheatcodes/spec/src/vm.rs index 519d6bad4d18..b3c65e1fb031 100644 --- a/crates/cheatcodes/spec/src/vm.rs +++ b/crates/cheatcodes/spec/src/vm.rs @@ -36,6 +36,24 @@ interface Vm { RecurrentPrank, } + /// The kind of account access that occurred. + enum AccountAccessKind { + /// The account was called. + Call, + /// The account was called via delegatecall. + DelegateCall, + /// The account was called via callcode. + CallCode, + /// The account was called via staticcall. + StaticCall, + /// The account was created. + Create, + /// The account was selfdestructed. + SelfDestruct, + /// Synthetic access indicating the current context has resumed after a previous sub-context (AccountAccess). + Resume, + } + /// An Ethereum log. Returned by `getRecordedLogs`. struct Log { /// The topics of the log, including the signature, if any. @@ -134,6 +152,66 @@ interface Vm { bytes stderr; } + /// Information on the chain and fork. + struct ChainInfo { + /// The fork identifier. Set to zero if no fork is active. + uint256 forkId; + /// The chain ID of the current fork. + uint256 chainId; + } + + /// The result of a `stopAndReturnStateDiff` call. + struct AccountAccess { + /// The chain and fork the access occurred. + ChainInfo chainInfo; + /// The kind of account access that determines what the account is. + /// If kind is Call, DelegateCall, StaticCall or CallCode, then the account is the callee. + /// If kind is Create, then the account is the newly created account. + /// If kind is SelfDestruct, then the account is the selfdestruct recipient. + /// If kind is a Resume, then account represents a account context that has resumed. + AccountAccessKind kind; + /// The account that was accessed. + /// It's either the account created, callee or a selfdestruct recipient for CREATE, CALL or SELFDESTRUCT. + address account; + /// What accessed the account. + address accessor; + /// If the account was initialized or empty prior to the access. + /// An account is considered initialized if it has code, a + /// non-zero nonce, or a non-zero balance. + bool initialized; + /// The previous balance of the accessed account. + uint256 oldBalance; + /// The potential new balance of the accessed account. + /// That is, all balance changes are recorded here, even if reverts occurred. + uint256 newBalance; + /// Code of the account deployed by CREATE. + bytes deployedCode; + /// Value passed along with the account access + uint256 value; + /// Input data provided to the CREATE or CALL + bytes data; + /// If this access reverted in either the current or parent context. + bool reverted; + /// An ordered list of storage accesses made during an account access operation. + StorageAccess[] storageAccesses; + } + + /// The storage accessed during an `AccountAccess`. + struct StorageAccess { + /// The account whose storage was accessed. + address account; + /// The slot that was accessed. + bytes32 slot; + /// If the access was a write. + bool isWrite; + /// The previous value of the slot. + bytes32 previousValue; + /// The new value of the slot. + bytes32 newValue; + /// If the access was reverted. + bool reverted; + } + // ======== EVM ======== /// Gets the address for a given private key. @@ -166,6 +244,15 @@ interface Vm { #[cheatcode(group = Evm, safety = Safe)] function accesses(address target) external returns (bytes32[] memory readSlots, bytes32[] memory writeSlots); + /// Record all account accesses as part of CREATE, CALL or SELFDESTRUCT opcodes in order, + /// along with the context of the calls + #[cheatcode(group = Evm, safety = Safe)] + function startStateDiffRecording() external; + + /// Returns an ordered array of all account accesses from a `vm.startStateDiffRecording` session. + #[cheatcode(group = Evm, safety = Safe)] + function stopAndReturnStateDiff() external returns (AccountAccess[] memory accesses); + // -------- Recording Map Writes -------- /// Starts recording all map SSTOREs for later retrieval. diff --git a/crates/cheatcodes/src/evm.rs b/crates/cheatcodes/src/evm.rs index 8585c5410ce0..630e3a8f9fad 100644 --- a/crates/cheatcodes/src/evm.rs +++ b/crates/cheatcodes/src/evm.rs @@ -341,6 +341,21 @@ impl Cheatcode for revertToCall { } } +impl Cheatcode for startStateDiffRecordingCall { + fn apply(&self, state: &mut Cheatcodes) -> Result { + let Self {} = self; + state.recorded_account_diffs_stack = Some(Default::default()); + Ok(Default::default()) + } +} + +impl Cheatcode for stopAndReturnStateDiffCall { + fn apply(&self, state: &mut Cheatcodes) -> Result { + let Self {} = self; + get_state_diff(state) + } +} + pub(super) fn get_nonce(ccx: &mut CheatsCtxt, address: &Address) -> Result { super::script::correct_sender_nonce(ccx)?; let (account, _) = ccx.data.journaled_state.load_account(*address, ccx.data.db)?; @@ -404,3 +419,22 @@ pub(super) fn journaled_account<'a, DB: DatabaseExt>( data.journaled_state.touch(&addr); Ok(data.journaled_state.state.get_mut(&addr).expect("account is loaded")) } + +/// Consumes recorded account accesses and returns them as an abi encoded +/// array of [AccountAccess]. If there are no accounts were +/// recorded as accessed, an abi encoded empty array is returned. +/// +/// In the case where `stopAndReturnStateDiff` is called at a lower +/// depth than `startStateDiffRecording`, multiple `Vec` +/// will be flattened, preserving the order of the accesses. +fn get_state_diff(state: &mut Cheatcodes) -> Result { + let res = state + .recorded_account_diffs_stack + .replace(Default::default()) + .unwrap_or_default() + .into_iter() + .flatten() + .map(|record| record.access) + .collect::>(); + Ok(res.abi_encode()) +} diff --git a/crates/cheatcodes/src/inspector.rs b/crates/cheatcodes/src/inspector.rs index ab5368c44f5d..02bea32a418c 100644 --- a/crates/cheatcodes/src/inspector.rs +++ b/crates/cheatcodes/src/inspector.rs @@ -28,7 +28,9 @@ use foundry_evm_core::{ use foundry_utils::types::ToEthers; use itertools::Itertools; use revm::{ - interpreter::{opcode, CallInputs, CreateInputs, Gas, InstructionResult, Interpreter}, + interpreter::{ + opcode, CallInputs, CallScheme, CreateInputs, Gas, InstructionResult, Interpreter, + }, primitives::{BlockEnv, CreateScheme, TransactTo}, EVMData, Inspector, }; @@ -85,6 +87,14 @@ pub struct BroadcastableTransaction { /// List of transactions that can be broadcasted. pub type BroadcastableTransactions = VecDeque; +#[derive(Debug, Clone)] +pub struct AccountAccess { + /// The account access. + pub access: crate::Vm::AccountAccess, + /// The call depth the account was accessed. + pub depth: u64, +} + /// An EVM inspector that handles calls to various cheatcodes, each with their own behavior. /// /// Cheatcodes can be called by contracts during execution to modify the VM environment, such as @@ -137,6 +147,13 @@ pub struct Cheatcodes { /// Recorded storage reads and writes pub accesses: Option, + /// Recorded account accesses (calls, creates) organized by relative call depth, where the + /// topmost vector corresponds to accesses at the depth at which account access recording + /// began. Each vector in the matrix represents a list of accesses at a specific call + /// depth. Once that call context has ended, the last vector is removed from the matrix and + /// merged into the previous vector. + pub recorded_account_diffs_stack: Option>>, + /// Recorded logs pub recorded_logs: Option>, @@ -226,6 +243,7 @@ impl Cheatcodes { } /// Determines the address of the contract and marks it as allowed + /// Returns the address of the contract created /// /// There may be cheatcodes in the constructor of the new contract, in order to allow them /// automatically we need to determine the new address @@ -233,13 +251,7 @@ impl Cheatcodes { &self, data: &mut EVMData<'_, DB>, inputs: &CreateInputs, - ) { - if data.journaled_state.depth > 1 && !data.db.has_cheatcode_access(inputs.caller) { - // we only grant cheat code access for new contracts if the caller also has - // cheatcode access and the new contract is created in top most call - return - } - + ) -> Address { let old_nonce = data .journaled_state .state @@ -248,7 +260,15 @@ impl Cheatcodes { .unwrap_or_default(); let created_address = get_create_address(inputs, old_nonce); + if data.journaled_state.depth > 1 && !data.db.has_cheatcode_access(inputs.caller) { + // we only grant cheat code access for new contracts if the caller also has + // cheatcode access and the new contract is created in top most call + return created_address + } + data.db.allow_cheatcode_access(created_address); + + created_address } /// Called when there was a revert. @@ -392,6 +412,114 @@ impl Inspector for Cheatcodes { } } + // Record account access via SELFDESTRUCT if `recordAccountAccesses` has been called + if let Some(account_accesses) = &mut self.recorded_account_diffs_stack { + if interpreter.current_opcode() == opcode::SELFDESTRUCT { + let target = try_or_continue!(interpreter.stack().peek(0)); + // load balance of this account + let value = if let Ok((account, _)) = + data.journaled_state.load_account(interpreter.contract().address, data.db) + { + account.info.balance + } else { + U256::ZERO + }; + let account = Address::from_word(B256::from(target)); + // get previous balance and initialized status of the target account + let (initialized, old_balance) = + if let Ok((account, _)) = data.journaled_state.load_account(account, data.db) { + (account.info.exists(), account.info.balance) + } else { + (false, U256::ZERO) + }; + // register access for the target account + let access = crate::Vm::AccountAccess { + chainInfo: crate::Vm::ChainInfo { + forkId: data.db.active_fork_id().unwrap_or_default(), + chainId: U256::from(data.env.cfg.chain_id), + }, + accessor: interpreter.contract().address, + account, + kind: crate::Vm::AccountAccessKind::SelfDestruct, + initialized, + oldBalance: old_balance, + newBalance: old_balance + value, + value, + data: vec![], + reverted: false, + deployedCode: vec![], + storageAccesses: vec![], + }; + // Ensure that we're not selfdestructing a context recording was initiated on + if let Some(last) = account_accesses.last_mut() { + last.push(AccountAccess { access, depth: data.journaled_state.depth() }); + } + } + } + + // Record granular ordered storage accesses if `startStateDiffRecording` has been called + if let Some(recorded_account_diffs_stack) = &mut self.recorded_account_diffs_stack { + match interpreter.current_opcode() { + opcode::SLOAD => { + let key = try_or_continue!(interpreter.stack().peek(0)); + let address = interpreter.contract().address; + + // Try to include present value for informational purposes, otherwise assume + // it's not set (zero value) + let mut present_value = U256::ZERO; + // Try to load the account and the slot's present value + if data.journaled_state.load_account(address, data.db).is_ok() { + if let Ok((previous, _)) = data.journaled_state.sload(address, key, data.db) + { + present_value = previous; + } + } + let access = crate::Vm::StorageAccess { + account: interpreter.contract().address, + slot: key.into(), + isWrite: false, + previousValue: present_value.into(), + newValue: present_value.into(), + reverted: false, + }; + append_storage_access( + recorded_account_diffs_stack, + access, + data.journaled_state.depth(), + ); + } + opcode::SSTORE => { + let key = try_or_continue!(interpreter.stack().peek(0)); + let value = try_or_continue!(interpreter.stack().peek(1)); + let address = interpreter.contract().address; + // Try to load the account and the slot's previous value, otherwise, assume it's + // not set (zero value) + let mut previous_value = U256::ZERO; + if data.journaled_state.load_account(address, data.db).is_ok() { + if let Ok((previous, _)) = data.journaled_state.sload(address, key, data.db) + { + previous_value = previous; + } + } + + let access = crate::Vm::StorageAccess { + account: address, + slot: key.into(), + isWrite: true, + previousValue: previous_value.into(), + newValue: value.into(), + reverted: false, + }; + append_storage_access( + recorded_account_diffs_stack, + access, + data.journaled_state.depth(), + ); + } + _ => (), + } + } + // If the allowed memory writes cheatcode is active at this context depth, check to see // if the current opcode can either mutate directly or expand memory. If the opcode at // the current program counter is a match, check if the modified memory lies within the @@ -689,6 +817,52 @@ impl Inspector for Cheatcodes { } } + // Record called accounts if `startStateDiffRecording` has been called + if let Some(recorded_account_diffs_stack) = &mut self.recorded_account_diffs_stack { + // Determine if account is "initialized," ie, it has a non-zero balance, a non-zero + // nonce, a non-zero KECCAK_EMPTY codehash, or non-empty code + let initialized; + let old_balance; + if let Ok((acc, _)) = data.journaled_state.load_account(call.contract, data.db) { + initialized = acc.info.exists(); + old_balance = acc.info.balance; + } else { + initialized = false; + old_balance = U256::ZERO; + } + let kind = match call.context.scheme { + CallScheme::Call => crate::Vm::AccountAccessKind::Call, + CallScheme::CallCode => crate::Vm::AccountAccessKind::CallCode, + CallScheme::DelegateCall => crate::Vm::AccountAccessKind::DelegateCall, + CallScheme::StaticCall => crate::Vm::AccountAccessKind::StaticCall, + }; + // Record this call by pushing it to a new pending vector; all subsequent calls at + // that depth will be pushed to the same vector. When the call ends, the + // RecordedAccountAccess (and all subsequent RecordedAccountAccesses) will be + // updated with the revert status of this call, since the EVM does not mark accounts + // as "warm" if the call from which they were accessed is reverted + recorded_account_diffs_stack.push(vec![AccountAccess { + access: crate::Vm::AccountAccess { + chainInfo: crate::Vm::ChainInfo { + forkId: data.db.active_fork_id().unwrap_or_default(), + chainId: U256::from(data.env.cfg.chain_id), + }, + accessor: call.context.caller, + account: call.contract, + kind, + initialized, + oldBalance: old_balance, + newBalance: U256::ZERO, // updated on call_end + value: call.transfer.value, + data: call.input.to_vec(), + reverted: false, + deployedCode: vec![], + storageAccesses: vec![], // updated on step + }, + depth: data.journaled_state.depth(), + }]); + } + (InstructionResult::Continue, gas, Bytes::new()) } @@ -755,6 +929,47 @@ impl Inspector for Cheatcodes { } } + // If `startStateDiffRecording` has been called, update the `reverted` status of the + // previous call depth's recorded accesses, if any + if let Some(recorded_account_diffs_stack) = &mut self.recorded_account_diffs_stack { + // The root call cannot be recorded. + if data.journaled_state.depth() > 0 { + let mut last_recorded_depth = + recorded_account_diffs_stack.pop().expect("missing CALL account accesses"); + // Update the reverted status of all deeper calls if this call reverted, in + // accordance with EVM behavior + if status.is_revert() { + last_recorded_depth.iter_mut().for_each(|element| { + element.access.reverted = true; + element + .access + .storageAccesses + .iter_mut() + .for_each(|storage_access| storage_access.reverted = true); + }) + } + let call_access = last_recorded_depth.first_mut().expect("empty AccountAccesses"); + // Assert that we're at the correct depth before recording post-call state changes. + // Depending on the depth the cheat was called at, there may not be any pending + // calls to update if execution has percolated up to a higher depth. + if call_access.depth == data.journaled_state.depth() { + if let Ok((acc, _)) = data.journaled_state.load_account(call.contract, data.db) + { + debug_assert!(access_is_call(call_access.access.kind)); + call_access.access.newBalance = acc.info.balance; + } + } + // Merge the last depth's AccountAccesses into the AccountAccesses at the current + // depth, or push them back onto the pending vector if higher depths were not + // recorded. This preserves ordering of accesses. + if let Some(last) = recorded_account_diffs_stack.last_mut() { + last.append(&mut last_recorded_depth); + } else { + recorded_account_diffs_stack.push(last_recorded_depth); + } + } + } + // At the end of the call, // we need to check if we've found all the emits. // We know we've found all the expected emits in the right order @@ -897,7 +1112,7 @@ impl Inspector for Cheatcodes { let gas = Gas::new(call.gas_limit); // allow cheatcodes from the address of the new contract - self.allow_cheatcodes_on_create(data, call); + let address = self.allow_cheatcodes_on_create(data, call); // Apply our prank if let Some(prank) = &self.prank { @@ -965,6 +1180,32 @@ impl Inspector for Cheatcodes { } } + // If `recordAccountAccesses` has been called, record the create + if let Some(recorded_account_diffs_stack) = &mut self.recorded_account_diffs_stack { + // Record the create context as an account access and create a new vector to record all + // subsequent account accesses + recorded_account_diffs_stack.push(vec![AccountAccess { + access: crate::Vm::AccountAccess { + chainInfo: crate::Vm::ChainInfo { + forkId: data.db.active_fork_id().unwrap_or_default(), + chainId: U256::from(data.env.cfg.chain_id), + }, + accessor: call.caller, + account: address, + kind: crate::Vm::AccountAccessKind::Create, + initialized: true, + oldBalance: U256::ZERO, // updated on create_end + newBalance: U256::ZERO, // updated on create_end + value: call.value, + data: call.init_code.to_vec(), + reverted: false, + deployedCode: vec![], // updated on create_end + storageAccesses: vec![], // updated on create_end + }, + depth: data.journaled_state.depth(), + }]); + } + (InstructionResult::Continue, None, gas, Bytes::new()) } @@ -1021,6 +1262,61 @@ impl Inspector for Cheatcodes { } } + // If `startStateDiffRecording` has been called, update the `reverted` status of the + // previous call depth's recorded accesses, if any + if let Some(recorded_account_diffs_stack) = &mut self.recorded_account_diffs_stack { + // The root call cannot be recorded. + if data.journaled_state.depth() > 0 { + let mut last_depth = + recorded_account_diffs_stack.pop().expect("missing CREATE account accesses"); + // Update the reverted status of all deeper calls if this call reverted, in + // accordance with EVM behavior + if status.is_revert() { + last_depth.iter_mut().for_each(|element| { + element.access.reverted = true; + element + .access + .storageAccesses + .iter_mut() + .for_each(|storage_access| storage_access.reverted = true); + }) + } + let create_access = last_depth.first_mut().expect("empty AccountAccesses"); + // Assert that we're at the correct depth before recording post-create state + // changes. Depending on what depth the cheat was called at, there + // may not be any pending calls to update if execution has + // percolated up to a higher depth. + if create_access.depth == data.journaled_state.depth() { + debug_assert_eq!( + create_access.access.kind as u8, + crate::Vm::AccountAccessKind::Create as u8 + ); + if let Some(address) = address { + if let Ok((created_acc, _)) = + data.journaled_state.load_account(address, data.db) + { + create_access.access.newBalance = created_acc.info.balance; + create_access.access.deployedCode = created_acc + .info + .code + .clone() + .unwrap_or_default() + .original_bytes() + .into(); + } + } + } + // Merge the last depth's AccountAccesses into the AccountAccesses at the current + // depth, or push them back onto the pending vector if higher depths were not + // recorded. This preserves ordering of accesses. + if let Some(last) = recorded_account_diffs_stack.last_mut() { + last.append(&mut last_depth); + } else { + recorded_account_diffs_stack.push(last_depth); + } + } + } + (status, address, remaining_gas, retdata) } } @@ -1117,3 +1413,61 @@ fn apply_dispatch(calls: &Vm::VmCalls, ccx: &mut CheatsCtxt } vm_calls!(match_) } + +/// Returns true if the kind of account access is a call. +fn access_is_call(kind: crate::Vm::AccountAccessKind) -> bool { + matches!( + kind, + crate::Vm::AccountAccessKind::Call | + crate::Vm::AccountAccessKind::StaticCall | + crate::Vm::AccountAccessKind::CallCode | + crate::Vm::AccountAccessKind::DelegateCall + ) +} + +/// Appends an AccountAccess that resumes the recording of the current context. +fn append_storage_access( + accesses: &mut [Vec], + storage_access: crate::Vm::StorageAccess, + storage_depth: u64, +) { + if let Some(last) = accesses.last_mut() { + // Assert that there's an existing record for the current context. + if !last.is_empty() && last.first().unwrap().depth < storage_depth { + // Three cases to consider: + // 1. If there hasn't been a context switch since the start of this context, then add + // the storage access to the current context record. + // 2. If there's an existing Resume record, then add the storage access to it. + // 3. Otherwise, create a new Resume record based on the current context. + if last.len() == 1 { + last.first_mut().unwrap().access.storageAccesses.push(storage_access); + } else { + let last_record = last.last_mut().unwrap(); + if last_record.access.kind as u8 == crate::Vm::AccountAccessKind::Resume as u8 { + last_record.access.storageAccesses.push(storage_access); + } else { + let entry = last.first().unwrap(); + let resume_record = crate::Vm::AccountAccess { + chainInfo: crate::Vm::ChainInfo { + forkId: entry.access.chainInfo.forkId, + chainId: entry.access.chainInfo.chainId, + }, + accessor: entry.access.accessor, + account: entry.access.account, + kind: crate::Vm::AccountAccessKind::Resume, + initialized: entry.access.initialized, + storageAccesses: vec![storage_access], + reverted: entry.access.reverted, + // The remaining fields are defaults + oldBalance: U256::ZERO, + newBalance: U256::ZERO, + value: U256::ZERO, + data: vec![], + deployedCode: vec![], + }; + last.push(AccountAccess { access: resume_record, depth: entry.depth }); + } + } + } + } +} diff --git a/testdata/cheats/RecordAccountAccesses.t.sol b/testdata/cheats/RecordAccountAccesses.t.sol new file mode 100644 index 000000000000..60652c130ed0 --- /dev/null +++ b/testdata/cheats/RecordAccountAccesses.t.sol @@ -0,0 +1,1204 @@ +// SPDX-License-Identifier: Unlicense +pragma solidity 0.8.18; + +import "ds-test/test.sol"; +import "./Vm.sol"; + +/// @notice Helper contract with a constructo that makes a call to itself then +/// optionally reverts if zero-length data is passed +contract SelfCaller { + constructor(bytes memory) payable { + assembly { + // call self to test that the cheatcode correctly reports the + // account as initialized even when there is no code at the + // contract address + pop(call(gas(), address(), div(callvalue(), 10), 0, 0, 0, 0)) + if eq(calldataload(0x04), 1) { revert(0, 0) } + } + } +} + +/// @notice Helper contract with a constructor that stores a value in storage +/// and then optionally reverts. +contract ConstructorStorer { + constructor(bool shouldRevert) { + assembly { + sstore(0x00, 0x01) + if shouldRevert { revert(0, 0) } + } + } +} + +/// @notice Helper contract that calls itself from the run method +contract Doer { + uint256[10] spacer; + mapping(bytes32 key => uint256 value) slots; + + function run() public payable { + slots[bytes32("doer 1")]++; + this.doStuff{value: msg.value / 10}(); + } + + function doStuff() external payable { + slots[bytes32("doer 2")]++; + } +} + +/// @notice Helper contract that selfdestructs to a target address within its +/// constructor +contract SelfDestructor { + constructor(address target) payable { + selfdestruct(payable(target)); + } +} + +/// @notice Helper contract that calls a Doer from the run method +contract Create2or { + function create2(bytes32 salt, bytes memory initcode) external payable returns (address result) { + assembly { + result := create2(callvalue(), add(initcode, 0x20), mload(initcode), salt) + } + } +} + +/// @notice Helper contract that calls a Doer from the run method and then +/// reverts +contract Reverter { + Doer immutable doer; + mapping(bytes32 key => uint256 value) slots; + + constructor(Doer _doer) { + doer = _doer; + } + + function run() public payable { + slots[bytes32("reverter")]++; + doer.run{value: msg.value / 10}(); + revert(); + } +} + +/// @notice Helper contract that calls a Doer from the run method +contract Succeeder { + Doer immutable doer; + mapping(bytes32 key => uint256 value) slots; + + constructor(Doer _doer) { + doer = _doer; + } + + function run() public payable { + slots[bytes32("succeeder")]++; + doer.run{value: msg.value / 10}(); + } +} + +/// @notice Helper contract that calls a Reverter and Succeeder from the run +/// method +contract NestedRunner { + Doer public immutable doer; + Reverter public immutable reverter; + Succeeder public immutable succeeder; + mapping(bytes32 key => uint256 value) slots; + + constructor() { + doer = new Doer(); + reverter = new Reverter(doer); + succeeder = new Succeeder(doer); + } + + function run(bool shouldRevert) public payable { + slots[bytes32("runner")]++; + try reverter.run{value: msg.value / 10}() { + if (shouldRevert) { + revert(); + } + } catch {} + succeeder.run{value: msg.value / 10}(); + if (shouldRevert) { + revert(); + } + } +} + +/// @notice Helper contract that writes to storage in a nested call +contract NestedStorer { + mapping(bytes32 key => uint256 value) slots; + + constructor() {} + + function run() public payable { + slots[bytes32("nested_storer 1")]++; + this.run2(); + slots[bytes32("nested_storer 2")]++; + } + + function run2() external payable { + slots[bytes32("nested_storer 3")]++; + slots[bytes32("nested_storer 4")]++; + } +} + +/// @notice Helper contract that directly reads from and writes to storage +contract StorageAccessor { + function read(bytes32 slot) public view returns (bytes32 value) { + assembly { + value := sload(slot) + } + } + + function write(bytes32 slot, bytes32 value) public { + assembly { + sstore(slot, value) + } + } +} + +/// @notice Proxy contract +contract Proxy { + bytes32 public constant IMPL_ADDR = bytes32(uint256(keccak256("ekans implementation"))); + + constructor(address _delegate) { + bytes32 impl = IMPL_ADDR; + assembly { + sstore(impl, _delegate) + } + } + + receive() external payable { + doProxyCall(); + } + + fallback() external payable { + doProxyCall(); + } + + function doProxyCall() internal { + address _target; + bytes32 impl = IMPL_ADDR; + assembly { + _target := sload(impl) + calldatacopy(0x0, 0x0, calldatasize()) + let result := delegatecall(gas(), _target, 0x0, calldatasize(), 0x0, 0) + returndatacopy(0x0, 0x0, returndatasize()) + switch result + case 0 { revert(0, 0) } + default { return(0, returndatasize()) } + } + } +} + +/// @notice Test that the cheatcode correctly records account accesses +contract RecordAccountAccessesTest is DSTest { + Vm constant cheats = Vm(HEVM_ADDRESS); + NestedRunner runner; + NestedStorer nestedStorer; + Create2or create2or; + StorageAccessor test1; + StorageAccessor test2; + + function setUp() public { + runner = new NestedRunner(); + nestedStorer = new NestedStorer(); + create2or = new Create2or(); + test1 = new StorageAccessor(); + test2 = new StorageAccessor(); + } + + function testStorageAccessDelegateCall() public { + StorageAccessor one = test1; + Proxy proxy = new Proxy(address(one)); + + cheats.startStateDiffRecording(); + address(proxy).call(abi.encodeCall(StorageAccessor.read, bytes32(uint256(1234)))); + Vm.AccountAccess[] memory called = cheats.stopAndReturnStateDiff(); + + assertEq(called.length, 2, "incorrect length"); + + assertEq(toUint(called[0].kind), toUint(Vm.AccountAccessKind.Call), "incorrect kind"); + assertEq(called[0].accessor, address(this)); + assertEq(called[0].account, address(proxy)); + + assertEq(toUint(called[1].kind), toUint(Vm.AccountAccessKind.DelegateCall), "incorrect kind"); + assertEq(called[1].account, address(one), "incorrect account"); + assertEq(called[1].accessor, address(this), "incorrect accessor"); + assertEq( + called[1].storageAccesses[0], + Vm.StorageAccess({ + account: address(proxy), + slot: bytes32(uint256(1234)), + isWrite: false, + previousValue: bytes32(uint256(0)), + newValue: bytes32(uint256(0)), + reverted: false + }) + ); + } + + /// @notice Test normal, non-nested storage accesses + function testStorageAccesses() public { + StorageAccessor one = test1; + StorageAccessor two = test2; + cheats.startStateDiffRecording(); + + one.read(bytes32(uint256(1234))); + one.write(bytes32(uint256(1235)), bytes32(uint256(5678))); + two.write(bytes32(uint256(5678)), bytes32(uint256(123469))); + two.write(bytes32(uint256(5678)), bytes32(uint256(1234))); + + Vm.AccountAccess[] memory called = cheats.stopAndReturnStateDiff(); + assertEq(called.length, 4, "incorrect length"); + + assertEq(called[0].storageAccesses.length, 1, "incorrect storage length"); + Vm.StorageAccess memory access = called[0].storageAccesses[0]; + assertEq( + access, + Vm.StorageAccess({ + account: address(one), + slot: bytes32(uint256(1234)), + isWrite: false, + previousValue: bytes32(uint256(0)), + newValue: bytes32(uint256(0)), + reverted: false + }) + ); + + assertEq(called[1].storageAccesses.length, 1, "incorrect storage length"); + access = called[1].storageAccesses[0]; + assertEq( + access, + Vm.StorageAccess({ + account: address(one), + slot: bytes32(uint256(1235)), + isWrite: true, + previousValue: bytes32(uint256(0)), + newValue: bytes32(uint256(5678)), + reverted: false + }) + ); + + assertEq(called[2].storageAccesses.length, 1, "incorrect storage length"); + access = called[2].storageAccesses[0]; + assertEq( + access, + Vm.StorageAccess({ + account: address(two), + slot: bytes32(uint256(5678)), + isWrite: true, + previousValue: bytes32(uint256(0)), + newValue: bytes32(uint256(123469)), + reverted: false + }) + ); + + assertEq(called[3].storageAccesses.length, 1, "incorrect storage length"); + access = called[3].storageAccesses[0]; + assertEq( + access, + Vm.StorageAccess({ + account: address(two), + slot: bytes32(uint256(5678)), + isWrite: true, + previousValue: bytes32(uint256(123469)), + newValue: bytes32(uint256(1234)), + reverted: false + }) + ); + } + + /// @notice Test that basic account accesses are correctly recorded + function testRecordAccountAccesses() public { + cheats.startStateDiffRecording(); + + (bool succ,) = address(1234).call(""); + (succ,) = address(5678).call{value: 1 ether}(""); + (succ,) = address(123469).call("hello world"); + (succ,) = address(5678).call(""); + // contract calls to self in constructor + SelfCaller caller = new SelfCaller{value: 2 ether}('hello2 world2'); + + Vm.AccountAccess[] memory called = cheats.stopAndReturnStateDiff(); + assertEq(called.length, 6); + assertEq( + called[0], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: address(this), + account: address(1234), + kind: Vm.AccountAccessKind.Call, + initialized: false, + oldBalance: 0, + newBalance: 0, + deployedCode: hex"", + value: 0, + data: "", + reverted: false, + storageAccesses: new Vm.StorageAccess[](0) + }) + ); + + assertEq( + called[1], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: address(this), + account: address(5678), + kind: Vm.AccountAccessKind.Call, + initialized: false, + oldBalance: 0, + newBalance: 1 ether, + deployedCode: hex"", + value: 1 ether, + data: "", + reverted: false, + storageAccesses: new Vm.StorageAccess[](0) + }) + ); + assertEq( + called[2], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: address(this), + account: address(123469), + kind: Vm.AccountAccessKind.Call, + initialized: false, + oldBalance: 0, + newBalance: 0, + deployedCode: hex"", + value: 0, + data: "hello world", + reverted: false, + storageAccesses: new Vm.StorageAccess[](0) + }) + ); + assertEq( + called[3], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: address(this), + account: address(5678), + kind: Vm.AccountAccessKind.Call, + initialized: true, + oldBalance: 1 ether, + newBalance: 1 ether, + deployedCode: hex"", + value: 0, + data: "", + reverted: false, + storageAccesses: new Vm.StorageAccess[](0) + }) + ); + assertEq( + called[4], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: address(this), + account: address(caller), + kind: Vm.AccountAccessKind.Create, + initialized: true, + oldBalance: 0, + newBalance: 2 ether, + deployedCode: address(caller).code, + value: 2 ether, + data: abi.encodePacked(type(SelfCaller).creationCode, abi.encode("hello2 world2")), + reverted: false, + storageAccesses: new Vm.StorageAccess[](0) + }) + ); + assertEq( + called[5], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: address(caller), + account: address(caller), + kind: Vm.AccountAccessKind.Call, + initialized: true, + oldBalance: 2 ether, + newBalance: 2 ether, + deployedCode: hex"", + value: 0.2 ether, + data: "", + reverted: false, + storageAccesses: new Vm.StorageAccess[](0) + }) + ); + } + + /// @notice Test that account accesses are correctly recorded when a call + /// reverts + function testRevertingCall() public { + uint256 initBalance = address(this).balance; + cheats.startStateDiffRecording(); + try this.revertingCall{value: 1 ether}(address(1234), "") {} catch {} + Vm.AccountAccess[] memory called = cheats.stopAndReturnStateDiff(); + assertEq(called.length, 2); + assertEq( + called[0], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: address(this), + account: address(this), + kind: Vm.AccountAccessKind.Call, + initialized: true, + oldBalance: initBalance, + newBalance: initBalance, + deployedCode: hex"", + value: 1 ether, + data: abi.encodeCall(this.revertingCall, (address(1234), "")), + reverted: true, + storageAccesses: new Vm.StorageAccess[](0) + }) + ); + assertEq( + called[1], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: address(this), + account: address(1234), + kind: Vm.AccountAccessKind.Call, + initialized: false, + oldBalance: 0, + newBalance: 0.1 ether, + deployedCode: hex"", + value: 0.1 ether, + data: "", + reverted: true, + storageAccesses: new Vm.StorageAccess[](0) + }) + ); + } + + /// @notice Test that nested account accesses are correctly recorded + function testNested() public { + cheats.startStateDiffRecording(); + runNested(false, false); + } + + /// @notice Test that nested account accesses are correctly recorded when + /// the first call reverts + function testNested_Revert() public { + cheats.startStateDiffRecording(); + runNested(true, false); + } + + /// @notice Helper function to test nested account accesses + /// @param shouldRevert Whether the first call should revert + function runNested(bool shouldRevert, bool expectFirstCall) public { + try runner.run{value: 1 ether}(shouldRevert) {} catch {} + Vm.AccountAccess[] memory called = cheats.stopAndReturnStateDiff(); + assertEq(called.length, 7 + toUint(expectFirstCall), "incorrect length"); + + uint256 startingIndex = toUint(expectFirstCall); + if (expectFirstCall) { + assertEq( + called[0], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: address(this), + account: address(1234), + kind: Vm.AccountAccessKind.Call, + oldBalance: 0, + newBalance: 0, + deployedCode: "", + initialized: false, + value: 0, + data: "", + reverted: false, + storageAccesses: new Vm.StorageAccess[](0) + }) + ); + } + + assertEq(called[startingIndex].storageAccesses.length, 2, "incorrect length"); + assertIncrementEq( + called[startingIndex].storageAccesses[0], + called[startingIndex].storageAccesses[1], + Vm.StorageAccess({ + account: address(runner), + slot: keccak256(abi.encodePacked(bytes32("runner"), bytes32(0))), + isWrite: true, + previousValue: bytes32(uint256(0)), + newValue: bytes32(uint256(1)), + reverted: shouldRevert + }) + ); + assertEq( + called[startingIndex], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: address(this), + account: address(runner), + kind: Vm.AccountAccessKind.Call, + oldBalance: 0, + newBalance: shouldRevert ? 0 : 0.9 ether, + deployedCode: "", + initialized: true, + value: 1 ether, + data: abi.encodeCall(NestedRunner.run, (shouldRevert)), + reverted: shouldRevert, + storageAccesses: new Vm.StorageAccess[](0) + }), + false + ); + + assertEq(called[startingIndex + 1].storageAccesses.length, 2, "incorrect length"); + assertIncrementEq( + called[startingIndex + 1].storageAccesses[0], + called[startingIndex + 1].storageAccesses[1], + Vm.StorageAccess({ + account: address(runner.reverter()), + slot: keccak256(abi.encodePacked(bytes32("reverter"), bytes32(0))), + isWrite: true, + previousValue: bytes32(uint256(0)), + newValue: bytes32(uint256(1)), + reverted: true + }) + ); + assertEq( + called[startingIndex + 1], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: address(runner), + account: address(runner.reverter()), + kind: Vm.AccountAccessKind.Call, + oldBalance: 0, + newBalance: 0, + deployedCode: "", + initialized: true, + value: 0.1 ether, + data: abi.encodeCall(Reverter.run, ()), + reverted: true, + storageAccesses: new Vm.StorageAccess[](0) + }), + false + ); + + assertEq(called[startingIndex + 2].storageAccesses.length, 2, "incorrect length"); + assertIncrementEq( + called[startingIndex + 2].storageAccesses[0], + called[startingIndex + 2].storageAccesses[1], + Vm.StorageAccess({ + account: address(runner.doer()), + slot: keccak256(abi.encodePacked(bytes32("doer 1"), uint256(10))), + isWrite: true, + previousValue: bytes32(uint256(0)), + newValue: bytes32(uint256(1)), + reverted: true + }) + ); + assertEq( + called[startingIndex + 2], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: address(runner.reverter()), + account: address(runner.doer()), + kind: Vm.AccountAccessKind.Call, + oldBalance: 0, + newBalance: 0.01 ether, + deployedCode: "", + initialized: true, + value: 0.01 ether, + data: abi.encodeCall(Doer.run, ()), + reverted: true, + storageAccesses: new Vm.StorageAccess[](0) + }), + false + ); + + assertEq(called[startingIndex + 3].storageAccesses.length, 2, "incorrect length"); + assertIncrementEq( + called[startingIndex + 3].storageAccesses[0], + called[startingIndex + 3].storageAccesses[1], + Vm.StorageAccess({ + account: address(runner.doer()), + slot: keccak256(abi.encodePacked(bytes32("doer 2"), uint256(10))), + isWrite: true, + previousValue: bytes32(uint256(0)), + newValue: bytes32(uint256(1)), + reverted: true + }) + ); + assertEq( + called[startingIndex + 3], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: address(runner.doer()), + account: address(runner.doer()), + kind: Vm.AccountAccessKind.Call, + oldBalance: 0.01 ether, + newBalance: 0.01 ether, + deployedCode: "", + initialized: true, + value: 0.001 ether, + data: abi.encodeCall(Doer.doStuff, ()), + reverted: true, + storageAccesses: new Vm.StorageAccess[](0) + }), + false + ); + + assertEq(called[startingIndex + 4].storageAccesses.length, 2, "incorrect length"); + assertIncrementEq( + called[startingIndex + 4].storageAccesses[0], + called[startingIndex + 4].storageAccesses[1], + Vm.StorageAccess({ + account: address(runner.succeeder()), + slot: keccak256(abi.encodePacked(bytes32("succeeder"), uint256(0))), + isWrite: true, + previousValue: bytes32(uint256(0)), + newValue: bytes32(uint256(1)), + reverted: shouldRevert + }) + ); + assertEq( + called[startingIndex + 4], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: address(runner), + account: address(runner.succeeder()), + kind: Vm.AccountAccessKind.Call, + oldBalance: 0, + newBalance: 0.09 ether, + deployedCode: "", + initialized: true, + value: 0.1 ether, + data: abi.encodeCall(Succeeder.run, ()), + reverted: shouldRevert, + storageAccesses: new Vm.StorageAccess[](0) + }), + false + ); + + assertEq(called[startingIndex + 5].storageAccesses.length, 2, "incorrect length"); + assertIncrementEq( + called[startingIndex + 5].storageAccesses[0], + called[startingIndex + 5].storageAccesses[1], + Vm.StorageAccess({ + account: address(runner.doer()), + slot: keccak256(abi.encodePacked(bytes32("doer 1"), uint256(10))), + isWrite: true, + previousValue: bytes32(uint256(0)), + newValue: bytes32(uint256(1)), + reverted: shouldRevert + }) + ); + assertEq( + called[startingIndex + 5], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: address(runner.succeeder()), + account: address(runner.doer()), + kind: Vm.AccountAccessKind.Call, + oldBalance: 0, + newBalance: 0.01 ether, + deployedCode: "", + initialized: true, + value: 0.01 ether, + data: abi.encodeCall(Doer.run, ()), + reverted: shouldRevert, + storageAccesses: new Vm.StorageAccess[](0) + }), + false + ); + + assertEq(called[startingIndex + 3].storageAccesses.length, 2, "incorrect length"); + assertIncrementEq( + called[startingIndex + 6].storageAccesses[0], + called[startingIndex + 6].storageAccesses[1], + Vm.StorageAccess({ + account: address(runner.doer()), + slot: keccak256(abi.encodePacked(bytes32("doer 2"), uint256(10))), + isWrite: true, + previousValue: bytes32(uint256(0)), + newValue: bytes32(uint256(1)), + reverted: shouldRevert + }) + ); + assertEq( + called[startingIndex + 6], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: address(runner.doer()), + account: address(runner.doer()), + kind: Vm.AccountAccessKind.Call, + oldBalance: 0.01 ether, + newBalance: 0.01 ether, + deployedCode: "", + initialized: true, + value: 0.001 ether, + data: abi.encodeCall(Doer.doStuff, ()), + reverted: shouldRevert, + storageAccesses: new Vm.StorageAccess[](0) + }), + false + ); + } + + function testNestedStorage() public { + cheats.startStateDiffRecording(); + nestedStorer.run(); + Vm.AccountAccess[] memory called = cheats.stopAndReturnStateDiff(); + assertEq(called.length, 3, "incorrect account access length"); + + assertEq(called[0].storageAccesses.length, 2, "incorrect run storage length"); + assertIncrementEq( + called[0].storageAccesses[0], + called[0].storageAccesses[1], + Vm.StorageAccess({ + account: address(nestedStorer), + slot: keccak256(abi.encodePacked(bytes32("nested_storer 1"), bytes32(0))), + isWrite: true, + previousValue: bytes32(uint256(0)), + newValue: bytes32(uint256(1)), + reverted: false + }) + ); + assertEq( + called[0], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: address(this), + account: address(nestedStorer), + kind: Vm.AccountAccessKind.Call, + oldBalance: 0, + newBalance: 0, + deployedCode: "", + initialized: true, + value: 0, + data: abi.encodeCall(NestedStorer.run, ()), + reverted: false, + storageAccesses: new Vm.StorageAccess[](0) + }), + false + ); + + assertEq(called[1].storageAccesses.length, 4, "incorrect run2 storage length"); + assertIncrementEq( + called[1].storageAccesses[0], + called[1].storageAccesses[1], + Vm.StorageAccess({ + account: address(nestedStorer), + slot: keccak256(abi.encodePacked(bytes32("nested_storer 3"), bytes32(0))), + isWrite: true, + previousValue: bytes32(uint256(0)), + newValue: bytes32(uint256(1)), + reverted: false + }) + ); + assertIncrementEq( + called[1].storageAccesses[2], + called[1].storageAccesses[3], + Vm.StorageAccess({ + account: address(nestedStorer), + slot: keccak256(abi.encodePacked(bytes32("nested_storer 4"), bytes32(0))), + isWrite: true, + previousValue: bytes32(uint256(0)), + newValue: bytes32(uint256(1)), + reverted: false + }) + ); + assertEq( + called[1], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: address(nestedStorer), + account: address(nestedStorer), + kind: Vm.AccountAccessKind.Call, + oldBalance: 0, + newBalance: 0, + deployedCode: "", + initialized: true, + value: 0, + data: abi.encodeCall(NestedStorer.run2, ()), + reverted: false, + storageAccesses: new Vm.StorageAccess[](0) + }), + false + ); + + assertEq(called[2].storageAccesses.length, 2, "incorrect resume storage length"); + assertIncrementEq( + called[2].storageAccesses[0], + called[2].storageAccesses[1], + Vm.StorageAccess({ + account: address(nestedStorer), + slot: keccak256(abi.encodePacked(bytes32("nested_storer 2"), bytes32(0))), + isWrite: true, + previousValue: bytes32(uint256(0)), + newValue: bytes32(uint256(1)), + reverted: false + }) + ); + assertEq( + called[2], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: address(this), + account: address(nestedStorer), + kind: Vm.AccountAccessKind.Resume, + oldBalance: 0, + newBalance: 0, + deployedCode: "", + initialized: true, + value: 0, + data: "", + reverted: false, + storageAccesses: new Vm.StorageAccess[](0) + }), + false + ); + } + + /// @notice Test that constructor account and storage accesses are recorded, including reverts + function testConstructorStorage() public { + cheats.startStateDiffRecording(); + address storer = address(new ConstructorStorer(false)); + try create2or.create2(bytes32(0), abi.encodePacked(type(ConstructorStorer).creationCode, abi.encode(true))) {} + catch {} + bytes memory creationCode = abi.encodePacked(type(ConstructorStorer).creationCode, abi.encode(true)); + address hypotheticalStorer = deriveCreate2Address(address(create2or), bytes32(0), keccak256(creationCode)); + + Vm.AccountAccess[] memory called = cheats.stopAndReturnStateDiff(); + assertEq(called.length, 3, "incorrect account access length"); + assertEq(toUint(called[0].kind), toUint(Vm.AccountAccessKind.Create), "incorrect kind"); + assertEq(toUint(called[1].kind), toUint(Vm.AccountAccessKind.Call), "incorrect kind"); + assertEq(toUint(called[2].kind), toUint(Vm.AccountAccessKind.Create), "incorrect kind"); + + assertEq(called[0].storageAccesses.length, 1, "incorrect storage access length"); + Vm.StorageAccess[] memory storageAccesses = new Vm.StorageAccess[](1); + storageAccesses[0] = Vm.StorageAccess({ + account: storer, + slot: bytes32(uint256(0)), + isWrite: true, + previousValue: bytes32(uint256(0)), + newValue: bytes32(uint256(1)), + reverted: false + }); + assertEq( + called[0], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: address(this), + account: address(storer), + kind: Vm.AccountAccessKind.Create, + oldBalance: 0, + newBalance: 0, + deployedCode: storer.code, + initialized: true, + value: 0, + data: abi.encodePacked(type(ConstructorStorer).creationCode, abi.encode(false)), + reverted: false, + storageAccesses: storageAccesses + }) + ); + + assertEq(called[1].storageAccesses.length, 0, "incorrect storage access length"); + assertEq( + called[1], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: address(this), + account: address(create2or), + kind: Vm.AccountAccessKind.Call, + oldBalance: 0, + newBalance: 0, + deployedCode: "", + initialized: true, + value: 0, + data: abi.encodeCall( + Create2or.create2, + (bytes32(0), abi.encodePacked(type(ConstructorStorer).creationCode, abi.encode(true))) + ), + reverted: false, + storageAccesses: new Vm.StorageAccess[](0) + }) + ); + + assertEq(called[2].storageAccesses.length, 1, "incorrect storage access length"); + storageAccesses = new Vm.StorageAccess[](1); + storageAccesses[0] = Vm.StorageAccess({ + account: hypotheticalStorer, + slot: bytes32(uint256(0)), + isWrite: true, + previousValue: bytes32(uint256(0)), + newValue: bytes32(uint256(1)), + reverted: true + }); + assertEq( + called[2], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: address(create2or), + account: hypotheticalStorer, + kind: Vm.AccountAccessKind.Create, + oldBalance: 0, + newBalance: 0, + deployedCode: address(hypotheticalStorer).code, + initialized: true, + value: 0, + data: creationCode, + reverted: true, + storageAccesses: storageAccesses + }) + ); + } + + /// @notice Test that account accesses are correctly recorded when the + /// recording is started from a lower depth than they are + /// retrieved + function testNested_LowerDepth() public { + this.startRecordingFromLowerDepth(); + runNested(false, true); + } + + /// @notice Test that account accesses are correctly recorded when + /// the first call reverts the and recording is started from + /// a lower depth than they are retrieved. + function testNested_LowerDepth_Revert() public { + this.startRecordingFromLowerDepth(); + runNested(true, true); + } + + /// @notice Test that constructor calls and calls made within a constructor + /// are correctly recorded, even if it reverts + function testCreateRevert() public { + cheats.startStateDiffRecording(); + bytes memory creationCode = abi.encodePacked(type(SelfCaller).creationCode, abi.encode("")); + try create2or.create2(bytes32(0), creationCode) {} catch {} + address hypotheticalAddress = deriveCreate2Address(address(create2or), bytes32(0), keccak256(creationCode)); + + Vm.AccountAccess[] memory called = cheats.stopAndReturnStateDiff(); + assertEq(called.length, 3, "incorrect length"); + assertEq( + called[1], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: address(create2or), + account: hypotheticalAddress, + kind: Vm.AccountAccessKind.Create, + oldBalance: 0, + newBalance: 0, + deployedCode: address(hypotheticalAddress).code, + initialized: true, + value: 0, + data: creationCode, + reverted: false, + storageAccesses: new Vm.StorageAccess[](0) + }) + ); + assertEq( + called[2], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: hypotheticalAddress, + account: hypotheticalAddress, + kind: Vm.AccountAccessKind.Call, + oldBalance: 0, + newBalance: 0, + deployedCode: hex"", + initialized: true, + value: 0, + data: "", + reverted: false, + storageAccesses: new Vm.StorageAccess[](0) + }) + ); + } + + /// @notice It is important to test SELFDESTRUCT behavior as long as there + /// are public networks that support the opcode, regardless of whether + /// or not Ethereum mainnet does. + function testSelfDestruct() public { + uint256 startingBalance = address(this).balance; + this.startRecordingFromLowerDepth(); + address a = address(new SelfDestructor{value:1 ether}(address(this))); + address b = address(new SelfDestructor{value:1 ether}(address(bytes20("doesn't exist yet")))); + Vm.AccountAccess[] memory called = cheats.stopAndReturnStateDiff(); + assertEq(called.length, 5, "incorrect length"); + assertEq( + called[1], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: address(this), + account: a, + kind: Vm.AccountAccessKind.Create, + oldBalance: 0, + newBalance: 0, + deployedCode: "", + initialized: true, + value: 1 ether, + data: abi.encodePacked(type(SelfDestructor).creationCode, abi.encode(address(this))), + reverted: false, + storageAccesses: new Vm.StorageAccess[](0) + }) + ); + assertEq( + called[2], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: address(a), + account: address(this), + kind: Vm.AccountAccessKind.SelfDestruct, + oldBalance: startingBalance - 1 ether, + newBalance: startingBalance, + deployedCode: "", + initialized: true, + value: 1 ether, + data: "", + reverted: false, + storageAccesses: new Vm.StorageAccess[](0) + }) + ); + assertEq( + called[3], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: address(this), + account: b, + kind: Vm.AccountAccessKind.Create, + oldBalance: 0, + newBalance: 0, + deployedCode: "", + initialized: true, + value: 1 ether, + data: abi.encodePacked(type(SelfDestructor).creationCode, abi.encode(address(bytes20("doesn't exist yet")))), + reverted: false, + storageAccesses: new Vm.StorageAccess[](0) + }) + ); + assertEq( + called[4], + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: address(b), + account: address(bytes20("doesn't exist yet")), + kind: Vm.AccountAccessKind.SelfDestruct, + oldBalance: 0, + newBalance: 1 ether, + deployedCode: hex"", + initialized: false, + value: 1 ether, + data: "", + reverted: false, + storageAccesses: new Vm.StorageAccess[](0) + }) + ); + } + + function startRecordingFromLowerDepth() external { + cheats.startStateDiffRecording(); + assembly { + pop(call(gas(), 1234, 0, 0, 0, 0, 0)) + } + } + + function revertingCall(address target, bytes memory data) external payable { + assembly { + pop(call(gas(), target, div(callvalue(), 10), add(data, 0x20), mload(data), 0, 0)) + } + revert(); + } + + /// Asserts that the given account access is a resume of the given parent + function assertResumeEq(Vm.AccountAccess memory actual, Vm.AccountAccess memory expected) internal { + assertEq( + actual, + Vm.AccountAccess({ + chainInfo: Vm.ChainInfo({forkId: 0, chainId: 0}), + accessor: expected.accessor, + account: expected.account, + kind: Vm.AccountAccessKind.Resume, + oldBalance: 0, + newBalance: 0, + deployedCode: "", + initialized: expected.initialized, + value: 0, + data: "", + reverted: expected.reverted, + storageAccesses: new Vm.StorageAccess[](0) + }), + false + ); + } + + function assertIncrementEq( + Vm.StorageAccess memory read, + Vm.StorageAccess memory write, + Vm.StorageAccess memory expected + ) internal { + assertEq( + read, + Vm.StorageAccess({ + account: expected.account, + slot: expected.slot, + isWrite: false, + previousValue: expected.previousValue, + newValue: expected.previousValue, + reverted: expected.reverted + }) + ); + assertEq( + write, + Vm.StorageAccess({ + account: expected.account, + slot: expected.slot, + isWrite: true, + previousValue: expected.previousValue, + newValue: expected.newValue, + reverted: expected.reverted + }) + ); + } + + function assertEq(Vm.AccountAccess memory actualAccess, Vm.AccountAccess memory expectedAccess) internal { + assertEq(actualAccess, expectedAccess, true); + } + + function assertEq(Vm.AccountAccess memory actualAccess, Vm.AccountAccess memory expectedAccess, bool checkStorage) + internal + { + assertEq(toUint(actualAccess.kind), toUint(expectedAccess.kind), "incorrect kind"); + assertEq(actualAccess.account, expectedAccess.account, "incorrect account"); + assertEq(actualAccess.accessor, expectedAccess.accessor, "incorrect accessor"); + assertEq(toUint(actualAccess.initialized), toUint(expectedAccess.initialized), "incorrect initialized"); + assertEq(actualAccess.oldBalance, expectedAccess.oldBalance, "incorrect oldBalance"); + assertEq(actualAccess.newBalance, expectedAccess.newBalance, "incorrect newBalance"); + assertEq(actualAccess.deployedCode, expectedAccess.deployedCode, "incorrect deployedCode"); + assertEq(actualAccess.value, expectedAccess.value, "incorrect value"); + assertEq(actualAccess.data, expectedAccess.data, "incorrect data"); + assertEq(toUint(actualAccess.reverted), toUint(expectedAccess.reverted), "incorrect reverted"); + if (checkStorage) { + assertEq( + actualAccess.storageAccesses.length, + expectedAccess.storageAccesses.length, + "incorrect storageAccesses length" + ); + for (uint256 i = 0; i < actualAccess.storageAccesses.length; i++) { + assertEq(actualAccess.storageAccesses[i], expectedAccess.storageAccesses[i]); + } + } + } + + function assertEq(Vm.StorageAccess memory actual, Vm.StorageAccess memory expected) internal { + assertEq(actual.account, expected.account, "incorrect storageAccess account"); + assertEq(actual.slot, expected.slot, "incorrect storageAccess slot"); + assertEq(toUint(actual.isWrite), toUint(expected.isWrite), "incorrect storageAccess isWrite"); + assertEq(actual.previousValue, expected.previousValue, "incorrect storageAccess previousValue"); + assertEq(actual.newValue, expected.newValue, "incorrect storageAccess newValue"); + assertEq(toUint(actual.reverted), toUint(expected.reverted), "incorrect storageAccess reverted"); + } + + function toUint(Vm.AccountAccessKind kind) internal pure returns (uint256 value) { + assembly { + value := and(kind, 0xff) + } + } + + function toUint(bool a) internal pure returns (uint256) { + return a ? 1 : 0; + } + + function deriveCreate2Address(address deployer, bytes32 salt, bytes32 codeHash) internal pure returns (address) { + return address(uint160(uint256(keccak256(abi.encodePacked(bytes1(0xff), deployer, salt, codeHash))))); + } +} diff --git a/testdata/cheats/Vm.sol b/testdata/cheats/Vm.sol index 0d718ccd967f..4fc5c3ce51b2 100644 --- a/testdata/cheats/Vm.sol +++ b/testdata/cheats/Vm.sol @@ -7,6 +7,7 @@ pragma solidity ^0.8.4; interface Vm { error CheatcodeError(string message); enum CallerMode { None, Broadcast, RecurrentBroadcast, Prank, RecurrentPrank } + enum AccountAccessKind { Call, DelegateCall, CallCode, StaticCall, Create, SelfDestruct, Resume } struct Log { bytes32[] topics; bytes data; address emitter; } struct Rpc { string key; string url; } struct EthGetLogs { address emitter; bytes32[] topics; bytes data; bytes32 blockHash; uint64 blockNumber; bytes32 transactionHash; uint64 transactionIndex; uint256 logIndex; bool removed; } @@ -14,6 +15,9 @@ interface Vm { struct FsMetadata { bool isDir; bool isSymlink; uint256 length; bool readOnly; uint256 modified; uint256 accessed; uint256 created; } struct Wallet { address addr; uint256 publicKeyX; uint256 publicKeyY; uint256 privateKey; } struct FfiResult { int32 exitCode; bytes stdout; bytes stderr; } + struct ChainInfo { uint256 forkId; uint256 chainId; } + struct AccountAccess { ChainInfo chainInfo; AccountAccessKind kind; address account; address accessor; bool initialized; uint256 oldBalance; uint256 newBalance; bytes deployedCode; uint256 value; bytes data; bool reverted; StorageAccess[] storageAccesses; } + struct StorageAccess { address account; bytes32 slot; bool isWrite; bytes32 previousValue; bytes32 newValue; bool reverted; } function accesses(address target) external returns (bytes32[] memory readSlots, bytes32[] memory writeSlots); function activeFork() external view returns (uint256 forkId); function addr(uint256 privateKey) external pure returns (address keyAddr); @@ -209,6 +213,8 @@ interface Vm { function startMappingRecording() external; function startPrank(address msgSender) external; function startPrank(address msgSender, address txOrigin) external; + function startStateDiffRecording() external; + function stopAndReturnStateDiff() external returns (AccountAccess[] memory accesses); function stopBroadcast() external; function stopMappingRecording() external; function stopPrank() external;