Skip to content

Commit

Permalink
feat: Add transforming adaptation
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Oct 17, 2024
1 parent 0f246f2 commit 1752c63
Show file tree
Hide file tree
Showing 15 changed files with 347 additions and 230 deletions.
37 changes: 27 additions & 10 deletions src/adapt_strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
>;
type Options = EuclideanAdaptOptions<A::Options>;

fn new(math: &mut M, options: Self::Options, num_tune: u64) -> Self {
fn new(math: &mut M, options: Self::Options, num_tune: u64, chain: u64) -> Self {
let num_tune_f = num_tune as f64;
let step_size_window = (options.step_size_window * num_tune_f) as u64;
let early_end = (options.early_window * num_tune_f) as u64;
Expand All @@ -100,7 +100,7 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy

Self {
step_size: StepSizeStrategy::new(options.dual_average_options),
mass_matrix: A::new(math, options.mass_matrix_options, num_tune),
mass_matrix: A::new(math, options.mass_matrix_options, num_tune, chain),
options,
num_tune,
early_end,
Expand All @@ -116,13 +116,13 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
math: &mut M,
options: &mut NutsOptions,
hamiltonian: &mut Self::Hamiltonian,
state: &State<M, <Self::Hamiltonian as Hamiltonian<M>>::Point>,
position: &[f64],
rng: &mut R,
) -> Result<(), NutsError> {
self.mass_matrix
.init(math, options, hamiltonian, state, rng)?;
.init(math, options, hamiltonian, position, rng)?;
self.step_size
.init(math, options, hamiltonian, state, rng)?;
.init(math, options, hamiltonian, position, rng)?;
Ok(())
}

Expand Down Expand Up @@ -186,8 +186,9 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
// First time we change the mass matrix
if did_change & self.has_initial_mass_matrix {
self.has_initial_mass_matrix = false;
let position = math.box_array(state.point().position());
self.step_size
.init(math, options, hamiltonian, state, rng)?;
.init(math, options, hamiltonian, &position, rng)?;
} else {
self.step_size.update_stepsize(hamiltonian, false)
}
Expand Down Expand Up @@ -403,7 +404,18 @@ pub mod test_logps {
unimplemented!()
}

fn transformed_logp(
fn init_from_transformed_position(
&mut self,
_params: &Self::TransformParams,
_untransformed_position: &mut [f64],
_untransformed_gradient: &mut [f64],
_transformed_position: &[f64],
_transformed_gradient: &mut [f64],
) -> Result<(f64, f64), Self::LogpError> {
unimplemented!()
}

fn init_from_untransformed_position(
&mut self,
_params: &Self::TransformParams,
_untransformed_position: &[f64],
Expand All @@ -424,15 +436,19 @@ pub mod test_logps {
unimplemented!()
}

fn new_transformation(
fn new_transformation<R: rand::Rng + ?Sized>(
&mut self,

Check failure on line 440 in src/adapt_strategy.rs

View workflow job for this annotation

GitHub Actions / Test Suite (stable)

method `new_transformation` has 4 parameters but the declaration in trait `CpuLogpFunc::new_transformation` has 5

Check failure on line 440 in src/adapt_strategy.rs

View workflow job for this annotation

GitHub Actions / Test Suite (nightly)

method `new_transformation` has 4 parameters but the declaration in trait `CpuLogpFunc::new_transformation` has 5
_rng: &mut R,
_untransformed_position: &[f64],
_untransfogmed_gradient: &[f64],
) -> Result<Self::TransformParams, Self::LogpError> {
unimplemented!()
}

fn transformation_id(&self, _params: &Self::TransformParams) -> i64 {
fn transformation_id(
&self,
_params: &Self::TransformParams,
) -> Result<i64, Self::LogpError> {
unimplemented!()
}
}
Expand Down Expand Up @@ -462,7 +478,8 @@ mod test {
let max_energy_error = 1000f64;
let step_size = 0.1f64;

let hamiltonian = EuclideanHamiltonian::new(mass_matrix, max_energy_error, step_size);
let hamiltonian =
EuclideanHamiltonian::new(&mut math, mass_matrix, max_energy_error, step_size);
let options = NutsOptions {
maxdepth: 10u64,
store_gradient: true,
Expand Down
22 changes: 7 additions & 15 deletions src/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::{
hamiltonian::{Hamiltonian, Point},
nuts::{draw, Collector, NutsOptions, NutsSampleStats, NutsStatsBuilder},
sampler_stats::SamplerStats,
state::{State, StatePool},
state::State,
Math, NutsError, Settings,
};

Expand Down Expand Up @@ -35,7 +35,6 @@ where
R: rand::Rng,
A: AdaptStrategy<M>,
{
pool: StatePool<M, <A::Hamiltonian as Hamiltonian<M>>::Point>,
hamiltonian: A::Hamiltonian,
collector: A::Collector,
options: NutsOptions,
Expand All @@ -56,18 +55,15 @@ where
{
pub fn new(
mut math: M,
hamiltonian: A::Hamiltonian,
mut hamiltonian: A::Hamiltonian,
strategy: A,
options: NutsOptions,
rng: R,
chain: u64,
) -> Self {
let pool_size: usize = options.maxdepth.checked_mul(2).unwrap().try_into().unwrap();
let pool = hamiltonian.new_pool(&mut math, pool_size);
let init = pool.new_state(&mut math);
let init = hamiltonian.pool().new_state(&mut math);
let collector = strategy.new_collector(&mut math);
NutsChain {
pool,
hamiltonian,
collector,
options,
Expand All @@ -87,14 +83,14 @@ pub trait AdaptStrategy<M: Math>: SamplerStats<M> {
type Collector: Collector<M, <Self::Hamiltonian as Hamiltonian<M>>::Point>;
type Options: Copy + Send + Debug + Default;

fn new(math: &mut M, options: Self::Options, num_tune: u64) -> Self;
fn new(math: &mut M, options: Self::Options, num_tune: u64, chain: u64) -> Self;

fn init<R: Rng + ?Sized>(
&mut self,
math: &mut M,
options: &mut NutsOptions,
hamiltonian: &mut Self::Hamiltonian,
state: &State<M, <Self::Hamiltonian as Hamiltonian<M>>::Point>,
position: &[f64],
rng: &mut R,
) -> Result<(), NutsError>;

Expand Down Expand Up @@ -151,24 +147,20 @@ where
type AdaptStrategy = A;

fn set_position(&mut self, position: &[f64]) -> Result<()> {
let state = self
.hamiltonian
.init_state(&mut self.math, &mut self.pool, position)?;
self.init = state;
self.strategy.init(
&mut self.math,
&mut self.options,
&mut self.hamiltonian,
&self.init,
position,
&mut self.rng,
)?;
self.init = self.hamiltonian.init_state(&mut self.math, position)?;
Ok(())
}

fn draw(&mut self) -> Result<(Box<[f64]>, Self::Stats)> {
let (state, info) = draw(
&mut self.math,
&mut self.pool,
&mut self.init,
&mut self.rng,
&mut self.hamiltonian,
Expand Down
102 changes: 73 additions & 29 deletions src/cpu_math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,15 +360,15 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
)
}

fn transformed_logp(
fn init_from_untransformed_position(
&mut self,
params: &Self::TransformParams,
untransformed_position: &Self::Vector,
untransformed_gradient: &mut Self::Vector,
transformed_position: &mut Self::Vector,
transformed_gradient: &mut Self::Vector,
) -> Result<(f64, f64), Self::LogpErr> {
self.logp_func.transformed_logp(
self.logp_func.init_from_untransformed_position(
params,
untransformed_position.as_slice(),
untransformed_gradient.as_slice_mut(),
Expand All @@ -377,11 +377,28 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
)
}

fn init_from_transformed_position(
&mut self,
params: &Self::TransformParams,
untransformed_position: &mut Self::Vector,
untransformed_gradient: &mut Self::Vector,
transformed_position: &Self::Vector,
transformed_gradient: &mut Self::Vector,
) -> Result<(f64, f64), Self::LogpErr> {
self.logp_func.init_from_transformed_position(
params,
untransformed_position.as_slice_mut(),
untransformed_gradient.as_slice_mut(),
transformed_position.as_slice(),
transformed_gradient.as_slice_mut(),
)
}

fn update_transformation<'a, R: rand::Rng + ?Sized>(
&'a mut self,
rng: &mut R,
untransformed_positions: impl Iterator<Item = &'a Self::Vector>,
untransformed_gradients: impl Iterator<Item = &'a Self::Vector>,
untransformed_positions: impl ExactSizeIterator<Item = &'a Self::Vector>,
untransformed_gradients: impl ExactSizeIterator<Item = &'a Self::Vector>,
params: &'a mut Self::TransformParams,
) -> Result<(), Self::LogpErr> {
self.logp_func.update_transformation(
Expand All @@ -392,18 +409,22 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
)
}

fn new_transformation(
fn new_transformation<R: rand::Rng + ?Sized>(
&mut self,
rng: &mut R,
untransformed_position: &Self::Vector,
untransfogmed_gradient: &Self::Vector,
chain: u64,
) -> Result<Self::TransformParams, Self::LogpErr> {
self.logp_func.new_transformation(
rng,
untransformed_position.as_slice(),
untransfogmed_gradient.as_slice(),
chain,
)
}

fn transformation_id(&self, params: &Self::TransformParams) -> i64 {
fn transformation_id(&self, params: &Self::TransformParams) -> Result<i64, Self::LogpErr> {
self.logp_func.transformation_id(params)
}
}
Expand All @@ -417,35 +438,58 @@ pub trait CpuLogpFunc {

fn inv_transform_normalize(
&mut self,
params: &Self::TransformParams,
untransformed_position: &[f64],
untransofrmed_gradient: &[f64],
transformed_position: &mut [f64],
transformed_gradient: &mut [f64],
) -> Result<f64, Self::LogpError>;
_params: &Self::TransformParams,
_untransformed_position: &[f64],
_untransformed_gradient: &[f64],
_transformed_position: &mut [f64],
_transformed_gradient: &mut [f64],
) -> Result<f64, Self::LogpError> {
unimplemented!()
}

fn transformed_logp(
fn init_from_untransformed_position(
&mut self,
params: &Self::TransformParams,
untransformed_position: &[f64],
untransformed_gradient: &mut [f64],
transformed_position: &mut [f64],
transformed_gradient: &mut [f64],
) -> Result<(f64, f64), Self::LogpError>;
_params: &Self::TransformParams,
_untransformed_position: &[f64],
_untransformed_gradient: &mut [f64],
_transformed_position: &mut [f64],
_transformed_gradient: &mut [f64],
) -> Result<(f64, f64), Self::LogpError> {
unimplemented!()
}

fn init_from_transformed_position(
&mut self,
_params: &Self::TransformParams,
_untransformed_position: &mut [f64],
_untransformed_gradient: &mut [f64],
_transformed_position: &[f64],
_transformed_gradient: &mut [f64],
) -> Result<(f64, f64), Self::LogpError> {
unimplemented!()
}

fn update_transformation<'a, R: rand::Rng + ?Sized>(
&'a mut self,
rng: &mut R,
untransformed_positions: impl Iterator<Item = &'a [f64]>,
untransformed_gradients: impl Iterator<Item = &'a [f64]>,
params: &'a mut Self::TransformParams,
) -> Result<(), Self::LogpError>;
_rng: &mut R,
_untransformed_positions: impl ExactSizeIterator<Item = &'a [f64]>,
_untransformed_gradients: impl ExactSizeIterator<Item = &'a [f64]>,
_params: &'a mut Self::TransformParams,
) -> Result<(), Self::LogpError> {
unimplemented!()
}

fn new_transformation(
fn new_transformation<R: rand::Rng + ?Sized>(
&mut self,
untransformed_position: &[f64],
untransfogmed_gradient: &[f64],
) -> Result<Self::TransformParams, Self::LogpError>;
_rng: &mut R,
_untransformed_position: &[f64],
_untransformed_gradient: &[f64],
_chain: u64,
) -> Result<Self::TransformParams, Self::LogpError> {
unimplemented!()
}

fn transformation_id(&self, params: &Self::TransformParams) -> i64;
fn transformation_id(&self, _params: &Self::TransformParams) -> Result<i64, Self::LogpError> {
unimplemented!()
}
}
Loading

0 comments on commit 1752c63

Please sign in to comment.