From 1bec337925544e5dc4f508b4c80090e373943070 Mon Sep 17 00:00:00 2001 From: "Oriol (ZBook)" Date: Fri, 18 Sep 2020 02:21:40 +0200 Subject: [PATCH 1/9] edit groups in idata.extend --- arviz/data/inference_data.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index 99ec363a9d..ab4a755933 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -820,6 +820,10 @@ def extend(self, other, join="left"): ) dataset = getattr(other, group) setattr(self, group, dataset) + if group.startswith(WARMUP_TAG): + self._groups_warmup.append(group) + else: + self._groups.append(group) set_index = _extend_xr_method(xr.Dataset.set_index) get_index = _extend_xr_method(xr.Dataset.get_index) From 1121820c067b9356a282e322e137bd1168730c08 Mon Sep 17 00:00:00 2001 From: "Oriol (ZBook)" Date: Sat, 19 Sep 2020 04:20:21 +0200 Subject: [PATCH 2/9] add test --- arviz/tests/base_tests/test_data.py | 1 + 1 file changed, 1 insertion(+) diff --git a/arviz/tests/base_tests/test_data.py b/arviz/tests/base_tests/test_data.py index b5e5a0112c..ccdf23e27f 100644 --- a/arviz/tests/base_tests/test_data.py +++ b/arviz/tests/base_tests/test_data.py @@ -651,6 +651,7 @@ def test_extend(self, data_random): idata = data_random idata2 = create_data_random(groups=["prior", "prior_predictive", "observed_data"], seed=7) idata.extend(idata2) + assert "prior" in idata._groups_all assert hasattr(idata, "prior") assert hasattr(idata, "prior_predictive") assert idata.prior.equals(idata2.prior) From 6ab22b726485740d067d9e6c1ab33978b25a94a5 Mon Sep 17 00:00:00 2001 From: "Oriol (ZBook)" Date: Sat, 19 Sep 2020 04:34:34 +0200 Subject: [PATCH 3/9] update changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 38ab4593af..7bd94b5f42 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ * Added `circ_var_names` argument to `plot_trace` allowing for circular traceplot (Matplotlib) ([1336](https://github.com/arviz-devs/arviz/pull/1336)) * Ridgeplot is hdi aware. By default displays truncated densities at the specified `hdi_prop` level ([1348](https://github.com/arviz-devs/arviz/pull/1348)) * Added `plot_separation` ([1359](https://github.com/arviz-devs/arviz/pull/1359)) +* Extended methods from `xr.Dataset` to `InferenceData` ([1254](https://github.com/arviz-devs/arviz/pull/1254)) +* Add `extend` and `add_groups` to `InferenceData` ([1300](https://github.com/arviz-devs/arviz/pull/1300) and [1386](https://github.com/arviz-devs/arviz/pull/1386)) ### Maintenance and fixes From b90afb6affe6d955222f31a53ab5737effd6f742 Mon Sep 17 00:00:00 2001 From: Oriol Abril-Pla Date: Mon, 21 Sep 2020 18:16:03 +0200 Subject: [PATCH 4/9] fix indentation --- arviz/data/inference_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index ab4a755933..23d606e742 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -822,7 +822,7 @@ def extend(self, other, join="left"): setattr(self, group, dataset) if group.startswith(WARMUP_TAG): self._groups_warmup.append(group) - else: + else: self._groups.append(group) set_index = _extend_xr_method(xr.Dataset.set_index) From bb93accdce53ddab92b883daf073c6df6fcb007c Mon Sep 17 00:00:00 2001 From: "Oriol (ZBook)" Date: Tue, 22 Sep 2020 00:57:06 +0200 Subject: [PATCH 5/9] fix lint --- arviz/data/inference_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index 23d606e742..cc6c838928 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -821,9 +821,9 @@ def extend(self, other, join="left"): dataset = getattr(other, group) setattr(self, group, dataset) if group.startswith(WARMUP_TAG): - self._groups_warmup.append(group) + self._groups_warmup.append(group) # pylint: disable=protected-access else: - self._groups.append(group) + self._groups.append(group) # pylint: disable=protected-access set_index = _extend_xr_method(xr.Dataset.set_index) get_index = _extend_xr_method(xr.Dataset.get_index) From acb0af93d6e9ab9074c946c4ab9a2cc4f20bb5cd Mon Sep 17 00:00:00 2001 From: "Oriol (ZBook)" Date: Tue, 22 Sep 2020 05:50:08 +0200 Subject: [PATCH 6/9] fix lint x2 --- arviz/data/base.py | 3 ++- arviz/data/inference_data.py | 4 ++-- arviz/tests/base_tests/test_data.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/arviz/data/base.py b/arviz/data/base.py index a7c5ab4dc0..a3729e12f5 100644 --- a/arviz/data/base.py +++ b/arviz/data/base.py @@ -245,6 +245,7 @@ def make_attrs(attrs=None, library=None): def _extend_xr_method(func): """Make wrapper to extend methods from xr.Dataset to InferenceData Class.""" + # pydocstyle requires a non empty line @functools.wraps(func) def wrapped(self, *args, **kwargs): @@ -280,7 +281,7 @@ def wrapped(self, *args, **kwargs): metagroup names. A la `pandas.filter`. inplace: bool, optional If ``True``, modify the InferenceData object inplace, - otherwise, return the modified copy. + otherwise, return the modified copy. """ see_also = """ See Also diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index cc6c838928..c1f06d98e5 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -821,9 +821,9 @@ def extend(self, other, join="left"): dataset = getattr(other, group) setattr(self, group, dataset) if group.startswith(WARMUP_TAG): - self._groups_warmup.append(group) # pylint: disable=protected-access + self._groups_warmup.append(group) else: - self._groups.append(group) # pylint: disable=protected-access + self._groups.append(group) set_index = _extend_xr_method(xr.Dataset.set_index) get_index = _extend_xr_method(xr.Dataset.get_index) diff --git a/arviz/tests/base_tests/test_data.py b/arviz/tests/base_tests/test_data.py index ccdf23e27f..8bff7e5f38 100644 --- a/arviz/tests/base_tests/test_data.py +++ b/arviz/tests/base_tests/test_data.py @@ -651,7 +651,7 @@ def test_extend(self, data_random): idata = data_random idata2 = create_data_random(groups=["prior", "prior_predictive", "observed_data"], seed=7) idata.extend(idata2) - assert "prior" in idata._groups_all + assert "prior" in idata._groups_all # pylint: disable=protected-access assert hasattr(idata, "prior") assert hasattr(idata, "prior_predictive") assert idata.prior.equals(idata2.prior) From 9d34062fe1a2f079e016dfe4c2a0150a7ebd2587 Mon Sep 17 00:00:00 2001 From: "Oriol (ZBook)" Date: Tue, 22 Sep 2020 12:58:24 +0200 Subject: [PATCH 7/9] extend tests --- arviz/tests/base_tests/test_data.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/arviz/tests/base_tests/test_data.py b/arviz/tests/base_tests/test_data.py index 8bff7e5f38..7bdc1b6b18 100644 --- a/arviz/tests/base_tests/test_data.py +++ b/arviz/tests/base_tests/test_data.py @@ -624,10 +624,10 @@ def test_add_groups(self, data_random): assert isinstance(idata.prior, xr.Dataset) assert hasattr(idata, "prior") - idata.add_groups(posterior_warmup={"a": data[..., 0], "b": data}) - assert "posterior_warmup" in idata._groups # pylint: disable=protected-access - assert isinstance(idata.posterior_warmup, xr.Dataset) - assert hasattr(idata, "posterior_warmup") + idata.add_groups(warmup_posterior={"a": data[..., 0], "b": data}) + assert "warmup_posterior" in idata._groups # pylint: disable=protected-access + assert isinstance(idata.warmup_posterior, xr.Dataset) + assert hasattr(idata, "warmup_posterior") def test_add_groups_warning(self, data_random): data = np.random.normal(size=(4, 500, 8)) @@ -649,9 +649,12 @@ def test_add_groups_error(self, data_random): def test_extend(self, data_random): idata = data_random - idata2 = create_data_random(groups=["prior", "prior_predictive", "observed_data"], seed=7) + idata2 = create_data_random( + groups=["prior", "prior_predictive", "observed_data", "warmup_posterior"], seed=7 + ) idata.extend(idata2) assert "prior" in idata._groups_all # pylint: disable=protected-access + assert "warmup_posterior" in idata._groups_all # pylint: disable=protected-access assert hasattr(idata, "prior") assert hasattr(idata, "prior_predictive") assert idata.prior.equals(idata2.prior) From 20b07fbf879149d89e3ad8577339c58317a783b4 Mon Sep 17 00:00:00 2001 From: "Oriol (ZBook)" Date: Tue, 22 Sep 2020 17:49:31 +0200 Subject: [PATCH 8/9] fix io_dict --- CHANGELOG.md | 2 ++ arviz/data/io_dict.py | 38 +++++++++++++++-------------- arviz/tests/base_tests/test_data.py | 2 +- arviz/tests/helpers.py | 1 + 4 files changed, 24 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7bd94b5f42..25c5d21ce4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,8 @@ * Add more `plot_parallel` examples ([#1380](https://github.com/arviz-devs/arviz/pull/1380)) * Bump minimum xarray version to 0.16.1 ([#1389](https://github.com/arviz-devs/arviz/pull/1389) * Fix multi rope for `plot_forest` ([#1390](https://github.com/arviz-devs/arviz/pull/1390)) +* Bump minimum xarray version to 0.16.1 ([#1389](https://github.com/arviz-devs/arviz/pull/1389)) +* `from_dict` will now store warmup groups even with the main group missing ([1386](https://github.com/arviz-devs/arviz/pull/1386)) ### Deprecation diff --git a/arviz/data/io_dict.py b/arviz/data/io_dict.py index fb37ab7b6d..b6b74732df 100644 --- a/arviz/data/io_dict.py +++ b/arviz/data/io_dict.py @@ -6,7 +6,7 @@ from .. import utils from ..rcparams import rcParams from .base import dict_to_dataset, generate_dims_coords, make_attrs, requires -from .inference_data import InferenceData +from .inference_data import InferenceData, WARMUP_TAG # pylint: disable=too-many-instance-attributes @@ -69,11 +69,15 @@ def __init__( self.attrs.pop("created_at", None) self.attrs.pop("arviz_version", None) - @requires("posterior") + def _init_dict(self, attr_name): + dict_or_none = getattr(self, attr_name, None) + return {} if dict_or_none is None else dict_or_none + + @requires(["posterior", f"{WARMUP_TAG}posterior"]) def posterior_to_xarray(self): """Convert posterior samples to xarray.""" - data = self.posterior - data_warmup = self.warmup_posterior if self.warmup_posterior is not None else {} + data = self._init_dict("posterior") + data_warmup = self._init_dict(f"{WARMUP_TAG}posterior") if not isinstance(data, dict): raise TypeError("DictConverter.posterior is not a dictionary") if not isinstance(data_warmup, dict): @@ -95,11 +99,11 @@ def posterior_to_xarray(self): ), ) - @requires("sample_stats") + @requires(["sample_stats", f"{WARMUP_TAG}sample_stats"]) def sample_stats_to_xarray(self): """Convert sample_stats samples to xarray.""" - data = self.sample_stats - data_warmup = self.warmup_sample_stats if self.warmup_sample_stats is not None else {} + data = self._init_dict("sample_stats") + data_warmup = self._init_dict(f"{WARMUP_TAG}sample_stats") if not isinstance(data, dict): raise TypeError("DictConverter.sample_stats is not a dictionary") if not isinstance(data_warmup, dict): @@ -122,11 +126,11 @@ def sample_stats_to_xarray(self): ), ) - @requires("log_likelihood") + @requires(["log_likelihood", f"{WARMUP_TAG}log_likelihood"]) def log_likelihood_to_xarray(self): """Convert log_likelihood samples to xarray.""" - data = self.log_likelihood - data_warmup = self.warmup_log_likelihood if self.warmup_log_likelihood is not None else {} + data = self._init_dict("log_likelihood") + data_warmup = self._init_dict(f"{WARMUP_TAG}log_likelihood") if not isinstance(data, dict): raise TypeError("DictConverter.log_likelihood is not a dictionary") if not isinstance(data_warmup, dict): @@ -141,13 +145,11 @@ def log_likelihood_to_xarray(self): ), ) - @requires("posterior_predictive") + @requires(["posterior_predictive", f"{WARMUP_TAG}posterior_predictive"]) def posterior_predictive_to_xarray(self): """Convert posterior_predictive samples to xarray.""" - data = self.posterior_predictive - data_warmup = ( - self.warmup_posterior_predictive if self.warmup_posterior_predictive is not None else {} - ) + data = self._init_dict("posterior_predictive") + data_warmup = self._init_dict(f"{WARMUP_TAG}posterior_predictive") if not isinstance(data, dict): raise TypeError("DictConverter.posterior_predictive is not a dictionary") if not isinstance(data_warmup, dict): @@ -162,11 +164,11 @@ def posterior_predictive_to_xarray(self): ), ) - @requires("predictions") + @requires(["predictions", f"{WARMUP_TAG}predictions"]) def predictions_to_xarray(self): """Convert predictions to xarray.""" - data = self.predictions - data_warmup = self.warmup_predictions if self.warmup_predictions is not None else {} + data = self._init_dict("predictions") + data_warmup = self._init_dict(f"{WARMUP_TAG}predictions") if not isinstance(data, dict): raise TypeError("DictConverter.predictions is not a dictionary") if not isinstance(data_warmup, dict): diff --git a/arviz/tests/base_tests/test_data.py b/arviz/tests/base_tests/test_data.py index 7bdc1b6b18..08a1448e41 100644 --- a/arviz/tests/base_tests/test_data.py +++ b/arviz/tests/base_tests/test_data.py @@ -625,7 +625,7 @@ def test_add_groups(self, data_random): assert hasattr(idata, "prior") idata.add_groups(warmup_posterior={"a": data[..., 0], "b": data}) - assert "warmup_posterior" in idata._groups # pylint: disable=protected-access + assert "warmup_posterior" in idata._groups_all # pylint: disable=protected-access assert isinstance(idata.warmup_posterior, xr.Dataset) assert hasattr(idata, "warmup_posterior") diff --git a/arviz/tests/helpers.py b/arviz/tests/helpers.py index e547eccfa9..4d80acb169 100644 --- a/arviz/tests/helpers.py +++ b/arviz/tests/helpers.py @@ -157,6 +157,7 @@ def create_data_random(groups=None, seed=10): prior_predictive={"a": data[..., 0], "b": data}, warmup_posterior={"a": data[..., 0], "b": data}, warmup_posterior_predictive={"a": data[..., 0], "b": data}, + warmup_prior={"a": data[..., 0], "b": data}, ) idata = from_dict( **{group: ary for group, ary in idata_dict.items() if group in groups}, save_warmup=True From 69be1575bb69fb3e758355b543809d46cd68ebae Mon Sep 17 00:00:00 2001 From: "Oriol (ZBook)" Date: Tue, 22 Sep 2020 19:51:48 +0200 Subject: [PATCH 9/9] modify default --- arviz/data/io_dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arviz/data/io_dict.py b/arviz/data/io_dict.py index b6b74732df..2b335eed94 100644 --- a/arviz/data/io_dict.py +++ b/arviz/data/io_dict.py @@ -70,7 +70,7 @@ def __init__( self.attrs.pop("arviz_version", None) def _init_dict(self, attr_name): - dict_or_none = getattr(self, attr_name, None) + dict_or_none = getattr(self, attr_name, {}) return {} if dict_or_none is None else dict_or_none @requires(["posterior", f"{WARMUP_TAG}posterior"])