diff --git a/Cargo.toml b/Cargo.toml index aa72132..e97786b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "nuts-rs" -version = "0.12.0" +version = "0.12.1" authors = [ "Adrian Seyboldt ", "PyMC Developers ", diff --git a/src/adapt_strategy.rs b/src/adapt_strategy.rs index 46f256d..655e899 100644 --- a/src/adapt_strategy.rs +++ b/src/adapt_strategy.rs @@ -21,7 +21,7 @@ use crate::{ use crate::nuts::{SamplerStats, StatTraceBuilder}; pub struct GlobalStrategy> { - step_size: StepSizeStrategy, + step_size: StepSizeStrategy, mass_matrix: A, options: AdaptOptions, num_tune: u64, @@ -73,7 +73,7 @@ impl> SamplerStats for GlobalStrategy< fn new_builder(&self, settings: &impl Settings, dim: usize) -> Self::Builder { CombinedStatsBuilder { - stats1: self.step_size.new_builder(settings, dim), + stats1: SamplerStats::::new_builder(&self.step_size, settings, dim), stats2: self.mass_matrix.new_builder(settings, dim), } } @@ -87,7 +87,7 @@ impl> AdaptStats for GlobalStrategy> AdaptStrategy for GlobalStrategy { type Potential = A::Potential; - type Collector = CombinedCollector, A::Collector>; + type Collector = CombinedCollector; type Options = AdaptOptions; fn new(math: &mut M, options: Self::Options, num_tune: u64) -> Self { @@ -99,7 +99,7 @@ impl> AdaptStrategy for GlobalStrategy assert!(early_end < num_tune); Self { - step_size: StepSizeStrategy::new(math, options.dual_average_options, num_tune), + step_size: StepSizeStrategy::new(options.dual_average_options), mass_matrix: A::new(math, options.mass_matrix_options, num_tune), options, num_tune, @@ -121,7 +121,6 @@ impl> AdaptStrategy for GlobalStrategy ) { self.mass_matrix.init(math, options, potential, state, rng); self.step_size.init(math, options, potential, state, rng); - self.step_size.enable(); } fn adapt( @@ -134,6 +133,8 @@ impl> AdaptStrategy for GlobalStrategy state: &State, rng: &mut R, ) { + self.step_size.update(&collector.collector1); + if draw >= self.num_tune { self.tuning = false; return; @@ -172,44 +173,31 @@ impl> AdaptStrategy for GlobalStrategy if did_change { self.last_update = draw; } + if is_late { - self.step_size.use_mean_sym(); + self.step_size.update_estimator_late(); + } else { + self.step_size.update_estimator_early(); } + // First time we change the mass matrix if did_change & self.has_initial_mass_matrix { self.has_initial_mass_matrix = false; self.step_size.init(math, options, potential, state, rng); } else { - self.step_size.adapt( - math, - options, - potential, - draw, - &collector.collector1, - state, - rng, - ); + self.step_size.update_stepsize(potential, false) } return; } - if draw == self.num_tune - 1 { - self.step_size.finalize(); - } - self.step_size.adapt( - math, - options, - potential, - draw, - &collector.collector1, - state, - rng, - ); + self.step_size.update_estimator_late(); + let is_last = draw == self.num_tune - 1; + self.step_size.update_stepsize(potential, is_last); } fn new_collector(&self, math: &mut M) -> Self::Collector { CombinedCollector { - collector1: self.step_size.new_collector(math), + collector1: self.step_size.new_collector(), collector2: self.mass_matrix.new_collector(math), _phantom: PhantomData, } diff --git a/src/stepsize.rs b/src/stepsize.rs index 06a2657..122737d 100644 --- a/src/stepsize.rs +++ b/src/stepsize.rs @@ -1,5 +1,3 @@ -use std::marker::PhantomData; - use crate::{ math_base::Math, nuts::{Collector, NutsOptions}, @@ -103,25 +101,23 @@ impl RunningMean { } } -pub struct AcceptanceRateCollector { +pub struct AcceptanceRateCollector { initial_energy: f64, pub(crate) mean: RunningMean, pub(crate) mean_sym: RunningMean, - phantom: PhantomData, } -impl AcceptanceRateCollector { - pub(crate) fn new() -> AcceptanceRateCollector { +impl AcceptanceRateCollector { + pub(crate) fn new() -> AcceptanceRateCollector { AcceptanceRateCollector { initial_energy: 0., mean: RunningMean::new(), mean_sym: RunningMean::new(), - phantom: PhantomData, } } } -impl Collector for AcceptanceRateCollector { +impl Collector for AcceptanceRateCollector { fn register_leapfrog( &mut self, _math: &mut M, diff --git a/src/stepsize_adapt.rs b/src/stepsize_adapt.rs index 5a209ea..114ff69 100644 --- a/src/stepsize_adapt.rs +++ b/src/stepsize_adapt.rs @@ -1,5 +1,3 @@ -use std::marker::PhantomData; - use arrow::{ array::{ArrayBuilder, PrimitiveBuilder, StructArray}, datatypes::{DataType, Field, Float64Type, UInt64Type}, @@ -7,40 +5,143 @@ use arrow::{ use rand::Rng; use crate::{ - mass_matrix_adapt::MassMatrixAdaptStrategy, nuts::{ - AdaptStats, AdaptStrategy, Collector, Direction, Hamiltonian, NutsOptions, SamplerStats, - StatTraceBuilder, + AdaptStats, Collector, Direction, Hamiltonian, NutsOptions, SamplerStats, StatTraceBuilder, }, state::State, stepsize::{AcceptanceRateCollector, DualAverage, DualAverageOptions}, Math, Settings, }; -pub struct Strategy { +pub struct Strategy { step_size_adapt: DualAverage, options: DualAverageSettings, - enabled: bool, - use_mean_sym: bool, - finalized: bool, last_mean_tree_accept: f64, last_sym_mean_tree_accept: f64, last_n_steps: u64, - _phantom1: PhantomData, - _phantom2: PhantomData, } -impl Strategy { - pub fn enable(&mut self) { - self.enabled = true; +impl Strategy { + pub fn new(options: DualAverageSettings) -> Self { + Self { + options, + step_size_adapt: DualAverage::new(options.params, options.initial_step), + last_n_steps: 0, + last_sym_mean_tree_accept: 0.0, + last_mean_tree_accept: 0.0, + } + } + + pub fn init( + &mut self, + math: &mut M, + options: &mut NutsOptions, + potential: &mut impl Hamiltonian, + state: &State, + rng: &mut R, + ) { + let mut pool = potential.new_pool(math, 1); + + let mut state = potential.copy_state(math, &mut pool, state); + state + .try_mut_inner() + .expect("New state should have only one reference") + .idx_in_trajectory = 0; + potential.randomize_momentum(math, &mut state, rng); + + let mut collector = AcceptanceRateCollector::new(); + + collector.register_init(math, &state, options); + + *potential.stepsize_mut() = self.options.initial_step; + + let state_next = potential.leapfrog( + math, + &mut pool, + &state, + Direction::Forward, + state.energy(), + &mut collector, + ); + + let Ok(_) = state_next else { + return; + }; + + let accept_stat = collector.mean.current(); + let dir = if accept_stat > self.options.target_accept { + Direction::Forward + } else { + Direction::Backward + }; + + for _ in 0..100 { + let mut collector = AcceptanceRateCollector::new(); + collector.register_init(math, &state, options); + let state_next = + potential.leapfrog(math, &mut pool, &state, dir, state.energy(), &mut collector); + let Ok(_) = state_next else { + *potential.stepsize_mut() = self.options.initial_step; + return; + }; + let accept_stat = collector.mean.current(); + match dir { + Direction::Forward => { + if (accept_stat <= self.options.target_accept) | (potential.stepsize() > 1e5) { + self.step_size_adapt = + DualAverage::new(self.options.params, potential.stepsize()); + return; + } + *potential.stepsize_mut() *= 2.; + } + Direction::Backward => { + if (accept_stat >= self.options.target_accept) | (potential.stepsize() < 1e-10) + { + self.step_size_adapt = + DualAverage::new(self.options.params, potential.stepsize()); + return; + } + *potential.stepsize_mut() /= 2.; + } + } + } + // If we don't find something better, use the specified initial value + *potential.stepsize_mut() = self.options.initial_step; + } + + pub fn update(&mut self, collector: &AcceptanceRateCollector) { + let mean_sym = collector.mean_sym.current(); + let mean = collector.mean.current(); + let n_steps = collector.mean.count(); + self.last_mean_tree_accept = mean; + self.last_sym_mean_tree_accept = mean_sym; + self.last_n_steps = n_steps; + } + + pub fn update_estimator_early(&mut self) { + self.step_size_adapt + .advance(self.last_mean_tree_accept, self.options.target_accept); } - pub fn finalize(&mut self) { - self.finalized = true; + pub fn update_estimator_late(&mut self) { + self.step_size_adapt + .advance(self.last_sym_mean_tree_accept, self.options.target_accept); } - pub fn use_mean_sym(&mut self) { - self.use_mean_sym = true; + pub fn update_stepsize( + &mut self, + potential: &mut impl Hamiltonian, + use_best_guess: bool, + ) { + if use_best_guess { + *potential.stepsize_mut() = self.step_size_adapt.current_step_size_adapted(); + } else { + *potential.stepsize_mut() = self.step_size_adapt.current_step_size(); + } + } + + pub fn new_collector(&self) -> AcceptanceRateCollector { + AcceptanceRateCollector::new() } } @@ -119,7 +220,7 @@ impl StatTraceBuilder for StatsBuilder { } } -impl> SamplerStats for Strategy { +impl SamplerStats for Strategy { type Builder = StatsBuilder; type Stats = Stats; @@ -142,7 +243,7 @@ impl> SamplerStats for Strategy> AdaptStats for Strategy { +impl AdaptStats for Strategy { fn num_grad_evals(stats: &Self::Stats) -> usize { stats.n_steps as usize } @@ -164,141 +265,3 @@ impl Default for DualAverageSettings { } } } - -impl> AdaptStrategy for Strategy { - type Potential = Mass::Potential; - type Collector = AcceptanceRateCollector; - type Options = DualAverageSettings; - - fn new(_math: &mut M, options: Self::Options, _num_tune: u64) -> Self { - Self { - options, - enabled: true, - step_size_adapt: DualAverage::new(options.params, options.initial_step), - finalized: false, - use_mean_sym: false, - last_n_steps: 0, - last_sym_mean_tree_accept: 0.0, - last_mean_tree_accept: 0.0, - _phantom1: PhantomData, - _phantom2: PhantomData, - } - } - - fn init( - &mut self, - math: &mut M, - options: &mut NutsOptions, - potential: &mut Self::Potential, - state: &State, - rng: &mut R, - ) { - let mut pool = potential.new_pool(math, 1); - - let mut state = potential.copy_state(math, &mut pool, state); - state - .try_mut_inner() - .expect("New state should have only one reference") - .idx_in_trajectory = 0; - potential.randomize_momentum(math, &mut state, rng); - - let mut collector = AcceptanceRateCollector::new(); - - collector.register_init(math, &state, options); - - *potential.stepsize_mut() = self.options.initial_step; - - let state_next = potential.leapfrog( - math, - &mut pool, - &state, - Direction::Forward, - state.energy(), - &mut collector, - ); - - let Ok(_) = state_next else { - return; - }; - - let accept_stat = collector.mean.current(); - let dir = if accept_stat > self.options.target_accept { - Direction::Forward - } else { - Direction::Backward - }; - - for _ in 0..100 { - let mut collector = AcceptanceRateCollector::new(); - collector.register_init(math, &state, options); - let state_next = - potential.leapfrog(math, &mut pool, &state, dir, state.energy(), &mut collector); - let Ok(_) = state_next else { - *potential.stepsize_mut() = self.options.initial_step; - return; - }; - let accept_stat = collector.mean.current(); - match dir { - Direction::Forward => { - if (accept_stat <= self.options.target_accept) | (potential.stepsize() > 1e5) { - self.step_size_adapt = - DualAverage::new(self.options.params, potential.stepsize()); - return; - } - *potential.stepsize_mut() *= 2.; - } - Direction::Backward => { - if (accept_stat >= self.options.target_accept) | (potential.stepsize() < 1e-10) - { - self.step_size_adapt = - DualAverage::new(self.options.params, potential.stepsize()); - return; - } - *potential.stepsize_mut() /= 2.; - } - } - } - // If we don't find something better, use the specified initial value - *potential.stepsize_mut() = self.options.initial_step; - } - - fn adapt( - &mut self, - _math: &mut M, - _options: &mut NutsOptions, - potential: &mut Self::Potential, - _draw: u64, - collector: &Self::Collector, - _state: &State, - _rng: &mut R, - ) { - let mean_sym = collector.mean_sym.current(); - let mean = collector.mean.current(); - let n_steps = collector.mean.count(); - self.last_mean_tree_accept = mean; - self.last_sym_mean_tree_accept = mean_sym; - self.last_n_steps = n_steps; - - let current = if self.use_mean_sym { mean_sym } else { mean }; - if self.finalized { - self.step_size_adapt - .advance(current, self.options.target_accept); - *potential.stepsize_mut() = self.step_size_adapt.current_step_size_adapted(); - return; - } - if !self.enabled { - return; - } - self.step_size_adapt - .advance(current, self.options.target_accept); - *potential.stepsize_mut() = self.step_size_adapt.current_step_size() - } - - fn new_collector(&self, _math: &mut M) -> Self::Collector { - AcceptanceRateCollector::new() - } - - fn is_tuning(&self) -> bool { - self.enabled - } -}