Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 63: Implementing SBC #68

Merged
merged 27 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
e17bff6
implementing sbc
SamuelBrand1 Feb 10, 2025
2b7e681
sbc method
SamuelBrand1 Feb 11, 2025
fda44bc
pass number of draws to plotting
SamuelBrand1 Feb 11, 2025
66fb72e
Update sbc.py
SamuelBrand1 Feb 11, 2025
62ada9c
Merge branch 'main' into spcb-sbc
SamuelBrand1 Feb 11, 2025
fa45689
lint
SamuelBrand1 Feb 11, 2025
e8b4648
unit tests
SamuelBrand1 Feb 11, 2025
0f38bc2
Tidying up and extending example
SamuelBrand1 Feb 11, 2025
cb6e2b5
update note
SamuelBrand1 Feb 11, 2025
972c32e
Merge remote-tracking branch 'origin/main' into spcb-sbc
AFg6K7h4fhy2 Feb 11, 2025
fb7b16d
pre-commit fixes
AFg6K7h4fhy2 Feb 11, 2025
a6b9e88
light docstring formatting
AFg6K7h4fhy2 Feb 11, 2025
9e83506
revert then modify seed param edit
AFg6K7h4fhy2 Feb 11, 2025
e798c85
Update forecasttools/sbc_plots.py
SamuelBrand1 Feb 12, 2025
da3750b
Update sbc_model_checking.qmd
SamuelBrand1 Feb 12, 2025
a1adf7b
update about rank statistics
SamuelBrand1 Feb 12, 2025
cd242ea
type-hint change
SamuelBrand1 Feb 12, 2025
241f31c
change axes iterator label
SamuelBrand1 Feb 12, 2025
ef77270
Update forecasttools/sbc_plots.py
SamuelBrand1 Feb 12, 2025
fad6539
docstrings for UniformCDF
SamuelBrand1 Feb 12, 2025
6da1696
typehint for _get_prior_predictive_samples return
SamuelBrand1 Feb 12, 2025
d266670
type hints for _get_posterior_samples
SamuelBrand1 Feb 12, 2025
501ba7c
typehint for run_simulations
SamuelBrand1 Feb 12, 2025
8346deb
refactor SBC inner loop
SamuelBrand1 Feb 14, 2025
9efc9bb
Add unit test and docstring
SamuelBrand1 Feb 14, 2025
250e49d
add inspection mode
SamuelBrand1 Feb 14, 2025
7ee709c
remove hand rolled uniformCDF class
SamuelBrand1 Feb 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 253 additions & 0 deletions forecasttools/sbc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
import arviz as az
import jax.numpy as jnp
import numpyro
from jax import random
from numpyro.infer import MCMC
from numpyro.infer.mcmc import MCMCKernel
from tqdm import tqdm
SamuelBrand1 marked this conversation as resolved.
Show resolved Hide resolved

from forecasttools.sbc_plots import plot_results


class SBC:
def __init__(
self,
mcmc_kernel: MCMCKernel,
*args,
observed_vars: dict[str, str],
num_simulations=10,
sample_kwargs=None,
seed=None,
inspection_mode=False,
**kwargs,
) -> None:
"""
Set up class for doing SBC.
Based on simulation based calibration (Talts et. al. 2018) in PyMC.

Parameters
----------
mcmc_kernel : numpyro.infer.mcmc.MCMCKernel
An instance of a numpyo MCMC kernel object.
observed_vars : dict[str, str]
A dictionary mapping observed/response variable name as a kwarg to
the numpyro model to the corresponding variable name sampled using
`numpyro.sample`.
args : tuple
Positional arguments passed to `numpyro.sample`.
num_simulations : int
How many simulations to run for SBC.
sample_kwargs : dict[str, Any]
Arguments passed to `numpyro.sample`. Defaults to
`dict(num_warmup=500, num_samples=100, progress_bar = False)`.
Which assumes a MCMC sampler e.g. NUTS.
seed : random.PRNGKey
Random seed.
kwargs : dict[str, Any]
Keyword arguments passed to `numpyro` models.
"""
if sample_kwargs is None:
sample_kwargs = dict(
num_warmup=500, num_samples=100, progress_bar=False
)
if seed is None:
seed = random.PRNGKey(1234)
self.mcmc_kernel = mcmc_kernel
if not hasattr(mcmc_kernel, "model"):
raise ValueError(
"The `mcmc_kernel` must have a 'model' attribute."
)

self.model = mcmc_kernel.model
self.args = args
self.kwargs = kwargs
self.observed_vars = observed_vars

