Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prepare for GPU sampling support #17

Merged
merged 3 commits into from
Oct 25, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
docs: Update README
  • Loading branch information
aseyboldt committed Oct 9, 2023
commit 6031ee0e4bbed9b73144fbb0a06f88507e617674
12 changes: 4 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@ for this sampler is [nutpie](https://github.com/pymc-devs/nutpie).
```rust
use nuts_rs::{CpuLogpFunc, LogpError, new_sampler, SamplerArgs, Chain, SampleStats};
use thiserror::Error;
use rand::thread_rng;

// Define a function that computes the unnormalized posterior density
// and its gradient.
@@ -55,32 +56,27 @@ impl CpuLogpFunc for PosteriorDensity {
let mut sampler_args = SamplerArgs::default();

// and modify as we like
sampler_args.step_size_adapt.target = 0.8;
sampler_args.num_tune = 1000;
sampler_args.maxdepth = 3; // small value just for testing...
sampler_args.mass_matrix_adapt.store_mass_matrix = true;

// We instanciate our posterior density function
let logp_func = PosteriorDensity {};

let chain = 0;
let seed = 42;
let mut sampler = new_sampler(logp_func, sampler_args, chain, seed);
let mut rng = thread_rng();
let mut sampler = new_sampler(logp_func, sampler_args, chain, &mut rng);

// Set to some initial position and start drawing samples.
sampler.set_position(&vec![0f64; 10]).expect("Unrecoverable error during init");
let mut trace = vec![]; // Collection of all draws
let mut stats = vec![]; // Collection of statistics like the acceptance rate for each draw
for _ in 0..2000 {
let (draw, info) = sampler.draw().expect("Unrecoverable error during sampling");
trace.push(draw);
let _info_vec = info.to_vec(); // We can collect the stats in a Vec
// Or get more detailed information about divergences
if let Some(div_info) = info.divergence_info() {
println!("Divergence at position {:?}", div_info.start_location());
println!("Divergence at position {:?}", div_info.start_location);
}
dbg!(&info);
stats.push(info);
}
```

9 changes: 2 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -84,13 +84,8 @@
//! ## Implementation details
//!
//! This crate mostly follows the implementation of NUTS in [Stan](https://mc-stan.org) and
//! [PyMC](https://docs.pymc.io/en/v3/), only tuning of mass matrix and step size differs:
//! In a first window we sample using the identity as mass matrix and adapt the
//! step size using the normal dual averaging algorithm.
//! After `discard_window` draws we start computing a diagonal mass matrix using
//! an exponentially decaying estimate for `sqrt(sample_var / grad_var)`.
//! After `2 * discard_window` draws we switch to the entimated mass mass_matrix
//! and keep adapting it live until `stop_tune_at`.
//! [PyMC](https://docs.pymc.io/en/v3/), only tuning of mass matrix and step size differs
//! somewhat.

pub(crate) mod adapt_strategy;
pub(crate) mod cpu_potential;