diff --git a/.changeset/fluffy-buses-jump.md b/.changeset/fluffy-buses-jump.md new file mode 100644 index 00000000000..0525a4d8e43 --- /dev/null +++ b/.changeset/fluffy-buses-jump.md @@ -0,0 +1,5 @@ +--- +'openzeppelin-solidity': minor +--- + +`Comparator`: A library of comparator functions, useful for customizing the behavior of the Heap structure. diff --git a/.changeset/great-pianos-work.md b/.changeset/great-pianos-work.md new file mode 100644 index 00000000000..da54483e47e --- /dev/null +++ b/.changeset/great-pianos-work.md @@ -0,0 +1,5 @@ +--- +'openzeppelin-solidity': minor +--- + +`Heap`: A data structure that implements a heap-based priority queue. diff --git a/.githooks/pre-push b/.githooks/pre-push index a51f3884a5d..f028ce58e0b 100755 --- a/.githooks/pre-push +++ b/.githooks/pre-push @@ -3,5 +3,6 @@ set -euo pipefail if [ "${CI:-"false"}" != "true" ]; then + npm run test:generation npm run lint fi diff --git a/contracts/mocks/Stateless.sol b/contracts/mocks/Stateless.sol index 7f18d573fda..846c77d98e8 100644 --- a/contracts/mocks/Stateless.sol +++ b/contracts/mocks/Stateless.sol @@ -22,6 +22,7 @@ import {ERC165} from "../utils/introspection/ERC165.sol"; import {ERC165Checker} from "../utils/introspection/ERC165Checker.sol"; import {ERC1967Utils} from "../proxy/ERC1967/ERC1967Utils.sol"; import {ERC721Holder} from "../token/ERC721/utils/ERC721Holder.sol"; +import {Heap} from "../utils/structs/Heap.sol"; import {Math} from "../utils/math/Math.sol"; import {MerkleProof} from "../utils/cryptography/MerkleProof.sol"; import {MessageHashUtils} from "../utils/cryptography/MessageHashUtils.sol"; diff --git a/contracts/utils/Arrays.sol b/contracts/utils/Arrays.sol index d67ae90ba2b..fe54bafee7c 100644 --- a/contracts/utils/Arrays.sol +++ b/contracts/utils/Arrays.sol @@ -4,6 +4,7 @@ pragma solidity ^0.8.20; +import {Comparators} from "./Comparators.sol"; import {SlotDerivation} from "./SlotDerivation.sol"; import {StorageSlot} from "./StorageSlot.sol"; import {Math} from "./math/Math.sol"; @@ -16,7 +17,7 @@ library Arrays { using StorageSlot for bytes32; /** - * @dev Sort an array of bytes32 (in memory) following the provided comparator function. + * @dev Sort an array of uint256 (in memory) following the provided comparator function. * * This function does the sorting "in place", meaning that it overrides the input. The object is returned for * convenience, but that returned value can be discarded safely if the caller has a memory pointer to the array. @@ -27,18 +28,18 @@ library Arrays { * consume more gas than is available in a block, leading to potential DoS. */ function sort( - bytes32[] memory array, - function(bytes32, bytes32) pure returns (bool) comp - ) internal pure returns (bytes32[] memory) { + uint256[] memory array, + function(uint256, uint256) pure returns (bool) comp + ) internal pure returns (uint256[] memory) { _quickSort(_begin(array), _end(array), comp); return array; } /** - * @dev Variant of {sort} that sorts an array of bytes32 in increasing order. + * @dev Variant of {sort} that sorts an array of uint256 in increasing order. */ - function sort(bytes32[] memory array) internal pure returns (bytes32[] memory) { - sort(array, _defaultComp); + function sort(uint256[] memory array) internal pure returns (uint256[] memory) { + sort(array, Comparators.lt); return array; } @@ -57,7 +58,7 @@ library Arrays { address[] memory array, function(address, address) pure returns (bool) comp ) internal pure returns (address[] memory) { - sort(_castToBytes32Array(array), _castToBytes32Comp(comp)); + sort(_castToUint256Array(array), _castToUint256Comp(comp)); return array; } @@ -65,12 +66,12 @@ library Arrays { * @dev Variant of {sort} that sorts an array of address in increasing order. */ function sort(address[] memory array) internal pure returns (address[] memory) { - sort(_castToBytes32Array(array), _defaultComp); + sort(_castToUint256Array(array), Comparators.lt); return array; } /** - * @dev Sort an array of uint256 (in memory) following the provided comparator function. + * @dev Sort an array of bytes32 (in memory) following the provided comparator function. * * This function does the sorting "in place", meaning that it overrides the input. The object is returned for * convenience, but that returned value can be discarded safely if the caller has a memory pointer to the array. @@ -81,18 +82,18 @@ library Arrays { * consume more gas than is available in a block, leading to potential DoS. */ function sort( - uint256[] memory array, - function(uint256, uint256) pure returns (bool) comp - ) internal pure returns (uint256[] memory) { - sort(_castToBytes32Array(array), _castToBytes32Comp(comp)); + bytes32[] memory array, + function(bytes32, bytes32) pure returns (bool) comp + ) internal pure returns (bytes32[] memory) { + sort(_castToUint256Array(array), _castToUint256Comp(comp)); return array; } /** - * @dev Variant of {sort} that sorts an array of uint256 in increasing order. + * @dev Variant of {sort} that sorts an array of bytes32 in increasing order. */ - function sort(uint256[] memory array) internal pure returns (uint256[] memory) { - sort(_castToBytes32Array(array), _defaultComp); + function sort(bytes32[] memory array) internal pure returns (bytes32[] memory) { + sort(_castToUint256Array(array), Comparators.lt); return array; } @@ -105,12 +106,12 @@ library Arrays { * IMPORTANT: Memory locations between `begin` and `end` are not validated/zeroed. This function should * be used only if the limits are within a memory array. */ - function _quickSort(uint256 begin, uint256 end, function(bytes32, bytes32) pure returns (bool) comp) private pure { + function _quickSort(uint256 begin, uint256 end, function(uint256, uint256) pure returns (bool) comp) private pure { unchecked { if (end - begin < 0x40) return; // Use first element as pivot - bytes32 pivot = _mload(begin); + uint256 pivot = _mload(begin); // Position where the pivot should be at the end of the loop uint256 pos = begin; @@ -132,7 +133,7 @@ library Arrays { /** * @dev Pointer to the memory location of the first element of `array`. */ - function _begin(bytes32[] memory array) private pure returns (uint256 ptr) { + function _begin(uint256[] memory array) private pure returns (uint256 ptr) { /// @solidity memory-safe-assembly assembly { ptr := add(array, 0x20) @@ -143,16 +144,16 @@ library Arrays { * @dev Pointer to the memory location of the first memory word (32bytes) after `array`. This is the memory word * that comes just after the last element of the array. */ - function _end(bytes32[] memory array) private pure returns (uint256 ptr) { + function _end(uint256[] memory array) private pure returns (uint256 ptr) { unchecked { return _begin(array) + array.length * 0x20; } } /** - * @dev Load memory word (as a bytes32) at location `ptr`. + * @dev Load memory word (as a uint256) at location `ptr`. */ - function _mload(uint256 ptr) private pure returns (bytes32 value) { + function _mload(uint256 ptr) private pure returns (uint256 value) { assembly { value := mload(ptr) } @@ -170,38 +171,33 @@ library Arrays { } } - /// @dev Comparator for sorting arrays in increasing order. - function _defaultComp(bytes32 a, bytes32 b) private pure returns (bool) { - return a < b; - } - /// @dev Helper: low level cast address memory array to uint256 memory array - function _castToBytes32Array(address[] memory input) private pure returns (bytes32[] memory output) { + function _castToUint256Array(address[] memory input) private pure returns (uint256[] memory output) { assembly { output := input } } - /// @dev Helper: low level cast uint256 memory array to uint256 memory array - function _castToBytes32Array(uint256[] memory input) private pure returns (bytes32[] memory output) { + /// @dev Helper: low level cast bytes32 memory array to uint256 memory array + function _castToUint256Array(bytes32[] memory input) private pure returns (uint256[] memory output) { assembly { output := input } } - /// @dev Helper: low level cast address comp function to bytes32 comp function - function _castToBytes32Comp( + /// @dev Helper: low level cast address comp function to uint256 comp function + function _castToUint256Comp( function(address, address) pure returns (bool) input - ) private pure returns (function(bytes32, bytes32) pure returns (bool) output) { + ) private pure returns (function(uint256, uint256) pure returns (bool) output) { assembly { output := input } } - /// @dev Helper: low level cast uint256 comp function to bytes32 comp function - function _castToBytes32Comp( - function(uint256, uint256) pure returns (bool) input - ) private pure returns (function(bytes32, bytes32) pure returns (bool) output) { + /// @dev Helper: low level cast bytes32 comp function to uint256 comp function + function _castToUint256Comp( + function(bytes32, bytes32) pure returns (bool) input + ) private pure returns (function(uint256, uint256) pure returns (bool) output) { assembly { output := input } diff --git a/contracts/utils/Comparators.sol b/contracts/utils/Comparators.sol new file mode 100644 index 00000000000..3a63aa0e8ee --- /dev/null +++ b/contracts/utils/Comparators.sol @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.20; + +library Comparators { + function lt(uint256 a, uint256 b) internal pure returns (bool) { + return a < b; + } + + function gt(uint256 a, uint256 b) internal pure returns (bool) { + return a > b; + } +} diff --git a/contracts/utils/README.adoc b/contracts/utils/README.adoc index 71c39dfa9c5..da9eae6a887 100644 --- a/contracts/utils/README.adoc +++ b/contracts/utils/README.adoc @@ -25,6 +25,7 @@ Miscellaneous contracts and libraries containing utility functions you can use t * {DoubleEndedQueue}: An implementation of a https://en.wikipedia.org/wiki/Double-ended_queue[double ended queue] whose values can be removed added or remove from both sides. Useful for FIFO and LIFO structures. * {CircularBuffer}: A data structure to store the last N values pushed to it. * {Checkpoints}: A data structure to store values mapped to an strictly increasing key. Can be used for storing and accessing values over time. + * {Heap}: A library that implements a https://en.wikipedia.org/wiki/Binary_heap[binary heap] in storage. * {MerkleTree}: A library with https://wikipedia.org/wiki/Merkle_Tree[Merkle Tree] data structures and helper functions. * {Create2}: Wrapper around the https://blog.openzeppelin.com/getting-the-most-out-of-create2/[`CREATE2` EVM opcode] for safe use without having to deal with low-level assembly. * {Address}: Collection of functions for overloading Solidity's https://docs.soliditylang.org/en/latest/types.html#address[`address`] type. @@ -38,6 +39,7 @@ Miscellaneous contracts and libraries containing utility functions you can use t * {Context}: An utility for abstracting the sender and calldata in the current execution context. * {Packing}: A library for packing and unpacking multiple values into bytes32 * {Panic}: A library to revert with https://docs.soliditylang.org/en/v0.8.20/control-structures.html#panic-via-assert-and-error-via-require[Solidity panic codes]. + * {Comparators}: A library that contains comparator functions to use with with the {Heap} library. [NOTE] ==== @@ -102,6 +104,8 @@ Ethereum contracts have no native concept of an interface, so applications must {{Checkpoints}} +{{Heap}} + {{MerkleTree}} == Libraries @@ -129,3 +133,5 @@ Ethereum contracts have no native concept of an interface, so applications must {{Packing}} {{Panic}} + +{{Comparators}} diff --git a/contracts/utils/structs/Heap.sol b/contracts/utils/structs/Heap.sol new file mode 100644 index 00000000000..ad684d40bdb --- /dev/null +++ b/contracts/utils/structs/Heap.sol @@ -0,0 +1,578 @@ +// SPDX-License-Identifier: MIT +// This file was procedurally generated from scripts/generate/templates/Heap.js. + +pragma solidity ^0.8.20; + +import {Math} from "../math/Math.sol"; +import {SafeCast} from "../math/SafeCast.sol"; +import {Comparators} from "../Comparators.sol"; +import {Panic} from "../Panic.sol"; + +/** + * @dev Library for managing https://en.wikipedia.org/wiki/Binary_heap[binary heap] that can be used as + * https://en.wikipedia.org/wiki/Priority_queue[priority queue]. + * + * Heaps are represented as an array of Node objects. This array stores two overlapping structures: + * * A tree structure where the first element (index 0) is the root, and where the node at index i is the child of the + * node at index (i-1)/2 and the father of nodes at index 2*i+1 and 2*i+2. Each node stores the index (in the array) + * where the corresponding value is stored. + * * A list of payloads values where each index contains a value and a lookup index. The type of the value depends on + * the variant being used. The lookup is the index of the node (in the tree) that points to this value. + * + * Some invariants: + * ``` + * i == heap.data[heap.data[i].index].lookup // for all indices i + * i == heap.data[heap.data[i].lookup].index // for all indices i + * ``` + * + * The structure is ordered so that each node is bigger than its parent. An immediate consequence is that the + * highest priority value is the one at the root. This value can be lookup up in constant time (O(1)) at + * `heap.data[heap.data[0].index].value` + * + * The structure is designed to perform the following operations with the corresponding complexities: + * + * * peek (get the highest priority in set): O(1) + * * insert (insert a value in the set): 0(log(n)) + * * pop (remove the highest priority value in set): O(log(n)) + * * replace (replace the highest priority value in set with a new value): O(log(n)) + * * length (get the number of elements in the set): O(1) + * * clear (remove all elements in the set): O(1) + */ +library Heap { + using Math for *; + using SafeCast for *; + + /** + * @dev Binary heap that support values of type uint256. + * + * Each element of that structures uses 2 storage slots. + */ + struct Uint256Heap { + Uint256HeapNode[] data; + } + + /** + * @dev Internal node type for Uint256Heap. Stores a value of type uint256. + */ + struct Uint256HeapNode { + uint256 value; + uint64 index; // position -> value + uint64 lookup; // value -> position + } + + /** + * @dev Lookup the root element of the heap. + */ + function peek(Uint256Heap storage self) internal view returns (uint256) { + // self.data[0] will `ARRAY_ACCESS_OUT_OF_BOUNDS` panic if heap is empty. + return _unsafeNodeAccess(self, self.data[0].index).value; + } + + /** + * @dev Remove (and return) the root element for the heap using the default comparator. + * + * NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator + * during the lifecycle of a heap will result in undefined behavior. + */ + function pop(Uint256Heap storage self) internal returns (uint256) { + return pop(self, Comparators.lt); + } + + /** + * @dev Remove (and return) the root element for the heap using the provided comparator. + * + * NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator + * during the lifecycle of a heap will result in undefined behavior. + */ + function pop( + Uint256Heap storage self, + function(uint256, uint256) view returns (bool) comp + ) internal returns (uint256) { + unchecked { + uint64 size = length(self); + if (size == 0) Panic.panic(Panic.EMPTY_ARRAY_POP); + + uint64 last = size - 1; + + // get root location (in the data array) and value + Uint256HeapNode storage rootNode = _unsafeNodeAccess(self, 0); + uint64 rootIdx = rootNode.index; + Uint256HeapNode storage rootData = _unsafeNodeAccess(self, rootIdx); + Uint256HeapNode storage lastNode = _unsafeNodeAccess(self, last); + uint256 rootDataValue = rootData.value; + + // if root is not the last element of the data array (that will get pop-ed), reorder the data array. + if (rootIdx != last) { + // get details about the value stored in the last element of the array (that will get pop-ed) + uint64 lastDataIdx = lastNode.lookup; + uint256 lastDataValue = lastNode.value; + // copy these values to the location of the root (that is safe, and that we no longer use) + rootData.value = lastDataValue; + rootData.lookup = lastDataIdx; + // update the tree node that used to point to that last element (value now located where the root was) + _unsafeNodeAccess(self, lastDataIdx).index = rootIdx; + } + + // get last leaf location (in the data array) and value + uint64 lastIdx = lastNode.index; + uint256 lastValue = _unsafeNodeAccess(self, lastIdx).value; + + // move the last leaf to the root, pop last leaf ... + rootNode.index = lastIdx; + _unsafeNodeAccess(self, lastIdx).lookup = 0; + self.data.pop(); + + // ... and heapify + _siftDown(self, last, 0, lastValue, comp); + + // return root value + return rootDataValue; + } + } + + /** + * @dev Insert a new element in the heap using the default comparator. + * + * NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator + * during the lifecycle of a heap will result in undefined behavior. + */ + function insert(Uint256Heap storage self, uint256 value) internal { + insert(self, value, Comparators.lt); + } + + /** + * @dev Insert a new element in the heap using the provided comparator. + * + * NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator + * during the lifecycle of a heap will result in undefined behavior. + */ + function insert( + Uint256Heap storage self, + uint256 value, + function(uint256, uint256) view returns (bool) comp + ) internal { + uint64 size = length(self); + if (size == type(uint64).max) Panic.panic(Panic.RESOURCE_ERROR); + + self.data.push(Uint256HeapNode({index: size, lookup: size, value: value})); + _siftUp(self, size, value, comp); + } + + /** + * @dev Return the root element for the heap, and replace it with a new value, using the default comparator. + * This is equivalent to using {pop} and {insert}, but requires only one rebalancing operation. + * + * NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator + * during the lifecycle of a heap will result in undefined behavior. + */ + function replace(Uint256Heap storage self, uint256 newValue) internal returns (uint256) { + return replace(self, newValue, Comparators.lt); + } + + /** + * @dev Return the root element for the heap, and replace it with a new value, using the provided comparator. + * This is equivalent to using {pop} and {insert}, but requires only one rebalancing operation. + * + * NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator + * during the lifecycle of a heap will result in undefined behavior. + */ + function replace( + Uint256Heap storage self, + uint256 newValue, + function(uint256, uint256) view returns (bool) comp + ) internal returns (uint256) { + uint64 size = length(self); + if (size == 0) Panic.panic(Panic.EMPTY_ARRAY_POP); + + // position of the node that holds the data for the root + uint64 rootIdx = _unsafeNodeAccess(self, 0).index; + // storage pointer to the node that holds the data for the root + Uint256HeapNode storage rootData = _unsafeNodeAccess(self, rootIdx); + + // cache old value and replace it + uint256 oldValue = rootData.value; + rootData.value = newValue; + + // re-heapify + _siftDown(self, size, 0, newValue, comp); + + // return old root value + return oldValue; + } + + /** + * @dev Returns the number of elements in the heap. + */ + function length(Uint256Heap storage self) internal view returns (uint64) { + return self.data.length.toUint64(); + } + + /** + * @dev Removes all elements in the heap. + */ + function clear(Uint256Heap storage self) internal { + Uint256HeapNode[] storage data = self.data; + /// @solidity memory-safe-assembly + assembly { + sstore(data.slot, 0) + } + } + + /* + * @dev Swap node `i` and `j` in the tree. + */ + function _swap(Uint256Heap storage self, uint64 i, uint64 j) private { + Uint256HeapNode storage ni = _unsafeNodeAccess(self, i); + Uint256HeapNode storage nj = _unsafeNodeAccess(self, j); + uint64 ii = ni.index; + uint64 jj = nj.index; + // update pointers to the data (swap the value) + ni.index = jj; + nj.index = ii; + // update lookup pointers for consistency + _unsafeNodeAccess(self, ii).lookup = j; + _unsafeNodeAccess(self, jj).lookup = i; + } + + /** + * @dev Perform heap maintenance on `self`, starting at position `pos` (with the `value`), using `comp` as a + * comparator, and moving toward the leafs of the underlying tree. + * + * NOTE: This is a private function that is called in a trusted context with already cached parameters. `length` + * and `value` could be extracted from `self` and `pos`, but that would require redundant storage read. These + * parameters are not verified. It is the caller role to make sure the parameters are correct. + */ + function _siftDown( + Uint256Heap storage self, + uint64 size, + uint64 pos, + uint256 value, + function(uint256, uint256) view returns (bool) comp + ) private { + uint256 left = 2 * pos + 1; // this could overflow uint64 + uint256 right = 2 * pos + 2; // this could overflow uint64 + + if (right < size) { + // the check guarantees that `left` and `right` are both valid uint32 + uint64 lIndex = uint64(left); + uint64 rIndex = uint64(right); + uint256 lValue = _unsafeNodeAccess(self, _unsafeNodeAccess(self, lIndex).index).value; + uint256 rValue = _unsafeNodeAccess(self, _unsafeNodeAccess(self, rIndex).index).value; + if (comp(lValue, value) || comp(rValue, value)) { + uint64 index = uint64(comp(lValue, rValue).ternary(lIndex, rIndex)); + _swap(self, pos, index); + _siftDown(self, size, index, value, comp); + } + } else if (left < size) { + // the check guarantees that `left` is a valid uint32 + uint64 lIndex = uint64(left); + uint256 lValue = _unsafeNodeAccess(self, _unsafeNodeAccess(self, lIndex).index).value; + if (comp(lValue, value)) { + _swap(self, pos, lIndex); + _siftDown(self, size, lIndex, value, comp); + } + } + } + + /** + * @dev Perform heap maintenance on `self`, starting at position `pos` (with the `value`), using `comp` as a + * comparator, and moving toward the root of the underlying tree. + * + * NOTE: This is a private function that is called in a trusted context with already cached parameters. `value` + * could be extracted from `self` and `pos`, but that would require redundant storage read. This parameters is not + * verified. It is the caller role to make sure the parameters are correct. + */ + function _siftUp( + Uint256Heap storage self, + uint64 pos, + uint256 value, + function(uint256, uint256) view returns (bool) comp + ) private { + unchecked { + while (pos > 0) { + uint64 parent = (pos - 1) / 2; + uint256 parentValue = _unsafeNodeAccess(self, _unsafeNodeAccess(self, parent).index).value; + if (comp(parentValue, value)) break; + _swap(self, pos, parent); + pos = parent; + } + } + } + + function _unsafeNodeAccess( + Uint256Heap storage self, + uint64 pos + ) private pure returns (Uint256HeapNode storage result) { + assembly ("memory-safe") { + mstore(0x00, self.slot) + result.slot := add(keccak256(0x00, 0x20), mul(pos, 2)) + } + } + + /** + * @dev Binary heap that support values of type uint208. + * + * Each element of that structures uses 1 storage slots. + */ + struct Uint208Heap { + Uint208HeapNode[] data; + } + + /** + * @dev Internal node type for Uint208Heap. Stores a value of type uint208. + */ + struct Uint208HeapNode { + uint208 value; + uint24 index; // position -> value + uint24 lookup; // value -> position + } + + /** + * @dev Lookup the root element of the heap. + */ + function peek(Uint208Heap storage self) internal view returns (uint208) { + // self.data[0] will `ARRAY_ACCESS_OUT_OF_BOUNDS` panic if heap is empty. + return _unsafeNodeAccess(self, self.data[0].index).value; + } + + /** + * @dev Remove (and return) the root element for the heap using the default comparator. + * + * NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator + * during the lifecycle of a heap will result in undefined behavior. + */ + function pop(Uint208Heap storage self) internal returns (uint208) { + return pop(self, Comparators.lt); + } + + /** + * @dev Remove (and return) the root element for the heap using the provided comparator. + * + * NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator + * during the lifecycle of a heap will result in undefined behavior. + */ + function pop( + Uint208Heap storage self, + function(uint256, uint256) view returns (bool) comp + ) internal returns (uint208) { + unchecked { + uint24 size = length(self); + if (size == 0) Panic.panic(Panic.EMPTY_ARRAY_POP); + + uint24 last = size - 1; + + // get root location (in the data array) and value + Uint208HeapNode storage rootNode = _unsafeNodeAccess(self, 0); + uint24 rootIdx = rootNode.index; + Uint208HeapNode storage rootData = _unsafeNodeAccess(self, rootIdx); + Uint208HeapNode storage lastNode = _unsafeNodeAccess(self, last); + uint208 rootDataValue = rootData.value; + + // if root is not the last element of the data array (that will get pop-ed), reorder the data array. + if (rootIdx != last) { + // get details about the value stored in the last element of the array (that will get pop-ed) + uint24 lastDataIdx = lastNode.lookup; + uint208 lastDataValue = lastNode.value; + // copy these values to the location of the root (that is safe, and that we no longer use) + rootData.value = lastDataValue; + rootData.lookup = lastDataIdx; + // update the tree node that used to point to that last element (value now located where the root was) + _unsafeNodeAccess(self, lastDataIdx).index = rootIdx; + } + + // get last leaf location (in the data array) and value + uint24 lastIdx = lastNode.index; + uint208 lastValue = _unsafeNodeAccess(self, lastIdx).value; + + // move the last leaf to the root, pop last leaf ... + rootNode.index = lastIdx; + _unsafeNodeAccess(self, lastIdx).lookup = 0; + self.data.pop(); + + // ... and heapify + _siftDown(self, last, 0, lastValue, comp); + + // return root value + return rootDataValue; + } + } + + /** + * @dev Insert a new element in the heap using the default comparator. + * + * NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator + * during the lifecycle of a heap will result in undefined behavior. + */ + function insert(Uint208Heap storage self, uint208 value) internal { + insert(self, value, Comparators.lt); + } + + /** + * @dev Insert a new element in the heap using the provided comparator. + * + * NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator + * during the lifecycle of a heap will result in undefined behavior. + */ + function insert( + Uint208Heap storage self, + uint208 value, + function(uint256, uint256) view returns (bool) comp + ) internal { + uint24 size = length(self); + if (size == type(uint24).max) Panic.panic(Panic.RESOURCE_ERROR); + + self.data.push(Uint208HeapNode({index: size, lookup: size, value: value})); + _siftUp(self, size, value, comp); + } + + /** + * @dev Return the root element for the heap, and replace it with a new value, using the default comparator. + * This is equivalent to using {pop} and {insert}, but requires only one rebalancing operation. + * + * NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator + * during the lifecycle of a heap will result in undefined behavior. + */ + function replace(Uint208Heap storage self, uint208 newValue) internal returns (uint208) { + return replace(self, newValue, Comparators.lt); + } + + /** + * @dev Return the root element for the heap, and replace it with a new value, using the provided comparator. + * This is equivalent to using {pop} and {insert}, but requires only one rebalancing operation. + * + * NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator + * during the lifecycle of a heap will result in undefined behavior. + */ + function replace( + Uint208Heap storage self, + uint208 newValue, + function(uint256, uint256) view returns (bool) comp + ) internal returns (uint208) { + uint24 size = length(self); + if (size == 0) Panic.panic(Panic.EMPTY_ARRAY_POP); + + // position of the node that holds the data for the root + uint24 rootIdx = _unsafeNodeAccess(self, 0).index; + // storage pointer to the node that holds the data for the root + Uint208HeapNode storage rootData = _unsafeNodeAccess(self, rootIdx); + + // cache old value and replace it + uint208 oldValue = rootData.value; + rootData.value = newValue; + + // re-heapify + _siftDown(self, size, 0, newValue, comp); + + // return old root value + return oldValue; + } + + /** + * @dev Returns the number of elements in the heap. + */ + function length(Uint208Heap storage self) internal view returns (uint24) { + return self.data.length.toUint24(); + } + + /** + * @dev Removes all elements in the heap. + */ + function clear(Uint208Heap storage self) internal { + Uint208HeapNode[] storage data = self.data; + /// @solidity memory-safe-assembly + assembly { + sstore(data.slot, 0) + } + } + + /* + * @dev Swap node `i` and `j` in the tree. + */ + function _swap(Uint208Heap storage self, uint24 i, uint24 j) private { + Uint208HeapNode storage ni = _unsafeNodeAccess(self, i); + Uint208HeapNode storage nj = _unsafeNodeAccess(self, j); + uint24 ii = ni.index; + uint24 jj = nj.index; + // update pointers to the data (swap the value) + ni.index = jj; + nj.index = ii; + // update lookup pointers for consistency + _unsafeNodeAccess(self, ii).lookup = j; + _unsafeNodeAccess(self, jj).lookup = i; + } + + /** + * @dev Perform heap maintenance on `self`, starting at position `pos` (with the `value`), using `comp` as a + * comparator, and moving toward the leafs of the underlying tree. + * + * NOTE: This is a private function that is called in a trusted context with already cached parameters. `length` + * and `value` could be extracted from `self` and `pos`, but that would require redundant storage read. These + * parameters are not verified. It is the caller role to make sure the parameters are correct. + */ + function _siftDown( + Uint208Heap storage self, + uint24 size, + uint24 pos, + uint208 value, + function(uint256, uint256) view returns (bool) comp + ) private { + uint256 left = 2 * pos + 1; // this could overflow uint24 + uint256 right = 2 * pos + 2; // this could overflow uint24 + + if (right < size) { + // the check guarantees that `left` and `right` are both valid uint32 + uint24 lIndex = uint24(left); + uint24 rIndex = uint24(right); + uint208 lValue = _unsafeNodeAccess(self, _unsafeNodeAccess(self, lIndex).index).value; + uint208 rValue = _unsafeNodeAccess(self, _unsafeNodeAccess(self, rIndex).index).value; + if (comp(lValue, value) || comp(rValue, value)) { + uint24 index = uint24(comp(lValue, rValue).ternary(lIndex, rIndex)); + _swap(self, pos, index); + _siftDown(self, size, index, value, comp); + } + } else if (left < size) { + // the check guarantees that `left` is a valid uint32 + uint24 lIndex = uint24(left); + uint208 lValue = _unsafeNodeAccess(self, _unsafeNodeAccess(self, lIndex).index).value; + if (comp(lValue, value)) { + _swap(self, pos, lIndex); + _siftDown(self, size, lIndex, value, comp); + } + } + } + + /** + * @dev Perform heap maintenance on `self`, starting at position `pos` (with the `value`), using `comp` as a + * comparator, and moving toward the root of the underlying tree. + * + * NOTE: This is a private function that is called in a trusted context with already cached parameters. `value` + * could be extracted from `self` and `pos`, but that would require redundant storage read. This parameters is not + * verified. It is the caller role to make sure the parameters are correct. + */ + function _siftUp( + Uint208Heap storage self, + uint24 pos, + uint208 value, + function(uint256, uint256) view returns (bool) comp + ) private { + unchecked { + while (pos > 0) { + uint24 parent = (pos - 1) / 2; + uint208 parentValue = _unsafeNodeAccess(self, _unsafeNodeAccess(self, parent).index).value; + if (comp(parentValue, value)) break; + _swap(self, pos, parent); + pos = parent; + } + } + } + + function _unsafeNodeAccess( + Uint208Heap storage self, + uint24 pos + ) private pure returns (Uint208HeapNode storage result) { + assembly ("memory-safe") { + mstore(0x00, self.slot) + result.slot := add(keccak256(0x00, 0x20), pos) + } + } +} diff --git a/docs/modules/ROOT/pages/utilities.adoc b/docs/modules/ROOT/pages/utilities.adoc index 31d4d5e33ed..ecb950caf91 100644 --- a/docs/modules/ROOT/pages/utilities.adoc +++ b/docs/modules/ROOT/pages/utilities.adoc @@ -189,6 +189,7 @@ Some use cases require more powerful data structures than arrays and mappings of - xref:api:utils.adoc#EnumerableSet[`EnumerableSet`]: A https://en.wikipedia.org/wiki/Set_(abstract_data_type)[set] with enumeration capabilities. - xref:api:utils.adoc#EnumerableMap[`EnumerableMap`]: A `mapping` variant with enumeration capabilities. - xref:api:utils.adoc#MerkleTree[`MerkleTree`]: An on-chain https://wikipedia.org/wiki/Merkle_Tree[Merkle Tree] with helper functions. +- xref:api:utils.adoc#Heap.sol[`Heap`]: A The `Enumerable*` structures are similar to mappings in that they store and remove elements in constant time and don't allow for repeated entries, but they also support _enumeration_, which means you can easily query all stored entries both on and off-chain. @@ -240,6 +241,32 @@ function _hashFn(bytes32 a, bytes32 b) internal view returns(bytes32) { } ---- +=== Using a Heap + +A https://en.wikipedia.org/wiki/Binary_heap[binary heap] is a data structure that always store the most important element at its peak and it can be used as a priority queue. + +To define what is most important in a heap, these frequently take comparator functions that tell the binary heap whether a value has more relevance than another. + +OpenZeppelin Contracts implements a Heap data structure with the properties of a binary heap. The heap uses the xref:api:utils.adoc#Comparators-lt-uint256-uint256-[`lt`] function by default but allows to customize its comparator. + +When using a custom comparator, it's recommended to wrap your function to avoid the possibility of mistakenly using a different comparator function: + +[source,solidity] +---- +function pop(Uint256Heap storage self) internal returns (uint256) { + return pop(self, Comparators.gt); +} + +function insert(Uint256Heap storage self, uint256 value) internal { + insert(self, value, Comparators.gt); +} + +function replace(Uint256Heap storage self, uint256 newValue) internal returns (uint256) { + return replace(self, newValue, Comparators.gt); +} +---- + + [[misc]] == Misc @@ -292,7 +319,7 @@ function _setImplementation(address newImplementation) internal { } ---- -The xref:api:utils.adoc#StorageSlot[`StorageSlot`] library also supports transient storage through user defined value types (UDVTs[https://docs.soliditylang.org/en/latest/types.html#user-defined-value-types]), which enables the same value types as in Solidity. +The xref:api:utils.adoc#StorageSlot[`StorageSlot`] library also supports transient storage through user defined value types (https://docs.soliditylang.org/en/latest/types.html#user-defined-value-types[UDVTs]), which enables the same value types as in Solidity. [source,solidity] ---- diff --git a/scripts/generate/run.js b/scripts/generate/run.js index f8ec17606ac..801e1eee90c 100755 --- a/scripts/generate/run.js +++ b/scripts/generate/run.js @@ -33,9 +33,10 @@ function generateFromTemplate(file, template, outputPrefix = '') { // Contracts for (const [file, template] of Object.entries({ 'utils/math/SafeCast.sol': './templates/SafeCast.js', + 'utils/structs/Checkpoints.sol': './templates/Checkpoints.js', 'utils/structs/EnumerableSet.sol': './templates/EnumerableSet.js', 'utils/structs/EnumerableMap.sol': './templates/EnumerableMap.js', - 'utils/structs/Checkpoints.sol': './templates/Checkpoints.js', + 'utils/structs/Heap.sol': './templates/Heap.js', 'utils/SlotDerivation.sol': './templates/SlotDerivation.js', 'utils/StorageSlot.sol': './templates/StorageSlot.js', 'utils/Arrays.sol': './templates/Arrays.js', @@ -48,6 +49,7 @@ for (const [file, template] of Object.entries({ // Tests for (const [file, template] of Object.entries({ 'utils/structs/Checkpoints.t.sol': './templates/Checkpoints.t.js', + 'utils/structs/Heap.t.sol': './templates/Heap.t.js', 'utils/Packing.t.sol': './templates/Packing.t.js', 'utils/SlotDerivation.t.sol': './templates/SlotDerivation.t.js', })) { diff --git a/scripts/generate/templates/Arrays.js b/scripts/generate/templates/Arrays.js index 30a6e069aa6..9823e4e5d7b 100644 --- a/scripts/generate/templates/Arrays.js +++ b/scripts/generate/templates/Arrays.js @@ -5,6 +5,7 @@ const { TYPES } = require('./Arrays.opts'); const header = `\ pragma solidity ^0.8.20; +import {Comparators} from "./Comparators.sol"; import {SlotDerivation} from "./SlotDerivation.sol"; import {StorageSlot} from "./StorageSlot.sol"; import {Math} from "./math/Math.sol"; @@ -31,9 +32,9 @@ function sort( function(${type}, ${type}) pure returns (bool) comp ) internal pure returns (${type}[] memory) { ${ - type === 'bytes32' + type === 'uint256' ? '_quickSort(_begin(array), _end(array), comp);' - : 'sort(_castToBytes32Array(array), _castToBytes32Comp(comp));' + : 'sort(_castToUint256Array(array), _castToUint256Comp(comp));' } return array; } @@ -42,7 +43,7 @@ function sort( * @dev Variant of {sort} that sorts an array of ${type} in increasing order. */ function sort(${type}[] memory array) internal pure returns (${type}[] memory) { - ${type === 'bytes32' ? 'sort(array, _defaultComp);' : 'sort(_castToBytes32Array(array), _defaultComp);'} + ${type === 'uint256' ? 'sort(array, Comparators.lt);' : 'sort(_castToUint256Array(array), Comparators.lt);'} return array; } `; @@ -57,12 +58,12 @@ const quickSort = `\ * IMPORTANT: Memory locations between \`begin\` and \`end\` are not validated/zeroed. This function should * be used only if the limits are within a memory array. */ -function _quickSort(uint256 begin, uint256 end, function(bytes32, bytes32) pure returns (bool) comp) private pure { +function _quickSort(uint256 begin, uint256 end, function(uint256, uint256) pure returns (bool) comp) private pure { unchecked { if (end - begin < 0x40) return; // Use first element as pivot - bytes32 pivot = _mload(begin); + uint256 pivot = _mload(begin); // Position where the pivot should be at the end of the loop uint256 pos = begin; @@ -84,7 +85,7 @@ function _quickSort(uint256 begin, uint256 end, function(bytes32, bytes32) pure /** * @dev Pointer to the memory location of the first element of \`array\`. */ -function _begin(bytes32[] memory array) private pure returns (uint256 ptr) { +function _begin(uint256[] memory array) private pure returns (uint256 ptr) { /// @solidity memory-safe-assembly assembly { ptr := add(array, 0x20) @@ -95,16 +96,16 @@ function _begin(bytes32[] memory array) private pure returns (uint256 ptr) { * @dev Pointer to the memory location of the first memory word (32bytes) after \`array\`. This is the memory word * that comes just after the last element of the array. */ -function _end(bytes32[] memory array) private pure returns (uint256 ptr) { +function _end(uint256[] memory array) private pure returns (uint256 ptr) { unchecked { return _begin(array) + array.length * 0x20; } } /** - * @dev Load memory word (as a bytes32) at location \`ptr\`. + * @dev Load memory word (as a uint256) at location \`ptr\`. */ -function _mload(uint256 ptr) private pure returns (bytes32 value) { +function _mload(uint256 ptr) private pure returns (uint256 value) { assembly { value := mload(ptr) } @@ -123,16 +124,9 @@ function _swap(uint256 ptr1, uint256 ptr2) private pure { } `; -const defaultComparator = `\ -/// @dev Comparator for sorting arrays in increasing order. -function _defaultComp(bytes32 a, bytes32 b) private pure returns (bool) { - return a < b; -} -`; - const castArray = type => `\ /// @dev Helper: low level cast ${type} memory array to uint256 memory array -function _castToBytes32Array(${type}[] memory input) private pure returns (bytes32[] memory output) { +function _castToUint256Array(${type}[] memory input) private pure returns (uint256[] memory output) { assembly { output := input } @@ -140,10 +134,10 @@ function _castToBytes32Array(${type}[] memory input) private pure returns (bytes `; const castComparator = type => `\ -/// @dev Helper: low level cast ${type} comp function to bytes32 comp function -function _castToBytes32Comp( +/// @dev Helper: low level cast ${type} comp function to uint256 comp function +function _castToUint256Comp( function(${type}, ${type}) pure returns (bool) input -) private pure returns (function(bytes32, bytes32) pure returns (bool) output) { +) private pure returns (function(uint256, uint256) pure returns (bool) output) { assembly { output := input } @@ -374,12 +368,11 @@ module.exports = format( 'using StorageSlot for bytes32;', '', // sorting, comparator, helpers and internal - sort('bytes32'), - TYPES.filter(type => type !== 'bytes32').map(sort), + sort('uint256'), + TYPES.filter(type => type !== 'uint256').map(sort), quickSort, - defaultComparator, - TYPES.filter(type => type !== 'bytes32').map(castArray), - TYPES.filter(type => type !== 'bytes32').map(castComparator), + TYPES.filter(type => type !== 'uint256').map(castArray), + TYPES.filter(type => type !== 'uint256').map(castComparator), // lookup search, // unsafe (direct) storage and memory access diff --git a/scripts/generate/templates/Heap.js b/scripts/generate/templates/Heap.js new file mode 100644 index 00000000000..7ed99939bb5 --- /dev/null +++ b/scripts/generate/templates/Heap.js @@ -0,0 +1,328 @@ +const format = require('../format-lines'); +const { TYPES } = require('./Heap.opts'); +const { capitalize } = require('../../helpers'); + +/* eslint-disable max-len */ +const header = `\ +pragma solidity ^0.8.20; + +import {Math} from "../math/Math.sol"; +import {SafeCast} from "../math/SafeCast.sol"; +import {Comparators} from "../Comparators.sol"; +import {Panic} from "../Panic.sol"; + +/** + * @dev Library for managing https://en.wikipedia.org/wiki/Binary_heap[binary heap] that can be used as + * https://en.wikipedia.org/wiki/Priority_queue[priority queue]. + * + * Heaps are represented as an array of Node objects. This array stores two overlapping structures: + * * A tree structure where the first element (index 0) is the root, and where the node at index i is the child of the + * node at index (i-1)/2 and the father of nodes at index 2*i+1 and 2*i+2. Each node stores the index (in the array) + * where the corresponding value is stored. + * * A list of payloads values where each index contains a value and a lookup index. The type of the value depends on + * the variant being used. The lookup is the index of the node (in the tree) that points to this value. + * + * Some invariants: + * \`\`\` + * i == heap.data[heap.data[i].index].lookup // for all indices i + * i == heap.data[heap.data[i].lookup].index // for all indices i + * \`\`\` + * + * The structure is ordered so that each node is bigger than its parent. An immediate consequence is that the + * highest priority value is the one at the root. This value can be lookup up in constant time (O(1)) at + * \`heap.data[heap.data[0].index].value\` + * + * The structure is designed to perform the following operations with the corresponding complexities: + * + * * peek (get the highest priority in set): O(1) + * * insert (insert a value in the set): 0(log(n)) + * * pop (remove the highest priority value in set): O(log(n)) + * * replace (replace the highest priority value in set with a new value): O(log(n)) + * * length (get the number of elements in the set): O(1) + * * clear (remove all elements in the set): O(1) + */ +`; + +const generate = ({ struct, node, valueType, indexType, blockSize }) => `\ +/** + * @dev Binary heap that support values of type ${valueType}. + * + * Each element of that structures uses ${blockSize} storage slots. + */ +struct ${struct} { + ${node}[] data; +} + +/** + * @dev Internal node type for ${struct}. Stores a value of type ${valueType}. + */ +struct ${node} { + ${valueType} value; + ${indexType} index; // position -> value + ${indexType} lookup; // value -> position +} + +/** + * @dev Lookup the root element of the heap. + */ +function peek(${struct} storage self) internal view returns (${valueType}) { + // self.data[0] will \`ARRAY_ACCESS_OUT_OF_BOUNDS\` panic if heap is empty. + return _unsafeNodeAccess(self, self.data[0].index).value; +} + +/** + * @dev Remove (and return) the root element for the heap using the default comparator. + * + * NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator + * during the lifecycle of a heap will result in undefined behavior. + */ +function pop(${struct} storage self) internal returns (${valueType}) { + return pop(self, Comparators.lt); +} + +/** + * @dev Remove (and return) the root element for the heap using the provided comparator. + * + * NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator + * during the lifecycle of a heap will result in undefined behavior. + */ +function pop( + ${struct} storage self, + function(uint256, uint256) view returns (bool) comp +) internal returns (${valueType}) { + unchecked { + ${indexType} size = length(self); + if (size == 0) Panic.panic(Panic.EMPTY_ARRAY_POP); + + ${indexType} last = size - 1; + + // get root location (in the data array) and value + ${node} storage rootNode = _unsafeNodeAccess(self, 0); + ${indexType} rootIdx = rootNode.index; + ${node} storage rootData = _unsafeNodeAccess(self, rootIdx); + ${node} storage lastNode = _unsafeNodeAccess(self, last); + ${valueType} rootDataValue = rootData.value; + + // if root is not the last element of the data array (that will get pop-ed), reorder the data array. + if (rootIdx != last) { + // get details about the value stored in the last element of the array (that will get pop-ed) + ${indexType} lastDataIdx = lastNode.lookup; + ${valueType} lastDataValue = lastNode.value; + // copy these values to the location of the root (that is safe, and that we no longer use) + rootData.value = lastDataValue; + rootData.lookup = lastDataIdx; + // update the tree node that used to point to that last element (value now located where the root was) + _unsafeNodeAccess(self, lastDataIdx).index = rootIdx; + } + + // get last leaf location (in the data array) and value + ${indexType} lastIdx = lastNode.index; + ${valueType} lastValue = _unsafeNodeAccess(self, lastIdx).value; + + // move the last leaf to the root, pop last leaf ... + rootNode.index = lastIdx; + _unsafeNodeAccess(self, lastIdx).lookup = 0; + self.data.pop(); + + // ... and heapify + _siftDown(self, last, 0, lastValue, comp); + + // return root value + return rootDataValue; + } +} + +/** + * @dev Insert a new element in the heap using the default comparator. + * + * NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator + * during the lifecycle of a heap will result in undefined behavior. + */ +function insert(${struct} storage self, ${valueType} value) internal { + insert(self, value, Comparators.lt); +} + +/** + * @dev Insert a new element in the heap using the provided comparator. + * + * NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator + * during the lifecycle of a heap will result in undefined behavior. + */ +function insert( + ${struct} storage self, + ${valueType} value, + function(uint256, uint256) view returns (bool) comp +) internal { + ${indexType} size = length(self); + if (size == type(${indexType}).max) Panic.panic(Panic.RESOURCE_ERROR); + + self.data.push(${struct}Node({index: size, lookup: size, value: value})); + _siftUp(self, size, value, comp); +} + +/** + * @dev Return the root element for the heap, and replace it with a new value, using the default comparator. + * This is equivalent to using {pop} and {insert}, but requires only one rebalancing operation. + * + * NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator + * during the lifecycle of a heap will result in undefined behavior. + */ +function replace(${struct} storage self, ${valueType} newValue) internal returns (${valueType}) { + return replace(self, newValue, Comparators.lt); +} + +/** + * @dev Return the root element for the heap, and replace it with a new value, using the provided comparator. + * This is equivalent to using {pop} and {insert}, but requires only one rebalancing operation. + * + * NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator + * during the lifecycle of a heap will result in undefined behavior. + */ +function replace( + ${struct} storage self, + ${valueType} newValue, + function(uint256, uint256) view returns (bool) comp +) internal returns (${valueType}) { + ${indexType} size = length(self); + if (size == 0) Panic.panic(Panic.EMPTY_ARRAY_POP); + + // position of the node that holds the data for the root + ${indexType} rootIdx = _unsafeNodeAccess(self, 0).index; + // storage pointer to the node that holds the data for the root + ${node} storage rootData = _unsafeNodeAccess(self, rootIdx); + + // cache old value and replace it + ${valueType} oldValue = rootData.value; + rootData.value = newValue; + + // re-heapify + _siftDown(self, size, 0, newValue, comp); + + // return old root value + return oldValue; +} + +/** + * @dev Returns the number of elements in the heap. + */ +function length(${struct} storage self) internal view returns (${indexType}) { + return self.data.length.to${capitalize(indexType)}(); +} + +/** + * @dev Removes all elements in the heap. + */ +function clear(${struct} storage self) internal { + ${struct}Node[] storage data = self.data; + /// @solidity memory-safe-assembly + assembly { + sstore(data.slot, 0) + } +} + +/* + * @dev Swap node \`i\` and \`j\` in the tree. + */ +function _swap(${struct} storage self, ${indexType} i, ${indexType} j) private { + ${node} storage ni = _unsafeNodeAccess(self, i); + ${node} storage nj = _unsafeNodeAccess(self, j); + ${indexType} ii = ni.index; + ${indexType} jj = nj.index; + // update pointers to the data (swap the value) + ni.index = jj; + nj.index = ii; + // update lookup pointers for consistency + _unsafeNodeAccess(self, ii).lookup = j; + _unsafeNodeAccess(self, jj).lookup = i; +} + +/** + * @dev Perform heap maintenance on \`self\`, starting at position \`pos\` (with the \`value\`), using \`comp\` as a + * comparator, and moving toward the leafs of the underlying tree. + * + * NOTE: This is a private function that is called in a trusted context with already cached parameters. \`length\` + * and \`value\` could be extracted from \`self\` and \`pos\`, but that would require redundant storage read. These + * parameters are not verified. It is the caller role to make sure the parameters are correct. + */ +function _siftDown( + ${struct} storage self, + ${indexType} size, + ${indexType} pos, + ${valueType} value, + function(uint256, uint256) view returns (bool) comp +) private { + uint256 left = 2 * pos + 1; // this could overflow ${indexType} + uint256 right = 2 * pos + 2; // this could overflow ${indexType} + + if (right < size) { + // the check guarantees that \`left\` and \`right\` are both valid uint32 + ${indexType} lIndex = ${indexType}(left); + ${indexType} rIndex = ${indexType}(right); + ${valueType} lValue = _unsafeNodeAccess(self, _unsafeNodeAccess(self, lIndex).index).value; + ${valueType} rValue = _unsafeNodeAccess(self, _unsafeNodeAccess(self, rIndex).index).value; + if (comp(lValue, value) || comp(rValue, value)) { + ${indexType} index = ${indexType}(comp(lValue, rValue).ternary(lIndex, rIndex)); + _swap(self, pos, index); + _siftDown(self, size, index, value, comp); + } + } else if (left < size) { + // the check guarantees that \`left\` is a valid uint32 + ${indexType} lIndex = ${indexType}(left); + ${valueType} lValue = _unsafeNodeAccess(self, _unsafeNodeAccess(self, lIndex).index).value; + if (comp(lValue, value)) { + _swap(self, pos, lIndex); + _siftDown(self, size, lIndex, value, comp); + } + } +} + +/** + * @dev Perform heap maintenance on \`self\`, starting at position \`pos\` (with the \`value\`), using \`comp\` as a + * comparator, and moving toward the root of the underlying tree. + * + * NOTE: This is a private function that is called in a trusted context with already cached parameters. \`value\` + * could be extracted from \`self\` and \`pos\`, but that would require redundant storage read. This parameters is not + * verified. It is the caller role to make sure the parameters are correct. + */ +function _siftUp( + ${struct} storage self, + ${indexType} pos, + ${valueType} value, + function(uint256, uint256) view returns (bool) comp +) private { + unchecked { + while (pos > 0) { + ${indexType} parent = (pos - 1) / 2; + ${valueType} parentValue = _unsafeNodeAccess(self, _unsafeNodeAccess(self, parent).index).value; + if (comp(parentValue, value)) break; + _swap(self, pos, parent); + pos = parent; + } + } +} + +function _unsafeNodeAccess( + ${struct} storage self, + ${indexType} pos +) private pure returns (${node} storage result) { + assembly ("memory-safe") { + mstore(0x00, self.slot) + result.slot := add(keccak256(0x00, 0x20), ${blockSize == 1 ? 'pos' : `mul(pos, ${blockSize})`}) + } +} +`; + +// GENERATE +module.exports = format( + header.trimEnd(), + 'library Heap {', + format( + [].concat( + 'using Math for *;', + 'using SafeCast for *;', + '', + TYPES.map(type => generate(type)), + ), + ).trimEnd(), + '}', +); diff --git a/scripts/generate/templates/Heap.opts.js b/scripts/generate/templates/Heap.opts.js new file mode 100644 index 00000000000..8b8be0afdfa --- /dev/null +++ b/scripts/generate/templates/Heap.opts.js @@ -0,0 +1,13 @@ +const makeType = (valueSize, indexSize) => ({ + struct: `Uint${valueSize}Heap`, + node: `Uint${valueSize}HeapNode`, + valueSize, + valueType: `uint${valueSize}`, + indexSize, + indexType: `uint${indexSize}`, + blockSize: Math.ceil((valueSize + 2 * indexSize) / 256), +}); + +module.exports = { + TYPES: [makeType(256, 64), makeType(208, 24)], +}; diff --git a/scripts/generate/templates/Heap.t.js b/scripts/generate/templates/Heap.t.js new file mode 100644 index 00000000000..04b3152ba3a --- /dev/null +++ b/scripts/generate/templates/Heap.t.js @@ -0,0 +1,89 @@ +const format = require('../format-lines'); +const { TYPES } = require('./Heap.opts'); + +/* eslint-disable max-len */ +const header = `\ +pragma solidity ^0.8.20; + +import {Test} from "forge-std/Test.sol"; +import {Math} from "@openzeppelin/contracts/utils/math/Math.sol"; +import {Heap} from "@openzeppelin/contracts/utils/structs/Heap.sol"; +import {Comparators} from "@openzeppelin/contracts/utils/Comparators.sol"; +`; + +const generate = ({ struct, valueType }) => `\ +contract ${struct}Test is Test { + using Heap for Heap.${struct}; + + Heap.${struct} internal heap; + + function _validateHeap(function(uint256, uint256) view returns (bool) comp) internal { + for (uint32 i = 0; i < heap.length(); ++i) { + // lookups + assertEq(i, heap.data[heap.data[i].index].lookup); + assertEq(i, heap.data[heap.data[i].lookup].index); + + // ordering: each node has a value bigger then its parent + if (i > 0) + assertFalse(comp(heap.data[heap.data[i].index].value, heap.data[heap.data[(i - 1) / 2].index].value)); + } + } + + function testFuzz(${valueType}[] calldata input) public { + vm.assume(input.length < 0x20); + assertEq(heap.length(), 0); + + uint256 min = type(uint256).max; + for (uint256 i = 0; i < input.length; ++i) { + heap.insert(input[i]); + assertEq(heap.length(), i + 1); + _validateHeap(Comparators.lt); + + min = Math.min(min, input[i]); + assertEq(heap.peek(), min); + } + + uint256 max = 0; + for (uint256 i = 0; i < input.length; ++i) { + ${valueType} top = heap.peek(); + ${valueType} pop = heap.pop(); + assertEq(heap.length(), input.length - i - 1); + _validateHeap(Comparators.lt); + + assertEq(pop, top); + assertGe(pop, max); + max = pop; + } + } + + function testFuzzGt(${valueType}[] calldata input) public { + vm.assume(input.length < 0x20); + assertEq(heap.length(), 0); + + uint256 max = 0; + for (uint256 i = 0; i < input.length; ++i) { + heap.insert(input[i], Comparators.gt); + assertEq(heap.length(), i + 1); + _validateHeap(Comparators.gt); + + max = Math.max(max, input[i]); + assertEq(heap.peek(), max); + } + + uint256 min = type(uint256).max; + for (uint256 i = 0; i < input.length; ++i) { + ${valueType} top = heap.peek(); + ${valueType} pop = heap.pop(Comparators.gt); + assertEq(heap.length(), input.length - i - 1); + _validateHeap(Comparators.gt); + + assertEq(pop, top); + assertLe(pop, min); + min = pop; + } + } +} +`; + +// GENERATE +module.exports = format(header, ...TYPES.map(type => generate(type))); diff --git a/test/utils/structs/Heap.t.sol b/test/utils/structs/Heap.t.sol new file mode 100644 index 00000000000..b9d0b98787c --- /dev/null +++ b/test/utils/structs/Heap.t.sol @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: MIT +// This file was procedurally generated from scripts/generate/templates/Heap.t.js. + +pragma solidity ^0.8.20; + +import {Test} from "forge-std/Test.sol"; +import {Math} from "@openzeppelin/contracts/utils/math/Math.sol"; +import {Heap} from "@openzeppelin/contracts/utils/structs/Heap.sol"; +import {Comparators} from "@openzeppelin/contracts/utils/Comparators.sol"; + +contract Uint256HeapTest is Test { + using Heap for Heap.Uint256Heap; + + Heap.Uint256Heap internal heap; + + function _validateHeap(function(uint256, uint256) view returns (bool) comp) internal { + for (uint32 i = 0; i < heap.length(); ++i) { + // lookups + assertEq(i, heap.data[heap.data[i].index].lookup); + assertEq(i, heap.data[heap.data[i].lookup].index); + + // ordering: each node has a value bigger then its parent + if (i > 0) + assertFalse(comp(heap.data[heap.data[i].index].value, heap.data[heap.data[(i - 1) / 2].index].value)); + } + } + + function testFuzz(uint256[] calldata input) public { + vm.assume(input.length < 0x20); + assertEq(heap.length(), 0); + + uint256 min = type(uint256).max; + for (uint256 i = 0; i < input.length; ++i) { + heap.insert(input[i]); + assertEq(heap.length(), i + 1); + _validateHeap(Comparators.lt); + + min = Math.min(min, input[i]); + assertEq(heap.peek(), min); + } + + uint256 max = 0; + for (uint256 i = 0; i < input.length; ++i) { + uint256 top = heap.peek(); + uint256 pop = heap.pop(); + assertEq(heap.length(), input.length - i - 1); + _validateHeap(Comparators.lt); + + assertEq(pop, top); + assertGe(pop, max); + max = pop; + } + } + + function testFuzzGt(uint256[] calldata input) public { + vm.assume(input.length < 0x20); + assertEq(heap.length(), 0); + + uint256 max = 0; + for (uint256 i = 0; i < input.length; ++i) { + heap.insert(input[i], Comparators.gt); + assertEq(heap.length(), i + 1); + _validateHeap(Comparators.gt); + + max = Math.max(max, input[i]); + assertEq(heap.peek(), max); + } + + uint256 min = type(uint256).max; + for (uint256 i = 0; i < input.length; ++i) { + uint256 top = heap.peek(); + uint256 pop = heap.pop(Comparators.gt); + assertEq(heap.length(), input.length - i - 1); + _validateHeap(Comparators.gt); + + assertEq(pop, top); + assertLe(pop, min); + min = pop; + } + } +} + +contract Uint208HeapTest is Test { + using Heap for Heap.Uint208Heap; + + Heap.Uint208Heap internal heap; + + function _validateHeap(function(uint256, uint256) view returns (bool) comp) internal { + for (uint32 i = 0; i < heap.length(); ++i) { + // lookups + assertEq(i, heap.data[heap.data[i].index].lookup); + assertEq(i, heap.data[heap.data[i].lookup].index); + + // ordering: each node has a value bigger then its parent + if (i > 0) + assertFalse(comp(heap.data[heap.data[i].index].value, heap.data[heap.data[(i - 1) / 2].index].value)); + } + } + + function testFuzz(uint208[] calldata input) public { + vm.assume(input.length < 0x20); + assertEq(heap.length(), 0); + + uint256 min = type(uint256).max; + for (uint256 i = 0; i < input.length; ++i) { + heap.insert(input[i]); + assertEq(heap.length(), i + 1); + _validateHeap(Comparators.lt); + + min = Math.min(min, input[i]); + assertEq(heap.peek(), min); + } + + uint256 max = 0; + for (uint256 i = 0; i < input.length; ++i) { + uint208 top = heap.peek(); + uint208 pop = heap.pop(); + assertEq(heap.length(), input.length - i - 1); + _validateHeap(Comparators.lt); + + assertEq(pop, top); + assertGe(pop, max); + max = pop; + } + } + + function testFuzzGt(uint208[] calldata input) public { + vm.assume(input.length < 0x20); + assertEq(heap.length(), 0); + + uint256 max = 0; + for (uint256 i = 0; i < input.length; ++i) { + heap.insert(input[i], Comparators.gt); + assertEq(heap.length(), i + 1); + _validateHeap(Comparators.gt); + + max = Math.max(max, input[i]); + assertEq(heap.peek(), max); + } + + uint256 min = type(uint256).max; + for (uint256 i = 0; i < input.length; ++i) { + uint208 top = heap.peek(); + uint208 pop = heap.pop(Comparators.gt); + assertEq(heap.length(), input.length - i - 1); + _validateHeap(Comparators.gt); + + assertEq(pop, top); + assertLe(pop, min); + min = pop; + } + } +} diff --git a/test/utils/structs/Heap.test.js b/test/utils/structs/Heap.test.js new file mode 100644 index 00000000000..7e95a0e7adb --- /dev/null +++ b/test/utils/structs/Heap.test.js @@ -0,0 +1,131 @@ +const { ethers } = require('hardhat'); +const { expect } = require('chai'); +const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers'); +const { PANIC_CODES } = require('@nomicfoundation/hardhat-chai-matchers/panic'); + +const { TYPES } = require('../../../scripts/generate/templates/Heap.opts'); + +async function fixture() { + const mock = await ethers.deployContract('$Heap'); + return { mock }; +} + +describe('Heap', function () { + beforeEach(async function () { + Object.assign(this, await loadFixture(fixture)); + }); + + for (const { struct, valueType } of TYPES) { + describe(struct, function () { + const popEvent = `return$pop_Heap_${struct}`; + const replaceEvent = `return$replace_Heap_${struct}_${valueType}`; + + beforeEach(async function () { + this.helper = { + clear: (...args) => this.mock[`$clear_Heap_${struct}`](0, ...args), + insert: (...args) => this.mock[`$insert(uint256,${valueType})`](0, ...args), + replace: (...args) => this.mock[`$replace(uint256,${valueType})`](0, ...args), + length: (...args) => this.mock[`$length_Heap_${struct}`](0, ...args), + pop: (...args) => this.mock[`$pop_Heap_${struct}`](0, ...args), + peek: (...args) => this.mock[`$peek_Heap_${struct}`](0, ...args), + }; + }); + + it('starts empty', async function () { + expect(await this.helper.length()).to.equal(0n); + }); + + it('peek, pop and replace from empty', async function () { + await expect(this.helper.peek()).to.be.revertedWithPanic(PANIC_CODES.ARRAY_ACCESS_OUT_OF_BOUNDS); + await expect(this.helper.pop()).to.be.revertedWithPanic(PANIC_CODES.POP_ON_EMPTY_ARRAY); + await expect(this.helper.replace(0n)).to.be.revertedWithPanic(PANIC_CODES.POP_ON_EMPTY_ARRAY); + }); + + it('clear', async function () { + await this.helper.insert(42n); + + expect(await this.helper.length()).to.equal(1n); + expect(await this.helper.peek()).to.equal(42n); + + await this.helper.clear(); + + expect(await this.helper.length()).to.equal(0n); + await expect(this.helper.peek()).to.be.revertedWithPanic(PANIC_CODES.ARRAY_ACCESS_OUT_OF_BOUNDS); + }); + + it('support duplicated items', async function () { + expect(await this.helper.length()).to.equal(0n); + + // insert 5 times + await this.helper.insert(42n); + await this.helper.insert(42n); + await this.helper.insert(42n); + await this.helper.insert(42n); + await this.helper.insert(42n); + + // pop 5 times + await expect(this.helper.pop()).to.emit(this.mock, popEvent).withArgs(42n); + await expect(this.helper.pop()).to.emit(this.mock, popEvent).withArgs(42n); + await expect(this.helper.pop()).to.emit(this.mock, popEvent).withArgs(42n); + await expect(this.helper.pop()).to.emit(this.mock, popEvent).withArgs(42n); + await expect(this.helper.pop()).to.emit(this.mock, popEvent).withArgs(42n); + + // popping a 6th time panics + await expect(this.helper.pop()).to.be.revertedWithPanic(PANIC_CODES.POP_ON_EMPTY_ARRAY); + }); + + it('insert, pop and replace', async function () { + const heap = []; + for (const { op, value } of [ + { op: 'insert', value: 712 }, // [712] + { op: 'insert', value: 20 }, // [20, 712] + { op: 'insert', value: 4337 }, // [20, 712, 4437] + { op: 'pop' }, // 20, [712, 4437] + { op: 'insert', value: 1559 }, // [712, 1559, 4437] + { op: 'insert', value: 165 }, // [165, 712, 1559, 4437] + { op: 'insert', value: 155 }, // [155, 165, 712, 1559, 4437] + { op: 'insert', value: 7702 }, // [155, 165, 712, 1559, 4437, 7702] + { op: 'pop' }, // 155, [165, 712, 1559, 4437, 7702] + { op: 'replace', value: 721 }, // 165, [712, 721, 1559, 4437, 7702] + { op: 'pop' }, // 712, [721, 1559, 4437, 7702] + { op: 'pop' }, // 721, [1559, 4437, 7702] + { op: 'pop' }, // 1559, [4437, 7702] + { op: 'pop' }, // 4437, [7702] + { op: 'pop' }, // 7702, [] + { op: 'pop' }, // panic + { op: 'replace', value: '1363' }, // panic + ]) { + switch (op) { + case 'insert': + await this.helper.insert(value); + heap.push(value); + heap.sort((a, b) => a - b); + break; + case 'pop': + if (heap.length == 0) { + await expect(this.helper.pop()).to.be.revertedWithPanic(PANIC_CODES.POP_ON_EMPTY_ARRAY); + } else { + await expect(this.helper.pop()).to.emit(this.mock, popEvent).withArgs(heap.shift()); + } + break; + case 'replace': + if (heap.length == 0) { + await expect(this.helper.replace(value)).to.be.revertedWithPanic(PANIC_CODES.POP_ON_EMPTY_ARRAY); + } else { + await expect(this.helper.replace(value)).to.emit(this.mock, replaceEvent).withArgs(heap.shift()); + heap.push(value); + heap.sort((a, b) => a - b); + } + break; + } + expect(await this.helper.length()).to.equal(heap.length); + if (heap.length == 0) { + await expect(this.helper.peek()).to.be.revertedWithPanic(PANIC_CODES.ARRAY_ACCESS_OUT_OF_BOUNDS); + } else { + expect(await this.helper.peek()).to.equal(heap[0]); + } + } + }); + }); + } +});