diff --git a/chain/chain/src/chain.rs b/chain/chain/src/chain.rs index 72787e82b85..70ea7748348 100644 --- a/chain/chain/src/chain.rs +++ b/chain/chain/src/chain.rs @@ -546,7 +546,7 @@ impl Chain { }) .cloned() .collect(); - runtime_adapter.get_tries().load_mem_tries_for_enabled_shards(&tracked_shards)?; + runtime_adapter.get_tries().load_mem_tries_for_enabled_shards(&tracked_shards, true)?; info!(target: "chain", "Init: header head @ #{} {}; block head @ #{} {}", header_head.height, header_head.last_block_hash, diff --git a/core/store/src/test_utils.rs b/core/store/src/test_utils.rs index db4480f85ef..5a5bb2f8dda 100644 --- a/core/store/src/test_utils.rs +++ b/core/store/src/test_utils.rs @@ -179,7 +179,7 @@ impl TestTriesBuilder { } update_for_chunk_extra.commit().unwrap(); - tries.load_mem_tries_for_enabled_shards(&shard_uids).unwrap(); + tries.load_mem_tries_for_enabled_shards(&shard_uids, false).unwrap(); } tries } diff --git a/core/store/src/trie/mem/arena/alloc.rs b/core/store/src/trie/mem/arena/alloc.rs index 04ef5221c43..d78c2a70563 100644 --- a/core/store/src/trie/mem/arena/alloc.rs +++ b/core/store/src/trie/mem/arena/alloc.rs @@ -45,11 +45,11 @@ pub struct Allocator { const MAX_ALLOC_SIZE: usize = 16 * 1024; const ROUND_UP_TO_8_BYTES_UNDER: usize = 256; const ROUND_UP_TO_64_BYTES_UNDER: usize = 1024; -const CHUNK_SIZE: usize = 4 * 1024 * 1024; +pub(crate) const CHUNK_SIZE: usize = 4 * 1024 * 1024; /// Calculates the allocation class (an index from 0 to NUM_ALLOCATION_CLASSES) /// for the given size that we wish to allocate. -const fn allocation_class(size: usize) -> usize { +pub(crate) const fn allocation_class(size: usize) -> usize { if size <= ROUND_UP_TO_8_BYTES_UNDER { (size + 7) / 8 - 1 } else if size <= ROUND_UP_TO_64_BYTES_UNDER { @@ -61,7 +61,7 @@ const fn allocation_class(size: usize) -> usize { } /// Calculates the size of the actual allocation for the given size class. -const fn allocation_size(size_class: usize) -> usize { +pub(crate) const fn allocation_size(size_class: usize) -> usize { if size_class <= allocation_class(ROUND_UP_TO_8_BYTES_UNDER) { (size_class + 1) * 8 } else if size_class <= allocation_class(ROUND_UP_TO_64_BYTES_UNDER) { @@ -88,13 +88,30 @@ impl Allocator { } } + pub fn new_with_initial_stats( + name: String, + active_allocs_bytes: usize, + active_allocs_count: usize, + ) -> Self { + let mut allocator = Self::new(name); + allocator.active_allocs_bytes = active_allocs_bytes; + allocator.active_allocs_count = active_allocs_count; + allocator.active_allocs_bytes_gauge.set(active_allocs_bytes as i64); + allocator.active_allocs_count_gauge.set(active_allocs_count as i64); + allocator + } + + pub fn update_memory_usage_gauge(&self, memory: &STArenaMemory) { + self.memory_usage_gauge.set(memory.chunks.len() as i64 * CHUNK_SIZE as i64); + } + /// Adds a new chunk to the arena, and updates the next_alloc_pos to the beginning of /// the new chunk. fn new_chunk(&mut self, memory: &mut STArenaMemory) { memory.chunks.push(vec![0; CHUNK_SIZE]); self.next_alloc_pos = ArenaPos { chunk: u32::try_from(memory.chunks.len() - 1).unwrap(), pos: 0 }; - self.memory_usage_gauge.set(memory.chunks.len() as i64 * CHUNK_SIZE as i64); + self.update_memory_usage_gauge(memory); } /// Allocates a slice of the given size in the arena. @@ -145,6 +162,11 @@ impl Allocator { pub fn num_active_allocs(&self) -> usize { self.active_allocs_count } + + #[cfg(test)] + pub fn active_allocs_bytes(&self) -> usize { + self.active_allocs_bytes + } } #[cfg(test)] diff --git a/core/store/src/trie/mem/arena/concurrent.rs b/core/store/src/trie/mem/arena/concurrent.rs new file mode 100644 index 00000000000..4f6bf34f138 --- /dev/null +++ b/core/store/src/trie/mem/arena/concurrent.rs @@ -0,0 +1,261 @@ +use super::alloc::{allocation_class, allocation_size, CHUNK_SIZE}; +use super::{Arena, ArenaMemory, ArenaPos, ArenaSliceMut, STArena}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +/// Arena that can be allocated on from multiple threads, but still allowing conversion to a +/// single-threaded `STArena` afterwards. +/// +/// The `ConcurrentArena` cannot be directly used; rather, for each thread wishing to use it, +/// `for_thread` must be called to get a `ConcurrentArenaForThread`, which then acts like a +/// normal arena, except that deallocation is not supported. +/// +/// The only synchronization they need is the chunk counter. The purpose is so that after +/// multiple threads allocate on their own arenas, the resulting memory can still be combined +/// into a single arena while allowing the pointers (ArenaPos) to still be valid. This is what +/// allows a memtrie to be loaded in parallel. +pub struct ConcurrentArena { + /// Chunks allocated by each `ConcurrentArenaForThread` share the same "logical" memory + /// space. This counter is used to ensure that each `ConcurrentArenaForThread` gets unique + /// chunk positions, so that the allocations made by different threads do not conflict in + /// their positions. + /// + /// The goal here is so that allocations coming from multiple threads can be merged into a + /// single arena at the end, without having to alter any arena pointers (ArenaPos). + next_chunk_pos: Arc, +} + +impl ConcurrentArena { + pub fn new() -> Self { + Self { next_chunk_pos: Arc::new(AtomicUsize::new(0)) } + } + + /// Returns an arena that can be used for one thread. + pub fn for_thread(&self) -> ConcurrentArenaForThread { + ConcurrentArenaForThread::new(self.next_chunk_pos.clone()) + } + + /// Converts the arena to a single-threaded arena. All returned values of `for_thread` must be + /// passed in. + /// + /// There is a caveat that may be fixed in the future if desired: the returned arena will have + /// some memory wasted. This is because the last chunk of each thread may not be full, and + /// the single-threaded arena is unable to make use of multiple partially filled chunks. The + /// maximum memory wasted is 4MB * number of threads; the average is 2MB * number of threads. + /// The wasted memory will be reclaimed when the memtrie shard is unloaded. + pub fn to_single_threaded( + self, + name: String, + threads: Vec, + ) -> STArena { + let mut chunks = vec![Vec::new(); self.next_chunk_pos.load(Ordering::Relaxed)]; + let mut active_allocs_bytes = 0; + let mut active_allocs_count = 0; + for thread in threads { + let memory = thread.memory; + for (pos, chunk) in memory.chunks.into_iter() { + assert!( + chunks[pos].is_empty(), + "Arena threads from the same ConcurrentArena passed in" + ); + chunks[pos] = chunk; + } + active_allocs_bytes += thread.allocator.active_allocs_bytes; + active_allocs_count += thread.allocator.active_allocs_count; + } + for chunk in &chunks { + assert!(!chunks.is_empty(), "Not all arena threads are passed in"); + assert_eq!(chunk.len(), CHUNK_SIZE); // may as well check this + } + STArena::new_from_existing_chunks(name, chunks, active_allocs_bytes, active_allocs_count) + } +} + +/// Arena to be used for a single thread. +pub struct ConcurrentArenaForThread { + memory: ConcurrentArenaMemory, + allocator: ConcurrentArenaAllocator, +} + +pub struct ConcurrentArenaMemory { + /// Chunks of memory allocated for this thread. The usize is the global chunk position. + chunks: Vec<(usize, Vec)>, + /// Index is global chunk position, value is local chunk position. + /// For a chunk position that does not belong to the thread, the value is `usize::MAX`. + /// This vector is as large as needed to contain the largest global chunk position used + /// by this thread, but might not be as large as the total number of chunks allocated + /// globally. + chunk_pos_global_to_local: Vec, +} + +impl ConcurrentArenaMemory { + pub fn new() -> Self { + Self { chunks: Vec::new(), chunk_pos_global_to_local: Vec::new() } + } + + pub fn add_chunk(&mut self, pos: usize) { + while self.chunk_pos_global_to_local.len() <= pos { + self.chunk_pos_global_to_local.push(usize::MAX); + } + self.chunk_pos_global_to_local[pos] = self.chunks.len(); + self.chunks.push((pos, vec![0; CHUNK_SIZE])); + } + + pub fn chunk(&self, pos: usize) -> &[u8] { + let index = self.chunk_pos_global_to_local[pos]; + &self.chunks[index].1 + } + + pub fn chunk_mut(&mut self, pos: usize) -> &mut [u8] { + let index = self.chunk_pos_global_to_local[pos]; + &mut self.chunks[index].1 + } +} + +impl ArenaMemory for ConcurrentArenaMemory { + fn raw_slice(&self, pos: ArenaPos, len: usize) -> &[u8] { + &self.chunk(pos.chunk())[pos.pos()..pos.pos() + len] + } + + fn raw_slice_mut(&mut self, pos: ArenaPos, len: usize) -> &mut [u8] { + &mut self.chunk_mut(pos.chunk())[pos.pos()..pos.pos() + len] + } +} + +/// Allocator for a single thread. Unlike the allocator for `STArena`, this one only supports +/// allocation and not deallocation, so it is substantially simpler. +pub struct ConcurrentArenaAllocator { + next_chunk_pos: Arc, + next_pos: ArenaPos, + + // Stats that will be transferred to the single-threaded arena. + active_allocs_bytes: usize, + active_allocs_count: usize, +} + +impl ConcurrentArenaAllocator { + fn new(next_chunk_pos: Arc) -> Self { + Self { + next_chunk_pos, + next_pos: ArenaPos::invalid(), + active_allocs_bytes: 0, + active_allocs_count: 0, + } + } + + pub fn allocate<'a>( + &mut self, + arena: &'a mut ConcurrentArenaMemory, + size: usize, + ) -> ArenaSliceMut<'a, ConcurrentArenaMemory> { + // We must allocate in the same kind of sizes as the single-threaded arena, + // so that after converting to `STArena`, these allocations can be properly + // reused. + let size_class = allocation_class(size); + let allocation_size = allocation_size(size_class); + if self.next_pos.is_invalid() || self.next_pos.pos() + allocation_size > CHUNK_SIZE { + let next_chunk_pos = self.next_chunk_pos.fetch_add(1, Ordering::Relaxed); + self.next_pos = ArenaPos { chunk: next_chunk_pos as u32, pos: 0 }; + arena.add_chunk(next_chunk_pos); + } + let pos = self.next_pos; + self.next_pos = pos.offset_by(allocation_size); + self.active_allocs_bytes += allocation_size; + self.active_allocs_count += 1; + ArenaSliceMut::new(arena, pos, size) + } +} + +impl ConcurrentArenaForThread { + fn new(next_chunk_pos: Arc) -> Self { + Self { + memory: ConcurrentArenaMemory::new(), + allocator: ConcurrentArenaAllocator::new(next_chunk_pos), + } + } +} + +impl Arena for ConcurrentArenaForThread { + type Memory = ConcurrentArenaMemory; + + fn memory(&self) -> &Self::Memory { + &self.memory + } + + fn memory_mut(&mut self) -> &mut Self::Memory { + &mut self.memory + } + + fn alloc(&mut self, size: usize) -> ArenaSliceMut { + self.allocator.allocate(&mut self.memory, size) + } +} + +#[cfg(test)] +mod tests { + use super::ConcurrentArena; + use crate::trie::mem::arena::alloc::CHUNK_SIZE; + use crate::trie::mem::arena::metrics::MEM_TRIE_ARENA_MEMORY_USAGE_BYTES; + use crate::trie::mem::arena::{Arena, ArenaMemory, ArenaWithDealloc}; + + #[test] + fn test_concurrent_arena() { + let arena = ConcurrentArena::new(); + let mut thread1 = arena.for_thread(); + let mut thread2 = arena.for_thread(); + let mut thread3 = arena.for_thread(); + + let mut alloc1 = thread1.alloc(17); + let mut alloc2 = thread2.alloc(25); + let mut alloc3 = thread3.alloc(40); + alloc1.raw_slice_mut().copy_from_slice(&[1; 17]); + alloc2.raw_slice_mut().copy_from_slice(&[2; 25]); + alloc3.raw_slice_mut().copy_from_slice(&[3; 40]); + let ptr1 = alloc1.raw_pos(); + let ptr2 = alloc2.raw_pos(); + let ptr3 = alloc3.raw_pos(); + + let name = rand::random::().to_string(); + let mut starena = arena.to_single_threaded(name.clone(), vec![thread1, thread2, thread3]); + + assert_eq!(starena.num_active_allocs(), 3); + assert_eq!(starena.active_allocs_bytes(), 24 + 32 + 40); + assert_eq!( + MEM_TRIE_ARENA_MEMORY_USAGE_BYTES.get_metric_with_label_values(&[&name]).unwrap().get(), + 3 * CHUNK_SIZE as i64 + ); + + let mut alloc4 = starena.alloc(17); + alloc4.raw_slice_mut().copy_from_slice(&[4; 17]); + let ptr4 = alloc4.raw_pos(); + + assert_eq!(starena.memory().raw_slice(ptr1, 17), &[1; 17]); + assert_eq!(starena.memory().raw_slice(ptr2, 25), &[2; 25]); + assert_eq!(starena.memory().raw_slice(ptr3, 40), &[3; 40]); + assert_eq!(starena.memory().raw_slice(ptr4, 17), &[4; 17]); + + // Allocations from the concurrent arena can be deallocated and reused in the converted STArena. + // Allocations of the same size class are reusable. + starena.dealloc(ptr1, 17); + let mut alloc5 = starena.alloc(23); + assert_eq!(alloc5.raw_pos(), ptr1); + alloc5.raw_slice_mut().copy_from_slice(&[5; 23]); + starena.dealloc(ptr2, 25); + let mut alloc6 = starena.alloc(32); + assert_eq!(alloc6.raw_pos(), ptr2); + alloc6.raw_slice_mut().copy_from_slice(&[6; 32]); + starena.dealloc(ptr3, 40); + let mut alloc7 = starena.alloc(37); + assert_eq!(alloc7.raw_pos(), ptr3); + alloc7.raw_slice_mut().copy_from_slice(&[7; 37]); + starena.dealloc(ptr4, 17); + let mut alloc8 = starena.alloc(24); + assert_eq!(alloc8.raw_pos(), ptr4); + alloc8.raw_slice_mut().copy_from_slice(&[8; 24]); + + assert_eq!(starena.memory().raw_slice(ptr1, 23), &[5; 23]); + assert_eq!(starena.memory().raw_slice(ptr2, 32), &[6; 32]); + assert_eq!(starena.memory().raw_slice(ptr3, 37), &[7; 37]); + assert_eq!(starena.memory().raw_slice(ptr4, 24), &[8; 24]); + } +} diff --git a/core/store/src/trie/mem/arena/mod.rs b/core/store/src/trie/mem/arena/mod.rs index d7213c3c816..235a58a9e91 100644 --- a/core/store/src/trie/mem/arena/mod.rs +++ b/core/store/src/trie/mem/arena/mod.rs @@ -1,4 +1,5 @@ mod alloc; +pub mod concurrent; mod metrics; use self::alloc::Allocator; @@ -153,11 +154,34 @@ impl STArena { Self { memory: STArenaMemory::new(), allocator: Allocator::new(name) } } + pub(crate) fn new_from_existing_chunks( + name: String, + chunks: Vec>, + active_allocs_bytes: usize, + active_allocs_count: usize, + ) -> Self { + let arena = Self { + memory: STArenaMemory { chunks }, + allocator: Allocator::new_with_initial_stats( + name, + active_allocs_bytes, + active_allocs_count, + ), + }; + arena.allocator.update_memory_usage_gauge(&arena.memory); + arena + } + /// Number of active allocations (alloc calls minus dealloc calls). #[cfg(test)] pub fn num_active_allocs(&self) -> usize { self.allocator.num_active_allocs() } + + #[cfg(test)] + pub fn active_allocs_bytes(&self) -> usize { + self.allocator.active_allocs_bytes() + } } impl Arena for STArena { diff --git a/core/store/src/trie/mem/construction.rs b/core/store/src/trie/mem/construction.rs index 7e5af49c1b2..9975671783a 100644 --- a/core/store/src/trie/mem/construction.rs +++ b/core/store/src/trie/mem/construction.rs @@ -220,8 +220,7 @@ impl<'a, A: Arena> TrieConstructor<'a, A> { /// Adds a leaf to the trie. The key must be greater than all previous keys /// inserted. - pub fn add_leaf(&mut self, key: &[u8], value: FlatStateValue) { - let mut nibbles = NibbleSlice::new(key); + pub fn add_leaf(&mut self, mut nibbles: NibbleSlice, value: FlatStateValue) { let mut i = 0; // We'll go down the segments to find where our nibbles deviate. // If the deviation happens in the middle of a segment, we would split diff --git a/core/store/src/trie/mem/flexible_data/encoding.rs b/core/store/src/trie/mem/flexible_data/encoding.rs index 99164b6fd6d..a76f845e2c1 100644 --- a/core/store/src/trie/mem/flexible_data/encoding.rs +++ b/core/store/src/trie/mem/flexible_data/encoding.rs @@ -1,5 +1,5 @@ use super::FlexibleDataHeader; -use crate::trie::mem::arena::{Arena, ArenaMemory, ArenaPtr, ArenaPtrMut, ArenaSliceMut}; +use crate::trie::mem::arena::{Arena, ArenaMemory, ArenaPtr, ArenaSliceMut}; use borsh::{BorshDeserialize, BorshSerialize}; use std::io::Write; @@ -103,37 +103,3 @@ impl<'a, M: ArenaMemory> RawDecoder<'a, M> { view } } - -/// Provides ability to decode, but also to overwrite some data. -pub struct RawDecoderMut<'a, M: ArenaMemory> { - data: ArenaPtrMut<'a, M>, - pos: usize, -} - -impl<'a, M: ArenaMemory> RawDecoderMut<'a, M> { - pub fn new(data: ArenaPtrMut<'a, M>) -> Self { - RawDecoderMut { data, pos: 0 } - } - - /// Same with `RawDecoder::decode`. - pub fn decode(&mut self) -> T { - let slice = self.data.slice(self.pos, T::SERIALIZED_SIZE); - let result = T::try_from_slice(slice.raw_slice()).unwrap(); - self.pos += T::SERIALIZED_SIZE; - result - } - - /// Same with `RawDecoder::peek`. - pub fn peek(&mut self) -> T { - let slice = self.data.slice(self.pos, T::SERIALIZED_SIZE); - T::try_from_slice(slice.raw_slice()).unwrap() - } - - /// Overwrites the data at the current position with the given data, - /// and advances the position by the size of the data. - pub fn overwrite(&mut self, data: T) { - let mut slice = self.data.slice_mut(self.pos, T::SERIALIZED_SIZE); - data.serialize(&mut slice.raw_slice_mut()).unwrap(); - self.pos += T::SERIALIZED_SIZE; - } -} diff --git a/core/store/src/trie/mem/loading.rs b/core/store/src/trie/mem/loading.rs index 04fe952076c..74b3e263378 100644 --- a/core/store/src/trie/mem/loading.rs +++ b/core/store/src/trie/mem/loading.rs @@ -6,29 +6,55 @@ use crate::flat::store_helper::{ use crate::flat::{FlatStorageError, FlatStorageStatus}; use crate::trie::mem::arena::Arena; use crate::trie::mem::construction::TrieConstructor; +use crate::trie::mem::parallel_loader::load_memtrie_in_parallel; use crate::trie::mem::updating::apply_memtrie_changes; -use crate::{DBCol, Store}; +use crate::{DBCol, NibbleSlice, Store}; use near_primitives::errors::StorageError; use near_primitives::hash::CryptoHash; use near_primitives::shard_layout::{get_block_shard_uid, ShardUId}; use near_primitives::state::FlatStateValue; use near_primitives::types::chunk_extra::ChunkExtra; use near_primitives::types::{BlockHeight, StateRoot}; -use rayon::prelude::{IntoParallelIterator, ParallelIterator}; use std::collections::BTreeSet; use std::time::Instant; use tracing::{debug, info}; /// Loads a trie from the FlatState column. The returned `MemTries` contains /// exactly one trie root. +/// +/// `parallelize` can be used to speed up reading from db. However, it should +/// only be used when no other work is being done, such as during initial +/// startup. It also incurs a higher peak memory usage. pub fn load_trie_from_flat_state( store: &Store, shard_uid: ShardUId, state_root: CryptoHash, block_height: BlockHeight, + parallelize: bool, ) -> Result { - let mut tries = MemTries::new(shard_uid); + if parallelize && state_root != CryptoHash::default() { + const NUM_PARALLEL_SUBTREES_DESIRED: usize = 256; + let load_start = Instant::now(); + let (arena, root_id) = load_memtrie_in_parallel( + store.clone(), + shard_uid, + state_root, + NUM_PARALLEL_SUBTREES_DESIRED, + shard_uid.to_string(), + )?; + info!(target: "memtrie", shard_uid=%shard_uid, "Done loading trie from flat state, took {:?}", load_start.elapsed()); + let root = root_id.as_ptr(arena.memory()); + assert_eq!( + root.view().node_hash(), + state_root, + "In-memory trie for shard {} has incorrect state root", + shard_uid + ); + return Ok(MemTries::new_from_arena_and_root(shard_uid, block_height, arena, root_id)); + } + + let mut tries = MemTries::new(shard_uid); tries.construct_root(block_height, |arena| -> Result, StorageError> { info!(target: "memtrie", shard_uid=%shard_uid, "Loading trie from flat state..."); let load_start = Instant::now(); @@ -44,7 +70,7 @@ pub fn load_trie_from_flat_state( FlatStorageError::StorageInternalError(format!( "invalid FlatState key format: {err}" ))})?; - recon.add_leaf(&key, value); + recon.add_leaf(NibbleSlice::new(&key), value); num_keys_loaded += 1; if num_keys_loaded % 1000000 == 0 { debug!( @@ -67,15 +93,9 @@ pub fn load_trie_from_flat_state( debug!( target: "memtrie", %shard_uid, - "Loaded {} keys; computing hash and memory usage...", + "Loaded {} keys in total", num_keys_loaded ); - let mut subtrees = Vec::new(); - root_id.as_ptr_mut(arena.memory_mut()).take_small_subtrees(1024 * 1024, &mut subtrees); - subtrees.into_par_iter().for_each(|mut subtree| { - subtree.compute_hash_recursively(); - }); - root_id.as_ptr_mut(arena.memory_mut()).compute_hash_recursively(); info!(target: "memtrie", shard_uid=%shard_uid, "Done loading trie from flat state, took {:?}", load_start.elapsed()); let root = root_id.as_ptr(arena.memory()); @@ -119,6 +139,7 @@ pub fn load_trie_from_flat_state_and_delta( store: &Store, shard_uid: ShardUId, state_root: Option, + parallelize: bool, ) -> Result { debug!(target: "memtrie", %shard_uid, "Loading base trie from flat state..."); let flat_head = match get_flat_storage_status(&store, shard_uid)? { @@ -137,7 +158,8 @@ pub fn load_trie_from_flat_state_and_delta( }; let mut mem_tries = - load_trie_from_flat_state(&store, shard_uid, state_root, flat_head.height).unwrap(); + load_trie_from_flat_state(&store, shard_uid, state_root, flat_head.height, parallelize) + .unwrap(); debug!(target: "memtrie", %shard_uid, "Loading flat state deltas..."); // We load the deltas in order of height, so that we always have the previous state root @@ -199,7 +221,7 @@ mod tests { use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; - fn check(keys: Vec>) { + fn check_maybe_parallelize(keys: Vec>, parallelize: bool) { let shard_tries = TestTriesBuilder::new().with_flat_storage(true).build(); let shard_uid = ShardUId::single_shard(); let changes = keys.iter().map(|key| (key.to_vec(), Some(key.to_vec()))).collect::>(); @@ -214,9 +236,14 @@ mod tests { let state_root = test_populate_trie(&shard_tries, &Trie::EMPTY_ROOT, shard_uid, changes); eprintln!("Trie and flat storage populated"); - let in_memory_trie = - load_trie_from_flat_state(&shard_tries.get_store(), shard_uid, state_root, 123) - .unwrap(); + let in_memory_trie = load_trie_from_flat_state( + &shard_tries.get_store(), + shard_uid, + state_root, + 123, + parallelize, + ) + .unwrap(); eprintln!("In memory trie loaded"); if keys.is_empty() { @@ -247,7 +274,7 @@ mod tests { // Do another access with the trie to see how many nodes we're supposed to // have accessed. let temp_trie = shard_tries.get_trie_for_shard(shard_uid, state_root); - temp_trie.get_optimized_ref(key, crate::KeyLookupMode::Trie).unwrap(); + temp_trie.get_optimized_ref(key, KeyLookupMode::Trie).unwrap(); assert_eq!( temp_trie.get_trie_nodes_count().db_reads, nodes_accessed.len() as u64, @@ -281,6 +308,11 @@ mod tests { } } + fn check(keys: Vec>) { + check_maybe_parallelize(keys.clone(), false); + check_maybe_parallelize(keys, true); + } + fn nibbles(hex: &str) -> Vec { if hex == "_" { return vec![]; @@ -466,7 +498,7 @@ mod tests { // Load into memory. It should load the base flat state (block 0), plus all // four deltas. We'll check against the state roots at each block; they should // all exist in the loaded memtrie. - let mem_tries = load_trie_from_flat_state_and_delta(&store, shard_uid, None).unwrap(); + let mem_tries = load_trie_from_flat_state_and_delta(&store, shard_uid, None, true).unwrap(); assert_eq!( memtrie_lookup(mem_tries.get_root(&state_root_0).unwrap(), &test_key.to_vec(), None) diff --git a/core/store/src/trie/mem/mod.rs b/core/store/src/trie/mem/mod.rs index 16794de1284..73146e3a145 100644 --- a/core/store/src/trie/mem/mod.rs +++ b/core/store/src/trie/mem/mod.rs @@ -17,6 +17,7 @@ pub mod loading; pub mod lookup; pub mod metrics; pub mod node; +mod parallel_loader; pub mod updating; /// Check this, because in the code we conveniently assume usize is 8 bytes. @@ -56,6 +57,18 @@ impl MemTries { } } + pub fn new_from_arena_and_root( + shard_uid: ShardUId, + block_height: BlockHeight, + arena: STArena, + root: MemTrieNodeId, + ) -> Self { + let mut tries = + Self { arena, roots: HashMap::new(), heights: Default::default(), shard_uid }; + tries.insert_root(root.as_ptr(tries.arena.memory()).view().node_hash(), root, block_height); + tries + } + /// Inserts a new root into the trie. The given function should perform /// the entire construction of the new trie, possibly based on some existing /// trie nodes. This internally takes care of refcounting. @@ -232,7 +245,6 @@ mod tests { extension: &NibbleSlice::new(&[]).encoded(true), }, ); - root.as_ptr_mut(arena.memory_mut()).compute_hash_recursively(); Ok(Some(root)) }) .unwrap(); diff --git a/core/store/src/trie/mem/node/encoding.rs b/core/store/src/trie/mem/node/encoding.rs index 7d69c9b29cc..33ed9ffb2f2 100644 --- a/core/store/src/trie/mem/node/encoding.rs +++ b/core/store/src/trie/mem/node/encoding.rs @@ -5,10 +5,9 @@ use crate::trie::mem::flexible_data::encoding::{BorshFixedSize, RawDecoder, RawE use crate::trie::mem::flexible_data::extension::EncodedExtensionHeader; use crate::trie::mem::flexible_data::value::EncodedValueHeader; use crate::trie::mem::flexible_data::FlexibleDataHeader; -use crate::trie::TRIE_COSTS; use borsh::{BorshDeserialize, BorshSerialize}; use near_primitives::hash::CryptoHash; -use near_primitives::state::FlatStateValue; +use std::mem::size_of; use smallvec::SmallVec; @@ -37,8 +36,8 @@ pub(crate) struct NonLeafHeader { } impl NonLeafHeader { - pub(crate) fn new(memory_usage: u64, node_hash: Option) -> Self { - Self { hash: node_hash.unwrap_or_default(), memory_usage } + pub(crate) fn new(memory_usage: u64, node_hash: CryptoHash) -> Self { + Self { hash: node_hash, memory_usage } } } @@ -128,43 +127,14 @@ impl MemTrieNodeId { } _ => {} } - // Let's also compute the memory usage of the node. We only do this for - // non-leaf nodes, because for leaf node it is very easy to just - // compute it on demand, so there's no need to store it. - let memory_usage = match &node { - InputMemTrieNode::Leaf { .. } => 0, - InputMemTrieNode::Extension { extension, child } => { - TRIE_COSTS.node_cost - + extension.len() as u64 * TRIE_COSTS.byte_of_key - + child.as_ptr(arena.memory()).view().memory_usage() - } - InputMemTrieNode::Branch { children } => { - let mut memory_usage = TRIE_COSTS.node_cost; - for child in children.iter() { - if let Some(child) = child { - memory_usage += child.as_ptr(arena.memory()).view().memory_usage(); - } - } - memory_usage - } - InputMemTrieNode::BranchWithValue { children, value } => { - let value_len = match value { - FlatStateValue::Ref(value_ref) => value_ref.len(), - FlatStateValue::Inlined(value) => value.len(), - }; - let mut memory_usage = TRIE_COSTS.node_cost - + value_len as u64 * TRIE_COSTS.byte_of_value - + TRIE_COSTS.node_cost; - for child in children.iter() { - if let Some(child) = child { - memory_usage += child.as_ptr(arena.memory()).view().memory_usage(); - } - } - memory_usage - } + // Prepare the raw node, for memory usage and hash computation. + let raw_node_with_size = if matches!(&node, InputMemTrieNode::Leaf { .. }) { + None + } else { + Some(node.to_raw_trie_node_with_size_non_leaf(arena.memory())) }; - // Finally, encode the data. We're still leaving the hash empty; that - // will be computed later in parallel. + + // Finally, encode the data. let data = match node { InputMemTrieNode::Leaf { value, extension } => { let extension_header = EncodedExtensionHeader::from_input(extension); @@ -190,9 +160,13 @@ impl MemTrieNodeId { arena, ExtensionHeader::SERIALIZED_SIZE + extension_header.flexible_data_length(), ); + let raw_node_with_size = raw_node_with_size.unwrap(); data.encode(ExtensionHeader { common: CommonHeader { refcount: 0, kind: NodeKind::Extension }, - nonleaf: NonLeafHeader::new(memory_usage, node_hash), + nonleaf: NonLeafHeader::new( + raw_node_with_size.memory_usage, + node_hash.unwrap_or_else(|| raw_node_with_size.hash()), + ), child: child.pos, extension: extension_header, }); @@ -205,9 +179,13 @@ impl MemTrieNodeId { arena, BranchHeader::SERIALIZED_SIZE + children_header.flexible_data_length(), ); + let raw_node_with_size = raw_node_with_size.unwrap(); data.encode(BranchHeader { common: CommonHeader { refcount: 0, kind: NodeKind::Branch }, - nonleaf: NonLeafHeader::new(memory_usage, node_hash), + nonleaf: NonLeafHeader::new( + raw_node_with_size.memory_usage, + node_hash.unwrap_or_else(|| raw_node_with_size.hash()), + ), children: children_header, }); data.encode_flexible(&children_header, &children); @@ -222,9 +200,13 @@ impl MemTrieNodeId { + children_header.flexible_data_length() + value_header.flexible_data_length(), ); + let raw_node_with_size = raw_node_with_size.unwrap(); data.encode(BranchWithValueHeader { common: CommonHeader { refcount: 0, kind: NodeKind::BranchWithValue }, - nonleaf: NonLeafHeader::new(memory_usage, node_hash), + nonleaf: NonLeafHeader::new( + raw_node_with_size.memory_usage, + node_hash.unwrap_or_else(|| raw_node_with_size.hash()), + ), children: children_header, value: value_header, }); @@ -238,24 +220,22 @@ impl MemTrieNodeId { /// Increments the refcount, returning the new refcount. pub(crate) fn add_ref(&self, memory: &mut impl ArenaMemory) -> u32 { - let mut ptr = self.as_ptr_mut(memory); - let mut decoder = ptr.decoder_mut(); - let mut header = decoder.peek::(); - let new_refcount = header.refcount + 1; - header.refcount = new_refcount; - decoder.overwrite(header); + // Refcount is always encoded as the first four bytes of the node memory. + let refcount_memory = memory.raw_slice_mut(self.pos, size_of::()); + let refcount = u32::from_le_bytes(refcount_memory.try_into().unwrap()); + let new_refcount = refcount.checked_add(1).unwrap(); + refcount_memory.copy_from_slice(new_refcount.to_le_bytes().as_ref()); new_refcount } /// Decrements the refcount, deallocating the node if it reaches zero. /// Returns the new refcount. pub(crate) fn remove_ref(&self, arena: &mut impl ArenaWithDealloc) -> u32 { - let mut ptr = self.as_ptr_mut(arena.memory_mut()); - let mut decoder = ptr.decoder_mut(); - let mut header = decoder.peek::(); - let new_refcount = header.refcount - 1; - header.refcount = new_refcount; - decoder.overwrite(header); + // Refcount is always encoded as the first four bytes of the node memory. + let refcount_memory = arena.memory_mut().raw_slice_mut(self.pos, size_of::()); + let refcount = u32::from_le_bytes(refcount_memory.try_into().unwrap()); + let new_refcount = refcount.checked_sub(1).unwrap(); + refcount_memory.copy_from_slice(new_refcount.to_le_bytes().as_ref()); if new_refcount == 0 { let mut children_to_unref: SmallVec<[ArenaPos; 16]> = SmallVec::new(); let node_ptr = self.as_ptr(arena.memory()); diff --git a/core/store/src/trie/mem/node/mod.rs b/core/store/src/trie/mem/node/mod.rs index d23e2bdf98e..1690b67dacc 100644 --- a/core/store/src/trie/mem/node/mod.rs +++ b/core/store/src/trie/mem/node/mod.rs @@ -1,13 +1,14 @@ -use super::arena::{Arena, ArenaMemory, ArenaPos, ArenaPtr, ArenaPtrMut}; +use super::arena::{Arena, ArenaMemory, ArenaPos, ArenaPtr}; use super::flexible_data::children::ChildrenView; use super::flexible_data::value::ValueView; +use crate::trie::{Children, TRIE_COSTS}; +use crate::{RawTrieNode, RawTrieNodeWithSize}; use derive_where::derive_where; use near_primitives::hash::CryptoHash; use near_primitives::state::FlatStateValue; use std::fmt::{Debug, Formatter}; mod encoding; -mod mutation; #[cfg(test)] mod tests; mod view; @@ -42,13 +43,6 @@ impl MemTrieNodeId { pub fn as_ptr<'a, M: ArenaMemory>(&self, arena: &'a M) -> MemTrieNodePtr<'a, M> { MemTrieNodePtr { ptr: arena.ptr(self.pos) } } - - pub(crate) fn as_ptr_mut<'a, M: ArenaMemory>( - &self, - arena: &'a mut M, - ) -> MemTrieNodePtrMut<'a, M> { - MemTrieNodePtrMut { ptr: arena.ptr_mut(self.pos) } - } } /// This is for internal use only, so that we can put `MemTrieNodeId` in an @@ -66,13 +60,6 @@ pub struct MemTrieNodePtr<'a, M: ArenaMemory> { ptr: ArenaPtr<'a, M>, } -/// Pointer to an in-memory trie node that allows mutable access to the node -/// and all its descendants. This is only for computing hashes, and internal -/// reference counting. -pub struct MemTrieNodePtrMut<'a, M: ArenaMemory> { - ptr: ArenaPtrMut<'a, M>, -} - impl<'a, M: ArenaMemory> Debug for MemTrieNodePtr<'a, M> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { self.id().fmt(f) @@ -128,3 +115,60 @@ pub enum MemTrieNodeView<'a, M: ArenaMemory> { value: ValueView<'a>, }, } + +impl<'a> InputMemTrieNode<'a> { + /// Converts the input node into a `RawTrieNodeWithSize`; this is used to initialize + /// memory usage and to calculate hash when constructing the memtrie node. + /// + /// This must not be called if the node is a leaf. + pub fn to_raw_trie_node_with_size_non_leaf( + &self, + arena: &Memory, + ) -> RawTrieNodeWithSize { + match self { + Self::Leaf { .. } => { + unreachable!("Leaf nodes do not need hash computation") + } + Self::Extension { extension, child, .. } => { + let view = child.as_ptr(arena).view(); + let memory_usage = TRIE_COSTS.node_cost + + extension.len() as u64 * TRIE_COSTS.byte_of_key + + view.memory_usage(); + let node = RawTrieNode::Extension(extension.to_vec(), view.node_hash()); + RawTrieNodeWithSize { node, memory_usage } + } + Self::Branch { children, .. } => { + let mut memory_usage = TRIE_COSTS.node_cost; + let mut hashes = [None; 16]; + for (i, child) in children.iter().enumerate() { + if let Some(child) = child { + let view = child.as_ptr(arena).view(); + hashes[i] = Some(view.node_hash()); + memory_usage += view.memory_usage(); + } + } + let node = RawTrieNode::BranchNoValue(Children(hashes)); + RawTrieNodeWithSize { node, memory_usage } + } + Self::BranchWithValue { children, value, .. } => { + let value_len = match value { + FlatStateValue::Ref(value_ref) => value_ref.len(), + FlatStateValue::Inlined(value) => value.len(), + }; + let mut memory_usage = TRIE_COSTS.node_cost + + value_len as u64 * TRIE_COSTS.byte_of_value + + TRIE_COSTS.node_cost; + let mut hashes = [None; 16]; + for (i, child) in children.iter().enumerate() { + if let Some(child) = child { + let view = child.as_ptr(arena).view(); + hashes[i] = Some(view.node_hash()); + memory_usage += view.memory_usage(); + } + } + let node = RawTrieNode::BranchWithValue(value.to_value_ref(), Children(hashes)); + RawTrieNodeWithSize { node, memory_usage } + } + } + } +} diff --git a/core/store/src/trie/mem/node/mutation.rs b/core/store/src/trie/mem/node/mutation.rs deleted file mode 100644 index f0c9543aa49..00000000000 --- a/core/store/src/trie/mem/node/mutation.rs +++ /dev/null @@ -1,103 +0,0 @@ -use super::encoding::{CommonHeader, NodeKind, NonLeafHeader}; -use super::{MemTrieNodePtr, MemTrieNodePtrMut}; -use crate::trie::mem::arena::ArenaMemory; -use crate::trie::mem::flexible_data::encoding::RawDecoderMut; -use near_primitives::hash::{hash, CryptoHash}; - -impl<'a, M: ArenaMemory> MemTrieNodePtrMut<'a, M> { - fn as_const<'b>(&'b self) -> MemTrieNodePtr<'b, M> { - MemTrieNodePtr { ptr: self.ptr.ptr() } - } - - pub(crate) fn decoder_mut(&mut self) -> RawDecoderMut { - RawDecoderMut::new(self.ptr.ptr_mut()) - } - - /// Obtains a list of mutable references to the children of this node, - /// destroying this mutable reference. - /// - /// Despite being implemented with unsafe code, this is a safe operation - /// because the children subtrees are disjoint (even if there are multiple - /// roots). It is very similar to `split_at_mut` on mutable slices. - fn split_children_mut(mut self) -> Vec> { - let arena_mut = self.ptr.arena_mut() as *mut M; - let mut result = Vec::new(); - let view = self.as_const().view(); - for child in view.iter_children() { - let child_id = child.id(); - let arena_mut_ref = unsafe { &mut *arena_mut }; - result.push(child_id.as_ptr_mut(arena_mut_ref)); - } - result - } - - /// Like `split_children_mut`, but does not destroy the reference itself. - /// This is possible because of the returned references can only be used - /// while this reference is being mutably held, but it does result in a - /// different lifetime. - fn children_mut<'b>(&'b mut self) -> Vec> { - let arena_mut = self.ptr.arena_mut() as *mut M; - let mut result = Vec::new(); - let view = self.as_const().view(); - for child in view.iter_children() { - let child_id = child.id(); - let arena_mut_ref = unsafe { &mut *arena_mut }; - result.push(child_id.as_ptr_mut(arena_mut_ref)); - } - result - } - - /// Computes the hash for this node, assuming children nodes already have - /// computed hashes. - fn compute_hash(&mut self) { - let raw_trie_node_with_size = self.as_const().view().to_raw_trie_node_with_size(); - let mut decoder = self.decoder_mut(); - match decoder.decode::().kind { - NodeKind::Leaf => {} - _ => { - let mut nonleaf = decoder.peek::(); - nonleaf.hash = hash(&borsh::to_vec(&raw_trie_node_with_size).unwrap()); - decoder.overwrite(nonleaf); - } - } - } - - /// Whether the hash is computed for this node. - fn is_hash_computed(&self) -> bool { - let mut decoder = self.as_const().decoder(); - match decoder.decode::().kind { - NodeKind::Leaf => true, - _ => decoder.peek::().hash != CryptoHash::default(), - } - } - - /// Computes the hashes of this subtree recursively, stopping at any nodes - /// whose hashes are already computed. - pub(crate) fn compute_hash_recursively(&mut self) { - if self.is_hash_computed() { - return; - } - for mut child in self.children_mut() { - child.compute_hash_recursively(); - } - self.compute_hash(); - } - - /// Recursively expand the current subtree until we arrive at subtrees - /// that are small enough (by memory usage); we store these subtrees in - /// the provided vector. The returned subtrees cover all leaves but are - /// disjoint. - pub(crate) fn take_small_subtrees( - self, - threshold_memory_usage: u64, - trees: &mut Vec>, - ) { - if self.as_const().view().memory_usage() < threshold_memory_usage { - trees.push(self); - } else { - for child in self.split_children_mut() { - child.take_small_subtrees(threshold_memory_usage, trees); - } - } - } -} diff --git a/core/store/src/trie/mem/node/tests.rs b/core/store/src/trie/mem/node/tests.rs index 200e21b7cfb..5dd61d3faed 100644 --- a/core/store/src/trie/mem/node/tests.rs +++ b/core/store/src/trie/mem/node/tests.rs @@ -110,7 +110,6 @@ fn test_basic_extension_node() { &mut arena, InputMemTrieNode::Extension { extension: &[5, 6, 7, 8, 9], child }, ); - node.as_ptr_mut(arena.memory_mut()).compute_hash_recursively(); let child_ptr = child.as_ptr(arena.memory()); let node_ptr = node.as_ptr(arena.memory()); assert_eq!( @@ -159,7 +158,6 @@ fn test_basic_branch_node() { &mut arena, InputMemTrieNode::Branch { children: branch_array(vec![(3, child1), (5, child2)]) }, ); - node.as_ptr_mut(arena.memory_mut()).compute_hash_recursively(); let child1_ptr = child1.as_ptr(arena.memory()); let child2_ptr = child2.as_ptr(arena.memory()); let node_ptr = node.as_ptr(arena.memory()); @@ -227,8 +225,6 @@ fn test_basic_branch_with_value_node() { }, ); - node.as_ptr_mut(arena.memory_mut()).compute_hash_recursively(); - let child1_ptr = child1.as_ptr(arena.memory()); let child2_ptr = child2.as_ptr(arena.memory()); let node_ptr = node.as_ptr(arena.memory()); diff --git a/core/store/src/trie/mem/parallel_loader.rs b/core/store/src/trie/mem/parallel_loader.rs new file mode 100644 index 00000000000..8917a35d2ad --- /dev/null +++ b/core/store/src/trie/mem/parallel_loader.rs @@ -0,0 +1,463 @@ +use super::arena::concurrent::{ConcurrentArena, ConcurrentArenaForThread}; +use super::arena::{Arena, STArena}; +use super::construction::TrieConstructor; +use super::node::{InputMemTrieNode, MemTrieNodeId}; +use crate::flat::FlatStorageError; +use crate::trie::Children; +use crate::{DBCol, NibbleSlice, RawTrieNode, RawTrieNodeWithSize, Store}; +use borsh::BorshDeserialize; +use near_primitives::errors::{MissingTrieValueContext, StorageError}; +use near_primitives::hash::CryptoHash; +use near_primitives::shard_layout::ShardUId; +use near_primitives::state::FlatStateValue; +use near_primitives::types::StateRoot; +use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; +use std::fmt::Debug; +use std::sync::Mutex; + +/// Top-level entry function to load a memtrie in parallel. +pub fn load_memtrie_in_parallel( + store: Store, + shard_uid: ShardUId, + root: StateRoot, + num_subtrees_desired: usize, + name: String, +) -> Result<(STArena, MemTrieNodeId), StorageError> { + let reader = ParallelMemTrieLoader::new(store, shard_uid, root, num_subtrees_desired); + let plan = reader.make_loading_plan()?; + tracing::info!("Loading {} subtrees in parallel", plan.subtrees_to_load.len()); + reader.load_in_parallel(plan, name) +} + +/// Logic to load a memtrie in parallel. It consists of three stages: +/// - First, we use the State column to visit the trie starting from the root. We recursively +/// expand the trie until all the unexpanded subtrees are small enough +/// (memory_usage <= `subtree_size`). The trie we have expanded is represented as a "plan", +/// which is a structure similar to the trie itself. +/// - Then, we load each small subtree (the keys under which all share a common prefix) in +/// parallel, by reading the FlatState column for keys that correspond to the prefix of that +/// subtree. The result of each construction is a `MemTrieNodeId` representing the root of that +/// subtree. +/// - Finally, We construct the final trie by using the loaded subtree roots and converting the +/// plan into a complete memtrie, returning the final root. +/// +/// This loader is only suitable for loading a single trie. It does not load multiple state roots, +/// or multiple shards. +pub struct ParallelMemTrieLoader { + store: Store, + shard_uid: ShardUId, + root: StateRoot, + num_subtrees_desired: usize, +} + +impl ParallelMemTrieLoader { + pub fn new( + store: Store, + shard_uid: ShardUId, + root: StateRoot, + num_subtrees_desired: usize, + ) -> Self { + Self { store, shard_uid, root, num_subtrees_desired } + } + + /// Implements stage 1; recursively expanding the trie until all subtrees are small enough. + fn make_loading_plan(&self) -> Result { + let subtrees_to_load = Mutex::new(Vec::new()); + let root = self.make_loading_plan_recursive( + self.root, + NibblePrefix::new(), + &subtrees_to_load, + None, + )?; + Ok(PartialTrieLoadingPlan { + root, + subtrees_to_load: subtrees_to_load.into_inner().unwrap(), + }) + } + + /// Helper function to implement stage 1, visiting a single node identified by this hash, + /// whose prefix is the given prefix. While expanding this node, any small subtrees + /// encountered are appended to the `subtrees_to_load` array. + fn make_loading_plan_recursive( + &self, + hash: CryptoHash, + mut prefix: NibblePrefix, + subtrees_to_load: &Mutex>, + max_subtree_size: Option, + ) -> Result { + // Read the node from the State column. + let mut key = [0u8; 40]; + key[0..8].copy_from_slice(&self.shard_uid.to_bytes()); + key[8..40].copy_from_slice(&hash.0); + let node = RawTrieNodeWithSize::try_from_slice( + &self + .store + .get(DBCol::State, &key) + .map_err(|e| StorageError::StorageInconsistentState(e.to_string()))? + .ok_or(StorageError::MissingTrieValue(MissingTrieValueContext::TrieStorage, hash))? + .as_slice(), + ) + .map_err(|e| StorageError::StorageInconsistentState(e.to_string()))?; + + let max_subtree_size = max_subtree_size + .unwrap_or_else(|| node.memory_usage / self.num_subtrees_desired as u64); + + // If subtree is small enough, add it to the list of subtrees to load, and we're done. + if node.memory_usage <= max_subtree_size { + let mut lock = subtrees_to_load.lock().unwrap(); + let subtree_id = lock.len(); + lock.push(prefix); + return Ok(TrieLoadingPlanNode::Load { subtree_id }); + } + + match node.node { + RawTrieNode::Leaf(extension, value_ref) => { + // If we happen to visit a leaf, we'll have to just read the leaf's value. This is + // almost like a corner case because we're not really interested in values here + // (that's the job of the parallel loading part), but if we do get here, we have to + // deal with it. + key[8..40].copy_from_slice(&value_ref.hash.0); + let value = self + .store + .get(DBCol::State, &key) + .map_err(|e| StorageError::StorageInconsistentState(e.to_string()))? + .ok_or(StorageError::MissingTrieValue( + MissingTrieValueContext::TrieStorage, + hash, + ))?; + let flat_value = FlatStateValue::on_disk(&value); + Ok(TrieLoadingPlanNode::Leaf { + extension: extension.into_boxed_slice(), + value: flat_value, + }) + } + RawTrieNode::BranchNoValue(children_hashes) => { + // If we visit a branch, recursively visit all children. + let children = self.make_children_plans_in_parallel( + children_hashes, + &prefix, + subtrees_to_load, + max_subtree_size, + )?; + + Ok(TrieLoadingPlanNode::Branch { children, value: None }) + } + RawTrieNode::BranchWithValue(value_ref, children_hashes) => { + // Similar here, except we have to also look up the value. + key[8..40].copy_from_slice(&value_ref.hash.0); + let value = self + .store + .get(DBCol::State, &key) + .map_err(|e| StorageError::StorageInconsistentState(e.to_string()))? + .ok_or(StorageError::MissingTrieValue( + MissingTrieValueContext::TrieStorage, + hash, + ))?; + let flat_value = FlatStateValue::on_disk(&value); + + let children = self.make_children_plans_in_parallel( + children_hashes, + &prefix, + subtrees_to_load, + max_subtree_size, + )?; + + Ok(TrieLoadingPlanNode::Branch { children, value: Some(flat_value) }) + } + RawTrieNode::Extension(extension, child) => { + let nibbles = NibbleSlice::from_encoded(&extension).0; + prefix.append(&nibbles); + let child = self.make_loading_plan_recursive( + child, + prefix, + subtrees_to_load, + Some(max_subtree_size), + )?; + Ok(TrieLoadingPlanNode::Extension { + extension: extension.into_boxed_slice(), + child: Box::new(child), + }) + } + } + } + + fn make_children_plans_in_parallel( + &self, + children_hashes: Children, + prefix: &NibblePrefix, + subtrees_to_load: &Mutex>, + max_subtree_size: u64, + ) -> Result)>, StorageError> { + let existing_children = children_hashes.iter().collect::>(); + let children = existing_children + .into_par_iter() + .map(|(i, child_hash)| -> Result<_, StorageError> { + let mut prefix = prefix.clone(); + prefix.push(i as u8); + let node = self.make_loading_plan_recursive( + *child_hash, + prefix, + subtrees_to_load, + Some(max_subtree_size), + )?; + Ok((i, Box::new(node))) + }) + .collect::, _>>()?; + Ok(children) + } + + /// This implements the loading of each subtree in stage 2. + fn load_one_subtree( + &self, + subtree_to_load: &NibblePrefix, + arena: &mut impl Arena, + ) -> Result { + // Figure out which range corresponds to the prefix of this subtree. + let (start, end) = subtree_to_load.to_iter_range(self.shard_uid); + + // Load all the keys in this range from the FlatState column. + let mut recon = TrieConstructor::new(arena); + for item in self.store.iter_range(DBCol::FlatState, Some(&start), Some(&end)) { + let (key, value) = item.map_err(|err| { + FlatStorageError::StorageInternalError(format!( + "Error iterating over FlatState: {err}" + )) + })?; + let key = NibbleSlice::new(&key[8..]).mid(subtree_to_load.num_nibbles()); + let value = FlatStateValue::try_from_slice(&value).map_err(|err| { + FlatStorageError::StorageInternalError(format!( + "invalid FlatState value format: {err}" + )) + })?; + recon.add_leaf(key, value); + } + Ok(recon.finalize().unwrap()) + } + + /// This implements stage 2 and 3, loading the subtrees in parallel an then constructing the + /// final trie. + fn load_in_parallel( + &self, + plan: PartialTrieLoadingPlan, + name: String, + ) -> Result<(STArena, MemTrieNodeId), StorageError> { + let arena = ConcurrentArena::new(); + + // A bit of an awkward Rayon dance. We run a multi-threaded fold; the fold state contains + // both a sparse vector of the loading results as well as the arena used for the thread. + // We need to collect both in the end, so fold is the only suitable method. + let (roots, threads): ( + Vec>>, + Vec, + ) = plan + .subtrees_to_load + .into_par_iter() + .enumerate() + .fold(|| -> (Vec>, ConcurrentArenaForThread) { + (Vec::new(), arena.for_thread()) + }, |(mut roots, mut arena), (i, prefix)| { + roots.push(self.load_one_subtree(&prefix, &mut arena).map(|root| (i, root))); + (roots, arena) + }) + .unzip(); + + let mut roots = roots.into_iter().flatten().collect::, _>>()?; + roots.sort_by_key(|(i, _)| *i); + let roots = roots.into_iter().map(|(_, root)| root).collect::>(); + + let mut arena = arena.to_single_threaded(name, threads); + let root = plan.root.to_node(&mut arena, &roots); + Ok((arena, root)) + } +} + +/// Specifies exactly what to do to create a node in the final trie. +#[derive(Debug)] +enum TrieLoadingPlanNode { + // The first three cases correspond exactly to the trie structure. + Branch { children: Vec<(u8, Box)>, value: Option }, + Extension { extension: Box<[u8]>, child: Box }, + Leaf { extension: Box<[u8]>, value: FlatStateValue }, + // This means this trie node is whatever loading this subtree yields. + Load { subtree_id: usize }, +} + +impl TrieLoadingPlanNode { + /// This implements the construction part of stage 3, where we convert a plan node to + /// a memtrie node. The `subtree_roots` is the parallel loading results. + fn to_node(self, arena: &mut impl Arena, subtree_roots: &[MemTrieNodeId]) -> MemTrieNodeId { + match self { + TrieLoadingPlanNode::Branch { children, value } => { + let mut res_children = [None; 16]; + for (nibble, child) in children { + res_children[nibble as usize] = Some(child.to_node(arena, subtree_roots)); + } + let input = match &value { + Some(value) => { + InputMemTrieNode::BranchWithValue { children: res_children, value } + } + None => InputMemTrieNode::Branch { children: res_children }, + }; + MemTrieNodeId::new(arena, input) + } + TrieLoadingPlanNode::Extension { extension, child } => { + let child = child.to_node(arena, subtree_roots); + let input = InputMemTrieNode::Extension { extension: &extension, child }; + MemTrieNodeId::new(arena, input) + } + TrieLoadingPlanNode::Leaf { extension, value } => { + let input = InputMemTrieNode::Leaf { extension: &extension, value: &value }; + MemTrieNodeId::new(arena, input) + } + TrieLoadingPlanNode::Load { subtree_id } => subtree_roots[subtree_id], + } + } +} + +#[derive(Debug)] +struct PartialTrieLoadingPlan { + root: TrieLoadingPlanNode, + subtrees_to_load: Vec, +} + +/// Represents a prefix of nibbles. Allows appending to the prefix, and implements logic of +/// calculating a range of keys that correspond to this prefix. +/// +/// A nibble just means a 4 bit number. +#[derive(Clone)] +struct NibblePrefix { + /// Big endian encoding of the nibbles. If there are an odd number of nibbles, this is + /// the encoding of the nibbles as if there were one more nibble at the end being zero. + prefix: Vec, + /// Whether the last byte of `prefix` represents one nibble rather than two. + odd: bool, +} + +impl Debug for NibblePrefix { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.odd { + write!( + f, + "{}{:x}", + hex::encode(&self.prefix[..self.prefix.len() - 1]), + self.prefix.last().unwrap() >> 4 + ) + } else { + write!(f, "{}", hex::encode(&self.prefix)) + } + } +} + +impl NibblePrefix { + pub fn new() -> Self { + Self { prefix: Vec::new(), odd: false } + } + + pub fn num_nibbles(&self) -> usize { + self.prefix.len() * 2 - if self.odd { 1 } else { 0 } + } + + pub fn push(&mut self, nibble: u8) { + debug_assert!(nibble < 16, "nibble must be less than 16"); + if self.odd { + *self.prefix.last_mut().unwrap() |= nibble; + } else { + self.prefix.push(nibble << 4); + } + self.odd = !self.odd; + } + + pub fn append(&mut self, nibbles: &NibbleSlice) { + for nibble in nibbles.iter() { + self.push(nibble); + } + } + + /// Converts the nibble prefix to an equivalent range of FlatState keys. + /// + /// If the number of nibbles is even, this is straight-forward; the keys will be in the form of + /// e.g. 0x123456 - 0x123457. If the number of nibbles is odd, the keys will cover the whole + /// range for the last 4 bits, e.g. 0x123450 - 0x123460. + pub fn to_iter_range(&self, shard_uid: ShardUId) -> (Vec, Vec) { + let start = shard_uid + .to_bytes() + .into_iter() + .chain(self.prefix.clone().into_iter()) + .collect::>(); + // The end key should always exist because we have a shard UID prefix to absorb the overflow. + let end = + calculate_end_key(&start, if self.odd { 16 } else { 1 }).expect("Should not overflow"); + (start, end) + } +} + +/// Calculates the end key of a lexically ordered key range where all the keys start with `start_key` +/// except that the i-th byte may be within [b, b + last_byte_increment), where i == start_key.len() - 1, +/// and b == start_key[i]. Returns None is the end key is unbounded. +fn calculate_end_key(start_key: &Vec, last_byte_increment: u8) -> Option> { + let mut v = start_key.clone(); + let mut carry = last_byte_increment; + for i in (0..v.len()).rev() { + let (new_val, overflowing) = v[i].overflowing_add(carry); + if overflowing { + carry = 1; + v.pop(); + } else { + v[i] = new_val; + return Some(v); + } + } + return None; +} + +#[cfg(test)] +mod tests { + use super::NibblePrefix; + use crate::trie::mem::parallel_loader::calculate_end_key; + use crate::NibbleSlice; + use near_primitives::shard_layout::ShardUId; + + #[test] + fn test_increment_vec_as_num() { + assert_eq!(calculate_end_key(&vec![0, 0, 0], 1), Some(vec![0, 0, 1])); + assert_eq!(calculate_end_key(&vec![0, 0, 255], 1), Some(vec![0, 1])); + assert_eq!(calculate_end_key(&vec![0, 5, 255], 1), Some(vec![0, 6])); + assert_eq!(calculate_end_key(&vec![0, 255, 255], 1), Some(vec![1])); + assert_eq!(calculate_end_key(&vec![255, 255, 254], 2), None); + } + + #[test] + fn test_nibble_prefix() { + let shard_uid = ShardUId { shard_id: 3, version: 2 }; + let iter_range = |prefix: &NibblePrefix| { + let (start, end) = prefix.to_iter_range(shard_uid); + format!("{}..{}", hex::encode(&start), hex::encode(&end)) + }; + + let mut prefix = NibblePrefix::new(); + assert_eq!(format!("{:?}", prefix), ""); + assert_eq!(iter_range(&prefix), "0200000003000000..0200000003000001"); + + prefix.push(4); + assert_eq!(format!("{:?}", prefix), "4"); + assert_eq!(iter_range(&prefix), "020000000300000040..020000000300000050"); + + prefix.push(15); + assert_eq!(format!("{:?}", prefix), "4f"); + assert_eq!(iter_range(&prefix), "02000000030000004f..020000000300000050"); + + prefix.append(&NibbleSlice::new(&hex::decode("5123").unwrap()).mid(1)); + assert_eq!(format!("{:?}", prefix), "4f123"); + assert_eq!(iter_range(&prefix), "02000000030000004f1230..02000000030000004f1240"); + + prefix.append(&NibbleSlice::new(&hex::decode("ff").unwrap())); + assert_eq!(format!("{:?}", prefix), "4f123ff"); + assert_eq!(iter_range(&prefix), "02000000030000004f123ff0..02000000030000004f1240"); + + let mut prefix = NibblePrefix::new(); + prefix.push(15); + prefix.push(15); + assert_eq!(format!("{:?}", prefix), "ff"); + assert_eq!(iter_range(&prefix), "0200000003000000ff..0200000003000001"); + } +} diff --git a/core/store/src/trie/raw_node.rs b/core/store/src/trie/raw_node.rs index 427e25fc3d7..5c1d117cde2 100644 --- a/core/store/src/trie/raw_node.rs +++ b/core/store/src/trie/raw_node.rs @@ -12,6 +12,12 @@ pub struct RawTrieNodeWithSize { pub(super) memory_usage: u64, } +impl RawTrieNodeWithSize { + pub fn hash(&self) -> CryptoHash { + CryptoHash::hash_bytes(&borsh::to_vec(self).unwrap()) + } +} + /// Trie node. #[derive(BorshSerialize, BorshDeserialize, Clone, Debug, PartialEq, Eq)] #[allow(clippy::large_enum_variant)] diff --git a/core/store/src/trie/shard_tries.rs b/core/store/src/trie/shard_tries.rs index 39e87a5261d..035f6665d30 100644 --- a/core/store/src/trie/shard_tries.rs +++ b/core/store/src/trie/shard_tries.rs @@ -417,9 +417,15 @@ impl ShardTries { &self, shard_uid: &ShardUId, state_root: Option, + parallelize: bool, ) -> Result<(), StorageError> { info!(target: "memtrie", "Loading trie to memory for shard {:?}...", shard_uid); - let mem_tries = load_trie_from_flat_state_and_delta(&self.0.store, *shard_uid, state_root)?; + let mem_tries = load_trie_from_flat_state_and_delta( + &self.0.store, + *shard_uid, + state_root, + parallelize, + )?; self.0.mem_tries.write().unwrap().insert(*shard_uid, Arc::new(RwLock::new(mem_tries))); info!(target: "memtrie", "Memtrie loading complete for shard {:?}", shard_uid); Ok(()) @@ -438,7 +444,7 @@ impl ShardTries { // It should not happen that memtrie is already loaded for a shard // for which we just did state sync. debug_assert!(!self.0.mem_tries.read().unwrap().contains_key(shard_uid)); - self.load_mem_trie(shard_uid, Some(*state_root)) + self.load_mem_trie(shard_uid, Some(*state_root), false) } /// Loads in-memory tries upon startup. The given shard_uids are possible candidates to load, @@ -447,6 +453,7 @@ impl ShardTries { pub fn load_mem_tries_for_enabled_shards( &self, tracked_shards: &[ShardUId], + parallelize: bool, ) -> Result<(), StorageError> { let trie_config = &self.0.trie_config; let shard_uids_to_load = tracked_shards @@ -461,7 +468,7 @@ impl ShardTries { info!(target: "memtrie", "Loading tries to memory for shards {:?}...", shard_uids_to_load); shard_uids_to_load .par_iter() - .map(|shard_uid| self.load_mem_trie(shard_uid, None)) + .map(|shard_uid| self.load_mem_trie(shard_uid, None, parallelize)) .collect::>>() .into_iter() .collect::>()?; diff --git a/core/store/src/trie/trie_recording.rs b/core/store/src/trie/trie_recording.rs index 112b682d498..e29e2bb5069 100644 --- a/core/store/src/trie/trie_recording.rs +++ b/core/store/src/trie/trie_recording.rs @@ -360,7 +360,7 @@ mod trie_recording_tests { // Now let's do this again with memtries enabled. Check that counters // are the same. assert_eq!(MEM_TRIE_NUM_LOOKUPS.get(), mem_trie_lookup_counts_before); - tries.load_mem_trie(&shard_uid, None).unwrap(); + tries.load_mem_trie(&shard_uid, None, false).unwrap(); // Delete the on-disk state so that we really know we're using // in-memory tries. destructively_delete_in_memory_state_from_disk(&store, &data_in_trie); diff --git a/tools/database/src/memtrie.rs b/tools/database/src/memtrie.rs index d61af3dc04e..4b5196eafa6 100644 --- a/tools/database/src/memtrie.rs +++ b/tools/database/src/memtrie.rs @@ -19,6 +19,8 @@ use std::time::Duration; pub struct LoadMemTrieCommand { #[clap(long, use_value_delimiter = true, value_delimiter = ',')] shard_id: Option>, + #[clap(long)] + no_parallel: bool, } impl LoadMemTrieCommand { @@ -59,8 +61,14 @@ impl LoadMemTrieCommand { .context("could not create the transaction runtime")?; println!("Loading memtries for shards {:?}...", selected_shard_uids); - runtime.get_tries().load_mem_tries_for_enabled_shards(&selected_shard_uids)?; - println!("Finished loading memtries, press Ctrl-C to exit."); + let start_time = std::time::Instant::now(); + runtime + .get_tries() + .load_mem_tries_for_enabled_shards(&selected_shard_uids, !self.no_parallel)?; + println!( + "Finished loading memtries, took {:?}, press Ctrl-C to exit.", + start_time.elapsed() + ); std::thread::sleep(Duration::from_secs(10_000_000_000)); Ok(()) } diff --git a/tools/fork-network/src/cli.rs b/tools/fork-network/src/cli.rs index 9a007ae4818..8c747a95391 100644 --- a/tools/fork-network/src/cli.rs +++ b/tools/fork-network/src/cli.rs @@ -332,7 +332,7 @@ impl ForkNetworkCommand { let runtime = NightshadeRuntime::from_config(home_dir, store.clone(), &near_config, epoch_manager) .context("could not create the transaction runtime")?; - runtime.get_tries().load_mem_tries_for_enabled_shards(&all_shard_uids).unwrap(); + runtime.get_tries().load_mem_tries_for_enabled_shards(&all_shard_uids, true).unwrap(); let make_storage_mutator: MakeSingleShardStorageMutatorFn = Arc::new(move |prev_state_root| {