Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update mypy defs, update packages #9

Merged
merged 1 commit into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion multiproof/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .standard import StandardMerkleTree, StandardMerkleTreeData
from .standard import MultiProof, StandardMerkleTree, StandardMerkleTreeData
19 changes: 12 additions & 7 deletions multiproof/bytes.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
def has_hex_prefix(hex_string: str) -> bool:
from typing import cast

from eth_typing import HexStr


def has_hex_prefix(hex_string: str | HexStr) -> bool:
return hex_string.startswith("0x")


def add_hex_prefix(hex_string: str) -> str:
def add_hex_prefix(hex_string: str) -> HexStr:
if not has_hex_prefix(hex_string):
return "0x" + hex_string
return hex_string
return HexStr("0x" + hex_string)
return cast(HexStr, hex_string)


def remove_hex_prefix(hex_string: str) -> str:
def remove_hex_prefix(hex_string: HexStr) -> str:
if has_hex_prefix(hex_string):
return hex_string[len("0x"):]

return hex_string


def to_hex(b: bytes) -> str:
def to_hex(b: bytes) -> HexStr:
return add_hex_prefix(b.hex())


def hex_to_bytes(hex_string: str) -> bytes:
def hex_to_bytes(hex_string: HexStr) -> bytes:
return bytes.fromhex(remove_hex_prefix(hex_string))


Expand Down
27 changes: 16 additions & 11 deletions multiproof/standard.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from dataclasses import dataclass
from functools import cmp_to_key
from typing import Any
from typing import Any, Generic, TypeVar

from eth_abi import encode as abi_encode
from eth_typing import HexStr
from web3 import Web3

from multiproof.bytes import compare_bytes, equals_bytes, hex_to_bytes, to_hex
Expand All @@ -12,16 +13,18 @@
process_proof, right_child_index)
from multiproof.utils import check_bounds

T = TypeVar('T')


@dataclass
class LeafValue:
value: Any | None
class LeafValue(Generic[T]):
value: T
tree_index: int


@dataclass
class StandardMerkleTreeData:
tree: list[str]
tree: list[HexStr]
values: list[LeafValue]
leaf_encoding: list[str]
format: str = 'standard-v1'
Expand All @@ -39,7 +42,7 @@ def standard_leaf_hash(values: Any, types: list[str]) -> bytes:


class StandardMerkleTree:
_hash_lookup: dict[str, int]
_hash_lookup: dict[HexStr, int]
tree: list[bytes]
values: list[LeafValue]
leaf_encoding: list[str]
Expand Down Expand Up @@ -84,13 +87,15 @@ def load(data: StandardMerkleTreeData) -> 'StandardMerkleTree':
)

@staticmethod
def verify(root: str, leaf_encoding: list[str], leaf_value: Any, proof: list[str]) -> bool:
def verify(
root: HexStr, leaf_encoding: list[str], leaf_value: Any, proof: list[HexStr]
) -> bool:
leaf_hash = standard_leaf_hash(leaf_value, leaf_encoding)
implied_root = process_proof(leaf_hash, [hex_to_bytes(x) for x in proof])
return equals_bytes(implied_root, hex_to_bytes(root))

@staticmethod
def verify_multi_proof(root: str, leaf_encoding: list[str], multiproof: MultiProof) -> bool:
def verify_multi_proof(root: HexStr, leaf_encoding: list[str], multiproof: MultiProof) -> bool:
leaf_hashes = [standard_leaf_hash(value, leaf_encoding) for value in multiproof.leaves]
proof_bytes = [hex_to_bytes(x) for x in multiproof.proof]
implied_root = process_multi_proof(
Expand All @@ -112,7 +117,7 @@ def dump(self) -> StandardMerkleTreeData:
)

@property
def root(self) -> str:
def root(self) -> HexStr:
return to_hex(self.tree[0])

def validate(self) -> None:
Expand All @@ -122,7 +127,7 @@ def validate(self) -> None:
if not is_valid_merkle_tree(self.tree):
raise ValueError("Merkle tree is invalid")

def leaf_hash(self, leaf: Any) -> str:
def leaf_hash(self, leaf: Any) -> HexStr:
return to_hex(standard_leaf_hash(leaf, self.leaf_encoding))

def leaf_lookup(self, leaf: Any) -> int:
Expand All @@ -131,7 +136,7 @@ def leaf_lookup(self, leaf: Any) -> int:
raise ValueError("Leaf is not in tree")
return v

def get_proof(self, leaf: LeafValue | int) -> list[str]:
def get_proof(self, leaf: T | int) -> list[HexStr]:
# input validity
value_index: int = leaf # type: ignore
if not isinstance(leaf, int):
Expand Down Expand Up @@ -180,7 +185,7 @@ def get_multi_proof(self, leaves: list[Any]) -> MultiProof:
proof_flags=proof.proof_flags,
)

def verify_leaf(self, leaf: int, proof: list[str]) -> bool:
def verify_leaf(self, leaf: int, proof: list[HexStr]) -> bool:
return self._verify_leaf(self._get_leaf_hash(leaf), [hex_to_bytes(p) for p in proof])

def _verify_leaf(self, leaf_hash: bytes, proof: list[bytes]) -> bool:
Expand Down
Loading
Loading