Skip to content

Commit

Permalink
Wallet (#4887)
Browse files Browse the repository at this point in the history
* spent height

* handle generator reorg

* tx cache

* coin cache

* rebuild cache if write fails

* save last few messages new_peak while syncing

* don't use dupe func

* tx reorg test

* lock not needed

* lint

* lock

* modify properly

* this shouldn't hit a disk ever

* use same number

* notify wallet only once, lock when getting a balance

* lock only if unspent coin records is None

* assert spent

* lint

* Add test for prev generator

Co-authored-by: Mariano <sorgente711@gmail.com>
  • Loading branch information
Yostra and mariano54 authored May 19, 2021
1 parent d46cd8e commit 1086471
Show file tree
Hide file tree
Showing 12 changed files with 285 additions and 184 deletions.
2 changes: 1 addition & 1 deletion chia/full_node/coin_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ async def rollback_to_block(self, block_index: int):
new_record = CoinRecord(
coin_record.coin,
coin_record.confirmed_block_index,
coin_record.spent_block_index,
uint32(0),
False,
coin_record.coinbase,
coin_record.timestamp,
Expand Down
7 changes: 6 additions & 1 deletion chia/full_node/full_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,7 +1171,12 @@ async def respond_block(
)
assert result_to_validate.required_iters == pre_validation_results[0].required_iters
added, error_code, fork_height = await self.blockchain.receive_block(block, result_to_validate, None)

if (
self.full_node_store.previous_generator is not None
and fork_height is not None
and fork_height < self.full_node_store.previous_generator.block_height
):
self.full_node_store.previous_generator = None
validation_time = time.time() - validation_start

if added == ReceiveBlockResult.ALREADY_HAVE_BLOCK:
Expand Down
21 changes: 11 additions & 10 deletions chia/rpc/wallet_rpc_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,16 +442,17 @@ async def get_wallet_balance(self, request: Dict) -> Dict:
assert self.service.wallet_state_manager is not None
wallet_id = uint32(int(request["wallet_id"]))
wallet = self.service.wallet_state_manager.wallets[wallet_id]
unspent_records = await self.service.wallet_state_manager.coin_store.get_unspent_coins_for_wallet(wallet_id)
balance = await wallet.get_confirmed_balance(unspent_records)
pending_balance = await wallet.get_unconfirmed_balance(unspent_records)
spendable_balance = await wallet.get_spendable_balance(unspent_records)
pending_change = await wallet.get_pending_change_balance()
max_send_amount = await wallet.get_max_send_amount(unspent_records)

unconfirmed_removals: Dict[bytes32, Coin] = await wallet.wallet_state_manager.unconfirmed_removals_for_wallet(
wallet_id
)
async with self.service.wallet_state_manager.lock:
unspent_records = await self.service.wallet_state_manager.coin_store.get_unspent_coins_for_wallet(wallet_id)
balance = await wallet.get_confirmed_balance(unspent_records)
pending_balance = await wallet.get_unconfirmed_balance(unspent_records)
spendable_balance = await wallet.get_spendable_balance(unspent_records)
pending_change = await wallet.get_pending_change_balance()
max_send_amount = await wallet.get_max_send_amount(unspent_records)

unconfirmed_removals: Dict[
bytes32, Coin
] = await wallet.wallet_state_manager.unconfirmed_removals_for_wallet(wallet_id)

wallet_balance = {
"wallet_id": wallet_id,
Expand Down
2 changes: 1 addition & 1 deletion chia/wallet/rl_wallet/rl_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ async def _get_rl_parent(self) -> Optional[Coin]:
rl_parent_id = self.rl_coin_record.coin.parent_coin_info
if rl_parent_id == self.rl_info.rl_origin_id:
return self.rl_info.rl_origin
rl_parent = await self.wallet_state_manager.coin_store.get_coin_record_by_coin_id(rl_parent_id)
rl_parent = await self.wallet_state_manager.coin_store.get_coin_record(rl_parent_id)
if rl_parent is None:
return None

Expand Down
2 changes: 1 addition & 1 deletion chia/wallet/trade_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ async def get_locked_coins_in_spend_bundle(self, bundle: SpendBundle) -> Dict[by
result = {}
removals = bundle.removals()
for coin in removals:
coin_record = await self.wallet_state_manager.coin_store.get_coin_record_by_coin_id(coin.name())
coin_record = await self.wallet_state_manager.coin_store.get_coin_record(coin.name())
if coin_record is None:
continue
result[coin_record.name()] = coin_record
Expand Down
16 changes: 13 additions & 3 deletions chia/wallet/wallet_blockchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from chia.wallet.block_record import HeaderBlockRecord
from chia.wallet.wallet_block_store import WalletBlockStore
from chia.wallet.wallet_coin_store import WalletCoinStore
from chia.wallet.wallet_transaction_store import WalletTransactionStore

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -58,6 +59,7 @@ class WalletBlockchain(BlockchainInterface):
__sub_epoch_summaries: Dict[uint32, SubEpochSummary] = {}
# Unspent Store
coin_store: WalletCoinStore
tx_store: WalletTransactionStore
# Store
block_store: WalletBlockStore
# Used to verify blocks in parallel
Expand All @@ -77,6 +79,8 @@ class WalletBlockchain(BlockchainInterface):
@staticmethod
async def create(
block_store: WalletBlockStore,
coin_store: WalletCoinStore,
tx_store: WalletTransactionStore,
consensus_constants: ConsensusConstants,
coins_of_interest_received: Callable, # f(removals: List[Coin], additions: List[Coin], height: uint32)
reorg_rollback: Callable,
Expand All @@ -88,7 +92,9 @@ async def create(
in the consensus constants config.
"""
self = WalletBlockchain()
self.lock = asyncio.Lock() # External lock handled by full node
self.lock = asyncio.Lock()
self.coin_store = coin_store
self.tx_store = tx_store
cpu_count = multiprocessing.cpu_count()
if cpu_count > 61:
cpu_count = 61 # Windows Server 2016 has an issue https://bugs.python.org/issue26903
Expand Down Expand Up @@ -228,7 +234,10 @@ async def receive_block(
await self.block_store.db_wrapper.commit_transaction()
except BaseException as e:
self.log.error(f"Error during db transaction: {e}")
await self.block_store.db_wrapper.rollback_transaction()
if self.block_store.db_wrapper.db._connection is not None:
await self.block_store.db_wrapper.rollback_transaction()
await self.coin_store.rebuild_wallet_cache()
await self.tx_store.rebuild_tx_cache()
raise
if fork_height is not None:
self.log.info(f"💰 Updated wallet peak to height {block_record.height}, weight {block_record.weight}, ")
Expand Down Expand Up @@ -271,7 +280,8 @@ async def _reconsider_peak(

# Rollback to fork
self.log.debug(f"fork_h: {fork_h}, SB: {block_record.height}, peak: {peak.height}")
await self.reorg_rollback(fork_h)
if block_record.prev_hash != peak.header_hash:
await self.reorg_rollback(fork_h)

# Rollback sub_epoch_summaries
heights_to_delete = []
Expand Down
133 changes: 60 additions & 73 deletions chia/wallet/wallet_coin_store.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
from typing import Dict, List, Optional, Set

import aiosqlite
Expand All @@ -18,9 +17,10 @@ class WalletCoinStore:
"""

db_connection: aiosqlite.Connection
# coin_record_cache keeps ALL coin records in memory. [record_name: record]
coin_record_cache: Dict[bytes32, WalletCoinRecord]
coin_wallet_record_cache: Dict[int, Dict[bytes32, WalletCoinRecord]]
wallet_cache_lock: asyncio.Lock
# unspent_coin_wallet_cache keeps ALL unspent coin records for wallet in memory [wallet_id: [record_name: record]]
unspent_coin_wallet_cache: Dict[int, Dict[bytes32, WalletCoinRecord]]
db_wrapper: DBWrapper

@classmethod
Expand Down Expand Up @@ -62,58 +62,66 @@ async def create(cls, wrapper: DBWrapper):
await self.db_connection.execute("CREATE INDEX IF NOT EXISTS wallet_id on coin_record(wallet_id)")

await self.db_connection.commit()
self.coin_record_cache = dict()
self.coin_wallet_record_cache = {}
all_coins = await self.get_all_coins()
for coin_record in all_coins:
self.coin_record_cache[coin_record.coin.name()] = coin_record

self.wallet_cache_lock = asyncio.Lock()
self.coin_record_cache = {}
self.unspent_coin_wallet_cache = {}
await self.rebuild_wallet_cache()
return self

async def _clear_database(self):
cursor = await self.db_connection.execute("DELETE FROM coin_record")
await cursor.close()
await self.db_connection.commit()

async def rebuild_wallet_cache(self):
# First update all coins that were reorged, then re-add coin_records
all_coins = await self.get_all_coins()
self.unspent_coin_wallet_cache = {}
self.coin_record_cache = {}
for coin_record in all_coins:
name = coin_record.name()
self.coin_record_cache[name] = coin_record
if coin_record.spent is False:
if coin_record.wallet_id not in self.unspent_coin_wallet_cache:
self.unspent_coin_wallet_cache[coin_record.wallet_id] = {}
self.unspent_coin_wallet_cache[coin_record.wallet_id][name] = coin_record

# Store CoinRecord in DB and ram cache
async def add_coin_record(self, record: WalletCoinRecord) -> None:
# update wallet cache
name = record.name()
self.coin_record_cache[name] = record
if record.wallet_id in self.unspent_coin_wallet_cache:
if record.spent and name in self.unspent_coin_wallet_cache[record.wallet_id]:
self.unspent_coin_wallet_cache[record.wallet_id].pop(name)
if not record.spent:
self.unspent_coin_wallet_cache[record.wallet_id][name] = record
else:
if not record.spent:
self.unspent_coin_wallet_cache[record.wallet_id] = {}
self.unspent_coin_wallet_cache[record.wallet_id][name] = record

await self.wallet_cache_lock.acquire()
try:
if record.wallet_id in self.coin_wallet_record_cache:
cache_dict = self.coin_wallet_record_cache[record.wallet_id]
if record.coin.name() in cache_dict and record.spent:
cache_dict.pop(record.coin.name())
else:
cache_dict[record.coin.name()] = record

cursor = await self.db_connection.execute(
"INSERT OR REPLACE INTO coin_record VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
record.coin.name().hex(),
record.confirmed_block_height,
record.spent_block_height,
int(record.spent),
int(record.coinbase),
str(record.coin.puzzle_hash.hex()),
str(record.coin.parent_coin_info.hex()),
bytes(record.coin.amount),
record.wallet_type,
record.wallet_id,
),
)
await cursor.close()
self.coin_record_cache[record.coin.name()] = record
finally:
self.wallet_cache_lock.release()
cursor = await self.db_connection.execute(
"INSERT OR REPLACE INTO coin_record VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
name.hex(),
record.confirmed_block_height,
record.spent_block_height,
int(record.spent),
int(record.coinbase),
str(record.coin.puzzle_hash.hex()),
str(record.coin.parent_coin_info.hex()),
bytes(record.coin.amount),
record.wallet_type,
record.wallet_id,
),
)
await cursor.close()

# Update coin_record to be spent in DB
async def set_spent(self, coin_name: bytes32, height: uint32):
async def set_spent(self, coin_name: bytes32, height: uint32) -> WalletCoinRecord:
current: Optional[WalletCoinRecord] = await self.get_coin_record(coin_name)
if current is None:
return None
assert current is not None
assert current.spent is False

spent: WalletCoinRecord = WalletCoinRecord(
current.coin,
Expand All @@ -126,6 +134,7 @@ async def set_spent(self, coin_name: bytes32, height: uint32):
)

await self.add_coin_record(spent)
return spent

def coin_record_from_row(self, row: sqlite3.Row) -> WalletCoinRecord:
coin = Coin(bytes32(bytes.fromhex(row[6])), bytes32(bytes.fromhex(row[5])), uint64.from_bytes(row[7]))
Expand Down Expand Up @@ -180,27 +189,11 @@ async def get_unspent_coins_at_height(self, height: Optional[uint32] = None) ->

async def get_unspent_coins_for_wallet(self, wallet_id: int) -> Set[WalletCoinRecord]:
""" Returns set of CoinRecords that have not been spent yet for a wallet. """
async with self.wallet_cache_lock:
if wallet_id in self.coin_wallet_record_cache:
wallet_coins: Dict[bytes32, WalletCoinRecord] = self.coin_wallet_record_cache[wallet_id]
return set(wallet_coins.values())

coin_set = set()

cursor = await self.db_connection.execute(
"SELECT * from coin_record WHERE spent=0 and wallet_id=?",
(wallet_id,),
)
rows = await cursor.fetchall()
await cursor.close()
cache_dict = {}
for row in rows:
coin_record = self.coin_record_from_row(row)
coin_set.add(coin_record)
cache_dict[coin_record.name()] = coin_record

self.coin_wallet_record_cache[wallet_id] = cache_dict
return coin_set
if wallet_id in self.unspent_coin_wallet_cache:
wallet_coins: Dict[bytes32, WalletCoinRecord] = self.unspent_coin_wallet_cache[wallet_id]
return set(wallet_coins.values())
else:
return set()

async def get_all_coins(self) -> Set[WalletCoinRecord]:
""" Returns set of all CoinRecords."""
Expand All @@ -219,43 +212,37 @@ async def get_coin_records_by_puzzle_hash(self, puzzle_hash: bytes32) -> List[Wa

return [self.coin_record_from_row(row) for row in rows]

async def get_coin_record_by_coin_id(self, coin_id: bytes32) -> Optional[WalletCoinRecord]:
"""Returns a coin records with the given name, if it exists"""
# TODO: This is a duplicate of get_coin_record()
return await self.get_coin_record(coin_id)

async def rollback_to_block(self, height: int):
"""
Rolls back the blockchain to block_index. All blocks confirmed after this point
are removed from the LCA. All coins confirmed after this point are removed.
All coins spent after this point are set to unspent. Can be -1 (rollback all)
"""
# Update memory cache

# Delete from storage
delete_queue: List[WalletCoinRecord] = []
for coin_name, coin_record in self.coin_record_cache.items():
if coin_record.spent_block_height > height:
new_record = WalletCoinRecord(
coin_record.coin,
coin_record.confirmed_block_height,
coin_record.spent_block_height,
uint32(0),
False,
coin_record.coinbase,
coin_record.wallet_type,
coin_record.wallet_id,
)
self.coin_record_cache[coin_record.coin.name()] = new_record
self.unspent_coin_wallet_cache[coin_record.wallet_id][coin_record.coin.name()] = new_record
if coin_record.confirmed_block_height > height:
delete_queue.append(coin_record)

for coin_record in delete_queue:
self.coin_record_cache.pop(coin_record.coin.name())
if coin_record.wallet_id in self.coin_wallet_record_cache:
coin_cache = self.coin_wallet_record_cache[coin_record.wallet_id]
if coin_record.wallet_id in self.unspent_coin_wallet_cache:
coin_cache = self.unspent_coin_wallet_cache[coin_record.wallet_id]
if coin_record.coin.name() in coin_cache:
coin_cache.pop(coin_record.coin.name())

# Delete from storage
c1 = await self.db_connection.execute("DELETE FROM coin_record WHERE confirmed_height>?", (height,))
await c1.close()
c2 = await self.db_connection.execute(
Expand Down
6 changes: 6 additions & 0 deletions chia/wallet/wallet_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from chia.util.errors import Err, ValidationError
from chia.util.ints import uint32, uint128
from chia.util.keychain import Keychain
from chia.util.lru_cache import LRUCache
from chia.util.merkle_set import MerkleSet, confirm_included_already_hashed, confirm_not_included_already_hashed
from chia.util.path import mkdir, path_from_root
from chia.wallet.block_record import HeaderBlockRecord
Expand Down Expand Up @@ -109,6 +110,7 @@ def __init__(
self.logged_in_fingerprint: Optional[int] = None
self.peer_task = None
self.logged_in = False
self.last_new_peak_messages = LRUCache(5)

def get_key_for_fingerprint(self, fingerprint: Optional[int]):
private_keys = self.keychain.get_all_private_keys()
Expand Down Expand Up @@ -442,6 +444,7 @@ async def new_peak_wallet(self, peak: wallet_protocol.NewPeakWallet, peer: WSChi
# Request weight proof
# Sync if PoW validates
if self.wallet_state_manager.sync_mode:
self.last_new_peak_messages.put(peer, peak)
return None
weight_request = RequestProofOfWeight(header_block.height, header_block.header_hash)
weight_proof_response: RespondProofOfWeight = await peer.request_proof_of_weight(
Expand Down Expand Up @@ -501,6 +504,7 @@ async def sync_job(self) -> None:
break
asyncio.create_task(self.check_new_peak())
await self.sync_event.wait()
self.last_new_peak_messages = LRUCache(5)
self.sync_event.clear()

if self._shut_down is True:
Expand All @@ -515,6 +519,8 @@ async def sync_job(self) -> None:
finally:
if self.wallet_state_manager is not None:
self.wallet_state_manager.set_sync_mode(False)
for peer, peak in self.last_new_peak_messages.cache.items():
asyncio.create_task(self.new_peak_wallet(peak, peer))
self.log.info("Loop end in sync job")

async def _sync(self) -> None:
Expand Down
Loading

0 comments on commit 1086471

Please sign in to comment.