Skip to content

Commit

Permalink
simplify other_included() in MerkleSet to always truncate the proof
Browse files Browse the repository at this point in the history
  • Loading branch information
arvidn committed Apr 15, 2024
1 parent 49521dc commit efe81b4
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions chia/util/merkle_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from abc import ABCMeta, abstractmethod
from hashlib import sha256
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple

from chia.types.blockchain_format.sized_bytes import bytes32

Expand Down Expand Up @@ -109,7 +109,7 @@ def is_included(self, tocheck: bytes, depth: int, p: List[bytes]) -> bool:
pass

@abstractmethod
def other_included(self, tocheck: bytes, depth: int, p: List[bytes], collapse: bool) -> None:
def other_included(self, p: List[bytes]) -> None:
pass

@abstractmethod
Expand All @@ -120,13 +120,10 @@ def _audit(self, hashes: List[bytes], bits: List[int]) -> None:
class MerkleSet:
root: Node

def __init__(self, leafs: List[bytes32], *, set_root: Optional[Node] = None):
if set_root is not None:
self.root = set_root
else:
self.root = _empty
for leaf in leafs:
self.root = self.root.add(leaf, 0)
def __init__(self, leafs: Iterable[bytes32]):
self.root = _empty
for leaf in leafs:
self.root = self.root.add(leaf, 0)

def get_root(self) -> bytes32:
return compress_root(self.root.get_hash())
Expand All @@ -141,6 +138,12 @@ def _audit(self, hashes: List[bytes]) -> None:
self.root._audit(newhashes, [])
assert newhashes == sorted(newhashes)

@staticmethod
def _with_root(root: Node) -> MerkleSet:
s = MerkleSet([])
s.root = root
return s


class EmptyNode(Node):
def __init__(self) -> None:
Expand All @@ -165,7 +168,7 @@ def is_included(self, tocheck: bytes, depth: int, p: List[bytes]) -> bool:
p.append(EMPTY)
return False

def other_included(self, tocheck: bytes, depth: int, p: List[bytes], collapse: bool) -> None:
def other_included(self, p: List[bytes]) -> None:
p.append(EMPTY)

def _audit(self, hashes: List[bytes], bits: List[int]) -> None:
Expand Down Expand Up @@ -216,7 +219,7 @@ def is_included(self, tocheck: bytes, depth: int, p: List[bytes]) -> bool:
p.append(TERMINAL + self.hash)
return tocheck == self.hash

def other_included(self, tocheck: bytes, depth: int, p: List[bytes], collapse: bool) -> None:
def other_included(self, p: List[bytes]) -> None:
p.append(TERMINAL + self.hash)

def _audit(self, hashes: List[bytes], bits: List[int]) -> None:
Expand Down Expand Up @@ -271,17 +274,14 @@ def is_included(self, tocheck: bytes, depth: int, p: List[bytes]) -> bool:
p.append(MIDDLE)
if get_bit(tocheck, depth) == 0:
r = self.children[0].is_included(tocheck, depth + 1, p)
self.children[1].other_included(tocheck, depth + 1, p, not self.children[0].is_empty())
self.children[1].other_included(p)
return r
else:
self.children[0].other_included(tocheck, depth + 1, p, not self.children[1].is_empty())
self.children[0].other_included(p)
return self.children[1].is_included(tocheck, depth + 1, p)

def other_included(self, tocheck: bytes, depth: int, p: List[bytes], collapse: bool) -> None:
if collapse or not self.is_double():
p.append(TRUNCATED + self.hash)
else:
self.is_included(tocheck, depth, p)
def other_included(self, p: List[bytes]) -> None:
p.append(TRUNCATED + self.hash)

def _audit(self, hashes: List[bytes], bits: List[int]) -> None:
self.children[0]._audit(hashes, bits + [0])
Expand Down Expand Up @@ -310,7 +310,7 @@ def add(self, toadd: bytes, depth: int) -> Node:
def is_included(self, tocheck: bytes, depth: int, p: List[bytes]) -> bool:
raise SetError()

def other_included(self, tocheck: bytes, depth: int, p: List[bytes], collapse: bool) -> None:
def other_included(self, p: List[bytes]) -> None:
p.append(TRUNCATED + self.hash)

def _audit(self, hashes: List[bytes], bits: List[int]) -> None:
Expand Down Expand Up @@ -345,7 +345,7 @@ def deserialize_proof(proof: bytes) -> MerkleSet:
r, pos = _deserialize(proof, 0, [])
if pos != len(proof):
raise SetError()
return MerkleSet([], set_root=r)
return MerkleSet._with_root(r)
except IndexError:
raise SetError()

Expand Down

0 comments on commit efe81b4

Please sign in to comment.