Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved typing #11

Merged
merged 6 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ repos:
"-rn", # Only display messages
"-sn", # Don't display the score
]
require_serial: true

- repo: local
hooks:
Expand All @@ -42,3 +43,4 @@ repos:
entry: mypy
language: system
types: [ python ]
require_serial: true
31 changes: 16 additions & 15 deletions multiproof/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
from dataclasses import dataclass
from itertools import pairwise
from typing import Any

from web3 import Web3
Expand All @@ -8,9 +9,9 @@


@dataclass
class MultiProof:
leaves: list[Any]
proof: list[Any]
class CoreMultiProof:
leaves: list[bytes]
proof: list[bytes]
proof_flags: list[bool]


Expand Down Expand Up @@ -41,40 +42,40 @@ def sibling_index(i: int) -> int:
raise ValueError('Root has no siblings')


def is_tree_node(tree: list[Any], i: int) -> bool:
def is_tree_node(tree: list[bytes], i: int) -> bool:
return 0 <= i < len(tree)


def is_internal_node(tree: list[Any], i: int) -> bool:
def is_internal_node(tree: list[bytes], i: int) -> bool:
return is_tree_node(tree, left_child_index(i))


def is_leaf_node(tree: list[Any], i: int) -> bool:
def is_leaf_node(tree: list[bytes], i: int) -> bool:
return is_tree_node(tree, i) and not is_internal_node(tree, i)


def is_valid_merkle_node(node: bytes) -> bool:
return len(node) == 32


def check_tree_node(tree: list[Any], i: int) -> None:
def check_tree_node(tree: list[bytes], i: int) -> None:
if not is_tree_node(tree, i):
raise ValueError("Index is not in tree")


def check_internal_node(tree: list[Any], i: int) -> None:
def check_internal_node(tree: list[bytes], i: int) -> None:
if not is_internal_node(tree, i):
raise ValueError("Index is not an internal tree node")


def check_leaf_node(tree: list[Any], i: int) -> None:
def check_leaf_node(tree: list[bytes], i: int) -> None:
if not is_leaf_node(tree, i):
raise ValueError("Index is not a leaf")


def check_valid_merkle_node(node: bytes) -> None:
if not is_valid_merkle_node(node):
raise ValueError("Merkle tree nodes must be Uint8Array of length 32")
raise ValueError("Merkle tree nodes must be byte array of length 32")


def make_merkle_tree(leaves: list[bytes]) -> list[bytes]:
Expand Down Expand Up @@ -118,14 +119,14 @@ def process_proof(leaf: bytes, proof: list[bytes]) -> bytes:
return result


def get_multi_proof(tree: list[bytes], indices: list[int]) -> MultiProof:
def get_multi_proof(tree: list[bytes], indices: list[int]) -> CoreMultiProof:
for index in indices:
check_leaf_node(tree, index)

indices = sorted(indices, reverse=True)

for i, p in enumerate(indices[1:]):
if p == indices[i]:
for prev_index, next_index in pairwise(indices):
if prev_index == next_index:
raise ValueError("Cannot prove duplicated index")

stack = indices[:]
Expand All @@ -149,14 +150,14 @@ def get_multi_proof(tree: list[bytes], indices: list[int]) -> MultiProof:
if len(indices) == 0:
proof.append(tree[0])

