diff --git a/chia/consensus/blockchain.py b/chia/consensus/blockchain.py index b555cab0523d..2d7943651b94 100644 --- a/chia/consensus/blockchain.py +++ b/chia/consensus/blockchain.py @@ -252,6 +252,9 @@ async def receive_block( block, None, ) + peak_height = None + records: List[BlockRecord] = [] + success = True # Always add the block to the database async with self.block_store.db_wrapper.lock: try: @@ -261,32 +264,36 @@ async def receive_block( block_record, genesis, fork_point_with_peak, npc_result ) await self.block_store.db_wrapper.commit_transaction() - self.add_block_record(block_record) - for fetched_block_record in records: - self.__height_to_hash[fetched_block_record.height] = fetched_block_record.header_hash - if fetched_block_record.sub_epoch_summary_included is not None: - if summaries_to_check is not None: - # make sure this matches the summary list we got - ses_n = len(self.get_ses_heights()) - if ( - fetched_block_record.sub_epoch_summary_included.get_hash() - != summaries_to_check[ses_n].get_hash() - ): - log.error( - f"block ses does not match list, " - f"got {fetched_block_record.sub_epoch_summary_included} " - f"expected {summaries_to_check[ses_n]}" - ) - return ReceiveBlockResult.INVALID_BLOCK, Err.INVALID_SUB_EPOCH_SUMMARY, None - self.__sub_epoch_summaries[ - fetched_block_record.height - ] = fetched_block_record.sub_epoch_summary_included - if peak_height is not None: - self._peak_height = peak_height - self.block_store.cache_block(block) except BaseException: + success = False await self.block_store.db_wrapper.rollback_transaction() raise + finally: + if success is True: + self.add_block_record(block_record) + for fetched_block_record in records: + self.__height_to_hash[fetched_block_record.height] = fetched_block_record.header_hash + if fetched_block_record.sub_epoch_summary_included is not None: + if summaries_to_check is not None: + # make sure this matches the summary list we got + ses_n = len(self.get_ses_heights()) + if ( + fetched_block_record.sub_epoch_summary_included.get_hash() + != summaries_to_check[ses_n].get_hash() + ): + log.error( + f"block ses does not match list, " + f"got {fetched_block_record.sub_epoch_summary_included} " + f"expected {summaries_to_check[ses_n]}" + ) + return ReceiveBlockResult.INVALID_BLOCK, Err.INVALID_SUB_EPOCH_SUMMARY, None + self.__sub_epoch_summaries[ + fetched_block_record.height + ] = fetched_block_record.sub_epoch_summary_included + if peak_height is not None: + self._peak_height = peak_height + self.block_store.cache_block(block) + if fork_height is not None: return ReceiveBlockResult.NEW_PEAK, None, fork_height else: diff --git a/chia/full_node/coin_store.py b/chia/full_node/coin_store.py index 1c6ef9925d67..35715fbe4efa 100644 --- a/chia/full_node/coin_store.py +++ b/chia/full_node/coin_store.py @@ -55,7 +55,7 @@ async def create(cls, db_wrapper: DBWrapper, cache_size: uint32 = uint32(60000)) await self.coin_record_db.execute("CREATE INDEX IF NOT EXISTS coin_spent on coin_record(spent)") - await self.coin_record_db.execute("CREATE INDEX IF NOT EXISTS coin_spent on coin_record(puzzle_hash)") + await self.coin_record_db.execute("CREATE INDEX IF NOT EXISTS coin_puzzle_hash on coin_record(puzzle_hash)") await self.coin_record_db.commit() self.coin_record_cache = LRUCache(cache_size) @@ -106,7 +106,7 @@ async def new_block(self, block: FullBlock, tx_additions: List[Coin], tx_removal # Checks DB and DiffStores for CoinRecord with coin_name and returns it async def get_coin_record(self, coin_name: bytes32) -> Optional[CoinRecord]: - cached = self.coin_record_cache.get(coin_name.hex()) + cached = self.coin_record_cache.get(coin_name) if cached is not None: return cached cursor = await self.coin_record_db.execute("SELECT * from coin_record WHERE coin_name=?", (coin_name.hex(),)) @@ -114,7 +114,9 @@ async def get_coin_record(self, coin_name: bytes32) -> Optional[CoinRecord]: await cursor.close() if row is not None: coin = Coin(bytes32(bytes.fromhex(row[6])), bytes32(bytes.fromhex(row[5])), uint64.from_bytes(row[7])) - return CoinRecord(coin, row[1], row[2], row[3], row[4], row[8]) + record = CoinRecord(coin, row[1], row[2], row[3], row[4], row[8]) + self.coin_record_cache.put(record.coin.name(), record) + return record return None async def get_coins_added_at_height(self, height: uint32) -> List[CoinRecord]: @@ -205,7 +207,7 @@ async def rollback_to_block(self, block_index: int): coin_record.coinbase, coin_record.timestamp, ) - self.coin_record_cache.put(coin_record.coin.name().hex(), new_record) + self.coin_record_cache.put(coin_record.coin.name(), new_record) if int(coin_record.confirmed_block_index) > block_index: delete_queue.append(coin_name) @@ -223,6 +225,9 @@ async def rollback_to_block(self, block_index: int): # Store CoinRecord in DB and ram cache async def _add_coin_record(self, record: CoinRecord, allow_replace: bool) -> None: + if self.coin_record_cache.get(record.coin.name()) is not None: + self.coin_record_cache.remove(record.coin.name()) + cursor = await self.coin_record_db.execute( f"INSERT {'OR REPLACE ' if allow_replace else ''}INTO coin_record VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?)", ( @@ -238,7 +243,6 @@ async def _add_coin_record(self, record: CoinRecord, allow_replace: bool) -> Non ), ) await cursor.close() - self.coin_record_cache.put(record.coin.name().hex(), record) # Update coin_record to be spent in DB async def _set_spent(self, coin_name: bytes32, index: uint32) -> uint64: diff --git a/chia/full_node/mempool_manager.py b/chia/full_node/mempool_manager.py index cca6f6705177..3aa0606ac8f8 100644 --- a/chia/full_node/mempool_manager.py +++ b/chia/full_node/mempool_manager.py @@ -71,6 +71,7 @@ def __init__(self, coin_store: CoinStore, consensus_constants: ConsensusConstant # The mempool will correspond to a certain peak self.peak: Optional[BlockRecord] = None self.mempool: Mempool = Mempool(self.mempool_max_total_cost) + self.lock = asyncio.Lock() def shut_down(self): self.pool.shutdown(wait=True) @@ -82,48 +83,49 @@ async def create_bundle_from_mempool( Returns aggregated spendbundle that can be used for creating new block, additions and removals in that spend_bundle """ - if ( - self.peak is None - or self.peak.header_hash != peak_header_hash - or int(time.time()) <= self.constants.INITIAL_FREEZE_END_TIMESTAMP - ): - return None - - cost_sum = 0 # Checks that total cost does not exceed block maximum - fee_sum = 0 # Checks that total fees don't exceed 64 bits - spend_bundles: List[SpendBundle] = [] - removals = [] - additions = [] - broke_from_inner_loop = False - log.info(f"Starting to make block, max cost: {self.constants.MAX_BLOCK_COST_CLVM}") - for dic in self.mempool.sorted_spends.values(): - if broke_from_inner_loop: - break - for item in dic.values(): - log.info(f"Cumulative cost: {cost_sum}") - if ( - item.cost + cost_sum <= self.limit_factor * self.constants.MAX_BLOCK_COST_CLVM - and item.fee + fee_sum <= self.constants.MAX_COIN_AMOUNT - ): - spend_bundles.append(item.spend_bundle) - cost_sum += item.cost - fee_sum += item.fee - removals.extend(item.removals) - additions.extend(item.additions) - else: - broke_from_inner_loop = True + async with self.lock: + if ( + self.peak is None + or self.peak.header_hash != peak_header_hash + or int(time.time()) <= self.constants.INITIAL_FREEZE_END_TIMESTAMP + ): + return None + + cost_sum = 0 # Checks that total cost does not exceed block maximum + fee_sum = 0 # Checks that total fees don't exceed 64 bits + spend_bundles: List[SpendBundle] = [] + removals = [] + additions = [] + broke_from_inner_loop = False + log.info(f"Starting to make block, max cost: {self.constants.MAX_BLOCK_COST_CLVM}") + for dic in self.mempool.sorted_spends.values(): + if broke_from_inner_loop: break - if len(spend_bundles) > 0: - log.info( - f"Cumulative cost of block (real cost should be less) {cost_sum}. Proportion " - f"full: {cost_sum / self.constants.MAX_BLOCK_COST_CLVM}" - ) - agg = SpendBundle.aggregate(spend_bundles) - assert set(agg.additions()) == set(additions) - assert set(agg.removals()) == set(removals) - return agg, additions, removals - else: - return None + for item in dic.values(): + log.info(f"Cumulative cost: {cost_sum}") + if ( + item.cost + cost_sum <= self.limit_factor * self.constants.MAX_BLOCK_COST_CLVM + and item.fee + fee_sum <= self.constants.MAX_COIN_AMOUNT + ): + spend_bundles.append(item.spend_bundle) + cost_sum += item.cost + fee_sum += item.fee + removals.extend(item.removals) + additions.extend(item.additions) + else: + broke_from_inner_loop = True + break + if len(spend_bundles) > 0: + log.info( + f"Cumulative cost of block (real cost should be less) {cost_sum}. Proportion " + f"full: {cost_sum / self.constants.MAX_BLOCK_COST_CLVM}" + ) + agg = SpendBundle.aggregate(spend_bundles) + assert set(agg.additions()) == set(additions) + assert set(agg.removals()) == set(removals) + return agg, additions, removals + else: + return None def get_filter(self) -> bytes: all_transactions: Set[bytes32] = set() @@ -227,6 +229,24 @@ async def add_spendbundle( spend_name: bytes32, validate_signature=True, program: Optional[SerializedProgram] = None, + locked: bool = False, + ) -> Tuple[Optional[uint64], MempoolInclusionStatus, Optional[Err]]: + if not locked: + await self.lock.acquire() + try: + result = await self._add_spendbundle(new_spend, npc_result, spend_name, validate_signature, program) + return result + finally: + if locked is False: + self.lock.release() + + async def _add_spendbundle( + self, + new_spend: SpendBundle, + npc_result: NPCResult, + spend_name: bytes32, + validate_signature=True, + program: Optional[SerializedProgram] = None, ) -> Tuple[Optional[uint64], MempoolInclusionStatus, Optional[Err]]: """ Tries to add spendbundle to either self.mempools or to_pool if it's specified. @@ -489,38 +509,41 @@ async def new_peak(self, new_peak: Optional[BlockRecord]) -> List[Tuple[SpendBun """ Called when a new peak is available, we try to recreate a mempool for the new tip. """ - if new_peak is None: - return [] - if new_peak.is_transaction_block is False: - return [] - if self.peak == new_peak: - return [] - assert new_peak.timestamp is not None - if new_peak.timestamp <= self.constants.INITIAL_FREEZE_END_TIMESTAMP: - return [] - - self.peak = new_peak - - old_pool = self.mempool - self.mempool = Mempool(self.mempool_max_total_cost) - - for item in old_pool.spends.values(): - await self.add_spendbundle(item.spend_bundle, item.npc_result, item.spend_bundle_name, False, item.program) - - potential_txs_copy = self.potential_txs.copy() - self.potential_txs = {} - txs_added = [] - for item in potential_txs_copy.values(): - cost, status, error = await self.add_spendbundle( - item.spend_bundle, item.npc_result, item.spend_bundle_name, program=item.program + async with self.lock: + if new_peak is None: + return [] + if new_peak.is_transaction_block is False: + return [] + if self.peak == new_peak: + return [] + assert new_peak.timestamp is not None + if new_peak.timestamp <= self.constants.INITIAL_FREEZE_END_TIMESTAMP: + return [] + + self.peak = new_peak + + old_pool = self.mempool + self.mempool = Mempool(self.mempool_max_total_cost) + + for item in old_pool.spends.values(): + await self.add_spendbundle( + item.spend_bundle, item.npc_result, item.spend_bundle_name, False, item.program, locked=True + ) + + potential_txs_copy = self.potential_txs.copy() + self.potential_txs = {} + txs_added = [] + for item in potential_txs_copy.values(): + cost, status, error = await self.add_spendbundle( + item.spend_bundle, item.npc_result, item.spend_bundle_name, program=item.program, locked=True + ) + if status == MempoolInclusionStatus.SUCCESS: + txs_added.append((item.spend_bundle, item.npc_result, item.spend_bundle_name)) + log.debug( + f"Size of mempool: {len(self.mempool.spends)} spends, cost: {self.mempool.total_mempool_cost} " + f"minimum fee to get in: {self.mempool.get_min_fee_rate(100000)}" ) - if status == MempoolInclusionStatus.SUCCESS: - txs_added.append((item.spend_bundle, item.npc_result, item.spend_bundle_name)) - log.debug( - f"Size of mempool: {len(self.mempool.spends)} spends, cost: {self.mempool.total_mempool_cost} " - f"minimum fee to get in: {self.mempool.get_min_fee_rate(100000)}" - ) - return txs_added + return txs_added async def get_items_not_in_filter(self, mempool_filter: PyBIP158) -> List[MempoolItem]: items: List[MempoolItem] = []