From 5b87e48ab335c633257f3d179b6f5f584cb4a798 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Sun, 22 Oct 2023 18:41:42 +0200 Subject: [PATCH] refactor: Prepare for GPU support --- Cargo.toml | 7 +- src/adapt_strategy.rs | 256 +++++++++--------- src/cpu_math.rs | 224 ++++++++++++++++ src/cpu_state.rs | 342 ------------------------- src/lib.rs | 34 +-- src/mass_matrix.rs | 154 ++++++----- src/math_base.rs | 102 ++++++++ src/nuts.rs | 266 +++++++++++-------- src/{cpu_potential.rs => potential.rs} | 153 +++++------ src/{cpu_sampler.rs => sampler.rs} | 191 +++----------- src/state.rs | 254 ++++++++++++++++++ src/stepsize.rs | 23 +- 12 files changed, 1097 insertions(+), 909 deletions(-) create mode 100644 src/cpu_math.rs delete mode 100644 src/cpu_state.rs create mode 100644 src/math_base.rs rename src/{cpu_potential.rs => potential.rs} (50%) rename src/{cpu_sampler.rs => sampler.rs} (55%) create mode 100644 src/state.rs diff --git a/Cargo.toml b/Cargo.toml index 74410bc..6dc0fcd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "nuts-rs" -version = "0.8.0" +version = "0.9.0" authors = ["Adrian Seyboldt ", "PyMC Developers "] edition = "2021" license = "MIT" @@ -19,12 +19,13 @@ rand = { version = "0.8.5", features = ["small_rng"] } rand_distr = "0.4.3" multiversion = "0.7.2" itertools = "0.11.0" -crossbeam = "0.8.2" thiserror = "1.0.43" -rayon = "1.7.0" arrow2 = { version = "0.17.3", optional = true } rand_chacha = "0.3.1" anyhow = "1.0.72" +faer = "0.13.5" +faer-core = "0.13.5" +pulp = "0.17.0" [dev-dependencies] proptest = "1.2.0" diff --git a/src/adapt_strategy.rs b/src/adapt_strategy.rs index b0f50f8..7a2aced 100644 --- a/src/adapt_strategy.rs +++ b/src/adapt_strategy.rs @@ -5,12 +5,13 @@ use arrow2::{ array::{MutableArray, MutableFixedSizeListArray, MutablePrimitiveArray, StructArray, TryPush}, datatypes::{DataType, Field}, }; -use itertools::izip; use crate::{ - cpu_potential::{CpuLogpFunc, EuclideanPotential}, mass_matrix::{DiagMassMatrix, DrawGradCollector, MassMatrix, RunningVariance}, - nuts::{AdaptStrategy, Collector, Hamiltonian, NutsOptions}, + math_base::Math, + nuts::{AdaptStrategy, Collector, NutsOptions}, + potential::EuclideanPotential, + state::State, stepsize::{AcceptanceRateCollector, DualAverage, DualAverageOptions}, DivergenceInfo, }; @@ -126,13 +127,13 @@ impl Default for DualAverageSettings { } } -impl AdaptStrategy for DualAverageStrategy { - type Potential = EuclideanPotential; - type Collector = AcceptanceRateCollector; +impl> AdaptStrategy for DualAverageStrategy { + type Potential = EuclideanPotential; + type Collector = AcceptanceRateCollector; type Stats = DualAverageStats; type Options = DualAverageSettings; - fn new(options: Self::Options, _num_tune: u64, _dim: usize) -> Self { + fn new(_math: &mut M, options: Self::Options, _num_tune: u64) -> Self { Self { options, enabled: true, @@ -146,15 +147,17 @@ impl AdaptStrategy for DualAverageStrategy fn init( &mut self, + _math: &mut M, _options: &mut NutsOptions, potential: &mut Self::Potential, - _state: &::State, + _state: &State, ) { potential.step_size = self.options.initial_step; } fn adapt( &mut self, + _math: &mut M, _options: &mut NutsOptions, potential: &mut Self::Potential, _draw: u64, @@ -179,12 +182,13 @@ impl AdaptStrategy for DualAverageStrategy potential.step_size = self.step_size_adapt.current_step_size() } - fn new_collector(&self) -> Self::Collector { + fn new_collector(&self, _math: &mut M) -> Self::Collector { AcceptanceRateCollector::new() } fn current_stats( &self, + _math: &mut M, _options: &NutsOptions, _potential: &Self::Potential, collector: &Self::Collector, @@ -204,39 +208,30 @@ pub struct DiagAdaptExpSettings { pub store_mass_matrix: bool, } -pub(crate) struct ExpWindowDiagAdapt { - dim: usize, - exp_variance_draw: RunningVariance, - exp_variance_grad: RunningVariance, - exp_variance_grad_bg: RunningVariance, - exp_variance_draw_bg: RunningVariance, +pub(crate) struct ExpWindowDiagAdapt { + exp_variance_draw: RunningVariance, + exp_variance_grad: RunningVariance, + exp_variance_grad_bg: RunningVariance, + exp_variance_draw_bg: RunningVariance, settings: DiagAdaptExpSettings, - _phantom: PhantomData, + _phantom: PhantomData, } -impl ExpWindowDiagAdapt { - fn update_estimators(&mut self, collector: &DrawGradCollector) { +impl ExpWindowDiagAdapt { + fn update_estimators(&mut self, math: &mut M, collector: &DrawGradCollector) { if collector.is_good { - self.exp_variance_draw - .add_sample(collector.draw.iter().copied()); - self.exp_variance_grad - .add_sample(collector.grad.iter().copied()); - self.exp_variance_draw_bg - .add_sample(collector.draw.iter().copied()); - self.exp_variance_grad_bg - .add_sample(collector.grad.iter().copied()); + self.exp_variance_draw.add_sample(math, &collector.draw); + self.exp_variance_grad.add_sample(math, &collector.grad); + self.exp_variance_draw_bg.add_sample(math, &collector.draw); + self.exp_variance_grad_bg.add_sample(math, &collector.grad); } } - fn switch(&mut self) { - self.exp_variance_draw = std::mem::replace( - &mut self.exp_variance_draw_bg, - RunningVariance::new(self.dim), - ); - self.exp_variance_grad = std::mem::replace( - &mut self.exp_variance_grad_bg, - RunningVariance::new(self.dim), - ); + fn switch(&mut self, math: &mut M) { + self.exp_variance_draw = + std::mem::replace(&mut self.exp_variance_draw_bg, RunningVariance::new(math)); + self.exp_variance_grad = + std::mem::replace(&mut self.exp_variance_grad_bg, RunningVariance::new(math)); } fn current_count(&self) -> u64 { @@ -249,23 +244,25 @@ impl ExpWindowDiagAdapt { self.exp_variance_draw_bg.count() } - fn update_potential(&self, potential: &mut EuclideanPotential) { + fn update_potential( + &self, + math: &mut M, + potential: &mut EuclideanPotential>, + ) { if self.current_count() < 3 { return; } - potential.mass_matrix.update_diag( - izip!( - self.exp_variance_draw.current(), - self.exp_variance_grad.current(), - ) - .map(|(draw, grad)| { - let val = (draw / grad).sqrt(); - if (!val.is_finite()) || (val == 0f64) { - None - } else { - Some(val.clamp(LOWER_LIMIT, UPPER_LIMIT)) - } - }), + + let (draw_var, draw_scale) = self.exp_variance_draw.current(); + let (grad_var, grad_scale) = self.exp_variance_grad.current(); + assert!(draw_scale == grad_scale); + + potential.mass_matrix.update_diag_draw_grad( + math, + draw_var, + grad_var, + None, + (LOWER_LIMIT, UPPER_LIMIT), ); } } @@ -335,19 +332,18 @@ impl ArrowRow for ExpWindowDiagAdaptStats { } } -impl AdaptStrategy for ExpWindowDiagAdapt { - type Potential = EuclideanPotential; - type Collector = DrawGradCollector; +impl AdaptStrategy for ExpWindowDiagAdapt { + type Potential = EuclideanPotential>; + type Collector = DrawGradCollector; type Stats = ExpWindowDiagAdaptStats; type Options = DiagAdaptExpSettings; - fn new(options: Self::Options, _num_tune: u64, dim: usize) -> Self { + fn new(math: &mut M, options: Self::Options, _num_tune: u64) -> Self { Self { - dim, - exp_variance_draw: RunningVariance::new(dim), - exp_variance_grad: RunningVariance::new(dim), - exp_variance_draw_bg: RunningVariance::new(dim), - exp_variance_grad_bg: RunningVariance::new(dim), + exp_variance_draw: RunningVariance::new(math), + exp_variance_grad: RunningVariance::new(math), + exp_variance_draw_bg: RunningVariance::new(math), + exp_variance_grad_bg: RunningVariance::new(math), settings: options, _phantom: PhantomData::default(), } @@ -355,29 +351,27 @@ impl AdaptStrategy for ExpWindowDiagAdapt { fn init( &mut self, + math: &mut M, _options: &mut NutsOptions, potential: &mut Self::Potential, - state: &::State, + state: &State, ) { - self.exp_variance_draw.add_sample(state.q.iter().copied()); - self.exp_variance_draw_bg - .add_sample(state.q.iter().copied()); - self.exp_variance_grad - .add_sample(state.grad.iter().copied()); - self.exp_variance_grad_bg - .add_sample(state.grad.iter().copied()); - - potential.mass_matrix.update_diag( - state - .grad - .iter() - .map(|&grad| grad.abs().clamp(INIT_LOWER_LIMIT, INIT_UPPER_LIMIT).recip()) - .map(|var| if var.is_finite() { Some(var) } else { Some(1.) }), + self.exp_variance_draw.add_sample(math, &state.q); + self.exp_variance_draw_bg.add_sample(math, &state.q); + self.exp_variance_grad.add_sample(math, &state.grad); + self.exp_variance_grad_bg.add_sample(math, &state.grad); + + potential.mass_matrix.update_diag_grad( + math, + &state.grad, + 1f64, + (INIT_LOWER_LIMIT, INIT_UPPER_LIMIT), ); } fn adapt( &mut self, + _math: &mut M, _options: &mut NutsOptions, _potential: &mut Self::Potential, _draw: u64, @@ -386,18 +380,19 @@ impl AdaptStrategy for ExpWindowDiagAdapt { // Must be controlled from a different meta strategy } - fn new_collector(&self) -> Self::Collector { - DrawGradCollector::new(self.dim) + fn new_collector(&self, math: &mut M) -> Self::Collector { + DrawGradCollector::new(math) } fn current_stats( &self, + math: &mut M, _options: &NutsOptions, potential: &Self::Potential, _collector: &Self::Collector, ) -> Self::Stats { let diag = if self.settings.store_mass_matrix { - Some(potential.mass_matrix.variance.clone()) + Some(math.box_array(&potential.mass_matrix.variance)) } else { None }; @@ -407,9 +402,9 @@ impl AdaptStrategy for ExpWindowDiagAdapt { } } -pub(crate) struct GradDiagStrategy { - step_size: DualAverageStrategy, - mass_matrix: ExpWindowDiagAdapt, +pub(crate) struct GradDiagStrategy { + step_size: DualAverageStrategy>, + mass_matrix: ExpWindowDiagAdapt, options: GradDiagOptions, num_tune: u64, // The number of draws in the the early window @@ -442,16 +437,13 @@ impl Default for GradDiagOptions { } } -impl AdaptStrategy for GradDiagStrategy { - type Potential = EuclideanPotential; - type Collector = CombinedCollector< - AcceptanceRateCollector< as Hamiltonian>::State>, - DrawGradCollector, - >; +impl AdaptStrategy for GradDiagStrategy { + type Potential = EuclideanPotential>; + type Collector = CombinedCollector, DrawGradCollector>; type Stats = CombinedStats; type Options = GradDiagOptions; - fn new(options: Self::Options, num_tune: u64, dim: usize) -> Self { + fn new(math: &mut M, options: Self::Options, num_tune: u64) -> Self { let num_tune_f = num_tune as f64; let step_size_window = (options.step_size_window * num_tune_f) as u64; let early_end = (options.early_window * num_tune_f) as u64; @@ -460,8 +452,8 @@ impl AdaptStrategy for GradDiagStrategy { assert!(early_end < num_tune); Self { - step_size: DualAverageStrategy::new(options.dual_average_options, num_tune, dim), - mass_matrix: ExpWindowDiagAdapt::new(options.mass_matrix_options, num_tune, dim), + step_size: DualAverageStrategy::new(math, options.dual_average_options, num_tune), + mass_matrix: ExpWindowDiagAdapt::new(math, options.mass_matrix_options, num_tune), options, num_tune, early_end, @@ -471,17 +463,19 @@ impl AdaptStrategy for GradDiagStrategy { fn init( &mut self, + math: &mut M, options: &mut NutsOptions, potential: &mut Self::Potential, - state: &::State, + state: &State, ) { - self.step_size.init(options, potential, state); - self.mass_matrix.init(options, potential, state); + self.step_size.init(math, options, potential, state); + self.mass_matrix.init(math, options, potential, state); self.step_size.enable(); } fn adapt( &mut self, + math: &mut M, options: &mut NutsOptions, potential: &mut Self::Potential, draw: u64, @@ -499,21 +493,22 @@ impl AdaptStrategy for GradDiagStrategy { self.options.mass_matrix_switch_freq }; - self.mass_matrix.update_estimators(&collector.collector2); + self.mass_matrix + .update_estimators(math, &collector.collector2); // We only switch if we have switch_freq draws in the background estimate, // and if the number of remaining mass matrix steps is larger than // the switch frequency. let could_switch = self.mass_matrix.background_count() >= switch_freq; let is_late = switch_freq + draw > self.final_step_size_window; if could_switch && (!is_late) { - self.mass_matrix.switch(); + self.mass_matrix.switch(math); } - self.mass_matrix.update_potential(potential); + self.mass_matrix.update_potential(math, potential); if is_late { self.step_size.use_mean_sym(); } self.step_size - .adapt(options, potential, draw, &collector.collector1); + .adapt(math, options, potential, draw, &collector.collector1); return; } @@ -521,18 +516,20 @@ impl AdaptStrategy for GradDiagStrategy { self.step_size.finalize(); } self.step_size - .adapt(options, potential, draw, &collector.collector1); + .adapt(math, options, potential, draw, &collector.collector1); } - fn new_collector(&self) -> Self::Collector { + fn new_collector(&self, math: &mut M) -> Self::Collector { CombinedCollector { - collector1: self.step_size.new_collector(), - collector2: self.mass_matrix.new_collector(), + collector1: self.step_size.new_collector(math), + collector2: self.mass_matrix.new_collector(math), + _phantom: PhantomData::default(), } } fn current_stats( &self, + math: &mut M, options: &NutsOptions, potential: &Self::Potential, collector: &Self::Collector, @@ -540,10 +537,10 @@ impl AdaptStrategy for GradDiagStrategy { CombinedStats { stats1: self .step_size - .current_stats(options, potential, &collector.collector1), + .current_stats(math, options, potential, &collector.collector1), stats2: self .mass_matrix - .current_stats(options, potential, &collector.collector2), + .current_stats(math, options, potential, &collector.collector2), } } } @@ -610,47 +607,52 @@ impl ArrowBuilder { +pub(crate) struct CombinedCollector, C2: Collector> { collector1: C1, collector2: C2, + _phantom: PhantomData, } -impl Collector for CombinedCollector +impl Collector for CombinedCollector where - C1: Collector, - C2: Collector, + C1: Collector, + C2: Collector, { - type State = C1::State; - fn register_leapfrog( &mut self, - start: &Self::State, - end: &Self::State, + math: &mut M, + start: &State, + end: &State, divergence_info: Option<&DivergenceInfo>, ) { self.collector1 - .register_leapfrog(start, end, divergence_info); + .register_leapfrog(math, start, end, divergence_info); self.collector2 - .register_leapfrog(start, end, divergence_info); + .register_leapfrog(math, start, end, divergence_info); } - fn register_draw(&mut self, state: &Self::State, info: &crate::nuts::SampleInfo) { - self.collector1.register_draw(state, info); - self.collector2.register_draw(state, info); + fn register_draw(&mut self, math: &mut M, state: &State, info: &crate::nuts::SampleInfo) { + self.collector1.register_draw(math, state, info); + self.collector2.register_draw(math, state, info); } - fn register_init(&mut self, state: &Self::State, options: &crate::nuts::NutsOptions) { - self.collector1.register_init(state, options); - self.collector2.register_init(state, options); + fn register_init( + &mut self, + math: &mut M, + state: &State, + options: &crate::nuts::NutsOptions, + ) { + self.collector1.register_init(math, state, options); + self.collector2.register_init(math, state, options); } } #[cfg(test)] pub mod test_logps { - use crate::{cpu_potential::CpuLogpFunc, nuts::LogpError}; + use crate::{cpu_math::CpuLogpFunc, nuts::LogpError}; use thiserror::Error; - #[derive(Clone)] + #[derive(Clone, Debug)] pub struct NormalLogp { dim: usize, mu: f64, @@ -671,7 +673,7 @@ pub mod test_logps { } impl CpuLogpFunc for NormalLogp { - type Err = NormalLogpError; + type LogpError = NormalLogpError; fn dim(&self) -> usize { self.dim @@ -695,21 +697,25 @@ pub mod test_logps { mod test { use super::test_logps::NormalLogp; use super::*; - use crate::nuts::{AdaptStrategy, Chain, NutsChain, NutsOptions}; + use crate::{ + cpu_math::CpuMath, + nuts::{AdaptStrategy, Chain, NutsChain, NutsOptions}, + }; #[test] fn instanciate_adaptive_sampler() { let ndim = 10; let func = NormalLogp::new(ndim, 3.); + let mut math = CpuMath::new(func); let num_tune = 100; let options = GradDiagOptions::default(); - let strategy = GradDiagStrategy::new(options, num_tune, ndim); + let strategy = GradDiagStrategy::new(&mut math, options, num_tune); - let mass_matrix = DiagMassMatrix::new(ndim); + let mass_matrix = DiagMassMatrix::new(&mut math); let max_energy_error = 1000f64; let step_size = 0.1f64; - let potential = EuclideanPotential::new(func, mass_matrix, max_energy_error, step_size); + let potential = EuclideanPotential::new(mass_matrix, max_energy_error, step_size); let options = NutsOptions { maxdepth: 10u64, store_gradient: true, @@ -722,7 +728,7 @@ mod test { }; let chain = 0u64; - let mut sampler = NutsChain::new(potential, strategy, options, rng, chain); + let mut sampler = NutsChain::new(math, potential, strategy, options, rng, chain); sampler.set_position(&vec![1.5f64; ndim]).unwrap(); for _ in 0..200 { sampler.draw().unwrap(); diff --git a/src/cpu_math.rs b/src/cpu_math.rs new file mode 100644 index 0000000..f335e5e --- /dev/null +++ b/src/cpu_math.rs @@ -0,0 +1,224 @@ +use std::{error::Error, fmt::Debug}; + +use itertools::izip; +use thiserror::Error; + +use crate::{ + math::{axpy, axpy_out, multiply, scalar_prods2, scalar_prods3, vector_dot}, + math_base::Math, + LogpError, +}; + +pub struct CpuMath { + logp_func: F, + arch: pulp::Arch, +} + +impl CpuMath { + pub fn new(logp_func: F) -> Self { + let arch = pulp::Arch::new(); + Self { logp_func, arch } + } +} + +#[non_exhaustive] +#[derive(Error, Debug)] +pub enum CpuMathError { + #[error("Error during array operation")] + ArrayError(), +} + +impl Math for CpuMath { + type Array = faer_core::Mat; + type LogpErr = F::LogpError; + type Err = CpuMathError; + + fn new_array(&self) -> Self::Array { + faer_core::Mat::zeros(self.dim(), 1) + } + + fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result { + self.logp_func.logp(position, gradient) + } + + fn logp_array( + &mut self, + position: &Self::Array, + gradient: &mut Self::Array, + ) -> Result { + self.logp_func + .logp(position.col_ref(0), gradient.col_mut(0)) + } + + fn dim(&self) -> usize { + self.logp_func.dim() + } + + fn scalar_prods3( + &mut self, + positive1: &Self::Array, + negative1: &Self::Array, + positive2: &Self::Array, + x: &Self::Array, + y: &Self::Array, + ) -> (f64, f64) { + scalar_prods3( + positive1.col_ref(0), + negative1.col_ref(0), + positive2.col_ref(0), + x.col_ref(0), + y.col_ref(0), + ) + } + + fn scalar_prods2( + &mut self, + positive1: &Self::Array, + positive2: &Self::Array, + x: &Self::Array, + y: &Self::Array, + ) -> (f64, f64) { + scalar_prods2( + positive1.col_ref(0), + positive2.col_ref(0), + x.col_ref(0), + y.col_ref(0), + ) + } + + fn read_from_slice(&mut self, dest: &mut Self::Array, source: &[f64]) { + dest.col_mut(0).copy_from_slice(source); + } + + fn write_to_slice(&mut self, source: &Self::Array, dest: &mut [f64]) { + dest.copy_from_slice(source.col_ref(0)) + } + + fn copy_into(&mut self, array: &Self::Array, dest: &mut Self::Array) { + dest.clone_from(array) + } + + fn axpy_out(&mut self, x: &Self::Array, y: &Self::Array, a: f64, out: &mut Self::Array) { + axpy_out(x.col_ref(0), y.col_ref(0), a, out.col_mut(0)); + } + + fn axpy(&mut self, x: &Self::Array, y: &mut Self::Array, a: f64) { + axpy(x.col_ref(0), y.col_mut(0), a); + } + + fn fill_array(&mut self, array: &mut Self::Array, val: f64) { + array.fill(val); + } + + fn array_all_finite(&mut self, array: &Self::Array) -> bool { + array.is_all_finite() + } + + fn array_all_finite_and_nonzero(&mut self, array: &Self::Array) -> bool { + array + .col_ref(0) + .iter() + .all(|&x| x.is_finite() & (x != 0f64)) + } + + fn array_mult(&mut self, array1: &Self::Array, array2: &Self::Array, dest: &mut Self::Array) { + multiply(array1.col_ref(0), array2.col_ref(0), dest.col_mut(0)) + } + + fn array_vector_dot(&mut self, array1: &Self::Array, array2: &Self::Array) -> f64 { + vector_dot(array1.col_ref(0), array2.col_ref(0)) + } + + fn array_gaussian( + &mut self, + rng: &mut R, + dest: &mut Self::Array, + stds: &Self::Array, + ) { + let dist = rand_distr::StandardNormal; + dest.col_mut(0) + .iter_mut() + .zip(stds.col_ref(0).iter()) + .for_each(|(p, &s)| { + let norm: f64 = rng.sample(dist); + *p = s * norm; + }); + } + + fn array_update_variance( + &mut self, + mean: &mut Self::Array, + variance: &mut Self::Array, + value: &Self::Array, + diff_scale: f64, // 1 / self.count + ) { + izip!( + mean.col_mut(0).iter_mut(), + variance.col_mut(0).iter_mut(), + value.col_ref(0) + ) + .for_each(|(mean, mut var, x)| { + let diff = x - *mean; + *mean += diff * diff_scale; + *var += diff * diff; + }); + } + + fn array_update_var_inv_std_draw_grad( + &mut self, + variance_out: &mut Self::Array, + inv_std: &mut Self::Array, + draw_var: &Self::Array, + grad_var: &Self::Array, + fill_invalid: Option, + clamp: (f64, f64), + ) { + izip!( + variance_out.col_mut(0).iter_mut(), + inv_std.col_mut(0).iter_mut(), + draw_var.col_ref(0).iter(), + grad_var.col_ref(0).iter(), + ) + .for_each(|(var_out, inv_std_out, &draw_var, &grad_var)| { + let val = (draw_var / grad_var).sqrt(); + if (!val.is_finite()) | (val == 0f64) { + if let Some(fill_val) = fill_invalid { + *var_out = fill_val; + *inv_std_out = fill_val.recip().sqrt(); + } + } else { + let val = val.clamp(clamp.0, clamp.1); + *var_out = val; + *inv_std_out = val.recip().sqrt(); + } + }); + } + + fn array_update_var_inv_std_grad( + &mut self, + variance_out: &mut Self::Array, + inv_std: &mut Self::Array, + gradient: &Self::Array, + fill_invalid: f64, + clamp: (f64, f64), + ) { + izip!( + variance_out.col_mut(0).iter_mut(), + inv_std.col_mut(0).iter_mut(), + gradient.col_ref(0).iter(), + ) + .for_each(|(var_out, inv_std_out, &grad_var)| { + let val = grad_var.abs().clamp(clamp.0, clamp.1).recip(); + let val = if val.is_finite() { val } else { fill_invalid }; + *var_out = val; + *inv_std_out = val.recip().sqrt(); + }); + } +} + +pub trait CpuLogpFunc { + type LogpError: Debug + Send + Sync + Error + LogpError + 'static; + + fn dim(&self) -> usize; + fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result; +} diff --git a/src/cpu_state.rs b/src/cpu_state.rs deleted file mode 100644 index 4055a30..0000000 --- a/src/cpu_state.rs +++ /dev/null @@ -1,342 +0,0 @@ -use std::{ - cell::RefCell, - fmt::Debug, - ops::{Deref, DerefMut}, - rc::{Rc, Weak}, -}; - -use crate::math::{axpy, axpy_out, scalar_prods2, scalar_prods3}; - -#[derive(Debug)] -struct StateStorage { - free_states: RefCell>>, -} - -impl StateStorage { - fn new() -> StateStorage { - StateStorage { - free_states: RefCell::new(Vec::with_capacity(20)), - } - } -} - -impl ReuseState for StateStorage { - fn reuse_state(&self, state: Rc) { - self.free_states.borrow_mut().push(state) - } -} - -pub(crate) struct StatePool { - storage: Rc, - dim: usize, -} - -impl StatePool { - pub(crate) fn new(dim: usize) -> StatePool { - StatePool { - storage: Rc::new(StateStorage::new()), - dim, - } - } - - pub(crate) fn new_state(&mut self) -> State { - let inner = match self.storage.free_states.borrow_mut().pop() { - Some(inner) => { - if self.dim != inner.inner.q.len() { - panic!("dim mismatch"); - } - inner - } - None => { - let owner: Rc = self.storage.clone(); - Rc::new(InnerStateReusable::new(self.dim, &owner)) - } - }; - State { - inner: std::mem::ManuallyDrop::new(inner), - } - } -} - -trait ReuseState: Debug { - fn reuse_state(&self, state: Rc); -} - -#[derive(Debug, Clone)] -pub(crate) struct InnerState { - pub(crate) p: Box<[f64]>, - pub(crate) q: Box<[f64]>, - pub(crate) v: Box<[f64]>, - pub(crate) p_sum: Box<[f64]>, - pub(crate) grad: Box<[f64]>, - pub(crate) idx_in_trajectory: i64, - pub(crate) kinetic_energy: f64, - pub(crate) potential_energy: f64, -} - -#[derive(Debug)] -pub(crate) struct InnerStateReusable { - inner: InnerState, - reuser: Weak, -} - -#[derive(Debug)] -pub(crate) struct AlignedArray { - size: usize, - data: *mut f64, -} - -impl AlignedArray { - pub(crate) fn new(size: usize) -> Self { - let layout = AlignedArray::make_layout(size); - // Alignment must match alignment of AlignedArrayInner - let ptr = unsafe { std::alloc::alloc_zeroed(layout) }; - if ptr.is_null() { - std::alloc::handle_alloc_error(layout); - } - Self { - data: ptr as *mut f64, - size, - } - } - - fn make_layout(size: usize) -> std::alloc::Layout { - std::alloc::Layout::from_size_align( - std::mem::size_of::().checked_mul(size).unwrap(), - 64, - ) - .unwrap() - } -} - -impl Drop for AlignedArray { - fn drop(&mut self) { - let layout = AlignedArray::make_layout(self.size); - unsafe { std::alloc::dealloc(self.data as *mut u8, layout) }; - } -} - -impl Clone for AlignedArray { - fn clone(&self) -> Self { - let mut new = AlignedArray::new(self.size); - new.copy_from_slice(self); - new - } -} - -impl Deref for AlignedArray { - type Target = [f64]; - - fn deref(&self) -> &Self::Target { - unsafe { std::slice::from_raw_parts(self.data, self.size) } - } -} - -impl DerefMut for AlignedArray { - fn deref_mut(&mut self) -> &mut Self::Target { - unsafe { std::slice::from_raw_parts_mut(self.data, self.size) } - } -} - -unsafe impl Send for AlignedArray {} - -impl InnerStateReusable { - fn new(size: usize, owner: &Rc) -> InnerStateReusable { - InnerStateReusable { - inner: InnerState { - p: vec![0.; size].into(), - //p: AlignedArray::new(size), - q: vec![0.; size].into(), - //q: AlignedArray::new(size), - v: vec![0.; size].into(), - //v: AlignedArray::new(size), - p_sum: vec![0.; size].into(), - //p_sum: AlignedArray::new(size), - grad: vec![0.; size].into(), - //grad: AlignedArray::new(size), - idx_in_trajectory: 0, - kinetic_energy: 0., - potential_energy: 0., - }, - reuser: Rc::downgrade(owner), - } - } -} - -#[derive(Debug)] -pub(crate) struct State { - inner: std::mem::ManuallyDrop>, -} - -impl Deref for State { - type Target = InnerState; - - fn deref(&self) -> &Self::Target { - &self.inner.inner - } -} - -#[derive(Debug)] -pub(crate) struct StateInUse {} - -type Result = std::result::Result; - -impl State { - pub(crate) fn try_mut_inner(&mut self) -> Result<&mut InnerState> { - match Rc::get_mut(&mut self.inner) { - Some(val) => Ok(&mut val.inner), - None => Err(StateInUse {}), - } - } -} - -impl Drop for State { - fn drop(&mut self) { - let mut rc = unsafe { std::mem::ManuallyDrop::take(&mut self.inner) }; - if let Some(state_ref) = Rc::get_mut(&mut rc) { - if let Some(reuser) = &mut state_ref.reuser.upgrade() { - reuser.reuse_state(rc); - } - } - } -} - -impl Clone for State { - fn clone(&self) -> Self { - State { - inner: self.inner.clone(), - } - } -} - -impl crate::nuts::State for State { - type Pool = StatePool; - - fn is_turning(&self, other: &Self) -> bool { - let (start, end) = if self.idx_in_trajectory < other.idx_in_trajectory { - (self, other) - } else { - (other, self) - }; - - let a = start.idx_in_trajectory; - let b = end.idx_in_trajectory; - - assert!(a < b); - let (turn1, turn2) = if (a >= 0) & (b >= 0) { - scalar_prods3(&end.p_sum, &start.p_sum, &start.p, &end.v, &start.v) - } else if (b >= 0) & (a < 0) { - scalar_prods2(&end.p_sum, &start.p_sum, &end.v, &start.v) - } else { - assert!((a < 0) & (b < 0)); - scalar_prods3(&start.p_sum, &end.p_sum, &end.p, &end.v, &start.v) - }; - - (turn1 < 0.) | (turn2 < 0.) - } - - fn write_position(&self, out: &mut [f64]) { - out.copy_from_slice(&self.q) - } - - fn write_gradient(&self, out: &mut [f64]) { - out.copy_from_slice(&self.grad); - } - - fn write_momentum(&self, out: &mut [f64]) { - out.copy_from_slice(&self.p); - } - - fn energy(&self) -> f64 { - self.kinetic_energy + self.potential_energy - } - - fn index_in_trajectory(&self) -> i64 { - self.idx_in_trajectory - } - - fn make_init_point(&mut self) { - let inner = self.try_mut_inner().unwrap(); - inner.idx_in_trajectory = 0; - inner.p_sum.copy_from_slice(&inner.p); - } - - fn potential_energy(&self) -> f64 { - self.potential_energy - } -} - -impl State { - pub(crate) fn first_momentum_halfstep(&self, out: &mut Self, epsilon: f64) { - axpy_out( - &self.grad, - &self.p, - epsilon / 2., - &mut out.try_mut_inner().expect("State already in use").p, - ); - } - - pub(crate) fn position_step(&self, out: &mut Self, epsilon: f64) { - let out = out.try_mut_inner().expect("State already in use"); - axpy_out(&out.v, &self.q, epsilon, &mut out.q); - } - - pub(crate) fn second_momentum_halfstep(&mut self, epsilon: f64) { - let inner = self.try_mut_inner().expect("State already in use"); - axpy(&inner.grad, &mut inner.p, epsilon / 2.); - } - - pub(crate) fn set_psum(&self, target: &mut Self, _dir: crate::nuts::Direction) { - let out = target.try_mut_inner().expect("State already in use"); - - assert!(out.idx_in_trajectory != 0); - - if out.idx_in_trajectory == -1 { - out.p_sum.copy_from_slice(&out.p); - } else { - axpy_out(&out.p, &self.p_sum, 1., &mut out.p_sum); - } - } - - pub(crate) fn index_in_trajectory(&self) -> i64 { - self.idx_in_trajectory - } - - pub(crate) fn index_in_trajectory_mut(&mut self) -> &mut i64 { - &mut self - .try_mut_inner() - .expect("State already in use") - .idx_in_trajectory - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn crate_pool() { - let mut pool = StatePool::new(10); - let mut state = pool.new_state(); - assert!(state.p.len() == 10); - state.try_mut_inner().unwrap(); - - let mut state2 = state.clone(); - assert!(state.try_mut_inner().is_err()); - assert!(state2.try_mut_inner().is_err()); - } - - #[test] - fn make_state() { - let dim = 10; - let mut pool = StatePool::new(dim); - let a = pool.new_state(); - - assert_eq!(a.idx_in_trajectory, 0); - assert!(a.p_sum.iter().all(|&x| x == 0f64)); - assert_eq!(a.p_sum.len(), dim); - assert_eq!(a.grad.len(), dim); - assert_eq!(a.q.len(), dim); - assert_eq!(a.p.len(), dim); - } -} diff --git a/src/lib.rs b/src/lib.rs index 4c1a814..1f21b3f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,12 +12,13 @@ //! ## Usage //! //! ``` -//! use nuts_rs::{CpuLogpFunc, LogpError, new_sampler, SamplerArgs, Chain, SampleStats}; +//! use nuts_rs::{CpuLogpFunc, CpuMath, LogpError, new_sampler, SamplerArgs, Chain, SampleStats}; //! use thiserror::Error; //! use rand::thread_rng; //! //! // Define a function that computes the unnormalized posterior density //! // and its gradient. +//! #[derive(Debug)] //! struct PosteriorDensity {} //! //! // The density might fail in a recoverable or non-recoverable manner... @@ -28,13 +29,13 @@ //! } //! //! impl CpuLogpFunc for PosteriorDensity { -//! type Err = PosteriorLogpError; +//! type LogpError = PosteriorLogpError; //! //! // We define a 10 dimensional normal distribution //! fn dim(&self) -> usize { 10 } //! //! // The normal likelihood with mean 3 and its gradient. -//! fn logp(&mut self, position: &[f64], grad: &mut [f64]) -> Result { +//! fn logp(&mut self, position: &[f64], grad: &mut [f64]) -> Result { //! let mu = 3f64; //! let logp = position //! .iter() @@ -59,10 +60,11 @@ //! //! // We instanciate our posterior density function //! let logp_func = PosteriorDensity {}; +//! let math = CpuMath::new(logp_func); //! //! let chain = 0; //! let mut rng = thread_rng(); -//! let mut sampler = new_sampler(logp_func, sampler_args, chain, &mut rng); +//! let mut sampler = new_sampler(math, sampler_args, chain, &mut rng); //! //! // Set to some initial position and start drawing samples. //! sampler.set_position(&vec![0f64; 10]).expect("Unrecoverable error during init"); @@ -78,9 +80,6 @@ //! } //! ``` //! -//! Sampling several chains in parallel so that samples are accessable as they are generated -//! is implemented in [`sample_parallel`]. -//! //! ## Implementation details //! //! This crate mostly follows the implementation of NUTS in [Stan](https://mc-stan.org) and @@ -88,21 +87,22 @@ //! somewhat. pub(crate) mod adapt_strategy; -pub(crate) mod cpu_potential; -pub(crate) mod cpu_sampler; -pub(crate) mod cpu_state; +pub(crate) mod cpu_math; pub(crate) mod mass_matrix; -pub mod math; +pub(crate) mod math; +pub(crate) mod math_base; pub(crate) mod nuts; +pub(crate) mod potential; +pub(crate) mod sampler; +pub(crate) mod state; pub(crate) mod stepsize; pub use adapt_strategy::DualAverageSettings; -pub use cpu_potential::CpuLogpFunc; -pub use cpu_sampler::test_logps; -pub use cpu_sampler::{ - new_sampler, sample_parallel, sample_sequentially, CpuLogpFuncMaker, InitPointFunc, - JitterInitFunc, ParallelChainResult, ParallelSamplingError, SamplerArgs, -}; +pub use cpu_math::{CpuLogpFunc, CpuMath}; #[cfg(feature = "arrow")] pub use nuts::{ArrowBuilder, ArrowRow}; pub use nuts::{Chain, DivergenceInfo, LogpError, NutsError, SampleStats}; +pub use sampler::test_logps; +pub use sampler::{ + new_sampler, sample_sequentially, InitPointFunc, JitterInitFunc, MathMaker, SamplerArgs, +}; diff --git a/src/mass_matrix.rs b/src/mass_matrix.rs index d4c0af6..c6b89bf 100644 --- a/src/mass_matrix.rs +++ b/src/mass_matrix.rs @@ -2,39 +2,72 @@ use itertools::izip; use multiversion::multiversion; use crate::{ - cpu_state::{InnerState, State}, - math::{multiply, vector_dot}, + math_base::Math, nuts::Collector, + state::{InnerState, State}, }; -pub(crate) trait MassMatrix { - fn update_velocity(&self, state: &mut InnerState); - fn update_kinetic_energy(&self, state: &mut InnerState); - fn randomize_momentum(&self, state: &mut InnerState, rng: &mut R); +pub(crate) trait MassMatrix { + fn update_velocity(&self, math: &mut M, state: &mut InnerState); + fn update_kinetic_energy(&self, math: &mut M, state: &mut InnerState); + fn randomize_momentum( + &self, + math: &mut M, + state: &mut InnerState, + rng: &mut R, + ); } pub(crate) struct NullCollector {} -impl Collector for NullCollector { - type State = State; -} +impl Collector for NullCollector {} #[derive(Debug)] -pub(crate) struct DiagMassMatrix { - inv_stds: Box<[f64]>, - pub(crate) variance: Box<[f64]>, +pub(crate) struct DiagMassMatrix { + inv_stds: M::Array, + pub(crate) variance: M::Array, } -impl DiagMassMatrix { - pub(crate) fn new(ndim: usize) -> Self { +impl DiagMassMatrix { + pub(crate) fn new(math: &mut M) -> Self { Self { - inv_stds: vec![0f64; ndim].into(), - variance: vec![0f64; ndim].into(), + inv_stds: math.new_array(), + variance: math.new_array(), } } - pub(crate) fn update_diag(&mut self, new_variance: impl Iterator>) { - update_diag(&mut self.variance, &mut self.inv_stds, new_variance); + pub(crate) fn update_diag_draw_grad( + &mut self, + math: &mut M, + draw_var: &M::Array, + grad_var: &M::Array, + fill_invalid: Option, + clamp: (f64, f64), + ) { + math.array_update_var_inv_std_draw_grad( + &mut self.variance, + &mut self.inv_stds, + draw_var, + grad_var, + fill_invalid, + clamp, + ); + } + + pub(crate) fn update_diag_grad( + &mut self, + math: &mut M, + gradient: &M::Array, + fill_invalid: f64, + clamp: (f64, f64), + ) { + math.array_update_var_inv_std_grad( + &mut self.variance, + &mut self.inv_stds, + gradient, + fill_invalid, + clamp, + ); } } @@ -55,64 +88,55 @@ fn update_diag( }); } -impl MassMatrix for DiagMassMatrix { - fn update_velocity(&self, state: &mut InnerState) { - multiply(&self.variance, &state.p, &mut state.v); +impl MassMatrix for DiagMassMatrix { + fn update_velocity(&self, math: &mut M, state: &mut InnerState) { + math.array_mult(&self.variance, &state.p, &mut state.v); } - fn update_kinetic_energy(&self, state: &mut InnerState) { - state.kinetic_energy = 0.5 * vector_dot(&state.p, &state.v); + fn update_kinetic_energy(&self, math: &mut M, state: &mut InnerState) { + state.kinetic_energy = 0.5 * math.array_vector_dot(&state.p, &state.v); } - fn randomize_momentum(&self, state: &mut InnerState, rng: &mut R) { - let dist = rand_distr::StandardNormal; - state - .p - .iter_mut() - .zip(self.inv_stds.iter()) - .for_each(|(p, &s)| { - let norm: f64 = rng.sample(dist); - *p = s * norm; - }); + fn randomize_momentum( + &self, + math: &mut M, + state: &mut InnerState, + rng: &mut R, + ) { + math.array_gaussian(rng, &mut state.p, &self.inv_stds); } } #[derive(Debug)] -pub(crate) struct RunningVariance { - mean: Box<[f64]>, - variance: Box<[f64]>, +pub(crate) struct RunningVariance { + mean: M::Array, + variance: M::Array, count: u64, } -impl RunningVariance { - pub(crate) fn new(dim: usize) -> Self { +impl RunningVariance { + pub(crate) fn new(math: &mut M) -> Self { Self { - mean: vec![0f64; dim].into(), - variance: vec![0f64; dim].into(), + mean: math.new_array(), + variance: math.new_array(), count: 0, } } - pub(crate) fn add_sample(&mut self, value: impl Iterator) { + //pub(crate) fn add_sample(&mut self, value: impl Iterator) { + pub(crate) fn add_sample(&mut self, math: &mut M, value: &M::Array) { self.count += 1; if self.count == 1 { - izip!(self.mean.iter_mut(), value).for_each(|(mean, val)| { - *mean = val; - }); + math.copy_into(value, &mut self.mean); } else { - izip!(self.mean.iter_mut(), self.variance.iter_mut(), value).for_each( - |(mean, var, x)| { - let diff = x - *mean; - *mean += diff / (self.count as f64); - *var += diff * diff; - }, - ); + math.array_update_variance(&mut self.mean, &mut self.variance, value, (self.count as f64).recip()); } } - pub(crate) fn current(&self) -> impl Iterator + '_ { + /// Return current variance and scaling factor + pub(crate) fn current(&self) -> (&M::Array, f64) { assert!(self.count > 1); - self.variance.iter().map(|&x| x / ((self.count - 1) as f64)) + (&self.variance, ((self.count - 1) as f64).recip()) } pub(crate) fn count(&self) -> u64 { @@ -120,28 +144,26 @@ impl RunningVariance { } } -pub(crate) struct DrawGradCollector { - pub(crate) draw: Box<[f64]>, - pub(crate) grad: Box<[f64]>, +pub(crate) struct DrawGradCollector { + pub(crate) draw: M::Array, + pub(crate) grad: M::Array, pub(crate) is_good: bool, } -impl DrawGradCollector { - pub(crate) fn new(dim: usize) -> Self { +impl DrawGradCollector { + pub(crate) fn new(math: &mut M) -> Self { DrawGradCollector { - draw: vec![0f64; dim].into(), - grad: vec![0f64; dim].into(), + draw: math.new_array(), + grad: math.new_array(), is_good: true, } } } -impl Collector for DrawGradCollector { - type State = State; - - fn register_draw(&mut self, state: &Self::State, info: &crate::nuts::SampleInfo) { - self.draw.copy_from_slice(&state.q); - self.grad.copy_from_slice(&state.grad); +impl Collector for DrawGradCollector { + fn register_draw(&mut self, math: &mut M, state: &State, info: &crate::nuts::SampleInfo) { + math.copy_into(&state.q, &mut self.draw); + math.copy_into(&state.grad, &mut self.grad); let idx = state.index_in_trajectory(); if info.divergence_info.is_some() { self.is_good = idx.abs() > 4; diff --git a/src/math_base.rs b/src/math_base.rs new file mode 100644 index 0000000..019a1da --- /dev/null +++ b/src/math_base.rs @@ -0,0 +1,102 @@ +use std::{error::Error, fmt::Debug}; + +use crate::LogpError; + +pub trait Math { + type Array: Debug; + type LogpErr: Debug + Send + Sync + LogpError + 'static; + type Err: Debug + Send + Sync + Error + 'static; + + fn new_array(&self) -> Self::Array; + + /// Compute the unnormalized log probability density of the posterior + /// + /// This needs to be implemnted by users of the library to define + /// what distribution the users wants to sample from. + /// + /// Errors during that computation can be recoverable or non-recoverable. + /// If a non-recoverable error occurs during sampling, the sampler will + /// stop and return an error. + fn logp_array( + &mut self, + position: &Self::Array, + gradient: &mut Self::Array, + ) -> Result; + + fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result; + + fn dim(&self) -> usize; + + fn scalar_prods3( + &mut self, + positive1: &Self::Array, + negative1: &Self::Array, + positive2: &Self::Array, + x: &Self::Array, + y: &Self::Array, + ) -> (f64, f64); + + fn scalar_prods2( + &mut self, + positive1: &Self::Array, + positive2: &Self::Array, + x: &Self::Array, + y: &Self::Array, + ) -> (f64, f64); + + fn read_from_slice(&mut self, dest: &mut Self::Array, source: &[f64]); + fn write_to_slice(&mut self, source: &Self::Array, dest: &mut [f64]); + fn copy_into(&mut self, array: &Self::Array, dest: &mut Self::Array); + fn axpy_out(&mut self, x: &Self::Array, y: &Self::Array, a: f64, out: &mut Self::Array); + fn axpy(&mut self, x: &Self::Array, y: &mut Self::Array, a: f64); + + fn box_array(&mut self, array: &Self::Array) -> Box<[f64]> { + let mut data = vec![0f64; self.dim()]; + self.write_to_slice(array, &mut data); + data.into() + } + + fn fill_array(&mut self, array: &mut Self::Array, val: f64); + + fn array_all_finite(&mut self, array: &Self::Array) -> bool; + fn array_all_finite_and_nonzero(&mut self, array: &Self::Array) -> bool; + fn array_mult(&mut self, array1: &Self::Array, array2: &Self::Array, dest: &mut Self::Array); + fn array_vector_dot(&mut self, array1: &Self::Array, array2: &Self::Array) -> f64; + fn array_gaussian( + &mut self, + rng: &mut R, + dest: &mut Self::Array, + stds: &Self::Array, + ); + fn array_update_variance( + &mut self, + mean: &mut Self::Array, + variance: &mut Self::Array, + value: &Self::Array, + diff_scale: f64, + ); + fn array_update_var_inv_std_draw_grad( + &mut self, + variance_out: &mut Self::Array, + inv_std: &mut Self::Array, + draw_var: &Self::Array, + grad_var: &Self::Array, + fill_invalid: Option, + clamp: (f64, f64), + ); + + fn array_update_var_inv_std_grad( + &mut self, + variance_out: &mut Self::Array, + inv_std: &mut Self::Array, + gradient: &Self::Array, + fill_invalid: f64, + clamp: (f64, f64), + ); +} + +trait Array { + fn write(&self, out: &mut [f64]); + fn elemwise_mult(&self, other: &Self, out: &mut Self); + fn len(&self) -> usize; +} diff --git a/src/nuts.rs b/src/nuts.rs index 4d9f6bc..4787068 100644 --- a/src/nuts.rs +++ b/src/nuts.rs @@ -10,7 +10,9 @@ use thiserror::Error; use std::{fmt::Debug, marker::PhantomData}; use crate::math::logaddexp; +use crate::state::{State, StatePool}; +use crate::math_base::Math; #[cfg(feature = "arrow")] use crate::SamplerArgs; @@ -68,18 +70,17 @@ impl rand::distributions::Distribution for rand::distributions::Stand /// /// Collectors can compute statistics like the mean acceptance rate /// or collect data for mass matrix adaptation. -pub trait Collector { - type State: State; - +pub(crate) trait Collector { fn register_leapfrog( &mut self, - _start: &Self::State, - _end: &Self::State, + _math: &mut M, + _start: &State, + _end: &State, _divergence_info: Option<&DivergenceInfo>, ) { } - fn register_draw(&mut self, _state: &Self::State, _info: &SampleInfo) {} - fn register_init(&mut self, _state: &Self::State, _options: &NutsOptions) {} + fn register_draw(&mut self, _math: &mut M, _state: &State, _info: &SampleInfo) {} + fn register_init(&mut self, _math: &mut M, _state: &State, _options: &NutsOptions) {} } /// Errors that happen when we evaluate the logp and gradient function @@ -90,9 +91,12 @@ pub trait LogpError: std::error::Error { } /// The hamiltonian defined by the potential energy and the kinetic energy -pub trait Hamiltonian { +pub(crate) trait Hamiltonian +where + M: Math, +{ /// The type that stores a point in phase space - type State: State; + //type State: State; /// Errors that happen during logp evaluation type LogpError: LogpError + Send; /// Statistics that should be exported to the trace as part of the sampler stats @@ -105,14 +109,15 @@ pub trait Hamiltonian { /// Perform one leapfrog step. /// /// Return either an unrecoverable error, a new state or a divergence. - fn leapfrog>( + fn leapfrog>( &mut self, - pool: &mut ::Pool, - start: &Self::State, + math: &mut M, + pool: &mut StatePool, + start: &State, dir: Direction, initial_energy: f64, collector: &mut C, - ) -> Result>; + ) -> Result, DivergenceInfo>>; /// Initialize a state at a new location. /// @@ -120,25 +125,29 @@ pub trait Hamiltonian { /// it will later be set using Self::randomize_momentum. fn init_state( &mut self, - pool: &mut ::Pool, + math: &mut M, + pool: &mut StatePool, init: &[f64], - ) -> Result; + ) -> Result>; /// Randomize the momentum part of a state - fn randomize_momentum(&self, state: &mut Self::State, rng: &mut R); + fn randomize_momentum( + &self, + math: &mut M, + state: &mut State, + rng: &mut R, + ); /// Return sampler statistics defined in Self::Stats fn current_stats(&self) -> Self::Stats; - fn new_empty_state(&mut self, pool: &mut ::Pool) -> Self::State; + fn new_empty_state(&mut self, math: &mut M, pool: &mut StatePool) -> State; /// Crate a new state pool that can be used to crate new states. - fn new_pool(&mut self, capacity: usize) -> ::Pool; - - /// The dimension of the hamiltonian (position only). - fn dim(&self) -> usize; + fn new_pool(&mut self, math: &mut M, capacity: usize) -> StatePool; } +/* /// A point in phase space /// /// This also needs to store the sum of momentum terms @@ -176,6 +185,7 @@ pub trait State: Clone + Debug { (initial_energy - self.energy()).min(0.) } } +*/ /// Information about a draw, exported as part of the sampler stats #[derive(Debug)] @@ -196,17 +206,17 @@ pub struct SampleInfo { } /// A part of the trajectory tree during NUTS sampling. -struct NutsTree> { +struct NutsTree, C: Collector> { /// The left position of the tree. /// /// The left side always has the smaller index_in_trajectory. /// Leapfrogs in backward direction will replace the left. - left: P::State, - right: P::State, + left: State, + right: State, /// A draw from the trajectory between left and right using /// multinomial sampling. - draw: P::State, + draw: State, log_size: f64, depth: u64, initial_energy: f64, @@ -214,24 +224,25 @@ struct NutsTree> { /// A tree is the main tree if it contains the initial point /// of the trajectory. is_main: bool, - collector: PhantomData, + _phantom: PhantomData, + _phantom2: PhantomData, } -enum ExtendResult> { +enum ExtendResult, C: Collector> { /// The tree extension succeeded properly, and the termination /// criterion was not reached. - Ok(NutsTree), + Ok(NutsTree), /// An unrecoverable error happend during a leapfrog step Err(NutsError), /// Tree extension succeeded and the termination criterion /// was reached. - Turning(NutsTree), + Turning(NutsTree), /// A divergence happend during tree extension. - Diverging(NutsTree, DivergenceInfo), + Diverging(NutsTree, DivergenceInfo), } -impl> NutsTree { - fn new(state: P::State) -> NutsTree { +impl, C: Collector> NutsTree { + fn new(state: State) -> NutsTree { let initial_energy = state.energy(); NutsTree { right: state.clone(), @@ -241,24 +252,26 @@ impl> NutsTree { log_size: 0., initial_energy, is_main: true, - collector: PhantomData, + _phantom: PhantomData::default(), + _phantom2: PhantomData::default(), } } #[inline] fn extend( mut self, - pool: &mut ::Pool, + math: &mut M, + pool: &mut StatePool, rng: &mut R, - potential: &mut P, + potential: &mut H, direction: Direction, collector: &mut C, - ) -> ExtendResult + ) -> ExtendResult where - P: Hamiltonian, + H: Hamiltonian, R: rand::Rng + ?Sized, { - let mut other = match self.single_step(pool, potential, direction, collector) { + let mut other = match self.single_step(math, pool, potential, direction, collector) { Ok(Ok(tree)) => tree, Ok(Err(info)) => return ExtendResult::Diverging(self, info), Err(err) => return ExtendResult::Err(err), @@ -266,7 +279,7 @@ impl> NutsTree { while other.depth < self.depth { use ExtendResult::*; - other = match other.extend(pool, rng, potential, direction, collector) { + other = match other.extend(math, pool, rng, potential, direction, collector) { Ok(tree) => tree, Turning(_) => { return Turning(self); @@ -285,17 +298,17 @@ impl> NutsTree { Direction::Backward => (&other.left, &self.right), }; - let mut turning = first.is_turning(last); + let mut turning = first.is_turning(math, last); if self.depth > 0 { if !turning { - turning = self.right.is_turning(&other.right); + turning = self.right.is_turning(math, &other.right); } if !turning { - turning = self.left.is_turning(&other.left); + turning = self.left.is_turning(math, &other.left); } } - self.merge_into(other, rng, direction); + self.merge_into(math, other, rng, direction); if turning { ExtendResult::Turning(self) @@ -306,7 +319,8 @@ impl> NutsTree { fn merge_into( &mut self, - other: NutsTree, + _math: &mut M, + other: NutsTree, rng: &mut R, direction: Direction, ) { @@ -342,16 +356,24 @@ impl> NutsTree { fn single_step( &self, - pool: &mut ::Pool, - potential: &mut P, + math: &mut M, + pool: &mut StatePool, + potential: &mut H, direction: Direction, collector: &mut C, - ) -> Result, DivergenceInfo>> { + ) -> Result, DivergenceInfo>> { let start = match direction { Direction::Forward => &self.right, Direction::Backward => &self.left, }; - let end = match potential.leapfrog(pool, start, direction, self.initial_energy, collector) { + let end = match potential.leapfrog( + math, + pool, + start, + direction, + self.initial_energy, + collector, + ) { Ok(Ok(end)) => end, Ok(Err(info)) => return Ok(Err(info)), Err(error) => return Err(error), @@ -366,7 +388,8 @@ impl> NutsTree { log_size, initial_energy: self.initial_energy, is_main: false, - collector: PhantomData, + _phantom: PhantomData::default(), + _phantom2: PhantomData::default(), })) } @@ -387,36 +410,38 @@ pub struct NutsOptions { pub store_unconstrained: bool, } -pub(crate) fn draw( - pool: &mut ::Pool, - init: &mut P::State, +pub(crate) fn draw( + math: &mut M, + pool: &mut StatePool, + init: &mut State, rng: &mut R, potential: &mut P, options: &NutsOptions, collector: &mut C, -) -> Result<(P::State, SampleInfo)> +) -> Result<(State, SampleInfo)> where - P: Hamiltonian, + M: Math, + P: Hamiltonian, R: rand::Rng + ?Sized, - C: Collector, + C: Collector, { - potential.randomize_momentum(init, rng); - init.make_init_point(); - collector.register_init(init, options); + potential.randomize_momentum(math, init, rng); + init.make_init_point(math); + collector.register_init(math, init, options); let mut tree = NutsTree::new(init.clone()); while tree.depth < options.maxdepth { let direction: Direction = rng.gen(); - tree = match tree.extend(pool, rng, potential, direction, collector) { + tree = match tree.extend(math, pool, rng, potential, direction, collector) { ExtendResult::Ok(tree) => tree, ExtendResult::Turning(tree) => { let info = tree.info(false, None); - collector.register_draw(&tree.draw, &info); + collector.register_draw(math, &tree.draw, &info); return Ok((tree.draw, info)); } ExtendResult::Diverging(tree, info) => { let info = tree.info(false, Some(info)); - collector.register_draw(&tree.draw, &info); + collector.register_draw(math, &tree.draw, &info); return Ok((tree.draw, info)); } ExtendResult::Err(error) => { @@ -425,7 +450,7 @@ where }; } let info = tree.info(true, None); - collector.register_draw(&tree.draw, &info); + collector.register_draw(math, &tree.draw, &info); Ok((tree.draw, info)) } @@ -523,7 +548,7 @@ where } #[cfg(feature = "arrow")] -pub struct StatsBuilder { +pub(crate) struct StatsBuilder, A: AdaptStrategy> { depth: MutablePrimitiveArray, maxdepth_reached: MutableBooleanArray, index_in_trajectory: MutablePrimitiveArray, @@ -545,7 +570,7 @@ pub struct StatsBuilder { } #[cfg(feature = "arrow")] -impl StatsBuilder { +impl, A: AdaptStrategy> StatsBuilder { fn new_with_capacity(dim: usize, settings: &SamplerArgs) -> Self { let capacity = (settings.num_tune + settings.num_draws) as usize; @@ -625,8 +650,8 @@ impl StatsBuilder { } #[cfg(feature = "arrow")] -impl ArrowBuilder> - for StatsBuilder +impl, A: AdaptStrategy> + ArrowBuilder> for StatsBuilder { fn append_value(&mut self, value: &NutsSampleStats) { self.depth.push(Some(value.depth)); @@ -819,9 +844,9 @@ impl ArrowBuilder { + type Hamiltonian; //: Hamiltonian; + type AdaptStrategy; //: AdaptStrategy; type Stats: SampleStats + 'static; #[cfg(feature = "arrow")] @@ -843,34 +868,44 @@ pub trait Chain { fn stats_builder(&self, dim: usize, settings: &SamplerArgs) -> Self::Builder; } -pub(crate) struct NutsChain +pub(crate) struct NutsChain where - P: Hamiltonian, + M: Math, + P: Hamiltonian, R: rand::Rng, - S: AdaptStrategy, + A: AdaptStrategy, { - pool: ::Pool, + pool: StatePool, potential: P, - collector: S::Collector, + collector: A::Collector, options: NutsOptions, rng: R, - init: P::State, + init: State, chain: u64, draw_count: u64, - strategy: S, + strategy: A, + math: M, } -impl NutsChain +impl NutsChain where - P: Hamiltonian, + M: Math, + P: Hamiltonian, R: rand::Rng, - S: AdaptStrategy, + A: AdaptStrategy, { - pub fn new(mut potential: P, strategy: S, options: NutsOptions, rng: R, chain: u64) -> Self { + pub fn new( + mut math: M, + mut potential: P, + strategy: A, + options: NutsOptions, + rng: R, + chain: u64, + ) -> Self { let pool_size: usize = options.maxdepth.checked_mul(2).unwrap().try_into().unwrap(); - let mut pool = potential.new_pool(pool_size); - let init = potential.new_empty_state(&mut pool); - let collector = strategy.new_collector(); + let mut pool = potential.new_pool(&mut math, pool_size); + let init = potential.new_empty_state(&mut math, &mut pool); + let collector = strategy.new_collector(&mut math); NutsChain { pool, potential, @@ -881,66 +916,78 @@ where chain, draw_count: 0, strategy, + math, } } } -pub trait AdaptStrategy { - type Potential: Hamiltonian; - type Collector: Collector::State>; +pub(crate) trait AdaptStrategy { + type Potential: Hamiltonian; + type Collector: Collector; #[cfg(feature = "arrow")] type Stats: Send + Debug + ArrowRow + 'static; #[cfg(not(feature = "arrow"))] type Stats: Send + Debug + 'static; type Options: Copy + Send + Default; - fn new(options: Self::Options, num_tune: u64, dim: usize) -> Self; + fn new(math: &mut M, options: Self::Options, num_tune: u64) -> Self; fn init( &mut self, + math: &mut M, options: &mut NutsOptions, potential: &mut Self::Potential, - state: &::State, + state: &State, ); fn adapt( &mut self, + math: &mut M, options: &mut NutsOptions, potential: &mut Self::Potential, draw: u64, collector: &Self::Collector, ); - fn new_collector(&self) -> Self::Collector; + fn new_collector(&self, math: &mut M) -> Self::Collector; fn current_stats( &self, + math: &mut M, options: &NutsOptions, potential: &Self::Potential, collector: &Self::Collector, ) -> Self::Stats; } -impl Chain for NutsChain +impl Chain for NutsChain where - H: Hamiltonian, + M: Math, + H: Hamiltonian, R: rand::Rng, - S: AdaptStrategy, + A: AdaptStrategy, { type Hamiltonian = H; - type AdaptStrategy = S; - type Stats = NutsSampleStats; + type AdaptStrategy = A; + type Stats = NutsSampleStats; fn set_position(&mut self, position: &[f64]) -> Result<()> { - let state = self.potential.init_state(&mut self.pool, position)?; + let state = self + .potential + .init_state(&mut self.math, &mut self.pool, position)?; self.init = state; - self.strategy - .init(&mut self.options, &mut self.potential, &self.init); + self.strategy.init( + &mut self.math, + &mut self.options, + &mut self.potential, + &self.init, + ); Ok(()) } fn draw(&mut self) -> Result<(Box<[f64]>, Self::Stats)> { let (state, info) = draw( + &mut self.math, &mut self.pool, &mut self.init, &mut self.rng, @@ -948,8 +995,8 @@ where &self.options, &mut self.collector, )?; - let mut position: Box<[f64]> = vec![0f64; self.potential.dim()].into(); - state.write_position(&mut position); + let mut position: Box<[f64]> = vec![0f64; self.math.dim()].into(); + state.write_position(&mut self.math, &mut position); let stats = NutsSampleStats { depth: info.depth, maxdepth_reached: info.reached_maxdepth, @@ -962,26 +1009,28 @@ where draw: self.draw_count, potential_stats: self.potential.current_stats(), strategy_stats: self.strategy.current_stats( + &mut self.math, &self.options, &self.potential, &self.collector, ), gradient: if self.options.store_gradient { - let mut gradient: Box<[f64]> = vec![0f64; self.potential.dim()].into(); - state.write_gradient(&mut gradient); + let mut gradient: Box<[f64]> = vec![0f64; self.math.dim()].into(); + state.write_gradient(&mut self.math, &mut gradient); Some(gradient) } else { None }, unconstrained: if self.options.store_unconstrained { - let mut unconstrained: Box<[f64]> = vec![0f64; self.potential.dim()].into(); - state.write_position(&mut unconstrained); + let mut unconstrained: Box<[f64]> = vec![0f64; self.math.dim()].into(); + state.write_position(&mut self.math, &mut unconstrained); Some(unconstrained) } else { None }, }; self.strategy.adapt( + &mut self.math, &mut self.options, &mut self.potential, self.draw_count, @@ -993,11 +1042,11 @@ where } fn dim(&self) -> usize { - self.potential.dim() + self.math.dim() } #[cfg(feature = "arrow")] - type Builder = StatsBuilder; + type Builder = StatsBuilder; #[cfg(feature = "arrow")] fn stats_builder(&self, dim: usize, settings: &SamplerArgs) -> Self::Builder { @@ -1010,7 +1059,9 @@ where mod tests { use rand::thread_rng; - use crate::{adapt_strategy::test_logps::NormalLogp, new_sampler, Chain, SamplerArgs}; + use crate::{ + adapt_strategy::test_logps::NormalLogp, cpu_math::CpuMath, new_sampler, Chain, SamplerArgs, + }; use super::ArrowBuilder; @@ -1018,11 +1069,12 @@ mod tests { fn to_arrow() { let ndim = 10; let func = NormalLogp::new(ndim, 3.); + let mut math = CpuMath::new(func); let settings = SamplerArgs::default(); let mut rng = thread_rng(); - let mut chain = new_sampler(func, settings, 0, &mut rng); + let mut chain = new_sampler(math, settings, 0, &mut rng); let mut builder = chain.stats_builder(ndim, &settings); diff --git a/src/cpu_potential.rs b/src/potential.rs similarity index 50% rename from src/cpu_potential.rs rename to src/potential.rs index 942bca0..225fbb9 100644 --- a/src/cpu_potential.rs +++ b/src/potential.rs @@ -1,48 +1,35 @@ use std::fmt::Debug; +use std::marker::PhantomData; #[cfg(feature = "arrow")] use arrow2::array::{MutableArray, MutablePrimitiveArray, StructArray}; #[cfg(feature = "arrow")] use arrow2::datatypes::{DataType, Field}; -use crate::cpu_state::{State, StatePool}; use crate::mass_matrix::MassMatrix; +use crate::math_base::Math; use crate::nuts::{Collector, Direction, DivergenceInfo, Hamiltonian, LogpError, NutsError}; +use crate::state::{State, StatePool}; #[cfg(feature = "arrow")] use crate::SamplerArgs; #[cfg(feature = "arrow")] use crate::nuts::{ArrowBuilder, ArrowRow}; -/// Compute the unnormalized log probability density of the posterior -/// -/// This needs to be implemnted by users of the library to define -/// what distribution the users wants to sample from. -/// -/// Errors during that computation can be recoverable or non-recoverable. -/// If a non-recoverable error occurs during sampling, the sampler will -/// stop and return an error. -pub trait CpuLogpFunc { - type Err: Debug + Send + Sync + LogpError + 'static; - - fn logp(&mut self, position: &[f64], grad: &mut [f64]) -> Result; - fn dim(&self) -> usize; -} - -pub(crate) struct EuclideanPotential { - logp: F, - pub(crate) mass_matrix: M, +pub(crate) struct EuclideanPotential> { + pub(crate) mass_matrix: Mass, max_energy_error: f64, pub(crate) step_size: f64, + _phantom: PhantomData, } -impl EuclideanPotential { - pub(crate) fn new(logp: F, mass_matrix: M, max_energy_error: f64, step_size: f64) -> Self { +impl> EuclideanPotential { + pub(crate) fn new(mass_matrix: Mass, max_energy_error: f64, step_size: f64) -> Self { EuclideanPotential { - logp, mass_matrix, max_energy_error, step_size, + _phantom: PhantomData::default(), } } } @@ -83,20 +70,20 @@ impl ArrowRow for PotentialStats { } } -impl Hamiltonian for EuclideanPotential { - type State = State; - type LogpError = F::Err; +impl> Hamiltonian for EuclideanPotential { + type LogpError = M::LogpErr; type Stats = PotentialStats; - fn leapfrog>( + fn leapfrog>( &mut self, - pool: &mut StatePool, - start: &Self::State, + math: &mut M, + pool: &mut StatePool, + start: &State, dir: Direction, initial_energy: f64, collector: &mut C, - ) -> Result, NutsError> { - let mut out = pool.new_state(); + ) -> Result, DivergenceInfo>, NutsError> { + let mut out = pool.new_state(math); let sign = match dir { Direction::Forward => 1, @@ -105,87 +92,89 @@ impl Hamiltonian for EuclideanPotential { let epsilon = (sign as f64) * self.step_size; - start.first_momentum_halfstep(&mut out, epsilon); - self.update_velocity(&mut out); + start.first_momentum_halfstep(math, &mut out, epsilon); + self.update_velocity(math, &mut out); - start.position_step(&mut out, epsilon); - if let Err(logp_error) = self.update_potential_gradient(&mut out) { + start.position_step(math, &mut out, epsilon); + if let Err(logp_error) = self.update_potential_gradient(math, &mut out) { if !logp_error.is_recoverable() { return Err(NutsError::LogpFailure(Box::new(logp_error))); } let div_info = DivergenceInfo { logp_function_error: Some(Box::new(logp_error)), - start_location: Some(start.q.clone()), - start_gradient: Some(start.grad.clone()), - start_momentum: Some(start.p.clone()), + start_location: Some(math.box_array(&start.q)), + start_gradient: Some(math.box_array(&start.grad)), + start_momentum: Some(math.box_array(&start.p)), end_location: None, start_idx_in_trajectory: Some(start.idx_in_trajectory), end_idx_in_trajectory: None, energy_error: None, }; - collector.register_leapfrog(start, &out, Some(&div_info)); + collector.register_leapfrog(math, start, &out, Some(&div_info)); return Ok(Err(div_info)); } - out.second_momentum_halfstep(epsilon); + out.second_momentum_halfstep(math, epsilon); - self.update_velocity(&mut out); - self.update_kinetic_energy(&mut out); + self.update_velocity(math, &mut out); + self.update_kinetic_energy(math, &mut out); *out.index_in_trajectory_mut() = start.index_in_trajectory() + sign; - start.set_psum(&mut out, dir); + start.set_psum(math, &mut out, dir); - let energy_error = { - use crate::nuts::State; - out.energy() - initial_energy - }; + let energy_error = { out.energy() - initial_energy }; if (energy_error > self.max_energy_error) | !energy_error.is_finite() { let divergence_info = DivergenceInfo { logp_function_error: None, - start_location: Some(start.q.clone()), - start_gradient: Some(start.grad.clone()), - end_location: Some(out.q.clone()), - start_momentum: Some(out.p.clone()), + start_location: Some(math.box_array(&start.q)), + start_gradient: Some(math.box_array(&start.grad)), + end_location: Some(math.box_array(&out.q)), + start_momentum: Some(math.box_array(&out.p)), start_idx_in_trajectory: Some(start.index_in_trajectory()), end_idx_in_trajectory: Some(out.index_in_trajectory()), energy_error: Some(energy_error), }; - collector.register_leapfrog(start, &out, Some(&divergence_info)); + collector.register_leapfrog(math, start, &out, Some(&divergence_info)); return Ok(Err(divergence_info)); } - collector.register_leapfrog(start, &out, None); + collector.register_leapfrog(math, start, &out, None); Ok(Ok(out)) } - fn init_state(&mut self, pool: &mut StatePool, init: &[f64]) -> Result { - let mut state = pool.new_state(); + fn init_state( + &mut self, + math: &mut M, + pool: &mut StatePool, + init: &[f64], + ) -> Result, NutsError> { + let mut state = pool.new_state(math); { let inner = state.try_mut_inner().expect("State already in use"); - inner.q.copy_from_slice(init); - inner.p_sum.fill(0.); + math.read_from_slice(&mut inner.q, init); + math.fill_array(&mut inner.p_sum, 0.); } - self.update_potential_gradient(&mut state) + self.update_potential_gradient(math, &mut state) .map_err(|e| NutsError::LogpFailure(Box::new(e)))?; - if state - .grad - .iter() - .cloned() - .any(|val| (val == 0f64) || (!val.is_finite())) - { + if !math.array_all_finite_and_nonzero(&state.grad) { Err(NutsError::BadInitGrad()) } else { Ok(state) } } - fn randomize_momentum(&self, state: &mut Self::State, rng: &mut R) { + fn randomize_momentum( + &self, + math: &mut M, + state: &mut State, + rng: &mut R, + ) { let inner = state.try_mut_inner().unwrap(); - self.mass_matrix.randomize_momentum(inner, rng); - self.mass_matrix.update_velocity(inner); - self.mass_matrix.update_kinetic_energy(inner); + self.mass_matrix.randomize_momentum(math, inner, rng); + self.mass_matrix.update_velocity(math, inner); + self.mass_matrix.update_kinetic_energy(math, inner); } fn current_stats(&self) -> Self::Stats { @@ -194,24 +183,24 @@ impl Hamiltonian for EuclideanPotential { } } - fn new_empty_state(&mut self, pool: &mut StatePool) -> Self::State { - pool.new_state() - } - - fn new_pool(&mut self, _capacity: usize) -> StatePool { - StatePool::new(self.dim()) + fn new_empty_state(&mut self, math: &mut M, pool: &mut StatePool) -> State { + pool.new_state(math) } - fn dim(&self) -> usize { - self.logp.dim() + fn new_pool(&mut self, math: &mut M, capacity: usize) -> StatePool { + StatePool::new(math, capacity) } } -impl EuclideanPotential { - fn update_potential_gradient(&mut self, state: &mut State) -> Result<(), F::Err> { +impl> EuclideanPotential { + fn update_potential_gradient( + &mut self, + math: &mut M, + state: &mut State, + ) -> Result<(), M::LogpErr> { let logp = { let inner = state.try_mut_inner().unwrap(); - self.logp.logp(&inner.q, &mut inner.grad) + math.logp_array(&inner.q, &mut inner.grad) }?; let inner = state.try_mut_inner().unwrap(); @@ -219,13 +208,13 @@ impl EuclideanPotential { Ok(()) } - fn update_velocity(&mut self, state: &mut State) { + fn update_velocity(&mut self, math: &mut M, state: &mut State) { self.mass_matrix - .update_velocity(state.try_mut_inner().expect("State already in us")) + .update_velocity(math, state.try_mut_inner().expect("State already in us")) } - fn update_kinetic_energy(&mut self, state: &mut State) { + fn update_kinetic_energy(&mut self, math: &mut M, state: &mut State) { self.mass_matrix - .update_kinetic_energy(state.try_mut_inner().expect("State already in us")) + .update_kinetic_energy(math, state.try_mut_inner().expect("State already in us")) } } diff --git a/src/cpu_sampler.rs b/src/sampler.rs similarity index 55% rename from src/cpu_sampler.rs rename to src/sampler.rs index 87ca787..147840a 100644 --- a/src/cpu_sampler.rs +++ b/src/sampler.rs @@ -1,14 +1,11 @@ use rand::{Rng, SeedableRng}; -use rayon::prelude::*; -use std::thread::JoinHandle; -use thiserror::Error; use crate::{ adapt_strategy::{GradDiagOptions, GradDiagStrategy}, - cpu_potential::EuclideanPotential, mass_matrix::DiagMassMatrix, + math_base::Math, nuts::{Chain, NutsChain, NutsError, NutsOptions, SampleStats}, - CpuLogpFunc, + potential::EuclideanPotential, }; /// Settings for the NUTS sampler @@ -57,142 +54,26 @@ pub trait InitPointFunc { fn new_init_point(&mut self, rng: &mut R, out: &mut [f64]); } -#[non_exhaustive] -#[derive(Error, Debug)] -pub enum ParallelSamplingError { - #[error("Could not send sample to controller thread")] - ChannelClosed(), - #[error("Nuts failed because of unrecoverable logp function error: {source}")] - NutsError { - #[from] - source: NutsError, - }, - #[error("Initialization of first point failed")] - InitError { source: NutsError }, - #[error("Timeout occured while waiting for next sample")] - Timeout, - #[error("Drawing sample paniced")] - Panic, - #[error("Creating a logp function failed")] - LogpFuncCreation { - #[from] - source: anyhow::Error, - }, -} - -pub type ParallelChainResult = Result<(), ParallelSamplingError>; +pub trait MathMaker: Send + Sync { + type Math: Math; -pub trait CpuLogpFuncMaker: Send + Sync -where - Func: CpuLogpFunc, -{ - fn make_logp_func(&self, chain: usize) -> Result; + fn make_math(&self, id: usize) -> Result::Err>; fn dim(&self) -> usize; } -/// Sample several chains in parallel and return all of the samples live in a channel -pub fn sample_parallel< - M: CpuLogpFuncMaker + 'static, - F: CpuLogpFunc, - I: InitPointFunc, - R: Rng + ?Sized, ->( - logp_func_maker: M, - init_point_func: &mut I, - settings: SamplerArgs, - n_chains: u64, - rng: &mut R, - n_try_init: u64, -) -> Result< - ( - JoinHandle>, - crossbeam::channel::Receiver<(Box<[f64]>, Box)>, - ), - ParallelSamplingError, -> { - let ndim = logp_func_maker.dim(); - let mut func = logp_func_maker.make_logp_func(0)?; - assert!(ndim == func.dim()); - let draws = settings.num_tune + settings.num_draws; - //let mut rng = StdRng::from_rng(rng).expect("Could not seed rng"); - let mut rng = rand_chacha::ChaCha8Rng::from_rng(rng).unwrap(); - - let mut points: Vec, Box<[f64]>), ::Err>> = (0..n_chains) - .map(|_| { - let mut position = vec![0.; ndim]; - let mut grad = vec![0.; ndim]; - init_point_func.new_init_point(&mut rng, &mut position); - - let mut error = None; - for _ in 0..n_try_init { - match func.logp(&position, &mut grad) { - Err(e) => error = Some(e), - Ok(_) => { - error = None; - break; - } - } - } - match error { - Some(e) => Err(e), - None => Ok((position.into(), grad.into())), - } - }) - .collect(); - - let points: Result, Box<[f64]>)>, _> = points.drain(..).collect(); - let points = points.map_err(|e| ParallelSamplingError::InitError { - source: NutsError::LogpFailure(Box::new(e)), - })?; - - let (sender, receiver) = crossbeam::channel::bounded(128); - - let handle = std::thread::spawn(move || { - let rng = rng.clone(); - let results: Vec> = points - .into_par_iter() - .with_max_len(1) - .enumerate() - .map_with(sender, |sender, (chain, point)| { - let func = logp_func_maker.make_logp_func(chain)?; - let mut rng = rng.clone(); - rng.set_stream(chain as u64); - let mut sampler = new_sampler( - func, - settings, - chain as u64, - //seed.wrapping_add(chain as u64), - &mut rng, - ); - sampler.set_position(&point.0)?; - for _ in 0..draws { - let (point2, info) = sampler.draw()?; - sender - .send((point2, Box::new(info) as Box)) - .map_err(|_| ParallelSamplingError::ChannelClosed())?; - } - Ok(()) - }) - .collect(); - results - }); - - Ok((handle, receiver)) -} - /// Create a new sampler -pub fn new_sampler( - logp: F, +pub fn new_sampler<'math, M: Math + 'math, R: Rng + ?Sized>( + mut math: M, settings: SamplerArgs, chain: u64, rng: &mut R, -) -> impl Chain { +) -> impl Chain + 'math { use crate::nuts::AdaptStrategy; let num_tune = settings.num_tune; - let strategy = GradDiagStrategy::new(settings.mass_matrix_adapt, num_tune, logp.dim()); - let mass_matrix = DiagMassMatrix::new(logp.dim()); + let strategy = GradDiagStrategy::new(&mut math, settings.mass_matrix_adapt, num_tune); + let mass_matrix = DiagMassMatrix::new(&mut math); let max_energy_error = settings.max_energy_error; - let potential = EuclideanPotential::new(logp, mass_matrix, max_energy_error, 1f64); + let potential = EuclideanPotential::new(mass_matrix, max_energy_error, 1f64); let options = NutsOptions { maxdepth: settings.maxdepth, @@ -202,18 +83,18 @@ pub fn new_sampler( let rng = rand::rngs::SmallRng::from_rng(rng).expect("Could not seed rng"); - NutsChain::new(potential, strategy, options, rng, chain) + NutsChain::new(math, potential, strategy, options, rng, chain) } -pub fn sample_sequentially( - logp: F, +pub fn sample_sequentially<'math, M: Math + 'math, R: Rng + ?Sized>( + math: M, settings: SamplerArgs, start: &[f64], draws: u64, chain: u64, rng: &mut R, -) -> Result, impl SampleStats), NutsError>>, NutsError> { - let mut sampler = new_sampler(logp, settings, chain, rng); +) -> Result, impl SampleStats + 'math), NutsError>> + 'math, NutsError> { + let mut sampler = new_sampler(math, settings, chain, rng); sampler.set_position(start)?; Ok((0..draws).map(move |_| sampler.draw())) } @@ -250,19 +131,29 @@ impl InitPointFunc for JitterInitFunc { } pub mod test_logps { - use crate::{cpu_potential::CpuLogpFunc, nuts::LogpError, CpuLogpFuncMaker}; + use crate::{ + cpu_math::{CpuLogpFunc, CpuMath}, + math_base::Math, + nuts::LogpError, + MathMaker, + }; use multiversion::multiversion; use thiserror::Error; - #[derive(Clone)] + #[derive(Clone, Debug)] pub struct NormalLogp { dim: usize, mu: f64, } - impl CpuLogpFuncMaker for NormalLogp { - fn make_logp_func(&self, _chain: usize) -> Result { - Ok(self.clone()) + impl MathMaker for NormalLogp { + type Math = CpuMath; + + fn make_math( + &self, + _chain: usize, + ) -> Result, ::Err> { + Ok(CpuMath::new(self.clone())) } fn dim(&self) -> usize { @@ -285,7 +176,7 @@ pub mod test_logps { } impl CpuLogpFunc for NormalLogp { - type Err = NormalLogpError; + type LogpError = NormalLogpError; fn dim(&self) -> usize { self.dim @@ -355,8 +246,7 @@ pub mod test_logps { #[cfg(test)] mod tests { use crate::{ - sample_parallel, sample_sequentially, test_logps::NormalLogp, JitterInitFunc, SampleStats, - SamplerArgs, + cpu_math::CpuMath, sample_sequentially, test_logps::NormalLogp, SampleStats, SamplerArgs, }; use itertools::Itertools; @@ -366,6 +256,7 @@ mod tests { #[test] fn sample_seq() { let logp = NormalLogp::new(10, 0.1); + let math = CpuMath::new(logp); let mut settings = SamplerArgs::default(); settings.num_tune = 100; settings.num_draws = 100; @@ -373,7 +264,7 @@ mod tests { let mut rng = StdRng::seed_from_u64(42); - let chain = sample_sequentially(logp.clone(), settings, &start, 200, 1, &mut rng).unwrap(); + let chain = sample_sequentially(math, settings, &start, 200, 1, &mut rng).unwrap(); let mut draws = chain.collect_vec(); assert_eq!(draws.len(), 200); @@ -382,17 +273,5 @@ mod tests { assert_eq!(vals.len(), 10); assert_eq!(stats.chain(), 1); assert_eq!(stats.draw(), 100); - - let maker = logp; - - let (handles, chains) = - sample_parallel(maker, &mut JitterInitFunc::new(), settings, 4, &mut rng, 10).unwrap(); - let mut draws = chains.iter().collect_vec(); - assert_eq!(draws.len(), 800); - assert!(handles.join().is_ok()); - - let draw0 = draws.remove(100); - let (vals, _) = draw0; - assert_eq!(vals.len(), 10); } } diff --git a/src/state.rs b/src/state.rs new file mode 100644 index 0000000..489301b --- /dev/null +++ b/src/state.rs @@ -0,0 +1,254 @@ +use std::{ + cell::RefCell, + fmt::Debug, + marker::PhantomData, + ops::{Deref, DerefMut}, + rc::{Rc, Weak}, +}; + +use crate::math_base::Math; + +struct StateStorage { + free_states: RefCell>>>, +} + +impl StateStorage { + fn new(_math: &mut M, capacity: usize) -> StateStorage { + StateStorage { + free_states: RefCell::new(Vec::with_capacity(capacity)), + } + } +} + +pub(crate) struct StatePool { + storage: Rc> +} + +impl StatePool { + pub(crate) fn new(math: &mut M, capacity: usize) -> StatePool { + StatePool { + storage: Rc::new(StateStorage::new(math, capacity)), + } + } + + pub(crate) fn new_state(&self, math: &mut M) -> State { + let inner = match self.storage.free_states.borrow_mut().pop() { + Some(inner) => inner, + None => { + Rc::new(InnerStateReusable::new(math, &self)) + } + }; + State { + inner: std::mem::ManuallyDrop::new(inner), + } + } +} + +#[derive(Debug, Clone)] +pub(crate) struct InnerState { + pub(crate) p: M::Array, + pub(crate) q: M::Array, + pub(crate) v: M::Array, + pub(crate) p_sum: M::Array, + pub(crate) grad: M::Array, + pub(crate) idx_in_trajectory: i64, + pub(crate) kinetic_energy: f64, + pub(crate) potential_energy: f64, + _phantom_todo: PhantomData, +} + +pub(crate) struct InnerStateReusable { + inner: InnerState, + reuser: Weak>, +} + +impl<'pool, M: Math> InnerStateReusable { + fn new(math: &mut M, owner: &StatePool) -> InnerStateReusable { + InnerStateReusable { + inner: InnerState { + p: math.new_array(), + q: math.new_array(), + v: math.new_array(), + p_sum: math.new_array(), + grad: math.new_array(), + idx_in_trajectory: 0, + kinetic_energy: 0., + potential_energy: 0., + _phantom_todo: PhantomData::default(), + }, + reuser: Rc::downgrade(&Rc::clone(&owner.storage)), + } + } +} + +pub(crate) struct State { + inner: std::mem::ManuallyDrop>>, +} + +impl Deref for State { + type Target = InnerState; + + fn deref(&self) -> &Self::Target { + &self.inner.inner + } +} + +#[derive(Debug)] +pub(crate) struct StateInUse {} + +type Result = std::result::Result; + +impl State { + pub(crate) fn try_mut_inner(&mut self) -> Result<&mut InnerState> { + match Rc::get_mut(&mut self.inner) { + Some(val) => Ok(&mut val.inner), + None => Err(StateInUse {}), + } + } +} + +impl Drop for State { + fn drop(&mut self) { + let rc = unsafe { std::mem::ManuallyDrop::take(&mut self.inner) }; + if (Rc::strong_count(&rc) == 1) & (Rc::weak_count(&rc) == 0) { + if let Some(storage) = rc.reuser.upgrade() { + storage.free_states.borrow_mut().push(rc); + } + } + } +} + +impl Clone for State { + fn clone(&self) -> Self { + State { + inner: self.inner.clone(), + } + } +} + +impl State { + pub(crate) fn is_turning(&self, math: &mut M, other: &Self) -> bool { + let (start, end) = if self.idx_in_trajectory < other.idx_in_trajectory { + (self, other) + } else { + (other, self) + }; + + let a = start.idx_in_trajectory; + let b = end.idx_in_trajectory; + + assert!(a < b); + let (turn1, turn2) = if (a >= 0) & (b >= 0) { + math.scalar_prods3(&end.p_sum, &start.p_sum, &start.p, &end.v, &start.v) + } else if (b >= 0) & (a < 0) { + math.scalar_prods2(&end.p_sum, &start.p_sum, &end.v, &start.v) + } else { + assert!((a < 0) & (b < 0)); + math.scalar_prods3(&start.p_sum, &end.p_sum, &end.p, &end.v, &start.v) + }; + + (turn1 < 0.) | (turn2 < 0.) + } + + pub(crate) fn write_position(&self, math: &mut M, out: &mut [f64]) { + math.write_to_slice(&self.q, out) + } + + pub(crate) fn write_gradient(&self, math: &mut M, out: &mut [f64]) { + math.write_to_slice(&self.grad, out) + } + + pub(crate) fn energy(&self) -> f64 { + self.kinetic_energy + self.potential_energy + } + + pub(crate) fn index_in_trajectory(&self) -> i64 { + self.idx_in_trajectory + } + + pub(crate) fn make_init_point(&mut self, math: &mut M) { + let inner = self.try_mut_inner().unwrap(); + inner.idx_in_trajectory = 0; + math.copy_into(&inner.p, &mut inner.p_sum); + } + + pub(crate) fn potential_energy(&self) -> f64 { + self.potential_energy + } + + pub(crate) fn first_momentum_halfstep(&self, math: &mut M, out: &mut Self, epsilon: f64) { + math.axpy_out( + &self.grad, + &self.p, + epsilon / 2., + &mut out.try_mut_inner().expect("State already in use").p, + ); + } + + pub(crate) fn position_step(&self, math: &mut M, out: &mut Self, epsilon: f64) { + let out = out.try_mut_inner().expect("State already in use"); + math.axpy_out(&out.v, &self.q, epsilon, &mut out.q); + } + + pub(crate) fn second_momentum_halfstep(&mut self, math: &mut M, epsilon: f64) { + let inner = self.try_mut_inner().expect("State already in use"); + math.axpy(&inner.grad, &mut inner.p, epsilon / 2.); + } + + pub(crate) fn set_psum(&self, math: &mut M, target: &mut Self, _dir: crate::nuts::Direction) { + let out = target.try_mut_inner().expect("State already in use"); + + assert!(out.idx_in_trajectory != 0); + + if out.idx_in_trajectory == -1 { + math.copy_into(&out.p, &mut out.p_sum); + } else { + math.axpy_out(&out.p, &self.p_sum, 1., &mut out.p_sum); + } + } + + pub(crate) fn index_in_trajectory_mut(&mut self) -> &mut i64 { + &mut self + .try_mut_inner() + .expect("State already in use") + .idx_in_trajectory + } +} + +#[cfg(test)] +mod tests { + use crate::{cpu_math::CpuMath, test_logps::NormalLogp}; + + use super::*; + + #[test] + fn crate_pool() { + let logp = NormalLogp::new(10, 0.2); + let mut math = CpuMath::new(logp); + let mut pool = StatePool::new(&mut math, 10); + let mut state = pool.new_state(&mut math); + assert!(state.p.nrows() == 10); + assert!(state.p.ncols() == 1); + state.try_mut_inner().unwrap(); + + let mut state2 = state.clone(); + assert!(state.try_mut_inner().is_err()); + assert!(state2.try_mut_inner().is_err()); + } + + #[test] + fn make_state() { + let dim = 10; + let logp = NormalLogp::new(dim, 0.2); + let mut math = CpuMath::new(logp); + let mut pool = StatePool::new(&mut math, 10); + let a = pool.new_state(&mut math); + + assert_eq!(a.idx_in_trajectory, 0); + assert!(a.p_sum.col_ref(0).iter().all(|&x| x == 0f64)); + assert_eq!(a.p_sum.col_ref(0).len(), dim); + assert_eq!(a.grad.col_ref(0).len(), dim); + assert_eq!(a.q.col_ref(0).len(), dim); + assert_eq!(a.p.col_ref(0).len(), dim); + } +} diff --git a/src/stepsize.rs b/src/stepsize.rs index 2a139f4..b37c52f 100644 --- a/src/stepsize.rs +++ b/src/stepsize.rs @@ -1,7 +1,9 @@ use std::marker::PhantomData; use crate::{ - nuts::{Collector, NutsOptions, State}, + math_base::Math, + nuts::{Collector, NutsOptions}, + state::State, DivergenceInfo, }; @@ -101,15 +103,15 @@ impl RunningMean { } } -pub(crate) struct AcceptanceRateCollector { +pub(crate) struct AcceptanceRateCollector { initial_energy: f64, pub(crate) mean: RunningMean, pub(crate) mean_sym: RunningMean, - phantom: PhantomData, + phantom: PhantomData, } -impl AcceptanceRateCollector { - pub(crate) fn new() -> AcceptanceRateCollector { +impl AcceptanceRateCollector { + pub(crate) fn new() -> AcceptanceRateCollector { AcceptanceRateCollector { initial_energy: 0., mean: RunningMean::new(), @@ -119,13 +121,12 @@ impl AcceptanceRateCollector { } } -impl Collector for AcceptanceRateCollector { - type State = S; - +impl Collector for AcceptanceRateCollector { fn register_leapfrog( &mut self, - _start: &Self::State, - end: &Self::State, + _math: &mut M, + _start: &State, + end: &State, divergence_info: Option<&DivergenceInfo>, ) { match divergence_info { @@ -145,7 +146,7 @@ impl Collector for AcceptanceRateCollector { }; } - fn register_init(&mut self, state: &Self::State, _options: &NutsOptions) { + fn register_init(&mut self, _math: &mut M, state: &State, _options: &NutsOptions) { self.initial_energy = state.energy(); self.mean.reset(); self.mean_sym.reset();