Skip to content
This repository has been archived by the owner on Oct 24, 2024. It is now read-only.

Add API methods in class definition #19

Merged
merged 1 commit into from
Aug 24, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 52 additions & 62 deletions datatree/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from xarray.core.variable import Variable
from xarray.core.combine import merge
from xarray.core import dtypes, utils
from xarray.core._typed_ops import DatasetOpsMixin

from .treenode import TreeNode, PathType, _init_single_treenode

Expand Down Expand Up @@ -188,7 +187,7 @@ def imag(self):
else:
raise AttributeError("property is not defined for a node with no data")

# TODO .loc
# TODO .loc, __contains__, __iter__, __array__, '__len__',

dims.__doc__ = Dataset.dims.__doc__
variables.__doc__ = Dataset.variables.__doc__
Expand All @@ -207,68 +206,71 @@ def imag(self):
"See the `map_over_subtree` decorator for more details.", width=117)


def _expose_methods_wrapped_to_map_over_subtree(obj, method_name, method):
def _wrap_then_attach_to_cls(cls_dict, methods_to_expose, wrap_func=None):
"""
Expose given method on node object, but wrapped to map over whole subtree, not just that node object.

Result is like having written this in obj's class definition:
Attach given methods on a class, and optionally wrap each method first. (i.e. with map_over_subtree)

Result is like having written this in the classes' definition:
```
@map_over_subtree
@wrap_func
def method_name(self, *args, **kwargs):
return self.method(*args, **kwargs)
```
"""

# Expose Dataset method, but wrapped to map over whole subtree when called
# TODO should we be using functools.partialmethod here instead?
mapped_over_tree = functools.partial(map_over_subtree(method), obj)
setattr(obj, method_name, mapped_over_tree)

# TODO do we really need this for ops like __add__?
# Add a line to the method's docstring explaining how it's been mapped
method_docstring = method.__doc__
if method_docstring is not None:
updated_method_docstring = method_docstring.replace('\n', _MAPPED_DOCSTRING_ADDENDUM, 1)
obj_method = getattr(obj, method_name)
setattr(obj_method, '__doc__', updated_method_docstring)

Parameters
----------
cls_dict
The __dict__ attribute of a class, which can also be accessed by calling vars() from within that classes'
definition.
methods_to_expose : Iterable[Tuple[str, callable]]
The method names and definitions supplied as a list of (method_name_string, method) pairs.\
This format matches the output of inspect.getmembers().
"""
for method_name, method in methods_to_expose:
wrapped_method = wrap_func(method) if wrap_func is not None else method
cls_dict[method_name] = wrapped_method

# TODO equals, broadcast_equals etc.
# TODO do dask-related private methods need to be exposed?
_DATASET_DASK_METHODS_TO_EXPOSE = ['load', 'compute', 'persist', 'unify_chunks', 'chunk', 'map_blocks']
_DATASET_METHODS_TO_EXPOSE = ['copy', 'as_numpy', '__copy__', '__deepcopy__', '__contains__', '__len__',
'__bool__', '__iter__', '__array__', 'set_coords', 'reset_coords', 'info',
'isel', 'sel', 'head', 'tail', 'thin', 'broadcast_like', 'reindex_like',
'reindex', 'interp', 'interp_like', 'rename', 'rename_dims', 'rename_vars',
'swap_dims', 'expand_dims', 'set_index', 'reset_index', 'reorder_levels', 'stack',
'unstack', 'update', 'merge', 'drop_vars', 'drop_sel', 'drop_isel', 'drop_dims',
'transpose', 'dropna', 'fillna', 'interpolate_na', 'ffill', 'bfill', 'combine_first',
'reduce', 'map', 'assign', 'diff', 'shift', 'roll', 'sortby', 'quantile', 'rank',
'differentiate', 'integrate', 'cumulative_integrate', 'filter_by_attrs', 'polyfit',
'pad', 'idxmin', 'idxmax', 'argmin', 'argmax', 'query', 'curvefit']
_DATASET_OPS_TO_EXPOSE = ['_unary_op', '_binary_op', '_inplace_binary_op']
_ALL_DATASET_METHODS_TO_EXPOSE = _DATASET_DASK_METHODS_TO_EXPOSE + _DATASET_METHODS_TO_EXPOSE + _DATASET_OPS_TO_EXPOSE

# TODO methods which should not or cannot act over the whole tree, such as .to_array

# TODO do we really need this for ops like __add__?
# Add a line to the method's docstring explaining how it's been mapped
method_docstring = method.__doc__
if method_docstring is not None:
updated_method_docstring = method_docstring.replace('\n', _MAPPED_DOCSTRING_ADDENDUM, 1)
setattr(cls_dict[method_name], '__doc__', updated_method_docstring)

class DatasetMethodsMixin:
"""Mixin to add Dataset methods like .mean(), but wrapped to map over all nodes in the subtree."""

# TODO is there a way to put this code in the class definition so we don't have to specifically call this method?
def _add_dataset_methods(self):
methods_to_expose = [(method_name, getattr(Dataset, method_name))
for method_name in _ALL_DATASET_METHODS_TO_EXPOSE]
class MappedDatasetMethodsMixin:
"""
Mixin to add Dataset methods like .mean(), but wrapped to map over all nodes in the subtree.

