Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support more efficient merkle proofs through calldata #3200

Merged
merged 9 commits into from
Jun 1, 2022
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
* `EnumerableMap`: add new `Bytes32ToUintMap` map type. ([#3416](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3416))
* `SafeCast`: add support for many more types, using procedural code generation. ([#3245](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3245))
* `MerkleProof`: add `multiProofVerify` to prove multiple values are part of a Merkle tree. ([#3276](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3276))
* `MerkleProof`: add calldata versions of the functions to avoid copying input arrays to memory and save gas. ([#3200](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3200))
* `ERC721`, `ERC1155`: simplified revert reasons. ([#3254](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3254))
* `ERC721`: removed redundant require statement. ([#3434](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3434))
* `PaymentSplitter`: add `releasable` getters. ([#3350](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3350))
Expand Down
24 changes: 18 additions & 6 deletions contracts/mocks/MerkleProofWrapper.sol
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,35 @@ contract MerkleProofWrapper {
return MerkleProof.verify(proof, root, leaf);
}

function verifyCalldata(
bytes32[] calldata proof,
bytes32 root,
bytes32 leaf
) public pure returns (bool) {
return MerkleProof.verifyCalldata(proof, root, leaf);
}

function processProof(bytes32[] memory proof, bytes32 leaf) public pure returns (bytes32) {
return MerkleProof.processProof(proof, leaf);
}

function processProofCalldata(bytes32[] calldata proof, bytes32 leaf) public pure returns (bytes32) {
return MerkleProof.processProofCalldata(proof, leaf);
}

function multiProofVerify(
bytes32 root,
bytes32[] memory leafs,
bytes32[] memory proofs,
bool[] memory proofFlag
bytes32[] calldata leafs,
bytes32[] calldata proofs,
bool[] calldata proofFlag
) public pure returns (bool) {
return MerkleProof.multiProofVerify(root, leafs, proofs, proofFlag);
}

function processMultiProof(
bytes32[] memory leafs,
bytes32[] memory proofs,
bool[] memory proofFlag
bytes32[] calldata leafs,
bytes32[] calldata proofs,
bool[] calldata proofFlag
) public pure returns (bytes32) {
return MerkleProof.processMultiProof(leafs, proofs, proofFlag);
}
Expand Down
55 changes: 40 additions & 15 deletions contracts/utils/cryptography/MerkleProof.sol
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,19 @@ library MerkleProof {
return processProof(proof, leaf) == root;
}

/**
* @dev Calldata version of {verify}
*
* _Available since v4.7._
*/
function verifyCalldata(
bytes32[] calldata proof,
bytes32 root,
bytes32 leaf
) internal pure returns (bool) {
return processProofCalldata(proof, leaf) == root;
}

/**
* @dev Returns the rebuilt hash obtained by traversing a Merkle tree up
* from `leaf` using `proof`. A `proof` is valid if and only if the rebuilt
Expand All @@ -48,6 +61,19 @@ library MerkleProof {
return computedHash;
}

/**
* @dev Calldata version of {processProof}
*
* _Available since v4.7._
*/
function processProofCalldata(bytes32[] calldata proof, bytes32 leaf) internal pure returns (bytes32) {
bytes32 computedHash = leaf;
for (uint256 i = 0; i < proof.length; i++) {
computedHash = _hashPair(computedHash, proof[i]);
}
return computedHash;
}

/**
* @dev Returns true if a `leafs` can be proved to be a part of a Merkle tree
* defined by `root`. For this, `proofs` for each leaf must be provided, containing
Expand All @@ -58,11 +84,11 @@ library MerkleProof {
*/
function multiProofVerify(
bytes32 root,
bytes32[] memory leafs,
bytes32[] memory proofs,
bool[] memory proofFlag
bytes32[] calldata leaves,
bytes32[] calldata proofs,
bool[] calldata proofFlag
) internal pure returns (bool) {
return processMultiProof(leafs, proofs, proofFlag) == root;
return processMultiProof(leaves, proofs, proofFlag) == root;
}

/**
Expand All @@ -73,20 +99,19 @@ library MerkleProof {
* _Available since v4.7._
*/
function processMultiProof(
bytes32[] memory leafs,
bytes32[] memory proofs,
bool[] memory proofFlag
bytes32[] calldata leaves,
bytes32[] calldata proofs,
bool[] calldata proofFlag
) internal pure returns (bytes32 merkleRoot) {
// This function rebuild the root hash by traversing the tree up from the leaves. The root is rebuilt by
// consuming and producing values on a queue. The queue starts with the `leafs` array, then goes onto the
// consuming and producing values on a queue. The queue starts with the `leaves` array, then goes onto the
// `hashes` array. At the end of the process, the last hash in the `hashes` array should contain the root of
// the merkle tree.
uint256 leafsLen = leafs.length;
uint256 proofsLen = proofs.length;
uint256 leavesLen = leaves.length;
uint256 totalHashes = proofFlag.length;

// Check proof validity.
require(leafsLen + proofsLen - 1 == totalHashes, "MerkleProof: invalid multiproof");
require(leavesLen + proofs.length - 1 == totalHashes, "MerkleProof: invalid multiproof");

// The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using
// `xxx[xxxPos++]`, which return the current value and increment the pointer, thus mimicking a queue's "pop".
Expand All @@ -100,15 +125,15 @@ library MerkleProof {
// - depending on the flag, either another value for the "main queue" (merging branches) or an element from the
// `proofs` array.
for (uint256 i = 0; i < totalHashes; i++) {
bytes32 a = leafPos < leafsLen ? leafs[leafPos++] : hashes[hashPos++];
bytes32 b = proofFlag[i] ? leafPos < leafsLen ? leafs[leafPos++] : hashes[hashPos++] : proofs[proofPos++];
bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++];
bytes32 b = proofFlag[i] ? leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++] : proofs[proofPos++];
hashes[i] = _hashPair(a, b);
}

if (totalHashes > 0) {
return hashes[totalHashes - 1];
} else if (leafsLen > 0) {
return leafs[0];
} else if (leavesLen > 0) {
return leaves[0];
} else {
return proofs[0];
}
Expand Down
12 changes: 8 additions & 4 deletions test/utils/cryptography/MerkleProof.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ contract('MerkleProof', function (accounts) {
const proof = merkleTree.getHexProof(leaf);

expect(await this.merkleProof.verify(proof, root, leaf)).to.equal(true);
expect(await this.merkleProof.verifyCalldata(proof, root, leaf)).to.equal(true);

// For demonstration, it is also possible to create valid proofs for certain 64-byte values *not* in elements:
const noSuchLeaf = keccak256(
Buffer.concat([keccak256(elements[0]), keccak256(elements[1])].sort(Buffer.compare)),
);
expect(await this.merkleProof.verify(proof.slice(1), root, noSuchLeaf)).to.equal(true);
expect(await this.merkleProof.verifyCalldata(proof.slice(1), root, noSuchLeaf)).to.equal(true);
});

it('returns false for an invalid Merkle proof', async function () {
Expand All @@ -47,6 +49,7 @@ contract('MerkleProof', function (accounts) {
const badProof = badMerkleTree.getHexProof(badElements[0]);

expect(await this.merkleProof.verify(badProof, correctRoot, correctLeaf)).to.equal(false);
expect(await this.merkleProof.verifyCalldata(badProof, correctRoot, correctLeaf)).to.equal(false);
});

it('returns false for a Merkle proof of invalid length', async function () {
Expand All @@ -61,6 +64,7 @@ contract('MerkleProof', function (accounts) {
const badProof = proof.slice(0, proof.length - 5);

expect(await this.merkleProof.verify(badProof, root, leaf)).to.equal(false);
expect(await this.merkleProof.verifyCalldata(badProof, root, leaf)).to.equal(false);
});
});

Expand Down Expand Up @@ -93,15 +97,15 @@ contract('MerkleProof', function (accounts) {
it('revert with invalid multi proof #1', async function () {
const fill = Buffer.alloc(32); // This could be anything, we are reconstructing a fake branch
const leaves = ['a', 'b', 'c', 'd'].map(keccak256).sort(Buffer.compare);
const badLeave = keccak256('e');
const badLeaf = keccak256('e');
const merkleTree = new MerkleTree(leaves, keccak256, { sort: true });

const root = merkleTree.getRoot();

await expectRevert(
this.merkleProof.multiProofVerify(
root,
[ leaves[0], badLeave ], // A, E
[ leaves[0], badLeaf ], // A, E
[ leaves[1], fill, merkleTree.layers[1][1] ],
[ false, false, false ],
),
Expand All @@ -112,15 +116,15 @@ contract('MerkleProof', function (accounts) {
it('revert with invalid multi proof #2', async function () {
const fill = Buffer.alloc(32); // This could be anything, we are reconstructing a fake branch
const leaves = ['a', 'b', 'c', 'd'].map(keccak256).sort(Buffer.compare);
const badLeave = keccak256('e');
const badLeaf = keccak256('e');
const merkleTree = new MerkleTree(leaves, keccak256, { sort: true });

const root = merkleTree.getRoot();

await expectRevert(
this.merkleProof.multiProofVerify(
root,
[ badLeave, leaves[0] ], // A, E
[ badLeaf, leaves[0] ], // A, E
[ leaves[1], fill, merkleTree.layers[1][1] ],
[ false, false, false, false ],
),
Expand Down