Skip to content

Commit

Permalink
Introduced ForcedFixedLenIter
Browse files Browse the repository at this point in the history
  • Loading branch information
paulhauner committed Aug 15, 2020
1 parent 52ab576 commit f2590e3
Showing 1 changed file with 32 additions and 8 deletions.
40 changes: 32 additions & 8 deletions consensus/types/src/beacon_state/tree_hash_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use rayon::prelude::*;
use ssz_derive::{Decode, Encode};
use ssz_types::VariableList;
use std::cmp::Ordering;
use std::iter::ExactSizeIterator;
use tree_hash::{mix_in_length, MerkleHasher, TreeHash};

/// The number of fields on a beacon state.
Expand Down Expand Up @@ -288,24 +289,47 @@ impl ValidatorsListTreeHashCache {
fn recalculate_tree_hash_root(&mut self, validators: &[Validator]) -> Result<Hash256, Error> {
let mut list_arena = std::mem::take(&mut self.list_arena);

let leaves = self
.values
.leaves(validators)?
.into_iter()
.flatten()
.map(|h| h.to_fixed_bytes())
.collect::<Vec<_>>();
let leaves = self.values.leaves(validators)?;
let num_leaves = leaves.iter().map(|arena| arena.len()).sum();

let leaves_iter = ForcedLengthIterator {
iter: leaves.into_iter().flatten().map(|h| h.to_fixed_bytes()),
len: num_leaves,
};

let list_root = self
.list_cache
.recalculate_merkle_root(&mut list_arena, leaves.into_iter())?;
.recalculate_merkle_root(&mut list_arena, leaves_iter)?;

self.list_arena = list_arena;

Ok(mix_in_length(&list_root, validators.len()))
}
}

/// Provides a wrapper around some `iter` if the number of items in the iterator is known to the
/// programmer but not the compiler. This allows use of `ExactSizeIterator` in some occasions.
///
/// Care should be taken to ensure `len` is accurate.
struct ForcedLengthIterator<I> {
iter: I,
len: usize,
}

impl<V, I: Iterator<Item = V>> Iterator for ForcedLengthIterator<I> {
type Item = V;

fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}

impl<V, I: Iterator<Item = V>> ExactSizeIterator for ForcedLengthIterator<I> {
fn len(&self) -> usize {
self.len
}
}

/// Provides a cache for each of the `Validator` objects in `state.validators` and computes the
/// roots of these using Rayon parallelization.
#[derive(Debug, PartialEq, Clone, Default, Encode, Decode)]
Expand Down

0 comments on commit f2590e3

Please sign in to comment.