for key in self.observed_vars:
if key in self.kwargs and self.kwargs[key] is not None:
raise ValueError(
f"The value for '{key}' in kwargs must be None for this to"
" be a prior predictive check."
)
AFg6K7h4fhy2 marked this conversation as resolved.
Show resolved Hide resolved

self.num_simulations = num_simulations
self.sample_kwargs = sample_kwargs
# Initialize the simulations and random seeds
self.simulations = {}
self._simulations_complete = 0
prior_pred_rng, sampler_rng = random.split(seed)
self._prior_pred_rng = prior_pred_rng
self._sampler_rng = sampler_rng
self.num_samples = None
# Set the inspection mode
# if in inspection mode, store all idata objects from fitting
self.inspection_mode = inspection_mode
if inspection_mode:
self.idatas = []

def _get_prior_predictive_samples(
self,
) -> tuple[dict[str, any], dict[str, any]]:
"""
Generate samples to use for the simulations by prior predictive
sampling. Then splits between observed and unobserved variables based
on the `observed_vars` attribute.

Returns
-------
tuple[dict[str, any], dict[str, any]]
The prior and prior predictive samples.
"""
prior_predictive_fn = numpyro.infer.Predictive(
self.mcmc_kernel.model, num_samples=self.num_simulations
)
prior_predictions = prior_predictive_fn(
self._prior_pred_rng, *self.args, **self.kwargs
)
prior_pred = {
k: prior_predictions[v] for k, v in self.observed_vars.items()
}
prior = {
k: v
for k, v in prior_predictions.items()
if k not in self.observed_vars.values()
}
return prior, prior_pred

def _get_posterior_samples(
self, seed: random.PRNGKey, prior_predictive_draw: dict[str, any]
) -> tuple[az.InferenceData, int]:
"""
Generate posterior samples conditioned to a prior predictive sample.
This returns the posterior samples and the number of samples. The
number of samples are used in scaling plotting and checking that each
inference draw has the same number of samples.

Parameters
----------
seed : random.PRNGKey
Random seed for MCMC sampling.
prior_predictive_draw : dict[str, any]
Prior predictive samples.

Returns
-------
tuple[az.InferenceData, int]
Posterior samples as an arviz InferenceData object, with the count
of posterior samples.
"""
mcmc = MCMC(self.mcmc_kernel, **self.sample_kwargs)
obs_vars = {**self.kwargs, **prior_predictive_draw}
mcmc.run(seed, *self.args, **obs_vars)
num_samples = mcmc.num_samples
# Check that the number of samples is consistent
if self.num_samples is None:
self.num_samples = num_samples
if self.num_samples != num_samples:
raise ValueError(
"The number of samples from the posterior is not consistent."
)
idata = az.from_numpyro(mcmc)
return idata

def _increment_rank_statistics(self, prior_draw, posterior) -> None:
"""
Increment the rank statistics for each parameter in the prior draw.

This method updates the `self.simulations` dictionary with the rank
statistics for each parameter in the `prior_draw` compared to the
`posterior`.

Returns:
None
"""
for name in prior_draw:
num_dims = jnp.ndim(prior_draw[name])
if num_dims == 0:
rank_statistics = (
(posterior[name].sel(chain=0) < prior_draw[name])
.sum()
.values
)
self.simulations[name].append(rank_statistics)
else:
rank_statistics = (
(posterior[name].sel(chain=0) < prior_draw[name])
.sum(axis=0)
.values
)
self.simulations[name].append(rank_statistics)

def run_simulations(self) -> None:
SamuelBrand1 marked this conversation as resolved.
Show resolved Hide resolved
"""
The main method of `SBC` class that runs the simulations for
simulation based calibration and fills the `simulations` attribute
with the results.
"""
prior, prior_pred = self._get_prior_predictive_samples()
sampler_seeds = random.split(self._sampler_rng, self.num_simulations)
self.simulations = {name: [] for name in prior}
progress = tqdm(
initial=self._simulations_complete,
total=self.num_simulations,
)
if self.inspection_mode:
self.prior = prior
self.prior_pred = prior_pred
try:
SamuelBrand1 marked this conversation as resolved.
Show resolved Hide resolved
while self._simulations_complete < self.num_simulations:
SamuelBrand1 marked this conversation as resolved.
Show resolved Hide resolved
idx = self._simulations_complete
prior_draw = {k: v[idx] for k, v in prior.items()}
prior_predictive_draw = {
k: v[idx] for k, v in prior_pred.items()
}
idata = self._get_posterior_samples(
sampler_seeds[idx], prior_predictive_draw
)
if self.inspection_mode:
self.idatas.append(idata)
self._increment_rank_statistics(prior_draw, idata["posterior"])
self._simulations_complete += 1
progress.update()
finally:
self.simulations = {
k: v[: self._simulations_complete]
for k, v in self.simulations.items()
}
progress.close()

