Skip to content
This repository has been archived by the owner on Oct 24, 2024. It is now read-only.

Commit

Permalink
Map over multiple subtrees (#32)
Browse files Browse the repository at this point in the history
* pseudocode ideas for generalizing map_over_subtree

* pseudocode for a generalized map_over_subtree (still only one return arg) + a new mapping.py file

* pseudocode for mapping but now multiple return values

* pseudocode for mapping but with multiple return values

* check_isomorphism works and has tests

* cleaned up the mapping tests a bit

* tests for mapping over multiple trees

* incorrect pseudocode attempt to map over multiple subtrees

* small improvements

* fixed test

* zipping of multiple arguments

* passes for mapping over a single tree

* successfully maps over multiple trees

* successfully returns multiple trees

* filled out all tests

* checking types now works for trees with only one node

* improved docstring
  • Loading branch information
TomNicholas authored Sep 2, 2021
1 parent 3f68eea commit 84d4814
Show file tree
Hide file tree
Showing 4 changed files with 283 additions and 70 deletions.
4 changes: 3 additions & 1 deletion datatree/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,10 +424,12 @@ def __init__(
else:
node_path, node_name = "/", path

relative_path = node_path.replace(self.name, "")

# Create and set new node
new_node = DataNode(name=node_name, data=data)
self.set_node(
node_path,
relative_path,
new_node,
allow_overwrite=False,
new_nodes_along_path=True,
Expand Down
192 changes: 166 additions & 26 deletions datatree/mapping.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import functools
from itertools import repeat

from anytree.iterators import LevelOrderIter
from xarray import DataArray, Dataset

from .treenode import TreeNode

Expand Down Expand Up @@ -43,11 +45,11 @@ def _check_isomorphic(subtree_a, subtree_b, require_names_equal=False):

if not isinstance(subtree_a, TreeNode):
raise TypeError(
f"Argument `subtree_a is not a tree, it is of type {type(subtree_a)}"
f"Argument `subtree_a` is not a tree, it is of type {type(subtree_a)}"
)
if not isinstance(subtree_b, TreeNode):
raise TypeError(
f"Argument `subtree_b is not a tree, it is of type {type(subtree_b)}"
f"Argument `subtree_b` is not a tree, it is of type {type(subtree_b)}"
)

# Walking nodes in "level-order" fashion means walking down from the root breadth-first.
Expand Down Expand Up @@ -83,57 +85,195 @@ def _check_isomorphic(subtree_a, subtree_b, require_names_equal=False):

def map_over_subtree(func):
"""
Decorator which turns a function which acts on (and returns) single Datasets into one which acts on DataTrees.
Decorator which turns a function which acts on (and returns) Datasets into one which acts on and returns DataTrees.
Applies a function to every dataset in this subtree, returning a new tree which stores the results.
Applies a function to every dataset in one or more subtrees, returning new trees which store the results.
The function will be applied to any dataset stored in this node, as well as any dataset stored in any of the
descendant nodes. The returned tree will have the same structure as the original subtree.
The function will be applied to any dataset stored in any of the nodes in the trees. The returned trees will have
the same structure as the supplied trees.
func needs to return a Dataset, DataArray, or None in order to be able to rebuild the subtree after mapping, as each
result will be assigned to its respective node of new tree via `DataTree.__setitem__`.
`func` needs to return one Datasets, DataArrays, or None in order to be able to rebuild the subtrees after
mapping, as each result will be assigned to its respective node of a new tree via `DataTree.__setitem__`. Any
returned value that is one of these types will be stacked into a separate tree before returning all of them.
The trees passed to the resulting function must all be isomorphic to one another. Their nodes need not be named
similarly, but all the output trees will have nodes named in the same way as the first tree passed.
Parameters
----------
func : callable
Function to apply to datasets with signature:
`func(node.ds, *args, **kwargs) -> Dataset`.
`func(*args, **kwargs) -> Union[Dataset, Iterable[Dataset]]`.
(i.e. func must accept at least one Dataset and return at least one Dataset.)
Function will not be applied to any nodes without datasets.
*args : tuple, optional
Positional arguments passed on to `func`.
Positional arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets \
via .ds .
**kwargs : Any
Keyword arguments passed on to `func`.
Keyword arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets
via .ds .
Returns
-------
mapped : callable
Wrapped function which returns tree created from results of applying ``func`` to the dataset at each node.
Wrapped function which returns one or more tree(s) created from results of applying ``func`` to the dataset at
each node.
See also
--------
DataTree.map_over_subtree
DataTree.map_over_subtree_inplace
DataTree.subtree
"""

# TODO examples in the docstring

# TODO inspect function to work out immediately if the wrong number of arguments were passed for it?

@functools.wraps(func)
def _map_over_subtree(tree, *args, **kwargs):
def _map_over_subtree(*args, **kwargs):
"""Internal function which maps func over every node in tree, returning a tree of the results."""
from .datatree import DataTree

all_tree_inputs = [a for a in args if isinstance(a, DataTree)] + [
a for a in kwargs.values() if isinstance(a, DataTree)
]

if len(all_tree_inputs) > 0:
first_tree, *other_trees = all_tree_inputs
else:
raise TypeError("Must pass at least one tree object")

for other_tree in other_trees:
# isomorphism is transitive so this is enough to guarantee all trees are mutually isomorphic
_check_isomorphic(first_tree, other_tree, require_names_equal=False)

# Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees
# We don't know which arguments are DataTrees so we zip all arguments together as iterables
# Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return
out_data_objects = {}
args_as_tree_length_iterables = [
a.subtree if isinstance(a, DataTree) else repeat(a) for a in args
]
n_args = len(args_as_tree_length_iterables)
kwargs_as_tree_length_iterables = {
k: v.subtree if isinstance(v, DataTree) else repeat(v)
for k, v in kwargs.items()
}
for node_of_first_tree, *all_node_args in zip(
first_tree.subtree,
*args_as_tree_length_iterables,
*list(kwargs_as_tree_length_iterables.values()),
):
node_args_as_datasets = [
a.ds if isinstance(a, DataTree) else a for a in all_node_args[:n_args]
]
node_kwargs_as_datasets = dict(
zip(
[k for k in kwargs_as_tree_length_iterables.keys()],
[
v.ds if isinstance(v, DataTree) else v
for v in all_node_args[n_args:]
],
)
)

# Recreate and act on root node
from .datatree import DataNode
# Now we can call func on the data in this particular set of corresponding nodes
results = (
func(*node_args_as_datasets, **node_kwargs_as_datasets)
if node_of_first_tree.has_data
else None
)

out_tree = DataNode(name=tree.name, data=tree.ds)
if out_tree.has_data:
out_tree.ds = func(out_tree.ds, *args, **kwargs)
# TODO implement mapping over multiple trees in-place using if conditions from here on?
out_data_objects[node_of_first_tree.pathstr] = results

# Find out how many return values we received
num_return_values = _check_all_return_values(out_data_objects)

# Reconstruct 1+ subtrees from the dict of results, by filling in all nodes of all result trees
result_trees = []
for i in range(num_return_values):
out_tree_contents = {}
for n in first_tree.subtree:
p = n.pathstr
if p in out_data_objects.keys():
if isinstance(out_data_objects[p], tuple):
output_node_data = out_data_objects[p][i]
else:
output_node_data = out_data_objects[p]
else:
output_node_data = None
out_tree_contents[p] = output_node_data

new_tree = DataTree(name=first_tree.name, data_objects=out_tree_contents)
result_trees.append(new_tree)

# If only one result then don't wrap it in a tuple
if len(result_trees) == 1:
return result_trees[0]
else:
return tuple(result_trees)

# Act on every other node in the tree, and rebuild from results
for node in tree.descendants:
# TODO make a proper relative_path method
relative_path = node.pathstr.replace(tree.pathstr, "")
result = func(node.ds, *args, **kwargs) if node.has_data else None
out_tree[relative_path] = result
return _map_over_subtree

return out_tree

return _map_over_subtree
def _check_single_set_return_values(path_to_node, obj):
"""Check types returned from single evaluation of func, and return number of return values received from func."""
if isinstance(obj, (Dataset, DataArray)):
return 1
elif isinstance(obj, tuple):
for r in obj:
if not isinstance(r, (Dataset, DataArray)):
raise TypeError(
f"One of the results of calling func on datasets on the nodes at position {path_to_node} is "
f"of type {type(r)}, not Dataset or DataArray."
)
return len(obj)
else:
raise TypeError(
f"The result of calling func on the node at position {path_to_node} is of type {type(obj)}, not "
f"Dataset or DataArray, nor a tuple of such types."
)


def _check_all_return_values(returned_objects):
"""Walk through all values returned by mapping func over subtrees, raising on any invalid or inconsistent types."""

if all(r is None for r in returned_objects.values()):
raise TypeError(
"Called supplied function on all nodes but found a return value of None for"
"all of them."
)

result_data_objects = [
(path_to_node, r)
for path_to_node, r in returned_objects.items()
if r is not None
]

if len(result_data_objects) == 1:
# Only one node in the tree: no need to check consistency of results between nodes
path_to_node, result = result_data_objects[0]
num_return_values = _check_single_set_return_values(path_to_node, result)
else:
prev_path, _ = result_data_objects[0]
prev_num_return_values, num_return_values = None, None
for path_to_node, obj in result_data_objects[1:]:
num_return_values = _check_single_set_return_values(path_to_node, obj)

if (
num_return_values != prev_num_return_values
and prev_num_return_values is not None
):
raise TypeError(
f"Calling func on the nodes at position {path_to_node} returns {num_return_values} separate return "
f"values, whereas calling func on the nodes at position {prev_path} instead returns "
f"{prev_num_return_values} separate return values."
)

prev_path, prev_num_return_values = path_to_node, num_return_values

return num_return_values
6 changes: 1 addition & 5 deletions datatree/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@


def assert_tree_equal(dt_a, dt_b):
assert dt_a.name == dt_b.name
assert dt_a.parent is dt_b.parent

assert dt_a.ds.equals(dt_b.ds)
for a, b in zip(dt_a.descendants, dt_b.descendants):
for a, b in zip(dt_a.subtree, dt_b.subtree):
assert a.name == b.name
assert a.pathstr == b.pathstr
if a.has_data:
Expand Down Expand Up @@ -321,7 +319,6 @@ def test_to_netcdf(self, tmpdir):
original_dt.to_netcdf(filepath, engine="netcdf4")

roundtrip_dt = open_datatree(filepath)

assert_tree_equal(original_dt, roundtrip_dt)

def test_to_zarr(self, tmpdir):
Expand All @@ -332,5 +329,4 @@ def test_to_zarr(self, tmpdir):
original_dt.to_zarr(filepath)

roundtrip_dt = open_datatree(filepath, engine="zarr")

assert_tree_equal(original_dt, roundtrip_dt)
Loading

0 comments on commit 84d4814

Please sign in to comment.