diff --git a/src/adapt_strategy.rs b/src/adapt_strategy.rs index 91a138f..0f34a23 100644 --- a/src/adapt_strategy.rs +++ b/src/adapt_strategy.rs @@ -90,7 +90,7 @@ impl> AdaptStrategy for GlobalStrategy >; type Options = EuclideanAdaptOptions; - fn new(math: &mut M, options: Self::Options, num_tune: u64) -> Self { + fn new(math: &mut M, options: Self::Options, num_tune: u64, chain: u64) -> Self { let num_tune_f = num_tune as f64; let step_size_window = (options.step_size_window * num_tune_f) as u64; let early_end = (options.early_window * num_tune_f) as u64; @@ -100,7 +100,7 @@ impl> AdaptStrategy for GlobalStrategy Self { step_size: StepSizeStrategy::new(options.dual_average_options), - mass_matrix: A::new(math, options.mass_matrix_options, num_tune), + mass_matrix: A::new(math, options.mass_matrix_options, num_tune, chain), options, num_tune, early_end, @@ -116,13 +116,13 @@ impl> AdaptStrategy for GlobalStrategy math: &mut M, options: &mut NutsOptions, hamiltonian: &mut Self::Hamiltonian, - state: &State>::Point>, + position: &[f64], rng: &mut R, ) -> Result<(), NutsError> { self.mass_matrix - .init(math, options, hamiltonian, state, rng)?; + .init(math, options, hamiltonian, position, rng)?; self.step_size - .init(math, options, hamiltonian, state, rng)?; + .init(math, options, hamiltonian, position, rng)?; Ok(()) } @@ -186,8 +186,9 @@ impl> AdaptStrategy for GlobalStrategy // First time we change the mass matrix if did_change & self.has_initial_mass_matrix { self.has_initial_mass_matrix = false; + let position = math.box_array(state.point().position()); self.step_size - .init(math, options, hamiltonian, state, rng)?; + .init(math, options, hamiltonian, &position, rng)?; } else { self.step_size.update_stepsize(hamiltonian, false) } @@ -403,7 +404,18 @@ pub mod test_logps { unimplemented!() } - fn transformed_logp( + fn init_from_transformed_position( + &mut self, + _params: &Self::TransformParams, + _untransformed_position: &mut [f64], + _untransformed_gradient: &mut [f64], + _transformed_position: &[f64], + _transformed_gradient: &mut [f64], + ) -> Result<(f64, f64), Self::LogpError> { + unimplemented!() + } + + fn init_from_untransformed_position( &mut self, _params: &Self::TransformParams, _untransformed_position: &[f64], @@ -424,15 +436,19 @@ pub mod test_logps { unimplemented!() } - fn new_transformation( + fn new_transformation( &mut self, + _rng: &mut R, _untransformed_position: &[f64], _untransfogmed_gradient: &[f64], ) -> Result { unimplemented!() } - fn transformation_id(&self, _params: &Self::TransformParams) -> i64 { + fn transformation_id( + &self, + _params: &Self::TransformParams, + ) -> Result { unimplemented!() } } @@ -462,7 +478,8 @@ mod test { let max_energy_error = 1000f64; let step_size = 0.1f64; - let hamiltonian = EuclideanHamiltonian::new(mass_matrix, max_energy_error, step_size); + let hamiltonian = + EuclideanHamiltonian::new(&mut math, mass_matrix, max_energy_error, step_size); let options = NutsOptions { maxdepth: 10u64, store_gradient: true, diff --git a/src/chain.rs b/src/chain.rs index d4fbe4a..c3914a1 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -6,7 +6,7 @@ use crate::{ hamiltonian::{Hamiltonian, Point}, nuts::{draw, Collector, NutsOptions, NutsSampleStats, NutsStatsBuilder}, sampler_stats::SamplerStats, - state::{State, StatePool}, + state::State, Math, NutsError, Settings, }; @@ -35,7 +35,6 @@ where R: rand::Rng, A: AdaptStrategy, { - pool: StatePool>::Point>, hamiltonian: A::Hamiltonian, collector: A::Collector, options: NutsOptions, @@ -56,18 +55,15 @@ where { pub fn new( mut math: M, - hamiltonian: A::Hamiltonian, + mut hamiltonian: A::Hamiltonian, strategy: A, options: NutsOptions, rng: R, chain: u64, ) -> Self { - let pool_size: usize = options.maxdepth.checked_mul(2).unwrap().try_into().unwrap(); - let pool = hamiltonian.new_pool(&mut math, pool_size); - let init = pool.new_state(&mut math); + let init = hamiltonian.pool().new_state(&mut math); let collector = strategy.new_collector(&mut math); NutsChain { - pool, hamiltonian, collector, options, @@ -87,14 +83,14 @@ pub trait AdaptStrategy: SamplerStats { type Collector: Collector>::Point>; type Options: Copy + Send + Debug + Default; - fn new(math: &mut M, options: Self::Options, num_tune: u64) -> Self; + fn new(math: &mut M, options: Self::Options, num_tune: u64, chain: u64) -> Self; fn init( &mut self, math: &mut M, options: &mut NutsOptions, hamiltonian: &mut Self::Hamiltonian, - state: &State>::Point>, + position: &[f64], rng: &mut R, ) -> Result<(), NutsError>; @@ -151,24 +147,20 @@ where type AdaptStrategy = A; fn set_position(&mut self, position: &[f64]) -> Result<()> { - let state = self - .hamiltonian - .init_state(&mut self.math, &mut self.pool, position)?; - self.init = state; self.strategy.init( &mut self.math, &mut self.options, &mut self.hamiltonian, - &self.init, + position, &mut self.rng, )?; + self.init = self.hamiltonian.init_state(&mut self.math, position)?; Ok(()) } fn draw(&mut self) -> Result<(Box<[f64]>, Self::Stats)> { let (state, info) = draw( &mut self.math, - &mut self.pool, &mut self.init, &mut self.rng, &mut self.hamiltonian, diff --git a/src/cpu_math.rs b/src/cpu_math.rs index 94fa2b7..c617640 100644 --- a/src/cpu_math.rs +++ b/src/cpu_math.rs @@ -360,7 +360,7 @@ impl Math for CpuMath { ) } - fn transformed_logp( + fn init_from_untransformed_position( &mut self, params: &Self::TransformParams, untransformed_position: &Self::Vector, @@ -368,7 +368,7 @@ impl Math for CpuMath { transformed_position: &mut Self::Vector, transformed_gradient: &mut Self::Vector, ) -> Result<(f64, f64), Self::LogpErr> { - self.logp_func.transformed_logp( + self.logp_func.init_from_untransformed_position( params, untransformed_position.as_slice(), untransformed_gradient.as_slice_mut(), @@ -377,11 +377,28 @@ impl Math for CpuMath { ) } + fn init_from_transformed_position( + &mut self, + params: &Self::TransformParams, + untransformed_position: &mut Self::Vector, + untransformed_gradient: &mut Self::Vector, + transformed_position: &Self::Vector, + transformed_gradient: &mut Self::Vector, + ) -> Result<(f64, f64), Self::LogpErr> { + self.logp_func.init_from_transformed_position( + params, + untransformed_position.as_slice_mut(), + untransformed_gradient.as_slice_mut(), + transformed_position.as_slice(), + transformed_gradient.as_slice_mut(), + ) + } + fn update_transformation<'a, R: rand::Rng + ?Sized>( &'a mut self, rng: &mut R, - untransformed_positions: impl Iterator, - untransformed_gradients: impl Iterator, + untransformed_positions: impl ExactSizeIterator, + untransformed_gradients: impl ExactSizeIterator, params: &'a mut Self::TransformParams, ) -> Result<(), Self::LogpErr> { self.logp_func.update_transformation( @@ -392,18 +409,22 @@ impl Math for CpuMath { ) } - fn new_transformation( + fn new_transformation( &mut self, + rng: &mut R, untransformed_position: &Self::Vector, untransfogmed_gradient: &Self::Vector, + chain: u64, ) -> Result { self.logp_func.new_transformation( + rng, untransformed_position.as_slice(), untransfogmed_gradient.as_slice(), + chain, ) } - fn transformation_id(&self, params: &Self::TransformParams) -> i64 { + fn transformation_id(&self, params: &Self::TransformParams) -> Result { self.logp_func.transformation_id(params) } } @@ -417,35 +438,58 @@ pub trait CpuLogpFunc { fn inv_transform_normalize( &mut self, - params: &Self::TransformParams, - untransformed_position: &[f64], - untransofrmed_gradient: &[f64], - transformed_position: &mut [f64], - transformed_gradient: &mut [f64], - ) -> Result; + _params: &Self::TransformParams, + _untransformed_position: &[f64], + _untransformed_gradient: &[f64], + _transformed_position: &mut [f64], + _transformed_gradient: &mut [f64], + ) -> Result { + unimplemented!() + } - fn transformed_logp( + fn init_from_untransformed_position( &mut self, - params: &Self::TransformParams, - untransformed_position: &[f64], - untransformed_gradient: &mut [f64], - transformed_position: &mut [f64], - transformed_gradient: &mut [f64], - ) -> Result<(f64, f64), Self::LogpError>; + _params: &Self::TransformParams, + _untransformed_position: &[f64], + _untransformed_gradient: &mut [f64], + _transformed_position: &mut [f64], + _transformed_gradient: &mut [f64], + ) -> Result<(f64, f64), Self::LogpError> { + unimplemented!() + } + + fn init_from_transformed_position( + &mut self, + _params: &Self::TransformParams, + _untransformed_position: &mut [f64], + _untransformed_gradient: &mut [f64], + _transformed_position: &[f64], + _transformed_gradient: &mut [f64], + ) -> Result<(f64, f64), Self::LogpError> { + unimplemented!() + } fn update_transformation<'a, R: rand::Rng + ?Sized>( &'a mut self, - rng: &mut R, - untransformed_positions: impl Iterator, - untransformed_gradients: impl Iterator, - params: &'a mut Self::TransformParams, - ) -> Result<(), Self::LogpError>; + _rng: &mut R, + _untransformed_positions: impl ExactSizeIterator, + _untransformed_gradients: impl ExactSizeIterator, + _params: &'a mut Self::TransformParams, + ) -> Result<(), Self::LogpError> { + unimplemented!() + } - fn new_transformation( + fn new_transformation( &mut self, - untransformed_position: &[f64], - untransfogmed_gradient: &[f64], - ) -> Result; + _rng: &mut R, + _untransformed_position: &[f64], + _untransformed_gradient: &[f64], + _chain: u64, + ) -> Result { + unimplemented!() + } - fn transformation_id(&self, params: &Self::TransformParams) -> i64; + fn transformation_id(&self, _params: &Self::TransformParams) -> Result { + unimplemented!() + } } diff --git a/src/euclidean_hamiltonian.rs b/src/euclidean_hamiltonian.rs index a43307e..fd7b49e 100644 --- a/src/euclidean_hamiltonian.rs +++ b/src/euclidean_hamiltonian.rs @@ -18,15 +18,23 @@ pub struct EuclideanHamiltonian> { pub(crate) mass_matrix: Mass, max_energy_error: f64, step_size: f64, + pool: StatePool>, _phantom: PhantomData, } impl> EuclideanHamiltonian { - pub(crate) fn new(mass_matrix: Mass, max_energy_error: f64, step_size: f64) -> Self { + pub(crate) fn new( + math: &mut M, + mass_matrix: Mass, + max_energy_error: f64, + step_size: f64, + ) -> Self { + let pool = StatePool::new(math, 10); EuclideanHamiltonian { mass_matrix, max_energy_error, step_size, + pool, _phantom: PhantomData, } } @@ -127,8 +135,8 @@ impl Point for EuclideanPoint { self.potential_energy + self.kinetic_energy } - fn energy_error(&self) -> f64 { - self.energy() - self.initial_energy + fn initial_energy(&self) -> f64 { + self.initial_energy } fn new(math: &mut M) -> Self { @@ -275,14 +283,15 @@ impl> Hamiltonian for EuclideanHamiltonian>( &mut self, math: &mut M, - pool: &mut StatePool, start: &State, dir: Direction, collector: &mut C, ) -> LeapfrogResult { - let mut out = pool.new_state(math); + let mut out = self.pool().new_state(math); let out_point = out.try_point_mut().expect("New point has other references"); + out_point.initial_energy = start.point().initial_energy(); + let sign = match dir { Direction::Forward => 1, Direction::Backward => -1, @@ -323,7 +332,7 @@ impl> Hamiltonian for EuclideanHamiltonian self.max_energy_error) | !energy_error.is_finite() { let divergence_info = DivergenceInfo { logp_function_error: None, @@ -347,10 +356,9 @@ impl> Hamiltonian for EuclideanHamiltonian, init: &[f64], ) -> Result, NutsError> { - let mut state = pool.new_state(math); + let mut state = self.pool().new_state(math); let point = state.try_point_mut().expect("State already in use"); math.read_from_slice(&mut point.position, init); math.fill_array(&mut point.p_sum, 0.); @@ -390,13 +398,8 @@ impl> Hamiltonian for EuclideanHamiltonian, - state: &State, - ) -> State { - let mut new_state = pool.new_state(math); + fn copy_state(&mut self, math: &mut M, state: &State) -> State { + let mut new_state = self.pool().new_state(math); state.point().copy_into( math, new_state @@ -406,6 +409,10 @@ impl> Hamiltonian for EuclideanHamiltonian &mut StatePool { + &mut self.pool + } + fn step_size(&self) -> f64 { self.step_size } diff --git a/src/hamiltonian.rs b/src/hamiltonian.rs index 9047c0b..fd34c8f 100644 --- a/src/hamiltonian.rs +++ b/src/hamiltonian.rs @@ -54,7 +54,12 @@ pub trait Point: Sized { fn index_in_trajectory(&self) -> i64; fn energy(&self) -> f64; fn logp(&self) -> f64; - fn energy_error(&self) -> f64; + + fn energy_error(&self) -> f64 { + self.energy() - self.initial_energy() + } + + fn initial_energy(&self) -> f64; fn new(math: &mut M) -> Self; fn copy_into(&self, math: &mut M, other: &mut Self); @@ -73,7 +78,6 @@ pub trait Hamiltonian: SamplerStats + Sized { fn leapfrog>( &mut self, math: &mut M, - pool: &mut StatePool, start: &State, dir: Direction, collector: &mut C, @@ -93,7 +97,6 @@ pub trait Hamiltonian: SamplerStats + Sized { fn init_state( &mut self, math: &mut M, - pool: &mut StatePool, init: &[f64], ) -> Result, NutsError>; @@ -105,16 +108,9 @@ pub trait Hamiltonian: SamplerStats + Sized { rng: &mut R, ) -> Result<(), NutsError>; - fn new_pool(&self, math: &mut M, pool_size: usize) -> StatePool { - StatePool::new(math, pool_size) - } + fn pool(&mut self) -> &mut StatePool; - fn copy_state( - &self, - math: &mut M, - pool: &mut StatePool, - state: &State, - ) -> State; + fn copy_state(&mut self, math: &mut M, state: &State) -> State; fn step_size(&self) -> f64; fn step_size_mut(&mut self) -> &mut f64; diff --git a/src/lib.rs b/src/lib.rs index 51c4171..678f6a8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,6 +30,7 @@ //! //! impl CpuLogpFunc for PosteriorDensity { //! type LogpError = PosteriorLogpError; +//! type TransformParams = (); //! //! // We define a 10 dimensional normal distribution //! fn dim(&self) -> usize { 10 } @@ -111,9 +112,10 @@ pub use nuts::{NutsError, SampleStats}; pub use sampler::{ sample_sequentially, ChainOutput, ChainProgress, DiagGradNutsSettings, DrawStorage, LowRankNutsSettings, Model, NutsSettings, ProgressCallback, Sampler, SamplerWaitResult, - Settings, Trace, + Settings, Trace, TransformedNutsSettings, }; pub use low_rank_mass_matrix::LowRankSettings; pub use mass_matrix_adapt::DiagAdaptExpSettings; pub use stepsize_adapt::DualAverageSettings; +pub use transform_adapt_strategy::TransformedSettings; diff --git a/src/low_rank_mass_matrix.rs b/src/low_rank_mass_matrix.rs index 7cd4e50..3343c34 100644 --- a/src/low_rank_mass_matrix.rs +++ b/src/low_rank_mass_matrix.rs @@ -10,7 +10,7 @@ use itertools::Itertools; use crate::{ chain::AdaptStrategy, euclidean_hamiltonian::{EuclideanHamiltonian, EuclideanPoint}, - hamiltonian::Point, + hamiltonian::{Hamiltonian, Point}, mass_matrix::{DrawGradCollector, MassMatrix}, mass_matrix_adapt::MassMatrixAdaptStrategy, sampler_stats::{SamplerStats, StatTraceBuilder}, @@ -574,7 +574,7 @@ impl AdaptStrategy for LowRankMassMatrixStrategy { type Collector = DrawGradCollector; type Options = LowRankSettings; - fn new(math: &mut M, options: Self::Options, _num_tune: u64) -> Self { + fn new(math: &mut M, options: Self::Options, _num_tune: u64, _chain: u64) -> Self { Self::new(math.dim(), options) } @@ -582,14 +582,18 @@ impl AdaptStrategy for LowRankMassMatrixStrategy { &mut self, math: &mut M, _options: &mut crate::nuts::NutsOptions, - potential: &mut Self::Hamiltonian, - state: &State>, + hamiltonian: &mut Self::Hamiltonian, + position: &[f64], _rng: &mut R, ) -> Result<(), NutsError> { - self.add_draw(math, state); - potential - .mass_matrix - .update_from_grad(math, state.point().gradient(), 1f64, (1e-20, 1e20)); + let state = hamiltonian.init_state(math, position)?; + self.add_draw(math, &state); + hamiltonian.mass_matrix.update_from_grad( + math, + state.point().gradient(), + 1f64, + (1e-20, 1e20), + ); Ok(()) } diff --git a/src/mass_matrix_adapt.rs b/src/mass_matrix_adapt.rs index 5822122..f82ee2d 100644 --- a/src/mass_matrix_adapt.rs +++ b/src/mass_matrix_adapt.rs @@ -5,7 +5,7 @@ use rand::Rng; use crate::{ chain::AdaptStrategy, euclidean_hamiltonian::{EuclideanHamiltonian, EuclideanPoint}, - hamiltonian::Point, + hamiltonian::{Hamiltonian, Point}, mass_matrix::{DiagMassMatrix, DrawGradCollector, MassMatrix, RunningVariance}, nuts::NutsOptions, sampler_stats::SamplerStats, @@ -141,7 +141,7 @@ impl AdaptStrategy for Strategy { type Collector = DrawGradCollector; type Options = DiagAdaptExpSettings; - fn new(math: &mut M, options: Self::Options, _num_tune: u64) -> Self { + fn new(math: &mut M, options: Self::Options, _num_tune: u64, _chain: u64) -> Self { Self { exp_variance_draw: RunningVariance::new(math), exp_variance_grad: RunningVariance::new(math), @@ -156,10 +156,12 @@ impl AdaptStrategy for Strategy { &mut self, math: &mut M, _options: &mut NutsOptions, - potential: &mut Self::Hamiltonian, - state: &State>, + hamiltonian: &mut Self::Hamiltonian, + position: &[f64], _rng: &mut R, ) -> Result<(), NutsError> { + let state = hamiltonian.init_state(math, position)?; + self.exp_variance_draw .add_sample(math, state.point().position()); self.exp_variance_draw_bg @@ -169,7 +171,7 @@ impl AdaptStrategy for Strategy { self.exp_variance_grad_bg .add_sample(math, state.point().gradient()); - potential.mass_matrix.update_diag_grad( + hamiltonian.mass_matrix.update_diag_grad( math, state.point().gradient(), 1f64, diff --git a/src/math_base.rs b/src/math_base.rs index df6f126..e4a5e07 100644 --- a/src/math_base.rs +++ b/src/math_base.rs @@ -149,7 +149,7 @@ pub trait Math { transformed_gradient: &mut Self::Vector, ) -> Result; - fn transformed_logp( + fn init_from_untransformed_position( &mut self, params: &Self::TransformParams, untransformed_position: &Self::Vector, @@ -158,19 +158,30 @@ pub trait Math { transformed_gradient: &mut Self::Vector, ) -> Result<(f64, f64), Self::LogpErr>; + fn init_from_transformed_position( + &mut self, + params: &Self::TransformParams, + untransformed_position: &mut Self::Vector, + untransformed_gradient: &mut Self::Vector, + transformed_position: &Self::Vector, + transformed_gradient: &mut Self::Vector, + ) -> Result<(f64, f64), Self::LogpErr>; + fn update_transformation<'a, R: rand::Rng + ?Sized>( &'a mut self, rng: &mut R, - untransformed_positions: impl Iterator, - untransformed_gradients: impl Iterator, + untransformed_positions: impl ExactSizeIterator, + untransformed_gradients: impl ExactSizeIterator, params: &'a mut Self::TransformParams, ) -> Result<(), Self::LogpErr>; - fn new_transformation( + fn new_transformation( &mut self, + rng: &mut R, untransformed_position: &Self::Vector, untransfogmed_gradient: &Self::Vector, + chain: u64, ) -> Result; - fn transformation_id(&self, params: &Self::TransformParams) -> i64; + fn transformation_id(&self, params: &Self::TransformParams) -> Result; } diff --git a/src/nuts.rs b/src/nuts.rs index 8455bb8..5b33c3f 100644 --- a/src/nuts.rs +++ b/src/nuts.rs @@ -14,7 +14,7 @@ use crate::hamiltonian::{Direction, DivergenceInfo, Hamiltonian, LeapfrogResult, use crate::math::logaddexp; use crate::sampler::Settings; use crate::sampler_stats::StatTraceBuilder; -use crate::state::{State, StatePool}; +use crate::state::State; use crate::math_base::Math; @@ -82,7 +82,6 @@ struct NutsTree, C: Collector> { draw: State, log_size: f64, depth: u64, - initial_energy: f64, /// A tree is the main tree if it contains the initial point /// of the trajectory. @@ -105,14 +104,12 @@ enum ExtendResult, C: Collector> { impl, C: Collector> NutsTree { fn new(state: State) -> NutsTree { - let initial_energy = state.energy(); NutsTree { right: state.clone(), left: state.clone(), draw: state, depth: 0, log_size: 0., - initial_energy, is_main: true, _phantom2: PhantomData, } @@ -123,7 +120,6 @@ impl, C: Collector> NutsTree { fn extend( mut self, math: &mut M, - pool: &mut StatePool, rng: &mut R, hamiltonian: &mut H, direction: Direction, @@ -134,7 +130,7 @@ impl, C: Collector> NutsTree { H: Hamiltonian, R: rand::Rng + ?Sized, { - let mut other = match self.single_step(math, pool, hamiltonian, direction, collector) { + let mut other = match self.single_step(math, hamiltonian, direction, collector) { Ok(Ok(tree)) => tree, Ok(Err(info)) => return ExtendResult::Diverging(self, info), Err(err) => return ExtendResult::Err(err), @@ -142,8 +138,7 @@ impl, C: Collector> NutsTree { while other.depth < self.depth { use ExtendResult::*; - other = match other.extend(math, pool, rng, hamiltonian, direction, collector, options) - { + other = match other.extend(math, rng, hamiltonian, direction, collector, options) { Ok(tree) => tree, Turning(_) => { return Turning(self); @@ -226,7 +221,6 @@ impl, C: Collector> NutsTree { fn single_step( &self, math: &mut M, - pool: &mut StatePool, hamiltonian: &mut H, direction: Direction, collector: &mut C, @@ -235,20 +229,20 @@ impl, C: Collector> NutsTree { Direction::Forward => &self.right, Direction::Backward => &self.left, }; - let end = match hamiltonian.leapfrog(math, pool, start, direction, collector) { + let end = match hamiltonian.leapfrog(math, start, direction, collector) { LeapfrogResult::Divergence(info) => return Ok(Err(info)), LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())), LeapfrogResult::Ok(end) => end, }; - let log_size = self.initial_energy - end.energy(); + // TODO sign? + let log_size = -end.point().energy_error(); Ok(Ok(NutsTree { right: end.clone(), left: end.clone(), draw: end, depth: 0, log_size, - initial_energy: self.initial_energy, is_main: false, _phantom2: PhantomData, })) @@ -259,7 +253,7 @@ impl, C: Collector> NutsTree { depth: self.depth, divergence_info, reached_maxdepth: maxdepth, - initial_energy: self.initial_energy, + initial_energy: self.draw.point().initial_energy(), draw_energy: self.draw.energy(), } } @@ -275,7 +269,6 @@ pub struct NutsOptions { pub(crate) fn draw( math: &mut M, - pool: &mut StatePool, init: &mut State, rng: &mut R, hamiltonian: &mut H, @@ -294,7 +287,7 @@ where let mut tree = NutsTree::new(init.clone()); while tree.depth < options.maxdepth { let direction: Direction = rng.gen(); - tree = match tree.extend(math, pool, rng, hamiltonian, direction, collector, options) { + tree = match tree.extend(math, rng, hamiltonian, direction, collector, options) { ExtendResult::Ok(tree) => tree, ExtendResult::Turning(tree) => { let info = tree.info(false, None); diff --git a/src/sampler.rs b/src/sampler.rs index ab5d93b..ad2557c 100644 --- a/src/sampler.rs +++ b/src/sampler.rs @@ -169,10 +169,10 @@ impl Settings for LowRankNutsSettings { rng: &mut R, ) -> Self::Chain { let num_tune = self.num_tune; - let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune); + let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune, chain); let mass_matrix = LowRankMassMatrix::new(&mut math, self.adapt_options.mass_matrix_options); let max_energy_error = self.max_energy_error; - let potential = EuclideanHamiltonian::new(mass_matrix, max_energy_error, 1f64); + let potential = EuclideanHamiltonian::new(&mut math, mass_matrix, max_energy_error, 1f64); let options = NutsOptions { maxdepth: self.maxdepth, @@ -230,13 +230,13 @@ impl Settings for DiagGradNutsSettings { rng: &mut R, ) -> Self::Chain { let num_tune = self.num_tune; - let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune); + let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune, chain); let mass_matrix = DiagMassMatrix::new( &mut math, self.adapt_options.mass_matrix_options.store_mass_matrix, ); let max_energy_error = self.max_energy_error; - let potential = EuclideanHamiltonian::new(mass_matrix, max_energy_error, 1f64); + let potential = EuclideanHamiltonian::new(&mut math, mass_matrix, max_energy_error, 1f64); let options = NutsOptions { maxdepth: self.maxdepth, @@ -296,7 +296,7 @@ impl Settings for TransformedNutsSettings { let num_tune = self.num_tune; let max_energy_error = self.max_energy_error; - let strategy = TransformAdaptation::new(&mut math, self.adapt_options, num_tune); + let strategy = TransformAdaptation::new(&mut math, self.adapt_options, num_tune, chain); let hamiltonian = TransformedHamiltonian::new(&mut math, max_energy_error); let options = NutsOptions { @@ -1028,7 +1028,7 @@ pub mod test_logps { unimplemented!() } - fn transformed_logp( + fn init_from_untransformed_position( &mut self, _params: &Self::TransformParams, _untransformed_position: &[f64], @@ -1039,6 +1039,17 @@ pub mod test_logps { unimplemented!() } + fn init_from_transformed_position( + &mut self, + _params: &Self::TransformParams, + _untransformed_position: &mut [f64], + _untransformed_gradient: &mut [f64], + _transformed_position: &[f64], + _transformed_gradient: &mut [f64], + ) -> std::result::Result<(f64, f64), Self::LogpError> { + unimplemented!() + } + fn update_transformation<'b, R: rand::Rng + ?Sized>( &'b mut self, _rng: &mut R, @@ -1049,15 +1060,19 @@ pub mod test_logps { unimplemented!() } - fn new_transformation( + fn new_transformation( &mut self, + _rng: &mut R, _untransformed_position: &[f64], _untransfogmed_gradient: &[f64], ) -> std::result::Result { unimplemented!() } - fn transformation_id(&self, _params: &Self::TransformParams) -> i64 { + fn transformation_id( + &self, + _params: &Self::TransformParams, + ) -> std::result::Result { unimplemented!() } } diff --git a/src/stepsize_adapt.rs b/src/stepsize_adapt.rs index e07b2be..16211c5 100644 --- a/src/stepsize_adapt.rs +++ b/src/stepsize_adapt.rs @@ -8,7 +8,6 @@ use crate::{ hamiltonian::{Direction, Hamiltonian, LeapfrogResult, Point}, nuts::{Collector, NutsOptions}, sampler_stats::{SamplerStats, StatTraceBuilder}, - state::State, stepsize::{AcceptanceRateCollector, DualAverage, DualAverageOptions}, Math, NutsError, Settings, }; @@ -37,12 +36,10 @@ impl Strategy { math: &mut M, options: &mut NutsOptions, hamiltonian: &mut impl Hamiltonian, - state: &State, + position: &[f64], rng: &mut R, ) -> Result<(), NutsError> { - let mut pool = hamiltonian.new_pool(math, 1); - - let mut state = hamiltonian.copy_state(math, &mut pool, state); + let mut state = hamiltonian.init_state(math, position)?; hamiltonian.initialize_trajectory(math, &mut state, rng)?; let mut collector = AcceptanceRateCollector::new(); @@ -51,8 +48,7 @@ impl Strategy { *hamiltonian.step_size_mut() = self.options.initial_step; - let state_next = - hamiltonian.leapfrog(math, &mut pool, &state, Direction::Forward, &mut collector); + let state_next = hamiltonian.leapfrog(math, &state, Direction::Forward, &mut collector); let LeapfrogResult::Ok(_) = state_next else { return Ok(()); @@ -68,7 +64,7 @@ impl Strategy { for _ in 0..100 { let mut collector = AcceptanceRateCollector::new(); collector.register_init(math, &state, options); - let state_next = hamiltonian.leapfrog(math, &mut pool, &state, dir, &mut collector); + let state_next = hamiltonian.leapfrog(math, &state, dir, &mut collector); let LeapfrogResult::Ok(_) = state_next else { *hamiltonian.step_size_mut() = self.options.initial_step; return Ok(()); diff --git a/src/transform_adapt_strategy.rs b/src/transform_adapt_strategy.rs index 3875e04..7c26f64 100644 --- a/src/transform_adapt_strategy.rs +++ b/src/transform_adapt_strategy.rs @@ -13,11 +13,11 @@ use crate::{DualAverageSettings, Math, NutsError, Settings}; #[derive(Clone, Copy, Debug)] pub struct TransformedSettings { - step_size_window: f64, - transform_update_freq: u64, - use_orbit_for_training: bool, - dual_average_options: DualAverageSettings, - transform_train_max_energy_error: f64, + pub step_size_window: f64, + pub transform_update_freq: u64, + pub use_orbit_for_training: bool, + pub dual_average_options: DualAverageSettings, + pub transform_train_max_energy_error: f64, } impl Default for TransformedSettings { @@ -25,7 +25,7 @@ impl Default for TransformedSettings { Self { step_size_window: 0.1f64, transform_update_freq: 50, - use_orbit_for_training: false, + use_orbit_for_training: true, transform_train_max_energy_error: 50f64, dual_average_options: Default::default(), } @@ -38,6 +38,7 @@ pub struct TransformAdaptation { num_tune: u64, final_window_size: u64, tuning: bool, + chain: u64, } #[derive(Clone, Copy, Default, Debug)] @@ -103,11 +104,10 @@ impl> Collector for DrawCollector { if self.collect_orbit { let point = end.point(); let energy_error = point.energy_error(); - if energy_error.abs() > self.max_energy_error { - return; + if energy_error.abs() < self.max_energy_error { + self.draws.push(math.copy_array(point.position())); + self.grads.push(math.copy_array(point.gradient())); } - self.draws.push(math.copy_array(point.position())); - self.grads.push(math.copy_array(point.gradient())); } } @@ -115,11 +115,10 @@ impl> Collector for DrawCollector { if !self.collect_orbit { let point = state.point(); let energy_error = point.energy_error(); - if energy_error.abs() > self.max_energy_error { - return; + if energy_error.abs() < self.max_energy_error { + self.draws.push(math.copy_array(point.position())); + self.grads.push(math.copy_array(point.gradient())); } - self.draws.push(math.copy_array(point.position())); - self.grads.push(math.copy_array(point.gradient())); } } } @@ -136,15 +135,17 @@ impl AdaptStrategy for TransformAdaptation { type Options = TransformedSettings; - fn new(_math: &mut M, options: Self::Options, num_tune: u64) -> Self { + fn new(_math: &mut M, options: Self::Options, num_tune: u64, chain: u64) -> Self { let step_size = StepSizeStrategy::new(options.dual_average_options); - let final_window_size = ((num_tune as f64) * options.step_size_window).floor() as u64; + let final_window_size = + ((num_tune as f64) * (1f64 - options.step_size_window)).floor() as u64; Self { step_size, options, num_tune, final_window_size, - tuning: false, + tuning: true, + chain, } } @@ -153,12 +154,12 @@ impl AdaptStrategy for TransformAdaptation { math: &mut M, options: &mut NutsOptions, hamiltonian: &mut Self::Hamiltonian, - state: &State>::Point>, + position: &[f64], rng: &mut R, ) -> Result<(), NutsError> { + hamiltonian.init_transformation(rng, math, position, self.chain)?; self.step_size - .init(math, options, hamiltonian, state, rng)?; - hamiltonian.init_transformation(math, state.point())?; + .init(math, options, hamiltonian, position, rng)?; Ok(()) } @@ -180,7 +181,16 @@ impl AdaptStrategy for TransformAdaptation { } if draw < self.final_window_size { - if (draw + 1) % self.options.transform_update_freq == 0 { + if draw < 100 { + if (draw > 0) & (draw % 10 == 0) { + hamiltonian.update_params( + math, + rng, + collector.collector2.draws.iter(), + collector.collector2.grads.iter(), + )?; + } + } else if (draw > 0) & (draw % self.options.transform_update_freq == 0) { hamiltonian.update_params( math, rng, @@ -188,6 +198,8 @@ impl AdaptStrategy for TransformAdaptation { collector.collector2.grads.iter(), )?; } + self.step_size.update_estimator_early(); + self.step_size.update_stepsize(hamiltonian, false); return Ok(()); } diff --git a/src/transformed_hamiltonian.rs b/src/transformed_hamiltonian.rs index f4bfb2c..f1c1e66 100644 --- a/src/transformed_hamiltonian.rs +++ b/src/transformed_hamiltonian.rs @@ -47,16 +47,16 @@ impl TransformedPoint { } fn update_kinetic_energy(&mut self, math: &mut M) { - self.kinetic_energy = math.array_vector_dot(&self.velocity, &self.velocity); + self.kinetic_energy = 0.5 * math.array_vector_dot(&self.velocity, &self.velocity); } - fn update_gradient( + fn init_from_untransformed_position( &mut self, hamiltonian: &TransformedHamiltonian, math: &mut M, ) -> Result<(), M::LogpErr> { let (logp, logdet) = { - math.transformed_logp( + math.init_from_untransformed_position( hamiltonian.params.as_ref().expect("No transformation set"), &self.untransformed_position, &mut self.untransformed_gradient, @@ -69,6 +69,25 @@ impl TransformedPoint { Ok(()) } + fn init_from_transformed_position( + &mut self, + hamiltonian: &TransformedHamiltonian, + math: &mut M, + ) -> Result<(), M::LogpErr> { + let (logp, logdet) = { + math.init_from_transformed_position( + hamiltonian.params.as_ref().expect("No transformation set"), + &mut self.untransformed_position, + &mut self.untransformed_gradient, + &self.transformed_position, + &mut self.transformed_gradient, + ) + }?; + self.logp = logp; + self.logdet = logdet; + Ok(()) + } + fn is_valid(&self, math: &mut M) -> bool { if !math.array_all_finite(&self.transformed_position) { return false; @@ -101,11 +120,11 @@ impl Point for TransformedPoint { } fn energy(&self) -> f64 { - self.kinetic_energy - self.logp - self.logdet + self.kinetic_energy - (self.logp + self.logdet) } - fn energy_error(&self) -> f64 { - self.energy() - self.initial_energy + fn initial_energy(&self) -> f64 { + self.initial_energy } fn logp(&self) -> f64 { @@ -159,32 +178,47 @@ impl Point for TransformedPoint { pub struct TransformedHamiltonian { ones: M::Vector, + zeros: M::Vector, step_size: f64, params: Option, max_energy_error: f64, _phantom: PhantomData, + pool: StatePool>, } impl TransformedHamiltonian { pub fn new(math: &mut M, max_energy_error: f64) -> Self { let mut ones = math.new_array(); math.fill_array(&mut ones, 1f64); + let mut zeros = math.new_array(); + math.fill_array(&mut zeros, 0f64); + let pool = StatePool::new(math, 10); Self { step_size: 0f64, ones, + zeros, params: None, max_energy_error, _phantom: Default::default(), + pool, } } - pub fn init_transformation( + pub fn init_transformation( &mut self, + rng: &mut R, math: &mut M, - state: &TransformedPoint, + position: &[f64], + chain: u64, ) -> Result<(), NutsError> { + let mut gradient_array = math.new_array(); + let mut position_array = math.new_array(); + math.read_from_slice(&mut position_array, position); + let _ = math + .logp_array(&position_array, &mut gradient_array) + .map_err(|_| NutsError::BadInitGrad())?; let params = math - .new_transformation(state.position(), state.gradient()) + .new_transformation(rng, &position_array, &gradient_array, chain) .map_err(|_| NutsError::BadInitGrad())?; self.params = Some(params); Ok(()) @@ -194,8 +228,8 @@ impl TransformedHamiltonian { &'a mut self, math: &'a mut M, rng: &mut R, - draws: impl Iterator, - grads: impl Iterator, + draws: impl ExactSizeIterator, + grads: impl ExactSizeIterator, ) -> Result<(), NutsError> { math.update_transformation( rng, @@ -249,14 +283,16 @@ impl Hamiltonian for TransformedHamiltonian { fn leapfrog>( &mut self, math: &mut M, - pool: &mut StatePool, start: &State, dir: Direction, collector: &mut C, ) -> LeapfrogResult { - let mut out = pool.new_state(math); + let mut out = self.pool().new_state(math); let out_point = out.try_point_mut().expect("New point has other references"); + out_point.initial_energy = start.point().initial_energy(); + out_point.transform_id = start.point().transform_id; + let sign = match dir { Direction::Forward => 1, Direction::Backward => -1, @@ -269,7 +305,7 @@ impl Hamiltonian for TransformedHamiltonian { .first_velocity_halfstep(math, out_point, epsilon); start.point().position_step(math, out_point, epsilon); - if let Err(logp_error) = out_point.update_gradient(self, math) { + if let Err(logp_error) = out_point.init_from_transformed_position(self, math) { if !logp_error.is_recoverable() { return LeapfrogResult::Err(logp_error); } @@ -292,7 +328,7 @@ impl Hamiltonian for TransformedHamiltonian { out_point.update_kinetic_energy(math); out_point.index_in_trajectory = start.index_in_trajectory() + sign; - let energy_error = { out_point.energy() - start.point().initial_energy }; + let energy_error = out_point.energy_error(); if (energy_error > self.max_energy_error) | !energy_error.is_finite() { let divergence_info = DivergenceInfo { logp_function_error: None, @@ -325,52 +361,28 @@ impl Hamiltonian for TransformedHamiltonian { (state2, state1) }; - let a = start.index_in_trajectory(); - let b = end.index_in_trajectory(); - - assert!(a < b); - // TODO double check - let (turn1, turn2) = if (a >= 0) & (b >= 0) { - math.scalar_prods3( - &end.point().transformed_position, - &start.point().transformed_position, - &start.point().velocity, - &end.point().velocity, - &start.point().velocity, - ) - } else if (b >= 0) & (a < 0) { - math.scalar_prods2( - &end.point().transformed_position, - &start.point().transformed_position, - &end.point().velocity, - &start.point().velocity, - ) - } else { - assert!((a < 0) & (b < 0)); - math.scalar_prods3( - &start.point().transformed_position, - &end.point().transformed_position, - &end.point().velocity, - &end.point().velocity, - &start.point().velocity, - ) - }; + let (turn1, turn2) = math.scalar_prods3( + &end.point().transformed_position, + &start.point().transformed_position, + &self.zeros, + &start.point().velocity, + &end.point().velocity, + ); - (turn1 < 0.) | (turn2 < 0.) + (turn1 < 0f64) | (turn2 < 0f64) } fn init_state( &mut self, math: &mut M, - pool: &mut StatePool, init: &[f64], ) -> Result, NutsError> { - let mut state = pool.new_state(math); + let mut state = self.pool().new_state(math); let point = state.try_point_mut().expect("State already in use"); math.read_from_slice(&mut point.untransformed_position, init); point - .update_gradient(self, math) + .init_from_untransformed_position(self, math) .map_err(|e| NutsError::LogpFailure(Box::new(e)))?; if !point.is_valid(math) { @@ -388,9 +400,10 @@ impl Hamiltonian for TransformedHamiltonian { ) -> Result<(), NutsError> { let point = state.try_point_mut().expect("State has other references"); math.array_gaussian(rng, &mut point.velocity, &self.ones); - if math.transformation_id(self.params.as_ref().expect("No transformation set")) - != point.transform_id - { + let current_transform_id = math + .transformation_id(self.params.as_ref().expect("No transformation set")) + .map_err(|e| NutsError::LogpFailure(Box::new(e)))?; + if current_transform_id != point.transform_id { let logdet = math .inv_transform_normalize( self.params.as_ref().expect("No transformation set"), @@ -401,19 +414,20 @@ impl Hamiltonian for TransformedHamiltonian { ) .map_err(|e| NutsError::LogpFailure(Box::new(e)))?; point.logdet = logdet; + point.transform_id = current_transform_id; } point.update_kinetic_energy(math); point.index_in_trajectory = 0; + point.initial_energy = point.energy(); Ok(()) } - fn copy_state( - &self, - math: &mut M, - pool: &mut StatePool, - state: &State, - ) -> State { - let mut new_state = pool.new_state(math); + fn pool(&mut self) -> &mut StatePool { + &mut self.pool + } + + fn copy_state(&mut self, math: &mut M, state: &State) -> State { + let mut new_state = self.pool.new_state(math); state.point().copy_into( math, new_state diff --git a/tests/sample_normal.rs b/tests/sample_normal.rs index 07d773a..0a55aaa 100644 --- a/tests/sample_normal.rs +++ b/tests/sample_normal.rs @@ -57,45 +57,57 @@ impl<'a> CpuLogpFunc for NormalLogp<'a> { fn inv_transform_normalize( &mut self, - params: &Self::TransformParams, - untransformed_position: &[f64], - untransofrmed_gradient: &[f64], - transformed_position: &mut [f64], - transformed_gradient: &mut [f64], + _params: &Self::TransformParams, + _untransformed_position: &[f64], + _untransofrmed_gradient: &[f64], + _transformed_position: &mut [f64], + _transformed_gradient: &mut [f64], ) -> Result { todo!() } - fn transformed_logp( + fn init_from_transformed_position( &mut self, - params: &Self::TransformParams, - untransformed_position: &[f64], - untransformed_gradient: &mut [f64], - transformed_position: &mut [f64], - transformed_gradient: &mut [f64], + _params: &Self::TransformParams, + _untransformed_position: &mut [f64], + _untransformed_gradient: &mut [f64], + _transformed_position: &[f64], + _transformed_gradient: &mut [f64], + ) -> Result<(f64, f64), Self::LogpError> { + todo!() + } + + fn init_from_untransformed_position( + &mut self, + _params: &Self::TransformParams, + _untransformed_position: &[f64], + _untransformed_gradient: &mut [f64], + _transformed_position: &mut [f64], + _transformed_gradient: &mut [f64], ) -> Result<(f64, f64), Self::LogpError> { todo!() } fn update_transformation<'b, R: rand::Rng + ?Sized>( &'b mut self, - rng: &mut R, - untransformed_positions: impl Iterator, - untransformed_gradients: impl Iterator, - params: &'b mut Self::TransformParams, + _rng: &mut R, + _untransformed_positions: impl Iterator, + _untransformed_gradients: impl Iterator, + _params: &'b mut Self::TransformParams, ) -> Result<(), Self::LogpError> { todo!() } - fn new_transformation( + fn new_transformation( &mut self, - untransformed_position: &[f64], - untransfogmed_gradient: &[f64], + _rng: &mut R, + _untransformed_position: &[f64], + _untransfogmed_gradient: &[f64], ) -> Result { todo!() } - fn transformation_id(&self, params: &Self::TransformParams) -> i64 { + fn transformation_id(&self, _params: &Self::TransformParams) -> Result { todo!() } }