diff --git a/tests/e2e-evm/contracts/ABI_BasicTests.sol b/tests/e2e-evm/contracts/ABI_BasicTests.sol index f1077498d..4112d10cb 100644 --- a/tests/e2e-evm/contracts/ABI_BasicTests.sol +++ b/tests/e2e-evm/contracts/ABI_BasicTests.sol @@ -9,8 +9,19 @@ pragma solidity ^0.8.24; // Low level caller // contract Caller { + /** + * @dev Call a function via CALL with the current msg.value + */ function functionCall(address payable to, bytes calldata data) external payable { - (bool success, bytes memory result) = to.call{value: msg.value}(data); + this.functionCallWithValue(to, msg.value, data); + } + + /** + * @dev Call a function via CALL with a specific value that may be different + * from the current msg.value + */ + function functionCallWithValue(address payable to, uint256 value, bytes calldata data) public payable { + (bool success, bytes memory result) = to.call{value: value}(data); if (!success) { // solhint-disable-next-line gas-custom-errors @@ -18,14 +29,85 @@ contract Caller { // solhint-disable-next-line no-inline-assembly assembly { + // Bubble up errors: revert(pointer, length of revert reason) + // - result is a dynamic array, so the first 32 bytes is the + // length of the array. + // - add(32, result) skips the length of the array and points to + // the start of the data. + // - mload(result) reads 32 bytes from the memory location, + // which is the length of the revert reason. revert(add(32, result), mload(result)) } } } - // TODO: Callcode + /** + * @dev Call a contract function via CALLCODE with the current msg.value + */ + function functionCallCode(address to, bytes calldata data) external payable { + this.functionCallCodeWithValue(to, msg.value, data); + } + + /** + * @dev Call a contract function via CALLCODE with a specific value that may + * be different from the current msg.value + */ + function functionCallCodeWithValue(address to, uint256 value, bytes calldata data) external payable { + // solhint-disable-next-line no-inline-assembly + assembly { + // Copy the calldata to memory, as callcode uses memory pointers. + // offset is where the actual data starts in the calldata. Copy the + // data to memory starting at 0. + // Note: We are taking full control of memory as we do not return + // to high-level Solidity code. This would be not memory safe as it + // may exceed the scratch space in the first 64 bytes from 0. + // This should be safe to still do, similar to the OpenZeppelin + // proxy contract, overwriting the full memory scratch pad at + // target 0 AND never returning to high-level Solidity code. + calldatacopy(0, data.offset, data.length) + + // callcode(g, a, v, in, insize, out, outsize) + // returns 0 on error (eg. out of gas) and 1 on success + let result := callcode( + gas(), // gas + to, // to address + value, // value to send + 0, // in - pointer to start of input, 0 since we copied the data to 0 + data.length, // insize - size of the input + 0, // out + 0 // outsize - 0 since we don't know the size of the output + ) + + // Copy the returned data to memory. + // returndatacopy(t, f, s) + // - t: target location in memory + // - f: source location in return data + // - s: size + // Note: Same memory safety notes as above with calldatacopy() + returndatacopy(0, 0, returndatasize()) + + switch result + // 0 on error + case 0 { + revert(0, returndatasize()) + } + // 1 on success + case 1 { + return(0, returndatasize()) + } + // Invalid result + default { + revert(0, 0) + } + } + } - function functionDelegateCall(address to, bytes calldata data) external { + /** + * @dev Call a contract function via DELEGATECALL with the current msg.value + * and current msg.sender. DELEGATECALL cannot specify a different + * value. + */ + function functionDelegateCall(address to, bytes calldata data) external payable { // solhint-disable-next-line avoid-low-level-calls (bool success, bytes memory result) = to.delegatecall(data); @@ -40,7 +122,12 @@ contract Caller { } } - function functionStaticCall(address to, bytes calldata data) external view { + /** + * @dev Call a contract function via STATICCALL with the current msg.value + * and current msg.sender. + * @return The result of the static call in bytes. + */ + function functionStaticCall(address to, bytes calldata data) external view returns (bytes memory) { (bool success, bytes memory result) = to.staticcall(data); if (!success) { @@ -52,6 +139,8 @@ contract Caller { revert(add(32, result), mload(result)) } } + + return result; } } diff --git a/tests/e2e-evm/contracts/ABI_DisabledTests.sol b/tests/e2e-evm/contracts/ABI_DisabledTests.sol index aaa090deb..f7d452806 100644 --- a/tests/e2e-evm/contracts/ABI_DisabledTests.sol +++ b/tests/e2e-evm/contracts/ABI_DisabledTests.sol @@ -14,6 +14,10 @@ contract NoopDisabledMock is ABI_BasicTests.NoopReceivePayableFallback { function noopPayable() external payable { mockRevert(); } + /** + * @dev This function is intentionally not marked as pure to test the + * behavior of view functions in disabled contracts. + */ // solc-ignore-next-line func-mutability function noopView() external view { mockRevert(); diff --git a/tests/e2e-evm/contracts/ContextInspector.sol b/tests/e2e-evm/contracts/ContextInspector.sol new file mode 100644 index 000000000..5a1914a1c --- /dev/null +++ b/tests/e2e-evm/contracts/ContextInspector.sol @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +interface ContextInspector { + /** + * @dev Emitted when the emitMsgSender() function is called. + */ + event MsgSender(address sender); + + /** + * @dev Emitted when the emitMsgValue() function is called. + * + * Note that `value` may be zero. + */ + event MsgValue(uint256 value); + + function emitMsgSender() external; + function emitMsgValue() external payable; + function getMsgSender() external view returns (address); +} + +/** + * @title A contract to inspect the msg context. + * @notice This contract is used to test the expected msg.sender and msg.value + * of a contract call in various scenarios. + */ +contract ContextInspectorMock is ContextInspector { + function emitMsgSender() external { + emit MsgSender(msg.sender); + } + + function emitMsgValue() external payable { + emit MsgValue(msg.value); + } + + /** + * @dev Returns the current msg.sender. This is primarily used for testing + * staticcall as events are not emitted. + */ + function getMsgSender() external view returns (address) { + return msg.sender; + } +} diff --git a/tests/e2e-evm/contracts/StorageTests.sol b/tests/e2e-evm/contracts/StorageTests.sol new file mode 100644 index 000000000..3b707b7f6 --- /dev/null +++ b/tests/e2e-evm/contracts/StorageTests.sol @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +interface StorageBasic { + function setStorageValue(uint256 value) external; +} + +/** + * @title A basic contract with a storage value. + * @notice This contract is used to test storage reads and writes, primarily + * for testing storage behavior in delegateCall. + */ +contract StorageBasicMock is StorageBasic { + uint256 public storageValue; + + function setStorageValue(uint256 value) external { + storageValue = value; + } +} diff --git a/tests/e2e-evm/test/abi_basic.test.ts b/tests/e2e-evm/test/abi_basic.test.ts index 185232f94..03ce1e0e7 100644 --- a/tests/e2e-evm/test/abi_basic.test.ts +++ b/tests/e2e-evm/test/abi_basic.test.ts @@ -17,8 +17,8 @@ import { Abi } from "abitype"; import { getAbiFallbackFunction, getAbiReceiveFunction } from "./helpers/abi"; import { whaleAddress } from "./addresses"; -const defaultGas = 25000n; -const contractCallerGas = defaultGas + 10000n; +const defaultGas = 25_000n; +const contractCallerGas = defaultGas + 12_000n; interface ContractTestCase { interface: keyof ArtifactsMap; @@ -89,11 +89,11 @@ describe("ABI_BasicTests", function () { // let publicClient: PublicClient; let walletClient: WalletClient; - let caller: GetContractReturnType; + let lowLevelCaller: GetContractReturnType; before("setup clients", async function () { publicClient = await hre.viem.getPublicClient(); walletClient = await hre.viem.getWalletClient(whaleAddress); - caller = await hre.viem.deployContract("Caller"); + lowLevelCaller = await hre.viem.deployContract("Caller"); }); interface StateContext { @@ -170,9 +170,9 @@ describe("ABI_BasicTests", function () { { name: "can be called by low level contract call", txParams: (ctx) => ({ - to: caller.address, + to: lowLevelCaller.address, data: encodeFunctionData({ - abi: caller.abi, + abi: lowLevelCaller.abi, functionName: "functionCall", args: [ctx.address, funcSelector], }), @@ -180,6 +180,19 @@ describe("ABI_BasicTests", function () { }), expectedStatus: "success", }, + { + name: "can be called by callcode", + txParams: (ctx) => ({ + to: lowLevelCaller.address, + data: encodeFunctionData({ + abi: lowLevelCaller.abi, + functionName: "functionCallCode", + args: [ctx.address, funcSelector], + }), + gas: contractCallerGas, + }), + expectedStatus: "success", + }, { name: "can be called by high level contract call", txParams: (ctx) => ({ @@ -205,9 +218,9 @@ describe("ABI_BasicTests", function () { { name: "can be called by static call", txParams: (ctx) => ({ - to: caller.address, + to: lowLevelCaller.address, data: encodeFunctionData({ - abi: caller.abi, + abi: lowLevelCaller.abi, functionName: "functionStaticCall", args: [ctx.address, funcSelector], }), @@ -218,9 +231,9 @@ describe("ABI_BasicTests", function () { { name: "can be called by static call with extra data", txParams: (ctx) => ({ - to: caller.address, + to: lowLevelCaller.address, data: encodeFunctionData({ - abi: caller.abi, + abi: lowLevelCaller.abi, functionName: "functionStaticCall", args: [ctx.address, concat([funcSelector, "0x01"])], }), @@ -258,9 +271,9 @@ describe("ABI_BasicTests", function () { { name: "can be called by high level contract call with value", txParams: (ctx) => ({ - to: caller.address, + to: lowLevelCaller.address, data: encodeFunctionData({ - abi: caller.abi, + abi: lowLevelCaller.abi, functionName: "functionCall", args: [ctx.address, funcSelector], }), @@ -313,9 +326,9 @@ describe("ABI_BasicTests", function () { { name: "can not be called by high level contract call with value", txParams: (ctx) => ({ - to: caller.address, + to: lowLevelCaller.address, data: encodeFunctionData({ - abi: caller.abi, + abi: lowLevelCaller.abi, functionName: "functionCall", args: [ctx.address, funcSelector], }), @@ -397,9 +410,9 @@ describe("ABI_BasicTests", function () { { name: "can be called by another contract with no data", txParams: (ctx) => ({ - to: caller.address, + to: lowLevelCaller.address, data: encodeFunctionData({ - abi: caller.abi, + abi: lowLevelCaller.abi, functionName: "functionCall", args: [ctx.address, "0x"], }), @@ -410,9 +423,9 @@ describe("ABI_BasicTests", function () { { name: "can be called by static call with no data", txParams: (ctx) => ({ - to: caller.address, + to: lowLevelCaller.address, data: encodeFunctionData({ - abi: caller.abi, + abi: lowLevelCaller.abi, functionName: "functionStaticCall", args: [ctx.address, "0x"], }), @@ -433,9 +446,9 @@ describe("ABI_BasicTests", function () { { name: "can not receive zero value transfers by high level contract call with no data", txParams: (ctx) => ({ - to: caller.address, + to: lowLevelCaller.address, data: encodeFunctionData({ - abi: caller.abi, + abi: lowLevelCaller.abi, functionName: "functionCall", args: [ctx.address, "0x"], }), @@ -446,9 +459,9 @@ describe("ABI_BasicTests", function () { { name: "can not receive zero value transfers by static call with no data", txParams: (ctx) => ({ - to: caller.address, + to: lowLevelCaller.address, data: encodeFunctionData({ - abi: caller.abi, + abi: lowLevelCaller.abi, functionName: "functionStaticCall", args: [ctx.address, "0x"], }), @@ -470,9 +483,9 @@ describe("ABI_BasicTests", function () { { name: "can not receive plain transfers via message call", txParams: (ctx) => ({ - to: caller.address, + to: lowLevelCaller.address, data: encodeFunctionData({ - abi: caller.abi, + abi: lowLevelCaller.abi, functionName: "functionCall", args: [ctx.address, "0x"], }), @@ -495,9 +508,9 @@ describe("ABI_BasicTests", function () { { name: "can receive plain transfers via message call", txParams: (ctx) => ({ - to: caller.address, + to: lowLevelCaller.address, data: encodeFunctionData({ - abi: caller.abi, + abi: lowLevelCaller.abi, functionName: "functionCall", args: [ctx.address, "0x"], }), @@ -524,9 +537,9 @@ describe("ABI_BasicTests", function () { { name: "can be called with a non-matching function selector via message call", txParams: (ctx) => ({ - to: caller.address, + to: lowLevelCaller.address, data: encodeFunctionData({ - abi: caller.abi, + abi: lowLevelCaller.abi, functionName: "functionCall", args: [ctx.address, toFunctionSelector("does_not_exist()")], }), @@ -537,9 +550,9 @@ describe("ABI_BasicTests", function () { { name: "can be called with a non-matching function selector via static call", txParams: (ctx) => ({ - to: caller.address, + to: lowLevelCaller.address, data: encodeFunctionData({ - abi: caller.abi, + abi: lowLevelCaller.abi, functionName: "functionStaticCall", args: [ctx.address, toFunctionSelector("does_not_exist()")], }), @@ -559,9 +572,9 @@ describe("ABI_BasicTests", function () { { name: "can be called with an invalid (short) function selector via message call", txParams: (ctx) => ({ - to: caller.address, + to: lowLevelCaller.address, data: encodeFunctionData({ - abi: caller.abi, + abi: lowLevelCaller.abi, functionName: "functionCall", args: [ctx.address, "0x010203"], }), @@ -572,9 +585,9 @@ describe("ABI_BasicTests", function () { { name: "can be called with an invalid (short) function selector via static call", txParams: (ctx) => ({ - to: caller.address, + to: lowLevelCaller.address, data: encodeFunctionData({ - abi: caller.abi, + abi: lowLevelCaller.abi, functionName: "functionStaticCall", args: [ctx.address, "0x010203"], }), @@ -599,9 +612,9 @@ describe("ABI_BasicTests", function () { { name: "can not be called with a non-matching function selector via message call", txParams: (ctx) => ({ - to: caller.address, + to: lowLevelCaller.address, data: encodeFunctionData({ - abi: caller.abi, + abi: lowLevelCaller.abi, functionName: "functionCall", args: [ctx.address, toFunctionSelector("does_not_exist()")], }), @@ -612,9 +625,9 @@ describe("ABI_BasicTests", function () { { name: "can not be called with a non-matching function selector via static call", txParams: (ctx) => ({ - to: caller.address, + to: lowLevelCaller.address, data: encodeFunctionData({ - abi: caller.abi, + abi: lowLevelCaller.abi, functionName: "functionStaticCall", args: [ctx.address, toFunctionSelector("does_not_exist()")], }), @@ -634,9 +647,9 @@ describe("ABI_BasicTests", function () { { name: "can not be called with an invalid (short) function selector via message call", txParams: (ctx) => ({ - to: caller.address, + to: lowLevelCaller.address, data: encodeFunctionData({ - abi: caller.abi, + abi: lowLevelCaller.abi, functionName: "functionCall", args: [ctx.address, "0x010203"], }), @@ -647,9 +660,9 @@ describe("ABI_BasicTests", function () { { name: "can not be called with an invalid (short) function selector via static call", txParams: (ctx) => ({ - to: caller.address, + to: lowLevelCaller.address, data: encodeFunctionData({ - abi: caller.abi, + abi: lowLevelCaller.abi, functionName: "functionStaticCall", args: [ctx.address, "0x010203"], }), @@ -676,9 +689,9 @@ describe("ABI_BasicTests", function () { { name: "can receive value with a non-matching function selector via message call", txParams: (ctx) => ({ - to: caller.address, + to: lowLevelCaller.address, data: encodeFunctionData({ - abi: caller.abi, + abi: lowLevelCaller.abi, functionName: "functionCall", args: [ctx.address, toFunctionSelector("does_not_exist()")], }), @@ -702,9 +715,9 @@ describe("ABI_BasicTests", function () { { name: "can receive value with an invalid (short) function selector via message call", txParams: (ctx) => ({ - to: caller.address, + to: lowLevelCaller.address, data: encodeFunctionData({ - abi: caller.abi, + abi: lowLevelCaller.abi, functionName: "functionCall", args: [ctx.address, "0x010203"], }), @@ -733,9 +746,9 @@ describe("ABI_BasicTests", function () { { name: "can not receive value with a non-matching function selector via message call", txParams: (ctx) => ({ - to: caller.address, + to: lowLevelCaller.address, data: encodeFunctionData({ - abi: caller.abi, + abi: lowLevelCaller.abi, functionName: "functionCall", args: [ctx.address, toFunctionSelector("does_not_exist()")], }), @@ -759,9 +772,9 @@ describe("ABI_BasicTests", function () { { name: "can not receive value with an invalid (short) function selector via message call", txParams: (ctx) => ({ - to: caller.address, + to: lowLevelCaller.address, data: encodeFunctionData({ - abi: caller.abi, + abi: lowLevelCaller.abi, functionName: "functionCall", args: [ctx.address, "0x010203"], }), diff --git a/tests/e2e-evm/test/abi_disabled.test.ts b/tests/e2e-evm/test/abi_disabled.test.ts index b58335e92..17e54a99f 100644 --- a/tests/e2e-evm/test/abi_disabled.test.ts +++ b/tests/e2e-evm/test/abi_disabled.test.ts @@ -157,6 +157,13 @@ describe("ABI_DisabledTests", function () { encodeFunctionData({ abi: caller.abi, functionName: "functionCall", args: [ctx.address, data] }), gas: messageCallGas, }, + { + name: "message callcode", + to: () => caller.address, + mutateData: (data) => + encodeFunctionData({ abi: caller.abi, functionName: "functionCallCode", args: [ctx.address, data] }), + gas: messageCallGas, + }, { name: "message delegatecall", to: () => caller.address, diff --git a/tests/e2e-evm/test/message_call.test.ts b/tests/e2e-evm/test/message_call.test.ts new file mode 100644 index 000000000..205ce27d5 --- /dev/null +++ b/tests/e2e-evm/test/message_call.test.ts @@ -0,0 +1,515 @@ +import hre from "hardhat"; +import type { ArtifactsMap } from "hardhat/types/artifacts"; +import type { PublicClient, WalletClient, GetContractReturnType } from "@nomicfoundation/hardhat-viem/types"; +import { expect } from "chai"; +import { + Abi, + AbiParameter, + Address, + Chain, + ContractFunctionName, + decodeAbiParameters, + DecodeAbiParametersReturnType, + decodeFunctionResult, + encodeFunctionData, + getAddress, + Hex, + isAddress, + pad, + parseAbiParameters, + parseEventLogs, + SendTransactionParameters, + toHex, + TransactionReceipt, +} from "viem"; +import { whaleAddress } from "./addresses"; +import { ExtractAbiEventNames } from "abitype"; + +const defaultGas = 25000n; +const contractCallerGas = defaultGas + 10000n; + +describe("Message calls", () => { + // + // Client + Wallet Setup + // + let publicClient: PublicClient; + let walletClient: WalletClient; + let lowLevelCaller: GetContractReturnType; + before("setup clients", async function () { + publicClient = await hre.viem.getPublicClient(); + walletClient = await hre.viem.getWalletClient(whaleAddress); + lowLevelCaller = await hre.viem.deployContract("Caller"); + }); + + interface TestContext { + lowLevelCaller: GetContractReturnType; + // Using a generic to enforce type safety on functions that use this Abi, + // eg. decodeFunctionResult would have the functionName parameter type + // checked. + implementationAbi: ArtifactsMap[Impl]["abi"]; + implementationAddress: Address; + } + + type CallType = "call" | "callcode" | "delegatecall" | "staticcall"; + + /** + * buildCallData creates the transaction data for the given message call type, + * wrapping the function call data in the appropriate low level caller + * function. + */ + function buildCallData( + type: CallType, + implAddr: Address, + data: Hex, + callCodeValue?: bigint, + ): SendTransactionParameters { + if (callCodeValue && type !== "callcode") { + expect.fail("callCodeValue parameter is only valid for callcode"); + } + + switch (type) { + case "call": + // Direct call, just pass through the tx data + return { + account: walletClient.account, + to: implAddr, + gas: defaultGas, + data, + }; + + case "callcode": + if (callCodeValue) { + // Custom value, use functionCallCodeWithValue that calls the + // implementation contract with the given value that may be different + // from the parent value. + return { + account: walletClient.account, + to: lowLevelCaller.address, + gas: contractCallerGas, + data: encodeFunctionData({ + abi: lowLevelCaller.abi, + functionName: "functionCallCodeWithValue", + args: [implAddr, callCodeValue, data], + }), + }; + } + + // No custom value, use functionCallCode which just uses msg.value + return { + account: walletClient.account, + to: lowLevelCaller.address, + // Needs additional gas since it calls the external + // functionCallCodeWithValue function. + gas: contractCallerGas + 1_105n, + data: encodeFunctionData({ + abi: lowLevelCaller.abi, + functionName: "functionCallCode", + args: [implAddr, data], + }), + }; + + case "delegatecall": + return { + account: walletClient.account, + to: lowLevelCaller.address, + gas: contractCallerGas, + data: encodeFunctionData({ + abi: lowLevelCaller.abi, + functionName: "functionDelegateCall", + args: [implAddr, data], + }), + }; + + case "staticcall": + return { + account: walletClient.account, + to: lowLevelCaller.address, + gas: contractCallerGas * 100n, + data: encodeFunctionData({ + abi: lowLevelCaller.abi, + functionName: "functionStaticCall", + args: [implAddr, data], + }), + }; + } + } + + /** + * getResponseFromReceiptLogs decodes the event log data from the given + * transaction receipt and returns the decoded parameters. + */ + function getResponseFromReceiptLogs( + abi: abi, + eventName: ExtractAbiEventNames, + params: params, + txReceipt: TransactionReceipt, + ): DecodeAbiParametersReturnType { + // This can also be scoped to specific eventNames, but it would be more + // clear if this produces an eventName mismatch instead of silently + // returning an empty array. + const logs = parseEventLogs({ + abi: abi, + logs: txReceipt.logs, + }); + + expect(logs.length).to.equal(1, "unexpected number of logs"); + const [log] = logs; + + if (log.eventName !== eventName) { + expect.fail(`unexpected event name`); + } + + return decodeAbiParameters(params, log.data); + } + + function itHasCorrectMsgSender(getContext: () => TestContext<"ContextInspector">) { + describe("msg.sender", () => { + let ctx: TestContext<"ContextInspector">; + + before("Setup context", function () { + ctx = getContext(); + }); + + // senderTestCase is to validate the expected msg.sender for a given message call type. + interface senderTestCase { + name: string; + type: CallType; + wantChildMsgSender: (ctx: TestContext<"ContextInspector">, signer: Address) => Address; + } + + const testCases: senderTestCase[] = [ + { + name: "direct call is parent (signer)", + type: "call", + wantChildMsgSender: (_, signer) => signer, + }, + { + name: "callcode msg.sender is parent (caller)", + type: "callcode", + wantChildMsgSender: (ctx) => ctx.lowLevelCaller.address, + }, + { + name: "delegatecall propagates msg.sender (same as caller, parent signer)", + type: "delegatecall", + wantChildMsgSender: (_, signer) => signer, + }, + { + name: "staticcall msg.sender is parent (caller)", + type: "staticcall", + wantChildMsgSender: (ctx) => ctx.lowLevelCaller.address, + }, + ]; + + for (const tc of testCases) { + it(tc.name, async function () { + let functionName: ContractFunctionName = "emitMsgSender"; + + // Static Call cannot emit events so we use a different function that + // just returns the msg.sender. + if (tc.type === "staticcall") { + functionName = "getMsgSender"; + } + + // ContextInspector function call data + const baseData = encodeFunctionData({ + abi: ctx.implementationAbi, + functionName, + args: [], + }); + + // Modify the call data based on the test case type + const txData = buildCallData(tc.type, ctx.implementationAddress, baseData); + // No value for this test + txData.value = 0n; + + if (tc.type === "staticcall") { + // Check return value + const returnedMsgSender = await publicClient.call(txData); + if (returnedMsgSender.data === undefined) { + expect.fail("call return data is undefined"); + } + + // Decode low level caller first since it is a byte array that + // includes the offset, length, and data. + const dataBytes = decodeFunctionResult({ + abi: ctx.lowLevelCaller.abi, + functionName: "functionStaticCall", + data: returnedMsgSender.data, + }); + + // Decode dataBytes as an address + const address = decodeFunctionResult({ + abi: ctx.implementationAbi, + functionName: "getMsgSender", + data: dataBytes, + }); + + const expectedSender = tc.wantChildMsgSender(ctx, walletClient.account.address); + expect(getAddress(address)).to.equal(getAddress(expectedSender), "unexpected msg.sender"); + + // Skip the rest of the test since staticcall does not emit events. + return; + } + + await publicClient.call(txData); + + const txHash = await walletClient.sendTransaction(txData); + const txReceipt = await publicClient.waitForTransactionReceipt({ hash: txHash }); + expect(txReceipt.status).to.equal("success"); + + const [receivedAddress] = getResponseFromReceiptLogs( + ctx.implementationAbi, + "MsgSender", + parseAbiParameters("address"), + txReceipt, + ); + + expect(isAddress(receivedAddress), "log.data should be an address").to.be.true; + + const expectedSender = tc.wantChildMsgSender(ctx, walletClient.account.address); + expect(getAddress(receivedAddress)).to.equal(getAddress(expectedSender), "unexpected msg.sender"); + }); + } + }); + } + + function itHasCorrectMsgValue(getContext: () => TestContext<"ContextInspector">) { + describe("msg.value", () => { + let ctx: TestContext<"ContextInspector">; + + before("Setup context", function () { + ctx = getContext(); + }); + + interface valueTestCase { + name: string; + type: CallType; + // msg.value for the parent + giveParentValue: bigint; + // Call value for the child, only applicable for call, callcode + giveChildCallValue?: bigint; + wantRevertReason?: string; + // Expected msg.value for the child + wantChildMsgValue: (txValue: bigint) => bigint; + } + + const testCases: valueTestCase[] = [ + { + name: "direct call", + type: "call", + giveParentValue: 10n, + wantChildMsgValue: (txValue) => txValue, + }, + { + name: "delegatecall propagates msg.value", + type: "delegatecall", + giveParentValue: 10n, + wantChildMsgValue: (txValue) => txValue, + }, + { + name: "callcode with value == parent msg.value", + type: "callcode", + giveParentValue: 10n, + wantChildMsgValue: (txValue) => txValue, + }, + { + name: "callcode with a value != parent msg.value", + type: "callcode", + // Transfers 10 signer -> caller contract + giveParentValue: 10n, + // Transfers 5 Caller -> implementation contract + giveChildCallValue: 5n, + wantChildMsgValue: () => 5n, + }, + { + name: "staticcall", + type: "staticcall", + giveParentValue: 10n, + wantRevertReason: "non-payable function was called with value 10", + wantChildMsgValue: () => 0n, + }, + ]; + + for (const tc of testCases) { + it(tc.name, async function () { + // Initial data of a emitMsgSender() call. + const baseData = encodeFunctionData({ + abi: ctx.implementationAbi, + functionName: "emitMsgValue", + args: [], + }); + + // Modify the call data based on the test case type + const txData = buildCallData(tc.type, ctx.implementationAddress, baseData, tc.giveChildCallValue); + txData.value = tc.giveParentValue; + + if (!tc.wantRevertReason) { + // This throws an error with revert reason to make it easier to debug + // if the transaction fails. + await publicClient.call(txData); + } else { + // rejectedWith is an string includes matcher + await expect(publicClient.call(txData)).to.be.rejectedWith(tc.wantRevertReason); + + // Cannot include msg.value with static call as it changes state so + // skip the rest of the test. + return; + } + + const txHash = await walletClient.sendTransaction(txData); + const txReceipt = await publicClient.waitForTransactionReceipt({ hash: txHash }); + expect(txReceipt.status).to.equal("success"); + + const [emittedAmount] = getResponseFromReceiptLogs( + ctx.implementationAbi, + "MsgValue", + parseAbiParameters("uint256"), + txReceipt, + ); + + // Assert msg.value is as expected + const expectedValue = tc.wantChildMsgValue(txData.value); + expect(emittedAmount).to.equal(expectedValue, "unexpected msg.value"); + }); + } + }); + } + + function itHasCorrectStorageLocation(getContext: () => TestContext<"StorageBasic">) { + describe("storage location", () => { + let ctx: TestContext<"StorageBasic">; + + before("Setup context", function () { + ctx = getContext(); + }); + + interface storageTestCase { + name: string; + callType: CallType; + wantRevert?: boolean; + wantStorageContract?: (ctx: TestContext<"StorageBasic">) => Address; + } + + const testCases: storageTestCase[] = [ + { + name: "call storage in implementation", + callType: "call", + wantStorageContract: (ctx) => ctx.implementationAddress, + }, + { + name: "callcode storage in caller", + callType: "callcode", + // Storage in caller contract + wantStorageContract: (ctx) => ctx.lowLevelCaller.address, + }, + { + name: "delegatecall storage in caller", + callType: "delegatecall", + // Storage in caller contract + wantStorageContract: (ctx) => ctx.lowLevelCaller.address, + }, + { + name: "staticcall storage not allowed", + callType: "staticcall", + wantRevert: true, + // No expected storage due to revert + }, + ]; + + let giveStoreValue = 0n; + + for (const tc of testCases) { + it(tc.name, async function () { + // Increment storage value for a different value each test + giveStoreValue++; + + const baseData = encodeFunctionData({ + abi: ctx.implementationAbi, + functionName: "setStorageValue", + args: [giveStoreValue], + }); + + const txData = buildCallData(tc.callType, ctx.implementationAddress, baseData); + // Signer is the whale + txData.account = whaleAddress; + // Call gas + storage gas + txData.gas = contractCallerGas + 20_2000n; + + if (!txData.to) { + expect.fail("to field not set"); + } + + if (tc.wantRevert) { + await expect(publicClient.call(txData)).to.be.rejected; + + // No actual transaction or storage to check! Skip the rest of the test. + return; + } else { + // Throw revert errors if the transaction fails. Not using expect() + // here since this still fails the test if the transaction fails and + // does not mess up the formatting of the error message. + await publicClient.call(txData); + } + + if (!tc.wantStorageContract) { + expect.fail("expected storage contract set"); + } + + const txHash = await walletClient.sendTransaction(txData); + const txReceipt = await publicClient.waitForTransactionReceipt({ hash: txHash }); + expect(txReceipt.status).to.equal("success"); + + if (txData.gas) { + expect(txReceipt.gasUsed < txData.gas, "gas to not be exhausted").to.be.true; + } + + // Check which contract the storage was set on + const storageContract = tc.wantStorageContract(ctx); + const storageValue = await publicClient.getStorageAt({ + address: storageContract, + slot: toHex(0), + }); + + const expectedStorage = pad(toHex(giveStoreValue)); + expect(storageValue).to.equal(expectedStorage, "unexpected storage value"); + }); + } + }); + } + + // Mock contracts & precompiles need to implement these interfaces. These + // ABIs will be used to test the message call behavior. + const contextInspectorAbi = hre.artifacts.readArtifactSync("ContextInspector").abi; + const storageBasicAbi = hre.artifacts.readArtifactSync("StorageBasic").abi; + + // Test context and storage for mock contracts as a baseline check + describe("Mock", () => { + let contextInspectorMock: GetContractReturnType; + let storageBasicMock: GetContractReturnType; + + before("deploy mock contracts", async function () { + contextInspectorMock = await hre.viem.deployContract("ContextInspectorMock"); + storageBasicMock = await hre.viem.deployContract("StorageBasicMock"); + }); + + itHasCorrectMsgSender(() => ({ + lowLevelCaller: lowLevelCaller, + implementationAbi: contextInspectorAbi, + implementationAddress: contextInspectorMock.address, + })); + + itHasCorrectMsgValue(() => ({ + lowLevelCaller: lowLevelCaller, + implementationAbi: contextInspectorAbi, + implementationAddress: contextInspectorMock.address, + })); + + itHasCorrectStorageLocation(() => ({ + lowLevelCaller: lowLevelCaller, + implementationAbi: storageBasicAbi, + implementationAddress: storageBasicMock.address, + })); + }); + + // TODO: Test context and storage for precompiles +});