def plot_results(self, kind="ecdf", var_names=None, color="C0"):
"""
Visual diagnostic for SBC.

Currently it support two options: `ecdf` for the empirical CDF plots
of the difference between prior and posterior. `hist` for the rank
histogram.

Parameters
----------
simulations
The SBC.simulations dictionary.
kind : str
What kind of plot to make. Supported values are 'ecdf' (default)
and 'hist'
var_names : list[str]
Variables to plot (defaults to all)
figsize : tuple
Figure size for the plot. If None, it will be defined
automatically.
color : str
Color to use for the eCDF or histogram

Returns
-------
fig, axes
matplotlib figure and axes
"""
return plot_results(
self.simulations,
self.num_samples,
kind=kind,
var_names=var_names,
color=color,
)
137 changes: 137 additions & 0 deletions forecasttools/sbc_plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""
Plots for the simulation based calibration
"""

import itertools

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import numpyro.distributions as dist
from scipy.special import bdtrik
SamuelBrand1 marked this conversation as resolved.
Show resolved Hide resolved


def plot_results(
simulations, ndraws, kind="ecdf", var_names=None, figsize=None, color="C0"
):
"""
Visual diagnostic for SBC.

Currently it support two options: `ecdf` for the empirical CDF plots
of the difference between prior and posterior. `hist` for the rank
histogram.

Parameters
----------
simulations : dict[str, Any]
The SBC.simulations dictionary.
ndraws : int
Number of draws in each posterior predictive sample
kind : str
What kind of plot to make. Supported values are 'ecdf' (default)
and 'hist'
var_names : list[str]
Variables to plot (defaults to all)
figsize : tuple
Figure size for the plot. If None, it will be defined automatically.
color : str
Color to use for the eCDF or histogram

Returns
-------
fig, axes
matplotlib figure and axes
"""

if kind not in ["ecdf", "hist"]:
raise ValueError(f"kind must be 'ecdf' or 'hist', not {kind}")

if var_names is None:
var_names = list(simulations.keys())

sims = {}
AFg6K7h4fhy2 marked this conversation as resolved.
Show resolved Hide resolved
for k in var_names:
ary = np.array(simulations[k])
while ary.ndim < 2:
ary = np.expand_dims(ary, -1)
sims[k] = ary

n_plots = sum(np.prod(v.shape[1:]) for v in sims.values())

if n_plots > 1:
if figsize is None:
figsize = (8, n_plots * 1.0)
AFg6K7h4fhy2 marked this conversation as resolved.
Show resolved Hide resolved

fig, axes = plt.subplots(
nrows=(n_plots + 1) // 2, ncols=2, figsize=figsize, sharex=True
)
axes = axes.flatten()
else:
if figsize is None:
figsize = (8, 1.5)

fig, axes = plt.subplots(nrows=1, ncols=1, figsize=figsize)
axes = [axes]

if kind == "ecdf":
cdf = dist.DiscreteUniform(high=ndraws).cdf

idx = 0
for var_name, var_data in sims.items():
plot_idxs = list(
itertools.product(*(np.arange(s) for s in var_data.shape[1:]))
)

for indices in plot_idxs:
if len(plot_idxs) > 1: # has dims
dim_label = f"{var_name}[{']['.join(map(str, indices))}]"
else:
dim_label = var_name
ax = axes[idx]
ary = var_data[(...,) + indices]
if kind == "ecdf":
az.plot_ecdf(
ary,
cdf=cdf,
difference=True,
pit=True,
confidence_bands="auto",
plot_kwargs={"color": color},
fill_kwargs={"color": color},
ax=ax,
)
else:
hist(ary, color=color, ax=ax)
ax.set_title(dim_label)
ax.set_yticks([])
idx += 1

for extra_ax in range(n_plots, len(axes)):
fig.delaxes(axes[extra_ax])

return fig, axes


def hist(ary, color, ax):
hist, bins = np.histogram(ary, bins="auto")
bin_centers = 0.5 * (bins[:-1] + bins[1:])
max_rank = np.ceil(bins[-1])
len_bins = len(bins)
n_sims = len(ary)

band = np.ceil(bdtrik([0.025, 0.5, 0.975], n_sims, 1 / len_bins))
ax.bar(
bin_centers,
hist,
width=bins[1] - bins[0],
color=color,
edgecolor="black",
)
ax.axhline(band[1], color="0.5", ls="--")
ax.fill_between(
np.linspace(0, max_rank, len_bins),
band[0],
band[2],
color="0.5",
AFg6K7h4fhy2 marked this conversation as resolved.
Show resolved Hide resolved
alpha=0.5,
)
Loading