diff --git a/Cargo.toml b/Cargo.toml index da251903dd8..19f7573851c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -75,3 +75,8 @@ rand_pcg = { path = "rand_pcg", version = "0.4.0" } bincode = "1.2.1" rayon = "1.5.3" criterion = { version = "0.4" } + +[[bench]] +name = "seq_choose" +path = "benches/seq_choose.rs" +harness = false \ No newline at end of file diff --git a/benches/seq.rs b/benches/seq.rs index 5b3a846f60b..3d57d4872e6 100644 --- a/benches/seq.rs +++ b/benches/seq.rs @@ -13,9 +13,9 @@ extern crate test; use test::Bencher; +use core::mem::size_of; use rand::prelude::*; use rand::seq::*; -use core::mem::size_of; // We force use of 32-bit RNG since seq code is optimised for use with 32-bit // generators on all platforms. @@ -74,76 +74,6 @@ seq_slice_choose_multiple!(seq_slice_choose_multiple_950_of_1000, 950, 1000); seq_slice_choose_multiple!(seq_slice_choose_multiple_10_of_100, 10, 100); seq_slice_choose_multiple!(seq_slice_choose_multiple_90_of_100, 90, 100); -#[bench] -fn seq_iter_choose_from_1000(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x: &mut [usize] = &mut [1; 1000]; - for (i, r) in x.iter_mut().enumerate() { - *r = i; - } - b.iter(|| { - let mut s = 0; - for _ in 0..RAND_BENCH_N { - s += x.iter().choose(&mut rng).unwrap(); - } - s - }); - b.bytes = size_of::() as u64 * crate::RAND_BENCH_N; -} - -#[derive(Clone)] -struct UnhintedIterator { - iter: I, -} -impl Iterator for UnhintedIterator { - type Item = I::Item; - - fn next(&mut self) -> Option { - self.iter.next() - } -} - -#[derive(Clone)] -struct WindowHintedIterator { - iter: I, - window_size: usize, -} -impl Iterator for WindowHintedIterator { - type Item = I::Item; - - fn next(&mut self) -> Option { - self.iter.next() - } - - fn size_hint(&self) -> (usize, Option) { - (core::cmp::min(self.iter.len(), self.window_size), None) - } -} - -#[bench] -fn seq_iter_unhinted_choose_from_1000(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x: &[usize] = &[1; 1000]; - b.iter(|| { - UnhintedIterator { iter: x.iter() } - .choose(&mut rng) - .unwrap() - }) -} - -#[bench] -fn seq_iter_window_hinted_choose_from_1000(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x: &[usize] = &[1; 1000]; - b.iter(|| { - WindowHintedIterator { - iter: x.iter(), - window_size: 7, - } - .choose(&mut rng) - }) -} - #[bench] fn seq_iter_choose_multiple_10_of_100(b: &mut Bencher) { let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); diff --git a/benches/seq_choose.rs b/benches/seq_choose.rs new file mode 100644 index 00000000000..44b4bdf9724 --- /dev/null +++ b/benches/seq_choose.rs @@ -0,0 +1,111 @@ +// Copyright 2018-2022 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use rand::prelude::*; +use rand::SeedableRng; + +criterion_group!( +name = benches; +config = Criterion::default(); +targets = bench +); +criterion_main!(benches); + +pub fn bench(c: &mut Criterion) { + bench_rng::(c, "ChaCha20"); + bench_rng::(c, "Pcg32"); + bench_rng::(c, "Pcg64"); +} + +fn bench_rng(c: &mut Criterion, rng_name: &'static str) { + for length in [1, 2, 3, 10, 100, 1000].map(|x| black_box(x)) { + c.bench_function( + format!("choose_size-hinted_from_{length}_{rng_name}").as_str(), + |b| { + let mut rng = Rng::seed_from_u64(123); + b.iter(|| choose_size_hinted(length, &mut rng)) + }, + ); + + c.bench_function( + format!("choose_stable_from_{length}_{rng_name}").as_str(), + |b| { + let mut rng = Rng::seed_from_u64(123); + b.iter(|| choose_stable(length, &mut rng)) + }, + ); + + c.bench_function( + format!("choose_unhinted_from_{length}_{rng_name}").as_str(), + |b| { + let mut rng = Rng::seed_from_u64(123); + b.iter(|| choose_unhinted(length, &mut rng)) + }, + ); + + c.bench_function( + format!("choose_windowed_from_{length}_{rng_name}").as_str(), + |b| { + let mut rng = Rng::seed_from_u64(123); + b.iter(|| choose_windowed(length, 7, &mut rng)) + }, + ); + } +} + +fn choose_size_hinted(max: usize, rng: &mut R) -> Option { + let iterator = 0..max; + iterator.choose(rng) +} + +fn choose_stable(max: usize, rng: &mut R) -> Option { + let iterator = 0..max; + iterator.choose_stable(rng) +} + +fn choose_unhinted(max: usize, rng: &mut R) -> Option { + let iterator = UnhintedIterator { iter: (0..max) }; + iterator.choose(rng) +} + +fn choose_windowed(max: usize, window_size: usize, rng: &mut R) -> Option { + let iterator = WindowHintedIterator { + iter: (0..max), + window_size, + }; + iterator.choose(rng) +} + +#[derive(Clone)] +struct UnhintedIterator { + iter: I, +} +impl Iterator for UnhintedIterator { + type Item = I::Item; + + fn next(&mut self) -> Option { + self.iter.next() + } +} + +#[derive(Clone)] +struct WindowHintedIterator { + iter: I, + window_size: usize, +} +impl Iterator for WindowHintedIterator { + type Item = I::Item; + + fn next(&mut self) -> Option { + self.iter.next() + } + + fn size_hint(&self) -> (usize, Option) { + (core::cmp::min(self.iter.len(), self.window_size), None) + } +} diff --git a/src/seq/coin_flipper.rs b/src/seq/coin_flipper.rs new file mode 100644 index 00000000000..77c18ded432 --- /dev/null +++ b/src/seq/coin_flipper.rs @@ -0,0 +1,152 @@ +use crate::RngCore; + +pub(crate) struct CoinFlipper { + pub rng: R, + chunk: u32, //TODO(opt): this should depend on RNG word size + chunk_remaining: u32, +} + +impl CoinFlipper { + pub fn new(rng: R) -> Self { + Self { + rng, + chunk: 0, + chunk_remaining: 0, + } + } + + #[inline] + /// Returns true with a probability of 1 / d + /// Uses an expected two bits of randomness + /// Panics if d == 0 + pub fn gen_ratio_one_over(&mut self, d: usize) -> bool { + debug_assert_ne!(d, 0); + // This uses the same logic as `gen_ratio` but is optimized for the case that + // the starting numerator is one (which it always is for `Sequence::Choose()`) + + // In this case (but not `gen_ratio`), this way of calculating c is always accurate + let c = (usize::BITS - 1 - d.leading_zeros()).min(32); + + if self.flip_c_heads(c) { + let numerator = 1 << c; + return self.gen_ratio(numerator, d); + } else { + return false; + } + } + + #[inline] + /// Returns true with a probability of n / d + /// Uses an expected two bits of randomness + fn gen_ratio(&mut self, mut n: usize, d: usize) -> bool { + // Explanation: + // We are trying to return true with a probability of n / d + // If n >= d, we can just return true + // Otherwise there are two possibilities 2n < d and 2n >= d + // In either case we flip a coin. + // If 2n < d + // If it comes up tails, return false + // If it comes up heads, double n and start again + // This is fair because (0.5 * 0) + (0.5 * 2n / d) = n / d and 2n is less than d + // (if 2n was greater than d we would effectively round it down to 1 + // by returning true) + // If 2n >= d + // If it comes up tails, set n to 2n - d and start again + // If it comes up heads, return true + // This is fair because (0.5 * 1) + (0.5 * (2n - d) / d) = n / d + // Note that if 2n = d and the coin comes up tails, n will be set to 0 + // before restarting which is equivalent to returning false. + + // As a performance optimization we can flip multiple coins at once + // This is efficient because we can use the `lzcnt` intrinsic + // We can check up to 32 flips at once but we only receive one bit of information + // - all heads or at least one tail. + + // Let c be the number of coins to flip. 1 <= c <= 32 + // If 2n < d, n * 2^c < d + // If the result is all heads, then set n to n * 2^c + // If there was at least one tail, return false + // If 2n >= d, the order of results matters so we flip one coin at a time so c = 1 + // Ideally, c will be as high as possible within these constraints + + while n < d { + // Find a good value for c by counting leading zeros + // This will either give the highest possible c, or 1 less than that + let c = n + .leading_zeros() + .saturating_sub(d.leading_zeros() + 1) + .clamp(1, 32); + + if self.flip_c_heads(c) { + // All heads + // Set n to n * 2^c + // If 2n >= d, the while loop will exit and we will return `true` + // If n * 2^c > `usize::MAX` we always return `true` anyway + n = n.saturating_mul(2_usize.pow(c)); + } else { + //At least one tail + if c == 1 { + // Calculate 2n - d. + // We need to use wrapping as 2n might be greater than `usize::MAX` + let next_n = n.wrapping_add(n).wrapping_sub(d); + if next_n == 0 || next_n > n { + // This will happen if 2n < d + return false; + } + n = next_n; + } else { + // c > 1 so 2n < d so we can return false + return false; + } + } + } + true + } + + /// If the next `c` bits of randomness all represent heads, consume them, return true + /// Otherwise return false and consume the number of heads plus one. + /// Generates new bits of randomness when necessary (in 32 bit chunks) + /// Has a 1 in 2 to the `c` chance of returning true + /// `c` must be less than or equal to 32 + fn flip_c_heads(&mut self, mut c: u32) -> bool { + debug_assert!(c <= 32); + // Note that zeros on the left of the chunk represent heads. + // It needs to be this way round because zeros are filled in when left shifting + loop { + let zeros = self.chunk.leading_zeros(); + + if zeros < c { + // The happy path - we found a 1 and can return false + // Note that because a 1 bit was detected, + // We cannot have run out of random bits so we don't need to check + + // First consume all of the bits read + // Using shl seems to give worse performance for size-hinted iterators + self.chunk = self.chunk.wrapping_shl(zeros + 1); + + self.chunk_remaining = self.chunk_remaining.saturating_sub(zeros + 1); + return false; + } else { + // The number of zeros is larger than `c` + // There are two possibilities + if let Some(new_remaining) = self.chunk_remaining.checked_sub(c) { + // Those zeroes were all part of our random chunk, + // throw away `c` bits of randomness and return true + self.chunk_remaining = new_remaining; + self.chunk <<= c; + return true; + } else { + // Some of those zeroes were part of the random chunk + // and some were part of the space behind it + // We need to take into account only the zeroes that were random + c -= self.chunk_remaining; + + // Generate a new chunk + self.chunk = self.rng.next_u32(); + self.chunk_remaining = 32; + // Go back to start of loop + } + } + } + } +} diff --git a/src/seq/mod.rs b/src/seq/mod.rs index 24c65bc9f08..15fca6fc9c0 100644 --- a/src/seq/mod.rs +++ b/src/seq/mod.rs @@ -24,20 +24,25 @@ //! `usize` indices are sampled as a `u32` where possible (also providing a //! small performance boost in some cases). - +mod coin_flipper; #[cfg(feature = "alloc")] #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] pub mod index; -#[cfg(feature = "alloc")] use core::ops::Index; +#[cfg(feature = "alloc")] +use core::ops::Index; -#[cfg(feature = "alloc")] use alloc::vec::Vec; +#[cfg(feature = "alloc")] +use alloc::vec::Vec; #[cfg(feature = "alloc")] use crate::distributions::uniform::{SampleBorrow, SampleUniform}; -#[cfg(feature = "alloc")] use crate::distributions::WeightedError; +#[cfg(feature = "alloc")] +use crate::distributions::WeightedError; use crate::Rng; +use self::coin_flipper::CoinFlipper; + /// Extension trait on slices, providing random mutation and sampling methods. /// /// This trait is implemented on all `[T]` slice types, providing several @@ -77,14 +82,16 @@ pub trait SliceRandom { /// assert_eq!(choices[..0].choose(&mut rng), None); /// ``` fn choose(&self, rng: &mut R) -> Option<&Self::Item> - where R: Rng + ?Sized; + where + R: Rng + ?Sized; /// Returns a mutable reference to one random element of the slice, or /// `None` if the slice is empty. /// /// For slices, complexity is `O(1)`. fn choose_mut(&mut self, rng: &mut R) -> Option<&mut Self::Item> - where R: Rng + ?Sized; + where + R: Rng + ?Sized; /// Chooses `amount` elements from the slice at random, without repetition, /// and in random order. The returned iterator is appropriate both for @@ -113,7 +120,8 @@ pub trait SliceRandom { #[cfg(feature = "alloc")] #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] fn choose_multiple(&self, rng: &mut R, amount: usize) -> SliceChooseIter - where R: Rng + ?Sized; + where + R: Rng + ?Sized; /// Similar to [`choose`], but where the likelihood of each outcome may be /// specified. @@ -249,7 +257,8 @@ pub trait SliceRandom { /// println!("Shuffled: {:?}", y); /// ``` fn shuffle(&mut self, rng: &mut R) - where R: Rng + ?Sized; + where + R: Rng + ?Sized; /// Shuffle a slice in place, but exit early. /// @@ -271,7 +280,8 @@ pub trait SliceRandom { fn partial_shuffle( &mut self, rng: &mut R, amount: usize, ) -> (&mut [Self::Item], &mut [Self::Item]) - where R: Rng + ?Sized; + where + R: Rng + ?Sized; } /// Extension trait on iterators, providing random sampling methods. @@ -309,26 +319,30 @@ pub trait IteratorRandom: Iterator + Sized { /// `choose` returning different elements. If you want consistent results /// and RNG usage consider using [`IteratorRandom::choose_stable`]. fn choose(mut self, rng: &mut R) -> Option - where R: Rng + ?Sized { + where + R: Rng + ?Sized, + { let (mut lower, mut upper) = self.size_hint(); - let mut consumed = 0; let mut result = None; // Handling for this condition outside the loop allows the optimizer to eliminate the loop // when the Iterator is an ExactSizeIterator. This has a large performance impact on e.g. // seq_iter_choose_from_1000. if upper == Some(lower) { - return if lower == 0 { - None - } else { - self.nth(gen_index(rng, lower)) + return match lower { + 0 => None, + 1 => self.next(), + _ => self.nth(gen_index(rng, lower)), }; } + let mut coin_flipper = coin_flipper::CoinFlipper::new(rng); + let mut consumed = 0; + // Continue until the iterator is exhausted loop { if lower > 1 { - let ix = gen_index(rng, lower + consumed); + let ix = gen_index(coin_flipper.rng, lower + consumed); let skip = if ix < lower { result = self.nth(ix); lower - (ix + 1) @@ -348,7 +362,7 @@ pub trait IteratorRandom: Iterator + Sized { return result; } consumed += 1; - if gen_index(rng, consumed) == 0 { + if coin_flipper.gen_ratio_one_over(consumed) { result = elem; } } @@ -378,9 +392,12 @@ pub trait IteratorRandom: Iterator + Sized { /// /// [`choose`]: IteratorRandom::choose fn choose_stable(mut self, rng: &mut R) -> Option - where R: Rng + ?Sized { + where + R: Rng + ?Sized, + { let mut consumed = 0; let mut result = None; + let mut coin_flipper = CoinFlipper::new(rng); loop { // Currently the only way to skip elements is `nth()`. So we need to @@ -392,7 +409,7 @@ pub trait IteratorRandom: Iterator + Sized { let (lower, _) = self.size_hint(); if lower >= 2 { let highest_selected = (0..lower) - .filter(|ix| gen_index(rng, consumed+ix+1) == 0) + .filter(|ix| coin_flipper.gen_ratio_one_over(consumed + ix + 1)) .last(); consumed += lower; @@ -407,10 +424,10 @@ pub trait IteratorRandom: Iterator + Sized { let elem = self.nth(next); if elem.is_none() { - return result + return result; } - if gen_index(rng, consumed+1) == 0 { + if coin_flipper.gen_ratio_one_over(consumed + 1) { result = elem; } consumed += 1; @@ -431,7 +448,9 @@ pub trait IteratorRandom: Iterator + Sized { /// Complexity is `O(n)` where `n` is the length of the iterator. /// For slices, prefer [`SliceRandom::choose_multiple`]. fn choose_multiple_fill(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize - where R: Rng + ?Sized { + where + R: Rng + ?Sized, + { let amount = buf.len(); let mut len = 0; while len < amount { @@ -471,7 +490,9 @@ pub trait IteratorRandom: Iterator + Sized { #[cfg(feature = "alloc")] #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] fn choose_multiple(mut self, rng: &mut R, amount: usize) -> Vec - where R: Rng + ?Sized { + where + R: Rng + ?Sized, + { let mut reservoir = Vec::with_capacity(amount); reservoir.extend(self.by_ref().take(amount)); @@ -495,12 +516,13 @@ pub trait IteratorRandom: Iterator + Sized { } } - impl SliceRandom for [T] { type Item = T; fn choose(&self, rng: &mut R) -> Option<&Self::Item> - where R: Rng + ?Sized { + where + R: Rng + ?Sized, + { if self.is_empty() { None } else { @@ -509,7 +531,9 @@ impl SliceRandom for [T] { } fn choose_mut(&mut self, rng: &mut R) -> Option<&mut Self::Item> - where R: Rng + ?Sized { + where + R: Rng + ?Sized, + { if self.is_empty() { None } else { @@ -520,7 +544,9 @@ impl SliceRandom for [T] { #[cfg(feature = "alloc")] fn choose_multiple(&self, rng: &mut R, amount: usize) -> SliceChooseIter - where R: Rng + ?Sized { + where + R: Rng + ?Sized, + { let amount = ::core::cmp::min(amount, self.len()); SliceChooseIter { slice: self, @@ -591,7 +617,9 @@ impl SliceRandom for [T] { } fn shuffle(&mut self, rng: &mut R) - where R: Rng + ?Sized { + where + R: Rng + ?Sized, + { for i in (1..self.len()).rev() { // invariant: elements with index > i have been locked in place. self.swap(i, gen_index(rng, i + 1)); @@ -601,7 +629,9 @@ impl SliceRandom for [T] { fn partial_shuffle( &mut self, rng: &mut R, amount: usize, ) -> (&mut [Self::Item], &mut [Self::Item]) - where R: Rng + ?Sized { + where + R: Rng + ?Sized, + { // This applies Durstenfeld's algorithm for the // [Fisher–Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm) // for an unbiased permutation, but exits early after choosing `amount` @@ -621,7 +651,6 @@ impl SliceRandom for [T] { impl IteratorRandom for I where I: Iterator + Sized {} - /// An iterator over multiple slice elements. /// /// This struct is created by @@ -658,12 +687,12 @@ impl<'a, S: Index + ?Sized + 'a, T: 'a> ExactSizeIterator } } - // Sample a number uniformly between 0 and `ubound`. Uses 32-bit sampling where // possible, primarily in order to produce the same output on 32-bit and 64-bit // platforms. #[inline] fn gen_index(rng: &mut R, ubound: usize) -> usize { + if ubound <= (core::u32::MAX as usize) { rng.gen_range(0..ubound as u32) as usize } else { @@ -671,12 +700,13 @@ fn gen_index(rng: &mut R, ubound: usize) -> usize { } } - #[cfg(test)] mod test { use super::*; - #[cfg(feature = "alloc")] use crate::Rng; - #[cfg(all(feature = "alloc", not(feature = "std")))] use alloc::vec::Vec; + #[cfg(feature = "alloc")] + use crate::Rng; + #[cfg(all(feature = "alloc", not(feature = "std")))] + use alloc::vec::Vec; #[test] fn test_slice_choose() { @@ -837,28 +867,40 @@ mod test { #[cfg(feature = "alloc")] test_iter(r, (0..9).collect::>().into_iter()); test_iter(r, UnhintedIterator { iter: 0..9 }); - test_iter(r, ChunkHintedIterator { - iter: 0..9, - chunk_size: 4, - chunk_remaining: 4, - hint_total_size: false, - }); - test_iter(r, ChunkHintedIterator { - iter: 0..9, - chunk_size: 4, - chunk_remaining: 4, - hint_total_size: true, - }); - test_iter(r, WindowHintedIterator { - iter: 0..9, - window_size: 2, - hint_total_size: false, - }); - test_iter(r, WindowHintedIterator { - iter: 0..9, - window_size: 2, - hint_total_size: true, - }); + test_iter( + r, + ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: false, + }, + ); + test_iter( + r, + ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: true, + }, + ); + test_iter( + r, + WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: false, + }, + ); + test_iter( + r, + WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: true, + }, + ); assert_eq!((0..0).choose(r), None); assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None); @@ -891,28 +933,40 @@ mod test { #[cfg(feature = "alloc")] test_iter(r, (0..9).collect::>().into_iter()); test_iter(r, UnhintedIterator { iter: 0..9 }); - test_iter(r, ChunkHintedIterator { - iter: 0..9, - chunk_size: 4, - chunk_remaining: 4, - hint_total_size: false, - }); - test_iter(r, ChunkHintedIterator { - iter: 0..9, - chunk_size: 4, - chunk_remaining: 4, - hint_total_size: true, - }); - test_iter(r, WindowHintedIterator { - iter: 0..9, - window_size: 2, - hint_total_size: false, - }); - test_iter(r, WindowHintedIterator { - iter: 0..9, - window_size: 2, - hint_total_size: true, - }); + test_iter( + r, + ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: false, + }, + ); + test_iter( + r, + ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: true, + }, + ); + test_iter( + r, + WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: false, + }, + ); + test_iter( + r, + WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: true, + }, + ); assert_eq!((0..0).choose(r), None); assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None); @@ -932,33 +986,48 @@ mod test { } let reference = test_iter(0..9); - assert_eq!(test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()), reference); + assert_eq!( + test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()), + reference + ); #[cfg(feature = "alloc")] assert_eq!(test_iter((0..9).collect::>().into_iter()), reference); assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference); - assert_eq!(test_iter(ChunkHintedIterator { - iter: 0..9, - chunk_size: 4, - chunk_remaining: 4, - hint_total_size: false, - }), reference); - assert_eq!(test_iter(ChunkHintedIterator { - iter: 0..9, - chunk_size: 4, - chunk_remaining: 4, - hint_total_size: true, - }), reference); - assert_eq!(test_iter(WindowHintedIterator { - iter: 0..9, - window_size: 2, - hint_total_size: false, - }), reference); - assert_eq!(test_iter(WindowHintedIterator { - iter: 0..9, - window_size: 2, - hint_total_size: true, - }), reference); + assert_eq!( + test_iter(ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: false, + }), + reference + ); + assert_eq!( + test_iter(ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: true, + }), + reference + ); + assert_eq!( + test_iter(WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: false, + }), + reference + ); + assert_eq!( + test_iter(WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: true, + }), + reference + ); } #[test] @@ -1129,7 +1198,7 @@ mod test { assert_eq!(choose([].iter().cloned()), None); assert_eq!(choose(0..100), Some(33)); - assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(40)); + assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27)); assert_eq!( choose(ChunkHintedIterator { iter: 0..100, @@ -1174,8 +1243,8 @@ mod test { } assert_eq!(choose([].iter().cloned()), None); - assert_eq!(choose(0..100), Some(40)); - assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(40)); + assert_eq!(choose(0..100), Some(27)); + assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27)); assert_eq!( choose(ChunkHintedIterator { iter: 0..100, @@ -1183,7 +1252,7 @@ mod test { chunk_remaining: 32, hint_total_size: false, }), - Some(40) + Some(27) ); assert_eq!( choose(ChunkHintedIterator { @@ -1192,7 +1261,7 @@ mod test { chunk_remaining: 32, hint_total_size: true, }), - Some(40) + Some(27) ); assert_eq!( choose(WindowHintedIterator { @@ -1200,7 +1269,7 @@ mod test { window_size: 32, hint_total_size: false, }), - Some(40) + Some(27) ); assert_eq!( choose(WindowHintedIterator { @@ -1208,7 +1277,7 @@ mod test { window_size: 32, hint_total_size: true, }), - Some(40) + Some(27) ); } @@ -1260,9 +1329,13 @@ mod test { // Case 2: All of the weights are 0 let choices = [('a', 0), ('b', 0), ('c', 0)]; - assert_eq!(choices - .choose_multiple_weighted(&mut rng, 2, |item| item.1) - .unwrap().count(), 2); + assert_eq!( + choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap() + .count(), + 2 + ); // Case 3: Negative weights let choices = [('a', -1), ('b', 1), ('c', 1)]; @@ -1275,9 +1348,13 @@ mod test { // Case 4: Empty list let choices = []; - assert_eq!(choices - .choose_multiple_weighted(&mut rng, 0, |_: &()| 0) - .unwrap().count(), 0); + assert_eq!( + choices + .choose_multiple_weighted(&mut rng, 0, |_: &()| 0) + .unwrap() + .count(), + 0 + ); // Case 5: NaN weights let choices = [('a', core::f64::NAN), ('b', 1.0), ('c', 1.0)];