diff --git a/datatree/__init__.py b/datatree/__init__.py index d799dc02..58b65aec 100644 --- a/datatree/__init__.py +++ b/datatree/__init__.py @@ -6,7 +6,7 @@ # import public API from .datatree import DataTree from .io import open_datatree -from .mapping import map_over_subtree +from .mapping import TreeIsomorphismError, map_over_subtree try: __version__ = get_distribution(__name__).version diff --git a/datatree/datatree.py b/datatree/datatree.py index af666927..7936b4b5 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -1,5 +1,7 @@ from __future__ import annotations +import copy +import itertools from collections import OrderedDict from html import escape from typing import ( @@ -8,17 +10,28 @@ Callable, Dict, Generic, + Hashable, Iterable, + Iterator, Mapping, MutableMapping, Optional, + Set, Tuple, Union, + cast, + overload, ) -from xarray import DataArray, Dataset +import pandas as pd from xarray.core import utils +from xarray.core.coordinates import DatasetCoordinates +from xarray.core.dataarray import DataArray +from xarray.core.dataset import Dataset, DataVariables +from xarray.core.indexes import Index, Indexes +from xarray.core.merge import dataset_update_method from xarray.core.options import OPTIONS as XR_OPTS +from xarray.core.utils import Default, Frozen, _default, hashable from xarray.core.variable import Variable from . import formatting, formatting_html @@ -31,6 +44,12 @@ from .render import RenderTree from .treenode import NodePath, Tree, TreeNode +try: + from xarray.core.variable import calculate_dimensions +except ImportError: + # for xarray versions 2022.03.0 and earlier + from xarray.core.dataset import calculate_dimensions + if TYPE_CHECKING: from xarray.core.merge import CoercibleValue @@ -41,7 +60,7 @@ # the entire API of `xarray.Dataset`, but with certain methods decorated to instead map the dataset function over every # node in the tree. As this API is copied without directly subclassing `xarray.Dataset` we instead create various Mixin # classes (in ops.py) which each define part of `xarray.Dataset`'s extensive API. - +# # Some of these methods must be wrapped to map over all nodes in the subtree. Others are fine to inherit unaltered # (normally because they (a) only call dataset properties and (b) don't return a dataset that should be nested into a new # tree) and some will get overridden by the class definition of DataTree. @@ -51,12 +70,187 @@ T_Path = Union[str, NodePath] +def _coerce_to_dataset(data: Dataset | DataArray | None) -> Dataset: + if isinstance(data, DataArray): + ds = data.to_dataset() + elif isinstance(data, Dataset): + ds = data + elif data is None: + ds = Dataset() + else: + raise TypeError( + f"data object is not an xarray Dataset, DataArray, or None, it is of type {type(data)}" + ) + return ds + + +def _check_for_name_collisions( + children: Iterable[str], variables: Iterable[Hashable] +) -> None: + colliding_names = set(children).intersection(set(variables)) + if colliding_names: + raise KeyError( + f"Some names would collide between variables and children: {list(colliding_names)}" + ) + + +class DatasetView(Dataset): + """ + An immutable Dataset-like view onto the data in a single DataTree node. + + In-place operations modifying this object should raise an AttributeError. + + Operations returning a new result will return a new xarray.Dataset object. + This includes all API on Dataset, which will be inherited. + + This requires overriding all inherited private constructors + + This object also keeps a reference to its wrapping tree node, to allow getting + items from elsewhere in the tree. + """ + + # TODO what happens if user alters (in-place) a DataArray they extracted from this object? + + _wrapping_node: DataTree + + __slots__ = ["_wrapping_node"] + + def __init__( + self, + data_vars: Mapping[Any, Any] = None, + coords: Mapping[Any, Any] = None, + attrs: Mapping[Any, Any] = None, + ): + raise AttributeError("DatasetView objects are not to be initialized directly") + + @classmethod + def _from_node( + cls, + wrapping_node: DataTree, + ) -> DatasetView: + """Constructor, using dataset attributes from wrapping node""" + + obj: DatasetView = object.__new__(cls) + obj._wrapping_node = wrapping_node + obj._variables = wrapping_node._variables + obj._coord_names = wrapping_node._coord_names + obj._dims = wrapping_node._dims + obj._indexes = wrapping_node._indexes + obj._attrs = wrapping_node._attrs + obj._close = wrapping_node._close + obj._encoding = wrapping_node._encoding + + return obj + + def __setitem__(self, key, val) -> None: + raise AttributeError( + "Mutation of the DatasetView is not allowed, please use __setitem__ on the wrapping DataTree node, " + "or use `DataTree.to_dataset()` if you want a mutable dataset" + ) + + def update(self, other) -> None: + raise AttributeError( + "Mutation of the DatasetView is not allowed, please use .update on the wrapping DataTree node, " + "or use `DataTree.to_dataset()` if you want a mutable dataset" + ) + + # FIXME https://github.com/python/mypy/issues/7328 + @overload + def __getitem__(self, key: Mapping) -> Dataset: # type: ignore[misc] + ... + + @overload + def __getitem__(self, key: Hashable) -> DataArray: # type: ignore[misc] + ... + + @overload + def __getitem__(self, key: Any) -> Dataset: + ... + + def __getitem__(self, key): + + # Copy the internals of Dataset.__getitem__ + if utils.is_dict_like(key): + return self.isel(**cast(Mapping, key)) + if hashable(key): + # allow path-like access to contents of other nodes + path = NodePath(key) + da = self._wrapped_node._get_item(path) + if isinstance(da, DataArray): + return da + else: + raise KeyError( + "DatasetView is only allowed to return variables, not entire DataTree nodes" + ) + else: + return self._copy_listed(key) + + @classmethod + def _construct_direct( + cls, + variables: dict[Any, Variable], + coord_names: set[Hashable], + dims: dict[Any, int] = None, + attrs: dict = None, + indexes: dict[Any, Index] = None, + encoding: dict = None, + close: Callable[[], None] = None, + ) -> Dataset: + """ + Overriding this method (along with ._replace) and modifying it to return a Dataset object + should hopefully ensure that the return type of any method on this object is a Dataset. + """ + if dims is None: + dims = calculate_dimensions(variables) + if indexes is None: + indexes = {} + obj = object.__new__(Dataset) + obj._variables = variables + obj._coord_names = coord_names + obj._dims = dims + obj._indexes = indexes + obj._attrs = attrs + obj._close = close + obj._encoding = encoding + return obj + + def _replace( + self, + variables: dict[Hashable, Variable] = None, + coord_names: set[Hashable] = None, + dims: dict[Any, int] = None, + attrs: dict[Hashable, Any] | None | Default = _default, + indexes: dict[Hashable, Index] = None, + encoding: dict | None | Default = _default, + inplace: bool = False, + ) -> Dataset: + """ + Overriding this method (along with ._construct_direct) and modifying it to return a Dataset object + should hopefully ensure that the return type of any method on this object is a Dataset. + """ + + if inplace: + raise AttributeError("In-place mutation of the DatasetView is not allowed") + + return Dataset._replace( + self, + variables=variables, + coord_names=coord_names, + dims=dims, + attrs=attrs, + indexes=indexes, + encoding=encoding, + inplace=inplace, + ) + + class DataTree( TreeNode, MappedDatasetMethodsMixin, MappedDataWithCoords, DataTreeArithmeticMixin, Generic[Tree], + Mapping, ): """ A tree-like hierarchical collection of xarray objects. @@ -80,10 +274,23 @@ class DataTree( # TODO .loc, __contains__, __iter__, __array__, __len__ + # TODO a lot of properties like .variables could be defined in a DataMapping class which both Dataset and DataTree inherit from + + # TODO __slots__ + + # TODO all groupby classes + _name: Optional[str] - _parent: Optional[Tree] - _children: OrderedDict[str, Tree] - _ds: Dataset + _parent: Optional[DataTree] + _children: OrderedDict[str, DataTree] + _attrs: Optional[Dict[Hashable, Any]] + _cache: Dict[str, Any] + _coord_names: Set[Hashable] + _dims: Dict[Hashable, int] + _encoding: Optional[Dict[Hashable, Any]] + _close: Optional[Callable[[], None]] + _indexes: Dict[Hashable, Index] + _variables: Dict[Hashable, Variable] def __init__( self, @@ -97,15 +304,15 @@ def __init__( Parameters ---------- - data : Dataset, DataArray, Variable or None, optional - Data to store under the .ds attribute of this node. DataArrays and Variables will be promoted to Datasets. + data : Dataset, DataArray, or None, optional + Data to store under the .ds attribute of this node. DataArrays will be promoted to Datasets. Default is None. parent : DataTree, optional Parent node to this node. Default is None. children : Mapping[str, DataTree], optional Any child nodes of this node. Default is None. name : str, optional - Name for the root node of the tree. + Name for the root node of the tree. Default is None. Returns ------- @@ -116,10 +323,28 @@ def __init__( DataTree.from_dict """ + # validate input + if children is None: + children = {} + ds = _coerce_to_dataset(data) + _check_for_name_collisions(children, ds.variables) + + # set tree attributes super().__init__(children=children) self.name = name self.parent = parent - self.ds = data # type: ignore[assignment] + + # set data attributes + self._replace( + inplace=True, + variables=ds._variables, + coord_names=ds._coord_names, + dims=ds._dims, + indexes=ds._indexes, + attrs=ds._attrs, + encoding=ds._encoding, + ) + self._close = ds._close @property def name(self) -> str | None: @@ -128,6 +353,8 @@ def name(self) -> str | None: @name.setter def name(self, name: str | None) -> None: + if not isinstance(name, str) and name is not None: + raise TypeError("name must either be a string or None") self._name = name @property @@ -142,55 +369,141 @@ def parent(self: DataTree, new_parent: DataTree) -> None: self._set_parent(new_parent, self.name) @property - def ds(self) -> Dataset: - """The data in this node, returned as a Dataset.""" - return self._ds + def ds(self) -> DatasetView: + """ + An immutable Dataset-like view onto the data in this node. + If you want a mutable Dataset containing the same data as in this node, + use `.to_dataset()` instead. + """ + return DatasetView._from_node(self) @ds.setter def ds(self, data: Union[Dataset, DataArray] = None) -> None: - if not isinstance(data, (Dataset, DataArray)) and data is not None: - raise TypeError( - f"{type(data)} object is not an xarray Dataset, DataArray, or None" - ) - if isinstance(data, DataArray): - data = data.to_dataset() - elif data is None: - data = Dataset() + ds = _coerce_to_dataset(data) - for var in list(data.variables): - if var in self.children: - raise KeyError( - f"Cannot add variable named {var}: node already has a child named {var}" - ) + _check_for_name_collisions(self.children, ds.variables) + + self._replace( + inplace=True, + variables=ds._variables, + coord_names=ds._coord_names, + dims=ds._dims, + indexes=ds._indexes, + attrs=ds._attrs, + encoding=ds._encoding, + ) + self._close = ds._close - self._ds = data + def _pre_attach(self: DataTree, parent: DataTree) -> None: + """ + Method which superclass calls before setting parent, here used to prevent having two + children with duplicate names (or a data variable with the same name as a child). + """ + super()._pre_attach(parent) + if self.name in list(parent.ds.variables): + raise KeyError( + f"parent {parent.name} already contains a data variable named {self.name}" + ) + + def to_dataset(self) -> Dataset: + """Return the data in this node as a new xarray.Dataset object.""" + return Dataset._construct_direct( + self._variables, + self._coord_names, + self._dims, + self._attrs, + self._indexes, + self._encoding, + self._close, + ) @property - def has_data(self) -> bool: + def has_data(self): """Whether or not there are any data variables in this node.""" - return len(self.ds.variables) > 0 + return len(self._variables) > 0 @property def has_attrs(self) -> bool: """Whether or not there are any metadata attributes in this node.""" - return len(self.ds.attrs.keys()) > 0 + return len(self.attrs.keys()) > 0 @property def is_empty(self) -> bool: """False if node contains any data or attrs. Does not look at children.""" return not (self.has_data or self.has_attrs) - def _pre_attach(self: DataTree, parent: DataTree) -> None: + @property + def variables(self) -> Mapping[Hashable, Variable]: + """Low level interface to Dataset contents as dict of Variable objects. + + This ordered dictionary is frozen to prevent mutation that could + violate Dataset invariants. It contains all variable objects + constituting the Dataset, including both data variables and + coordinates. """ - Method which superclass calls before setting parent, here used to prevent having two - children with duplicate names (or a data variable with the same name as a child). + return Frozen(self._variables) + + @property + def attrs(self) -> Dict[Hashable, Any]: + """Dictionary of global attributes on this dataset""" + if self._attrs is None: + self._attrs = {} + return self._attrs + + @attrs.setter + def attrs(self, value: Mapping[Any, Any]) -> None: + self._attrs = dict(value) + + @property + def encoding(self) -> Dict: + """Dictionary of global encoding attributes on this dataset""" + if self._encoding is None: + self._encoding = {} + return self._encoding + + @encoding.setter + def encoding(self, value: Mapping) -> None: + self._encoding = dict(value) + + @property + def dims(self) -> Mapping[Hashable, int]: + """Mapping from dimension names to lengths. + + Cannot be modified directly, but is updated when adding new variables. + + Note that type of this object differs from `DataArray.dims`. + See `Dataset.sizes` and `DataArray.sizes` for consistently named + properties. """ - super()._pre_attach(parent) - if parent.has_data and self.name in list(parent.ds.variables): - raise KeyError( - f"parent {parent.name} already contains a data variable named {self.name}" - ) + return Frozen(self._dims) + + @property + def sizes(self) -> Mapping[Hashable, int]: + """Mapping from dimension names to lengths. + + Cannot be modified directly, but is updated when adding new variables. + + This is an alias for `Dataset.dims` provided for the benefit of + consistency with `DataArray.sizes`. + + See Also + -------- + DataArray.sizes + """ + return self.dims + + def __contains__(self, key: object) -> bool: + """The 'in' operator will return true or false depending on whether + 'key' is either an array stored in the datatree or a child node, or neither. + """ + return key in self.variables or key in self.children + + def __bool__(self) -> bool: + return bool(self.ds.data_vars) or bool(self.children) + + def __iter__(self) -> Iterator[Hashable]: + return itertools.chain(self.ds.data_vars, self.children) def __repr__(self) -> str: return formatting.datatree_repr(self) @@ -204,6 +517,122 @@ def _repr_html_(self): return f"
{escape(repr(self))}" return formatting_html.datatree_repr(self) + @classmethod + def _construct_direct( + cls, + variables: dict[Any, Variable], + coord_names: set[Hashable], + dims: dict[Any, int] = None, + attrs: dict = None, + indexes: dict[Any, Index] = None, + encoding: dict = None, + name: str | None = None, + parent: DataTree | None = None, + children: OrderedDict[str, DataTree] = None, + close: Callable[[], None] = None, + ) -> DataTree: + """Shortcut around __init__ for internal use when we want to skip + costly validation + """ + + # data attributes + if dims is None: + dims = calculate_dimensions(variables) + if indexes is None: + indexes = {} + if children is None: + children = OrderedDict() + + obj: DataTree = object.__new__(cls) + obj._variables = variables + obj._coord_names = coord_names + obj._dims = dims + obj._indexes = indexes + obj._attrs = attrs + obj._close = close + obj._encoding = encoding + + # tree attributes + obj._name = name + obj._children = children + obj._parent = parent + + return obj + + def _replace( + self: DataTree, + variables: dict[Hashable, Variable] = None, + coord_names: set[Hashable] = None, + dims: dict[Any, int] = None, + attrs: dict[Hashable, Any] | None | Default = _default, + indexes: dict[Hashable, Index] = None, + encoding: dict | None | Default = _default, + name: str | None | Default = _default, + parent: DataTree | None = _default, + children: OrderedDict[str, DataTree] = None, + inplace: bool = False, + ) -> DataTree: + """ + Fastpath constructor for internal use. + + Returns an object with optionally replaced attributes. + + Explicitly passed arguments are *not* copied when placed on the new + datatree. It is up to the caller to ensure that they have the right type + and are not used elsewhere. + """ + if inplace: + if variables is not None: + self._variables = variables + if coord_names is not None: + self._coord_names = coord_names + if dims is not None: + self._dims = dims + if attrs is not _default: + self._attrs = attrs + if indexes is not None: + self._indexes = indexes + if encoding is not _default: + self._encoding = encoding + if name is not _default: + self._name = name + if parent is not _default: + self._parent = parent + if children is not None: + self._children = children + obj = self + else: + if variables is None: + variables = self._variables.copy() + if coord_names is None: + coord_names = self._coord_names.copy() + if dims is None: + dims = self._dims.copy() + if attrs is _default: + attrs = copy.copy(self._attrs) + if indexes is None: + indexes = self._indexes.copy() + if encoding is _default: + encoding = copy.copy(self._encoding) + if name is _default: + name = self._name # no need to copy str objects or None + if parent is _default: + parent = copy.copy(self._parent) + if children is _default: + children = copy.copy(self._children) + obj = self._construct_direct( + variables, + coord_names, + dims, + attrs, + indexes, + encoding, + name, + parent, + children, + ) + return obj + def get( self: DataTree, key: str, default: Optional[DataTree | DataArray] = None ) -> Optional[DataTree | DataArray]: @@ -221,8 +650,8 @@ def get( """ if key in self.children: return self.children[key] - elif key in self.ds: - return self.ds[key] + elif key in self.variables: + return self.to_dataset()._construct_dataarray(key) else: return default @@ -267,7 +696,7 @@ def _set(self, key: str, val: DataTree | CoercibleValue) -> None: val.parent = self elif isinstance(val, (DataArray, Variable)): # TODO this should also accomodate other types that can be coerced into Variables - self.ds[key] = val + self.update({key: val}) else: raise TypeError(f"Type {type(val)} cannot be assigned to a DataTree") @@ -311,8 +740,12 @@ def update(self, other: Dataset | Mapping[str, DataTree | DataArray]) -> None: else: raise TypeError(f"Type {type(v)} cannot be assigned to a DataTree") - super().update(new_children) - self.ds.update(new_variables) + vars_merge_result = dataset_update_method(self.to_dataset(), new_variables) + # TODO are there any subtleties with preserving order of children like this? + merged_children = OrderedDict(**self.children, **new_children) + self._replace( + inplace=True, children=merged_children, **vars_merge_result._asdict() + ) @classmethod def from_dict( @@ -360,6 +793,7 @@ def from_dict( allow_overwrite=False, new_nodes_along_path=True, ) + return obj def to_dict(self) -> Dict[str, Any]: @@ -374,14 +808,38 @@ def to_dict(self) -> Dict[str, Any]: @property def nbytes(self) -> int: - return sum(node.ds.nbytes if node.has_data else 0 for node in self.subtree) + return sum(node.to_dataset().nbytes for node in self.subtree) def __len__(self) -> int: - if self.children: - n_children = len(self.children) - else: - n_children = 0 - return n_children + len(self.ds) + return len(self.children) + len(self.data_vars) + + @property + def indexes(self) -> Indexes[pd.Index]: + """Mapping of pandas.Index objects used for label based indexing. + Raises an error if this DataTree node has indexes that cannot be coerced + to pandas.Index objects. + See Also / + -------- + DataTree.xindexes + """ + return self.xindexes.to_pandas_indexes() + + @property + def xindexes(self) -> Indexes[Index]: + """Mapping of xarray Index objects used for label based indexing.""" + return Indexes(self._indexes, {k: self._variables[k] for k in self._indexes}) + + @property + def coords(self) -> DatasetCoordinates: + """Dictionary of xarray.DataArray objects corresponding to coordinate + variables + """ + return DatasetCoordinates(self.to_dataset()) + + @property + def data_vars(self) -> DataVariables: + """Dictionary of DataArray objects corresponding to data variables""" + return DataVariables(self.to_dataset()) def isomorphic( self, diff --git a/datatree/ops.py b/datatree/ops.py index ee55ccfe..bdc931c9 100644 --- a/datatree/ops.py +++ b/datatree/ops.py @@ -30,8 +30,8 @@ "map_blocks", ] _DATASET_METHODS_TO_MAP = [ - "copy", "as_numpy", + "copy", "__copy__", "__deepcopy__", "set_coords", @@ -57,7 +57,6 @@ "reorder_levels", "stack", "unstack", - "update", "merge", "drop_vars", "drop_sel", @@ -245,7 +244,6 @@ class MappedDataWithCoords: """ # TODO add mapped versions of groupby, weighted, rolling, rolling_exp, coarsen, resample - # TODO re-implement AttrsAccessMixin stuff so that it includes access to child nodes _wrap_then_attach_to_cls( target_cls_dict=vars(), source_cls=Dataset, diff --git a/datatree/tests/test_datatree.py b/datatree/tests/test_datatree.py index 5232f3e5..cfe67da0 100644 --- a/datatree/tests/test_datatree.py +++ b/datatree/tests/test_datatree.py @@ -1,7 +1,12 @@ +from copy import copy, deepcopy + +import numpy as np import pytest import xarray as xr import xarray.testing as xrt +from xarray.tests import create_test_data, source_ndarray +import datatree.testing as dtt from datatree import DataTree @@ -40,15 +45,16 @@ def create_test_datatree(modify=lambda ds: ds): root_data = modify(xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})) # Avoid using __init__ so we can independently test it - root = DataTree(data=root_data) - set1 = DataTree(name="set1", parent=root, data=set1_data) - DataTree(name="set1", parent=set1) - DataTree(name="set2", parent=set1) - set2 = DataTree(name="set2", parent=root, data=set2_data) - DataTree(name="set1", parent=set2) - DataTree(name="set3", parent=root) - - return root + d = { + "/": root_data, + "/set1": set1_data, + "/set1/set1": None, + "/set1/set2": None, + "/set2": set2_data, + "/set2/set1": None, + "/set3": None, + } + return DataTree.from_dict(d) class TestTreeCreation: @@ -57,7 +63,7 @@ def test_empty(self): assert dt.name == "root" assert dt.parent is None assert dt.children == {} - xrt.assert_identical(dt.ds, xr.Dataset()) + xrt.assert_identical(dt.to_dataset(), xr.Dataset()) def test_unnamed(self): dt = DataTree() @@ -70,12 +76,37 @@ def test_setparent_unnamed_child_node_fails(self): with pytest.raises(ValueError, match="unnamed"): DataTree(parent=john) + def test_create_two_children(self): + root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}) + set1_data = xr.Dataset({"a": 0, "b": 1}) + + root = DataTree(data=root_data) + set1 = DataTree(name="set1", parent=root, data=set1_data) + DataTree(name="set1", parent=root) + DataTree(name="set2", parent=set1) + + def test_create_full_tree(self): + root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}) + set1_data = xr.Dataset({"a": 0, "b": 1}) + set2_data = xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])}) + + root = DataTree(data=root_data) + set1 = DataTree(name="set1", parent=root, data=set1_data) + DataTree(name="set1", parent=set1) + DataTree(name="set2", parent=set1) + set2 = DataTree(name="set2", parent=root, data=set2_data) + DataTree(name="set1", parent=set2) + DataTree(name="set3", parent=root) + + expected = create_test_datatree() + assert root.identical(expected) + class TestStoreDatasets: def test_create_with_data(self): dat = xr.Dataset({"a": 0}) john = DataTree(name="john", data=dat) - assert john.ds is dat + xrt.assert_identical(john.to_dataset(), dat) with pytest.raises(TypeError): DataTree(name="mary", parent=john, data="junk") # noqa @@ -84,7 +115,7 @@ def test_set_data(self): john = DataTree(name="john") dat = xr.Dataset({"a": 0}) john.ds = dat - assert john.ds is dat + xrt.assert_identical(john.to_dataset(), dat) with pytest.raises(TypeError): john.ds = "junk" @@ -105,20 +136,13 @@ def test_parent_already_has_variable_with_childs_name(self): def test_assign_when_already_child_with_variables_name(self): dt = DataTree(data=None) DataTree(name="a", data=None, parent=dt) - with pytest.raises(KeyError, match="already has a child named a"): + with pytest.raises(KeyError, match="names would collide"): dt.ds = xr.Dataset({"a": 0}) dt.ds = xr.Dataset() - with pytest.raises(KeyError, match="already has a child named a"): - dt.ds = dt.ds.assign(a=xr.DataArray(0)) - - @pytest.mark.xfail - def test_update_when_already_child_with_variables_name(self): - # See issue #38 - dt = DataTree(name="root", data=None) - DataTree(name="a", data=None, parent=dt) - with pytest.raises(KeyError, match="already has a child named a"): - dt.ds["a"] = xr.DataArray(0) + new_ds = dt.to_dataset().assign(a=xr.DataArray(0)) + with pytest.raises(KeyError, match="names would collide"): + dt.ds = new_ds class TestGet: @@ -175,7 +199,87 @@ def test_getitem_dict_like_selection_access_to_dataset(self): class TestUpdate: - ... + def test_update_new_named_dataarray(self): + da = xr.DataArray(name="temp", data=[0, 50]) + folder1 = DataTree(name="folder1") + folder1.update({"results": da}) + expected = da.rename("results") + xrt.assert_equal(folder1["results"], expected) + + +class TestCopy: + def test_copy(self): + dt = create_test_datatree() + + for node in dt.root.subtree: + node.attrs["Test"] = [1, 2, 3] + + for copied in [dt.copy(deep=False), copy(dt)]: + dtt.assert_identical(dt, copied) + + for node, copied_node in zip(dt.root.subtree, copied.root.subtree): + + assert node.encoding == copied_node.encoding + # Note: IndexVariable objects with string dtype are always + # copied because of xarray.core.util.safe_cast_to_index. + # Limiting the test to data variables. + # TODO use .data_vars once that property is available + data_vars = [v for v in node.variables if v not in node._coord_names] + for k in data_vars: + v0 = node.variables[k] + v1 = copied_node.variables[k] + assert source_ndarray(v0.data) is source_ndarray(v1.data) + copied_node["foo"] = xr.DataArray(data=np.arange(5), dims="z") + assert "foo" not in node + + copied_node.attrs["foo"] = "bar" + assert "foo" not in node.attrs + assert node.attrs["Test"] is copied_node.attrs["Test"] + + @pytest.mark.xfail(reason="unresolved bug with deepcopying") + def test_deepcopy(self): + dt = create_test_datatree() + + for node in dt.root.subtree: + node.attrs["Test"] = [1, 2, 3] + + for copied in [dt.copy(deep=True), deepcopy(dt)]: + dtt.assert_identical(dt, copied) + + for node, copied_node in zip(dt.root.subtree, copied.root.subtree): + assert node.encoding == copied_node.encoding + # Note: IndexVariable objects with string dtype are always + # copied because of xarray.core.util.safe_cast_to_index. + # Limiting the test to data variables. + # TODO use .data_vars once that property is available + data_vars = [v for v in node.variables if v not in node._coord_names] + for k in data_vars: + v0 = node.variables[k] + v1 = copied_node.variables[k] + assert source_ndarray(v0.data) is source_ndarray(v1.data) + copied_node["foo"] = xr.DataArray(data=np.arange(5), dims="z") + assert "foo" not in node + + copied_node.attrs["foo"] = "bar" + assert "foo" not in node.attrs + assert node.attrs["Test"] is copied_node.attrs["Test"] + + @pytest.mark.xfail(reason="data argument not yet implemented") + def test_copy_with_data(self): + orig = create_test_datatree() + # TODO use .data_vars once that property is available + data_vars = { + k: v for k, v in orig.variables.items() if k not in orig._coord_names + } + new_data = {k: np.random.randn(*v.shape) for k, v in data_vars.items()} + actual = orig.copy(data=new_data) + + expected = orig.copy() + for k, v in new_data.items(): + expected[k].data = v + dtt.assert_identical(expected, actual) + + # TODO test parents and children? class TestSetItem: @@ -209,13 +313,13 @@ def test_setitem_new_empty_node(self): john["mary"] = DataTree() mary = john["mary"] assert isinstance(mary, DataTree) - xrt.assert_identical(mary.ds, xr.Dataset()) + xrt.assert_identical(mary.to_dataset(), xr.Dataset()) def test_setitem_overwrite_data_in_node_with_none(self): john = DataTree(name="john") mary = DataTree(name="mary", parent=john, data=xr.Dataset()) john["mary"] = DataTree() - xrt.assert_identical(mary.ds, xr.Dataset()) + xrt.assert_identical(mary.to_dataset(), xr.Dataset()) john.ds = xr.Dataset() with pytest.raises(ValueError, match="has no name"): @@ -226,27 +330,27 @@ def test_setitem_dataset_on_this_node(self): data = xr.Dataset({"temp": [0, 50]}) results = DataTree(name="results") results["."] = data - assert results.ds is data + xrt.assert_identical(results.to_dataset(), data) @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") def test_setitem_dataset_as_new_node(self): data = xr.Dataset({"temp": [0, 50]}) folder1 = DataTree(name="folder1") folder1["results"] = data - assert folder1["results"].ds is data + xrt.assert_identical(folder1["results"].to_dataset(), data) @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") def test_setitem_dataset_as_new_node_requiring_intermediate_nodes(self): data = xr.Dataset({"temp": [0, 50]}) folder1 = DataTree(name="folder1") folder1["results/highres"] = data - assert folder1["results/highres"].ds is data + xrt.assert_identical(folder1["results/highres"].to_dataset(), data) def test_setitem_named_dataarray(self): - data = xr.DataArray(name="temp", data=[0, 50]) + da = xr.DataArray(name="temp", data=[0, 50]) folder1 = DataTree(name="folder1") - folder1["results"] = data - expected = data.rename("results") + folder1["results"] = da + expected = da.rename("results") xrt.assert_equal(folder1["results"], expected) def test_setitem_unnamed_dataarray(self): @@ -275,7 +379,7 @@ def test_setitem_dataarray_replace_existing_node(self): p = xr.DataArray(data=[2, 3]) results["pressure"] = p expected = t.assign(pressure=p) - xrt.assert_identical(results.ds, expected) + xrt.assert_identical(results.to_dataset(), expected) class TestDictionaryInterface: @@ -289,16 +393,16 @@ def test_data_in_root(self): assert dt.name is None assert dt.parent is None assert dt.children == {} - assert dt.ds is dat + xrt.assert_identical(dt.to_dataset(), dat) def test_one_layer(self): dat1, dat2 = xr.Dataset({"a": 1}), xr.Dataset({"b": 2}) dt = DataTree.from_dict({"run1": dat1, "run2": dat2}) - xrt.assert_identical(dt.ds, xr.Dataset()) + xrt.assert_identical(dt.to_dataset(), xr.Dataset()) assert dt.name is None - assert dt["run1"].ds is dat1 + xrt.assert_identical(dt["run1"].to_dataset(), dat1) assert dt["run1"].children == {} - assert dt["run2"].ds is dat2 + xrt.assert_identical(dt["run2"].to_dataset(), dat2) assert dt["run2"].children == {} def test_two_layers(self): @@ -307,13 +411,13 @@ def test_two_layers(self): assert "highres" in dt.children assert "lowres" in dt.children highres_run = dt["highres/run"] - assert highres_run.ds is dat1 + xrt.assert_identical(highres_run.to_dataset(), dat1) def test_nones(self): dt = DataTree.from_dict({"d": None, "d/e": None}) assert [node.name for node in dt.subtree] == [None, "d", "e"] assert [node.path for node in dt.subtree] == ["/", "/d", "/d/e"] - xrt.assert_equal(dt["d/e"].ds, xr.Dataset()) + xrt.assert_identical(dt["d/e"].to_dataset(), xr.Dataset()) def test_full(self): dt = create_test_datatree() @@ -343,8 +447,57 @@ def test_roundtrip_unnamed_root(self): assert roundtrip.equals(dt) -class TestBrowsing: - ... +class TestDatasetView: + def test_view_contents(self): + ds = create_test_data() + dt = DataTree(data=ds) + assert ds.identical( + dt.ds + ) # this only works because Dataset.identical doesn't check types + assert isinstance(dt.ds, xr.Dataset) + + def test_immutability(self): + # See issue #38 + dt = DataTree(name="root", data=None) + DataTree(name="a", data=None, parent=dt) + + with pytest.raises( + AttributeError, match="Mutation of the DatasetView is not allowed" + ): + dt.ds["a"] = xr.DataArray(0) + + with pytest.raises( + AttributeError, match="Mutation of the DatasetView is not allowed" + ): + dt.ds.update({"a": 0}) + + # TODO are there any other ways you can normally modify state (in-place)? + # (not attribute-like assignment because that doesn't work on Dataset anyway) + + def test_methods(self): + ds = create_test_data() + dt = DataTree(data=ds) + assert ds.mean().identical(dt.ds.mean()) + assert type(dt.ds.mean()) == xr.Dataset + + def test_getitem(self): + dt = create_test_datatree() + da = dt.ds["a"] + print(da) + expected = xr.DataArray(dims=("y",), data=[6, 7, 8]) + print(expected) + xrt.assert_equal(da, expected) + + # @pytest.mark.xfail + def test_get_items_in_other_nodes(self): + dt = create_test_datatree() + node_dsv = dt["set1"].ds + + xrt.assert_identical(node_dsv["../a"], dt["a"]) + + # don't allow retrieving other DataTree nodes + with pytest.raises(KeyError): + node_dsv["../set2"] class TestRestructuring: diff --git a/datatree/treenode.py b/datatree/treenode.py index c3effd91..04c3e68c 100644 --- a/datatree/treenode.py +++ b/datatree/treenode.py @@ -134,7 +134,9 @@ def _detach(self, parent: Tree | None) -> None: def _attach(self, parent: Tree | None, child_name: str = None) -> None: if parent is not None: if child_name is None: - raise ValueError("Cannot directly assign a parent to an unnamed node") + raise ValueError( + "To directly set parent, child needs a name, but child is unnamed" + ) self._pre_attach(parent) parentchildren = parent._children diff --git a/docs/source/api.rst b/docs/source/api.rst index 5cd16466..5cc45679 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -18,6 +18,8 @@ Creating a DataTree Tree Attributes --------------- +Attributes relating to the recursive tree-like structure of a ``DataTree``. + .. autosummary:: :toctree: generated/ @@ -34,34 +36,40 @@ Tree Attributes DataTree.ancestors DataTree.groups -Data Attributes ---------------- +Data Contents +------------- + +Interface to the data objects (optionally) stored inside a single ``DataTree`` node. +This interface echoes that of ``xarray.Dataset``. .. autosummary:: :toctree: generated/ DataTree.dims - DataTree.variables - DataTree.encoding DataTree.sizes + DataTree.data_vars + DataTree.coords DataTree.attrs + DataTree.encoding DataTree.indexes - DataTree.xindexes - DataTree.coords DataTree.chunks + DataTree.nbytes DataTree.ds + DataTree.to_dataset DataTree.has_data DataTree.has_attrs DataTree.is_empty .. - Missing - DataTree.chunksizes + Missing: + ``DataTree.chunksizes`` Dictionary interface -------------------- +``DataTree`` objects also have a dict-like interface mapping keys to either ``xarray.DataArray``s or to child ``DataTree`` nodes. + .. autosummary:: :toctree: generated/ @@ -70,16 +78,14 @@ Dictionary interface DataTree.__delitem__ DataTree.update DataTree.get - -.. - - Missing DataTree.items DataTree.keys DataTree.values -Tree Manipulation Methods -------------------------- +Tree Manipulation +----------------- + +For manipulating, traversing, navigating, or mapping over the tree structure. .. autosummary:: :toctree: generated/ @@ -89,127 +95,181 @@ Tree Manipulation Methods DataTree.relative_to DataTree.iter_lineage DataTree.find_common_ancestor + map_over_subtree -Tree Manipulation Utilities ---------------------------- +DataTree Contents +----------------- + +Manipulate the contents of all nodes in a tree simultaneously. .. autosummary:: :toctree: generated/ - map_over_subtree + DataTree.copy + DataTree.assign + DataTree.assign_coords + DataTree.merge + DataTree.rename + DataTree.rename_vars + DataTree.rename_dims + DataTree.swap_dims + DataTree.expand_dims + DataTree.drop_vars + DataTree.drop_dims + DataTree.set_coords + DataTree.reset_coords -Methods -------- -.. +DataTree Node Contents +---------------------- - TODO divide these up into "Dataset contents", "Indexing", "Computation" etc. +Manipulate the contents of a single DataTree node. + +Comparisons +=========== + +Compare one ``DataTree`` object to another. + +.. autosummary:: + :toctree: generated/ + + DataTree.isomorphic + DataTree.equals + DataTree.identical + +Indexing +======== + +Index into all nodes in the subtree simultaneously. .. autosummary:: :toctree: generated/ - DataTree.load - DataTree.compute - DataTree.persist - DataTree.unify_chunks - DataTree.chunk - DataTree.map_blocks - DataTree.copy - DataTree.as_numpy - DataTree.__copy__ - DataTree.__deepcopy__ - DataTree.set_coords - DataTree.reset_coords - DataTree.info DataTree.isel DataTree.sel + DataTree.drop_sel + DataTree.drop_isel DataTree.head DataTree.tail DataTree.thin - DataTree.broadcast_like - DataTree.reindex_like - DataTree.reindex + DataTree.squeeze DataTree.interp DataTree.interp_like - DataTree.rename - DataTree.rename_dims - DataTree.rename_vars - DataTree.swap_dims - DataTree.expand_dims + DataTree.reindex + DataTree.reindex_like DataTree.set_index DataTree.reset_index DataTree.reorder_levels - DataTree.stack - DataTree.unstack - DataTree.update - DataTree.merge - DataTree.drop_vars - DataTree.drop_sel - DataTree.drop_isel - DataTree.drop_dims - DataTree.isomorphic - DataTree.equals - DataTree.identical - DataTree.transpose + DataTree.query + +.. + + Missing: + ``DataTree.loc`` + + +Missing Value Handling +====================== + +.. autosummary:: + :toctree: generated/ + + DataTree.isnull + DataTree.notnull + DataTree.combine_first DataTree.dropna DataTree.fillna - DataTree.interpolate_na DataTree.ffill DataTree.bfill - DataTree.combine_first - DataTree.reduce - DataTree.map - DataTree.assign - DataTree.diff - DataTree.shift - DataTree.roll - DataTree.sortby - DataTree.quantile - DataTree.rank - DataTree.differentiate - DataTree.integrate - DataTree.cumulative_integrate - DataTree.filter_by_attrs - DataTree.polyfit - DataTree.pad - DataTree.idxmin - DataTree.idxmax - DataTree.argmin - DataTree.argmax - DataTree.query - DataTree.curvefit - DataTree.squeeze - DataTree.clip - DataTree.assign_coords + DataTree.interpolate_na DataTree.where - DataTree.close - DataTree.isnull - DataTree.notnull DataTree.isin - DataTree.astype -Comparisons +Computation +=========== + +Apply a computation to the data in all nodes in the subtree simultaneously. + +.. autosummary:: + :toctree: generated/ + + Dataset.map + Dataset.reduce + Dataset.diff + Dataset.quantile + Dataset.differentiate + Dataset.integrate + Dataset.map_blocks + Dataset.polyfit + Dataset.curvefit + +Aggregation =========== +Aggregate data in all nodes in the subtree simultaneously. + .. autosummary:: :toctree: generated/ - testing.assert_isomorphic - testing.assert_equal - testing.assert_identical + Dataset.all + Dataset.any + Dataset.argmax + Dataset.argmin + Dataset.idxmax + Dataset.idxmin + Dataset.max + Dataset.min + Dataset.mean + Dataset.median + Dataset.prod + Dataset.sum + Dataset.std + Dataset.var + Dataset.cumsum + Dataset.cumprod ndarray methods ---------------- +=============== + +Methods copied from `np.ndarray` objects, here applying to the data in all nodes in the subtree. .. autosummary:: :toctree: generated/ - DataTree.nbytes - DataTree.real + DataTree.argsort + DataTree.astype + DataTree.clip + DataTree.conj + DataTree.conjugate DataTree.imag + DataTree.round + DataTree.real + DataTree.rank + +Reshaping and reorganising +========================== + +Reshape or reorganise the data in all nodes in the subtree. + +.. autosummary:: + :toctree: generated/ + + DataTree.transpose + DataTree.stack + DataTree.unstack + DataTree.shift + DataTree.roll + DataTree.pad + DataTree.sortby + DataTree.broadcast_like + +Plotting +======== I/O === +Create or + .. autosummary:: :toctree: generated/ @@ -221,14 +281,46 @@ I/O .. - Missing - open_mfdatatree + Missing: + ``open_mfdatatree`` + +Tutorial +======== + +Testing +======= + +Test that two DataTree objects are similar. + +.. autosummary:: + :toctree: generated/ + + testing.assert_isomorphic + testing.assert_equal + testing.assert_identical Exceptions ========== +Exceptions raised when manipulating trees. + .. autosummary:: :toctree: generated/ - TreeError TreeIsomorphismError + +Advanced API +============ + +Relatively advanced API for users or developers looking to understand the internals, or extend functionality. + +.. autosummary:: + :toctree: generated/ + + DataTree.variables + +.. + + Missing: + ``DataTree.set_close`` + ``register_datatree_accessor``