diff --git a/gossip/src/weighted_shuffle.rs b/gossip/src/weighted_shuffle.rs index 7c12debce469e0..656615449b2a79 100644 --- a/gossip/src/weighted_shuffle.rs +++ b/gossip/src/weighted_shuffle.rs @@ -9,6 +9,14 @@ use { std::ops::{AddAssign, Sub, SubAssign}, }; +// Each internal tree node has FANOUT many child nodes with indices: +// (index << BIT_SHIFT) + 1 ..= (index << BIT_SHIFT) + FANOUT +// Conversely, for each node, the parent node is obtained by: +// (index - 1) >> BIT_SHIFT +const BIT_SHIFT: usize = 4; +const FANOUT: usize = 1 << BIT_SHIFT; +const BIT_MASK: usize = FANOUT - 1; + /// Implements an iterator where indices are shuffled according to their /// weights: /// - Returned indices are unique in the range [0, weights.len()). @@ -18,12 +26,13 @@ use { /// non-zero weighted indices. #[derive(Clone)] pub struct WeightedShuffle { - // Underlying array implementing binary tree. - // tree[i] is the sum of weights in the left sub-tree of node i. - tree: Vec, + // Underlying array implementing the tree. + // tree[i][j] is the sum of all weights in the j'th sub-tree of node i. + tree: Vec<[T; FANOUT - 1]>, // Current sum of all weights, excluding already sampled ones. weight: T, - zeros: Vec, // Indices of zero weighted entries. + // Indices of zero weighted entries. + zeros: Vec, } impl WeightedShuffle @@ -34,7 +43,7 @@ where /// they are treated as zero. pub fn new(name: &'static str, weights: &[T]) -> Self { let zero = ::default(); - let mut tree = vec![zero; get_tree_size(weights.len())]; + let mut tree = vec![[zero; FANOUT - 1]; get_tree_size(weights.len())]; let mut sum = zero; let mut zeros = Vec::default(); let mut num_negative = 0; @@ -59,12 +68,14 @@ where continue; } }; - let mut index = tree.len() + k; + // Traverse the tree from the leaf node upwards to the root, + // updating the sub-tree sums along the way. + let mut index = tree.len() + k; // leaf node while index != 0 { - let offset = index & 1; - index = (index - 1) >> 1; + let offset = index & BIT_MASK; + index = (index - 1) >> BIT_SHIFT; // parent node if offset > 0 { - tree[index] += weight; + tree[index][offset - 1] += weight; } } } @@ -88,54 +99,73 @@ where { // Removes given weight at index k. fn remove(&mut self, k: usize, weight: T) { + debug_assert!(self.weight >= weight); self.weight -= weight; - let mut index = self.tree.len() + k; + // Traverse the tree from the leaf node upwards to the root, + // updating the sub-tree sums along the way. + let mut index = self.tree.len() + k; // leaf node while index != 0 { - let offset = index & 1; - index = (index - 1) >> 1; + let offset = index & BIT_MASK; + index = (index - 1) >> BIT_SHIFT; // parent node if offset > 0 { - self.tree[index] -= weight; + debug_assert!(self.tree[index][offset - 1] >= weight); + self.tree[index][offset - 1] -= weight; } } } - // Returns smallest index such that cumsum of weights[..=k] > val, + // Returns smallest index such that sum of weights[..=k] > val, // along with its respective weight. fn search(&self, mut val: T) -> (/*index:*/ usize, /*weight:*/ T) { let zero = ::default(); debug_assert!(val >= zero); debug_assert!(val < self.weight); - let mut index = 0; + // Traverse the tree downwards from the root while maintaining the + // weight of the subtree which contains the target leaf node. + let mut index = 0; // root let mut weight = self.weight; - while index < self.tree.len() { - if val < self.tree[index] { - weight = self.tree[index]; - index = (index << 1) + 1; - } else { - weight -= self.tree[index]; - val -= self.tree[index]; - index = (index << 1) + 2; + 'outer: while index < self.tree.len() { + for (j, &node) in self.tree[index].iter().enumerate() { + if val < node { + // Traverse to the j+1 subtree of self.tree[index]. + weight = node; + index = (index << BIT_SHIFT) + j + 1; + continue 'outer; + } else { + debug_assert!(weight >= node); + weight -= node; + val -= node; + } } + // Traverse to the right-most subtree of self.tree[index]. + index = (index << BIT_SHIFT) + FANOUT; } (index - self.tree.len(), weight) } pub fn remove_index(&mut self, k: usize) { - let mut index = self.tree.len() + k; + // Traverse the tree from the leaf node upwards to the root, while + // maintaining the sum of weights of subtrees *not* containing the leaf + // node. + let mut index = self.tree.len() + k; // leaf node let mut weight = ::default(); // zero while index != 0 { - let offset = index & 1; - index = (index - 1) >> 1; + let offset = index & BIT_MASK; + index = (index - 1) >> BIT_SHIFT; // parent node if offset > 0 { - if self.tree[index] != weight { - self.remove(k, self.tree[index] - weight); + if self.tree[index][offset - 1] != weight { + self.remove(k, self.tree[index][offset - 1] - weight); } else { self.remove_zero(k); } return; } - weight += self.tree[index]; + // The leaf node is in the right-most subtree of self.tree[index]. + for &node in &self.tree[index] { + weight += node; + } } + // The leaf node is the right-most node of the whole tree. if self.weight != weight { self.remove(k, self.weight - weight); } else { @@ -193,17 +223,16 @@ where } } -// Maps number of items to the "internal" size of the binary tree "implicitly" -// holding those items on the leaves. +// Maps number of items to the "internal" size of the tree +// which "implicitly" holds those items on the leaves. fn get_tree_size(count: usize) -> usize { - let shift = usize::BITS - - count.leading_zeros() - - if count.is_power_of_two() && count != 1 { - 1 - } else { - 0 - }; - (1usize << shift) - 1 + let mut size = if count == 1 { 1 } else { 0 }; + let mut nodes = 1; + while nodes < count { + size += nodes; + nodes *= FANOUT; + } + size } #[cfg(test)] @@ -251,25 +280,18 @@ mod tests { #[test] fn test_get_tree_size() { assert_eq!(get_tree_size(0), 0); - assert_eq!(get_tree_size(1), 1); - assert_eq!(get_tree_size(2), 1); - assert_eq!(get_tree_size(3), 3); - assert_eq!(get_tree_size(4), 3); - for count in 5..9 { - assert_eq!(get_tree_size(count), 7); + for count in 1..=16 { + assert_eq!(get_tree_size(count), 1); + } + for count in 17..=256 { + assert_eq!(get_tree_size(count), 1 + 16); } - for count in 9..17 { - assert_eq!(get_tree_size(count), 15); + for count in 257..=4096 { + assert_eq!(get_tree_size(count), 1 + 16 + 16 * 16); } - for count in 17..33 { - assert_eq!(get_tree_size(count), 31); + for count in 4097..=65536 { + assert_eq!(get_tree_size(count), 1 + 16 + 16 * 16 + 16 * 16 * 16); } - assert_eq!(get_tree_size((1 << 16) - 1), (1 << 16) - 1); - assert_eq!(get_tree_size(1 << 16), (1 << 16) - 1); - assert_eq!(get_tree_size((1 << 16) + 1), (1 << 17) - 1); - assert_eq!(get_tree_size((1 << 17) - 1), (1 << 17) - 1); - assert_eq!(get_tree_size(1 << 17), (1 << 17) - 1); - assert_eq!(get_tree_size((1 << 17) + 1), (1 << 18) - 1); } // Asserts that empty weights will return empty shuffle.