Skip to content

Commit

Permalink
fix io_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Sep 22, 2020
1 parent ac798b9 commit fe7fef8
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 20 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
* Update diagnostics to be on par with posterior package ([#1366](https://github.com/arviz-devs/arviz/pull/1366))
* Use method="average" in `scipy.stats.rankdata` ([#1380](https://github.com/arviz-devs/arviz/pull/1380))
* Add more `plot_parallel` examples ([#1380](https://github.com/arviz-devs/arviz/pull/1380))
* Bump minimum xarray version to 0.16.1 ([#1389](https://github.com/arviz-devs/arviz/pull/1389)
* Bump minimum xarray version to 0.16.1 ([#1389](https://github.com/arviz-devs/arviz/pull/1389))
* `from_dict` will now store warmup groups even with the main group missing ([1386](https://github.com/arviz-devs/arviz/pull/1386))

### Deprecation

Expand Down
38 changes: 20 additions & 18 deletions arviz/data/io_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .. import utils
from ..rcparams import rcParams
from .base import dict_to_dataset, generate_dims_coords, make_attrs, requires
from .inference_data import InferenceData
from .inference_data import InferenceData, WARMUP_TAG


# pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -69,11 +69,15 @@ def __init__(
self.attrs.pop("created_at", None)
self.attrs.pop("arviz_version", None)

@requires("posterior")
def _init_dict(self, attr_name):
dict_or_none = getattr(self, attr_name, None)
return {} if dict_or_none is None else dict_or_none

@requires(["posterior", f"{WARMUP_TAG}posterior"])
def posterior_to_xarray(self):
"""Convert posterior samples to xarray."""
data = self.posterior
data_warmup = self.warmup_posterior if self.warmup_posterior is not None else {}
data = self._init_dict("posterior")
data_warmup = self._init_dict(f"{WARMUP_TAG}posterior")
if not isinstance(data, dict):
raise TypeError("DictConverter.posterior is not a dictionary")
if not isinstance(data_warmup, dict):
Expand All @@ -95,11 +99,11 @@ def posterior_to_xarray(self):
),
)

@requires("sample_stats")
@requires(["sample_stats", f"{WARMUP_TAG}sample_stats"])
def sample_stats_to_xarray(self):
"""Convert sample_stats samples to xarray."""
data = self.sample_stats
data_warmup = self.warmup_sample_stats if self.warmup_sample_stats is not None else {}
data = self._init_dict("sample_stats")
data_warmup = self._init_dict(f"{WARMUP_TAG}sample_stats")
if not isinstance(data, dict):
raise TypeError("DictConverter.sample_stats is not a dictionary")
if not isinstance(data_warmup, dict):
Expand All @@ -122,11 +126,11 @@ def sample_stats_to_xarray(self):
),
)

@requires("log_likelihood")
@requires(["log_likelihood", f"{WARMUP_TAG}log_likelihood"])
def log_likelihood_to_xarray(self):
"""Convert log_likelihood samples to xarray."""
data = self.log_likelihood
data_warmup = self.warmup_log_likelihood if self.warmup_log_likelihood is not None else {}
data = self._init_dict("log_likelihood")
data_warmup = self._init_dict(f"{WARMUP_TAG}log_likelihood")
if not isinstance(data, dict):
raise TypeError("DictConverter.log_likelihood is not a dictionary")
if not isinstance(data_warmup, dict):
Expand All @@ -141,13 +145,11 @@ def log_likelihood_to_xarray(self):
),
)

@requires("posterior_predictive")
@requires(["posterior_predictive", f"{WARMUP_TAG}posterior_predictive"])
def posterior_predictive_to_xarray(self):
"""Convert posterior_predictive samples to xarray."""
data = self.posterior_predictive
data_warmup = (
self.warmup_posterior_predictive if self.warmup_posterior_predictive is not None else {}
)
data = self._init_dict("posterior_predictive")
data_warmup = self._init_dict(f"{WARMUP_TAG}posterior_predictive")
if not isinstance(data, dict):
raise TypeError("DictConverter.posterior_predictive is not a dictionary")
if not isinstance(data_warmup, dict):
Expand All @@ -162,11 +164,11 @@ def posterior_predictive_to_xarray(self):
),
)

@requires("predictions")
@requires(["predictions", f"{WARMUP_TAG}predictions"])
def predictions_to_xarray(self):
"""Convert predictions to xarray."""
data = self.predictions
data_warmup = self.warmup_predictions if self.warmup_predictions is not None else {}
data = self._init_dict("predictions")
data_warmup = self._init_dict(f"{WARMUP_TAG}predictions")
if not isinstance(data, dict):
raise TypeError("DictConverter.predictions is not a dictionary")
if not isinstance(data_warmup, dict):
Expand Down
2 changes: 1 addition & 1 deletion arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ def test_add_groups(self, data_random):
assert hasattr(idata, "prior")

idata.add_groups(warmup_posterior={"a": data[..., 0], "b": data})
assert "warmup_posterior" in idata._groups # pylint: disable=protected-access
assert "warmup_posterior" in idata._groups_all # pylint: disable=protected-access
assert isinstance(idata.warmup_posterior, xr.Dataset)
assert hasattr(idata, "warmup_posterior")

Expand Down
1 change: 1 addition & 0 deletions arviz/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def create_data_random(groups=None, seed=10):
prior_predictive={"a": data[..., 0], "b": data},
warmup_posterior={"a": data[..., 0], "b": data},
warmup_posterior_predictive={"a": data[..., 0], "b": data},
warmup_prior={"a": data[..., 0], "b": data},
)
idata = from_dict(
**{group: ary for group, ary in idata_dict.items() if group in groups}, save_warmup=True
Expand Down

0 comments on commit fe7fef8

Please sign in to comment.