Skip to content

Commit

Permalink
Add beanmachine tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zaxtax committed Dec 5, 2022
1 parent de681df commit 5035989
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 0 deletions.
92 changes: 92 additions & 0 deletions arviz/tests/external_tests/test_data_beanmachine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# pylint: disable=no-member, invalid-name, redefined-outer-name
import numpy as np
import packaging
import pytest

from ...data.io_beanmachine import from_beanmachine # pylint: disable=wrong-import-position
from ..helpers import ( # pylint: disable=unused-import, wrong-import-position
chains,
check_multiple_attrs,
draws,
eight_schools_params,
importorskip,
load_cached_models,
)

# Skip all tests if pyro or pytorch not installed
torch = importorskip("torch")
bm = importorskip("beanmachine.ppl")
dist = torch.distributions


class TestDataBeanMachine:
@pytest.fixture(scope="class")
def data(self, eight_schools_params, draws, chains):
class Data:
model, obj = load_cached_models(eight_schools_params, draws, chains, "beanmachine")["beanmachine"]

return Data

@pytest.fixture(scope="class")
def predictions_params(self):
"""Predictions data for eight schools."""
return {
"J": 8,
"sigma": torch.tensor([5.0, 7.0, 12.0, 4.0, 6.0, 10.0, 3.0, 9.0]),
}

@pytest.fixture(scope="class")
def predictions_data(self, data, predictions_params):
"""Generate predictions for predictions_params"""
posterior_samples = data.obj
model = data.model
predictions = bm.inference.predictive.simulate([model.obs()], posterior_samples)
return predictions

def get_inference_data(self, data, eight_schools_params, predictions_data):
posterior_samples = data.obj
model = data.model
predictions = predictions_data
return from_beanmachine(
sampler=predictions,
coords={
"school": np.arange(eight_schools_params["J"]),
"school_pred": np.arange(eight_schools_params["J"]),
},
)

def test_inference_data(self, data, eight_schools_params, predictions_data):
inference_data = self.get_inference_data(data, eight_schools_params, predictions_data)
mu = data.model.mu()
tau = data.model.tau()
eta = data.model.eta()
obs = data.model.obs()

assert mu in inference_data.posterior
assert tau in inference_data.posterior
assert eta in inference_data.posterior
assert obs in inference_data.posterior_predictive

def test_inference_data_has_log_likelihood_and_observed_data(self, data):
idata = from_beanmachine(data.obj)
obs = data.model.obs()

assert obs in idata.log_likelihood
assert obs in idata.observed_data

def test_inference_data_no_posterior(
self, data, eight_schools_params, predictions_data, predictions_params
):
posterior_samples = data.obj
model = data.model
prior = bm.GlobalNoUTurnSampler().infer(
queries=[model.mu(), model.tau(), model.eta()],
observations={},
num_samples=100,
num_adaptive_samples=100,
num_chains=2,
)
# only prior
inference_data = from_beanmachine(prior)
assert not model.obs() in inference_data.posterior
assert "observed_data" not in inference_data
38 changes: 38 additions & 0 deletions arviz/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,43 @@ def pystan_noncentered_schools(data, draws, chains):
return stan_model, fit


def bm_schools_model(data, draws, chains):
import beanmachine.ppl as bm
import torch
import torch.distributions as dist

class EightSchools:
@bm.random_variable
def mu(self):
return dist.Normal(0, 5)

@bm.random_variable
def tau(self):
return dist.HalfCauchy(5)

@bm.random_variable
def eta(self):
return dist.Normal(0, 1).expand((data["J"],))

@bm.functional
def theta(self):
return self.mu() + self.tau() * self.eta()

@bm.random_variable
def obs(self):
return dist.Normal(self.theta(), torch.from_numpy(data["sigma"]).float())

model = EightSchools()
posterior = bm.GlobalNoUTurnSampler().infer(
queries=[model.mu(), model.tau(), model.eta()],
observations={model.obs(): torch.from_numpy(data["y"]).float()},
num_samples=draws,
num_adaptive_samples=500,
num_chains=chains,
)
return model, posterior


def library_handle(library):
"""Import a library and return the handle."""
if library == "pystan":
Expand All @@ -506,6 +543,7 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
("emcee", emcee_schools_model),
("pyro", pyro_noncentered_schools),
("numpyro", numpyro_schools_model),
("beanmachine", bm_schools_model),
)
data_directory = os.path.join(here, "saved_models")
models = {}
Expand Down

0 comments on commit 5035989

Please sign in to comment.