Skip to content

Commit

Permalink
allow from_pymc3 args to override model dims and coords (#1249)
Browse files Browse the repository at this point in the history
* allow from_pymc3 args to override model values

* lint

* black

* add docstring
  • Loading branch information
OriolAbril authored Jun 19, 2020
1 parent e95fbe1 commit f2c954e
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 9 deletions.
3 changes: 1 addition & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
* loo-pit plot. The kde is computed over the data interval (this could be shorter than [0, 1]). The HDI is computed analitically (#1215)
* Added `html_repr` of InferenceData objects for jupyter notebooks. (#1217)
* Added support for PyJAGS via the function `from_pyjags` in the module arviz.data.io_pyjags. (#1219)
* `from_pymc3` can now retrieve `coords` and `dims` from model context (#1228
and #1240)
* `from_pymc3` can now retrieve `coords` and `dims` from model context (#1228, #1240 and #1249)

### Maintenance and fixes
* Include data from `MultiObservedRV` to `observed_data` when using
Expand Down
15 changes: 8 additions & 7 deletions arviz/data/io_pymc3.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,14 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
aelem = arbitrary_element(get_from)
self.ndraws = aelem.shape[0]

self.coords = coords
if coords is None and hasattr(self.model, "coords"):
self.coords = self.model.coords

self.dims = dims
if dims is None and hasattr(self.model, "RV_dims"):
self.dims = {k: list(v) for k, v in self.model.RV_dims.items()}
self.coords = {} if coords is None else coords
if hasattr(self.model, "coords"):
self.coords = {**self.model.coords, **self.coords}

self.dims = {} if dims is None else dims
if hasattr(self.model, "RV_dims"):
model_dims = {k: list(v) for k, v in self.model.RV_dims.items()}
self.dims = {**model_dims, **self.dims}

self.observations, self.multi_observations = self.find_observations()

Expand Down
27 changes: 27 additions & 0 deletions arviz/tests/external_tests/test_data_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,33 @@ def test_autodetect_coords_from_model(self, use_context):
np.testing.assert_array_equal(idata.observed_data.coords["date"], coords["date"])
np.testing.assert_array_equal(idata.observed_data.coords["city"], coords["city"])

def test_ovewrite_model_coords_dims(self):
"""Check coords and dims from model object can be partially overwrited."""
dim1 = ["a", "b"]
new_dim1 = ["c", "d"]
coords = {"dim1": dim1, "dim2": ["c1", "c2"]}
x_data = np.arange(4).reshape((2, 2))
y = x_data + np.random.normal(size=(2, 2))
with pm.Model(coords=coords):
x = pm.Data("x", x_data, dims=("dim1", "dim2"))
beta = pm.Normal("beta", 0, 1, dims="dim1")
_ = pm.Normal("obs", x * beta, 1, observed=y, dims=("dim1", "dim2"))
trace = pm.sample(100, tune=100)
idata1 = from_pymc3(trace)
idata2 = from_pymc3(trace, coords={"dim1": new_dim1}, dims={"beta": ["dim2"]})

test_dict = {"posterior": ["beta"], "observed_data": ["obs"], "constant_data": ["x"]}
fails1 = check_multiple_attrs(test_dict, idata1)
assert not fails1
fails2 = check_multiple_attrs(test_dict, idata2)
assert not fails2
assert "dim1" in list(idata1.posterior.beta.dims)
assert "dim2" in list(idata2.posterior.beta.dims)
assert np.all(idata1.constant_data.x.dim1.values == np.array(dim1))
assert np.all(idata1.constant_data.x.dim2.values == np.array(["c1", "c2"]))
assert np.all(idata2.constant_data.x.dim1.values == np.array(new_dim1))
assert np.all(idata2.constant_data.x.dim2.values == np.array(["c1", "c2"]))

def test_missing_data_model(self):
# source pymc3/pymc3/tests/test_missing.py
data = ma.masked_values([1, 2, -1, 4, -1], value=-1)
Expand Down

0 comments on commit f2c954e

Please sign in to comment.