diff --git a/CHANGELOG.md b/CHANGELOG.md index dd051f1768..193489e964 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,8 +6,7 @@ * loo-pit plot. The kde is computed over the data interval (this could be shorter than [0, 1]). The HDI is computed analitically (#1215) * Added `html_repr` of InferenceData objects for jupyter notebooks. (#1217) * Added support for PyJAGS via the function `from_pyjags` in the module arviz.data.io_pyjags. (#1219) -* `from_pymc3` can now retrieve `coords` and `dims` from model context (#1228 - and #1240) +* `from_pymc3` can now retrieve `coords` and `dims` from model context (#1228, #1240 and #1249) ### Maintenance and fixes * Include data from `MultiObservedRV` to `observed_data` when using diff --git a/arviz/data/io_pymc3.py b/arviz/data/io_pymc3.py index 10a5657d78..37469c6cbf 100644 --- a/arviz/data/io_pymc3.py +++ b/arviz/data/io_pymc3.py @@ -147,13 +147,14 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray: aelem = arbitrary_element(get_from) self.ndraws = aelem.shape[0] - self.coords = coords - if coords is None and hasattr(self.model, "coords"): - self.coords = self.model.coords - - self.dims = dims - if dims is None and hasattr(self.model, "RV_dims"): - self.dims = {k: list(v) for k, v in self.model.RV_dims.items()} + self.coords = {} if coords is None else coords + if hasattr(self.model, "coords"): + self.coords = {**self.model.coords, **self.coords} + + self.dims = {} if dims is None else dims + if hasattr(self.model, "RV_dims"): + model_dims = {k: list(v) for k, v in self.model.RV_dims.items()} + self.dims = {**model_dims, **self.dims} self.observations, self.multi_observations = self.find_observations() diff --git a/arviz/tests/external_tests/test_data_pymc.py b/arviz/tests/external_tests/test_data_pymc.py index 41194a1ec8..2343152b44 100644 --- a/arviz/tests/external_tests/test_data_pymc.py +++ b/arviz/tests/external_tests/test_data_pymc.py @@ -244,6 +244,33 @@ def test_autodetect_coords_from_model(self, use_context): np.testing.assert_array_equal(idata.observed_data.coords["date"], coords["date"]) np.testing.assert_array_equal(idata.observed_data.coords["city"], coords["city"]) + def test_ovewrite_model_coords_dims(self): + """Check coords and dims from model object can be partially overwrited.""" + dim1 = ["a", "b"] + new_dim1 = ["c", "d"] + coords = {"dim1": dim1, "dim2": ["c1", "c2"]} + x_data = np.arange(4).reshape((2, 2)) + y = x_data + np.random.normal(size=(2, 2)) + with pm.Model(coords=coords): + x = pm.Data("x", x_data, dims=("dim1", "dim2")) + beta = pm.Normal("beta", 0, 1, dims="dim1") + _ = pm.Normal("obs", x * beta, 1, observed=y, dims=("dim1", "dim2")) + trace = pm.sample(100, tune=100) + idata1 = from_pymc3(trace) + idata2 = from_pymc3(trace, coords={"dim1": new_dim1}, dims={"beta": ["dim2"]}) + + test_dict = {"posterior": ["beta"], "observed_data": ["obs"], "constant_data": ["x"]} + fails1 = check_multiple_attrs(test_dict, idata1) + assert not fails1 + fails2 = check_multiple_attrs(test_dict, idata2) + assert not fails2 + assert "dim1" in list(idata1.posterior.beta.dims) + assert "dim2" in list(idata2.posterior.beta.dims) + assert np.all(idata1.constant_data.x.dim1.values == np.array(dim1)) + assert np.all(idata1.constant_data.x.dim2.values == np.array(["c1", "c2"])) + assert np.all(idata2.constant_data.x.dim1.values == np.array(new_dim1)) + assert np.all(idata2.constant_data.x.dim2.values == np.array(["c1", "c2"])) + def test_missing_data_model(self): # source pymc3/pymc3/tests/test_missing.py data = ma.masked_values([1, 2, -1, 4, -1], value=-1)