From 3369c7530b9e608f3bcebd6e2336edf11ad7f2b7 Mon Sep 17 00:00:00 2001 From: nonergodic Date: Tue, 3 Sep 2024 01:37:04 -0700 Subject: [PATCH] fix guardian set rotation issue, stick to custom errors only --- src/QueryResponse.sol | 65 +++++++++++++++++++++++-------------------- 1 file changed, 35 insertions(+), 30 deletions(-) diff --git a/src/QueryResponse.sol b/src/QueryResponse.sol index 5e96c74..e03377a 100644 --- a/src/QueryResponse.sol +++ b/src/QueryResponse.sol @@ -128,8 +128,6 @@ struct SolanaPdaResult { uint8 bump; } -// Custom errors - error WrongQueryType(uint8 received, uint8 expected); error InvalidResponseVersion(); error VersionMismatch(); @@ -144,7 +142,6 @@ error InvalidFunctionSignature(); error InvalidChainId(); error StaleBlockNum(); error StaleBlockTime(); -error NoQuorum(); error VerificationFailed(); //QueryResponse is a library that implements the parsing and verification of @@ -163,7 +160,6 @@ library QueryResponseLib { return keccak256(abi.encodePacked(RESPONSE_PREFIX, keccak256(response))); } - //WARNING: see verifyQueryResponse WARNING function parseAndVerifyQueryResponse( address wormhole, bytes memory response, @@ -173,27 +169,33 @@ library QueryResponseLib { return parseQueryResponse(response); } - //WARNING: This call can fail during times of guardian set rotation. - // Unlikely, but possible: - // Since only the current guardian set is considered when verifying signatures here, a response - // will be rejected with a failed verification, even if the signing guardian is still within - // what would otherwise be the 24 hour transition window. function verifyQueryResponse( address wormhole, bytes memory response, IWormhole.Signature[] memory signatures ) internal view { unchecked { IWormhole wormhole_ = IWormhole(wormhole); - IWormhole.GuardianSet memory guardianSet = - wormhole_.getGuardianSet(wormhole_.getCurrentGuardianSetIndex()); - uint quorum = guardianSet.keys.length * 2 / 3 + 1; - if (signatures.length < quorum) - revert NoQuorum(); - - (bool signaturesValid, ) = - wormhole_.verifySignatures(calcPrefixedResponseHash(response), signatures, guardianSet); - if(!signaturesValid) - revert VerificationFailed(); + uint32 guardianSetIndex = wormhole_.getCurrentGuardianSetIndex(); + IWormhole.GuardianSet memory guardianSet = wormhole_.getGuardianSet(guardianSetIndex); + + while (true) { + uint quorum = guardianSet.keys.length * 2 / 3 + 1; + if (signatures.length >= quorum) { + (bool signaturesValid, ) = + wormhole_.verifySignatures(calcPrefixedResponseHash(response), signatures, guardianSet); + if (signaturesValid) + return; + } + + //check if the previous guardian set is still valid and if yes, try with that + if (guardianSetIndex > 0) { + guardianSet = wormhole_.getGuardianSet(--guardianSetIndex); + if (guardianSet.expirationTime < block.timestamp) + revert VerificationFailed(); + } + else + revert VerificationFailed(); + } }} function parseQueryResponse( @@ -207,8 +209,8 @@ library QueryResponseLib { (ret.senderChainId, offset) = response.asUint16Unchecked(offset); - //For off chain requests (chainID zero), the requestId is the 65 byte signature. - //For on chain requests, it is the 32 byte VAA hash. + //for off-chain requests (chainID zero), the requestId is the 65 byte signature + //for on-chain requests, it is the 32 byte VAA hash (ret.requestId, offset) = response.sliceUnchecked(offset, ret.senderChainId == 0 ? 65 : 32); uint32 queryReqLen; @@ -227,7 +229,7 @@ library QueryResponseLib { uint8 numPerChainQueries; (numPerChainQueries, reqOff) = response.asUint8Unchecked(reqOff); - //A valid query request has at least one per chain query + //a valid query request must have at least one per-chain-query if (numPerChainQueries == 0) revert ZeroQueries(); @@ -242,7 +244,7 @@ library QueryResponseLib { ret.responses = new PerChainQueryResponse[](numPerChainQueries); - //Walk through the requests and responses in lock step. + //walk through the requests and responses in lock step. for (uint i; i < numPerChainQueries; ++i) { (ret.responses[i].chainId, reqOff) = response.asUint16Unchecked(reqOff); uint16 respChainId; @@ -262,7 +264,7 @@ library QueryResponseLib { (ret.responses[i].response, respOff) = response.sliceUint32PrefixedUnchecked(respOff); } - //End of request body should align with start of response body + //end of request body should align with start of response body if (startOfResponse != reqOff) revert InvalidPayloadLength(startOfResponse, reqOff); @@ -294,7 +296,7 @@ library QueryResponseLib { ret.results = new EthCallRecord[](numBatchCallData); - //Walk through the call inputs and outputs in lock step. + //walk through the call inputs and outputs in lock step. for (uint i; i < numBatchCallData; ++i) { (ret.results[i].contractAddress, reqOff) = pcr.request.asAddressUnchecked(reqOff); (ret.results[i].callData, reqOff) = pcr.request.sliceUint32PrefixedUnchecked(reqOff); @@ -336,7 +338,7 @@ library QueryResponseLib { ret.results = new EthCallRecord[](numBatchCallData); - // Walk through the call inputs and outputs in lock step. + //walk through the call inputs and outputs in lock step. for (uint i; i < numBatchCallData; ++i) { (ret.results[i].contractAddress, reqOff) = pcr.request.asAddressUnchecked(reqOff); (ret.results[i].callData, reqOff) = pcr.request.sliceUint32PrefixedUnchecked(reqOff); @@ -373,7 +375,7 @@ library QueryResponseLib { ret.results = new EthCallRecord[](numBatchCallData); - //Walk through the call inputs and outputs in lock step. + //walk through the call inputs and outputs in lock step. for (uint i; i < numBatchCallData; ++i) { (ret.results[i].contractAddress, reqOff) = pcr.request.asAddressUnchecked(reqOff); (ret.results[i].callData, reqOff) = pcr.request.sliceUint32PrefixedUnchecked(reqOff); @@ -412,7 +414,7 @@ library QueryResponseLib { ret.results = new SolanaAccountResult[](numAccounts); - //Walk through the call inputs and outputs in lock step. + //walk through the call inputs and outputs in lock step. for (uint i; i < numAccounts; ++i) { (ret.results[i].account, reqOff) = pcr.request.asBytes32Unchecked(reqOff); @@ -455,7 +457,7 @@ library QueryResponseLib { ret.results = new SolanaPdaResult[](numPdas); - //Walk through the call inputs and outputs in lock step. + //walk through the call inputs and outputs in lock step. for (uint i; i < numPdas; ++i) { (ret.results[i].programId, reqOff) = pcr.request.asBytes32Unchecked(reqOff); @@ -539,7 +541,10 @@ library QueryResponseLib { validateContractAddress(ecd.contractAddress, validContractAddresses); if (validFunctionSignatures.length > 0) { - (bytes4 funcSig,) = ecd.callData.asBytes4(0); + if (ecd.callData.length < 4) + revert InvalidFunctionSignature(); + + (bytes4 funcSig, uint offset) = ecd.callData.asBytes4Unchecked(0); validateFunctionSignature(funcSig, validFunctionSignatures); } }}