Skip to content

Commit

Permalink
central limit improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
TheIronBorn committed Jun 19, 2018
1 parent 2d04859 commit a418ea3
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 68 deletions.
4 changes: 2 additions & 2 deletions benches/box_muller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::mem::{size_of, transmute};
use stdsimd::simd::*;
use test::Bencher;

use rand::{Rng, FromEntropy};
use rand::{Rng, RngCore, FromEntropy};
use rand::prng::{SfcAltSplit64x2a, XorShiftRng};
use rand::prng::hc128::Hc128Rng;
use rand::distributions::box_muller::{BoxMuller, BoxMullerCore};
Expand Down Expand Up @@ -84,7 +84,7 @@ macro_rules! distr_fx {
}
}

// module structure to allow easy `cargo benchcmp` use
// module structure to ease `cargo benchcmp` use

// hacked sin_cos method
mod hacked {
Expand Down
1 change: 1 addition & 0 deletions benches/generators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ init_gen!(init_hc128, Hc128Rng);
init_gen!(init_isaac, IsaacRng);
init_gen!(init_isaac64, Isaac64Rng);
init_gen!(init_chacha, ChaChaRng);
#[cfg(features = "simd_support")]
init_gen!(init_sfc32x4, Sfc32x4Rng);

#[bench]
Expand Down
10 changes: 8 additions & 2 deletions rand_core/src/simd_impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@ use stdsimd::simd::*;
///
/// ```rust
/// #![feature(stdsimd)]
/// use std::simd::u32x4;
/// extern crate stdsimd;
/// extern crate rand_core;
/// use stdsimd::simd::u32x4;
/// use rand_core::simd_impls::SimdRng;
///
/// #[allow(dead_code)]
/// struct CountingSimdRng(u32x4);
///
/// impl SimdRng<u32x4> for CountingSimdRng {
Expand All @@ -39,10 +42,13 @@ pub trait SimdRng<Vector> {
///
/// ```rust
/// #![feature(stdsimd)]
/// use std::simd::u32x4;
/// extern crate stdsimd;
/// extern crate rand_core;
/// use stdsimd::simd::u32x4;
/// use rand_core::{RngCore, Error};
/// use rand_core::simd_impls::{SimdRng, SimdRngImpls};
///
/// #[allow(dead_code)]
/// struct CountingSimdRng(u32x4);
///
/// impl SimdRng<u32x4> for CountingSimdRng {
Expand Down
1 change: 1 addition & 0 deletions src/distributions/box_muller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use core::intrinsics::fmaf64;
#[cfg(feature="simd_support")]
use core::mem::*;
#[cfg(feature="simd_support")]
#[allow(unused_imports)]
use core::{f32, f64};

use Rng;
Expand Down
168 changes: 107 additions & 61 deletions src/distributions/central_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ use stdsimd::simd::*;
use Rng;
use distributions::Distribution;

///
pub trait CentralLimit<T> {
///
fn new(mean: T, std_dev: T) -> Self;
}

/// The normal distribution `N(mean, std_dev**2)`.
///
/// This uses the central limit theorem. It is well suited to an SIMD
Expand All @@ -22,113 +28,153 @@ use distributions::Distribution;
/// println!("{} is from a N(2, 9) distribution", v)
/// ```
#[derive(Clone, Copy, Debug)]
pub struct CentralLimit<T> {
mean: f64,
std_dev: f64,
phantom: PhantomData<T>,
pub struct CentralLimitVector<T> {
mean: T,
std_dev: T,
}

impl<T> CentralLimit<T> {
/// Construct a new `CentralLimit` distribution with the given mean and
/// standard deviation.
///
/// # Panics
///
/// Panics if `std_dev < 0`.
#[inline]
// TODO: implement for vectors/f32
pub fn new(mean: f64, std_dev: f64) -> Self {
assert!(std_dev >= 0.0, "CentralLimit::new called with `std_dev` < 0");
Self {
mean,
std_dev,
phantom: PhantomData,
}
}
}

macro_rules! impl_simd {
macro_rules! impl_clt_vector {
($ty:ident, $scalar:ty, $num:expr) => (
impl Distribution<$ty> for CentralLimit<$ty> {
impl CentralLimit<$ty> for CentralLimitVector<$ty> {
/// Construct a new `CentralLimitVector` distribution with the given mean and
/// standard deviation.
///
/// # Panics
///
/// Panics if `std_dev < 0`.
#[inline]
fn new(mean: $ty, std_dev: $ty) -> Self {
assert!(std_dev.ge($ty::splat(0.0)).all(), "CentralLimitVector::new called with `std_dev` < 0");
Self { mean, std_dev }
}
}

impl Distribution<$ty> for CentralLimitVector<$ty> {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $ty {
// Irwin–Hall mean and std_dev (sqrt variance)
// Irwin–Hall mean and std_dev (square root of variance)
// https://en.wikipedia.org/wiki/Irwin–Hall_distribution
const IH_MEAN: $scalar = $num as $scalar / 2.0;
// TODO: verify eval at compile
// TODO: benchmark `sqrt` vs `sqrte`
let ih_std_dev_inv = 1.0 / ($num as $scalar / 12.0).sqrt();

// get Irwin–Hall distr
// get Irwin–Hall distribution
let mut sum = $ty::default();
for _ in 0..$num {
sum += rng.gen::<$ty>();
}
// adjust Irwin–Hall distr to normal distr
// TODO: look into optimizing/combining the two distribution
// adjustments
let n = (sum - IH_MEAN) * ih_std_dev_inv;

self.mean + self.std_dev * n
// adjust Irwin–Hall distribution to specified normal distribution
// TODO: ensure optimized when mean and std_dev are SIMD vectors
// NOTE: variable names here might be misleading
// NOTE: this is fast when mean and std_dev are compile-time constant,
// slower than other math when not. We prioritize the constant
// case here.
let std_dev = self.std_dev * ih_std_dev_inv;
let mean = self.mean - IH_MEAN * std_dev;
mean + std_dev * sum
}
}
)
}

// TODO: tune for better number of samples?
/*impl_simd! { f32x2, f32, 4 }
impl_simd! { f32x4, f32, 4 }
impl_simd! { f32x8, f32, 4 }
impl_simd! { f32x16, f32, 4 }*/
impl_simd! { f64x2, f64, 4 }
impl_simd! { f64x4, f64, 4 }
impl_simd! { f64x8, f64, 4 }

macro_rules! impl_simd_to_scalar {
impl_clt_vector! { f32x2, f32, 4 }
impl_clt_vector! { f32x4, f32, 4 }
impl_clt_vector! { f32x8, f32, 4 }
impl_clt_vector! { f32x16, f32, 4 }
impl_clt_vector! { f64x2, f64, 4 }
impl_clt_vector! { f64x4, f64, 4 }
impl_clt_vector! { f64x8, f64, 4 }

/// The normal distribution `N(mean, std_dev**2)`.
///
/// This uses the central limit theorem. It is well suited to an SIMD
/// implementation, even on older hardware.
///
/// # Example
///
/// ```
/// use rand::distributions::{CentralLimit, Distribution};
///
/// // mean 2, standard deviation 3
/// let normal = CentralLimit::new(2.0, 3.0);
/// let v = normal.sample(&mut rand::thread_rng());
/// println!("{} is from a N(2, 9) distribution", v)
/// ```
#[derive(Clone, Copy, Debug)]
pub struct CentralLimitScalar<T, V> {
mean: T,
std_dev: T,
phantom: PhantomData<V>,
}


macro_rules! impl_clt_scalar {
($ty:ident, $scalar:ty) => (
impl Distribution<$scalar> for CentralLimit<$ty> {
impl CentralLimit<$scalar> for CentralLimitScalar<$scalar, $ty> {
/// Construct a new `CentralLimitScalar` distribution with the given mean and
/// standard deviation.
///
/// # Panics
///
/// Panics if `std_dev < 0`.
#[inline]
fn new(mean: $scalar, std_dev: $scalar) -> Self {
assert!(std_dev >= 0.0, "CentralLimitScalar::new called with `std_dev` < 0");
Self { mean, std_dev, phantom: PhantomData }
}
}

impl Distribution<$scalar> for CentralLimitScalar<$scalar, $ty> {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $scalar {
const IH_MEAN: $scalar = $ty::lanes / 2.0;
let ih_std_dev_inv = 1.0 / ($ty::lanes / 12.0).sqrt();
// Irwin–Hall mean and std_dev (square root of variance)
// https://en.wikipedia.org/wiki/Irwin–Hall_distribution
const IH_MEAN: $scalar = $ty::lanes() as $scalar / 2.0;
let ih_std_dev_inv = 1.0 / ($ty::lanes() as $scalar / 12.0).sqrt();

// get Irwin–Hall distribution
let sum = rng.gen::<$ty>().sum();

let n = (rng.gen::<$ty>().sum() - IH_MEAN) * ih_std_dev_inv;
self.mean + self.std_dev * n
// adjust Irwin–Hall distribution to specified normal distribution
let std_dev = self.std_dev * ih_std_dev_inv;
let mean = self.mean - IH_MEAN * std_dev;
mean + std_dev * sum
}
}
)
}

/*impl_simd_to_scalar! { f32x2, f32 }
impl_simd_to_scalar! { f32x4, f32 }
impl_simd_to_scalar! { f32x8, f32 }
impl_simd_to_scalar! { f32x16, f32 }*/
impl_simd_to_scalar! { f64x2, f64 }
impl_simd_to_scalar! { f64x4, f64 }
impl_simd_to_scalar! { f64x8, f64 }
impl_clt_scalar! { f32x2, f32 }
impl_clt_scalar! { f32x4, f32 }
impl_clt_scalar! { f32x8, f32 }
impl_clt_scalar! { f32x16, f32 }
impl_clt_scalar! { f64x2, f64 }
impl_clt_scalar! { f64x4, f64 }
impl_clt_scalar! { f64x8, f64 }

#[cfg(test)]
mod tests {
use stdsimd::simd::*;
use super::{Rng, CentralLimit};
use super::*;

#[test]
fn test_clt_vector() {
let norm = CentralLimit::<f64x2>::new(10.0, 10.0);
let norm = CentralLimitVector::new(f64x2::splat(10.0), f64x2::splat(10.0));
let mut rng = ::test::rng(210);
for _ in 0..1000 {
let _: f64x2 = rng.sample(norm);
rng.sample(norm);
}
}
#[test]
fn test_clt_scalar() {
let norm = CentralLimit::<f64x2>::new(10.0, 10.0);
let norm = CentralLimitScalar::<f64, f64x2>::new(10.0, 10.0);
let mut rng = ::test::rng(210);
for _ in 0..1000 {
let _: f64 = rng.sample(norm);
rng.sample(norm);
}
}
#[test]
#[should_panic]
fn test_clt_invalid_sd() {
CentralLimit::<f32x2>::new(10.0, -1.0);
CentralLimitScalar::<f64, f64x2>::new(10.0, -1.0);
}
}
6 changes: 3 additions & 3 deletions src/distributions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
//! of course [`sample`].
//!
//! Abstractly, a [probability distribution] describes the probability of
//! occurance of each value in its sample space.
//! occurrence of each value in its sample space.
//!
//! More concretely, an implementation of `Distribution<T>` for type `X` is an
//! algorithm for choosing values from the sample space (a subset of `T`)
Expand Down Expand Up @@ -182,7 +182,7 @@ pub use self::uniform::Uniform as Range;
#[cfg(feature = "std")]
#[doc(inline)] pub use self::binomial::Binomial;
#[doc(inline)] pub use self::bernoulli::Bernoulli;
#[cfg(feature="simd_support")] // neccessary for doc tests?
#[cfg(feature="simd_support")] // necessary for doc tests?
pub use self::box_muller::{BoxMuller, BoxMullerCore, LogBoxMuller};
#[doc(inline)] pub use self::cauchy::Cauchy;

Expand All @@ -192,7 +192,7 @@ pub mod uniform;
#[cfg(feature="std")]
#[doc(hidden)] pub mod normal;
#[cfg(all(feature="std", feature = "simd_support"))]
#[doc(hidden)] pub mod central_limit;
pub mod central_limit;
#[cfg(feature="std")]
#[doc(hidden)] pub mod exponential;
#[cfg(feature = "std")]
Expand Down

0 comments on commit a418ea3

Please sign in to comment.