for method_name, method in methods_to_expose:
_expose_methods_wrapped_to_map_over_subtree(self, method_name, method)
Every method wrapped here needs to have a return value of Dataset or DataArray in order to construct a new tree.
"""
# TODO equals, broadcast_equals etc.
# TODO do dask-related private methods need to be exposed?
_DATASET_DASK_METHODS_TO_EXPOSE = ['load', 'compute', 'persist', 'unify_chunks', 'chunk', 'map_blocks']
_DATASET_METHODS_TO_EXPOSE = ['copy', 'as_numpy', '__copy__', '__deepcopy__', 'set_coords', 'reset_coords', 'info',
'isel', 'sel', 'head', 'tail', 'thin', 'broadcast_like', 'reindex_like',
'reindex', 'interp', 'interp_like', 'rename', 'rename_dims', 'rename_vars',
'swap_dims', 'expand_dims', 'set_index', 'reset_index', 'reorder_levels', 'stack',
'unstack', 'update', 'merge', 'drop_vars', 'drop_sel', 'drop_isel', 'drop_dims',
'transpose', 'dropna', 'fillna', 'interpolate_na', 'ffill', 'bfill', 'combine_first',
'reduce', 'map', 'assign', 'diff', 'shift', 'roll', 'sortby', 'quantile', 'rank',
'differentiate', 'integrate', 'cumulative_integrate', 'filter_by_attrs', 'polyfit',
'pad', 'idxmin', 'idxmax', 'argmin', 'argmax', 'query', 'curvefit']
# TODO unsure if these are called by external functions or not?
_DATASET_OPS_TO_EXPOSE = ['_unary_op', '_binary_op', '_inplace_binary_op']
_ALL_DATASET_METHODS_TO_EXPOSE = _DATASET_DASK_METHODS_TO_EXPOSE + _DATASET_METHODS_TO_EXPOSE + _DATASET_OPS_TO_EXPOSE

# TODO methods which should not or cannot act over the whole tree, such as .to_array

methods_to_expose = [(method_name, getattr(Dataset, method_name))
for method_name in _ALL_DATASET_METHODS_TO_EXPOSE]
_wrap_then_attach_to_cls(vars(), methods_to_expose, wrap_func=map_over_subtree)


# TODO implement ArrayReduce type methods


class DataTree(TreeNode, DatasetPropertiesMixin, DatasetMethodsMixin):
class DataTree(TreeNode, DatasetPropertiesMixin, MappedDatasetMethodsMixin):
"""
A tree-like hierarchical collection of xarray objects.

Expand Down Expand Up @@ -339,15 +341,6 @@ def __init__(
new_node = self.get_node(path)
new_node[path] = data

# TODO this has to be
self._add_all_dataset_api()

def _add_all_dataset_api(self):
# Add methods like .isel(), but wrapped to map over subtrees
self._add_dataset_methods()

# TODO add dataset ops here

@property
def ds(self) -> Dataset:
return self._ds
Expand Down Expand Up @@ -396,9 +389,6 @@ def _init_single_datatree_node(
obj = object.__new__(cls)
obj = _init_single_treenode(obj, name=name, parent=parent, children=children)
obj.ds = data

obj._add_all_dataset_api()

return obj

def __str__(self):
Expand Down Expand Up @@ -435,7 +425,7 @@ def _single_node_repr(self):
def __repr__(self):
"""Information about this node, including its relationships to other nodes."""
# TODO redo this to look like the Dataset repr, but just with child and parent info
parent = self.parent.name if self.parent else "None"
parent = self.parent.name if self.parent is not None else "None"
node_str = f"DataNode(name='{self.name}', parent='{parent}', children={[c.name for c in self.children]},"

if self.has_data:
Expand Down Expand Up @@ -554,7 +544,7 @@ def __setitem__(
except anytree.resolver.ResolverError:
existing_node = None

if existing_node:
if existing_node is not None:
if isinstance(value, Dataset):
# replace whole dataset
existing_node.ds = Dataset
Expand Down