Skip to content

Commit

Permalink
some rs datalayer via store
Browse files Browse the repository at this point in the history
  • Loading branch information
altendky committed Feb 18, 2025
1 parent a8724f0 commit d8f083c
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 48 deletions.
11 changes: 6 additions & 5 deletions chia/_tests/core/data_layer/test_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@
Subscription,
TerminalNode,
_debug_dump,
as_program,
get_delta_filename_path,
get_full_tree_filename_path,
leaf_hash,
)
from chia.data_layer.data_store import DataStore
from chia.data_layer.data_store import DataStore, MerkleBlobHint
from chia.data_layer.download_data import insert_from_delta_file, write_files_for_root
from chia.data_layer.util.benchmark import generate_datastore
from chia.data_layer.util.merkle_blob import MerkleBlob, RawInternalMerkleNode, RawLeafMerkleNode
from chia.data_layer.util.merkle_blob import RawInternalMerkleNode, RawLeafMerkleNode
from chia.types.blockchain_format.program import Program
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.byte_types import hexstr_to_bytes
Expand Down Expand Up @@ -792,7 +793,7 @@ async def test_proof_of_inclusion_by_hash_program(data_store: DataStore, store_i

proof = await data_store.get_proof_of_inclusion_by_hash(node_hash=node.hash, store_id=store_id)

assert proof.as_program() == [
assert as_program(proof) == [
b"\x04",
[
bytes32.fromhex("fb66fe539b3eb2020dfbfadfd601fa318521292b41f04c2057c16fca6b947ca1"),
Expand Down Expand Up @@ -833,7 +834,7 @@ async def test_proof_of_inclusion_by_hash_bytes(data_store: DataStore, store_id:
b"\xe2\xa0\xaeX\xe2\x80\x80"
)

assert bytes(proof.as_program()) == expected
assert bytes(as_program(proof)) == expected


# @pytest.mark.anyio
Expand Down Expand Up @@ -1286,7 +1287,7 @@ async def write_tree_to_file_old_format(
node_hash: bytes32,
store_id: bytes32,
writer: BinaryIO,
merkle_blob: Optional[MerkleBlob] = None,
merkle_blob: Optional[MerkleBlobHint] = None,
hash_to_index: Optional[dict[bytes32, TreeIndex]] = None,
) -> None:
if node_hash == bytes32.zeros:
Expand Down
11 changes: 7 additions & 4 deletions chia/data_layer/data_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
ProofOfInclusionLayer,
Root,
ServerInfo,
Side,
Status,
StoreProofs,
Subscription,
Expand All @@ -44,6 +45,8 @@
get_delta_filename_path,
get_full_tree_filename_path,
leaf_hash,
sibling_hashes,
sibling_sides_integer,
unspecified,
)
from chia.data_layer.data_layer_wallet import DataLayerWallet, Mirror, verify_offer
Expand Down Expand Up @@ -1065,7 +1068,7 @@ async def process_offered_stores(self, offer_stores: tuple[OfferStore, ...]) ->
node_hash=proof_of_inclusion.node_hash,
layers=tuple(
Layer(
other_hash_side=layer.other_hash_side,
other_hash_side=Side(layer.other_hash_side),
other_hash=layer.other_hash,
combined_hash=layer.combined_hash,
)
Expand Down Expand Up @@ -1164,12 +1167,12 @@ async def take_offer(
for layer in proof.layers
]
proof_of_inclusion = ProofOfInclusion(node_hash=proof.node_hash, layers=layers)
sibling_sides_integer = proof_of_inclusion.sibling_sides_integer()
sibling_sides_integer_value = sibling_sides_integer(proof_of_inclusion)
proofs_of_inclusion.append(
(
root.hex(),
str(sibling_sides_integer),
["0x" + sibling_hash.hex() for sibling_hash in proof_of_inclusion.sibling_hashes()],
str(sibling_sides_integer_value),
["0x" + sibling_hash.hex() for sibling_hash in sibling_hashes(proof_of_inclusion)],
)
)

Expand Down
43 changes: 31 additions & 12 deletions chia/data_layer/data_layer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from enum import Enum, IntEnum
from hashlib import sha256
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union, cast

import aiosqlite
import chia_rs.datalayer
from typing_extensions import final

from chia.data_layer.data_layer_errors import ProofIntegrityError
Expand All @@ -24,6 +25,9 @@
from chia.data_layer.data_store import DataStore
from chia.wallet.wallet_node import WalletNode

ProofOfInclusionHint = Union["ProofOfInclusion", chia_rs.datalayer.ProofOfInclusion]
ProofOfInclusionLayerHint = Union["ProofOfInclusionLayer", chia_rs.datalayer.ProofOfInclusionLayer]


def internal_hash(left_hash: bytes32, right_hash: bytes32) -> bytes32:
# see test for the definition this is optimized from
Expand Down Expand Up @@ -187,7 +191,7 @@ class NodeType(IntEnum):


@final
class Side(IntEnum):
class Side(uint8, Enum):
LEFT = 0
RIGHT = 1

Expand Down Expand Up @@ -278,7 +282,31 @@ def from_hashes(cls, primary_hash: bytes32, other_hash_side: Side, other_hash: b
return cls(other_hash_side=other_hash_side, other_hash=other_hash, combined_hash=combined_hash)


other_side_to_bit = {Side.LEFT: 1, Side.RIGHT: 0}
def sibling_sides_integer(proof: ProofOfInclusionHint) -> int:
# casting to workaround this
# class C: ...
# class D: ...
#
# m: list[C | D]
# reveal_type(enumerate(m))
# # main.py:5: note: Revealed type is "builtins.enumerate[Union[__main__.C, __main__.D]]"
#
# n: list[C] | list[D]
# reveal_type(enumerate(n))
# main.py:9: note: Revealed type is "builtins.enumerate[builtins.object]"

return sum(
(1 << index if cast(ProofOfInclusionLayerHint, layer).other_hash_side == Side.LEFT else 0)
for index, layer in enumerate(proof.layers)
)


def sibling_hashes(proof: ProofOfInclusionHint) -> list[bytes32]:
return [layer.other_hash for layer in proof.layers]


def as_program(proof: ProofOfInclusionHint) -> Program:
return Program.to([sibling_sides_integer(proof), sibling_hashes(proof)])


@dataclass(frozen=True)
Expand All @@ -293,15 +321,6 @@ def root_hash(self) -> bytes32:

return self.layers[-1].combined_hash

def sibling_sides_integer(self) -> int:
return sum(other_side_to_bit[layer.other_hash_side] << index for index, layer in enumerate(self.layers))

def sibling_hashes(self) -> list[bytes32]:
return [layer.other_hash for layer in self.layers]

def as_program(self) -> Program:
return Program.to([self.sibling_sides_integer(), self.sibling_hashes()])

def valid(self) -> bool:
existing_hash = self.node_hash

Expand Down
65 changes: 38 additions & 27 deletions chia/data_layer/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Any, BinaryIO, Callable, Optional, Union

import aiosqlite
import chia_rs.datalayer
from chia_rs.datalayer import KeyId, TreeIndex, ValueId

from chia.data_layer.data_layer_errors import KeyNotFoundError, MerkleBlobNotFoundError, TreeGenerationIncrementingError
Expand All @@ -26,7 +27,7 @@
Node,
NodeType,
OperationType,
ProofOfInclusion,
ProofOfInclusionHint,
Root,
SerializedNode,
ServerInfo,
Expand Down Expand Up @@ -60,13 +61,17 @@
# TODO: review exceptions for values that shouldn't be displayed
# TODO: pick exception types other than Exception

MerkleBlobHint = Union[MerkleBlob, chia_rs.datalayer.MerkleBlob]
LeafTypes = (RawLeafMerkleNode, chia_rs.datalayer.LeafNode)
InternalTypes = (RawInternalMerkleNode, chia_rs.datalayer.InternalNode)


@dataclass
class DataStore:
"""A key/value store with the pairs being terminal nodes in a CLVM object tree."""

db_wrapper: DBWrapper2
recent_merkle_blobs: LRUCache[bytes32, MerkleBlob]
recent_merkle_blobs: LRUCache[bytes32, MerkleBlobHint]

@classmethod
@contextlib.asynccontextmanager
Expand All @@ -86,7 +91,7 @@ async def managed(
row_factory=aiosqlite.Row,
log_path=sql_log_path,
) as db_wrapper:
recent_merkle_blobs: LRUCache[bytes32, MerkleBlob] = LRUCache(capacity=128)
recent_merkle_blobs: LRUCache[bytes32, MerkleBlobHint] = LRUCache(capacity=128)
self = cls(db_wrapper=db_wrapper, recent_merkle_blobs=recent_merkle_blobs)

async with db_wrapper.writer() as writer:
Expand Down Expand Up @@ -297,9 +302,9 @@ async def insert_into_data_store_from_file(
nodes = merkle_blob.get_nodes_with_indexes(index=index)
index_to_hash = {index: bytes32(node.hash) for index, node in nodes}
for _, node in nodes:
if isinstance(node, RawLeafMerkleNode):
if isinstance(node, LeafTypes):
terminal_nodes[bytes32(node.hash)] = (node.key, node.value)
elif isinstance(node, RawInternalMerkleNode):
elif isinstance(node, InternalTypes):
internal_nodes[bytes32(node.hash)] = (index_to_hash[node.left], index_to_hash[node.right])

merkle_blob = MerkleBlob.from_node_list(internal_nodes, terminal_nodes, root_hash)
Expand Down Expand Up @@ -392,13 +397,14 @@ async def get_merkle_blob(
root_hash: Optional[bytes32],
read_only: bool = False,
update_cache: bool = True,
) -> MerkleBlob:
) -> MerkleBlobHint:
if root_hash is None:
return MerkleBlob(blob=bytearray())
return chia_rs.datalayer.MerkleBlob(blob=bytearray())

existing_blob = self.recent_merkle_blobs.get(root_hash)
if existing_blob is not None:
return existing_blob if read_only else copy.deepcopy(existing_blob)
# return existing_blob if read_only else copy.deepcopy(existing_blob)
return existing_blob if read_only else chia_rs.datalayer.MerkleBlob(existing_blob.blob)

async with self.db_wrapper.reader() as reader:
cursor = await reader.execute(
Expand All @@ -413,7 +419,7 @@ async def get_merkle_blob(
if row is None:
raise MerkleBlobNotFoundError(root_hash=root_hash)

merkle_blob = MerkleBlob(blob=bytearray(row["blob"]))
merkle_blob = chia_rs.datalayer.MerkleBlob(blob=bytearray(row["blob"]))

if update_cache:
self.recent_merkle_blobs.put(root_hash, copy.deepcopy(merkle_blob))
Expand All @@ -422,7 +428,7 @@ async def get_merkle_blob(

async def insert_root_from_merkle_blob(
self,
merkle_blob: MerkleBlob,
merkle_blob: MerkleBlobHint,
store_id: bytes32,
status: Status,
old_root: Optional[Root] = None,
Expand All @@ -445,7 +451,8 @@ async def insert_root_from_merkle_blob(
(root_hash, merkle_blob.blob, store_id),
)
if update_cache:
self.recent_merkle_blobs.put(root_hash, copy.deepcopy(merkle_blob))
# self.recent_merkle_blobs.put(root_hash, copy.deepcopy(merkle_blob))
self.recent_merkle_blobs.put(root_hash, chia_rs.datalayer.MerkleBlob(merkle_blob.blob))

return await self._insert_root(store_id, root_hash, status)

Expand Down Expand Up @@ -836,11 +843,11 @@ async def get_ancestors(
merkle_blob = await self.get_merkle_blob(root_hash=root_hash)
reference_kid, _ = merkle_blob.get_node_by_hash(node_hash)

reference_index = merkle_blob.key_to_index[reference_kid]
reference_index = merkle_blob.get_key_index(reference_kid)
lineage = merkle_blob.get_lineage_with_indexes(reference_index)
result: list[InternalNode] = []
for index, node in itertools.islice(lineage, 1, None):
assert isinstance(node, RawInternalMerkleNode)
assert isinstance(node, InternalTypes)
result.append(
InternalNode(
hash=node.hash,
Expand Down Expand Up @@ -1096,17 +1103,19 @@ async def get_keys(

return keys

def get_reference_kid_side(self, merkle_blob: MerkleBlob, seed: bytes32) -> tuple[KeyId, Side]:
def get_reference_kid_side(self, merkle_blob: MerkleBlobHint, seed: bytes32) -> tuple[KeyId, Side]:
side_seed = bytes(seed)[0]
side = Side.LEFT if side_seed < 128 else Side.RIGHT
reference_node = merkle_blob.get_random_leaf_node(seed)
kid = reference_node.key
return (kid, side)

async def get_terminal_node_from_kid(self, merkle_blob: MerkleBlob, kid: KeyId, store_id: bytes32) -> TerminalNode:
index = merkle_blob.key_to_index[kid]
async def get_terminal_node_from_kid(
self, merkle_blob: MerkleBlobHint, kid: KeyId, store_id: bytes32
) -> TerminalNode:
index = merkle_blob.get_key_index(kid)
raw_node = merkle_blob.get_raw_node(index)
assert isinstance(raw_node, RawLeafMerkleNode)
assert isinstance(raw_node, LeafTypes)
return await self.get_terminal_node(raw_node.key, raw_node.value, store_id)

async def get_terminal_node_for_seed(self, seed: bytes32, store_id: bytes32) -> Optional[TerminalNode]:
Expand Down Expand Up @@ -1249,7 +1258,11 @@ async def insert_batch(

key_hashed = key_hash(key)
kid, vid = await self.add_key_value(key, value, store_id)
if merkle_blob.key_exists(kid):
try:
merkle_blob.get_key_index(kid)
except (KeyError, chia_rs.datalayer.UnknownKeyError):
pass
else:
raise Exception(f"Key already present: {key.hex()}")
hash = leaf_hash(key, value)

Expand Down Expand Up @@ -1362,8 +1375,6 @@ async def get_node_by_key(
if kvid is None:
raise KeyNotFoundError(key=key)
kid = KeyId(kvid)
if not merkle_blob.key_exists(kid):
raise KeyNotFoundError(key=key)
return await self.get_terminal_node_from_kid(merkle_blob, kid, store_id)

async def get_node(self, node_hash: bytes32) -> Node:
Expand All @@ -1389,14 +1400,14 @@ async def get_tree_as_nodes(self, store_id: bytes32) -> Node:
hash_to_node: dict[bytes32, Node] = {}
tree_node: Node
for _, node in reversed(nodes):
if isinstance(node, RawInternalMerkleNode):
if isinstance(node, InternalTypes):
left_hash = merkle_blob.get_hash_at_index(node.left)
right_hash = merkle_blob.get_hash_at_index(node.right)
tree_node = InternalNode.from_child_nodes(
left=hash_to_node[left_hash], right=hash_to_node[right_hash]
)
else:
assert isinstance(node, RawLeafMerkleNode)
assert isinstance(node, LeafTypes)
tree_node = await self.get_terminal_node(node.key, node.value, store_id)
hash_to_node[bytes32(node.hash)] = tree_node

Expand All @@ -1409,7 +1420,7 @@ async def get_proof_of_inclusion_by_hash(
node_hash: bytes32,
store_id: bytes32,
root_hash: Optional[bytes32] = None,
) -> ProofOfInclusion:
) -> ProofOfInclusionHint:
if root_hash is None:
root = await self.get_tree_root(store_id=store_id)
root_hash = root.node_hash
Expand All @@ -1421,7 +1432,7 @@ async def get_proof_of_inclusion_by_key(
self,
key: bytes,
store_id: bytes32,
) -> ProofOfInclusion:
) -> ProofOfInclusionHint:
root = await self.get_tree_root(store_id=store_id)
merkle_blob = await self.get_merkle_blob(root_hash=root.node_hash)
kvid = await self.get_kvid(key, store_id)
Expand All @@ -1437,7 +1448,7 @@ async def write_tree_to_file(
store_id: bytes32,
deltas_only: bool,
writer: BinaryIO,
merkle_blob: Optional[MerkleBlob] = None,
merkle_blob: Optional[MerkleBlobHint] = None,
hash_to_index: Optional[dict[bytes32, TreeIndex]] = None,
existing_hashes: Optional[set[bytes32]] = None,
) -> None:
Expand Down Expand Up @@ -1465,7 +1476,7 @@ async def write_tree_to_file(
raw_node = merkle_blob.get_raw_node(raw_index)

to_write = b""
if isinstance(raw_node, RawInternalMerkleNode):
if isinstance(raw_node, InternalTypes):
left_hash = merkle_blob.get_hash_at_index(raw_node.left)
right_hash = merkle_blob.get_hash_at_index(raw_node.right)
await self.write_tree_to_file(
Expand All @@ -1475,7 +1486,7 @@ async def write_tree_to_file(
root, right_hash, store_id, deltas_only, writer, merkle_blob, hash_to_index, existing_hashes
)
to_write = bytes(SerializedNode(False, bytes(left_hash), bytes(right_hash)))
elif isinstance(raw_node, RawLeafMerkleNode):
elif isinstance(raw_node, LeafTypes):
node = await self.get_terminal_node(raw_node.key, raw_node.value, store_id)
to_write = bytes(SerializedNode(True, node.key, node.value))
else:
Expand Down

0 comments on commit d8f083c

Please sign in to comment.