Skip to content

Commit

Permalink
Preprocess prior
Browse files Browse the repository at this point in the history
  • Loading branch information
zaxtax committed Dec 5, 2022
1 parent bcf56a5 commit d68d9ce
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
11 changes: 2 additions & 9 deletions arviz/tests/external_tests/test_data_beanmachine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class TestDataBeanMachine:
@pytest.fixture(scope="class")
def data(self, eight_schools_params, draws, chains):
class Data:
model, obj = load_cached_models(
model, prior, obj = load_cached_models(
eight_schools_params,
draws,
chains,
Expand Down Expand Up @@ -71,14 +71,7 @@ def test_inference_data_has_log_likelihood_and_observed_data(self, data):

def test_inference_data_no_posterior(self, data):
model = data.model
prior = bm.GlobalNoUTurnSampler().infer(
queries=[model.mu(), model.tau(), model.eta()],
observations={},
num_samples=100,
num_adaptive_samples=100,
num_chains=2,
)
# only prior
inference_data = from_beanmachine(prior)
inference_data = from_beanmachine(data.prior)
assert not model.obs() in inference_data.posterior
assert "observed_data" not in inference_data
11 changes: 10 additions & 1 deletion arviz/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,14 +513,23 @@ def obs(self):
return dist.Normal(self.theta(), torch.from_numpy(data["sigma"]).float())

model = EightSchools()

prior = bm.GlobalNoUTurnSampler().infer(
queries=[model.mu(), model.tau(), model.eta()],
observations={},
num_samples=draws,
num_adaptive_samples=500,
num_chains=chains,
)

posterior = bm.GlobalNoUTurnSampler().infer(
queries=[model.mu(), model.tau(), model.eta()],
observations={model.obs(): torch.from_numpy(data["y"]).float()},
num_samples=draws,
num_adaptive_samples=500,
num_chains=chains,
)
return model, posterior
return model, prior, posterior


def library_handle(library):
Expand Down

0 comments on commit d68d9ce

Please sign in to comment.