diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 021f516..fd97202 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,6 +34,7 @@ repos: "-rn", # Only display messages "-sn", # Don't display the score ] + require_serial: true - repo: local hooks: @@ -42,3 +43,4 @@ repos: entry: mypy language: system types: [ python ] + require_serial: true diff --git a/multiproof/core.py b/multiproof/core.py index 835b023..25fb325 100644 --- a/multiproof/core.py +++ b/multiproof/core.py @@ -1,5 +1,6 @@ import math from dataclasses import dataclass +from itertools import pairwise from typing import Any from web3 import Web3 @@ -8,9 +9,9 @@ @dataclass -class MultiProof: - leaves: list[Any] - proof: list[Any] +class CoreMultiProof: + leaves: list[bytes] + proof: list[bytes] proof_flags: list[bool] @@ -41,15 +42,15 @@ def sibling_index(i: int) -> int: raise ValueError('Root has no siblings') -def is_tree_node(tree: list[Any], i: int) -> bool: +def is_tree_node(tree: list[bytes], i: int) -> bool: return 0 <= i < len(tree) -def is_internal_node(tree: list[Any], i: int) -> bool: +def is_internal_node(tree: list[bytes], i: int) -> bool: return is_tree_node(tree, left_child_index(i)) -def is_leaf_node(tree: list[Any], i: int) -> bool: +def is_leaf_node(tree: list[bytes], i: int) -> bool: return is_tree_node(tree, i) and not is_internal_node(tree, i) @@ -57,24 +58,24 @@ def is_valid_merkle_node(node: bytes) -> bool: return len(node) == 32 -def check_tree_node(tree: list[Any], i: int) -> None: +def check_tree_node(tree: list[bytes], i: int) -> None: if not is_tree_node(tree, i): raise ValueError("Index is not in tree") -def check_internal_node(tree: list[Any], i: int) -> None: +def check_internal_node(tree: list[bytes], i: int) -> None: if not is_internal_node(tree, i): raise ValueError("Index is not an internal tree node") -def check_leaf_node(tree: list[Any], i: int) -> None: +def check_leaf_node(tree: list[bytes], i: int) -> None: if not is_leaf_node(tree, i): raise ValueError("Index is not a leaf") def check_valid_merkle_node(node: bytes) -> None: if not is_valid_merkle_node(node): - raise ValueError("Merkle tree nodes must be Uint8Array of length 32") + raise ValueError("Merkle tree nodes must be byte array of length 32") def make_merkle_tree(leaves: list[bytes]) -> list[bytes]: @@ -118,14 +119,14 @@ def process_proof(leaf: bytes, proof: list[bytes]) -> bytes: return result -def get_multi_proof(tree: list[bytes], indices: list[int]) -> MultiProof: +def get_multi_proof(tree: list[bytes], indices: list[int]) -> CoreMultiProof: for index in indices: check_leaf_node(tree, index) indices = sorted(indices, reverse=True) - for i, p in enumerate(indices[1:]): - if p == indices[i]: + for prev_index, next_index in pairwise(indices): + if prev_index == next_index: raise ValueError("Cannot prove duplicated index") stack = indices[:] @@ -149,14 +150,14 @@ def get_multi_proof(tree: list[bytes], indices: list[int]) -> MultiProof: if len(indices) == 0: proof.append(tree[0]) - return MultiProof( + return CoreMultiProof( leaves=[tree[i] for i in indices], proof=proof, proof_flags=proof_flags, ) -def process_multi_proof(multiproof: MultiProof) -> bytes: +def process_multi_proof(multiproof: CoreMultiProof) -> bytes: for leaf in multiproof.leaves: check_valid_merkle_node(leaf) diff --git a/multiproof/standard.py b/multiproof/standard.py index dc886c8..1a7f908 100644 --- a/multiproof/standard.py +++ b/multiproof/standard.py @@ -1,13 +1,13 @@ from dataclasses import dataclass from functools import cmp_to_key -from typing import Any, Generic, TypeVar +from typing import Generic, TypeVar from eth_abi import encode as abi_encode from eth_typing import HexStr from web3 import Web3 from multiproof.bytes import compare_bytes, equals_bytes, hex_to_bytes, to_hex -from multiproof.core import (MultiProof, get_multi_proof, get_proof, +from multiproof.core import (CoreMultiProof, get_multi_proof, get_proof, is_valid_merkle_tree, left_child_index, make_merkle_tree, process_multi_proof, process_proof, right_child_index) @@ -23,31 +23,41 @@ class LeafValue(Generic[T]): @dataclass -class StandardMerkleTreeData: +class MultiProof(Generic[T]): + """ + User-friendly version of multiproof, compare with CoreMultiProof + """ + leaves: list[T] + proof: list[HexStr] + proof_flags: list[bool] + + +@dataclass +class StandardMerkleTreeData(Generic[T]): tree: list[HexStr] - values: list[LeafValue] + values: list[LeafValue[T]] leaf_encoding: list[str] format: str = 'standard-v1' @dataclass -class HashedValue: - value: Any +class HashedValue(Generic[T]): + value: T index: int hash: bytes -def standard_leaf_hash(values: Any, types: list[str]) -> bytes: - return Web3.keccak(Web3.keccak(abi_encode(types, values))) +def standard_leaf_hash(values: T, types: list[str]) -> bytes: + return Web3.keccak(Web3.keccak(abi_encode(types, values))) # type: ignore -class StandardMerkleTree: +class StandardMerkleTree(Generic[T]): _hash_lookup: dict[HexStr, int] tree: list[bytes] - values: list[LeafValue] + values: list[LeafValue[T]] leaf_encoding: list[str] - def __init__(self, tree: list[bytes], values: list[LeafValue], leaf_encoding: list[str]): + def __init__(self, tree: list[bytes], values: list[LeafValue[T]], leaf_encoding: list[str]): self.tree = tree self.values = values self.leaf_encoding = leaf_encoding @@ -56,8 +66,8 @@ def __init__(self, tree: list[bytes], values: list[LeafValue], leaf_encoding: li self._hash_lookup[to_hex(standard_leaf_hash(leaf_value.value, leaf_encoding))] = index @staticmethod - def of(values: list[Any], leaf_encoding: list[str]) -> 'StandardMerkleTree': - hashed_values: list[HashedValue] = [] + def of(values: list[T], leaf_encoding: list[str]) -> 'StandardMerkleTree[T]': + hashed_values: list[HashedValue[T]] = [] for index, value in enumerate(values): hashed_values.append( HashedValue(value=value, index=index, hash=standard_leaf_hash(value, leaf_encoding)) @@ -77,7 +87,7 @@ def of(values: list[Any], leaf_encoding: list[str]) -> 'StandardMerkleTree': return StandardMerkleTree(tree, indexed_values, leaf_encoding) @staticmethod - def load(data: StandardMerkleTreeData) -> 'StandardMerkleTree': + def load(data: StandardMerkleTreeData[T]) -> 'StandardMerkleTree[T]': if data.format != 'standard-v1': raise ValueError(f"Unknown format '{data.format}'") return StandardMerkleTree( @@ -88,7 +98,7 @@ def load(data: StandardMerkleTreeData) -> 'StandardMerkleTree': @staticmethod def verify( - root: HexStr, leaf_encoding: list[str], leaf_value: Any, proof: list[HexStr] + root: HexStr, leaf_encoding: list[str], leaf_value: T, proof: list[HexStr] ) -> bool: leaf_hash = standard_leaf_hash(leaf_value, leaf_encoding) implied_root = process_proof(leaf_hash, [hex_to_bytes(x) for x in proof]) @@ -99,7 +109,7 @@ def verify_multi_proof(root: HexStr, leaf_encoding: list[str], multiproof: Multi leaf_hashes = [standard_leaf_hash(value, leaf_encoding) for value in multiproof.leaves] proof_bytes = [hex_to_bytes(x) for x in multiproof.proof] implied_root = process_multi_proof( - multiproof=MultiProof( + multiproof=CoreMultiProof( leaves=leaf_hashes, proof=proof_bytes, proof_flags=multiproof.proof_flags, @@ -108,7 +118,7 @@ def verify_multi_proof(root: HexStr, leaf_encoding: list[str], multiproof: Multi return equals_bytes(implied_root, hex_to_bytes(root)) - def dump(self) -> StandardMerkleTreeData: + def dump(self) -> StandardMerkleTreeData[T]: return StandardMerkleTreeData( format='standard-v1', tree=[to_hex(v) for v in self.tree], @@ -127,10 +137,10 @@ def validate(self) -> None: if not is_valid_merkle_tree(self.tree): raise ValueError("Merkle tree is invalid") - def leaf_hash(self, leaf: Any) -> HexStr: + def leaf_hash(self, leaf: T) -> HexStr: return to_hex(standard_leaf_hash(leaf, self.leaf_encoding)) - def leaf_lookup(self, leaf: Any) -> int: + def leaf_lookup(self, leaf: T) -> int: v = self._hash_lookup[self.leaf_hash(leaf)] if v is None: raise ValueError("Leaf is not in tree") @@ -156,13 +166,12 @@ def get_proof(self, leaf: T | int) -> list[HexStr]: return [to_hex(p) for p in proof] - def get_multi_proof(self, leaves: list[Any]) -> MultiProof: + def get_multi_proof(self, leaves: list[int] | list[T]) -> MultiProof: # input validity - value_indices = [] + value_indices: list[int] = [] for leaf in leaves: - value_index = leaf if isinstance(leaf, int): - value_indices.append(value_index) + value_indices.append(leaf) else: value_indices.append(self.leaf_lookup(leaf)) @@ -194,20 +203,20 @@ def _verify_leaf(self, leaf_hash: bytes, proof: list[bytes]) -> bool: def verify_multi_proof_leaf(self, multiproof: MultiProof) -> bool: return self._verify_multi_proof_leaf( - MultiProof( + CoreMultiProof( leaves=[self._get_leaf_hash(leaf) for leaf in multiproof.leaves], proof=[hex_to_bytes(proof) for proof in multiproof.proof], proof_flags=multiproof.proof_flags, ) ) - def _verify_multi_proof_leaf(self, multi_proof: MultiProof) -> bool: + def _verify_multi_proof_leaf(self, multi_proof: CoreMultiProof) -> bool: implied_root = process_multi_proof(multi_proof) return equals_bytes(implied_root, self.tree[0]) def _validate_value(self, value_index: int) -> bytes: check_bounds(self.values, value_index) - leaf: LeafValue = self.values[value_index] + leaf = self.values[value_index] check_bounds(self.tree, leaf.tree_index) leaf_hash = standard_leaf_hash(leaf.value, self.leaf_encoding)