diff --git a/CHANGELOG.md b/CHANGELOG.md index 4761652ed6..46bd28bbe4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ * loo-pit plot. The kde is computed over the data interval (this could be shorter than [0, 1]). The hdi is computed analitically (#1215) * Added `html_repr` of InferenceData objects for jupyter notebooks. (#1217) * Added support for PyJAGS via the function `from_pyjags` in the module arviz.data.io_pyjags. (#1219) +* `from_pymc3` can now retrieve `coords` and `dims` from model context (#1228 + and #1240) ### Maintenance and fixes * Include data from `MultiObservedRV` to `observed_data` when using diff --git a/arviz/data/io_pymc3.py b/arviz/data/io_pymc3.py index 98129a2fec..10a5657d78 100644 --- a/arviz/data/io_pymc3.py +++ b/arviz/data/io_pymc3.py @@ -148,12 +148,12 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray: self.ndraws = aelem.shape[0] self.coords = coords - if coords is None and hasattr(model, "coords"): - self.coords = model.coords + if coords is None and hasattr(self.model, "coords"): + self.coords = self.model.coords self.dims = dims - if dims is None and hasattr(model, "RV_dims"): - self.dims = {k: list(v) for k, v in model.RV_dims.items()} + if dims is None and hasattr(self.model, "RV_dims"): + self.dims = {k: list(v) for k, v in self.model.RV_dims.items()} self.observations, self.multi_observations = self.find_observations() diff --git a/arviz/tests/external_tests/test_data_pymc.py b/arviz/tests/external_tests/test_data_pymc.py index b90f82f575..41194a1ec8 100644 --- a/arviz/tests/external_tests/test_data_pymc.py +++ b/arviz/tests/external_tests/test_data_pymc.py @@ -202,7 +202,8 @@ def test_posterior_predictive_warning(self, data, eight_schools_params, caplog): packaging.version.Version(pm.__version__) < packaging.version.Version("3.9.0"), reason="Requires PyMC3 >= 3.9.0", ) - def test_autodetect_coords_from_model(self): + @pytest.mark.parametrize("use_context", [True, False]) + def test_autodetect_coords_from_model(self, use_context): df_data = pd.DataFrame(columns=["date"]).set_index("date") dates = pd.date_range(start="2020-05-01", end="2020-05-20") for city, mu in {"Berlin": 15, "San Marino": 18, "Paris": 16}.items(): @@ -231,8 +232,14 @@ def test_autodetect_coords_from_model(self): draws=30, step=pm.Metropolis(), ) - idata = from_pymc3(trace=trace, model=model) + if use_context: + idata = from_pymc3(trace=trace) + if not use_context: + idata = from_pymc3(trace=trace, model=model) + assert "city" in list(idata.posterior.dims) + assert "city" in list(idata.observed_data.dims) + assert "date" in list(idata.observed_data.dims) np.testing.assert_array_equal(idata.posterior.coords["city"], coords["city"]) np.testing.assert_array_equal(idata.observed_data.coords["date"], coords["date"]) np.testing.assert_array_equal(idata.observed_data.coords["city"], coords["city"])