diff --git a/CHANGELOG.md b/CHANGELOG.md index a774d96..a07cf6b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,7 @@ All notable changes to this project will be documented in this file. -## [0.12.0] - 2024-07-05 +## [0.13.0] - 2024-10-23 ### Bug Fixes @@ -10,6 +10,8 @@ All notable changes to this project will be documented in this file. - Append missing values for non-diverging draws (Adrian Seyboldt) +- Fix bug where step size stats were not updated after tuning (Adrian Seyboldt) + ### Features @@ -21,6 +23,14 @@ All notable changes to this project will be documented in this file. - Add low-rank modified mass matrix adaptation (Adrian Seyboldt) +- Make cpu_math parallelization configurable (Adrian Seyboldt) + +- Add transforming adaptation (Adrian Seyboldt) + +- Improve error info for BadInitGrad (Adrian Seyboldt) + +- Do not report invalid gradients for transform adapt (Adrian Seyboldt) + ### Miscellaneous Tasks @@ -30,11 +40,19 @@ All notable changes to this project will be documented in this file. - Update changelog (Adrian Seyboldt) +- Prepare release (Adrian Seyboldt) + +- Prepare release (Adrian Seyboldt) + +- Update dependencies (Adrian Seyboldt) + ### Refactor - Switch to arrow-rs (Adrian Seyboldt) +- Refactor mass matrix adaptation traits (Adrian Seyboldt) + ### Styling diff --git a/Cargo.toml b/Cargo.toml index e97786b..f430478 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,9 +1,9 @@ [package] name = "nuts-rs" -version = "0.12.1" +version = "0.13.0" authors = [ - "Adrian Seyboldt ", - "PyMC Developers ", + "Adrian Seyboldt ", + "PyMC Developers ", ] edition = "2021" license = "MIT" @@ -23,10 +23,10 @@ rand_distr = "0.4.3" multiversion = "0.7.2" itertools = "0.13.0" thiserror = "1.0.43" -arrow = { version = "52.0.0", default-features = false, features = ["ffi"] } +arrow = { version = "53.1.0", default-features = false, features = ["ffi"] } rand_chacha = "0.3.1" anyhow = "1.0.72" -faer = "0.19.1" +faer = { version = "0.19.4", default-features = false, features = ["std"] } pulp = "0.18.21" rayon = "1.10.0" @@ -36,7 +36,7 @@ pretty_assertions = "1.4.0" criterion = "0.5.1" nix = "0.29.0" approx = "0.5.1" -ndarray = "0.15.6" +ndarray = "0.16.1" [[bench]] name = "sample" diff --git a/src/adapt_strategy.rs b/src/adapt_strategy.rs index 655e899..f1a26e7 100644 --- a/src/adapt_strategy.rs +++ b/src/adapt_strategy.rs @@ -5,25 +5,27 @@ use itertools::Itertools; use rand::Rng; use crate::{ + chain::AdaptStrategy, + euclidean_hamiltonian::EuclideanHamiltonian, + hamiltonian::{DivergenceInfo, Hamiltonian, Point}, mass_matrix_adapt::MassMatrixAdaptStrategy, math_base::Math, - nuts::{AdaptStats, AdaptStrategy, Collector, NutsOptions}, + nuts::{Collector, NutsOptions}, sampler::Settings, + sampler_stats::{SamplerStats, StatTraceBuilder}, state::State, stepsize::AcceptanceRateCollector, stepsize_adapt::{ DualAverageSettings, Stats as StepSizeStats, StatsBuilder as StepSizeStatsBuilder, Strategy as StepSizeStrategy, }, - DivergenceInfo, + NutsError, }; -use crate::nuts::{SamplerStats, StatTraceBuilder}; - pub struct GlobalStrategy> { step_size: StepSizeStrategy, mass_matrix: A, - options: AdaptOptions, + options: EuclideanAdaptOptions, num_tune: u64, // The number of draws in the the early window early_end: u64, @@ -36,7 +38,7 @@ pub struct GlobalStrategy> { } #[derive(Debug, Clone, Copy)] -pub struct AdaptOptions { +pub struct EuclideanAdaptOptions { pub dual_average_options: DualAverageSettings, pub mass_matrix_options: S, pub early_window: f64, @@ -46,7 +48,7 @@ pub struct AdaptOptions { pub mass_matrix_update_freq: u64, } -impl Default for AdaptOptions { +impl Default for EuclideanAdaptOptions { fn default() -> Self { Self { dual_average_options: DualAverageSettings::default(), @@ -79,18 +81,17 @@ impl> SamplerStats for GlobalStrategy< } } -impl> AdaptStats for GlobalStrategy { - fn num_grad_evals(stats: &Self::Stats) -> usize { - stats.stats1.n_steps as usize - } -} - impl> AdaptStrategy for GlobalStrategy { - type Potential = A::Potential; - type Collector = CombinedCollector; - type Options = AdaptOptions; - - fn new(math: &mut M, options: Self::Options, num_tune: u64) -> Self { + type Hamiltonian = EuclideanHamiltonian; + type Collector = CombinedCollector< + M, + >::Point, + AcceptanceRateCollector, + A::Collector, + >; + type Options = EuclideanAdaptOptions; + + fn new(math: &mut M, options: Self::Options, num_tune: u64, chain: u64) -> Self { let num_tune_f = num_tune as f64; let step_size_window = (options.step_size_window * num_tune_f) as u64; let early_end = (options.early_window * num_tune_f) as u64; @@ -100,7 +101,7 @@ impl> AdaptStrategy for GlobalStrategy Self { step_size: StepSizeStrategy::new(options.dual_average_options), - mass_matrix: A::new(math, options.mass_matrix_options, num_tune), + mass_matrix: A::new(math, options.mass_matrix_options, num_tune, chain), options, num_tune, early_end, @@ -115,29 +116,38 @@ impl> AdaptStrategy for GlobalStrategy &mut self, math: &mut M, options: &mut NutsOptions, - potential: &mut Self::Potential, - state: &State, + hamiltonian: &mut Self::Hamiltonian, + position: &[f64], rng: &mut R, - ) { - self.mass_matrix.init(math, options, potential, state, rng); - self.step_size.init(math, options, potential, state, rng); + ) -> Result<(), NutsError> { + let state = hamiltonian.init_state(math, position)?; + self.mass_matrix.init( + math, + options, + &mut hamiltonian.mass_matrix, + state.point(), + rng, + )?; + self.step_size + .init(math, options, hamiltonian, position, rng)?; + Ok(()) } fn adapt( &mut self, math: &mut M, options: &mut NutsOptions, - potential: &mut Self::Potential, + hamiltonian: &mut Self::Hamiltonian, draw: u64, collector: &Self::Collector, - state: &State, + state: &State>::Point>, rng: &mut R, - ) { + ) -> Result<(), NutsError> { self.step_size.update(&collector.collector1); if draw >= self.num_tune { self.tuning = false; - return; + return Ok(()); } if draw < self.final_step_size_window { @@ -165,7 +175,7 @@ impl> AdaptStrategy for GlobalStrategy let did_change = if force_update | (draw - self.last_update >= self.options.mass_matrix_update_freq) { - self.mass_matrix.update_potential(math, potential) + self.mass_matrix.adapt(math, &mut hamiltonian.mass_matrix) } else { false }; @@ -183,24 +193,26 @@ impl> AdaptStrategy for GlobalStrategy // First time we change the mass matrix if did_change & self.has_initial_mass_matrix { self.has_initial_mass_matrix = false; - self.step_size.init(math, options, potential, state, rng); + let position = math.box_array(state.point().position()); + self.step_size + .init(math, options, hamiltonian, &position, rng)?; } else { - self.step_size.update_stepsize(potential, false) + self.step_size.update_stepsize(hamiltonian, false) } - return; + return Ok(()); } self.step_size.update_estimator_late(); let is_last = draw == self.num_tune - 1; - self.step_size.update_stepsize(potential, is_last); + self.step_size.update_stepsize(hamiltonian, is_last); + Ok(()) } fn new_collector(&self, math: &mut M) -> Self::Collector { - CombinedCollector { - collector1: self.step_size.new_collector(), - collector2: self.mass_matrix.new_collector(math), - _phantom: PhantomData, - } + Self::Collector::new( + self.step_size.new_collector(), + self.mass_matrix.new_collector(math), + ) } fn is_tuning(&self) -> bool { @@ -216,8 +228,8 @@ pub struct CombinedStats { #[derive(Clone)] pub struct CombinedStatsBuilder { - stats1: B1, - stats2: B2, + pub stats1: B1, + pub stats2: B2, } impl StatTraceBuilder> for CombinedStatsBuilder @@ -277,22 +289,48 @@ where } } -pub struct CombinedCollector, C2: Collector> { - collector1: C1, - collector2: C2, +pub struct CombinedCollector +where + M: Math, + P: Point, + C1: Collector, + C2: Collector, +{ + pub collector1: C1, + pub collector2: C2, _phantom: PhantomData, + _phantom2: PhantomData

