Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: remove part functions from epoch manager #12834

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions chain/chain/src/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ use crate::orphan::{Orphan, OrphanBlockPool};
use crate::rayon_spawner::RayonAsyncComputationSpawner;
use crate::resharding::manager::ReshardingManager;
use crate::resharding::types::ReshardingSender;
use crate::sharding::{get_receipts_shuffle_salt, shuffle_receipt_proofs};
use crate::sharding::{
get_part_owner, get_receipts_shuffle_salt, num_total_parts, shuffle_receipt_proofs,
};
use crate::signature_verification::{
verify_block_header_signature_with_epoch_manager, verify_block_vrf,
verify_chunk_header_signature_with_epoch_manager,
Expand Down Expand Up @@ -1408,8 +1410,10 @@ impl Chain {
return Ok(true);
}
}
for part_id in 0..self.epoch_manager.num_total_parts() {
if &Some(self.epoch_manager.get_part_owner(&epoch_id, part_id as u64)?) == me {
let total_parts = num_total_parts(self.epoch_manager.as_ref());
for part_id in 0..total_parts {
if &Some(get_part_owner(self.epoch_manager.as_ref(), &epoch_id, part_id as u64)?) == me
{
return Ok(true);
}
}
Expand Down
39 changes: 39 additions & 0 deletions chain/chain/src/sharding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,50 @@ use near_epoch_manager::EpochManagerAdapter;
use near_primitives::block::Block;
use near_primitives::errors::EpochError;
use near_primitives::hash::CryptoHash;
use near_primitives::types::{AccountId, EpochId};
use near_primitives::version::ProtocolFeature;
use rand::seq::SliceRandom;
use rand::SeedableRng;
use rand_chacha::ChaCha20Rng;

/// Number of Reed-Solomon parts we split each chunk into.
///
/// Note: this shouldn't be too large, our Reed-Solomon supports at most 256
/// parts.
pub fn num_total_parts(epoch_manager: &dyn EpochManagerAdapter) -> usize {
stedfn marked this conversation as resolved.
Show resolved Hide resolved
let seats = epoch_manager.get_genesis_num_block_producer_seats();
if seats > 1 {
seats as usize
} else {
2
}
}

/// How many Reed-Solomon parts are data parts.
///
/// That is, fetching this many parts should be enough to reconstruct a
/// chunk, if there are no errors.
pub fn num_data_parts(epoch_manager: &dyn EpochManagerAdapter) -> usize {
stedfn marked this conversation as resolved.
Show resolved Hide resolved
let total_parts = num_total_parts(epoch_manager);
if total_parts <= 3 {
1
} else {
(total_parts - 1) / 3
}
}

/// Returns `account_id` that is supposed to have the `part_id`.
pub fn get_part_owner(
stedfn marked this conversation as resolved.
Show resolved Hide resolved
epoch_manager: &dyn EpochManagerAdapter,
epoch_id: &EpochId,
part_id: u64,
) -> Result<AccountId, EpochError> {
let epoch_info = epoch_manager.get_epoch_info(&epoch_id)?;
let settlement = epoch_info.block_producers_settlement();
let validator_id = settlement[part_id as usize % settlement.len()];
Ok(epoch_info.get_validator(validator_id).account_id().clone())
}

/// Gets salt for shuffling receipts grouped by **source shards** before
/// processing them in the target shard.
pub fn get_receipts_shuffle_salt<'a>(
Expand Down
34 changes: 12 additions & 22 deletions chain/chain/src/test_utils/kv_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,28 +431,18 @@ impl EpochManagerAdapter for MockEpochManager {
Ok(self.get_shard_layout(epoch_id)?.shard_ids().collect())
}

fn num_total_parts(&self) -> usize {
12 + (self.num_shards as usize + 1) % 50
}

fn num_data_parts(&self) -> usize {
// Same as in Nightshade Runtime
let total_parts = self.num_total_parts();
if total_parts <= 3 {
1
} else {
(total_parts - 1) / 3
}
}

fn get_part_owner(&self, epoch_id: &EpochId, part_id: u64) -> Result<AccountId, EpochError> {
let validators =
&self.get_epoch_block_producers_ordered(epoch_id, &CryptoHash::default())?;
// if we don't use data_parts and total_parts as part of the formula here, the part owner
// would not depend on height, and tests wouldn't catch passing wrong height here
let idx = part_id as usize + self.num_data_parts() + self.num_total_parts();
Ok(validators[idx as usize % validators.len()].0.account_id().clone())
}
fn get_genesis_num_block_producer_seats(&self) -> u64 {
12 + (self.num_shards + 1) % 50
}

// fn get_part_owner(&self, epoch_id: &EpochId, part_id: u64) -> Result<AccountId, EpochError> {
stedfn marked this conversation as resolved.
Show resolved Hide resolved
// let validators =
// &self.get_epoch_block_producers_ordered(epoch_id, &CryptoHash::default())?;
// // if we don't use data_parts and total_parts as part of the formula here, the part owner
// // would not depend on height, and tests wouldn't catch passing wrong height here
// let idx = part_id as usize + num_data_parts(self) + num_total_parts(self);
// Ok(validators[idx as usize % validators.len()].0.account_id().clone())
// }

fn account_id_to_shard_id(
&self,
Expand Down
4 changes: 3 additions & 1 deletion chain/chain/src/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use near_primitives::transaction::SignedTransaction;
use near_primitives::types::chunk_extra::ChunkExtra;
use near_primitives::types::{AccountId, BlockHeight, EpochId, Nonce};

use crate::sharding::num_data_parts;
use crate::signature_verification::{
verify_block_header_signature_with_epoch_manager,
verify_chunk_header_signature_with_epoch_manager,
Expand Down Expand Up @@ -343,7 +344,8 @@ fn validate_chunk_proofs_challenge(
let tmp_chunk;
let chunk_ref = match &*chunk_proofs.chunk {
MaybeEncodedShardChunk::Encoded(encoded_chunk) => {
match encoded_chunk.decode_chunk(epoch_manager.num_data_parts()) {
let data_parts = num_data_parts(epoch_manager);
match encoded_chunk.decode_chunk(data_parts) {
Ok(chunk) => {
tmp_chunk = Some(chunk);
tmp_chunk.as_ref().unwrap()
Expand Down
6 changes: 4 additions & 2 deletions chain/chunks/src/logic.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use near_chain::sharding::{get_part_owner, num_data_parts};
use near_chain::ChainStoreAccess;
use near_chain::{
types::EpochManagerAdapter, validate::validate_chunk_proofs, BlockHeader, Chain, ChainStore,
Expand Down Expand Up @@ -34,7 +35,7 @@ pub fn need_part(
epoch_manager: &dyn EpochManagerAdapter,
) -> Result<bool, EpochError> {
let epoch_id = epoch_manager.get_epoch_id_from_prev_block(prev_block_hash)?;
Ok(Some(&epoch_manager.get_part_owner(&epoch_id, part_ord)?) == me)
Ok(Some(&get_part_owner(epoch_manager, &epoch_id, part_ord)?) == me)
}

pub fn get_shards_cares_about_this_or_next_epoch(
Expand Down Expand Up @@ -166,8 +167,9 @@ pub fn decode_encoded_chunk(
?chunk_hash)
.entered();

let data_parts = num_data_parts(epoch_manager);
if let Ok(shard_chunk) = encoded_chunk
.decode_chunk(epoch_manager.num_data_parts())
.decode_chunk(data_parts)
.map_err(|err| Error::from(err))
.and_then(|shard_chunk| {
if !validate_chunk_proofs(&shard_chunk, epoch_manager)? {
Expand Down
54 changes: 27 additions & 27 deletions chain/chunks/src/shards_manager_actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ use near_async::time::Duration;
use near_async::time::{self, Clock};
use near_chain::byzantine_assert;
use near_chain::near_chain_primitives::error::Error::DBNotFoundErr;
use near_chain::sharding::{get_part_owner, num_data_parts, num_total_parts};
use near_chain::signature_verification::{
verify_chunk_header_signature_with_epoch_manager,
verify_chunk_header_signature_with_epoch_manager_and_parts,
Expand Down Expand Up @@ -359,6 +360,7 @@ impl ShardsManagerActor {
initial_chain_header_head: Tip,
chunk_request_retry_period: Duration,
) -> Self {
let data_parts = num_data_parts(epoch_manager.as_ref());
Self {
clock,
validator_signer,
Expand All @@ -368,11 +370,8 @@ impl ShardsManagerActor {
shard_tracker,
peer_manager_adapter: network_adapter,
client_adapter,
rs: ReedSolomon::new(
epoch_manager.num_data_parts(),
epoch_manager.num_total_parts() - epoch_manager.num_data_parts(),
)
.unwrap(),
rs: ReedSolomon::new(data_parts, num_total_parts(epoch_manager.as_ref()) - data_parts)
.unwrap(),
encoded_chunks: EncodedChunksCache::new(),
requested_partial_encoded_chunks: RequestPool::new(
CHUNK_REQUEST_RETRY,
Expand Down Expand Up @@ -480,7 +479,8 @@ impl ShardsManagerActor {

let epoch_id = self.epoch_manager.get_epoch_id_from_prev_block(ancestor_hash)?;

for part_ord in 0..self.epoch_manager.num_total_parts() {
let total_parts = num_total_parts(self.epoch_manager.as_ref());
for part_ord in 0..total_parts {
let part_ord = part_ord as u64;
if cache_entry.is_some_and(|cache_entry| cache_entry.parts.contains_key(&part_ord)) {
continue;
Expand All @@ -489,7 +489,7 @@ impl ShardsManagerActor {
// Note: If request_from_archival is true, we potentially call
// get_part_owner unnecessarily. It’s probably not worth optimizing
// though unless you can think of a concise way to do it.
let part_owner = self.epoch_manager.get_part_owner(&epoch_id, part_ord)?;
let part_owner = get_part_owner(self.epoch_manager.as_ref(), &epoch_id, part_ord)?;
let we_own_part = Some(&part_owner) == me;
if !request_full && !we_own_part {
continue;
Expand Down Expand Up @@ -1126,7 +1126,7 @@ impl ShardsManagerActor {
chunk_hash = ?chunk.chunk_hash())
.entered();

let data_parts = self.epoch_manager.num_data_parts();
let data_parts = num_data_parts(self.epoch_manager.as_ref());
if chunk.content().num_fetched_parts() < data_parts {
debug!(target: "chunks", num_fetched_parts = chunk.content().num_fetched_parts(), data_parts, "Incomplete");
return ChunkStatus::Incomplete;
Expand Down Expand Up @@ -1212,7 +1212,7 @@ impl ShardsManagerActor {
}

// check part merkle proofs
let num_total_parts = self.epoch_manager.num_total_parts();
let num_total_parts = num_total_parts(self.epoch_manager.as_ref());
for part_info in forward.parts.iter() {
self.validate_part(forward.merkle_root, part_info, num_total_parts)?;
}
Expand Down Expand Up @@ -1260,7 +1260,7 @@ impl ShardsManagerActor {

fn insert_forwarded_chunk(&mut self, forward: PartialEncodedChunkForwardMsg) {
let chunk_hash = forward.chunk_hash.clone();
let num_total_parts = self.epoch_manager.num_total_parts() as u64;
let num_total_parts = num_total_parts(self.epoch_manager.as_ref()) as u64;
match self.chunk_forwards_cache.get_mut(&chunk_hash) {
None => {
// Never seen this chunk hash before, collect the parts and cache them
Expand Down Expand Up @@ -1505,9 +1505,9 @@ impl ShardsManagerActor {
if entry.complete {
return Ok(ProcessPartialEncodedChunkResult::Known);
}
debug!(target: "chunks", num_parts_in_cache = entry.parts.len(), total_needed = self.epoch_manager.num_data_parts());
debug!(target: "chunks", num_parts_in_cache = entry.parts.len(), total_needed = num_data_parts(self.epoch_manager.as_ref()));
} else {
debug!(target: "chunks", num_parts_in_cache = 0, total_needed = self.epoch_manager.num_data_parts());
debug!(target: "chunks", num_parts_in_cache = 0, total_needed = num_data_parts(self.epoch_manager.as_ref()));
}

// 1.b Checking chunk height
Expand Down Expand Up @@ -1548,7 +1548,7 @@ impl ShardsManagerActor {
let partial_encoded_chunk = partial_encoded_chunk.as_ref().into_inner();

// 1.d Checking part_ords' validity
let num_total_parts = self.epoch_manager.num_total_parts();
let num_total_parts = num_total_parts(self.epoch_manager.as_ref());
for part_info in partial_encoded_chunk.parts.iter() {
// TODO: only validate parts we care about
// https://github.com/near/nearcore/issues/5885
Expand Down Expand Up @@ -1714,7 +1714,7 @@ impl ShardsManagerActor {
let have_all_parts = self.has_all_parts(&prev_block_hash, entry, me)?;
let have_all_receipts = self.has_all_receipts(&prev_block_hash, entry, me)?;

let can_reconstruct = entry.parts.len() >= self.epoch_manager.num_data_parts();
let can_reconstruct = entry.parts.len() >= num_data_parts(self.epoch_manager.as_ref());
let chunk_producer = self
.epoch_manager
.get_chunk_producer_info(&ChunkProductionKey {
Expand Down Expand Up @@ -1765,7 +1765,7 @@ impl ShardsManagerActor {
let protocol_version = self.epoch_manager.get_epoch_protocol_version(&epoch_id)?;
let mut encoded_chunk = EncodedShardChunk::from_header(
header.clone(),
self.epoch_manager.num_total_parts(),
num_total_parts(self.epoch_manager.as_ref()),
protocol_version,
);

Expand Down Expand Up @@ -1854,9 +1854,7 @@ impl ShardsManagerActor {
.iter()
.filter(|part| {
part_ords.contains(&part.part_ord)
&& self
.epoch_manager
.get_part_owner(epoch_id, part.part_ord)
&& get_part_owner(self.epoch_manager.as_ref(), epoch_id, part.part_ord)
.is_ok_and(|owner| &owner == me)
})
.cloned()
Expand Down Expand Up @@ -1994,7 +1992,8 @@ impl ShardsManagerActor {
chunk_entry: &EncodedChunksCacheEntry,
me: Option<&AccountId>,
) -> Result<bool, Error> {
for part_ord in 0..self.epoch_manager.num_total_parts() {
let total_parts = num_total_parts(self.epoch_manager.as_ref());
for part_ord in 0..total_parts {
let part_ord = part_ord as u64;
if !chunk_entry.parts.contains_key(&part_ord) {
if need_part(prev_block_hash, part_ord, me, self.epoch_manager.as_ref())? {
Expand Down Expand Up @@ -2072,9 +2071,10 @@ impl ShardsManagerActor {

let mut block_producer_mapping = HashMap::new();
let epoch_id = self.epoch_manager.get_epoch_id_from_prev_block(&prev_block_hash)?;
for part_ord in 0..self.epoch_manager.num_total_parts() {
let total_parts = num_total_parts(self.epoch_manager.as_ref());
for part_ord in 0..total_parts {
let part_ord = part_ord as u64;
let to_whom = self.epoch_manager.get_part_owner(&epoch_id, part_ord).unwrap();
let to_whom = get_part_owner(self.epoch_manager.as_ref(), &epoch_id, part_ord).unwrap();

let entry = block_producer_mapping.entry(to_whom).or_insert_with(Vec::new);
entry.push(part_ord);
Expand Down Expand Up @@ -2518,7 +2518,7 @@ mod test {
})
.count()
};
let non_owned_part_ords: Vec<u64> = (0..(fixture.epoch_manager.num_total_parts() as u64))
let non_owned_part_ords: Vec<u64> = (0..(num_total_parts(&fixture.epoch_manager) as u64))
.filter(|ord| !fixture.mock_part_ords.contains(ord))
.collect();
// Received 3 partial encoded chunks; the owned part is received 3 times, but should
Expand Down Expand Up @@ -2934,7 +2934,7 @@ mod test {
let mut update = fixture.chain_store.store_update();
let shard_chunk = fixture
.mock_encoded_chunk
.decode_chunk(fixture.epoch_manager.num_data_parts())
.decode_chunk(num_data_parts(&fixture.epoch_manager))
.unwrap();
update.save_chunk(shard_chunk);
update.commit().unwrap();
Expand Down Expand Up @@ -3026,7 +3026,7 @@ mod test {
let mut update = fixture.chain_store.store_update();
let shard_chunk = fixture
.mock_encoded_chunk
.decode_chunk(fixture.epoch_manager.num_data_parts())
.decode_chunk(num_data_parts(&fixture.epoch_manager))
.unwrap();
update.save_chunk(shard_chunk);
update.commit().unwrap();
Expand Down Expand Up @@ -3158,15 +3158,15 @@ mod test {
let mut update = fixture.chain_store.store_update();
let shard_chunk = fixture
.mock_encoded_chunk
.decode_chunk(fixture.epoch_manager.num_data_parts())
.decode_chunk(num_data_parts(&fixture.epoch_manager))
.unwrap();
update.save_chunk(shard_chunk);
update.commit().unwrap();

let (source, response) =
shards_manager.prepare_partial_encoded_chunk_response(PartialEncodedChunkRequestMsg {
chunk_hash: fixture.mock_chunk_header.chunk_hash(),
part_ords: vec![0, fixture.epoch_manager.num_total_parts() as u64],
part_ords: vec![0, num_total_parts(&fixture.epoch_manager) as u64],
tracking_shards: HashSet::new(),
});
assert_eq!(source, PartialEncodedChunkResponseSource::ShardChunkOnDisk);
Expand Down Expand Up @@ -3194,7 +3194,7 @@ mod test {
let mut update = fixture.chain_store.store_update();
let shard_chunk = fixture
.mock_encoded_chunk
.decode_chunk(fixture.epoch_manager.num_data_parts())
.decode_chunk(num_data_parts(&fixture.epoch_manager))
.unwrap();
update.save_chunk(shard_chunk);
update.commit().unwrap();
Expand Down
7 changes: 4 additions & 3 deletions chain/chunks/src/test_utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use near_async::messaging::CanSend;
use near_chain::sharding::{get_part_owner, num_data_parts, num_total_parts};
use near_chain::types::{EpochManagerAdapter, Tip};
use near_chain::{Chain, ChainStore};
use near_epoch_manager::shard_tracker::{ShardTracker, TrackedConfig};
Expand Down Expand Up @@ -87,8 +88,8 @@ impl ChunkTestFixture {
let mock_network = Arc::new(MockPeerManagerAdapter::default());
let mock_client_adapter = Arc::new(MockClientAdapterForShardsManager::default());

let data_parts = epoch_manager.num_data_parts();
let parity_parts = epoch_manager.num_total_parts() - data_parts;
let data_parts = num_data_parts(&epoch_manager);
let parity_parts = num_total_parts(&epoch_manager) - data_parts;
let rs = ReedSolomon::new(data_parts, parity_parts).unwrap();
let mock_ancestor_hash = CryptoHash::default();
// generate a random block hash for the block at height 1
Expand Down Expand Up @@ -175,7 +176,7 @@ impl ChunkTestFixture {
.iter()
.copied()
.filter(|p| {
epoch_manager.get_part_owner(&mock_epoch_id, *p).unwrap() == mock_chunk_part_owner
get_part_owner(&epoch_manager, &mock_epoch_id, *p).unwrap() == mock_chunk_part_owner
})
.collect();
let encoded_chunk = mock_chunk.create_partial_encoded_chunk(
Expand Down
Loading
Loading