Skip to content

Commit

Permalink
feat: Add sampler stats for points
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Oct 25, 2024
1 parent c416c42 commit 2a55602
Show file tree
Hide file tree
Showing 9 changed files with 180 additions and 31 deletions.
18 changes: 15 additions & 3 deletions src/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@ where
draw_count: u64,
strategy: A,
math: M,
stats: Option<NutsSampleStats<<A::Hamiltonian as SamplerStats<M>>::Stats, A::Stats>>,
stats: Option<
NutsSampleStats<
<<A::Hamiltonian as Hamiltonian<M>>::Point as SamplerStats<M>>::Stats,
<A::Hamiltonian as SamplerStats<M>>::Stats,
A::Stats,
>,
>,
}

impl<M, R, A> NutsChain<M, R, A>
Expand Down Expand Up @@ -117,17 +123,22 @@ where
A: AdaptStrategy<M>,
{
type Builder = NutsStatsBuilder<
<<A::Hamiltonian as Hamiltonian<M>>::Point as SamplerStats<M>>::Builder,
<A::Hamiltonian as SamplerStats<M>>::Builder,
<A as SamplerStats<M>>::Builder,
>;
type Stats =
NutsSampleStats<<A::Hamiltonian as SamplerStats<M>>::Stats, <A as SamplerStats<M>>::Stats>;
type Stats = NutsSampleStats<
<<A::Hamiltonian as Hamiltonian<M>>::Point as SamplerStats<M>>::Stats,
<A::Hamiltonian as SamplerStats<M>>::Stats,
<A as SamplerStats<M>>::Stats,
>;

fn new_builder(&self, settings: &impl Settings, dim: usize) -> Self::Builder {
NutsStatsBuilder::new_with_capacity(
settings,
&self.hamiltonian,
&self.strategy,
self.init.point(),
dim,
&self.options,
)
Expand Down Expand Up @@ -182,6 +193,7 @@ where
draw: self.draw_count,
potential_stats: self.hamiltonian.current_stats(&mut self.math),
strategy_stats: self.strategy.current_stats(&mut self.math),
point_stats: state.point().current_stats(&mut self.math),
gradient: if self.options.store_gradient {
let mut gradient: Box<[f64]> = vec![0f64; self.math.dim()].into();
state.write_gradient(&mut self.math, &mut gradient);
Expand Down
8 changes: 8 additions & 0 deletions src/cpu_math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
)
}

fn sq_norm_sum(&mut self, x: &Self::Vector, y: &Self::Vector) -> f64 {
x.as_slice()
.iter()
.zip(y.as_slice())
.map(|(&x, &y)| (x + y) * (x + y))
.sum()
}

fn read_from_slice(&mut self, dest: &mut Self::Vector, source: &[f64]) {
dest.as_slice_mut().copy_from_slice(source);
}
Expand Down
34 changes: 34 additions & 0 deletions src/euclidean_hamiltonian.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,40 @@ pub struct EuclideanPoint<M: Math> {
pub initial_energy: f64,
}

#[derive(Clone, Debug)]
pub struct PointStats {}

pub struct PointStatsBuilder {}

impl StatTraceBuilder<PointStats> for PointStatsBuilder {
fn append_value(&mut self, value: PointStats) {
let PointStats {} = value;
}

fn finalize(self) -> Option<StructArray> {
let Self {} = self;
None
}

fn inspect(&self) -> Option<StructArray> {
let Self {} = self;
None
}
}

impl<M: Math> SamplerStats<M> for EuclideanPoint<M> {
type Stats = PointStats;
type Builder = PointStatsBuilder;

fn new_builder(&self, _settings: &impl Settings, _dim: usize) -> Self::Builder {
Self::Builder {}
}

fn current_stats(&self, _math: &mut M) -> Self::Stats {
PointStats {}
}
}

impl<M: Math> EuclideanPoint<M> {
fn is_turning(&self, math: &mut M, other: &Self) -> bool {
let (start, end) = if self.index_in_trajectory() < other.index_in_trajectory() {
Expand Down
2 changes: 1 addition & 1 deletion src/hamiltonian.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ pub enum LeapfrogResult<M: Math, P: Point<M>> {
Err(M::LogpErr),
}

pub trait Point<M: Math>: Sized {
pub trait Point<M: Math>: Sized + SamplerStats<M> {
fn position(&self) -> &M::Vector;
fn gradient(&self) -> &M::Vector;
fn index_in_trajectory(&self) -> i64;
Expand Down
2 changes: 2 additions & 0 deletions src/math_base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ pub trait Math {
y: &Self::Vector,
) -> (f64, f64);

fn sq_norm_sum(&mut self, x: &Self::Vector, y: &Self::Vector) -> f64;

fn read_from_slice(&mut self, dest: &mut Self::Vector, source: &[f64]);
fn write_to_slice(&mut self, source: &Self::Vector, dest: &mut [f64]);
fn eigs_as_array(&mut self, source: &Self::EigValues) -> Box<[f64]>;
Expand Down
30 changes: 24 additions & 6 deletions src/nuts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,11 @@ where

#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct NutsSampleStats<HStats: Send + Debug + Clone, AdaptStats: Send + Debug + Clone> {
pub struct NutsSampleStats<
PointStats: Send + Debug + Clone,
HStats: Send + Debug + Clone,
AdaptStats: Send + Debug + Clone,
> {
pub depth: u64,
pub maxdepth_reached: bool,
pub idx_in_trajectory: i64,
Expand All @@ -324,6 +328,7 @@ pub struct NutsSampleStats<HStats: Send + Debug + Clone, AdaptStats: Send + Debu
pub unconstrained: Option<Box<[f64]>>,
pub potential_stats: HStats,
pub strategy_stats: AdaptStats,
pub point_stats: PointStats,
pub tuning: bool,
}

Expand All @@ -338,7 +343,7 @@ pub struct SampleStats {
pub num_steps: u64,
}

pub struct NutsStatsBuilder<H, A> {
pub struct NutsStatsBuilder<P, H, A> {
depth: PrimitiveBuilder<UInt64Type>,
maxdepth_reached: BooleanBuilder,
index_in_trajectory: PrimitiveBuilder<Int64Type>,
Expand All @@ -351,6 +356,7 @@ pub struct NutsStatsBuilder<H, A> {
gradient: Option<FixedSizeListBuilder<PrimitiveBuilder<Float64Type>>>,
hamiltonian: H,
adapt: A,
point: P,
diverging: BooleanBuilder,
divergence_start: Option<FixedSizeListBuilder<PrimitiveBuilder<Float64Type>>>,
divergence_start_grad: Option<FixedSizeListBuilder<PrimitiveBuilder<Float64Type>>>,
Expand All @@ -360,15 +366,17 @@ pub struct NutsStatsBuilder<H, A> {
n_dim: usize,
}

impl<HB, AB> NutsStatsBuilder<HB, AB> {
impl<PB, HB, AB> NutsStatsBuilder<PB, HB, AB> {
pub fn new_with_capacity<
M: Math,
H: Hamiltonian<M, Builder = HB>,
P: Point<M, Builder = PB>,
H: Hamiltonian<M, Builder = HB, Point = P>,
A: AdaptStrategy<M, Builder = AB>,
>(
settings: &impl Settings,
hamiltonian: &H,
adapt: &A,
point: &P,
dim: usize,
options: &NutsOptions,
) -> Self {
Expand Down Expand Up @@ -430,6 +438,7 @@ impl<HB, AB> NutsStatsBuilder<HB, AB> {
unconstrained,
hamiltonian: hamiltonian.new_builder(settings, dim),
adapt: adapt.new_builder(settings, dim),
point: point.new_builder(settings, dim),
diverging: BooleanBuilder::with_capacity(capacity),
divergence_start: div_start,
divergence_start_grad: div_start_grad,
Expand All @@ -441,14 +450,17 @@ impl<HB, AB> NutsStatsBuilder<HB, AB> {
}
}

impl<HS, AS, HB, AB> StatTraceBuilder<NutsSampleStats<HS, AS>> for NutsStatsBuilder<HB, AB>
impl<PS, HS, AS, PB, HB, AB> StatTraceBuilder<NutsSampleStats<PS, HS, AS>>
for NutsStatsBuilder<PB, HB, AB>
where
HB: StatTraceBuilder<HS>,
AB: StatTraceBuilder<AS>,
PB: StatTraceBuilder<PS>,
HS: Clone + Send + Debug,
AS: Clone + Send + Debug,
PS: Clone + Send + Debug,
{
fn append_value(&mut self, value: NutsSampleStats<HS, AS>) {
fn append_value(&mut self, value: NutsSampleStats<PS, HS, AS>) {
let NutsSampleStats {
depth,
maxdepth_reached,
Expand All @@ -463,6 +475,7 @@ where
unconstrained,
potential_stats,
strategy_stats,
point_stats,
tuning,
} = value;

Expand Down Expand Up @@ -532,6 +545,7 @@ where

self.hamiltonian.append_value(potential_stats);
self.adapt.append_value(strategy_stats);
self.point.append_value(point_stats);
}

fn finalize(self) -> Option<StructArray> {
Expand All @@ -548,6 +562,7 @@ where
gradient,
hamiltonian,
adapt,
point,
mut diverging,
divergence_start,
divergence_start_grad,
Expand Down Expand Up @@ -615,6 +630,7 @@ where

merge_into(hamiltonian, &mut arrays, &mut fields);
merge_into(adapt, &mut arrays, &mut fields);
merge_into(point, &mut arrays, &mut fields);

add_field(gradient, "gradient", &mut arrays, &mut fields);
add_field(
Expand Down Expand Up @@ -667,6 +683,7 @@ where
gradient,
hamiltonian,
adapt,
point,
diverging,
divergence_start,
divergence_start_grad,
Expand Down Expand Up @@ -734,6 +751,7 @@ where

merge_into(hamiltonian, &mut arrays, &mut fields);
merge_into(adapt, &mut arrays, &mut fields);
merge_into(point, &mut arrays, &mut fields);

add_field(gradient, "gradient", &mut arrays, &mut fields);
add_field(
Expand Down
4 changes: 2 additions & 2 deletions src/sampler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,7 @@ pub mod test_logps {
}
}

impl<'a> CpuLogpFunc for &'a NormalLogp {
impl CpuLogpFunc for &NormalLogp {
type LogpError = NormalLogpError;
type TransformParams = ();

Expand Down Expand Up @@ -981,7 +981,7 @@ pub mod test_logps {
for (p, g) in pos.chunks_exact(4).zip(grad.chunks_exact_mut(4)) {
let p = f64x4::from_slice(p);
let val = mu_splat - p;
logp = logp - val * val * f64x4::splat(0.5);
logp = val * val * f64x4::splat(0.5);
g.copy_from_slice(&val.to_array());
}

Expand Down
56 changes: 37 additions & 19 deletions src/transform_adapt_strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,38 +108,56 @@ impl<M: Math, P: Point<M>> Collector<M, P> for DrawCollector<M> {
math: &mut M,
_start: &State<M, P>,
end: &State<M, P>,
_divergence_info: Option<&crate::DivergenceInfo>,
divergence_info: Option<&crate::DivergenceInfo>,
) {
if divergence_info.is_some() {
return;
}

if self.collect_orbit {
let point = end.point();
let energy_error = point.energy_error();
if energy_error.abs() < self.max_energy_error {
if !math.array_all_finite(point.position()) {
return;
}
if !math.array_all_finite(point.gradient()) {
return;
}
self.draws.push(math.copy_array(point.position()));
self.grads.push(math.copy_array(point.gradient()));
if !energy_error.is_finite() {
return;
}

if energy_error > self.max_energy_error {
return;
}

if !math.array_all_finite(point.position()) {
return;
}
if !math.array_all_finite(point.gradient()) {
return;
}

self.draws.push(math.copy_array(point.position()));
self.grads.push(math.copy_array(point.gradient()));
}
}

fn register_draw(&mut self, math: &mut M, state: &State<M, P>, _info: &SampleInfo) {
if !self.collect_orbit {
let point = state.point();
let energy_error = point.energy_error();
if energy_error.abs() < self.max_energy_error {
if !math.array_all_finite(point.position()) {
return;
}
if !math.array_all_finite(point.gradient()) {
return;
}
self.draws.push(math.copy_array(point.position()));
self.grads.push(math.copy_array(point.gradient()));
if !energy_error.is_finite() {
return;
}

if energy_error > self.max_energy_error {
return;
}

if !math.array_all_finite(point.position()) {
return;
}
if !math.array_all_finite(point.gradient()) {
return;
}

self.draws.push(math.copy_array(point.position()));
self.grads.push(math.copy_array(point.gradient()));
}
}
}
Expand Down
Loading

0 comments on commit 2a55602

Please sign in to comment.