Skip to content

Commit

Permalink
Add None for fields not used in beanmachine converter (#2154)
Browse files Browse the repository at this point in the history
* Add None for fields not used by beanmachine converter

* Add beanmachine tests

* Lint fix

* Preprocess prior

* Update arviz/tests/external_tests/test_data_beanmachine.py

Co-authored-by: Oriol Abril-Pla <oriol.abril.pla@gmail.com>

* Update arviz/tests/external_tests/test_data_beanmachine.py

Co-authored-by: Oriol Abril-Pla <oriol.abril.pla@gmail.com>

* Update CHANGELOG

Co-authored-by: Oriol Abril-Pla <oriol.abril.pla@gmail.com>
  • Loading branch information
zaxtax and OriolAbril authored Dec 6, 2022
1 parent f558290 commit af34307
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

### Maintenance and fixes
- Fix `reloo` outdated usage of `ELPDData` ([2158](https://github.com/arviz-devs/arviz/pull/2158))
- Fix bug when beanmachine objects lack some fields ([2154](https://github.com/arviz-devs/arviz/pull/2154))

### Deprecation

Expand Down
8 changes: 8 additions & 0 deletions arviz/data/io_beanmachine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,23 @@ def __init__(

if "posterior" in self.sampler.namespaces:
self.posterior = self.sampler.namespaces["posterior"].samples
else:
self.posterior = None

if "posterior_predictive" in self.sampler.namespaces:
self.posterior_predictive = self.sampler.namespaces["posterior_predictive"].samples
else:
self.posterior_predictive = None

if self.sampler.log_likelihoods is not None:
self.log_likelihoods = self.sampler.log_likelihoods
else:
self.log_likelihoods = None

if self.sampler.observations is not None:
self.observations = self.sampler.observations
else:
self.observations = None

@requires("posterior")
def posterior_to_xarray(self):
Expand Down
76 changes: 76 additions & 0 deletions arviz/tests/external_tests/test_data_beanmachine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# pylint: disable=no-member, invalid-name, redefined-outer-name
import numpy as np
import pytest

from ...data.io_beanmachine import from_beanmachine # pylint: disable=wrong-import-position
from ..helpers import ( # pylint: disable=unused-import, wrong-import-position
chains,
draws,
eight_schools_params,
importorskip,
load_cached_models,
)

# Skip all tests if beanmachine or pytorch not installed
torch = importorskip("torch")
bm = importorskip("beanmachine.ppl")
dist = torch.distributions


class TestDataBeanMachine:
@pytest.fixture(scope="class")
def data(self, eight_schools_params, draws, chains):
class Data:
model, prior, obj = load_cached_models(
eight_schools_params,
draws,
chains,
"beanmachine",
)["beanmachine"]

return Data

@pytest.fixture(scope="class")
def predictions_data(self, data):
"""Generate predictions for predictions_params"""
posterior_samples = data.obj
model = data.model
predictions = bm.inference.predictive.simulate([model.obs()], posterior_samples)
return predictions

def get_inference_data(self, eight_schools_params, predictions_data):
predictions = predictions_data
return from_beanmachine(
sampler=predictions,
coords={
"school": np.arange(eight_schools_params["J"]),
"school_pred": np.arange(eight_schools_params["J"]),
},
)

def test_inference_data(self, data, eight_schools_params, predictions_data):
inference_data = self.get_inference_data(eight_schools_params, predictions_data)
model = data.model
mu = model.mu()
tau = model.tau()
eta = model.eta()
obs = model.obs()

assert mu in inference_data.posterior
assert tau in inference_data.posterior
assert eta in inference_data.posterior
assert obs in inference_data.posterior_predictive

def test_inference_data_has_log_likelihood_and_observed_data(self, data):
idata = from_beanmachine(data.obj)
obs = data.model.obs()

assert obs in idata.log_likelihood
assert obs in idata.observed_data

def test_inference_data_no_posterior(self, data):
model = data.model
# only prior
inference_data = from_beanmachine(data.prior)
assert not model.obs() in inference_data.posterior
assert "observed_data" not in inference_data
47 changes: 47 additions & 0 deletions arviz/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,52 @@ def pystan_noncentered_schools(data, draws, chains):
return stan_model, fit


def bm_schools_model(data, draws, chains):
import beanmachine.ppl as bm
import torch
import torch.distributions as dist

class EightSchools:
@bm.random_variable
def mu(self):
return dist.Normal(0, 5)

@bm.random_variable
def tau(self):
return dist.HalfCauchy(5)

@bm.random_variable
def eta(self):
return dist.Normal(0, 1).expand((data["J"],))

@bm.functional
def theta(self):
return self.mu() + self.tau() * self.eta()

@bm.random_variable
def obs(self):
return dist.Normal(self.theta(), torch.from_numpy(data["sigma"]).float())

model = EightSchools()

prior = bm.GlobalNoUTurnSampler().infer(
queries=[model.mu(), model.tau(), model.eta()],
observations={},
num_samples=draws,
num_adaptive_samples=500,
num_chains=chains,
)

posterior = bm.GlobalNoUTurnSampler().infer(
queries=[model.mu(), model.tau(), model.eta()],
observations={model.obs(): torch.from_numpy(data["y"]).float()},
num_samples=draws,
num_adaptive_samples=500,
num_chains=chains,
)
return model, prior, posterior


def library_handle(library):
"""Import a library and return the handle."""
if library == "pystan":
Expand All @@ -506,6 +552,7 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
("emcee", emcee_schools_model),
("pyro", pyro_noncentered_schools),
("numpyro", numpyro_schools_model),
("beanmachine", bm_schools_model),
)
data_directory = os.path.join(here, "saved_models")
models = {}
Expand Down

0 comments on commit af34307

Please sign in to comment.