diff --git a/rvsol/src/RISCV.sol b/rvsol/src/RISCV.sol index 1bbb681e..3b190109 100644 --- a/rvsol/src/RISCV.sol +++ b/rvsol/src/RISCV.sol @@ -346,6 +346,13 @@ contract RISCV is IBigStepper { } if iszero(eq(_proof.offset, proofContentOffset())) { revert(0, 0) } + if mod(calldataload(sub(proofContentOffset(), 32)), mul(60, 32)) { + // proof offset must be stateContentOffset+paddedStateSize+32 + // proof size: 64-5+1=60 * 32 byte leaf, + // so the proofSize must be a multiple of 60*32 + revert(0, 0) + } + // // State loading // diff --git a/rvsol/test/RISCV.t.sol b/rvsol/test/RISCV.t.sol index 79849ff4..e62b2455 100644 --- a/rvsol/test/RISCV.t.sol +++ b/rvsol/test/RISCV.t.sol @@ -2371,11 +2371,25 @@ contract RISCV_Test is CommonTest { riscv.step(encodedState, proof, 0); } + function test_invalid_proof_size() public { + uint32 insn = encodeRType(0xff, 0, 0, 0, 0, 0); + (State memory state, bytes memory proof) = constructRISCVState(0, insn); + bytes memory encodedState = encodeState(state); + proof = hex"00"; // Invalid memory proof size + + vm.expectRevert(); + riscv.step(encodedState, proof, 0); + } + function test_invalid_proof() public { uint32 insn = encodeRType(0xff, 0, 0, 0, 0, 0); (State memory state, bytes memory proof) = constructRISCVState(0, insn); bytes memory encodedState = encodeState(state); - proof = hex"00"; // Invalid memory proof + + // Overwrite the first 60 bytes of the proof with zero to create invalid memory proof + for (uint256 i = 0; i < 60 && i < proof.length; i++) { + proof[i] = 0x00; + } vm.expectRevert(hex"00000000000000000000000000000000000000000000000000000000badf00d1"); riscv.step(encodedState, proof, 0);