diff --git a/consensus/types/src/beacon_state/tree_hash_cache.rs b/consensus/types/src/beacon_state/tree_hash_cache.rs index 0c8899c0255..117c6aa136b 100644 --- a/consensus/types/src/beacon_state/tree_hash_cache.rs +++ b/consensus/types/src/beacon_state/tree_hash_cache.rs @@ -7,6 +7,7 @@ use rayon::prelude::*; use ssz_derive::{Decode, Encode}; use ssz_types::VariableList; use std::cmp::Ordering; +use std::iter::ExactSizeIterator; use tree_hash::{mix_in_length, MerkleHasher, TreeHash}; /// The number of fields on a beacon state. @@ -288,17 +289,17 @@ impl ValidatorsListTreeHashCache { fn recalculate_tree_hash_root(&mut self, validators: &[Validator]) -> Result { let mut list_arena = std::mem::take(&mut self.list_arena); - let leaves = self - .values - .leaves(validators)? - .into_iter() - .flatten() - .map(|h| h.to_fixed_bytes()) - .collect::>(); + let leaves = self.values.leaves(validators)?; + let num_leaves = leaves.iter().map(|arena| arena.len()).sum(); + + let leaves_iter = ForcedLengthIterator { + iter: leaves.into_iter().flatten().map(|h| h.to_fixed_bytes()), + len: num_leaves, + }; let list_root = self .list_cache - .recalculate_merkle_root(&mut list_arena, leaves.into_iter())?; + .recalculate_merkle_root(&mut list_arena, leaves_iter)?; self.list_arena = list_arena; @@ -306,6 +307,29 @@ impl ValidatorsListTreeHashCache { } } +/// Provides a wrapper around some `iter` if the number of items in the iterator is known to the +/// programmer but not the compiler. This allows use of `ExactSizeIterator` in some occasions. +/// +/// Care should be taken to ensure `len` is accurate. +struct ForcedLengthIterator { + iter: I, + len: usize, +} + +impl> Iterator for ForcedLengthIterator { + type Item = V; + + fn next(&mut self) -> Option { + self.iter.next() + } +} + +impl> ExactSizeIterator for ForcedLengthIterator { + fn len(&self) -> usize { + self.len + } +} + /// Provides a cache for each of the `Validator` objects in `state.validators` and computes the /// roots of these using Rayon parallelization. #[derive(Debug, PartialEq, Clone, Default, Encode, Decode)]