From 9cf68322291c2b2aa9a92dddf98336e519584270 Mon Sep 17 00:00:00 2001 From: Colin Carroll Date: Tue, 14 Nov 2023 15:26:31 -0500 Subject: [PATCH] Make `from_dict` more flexible, and add `from_pytree` --- CHANGELOG.md | 1 + arviz/data/__init__.py | 3 +- arviz/data/base.py | 50 ++++++++++++++++++++- arviz/data/converters.py | 4 ++ arviz/data/io_dict.py | 3 ++ arviz/plots/backends/matplotlib/pairplot.py | 2 +- arviz/tests/base_tests/test_data.py | 13 ++++++ requirements.txt | 1 + 8 files changed, 73 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 957adfbfdf..27d166fb4c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ ### New features - Add filter_vars functionality to `InfereceData.to_dataframe`method ([2277](https://github.com/arviz-devs/arviz/pull/2277)) +- Support for `pytree`s and robust to nested dictionaries. (2291) ### Maintenance and fixes diff --git a/arviz/data/__init__.py b/arviz/data/__init__.py index 742fece161..f2545ddafa 100644 --- a/arviz/data/__init__.py +++ b/arviz/data/__init__.py @@ -7,7 +7,7 @@ from .io_cmdstan import from_cmdstan from .io_cmdstanpy import from_cmdstanpy from .io_datatree import from_datatree, to_datatree -from .io_dict import from_dict +from .io_dict import from_dict, from_pytree from .io_emcee import from_emcee from .io_json import from_json, to_json from .io_netcdf import from_netcdf, to_netcdf @@ -38,6 +38,7 @@ "from_cmdstanpy", "from_datatree", "from_dict", + "from_pytree", "from_json", "from_pyro", "from_numpyro", diff --git a/arviz/data/base.py b/arviz/data/base.py index cf0f281e87..91c248fec5 100644 --- a/arviz/data/base.py +++ b/arviz/data/base.py @@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union import numpy as np +import tree import xarray as xr try: @@ -66,6 +67,46 @@ def wrapped(cls: RequiresArgTypeT) -> Optional[RequiresReturnTypeT]: return wrapped +def _yield_flat_up_to(shallow_tree, input_tree, path=()): + """Yields (path, value) pairs of input_tree flattened up to shallow_tree. + + Adapted from dm-tree (https://github.com/google-deepmind/tree) to allow + lists as leaves. + + Args: + shallow_tree: Nested structure. Traverse no further than its leaf nodes. + input_tree: Nested structure. Return the paths and values from this tree. + Must have the same upper structure as shallow_tree. + path: Tuple. Optional argument, only used when recursing. The path from the + root of the original shallow_tree, down to the root of the shallow_tree + arg of this recursive call. + + Yields: + Pairs of (path, value), where path the tuple path of a leaf node in + shallow_tree, and value is the value of the corresponding node in + input_tree. + """ + # pylint: disable=protected-access + if (isinstance(shallow_tree, tree._TEXT_OR_BYTES) or + not (isinstance(shallow_tree, tree.collections_abc.Mapping) or + tree._is_namedtuple(shallow_tree) or + tree._is_attrs(shallow_tree))): + yield (path, input_tree) + else: + input_tree = dict(tree._yield_sorted_items(input_tree)) + for shallow_key, shallow_subtree in tree._yield_sorted_items(shallow_tree): + subpath = path + (shallow_key,) + input_subtree = input_tree[shallow_key] + for leaf_path, leaf_value in _yield_flat_up_to(shallow_subtree, + input_subtree, + path=subpath): + yield (leaf_path, leaf_value) + # pylint: enable=protected-access + + +def _flatten_with_path(structure): + return list(_yield_flat_up_to(structure, structure)) + def generate_dims_coords( shape, @@ -255,7 +296,7 @@ def numpy_to_data_array( return xr.DataArray(ary, coords=coords, dims=dims) -def dict_to_dataset( +def pytree_to_dataset( data, *, attrs=None, @@ -266,7 +307,7 @@ def dict_to_dataset( index_origin=None, skip_event_dims=None, ): - """Convert a dictionary of numpy arrays to an xarray.Dataset. + """Convert a pytree of numpy arrays to an xarray.Dataset. Parameters ---------- @@ -302,6 +343,10 @@ def dict_to_dataset( """ if dims is None: dims = {} + try: + data = {k[0] if len(k) == 1 else k: v for k, v in _flatten_with_path(data)} + except TypeError: # probably unsortable keys -- the function will still work if + pass # it is an honest dictionary. data_vars = { key: numpy_to_data_array( @@ -317,6 +362,7 @@ def dict_to_dataset( } return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library)) +dict_to_dataset = pytree_to_dataset def make_attrs(attrs=None, library=None): """Make standard attributes to attach to xarray datasets. diff --git a/arviz/data/converters.py b/arviz/data/converters.py index 2961f0aaf1..a8f34bc490 100644 --- a/arviz/data/converters.py +++ b/arviz/data/converters.py @@ -1,5 +1,6 @@ """High level conversion functions.""" import numpy as np +import tree import xarray as xr from .base import dict_to_dataset @@ -105,6 +106,8 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None, dataset = obj.to_dataset() elif isinstance(obj, dict): dataset = dict_to_dataset(obj, coords=coords, dims=dims) + elif tree.is_nested(obj) and not isinstance(obj, (list, tuple)): + dataset = dict_to_dataset(obj, coords=coords, dims=dims) elif isinstance(obj, np.ndarray): dataset = dict_to_dataset({"x": obj}, coords=coords, dims=dims) elif isinstance(obj, (list, tuple)) and isinstance(obj[0], str) and obj[0].endswith(".csv"): @@ -118,6 +121,7 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None, "xarray dataarray", "xarray dataset", "dict", + "pytree", "netcdf filename", "numpy array", "pystan fit", diff --git a/arviz/data/io_dict.py b/arviz/data/io_dict.py index 4d34157ddc..d76a7511c6 100644 --- a/arviz/data/io_dict.py +++ b/arviz/data/io_dict.py @@ -458,3 +458,6 @@ def from_dict( attrs=attrs, **kwargs, ).to_inference_data() + + +from_pytree = from_dict diff --git a/arviz/plots/backends/matplotlib/pairplot.py b/arviz/plots/backends/matplotlib/pairplot.py index 7e50b43d8e..f17f51e19b 100644 --- a/arviz/plots/backends/matplotlib/pairplot.py +++ b/arviz/plots/backends/matplotlib/pairplot.py @@ -333,7 +333,7 @@ def plot_pair( if reference_values: x_name = flat_var_names[i] y_name = flat_var_names[j + not_marginals] - if x_name and y_name not in difference: + if (x_name not in difference) and (y_name not in difference): ax[j, i].plot( reference_values_copy[x_name], reference_values_copy[y_name], diff --git a/arviz/tests/base_tests/test_data.py b/arviz/tests/base_tests/test_data.py index ae8c3e4cad..8680e0dbc4 100644 --- a/arviz/tests/base_tests/test_data.py +++ b/arviz/tests/base_tests/test_data.py @@ -1074,6 +1074,19 @@ def test_dict_to_dataset(): assert set(dataset.a.coords) == {"chain", "draw"} assert set(dataset.b.coords) == {"chain", "draw", "c"} +def test_nested_dict_to_dataset(): + datadict = {"top": { + "a": np.random.randn(100), + "b": np.random.randn(1, 100, 10)}, + "d": np.random.randn(100)} + dataset = convert_to_dataset(datadict, coords={"c": np.arange(10)}, dims={("top", "b"): ["c"]}) + assert set(dataset.data_vars) == {("top", "a"), ("top", "b"), "d"} + assert set(dataset.coords) == {"chain", "draw", "c"} + + assert set(dataset[("top", "a")].coords) == {"chain", "draw"} + assert set(dataset[("top", "b")].coords) == {"chain", "draw", "c"} + assert set(dataset.d.coords) == {"chain", "draw"} + def test_dict_to_dataset_event_dims_error(): datadict = {"a": np.random.randn(1, 100, 10)} diff --git a/requirements.txt b/requirements.txt index d764477be1..549d4fa07b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ numpy>=1.22.0,<2.0 scipy>=1.8.0 packaging pandas>=1.4.0 +dm-tree>=0.1.8 xarray>=0.21.0 h5netcdf>=1.0.2 typing_extensions>=4.1.0