diff --git a/datatree/datatree.py b/datatree/datatree.py index 79257343..e39b0c05 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -424,10 +424,12 @@ def __init__( else: node_path, node_name = "/", path + relative_path = node_path.replace(self.name, "") + # Create and set new node new_node = DataNode(name=node_name, data=data) self.set_node( - node_path, + relative_path, new_node, allow_overwrite=False, new_nodes_along_path=True, diff --git a/datatree/mapping.py b/datatree/mapping.py index b0ff2b22..94b17ac0 100644 --- a/datatree/mapping.py +++ b/datatree/mapping.py @@ -1,6 +1,8 @@ import functools +from itertools import repeat from anytree.iterators import LevelOrderIter +from xarray import DataArray, Dataset from .treenode import TreeNode @@ -43,11 +45,11 @@ def _check_isomorphic(subtree_a, subtree_b, require_names_equal=False): if not isinstance(subtree_a, TreeNode): raise TypeError( - f"Argument `subtree_a is not a tree, it is of type {type(subtree_a)}" + f"Argument `subtree_a` is not a tree, it is of type {type(subtree_a)}" ) if not isinstance(subtree_b, TreeNode): raise TypeError( - f"Argument `subtree_b is not a tree, it is of type {type(subtree_b)}" + f"Argument `subtree_b` is not a tree, it is of type {type(subtree_b)}" ) # Walking nodes in "level-order" fashion means walking down from the root breadth-first. @@ -83,57 +85,195 @@ def _check_isomorphic(subtree_a, subtree_b, require_names_equal=False): def map_over_subtree(func): """ - Decorator which turns a function which acts on (and returns) single Datasets into one which acts on DataTrees. + Decorator which turns a function which acts on (and returns) Datasets into one which acts on and returns DataTrees. - Applies a function to every dataset in this subtree, returning a new tree which stores the results. + Applies a function to every dataset in one or more subtrees, returning new trees which store the results. - The function will be applied to any dataset stored in this node, as well as any dataset stored in any of the - descendant nodes. The returned tree will have the same structure as the original subtree. + The function will be applied to any dataset stored in any of the nodes in the trees. The returned trees will have + the same structure as the supplied trees. - func needs to return a Dataset, DataArray, or None in order to be able to rebuild the subtree after mapping, as each - result will be assigned to its respective node of new tree via `DataTree.__setitem__`. + `func` needs to return one Datasets, DataArrays, or None in order to be able to rebuild the subtrees after + mapping, as each result will be assigned to its respective node of a new tree via `DataTree.__setitem__`. Any + returned value that is one of these types will be stacked into a separate tree before returning all of them. + + The trees passed to the resulting function must all be isomorphic to one another. Their nodes need not be named + similarly, but all the output trees will have nodes named in the same way as the first tree passed. Parameters ---------- func : callable Function to apply to datasets with signature: - `func(node.ds, *args, **kwargs) -> Dataset`. + `func(*args, **kwargs) -> Union[Dataset, Iterable[Dataset]]`. + + (i.e. func must accept at least one Dataset and return at least one Dataset.) Function will not be applied to any nodes without datasets. *args : tuple, optional - Positional arguments passed on to `func`. + Positional arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets \ + via .ds . **kwargs : Any - Keyword arguments passed on to `func`. + Keyword arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets + via .ds . Returns ------- mapped : callable - Wrapped function which returns tree created from results of applying ``func`` to the dataset at each node. + Wrapped function which returns one or more tree(s) created from results of applying ``func`` to the dataset at + each node. See also -------- DataTree.map_over_subtree DataTree.map_over_subtree_inplace + DataTree.subtree """ + # TODO examples in the docstring + + # TODO inspect function to work out immediately if the wrong number of arguments were passed for it? + @functools.wraps(func) - def _map_over_subtree(tree, *args, **kwargs): + def _map_over_subtree(*args, **kwargs): """Internal function which maps func over every node in tree, returning a tree of the results.""" + from .datatree import DataTree + + all_tree_inputs = [a for a in args if isinstance(a, DataTree)] + [ + a for a in kwargs.values() if isinstance(a, DataTree) + ] + + if len(all_tree_inputs) > 0: + first_tree, *other_trees = all_tree_inputs + else: + raise TypeError("Must pass at least one tree object") + + for other_tree in other_trees: + # isomorphism is transitive so this is enough to guarantee all trees are mutually isomorphic + _check_isomorphic(first_tree, other_tree, require_names_equal=False) + + # Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees + # We don't know which arguments are DataTrees so we zip all arguments together as iterables + # Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return + out_data_objects = {} + args_as_tree_length_iterables = [ + a.subtree if isinstance(a, DataTree) else repeat(a) for a in args + ] + n_args = len(args_as_tree_length_iterables) + kwargs_as_tree_length_iterables = { + k: v.subtree if isinstance(v, DataTree) else repeat(v) + for k, v in kwargs.items() + } + for node_of_first_tree, *all_node_args in zip( + first_tree.subtree, + *args_as_tree_length_iterables, + *list(kwargs_as_tree_length_iterables.values()), + ): + node_args_as_datasets = [ + a.ds if isinstance(a, DataTree) else a for a in all_node_args[:n_args] + ] + node_kwargs_as_datasets = dict( + zip( + [k for k in kwargs_as_tree_length_iterables.keys()], + [ + v.ds if isinstance(v, DataTree) else v + for v in all_node_args[n_args:] + ], + ) + ) - # Recreate and act on root node - from .datatree import DataNode + # Now we can call func on the data in this particular set of corresponding nodes + results = ( + func(*node_args_as_datasets, **node_kwargs_as_datasets) + if node_of_first_tree.has_data + else None + ) - out_tree = DataNode(name=tree.name, data=tree.ds) - if out_tree.has_data: - out_tree.ds = func(out_tree.ds, *args, **kwargs) + # TODO implement mapping over multiple trees in-place using if conditions from here on? + out_data_objects[node_of_first_tree.pathstr] = results + + # Find out how many return values we received + num_return_values = _check_all_return_values(out_data_objects) + + # Reconstruct 1+ subtrees from the dict of results, by filling in all nodes of all result trees + result_trees = [] + for i in range(num_return_values): + out_tree_contents = {} + for n in first_tree.subtree: + p = n.pathstr + if p in out_data_objects.keys(): + if isinstance(out_data_objects[p], tuple): + output_node_data = out_data_objects[p][i] + else: + output_node_data = out_data_objects[p] + else: + output_node_data = None + out_tree_contents[p] = output_node_data + + new_tree = DataTree(name=first_tree.name, data_objects=out_tree_contents) + result_trees.append(new_tree) + + # If only one result then don't wrap it in a tuple + if len(result_trees) == 1: + return result_trees[0] + else: + return tuple(result_trees) - # Act on every other node in the tree, and rebuild from results - for node in tree.descendants: - # TODO make a proper relative_path method - relative_path = node.pathstr.replace(tree.pathstr, "") - result = func(node.ds, *args, **kwargs) if node.has_data else None - out_tree[relative_path] = result + return _map_over_subtree - return out_tree - return _map_over_subtree +def _check_single_set_return_values(path_to_node, obj): + """Check types returned from single evaluation of func, and return number of return values received from func.""" + if isinstance(obj, (Dataset, DataArray)): + return 1 + elif isinstance(obj, tuple): + for r in obj: + if not isinstance(r, (Dataset, DataArray)): + raise TypeError( + f"One of the results of calling func on datasets on the nodes at position {path_to_node} is " + f"of type {type(r)}, not Dataset or DataArray." + ) + return len(obj) + else: + raise TypeError( + f"The result of calling func on the node at position {path_to_node} is of type {type(obj)}, not " + f"Dataset or DataArray, nor a tuple of such types." + ) + + +def _check_all_return_values(returned_objects): + """Walk through all values returned by mapping func over subtrees, raising on any invalid or inconsistent types.""" + + if all(r is None for r in returned_objects.values()): + raise TypeError( + "Called supplied function on all nodes but found a return value of None for" + "all of them." + ) + + result_data_objects = [ + (path_to_node, r) + for path_to_node, r in returned_objects.items() + if r is not None + ] + + if len(result_data_objects) == 1: + # Only one node in the tree: no need to check consistency of results between nodes + path_to_node, result = result_data_objects[0] + num_return_values = _check_single_set_return_values(path_to_node, result) + else: + prev_path, _ = result_data_objects[0] + prev_num_return_values, num_return_values = None, None + for path_to_node, obj in result_data_objects[1:]: + num_return_values = _check_single_set_return_values(path_to_node, obj) + + if ( + num_return_values != prev_num_return_values + and prev_num_return_values is not None + ): + raise TypeError( + f"Calling func on the nodes at position {path_to_node} returns {num_return_values} separate return " + f"values, whereas calling func on the nodes at position {prev_path} instead returns " + f"{prev_num_return_values} separate return values." + ) + + prev_path, prev_num_return_values = path_to_node, num_return_values + + return num_return_values diff --git a/datatree/tests/test_datatree.py b/datatree/tests/test_datatree.py index 4592643b..6ce51851 100644 --- a/datatree/tests/test_datatree.py +++ b/datatree/tests/test_datatree.py @@ -8,11 +8,9 @@ def assert_tree_equal(dt_a, dt_b): - assert dt_a.name == dt_b.name assert dt_a.parent is dt_b.parent - assert dt_a.ds.equals(dt_b.ds) - for a, b in zip(dt_a.descendants, dt_b.descendants): + for a, b in zip(dt_a.subtree, dt_b.subtree): assert a.name == b.name assert a.pathstr == b.pathstr if a.has_data: @@ -321,7 +319,6 @@ def test_to_netcdf(self, tmpdir): original_dt.to_netcdf(filepath, engine="netcdf4") roundtrip_dt = open_datatree(filepath) - assert_tree_equal(original_dt, roundtrip_dt) def test_to_zarr(self, tmpdir): @@ -332,5 +329,4 @@ def test_to_zarr(self, tmpdir): original_dt.to_zarr(filepath) roundtrip_dt = open_datatree(filepath, engine="zarr") - assert_tree_equal(original_dt, roundtrip_dt) diff --git a/datatree/tests/test_mapping.py b/datatree/tests/test_mapping.py index da2ad8be..b94840dc 100644 --- a/datatree/tests/test_mapping.py +++ b/datatree/tests/test_mapping.py @@ -1,9 +1,8 @@ import pytest import xarray as xr from test_datatree import assert_tree_equal, create_test_datatree -from xarray.testing import assert_equal -from datatree.datatree import DataNode, DataTree +from datatree.datatree import DataTree from datatree.mapping import TreeIsomorphismError, _check_isomorphic, map_over_subtree from datatree.treenode import TreeNode @@ -91,17 +90,36 @@ def test_not_isomorphic_complex_tree(self): class TestMapOverSubTree: - @pytest.mark.xfail def test_no_trees_passed(self): - raise NotImplementedError + @map_over_subtree + def times_ten(ds): + return 10.0 * ds + + with pytest.raises(TypeError, match="Must pass at least one tree"): + times_ten("dt") - @pytest.mark.xfail def test_not_isomorphic(self): - raise NotImplementedError + dt1 = create_test_datatree() + dt2 = create_test_datatree() + dt2["set4"] = None + + @map_over_subtree + def times_ten(ds1, ds2): + return ds1 * ds2 + + with pytest.raises(TreeIsomorphismError): + times_ten(dt1, dt2) - @pytest.mark.xfail def test_no_trees_returned(self): - raise NotImplementedError + dt1 = create_test_datatree() + dt2 = create_test_datatree() + + @map_over_subtree + def bad_func(ds1, ds2): + return None + + with pytest.raises(TypeError, match="return value of None"): + bad_func(dt1, dt2) def test_single_dt_arg(self): dt = create_test_datatree() @@ -110,8 +128,8 @@ def test_single_dt_arg(self): def times_ten(ds): return 10.0 * ds - result_tree = times_ten(dt) expected = create_test_datatree(modify=lambda ds: 10.0 * ds) + result_tree = times_ten(dt) assert_tree_equal(result_tree, expected) def test_single_dt_arg_plus_args_and_kwargs(self): @@ -119,43 +137,109 @@ def test_single_dt_arg_plus_args_and_kwargs(self): @map_over_subtree def multiply_then_add(ds, times, add=0.0): - return times * ds + add + return (times * ds) + add - result_tree = multiply_then_add(dt, 10.0, add=2.0) expected = create_test_datatree(modify=lambda ds: (10.0 * ds) + 2.0) + result_tree = multiply_then_add(dt, 10.0, add=2.0) assert_tree_equal(result_tree, expected) - @pytest.mark.xfail def test_multiple_dt_args(self): - ds = xr.Dataset({"a": ("x", [1, 2, 3])}) - dt = DataNode("root", data=ds) - DataNode("results", data=ds + 0.2, parent=dt) + dt1 = create_test_datatree() + dt2 = create_test_datatree() @map_over_subtree def add(ds1, ds2): return ds1 + ds2 - expected = DataNode("root", data=ds * 2) - DataNode("results", data=(ds + 0.2) * 2, parent=expected) + expected = create_test_datatree(modify=lambda ds: 2.0 * ds) + result = add(dt1, dt2) + assert_tree_equal(result, expected) - result = add(dt, dt) + def test_dt_as_kwarg(self): + dt1 = create_test_datatree() + dt2 = create_test_datatree() - # dt1 = create_test_datatree() - # dt2 = create_test_datatree() - # expected = create_test_datatree(modify=lambda ds: 2 * ds) + @map_over_subtree + def add(ds1, value=0.0): + return ds1 + value + expected = create_test_datatree(modify=lambda ds: 2.0 * ds) + result = add(dt1, value=dt2) assert_tree_equal(result, expected) - @pytest.mark.xfail - def test_dt_as_kwarg(self): - raise NotImplementedError + def test_return_multiple_dts(self): + dt = create_test_datatree() + + @map_over_subtree + def minmax(ds): + return ds.min(), ds.max() + + dt_min, dt_max = minmax(dt) + expected_min = create_test_datatree(modify=lambda ds: ds.min()) + assert_tree_equal(dt_min, expected_min) + expected_max = create_test_datatree(modify=lambda ds: ds.max()) + assert_tree_equal(dt_max, expected_max) + + def test_return_wrong_type(self): + dt1 = create_test_datatree() + + @map_over_subtree + def bad_func(ds1): + return "string" + + with pytest.raises(TypeError, match="not Dataset or DataArray"): + bad_func(dt1) + + def test_return_tuple_of_wrong_types(self): + dt1 = create_test_datatree() + + @map_over_subtree + def bad_func(ds1): + return xr.Dataset(), "string" + + with pytest.raises(TypeError, match="not Dataset or DataArray"): + bad_func(dt1) @pytest.mark.xfail - def test_return_multiple_dts(self): - raise NotImplementedError + def test_return_inconsistent_number_of_results(self): + dt1 = create_test_datatree() + + @map_over_subtree + def bad_func(ds): + # Datasets in create_test_datatree() have different numbers of dims + # TODO need to instead return different numbers of Dataset objects for this test to catch the intended error + return tuple(ds.dims) + + with pytest.raises(TypeError, match="instead returns"): + bad_func(dt1) + + def test_wrong_number_of_arguments_for_func(self): + dt = create_test_datatree() + + @map_over_subtree + def times_ten(ds): + return 10.0 * ds + + with pytest.raises( + TypeError, match="takes 1 positional argument but 2 were given" + ): + times_ten(dt, dt) + + def test_map_single_dataset_against_whole_tree(self): + dt = create_test_datatree() + + @map_over_subtree + def nodewise_merge(node_ds, fixed_ds): + return xr.merge([node_ds, fixed_ds]) + + other_ds = xr.Dataset({"z": ("z", [0])}) + expected = create_test_datatree(modify=lambda ds: xr.merge([ds, other_ds])) + result_tree = nodewise_merge(dt, other_ds) + assert_tree_equal(result_tree, expected) @pytest.mark.xfail - def test_return_no_dts(self): + def test_trees_with_different_node_names(self): + # TODO test this after I've got good tests for renaming nodes raise NotImplementedError def test_dt_method(self): @@ -164,18 +248,9 @@ def test_dt_method(self): def multiply_then_add(ds, times, add=0.0): return times * ds + add + expected = create_test_datatree(modify=lambda ds: (10.0 * ds) + 2.0) result_tree = dt.map_over_subtree(multiply_then_add, 10.0, add=2.0) - - for ( - result_node, - original_node, - ) in zip(result_tree.subtree, dt.subtree): - assert isinstance(result_node, DataTree) - - if original_node.has_data: - assert_equal(result_node.ds, (original_node.ds * 10.0) + 2.0) - else: - assert not result_node.has_data + assert_tree_equal(result_tree, expected) @pytest.mark.xfail