From 83e4d5adf3e42e3ca5971d4de056d870a59409d9 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Sat, 22 Feb 2025 11:09:53 -0500 Subject: [PATCH] more datalayer rust transition prep --- chia/_tests/core/data_layer/test_data_rpc.py | 12 ++-- .../_tests/core/data_layer/test_data_store.py | 25 +++++---- chia/data_layer/data_layer.py | 3 +- chia/data_layer/data_layer_util.py | 6 +- chia/data_layer/data_store.py | 55 +++++++++++-------- poetry.lock | 2 +- 6 files changed, 61 insertions(+), 42 deletions(-) diff --git a/chia/_tests/core/data_layer/test_data_rpc.py b/chia/_tests/core/data_layer/test_data_rpc.py index b7f2a7ad2565..f7f6d19c072f 100644 --- a/chia/_tests/core/data_layer/test_data_rpc.py +++ b/chia/_tests/core/data_layer/test_data_rpc.py @@ -19,6 +19,7 @@ from typing import Any, Optional, cast import anyio +import chia_rs.datalayer import pytest from chia._tests.util.misc import boolean_datacases @@ -1017,10 +1018,8 @@ async def process_for_data_layer_keys( for sleep_time in backoff_times(): try: value = await data_layer.get_value(store_id=store_id, key=expected_key) - except Exception as e: - # TODO: more specific exceptions... - if "Key not found" not in str(e): - raise # pragma: no cover + except (KeyNotFoundError, chia_rs.datalayer.UnknownKeyError): + pass else: if expected_value is None or value == expected_value: break @@ -3437,7 +3436,10 @@ async def test_unsubmitted_batch_update( count=NUM_BLOCKS_WITHOUT_SUBMIT, guarantee_transaction_blocks=True ) keys_values = await data_rpc_api.get_keys_values({"id": store_id.hex()}) - assert keys_values == prev_keys_values + # order agnostic comparison of the list of dicts + assert {item["key"]: item for item in keys_values["keys_values"]} == { + item["key"]: item for item in prev_keys_values["keys_values"] + } pending_root = await data_layer.data_store.get_pending_root(store_id=store_id) assert pending_root is not None diff --git a/chia/_tests/core/data_layer/test_data_store.py b/chia/_tests/core/data_layer/test_data_store.py index d341d7a1aedf..3eb856a27b0e 100644 --- a/chia/_tests/core/data_layer/test_data_store.py +++ b/chia/_tests/core/data_layer/test_data_store.py @@ -14,6 +14,7 @@ from typing import Any, BinaryIO, Callable, Optional import aiohttp +import chia_rs.datalayer import pytest from chia_rs.datalayer import TreeIndex @@ -39,10 +40,9 @@ get_full_tree_filename_path, leaf_hash, ) -from chia.data_layer.data_store import DataStore +from chia.data_layer.data_store import DataStore, InternalTypes, LeafTypes, MerkleBlobHint from chia.data_layer.download_data import insert_from_delta_file, write_files_for_root from chia.data_layer.util.benchmark import generate_datastore -from chia.data_layer.util.merkle_blob import MerkleBlob, RawInternalMerkleNode, RawLeafMerkleNode from chia.types.blockchain_format.program import Program from chia.types.blockchain_format.sized_bytes import bytes32 from chia.util.byte_types import hexstr_to_bytes @@ -492,8 +492,8 @@ async def test_insert_batch_reference_and_side( nodes_with_indexes = merkle_blob.get_nodes_with_indexes() nodes = [pair[1] for pair in nodes_with_indexes] assert len(nodes) == 3 - assert isinstance(nodes[1], RawLeafMerkleNode) - assert isinstance(nodes[2], RawLeafMerkleNode) + assert isinstance(nodes[1], LeafTypes) + assert isinstance(nodes[2], LeafTypes) left_terminal_node = await data_store.get_terminal_node(nodes[1].key, nodes[1].value, store_id) right_terminal_node = await data_store.get_terminal_node(nodes[2].key, nodes[2].value, store_id) if side == Side.LEFT: @@ -1246,7 +1246,7 @@ async def write_tree_to_file_old_format( node_hash: bytes32, store_id: bytes32, writer: BinaryIO, - merkle_blob: Optional[MerkleBlob] = None, + merkle_blob: Optional[MerkleBlobHint] = None, hash_to_index: Optional[dict[bytes32, TreeIndex]] = None, ) -> None: if node_hash == bytes32.zeros: @@ -1266,13 +1266,13 @@ async def write_tree_to_file_old_format( raw_node = merkle_blob.get_raw_node(raw_index) to_write = b"" - if isinstance(raw_node, RawInternalMerkleNode): + if isinstance(raw_node, InternalTypes): left_hash = merkle_blob.get_hash_at_index(raw_node.left) right_hash = merkle_blob.get_hash_at_index(raw_node.right) await write_tree_to_file_old_format(data_store, root, left_hash, store_id, writer, merkle_blob, hash_to_index) await write_tree_to_file_old_format(data_store, root, right_hash, store_id, writer, merkle_blob, hash_to_index) to_write = bytes(SerializedNode(False, bytes(left_hash), bytes(right_hash))) - elif isinstance(raw_node, RawLeafMerkleNode): + elif isinstance(raw_node, LeafTypes): node = await data_store.get_terminal_node(raw_node.key, raw_node.value, store_id) to_write = bytes(SerializedNode(True, node.key, node.value)) else: @@ -1680,7 +1680,8 @@ async def mock_http_download_2( filenames = {entry.name for entry in entries} assert len(filenames) == num_files + max_full_files - 1 kv = await data_store.get_keys_values(store_id=store_id) - assert kv == kv_before + # order agnostic comparison of the list + assert set(kv) == set(kv_before) @pytest.mark.anyio @@ -1718,7 +1719,7 @@ async def test_get_node_by_key_with_overlapping_keys(raw_data_store: DataStore) if random.randint(0, 4) == 0: batch = [{"action": "delete", "key": key}] await raw_data_store.insert_batch(store_id, batch, status=Status.COMMITTED) - with pytest.raises(KeyNotFoundError, match=f"Key not found: {key.hex()}"): + with pytest.raises((KeyNotFoundError, chia_rs.datalayer.UnknownKeyError)): await raw_data_store.get_node_by_key(store_id=store_id, key=key) @@ -1787,7 +1788,8 @@ async def test_insert_from_delta_file_correct_file_exists( filenames = {entry.name for entry in entries} assert len(filenames) == num_files + 2 # 1 full and 6 deltas kv = await data_store.get_keys_values(store_id=store_id) - assert kv == kv_before + # order agnostic comparison of the list + assert set(kv) == set(kv_before) @pytest.mark.anyio @@ -2002,7 +2004,8 @@ async def test_migration( data_store.recent_merkle_blobs = LRUCache(capacity=128) assert await data_store.get_keys_values(store_id=store_id) == [] await data_store.migrate_db(tmp_path) - assert await data_store.get_keys_values(store_id=store_id) == kv_before + # order agnostic comparison of the list + assert set(await data_store.get_keys_values(store_id=store_id)) == set(kv_before) @pytest.mark.anyio diff --git a/chia/data_layer/data_layer.py b/chia/data_layer/data_layer.py index 1c65952bb1a8..814c13b4ff13 100644 --- a/chia/data_layer/data_layer.py +++ b/chia/data_layer/data_layer.py @@ -34,6 +34,7 @@ ProofOfInclusionLayer, Root, ServerInfo, + Side, Status, StoreProofs, Subscription, @@ -1065,7 +1066,7 @@ async def process_offered_stores(self, offer_stores: tuple[OfferStore, ...]) -> node_hash=proof_of_inclusion.node_hash, layers=tuple( Layer( - other_hash_side=layer.other_hash_side, + other_hash_side=Side(layer.other_hash_side), other_hash=layer.other_hash, combined_hash=layer.combined_hash, ) diff --git a/chia/data_layer/data_layer_util.py b/chia/data_layer/data_layer_util.py index a7d0ba6e0848..3c87880e2d63 100644 --- a/chia/data_layer/data_layer_util.py +++ b/chia/data_layer/data_layer_util.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union import aiosqlite +import chia_rs.datalayer from typing_extensions import final from chia.data_layer.data_layer_errors import ProofIntegrityError @@ -24,6 +25,9 @@ from chia.data_layer.data_store import DataStore from chia.wallet.wallet_node import WalletNode +ProofOfInclusionHint = Union["ProofOfInclusion", chia_rs.datalayer.ProofOfInclusion] +ProofOfInclusionLayerHint = Union["ProofOfInclusionLayer", chia_rs.datalayer.ProofOfInclusionLayer] + def internal_hash(left_hash: bytes32, right_hash: bytes32) -> bytes32: # see test for the definition this is optimized from @@ -187,7 +191,7 @@ class NodeType(IntEnum): @final -class Side(IntEnum): +class Side(uint8, Enum): LEFT = 0 RIGHT = 1 diff --git a/chia/data_layer/data_store.py b/chia/data_layer/data_store.py index dfcb3bcb95bc..194e6c2ac4b6 100644 --- a/chia/data_layer/data_store.py +++ b/chia/data_layer/data_store.py @@ -12,6 +12,7 @@ from typing import Any, BinaryIO, Callable, Optional, Union import aiosqlite +import chia_rs.datalayer from chia_rs.datalayer import KeyId, TreeIndex, ValueId from chia.data_layer.data_layer_errors import KeyNotFoundError, MerkleBlobNotFoundError, TreeGenerationIncrementingError @@ -26,7 +27,7 @@ Node, NodeType, OperationType, - ProofOfInclusion, + ProofOfInclusionHint, Root, SerializedNode, ServerInfo, @@ -60,13 +61,17 @@ # TODO: review exceptions for values that shouldn't be displayed # TODO: pick exception types other than Exception +MerkleBlobHint = Union[MerkleBlob, chia_rs.datalayer.MerkleBlob] +LeafTypes = (RawLeafMerkleNode, chia_rs.datalayer.LeafNode) +InternalTypes = (RawInternalMerkleNode, chia_rs.datalayer.InternalNode) + @dataclass class DataStore: """A key/value store with the pairs being terminal nodes in a CLVM object tree.""" db_wrapper: DBWrapper2 - recent_merkle_blobs: LRUCache[bytes32, MerkleBlob] + recent_merkle_blobs: LRUCache[bytes32, MerkleBlobHint] @classmethod @contextlib.asynccontextmanager @@ -86,7 +91,7 @@ async def managed( row_factory=aiosqlite.Row, log_path=sql_log_path, ) as db_wrapper: - recent_merkle_blobs: LRUCache[bytes32, MerkleBlob] = LRUCache(capacity=128) + recent_merkle_blobs: LRUCache[bytes32, MerkleBlobHint] = LRUCache(capacity=128) self = cls(db_wrapper=db_wrapper, recent_merkle_blobs=recent_merkle_blobs) async with db_wrapper.writer() as writer: @@ -297,9 +302,9 @@ async def insert_into_data_store_from_file( nodes = merkle_blob.get_nodes_with_indexes(index=index) index_to_hash = {index: bytes32(node.hash) for index, node in nodes} for _, node in nodes: - if isinstance(node, RawLeafMerkleNode): + if isinstance(node, LeafTypes): terminal_nodes[bytes32(node.hash)] = (node.key, node.value) - elif isinstance(node, RawInternalMerkleNode): + elif isinstance(node, InternalTypes): internal_nodes[bytes32(node.hash)] = (index_to_hash[node.left], index_to_hash[node.right]) merkle_blob = MerkleBlob.from_node_list(internal_nodes, terminal_nodes, root_hash) @@ -392,7 +397,7 @@ async def get_merkle_blob( root_hash: Optional[bytes32], read_only: bool = False, update_cache: bool = True, - ) -> MerkleBlob: + ) -> MerkleBlobHint: if root_hash is None: return MerkleBlob(blob=bytearray()) @@ -422,7 +427,7 @@ async def get_merkle_blob( async def insert_root_from_merkle_blob( self, - merkle_blob: MerkleBlob, + merkle_blob: MerkleBlobHint, store_id: bytes32, status: Status, old_root: Optional[Root] = None, @@ -836,11 +841,11 @@ async def get_ancestors( merkle_blob = await self.get_merkle_blob(root_hash=root_hash) reference_kid, _ = merkle_blob.get_node_by_hash(node_hash) - reference_index = merkle_blob.key_to_index[reference_kid] + reference_index = merkle_blob.get_key_index(reference_kid) lineage = merkle_blob.get_lineage_with_indexes(reference_index) result: list[InternalNode] = [] for index, node in itertools.islice(lineage, 1, None): - assert isinstance(node, RawInternalMerkleNode) + assert isinstance(node, InternalTypes) result.append( InternalNode( hash=node.hash, @@ -1096,17 +1101,19 @@ async def get_keys( return keys - def get_reference_kid_side(self, merkle_blob: MerkleBlob, seed: bytes32) -> tuple[KeyId, Side]: + def get_reference_kid_side(self, merkle_blob: MerkleBlobHint, seed: bytes32) -> tuple[KeyId, Side]: side_seed = bytes(seed)[0] side = Side.LEFT if side_seed < 128 else Side.RIGHT reference_node = merkle_blob.get_random_leaf_node(seed) kid = reference_node.key return (kid, side) - async def get_terminal_node_from_kid(self, merkle_blob: MerkleBlob, kid: KeyId, store_id: bytes32) -> TerminalNode: - index = merkle_blob.key_to_index[kid] + async def get_terminal_node_from_kid( + self, merkle_blob: MerkleBlobHint, kid: KeyId, store_id: bytes32 + ) -> TerminalNode: + index = merkle_blob.get_key_index(kid) raw_node = merkle_blob.get_raw_node(index) - assert isinstance(raw_node, RawLeafMerkleNode) + assert isinstance(raw_node, LeafTypes) return await self.get_terminal_node(raw_node.key, raw_node.value, store_id) async def get_terminal_node_for_seed(self, seed: bytes32, store_id: bytes32) -> Optional[TerminalNode]: @@ -1249,7 +1256,11 @@ async def insert_batch( key_hashed = key_hash(key) kid, vid = await self.add_key_value(key, value, store_id) - if merkle_blob.key_exists(kid): + try: + merkle_blob.get_key_index(kid) + except (KeyError, chia_rs.datalayer.UnknownKeyError): + pass + else: raise Exception(f"Key already present: {key.hex()}") hash = leaf_hash(key, value) @@ -1362,8 +1373,6 @@ async def get_node_by_key( if kvid is None: raise KeyNotFoundError(key=key) kid = KeyId(kvid) - if not merkle_blob.key_exists(kid): - raise KeyNotFoundError(key=key) return await self.get_terminal_node_from_kid(merkle_blob, kid, store_id) async def get_node(self, node_hash: bytes32) -> Node: @@ -1389,14 +1398,14 @@ async def get_tree_as_nodes(self, store_id: bytes32) -> Node: hash_to_node: dict[bytes32, Node] = {} tree_node: Node for _, node in reversed(nodes): - if isinstance(node, RawInternalMerkleNode): + if isinstance(node, InternalTypes): left_hash = merkle_blob.get_hash_at_index(node.left) right_hash = merkle_blob.get_hash_at_index(node.right) tree_node = InternalNode.from_child_nodes( left=hash_to_node[left_hash], right=hash_to_node[right_hash] ) else: - assert isinstance(node, RawLeafMerkleNode) + assert isinstance(node, LeafTypes) tree_node = await self.get_terminal_node(node.key, node.value, store_id) hash_to_node[bytes32(node.hash)] = tree_node @@ -1409,7 +1418,7 @@ async def get_proof_of_inclusion_by_hash( node_hash: bytes32, store_id: bytes32, root_hash: Optional[bytes32] = None, - ) -> ProofOfInclusion: + ) -> ProofOfInclusionHint: if root_hash is None: root = await self.get_tree_root(store_id=store_id) root_hash = root.node_hash @@ -1421,7 +1430,7 @@ async def get_proof_of_inclusion_by_key( self, key: bytes, store_id: bytes32, - ) -> ProofOfInclusion: + ) -> ProofOfInclusionHint: root = await self.get_tree_root(store_id=store_id) merkle_blob = await self.get_merkle_blob(root_hash=root.node_hash) kvid = await self.get_kvid(key, store_id) @@ -1437,7 +1446,7 @@ async def write_tree_to_file( store_id: bytes32, deltas_only: bool, writer: BinaryIO, - merkle_blob: Optional[MerkleBlob] = None, + merkle_blob: Optional[MerkleBlobHint] = None, hash_to_index: Optional[dict[bytes32, TreeIndex]] = None, existing_hashes: Optional[set[bytes32]] = None, ) -> None: @@ -1465,7 +1474,7 @@ async def write_tree_to_file( raw_node = merkle_blob.get_raw_node(raw_index) to_write = b"" - if isinstance(raw_node, RawInternalMerkleNode): + if isinstance(raw_node, InternalTypes): left_hash = merkle_blob.get_hash_at_index(raw_node.left) right_hash = merkle_blob.get_hash_at_index(raw_node.right) await self.write_tree_to_file( @@ -1475,7 +1484,7 @@ async def write_tree_to_file( root, right_hash, store_id, deltas_only, writer, merkle_blob, hash_to_index, existing_hashes ) to_write = bytes(SerializedNode(False, bytes(left_hash), bytes(right_hash))) - elif isinstance(raw_node, RawLeafMerkleNode): + elif isinstance(raw_node, LeafTypes): node = await self.get_terminal_node(raw_node.key, raw_node.value, store_id) to_write = bytes(SerializedNode(True, node.key, node.value)) else: diff --git a/poetry.lock b/poetry.lock index e89b062cb384..90c6f368c448 100644 --- a/poetry.lock +++ b/poetry.lock @@ -782,7 +782,7 @@ typing-extensions = "*" type = "git" url = "https://github.com/chia-network/chia_rs" reference = "long_lived/initial_datalayer" -resolved_reference = "cbbfa261633e174dfffb0b2062171e99467c366c" +resolved_reference = "da91206dc3b8f8909b4d6025930e13e35935255e" subdirectory = "wheel/" [[package]]