diff --git a/arviz/data/io_dict.py b/arviz/data/io_dict.py index 12e1449298..2748e8e8e7 100644 --- a/arviz/data/io_dict.py +++ b/arviz/data/io_dict.py @@ -1,10 +1,9 @@ """Dictionary specific conversion code.""" import warnings -import xarray as xr +from typing import Optional from .inference_data import InferenceData -from .base import requires, dict_to_dataset, generate_dims_coords, make_attrs -from .. import utils +from .base import requires, dict_to_dataset # pylint: disable=too-many-instance-attributes @@ -25,6 +24,7 @@ def __init__( observed_data=None, constant_data=None, predictions_constant_data=None, + index_origin=None, coords=None, dims=None ): @@ -39,6 +39,7 @@ def __init__( self.observed_data = observed_data self.constant_data = constant_data self.predictions_constant_data = predictions_constant_data + self.index_origin = index_origin self.coords = coords self.dims = dims @@ -56,7 +57,9 @@ def posterior_to_xarray(self): UserWarning, ) - return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims) + return dict_to_dataset( + data, library=None, coords=self.coords, dims=self.dims, index_origin=self.index_origin + ) @requires("sample_stats") def sample_stats_to_xarray(self): @@ -73,7 +76,9 @@ def sample_stats_to_xarray(self): PendingDeprecationWarning, ) - return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims) + return dict_to_dataset( + data, library=None, coords=self.coords, dims=self.dims, index_origin=self.index_origin + ) @requires("log_likelihood") def log_likelihood_to_xarray(self): @@ -82,7 +87,9 @@ def log_likelihood_to_xarray(self): if not isinstance(data, dict): raise TypeError("DictConverter.log_likelihood is not a dictionary") - return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims) + return dict_to_dataset( + data, library=None, coords=self.coords, dims=self.dims, index_origin=self.index_origin + ) @requires("posterior_predictive") def posterior_predictive_to_xarray(self): @@ -91,7 +98,9 @@ def posterior_predictive_to_xarray(self): if not isinstance(data, dict): raise TypeError("DictConverter.posterior_predictive is not a dictionary") - return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims) + return dict_to_dataset( + data, library=None, coords=self.coords, dims=self.dims, index_origin=self.index_origin + ) @requires("predictions") def predictions_to_xarray(self): @@ -100,7 +109,9 @@ def predictions_to_xarray(self): if not isinstance(data, dict): raise TypeError("DictConverter.predictions is not a dictionary") - return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims) + return dict_to_dataset( + data, library=None, coords=self.coords, dims=self.dims, index_origin=self.index_origin + ) @requires("prior") def prior_to_xarray(self): @@ -109,7 +120,9 @@ def prior_to_xarray(self): if not isinstance(data, dict): raise TypeError("DictConverter.prior is not a dictionary") - return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims) + return dict_to_dataset( + data, library=None, coords=self.coords, dims=self.dims, index_origin=self.index_origin + ) @requires("sample_stats_prior") def sample_stats_prior_to_xarray(self): @@ -118,7 +131,9 @@ def sample_stats_prior_to_xarray(self): if not isinstance(data, dict): raise TypeError("DictConverter.sample_stats_prior is not a dictionary") - return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims) + return dict_to_dataset( + data, library=None, coords=self.coords, dims=self.dims, index_origin=self.index_origin + ) @requires("prior_predictive") def prior_predictive_to_xarray(self): @@ -127,26 +142,22 @@ def prior_predictive_to_xarray(self): if not isinstance(data, dict): raise TypeError("DictConverter.prior_predictive is not a dictionary") - return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims) + return dict_to_dataset( + data, library=None, coords=self.coords, dims=self.dims, index_origin=self.index_origin + ) - def data_to_xarray(self, dct, group): + def data_to_xarray(self, data, group): """Convert data to xarray.""" - data = dct if not isinstance(data, dict): raise TypeError("DictConverter.{} is not a dictionary".format(group)) - if self.dims is None: - dims = {} - else: - dims = self.dims - new_data = dict() - for key, vals in data.items(): - vals = utils.one_de(vals) - val_dims = dims.get(key) - val_dims, coords = generate_dims_coords( - vals.shape, key, dims=val_dims, coords=self.coords - ) - new_data[key] = xr.DataArray(vals, dims=val_dims, coords=coords) - return xr.Dataset(data_vars=new_data, attrs=make_attrs(library=None)) + return dict_to_dataset( + data, + library=None, + coords=self.coords, + dims=self.dims, + default_dims=[], + index_origin=self.index_origin, + ) @requires("observed_data") def observed_data_to_xarray(self): @@ -202,6 +213,7 @@ def from_dict( observed_data=None, constant_data=None, predictions_constant_data=None, + index_origin: Optional[int] = None, coords=None, dims=None ): @@ -223,6 +235,7 @@ def from_dict( observed_data : dict constant_data : dict predictions_constant_data: dict + index_origin : int, optional coords : dict[str, iterable] A dictionary containing the values that are used as index. The key is the name of the dimension, the values are the index values. @@ -245,6 +258,7 @@ def from_dict( observed_data=observed_data, constant_data=constant_data, predictions_constant_data=predictions_constant_data, + index_origin=index_origin, coords=coords, dims=dims, ).to_inference_data()