return MultiProof(
return CoreMultiProof(
leaves=[tree[i] for i in indices],
proof=proof,
proof_flags=proof_flags,
)


def process_multi_proof(multiproof: MultiProof) -> bytes:
def process_multi_proof(multiproof: CoreMultiProof) -> bytes:
for leaf in multiproof.leaves:
check_valid_merkle_node(leaf)

Expand Down
61 changes: 35 additions & 26 deletions multiproof/standard.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from dataclasses import dataclass
from functools import cmp_to_key
from typing import Any, Generic, TypeVar
from typing import Generic, TypeVar

from eth_abi import encode as abi_encode
from eth_typing import HexStr
from web3 import Web3

from multiproof.bytes import compare_bytes, equals_bytes, hex_to_bytes, to_hex
from multiproof.core import (MultiProof, get_multi_proof, get_proof,
from multiproof.core import (CoreMultiProof, get_multi_proof, get_proof,
is_valid_merkle_tree, left_child_index,
make_merkle_tree, process_multi_proof,
process_proof, right_child_index)
Expand All @@ -23,31 +23,41 @@ class LeafValue(Generic[T]):


@dataclass
class StandardMerkleTreeData:
class MultiProof(Generic[T]):
"""
User-friendly version of multiproof, compare with CoreMultiProof
"""
leaves: list[T]
proof: list[HexStr]
proof_flags: list[bool]


@dataclass
class StandardMerkleTreeData(Generic[T]):
tree: list[HexStr]
values: list[LeafValue]
values: list[LeafValue[T]]
leaf_encoding: list[str]
format: str = 'standard-v1'


@dataclass
class HashedValue:
value: Any
class HashedValue(Generic[T]):
value: T
index: int
hash: bytes


def standard_leaf_hash(values: Any, types: list[str]) -> bytes:
return Web3.keccak(Web3.keccak(abi_encode(types, values)))
def standard_leaf_hash(values: T, types: list[str]) -> bytes:
return Web3.keccak(Web3.keccak(abi_encode(types, values))) # type: ignore


class StandardMerkleTree:
class StandardMerkleTree(Generic[T]):
_hash_lookup: dict[HexStr, int]
tree: list[bytes]
values: list[LeafValue]
values: list[LeafValue[T]]
leaf_encoding: list[str]

def __init__(self, tree: list[bytes], values: list[LeafValue], leaf_encoding: list[str]):
def __init__(self, tree: list[bytes], values: list[LeafValue[T]], leaf_encoding: list[str]):
self.tree = tree
self.values = values
self.leaf_encoding = leaf_encoding
Expand All @@ -56,8 +66,8 @@ def __init__(self, tree: list[bytes], values: list[LeafValue], leaf_encoding: li
self._hash_lookup[to_hex(standard_leaf_hash(leaf_value.value, leaf_encoding))] = index

@staticmethod
def of(values: list[Any], leaf_encoding: list[str]) -> 'StandardMerkleTree':
hashed_values: list[HashedValue] = []
def of(values: list[T], leaf_encoding: list[str]) -> 'StandardMerkleTree[T]':
hashed_values: list[HashedValue[T]] = []
for index, value in enumerate(values):
hashed_values.append(
HashedValue(value=value, index=index, hash=standard_leaf_hash(value, leaf_encoding))
Expand All @@ -77,7 +87,7 @@ def of(values: list[Any], leaf_encoding: list[str]) -> 'StandardMerkleTree':
return StandardMerkleTree(tree, indexed_values, leaf_encoding)

@staticmethod
def load(data: StandardMerkleTreeData) -> 'StandardMerkleTree':
def load(data: StandardMerkleTreeData[T]) -> 'StandardMerkleTree[T]':
if data.format != 'standard-v1':
raise ValueError(f"Unknown format '{data.format}'")
return StandardMerkleTree(
Expand All @@ -88,7 +98,7 @@ def load(data: StandardMerkleTreeData) -> 'StandardMerkleTree':

@staticmethod
def verify(
root: HexStr, leaf_encoding: list[str], leaf_value: Any, proof: list[HexStr]
root: HexStr, leaf_encoding: list[str], leaf_value: T, proof: list[HexStr]
) -> bool:
leaf_hash = standard_leaf_hash(leaf_value, leaf_encoding)
implied_root = process_proof(leaf_hash, [hex_to_bytes(x) for x in proof])
Expand All @@ -99,7 +109,7 @@ def verify_multi_proof(root: HexStr, leaf_encoding: list[str], multiproof: Multi
leaf_hashes = [standard_leaf_hash(value, leaf_encoding) for value in multiproof.leaves]
proof_bytes = [hex_to_bytes(x) for x in multiproof.proof]
implied_root = process_multi_proof(
multiproof=MultiProof(
multiproof=CoreMultiProof(
leaves=leaf_hashes,
proof=proof_bytes,
proof_flags=multiproof.proof_flags,
Expand All @@ -108,7 +118,7 @@ def verify_multi_proof(root: HexStr, leaf_encoding: list[str], multiproof: Multi

return equals_bytes(implied_root, hex_to_bytes(root))

def dump(self) -> StandardMerkleTreeData:
def dump(self) -> StandardMerkleTreeData[T]:
return StandardMerkleTreeData(
format='standard-v1',
tree=[to_hex(v) for v in self.tree],
Expand All @@ -127,10 +137,10 @@ def validate(self) -> None:
if not is_valid_merkle_tree(self.tree):
raise ValueError("Merkle tree is invalid")

def leaf_hash(self, leaf: Any) -> HexStr:
def leaf_hash(self, leaf: T) -> HexStr:
return to_hex(standard_leaf_hash(leaf, self.leaf_encoding))

def leaf_lookup(self, leaf: Any) -> int:
def leaf_lookup(self, leaf: T) -> int:
v = self._hash_lookup[self.leaf_hash(leaf)]
if v is None:
raise ValueError("Leaf is not in tree")
Expand All @@ -156,13 +166,12 @@ def get_proof(self, leaf: T | int) -> list[HexStr]:

return [to_hex(p) for p in proof]

def get_multi_proof(self, leaves: list[Any]) -> MultiProof:
def get_multi_proof(self, leaves: list[int] | list[T]) -> MultiProof:
# input validity
value_indices = []
value_indices: list[int] = []
for leaf in leaves:
value_index = leaf
if isinstance(leaf, int):
value_indices.append(value_index)
value_indices.append(leaf)
else:
value_indices.append(self.leaf_lookup(leaf))

Expand Down Expand Up @@ -194,20 +203,20 @@ def _verify_leaf(self, leaf_hash: bytes, proof: list[bytes]) -> bool:

def verify_multi_proof_leaf(self, multiproof: MultiProof) -> bool:
return self._verify_multi_proof_leaf(
MultiProof(
CoreMultiProof(
leaves=[self._get_leaf_hash(leaf) for leaf in multiproof.leaves],
proof=[hex_to_bytes(proof) for proof in multiproof.proof],
proof_flags=multiproof.proof_flags,
)
)

def _verify_multi_proof_leaf(self, multi_proof: MultiProof) -> bool:
def _verify_multi_proof_leaf(self, multi_proof: CoreMultiProof) -> bool:
implied_root = process_multi_proof(multi_proof)
return equals_bytes(implied_root, self.tree[0])

def _validate_value(self, value_index: int) -> bytes:
check_bounds(self.values, value_index)
leaf: LeafValue = self.values[value_index]
leaf = self.values[value_index]
check_bounds(self.tree, leaf.tree_index)
leaf_hash = standard_leaf_hash(leaf.value, self.leaf_encoding)

Expand Down