Skip to content

Commit

Permalink
update io_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Jun 3, 2020
1 parent 2f72bdb commit 9175645
Showing 1 changed file with 40 additions and 26 deletions.
66 changes: 40 additions & 26 deletions arviz/data/io_dict.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""Dictionary specific conversion code."""
import warnings
import xarray as xr
from typing import Optional

from .inference_data import InferenceData
from .base import requires, dict_to_dataset, generate_dims_coords, make_attrs
from .. import utils
from .base import requires, dict_to_dataset


# pylint: disable=too-many-instance-attributes
Expand All @@ -25,6 +24,7 @@ def __init__(
observed_data=None,
constant_data=None,
predictions_constant_data=None,
index_origin=None,
coords=None,
dims=None
):
Expand All @@ -39,6 +39,7 @@ def __init__(
self.observed_data = observed_data
self.constant_data = constant_data
self.predictions_constant_data = predictions_constant_data
self.index_origin = index_origin
self.coords = coords
self.dims = dims

Expand All @@ -56,7 +57,9 @@ def posterior_to_xarray(self):
UserWarning,
)

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)
return dict_to_dataset(
data, library=None, coords=self.coords, dims=self.dims, index_origin=self.index_origin
)

@requires("sample_stats")
def sample_stats_to_xarray(self):
Expand All @@ -73,7 +76,9 @@ def sample_stats_to_xarray(self):
PendingDeprecationWarning,
)

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)
return dict_to_dataset(
data, library=None, coords=self.coords, dims=self.dims, index_origin=self.index_origin
)

@requires("log_likelihood")
def log_likelihood_to_xarray(self):
Expand All @@ -82,7 +87,9 @@ def log_likelihood_to_xarray(self):
if not isinstance(data, dict):
raise TypeError("DictConverter.log_likelihood is not a dictionary")

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)
return dict_to_dataset(
data, library=None, coords=self.coords, dims=self.dims, index_origin=self.index_origin
)

@requires("posterior_predictive")
def posterior_predictive_to_xarray(self):
Expand All @@ -91,7 +98,9 @@ def posterior_predictive_to_xarray(self):
if not isinstance(data, dict):
raise TypeError("DictConverter.posterior_predictive is not a dictionary")

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)
return dict_to_dataset(
data, library=None, coords=self.coords, dims=self.dims, index_origin=self.index_origin
)

@requires("predictions")
def predictions_to_xarray(self):
Expand All @@ -100,7 +109,9 @@ def predictions_to_xarray(self):
if not isinstance(data, dict):
raise TypeError("DictConverter.predictions is not a dictionary")

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)
return dict_to_dataset(
data, library=None, coords=self.coords, dims=self.dims, index_origin=self.index_origin
)

@requires("prior")
def prior_to_xarray(self):
Expand All @@ -109,7 +120,9 @@ def prior_to_xarray(self):
if not isinstance(data, dict):
raise TypeError("DictConverter.prior is not a dictionary")

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)
return dict_to_dataset(
data, library=None, coords=self.coords, dims=self.dims, index_origin=self.index_origin
)

@requires("sample_stats_prior")
def sample_stats_prior_to_xarray(self):
Expand All @@ -118,7 +131,9 @@ def sample_stats_prior_to_xarray(self):
if not isinstance(data, dict):
raise TypeError("DictConverter.sample_stats_prior is not a dictionary")

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)
return dict_to_dataset(
data, library=None, coords=self.coords, dims=self.dims, index_origin=self.index_origin
)

@requires("prior_predictive")
def prior_predictive_to_xarray(self):
Expand All @@ -127,26 +142,22 @@ def prior_predictive_to_xarray(self):
if not isinstance(data, dict):
raise TypeError("DictConverter.prior_predictive is not a dictionary")

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)
return dict_to_dataset(
data, library=None, coords=self.coords, dims=self.dims, index_origin=self.index_origin
)

def data_to_xarray(self, dct, group):
def data_to_xarray(self, data, group):
"""Convert data to xarray."""
data = dct
if not isinstance(data, dict):
raise TypeError("DictConverter.{} is not a dictionary".format(group))
if self.dims is None:
dims = {}
else:
dims = self.dims
new_data = dict()
for key, vals in data.items():
vals = utils.one_de(vals)
val_dims = dims.get(key)
val_dims, coords = generate_dims_coords(
vals.shape, key, dims=val_dims, coords=self.coords
)
new_data[key] = xr.DataArray(vals, dims=val_dims, coords=coords)
return xr.Dataset(data_vars=new_data, attrs=make_attrs(library=None))
return dict_to_dataset(
data,
library=None,
coords=self.coords,
dims=self.dims,
default_dims=[],
index_origin=self.index_origin,
)

@requires("observed_data")
def observed_data_to_xarray(self):
Expand Down Expand Up @@ -202,6 +213,7 @@ def from_dict(
observed_data=None,
constant_data=None,
predictions_constant_data=None,
index_origin: Optional[int] = None,
coords=None,
dims=None
):
Expand All @@ -223,6 +235,7 @@ def from_dict(
observed_data : dict
constant_data : dict
predictions_constant_data: dict
index_origin : int, optional
coords : dict[str, iterable]
A dictionary containing the values that are used as index. The key
is the name of the dimension, the values are the index values.
Expand All @@ -245,6 +258,7 @@ def from_dict(
observed_data=observed_data,
constant_data=constant_data,
predictions_constant_data=predictions_constant_data,
index_origin=index_origin,
coords=coords,
dims=dims,
).to_inference_data()

0 comments on commit 9175645

Please sign in to comment.