Skip to content

Commit

Permalink
refactor: Refactor mass matrix adaptation traits
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Oct 17, 2024
1 parent 1752c63 commit fb63fb1
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 148 deletions.
22 changes: 15 additions & 7 deletions src/adapt_strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use rand::Rng;

use crate::{
chain::AdaptStrategy,
euclidean_hamiltonian::EuclideanHamiltonian,
hamiltonian::{DivergenceInfo, Hamiltonian, Point},
mass_matrix_adapt::MassMatrixAdaptStrategy,
math_base::Math,
Expand Down Expand Up @@ -81,7 +82,7 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> SamplerStats<M> for GlobalStrategy<
}

impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy<M, A> {
type Hamiltonian = A::Hamiltonian;
type Hamiltonian = EuclideanHamiltonian<M, A::MassMatrix>;
type Collector = CombinedCollector<
M,
<Self::Hamiltonian as Hamiltonian<M>>::Point,
Expand Down Expand Up @@ -119,8 +120,14 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
position: &[f64],
rng: &mut R,
) -> Result<(), NutsError> {
self.mass_matrix
.init(math, options, hamiltonian, position, rng)?;
let state = hamiltonian.init_state(math, position)?;
self.mass_matrix.init(
math,
options,
&mut hamiltonian.mass_matrix,
state.point(),
rng,
)?;
self.step_size
.init(math, options, hamiltonian, position, rng)?;
Ok(())
Expand Down Expand Up @@ -168,7 +175,7 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
let did_change = if force_update
| (draw - self.last_update >= self.options.mass_matrix_update_freq)
{
self.mass_matrix.update_potential(math, hamiltonian)
self.mass_matrix.adapt(math, &mut hamiltonian.mass_matrix)
} else {
false
};
Expand Down Expand Up @@ -221,8 +228,8 @@ pub struct CombinedStats<D1, D2> {

#[derive(Clone)]
pub struct CombinedStatsBuilder<B1, B2> {
stats1: B1,
stats2: B2,
pub stats1: B1,
pub stats2: B2,
}

impl<S1, S2, B1, B2> StatTraceBuilder<CombinedStats<S1, S2>> for CombinedStatsBuilder<B1, B2>
Expand Down Expand Up @@ -441,6 +448,7 @@ pub mod test_logps {
_rng: &mut R,
_untransformed_position: &[f64],
_untransfogmed_gradient: &[f64],
_chain: u64,
) -> Result<Self::TransformParams, Self::LogpError> {
unimplemented!()
}
Expand Down Expand Up @@ -472,7 +480,7 @@ mod test {
let mut math = CpuMath::new(func);
let num_tune = 100;
let options = EuclideanAdaptOptions::<DiagAdaptExpSettings>::default();
let strategy = GlobalStrategy::<_, Strategy<_>>::new(&mut math, options, num_tune);
let strategy = GlobalStrategy::<_, Strategy<_>>::new(&mut math, options, num_tune, 0u64);

let mass_matrix = DiagMassMatrix::new(&mut math, true);
let max_energy_error = 1000f64;
Expand Down
55 changes: 13 additions & 42 deletions src/low_rank_mass_matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@ use faer::{Col, Mat, Scale};
use itertools::Itertools;

use crate::{
chain::AdaptStrategy,
euclidean_hamiltonian::{EuclideanHamiltonian, EuclideanPoint},
hamiltonian::{Hamiltonian, Point},
euclidean_hamiltonian::EuclideanPoint,
hamiltonian::Point,
mass_matrix::{DrawGradCollector, MassMatrix},
mass_matrix_adapt::MassMatrixAdaptStrategy,
sampler_stats::{SamplerStats, StatTraceBuilder},
state::State,
Math, NutsError,
};

Expand Down Expand Up @@ -392,12 +390,12 @@ impl LowRankMassMatrixStrategy {
}
}

pub fn add_draw<M: Math>(&mut self, math: &mut M, state: &State<M, EuclideanPoint<M>>) {
pub fn add_draw<M: Math>(&mut self, math: &mut M, point: &impl Point<M>) {
assert!(math.dim() == self.ndim);
let mut draw = vec![0f64; self.ndim];
state.write_position(math, &mut draw);
math.write_to_slice(point.position(), &mut draw);
let mut grad = vec![0f64; self.ndim];
state.write_gradient(math, &mut grad);
math.write_to_slice(point.gradient(), &mut grad);

self.draws.push_back(draw);
self.grads.push_back(grad);
Expand Down Expand Up @@ -569,8 +567,8 @@ impl<M: Math> SamplerStats<M> for LowRankMassMatrixStrategy {
}
}

impl<M: Math> AdaptStrategy<M> for LowRankMassMatrixStrategy {
type Hamiltonian = EuclideanHamiltonian<M, LowRankMassMatrix<M>>;
impl<M: Math> MassMatrixAdaptStrategy<M> for LowRankMassMatrixStrategy {
type MassMatrix = LowRankMassMatrix<M>;
type Collector = DrawGradCollector<M>;
type Options = LowRankSettings;

Expand All @@ -582,46 +580,19 @@ impl<M: Math> AdaptStrategy<M> for LowRankMassMatrixStrategy {
&mut self,
math: &mut M,
_options: &mut crate::nuts::NutsOptions,
hamiltonian: &mut Self::Hamiltonian,
position: &[f64],
_rng: &mut R,
) -> Result<(), NutsError> {
let state = hamiltonian.init_state(math, position)?;
self.add_draw(math, &state);
hamiltonian.mass_matrix.update_from_grad(
math,
state.point().gradient(),
1f64,
(1e-20, 1e20),
);
Ok(())
}

fn adapt<R: rand::Rng + ?Sized>(
&mut self,
_math: &mut M,
_options: &mut crate::nuts::NutsOptions,
_potential: &mut Self::Hamiltonian,
_draw: u64,
_collector: &Self::Collector,
_state: &State<M, EuclideanPoint<M>>,
mass_matrix: &mut Self::MassMatrix,
point: &impl Point<M>,
_rng: &mut R,
) -> Result<(), NutsError> {
self.add_draw(math, point);
mass_matrix.update_from_grad(math, point.gradient(), 1f64, (1e-20, 1e20));
Ok(())
}

fn new_collector(&self, math: &mut M) -> Self::Collector {
DrawGradCollector::new(math)
}

fn is_tuning(&self) -> bool {
unreachable!()
}
}

impl<M: Math> MassMatrixAdaptStrategy<M> for LowRankMassMatrixStrategy {
type MassMatrix = LowRankMassMatrix<M>;

fn update_estimators(&mut self, math: &mut M, collector: &Self::Collector) {
if collector.is_good {
let mut draw = vec![0f64; self.ndim];
Expand Down Expand Up @@ -651,11 +622,11 @@ impl<M: Math> MassMatrixAdaptStrategy<M> for LowRankMassMatrixStrategy {
self.draws.len().checked_sub(self.background_split).unwrap() as u64
}

fn update_potential(&self, math: &mut M, potential: &mut Self::Hamiltonian) -> bool {
fn adapt(&self, math: &mut M, mass_matrix: &mut Self::MassMatrix) -> bool {
if <LowRankMassMatrixStrategy as MassMatrixAdaptStrategy<M>>::current_count(self) < 3 {
return false;
}
self.update(math, &mut potential.mass_matrix);
self.update(math, mass_matrix);

true
}
Expand Down
115 changes: 45 additions & 70 deletions src/mass_matrix_adapt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@ use std::marker::PhantomData;
use rand::Rng;

use crate::{
chain::AdaptStrategy,
euclidean_hamiltonian::{EuclideanHamiltonian, EuclideanPoint},
hamiltonian::{Hamiltonian, Point},
euclidean_hamiltonian::EuclideanPoint,
hamiltonian::Point,
mass_matrix::{DiagMassMatrix, DrawGradCollector, MassMatrix, RunningVariance},
nuts::NutsOptions,
nuts::{Collector, NutsOptions},
sampler_stats::SamplerStats,
state::State,
Math, NutsError, Settings,
};
const LOWER_LIMIT: f64 = 1e-20f64;
Expand Down Expand Up @@ -43,8 +41,10 @@ pub struct Strategy<M: Math> {
_phantom: PhantomData<M>,
}

pub trait MassMatrixAdaptStrategy<M: Math>: AdaptStrategy<M> {
pub trait MassMatrixAdaptStrategy<M: Math>: SamplerStats<M> {
type MassMatrix: MassMatrix<M>;
type Collector: Collector<M, EuclideanPoint<M>>;
type Options: std::fmt::Debug + Default + Clone + Send + Sync + Copy;

fn update_estimators(&mut self, math: &mut M, collector: &Self::Collector);

Expand All @@ -55,11 +55,26 @@ pub trait MassMatrixAdaptStrategy<M: Math>: AdaptStrategy<M> {
fn background_count(&self) -> u64;

/// Give the opportunity to update the potential and return if it was changed
fn update_potential(&self, math: &mut M, potential: &mut Self::Hamiltonian) -> bool;
fn adapt(&self, math: &mut M, mass_matrix: &mut Self::MassMatrix) -> bool;

fn new(math: &mut M, options: Self::Options, _num_tune: u64, _chain: u64) -> Self;

fn init<R: Rng + ?Sized>(
&mut self,
math: &mut M,
_options: &mut NutsOptions,
mass_matrix: &mut Self::MassMatrix,
point: &impl Point<M>,
_rng: &mut R,
) -> Result<(), NutsError>;

fn new_collector(&self, math: &mut M) -> Self::Collector;
}

impl<M: Math> MassMatrixAdaptStrategy<M> for Strategy<M> {
type MassMatrix = DiagMassMatrix<M>;
type Collector = DrawGradCollector<M>;
type Options = DiagAdaptExpSettings;

fn update_estimators(&mut self, math: &mut M, collector: &DrawGradCollector<M>) {
if collector.is_good {
Expand Down Expand Up @@ -88,11 +103,7 @@ impl<M: Math> MassMatrixAdaptStrategy<M> for Strategy<M> {
}

/// Give the opportunity to update the potential and return if it was changed
fn update_potential(
&self,
math: &mut M,
potential: &mut EuclideanHamiltonian<M, Self::MassMatrix>,
) -> bool {
fn adapt(&self, math: &mut M, mass_matrix: &mut DiagMassMatrix<M>) -> bool {
if self.current_count() < 3 {
return false;
}
Expand All @@ -102,7 +113,7 @@ impl<M: Math> MassMatrixAdaptStrategy<M> for Strategy<M> {
assert!(draw_scale == grad_scale);

if self._settings.use_grad_based_estimate {
potential.mass_matrix.update_diag_draw_grad(
mass_matrix.update_diag_draw_grad(
math,
draw_var,
grad_var,
Expand All @@ -111,35 +122,11 @@ impl<M: Math> MassMatrixAdaptStrategy<M> for Strategy<M> {
);
} else {
let scale = (self.exp_variance_draw.count() as f64).recip();
potential.mass_matrix.update_diag_draw(
math,
draw_var,
scale,
None,
(LOWER_LIMIT, UPPER_LIMIT),
);
mass_matrix.update_diag_draw(math, draw_var, scale, None, (LOWER_LIMIT, UPPER_LIMIT));
}

true
}
}

pub type Stats = ();
pub type StatsBuilder = ();

impl<M: Math> SamplerStats<M> for Strategy<M> {
type Builder = Stats;
type Stats = StatsBuilder;

fn new_builder(&self, _settings: &impl Settings, _dim: usize) -> Self::Builder {}

fn current_stats(&self, _math: &mut M) -> Self::Stats {}
}

impl<M: Math> AdaptStrategy<M> for Strategy<M> {
type Hamiltonian = EuclideanHamiltonian<M, DiagMassMatrix<M>>;
type Collector = DrawGradCollector<M>;
type Options = DiagAdaptExpSettings;

fn new(math: &mut M, options: Self::Options, _num_tune: u64, _chain: u64) -> Self {
Self {
Expand All @@ -156,49 +143,37 @@ impl<M: Math> AdaptStrategy<M> for Strategy<M> {
&mut self,
math: &mut M,
_options: &mut NutsOptions,
hamiltonian: &mut Self::Hamiltonian,
position: &[f64],
mass_matrix: &mut Self::MassMatrix,
point: &impl Point<M>,
_rng: &mut R,
) -> Result<(), NutsError> {
let state = hamiltonian.init_state(math, position)?;

self.exp_variance_draw
.add_sample(math, state.point().position());
self.exp_variance_draw_bg
.add_sample(math, state.point().position());
self.exp_variance_grad
.add_sample(math, state.point().gradient());
self.exp_variance_grad_bg
.add_sample(math, state.point().gradient());

hamiltonian.mass_matrix.update_diag_grad(
self.exp_variance_draw.add_sample(math, point.position());
self.exp_variance_draw_bg.add_sample(math, point.position());
self.exp_variance_grad.add_sample(math, point.gradient());
self.exp_variance_grad_bg.add_sample(math, point.gradient());

mass_matrix.update_diag_grad(
math,
state.point().gradient(),
point.gradient(),
1f64,
(INIT_LOWER_LIMIT, INIT_UPPER_LIMIT),
);
Ok(())
}

fn adapt<R: Rng + ?Sized>(
&mut self,
_math: &mut M,
_options: &mut NutsOptions,
_potential: &mut Self::Hamiltonian,
_draw: u64,
_collector: &Self::Collector,
_state: &State<M, EuclideanPoint<M>>,
_rng: &mut R,
) -> Result<(), NutsError> {
// Must be controlled from a different meta strategy
Ok(())
}

fn new_collector(&self, math: &mut M) -> Self::Collector {
DrawGradCollector::new(math)
}
}

fn is_tuning(&self) -> bool {
unreachable!()
}
pub type Stats = ();
pub type StatsBuilder = ();

impl<M: Math> SamplerStats<M> for Strategy<M> {
type Builder = Stats;
type Stats = StatsBuilder;

fn new_builder(&self, _settings: &impl Settings, _dim: usize) -> Self::Builder {}

fn current_stats(&self, _math: &mut M) -> Self::Stats {}
}
1 change: 0 additions & 1 deletion src/nuts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,6 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
LeapfrogResult::Ok(end) => end,
};

// TODO sign?
let log_size = -end.point().energy_error();
Ok(Ok(NutsTree {
right: end.clone(),
Expand Down
6 changes: 3 additions & 3 deletions src/sampler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,8 @@ impl Settings for TransformedNutsSettings {
&self,
stats: &<Self::Chain<M> as SamplerStats<M>>::Stats,
) -> SampleStats {
// TODO
let step_size = 0.;
let num_steps = 0;
let step_size = stats.potential_stats.step_size;
let num_steps = stats.strategy_stats.step_size.n_steps;
SampleStats {
chain: stats.chain,
draw: stats.draw,
Expand Down Expand Up @@ -1065,6 +1064,7 @@ pub mod test_logps {
_rng: &mut R,
_untransformed_position: &[f64],
_untransfogmed_gradient: &[f64],
_chain: u64,
) -> std::result::Result<Self::TransformParams, Self::LogpError> {
unimplemented!()
}
Expand Down
Loading

0 comments on commit fb63fb1

Please sign in to comment.