, } -impl Collector for CombinedCollector +impl CombinedCollector where - C1: Collector, - C2: Collector, + M: Math, + P: Point, + C1: Collector, + C2: Collector, +{ + pub fn new(collector1: C1, collector2: C2) -> Self { + CombinedCollector { + collector1, + collector2, + _phantom: PhantomData, + _phantom2: PhantomData, + } + } +} + +impl Collector for CombinedCollector +where + M: Math, + P: Point, + C1: Collector, + C2: Collector, { fn register_leapfrog( &mut self, math: &mut M, - start: &State, - end: &State, + start: &State, + end: &State, divergence_info: Option<&DivergenceInfo>, ) { self.collector1 @@ -301,7 +339,7 @@ where .register_leapfrog(math, start, end, divergence_info); } - fn register_draw(&mut self, math: &mut M, state: &State, info: &crate::nuts::SampleInfo) { + fn register_draw(&mut self, math: &mut M, state: &State, info: &crate::nuts::SampleInfo) { self.collector1.register_draw(math, state, info); self.collector2.register_draw(math, state, info); } @@ -309,7 +347,7 @@ where fn register_init( &mut self, math: &mut M, - state: &State, + state: &State, options: &crate::nuts::NutsOptions, ) { self.collector1.register_init(math, state, options); @@ -319,7 +357,7 @@ where #[cfg(test)] pub mod test_logps { - use crate::{cpu_math::CpuLogpFunc, nuts::LogpError}; + use crate::{cpu_math::CpuLogpFunc, math_base::LogpError}; use thiserror::Error; #[derive(Clone, Debug)] @@ -344,6 +382,7 @@ pub mod test_logps { impl CpuLogpFunc for NormalLogp { type LogpError = NormalLogpError; + type TransformParams = (); fn dim(&self) -> usize { self.dim @@ -360,6 +399,66 @@ pub mod test_logps { } Ok(logp) } + + fn inv_transform_normalize( + &mut self, + _params: &Self::TransformParams, + _untransformed_position: &[f64], + _untransofrmed_gradient: &[f64], + _transformed_position: &mut [f64], + _transformed_gradient: &mut [f64], + ) -> Result { + unimplemented!() + } + + fn init_from_transformed_position( + &mut self, + _params: &Self::TransformParams, + _untransformed_position: &mut [f64], + _untransformed_gradient: &mut [f64], + _transformed_position: &[f64], + _transformed_gradient: &mut [f64], + ) -> Result<(f64, f64), Self::LogpError> { + unimplemented!() + } + + fn init_from_untransformed_position( + &mut self, + _params: &Self::TransformParams, + _untransformed_position: &[f64], + _untransformed_gradient: &mut [f64], + _transformed_position: &mut [f64], + _transformed_gradient: &mut [f64], + ) -> Result<(f64, f64), Self::LogpError> { + unimplemented!() + } + + fn update_transformation<'a, R: rand::Rng + ?Sized>( + &'a mut self, + _rng: &mut R, + _untransformed_positions: impl Iterator, + _untransformed_gradients: impl Iterator, + _params: &'a mut Self::TransformParams, + ) -> Result<(), Self::LogpError> { + unimplemented!() + } + + fn new_transformation( + &mut self, + _rng: &mut R, + _untransformed_position: &[f64], + _untransfogmed_gradient: &[f64], + _chain: u64, + ) -> Result { + unimplemented!() + } + + fn transformation_id( + &self, + _params: &Self::TransformParams, + ) -> Result { + unimplemented!() + } } } @@ -368,11 +467,8 @@ mod test { use super::test_logps::NormalLogp; use super::*; use crate::{ - cpu_math::CpuMath, - mass_matrix::DiagMassMatrix, - nuts::{AdaptStrategy, Chain, NutsChain, NutsOptions}, - potential::EuclideanPotential, - DiagAdaptExpSettings, + chain::NutsChain, cpu_math::CpuMath, euclidean_hamiltonian::EuclideanHamiltonian, + mass_matrix::DiagMassMatrix, Chain, DiagAdaptExpSettings, }; #[test] @@ -383,14 +479,15 @@ mod test { let func = NormalLogp::new(ndim, 3.); let mut math = CpuMath::new(func); let num_tune = 100; - let options = AdaptOptions::::default(); - let strategy = GlobalStrategy::<_, Strategy<_>>::new(&mut math, options, num_tune); + let options = EuclideanAdaptOptions::::default(); + let strategy = GlobalStrategy::<_, Strategy<_>>::new(&mut math, options, num_tune, 0u64); let mass_matrix = DiagMassMatrix::new(&mut math, true); let max_energy_error = 1000f64; let step_size = 0.1f64; - let potential = EuclideanPotential::new(mass_matrix, max_energy_error, step_size); + let hamiltonian = + EuclideanHamiltonian::new(&mut math, mass_matrix, max_energy_error, step_size); let options = NutsOptions { maxdepth: 10u64, store_gradient: true, @@ -405,7 +502,7 @@ mod test { }; let chain = 0u64; - let mut sampler = NutsChain::new(math, potential, strategy, options, rng, chain); + let mut sampler = NutsChain::new(math, hamiltonian, strategy, options, rng, chain); sampler.set_position(&vec![1.5f64; ndim]).unwrap(); for _ in 0..200 { sampler.draw().unwrap(); diff --git a/src/chain.rs b/src/chain.rs new file mode 100644 index 0000000..c3914a1 --- /dev/null +++ b/src/chain.rs @@ -0,0 +1,221 @@ +use std::fmt::Debug; + +use rand::Rng; + +use crate::{ + hamiltonian::{Hamiltonian, Point}, + nuts::{draw, Collector, NutsOptions, NutsSampleStats, NutsStatsBuilder}, + sampler_stats::SamplerStats, + state::State, + Math, NutsError, Settings, +}; + +use anyhow::Result; + +/// Draw samples from the posterior distribution using Hamiltonian MCMC. +pub trait Chain: SamplerStats { + type AdaptStrategy: AdaptStrategy; + + /// Initialize the sampler to a position. This should be called + /// before calling draw. + /// + /// This fails if the logp function returns an error. + fn set_position(&mut self, position: &[f64]) -> Result<()>; + + /// Draw a new sample and return the position and some diagnosic information. + fn draw(&mut self) -> Result<(Box<[f64]>, Self::Stats)>; + + /// The dimensionality of the posterior. + fn dim(&self) -> usize; +} + +pub struct NutsChain +where + M: Math, + R: rand::Rng, + A: AdaptStrategy, +{ + hamiltonian: A::Hamiltonian, + collector: A::Collector, + options: NutsOptions, + rng: R, + init: State>::Point>, + chain: u64, + draw_count: u64, + strategy: A, + math: M, + stats: Option>::Stats, A::Stats>>, +} + +impl NutsChain +where + M: Math, + R: rand::Rng, + A: AdaptStrategy, +{ + pub fn new( + mut math: M, + mut hamiltonian: A::Hamiltonian, + strategy: A, + options: NutsOptions, + rng: R, + chain: u64, + ) -> Self { + let init = hamiltonian.pool().new_state(&mut math); + let collector = strategy.new_collector(&mut math); + NutsChain { + hamiltonian, + collector, + options, + rng, + init, + chain, + draw_count: 0, + strategy, + math, + stats: None, + } + } +} + +pub trait AdaptStrategy: SamplerStats { + type Hamiltonian: Hamiltonian; + type Collector: Collector>::Point>; + type Options: Copy + Send + Debug + Default; + + fn new(math: &mut M, options: Self::Options, num_tune: u64, chain: u64) -> Self; + + fn init( + &mut self, + math: &mut M, + options: &mut NutsOptions, + hamiltonian: &mut Self::Hamiltonian, + position: &[f64], + rng: &mut R, + ) -> Result<(), NutsError>; + + #[allow(clippy::too_many_arguments)] + fn adapt( + &mut self, + math: &mut M, + options: &mut NutsOptions, + hamiltonian: &mut Self::Hamiltonian, + draw: u64, + collector: &Self::Collector, + state: &State>::Point>, + rng: &mut R, + ) -> Result<(), NutsError>; + + fn new_collector(&self, math: &mut M) -> Self::Collector; + fn is_tuning(&self) -> bool; +} + +impl SamplerStats for NutsChain +where + M: Math, + R: rand::Rng, + A: AdaptStrategy, +{ + type Builder = NutsStatsBuilder< + >::Builder, + >::Builder, + >; + type Stats = + NutsSampleStats<>::Stats, >::Stats>; + + fn new_builder(&self, settings: &impl Settings, dim: usize) -> Self::Builder { + NutsStatsBuilder::new_with_capacity( + settings, + &self.hamiltonian, + &self.strategy, + dim, + &self.options, + ) + } + + fn current_stats(&self, _math: &mut M) -> Self::Stats { + self.stats.as_ref().expect("No stats available").clone() + } +} + +impl Chain for NutsChain +where + M: Math, + R: rand::Rng, + A: AdaptStrategy, +{ + type AdaptStrategy = A; + + fn set_position(&mut self, position: &[f64]) -> Result<()> { + self.strategy.init( + &mut self.math, + &mut self.options, + &mut self.hamiltonian, + position, + &mut self.rng, + )?; + self.init = self.hamiltonian.init_state(&mut self.math, position)?; + Ok(()) + } + + fn draw(&mut self) -> Result<(Box<[f64]>, Self::Stats)> { + let (state, info) = draw( + &mut self.math, + &mut self.init, + &mut self.rng, + &mut self.hamiltonian, + &self.options, + &mut self.collector, + )?; + let mut position: Box<[f64]> = vec![0f64; self.math.dim()].into(); + state.write_position(&mut self.math, &mut position); + + let stats = NutsSampleStats { + depth: info.depth, + maxdepth_reached: info.reached_maxdepth, + idx_in_trajectory: state.index_in_trajectory(), + logp: state.point().logp(), + energy: state.point().energy(), + energy_error: info.draw_energy - info.initial_energy, + divergence_info: info.divergence_info, + chain: self.chain, + draw: self.draw_count, + potential_stats: self.hamiltonian.current_stats(&mut self.math), + strategy_stats: self.strategy.current_stats(&mut self.math), + gradient: if self.options.store_gradient { + let mut gradient: Box<[f64]> = vec![0f64; self.math.dim()].into(); + state.write_gradient(&mut self.math, &mut gradient); + Some(gradient) + } else { + None + }, + unconstrained: if self.options.store_unconstrained { + let mut unconstrained: Box<[f64]> = vec![0f64; self.math.dim()].into(); + state.write_position(&mut self.math, &mut unconstrained); + Some(unconstrained) + } else { + None + }, + tuning: self.strategy.is_tuning(), + }; + + self.strategy.adapt( + &mut self.math, + &mut self.options, + &mut self.hamiltonian, + self.draw_count, + &self.collector, + &state, + &mut self.rng, + )?; + + self.draw_count += 1; + + self.init = state; + Ok((position, stats)) + } + + fn dim(&self) -> usize { + self.math.dim() + } +} diff --git a/src/cpu_math.rs b/src/cpu_math.rs index abb1221..c617640 100644 --- a/src/cpu_math.rs +++ b/src/cpu_math.rs @@ -6,20 +6,34 @@ use thiserror::Error; use crate::{ math::{axpy, axpy_out, multiply, scalar_prods2, scalar_prods3, vector_dot}, - math_base::Math, - LogpError, + math_base::{LogpError, Math}, }; #[derive(Debug)] pub struct CpuMath { logp_func: F, arch: pulp::Arch, + parallel: faer::Parallelism<'static>, } impl CpuMath { pub fn new(logp_func: F) -> Self { let arch = pulp::Arch::new(); - Self { logp_func, arch } + let parallel = faer::Parallelism::None; + Self { + logp_func, + arch, + parallel, + } + } + + pub fn with_parallel(logp_func: F, parallel: faer::Parallelism<'static>) -> Self { + let arch = pulp::Arch::new(); + Self { + logp_func, + arch, + parallel, + } } } @@ -36,8 +50,9 @@ impl Math for CpuMath { type EigValues = Col; type LogpErr = F::LogpError; type Err = CpuMathError; + type TransformParams = F::TransformParams; - fn new_array(&self) -> Self::Vector { + fn new_array(&mut self) -> Self::Vector { Col::zeros(self.dim()) } @@ -327,11 +342,154 @@ impl Math for CpuMath { fn eigs_as_array(&mut self, source: &Self::EigValues) -> Box<[f64]> { source.as_slice().to_vec().into() } + + fn inv_transform_normalize( + &mut self, + params: &Self::TransformParams, + untransformed_position: &Self::Vector, + untransofrmed_gradient: &Self::Vector, + transformed_position: &mut Self::Vector, + transformed_gradient: &mut Self::Vector, + ) -> Result { + self.logp_func.inv_transform_normalize( + params, + untransformed_position.as_slice(), + untransofrmed_gradient.as_slice(), + transformed_position.as_slice_mut(), + transformed_gradient.as_slice_mut(), + ) + } + + fn init_from_untransformed_position( + &mut self, + params: &Self::TransformParams, + untransformed_position: &Self::Vector, + untransformed_gradient: &mut Self::Vector, + transformed_position: &mut Self::Vector, + transformed_gradient: &mut Self::Vector, + ) -> Result<(f64, f64), Self::LogpErr> { + self.logp_func.init_from_untransformed_position( + params, + untransformed_position.as_slice(), + untransformed_gradient.as_slice_mut(), + transformed_position.as_slice_mut(), + transformed_gradient.as_slice_mut(), + ) + } + + fn init_from_transformed_position( + &mut self, + params: &Self::TransformParams, + untransformed_position: &mut Self::Vector, + untransformed_gradient: &mut Self::Vector, + transformed_position: &Self::Vector, + transformed_gradient: &mut Self::Vector, + ) -> Result<(f64, f64), Self::LogpErr> { + self.logp_func.init_from_transformed_position( + params, + untransformed_position.as_slice_mut(), + untransformed_gradient.as_slice_mut(), + transformed_position.as_slice(), + transformed_gradient.as_slice_mut(), + ) + } + + fn update_transformation<'a, R: rand::Rng + ?Sized>( + &'a mut self, + rng: &mut R, + untransformed_positions: impl ExactSizeIterator, + untransformed_gradients: impl ExactSizeIterator, + params: &'a mut Self::TransformParams, + ) -> Result<(), Self::LogpErr> { + self.logp_func.update_transformation( + rng, + untransformed_positions.map(|x| x.as_slice()), + untransformed_gradients.map(|x| x.as_slice()), + params, + ) + } + + fn new_transformation( + &mut self, + rng: &mut R, + untransformed_position: &Self::Vector, + untransfogmed_gradient: &Self::Vector, + chain: u64, + ) -> Result { + self.logp_func.new_transformation( + rng, + untransformed_position.as_slice(), + untransfogmed_gradient.as_slice(), + chain, + ) + } + + fn transformation_id(&self, params: &Self::TransformParams) -> Result { + self.logp_func.transformation_id(params) + } } pub trait CpuLogpFunc { type LogpError: Debug + Send + Sync + Error + LogpError + 'static; + type TransformParams; fn dim(&self) -> usize; fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result; + + fn inv_transform_normalize( + &mut self, + _params: &Self::TransformParams, + _untransformed_position: &[f64], + _untransformed_gradient: &[f64], + _transformed_position: &mut [f64], + _transformed_gradient: &mut [f64], + ) -> Result { + unimplemented!() + } + + fn init_from_untransformed_position( + &mut self, + _params: &Self::TransformParams, + _untransformed_position: &[f64], + _untransformed_gradient: &mut [f64], + _transformed_position: &mut [f64], + _transformed_gradient: &mut [f64], + ) -> Result<(f64, f64), Self::LogpError> { + unimplemented!() + } + + fn init_from_transformed_position( + &mut self, + _params: &Self::TransformParams, + _untransformed_position: &mut [f64], + _untransformed_gradient: &mut [f64], + _transformed_position: &[f64], + _transformed_gradient: &mut [f64], + ) -> Result<(f64, f64), Self::LogpError> { + unimplemented!() + } + + fn update_transformation<'a, R: rand::Rng + ?Sized>( + &'a mut self, + _rng: &mut R, + _untransformed_positions: impl ExactSizeIterator, + _untransformed_gradients: impl ExactSizeIterator, + _params: &'a mut Self::TransformParams, + ) -> Result<(), Self::LogpError> { + unimplemented!() + } + + fn new_transformation( + &mut self, + _rng: &mut R, + _untransformed_position: &[f64], + _untransformed_gradient: &[f64], + _chain: u64, + ) -> Result { + unimplemented!() + } + + fn transformation_id(&self, _params: &Self::TransformParams) -> Result { + unimplemented!() + } } diff --git a/src/euclidean_hamiltonian.rs b/src/euclidean_hamiltonian.rs new file mode 100644 index 0000000..d47ac4a --- /dev/null +++ b/src/euclidean_hamiltonian.rs @@ -0,0 +1,425 @@ +use std::fmt::Debug; +use std::marker::PhantomData; +use std::sync::Arc; + +use arrow::array::{ArrayBuilder, Float64Builder, StructArray}; +use arrow::datatypes::{DataType, Field}; + +use crate::hamiltonian::{Direction, DivergenceInfo, Hamiltonian, LeapfrogResult, Point}; +use crate::mass_matrix::MassMatrix; +use crate::math_base::Math; +use crate::nuts::{Collector, NutsError}; +use crate::sampler::Settings; +use crate::sampler_stats::{SamplerStats, StatTraceBuilder}; +use crate::state::{State, StatePool}; +use crate::LogpError; + +pub struct EuclideanHamiltonian> { + pub(crate) mass_matrix: Mass, + max_energy_error: f64, + step_size: f64, + pool: StatePool>, + _phantom: PhantomData, +} + +impl> EuclideanHamiltonian { + pub(crate) fn new( + math: &mut M, + mass_matrix: Mass, + max_energy_error: f64, + step_size: f64, + ) -> Self { + let pool = StatePool::new(math, 10); + EuclideanHamiltonian { + mass_matrix, + max_energy_error, + step_size, + pool, + _phantom: PhantomData, + } + } +} + +pub struct EuclideanPoint { + pub position: M::Vector, + pub velocity: M::Vector, + pub gradient: M::Vector, + pub momentum: M::Vector, + pub kinetic_energy: f64, + pub potential_energy: f64, + pub index_in_trajectory: i64, + pub p_sum: M::Vector, + pub initial_energy: f64, +} + +impl EuclideanPoint { + fn is_turning(&self, math: &mut M, other: &Self) -> bool { + let (start, end) = if self.index_in_trajectory() < other.index_in_trajectory() { + (self, other) + } else { + (other, self) + }; + + let a = start.index_in_trajectory(); + let b = end.index_in_trajectory(); + + assert!(a < b); + let (turn1, turn2) = if (a >= 0) & (b >= 0) { + math.scalar_prods3( + &end.p_sum, + &start.p_sum, + &start.momentum, + &end.velocity, + &start.velocity, + ) + } else if (b >= 0) & (a < 0) { + math.scalar_prods2(&end.p_sum, &start.p_sum, &end.velocity, &start.velocity) + } else { + assert!((a < 0) & (b < 0)); + math.scalar_prods3( + &start.p_sum, + &end.p_sum, + &end.momentum, + &end.velocity, + &start.velocity, + ) + }; + + (turn1 < 0.) | (turn2 < 0.) + } + + fn first_momentum_halfstep(&self, math: &mut M, out: &mut Self, epsilon: f64) { + math.axpy_out( + &self.gradient, + &self.momentum, + epsilon / 2., + &mut out.momentum, + ); + } + + fn position_step(&self, math: &mut M, out: &mut Self, epsilon: f64) { + math.axpy_out(&out.velocity, &self.position, epsilon, &mut out.position); + } + + fn second_momentum_halfstep(&mut self, math: &mut M, epsilon: f64) { + math.axpy(&self.gradient, &mut self.momentum, epsilon / 2.); + } + + fn set_psum(&self, math: &mut M, out: &mut Self, _dir: Direction) { + assert!(out.index_in_trajectory != 0); + + if out.index_in_trajectory == -1 { + math.copy_into(&out.momentum, &mut out.p_sum); + } else { + math.axpy_out(&out.momentum, &self.p_sum, 1., &mut out.p_sum); + } + } + + fn update_potential_gradient(&mut self, math: &mut M) -> Result<(), M::LogpErr> { + let logp = { math.logp_array(&self.position, &mut self.gradient) }?; + self.potential_energy = -logp; + Ok(()) + } +} + +impl Point for EuclideanPoint { + fn position(&self) -> &::Vector { + &self.position + } + + fn gradient(&self) -> &::Vector { + &self.gradient + } + + fn energy(&self) -> f64 { + self.potential_energy + self.kinetic_energy + } + + fn initial_energy(&self) -> f64 { + self.initial_energy + } + + fn new(math: &mut M) -> Self { + Self { + position: math.new_array(), + velocity: math.new_array(), + gradient: math.new_array(), + momentum: math.new_array(), + kinetic_energy: 0f64, + potential_energy: 0f64, + index_in_trajectory: 0, + p_sum: math.new_array(), + initial_energy: 0f64, + } + } + + fn index_in_trajectory(&self) -> i64 { + self.index_in_trajectory + } + + fn logp(&self) -> f64 { + -self.potential_energy + } + + fn copy_into(&self, math: &mut M, other: &mut Self) { + let Self { + position, + velocity, + gradient, + momentum, + kinetic_energy, + potential_energy, + index_in_trajectory, + p_sum, + initial_energy, + } = self; + math.copy_into(position, &mut other.position); + math.copy_into(velocity, &mut other.velocity); + math.copy_into(gradient, &mut other.gradient); + math.copy_into(momentum, &mut other.momentum); + math.copy_into(p_sum, &mut other.p_sum); + other.kinetic_energy = *kinetic_energy; + other.potential_energy = *potential_energy; + other.initial_energy = *initial_energy; + other.index_in_trajectory = *index_in_trajectory; + } +} + +#[derive(Copy, Clone, Debug)] +pub struct PotentialStats { + mass_matrix_stats: S, + pub step_size: f64, +} + +pub struct PotentialStatsBuilder { + mass_matrix: B, + step_size: Float64Builder, +} + +impl> StatTraceBuilder> + for PotentialStatsBuilder +{ + fn append_value(&mut self, value: PotentialStats) { + let PotentialStats { + mass_matrix_stats, + step_size, + } = value; + + self.mass_matrix.append_value(mass_matrix_stats); + self.step_size.append_value(step_size); + } + + fn finalize(self) -> Option { + let Self { + mass_matrix, + mut step_size, + } = self; + + let mut fields = vec![Field::new("step_size", DataType::Float64, false)]; + let mut arrays = vec![ArrayBuilder::finish(&mut step_size)]; + + if let Some(mass_matrix) = mass_matrix.finalize() { + let (m_fields, m_data, m_bitmap) = mass_matrix.into_parts(); + assert!(m_bitmap.is_none()); + fields.extend( + m_fields + .into_iter() + .map(|v| Arc::unwrap_or_clone(v.to_owned())), + ); + arrays.extend(m_data); + } + + Some(StructArray::new(fields.into(), arrays, None)) + } + + fn inspect(&self) -> Option { + let Self { + mass_matrix, + step_size, + } = self; + + let mut fields = vec![Field::new("step_size", DataType::Float64, false)]; + let mut arrays = vec![ArrayBuilder::finish_cloned(step_size)]; + + if let Some(mass_matrix) = mass_matrix.inspect() { + let (m_fields, m_data, m_bitmap) = mass_matrix.into_parts(); + assert!(m_bitmap.is_none()); + fields.extend( + m_fields + .into_iter() + .map(|v| Arc::unwrap_or_clone(v.to_owned())), + ); + arrays.extend(m_data); + } + + Some(StructArray::new(fields.into(), arrays, None)) + } +} + +impl> SamplerStats for EuclideanHamiltonian { + type Builder = PotentialStatsBuilder; + type Stats = PotentialStats; + + fn new_builder(&self, settings: &impl Settings, dim: usize) -> Self::Builder { + Self::Builder { + mass_matrix: self.mass_matrix.new_builder(settings, dim), + step_size: Float64Builder::with_capacity( + settings.hint_num_draws() + settings.hint_num_tune(), + ), + } + } + + fn current_stats(&self, math: &mut M) -> Self::Stats { + PotentialStats { + mass_matrix_stats: self.mass_matrix.current_stats(math), + step_size: self.step_size, + } + } +} + +impl> Hamiltonian for EuclideanHamiltonian { + type Point = EuclideanPoint; + + fn leapfrog>( + &mut self, + math: &mut M, + start: &State, + dir: Direction, + collector: &mut C, + ) -> LeapfrogResult { + let mut out = self.pool().new_state(math); + let out_point = out.try_point_mut().expect("New point has other references"); + + out_point.initial_energy = start.point().initial_energy(); + + let sign = match dir { + Direction::Forward => 1, + Direction::Backward => -1, + }; + + let epsilon = (sign as f64) * self.step_size; + + start + .point() + .first_momentum_halfstep(math, out_point, epsilon); + self.mass_matrix.update_velocity(math, out_point); + + start.point().position_step(math, out_point, epsilon); + if let Err(logp_error) = out_point.update_potential_gradient(math) { + if !logp_error.is_recoverable() { + return LeapfrogResult::Err(logp_error); + } + let div_info = DivergenceInfo { + logp_function_error: Some(Arc::new(Box::new(logp_error))), + start_location: Some(math.box_array(start.point().position())), + start_gradient: Some(math.box_array(&start.point().gradient)), + start_momentum: Some(math.box_array(&start.point().momentum)), + end_location: None, + start_idx_in_trajectory: Some(start.point().index_in_trajectory()), + end_idx_in_trajectory: None, + energy_error: None, + }; + collector.register_leapfrog(math, start, &out, Some(&div_info)); + return LeapfrogResult::Divergence(div_info); + } + + out_point.second_momentum_halfstep(math, epsilon); + + self.mass_matrix.update_velocity(math, out_point); + self.mass_matrix.update_kinetic_energy(math, out_point); + + out_point.index_in_trajectory = start.index_in_trajectory() + sign; + + start.point().set_psum(math, out_point, dir); + + let energy_error = out_point.energy_error(); + if (energy_error > self.max_energy_error) | !energy_error.is_finite() { + let divergence_info = DivergenceInfo { + logp_function_error: None, + start_location: Some(math.box_array(start.point().position())), + start_gradient: Some(math.box_array(start.point().gradient())), + end_location: Some(math.box_array(&out_point.position)), + start_momentum: Some(math.box_array(&out_point.momentum)), + start_idx_in_trajectory: Some(start.index_in_trajectory()), + end_idx_in_trajectory: Some(out.index_in_trajectory()), + energy_error: Some(energy_error), + }; + collector.register_leapfrog(math, start, &out, Some(&divergence_info)); + return LeapfrogResult::Divergence(divergence_info); + } + + collector.register_leapfrog(math, start, &out, None); + + LeapfrogResult::Ok(out) + } + + fn init_state( + &mut self, + math: &mut M, + init: &[f64], + ) -> Result, NutsError> { + let mut state = self.pool().new_state(math); + let point = state.try_point_mut().expect("State already in use"); + math.read_from_slice(&mut point.position, init); + math.fill_array(&mut point.p_sum, 0.); + + point + .update_potential_gradient(math) + .map_err(|e| NutsError::LogpFailure(Box::new(e)))?; + if !math.array_all_finite_and_nonzero(&point.gradient) { + Err(NutsError::BadInitGrad( + anyhow::anyhow!("Invalid initial point").into(), + )) + } else { + Ok(state) + } + } + + fn initialize_trajectory( + &self, + math: &mut M, + state: &mut State, + rng: &mut R, + ) -> Result<(), NutsError> { + let inner = state.try_point_mut().expect("State has other references"); + self.mass_matrix.randomize_momentum(math, inner, rng); + self.mass_matrix.update_velocity(math, inner); + self.mass_matrix.update_kinetic_energy(math, inner); + inner.index_in_trajectory = 0; + inner.initial_energy = inner.energy(); + math.copy_into(&inner.momentum, &mut inner.p_sum); + Ok(()) + } + + fn is_turning( + &self, + math: &mut M, + state1: &State, + state2: &State, + ) -> bool { + state1.point().is_turning(math, state2.point()) + } + + fn copy_state(&mut self, math: &mut M, state: &State) -> State { + let mut new_state = self.pool().new_state(math); + state.point().copy_into( + math, + new_state + .try_point_mut() + .expect("New point should not have other references"), + ); + new_state + } + + fn pool(&mut self) -> &mut StatePool { + &mut self.pool + } + + fn step_size(&self) -> f64 { + self.step_size + } + + fn step_size_mut(&mut self) -> &mut f64 { + &mut self.step_size + } +} diff --git a/src/hamiltonian.rs b/src/hamiltonian.rs new file mode 100644 index 0000000..fd34c8f --- /dev/null +++ b/src/hamiltonian.rs @@ -0,0 +1,117 @@ +use std::sync::Arc; + +use crate::{ + nuts::Collector, + sampler_stats::SamplerStats, + state::{State, StatePool}, + Math, NutsError, +}; + +/// Details about a divergence that might have occured during sampling +/// +/// There are two reasons why we might observe a divergence: +/// - The integration error of the Hamiltonian is larger than +/// a cutoff value or nan. +/// - The logp function caused a recoverable error (eg if an ODE solver +/// failed) +#[derive(Debug, Clone)] +pub struct DivergenceInfo { + pub start_momentum: Option>, + pub start_location: Option>, + pub start_gradient: Option>, + pub end_location: Option>, + pub energy_error: Option, + pub end_idx_in_trajectory: Option, + pub start_idx_in_trajectory: Option, + pub logp_function_error: Option>, +} + +#[derive(Debug, Copy, Clone)] +pub enum Direction { + Forward, + Backward, +} + +impl rand::distributions::Distribution for rand::distributions::Standard { + fn sample(&self, rng: &mut R) -> Direction { + if rng.gen::() { + Direction::Forward + } else { + Direction::Backward + } + } +} + +pub enum LeapfrogResult> { + Ok(State), + Divergence(DivergenceInfo), + Err(M::LogpErr), +} + +pub trait Point: Sized { + fn position(&self) -> &M::Vector; + fn gradient(&self) -> &M::Vector; + fn index_in_trajectory(&self) -> i64; + fn energy(&self) -> f64; + fn logp(&self) -> f64; + + fn energy_error(&self) -> f64 { + self.energy() - self.initial_energy() + } + + fn initial_energy(&self) -> f64; + + fn new(math: &mut M) -> Self; + fn copy_into(&self, math: &mut M, other: &mut Self); +} + +/// The hamiltonian defined by the potential energy and the kinetic energy +pub trait Hamiltonian: SamplerStats + Sized { + /// The type that stores a point in phase space, together + /// with some information about the location inside the + /// integration trajectory. + type Point: Point; + + /// Perform one leapfrog step. + /// + /// Return either an unrecoverable error, a new state or a divergence. + fn leapfrog>( + &mut self, + math: &mut M, + start: &State, + dir: Direction, + collector: &mut C, + ) -> LeapfrogResult; + + fn is_turning( + &self, + math: &mut M, + state1: &State, + state2: &State, + ) -> bool; + + /// Initialize a state at a new location. + /// + /// The momentum should be initialized to some arbitrary invalid number, + /// it will later be set using Self::randomize_momentum. + fn init_state( + &mut self, + math: &mut M, + init: &[f64], + ) -> Result, NutsError>; + + /// Randomize the momentum part of a state + fn initialize_trajectory( + &self, + math: &mut M, + state: &mut State, + rng: &mut R, + ) -> Result<(), NutsError>; + + fn pool(&mut self) -> &mut StatePool; + + fn copy_state(&mut self, math: &mut M, state: &State) -> State; + + fn step_size(&self) -> f64; + fn step_size_mut(&mut self) -> &mut f64; +} diff --git a/src/lib.rs b/src/lib.rs index 1514295..b86b0f3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -31,6 +31,9 @@ //! impl CpuLogpFunc for PosteriorDensity { //! type LogpError = PosteriorLogpError; //! +//! // Only used for transforming adaptation. +//! type TransformParams = (); +//! //! // We define a 10 dimensional normal distribution //! fn dim(&self) -> usize { 10 } //! @@ -84,29 +87,37 @@ //! somewhat. mod adapt_strategy; +mod chain; mod cpu_math; +mod euclidean_hamiltonian; +mod hamiltonian; mod low_rank_mass_matrix; mod mass_matrix; mod mass_matrix_adapt; mod math; mod math_base; mod nuts; -mod potential; mod sampler; +mod sampler_stats; mod state; mod stepsize; mod stepsize_adapt; +mod transform_adapt_strategy; +mod transformed_hamiltonian; -pub use adapt_strategy::AdaptOptions; +pub use adapt_strategy::EuclideanAdaptOptions; +pub use chain::Chain; pub use cpu_math::{CpuLogpFunc, CpuMath}; -pub use math_base::Math; -pub use nuts::{Chain, DivergenceInfo, LogpError, NutsError, SampleStats}; +pub use hamiltonian::DivergenceInfo; +pub use math_base::{LogpError, Math}; +pub use nuts::{NutsError, SampleStats}; pub use sampler::{ sample_sequentially, ChainOutput, ChainProgress, DiagGradNutsSettings, DrawStorage, LowRankNutsSettings, Model, NutsSettings, ProgressCallback, Sampler, SamplerWaitResult, - Settings, Trace, + Settings, Trace, TransformedNutsSettings, }; pub use low_rank_mass_matrix::LowRankSettings; pub use mass_matrix_adapt::DiagAdaptExpSettings; pub use stepsize_adapt::DualAverageSettings; +pub use transform_adapt_strategy::TransformedSettings; diff --git a/src/low_rank_mass_matrix.rs b/src/low_rank_mass_matrix.rs index 1327246..b4edfa3 100644 --- a/src/low_rank_mass_matrix.rs +++ b/src/low_rank_mass_matrix.rs @@ -8,12 +8,12 @@ use faer::{Col, Mat, Scale}; use itertools::Itertools; use crate::{ + euclidean_hamiltonian::EuclideanPoint, + hamiltonian::Point, mass_matrix::{DrawGradCollector, MassMatrix}, mass_matrix_adapt::MassMatrixAdaptStrategy, - nuts::{AdaptStats, AdaptStrategy, SamplerStats, StatTraceBuilder}, - potential::EuclideanPotential, - state::State, - Math, + sampler_stats::{SamplerStats, StatTraceBuilder}, + Math, NutsError, }; #[derive(Debug)] @@ -300,33 +300,39 @@ impl SamplerStats for LowRankMassMatrix { } impl MassMatrix for LowRankMassMatrix { - fn update_velocity(&self, math: &mut M, state: &mut crate::state::InnerState) { + fn update_velocity(&self, math: &mut M, state: &mut EuclideanPoint) { let Some(inner) = self.inner.as_ref() else { - math.array_mult(&self.variance, &state.p, &mut state.v); + math.array_mult(&self.variance, &state.momentum, &mut state.velocity); return; }; - math.array_mult_eigs(&self.stds, &state.p, &mut state.v, &inner.vecs, &inner.vals); + math.array_mult_eigs( + &self.stds, + &state.momentum, + &mut state.velocity, + &inner.vecs, + &inner.vals, + ); } - fn update_kinetic_energy(&self, math: &mut M, state: &mut crate::state::InnerState) { - state.kinetic_energy = 0.5 * math.array_vector_dot(&state.p, &state.v); + fn update_kinetic_energy(&self, math: &mut M, state: &mut EuclideanPoint) { + state.kinetic_energy = 0.5 * math.array_vector_dot(&state.momentum, &state.velocity); } fn randomize_momentum( &self, math: &mut M, - state: &mut crate::state::InnerState, + state: &mut EuclideanPoint, rng: &mut R, ) { let Some(inner) = self.inner.as_ref() else { - math.array_gaussian(rng, &mut state.p, &self.inv_stds); + math.array_gaussian(rng, &mut state.momentum, &self.inv_stds); return; }; math.array_gaussian_eigs( rng, - &mut state.p, + &mut state.momentum, &self.inv_stds, &inner.vals_sqrt_inv, &inner.vecs, @@ -384,12 +390,12 @@ impl LowRankMassMatrixStrategy { } } - pub fn add_draw(&mut self, math: &mut M, state: &State) { + pub fn add_draw(&mut self, math: &mut M, point: &impl Point) { assert!(math.dim() == self.ndim); let mut draw = vec![0f64; self.ndim]; - math.write_to_slice(&state.q, &mut draw); + math.write_to_slice(point.position(), &mut draw); let mut grad = vec![0f64; self.ndim]; - math.write_to_slice(&state.grad, &mut grad); + math.write_to_slice(point.gradient(), &mut grad); self.draws.push_back(draw); self.grads.push_back(grad); @@ -548,15 +554,8 @@ fn spd_mean(cov_draws: Mat, cov_grads: Mat) -> Mat { (&grads_inv_sqrt) * m_sqrt * grads_inv_sqrt } -impl AdaptStats for LowRankMassMatrixStrategy { - fn num_grad_evals(_stats: &Self::Stats) -> usize { - unimplemented!() - } -} - impl SamplerStats for LowRankMassMatrixStrategy { type Stats = Stats; - type Builder = Builder; fn new_builder(&self, _settings: &impl crate::Settings, _dim: usize) -> Self::Builder { @@ -568,14 +567,12 @@ impl SamplerStats for LowRankMassMatrixStrategy { } } -impl AdaptStrategy for LowRankMassMatrixStrategy { - type Potential = EuclideanPotential>; - +impl MassMatrixAdaptStrategy for LowRankMassMatrixStrategy { + type MassMatrix = LowRankMassMatrix; type Collector = DrawGradCollector; - type Options = LowRankSettings; - fn new(math: &mut M, options: Self::Options, _num_tune: u64) -> Self { + fn new(math: &mut M, options: Self::Options, _num_tune: u64, _chain: u64) -> Self { Self::new(math.dim(), options) } @@ -583,40 +580,19 @@ impl AdaptStrategy for LowRankMassMatrixStrategy { &mut self, math: &mut M, _options: &mut crate::nuts::NutsOptions, - potential: &mut Self::Potential, - state: &State, + mass_matrix: &mut Self::MassMatrix, + point: &impl Point, _rng: &mut R, - ) { - self.add_draw(math, state); - potential - .mass_matrix - .update_from_grad(math, &state.grad, 1f64, (1e-20, 1e20)) - } - - fn adapt( - &mut self, - _math: &mut M, - _options: &mut crate::nuts::NutsOptions, - _potential: &mut Self::Potential, - _draw: u64, - _collector: &Self::Collector, - _state: &State, - _rng: &mut R, - ) { + ) -> Result<(), NutsError> { + self.add_draw(math, point); + mass_matrix.update_from_grad(math, point.gradient(), 1f64, (1e-20, 1e20)); + Ok(()) } fn new_collector(&self, math: &mut M) -> Self::Collector { DrawGradCollector::new(math) } - fn is_tuning(&self) -> bool { - unreachable!() - } -} - -impl MassMatrixAdaptStrategy for LowRankMassMatrixStrategy { - type MassMatrix = LowRankMassMatrix; - fn update_estimators(&mut self, math: &mut M, collector: &Self::Collector) { if collector.is_good { let mut draw = vec![0f64; self.ndim]; @@ -646,11 +622,11 @@ impl MassMatrixAdaptStrategy for LowRankMassMatrixStrategy { self.draws.len().checked_sub(self.background_split).unwrap() as u64 } - fn update_potential(&self, math: &mut M, potential: &mut Self::Potential) -> bool { + fn adapt(&self, math: &mut M, mass_matrix: &mut Self::MassMatrix) -> bool { if >::current_count(self) < 3 { return false; } - self.update(math, &mut potential.mass_matrix); + self.update(math, mass_matrix); true } @@ -662,7 +638,7 @@ mod test { use faer::{assert_matrix_eq, Col, Mat}; use rand::{rngs::SmallRng, Rng, SeedableRng}; - use rand_distr::{num_traits::ToBytes, StandardNormal}; + use rand_distr::StandardNormal; use super::{estimate_mass_matrix, spd_mean}; diff --git a/src/mass_matrix.rs b/src/mass_matrix.rs index 96ca8d3..d6d5a8c 100644 --- a/src/mass_matrix.rs +++ b/src/mass_matrix.rs @@ -4,28 +4,29 @@ use arrow::{ }; use crate::{ + euclidean_hamiltonian::EuclideanPoint, + hamiltonian::Point, math_base::Math, nuts::Collector, - nuts::SamplerStats, - nuts::StatTraceBuilder, sampler::Settings, - state::{InnerState, State}, + sampler_stats::{SamplerStats, StatTraceBuilder}, + state::State, }; pub trait MassMatrix: SamplerStats { - fn update_velocity(&self, math: &mut M, state: &mut InnerState); - fn update_kinetic_energy(&self, math: &mut M, state: &mut InnerState); + fn update_velocity(&self, math: &mut M, state: &mut EuclideanPoint); + fn update_kinetic_energy(&self, math: &mut M, state: &mut EuclideanPoint); fn randomize_momentum( &self, math: &mut M, - state: &mut InnerState, + point: &mut EuclideanPoint, rng: &mut R, ); } pub struct NullCollector {} -impl Collector for NullCollector {} +impl> Collector for NullCollector {} #[derive(Debug)] pub struct DiagMassMatrix { @@ -178,21 +179,21 @@ impl DiagMassMatrix { } impl MassMatrix for DiagMassMatrix { - fn update_velocity(&self, math: &mut M, state: &mut InnerState) { - math.array_mult(&self.variance, &state.p, &mut state.v); + fn update_velocity(&self, math: &mut M, point: &mut EuclideanPoint) { + math.array_mult(&self.variance, &point.momentum, &mut point.velocity); } - fn update_kinetic_energy(&self, math: &mut M, state: &mut InnerState) { - state.kinetic_energy = 0.5 * math.array_vector_dot(&state.p, &state.v); + fn update_kinetic_energy(&self, math: &mut M, point: &mut EuclideanPoint) { + point.kinetic_energy = 0.5 * math.array_vector_dot(&point.momentum, &point.velocity); } fn randomize_momentum( &self, math: &mut M, - state: &mut InnerState, + point: &mut EuclideanPoint, rng: &mut R, ) { - math.array_gaussian(rng, &mut state.p, &self.inv_stds); + math.array_gaussian(rng, &mut point.momentum, &self.inv_stds); } } @@ -253,10 +254,15 @@ impl DrawGradCollector { } } -impl Collector for DrawGradCollector { - fn register_draw(&mut self, math: &mut M, state: &State, info: &crate::nuts::SampleInfo) { - math.copy_into(&state.q, &mut self.draw); - math.copy_into(&state.grad, &mut self.grad); +impl Collector> for DrawGradCollector { + fn register_draw( + &mut self, + math: &mut M, + state: &State>, + info: &crate::nuts::SampleInfo, + ) { + math.copy_into(state.point().position(), &mut self.draw); + math.copy_into(state.point().gradient(), &mut self.grad); let idx = state.index_in_trajectory(); if info.divergence_info.is_some() { self.is_good = idx.abs() > 4; diff --git a/src/mass_matrix_adapt.rs b/src/mass_matrix_adapt.rs index c542b59..06a04aa 100644 --- a/src/mass_matrix_adapt.rs +++ b/src/mass_matrix_adapt.rs @@ -3,11 +3,12 @@ use std::marker::PhantomData; use rand::Rng; use crate::{ + euclidean_hamiltonian::EuclideanPoint, + hamiltonian::Point, mass_matrix::{DiagMassMatrix, DrawGradCollector, MassMatrix, RunningVariance}, - nuts::{AdaptStats, AdaptStrategy, NutsOptions, SamplerStats}, - potential::EuclideanPotential, - state::State, - Math, Settings, + nuts::{Collector, NutsOptions}, + sampler_stats::SamplerStats, + Math, NutsError, Settings, }; const LOWER_LIMIT: f64 = 1e-20f64; const UPPER_LIMIT: f64 = 1e20f64; @@ -40,8 +41,10 @@ pub struct Strategy { _phantom: PhantomData, } -pub trait MassMatrixAdaptStrategy: AdaptStrategy { +pub trait MassMatrixAdaptStrategy: SamplerStats { type MassMatrix: MassMatrix; + type Collector: Collector>; + type Options: std::fmt::Debug + Default + Clone + Send + Sync + Copy; fn update_estimators(&mut self, math: &mut M, collector: &Self::Collector); @@ -52,11 +55,26 @@ pub trait MassMatrixAdaptStrategy: AdaptStrategy { fn background_count(&self) -> u64; /// Give the opportunity to update the potential and return if it was changed - fn update_potential(&self, math: &mut M, potential: &mut Self::Potential) -> bool; + fn adapt(&self, math: &mut M, mass_matrix: &mut Self::MassMatrix) -> bool; + + fn new(math: &mut M, options: Self::Options, _num_tune: u64, _chain: u64) -> Self; + + fn init( + &mut self, + math: &mut M, + _options: &mut NutsOptions, + mass_matrix: &mut Self::MassMatrix, + point: &impl Point, + _rng: &mut R, + ) -> Result<(), NutsError>; + + fn new_collector(&self, math: &mut M) -> Self::Collector; } impl MassMatrixAdaptStrategy for Strategy { type MassMatrix = DiagMassMatrix; + type Collector = DrawGradCollector; + type Options = DiagAdaptExpSettings; fn update_estimators(&mut self, math: &mut M, collector: &DrawGradCollector) { if collector.is_good { @@ -85,11 +103,7 @@ impl MassMatrixAdaptStrategy for Strategy { } /// Give the opportunity to update the potential and return if it was changed - fn update_potential( - &self, - math: &mut M, - potential: &mut EuclideanPotential, - ) -> bool { + fn adapt(&self, math: &mut M, mass_matrix: &mut DiagMassMatrix) -> bool { if self.current_count() < 3 { return false; } @@ -99,7 +113,7 @@ impl MassMatrixAdaptStrategy for Strategy { assert!(draw_scale == grad_scale); if self._settings.use_grad_based_estimate { - potential.mass_matrix.update_diag_draw_grad( + mass_matrix.update_diag_draw_grad( math, draw_var, grad_var, @@ -108,44 +122,13 @@ impl MassMatrixAdaptStrategy for Strategy { ); } else { let scale = (self.exp_variance_draw.count() as f64).recip(); - potential.mass_matrix.update_diag_draw( - math, - draw_var, - scale, - None, - (LOWER_LIMIT, UPPER_LIMIT), - ); + mass_matrix.update_diag_draw(math, draw_var, scale, None, (LOWER_LIMIT, UPPER_LIMIT)); } true } -} - -pub type Stats = (); -pub type StatsBuilder = (); - -impl SamplerStats for Strategy { - type Builder = Stats; - type Stats = StatsBuilder; - - fn new_builder(&self, _settings: &impl Settings, _dim: usize) -> Self::Builder {} - - fn current_stats(&self, _math: &mut M) -> Self::Stats {} -} - -impl AdaptStats for Strategy { - // This is never called - fn num_grad_evals(_stats: &Self::Stats) -> usize { - unimplemented!() - } -} -impl AdaptStrategy for Strategy { - type Potential = EuclideanPotential>; - type Collector = DrawGradCollector; - type Options = DiagAdaptExpSettings; - - fn new(math: &mut M, options: Self::Options, _num_tune: u64) -> Self { + fn new(math: &mut M, options: Self::Options, _num_tune: u64, _chain: u64) -> Self { Self { exp_variance_draw: RunningVariance::new(math), exp_variance_grad: RunningVariance::new(math), @@ -160,41 +143,37 @@ impl AdaptStrategy for Strategy { &mut self, math: &mut M, _options: &mut NutsOptions, - potential: &mut Self::Potential, - state: &State, + mass_matrix: &mut Self::MassMatrix, + point: &impl Point, _rng: &mut R, - ) { - self.exp_variance_draw.add_sample(math, &state.q); - self.exp_variance_draw_bg.add_sample(math, &state.q); - self.exp_variance_grad.add_sample(math, &state.grad); - self.exp_variance_grad_bg.add_sample(math, &state.grad); + ) -> Result<(), NutsError> { + self.exp_variance_draw.add_sample(math, point.position()); + self.exp_variance_draw_bg.add_sample(math, point.position()); + self.exp_variance_grad.add_sample(math, point.gradient()); + self.exp_variance_grad_bg.add_sample(math, point.gradient()); - potential.mass_matrix.update_diag_grad( + mass_matrix.update_diag_grad( math, - &state.grad, + point.gradient(), 1f64, (INIT_LOWER_LIMIT, INIT_UPPER_LIMIT), ); - } - - fn adapt( - &mut self, - _math: &mut M, - _options: &mut NutsOptions, - _potential: &mut Self::Potential, - _draw: u64, - _collector: &Self::Collector, - _state: &State, - _rng: &mut R, - ) { - // Must be controlled from a different meta strategy + Ok(()) } fn new_collector(&self, math: &mut M) -> Self::Collector { DrawGradCollector::new(math) } +} - fn is_tuning(&self) -> bool { - unreachable!() - } +pub type Stats = (); +pub type StatsBuilder = (); + +impl SamplerStats for Strategy { + type Builder = Stats; + type Stats = StatsBuilder; + + fn new_builder(&self, _settings: &impl Settings, _dim: usize) -> Self::Builder {} + + fn current_stats(&self, _math: &mut M) -> Self::Stats {} } diff --git a/src/math_base.rs b/src/math_base.rs index 3fcc228..e4a5e07 100644 --- a/src/math_base.rs +++ b/src/math_base.rs @@ -1,15 +1,27 @@ use std::{error::Error, fmt::Debug}; -use crate::LogpError; +/// Errors that happen when we evaluate the logp and gradient function +pub trait LogpError: std::error::Error + Send { + /// Unrecoverable errors during logp computation stop sampling, + /// recoverable errors are seen as divergences. + fn is_recoverable(&self) -> bool; +} pub trait Math { type Vector: Debug; type EigVectors: Debug; type EigValues: Debug; - type LogpErr: Debug + Send + Sync + LogpError + 'static; + type LogpErr: Debug + Send + Sync + LogpError + Sized + 'static; type Err: Debug + Send + Sync + Error + 'static; + type TransformParams; + + fn new_array(&mut self) -> Self::Vector; - fn new_array(&self) -> Self::Vector; + fn copy_array(&mut self, array: &Self::Vector) -> Self::Vector { + let mut copy = self.new_array(); + self.copy_into(array, &mut copy); + copy + } fn new_eig_vectors<'a>( &'a mut self, @@ -127,4 +139,49 @@ pub trait Math { fill_invalid: f64, clamp: (f64, f64), ); + + fn inv_transform_normalize( + &mut self, + params: &Self::TransformParams, + untransformed_position: &Self::Vector, + untransofrmed_gradient: &Self::Vector, + transformed_position: &mut Self::Vector, + transformed_gradient: &mut Self::Vector, + ) -> Result; + + fn init_from_untransformed_position( + &mut self, + params: &Self::TransformParams, + untransformed_position: &Self::Vector, + untransformed_gradient: &mut Self::Vector, + transformed_position: &mut Self::Vector, + transformed_gradient: &mut Self::Vector, + ) -> Result<(f64, f64), Self::LogpErr>; + + fn init_from_transformed_position( + &mut self, + params: &Self::TransformParams, + untransformed_position: &mut Self::Vector, + untransformed_gradient: &mut Self::Vector, + transformed_position: &Self::Vector, + transformed_gradient: &mut Self::Vector, + ) -> Result<(f64, f64), Self::LogpErr>; + + fn update_transformation<'a, R: rand::Rng + ?Sized>( + &'a mut self, + rng: &mut R, + untransformed_positions: impl ExactSizeIterator, + untransformed_gradients: impl ExactSizeIterator, + params: &'a mut Self::TransformParams, + ) -> Result<(), Self::LogpErr>; + + fn new_transformation( + &mut self, + rng: &mut R, + untransformed_position: &Self::Vector, + untransfogmed_gradient: &Self::Vector, + chain: u64, + ) -> Result; + + fn transformation_id(&self, params: &Self::TransformParams) -> Result; } diff --git a/src/nuts.rs b/src/nuts.rs index 62411ad..62df0b5 100644 --- a/src/nuts.rs +++ b/src/nuts.rs @@ -3,148 +3,51 @@ use arrow::array::{ StructArray, }; use arrow::datatypes::{DataType, Field, Fields, Float64Type, Int64Type, UInt64Type}; -use rand::Rng; use thiserror::Error; use std::ops::Deref; use std::sync::Arc; use std::{fmt::Debug, marker::PhantomData}; +use crate::chain::AdaptStrategy; +use crate::hamiltonian::{Direction, DivergenceInfo, Hamiltonian, LeapfrogResult, Point}; use crate::math::logaddexp; use crate::sampler::Settings; -use crate::state::{State, StatePool}; +use crate::sampler_stats::StatTraceBuilder; +use crate::state::State; use crate::math_base::Math; #[non_exhaustive] #[derive(Error, Debug)] pub enum NutsError { - #[error("Logp function returned error: {0}")] + #[error("Logp function returned error: {0:?}")] LogpFailure(Box), #[error("Could not serialize sample stats")] SerializeFailure(), - #[error("Could not initialize state because of bad initial gradient.")] - BadInitGrad(), + #[error("Could not initialize state because of bad initial gradient: {0:?}")] + BadInitGrad(Box), } pub type Result = std::result::Result; -/// Details about a divergence that might have occured during sampling -/// -/// There are two reasons why we might observe a divergence: -/// - The integration error of the Hamiltonian is larger than -/// a cutoff value or nan. -/// - The logp function caused a recoverable error (eg if an ODE solver -/// failed) -#[derive(Debug, Clone)] -pub struct DivergenceInfo { - pub start_momentum: Option>, - pub start_location: Option>, - pub start_gradient: Option>, - pub end_location: Option>, - pub energy_error: Option, - pub end_idx_in_trajectory: Option, - pub start_idx_in_trajectory: Option, - pub logp_function_error: Option>, -} - -#[derive(Debug, Copy, Clone)] -pub enum Direction { - Forward, - Backward, -} - -impl rand::distributions::Distribution for rand::distributions::Standard { - fn sample(&self, rng: &mut R) -> Direction { - if rng.gen::() { - Direction::Forward - } else { - Direction::Backward - } - } -} - /// Callbacks for various events during a Nuts sampling step. /// /// Collectors can compute statistics like the mean acceptance rate /// or collect data for mass matrix adaptation. -pub trait Collector { +pub trait Collector> { fn register_leapfrog( &mut self, _math: &mut M, - _start: &State, - _end: &State, + _start: &State, + _end: &State, _divergence_info: Option<&DivergenceInfo>, ) { } - fn register_draw(&mut self, _math: &mut M, _state: &State, _info: &SampleInfo) {} - fn register_init(&mut self, _math: &mut M, _state: &State, _options: &NutsOptions) {} -} - -/// Errors that happen when we evaluate the logp and gradient function -pub trait LogpError: std::error::Error { - /// Unrecoverable errors during logp computation stop sampling, - /// recoverable errors are seen as divergences. - fn is_recoverable(&self) -> bool; -} - -pub trait HamiltonianStats: SamplerStats { - fn stat_step_size(stats: &Self::Stats) -> f64; -} - -/// The hamiltonian defined by the potential energy and the kinetic energy -pub trait Hamiltonian: HamiltonianStats -where - M: Math, -{ - /// The type that stores a point in phase space - //type State: State; - /// Errors that happen during logp evaluation - type LogpError: LogpError + Send; - - /// Perform one leapfrog step. - /// - /// Return either an unrecoverable error, a new state or a divergence. - fn leapfrog>( - &mut self, - math: &mut M, - pool: &mut StatePool, - start: &State, - dir: Direction, - initial_energy: f64, - collector: &mut C, - ) -> Result, DivergenceInfo>>; - - /// Initialize a state at a new location. - /// - /// The momentum should be initialized to some arbitrary invalid number, - /// it will later be set using Self::randomize_momentum. - fn init_state( - &mut self, - math: &mut M, - pool: &mut StatePool, - init: &[f64], - ) -> Result>; - - /// Randomize the momentum part of a state - fn randomize_momentum( - &self, - math: &mut M, - state: &mut State, - rng: &mut R, - ); - - fn new_empty_state(&mut self, math: &mut M, pool: &mut StatePool) -> State; - - /// Crate a new state pool that can be used to crate new states. - fn new_pool(&mut self, math: &mut M, capacity: usize) -> StatePool; - - fn copy_state(&mut self, math: &mut M, pool: &mut StatePool, state: &State) -> State; - - fn stepsize_mut(&mut self) -> &mut f64; - fn stepsize(&self) -> f64; + fn register_draw(&mut self, _math: &mut M, _state: &State, _info: &SampleInfo) {} + fn register_init(&mut self, _math: &mut M, _state: &State, _options: &NutsOptions) {} } /// Information about a draw, exported as part of the sampler stats @@ -166,29 +69,27 @@ pub struct SampleInfo { } /// A part of the trajectory tree during NUTS sampling. -struct NutsTree, C: Collector> { +struct NutsTree, C: Collector> { /// The left position of the tree. /// /// The left side always has the smaller index_in_trajectory. /// Leapfrogs in backward direction will replace the left. - left: State, - right: State, + left: State, + right: State, /// A draw from the trajectory between left and right using /// multinomial sampling. - draw: State, + draw: State, log_size: f64, depth: u64, - initial_energy: f64, /// A tree is the main tree if it contains the initial point /// of the trajectory. is_main: bool, - _phantom: PhantomData, _phantom2: PhantomData, } -enum ExtendResult, C: Collector> { +enum ExtendResult, C: Collector> { /// The tree extension succeeded properly, and the termination /// criterion was not reached. Ok(NutsTree), @@ -201,18 +102,15 @@ enum ExtendResult, C: Collector> { Diverging(NutsTree, DivergenceInfo), } -impl, C: Collector> NutsTree { - fn new(state: State) -> NutsTree { - let initial_energy = state.energy(); +impl, C: Collector> NutsTree { + fn new(state: State) -> NutsTree { NutsTree { right: state.clone(), left: state.clone(), draw: state, depth: 0, log_size: 0., - initial_energy, is_main: true, - _phantom: PhantomData, _phantom2: PhantomData, } } @@ -222,9 +120,8 @@ impl, C: Collector> NutsTree { fn extend( mut self, math: &mut M, - pool: &mut StatePool, rng: &mut R, - potential: &mut H, + hamiltonian: &mut H, direction: Direction, collector: &mut C, options: &NutsOptions, @@ -233,7 +130,7 @@ impl, C: Collector> NutsTree { H: Hamiltonian, R: rand::Rng + ?Sized, { - let mut other = match self.single_step(math, pool, potential, direction, collector) { + let mut other = match self.single_step(math, hamiltonian, direction, collector) { Ok(Ok(tree)) => tree, Ok(Err(info)) => return ExtendResult::Diverging(self, info), Err(err) => return ExtendResult::Err(err), @@ -241,7 +138,7 @@ impl, C: Collector> NutsTree { while other.depth < self.depth { use ExtendResult::*; - other = match other.extend(math, pool, rng, potential, direction, collector, options) { + other = match other.extend(math, rng, hamiltonian, direction, collector, options) { Ok(tree) => tree, Turning(_) => { return Turning(self); @@ -261,13 +158,13 @@ impl, C: Collector> NutsTree { }; let turning = if options.check_turning { - let mut turning = first.is_turning(math, last); + let mut turning = hamiltonian.is_turning(math, first, last); if self.depth > 0 { if !turning { - turning = self.right.is_turning(math, &other.right); + turning = hamiltonian.is_turning(math, &self.right, &other.right); } if !turning { - turning = self.left.is_turning(math, &other.left); + turning = hamiltonian.is_turning(math, &self.left, &other.left); } } turning @@ -324,8 +221,7 @@ impl, C: Collector> NutsTree { fn single_step( &self, math: &mut M, - pool: &mut StatePool, - potential: &mut H, + hamiltonian: &mut H, direction: Direction, collector: &mut C, ) -> Result, DivergenceInfo>> { @@ -333,29 +229,20 @@ impl, C: Collector> NutsTree { Direction::Forward => &self.right, Direction::Backward => &self.left, }; - let end = match potential.leapfrog( - math, - pool, - start, - direction, - self.initial_energy, - collector, - ) { - Ok(Ok(end)) => end, - Ok(Err(info)) => return Ok(Err(info)), - Err(error) => return Err(error), + let end = match hamiltonian.leapfrog(math, start, direction, collector) { + LeapfrogResult::Divergence(info) => return Ok(Err(info)), + LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())), + LeapfrogResult::Ok(end) => end, }; - let log_size = self.initial_energy - end.energy(); + let log_size = -end.point().energy_error(); Ok(Ok(NutsTree { right: end.clone(), left: end.clone(), draw: end, depth: 0, log_size, - initial_energy: self.initial_energy, is_main: false, - _phantom: PhantomData, _phantom2: PhantomData, })) } @@ -365,7 +252,7 @@ impl, C: Collector> NutsTree { depth: self.depth, divergence_info, reached_maxdepth: maxdepth, - initial_energy: self.initial_energy, + initial_energy: self.draw.point().initial_energy(), draw_energy: self.draw.energy(), } } @@ -379,29 +266,27 @@ pub struct NutsOptions { pub store_divergences: bool, } -pub(crate) fn draw( +pub(crate) fn draw( math: &mut M, - pool: &mut StatePool, - init: &mut State, + init: &mut State, rng: &mut R, - potential: &mut P, + hamiltonian: &mut H, options: &NutsOptions, collector: &mut C, -) -> Result<(State, SampleInfo)> +) -> Result<(State, SampleInfo)> where M: Math, - P: Hamiltonian, + H: Hamiltonian, R: rand::Rng + ?Sized, - C: Collector, + C: Collector, { - potential.randomize_momentum(math, init, rng); - init.make_init_point(math); + hamiltonian.initialize_trajectory(math, init, rng)?; collector.register_init(math, init, options); let mut tree = NutsTree::new(init.clone()); while tree.depth < options.maxdepth { let direction: Direction = rng.gen(); - tree = match tree.extend(math, pool, rng, potential, direction, collector, options) { + tree = match tree.extend(math, rng, hamiltonian, direction, collector, options) { ExtendResult::Ok(tree) => tree, ExtendResult::Turning(tree) => { let info = tree.info(false, None); @@ -423,48 +308,23 @@ where Ok((tree.draw, info)) } -pub trait SamplerStats { - type Stats: Send + Debug + Clone; - type Builder: StatTraceBuilder; - - fn new_builder(&self, settings: &impl Settings, dim: usize) -> Self::Builder; - fn current_stats(&self, math: &mut M) -> Self::Stats; -} - -impl StatTraceBuilder<()> for () { - fn append_value(&mut self, _value: ()) {} - - fn finalize(self) -> Option { - None - } - - fn inspect(&self) -> Option { - None - } -} - -pub trait StatTraceBuilder: Send { - fn append_value(&mut self, value: T); - fn finalize(self) -> Option; - fn inspect(&self) -> Option; -} - #[derive(Debug, Clone)] +#[non_exhaustive] pub struct NutsSampleStats { - depth: u64, - maxdepth_reached: bool, - idx_in_trajectory: i64, - logp: f64, - energy: f64, - energy_error: f64, - pub(crate) divergence_info: Option, - pub(crate) chain: u64, - pub(crate) draw: u64, - gradient: Option>, - unconstrained: Option>, - pub(crate) potential_stats: HStats, - pub(crate) strategy_stats: AdaptStats, - pub(crate) tuning: bool, + pub depth: u64, + pub maxdepth_reached: bool, + pub idx_in_trajectory: i64, + pub logp: f64, + pub energy: f64, + pub energy_error: f64, + pub divergence_info: Option, + pub chain: u64, + pub draw: u64, + pub gradient: Option>, + pub unconstrained: Option>, + pub potential_stats: HStats, + pub strategy_stats: AdaptStats, + pub tuning: bool, } #[derive(Debug, Clone)] @@ -474,8 +334,8 @@ pub struct SampleStats { pub chain: u64, pub diverging: bool, pub tuning: bool, - pub num_steps: usize, pub step_size: f64, + pub num_steps: u64, } pub struct NutsStatsBuilder { @@ -501,7 +361,7 @@ pub struct NutsStatsBuilder { } impl NutsStatsBuilder { - fn new_with_capacity< + pub fn new_with_capacity< M: Math, H: Hamiltonian, A: AdaptStrategy, @@ -913,246 +773,6 @@ where } } -/// Draw samples from the posterior distribution using Hamiltonian MCMC. -pub trait Chain: SamplerStats { - type Hamiltonian: Hamiltonian; - type AdaptStrategy: AdaptStrategy; - - /// Initialize the sampler to a position. This should be called - /// before calling draw. - /// - /// This fails if the logp function returns an error. - fn set_position(&mut self, position: &[f64]) -> Result<()>; - - /// Draw a new sample and return the position and some diagnosic information. - fn draw(&mut self) -> Result<(Box<[f64]>, Self::Stats)>; - - // Extract a summary of the sample stats - fn stats_summary(stats: &Self::Stats) -> SampleStats; - - /// The dimensionality of the posterior. - fn dim(&self) -> usize; -} - -pub struct NutsChain -where - M: Math, - P: Hamiltonian, - R: rand::Rng, - A: AdaptStrategy, -{ - pool: StatePool, - potential: P, - collector: A::Collector, - options: NutsOptions, - rng: R, - init: State, - chain: u64, - draw_count: u64, - strategy: A, - math: M, - stats: Option>, -} - -impl NutsChain -where - M: Math, - P: Hamiltonian, - R: rand::Rng, - A: AdaptStrategy, -{ - pub fn new( - mut math: M, - mut potential: P, - strategy: A, - options: NutsOptions, - rng: R, - chain: u64, - ) -> Self { - let pool_size: usize = options.maxdepth.checked_mul(2).unwrap().try_into().unwrap(); - let mut pool = potential.new_pool(&mut math, pool_size); - let init = potential.new_empty_state(&mut math, &mut pool); - let collector = strategy.new_collector(&mut math); - NutsChain { - pool, - potential, - collector, - options, - rng, - init, - chain, - draw_count: 0, - strategy, - math, - stats: None, - } - } -} - -pub trait AdaptStats: SamplerStats { - fn num_grad_evals(stats: &Self::Stats) -> usize; -} - -pub trait AdaptStrategy: AdaptStats { - type Potential: Hamiltonian; - type Collector: Collector; - type Options: Copy + Send + Debug + Default; - - fn new(math: &mut M, options: Self::Options, num_tune: u64) -> Self; - - fn init( - &mut self, - math: &mut M, - options: &mut NutsOptions, - potential: &mut Self::Potential, - state: &State, - rng: &mut R, - ); - - #[allow(clippy::too_many_arguments)] - fn adapt( - &mut self, - math: &mut M, - options: &mut NutsOptions, - potential: &mut Self::Potential, - draw: u64, - collector: &Self::Collector, - state: &State, - rng: &mut R, - ); - - fn new_collector(&self, math: &mut M) -> Self::Collector; - fn is_tuning(&self) -> bool; -} - -impl SamplerStats for NutsChain -where - M: Math, - H: Hamiltonian + SamplerStats, - R: rand::Rng, - A: AdaptStrategy, -{ - type Builder = NutsStatsBuilder; - type Stats = NutsSampleStats; - - fn new_builder(&self, settings: &impl Settings, dim: usize) -> Self::Builder { - NutsStatsBuilder::new_with_capacity( - settings, - &self.potential, - &self.strategy, - dim, - &self.options, - ) - } - - fn current_stats(&self, _math: &mut M) -> Self::Stats { - self.stats.as_ref().expect("No stats available").clone() - } -} - -impl Chain for NutsChain -where - M: Math, - H: Hamiltonian, - R: rand::Rng, - A: AdaptStrategy, -{ - type Hamiltonian = H; - type AdaptStrategy = A; - - fn set_position(&mut self, position: &[f64]) -> Result<()> { - let state = self - .potential - .init_state(&mut self.math, &mut self.pool, position)?; - self.init = state; - self.strategy.init( - &mut self.math, - &mut self.options, - &mut self.potential, - &self.init, - &mut self.rng, - ); - Ok(()) - } - - fn draw(&mut self) -> Result<(Box<[f64]>, Self::Stats)> { - let (state, info) = draw( - &mut self.math, - &mut self.pool, - &mut self.init, - &mut self.rng, - &mut self.potential, - &self.options, - &mut self.collector, - )?; - let mut position: Box<[f64]> = vec![0f64; self.math.dim()].into(); - state.write_position(&mut self.math, &mut position); - - let stats = NutsSampleStats { - depth: info.depth, - maxdepth_reached: info.reached_maxdepth, - idx_in_trajectory: state.index_in_trajectory(), - logp: -state.potential_energy(), - energy: state.energy(), - energy_error: info.draw_energy - info.initial_energy, - divergence_info: info.divergence_info, - chain: self.chain, - draw: self.draw_count, - potential_stats: self.potential.current_stats(&mut self.math), - strategy_stats: self.strategy.current_stats(&mut self.math), - gradient: if self.options.store_gradient { - let mut gradient: Box<[f64]> = vec![0f64; self.math.dim()].into(); - state.write_gradient(&mut self.math, &mut gradient); - Some(gradient) - } else { - None - }, - unconstrained: if self.options.store_unconstrained { - let mut unconstrained: Box<[f64]> = vec![0f64; self.math.dim()].into(); - state.write_position(&mut self.math, &mut unconstrained); - Some(unconstrained) - } else { - None - }, - tuning: self.strategy.is_tuning(), - }; - - self.strategy.adapt( - &mut self.math, - &mut self.options, - &mut self.potential, - self.draw_count, - &self.collector, - &state, - &mut self.rng, - ); - - self.draw_count += 1; - - self.init = state; - Ok((position, stats)) - } - - fn dim(&self) -> usize { - self.math.dim() - } - - fn stats_summary(stats: &Self::Stats) -> SampleStats { - let pot_stats = &stats.potential_stats; - let step_size = H::stat_step_size(pot_stats); - let adapt_stats = &stats.strategy_stats; - let num_steps = A::num_grad_evals(adapt_stats); - SampleStats { - draw: stats.draw, - chain: stats.chain, - diverging: stats.divergence_info.is_some(), - tuning: stats.tuning, - num_steps, - step_size, - } - } -} - #[cfg(test)] mod tests { use rand::thread_rng; @@ -1160,13 +780,11 @@ mod tests { use crate::{ adapt_strategy::test_logps::NormalLogp, cpu_math::CpuMath, - nuts::{Chain, SamplerStats}, sampler::DiagGradNutsSettings, - Settings, + sampler_stats::{SamplerStats, StatTraceBuilder}, + Chain, Settings, }; - use super::StatTraceBuilder; - #[test] fn to_arrow() { let ndim = 10; diff --git a/src/potential.rs b/src/potential.rs deleted file mode 100644 index bc3b25d..0000000 --- a/src/potential.rs +++ /dev/null @@ -1,283 +0,0 @@ -use std::fmt::Debug; -use std::marker::PhantomData; -use std::sync::Arc; - -use arrow::array::{ArrayBuilder, PrimitiveBuilder, StructArray}; -use arrow::datatypes::{DataType, Field, Float64Type}; - -use crate::mass_matrix::MassMatrix; -use crate::math_base::Math; -use crate::nuts::{ - Collector, Direction, DivergenceInfo, Hamiltonian, HamiltonianStats, LogpError, NutsError, -}; -use crate::nuts::{SamplerStats, StatTraceBuilder}; -use crate::sampler::Settings; -use crate::state::{State, StatePool}; - -pub struct EuclideanPotential> { - pub(crate) mass_matrix: Mass, - max_energy_error: f64, - pub(crate) step_size: f64, - _phantom: PhantomData, -} - -impl> EuclideanPotential { - pub(crate) fn new(mass_matrix: Mass, max_energy_error: f64, step_size: f64) -> Self { - EuclideanPotential { - mass_matrix, - max_energy_error, - step_size, - _phantom: PhantomData, - } - } -} - -#[derive(Copy, Clone, Debug)] -pub struct PotentialStats { - step_size: f64, - mass_matrix_stats: S, -} - -pub struct PotentialStatsBuilder { - step_size: PrimitiveBuilder, - mass_matrix: B, -} - -impl> StatTraceBuilder> - for PotentialStatsBuilder -{ - fn append_value(&mut self, value: PotentialStats) { - let PotentialStats { - step_size, - mass_matrix_stats, - } = value; - - self.step_size.append_value(step_size); - self.mass_matrix.append_value(mass_matrix_stats) - } - - fn finalize(self) -> Option { - let Self { - mut step_size, - mass_matrix, - } = self; - - let mut fields = vec![Field::new("step_size", DataType::Float64, false)]; - - let mut arrays = vec![ArrayBuilder::finish(&mut step_size)]; - if let Some(mass_matrix) = mass_matrix.finalize() { - let (m_fields, m_data, m_bitmap) = mass_matrix.into_parts(); - assert!(m_bitmap.is_none()); - fields.extend( - m_fields - .into_iter() - .map(|v| Arc::unwrap_or_clone(v.to_owned())), - ); - arrays.extend(m_data); - } - - Some(StructArray::new(fields.into(), arrays, None)) - } - - fn inspect(&self) -> Option { - let Self { - step_size, - mass_matrix, - } = self; - - let mut fields = vec![Field::new("step_size", DataType::Float64, false)]; - - let mut arrays = vec![ArrayBuilder::finish_cloned(step_size)]; - if let Some(mass_matrix) = mass_matrix.inspect() { - let (m_fields, m_data, m_bitmap) = mass_matrix.into_parts(); - assert!(m_bitmap.is_none()); - fields.extend( - m_fields - .into_iter() - .map(|v| Arc::unwrap_or_clone(v.to_owned())), - ); - arrays.extend(m_data); - } - - Some(StructArray::new(fields.into(), arrays, None)) - } -} - -impl> SamplerStats for EuclideanPotential { - type Builder = PotentialStatsBuilder; - type Stats = PotentialStats; - - fn new_builder(&self, settings: &impl Settings, dim: usize) -> Self::Builder { - Self::Builder { - step_size: PrimitiveBuilder::new(), - mass_matrix: self.mass_matrix.new_builder(settings, dim), - } - } - - fn current_stats(&self, math: &mut M) -> Self::Stats { - PotentialStats { - step_size: self.step_size, - mass_matrix_stats: self.mass_matrix.current_stats(math), - } - } -} - -impl> HamiltonianStats for EuclideanPotential { - fn stat_step_size(stats: &Self::Stats) -> f64 { - stats.step_size - } -} - -impl> Hamiltonian for EuclideanPotential { - type LogpError = M::LogpErr; - - fn leapfrog>( - &mut self, - math: &mut M, - pool: &mut StatePool, - start: &State, - dir: Direction, - initial_energy: f64, - collector: &mut C, - ) -> Result, DivergenceInfo>, NutsError> { - let mut out = pool.new_state(math); - - let sign = match dir { - Direction::Forward => 1, - Direction::Backward => -1, - }; - - let epsilon = (sign as f64) * self.step_size; - - start.first_momentum_halfstep(math, &mut out, epsilon); - self.update_velocity(math, &mut out); - - start.position_step(math, &mut out, epsilon); - if let Err(logp_error) = self.update_potential_gradient(math, &mut out) { - if !logp_error.is_recoverable() { - return Err(NutsError::LogpFailure(Box::new(logp_error))); - } - let div_info = DivergenceInfo { - logp_function_error: Some(Arc::new(Box::new(logp_error))), - start_location: Some(math.box_array(&start.q)), - start_gradient: Some(math.box_array(&start.grad)), - start_momentum: Some(math.box_array(&start.p)), - end_location: None, - start_idx_in_trajectory: Some(start.idx_in_trajectory), - end_idx_in_trajectory: None, - energy_error: None, - }; - collector.register_leapfrog(math, start, &out, Some(&div_info)); - return Ok(Err(div_info)); - } - - out.second_momentum_halfstep(math, epsilon); - - self.update_velocity(math, &mut out); - self.update_kinetic_energy(math, &mut out); - - *out.index_in_trajectory_mut() = start.index_in_trajectory() + sign; - - start.set_psum(math, &mut out, dir); - - let energy_error = { out.energy() - initial_energy }; - if (energy_error > self.max_energy_error) | !energy_error.is_finite() { - let divergence_info = DivergenceInfo { - logp_function_error: None, - start_location: Some(math.box_array(&start.q)), - start_gradient: Some(math.box_array(&start.grad)), - end_location: Some(math.box_array(&out.q)), - start_momentum: Some(math.box_array(&out.p)), - start_idx_in_trajectory: Some(start.index_in_trajectory()), - end_idx_in_trajectory: Some(out.index_in_trajectory()), - energy_error: Some(energy_error), - }; - collector.register_leapfrog(math, start, &out, Some(&divergence_info)); - return Ok(Err(divergence_info)); - } - - collector.register_leapfrog(math, start, &out, None); - - Ok(Ok(out)) - } - - fn init_state( - &mut self, - math: &mut M, - pool: &mut StatePool, - init: &[f64], - ) -> Result, NutsError> { - let mut state = pool.new_state(math); - { - let inner = state.try_mut_inner().expect("State already in use"); - math.read_from_slice(&mut inner.q, init); - math.fill_array(&mut inner.p_sum, 0.); - } - self.update_potential_gradient(math, &mut state) - .map_err(|e| NutsError::LogpFailure(Box::new(e)))?; - if !math.array_all_finite_and_nonzero(&state.grad) { - Err(NutsError::BadInitGrad()) - } else { - Ok(state) - } - } - - fn randomize_momentum( - &self, - math: &mut M, - state: &mut State, - rng: &mut R, - ) { - let inner = state.try_mut_inner().unwrap(); - self.mass_matrix.randomize_momentum(math, inner, rng); - self.mass_matrix.update_velocity(math, inner); - self.mass_matrix.update_kinetic_energy(math, inner); - } - - fn new_empty_state(&mut self, math: &mut M, pool: &mut StatePool) -> State { - pool.new_state(math) - } - - fn new_pool(&mut self, math: &mut M, capacity: usize) -> StatePool { - StatePool::new(math, capacity) - } - - fn copy_state(&mut self, math: &mut M, pool: &mut StatePool, state: &State) -> State { - pool.copy_state(math, state) - } - - fn stepsize_mut(&mut self) -> &mut f64 { - &mut self.step_size - } - - fn stepsize(&self) -> f64 { - self.step_size - } -} - -impl> EuclideanPotential { - fn update_potential_gradient( - &mut self, - math: &mut M, - state: &mut State, - ) -> Result<(), M::LogpErr> { - let logp = { - let inner = state.try_mut_inner().unwrap(); - math.logp_array(&inner.q, &mut inner.grad) - }?; - - let inner = state.try_mut_inner().unwrap(); - inner.potential_energy = -logp; - Ok(()) - } - - fn update_velocity(&mut self, math: &mut M, state: &mut State) { - self.mass_matrix - .update_velocity(math, state.try_mut_inner().expect("State already in us")) - } - - fn update_kinetic_energy(&mut self, math: &mut M, state: &mut State) { - self.mass_matrix - .update_kinetic_energy(math, state.try_mut_inner().expect("State already in us")) - } -} diff --git a/src/sampler.rs b/src/sampler.rs index e6278df..44084c9 100644 --- a/src/sampler.rs +++ b/src/sampler.rs @@ -17,17 +17,18 @@ use std::{ }; use crate::{ - adapt_strategy::{AdaptOptions, GlobalStrategy}, + adapt_strategy::{EuclideanAdaptOptions, GlobalStrategy}, + chain::{AdaptStrategy, Chain, NutsChain}, + euclidean_hamiltonian::EuclideanHamiltonian, low_rank_mass_matrix::{LowRankMassMatrix, LowRankMassMatrixStrategy, LowRankSettings}, mass_matrix::DiagMassMatrix, mass_matrix_adapt::Strategy as DiagMassMatrixStrategy, math_base::Math, - nuts::{ - AdaptStats, Chain, HamiltonianStats, NutsChain, NutsOptions, SampleStats, SamplerStats, - StatTraceBuilder, - }, - potential::EuclideanPotential, - DiagAdaptExpSettings, + nuts::NutsOptions, + sampler_stats::{SamplerStats, StatTraceBuilder}, + transform_adapt_strategy::{TransformAdaptation, TransformedSettings}, + transformed_hamiltonian::TransformedHamiltonian, + DiagAdaptExpSettings, SampleStats, }; /// All sampler configurations implement this trait @@ -54,13 +55,15 @@ pub trait Settings: private::Sealed + Clone + Copy + Default + Sync + Send + 'st mod private { use crate::DiagGradNutsSettings; - use super::LowRankNutsSettings; + use super::{LowRankNutsSettings, TransformedNutsSettings}; pub trait Sealed {} impl Sealed for DiagGradNutsSettings {} impl Sealed for LowRankNutsSettings {} + + impl Sealed for TransformedNutsSettings {} } /// Settings for the NUTS sampler @@ -83,15 +86,16 @@ pub struct NutsSettings { /// Store detailed information about each divergence in the sampler stats pub store_divergences: bool, /// Settings for mass matrix adaptation. - pub adapt_options: AdaptOptions, + pub adapt_options: A, pub check_turning: bool, pub num_chains: usize, pub seed: u64, } -pub type DiagGradNutsSettings = NutsSettings; -pub type LowRankNutsSettings = NutsSettings; +pub type DiagGradNutsSettings = NutsSettings>; +pub type LowRankNutsSettings = NutsSettings>; +pub type TransformedNutsSettings = NutsSettings; impl Default for DiagGradNutsSettings { fn default() -> Self { @@ -103,7 +107,7 @@ impl Default for DiagGradNutsSettings { store_gradient: false, store_unconstrained: false, store_divergences: false, - adapt_options: AdaptOptions::default(), + adapt_options: EuclideanAdaptOptions::default(), check_turning: true, seed: 0, num_chains: 6, @@ -121,7 +125,7 @@ impl Default for LowRankNutsSettings { store_gradient: false, store_unconstrained: false, store_divergences: false, - adapt_options: AdaptOptions::default(), + adapt_options: EuclideanAdaptOptions::default(), check_turning: true, seed: 0, num_chains: 6, @@ -131,19 +135,29 @@ impl Default for LowRankNutsSettings { } } -type DiagGradNutsChain = NutsChain< - M, - EuclideanPotential>, - SmallRng, - GlobalStrategy>, ->; +impl Default for TransformedNutsSettings { + fn default() -> Self { + Self { + num_tune: 4000, + num_draws: 1000, + maxdepth: 8, + max_energy_error: 1000f64, + store_gradient: false, + store_unconstrained: false, + store_divergences: false, + adapt_options: Default::default(), + check_turning: true, + seed: 0, + num_chains: 6, + } + } +} + +type DiagGradNutsChain = NutsChain>>; + +type LowRankNutsChain = NutsChain>; -type LowRankNutsChain = NutsChain< - M, - EuclideanPotential>, - SmallRng, - GlobalStrategy, ->; +type TransformingNutsChain = NutsChain; impl Settings for LowRankNutsSettings { type Chain = LowRankNutsChain; @@ -154,12 +168,11 @@ impl Settings for LowRankNutsSettings { mut math: M, rng: &mut R, ) -> Self::Chain { - use crate::nuts::AdaptStrategy; let num_tune = self.num_tune; - let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune); + let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune, chain); let mass_matrix = LowRankMassMatrix::new(&mut math, self.adapt_options.mass_matrix_options); let max_energy_error = self.max_energy_error; - let potential = EuclideanPotential::new(mass_matrix, max_energy_error, 1f64); + let potential = EuclideanHamiltonian::new(&mut math, mass_matrix, max_energy_error, 1f64); let options = NutsOptions { maxdepth: self.maxdepth, @@ -178,10 +191,8 @@ impl Settings for LowRankNutsSettings { &self, stats: & as SamplerStats>::Stats, ) -> SampleStats { - let step_size = - as Chain>::Hamiltonian::stat_step_size(&stats.potential_stats); - let num_steps = - as Chain>::AdaptStrategy::num_grad_evals(&stats.strategy_stats); + let step_size = stats.potential_stats.step_size; + let num_steps = stats.strategy_stats.stats1.n_steps; SampleStats { chain: stats.chain, draw: stats.draw, @@ -218,15 +229,14 @@ impl Settings for DiagGradNutsSettings { mut math: M, rng: &mut R, ) -> Self::Chain { - use crate::nuts::AdaptStrategy; let num_tune = self.num_tune; - let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune); + let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune, chain); let mass_matrix = DiagMassMatrix::new( &mut math, self.adapt_options.mass_matrix_options.store_mass_matrix, ); let max_energy_error = self.max_energy_error; - let potential = EuclideanPotential::new(mass_matrix, max_energy_error, 1f64); + let potential = EuclideanHamiltonian::new(&mut math, mass_matrix, max_energy_error, 1f64); let options = NutsOptions { maxdepth: self.maxdepth, @@ -245,10 +255,68 @@ impl Settings for DiagGradNutsSettings { &self, stats: & as SamplerStats>::Stats, ) -> SampleStats { - let step_size = - as Chain>::Hamiltonian::stat_step_size(&stats.potential_stats); - let num_steps = - as Chain>::AdaptStrategy::num_grad_evals(&stats.strategy_stats); + let step_size = stats.potential_stats.step_size; + let num_steps = stats.strategy_stats.stats1.n_steps; + SampleStats { + chain: stats.chain, + draw: stats.draw, + diverging: stats.divergence_info.is_some(), + tuning: stats.tuning, + step_size, + num_steps, + } + } + + fn hint_num_tune(&self) -> usize { + self.num_tune as _ + } + + fn hint_num_draws(&self) -> usize { + self.num_draws as _ + } + + fn num_chains(&self) -> usize { + self.num_chains + } + + fn seed(&self) -> u64 { + self.seed + } +} + +impl Settings for TransformedNutsSettings { + type Chain = TransformingNutsChain; + + fn new_chain( + &self, + chain: u64, + mut math: M, + rng: &mut R, + ) -> Self::Chain { + let num_tune = self.num_tune; + let max_energy_error = self.max_energy_error; + + let strategy = TransformAdaptation::new(&mut math, self.adapt_options, num_tune, chain); + let hamiltonian = TransformedHamiltonian::new(&mut math, max_energy_error); + + let options = NutsOptions { + maxdepth: self.maxdepth, + store_gradient: self.store_gradient, + store_divergences: self.store_divergences, + store_unconstrained: self.store_unconstrained, + check_turning: self.check_turning, + }; + + let rng = rand::rngs::SmallRng::from_rng(rng).expect("Could not seed rng"); + NutsChain::new(math, hamiltonian, strategy, options, rng, chain) + } + + fn sample_stats( + &self, + stats: & as SamplerStats>::Stats, + ) -> SampleStats { + let step_size = stats.potential_stats.step_size; + let num_steps = stats.strategy_stats.step_size.n_steps; SampleStats { chain: stats.chain, draw: stats.draw, @@ -290,7 +358,6 @@ pub fn sample_sequentially<'math, M: Math + 'math, R: Rng + ?Sized>( sampler .draw() .map(|(point, info)| (point, settings.sample_stats::(&info))) - .map_err(|e| e.into()) })) } @@ -357,8 +424,8 @@ impl ChainProgress { self.finished_draws += 1; self.tuning = stats.tuning; - self.latest_num_steps = stats.num_steps; - self.total_num_steps += stats.num_steps; + self.latest_num_steps = stats.num_steps as usize; + self.total_num_steps += stats.num_steps as usize; self.step_size = stats.step_size; self.runtime += draw_duration; } @@ -375,13 +442,9 @@ enum ChainCommand { Pause, } -type Builder<'model, M, S> = <::Chain<::Math<'model>> as SamplerStats< - ::Math<'model>, ->>::Builder; - struct ChainTrace<'model, M: Model + 'model, S: Settings> { draws_builder: M::DrawStorage<'model, S>, - stats_builder: Builder<'model, M, S>, + stats_builder: > as SamplerStats>>::Builder, chain_id: u64, } @@ -417,7 +480,7 @@ where progress: Arc>, } -impl<'scope, M: Model + 'scope, S: Settings> ChainProcess<'scope, M, S> { +impl<'scope, M: Model, S: Settings> ChainProcess<'scope, M, S> { fn finalize_many(chains: Vec) -> Vec>> { chains .into_iter() @@ -521,7 +584,6 @@ impl<'scope, M: Model + 'scope, S: Settings> ChainProcess<'scope, M, S> { } if let Some(error) = error { - let error: anyhow::Error = error.into(); return Err(error.context("All initialization points failed")); } @@ -860,7 +922,7 @@ pub mod test_logps { use crate::{ cpu_math::{CpuLogpFunc, CpuMath}, - nuts::LogpError, + math_base::LogpError, Settings, }; use anyhow::Result; @@ -889,6 +951,7 @@ pub mod test_logps { impl<'a> CpuLogpFunc for &'a NormalLogp { type LogpError = NormalLogpError; + type TransformParams = (); fn dim(&self) -> usize { self.dim @@ -952,6 +1015,66 @@ pub mod test_logps { Ok(logp) } + + fn inv_transform_normalize( + &mut self, + _params: &Self::TransformParams, + _untransformed_position: &[f64], + _untransofrmed_gradient: &[f64], + _transformed_position: &mut [f64], + _transformed_gradient: &mut [f64], + ) -> std::result::Result { + unimplemented!() + } + + fn init_from_untransformed_position( + &mut self, + _params: &Self::TransformParams, + _untransformed_position: &[f64], + _untransformed_gradient: &mut [f64], + _transformed_position: &mut [f64], + _transformed_gradient: &mut [f64], + ) -> std::result::Result<(f64, f64), Self::LogpError> { + unimplemented!() + } + + fn init_from_transformed_position( + &mut self, + _params: &Self::TransformParams, + _untransformed_position: &mut [f64], + _untransformed_gradient: &mut [f64], + _transformed_position: &[f64], + _transformed_gradient: &mut [f64], + ) -> std::result::Result<(f64, f64), Self::LogpError> { + unimplemented!() + } + + fn update_transformation<'b, R: rand::Rng + ?Sized>( + &'b mut self, + _rng: &mut R, + _untransformed_positions: impl Iterator, + _untransformed_gradients: impl Iterator, + _params: &'b mut Self::TransformParams, + ) -> std::result::Result<(), Self::LogpError> { + unimplemented!() + } + + fn new_transformation( + &mut self, + _rng: &mut R, + _untransformed_position: &[f64], + _untransfogmed_gradient: &[f64], + _chain: u64, + ) -> std::result::Result { + unimplemented!() + } + + fn transformation_id( + &self, + _params: &Self::TransformParams, + ) -> std::result::Result { + unimplemented!() + } } pub struct SimpleDrawStorage { diff --git a/src/sampler_stats.rs b/src/sampler_stats.rs new file mode 100644 index 0000000..224a9a7 --- /dev/null +++ b/src/sampler_stats.rs @@ -0,0 +1,31 @@ +use std::fmt::Debug; + +use arrow::array::StructArray; + +use crate::{Math, Settings}; + +pub trait SamplerStats { + type Stats: Send + Debug + Clone; + type Builder: StatTraceBuilder; + + fn new_builder(&self, settings: &impl Settings, dim: usize) -> Self::Builder; + fn current_stats(&self, math: &mut M) -> Self::Stats; +} + +impl StatTraceBuilder<()> for () { + fn append_value(&mut self, _value: ()) {} + + fn finalize(self) -> Option { + None + } + + fn inspect(&self) -> Option { + None + } +} + +pub trait StatTraceBuilder: Send { + fn append_value(&mut self, value: T); + fn finalize(self) -> Option; + fn inspect(&self) -> Option; +} diff --git a/src/state.rs b/src/state.rs index 1efa08a..94f277d 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,36 +1,35 @@ use std::{ cell::RefCell, fmt::Debug, - ops::Deref, rc::{Rc, Weak}, }; -use crate::math_base::Math; +use crate::{hamiltonian::Point, math_base::Math}; -struct StateStorage { - free_states: RefCell>>>, +struct StateStorage> { + free_states: RefCell>>>, } -impl StateStorage { - fn new(_math: &mut M, capacity: usize) -> StateStorage { +impl> StateStorage { + fn new(_math: &mut M, capacity: usize) -> StateStorage { StateStorage { free_states: RefCell::new(Vec::with_capacity(capacity)), } } } -pub struct StatePool { - storage: Rc>, +pub struct StatePool> { + storage: Rc>, } -impl StatePool { - pub fn new(math: &mut M, capacity: usize) -> StatePool { +impl> StatePool { + pub fn new(math: &mut M, capacity: usize) -> StatePool { StatePool { storage: Rc::new(StateStorage::new(math, capacity)), } } - pub fn new_state(&self, math: &mut M) -> State { + pub fn new_state(&self, math: &mut M) -> State { let inner = match self.storage.free_states.borrow_mut().pop() { Some(inner) => inner, None => Rc::new(InnerStateReusable::new(math, self)), @@ -40,80 +39,32 @@ impl StatePool { } } - pub fn copy_state(&self, math: &mut M, state: &State) -> State { + pub fn copy_state(&self, math: &mut M, state: &State) -> State { let mut new_state = self.new_state(math); - - let InnerState { - q, - p, - p_sum, - grad, - v, - idx_in_trajectory, - kinetic_energy, - potential_energy, - } = new_state - .try_mut_inner() + let new_point = new_state + .try_point_mut() .expect("New state should not have references"); - - math.copy_into(&state.q, q); - math.copy_into(&state.p, p); - math.copy_into(&state.p_sum, p_sum); - math.copy_into(&state.grad, grad); - math.copy_into(&state.v, v); - *idx_in_trajectory = state.idx_in_trajectory; - *kinetic_energy = state.kinetic_energy; - *potential_energy = state.potential_energy; - + state.point().copy_into(math, new_point); new_state } } -#[derive(Debug, Clone)] -pub struct InnerState { - pub(crate) p: M::Vector, - pub(crate) q: M::Vector, - pub(crate) v: M::Vector, - pub(crate) p_sum: M::Vector, - pub(crate) grad: M::Vector, - pub(crate) idx_in_trajectory: i64, - pub(crate) kinetic_energy: f64, - pub(crate) potential_energy: f64, +pub(crate) struct InnerStateReusable> { + inner: P, + reuser: Weak>, } -pub(crate) struct InnerStateReusable { - inner: InnerState, - reuser: Weak>, -} - -impl InnerStateReusable { - fn new(math: &mut M, owner: &StatePool) -> InnerStateReusable { +impl> InnerStateReusable { + fn new(math: &mut M, owner: &StatePool) -> InnerStateReusable { InnerStateReusable { - inner: InnerState { - p: math.new_array(), - q: math.new_array(), - v: math.new_array(), - p_sum: math.new_array(), - grad: math.new_array(), - idx_in_trajectory: 0, - kinetic_energy: 0., - potential_energy: 0., - }, + inner: P::new(math), reuser: Rc::downgrade(&Rc::clone(&owner.storage)), } } } -pub struct State { - inner: std::mem::ManuallyDrop>>, -} - -impl Deref for State { - type Target = InnerState; - - fn deref(&self) -> &Self::Target { - &self.inner.inner - } +pub struct State> { + inner: std::mem::ManuallyDrop>>, } #[derive(Debug)] @@ -121,16 +72,36 @@ pub struct StateInUse {} type Result = std::result::Result; -impl State { - pub(crate) fn try_mut_inner(&mut self) -> Result<&mut InnerState> { +impl> State { + pub fn point(&self) -> &P { + &self.inner.inner + } + + pub fn try_point_mut(&mut self) -> Result<&mut P> { match Rc::get_mut(&mut self.inner) { Some(val) => Ok(&mut val.inner), None => Err(StateInUse {}), } } + + pub fn index_in_trajectory(&self) -> i64 { + self.inner.inner.index_in_trajectory() + } + + pub fn write_position(&self, math: &mut M, out: &mut [f64]) { + math.write_to_slice(self.point().position(), out) + } + + pub fn write_gradient(&self, math: &mut M, out: &mut [f64]) { + math.write_to_slice(self.point().gradient(), out) + } + + pub fn energy(&self) -> f64 { + self.point().energy() + } } -impl Drop for State { +impl> Drop for State { fn drop(&mut self) { let rc = unsafe { std::mem::ManuallyDrop::take(&mut self.inner) }; if (Rc::strong_count(&rc) == 1) & (Rc::weak_count(&rc) == 0) { @@ -141,7 +112,7 @@ impl Drop for State { } } -impl Clone for State { +impl> Clone for State { fn clone(&self) -> Self { State { inner: self.inner.clone(), @@ -149,98 +120,11 @@ impl Clone for State { } } -impl State { - pub(crate) fn is_turning(&self, math: &mut M, other: &Self) -> bool { - let (start, end) = if self.idx_in_trajectory < other.idx_in_trajectory { - (self, other) - } else { - (other, self) - }; - - let a = start.idx_in_trajectory; - let b = end.idx_in_trajectory; - - assert!(a < b); - let (turn1, turn2) = if (a >= 0) & (b >= 0) { - math.scalar_prods3(&end.p_sum, &start.p_sum, &start.p, &end.v, &start.v) - } else if (b >= 0) & (a < 0) { - math.scalar_prods2(&end.p_sum, &start.p_sum, &end.v, &start.v) - } else { - assert!((a < 0) & (b < 0)); - math.scalar_prods3(&start.p_sum, &end.p_sum, &end.p, &end.v, &start.v) - }; - - (turn1 < 0.) | (turn2 < 0.) - } - - pub(crate) fn write_position(&self, math: &mut M, out: &mut [f64]) { - math.write_to_slice(&self.q, out) - } - - pub(crate) fn write_gradient(&self, math: &mut M, out: &mut [f64]) { - math.write_to_slice(&self.grad, out) - } - - pub(crate) fn energy(&self) -> f64 { - self.kinetic_energy + self.potential_energy - } - - pub(crate) fn index_in_trajectory(&self) -> i64 { - self.idx_in_trajectory - } - - pub(crate) fn make_init_point(&mut self, math: &mut M) { - let inner = self.try_mut_inner().unwrap(); - inner.idx_in_trajectory = 0; - math.copy_into(&inner.p, &mut inner.p_sum); - } - - pub(crate) fn potential_energy(&self) -> f64 { - self.potential_energy - } - - pub(crate) fn first_momentum_halfstep(&self, math: &mut M, out: &mut Self, epsilon: f64) { - math.axpy_out( - &self.grad, - &self.p, - epsilon / 2., - &mut out.try_mut_inner().expect("State already in use").p, - ); - } - - pub(crate) fn position_step(&self, math: &mut M, out: &mut Self, epsilon: f64) { - let out = out.try_mut_inner().expect("State already in use"); - math.axpy_out(&out.v, &self.q, epsilon, &mut out.q); - } - - pub(crate) fn second_momentum_halfstep(&mut self, math: &mut M, epsilon: f64) { - let inner = self.try_mut_inner().expect("State already in use"); - math.axpy(&inner.grad, &mut inner.p, epsilon / 2.); - } - - pub(crate) fn set_psum(&self, math: &mut M, target: &mut Self, _dir: crate::nuts::Direction) { - let out = target.try_mut_inner().expect("State already in use"); - - assert!(out.idx_in_trajectory != 0); - - if out.idx_in_trajectory == -1 { - math.copy_into(&out.p, &mut out.p_sum); - } else { - math.axpy_out(&out.p, &self.p_sum, 1., &mut out.p_sum); - } - } - - pub(crate) fn index_in_trajectory_mut(&mut self) -> &mut i64 { - &mut self - .try_mut_inner() - .expect("State already in use") - .idx_in_trajectory - } -} - #[cfg(test)] mod tests { - use crate::{cpu_math::CpuMath, sampler::test_logps::NormalLogp}; + use crate::{ + cpu_math::CpuMath, euclidean_hamiltonian::EuclideanPoint, sampler::test_logps::NormalLogp, + }; use super::*; @@ -248,15 +132,13 @@ mod tests { fn crate_pool() { let logp = NormalLogp { dim: 10, mu: 0.2 }; let mut math = CpuMath::new(&logp); - let pool = StatePool::new(&mut math, 10); + let pool: StatePool<_, EuclideanPoint<_>> = StatePool::new(&mut math, 10); let mut state = pool.new_state(&mut math); - assert!(state.p.nrows() == 10); - assert!(state.p.ncols() == 1); - state.try_mut_inner().unwrap(); + state.try_point_mut().unwrap(); let mut state2 = state.clone(); - assert!(state.try_mut_inner().is_err()); - assert!(state2.try_mut_inner().is_err()); + assert!(state.try_point_mut().is_err()); + assert!(state2.try_point_mut().is_err()); } #[test] @@ -264,14 +146,8 @@ mod tests { let dim = 10; let logp = NormalLogp { dim, mu: 0.2 }; let mut math = CpuMath::new(&logp); - let pool = StatePool::new(&mut math, 10); + let pool: StatePool<_, EuclideanPoint<_>> = StatePool::new(&mut math, 10); let a = pool.new_state(&mut math); - - assert_eq!(a.idx_in_trajectory, 0); - assert!(a.p_sum.as_slice().iter().all(|&x| x == 0f64)); - assert_eq!(a.p_sum.as_slice().len(), dim); - assert_eq!(a.grad.as_slice().len(), dim); - assert_eq!(a.q.as_slice().len(), dim); - assert_eq!(a.p.as_slice().len(), dim); + assert_eq!(a.index_in_trajectory(), 0); } } diff --git a/src/stepsize.rs b/src/stepsize.rs index 122737d..2556d8f 100644 --- a/src/stepsize.rs +++ b/src/stepsize.rs @@ -1,8 +1,8 @@ use crate::{ + hamiltonian::{DivergenceInfo, Point}, math_base::Math, nuts::{Collector, NutsOptions}, state::State, - DivergenceInfo, }; /// Settings for step size adaptation @@ -117,12 +117,12 @@ impl AcceptanceRateCollector { } } -impl Collector for AcceptanceRateCollector { +impl> Collector for AcceptanceRateCollector { fn register_leapfrog( &mut self, _math: &mut M, - _start: &State, - end: &State, + _start: &State, + end: &State, divergence_info: Option<&DivergenceInfo>, ) { match divergence_info { @@ -142,7 +142,7 @@ impl Collector for AcceptanceRateCollector { }; } - fn register_init(&mut self, _math: &mut M, state: &State, _options: &NutsOptions) { + fn register_init(&mut self, _math: &mut M, state: &State, _options: &NutsOptions) { self.initial_energy = state.energy(); self.mean.reset(); self.mean_sym.reset(); diff --git a/src/stepsize_adapt.rs b/src/stepsize_adapt.rs index 114ff69..16211c5 100644 --- a/src/stepsize_adapt.rs +++ b/src/stepsize_adapt.rs @@ -5,12 +5,11 @@ use arrow::{ use rand::Rng; use crate::{ - nuts::{ - AdaptStats, Collector, Direction, Hamiltonian, NutsOptions, SamplerStats, StatTraceBuilder, - }, - state::State, + hamiltonian::{Direction, Hamiltonian, LeapfrogResult, Point}, + nuts::{Collector, NutsOptions}, + sampler_stats::{SamplerStats, StatTraceBuilder}, stepsize::{AcceptanceRateCollector, DualAverage, DualAverageOptions}, - Math, Settings, + Math, NutsError, Settings, }; pub struct Strategy { @@ -32,40 +31,27 @@ impl Strategy { } } - pub fn init( + pub fn init>( &mut self, math: &mut M, options: &mut NutsOptions, - potential: &mut impl Hamiltonian, - state: &State, + hamiltonian: &mut impl Hamiltonian, + position: &[f64], rng: &mut R, - ) { - let mut pool = potential.new_pool(math, 1); - - let mut state = potential.copy_state(math, &mut pool, state); - state - .try_mut_inner() - .expect("New state should have only one reference") - .idx_in_trajectory = 0; - potential.randomize_momentum(math, &mut state, rng); + ) -> Result<(), NutsError> { + let mut state = hamiltonian.init_state(math, position)?; + hamiltonian.initialize_trajectory(math, &mut state, rng)?; let mut collector = AcceptanceRateCollector::new(); collector.register_init(math, &state, options); - *potential.stepsize_mut() = self.options.initial_step; + *hamiltonian.step_size_mut() = self.options.initial_step; - let state_next = potential.leapfrog( - math, - &mut pool, - &state, - Direction::Forward, - state.energy(), - &mut collector, - ); + let state_next = hamiltonian.leapfrog(math, &state, Direction::Forward, &mut collector); - let Ok(_) = state_next else { - return; + let LeapfrogResult::Ok(_) = state_next else { + return Ok(()); }; let accept_stat = collector.mean.current(); @@ -78,35 +64,37 @@ impl Strategy { for _ in 0..100 { let mut collector = AcceptanceRateCollector::new(); collector.register_init(math, &state, options); - let state_next = - potential.leapfrog(math, &mut pool, &state, dir, state.energy(), &mut collector); - let Ok(_) = state_next else { - *potential.stepsize_mut() = self.options.initial_step; - return; + let state_next = hamiltonian.leapfrog(math, &state, dir, &mut collector); + let LeapfrogResult::Ok(_) = state_next else { + *hamiltonian.step_size_mut() = self.options.initial_step; + return Ok(()); }; let accept_stat = collector.mean.current(); match dir { Direction::Forward => { - if (accept_stat <= self.options.target_accept) | (potential.stepsize() > 1e5) { + if (accept_stat <= self.options.target_accept) | (hamiltonian.step_size() > 1e5) + { self.step_size_adapt = - DualAverage::new(self.options.params, potential.stepsize()); - return; + DualAverage::new(self.options.params, hamiltonian.step_size()); + return Ok(()); } - *potential.stepsize_mut() *= 2.; + *hamiltonian.step_size_mut() *= 2.; } Direction::Backward => { - if (accept_stat >= self.options.target_accept) | (potential.stepsize() < 1e-10) + if (accept_stat >= self.options.target_accept) + | (hamiltonian.step_size() < 1e-10) { self.step_size_adapt = - DualAverage::new(self.options.params, potential.stepsize()); - return; + DualAverage::new(self.options.params, hamiltonian.step_size()); + return Ok(()); } - *potential.stepsize_mut() /= 2.; + *hamiltonian.step_size_mut() /= 2.; } } } // If we don't find something better, use the specified initial value - *potential.stepsize_mut() = self.options.initial_step; + *hamiltonian.step_size_mut() = self.options.initial_step; + Ok(()) } pub fn update(&mut self, collector: &AcceptanceRateCollector) { @@ -134,9 +122,9 @@ impl Strategy { use_best_guess: bool, ) { if use_best_guess { - *potential.stepsize_mut() = self.step_size_adapt.current_step_size_adapted(); + *potential.step_size_mut() = self.step_size_adapt.current_step_size_adapted(); } else { - *potential.stepsize_mut() = self.step_size_adapt.current_step_size(); + *potential.step_size_mut() = self.step_size_adapt.current_step_size(); } } @@ -243,12 +231,6 @@ impl SamplerStats for Strategy { } } -impl AdaptStats for Strategy { - fn num_grad_evals(stats: &Self::Stats) -> usize { - stats.n_steps as usize - } -} - #[derive(Debug, Clone, Copy)] pub struct DualAverageSettings { pub target_accept: f64, diff --git a/src/transform_adapt_strategy.rs b/src/transform_adapt_strategy.rs new file mode 100644 index 0000000..7d5c7fd --- /dev/null +++ b/src/transform_adapt_strategy.rs @@ -0,0 +1,247 @@ +use arrow::array::StructArray; + +use crate::adapt_strategy::CombinedCollector; +use crate::chain::AdaptStrategy; +use crate::hamiltonian::{Hamiltonian, Point}; +use crate::nuts::{Collector, NutsOptions, SampleInfo}; +use crate::sampler_stats::{SamplerStats, StatTraceBuilder}; +use crate::state::State; +use crate::stepsize::AcceptanceRateCollector; +use crate::stepsize_adapt::{ + Stats as StepSizeStats, StatsBuilder as StepSizeStatsBuilder, Strategy as StepSizeStrategy, +}; +use crate::transformed_hamiltonian::TransformedHamiltonian; +use crate::{DualAverageSettings, Math, NutsError, Settings}; + +#[derive(Clone, Copy, Debug)] +pub struct TransformedSettings { + pub step_size_window: f64, + pub transform_update_freq: u64, + pub use_orbit_for_training: bool, + pub dual_average_options: DualAverageSettings, + pub transform_train_max_energy_error: f64, +} + +impl Default for TransformedSettings { + fn default() -> Self { + Self { + step_size_window: 0.1f64, + transform_update_freq: 50, + use_orbit_for_training: true, + transform_train_max_energy_error: 50f64, + dual_average_options: Default::default(), + } + } +} + +pub struct TransformAdaptation { + step_size: StepSizeStrategy, + options: TransformedSettings, + num_tune: u64, + final_window_size: u64, + tuning: bool, + chain: u64, +} + +#[derive(Clone, Debug)] +pub struct Stats { + pub step_size: StepSizeStats, +} + +pub struct Builder { + step_size: StepSizeStatsBuilder, +} + +impl StatTraceBuilder for Builder { + fn append_value(&mut self, value: Stats) { + let Stats { step_size } = value; + self.step_size.append_value(step_size); + } + + fn finalize(self) -> Option { + let Self { step_size } = self; + step_size.finalize() + } + + fn inspect(&self) -> Option { + let Self { step_size } = self; + step_size.inspect() + } +} + +impl SamplerStats for TransformAdaptation { + type Stats = Stats; + type Builder = Builder; + + fn new_builder(&self, settings: &impl Settings, dim: usize) -> Self::Builder { + let step_size = SamplerStats::::new_builder(&self.step_size, settings, dim); + Builder { step_size } + } + + fn current_stats(&self, math: &mut M) -> Self::Stats { + let step_size = self.step_size.current_stats(math); + Stats { step_size } + } +} + +pub struct DrawCollector { + draws: Vec, + grads: Vec, + collect_orbit: bool, + max_energy_error: f64, +} + +impl DrawCollector { + fn new(_math: &mut M, collect_orbit: bool, max_energy_error: f64) -> Self { + Self { + draws: vec![], + grads: vec![], + collect_orbit, + max_energy_error, + } + } +} + +impl> Collector for DrawCollector { + fn register_leapfrog( + &mut self, + math: &mut M, + _start: &State, + end: &State, + _divergence_info: Option<&crate::DivergenceInfo>, + ) { + if self.collect_orbit { + let point = end.point(); + let energy_error = point.energy_error(); + if energy_error.abs() < self.max_energy_error { + if !math.array_all_finite(point.position()) { + return; + } + if !math.array_all_finite(point.gradient()) { + return; + } + self.draws.push(math.copy_array(point.position())); + self.grads.push(math.copy_array(point.gradient())); + } + } + } + + fn register_draw(&mut self, math: &mut M, state: &State, _info: &SampleInfo) { + if !self.collect_orbit { + let point = state.point(); + let energy_error = point.energy_error(); + if energy_error.abs() < self.max_energy_error { + if !math.array_all_finite(point.position()) { + return; + } + if !math.array_all_finite(point.gradient()) { + return; + } + self.draws.push(math.copy_array(point.position())); + self.grads.push(math.copy_array(point.gradient())); + } + } + } +} + +impl AdaptStrategy for TransformAdaptation { + type Hamiltonian = TransformedHamiltonian; + + type Collector = CombinedCollector< + M, + >::Point, + AcceptanceRateCollector, + DrawCollector, + >; + + type Options = TransformedSettings; + + fn new(_math: &mut M, options: Self::Options, num_tune: u64, chain: u64) -> Self { + let step_size = StepSizeStrategy::new(options.dual_average_options); + let final_window_size = + ((num_tune as f64) * (1f64 - options.step_size_window)).floor() as u64; + Self { + step_size, + options, + num_tune, + final_window_size, + tuning: true, + chain, + } + } + + fn init( + &mut self, + math: &mut M, + options: &mut NutsOptions, + hamiltonian: &mut Self::Hamiltonian, + position: &[f64], + rng: &mut R, + ) -> Result<(), NutsError> { + hamiltonian.init_transformation(rng, math, position, self.chain)?; + self.step_size + .init(math, options, hamiltonian, position, rng)?; + Ok(()) + } + + fn adapt( + &mut self, + math: &mut M, + _options: &mut NutsOptions, + hamiltonian: &mut Self::Hamiltonian, + draw: u64, + collector: &Self::Collector, + _state: &State>::Point>, + rng: &mut R, + ) -> Result<(), NutsError> { + self.step_size.update(&collector.collector1); + + if draw >= self.num_tune { + self.tuning = false; + return Ok(()); + } + + if draw < self.final_window_size { + if draw < 100 { + if (draw > 0) & (draw % 10 == 0) { + hamiltonian.update_params( + math, + rng, + collector.collector2.draws.iter(), + collector.collector2.grads.iter(), + )?; + } + } else if (draw > 0) & (draw % self.options.transform_update_freq == 0) { + hamiltonian.update_params( + math, + rng, + collector.collector2.draws.iter(), + collector.collector2.grads.iter(), + )?; + } + self.step_size.update_estimator_early(); + self.step_size.update_stepsize(hamiltonian, false); + return Ok(()); + } + + self.step_size.update_estimator_late(); + let is_last = draw == self.num_tune - 1; + self.step_size.update_stepsize(hamiltonian, is_last); + Ok(()) + } + + fn new_collector(&self, math: &mut M) -> Self::Collector { + Self::Collector::new( + self.step_size.new_collector(), + DrawCollector::new( + math, + self.options.use_orbit_for_training, + self.options.transform_train_max_energy_error, + ), + ) + } + + fn is_tuning(&self) -> bool { + self.tuning + } +} diff --git a/src/transformed_hamiltonian.rs b/src/transformed_hamiltonian.rs new file mode 100644 index 0000000..79e4a83 --- /dev/null +++ b/src/transformed_hamiltonian.rs @@ -0,0 +1,470 @@ +use std::{marker::PhantomData, sync::Arc}; + +use arrow::{ + array::{ArrayBuilder, Float64Builder, StructArray}, + datatypes::{DataType, Field}, +}; + +use crate::{ + hamiltonian::{Direction, Hamiltonian, LeapfrogResult, Point}, + sampler_stats::{SamplerStats, StatTraceBuilder}, + state::{State, StatePool}, + DivergenceInfo, LogpError, Math, NutsError, Settings, +}; + +pub struct TransformedPoint { + untransformed_position: M::Vector, + untransformed_gradient: M::Vector, + transformed_position: M::Vector, + transformed_gradient: M::Vector, + velocity: M::Vector, + index_in_trajectory: i64, + logp: f64, + logdet: f64, + kinetic_energy: f64, + initial_energy: f64, + transform_id: i64, +} + +impl TransformedPoint { + fn first_velocity_halfstep(&self, math: &mut M, out: &mut Self, epsilon: f64) { + math.axpy_out( + &self.transformed_gradient, + &self.velocity, + epsilon / 2., + &mut out.velocity, + ); + } + + fn position_step(&self, math: &mut M, out: &mut Self, epsilon: f64) { + math.axpy_out( + &out.velocity, + &self.transformed_position, + epsilon, + &mut out.transformed_position, + ); + } + + fn second_velocity_halfstep(&mut self, math: &mut M, epsilon: f64) { + math.axpy(&self.transformed_gradient, &mut self.velocity, epsilon / 2.); + } + + fn update_kinetic_energy(&mut self, math: &mut M) { + self.kinetic_energy = 0.5 * math.array_vector_dot(&self.velocity, &self.velocity); + } + + fn init_from_untransformed_position( + &mut self, + hamiltonian: &TransformedHamiltonian, + math: &mut M, + ) -> Result<(), M::LogpErr> { + let (logp, logdet) = { + math.init_from_untransformed_position( + hamiltonian.params.as_ref().expect("No transformation set"), + &self.untransformed_position, + &mut self.untransformed_gradient, + &mut self.transformed_position, + &mut self.transformed_gradient, + ) + }?; + self.logp = logp; + self.logdet = logdet; + Ok(()) + } + + fn init_from_transformed_position( + &mut self, + hamiltonian: &TransformedHamiltonian, + math: &mut M, + ) -> Result<(), M::LogpErr> { + let (logp, logdet) = { + math.init_from_transformed_position( + hamiltonian.params.as_ref().expect("No transformation set"), + &mut self.untransformed_position, + &mut self.untransformed_gradient, + &self.transformed_position, + &mut self.transformed_gradient, + ) + }?; + self.logp = logp; + self.logdet = logdet; + Ok(()) + } + + fn is_valid(&self, math: &mut M) -> bool { + if !math.array_all_finite(&self.transformed_position) { + return false; + } + if !math.array_all_finite_and_nonzero(&self.transformed_gradient) { + return false; + } + if !math.array_all_finite(&self.untransformed_gradient) { + return false; + } + if !math.array_all_finite(&self.untransformed_position) { + return false; + } + + true + } +} + +impl Point for TransformedPoint { + fn position(&self) -> &::Vector { + &self.untransformed_position + } + + fn gradient(&self) -> &::Vector { + &self.untransformed_gradient + } + + fn index_in_trajectory(&self) -> i64 { + self.index_in_trajectory + } + + fn energy(&self) -> f64 { + self.kinetic_energy - (self.logp + self.logdet) + } + + fn initial_energy(&self) -> f64 { + self.initial_energy + } + + fn logp(&self) -> f64 { + self.logp + } + + fn new(math: &mut M) -> Self { + Self { + untransformed_position: math.new_array(), + untransformed_gradient: math.new_array(), + transformed_position: math.new_array(), + transformed_gradient: math.new_array(), + velocity: math.new_array(), + index_in_trajectory: 0, + logp: 0f64, + logdet: 0f64, + kinetic_energy: 0f64, + transform_id: -1, + initial_energy: 0f64, + } + } + + fn copy_into(&self, math: &mut M, other: &mut Self) { + let Self { + untransformed_position, + untransformed_gradient, + transformed_position, + transformed_gradient, + velocity, + index_in_trajectory, + logp, + logdet, + kinetic_energy, + transform_id, + initial_energy, + } = self; + + other.index_in_trajectory = *index_in_trajectory; + other.logp = *logp; + other.logdet = *logdet; + other.kinetic_energy = *kinetic_energy; + other.transform_id = *transform_id; + other.initial_energy = *initial_energy; + math.copy_into(untransformed_position, &mut other.untransformed_position); + math.copy_into(untransformed_gradient, &mut other.untransformed_gradient); + math.copy_into(transformed_position, &mut other.transformed_position); + math.copy_into(transformed_gradient, &mut other.transformed_gradient); + math.copy_into(velocity, &mut other.velocity); + } +} + +pub struct TransformedHamiltonian { + ones: M::Vector, + zeros: M::Vector, + step_size: f64, + params: Option, + max_energy_error: f64, + _phantom: PhantomData, + pool: StatePool>, +} + +impl TransformedHamiltonian { + pub fn new(math: &mut M, max_energy_error: f64) -> Self { + let mut ones = math.new_array(); + math.fill_array(&mut ones, 1f64); + let mut zeros = math.new_array(); + math.fill_array(&mut zeros, 0f64); + let pool = StatePool::new(math, 10); + Self { + step_size: 0f64, + ones, + zeros, + params: None, + max_energy_error, + _phantom: Default::default(), + pool, + } + } + + pub fn init_transformation( + &mut self, + rng: &mut R, + math: &mut M, + position: &[f64], + chain: u64, + ) -> Result<(), NutsError> { + let mut gradient_array = math.new_array(); + let mut position_array = math.new_array(); + math.read_from_slice(&mut position_array, position); + let _ = math + .logp_array(&position_array, &mut gradient_array) + .map_err(|e| NutsError::BadInitGrad(Box::new(e)))?; + let params = math + .new_transformation(rng, &position_array, &gradient_array, chain) + .map_err(|e| NutsError::BadInitGrad(Box::new(e)))?; + self.params = Some(params); + Ok(()) + } + + pub fn update_params<'a, R: rand::Rng + ?Sized>( + &'a mut self, + math: &'a mut M, + rng: &mut R, + draws: impl ExactSizeIterator, + grads: impl ExactSizeIterator, + ) -> Result<(), NutsError> { + math.update_transformation( + rng, + draws, + grads, + self.params.as_mut().expect("Transformation was empty"), + ) + .map_err(|e| NutsError::BadInitGrad(Box::new(e)))?; + Ok(()) + } +} + +#[derive(Debug, Clone, Default)] +pub struct Stats { + pub step_size: f64, +} + +pub struct Builder { + step_size: Float64Builder, +} + +impl StatTraceBuilder for Builder { + fn append_value(&mut self, value: Stats) { + let Stats { step_size } = value; + self.step_size.append_value(step_size); + } + + fn finalize(self) -> Option { + let Self { mut step_size } = self; + + let fields = vec![Field::new("step_size", DataType::Float64, false)]; + let arrays = vec![ArrayBuilder::finish(&mut step_size)]; + + Some(StructArray::new(fields.into(), arrays, None)) + } + + fn inspect(&self) -> Option { + let Self { step_size } = self; + + let fields = vec![Field::new("step_size", DataType::Float64, false)]; + let arrays = vec![ArrayBuilder::finish_cloned(step_size)]; + + Some(StructArray::new(fields.into(), arrays, None)) + } +} + +impl SamplerStats for TransformedHamiltonian { + type Stats = Stats; + type Builder = Builder; + + fn new_builder(&self, settings: &impl Settings, _dim: usize) -> Self::Builder { + Builder { + step_size: Float64Builder::with_capacity( + settings.hint_num_draws() + settings.hint_num_tune(), + ), + } + } + + fn current_stats(&self, _math: &mut M) -> Self::Stats { + Stats { + step_size: self.step_size, + } + } +} + +impl Hamiltonian for TransformedHamiltonian { + type Point = TransformedPoint; + + fn leapfrog>( + &mut self, + math: &mut M, + start: &State, + dir: Direction, + collector: &mut C, + ) -> LeapfrogResult { + let mut out = self.pool().new_state(math); + let out_point = out.try_point_mut().expect("New point has other references"); + + out_point.initial_energy = start.point().initial_energy(); + out_point.transform_id = start.point().transform_id; + + let sign = match dir { + Direction::Forward => 1, + Direction::Backward => -1, + }; + + let epsilon = (sign as f64) * self.step_size; + + start + .point() + .first_velocity_halfstep(math, out_point, epsilon); + + start.point().position_step(math, out_point, epsilon); + if let Err(logp_error) = out_point.init_from_transformed_position(self, math) { + if !logp_error.is_recoverable() { + return LeapfrogResult::Err(logp_error); + } + let div_info = DivergenceInfo { + logp_function_error: Some(Arc::new(Box::new(logp_error))), + start_location: Some(math.box_array(start.point().position())), + start_gradient: Some(math.box_array(start.point().gradient())), + start_momentum: None, + end_location: None, + start_idx_in_trajectory: Some(start.point().index_in_trajectory()), + end_idx_in_trajectory: None, + energy_error: None, + }; + collector.register_leapfrog(math, start, &out, Some(&div_info)); + return LeapfrogResult::Divergence(div_info); + } + + out_point.second_velocity_halfstep(math, epsilon); + + out_point.update_kinetic_energy(math); + out_point.index_in_trajectory = start.index_in_trajectory() + sign; + + let energy_error = out_point.energy_error(); + if (energy_error > self.max_energy_error) | !energy_error.is_finite() { + let divergence_info = DivergenceInfo { + logp_function_error: None, + start_location: Some(math.box_array(start.point().position())), + start_gradient: Some(math.box_array(start.point().gradient())), + end_location: Some(math.box_array(out_point.position())), + start_momentum: None, + start_idx_in_trajectory: Some(start.index_in_trajectory()), + end_idx_in_trajectory: Some(out.index_in_trajectory()), + energy_error: Some(energy_error), + }; + collector.register_leapfrog(math, start, &out, Some(&divergence_info)); + return LeapfrogResult::Divergence(divergence_info); + } + + collector.register_leapfrog(math, start, &out, None); + + LeapfrogResult::Ok(out) + } + + fn is_turning( + &self, + math: &mut M, + state1: &State, + state2: &State, + ) -> bool { + let (start, end) = if state1.index_in_trajectory() < state2.index_in_trajectory() { + (state1, state2) + } else { + (state2, state1) + }; + + let (turn1, turn2) = math.scalar_prods3( + &end.point().transformed_position, + &start.point().transformed_position, + &self.zeros, + &start.point().velocity, + &end.point().velocity, + ); + + (turn1 < 0f64) | (turn2 < 0f64) + } + + fn init_state( + &mut self, + math: &mut M, + init: &[f64], + ) -> Result, NutsError> { + let mut state = self.pool().new_state(math); + let point = state.try_point_mut().expect("State already in use"); + math.read_from_slice(&mut point.untransformed_position, init); + + point + .init_from_untransformed_position(self, math) + .map_err(|e| NutsError::LogpFailure(Box::new(e)))?; + + if !point.is_valid(math) { + Err(NutsError::BadInitGrad( + anyhow::anyhow!("Invalid initial point").into(), + )) + } else { + Ok(state) + } + } + + fn initialize_trajectory( + &self, + math: &mut M, + state: &mut State, + rng: &mut R, + ) -> Result<(), NutsError> { + let point = state.try_point_mut().expect("State has other references"); + math.array_gaussian(rng, &mut point.velocity, &self.ones); + let current_transform_id = math + .transformation_id(self.params.as_ref().expect("No transformation set")) + .map_err(|e| NutsError::LogpFailure(Box::new(e)))?; + if current_transform_id != point.transform_id { + let logdet = math + .inv_transform_normalize( + self.params.as_ref().expect("No transformation set"), + &point.untransformed_position, + &point.untransformed_gradient, + &mut point.transformed_position, + &mut point.transformed_gradient, + ) + .map_err(|e| NutsError::LogpFailure(Box::new(e)))?; + point.logdet = logdet; + point.transform_id = current_transform_id; + } + point.update_kinetic_energy(math); + point.index_in_trajectory = 0; + point.initial_energy = point.energy(); + Ok(()) + } + + fn pool(&mut self) -> &mut StatePool { + &mut self.pool + } + + fn copy_state(&mut self, math: &mut M, state: &State) -> State { + let mut new_state = self.pool.new_state(math); + state.point().copy_into( + math, + new_state + .try_point_mut() + .expect("New point should not have other references"), + ); + new_state + } + + fn step_size(&self) -> f64 { + self.step_size + } + + fn step_size_mut(&mut self) -> &mut f64 { + &mut self.step_size + } +} diff --git a/tests/sample_normal.rs b/tests/sample_normal.rs index d1289e2..895d220 100644 --- a/tests/sample_normal.rs +++ b/tests/sample_normal.rs @@ -8,8 +8,9 @@ use arrow::{ datatypes::Float64Type, }; use nuts_rs::{ - AdaptOptions, CpuLogpFunc, CpuMath, DiagAdaptExpSettings, DiagGradNutsSettings, DrawStorage, - LogpError, LowRankNutsSettings, Model, Sampler, SamplerWaitResult, Settings, Trace, + CpuLogpFunc, CpuMath, DiagAdaptExpSettings, DiagGradNutsSettings, DrawStorage, + EuclideanAdaptOptions, LogpError, LowRankNutsSettings, Model, Sampler, SamplerWaitResult, + Settings, Trace, }; use rand::prelude::Rng; use rand_distr::{Distribution, StandardNormal}; @@ -31,6 +32,7 @@ impl LogpError for NormalLogpError { impl<'a> CpuLogpFunc for NormalLogp<'a> { type LogpError = NormalLogpError; + type TransformParams = (); fn dim(&self) -> usize { self.dim @@ -52,6 +54,63 @@ impl<'a> CpuLogpFunc for NormalLogp<'a> { }); Ok(logp) } + + fn inv_transform_normalize( + &mut self, + _params: &Self::TransformParams, + _untransformed_position: &[f64], + _untransofrmed_gradient: &[f64], + _transformed_position: &mut [f64], + _transformed_gradient: &mut [f64], + ) -> Result { + todo!() + } + + fn init_from_transformed_position( + &mut self, + _params: &Self::TransformParams, + _untransformed_position: &mut [f64], + _untransformed_gradient: &mut [f64], + _transformed_position: &[f64], + _transformed_gradient: &mut [f64], + ) -> Result<(f64, f64), Self::LogpError> { + todo!() + } + + fn init_from_untransformed_position( + &mut self, + _params: &Self::TransformParams, + _untransformed_position: &[f64], + _untransformed_gradient: &mut [f64], + _transformed_position: &mut [f64], + _transformed_gradient: &mut [f64], + ) -> Result<(f64, f64), Self::LogpError> { + todo!() + } + + fn update_transformation<'b, R: rand::Rng + ?Sized>( + &'b mut self, + _rng: &mut R, + _untransformed_positions: impl Iterator, + _untransformed_gradients: impl Iterator, + _params: &'b mut Self::TransformParams, + ) -> Result<(), Self::LogpError> { + todo!() + } + + fn new_transformation( + &mut self, + _rng: &mut R, + _untransformed_position: &[f64], + _untransfogmed_gradient: &[f64], + _chain: u64, + ) -> Result { + todo!() + } + + fn transformation_id(&self, _params: &Self::TransformParams) -> Result { + todo!() + } } struct Storage { @@ -93,11 +152,13 @@ impl NormalModel { } impl Model for NormalModel { - type Math<'model> = CpuMath> + type Math<'model> + = CpuMath> where Self: 'model; - type DrawStorage<'model, S: Settings> = Storage + type DrawStorage<'model, S: Settings> + = Storage where Self: 'model; @@ -158,7 +219,7 @@ fn sample_debug_stats() -> anyhow::Result { store_gradient: true, store_divergences: true, store_unconstrained: true, - adapt_options: AdaptOptions { + adapt_options: EuclideanAdaptOptions { mass_matrix_options: DiagAdaptExpSettings { store_mass_matrix: true, use_grad_based_estimate: true, @@ -192,7 +253,7 @@ fn sample_eigs_debug_stats() -> anyhow::Result { store_gradient: true, store_divergences: true, store_unconstrained: true, - adapt_options: AdaptOptions { + adapt_options: EuclideanAdaptOptions { mass_matrix_options: nuts_rs::LowRankSettings { store_mass_matrix: false, ..Default::default()