Skip to content

Commit

Permalink
refactor(resharding) - small refactors, improvements and asserts (#12572
Browse files Browse the repository at this point in the history
)

Some small improvements, mostly following up on the previous PR. Most
importantly the V2 ShardLayout is now better validated on both paths
(the v2 ctor and the serde deserialize).
  • Loading branch information
wacban authored Dec 6, 2024
1 parent 63de148 commit 0287e07
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 68 deletions.
63 changes: 49 additions & 14 deletions core/primitives/src/shard_layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,23 @@ impl TryFrom<SerdeShardLayoutV2> for ShardLayoutV2 {
.map(|(k, v)| Ok((k.parse()?, v)))
.collect::<Result<_, Self::Error>>()?;

match (&shards_split_map, &shards_parent_map) {
(None, None) => {}
(Some(shard_split_map), Some(shards_parent_map)) => {
let expected_shards_parent_map =
validate_and_derive_shard_parent_map_v2(&shard_ids, &shard_split_map);
if &expected_shards_parent_map != shards_parent_map {
return Err("shards_parent_map does not match the expected value".into());
}
}
_ => {
return Err(
"shards_split_map and shards_parent_map must be both present or both absent"
.into(),
)
}
}

Ok(Self {
boundary_accounts,
shard_ids,
Expand Down Expand Up @@ -442,19 +459,8 @@ impl ShardLayout {
});
};

let mut shards_parent_map = ShardsParentMapV2::new();
for (&parent_shard_id, shard_ids) in shards_split_map.iter() {
for &shard_id in shard_ids {
let prev = shards_parent_map.insert(shard_id, parent_shard_id);
assert!(prev.is_none(), "no shard should appear in the map twice");
}
}

assert_eq!(
shard_ids.iter().copied().sorted().collect_vec(),
shards_parent_map.keys().copied().collect_vec()
);

let shards_parent_map =
validate_and_derive_shard_parent_map_v2(&shard_ids, &shards_split_map);
let shards_split_map = Some(shards_split_map);
let shards_parent_map = Some(shards_parent_map);
Self::V2(ShardLayoutV2 {
Expand Down Expand Up @@ -773,7 +779,7 @@ impl ShardLayout {

/// Returns all of the shards from the previous shard layout that were
/// split into multiple shards in this shard layout.
pub fn get_parent_shard_ids(&self) -> Result<BTreeSet<ShardId>, ShardLayoutError> {
pub fn get_split_parent_shard_ids(&self) -> Result<BTreeSet<ShardId>, ShardLayoutError> {
let mut parent_shard_ids = BTreeSet::new();
for shard_id in self.shard_ids() {
let parent_shard_id = self.try_get_parent_shard_id(shard_id)?;
Expand All @@ -789,6 +795,35 @@ impl ShardLayout {
}
}

// Validates the shards_split_map and derives the shards_parent_map from it.
fn validate_and_derive_shard_parent_map_v2(
shard_ids: &Vec<ShardId>,
shards_split_map: &ShardsSplitMapV2,
) -> ShardsParentMapV2 {
let mut shards_parent_map = ShardsParentMapV2::new();
for (&parent_shard_id, child_shard_ids) in shards_split_map.iter() {
for &child_shard_id in child_shard_ids {
let prev = shards_parent_map.insert(child_shard_id, parent_shard_id);
assert!(prev.is_none(), "no shard should appear in the map twice");
}
if let &[child_shard_id] = child_shard_ids.as_slice() {
// The parent shards with only one child shard are not split and
// should keep the same shard id.
assert_eq!(parent_shard_id, child_shard_id);
} else {
// The parent shards with multiple children shards are split.
// The parent shard id should not longer be used.
assert!(!shard_ids.contains(&parent_shard_id));
}
}

assert_eq!(
shard_ids.iter().copied().sorted().collect_vec(),
shards_parent_map.keys().copied().collect_vec()
);
shards_parent_map
}

/// Maps an account to the shard that it belongs to given a shard_layout
pub fn account_id_to_shard_uid(account_id: &AccountId, shard_layout: &ShardLayout) -> ShardUId {
ShardUId::from_shard_id_and_layout(
Expand Down
139 changes: 86 additions & 53 deletions integration-tests/src/test_loop/tests/resharding_v3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -618,20 +618,14 @@ fn assert_state_sanity(
.unwrap();

for shard_uid in shard_layout.shard_uids() {
// TODO - the condition for checks is duplicated in the
// `get_epoch_check` method, refactor this.
if !load_mem_tries_for_tracked_shards {
// In the old layout do not enforce except for shards pending resharding.
if !is_resharded && !shards_pending_resharding.contains(&shard_uid) {
tracing::debug!(target: "test", ?shard_uid, "skipping shard not pending resharding");
continue;
}

// In the new layout do not enforce for shards that were not split.
if is_resharded && !shard_was_split(&shard_layout, shard_uid.shard_id()) {
tracing::debug!(target: "test", ?shard_uid, "skipping shard not split");
continue;
}
if !should_assert_state_sanity(
load_mem_tries_for_tracked_shards,
is_resharded,
&shards_pending_resharding,
&shard_layout,
&shard_uid,
) {
continue;
}

if !client_tracking_shard(client, shard_uid.shard_id(), &final_head.prev_block_hash) {
Expand Down Expand Up @@ -692,6 +686,31 @@ fn assert_state_sanity(
checked_shards
}

fn should_assert_state_sanity(
load_mem_tries_for_tracked_shards: bool,
is_resharded: bool,
shards_pending_resharding: &HashSet<ShardUId>,
shard_layout: &ShardLayout,
shard_uid: &ShardUId,
) -> bool {
// Always assert if the tracked shards are loaded into memory.
if load_mem_tries_for_tracked_shards {
return true;
}

// In the old layout do not enforce except for shards pending resharding.
if !is_resharded && !shards_pending_resharding.contains(&shard_uid) {
return false;
}

// In the new layout do not enforce for shards that were not split.
if is_resharded && !shard_was_split(shard_layout, shard_uid.shard_id()) {
return false;
}

true
}

// For each epoch, keep a map from AccountId to a map with keys equal to
// the set of shards that account tracks in that epoch, and bool values indicating
// whether the equality of flat storage and memtries has been checked for that shard
Expand Down Expand Up @@ -737,47 +756,61 @@ impl TrieSanityCheck {
let shard_layout = client.epoch_manager.get_shard_layout(&tip.epoch_id).unwrap();
let is_resharded = shard_layout.num_shards() == new_num_shards;

match self.checks.entry(tip.epoch_id) {
std::collections::hash_map::Entry::Occupied(e) => e.into_mut(),
std::collections::hash_map::Entry::Vacant(e) => {
let shard_uids = shard_layout.shard_uids().collect_vec();
let mut check = HashMap::new();
for account_id in self.accounts.iter() {
let tracked = shard_uids
.iter()
.filter_map(|uid| {
if !is_resharded
&& !self.load_mem_tries_for_tracked_shards
&& !shards_pending_resharding.contains(uid)
{
return None;
}

if is_resharded
&& !self.load_mem_tries_for_tracked_shards
&& !shard_was_split(&shard_layout, uid.shard_id())
{
return None;
}

let cares = client.shard_tracker.care_about_shard(
Some(account_id),
&tip.prev_block_hash,
uid.shard_id(),
false,
);
if cares {
Some((*uid, false))
} else {
None
}
})
.collect();
check.insert(account_id.clone(), tracked);
}
e.insert(check)
if self.checks.contains_key(&tip.epoch_id) {
return self.checks.get_mut(&tip.epoch_id).unwrap();
}

let mut check = HashMap::new();
for account_id in self.accounts.iter() {
let check_shard_uids = self.get_epoch_check_for_account(
client,
tip,
is_resharded,
&shards_pending_resharding,
&shard_layout,
account_id,
);
check.insert(account_id.clone(), check_shard_uids);
}

self.checks.insert(tip.epoch_id, check);
self.checks.get_mut(&tip.epoch_id).unwrap()
}

// Returns the expected shard uids for the given account.
fn get_epoch_check_for_account(
&self,
client: &Client,
tip: &Tip,
is_resharded: bool,
shards_pending_resharding: &HashSet<ShardUId>,
shard_layout: &ShardLayout,
account_id: &AccountId,
) -> HashMap<ShardUId, bool> {
let mut check_shard_uids = HashMap::new();
for shard_uid in shard_layout.shard_uids() {
if !should_assert_state_sanity(
self.load_mem_tries_for_tracked_shards,
is_resharded,
shards_pending_resharding,
shard_layout,
&shard_uid,
) {
continue;
}

let cares = client.shard_tracker.care_about_shard(
Some(account_id),
&tip.prev_block_hash,
shard_uid.shard_id(),
false,
);
if !cares {
continue;
}
check_shard_uids.insert(shard_uid, false);
}
check_shard_uids
}

// Check trie sanity and keep track of which shards were succesfully fully checked
Expand Down
2 changes: 1 addition & 1 deletion runtime/runtime/src/congestion_control.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ impl ReceiptSinkV2 {
if ProtocolFeature::SimpleNightshadeV4.enabled(protocol_version) {
(
shard_layout.shard_ids().collect_vec(),
shard_layout.get_parent_shard_ids().map_err(Into::<EpochError>::into)?,
shard_layout.get_split_parent_shard_ids().map_err(Into::<EpochError>::into)?,
)
} else {
(self.outgoing_limit.keys().copied().collect_vec(), BTreeSet::new())
Expand Down

0 comments on commit 0287e07

Please sign in to comment.