diff --git a/DATATREE_MIGRATION_GUIDE.md b/DATATREE_MIGRATION_GUIDE.md index 312cb842d84..d91707e8c86 100644 --- a/DATATREE_MIGRATION_GUIDE.md +++ b/DATATREE_MIGRATION_GUIDE.md @@ -36,7 +36,8 @@ A number of other API changes have been made, which should only require minor mo - The top-level import has changed, from `from datatree import DataTree, open_datatree` to `from xarray import DataTree, open_datatree`. Alternatively you can now just use the `import xarray as xr` namespace convention for everything datatree-related. - The `DataTree.ds` property has been changed to `DataTree.dataset`, though `DataTree.ds` remains as an alias for `DataTree.dataset`. - Similarly the `ds` kwarg in the `DataTree.__init__` constructor has been replaced by `dataset`, i.e. use `DataTree(dataset=)` instead of `DataTree(ds=...)`. -- The method `DataTree.to_dataset()` still exists but now has different options for controlling which variables are present on the resulting `Dataset`, e.g. `inherited=True/False`. +- The method `DataTree.to_dataset()` still exists but now has different options for controlling which variables are present on the resulting `Dataset`, e.g. `inherit=True/False`. +- `DataTree.copy()` also has a new `inherit` keyword argument for controlling whether or not coordinates defined on parents are copied (only relevant when copying a non-root node). - The `DataTree.parent` property is now read-only. To assign a ancestral relationships directly you must instead use the `.children` property on the parent node, which remains settable. - Similarly the `parent` kwarg has been removed from the `DataTree.__init__` constuctor. - DataTree objects passed to the `children` kwarg in `DataTree.__init__` are now shallow-copied. diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 83a2f47c38d..452a1af2d3b 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -826,19 +826,13 @@ def _replace_node( self.children = children - def _copy_node( - self: DataTree, - deep: bool = False, - ) -> DataTree: - """Copy just one node of a tree""" - - new_node = super()._copy_node() - - data = self._to_dataset_view(rebuild_dims=False, inherit=False) + def _copy_node(self, inherit: bool, deep: bool = False) -> Self: + """Copy just one node of a tree.""" + new_node = super()._copy_node(inherit=inherit, deep=deep) + data = self._to_dataset_view(rebuild_dims=False, inherit=inherit) if deep: data = data.copy(deep=True) new_node._set_node_data(data) - return new_node def get( # type: ignore[override] @@ -1159,7 +1153,9 @@ def depth(item) -> int: new_nodes_along_path=True, ) - return obj + # TODO: figure out why mypy is raising an error here, likely something + # to do with the return type of Dataset.copy() + return obj # type: ignore[return-value] def to_dict(self) -> dict[str, Dataset]: """ diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index 604eb274aa9..102b92dc03b 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -10,6 +10,7 @@ TypeVar, ) +from xarray.core.types import Self from xarray.core.utils import Frozen, is_dict_like if TYPE_CHECKING: @@ -238,10 +239,7 @@ def _post_attach_children(self: Tree, children: Mapping[str, Tree]) -> None: """Method call after attaching `children`.""" pass - def copy( - self: Tree, - deep: bool = False, - ) -> Tree: + def copy(self, *, inherit: bool = True, deep: bool = False) -> Self: """ Returns a copy of this subtree. @@ -254,7 +252,12 @@ def copy( Parameters ---------- - deep : bool, default: False + inherit : bool + Whether inherited coordinates defined on parents of this node should + also be copied onto the new tree. Only relevant if the `parent` of + this node is not yet, and "Inherited coordinates" appear in its + repr. + deep : bool Whether each component variable is loaded into memory and copied onto the new object. Default is False. @@ -269,35 +272,27 @@ def copy( xarray.Dataset.copy pandas.DataFrame.copy """ - return self._copy_subtree(deep=deep) + return self._copy_subtree(inherit=inherit, deep=deep) - def _copy_subtree( - self: Tree, - deep: bool = False, - memo: dict[int, Any] | None = None, - ) -> Tree: + def _copy_subtree(self, inherit: bool, deep: bool = False) -> Self: """Copy entire subtree recursively.""" - - new_tree = self._copy_node(deep=deep) + new_tree = self._copy_node(inherit=inherit, deep=deep) for name, child in self.children.items(): # TODO use `.children[name] = ...` once #9477 is implemented - new_tree._set(name, child._copy_subtree(deep=deep)) - + new_tree._set(name, child._copy_subtree(inherit=False, deep=deep)) return new_tree - def _copy_node( - self: Tree, - deep: bool = False, - ) -> Tree: + def _copy_node(self, inherit: bool, deep: bool = False) -> Self: """Copy just one node of a tree""" new_empty_node = type(self)() return new_empty_node - def __copy__(self: Tree) -> Tree: - return self._copy_subtree(deep=False) + def __copy__(self) -> Self: + return self._copy_subtree(inherit=True, deep=False) - def __deepcopy__(self: Tree, memo: dict[int, Any] | None = None) -> Tree: - return self._copy_subtree(deep=True, memo=memo) + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self: + del memo # nodes cannot be reused in a DataTree + return self._copy_subtree(inherit=True, deep=True) def _iter_parents(self: Tree) -> Iterator[Tree]: """Iterate up the tree, starting from the current node's parent.""" @@ -693,17 +688,14 @@ def __str__(self) -> str: name_repr = repr(self.name) if self.name is not None else "" return f"NamedNode({name_repr})" - def _post_attach(self: AnyNamedNode, parent: AnyNamedNode, name: str) -> None: + def _post_attach(self, parent: Self, name: str) -> None: """Ensures child has name attribute corresponding to key under which it has been stored.""" _validate_name(name) # is this check redundant? self._name = name - def _copy_node( - self: AnyNamedNode, - deep: bool = False, - ) -> AnyNamedNode: + def _copy_node(self, inherit: bool, deep: bool = False) -> Self: """Copy just one node of a tree""" - new_node = super()._copy_node() + new_node = super()._copy_node(inherit=inherit, deep=deep) new_node._name = self.name return new_node diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index e20785bb3cf..c43f53a22c1 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -414,10 +414,18 @@ def test_copy_coord_inheritance(self) -> None: tree = DataTree.from_dict( {"/": xr.Dataset(coords={"x": [0, 1]}), "/c": DataTree()} ) - tree2 = tree.copy() - node_ds = tree2.children["c"].to_dataset(inherit=False) + actual = tree.copy() + node_ds = actual.children["c"].to_dataset(inherit=False) assert_identical(node_ds, xr.Dataset()) + actual = tree.children["c"].copy() + expected = DataTree(Dataset(coords={"x": [0, 1]}), name="c") + assert_identical(expected, actual) + + actual = tree.children["c"].copy(inherit=False) + expected = DataTree(name="c") + assert_identical(expected, actual) + def test_deepcopy(self, create_test_datatree): dt = create_test_datatree()