Skip to content

Commit

Permalink
Improved typing (#11)
Browse files Browse the repository at this point in the history
* Fix duplication check

* Add CoreMultiProof

* Add require_serial

* Fix error message

* Fix typing

* Del explicit typing
  • Loading branch information
evgeny-stakewise authored Apr 29, 2024
1 parent 0f984e6 commit 35ecaec
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 41 deletions.
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ repos:
"-rn", # Only display messages
"-sn", # Don't display the score
]
require_serial: true

- repo: local
hooks:
Expand All @@ -42,3 +43,4 @@ repos:
entry: mypy
language: system
types: [ python ]
require_serial: true
31 changes: 16 additions & 15 deletions multiproof/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
from dataclasses import dataclass
from itertools import pairwise
from typing import Any

from web3 import Web3
Expand All @@ -8,9 +9,9 @@


@dataclass
class MultiProof:
leaves: list[Any]
proof: list[Any]
class CoreMultiProof:
leaves: list[bytes]
proof: list[bytes]
proof_flags: list[bool]


Expand Down Expand Up @@ -41,40 +42,40 @@ def sibling_index(i: int) -> int:
raise ValueError('Root has no siblings')


def is_tree_node(tree: list[Any], i: int) -> bool:
def is_tree_node(tree: list[bytes], i: int) -> bool:
return 0 <= i < len(tree)


def is_internal_node(tree: list[Any], i: int) -> bool:
def is_internal_node(tree: list[bytes], i: int) -> bool:
return is_tree_node(tree, left_child_index(i))


def is_leaf_node(tree: list[Any], i: int) -> bool:
def is_leaf_node(tree: list[bytes], i: int) -> bool:
return is_tree_node(tree, i) and not is_internal_node(tree, i)


def is_valid_merkle_node(node: bytes) -> bool:
return len(node) == 32


def check_tree_node(tree: list[Any], i: int) -> None:
def check_tree_node(tree: list[bytes], i: int) -> None:
if not is_tree_node(tree, i):
raise ValueError("Index is not in tree")


def check_internal_node(tree: list[Any], i: int) -> None:
def check_internal_node(tree: list[bytes], i: int) -> None:
if not is_internal_node(tree, i):
raise ValueError("Index is not an internal tree node")


def check_leaf_node(tree: list[Any], i: int) -> None:
def check_leaf_node(tree: list[bytes], i: int) -> None:
if not is_leaf_node(tree, i):
raise ValueError("Index is not a leaf")


def check_valid_merkle_node(node: bytes) -> None:
if not is_valid_merkle_node(node):
raise ValueError("Merkle tree nodes must be Uint8Array of length 32")
raise ValueError("Merkle tree nodes must be byte array of length 32")


def make_merkle_tree(leaves: list[bytes]) -> list[bytes]:
Expand Down Expand Up @@ -118,14 +119,14 @@ def process_proof(leaf: bytes, proof: list[bytes]) -> bytes:
return result


def get_multi_proof(tree: list[bytes], indices: list[int]) -> MultiProof:
def get_multi_proof(tree: list[bytes], indices: list[int]) -> CoreMultiProof:
for index in indices:
check_leaf_node(tree, index)

indices = sorted(indices, reverse=True)

for i, p in enumerate(indices[1:]):
if p == indices[i]:
for prev_index, next_index in pairwise(indices):
if prev_index == next_index:
raise ValueError("Cannot prove duplicated index")

stack = indices[:]
Expand All @@ -149,14 +150,14 @@ def get_multi_proof(tree: list[bytes], indices: list[int]) -> MultiProof:
if len(indices) == 0:
proof.append(tree[0])

