From fe7fef8ead58223b0681a39c6c404d76c054110c Mon Sep 17 00:00:00 2001 From: "Oriol (ZBook)" Date: Tue, 22 Sep 2020 17:49:31 +0200 Subject: [PATCH] fix io_dict --- CHANGELOG.md | 3 ++- arviz/data/io_dict.py | 38 +++++++++++++++-------------- arviz/tests/base_tests/test_data.py | 2 +- arviz/tests/helpers.py | 1 + 4 files changed, 24 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c307c768ac..007c99f83e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,7 +27,8 @@ * Update diagnostics to be on par with posterior package ([#1366](https://github.com/arviz-devs/arviz/pull/1366)) * Use method="average" in `scipy.stats.rankdata` ([#1380](https://github.com/arviz-devs/arviz/pull/1380)) * Add more `plot_parallel` examples ([#1380](https://github.com/arviz-devs/arviz/pull/1380)) -* Bump minimum xarray version to 0.16.1 ([#1389](https://github.com/arviz-devs/arviz/pull/1389) +* Bump minimum xarray version to 0.16.1 ([#1389](https://github.com/arviz-devs/arviz/pull/1389)) +* `from_dict` will now store warmup groups even with the main group missing ([1386](https://github.com/arviz-devs/arviz/pull/1386)) ### Deprecation diff --git a/arviz/data/io_dict.py b/arviz/data/io_dict.py index fb37ab7b6d..b6b74732df 100644 --- a/arviz/data/io_dict.py +++ b/arviz/data/io_dict.py @@ -6,7 +6,7 @@ from .. import utils from ..rcparams import rcParams from .base import dict_to_dataset, generate_dims_coords, make_attrs, requires -from .inference_data import InferenceData +from .inference_data import InferenceData, WARMUP_TAG # pylint: disable=too-many-instance-attributes @@ -69,11 +69,15 @@ def __init__( self.attrs.pop("created_at", None) self.attrs.pop("arviz_version", None) - @requires("posterior") + def _init_dict(self, attr_name): + dict_or_none = getattr(self, attr_name, None) + return {} if dict_or_none is None else dict_or_none + + @requires(["posterior", f"{WARMUP_TAG}posterior"]) def posterior_to_xarray(self): """Convert posterior samples to xarray.""" - data = self.posterior - data_warmup = self.warmup_posterior if self.warmup_posterior is not None else {} + data = self._init_dict("posterior") + data_warmup = self._init_dict(f"{WARMUP_TAG}posterior") if not isinstance(data, dict): raise TypeError("DictConverter.posterior is not a dictionary") if not isinstance(data_warmup, dict): @@ -95,11 +99,11 @@ def posterior_to_xarray(self): ), ) - @requires("sample_stats") + @requires(["sample_stats", f"{WARMUP_TAG}sample_stats"]) def sample_stats_to_xarray(self): """Convert sample_stats samples to xarray.""" - data = self.sample_stats - data_warmup = self.warmup_sample_stats if self.warmup_sample_stats is not None else {} + data = self._init_dict("sample_stats") + data_warmup = self._init_dict(f"{WARMUP_TAG}sample_stats") if not isinstance(data, dict): raise TypeError("DictConverter.sample_stats is not a dictionary") if not isinstance(data_warmup, dict): @@ -122,11 +126,11 @@ def sample_stats_to_xarray(self): ), ) - @requires("log_likelihood") + @requires(["log_likelihood", f"{WARMUP_TAG}log_likelihood"]) def log_likelihood_to_xarray(self): """Convert log_likelihood samples to xarray.""" - data = self.log_likelihood - data_warmup = self.warmup_log_likelihood if self.warmup_log_likelihood is not None else {} + data = self._init_dict("log_likelihood") + data_warmup = self._init_dict(f"{WARMUP_TAG}log_likelihood") if not isinstance(data, dict): raise TypeError("DictConverter.log_likelihood is not a dictionary") if not isinstance(data_warmup, dict): @@ -141,13 +145,11 @@ def log_likelihood_to_xarray(self): ), ) - @requires("posterior_predictive") + @requires(["posterior_predictive", f"{WARMUP_TAG}posterior_predictive"]) def posterior_predictive_to_xarray(self): """Convert posterior_predictive samples to xarray.""" - data = self.posterior_predictive - data_warmup = ( - self.warmup_posterior_predictive if self.warmup_posterior_predictive is not None else {} - ) + data = self._init_dict("posterior_predictive") + data_warmup = self._init_dict(f"{WARMUP_TAG}posterior_predictive") if not isinstance(data, dict): raise TypeError("DictConverter.posterior_predictive is not a dictionary") if not isinstance(data_warmup, dict): @@ -162,11 +164,11 @@ def posterior_predictive_to_xarray(self): ), ) - @requires("predictions") + @requires(["predictions", f"{WARMUP_TAG}predictions"]) def predictions_to_xarray(self): """Convert predictions to xarray.""" - data = self.predictions - data_warmup = self.warmup_predictions if self.warmup_predictions is not None else {} + data = self._init_dict("predictions") + data_warmup = self._init_dict(f"{WARMUP_TAG}predictions") if not isinstance(data, dict): raise TypeError("DictConverter.predictions is not a dictionary") if not isinstance(data_warmup, dict): diff --git a/arviz/tests/base_tests/test_data.py b/arviz/tests/base_tests/test_data.py index 7bdc1b6b18..08a1448e41 100644 --- a/arviz/tests/base_tests/test_data.py +++ b/arviz/tests/base_tests/test_data.py @@ -625,7 +625,7 @@ def test_add_groups(self, data_random): assert hasattr(idata, "prior") idata.add_groups(warmup_posterior={"a": data[..., 0], "b": data}) - assert "warmup_posterior" in idata._groups # pylint: disable=protected-access + assert "warmup_posterior" in idata._groups_all # pylint: disable=protected-access assert isinstance(idata.warmup_posterior, xr.Dataset) assert hasattr(idata, "warmup_posterior") diff --git a/arviz/tests/helpers.py b/arviz/tests/helpers.py index e547eccfa9..4d80acb169 100644 --- a/arviz/tests/helpers.py +++ b/arviz/tests/helpers.py @@ -157,6 +157,7 @@ def create_data_random(groups=None, seed=10): prior_predictive={"a": data[..., 0], "b": data}, warmup_posterior={"a": data[..., 0], "b": data}, warmup_posterior_predictive={"a": data[..., 0], "b": data}, + warmup_prior={"a": data[..., 0], "b": data}, ) idata = from_dict( **{group: ary for group, ary in idata_dict.items() if group in groups}, save_warmup=True