From 0b2913e9562c809509b389d0d74edcd5af2675a5 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Thu, 28 Nov 2024 11:53:33 +0100 Subject: [PATCH] refactor: Remove unnecessary stats structs and add some transform stats --- Cargo.toml | 8 +- src/adapt_strategy.rs | 77 +++-- src/chain.rs | 574 ++++++++++++++++++++++++++++---- src/cpu_math.rs | 9 +- src/euclidean_hamiltonian.rs | 64 ++-- src/lib.rs | 8 +- src/low_rank_mass_matrix.rs | 114 +++---- src/mass_matrix.rs | 43 +-- src/mass_matrix_adapt.rs | 17 +- src/math_base.rs | 1 + src/nuts.rs | 508 +--------------------------- src/sampler.rs | 131 ++++---- src/sampler_stats.rs | 30 +- src/stepsize_adapt.rs | 48 ++- src/transform_adapt_strategy.rs | 49 +-- src/transformed_hamiltonian.rs | 183 +++++++--- tests/sample_normal.rs | 1 + 17 files changed, 933 insertions(+), 932 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f430478..55d0340 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,8 +2,8 @@ name = "nuts-rs" version = "0.13.0" authors = [ - "Adrian Seyboldt ", - "PyMC Developers ", + "Adrian Seyboldt ", + "PyMC Developers ", ] edition = "2021" license = "MIT" @@ -22,12 +22,12 @@ rand = { version = "0.8.5", features = ["small_rng"] } rand_distr = "0.4.3" multiversion = "0.7.2" itertools = "0.13.0" -thiserror = "1.0.43" +thiserror = "2.0.3" arrow = { version = "53.1.0", default-features = false, features = ["ffi"] } rand_chacha = "0.3.1" anyhow = "1.0.72" faer = { version = "0.19.4", default-features = false, features = ["std"] } -pulp = "0.18.21" +pulp = "0.19.6" rayon = "1.10.0" [dev-dependencies] diff --git a/src/adapt_strategy.rs b/src/adapt_strategy.rs index f1a26e7..5b14935 100644 --- a/src/adapt_strategy.rs +++ b/src/adapt_strategy.rs @@ -16,8 +16,7 @@ use crate::{ state::State, stepsize::AcceptanceRateCollector, stepsize_adapt::{ - DualAverageSettings, Stats as StepSizeStats, StatsBuilder as StepSizeStatsBuilder, - Strategy as StepSizeStrategy, + DualAverageSettings, StatsBuilder as StepSizeStatsBuilder, Strategy as StepSizeStrategy, }, NutsError, }; @@ -63,20 +62,18 @@ impl Default for EuclideanAdaptOptions { } impl> SamplerStats for GlobalStrategy { - type Stats = CombinedStats; - type Builder = CombinedStatsBuilder; + type Builder = GlobalStrategyBuilder; + type StatOptions = >::StatOptions; - fn current_stats(&self, math: &mut M) -> Self::Stats { - CombinedStats { - stats1: self.step_size.current_stats(math), - stats2: self.mass_matrix.current_stats(math), - } - } - - fn new_builder(&self, settings: &impl Settings, dim: usize) -> Self::Builder { - CombinedStatsBuilder { - stats1: SamplerStats::::new_builder(&self.step_size, settings, dim), - stats2: self.mass_matrix.new_builder(settings, dim), + fn new_builder( + &self, + options: Self::StatOptions, + settings: &impl Settings, + dim: usize, + ) -> Self::Builder { + GlobalStrategyBuilder { + step_size: SamplerStats::::new_builder(&self.step_size, (), settings, dim), + mass_matrix: self.mass_matrix.new_builder(options, settings, dim), } } } @@ -218,33 +215,37 @@ impl> AdaptStrategy for GlobalStrategy fn is_tuning(&self) -> bool { self.tuning } -} -#[derive(Debug, Clone)] -pub struct CombinedStats { - pub stats1: D1, - pub stats2: D2, + fn last_num_steps(&self) -> u64 { + self.step_size.last_n_steps + } } -#[derive(Clone)] -pub struct CombinedStatsBuilder { - pub stats1: B1, - pub stats2: B2, +pub struct GlobalStrategyBuilder { + pub step_size: StepSizeStatsBuilder, + pub mass_matrix: B, } -impl StatTraceBuilder> for CombinedStatsBuilder +impl StatTraceBuilder> for GlobalStrategyBuilder where - B1: StatTraceBuilder, - B2: StatTraceBuilder, + A: MassMatrixAdaptStrategy, { - fn append_value(&mut self, value: CombinedStats) { - self.stats1.append_value(value.stats1); - self.stats2.append_value(value.stats2); + fn append_value(&mut self, math: Option<&mut M>, value: &GlobalStrategy) { + let math = math.expect("Smapler stats need math"); + self.step_size.append_value(Some(math), &value.step_size); + self.mass_matrix + .append_value(Some(math), &value.mass_matrix); } fn finalize(self) -> Option { - let Self { stats1, stats2 } = self; - match (stats1.finalize(), stats2.finalize()) { + let Self { + step_size, + mass_matrix, + } = self; + match ( + StatTraceBuilder::::finalize(step_size), + mass_matrix.finalize(), + ) { (None, None) => None, (Some(stats1), None) => Some(stats1), (None, Some(stats2)) => Some(stats2), @@ -266,8 +267,14 @@ where } fn inspect(&self) -> Option { - let Self { stats1, stats2 } = self; - match (stats1.inspect(), stats2.inspect()) { + let Self { + step_size, + mass_matrix, + } = self; + match ( + StatTraceBuilder::::inspect(step_size), + mass_matrix.inspect(), + ) { (None, None) => None, (Some(stats1), None) => Some(stats1), (None, Some(stats2)) => Some(stats2), @@ -374,6 +381,7 @@ pub mod test_logps { #[derive(Error, Debug)] pub enum NormalLogpError {} + impl LogpError for NormalLogpError { fn is_recoverable(&self) -> bool { false @@ -438,6 +446,7 @@ pub mod test_logps { _rng: &mut R, _untransformed_positions: impl Iterator, _untransformed_gradients: impl Iterator, + _untransformed_logp: impl Iterator, _params: &'a mut Self::TransformParams, ) -> Result<(), Self::LogpError> { unimplemented!() diff --git a/src/chain.rs b/src/chain.rs index 6c2c767..a182283 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -1,11 +1,24 @@ -use std::fmt::Debug; +use std::{ + cell::RefCell, + fmt::Debug, + ops::{Deref, DerefMut}, + sync::Arc, +}; +use arrow::{ + array::{ + Array, ArrayBuilder, BooleanBuilder, FixedSizeListBuilder, PrimitiveBuilder, StringBuilder, + StructArray, + }, + datatypes::{DataType, Field, Fields, Float64Type, Int64Type, UInt64Type}, +}; use rand::Rng; use crate::{ hamiltonian::{Hamiltonian, Point}, - nuts::{draw, Collector, NutsOptions, NutsSampleStats, NutsStatsBuilder}, - sampler_stats::SamplerStats, + nuts::{draw, Collector, NutsOptions, SampleInfo}, + sampler::Progress, + sampler_stats::{SamplerStats, StatTraceBuilder}, state::State, Math, NutsError, Settings, }; @@ -23,7 +36,7 @@ pub trait Chain: SamplerStats { fn set_position(&mut self, position: &[f64]) -> Result<()>; /// Draw a new sample and return the position and some diagnosic information. - fn draw(&mut self) -> Result<(Box<[f64]>, Self::Stats)>; + fn draw(&mut self) -> Result<(Box<[f64]>, Progress)>; /// The dimensionality of the posterior. fn dim(&self) -> usize; @@ -39,18 +52,12 @@ where collector: A::Collector, options: NutsOptions, rng: R, - init: State>::Point>, + state: State>::Point>, + last_info: Option, chain: u64, draw_count: u64, strategy: A, - math: M, - stats: Option< - NutsSampleStats< - <>::Point as SamplerStats>::Stats, - >::Stats, - A::Stats, - >, - >, + math: RefCell, } impl NutsChain @@ -74,12 +81,12 @@ where collector, options, rng, - init, + state: init, + last_info: None, chain, draw_count: 0, strategy, - math, - stats: None, + math: math.into(), } } } @@ -114,6 +121,7 @@ pub trait AdaptStrategy: SamplerStats { fn new_collector(&self, math: &mut M) -> Self::Collector; fn is_tuning(&self) -> bool; + fn last_num_steps(&self) -> u64; } impl SamplerStats for NutsChain @@ -122,31 +130,25 @@ where R: rand::Rng, A: AdaptStrategy, { - type Builder = NutsStatsBuilder< - <>::Point as SamplerStats>::Builder, - >::Builder, - >::Builder, - >; - type Stats = NutsSampleStats< - <>::Point as SamplerStats>::Stats, - >::Stats, - >::Stats, - >; - - fn new_builder(&self, settings: &impl Settings, dim: usize) -> Self::Builder { + type Builder = NutsStatsBuilder; + type StatOptions = StatOptions; + + fn new_builder( + &self, + options: StatOptions, + settings: &impl Settings, + dim: usize, + ) -> Self::Builder { NutsStatsBuilder::new_with_capacity( + options, settings, &self.hamiltonian, &self.strategy, - self.init.point(), + self.state.point(), dim, &self.options, ) } - - fn current_stats(&self, _math: &mut M) -> Self::Stats { - self.stats.as_ref().expect("No stats available").clone() - } } impl Chain for NutsChain @@ -158,61 +160,44 @@ where type AdaptStrategy = A; fn set_position(&mut self, position: &[f64]) -> Result<()> { + let mut math_ = self.math.borrow_mut(); + let math = math_.deref_mut(); self.strategy.init( - &mut self.math, + math, &mut self.options, &mut self.hamiltonian, position, &mut self.rng, )?; - self.init = self.hamiltonian.init_state(&mut self.math, position)?; + self.state = self.hamiltonian.init_state(math, position)?; Ok(()) } - fn draw(&mut self) -> Result<(Box<[f64]>, Self::Stats)> { + fn draw(&mut self) -> Result<(Box<[f64]>, Progress)> { + let mut math_ = self.math.borrow_mut(); + let math = math_.deref_mut(); let (state, info) = draw( - &mut self.math, - &mut self.init, + math, + &mut self.state, &mut self.rng, &mut self.hamiltonian, &self.options, &mut self.collector, )?; - let mut position: Box<[f64]> = vec![0f64; self.math.dim()].into(); - state.write_position(&mut self.math, &mut position); - - let stats = NutsSampleStats { - depth: info.depth, - maxdepth_reached: info.reached_maxdepth, - idx_in_trajectory: state.index_in_trajectory(), - logp: state.point().logp(), - energy: state.point().energy(), - energy_error: info.draw_energy - info.initial_energy, - divergence_info: info.divergence_info, - chain: self.chain, + let mut position: Box<[f64]> = vec![0f64; math.dim()].into(); + state.write_position(math, &mut position); + + let progress = Progress { draw: self.draw_count, - potential_stats: self.hamiltonian.current_stats(&mut self.math), - strategy_stats: self.strategy.current_stats(&mut self.math), - point_stats: state.point().current_stats(&mut self.math), - gradient: if self.options.store_gradient { - let mut gradient: Box<[f64]> = vec![0f64; self.math.dim()].into(); - state.write_gradient(&mut self.math, &mut gradient); - Some(gradient) - } else { - None - }, - unconstrained: if self.options.store_unconstrained { - let mut unconstrained: Box<[f64]> = vec![0f64; self.math.dim()].into(); - state.write_position(&mut self.math, &mut unconstrained); - Some(unconstrained) - } else { - None - }, + chain: self.chain, + diverging: info.divergence_info.is_some(), tuning: self.strategy.is_tuning(), + step_size: self.hamiltonian.step_size(), + num_steps: self.strategy.last_num_steps(), }; self.strategy.adapt( - &mut self.math, + math, &mut self.options, &mut self.hamiltonian, self.draw_count, @@ -223,11 +208,458 @@ where self.draw_count += 1; - self.init = state; - Ok((position, stats)) + self.state = state; + self.last_info = Some(info); + Ok((position, progress)) } fn dim(&self) -> usize { - self.math.dim() + self.math.borrow().dim() + } +} + +pub struct NutsStatsBuilder> { + depth: PrimitiveBuilder, + maxdepth_reached: BooleanBuilder, + index_in_trajectory: PrimitiveBuilder, + logp: PrimitiveBuilder, + energy: PrimitiveBuilder, + chain: PrimitiveBuilder, + draw: PrimitiveBuilder, + energy_error: PrimitiveBuilder, + unconstrained: Option>>, + gradient: Option>>, + hamiltonian: >::Builder, + adapt: A::Builder, + point: <>::Point as SamplerStats>::Builder, + diverging: BooleanBuilder, + divergence_start: Option>>, + divergence_start_grad: Option>>, + divergence_end: Option>>, + divergence_momentum: Option>>, + divergence_msg: Option, +} + +pub struct StatOptions> { + pub adapt: A::StatOptions, + pub hamiltonian: >::StatOptions, + pub point: <>::Point as SamplerStats>::StatOptions, +} + +impl> NutsStatsBuilder { + pub fn new_with_capacity( + stat_options: StatOptions, + settings: &impl Settings, + hamiltonian: &A::Hamiltonian, + adapt: &A, + point: &>::Point, + dim: usize, + options: &NutsOptions, + ) -> Self { + let capacity = settings.hint_num_tune() + settings.hint_num_draws(); + + let gradient = if options.store_gradient { + let items = PrimitiveBuilder::with_capacity(capacity); + Some(FixedSizeListBuilder::new(items, dim as i32)) + } else { + None + }; + + let unconstrained = if options.store_unconstrained { + let items = PrimitiveBuilder::with_capacity(capacity); + Some(FixedSizeListBuilder::with_capacity( + items, dim as i32, capacity, + )) + } else { + None + }; + + let (div_start, div_start_grad, div_end, div_mom, div_msg) = if options.store_divergences { + let start_location_prim = PrimitiveBuilder::new(); + let start_location_list = FixedSizeListBuilder::new(start_location_prim, dim as i32); + + let start_grad_prim = PrimitiveBuilder::new(); + let start_grad_list = FixedSizeListBuilder::new(start_grad_prim, dim as i32); + + let end_location_prim = PrimitiveBuilder::new(); + let end_location_list = FixedSizeListBuilder::new(end_location_prim, dim as i32); + + let momentum_location_prim = PrimitiveBuilder::new(); + let momentum_location_list = + FixedSizeListBuilder::new(momentum_location_prim, dim as i32); + + let msg_list = StringBuilder::new(); + + ( + Some(start_location_list), + Some(start_grad_list), + Some(end_location_list), + Some(momentum_location_list), + Some(msg_list), + ) + } else { + (None, None, None, None, None) + }; + + Self { + depth: PrimitiveBuilder::with_capacity(capacity), + maxdepth_reached: BooleanBuilder::with_capacity(capacity), + index_in_trajectory: PrimitiveBuilder::with_capacity(capacity), + logp: PrimitiveBuilder::with_capacity(capacity), + energy: PrimitiveBuilder::with_capacity(capacity), + chain: PrimitiveBuilder::with_capacity(capacity), + draw: PrimitiveBuilder::with_capacity(capacity), + energy_error: PrimitiveBuilder::with_capacity(capacity), + gradient, + unconstrained, + hamiltonian: hamiltonian.new_builder(stat_options.hamiltonian, settings, dim), + adapt: adapt.new_builder(stat_options.adapt, settings, dim), + point: point.new_builder(stat_options.point, settings, dim), + diverging: BooleanBuilder::with_capacity(capacity), + divergence_start: div_start, + divergence_start_grad: div_start_grad, + divergence_end: div_end, + divergence_momentum: div_mom, + divergence_msg: div_msg, + } + } +} + +impl> StatTraceBuilder> + for NutsStatsBuilder +{ + fn append_value(&mut self, _math: Option<&mut M>, value: &NutsChain) { + let mut math_ = value.math.borrow_mut(); + let math = math_.deref_mut(); + let Self { + ref mut depth, + ref mut maxdepth_reached, + ref mut index_in_trajectory, + logp, + energy, + chain, + draw, + energy_error, + ref mut unconstrained, + ref mut gradient, + hamiltonian, + adapt, + point, + diverging, + ref mut divergence_start, + divergence_start_grad, + divergence_end, + divergence_momentum, + divergence_msg, + } = self; + + let info = value.last_info.as_ref().expect("Sampler has not started"); + let draw_point = value.state.point(); + + depth.append_value(info.depth); + maxdepth_reached.append_value(info.reached_maxdepth); + index_in_trajectory.append_value(draw_point.index_in_trajectory()); + logp.append_value(draw_point.logp()); + energy.append_value(draw_point.energy()); + chain.append_value(value.chain); + draw.append_value(value.draw_count); + diverging.append_value(info.divergence_info.is_some()); + energy_error.append_value(draw_point.energy_error()); + + fn add_slice>( + store: &mut Option>>, + values: Option, + n_dim: usize, + ) { + let Some(store) = store.as_mut() else { + return; + }; + + if let Some(values) = values.as_ref() { + store.values().append_slice(values.as_ref()); + store.append(true); + } else { + store.values().append_nulls(n_dim); + store.append(false); + } + } + + let n_dim = math.dim(); + add_slice(gradient, Some(math.box_array(draw_point.gradient())), n_dim); + add_slice( + unconstrained, + Some(math.box_array(draw_point.position())), + n_dim, + ); + + let div_info = info.divergence_info.as_ref(); + add_slice( + divergence_start, + div_info.and_then(|info| info.start_location.as_ref()), + n_dim, + ); + add_slice( + divergence_start_grad, + div_info.and_then(|info| info.start_gradient.as_ref()), + n_dim, + ); + add_slice( + divergence_end, + div_info.and_then(|info| info.end_location.as_ref()), + n_dim, + ); + add_slice( + divergence_momentum, + div_info.and_then(|info| info.start_momentum.as_ref()), + n_dim, + ); + + if let Some(div_msg) = divergence_msg.as_mut() { + if let Some(err) = div_info.and_then(|info| info.logp_function_error.as_ref()) { + div_msg.append_value(format!("{}", err)); + } else { + div_msg.append_null(); + } + } + + hamiltonian.append_value(Some(math), &value.hamiltonian); + adapt.append_value(Some(math), &value.strategy); + point.append_value(Some(math), draw_point); + } + + fn finalize(self) -> Option { + let Self { + mut depth, + mut maxdepth_reached, + mut index_in_trajectory, + mut logp, + mut energy, + mut chain, + mut draw, + mut energy_error, + unconstrained, + gradient, + hamiltonian, + adapt, + point, + mut diverging, + divergence_start, + divergence_start_grad, + divergence_end, + divergence_momentum, + divergence_msg, + } = self; + + let mut fields = vec![ + Field::new("depth", DataType::UInt64, false), + Field::new("maxdepth_reached", DataType::Boolean, false), + Field::new("index_in_trajectory", DataType::Int64, false), + Field::new("logp", DataType::Float64, false), + Field::new("energy", DataType::Float64, false), + Field::new("chain", DataType::UInt64, false), + Field::new("draw", DataType::UInt64, false), + Field::new("diverging", DataType::Boolean, false), + Field::new("energy_error", DataType::Float64, false), + ]; + + let mut arrays: Vec> = vec![ + ArrayBuilder::finish(&mut depth), + ArrayBuilder::finish(&mut maxdepth_reached), + ArrayBuilder::finish(&mut index_in_trajectory), + ArrayBuilder::finish(&mut logp), + ArrayBuilder::finish(&mut energy), + ArrayBuilder::finish(&mut chain), + ArrayBuilder::finish(&mut draw), + ArrayBuilder::finish(&mut diverging), + ArrayBuilder::finish(&mut energy_error), + ]; + + fn merge_into>( + builder: B, + arrays: &mut Vec>, + fields: &mut Vec, + ) { + let Some(struct_array) = builder.finalize() else { + return; + }; + + let (struct_fields, struct_arrays, bitmap) = struct_array.into_parts(); + assert!(bitmap.is_none()); + arrays.extend(struct_arrays); + fields.extend(struct_fields.into_iter().map(|x| x.deref().clone())); + } + + fn add_field( + mut builder: Option, + name: &str, + arrays: &mut Vec>, + fields: &mut Vec, + ) { + let Some(mut builder) = builder.take() else { + return; + }; + + let array = ArrayBuilder::finish(&mut builder); + fields.push(Field::new(name, array.data_type().clone(), true)); + arrays.push(array); + } + + merge_into(hamiltonian, &mut arrays, &mut fields); + merge_into(adapt, &mut arrays, &mut fields); + merge_into(point, &mut arrays, &mut fields); + + add_field(gradient, "gradient", &mut arrays, &mut fields); + add_field( + unconstrained, + "unconstrained_draw", + &mut arrays, + &mut fields, + ); + add_field( + divergence_start, + "divergence_start", + &mut arrays, + &mut fields, + ); + add_field( + divergence_start_grad, + "divergence_start_gradient", + &mut arrays, + &mut fields, + ); + add_field(divergence_end, "divergence_end", &mut arrays, &mut fields); + add_field( + divergence_momentum, + "divergence_momentum", + &mut arrays, + &mut fields, + ); + add_field( + divergence_msg, + "divergence_messagem", + &mut arrays, + &mut fields, + ); + + let fields = Fields::from(fields); + Some(StructArray::new(fields, arrays, None)) + } + + fn inspect(&self) -> Option { + let Self { + depth, + maxdepth_reached, + index_in_trajectory, + logp, + energy, + chain, + draw, + energy_error, + unconstrained, + gradient, + hamiltonian, + adapt, + point, + diverging, + divergence_start, + divergence_start_grad, + divergence_end, + divergence_momentum, + divergence_msg, + } = self; + + let mut fields = vec![ + Field::new("depth", DataType::UInt64, false), + Field::new("maxdepth_reached", DataType::Boolean, false), + Field::new("index_in_trajectory", DataType::Int64, false), + Field::new("logp", DataType::Float64, false), + Field::new("energy", DataType::Float64, false), + Field::new("chain", DataType::UInt64, false), + Field::new("draw", DataType::UInt64, false), + Field::new("diverging", DataType::Boolean, false), + Field::new("energy_error", DataType::Float64, false), + ]; + + let mut arrays: Vec> = vec![ + ArrayBuilder::finish_cloned(depth), + ArrayBuilder::finish_cloned(maxdepth_reached), + ArrayBuilder::finish_cloned(index_in_trajectory), + ArrayBuilder::finish_cloned(logp), + ArrayBuilder::finish_cloned(energy), + ArrayBuilder::finish_cloned(chain), + ArrayBuilder::finish_cloned(draw), + ArrayBuilder::finish_cloned(diverging), + ArrayBuilder::finish_cloned(energy_error), + ]; + + fn merge_into>( + builder: &B, + arrays: &mut Vec>, + fields: &mut Vec, + ) { + let Some(struct_array) = builder.inspect() else { + return; + }; + + let (struct_fields, struct_arrays, bitmap) = struct_array.into_parts(); + assert!(bitmap.is_none()); + arrays.extend(struct_arrays); + fields.extend(struct_fields.into_iter().map(|x| x.deref().clone())); + } + + fn add_field( + builder: &Option, + name: &str, + arrays: &mut Vec>, + fields: &mut Vec, + ) { + let Some(builder) = builder.as_ref() else { + return; + }; + + let array = ArrayBuilder::finish_cloned(builder); + fields.push(Field::new(name, array.data_type().clone(), true)); + arrays.push(array); + } + + merge_into(hamiltonian, &mut arrays, &mut fields); + merge_into(adapt, &mut arrays, &mut fields); + merge_into(point, &mut arrays, &mut fields); + + add_field(gradient, "gradient", &mut arrays, &mut fields); + add_field( + unconstrained, + "unconstrained_draw", + &mut arrays, + &mut fields, + ); + add_field( + divergence_start, + "divergence_start", + &mut arrays, + &mut fields, + ); + add_field( + divergence_start_grad, + "divergence_start_gradient", + &mut arrays, + &mut fields, + ); + add_field(divergence_end, "divergence_end", &mut arrays, &mut fields); + add_field( + divergence_momentum, + "divergence_momentum", + &mut arrays, + &mut fields, + ); + add_field( + divergence_msg, + "divergence_messagem", + &mut arrays, + &mut fields, + ); + + let fields = Fields::from(fields); + Some(StructArray::new(fields, arrays, None)) } } diff --git a/src/cpu_math.rs b/src/cpu_math.rs index 34c21ea..7305c23 100644 --- a/src/cpu_math.rs +++ b/src/cpu_math.rs @@ -13,7 +13,7 @@ use crate::{ pub struct CpuMath { logp_func: F, arch: pulp::Arch, - parallel: faer::Parallelism<'static>, + _parallel: faer::Parallelism<'static>, } impl CpuMath { @@ -23,7 +23,7 @@ impl CpuMath { Self { logp_func, arch, - parallel, + _parallel: parallel, } } @@ -32,7 +32,7 @@ impl CpuMath { Self { logp_func, arch, - parallel, + _parallel: parallel, } } } @@ -407,12 +407,14 @@ impl Math for CpuMath { rng: &mut R, untransformed_positions: impl ExactSizeIterator, untransformed_gradients: impl ExactSizeIterator, + untransformed_logp: impl ExactSizeIterator, params: &'a mut Self::TransformParams, ) -> Result<(), Self::LogpErr> { self.logp_func.update_transformation( rng, untransformed_positions.map(|x| x.as_slice()), untransformed_gradients.map(|x| x.as_slice()), + untransformed_logp, params, ) } @@ -482,6 +484,7 @@ pub trait CpuLogpFunc { _rng: &mut R, _untransformed_positions: impl ExactSizeIterator, _untransformed_gradients: impl ExactSizeIterator, + _untransformed_logp: impl ExactSizeIterator, _params: &'a mut Self::TransformParams, ) -> Result<(), Self::LogpError> { unimplemented!() diff --git a/src/euclidean_hamiltonian.rs b/src/euclidean_hamiltonian.rs index 84e9bff..ae5be9a 100644 --- a/src/euclidean_hamiltonian.rs +++ b/src/euclidean_hamiltonian.rs @@ -1,4 +1,3 @@ -use std::fmt::Debug; use std::marker::PhantomData; use std::sync::Arc; @@ -52,14 +51,11 @@ pub struct EuclideanPoint { pub initial_energy: f64, } -#[derive(Clone, Debug)] -pub struct PointStats {} - pub struct PointStatsBuilder {} -impl StatTraceBuilder for PointStatsBuilder { - fn append_value(&mut self, value: PointStats) { - let PointStats {} = value; +impl StatTraceBuilder> for PointStatsBuilder { + fn append_value(&mut self, _math: Option<&mut M>, _value: &EuclideanPoint) { + let Self {} = self; } fn finalize(self) -> Option { @@ -74,16 +70,17 @@ impl StatTraceBuilder for PointStatsBuilder { } impl SamplerStats for EuclideanPoint { - type Stats = PointStats; type Builder = PointStatsBuilder; + type StatOptions = (); - fn new_builder(&self, _settings: &impl Settings, _dim: usize) -> Self::Builder { + fn new_builder( + &self, + _stat_options: Self::StatOptions, + _settings: &impl Settings, + _dim: usize, + ) -> Self::Builder { Self::Builder {} } - - fn current_stats(&self, _math: &mut M) -> Self::Stats { - PointStats {} - } } impl EuclideanPoint { @@ -219,28 +216,23 @@ impl Point for EuclideanPoint { } } -#[derive(Copy, Clone, Debug)] -pub struct PotentialStats { - mass_matrix_stats: S, - pub step_size: f64, -} - pub struct PotentialStatsBuilder { mass_matrix: B, step_size: Float64Builder, } -impl> StatTraceBuilder> - for PotentialStatsBuilder +impl> StatTraceBuilder> + for PotentialStatsBuilder { - fn append_value(&mut self, value: PotentialStats) { - let PotentialStats { - mass_matrix_stats, + fn append_value(&mut self, math: Option<&mut M>, value: &EuclideanHamiltonian) { + let math = math.expect("Sampler stats needs math"); + let Self { + mass_matrix, step_size, - } = value; + } = self; - self.mass_matrix.append_value(mass_matrix_stats); - self.step_size.append_value(step_size); + mass_matrix.append_value(Some(math), &value.mass_matrix); + step_size.append_value(value.step_size); } fn finalize(self) -> Option { @@ -292,23 +284,21 @@ impl> StatTraceBuilder> SamplerStats for EuclideanHamiltonian { type Builder = PotentialStatsBuilder; - type Stats = PotentialStats; + type StatOptions = Mass::StatOptions; - fn new_builder(&self, settings: &impl Settings, dim: usize) -> Self::Builder { + fn new_builder( + &self, + stat_options: Self::StatOptions, + settings: &impl Settings, + dim: usize, + ) -> Self::Builder { Self::Builder { - mass_matrix: self.mass_matrix.new_builder(settings, dim), + mass_matrix: self.mass_matrix.new_builder(stat_options, settings, dim), step_size: Float64Builder::with_capacity( settings.hint_num_draws() + settings.hint_num_tune(), ), } } - - fn current_stats(&self, math: &mut M) -> Self::Stats { - PotentialStats { - mass_matrix_stats: self.mass_matrix.current_stats(math), - step_size: self.step_size, - } - } } impl> Hamiltonian for EuclideanHamiltonian { diff --git a/src/lib.rs b/src/lib.rs index b86b0f3..53e0273 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,7 +11,7 @@ //! ## Usage //! //! ``` -//! use nuts_rs::{CpuLogpFunc, CpuMath, LogpError, DiagGradNutsSettings, Chain, SampleStats, +//! use nuts_rs::{CpuLogpFunc, CpuMath, LogpError, DiagGradNutsSettings, Chain, Progress, //! Settings}; //! use thiserror::Error; //! use rand::thread_rng; @@ -110,11 +110,11 @@ pub use chain::Chain; pub use cpu_math::{CpuLogpFunc, CpuMath}; pub use hamiltonian::DivergenceInfo; pub use math_base::{LogpError, Math}; -pub use nuts::{NutsError, SampleStats}; +pub use nuts::NutsError; pub use sampler::{ sample_sequentially, ChainOutput, ChainProgress, DiagGradNutsSettings, DrawStorage, - LowRankNutsSettings, Model, NutsSettings, ProgressCallback, Sampler, SamplerWaitResult, - Settings, Trace, TransformedNutsSettings, + LowRankNutsSettings, Model, NutsSettings, Progress, ProgressCallback, Sampler, + SamplerWaitResult, Settings, Trace, TransformedNutsSettings, }; pub use low_rank_mass_matrix::LowRankSettings; diff --git a/src/low_rank_mass_matrix.rs b/src/low_rank_mass_matrix.rs index b4edfa3..e2f2b6d 100644 --- a/src/low_rank_mass_matrix.rs +++ b/src/low_rank_mass_matrix.rs @@ -118,45 +118,43 @@ impl Default for LowRankSettings { } } -#[derive(Clone, Debug)] -pub struct MatrixStats { - eigenvals: Option>, - stds: Option>, - num_eigenvalues: u64, -} - pub struct MatrixBuilder { eigenvals: Option>>, stds: Option>>, num_eigenvalues: PrimitiveBuilder, } -impl StatTraceBuilder for MatrixBuilder { - fn append_value(&mut self, value: MatrixStats) { - let MatrixStats { +impl StatTraceBuilder> for MatrixBuilder { + fn append_value(&mut self, math: Option<&mut M>, value: &LowRankMassMatrix) { + let math = math.expect("Need reference to math for stats"); + let Self { eigenvals, stds, num_eigenvalues, - } = value; + } = self; - if let Some(store) = self.eigenvals.as_mut() { - if let Some(values) = eigenvals.as_ref() { - store.values().append_slice(values); + if let Some(store) = eigenvals { + if let Some(inner) = &value.inner { + store + .values() + .append_slice(&math.eigs_as_array(&inner.vals)); store.append(true); } else { store.append(false); } } - if let Some(store) = self.stds.as_mut() { - if let Some(values) = stds.as_ref() { - store.values().append_slice(values); - store.append(true); - } else { - store.append(false); - } + if let Some(store) = stds { + store.values().append_slice(&math.box_array(&value.stds)); + store.append(true); } - self.num_eigenvalues.append_value(num_eigenvalues); + num_eigenvalues.append_value( + value + .inner + .as_ref() + .map(|inner| inner.num_eigenvalues) + .unwrap_or(0), + ); } fn finalize(self) -> Option { @@ -242,10 +240,15 @@ impl StatTraceBuilder for MatrixBuilder { } impl SamplerStats for LowRankMassMatrix { - type Stats = MatrixStats; type Builder = MatrixBuilder; + type StatOptions = (); - fn new_builder(&self, _settings: &impl crate::Settings, dim: usize) -> Self::Builder { + fn new_builder( + &self, + _stat_options: Self::StatOptions, + _settings: &impl crate::Settings, + dim: usize, + ) -> Self::Builder { let num_eigenvalues = PrimitiveBuilder::new(); if self.settings.store_mass_matrix { let items = PrimitiveBuilder::new(); @@ -267,36 +270,6 @@ impl SamplerStats for LowRankMassMatrix { } } } - - fn current_stats(&self, math: &mut M) -> Self::Stats { - let num_eigenvalues = self - .inner - .as_ref() - .map(|inner| inner.num_eigenvalues) - .unwrap_or(0); - - if self.settings.store_mass_matrix { - let mut stds = vec![0f64; math.dim()].into_boxed_slice(); - math.write_to_slice(&self.stds, &mut stds); - - let vals = self - .inner - .as_ref() - .map(|inner| math.eigs_as_array(&inner.vals)); - - MatrixStats { - stds: Some(stds), - eigenvals: vals, - num_eigenvalues, - } - } else { - MatrixStats { - stds: None, - eigenvals: None, - num_eigenvalues, - } - } - } } impl MassMatrix for LowRankMassMatrix { @@ -340,22 +313,24 @@ impl MassMatrix for LowRankMassMatrix { } } +/* #[derive(Debug, Clone)] pub struct Stats { - //foreground_length: u64, - //background_length: u64, - //is_update: bool, - //diag: Box<[f64]>, - //eigvalues: Box<[f64]>, - //eigvectors: Box<[f64]>, + foreground_length: u64, + background_length: u64, + is_update: bool, + diag: Box<[f64]>, + eigvalues: Box<[f64]>, + eigvectors: Box<[f64]>, } +*/ #[derive(Debug)] pub struct Builder {} -impl StatTraceBuilder for Builder { - fn append_value(&mut self, value: Stats) { - let Stats {} = value; +impl StatTraceBuilder for Builder { + fn append_value(&mut self, _math: Option<&mut M>, _value: &LowRankMassMatrixStrategy) { + let Self {} = self; } fn finalize(self) -> Option { @@ -555,16 +530,17 @@ fn spd_mean(cov_draws: Mat, cov_grads: Mat) -> Mat { } impl SamplerStats for LowRankMassMatrixStrategy { - type Stats = Stats; type Builder = Builder; + type StatOptions = (); - fn new_builder(&self, _settings: &impl crate::Settings, _dim: usize) -> Self::Builder { + fn new_builder( + &self, + _stat_options: Self::StatOptions, + _settings: &impl crate::Settings, + _dim: usize, + ) -> Self::Builder { Builder {} } - - fn current_stats(&self, _math: &mut M) -> Self::Stats { - Stats {} - } } impl MassMatrixAdaptStrategy for LowRankMassMatrixStrategy { diff --git a/src/mass_matrix.rs b/src/mass_matrix.rs index d6d5a8c..2f0219c 100644 --- a/src/mass_matrix.rs +++ b/src/mass_matrix.rs @@ -35,26 +35,19 @@ pub struct DiagMassMatrix { store_mass_matrix: bool, } -#[derive(Clone, Debug)] -pub struct DiagMassMatrixStats { - pub mass_matrix_inv: Option>, -} - pub struct DiagMassMatrixStatsBuilder { mass_matrix_inv: Option>>, } -impl StatTraceBuilder for DiagMassMatrixStatsBuilder { - fn append_value(&mut self, value: DiagMassMatrixStats) { - let DiagMassMatrixStats { mass_matrix_inv } = value; +impl StatTraceBuilder> for DiagMassMatrixStatsBuilder { + fn append_value(&mut self, math: Option<&mut M>, value: &DiagMassMatrix) { + let math = math.expect("Need reference to math for stats"); + let Self { mass_matrix_inv } = self; - if let Some(store) = self.mass_matrix_inv.as_mut() { - if let Some(values) = mass_matrix_inv.as_ref() { - store.values().append_slice(values); - store.append(true); - } else { - store.append(false); - } + if let Some(store) = mass_matrix_inv { + let values = math.box_array(&value.variance); + store.values().append_slice(&values); + store.append(true); } } @@ -88,9 +81,14 @@ impl StatTraceBuilder for DiagMassMatrixStatsBuilder { impl SamplerStats for DiagMassMatrix { type Builder = DiagMassMatrixStatsBuilder; - type Stats = DiagMassMatrixStats; + type StatOptions = (); - fn new_builder(&self, _settings: &impl Settings, dim: usize) -> Self::Builder { + fn new_builder( + &self, + _stat_options: Self::StatOptions, + _settings: &impl Settings, + dim: usize, + ) -> Self::Builder { if self.store_mass_matrix { let items = PrimitiveBuilder::new(); let values = FixedSizeListBuilder::new(items, dim as _); @@ -103,17 +101,6 @@ impl SamplerStats for DiagMassMatrix { } } } - - fn current_stats(&self, math: &mut M) -> Self::Stats { - let matrix = if self.store_mass_matrix { - Some(math.box_array(&self.variance)) - } else { - None - }; - DiagMassMatrixStats { - mass_matrix_inv: matrix, - } - } } impl DiagMassMatrix { diff --git a/src/mass_matrix_adapt.rs b/src/mass_matrix_adapt.rs index 06a04aa..419243a 100644 --- a/src/mass_matrix_adapt.rs +++ b/src/mass_matrix_adapt.rs @@ -166,14 +166,17 @@ impl MassMatrixAdaptStrategy for Strategy { } } -pub type Stats = (); pub type StatsBuilder = (); impl SamplerStats for Strategy { - type Builder = Stats; - type Stats = StatsBuilder; - - fn new_builder(&self, _settings: &impl Settings, _dim: usize) -> Self::Builder {} - - fn current_stats(&self, _math: &mut M) -> Self::Stats {} + type Builder = StatsBuilder; + type StatOptions = (); + + fn new_builder( + &self, + _stat_options: Self::StatOptions, + _settings: &impl Settings, + _dim: usize, + ) -> Self::Builder { + } } diff --git a/src/math_base.rs b/src/math_base.rs index b8751f9..82e5b50 100644 --- a/src/math_base.rs +++ b/src/math_base.rs @@ -174,6 +174,7 @@ pub trait Math { rng: &mut R, untransformed_positions: impl ExactSizeIterator, untransformed_gradients: impl ExactSizeIterator, + untransformed_logps: impl ExactSizeIterator, params: &'a mut Self::TransformParams, ) -> Result<(), Self::LogpErr>; diff --git a/src/nuts.rs b/src/nuts.rs index 246775f..18e782e 100644 --- a/src/nuts.rs +++ b/src/nuts.rs @@ -1,19 +1,9 @@ -use arrow::array::{ - Array, ArrayBuilder, BooleanBuilder, FixedSizeListBuilder, PrimitiveBuilder, StringBuilder, - StructArray, -}; -use arrow::datatypes::{DataType, Field, Fields, Float64Type, Int64Type, UInt64Type}; use thiserror::Error; -use std::ops::Deref; -use std::sync::Arc; use std::{fmt::Debug, marker::PhantomData}; -use crate::chain::AdaptStrategy; use crate::hamiltonian::{Direction, DivergenceInfo, Hamiltonian, LeapfrogResult, Point}; use crate::math::logaddexp; -use crate::sampler::Settings; -use crate::sampler_stats::StatTraceBuilder; use crate::state::State; use crate::math_base::Math; @@ -308,495 +298,13 @@ where Ok((tree.draw, info)) } -#[derive(Debug, Clone)] -#[non_exhaustive] -pub struct NutsSampleStats< - PointStats: Send + Debug + Clone, - HStats: Send + Debug + Clone, - AdaptStats: Send + Debug + Clone, -> { - pub depth: u64, - pub maxdepth_reached: bool, - pub idx_in_trajectory: i64, - pub logp: f64, - pub energy: f64, - pub energy_error: f64, - pub divergence_info: Option, - pub chain: u64, - pub draw: u64, - pub gradient: Option>, - pub unconstrained: Option>, - pub potential_stats: HStats, - pub strategy_stats: AdaptStats, - pub point_stats: PointStats, - pub tuning: bool, -} - -#[derive(Debug, Clone)] -#[non_exhaustive] -pub struct SampleStats { - pub draw: u64, - pub chain: u64, - pub diverging: bool, - pub tuning: bool, - pub step_size: f64, - pub num_steps: u64, -} - -pub struct NutsStatsBuilder { - depth: PrimitiveBuilder, - maxdepth_reached: BooleanBuilder, - index_in_trajectory: PrimitiveBuilder, - logp: PrimitiveBuilder, - energy: PrimitiveBuilder, - chain: PrimitiveBuilder, - draw: PrimitiveBuilder, - energy_error: PrimitiveBuilder, - unconstrained: Option>>, - gradient: Option>>, - hamiltonian: H, - adapt: A, - point: P, - diverging: BooleanBuilder, - divergence_start: Option>>, - divergence_start_grad: Option>>, - divergence_end: Option>>, - divergence_momentum: Option>>, - divergence_msg: Option, - n_dim: usize, -} - -impl NutsStatsBuilder { - pub fn new_with_capacity< - M: Math, - P: Point, - H: Hamiltonian, - A: AdaptStrategy, - >( - settings: &impl Settings, - hamiltonian: &H, - adapt: &A, - point: &P, - dim: usize, - options: &NutsOptions, - ) -> Self { - let capacity = settings.hint_num_tune() + settings.hint_num_draws(); - - let gradient = if options.store_gradient { - let items = PrimitiveBuilder::with_capacity(capacity); - Some(FixedSizeListBuilder::new(items, dim as i32)) - } else { - None - }; - - let unconstrained = if options.store_unconstrained { - let items = PrimitiveBuilder::with_capacity(capacity); - Some(FixedSizeListBuilder::with_capacity( - items, dim as i32, capacity, - )) - } else { - None - }; - - let (div_start, div_start_grad, div_end, div_mom, div_msg) = if options.store_divergences { - let start_location_prim = PrimitiveBuilder::new(); - let start_location_list = FixedSizeListBuilder::new(start_location_prim, dim as i32); - - let start_grad_prim = PrimitiveBuilder::new(); - let start_grad_list = FixedSizeListBuilder::new(start_grad_prim, dim as i32); - - let end_location_prim = PrimitiveBuilder::new(); - let end_location_list = FixedSizeListBuilder::new(end_location_prim, dim as i32); - - let momentum_location_prim = PrimitiveBuilder::new(); - let momentum_location_list = - FixedSizeListBuilder::new(momentum_location_prim, dim as i32); - - let msg_list = StringBuilder::new(); - - ( - Some(start_location_list), - Some(start_grad_list), - Some(end_location_list), - Some(momentum_location_list), - Some(msg_list), - ) - } else { - (None, None, None, None, None) - }; - - Self { - depth: PrimitiveBuilder::with_capacity(capacity), - maxdepth_reached: BooleanBuilder::with_capacity(capacity), - index_in_trajectory: PrimitiveBuilder::with_capacity(capacity), - logp: PrimitiveBuilder::with_capacity(capacity), - energy: PrimitiveBuilder::with_capacity(capacity), - chain: PrimitiveBuilder::with_capacity(capacity), - draw: PrimitiveBuilder::with_capacity(capacity), - energy_error: PrimitiveBuilder::with_capacity(capacity), - gradient, - unconstrained, - hamiltonian: hamiltonian.new_builder(settings, dim), - adapt: adapt.new_builder(settings, dim), - point: point.new_builder(settings, dim), - diverging: BooleanBuilder::with_capacity(capacity), - divergence_start: div_start, - divergence_start_grad: div_start_grad, - divergence_end: div_end, - divergence_momentum: div_mom, - divergence_msg: div_msg, - n_dim: dim, - } - } -} - -impl StatTraceBuilder> - for NutsStatsBuilder -where - HB: StatTraceBuilder, - AB: StatTraceBuilder, - PB: StatTraceBuilder, - HS: Clone + Send + Debug, - AS: Clone + Send + Debug, - PS: Clone + Send + Debug, -{ - fn append_value(&mut self, value: NutsSampleStats) { - let NutsSampleStats { - depth, - maxdepth_reached, - idx_in_trajectory, - logp, - energy, - energy_error, - divergence_info, - chain, - draw, - gradient, - unconstrained, - potential_stats, - strategy_stats, - point_stats, - tuning, - } = value; - - // We don't need to store tuning explicity - let _ = tuning; - - self.depth.append_value(depth); - self.maxdepth_reached.append_value(maxdepth_reached); - self.index_in_trajectory.append_value(idx_in_trajectory); - self.logp.append_value(logp); - self.energy.append_value(energy); - self.chain.append_value(chain); - self.draw.append_value(draw); - self.diverging.append_value(divergence_info.is_some()); - self.energy_error.append_value(energy_error); - - fn add_slice>( - store: &mut Option>>, - values: Option, - n_dim: usize, - ) { - let Some(store) = store.as_mut() else { - return; - }; - - if let Some(values) = values.as_ref() { - store.values().append_slice(values.as_ref()); - store.append(true); - } else { - store.values().append_nulls(n_dim); - store.append(false); - } - } - - add_slice(&mut self.gradient, gradient.as_ref(), self.n_dim); - add_slice(&mut self.unconstrained, unconstrained.as_ref(), self.n_dim); - - let div_info = divergence_info.as_ref(); - add_slice( - &mut self.divergence_start, - div_info.and_then(|info| info.start_location.as_ref()), - self.n_dim, - ); - add_slice( - &mut self.divergence_start_grad, - div_info.and_then(|info| info.start_gradient.as_ref()), - self.n_dim, - ); - add_slice( - &mut self.divergence_end, - div_info.and_then(|info| info.end_location.as_ref()), - self.n_dim, - ); - add_slice( - &mut self.divergence_momentum, - div_info.and_then(|info| info.start_momentum.as_ref()), - self.n_dim, - ); - - if let Some(div_msg) = self.divergence_msg.as_mut() { - if let Some(err) = div_info.and_then(|info| info.logp_function_error.as_ref()) { - div_msg.append_value(format!("{}", err)); - } else { - div_msg.append_null(); - } - } - - self.hamiltonian.append_value(potential_stats); - self.adapt.append_value(strategy_stats); - self.point.append_value(point_stats); - } - - fn finalize(self) -> Option { - let Self { - mut depth, - mut maxdepth_reached, - mut index_in_trajectory, - mut logp, - mut energy, - mut chain, - mut draw, - mut energy_error, - unconstrained, - gradient, - hamiltonian, - adapt, - point, - mut diverging, - divergence_start, - divergence_start_grad, - divergence_end, - divergence_momentum, - divergence_msg, - n_dim, - } = self; - - let _ = n_dim; - - let mut fields = vec![ - Field::new("depth", DataType::UInt64, false), - Field::new("maxdepth_reached", DataType::Boolean, false), - Field::new("index_in_trajectory", DataType::Int64, false), - Field::new("logp", DataType::Float64, false), - Field::new("energy", DataType::Float64, false), - Field::new("chain", DataType::UInt64, false), - Field::new("draw", DataType::UInt64, false), - Field::new("diverging", DataType::Boolean, false), - Field::new("energy_error", DataType::Float64, false), - ]; - - let mut arrays: Vec> = vec![ - ArrayBuilder::finish(&mut depth), - ArrayBuilder::finish(&mut maxdepth_reached), - ArrayBuilder::finish(&mut index_in_trajectory), - ArrayBuilder::finish(&mut logp), - ArrayBuilder::finish(&mut energy), - ArrayBuilder::finish(&mut chain), - ArrayBuilder::finish(&mut draw), - ArrayBuilder::finish(&mut diverging), - ArrayBuilder::finish(&mut energy_error), - ]; - - fn merge_into>( - builder: B, - arrays: &mut Vec>, - fields: &mut Vec, - ) { - let Some(struct_array) = builder.finalize() else { - return; - }; - - let (struct_fields, struct_arrays, bitmap) = struct_array.into_parts(); - assert!(bitmap.is_none()); - arrays.extend(struct_arrays); - fields.extend(struct_fields.into_iter().map(|x| x.deref().clone())); - } - - fn add_field( - mut builder: Option, - name: &str, - arrays: &mut Vec>, - fields: &mut Vec, - ) { - let Some(mut builder) = builder.take() else { - return; - }; - - let array = ArrayBuilder::finish(&mut builder); - fields.push(Field::new(name, array.data_type().clone(), true)); - arrays.push(array); - } - - merge_into(hamiltonian, &mut arrays, &mut fields); - merge_into(adapt, &mut arrays, &mut fields); - merge_into(point, &mut arrays, &mut fields); - - add_field(gradient, "gradient", &mut arrays, &mut fields); - add_field( - unconstrained, - "unconstrained_draw", - &mut arrays, - &mut fields, - ); - add_field( - divergence_start, - "divergence_start", - &mut arrays, - &mut fields, - ); - add_field( - divergence_start_grad, - "divergence_start_gradient", - &mut arrays, - &mut fields, - ); - add_field(divergence_end, "divergence_end", &mut arrays, &mut fields); - add_field( - divergence_momentum, - "divergence_momentum", - &mut arrays, - &mut fields, - ); - add_field( - divergence_msg, - "divergence_messagem", - &mut arrays, - &mut fields, - ); - - let fields = Fields::from(fields); - Some(StructArray::new(fields, arrays, None)) - } - - fn inspect(&self) -> Option { - let Self { - depth, - maxdepth_reached, - index_in_trajectory, - logp, - energy, - chain, - draw, - energy_error, - unconstrained, - gradient, - hamiltonian, - adapt, - point, - diverging, - divergence_start, - divergence_start_grad, - divergence_end, - divergence_momentum, - divergence_msg, - n_dim, - } = self; - - let _ = n_dim; - - let mut fields = vec![ - Field::new("depth", DataType::UInt64, false), - Field::new("maxdepth_reached", DataType::Boolean, false), - Field::new("index_in_trajectory", DataType::Int64, false), - Field::new("logp", DataType::Float64, false), - Field::new("energy", DataType::Float64, false), - Field::new("chain", DataType::UInt64, false), - Field::new("draw", DataType::UInt64, false), - Field::new("diverging", DataType::Boolean, false), - Field::new("energy_error", DataType::Float64, false), - ]; - - let mut arrays: Vec> = vec![ - ArrayBuilder::finish_cloned(depth), - ArrayBuilder::finish_cloned(maxdepth_reached), - ArrayBuilder::finish_cloned(index_in_trajectory), - ArrayBuilder::finish_cloned(logp), - ArrayBuilder::finish_cloned(energy), - ArrayBuilder::finish_cloned(chain), - ArrayBuilder::finish_cloned(draw), - ArrayBuilder::finish_cloned(diverging), - ArrayBuilder::finish_cloned(energy_error), - ]; - - fn merge_into>( - builder: &B, - arrays: &mut Vec>, - fields: &mut Vec, - ) { - let Some(struct_array) = builder.inspect() else { - return; - }; - - let (struct_fields, struct_arrays, bitmap) = struct_array.into_parts(); - assert!(bitmap.is_none()); - arrays.extend(struct_arrays); - fields.extend(struct_fields.into_iter().map(|x| x.deref().clone())); - } - - fn add_field( - builder: &Option, - name: &str, - arrays: &mut Vec>, - fields: &mut Vec, - ) { - let Some(builder) = builder.as_ref() else { - return; - }; - - let array = ArrayBuilder::finish_cloned(builder); - fields.push(Field::new(name, array.data_type().clone(), true)); - arrays.push(array); - } - - merge_into(hamiltonian, &mut arrays, &mut fields); - merge_into(adapt, &mut arrays, &mut fields); - merge_into(point, &mut arrays, &mut fields); - - add_field(gradient, "gradient", &mut arrays, &mut fields); - add_field( - unconstrained, - "unconstrained_draw", - &mut arrays, - &mut fields, - ); - add_field( - divergence_start, - "divergence_start", - &mut arrays, - &mut fields, - ); - add_field( - divergence_start_grad, - "divergence_start_gradient", - &mut arrays, - &mut fields, - ); - add_field(divergence_end, "divergence_end", &mut arrays, &mut fields); - add_field( - divergence_momentum, - "divergence_momentum", - &mut arrays, - &mut fields, - ); - add_field( - divergence_msg, - "divergence_messagem", - &mut arrays, - &mut fields, - ); - - let fields = Fields::from(fields); - Some(StructArray::new(fields, arrays, None)) - } -} - #[cfg(test)] mod tests { - use rand::thread_rng; + use rand::{rngs::ThreadRng, thread_rng}; use crate::{ adapt_strategy::test_logps::NormalLogp, + chain::NutsChain, cpu_math::CpuMath, sampler::DiagGradNutsSettings, sampler_stats::{SamplerStats, StatTraceBuilder}, @@ -814,13 +322,17 @@ mod tests { let mut chain = settings.new_chain(0, math, &mut rng); - let mut builder = chain.new_builder(&settings, ndim); + let opt_settings = settings.stats_options(&chain); + let mut builder = chain.new_builder(opt_settings, &settings, ndim); + let (_, mut progress) = chain.draw().unwrap(); for _ in 0..10 { - let (_, stats) = chain.draw().unwrap(); - builder.append_value(stats); + let (_, prog) = chain.draw().unwrap(); + progress = prog; + builder.append_value(None, &chain); } - builder.finalize(); + assert!(!progress.diverging); + StatTraceBuilder::<_, NutsChain<_, ThreadRng, _>>::finalize(builder); } } diff --git a/src/sampler.rs b/src/sampler.rs index 933c053..7481d0b 100644 --- a/src/sampler.rs +++ b/src/sampler.rs @@ -18,7 +18,7 @@ use std::{ use crate::{ adapt_strategy::{EuclideanAdaptOptions, GlobalStrategy}, - chain::{AdaptStrategy, Chain, NutsChain}, + chain::{AdaptStrategy, Chain, NutsChain, StatOptions}, euclidean_hamiltonian::EuclideanHamiltonian, low_rank_mass_matrix::{LowRankMassMatrix, LowRankMassMatrixStrategy, LowRankSettings}, mass_matrix::DiagMassMatrix, @@ -27,8 +27,8 @@ use crate::{ nuts::NutsOptions, sampler_stats::{SamplerStats, StatTraceBuilder}, transform_adapt_strategy::{TransformAdaptation, TransformedSettings}, - transformed_hamiltonian::TransformedHamiltonian, - DiagAdaptExpSettings, SampleStats, + transformed_hamiltonian::{TransformedHamiltonian, TransformedPointStatsOptions}, + DiagAdaptExpSettings, }; /// All sampler configurations implement this trait @@ -42,14 +42,25 @@ pub trait Settings: private::Sealed + Clone + Copy + Default + Sync + Send + 'st rng: &mut R, ) -> Self::Chain; - fn sample_stats( - &self, - stats: & as SamplerStats>::Stats, - ) -> SampleStats; fn hint_num_tune(&self) -> usize; fn hint_num_draws(&self) -> usize; fn num_chains(&self) -> usize; fn seed(&self) -> u64; + fn stats_options( + &self, + chain: &Self::Chain, + ) -> as SamplerStats>::StatOptions; +} + +#[derive(Debug, Clone)] +#[non_exhaustive] +pub struct Progress { + pub draw: u64, + pub chain: u64, + pub diverging: bool, + pub tuning: bool, + pub step_size: f64, + pub num_steps: u64, } mod private { @@ -187,22 +198,6 @@ impl Settings for LowRankNutsSettings { NutsChain::new(math, potential, strategy, options, rng, chain) } - fn sample_stats( - &self, - stats: & as SamplerStats>::Stats, - ) -> SampleStats { - let step_size = stats.potential_stats.step_size; - let num_steps = stats.strategy_stats.stats1.n_steps; - SampleStats { - chain: stats.chain, - draw: stats.draw, - diverging: stats.divergence_info.is_some(), - tuning: stats.tuning, - step_size, - num_steps, - } - } - fn hint_num_tune(&self) -> usize { self.num_tune as _ } @@ -218,6 +213,17 @@ impl Settings for LowRankNutsSettings { fn seed(&self) -> u64 { self.seed } + + fn stats_options( + &self, + _chain: &Self::Chain, + ) -> as SamplerStats>::StatOptions { + StatOptions { + adapt: (), + hamiltonian: (), + point: (), + } + } } impl Settings for DiagGradNutsSettings { @@ -251,22 +257,6 @@ impl Settings for DiagGradNutsSettings { NutsChain::new(math, potential, strategy, options, rng, chain) } - fn sample_stats( - &self, - stats: & as SamplerStats>::Stats, - ) -> SampleStats { - let step_size = stats.potential_stats.step_size; - let num_steps = stats.strategy_stats.stats1.n_steps; - SampleStats { - chain: stats.chain, - draw: stats.draw, - diverging: stats.divergence_info.is_some(), - tuning: stats.tuning, - step_size, - num_steps, - } - } - fn hint_num_tune(&self) -> usize { self.num_tune as _ } @@ -282,6 +272,17 @@ impl Settings for DiagGradNutsSettings { fn seed(&self) -> u64 { self.seed } + + fn stats_options( + &self, + _chain: &Self::Chain, + ) -> as SamplerStats>::StatOptions { + StatOptions { + adapt: (), + hamiltonian: (), + point: (), + } + } } impl Settings for TransformedNutsSettings { @@ -311,22 +312,6 @@ impl Settings for TransformedNutsSettings { NutsChain::new(math, hamiltonian, strategy, options, rng, chain) } - fn sample_stats( - &self, - stats: & as SamplerStats>::Stats, - ) -> SampleStats { - let step_size = stats.potential_stats.step_size; - let num_steps = stats.strategy_stats.step_size.n_steps; - SampleStats { - chain: stats.chain, - draw: stats.draw, - diverging: stats.divergence_info.is_some(), - tuning: stats.tuning, - step_size, - num_steps, - } - } - fn hint_num_tune(&self) -> usize { self.num_tune as _ } @@ -342,6 +327,21 @@ impl Settings for TransformedNutsSettings { fn seed(&self) -> u64 { self.seed } + + fn stats_options( + &self, + _chain: &Self::Chain, + ) -> as SamplerStats>::StatOptions { + // TODO make extra config + let point = TransformedPointStatsOptions { + store_transformed: self.store_unconstrained, + }; + StatOptions { + adapt: (), + hamiltonian: (), + point, + } + } } pub fn sample_sequentially<'math, M: Math + 'math, R: Rng + ?Sized>( @@ -351,14 +351,10 @@ pub fn sample_sequentially<'math, M: Math + 'math, R: Rng + ?Sized>( draws: u64, chain: u64, rng: &mut R, -) -> Result, SampleStats)>> + 'math> { +) -> Result, Progress)>> + 'math> { let mut sampler = settings.new_chain(chain, math, rng); sampler.set_position(start)?; - Ok((0..draws).map(move |_| { - sampler - .draw() - .map(|(point, info)| (point, settings.sample_stats::(&info))) - })) + Ok((0..draws).map(move |_| sampler.draw())) } pub trait DrawStorage: Send { @@ -416,7 +412,7 @@ impl ChainProgress { } } - fn update(&mut self, stats: &SampleStats, draw_duration: Duration) { + fn update(&mut self, stats: &Progress, draw_duration: Duration) { if stats.diverging & !stats.tuning { self.divergences += 1; self.divergent_draws.push(self.finished_draws); @@ -558,7 +554,8 @@ impl<'scope, M: Model, S: Settings> ChainProcess<'scope, M, S> { let draw_trace = model .new_trace(&mut rng, chain_id, settings) .context("Failed to create trace object")?; - let stats_trace = sampler.new_builder(settings, dim); + let stat_opts = settings.stats_options(&sampler); + let stats_trace = sampler.new_builder(stat_opts, settings, dim); let new_trace = ChainTrace { draws_builder: draw_trace, @@ -618,9 +615,9 @@ impl<'scope, M: Model, S: Settings> ChainProcess<'scope, M, S> { progress .lock() .expect("Poisoned mutex") - .update(&settings.sample_stats(&info), now.elapsed()); + .update(&info, now.elapsed()); DrawStorage::append_value(&mut val.draws_builder, &point)?; - StatTraceBuilder::append_value(&mut val.stats_builder, info); + StatTraceBuilder::append_value(&mut val.stats_builder, None, &sampler); draw += 1; if draw == draws { break; @@ -943,6 +940,7 @@ pub mod test_logps { #[derive(Error, Debug)] pub enum NormalLogpError {} + impl LogpError for NormalLogpError { fn is_recoverable(&self) -> bool { false @@ -1054,6 +1052,7 @@ pub mod test_logps { _rng: &mut R, _untransformed_positions: impl Iterator, _untransformed_gradients: impl Iterator, + _untransformed_logp: impl Iterator, _params: &'b mut Self::TransformParams, ) -> std::result::Result<(), Self::LogpError> { unimplemented!() @@ -1180,7 +1179,7 @@ mod tests { let mut chain = settings.new_chain(0, math, &mut rng); let (_draw, info) = chain.draw()?; - assert!(settings.sample_stats::>(&info).tuning); + assert!(info.tuning); assert_eq!(info.draw, 0); let math = CpuMath::new(&logp); diff --git a/src/sampler_stats.rs b/src/sampler_stats.rs index 224a9a7..3b3b457 100644 --- a/src/sampler_stats.rs +++ b/src/sampler_stats.rs @@ -1,19 +1,27 @@ -use std::fmt::Debug; - use arrow::array::StructArray; use crate::{Math, Settings}; pub trait SamplerStats { - type Stats: Send + Debug + Clone; - type Builder: StatTraceBuilder; + type Builder: StatTraceBuilder; + type StatOptions; + + fn new_builder( + &self, + options: Self::StatOptions, + settings: &impl Settings, + dim: usize, + ) -> Self::Builder; +} - fn new_builder(&self, settings: &impl Settings, dim: usize) -> Self::Builder; - fn current_stats(&self, math: &mut M) -> Self::Stats; +pub trait StatTraceBuilder: Send { + fn append_value(&mut self, math: Option<&mut M>, value: &T); + fn finalize(self) -> Option; + fn inspect(&self) -> Option; } -impl StatTraceBuilder<()> for () { - fn append_value(&mut self, _value: ()) {} +impl StatTraceBuilder for () { + fn append_value(&mut self, _math: Option<&mut M>, _value: &T) {} fn finalize(self) -> Option { None @@ -23,9 +31,3 @@ impl StatTraceBuilder<()> for () { None } } - -pub trait StatTraceBuilder: Send { - fn append_value(&mut self, value: T); - fn finalize(self) -> Option; - fn inspect(&self) -> Option; -} diff --git a/src/stepsize_adapt.rs b/src/stepsize_adapt.rs index 16211c5..f8323e9 100644 --- a/src/stepsize_adapt.rs +++ b/src/stepsize_adapt.rs @@ -15,9 +15,9 @@ use crate::{ pub struct Strategy { step_size_adapt: DualAverage, options: DualAverageSettings, - last_mean_tree_accept: f64, - last_sym_mean_tree_accept: f64, - last_n_steps: u64, + pub last_mean_tree_accept: f64, + pub last_sym_mean_tree_accept: f64, + pub last_n_steps: u64, } impl Strategy { @@ -133,14 +133,6 @@ impl Strategy { } } -#[derive(Debug, Clone, Copy)] -pub struct Stats { - pub step_size_bar: f64, - pub mean_tree_accept: f64, - pub mean_tree_accept_sym: f64, - pub n_steps: u64, -} - pub struct StatsBuilder { step_size_bar: PrimitiveBuilder, mean_tree_accept: PrimitiveBuilder, @@ -148,13 +140,15 @@ pub struct StatsBuilder { n_steps: PrimitiveBuilder, } -impl StatTraceBuilder for StatsBuilder { - fn append_value(&mut self, value: Stats) { - self.step_size_bar.append_value(value.step_size_bar); - self.mean_tree_accept.append_value(value.mean_tree_accept); +impl StatTraceBuilder for StatsBuilder { + fn append_value(&mut self, _math: Option<&mut M>, value: &Strategy) { + self.step_size_bar + .append_value(value.step_size_adapt.current_step_size_adapted()); + self.mean_tree_accept + .append_value(value.last_mean_tree_accept); self.mean_tree_accept_sym - .append_value(value.mean_tree_accept_sym); - self.n_steps.append_value(value.n_steps); + .append_value(value.last_sym_mean_tree_accept); + self.n_steps.append_value(value.last_n_steps); } fn finalize(self) -> Option { @@ -210,9 +204,14 @@ impl StatTraceBuilder for StatsBuilder { impl SamplerStats for Strategy { type Builder = StatsBuilder; - type Stats = Stats; - - fn new_builder(&self, _settings: &impl Settings, _dim: usize) -> Self::Builder { + type StatOptions = (); + + fn new_builder( + &self, + _stat_options: Self::StatOptions, + _settings: &impl Settings, + _dim: usize, + ) -> Self::Builder { Self::Builder { step_size_bar: PrimitiveBuilder::new(), mean_tree_accept: PrimitiveBuilder::new(), @@ -220,15 +219,6 @@ impl SamplerStats for Strategy { n_steps: PrimitiveBuilder::new(), } } - - fn current_stats(&self, _math: &mut M) -> Self::Stats { - Stats { - step_size_bar: self.step_size_adapt.current_step_size_adapted(), - mean_tree_accept: self.last_mean_tree_accept, - mean_tree_accept_sym: self.last_sym_mean_tree_accept, - n_steps: self.last_n_steps, - } - } } #[derive(Debug, Clone, Copy)] diff --git a/src/transform_adapt_strategy.rs b/src/transform_adapt_strategy.rs index dbf666d..a551e04 100644 --- a/src/transform_adapt_strategy.rs +++ b/src/transform_adapt_strategy.rs @@ -7,9 +7,7 @@ use crate::nuts::{Collector, NutsOptions, SampleInfo}; use crate::sampler_stats::{SamplerStats, StatTraceBuilder}; use crate::state::State; use crate::stepsize::AcceptanceRateCollector; -use crate::stepsize_adapt::{ - Stats as StepSizeStats, StatsBuilder as StepSizeStatsBuilder, Strategy as StepSizeStrategy, -}; +use crate::stepsize_adapt::{StatsBuilder as StepSizeStatsBuilder, Strategy as StepSizeStrategy}; use crate::transformed_hamiltonian::TransformedHamiltonian; use crate::{DualAverageSettings, Math, NutsError, Settings}; @@ -43,50 +41,46 @@ pub struct TransformAdaptation { chain: u64, } -#[derive(Clone, Debug)] -pub struct Stats { - pub step_size: StepSizeStats, -} - pub struct Builder { step_size: StepSizeStatsBuilder, } -impl StatTraceBuilder for Builder { - fn append_value(&mut self, value: Stats) { - let Stats { step_size } = value; - self.step_size.append_value(step_size); +impl StatTraceBuilder for Builder { + fn append_value(&mut self, math: Option<&mut M>, value: &TransformAdaptation) { + let Self { step_size } = self; + step_size.append_value(math, &value.step_size); } fn finalize(self) -> Option { let Self { step_size } = self; - step_size.finalize() + >::finalize(step_size) } fn inspect(&self) -> Option { let Self { step_size } = self; - step_size.inspect() + >::inspect(step_size) } } impl SamplerStats for TransformAdaptation { - type Stats = Stats; type Builder = Builder; - - fn new_builder(&self, settings: &impl Settings, dim: usize) -> Self::Builder { - let step_size = SamplerStats::::new_builder(&self.step_size, settings, dim); + type StatOptions = (); + + fn new_builder( + &self, + _stat_options: Self::StatOptions, + settings: &impl Settings, + dim: usize, + ) -> Self::Builder { + let step_size = SamplerStats::::new_builder(&self.step_size, (), settings, dim); Builder { step_size } } - - fn current_stats(&self, math: &mut M) -> Self::Stats { - let step_size = self.step_size.current_stats(math); - Stats { step_size } - } } pub struct DrawCollector { draws: Vec, grads: Vec, + logps: Vec, collect_orbit: bool, max_energy_error: f64, } @@ -96,6 +90,7 @@ impl DrawCollector { Self { draws: vec![], grads: vec![], + logps: vec![], collect_orbit, max_energy_error, } @@ -134,6 +129,7 @@ impl> Collector for DrawCollector { self.draws.push(math.copy_array(point.position())); self.grads.push(math.copy_array(point.gradient())); + self.logps.push(point.logp()); } } @@ -158,6 +154,7 @@ impl> Collector for DrawCollector { self.draws.push(math.copy_array(point.position())); self.grads.push(math.copy_array(point.gradient())); + self.logps.push(point.logp()); } } } @@ -227,6 +224,7 @@ impl AdaptStrategy for TransformAdaptation { rng, collector.collector2.draws.iter(), collector.collector2.grads.iter(), + collector.collector2.logps.iter(), )?; } } else if (draw > 0) & (draw % self.options.transform_update_freq == 0) { @@ -235,6 +233,7 @@ impl AdaptStrategy for TransformAdaptation { rng, collector.collector2.draws.iter(), collector.collector2.grads.iter(), + collector.collector2.logps.iter(), )?; } self.step_size.update_estimator_early(); @@ -262,4 +261,8 @@ impl AdaptStrategy for TransformAdaptation { fn is_tuning(&self) -> bool { self.tuning } + + fn last_num_steps(&self) -> u64 { + self.step_size.last_n_steps + } } diff --git a/src/transformed_hamiltonian.rs b/src/transformed_hamiltonian.rs index 3518361..93581a7 100644 --- a/src/transformed_hamiltonian.rs +++ b/src/transformed_hamiltonian.rs @@ -1,8 +1,11 @@ use std::{marker::PhantomData, sync::Arc}; use arrow::{ - array::{ArrayBuilder, Float64Builder, StructArray}, - datatypes::{DataType, Field}, + array::{ + ArrayBuilder, FixedSizeListBuilder, Float64Builder, Int64Builder, PrimitiveBuilder, + StructArray, + }, + datatypes::{DataType, Field, Float64Type, Int64Type}, }; use crate::{ @@ -26,59 +29,153 @@ pub struct TransformedPoint { transform_id: i64, } -#[derive(Clone, Debug)] -pub struct TransformedPointStats { - pub fisher_distance: f64, -} - pub struct TransformedPointStatsBuilder { fisher_distance: Float64Builder, + transformed_position: Option>>, + transformed_gradient: Option>>, + transformation_index: PrimitiveBuilder, } -impl StatTraceBuilder for TransformedPointStatsBuilder { - fn append_value(&mut self, value: TransformedPointStats) { - let TransformedPointStats { fisher_distance } = value; +impl StatTraceBuilder> for TransformedPointStatsBuilder { + fn append_value(&mut self, math: Option<&mut M>, value: &TransformedPoint) { + let math = math.expect("Transformed point stats need math instance"); + let Self { + fisher_distance, + transformed_position, + transformed_gradient, + transformation_index, + } = self; + + fisher_distance.append_value( + math.sq_norm_sum(&value.transformed_position, &value.transformed_gradient), + ); + transformation_index.append_value(value.transform_id); - self.fisher_distance.append_value(fisher_distance); + if let Some(store) = transformed_position { + store + .values() + .append_slice(&math.box_array(&value.transformed_position)); + store.append(true); + } + if let Some(store) = transformed_gradient { + store + .values() + .append_slice(&math.box_array(&value.transformed_gradient)); + store.append(true); + } } fn finalize(self) -> Option { let Self { mut fisher_distance, + transformed_position, + transformed_gradient, + mut transformation_index, } = self; - let fields = vec![Field::new("fisher_distance", DataType::Float64, false)]; - let arrays = vec![ArrayBuilder::finish(&mut fisher_distance)]; + let mut fields = vec![ + Field::new("fisher_distance", DataType::Float64, false), + Field::new("transformation_index", DataType::Int64, false), + ]; + let mut arrays = vec![ + ArrayBuilder::finish(&mut fisher_distance), + ArrayBuilder::finish(&mut transformation_index), + ]; + + if let Some(mut transformed_position) = transformed_position { + let array = ArrayBuilder::finish(&mut transformed_position); + fields.push(Field::new( + "transformed_position", + array.data_type().clone(), + true, + )); + arrays.push(array); + } + + if let Some(mut transformed_gradient) = transformed_gradient { + let array = ArrayBuilder::finish(&mut transformed_gradient); + fields.push(Field::new( + "transformed_gradient", + array.data_type().clone(), + true, + )); + arrays.push(array); + } Some(StructArray::new(fields.into(), arrays, None)) } fn inspect(&self) -> Option { - let Self { fisher_distance } = self; + let Self { + fisher_distance, + transformed_position, + transformed_gradient, + transformation_index, + } = self; - let fields = vec![Field::new("fisher_distance", DataType::Float64, false)]; - let arrays = vec![ArrayBuilder::finish_cloned(fisher_distance)]; + let mut fields = vec![ + Field::new("fisher_distance", DataType::Float64, false), + Field::new("transformation_index", DataType::Int64, false), + ]; + let mut arrays = vec![ + ArrayBuilder::finish_cloned(fisher_distance), + ArrayBuilder::finish_cloned(transformation_index), + ]; + + if let Some(transformed_position) = transformed_position { + let array = ArrayBuilder::finish_cloned(transformed_position); + fields.push(Field::new( + "transformed_position", + array.data_type().clone(), + true, + )); + arrays.push(array); + } + + if let Some(transformed_gradient) = transformed_gradient { + let array = ArrayBuilder::finish_cloned(transformed_gradient); + fields.push(Field::new( + "transformed_gradient", + array.data_type().clone(), + true, + )); + arrays.push(array); + } Some(StructArray::new(fields.into(), arrays, None)) } } +#[derive(Debug, Clone, Copy)] +pub struct TransformedPointStatsOptions { + pub store_transformed: bool, +} + impl SamplerStats for TransformedPoint { - type Stats = TransformedPointStats; type Builder = TransformedPointStatsBuilder; + type StatOptions = TransformedPointStatsOptions; - fn new_builder(&self, settings: &impl Settings, _dim: usize) -> Self::Builder { - TransformedPointStatsBuilder { - fisher_distance: Float64Builder::with_capacity( - settings.hint_num_tune() + settings.hint_num_draws(), - ), + fn new_builder( + &self, + stat_options: Self::StatOptions, + settings: &impl Settings, + dim: usize, + ) -> Self::Builder { + let count = settings.hint_num_tune() + settings.hint_num_draws(); + + let mut transformed_position = None; + let mut transformed_gradient = None; + if stat_options.store_transformed { + let items = PrimitiveBuilder::new(); + transformed_position = Some(FixedSizeListBuilder::new(items, dim as _)); + let items = PrimitiveBuilder::new(); + transformed_gradient = Some(FixedSizeListBuilder::new(items, dim as _)); } - } - - fn current_stats(&self, math: &mut M) -> Self::Stats { - TransformedPointStats { - fisher_distance: math - .sq_norm_sum(&self.transformed_position, &self.transformed_gradient), + TransformedPointStatsBuilder { + fisher_distance: Float64Builder::with_capacity(count), + transformation_index: Int64Builder::with_capacity(count), + transformed_gradient, + transformed_position, } } } @@ -290,11 +387,13 @@ impl TransformedHamiltonian { rng: &mut R, draws: impl ExactSizeIterator, grads: impl ExactSizeIterator, + logps: impl ExactSizeIterator, ) -> Result<(), NutsError> { math.update_transformation( rng, draws, grads, + logps, self.params.as_mut().expect("Transformation was empty"), ) .map_err(|e| NutsError::BadInitGrad(Box::new(e)))?; @@ -302,19 +401,14 @@ impl TransformedHamiltonian { } } -#[derive(Debug, Clone, Default)] -pub struct Stats { - pub step_size: f64, -} - pub struct Builder { step_size: Float64Builder, } -impl StatTraceBuilder for Builder { - fn append_value(&mut self, value: Stats) { - let Stats { step_size } = value; - self.step_size.append_value(step_size); +impl StatTraceBuilder> for Builder { + fn append_value(&mut self, _math: Option<&mut M>, value: &TransformedHamiltonian) { + let Self { step_size } = self; + step_size.append_value(value.step_size); } fn finalize(self) -> Option { @@ -337,22 +431,21 @@ impl StatTraceBuilder for Builder { } impl SamplerStats for TransformedHamiltonian { - type Stats = Stats; type Builder = Builder; + type StatOptions = (); - fn new_builder(&self, settings: &impl Settings, _dim: usize) -> Self::Builder { + fn new_builder( + &self, + _stat_options: Self::StatOptions, + settings: &impl Settings, + _dim: usize, + ) -> Self::Builder { Builder { step_size: Float64Builder::with_capacity( settings.hint_num_draws() + settings.hint_num_tune(), ), } } - - fn current_stats(&self, _math: &mut M) -> Self::Stats { - Stats { - step_size: self.step_size, - } - } } impl Hamiltonian for TransformedHamiltonian { diff --git a/tests/sample_normal.rs b/tests/sample_normal.rs index 895d220..a6288bc 100644 --- a/tests/sample_normal.rs +++ b/tests/sample_normal.rs @@ -93,6 +93,7 @@ impl<'a> CpuLogpFunc for NormalLogp<'a> { _rng: &mut R, _untransformed_positions: impl Iterator, _untransformed_gradients: impl Iterator, + _logps: impl Iterator, _params: &'b mut Self::TransformParams, ) -> Result<(), Self::LogpError> { todo!()