diff --git a/proptest-regressions/math.txt b/proptest-regressions/math.txt index 5f9e272..0edc464 100644 --- a/proptest-regressions/math.txt +++ b/proptest-regressions/math.txt @@ -7,3 +7,4 @@ cc ea2a2598ee637946e47a8a744d25f76239671b82c5971a56b14ce5ee06838cb4 # shrinks to x = 4.8329699435311735, y = 9.38911339170414 cc cf16a8d08e8ee8f7f3d3cfd60840e136ac51d130dffcd42db1a9a68d7e51f394 # shrinks to (x, y) = ([2.9394791070664547e110, 0.0], [inf, 0.0]), a = -2.4153502104628106e222 cc 28897b64919482133f3885c3de51da0895409d23c9dd503a7b51a3e949bda307 # shrinks to (x1, x2, x3, y1, y2) = ([0.0], [0.0], [-4.0946726283401733e139], [0.0], [1.3157422010991668e73]) +cc acf6caef8a89a75ddab31ec3e391850723a625084df032aec2b650c2f95ba1fb # shrinks to (x, y) = ([0.0, 0.0, 0.0, 1.2271235629394547e205, 0.0, 0.0, -0.0, 0.0], [0.0, 0.0, 0.0, 7.121658452243713e81, 0.0, 0.0, 0.0, 0.0]), a = -6.261465657118442e-124 diff --git a/src/cpu_math.rs b/src/cpu_math.rs index f335e5e..81cb489 100644 --- a/src/cpu_math.rs +++ b/src/cpu_math.rs @@ -115,10 +115,12 @@ impl Math for CpuMath { } fn array_all_finite_and_nonzero(&mut self, array: &Self::Array) -> bool { - array - .col_ref(0) - .iter() - .all(|&x| x.is_finite() & (x != 0f64)) + self.arch.dispatch(|| { + array + .col_ref(0) + .iter() + .all(|&x| x.is_finite() & (x != 0f64)) + }) } fn array_mult(&mut self, array1: &Self::Array, array2: &Self::Array, dest: &mut Self::Array) { @@ -152,16 +154,18 @@ impl Math for CpuMath { value: &Self::Array, diff_scale: f64, // 1 / self.count ) { - izip!( - mean.col_mut(0).iter_mut(), - variance.col_mut(0).iter_mut(), - value.col_ref(0) - ) - .for_each(|(mean, mut var, x)| { - let diff = x - *mean; - *mean += diff * diff_scale; - *var += diff * diff; - }); + self.arch.dispatch(|| { + izip!( + mean.col_mut(0).iter_mut(), + variance.col_mut(0).iter_mut(), + value.col_ref(0) + ) + .for_each(|(mean, var, x)| { + let diff = x - *mean; + *mean += diff * diff_scale; + *var += diff * diff; + }); + }) } fn array_update_var_inv_std_draw_grad( @@ -173,24 +177,26 @@ impl Math for CpuMath { fill_invalid: Option, clamp: (f64, f64), ) { - izip!( - variance_out.col_mut(0).iter_mut(), - inv_std.col_mut(0).iter_mut(), - draw_var.col_ref(0).iter(), - grad_var.col_ref(0).iter(), - ) - .for_each(|(var_out, inv_std_out, &draw_var, &grad_var)| { - let val = (draw_var / grad_var).sqrt(); - if (!val.is_finite()) | (val == 0f64) { - if let Some(fill_val) = fill_invalid { - *var_out = fill_val; - *inv_std_out = fill_val.recip().sqrt(); + self.arch.dispatch(|| { + izip!( + variance_out.col_mut(0).iter_mut(), + inv_std.col_mut(0).iter_mut(), + draw_var.col_ref(0).iter(), + grad_var.col_ref(0).iter(), + ) + .for_each(|(var_out, inv_std_out, &draw_var, &grad_var)| { + let val = (draw_var / grad_var).sqrt(); + if (!val.is_finite()) | (val == 0f64) { + if let Some(fill_val) = fill_invalid { + *var_out = fill_val; + *inv_std_out = fill_val.recip().sqrt(); + } + } else { + let val = val.clamp(clamp.0, clamp.1); + *var_out = val; + *inv_std_out = val.recip().sqrt(); } - } else { - let val = val.clamp(clamp.0, clamp.1); - *var_out = val; - *inv_std_out = val.recip().sqrt(); - } + }); }); } @@ -202,16 +208,18 @@ impl Math for CpuMath { fill_invalid: f64, clamp: (f64, f64), ) { - izip!( - variance_out.col_mut(0).iter_mut(), - inv_std.col_mut(0).iter_mut(), - gradient.col_ref(0).iter(), - ) - .for_each(|(var_out, inv_std_out, &grad_var)| { - let val = grad_var.abs().clamp(clamp.0, clamp.1).recip(); - let val = if val.is_finite() { val } else { fill_invalid }; - *var_out = val; - *inv_std_out = val.recip().sqrt(); + self.arch.dispatch(|| { + izip!( + variance_out.col_mut(0).iter_mut(), + inv_std.col_mut(0).iter_mut(), + gradient.col_ref(0).iter(), + ) + .for_each(|(var_out, inv_std_out, &grad_var)| { + let val = grad_var.abs().clamp(clamp.0, clamp.1).recip(); + let val = if val.is_finite() { val } else { fill_invalid }; + *var_out = val; + *inv_std_out = val.recip().sqrt(); + }); }); } } diff --git a/src/mass_matrix.rs b/src/mass_matrix.rs index c6b89bf..f9878e6 100644 --- a/src/mass_matrix.rs +++ b/src/mass_matrix.rs @@ -1,6 +1,3 @@ -use itertools::izip; -use multiversion::multiversion; - use crate::{ math_base::Math, nuts::Collector, @@ -71,23 +68,6 @@ impl DiagMassMatrix { } } -#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] -fn update_diag( - variance_out: &mut [f64], - inv_std_out: &mut [f64], - new_variance: impl Iterator>, -) { - izip!(variance_out, inv_std_out, new_variance).for_each(|(var, inv_std, x)| { - if let Some(x) = x { - assert!(x.is_finite(), "Illegal value on mass matrix: {}", x); - assert!(x > 0f64, "Illegal value on mass matrix: {}", x); - //assert!(*var != x, "No change in mass matrix from {} to {}", *var, x); - *var = x; - *inv_std = (1. / x).sqrt(); - }; - }); -} - impl MassMatrix for DiagMassMatrix { fn update_velocity(&self, math: &mut M, state: &mut InnerState) { math.array_mult(&self.variance, &state.p, &mut state.v); diff --git a/src/math.rs b/src/math.rs index dadd322..7161fde 100644 --- a/src/math.rs +++ b/src/math.rs @@ -373,7 +373,7 @@ mod tests { let mut y = y.clone(); axpy(&x[..], &mut y[..], a); for ((&x, y), out) in x.iter().zip(orig).zip(y) { - assert_approx_eq(out, a * x + y); + assert_approx_eq(out, a.mul_add(x, y)); } } diff --git a/src/nuts.rs b/src/nuts.rs index 4787068..0efc3bf 100644 --- a/src/nuts.rs +++ b/src/nuts.rs @@ -1069,7 +1069,7 @@ mod tests { fn to_arrow() { let ndim = 10; let func = NormalLogp::new(ndim, 3.); - let mut math = CpuMath::new(func); + let math = CpuMath::new(func); let settings = SamplerArgs::default(); let mut rng = thread_rng(); diff --git a/src/state.rs b/src/state.rs index 489301b..45e83d2 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,8 +1,7 @@ use std::{ cell::RefCell, fmt::Debug, - marker::PhantomData, - ops::{Deref, DerefMut}, + ops::Deref, rc::{Rc, Weak}, }; @@ -54,7 +53,6 @@ pub(crate) struct InnerState { pub(crate) idx_in_trajectory: i64, pub(crate) kinetic_energy: f64, pub(crate) potential_energy: f64, - _phantom_todo: PhantomData, } pub(crate) struct InnerStateReusable { @@ -74,7 +72,6 @@ impl<'pool, M: Math> InnerStateReusable { idx_in_trajectory: 0, kinetic_energy: 0., potential_energy: 0., - _phantom_todo: PhantomData::default(), }, reuser: Rc::downgrade(&Rc::clone(&owner.storage)), } @@ -225,7 +222,7 @@ mod tests { fn crate_pool() { let logp = NormalLogp::new(10, 0.2); let mut math = CpuMath::new(logp); - let mut pool = StatePool::new(&mut math, 10); + let pool = StatePool::new(&mut math, 10); let mut state = pool.new_state(&mut math); assert!(state.p.nrows() == 10); assert!(state.p.ncols() == 1); @@ -241,7 +238,7 @@ mod tests { let dim = 10; let logp = NormalLogp::new(dim, 0.2); let mut math = CpuMath::new(logp); - let mut pool = StatePool::new(&mut math, 10); + let pool = StatePool::new(&mut math, 10); let a = pool.new_state(&mut math); assert_eq!(a.idx_in_trajectory, 0);