return MultiProof(
return CoreMultiProof(
leaves=[tree[i] for i in indices],
proof=proof,
proof_flags=proof_flags,
)


def process_multi_proof(multiproof: MultiProof) -> bytes:
def process_multi_proof(multiproof: CoreMultiProof) -> bytes:
for leaf in multiproof.leaves:
check_valid_merkle_node(leaf)

Expand Down
61 changes: 35 additions & 26 deletions multiproof/standard.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from dataclasses import dataclass
from functools import cmp_to_key
from typing import Any, Generic, TypeVar
from typing import Generic, TypeVar

from eth_abi import encode as abi_encode
from eth_typing import HexStr
from web3 import Web3

from multiproof.bytes import compare_bytes, equals_bytes, hex_to_bytes, to_hex
from multiproof.core import (MultiProof, get_multi_proof, get_proof,
from multiproof.core import (CoreMultiProof, get_multi_proof, get_proof,
is_valid_merkle_tree, left_child_index,
make_merkle_tree, process_multi_proof,
process_proof, right_child_index)
Expand All @@ -23,31 +23,41 @@ class LeafValue(Generic[T]):


@dataclass
class StandardMerkleTreeData:
class MultiProof(Generic[T]):
"""
User-friendly version of multiproof, compare with CoreMultiProof
"""
leaves: list[T]
proof: list[HexStr]
proof_flags: list[bool]


@dataclass
class StandardMerkleTreeData(Generic[T]):
tree: list[HexStr]
values: list[LeafValue]
values: list[LeafValue[T]]
leaf_encoding: list[str]
format: str = 'standard-v1'


@dataclass
class HashedValue:
value: Any
class HashedValue(Generic[T]):
value: T
index: int
hash: bytes


def standard_leaf_hash(values: Any, types: list[str]) -> bytes:
return Web3.keccak(Web3.keccak(abi_encode(types, values)))
def standard_leaf_hash(values: T, types: list[str]) -> bytes:
return Web3.keccak(Web3.keccak(abi_encode(types, values))) # type: ignore


class StandardMerkleTree:
class StandardMerkleTree(Generic[T]):
_hash_lookup: dict[HexStr, int]
tree: list[bytes]
values: list[LeafValue]
values: list[LeafValue[T]]
leaf_encoding: list[str]

def __init__(self, tree: list[bytes], values: list[LeafValue], leaf_encoding: list[str]):
def __init__(self, tree: list[bytes], values: list[LeafValue[T]], leaf_encoding: list[str]):
self.tree = tree
self.values = values
self.leaf_encoding = leaf_encoding
Expand All @@ -56,8 +66,8 @@ def __init__(self, tree: list[bytes], values: list[LeafValue], leaf_encoding: li
self._hash_lookup[to_hex(standard_leaf_hash(leaf_value.value, leaf_encoding))] = index

@staticmethod
def of(values: list[Any], leaf_encoding: list[str]) -> 'StandardMerkleTree':
hashed_values: list[HashedValue] = []
def of(values: list[T], leaf_encoding: list[str]) -> 'StandardMerkleTree[T]':
hashed_values: list[HashedValue[T]] = []
for index, value in enumerate(values):
hashed_values.append(
HashedValue(value=value, index=index, hash=standard_leaf_hash(value, leaf_encoding))
Expand All @@ -77,7 +87,7 @@ def of(values: list[Any], leaf_encoding: list[str]) -> 'StandardMerkleTree':
return StandardMerkleTree(tree, indexed_values, leaf_encoding)

@staticmethod
def load(data: StandardMerkleTreeData) -> 'StandardMerkleTree':
def load(data: StandardMerkleTreeData[T]) -> 'StandardMerkleTree[T]':
if data.format != 'standard-v1':
raise ValueError(f"Unknown format '{data.format}'")
return StandardMerkleTree(
Expand All @@ -88,7 +98,7 @@ def load(data: StandardMerkleTreeData) -> 'StandardMerkleTree':

@staticmethod
def verify(
root: HexStr, leaf_encoding: list[str], leaf_value: Any, proof: list[HexStr]
root: HexStr, leaf_encoding: list[str], leaf_value: T, proof: list[HexStr]
) -> bool:
leaf_hash = standard_leaf_hash(leaf_value, leaf_encoding)
implied_root = process_proof(leaf_hash, [hex_to_bytes(x) for x in proof])
Expand All @@ -99,7 +109,7 @@ def verify_multi_proof(root: HexStr, leaf_encoding: list[str], multiproof: Multi
leaf_hashes = [standard_leaf_hash(value, leaf_encoding) for value in multiproof.leaves]
proof_bytes = [hex_to_bytes(x) for x in multiproof.proof]
implied_root = process_multi_proof(
multiproof=MultiProof(
multiproof=CoreMultiProof(
leaves=leaf_hashes,
proof=proof_bytes,
proof_flags=multiproof.proof_flags,
Expand All @@ -108,7 +118,7 @@ def verify_multi_proof(root: HexStr, leaf_encoding: list[str], multiproof: Multi

return equals_bytes(implied_root, hex_to_bytes(root))

def dump(self) -> StandardMerkleTreeData:
def dump(self) -> StandardMerkleTreeData[T]:
return StandardMerkleTreeData(
format='standard-v1',
tree=[to_hex(v) for v in self.tree],
Expand All @@ -127,10 +137,10 @@ def validate(self) -> None:
if not is_valid_merkle_tree(self.tree):
raise ValueError("Merkle tree is invalid")

def leaf_hash(self, leaf: Any) -> HexStr:
def leaf_hash(self, leaf: T) -> HexStr:
return to_hex(standard_leaf_hash(leaf, self.leaf_encoding))

def leaf_lookup(self, leaf: Any) -> int:
def leaf_lookup(self, leaf: T) -> int:
v = self._hash_lookup[self.leaf_hash(leaf)]
if v is None:
raise ValueError("Leaf is not in tree")
Expand All @@ -156,13 +166,12 @@ def get_proof(self, leaf: T | int) -> list[HexStr]:

return [to_hex(p) for p in proof]

def get_multi_proof(self, leaves: list[Any]) -> MultiProof:
def get_multi_proof(self, leaves: list[int] | list[T]) -> MultiProof:
# input validity
value_indices = []
value_indices: list[int] = []
for leaf in leaves:
value_index = leaf
if isinstance(leaf, int):
value_indices.append(value_index)
value_indices.append(leaf)
else:
value_indices.append(self.leaf_lookup(leaf))

Expand Down Expand Up @@ -194,20 +203,20 @@ def _verify_leaf(self, leaf_hash: bytes, proof: list[bytes]) -> bool:

def verify_multi_proof_leaf(self, multiproof: MultiProof) -> bool:
return self._verify_multi_proof_leaf(
MultiProof(
CoreMultiProof(
leaves=[self._get_leaf_hash(leaf) for leaf in multiproof.leaves],
proof=[hex_to_bytes(proof) for proof in multiproof.proof],
proof_flags=multiproof.proof_flags,
)
)

def _verify_multi_proof_leaf(self, multi_proof: MultiProof) -> bool:
def _verify_multi_proof_leaf(self, multi_proof: CoreMultiProof) -> bool:
implied_root = process_multi_proof(multi_proof)
return equals_bytes(implied_root, self.tree[0])

def _validate_value(self, value_index: int) -> bytes:
check_bounds(self.values, value_index)
leaf: LeafValue = self.values[value_index]
leaf = self.values[value_index]
check_bounds(self.tree, leaf.tree_index)
leaf_hash = standard_leaf_hash(leaf.value, self.leaf_encoding)

Expand Down

0 comments on commit 35ecaec

Please sign in to comment.