diff --git a/datatree/datatree.py b/datatree/datatree.py index b2cc8602..081e7117 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -1,7 +1,6 @@ from __future__ import annotations import functools import textwrap -import inspect from typing import Mapping, Hashable, Union, List, Any, Callable, Iterable, Dict @@ -12,6 +11,9 @@ from xarray.core.variable import Variable from xarray.core.combine import merge from xarray.core import dtypes, utils +from xarray.core.common import DataWithCoords +from xarray.core.arithmetic import DatasetArithmetic +from xarray.core.ops import REDUCE_METHODS, NAN_REDUCE_METHODS, NAN_CUM_METHODS from .treenode import TreeNode, PathType, _init_single_treenode @@ -46,7 +48,8 @@ def map_over_subtree(func): The function will be applied to any dataset stored in this node, as well as any dataset stored in any of the descendant nodes. The returned tree will have the same structure as the original subtree. - func needs to return a Dataset in order to rebuild the subtree. + func needs to return a Dataset, DataArray, or None in order to be able to rebuild the subtree after mapping, as each + result will be assigned to its respective node of new tree via `DataTree.__setitem__`. Parameters ---------- @@ -132,7 +135,6 @@ def attrs(self): else: raise AttributeError("property is not defined for a node with no data") - @property def nbytes(self) -> int: return sum(node.ds.nbytes for node in self.subtree_nodes) @@ -203,10 +205,33 @@ def imag(self): _MAPPED_DOCSTRING_ADDENDUM = textwrap.fill("This method was copied from xarray.Dataset, but has been altered to " "call the method on the Datasets stored in every node of the subtree. " - "See the `map_over_subtree` decorator for more details.", width=117) - - -def _wrap_then_attach_to_cls(cls_dict, methods_to_expose, wrap_func=None): + "See the `map_over_subtree` function for more details.", width=117) + +# TODO equals, broadcast_equals etc. +# TODO do dask-related private methods need to be exposed? +_DATASET_DASK_METHODS_TO_MAP = ['load', 'compute', 'persist', 'unify_chunks', 'chunk', 'map_blocks'] +_DATASET_METHODS_TO_MAP = ['copy', 'as_numpy', '__copy__', '__deepcopy__', 'set_coords', 'reset_coords', 'info', + 'isel', 'sel', 'head', 'tail', 'thin', 'broadcast_like', 'reindex_like', + 'reindex', 'interp', 'interp_like', 'rename', 'rename_dims', 'rename_vars', + 'swap_dims', 'expand_dims', 'set_index', 'reset_index', 'reorder_levels', 'stack', + 'unstack', 'update', 'merge', 'drop_vars', 'drop_sel', 'drop_isel', 'drop_dims', + 'transpose', 'dropna', 'fillna', 'interpolate_na', 'ffill', 'bfill', 'combine_first', + 'reduce', 'map', 'assign', 'diff', 'shift', 'roll', 'sortby', 'quantile', 'rank', + 'differentiate', 'integrate', 'cumulative_integrate', 'filter_by_attrs', 'polyfit', + 'pad', 'idxmin', 'idxmax', 'argmin', 'argmax', 'query', 'curvefit'] +# TODO unsure if these are called by external functions or not? +_DATASET_OPS_TO_MAP = ['_unary_op', '_binary_op', '_inplace_binary_op'] +_ALL_DATASET_METHODS_TO_MAP = _DATASET_DASK_METHODS_TO_MAP + _DATASET_METHODS_TO_MAP + _DATASET_OPS_TO_MAP + +_DATA_WITH_COORDS_METHODS_TO_MAP = ['squeeze', 'clip', 'assign_coords', 'where', 'close', 'isnull', 'notnull', + 'isin', 'astype'] + +# TODO NUM_BINARY_OPS apparently aren't defined on DatasetArithmetic, and don't appear to be injected anywhere... +#['__array_ufunc__'] \ +_ARITHMETIC_METHODS_TO_WRAP = REDUCE_METHODS + NAN_REDUCE_METHODS + NAN_CUM_METHODS + + +def _wrap_then_attach_to_cls(target_cls_dict, source_cls, methods_to_set, wrap_func=None): """ Attach given methods on a class, and optionally wrap each method first. (i.e. with map_over_subtree) @@ -219,23 +244,32 @@ def method_name(self, *args, **kwargs): Parameters ---------- - cls_dict - The __dict__ attribute of a class, which can also be accessed by calling vars() from within that classes' - definition. - methods_to_expose : Iterable[Tuple[str, callable]] - The method names and definitions supplied as a list of (method_name_string, method) pairs.\ + target_cls_dict : MappingProxy + The __dict__ attribute of the class which we want the methods to be added to. (The __dict__ attribute can also + be accessed by calling vars() from within that classes' definition.) This will be updated by this function. + source_cls : class + Class object from which we want to copy methods (and optionally wrap them). Should be the actual class object + (or instance), not just the __dict__. + methods_to_set : Iterable[Tuple[str, callable]] + The method names and definitions supplied as a list of (method_name_string, method) pairs. This format matches the output of inspect.getmembers(). + wrap_func : callable, optional + Function to decorate each method with. Must have the same return type as the method. """ - for method_name, method in methods_to_expose: - wrapped_method = wrap_func(method) if wrap_func is not None else method - cls_dict[method_name] = wrapped_method - - # TODO do we really need this for ops like __add__? - # Add a line to the method's docstring explaining how it's been mapped - method_docstring = method.__doc__ - if method_docstring is not None: - updated_method_docstring = method_docstring.replace('\n', _MAPPED_DOCSTRING_ADDENDUM, 1) - setattr(cls_dict[method_name], '__doc__', updated_method_docstring) + for method_name in methods_to_set: + orig_method = getattr(source_cls, method_name) + wrapped_method = wrap_func(orig_method) if wrap_func is not None else orig_method + target_cls_dict[method_name] = wrapped_method + + if wrap_func is map_over_subtree: + # Add a paragraph to the method's docstring explaining how it's been mapped + orig_method_docstring = orig_method.__doc__ + if orig_method_docstring is not None: + if '\n' in orig_method_docstring: + new_method_docstring = orig_method_docstring.replace('\n', _MAPPED_DOCSTRING_ADDENDUM, 1) + else: + new_method_docstring = orig_method_docstring + f"\n\n{_MAPPED_DOCSTRING_ADDENDUM}" + setattr(target_cls_dict[method_name], '__doc__', new_method_docstring) class MappedDatasetMethodsMixin: @@ -244,33 +278,28 @@ class MappedDatasetMethodsMixin: Every method wrapped here needs to have a return value of Dataset or DataArray in order to construct a new tree. """ - # TODO equals, broadcast_equals etc. - # TODO do dask-related private methods need to be exposed? - _DATASET_DASK_METHODS_TO_EXPOSE = ['load', 'compute', 'persist', 'unify_chunks', 'chunk', 'map_blocks'] - _DATASET_METHODS_TO_EXPOSE = ['copy', 'as_numpy', '__copy__', '__deepcopy__', 'set_coords', 'reset_coords', 'info', - 'isel', 'sel', 'head', 'tail', 'thin', 'broadcast_like', 'reindex_like', - 'reindex', 'interp', 'interp_like', 'rename', 'rename_dims', 'rename_vars', - 'swap_dims', 'expand_dims', 'set_index', 'reset_index', 'reorder_levels', 'stack', - 'unstack', 'update', 'merge', 'drop_vars', 'drop_sel', 'drop_isel', 'drop_dims', - 'transpose', 'dropna', 'fillna', 'interpolate_na', 'ffill', 'bfill', 'combine_first', - 'reduce', 'map', 'assign', 'diff', 'shift', 'roll', 'sortby', 'quantile', 'rank', - 'differentiate', 'integrate', 'cumulative_integrate', 'filter_by_attrs', 'polyfit', - 'pad', 'idxmin', 'idxmax', 'argmin', 'argmax', 'query', 'curvefit'] - # TODO unsure if these are called by external functions or not? - _DATASET_OPS_TO_EXPOSE = ['_unary_op', '_binary_op', '_inplace_binary_op'] - _ALL_DATASET_METHODS_TO_EXPOSE = _DATASET_DASK_METHODS_TO_EXPOSE + _DATASET_METHODS_TO_EXPOSE + _DATASET_OPS_TO_EXPOSE + __slots__ = () + _wrap_then_attach_to_cls(vars(), Dataset, _ALL_DATASET_METHODS_TO_MAP, wrap_func=map_over_subtree) - # TODO methods which should not or cannot act over the whole tree, such as .to_array - methods_to_expose = [(method_name, getattr(Dataset, method_name)) - for method_name in _ALL_DATASET_METHODS_TO_EXPOSE] - _wrap_then_attach_to_cls(vars(), methods_to_expose, wrap_func=map_over_subtree) +class MappedDataWithCoords(DataWithCoords): + # TODO add mapped versions of groupby, weighted, rolling, rolling_exp, coarsen, resample + # TODO re-implement AttrsAccessMixin stuff so that it includes access to child nodes + _wrap_then_attach_to_cls(vars(), DataWithCoords, _DATA_WITH_COORDS_METHODS_TO_MAP, wrap_func=map_over_subtree) -# TODO implement ArrayReduce type methods +class DataTreeArithmetic(DatasetArithmetic): + """ + Mixin to add Dataset methods like __add__ and .mean() + + Some of these methods must be wrapped to map over all nodes in the subtree. Others are fine unaltered (normally + because they (a) only call dataset properties and (b) don't return a dataset that should be nested into a new + tree) and some will get overridden by the class definition of DataTree. + """ + _wrap_then_attach_to_cls(vars(), DatasetArithmetic, _ARITHMETIC_METHODS_TO_WRAP, wrap_func=map_over_subtree) -class DataTree(TreeNode, DatasetPropertiesMixin, MappedDatasetMethodsMixin): +class DataTree(TreeNode, DatasetPropertiesMixin, MappedDatasetMethodsMixin, MappedDataWithCoords, DataTreeArithmetic): """ A tree-like hierarchical collection of xarray objects. @@ -312,6 +341,8 @@ class DataTree(TreeNode, DatasetPropertiesMixin, MappedDatasetMethodsMixin): # TODO currently allows self.ds = None, should we instead always store at least an empty Dataset? + # TODO dataset methods which should not or cannot act over the whole tree, such as .to_array + def __init__( self, data_objects: Dict[PathType, Union[Dataset, DataArray]] = None, diff --git a/datatree/tests/test_dataset_api.py b/datatree/tests/test_dataset_api.py index c6d0f150..20cac079 100644 --- a/datatree/tests/test_dataset_api.py +++ b/datatree/tests/test_dataset_api.py @@ -80,7 +80,6 @@ def test_properties(self): assert dt.sizes == dt.ds.sizes assert dt.variables == dt.ds.variables - def test_no_data_no_properties(self): dt = DataNode('root', data=None) with pytest.raises(AttributeError): @@ -96,24 +95,73 @@ def test_no_data_no_properties(self): class TestDSMethodInheritance: - def test_root(self): + def test_dataset_method(self): + # test root da = xr.DataArray(name='a', data=[1, 2, 3], dims='x') dt = DataNode('root', data=da) expected_ds = da.to_dataset().isel(x=1) result_ds = dt.isel(x=1).ds assert_equal(result_ds, expected_ds) - def test_descendants(self): - da = xr.DataArray(name='a', data=[1, 2, 3], dims='x') - dt = DataNode('root') + # test descendant DataNode('results', parent=dt, data=da) - expected_ds = da.to_dataset().isel(x=1) result_ds = dt.isel(x=1)['results'].ds assert_equal(result_ds, expected_ds) + def test_reduce_method(self): + # test root + da = xr.DataArray(name='a', data=[False, True, False], dims='x') + dt = DataNode('root', data=da) + expected_ds = da.to_dataset().any() + result_ds = dt.any().ds + assert_equal(result_ds, expected_ds) + + # test descendant + DataNode('results', parent=dt, data=da) + result_ds = dt.any()['results'].ds + assert_equal(result_ds, expected_ds) + + def test_nan_reduce_method(self): + # test root + da = xr.DataArray(name='a', data=[1, 2, 3], dims='x') + dt = DataNode('root', data=da) + expected_ds = da.to_dataset().mean() + result_ds = dt.mean().ds + assert_equal(result_ds, expected_ds) + + # test descendant + DataNode('results', parent=dt, data=da) + result_ds = dt.mean()['results'].ds + assert_equal(result_ds, expected_ds) + + def test_cum_method(self): + # test root + da = xr.DataArray(name='a', data=[1, 2, 3], dims='x') + dt = DataNode('root', data=da) + expected_ds = da.to_dataset().cumsum() + result_ds = dt.cumsum().ds + assert_equal(result_ds, expected_ds) + + # test descendant + DataNode('results', parent=dt, data=da) + result_ds = dt.cumsum()['results'].ds + assert_equal(result_ds, expected_ds) + class TestOps: - ... + @pytest.mark.xfail + def test_binary_op(self): + ds1 = xr.Dataset({'a': [5], 'b': [3]}) + ds2 = xr.Dataset({'x': [0.1, 0.2], 'y': [10, 20]}) + dt = DataNode('root', data=ds1) + DataNode('subnode', data=ds2, parent=dt) + + expected_root = DataNode('root', data=ds1*ds1) + expected_descendant = DataNode('subnode', data=ds2*ds2, parent=expected_root) + result = dt * dt + + assert_equal(result.ds, expected_root.ds) + assert_equal(result['subnode'].ds, expected_descendant.ds) @pytest.mark.xfail