From 4bd89a6053f2c6d59d18579e722fefbc0005f205 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 14 Aug 2025 11:34:08 -0700 Subject: [PATCH 1/2] Improve DataTree typing --- xarray/backends/api.py | 2 +- xarray/core/datatree.py | 42 +++++++-- xarray/core/datatree_io.py | 6 +- xarray/core/datatree_mapping.py | 11 +-- xarray/core/treenode.py | 119 ++++++++++++------------- xarray/tests/test_backends_datatree.py | 2 +- xarray/tests/test_datatree.py | 6 +- xarray/tests/test_treenode.py | 2 +- 8 files changed, 107 insertions(+), 83 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index af1f5730340..39698a686c0 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -304,7 +304,7 @@ def _protect_dataset_variables_inplace(dataset: Dataset, cache: bool) -> None: def _protect_datatree_variables_inplace(tree: DataTree, cache: bool) -> None: for node in tree.subtree: - _protect_dataset_variables_inplace(node, cache) + _protect_dataset_variables_inplace(node.dataset, cache) def _finalize_store(write, store): diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index af71ec6c23d..3f463f6a773 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -439,7 +439,7 @@ def map( # type: ignore[override] class DataTree( - NamedNode["DataTree"], + NamedNode, DataTreeAggregations, DataTreeOpsMixin, TreeAttrAccessMixin, @@ -559,9 +559,12 @@ def _node_coord_variables_with_index(self) -> Mapping[Hashable, Variable]: @property def _coord_variables(self) -> ChainMap[Hashable, Variable]: + # ChainMap is incorrected typed in typeshed (only the first argument + # needs to be mutable) + # https://github.com/python/typeshed/issues/8430 return ChainMap( self._node_coord_variables, - *(p._node_coord_variables_with_index for p in self.parents), + *(p._node_coord_variables_with_index for p in self.parents), # type: ignore[arg-type] ) @property @@ -1340,7 +1343,7 @@ def equals(self, other: DataTree) -> bool: ) def _inherited_coords_set(self) -> set[str]: - return set(self.parent.coords if self.parent else []) + return set(self.parent.coords if self.parent else []) # type: ignore[arg-type] def identical(self, other: DataTree) -> bool: """ @@ -1563,9 +1566,33 @@ def match(self, pattern: str) -> DataTree: } return DataTree.from_dict(matching_nodes, name=self.name) + @overload def map_over_datasets( self, - func: Callable, + func: Callable[..., Dataset | None], + *args: Any, + kwargs: Mapping[str, Any] | None = None, + ) -> DataTree: ... + + @overload + def map_over_datasets( + self, + func: Callable[..., tuple[Dataset | None, Dataset | None]], + *args: Any, + kwargs: Mapping[str, Any] | None = None, + ) -> tuple[DataTree, DataTree]: ... + + @overload + def map_over_datasets( + self, + func: Callable[..., tuple[Dataset | None, ...]], + *args: Any, + kwargs: Mapping[str, Any] | None = None, + ) -> tuple[DataTree, ...]: ... + + def map_over_datasets( + self, + func: Callable[..., Dataset | None | tuple[Dataset | None, ...]], *args: Any, kwargs: Mapping[str, Any] | None = None, ) -> DataTree | tuple[DataTree, ...]: @@ -1600,8 +1627,7 @@ def map_over_datasets( map_over_datasets """ # TODO this signature means that func has no way to know which node it is being called upon - change? - # TODO fix this typing error - return map_over_datasets(func, self, *args, kwargs=kwargs) + return map_over_datasets(func, self, *args, kwargs=kwargs) # type: ignore[arg-type] @overload def pipe( @@ -1695,7 +1721,7 @@ def groups(self): def _unary_op(self, f, *args, **kwargs) -> DataTree: # TODO do we need to any additional work to avoid duplication etc.? (Similar to aggregations) - return self.map_over_datasets(functools.partial(f, **kwargs), *args) # type: ignore[return-value] + return self.map_over_datasets(functools.partial(f, **kwargs), *args) def _binary_op(self, other, f, reflexive=False, join=None) -> DataTree: from xarray.core.groupby import GroupBy @@ -1911,7 +1937,7 @@ def to_zarr( ) def _get_all_dims(self) -> set: - all_dims = set() + all_dims: set[Any] = set() for node in self.subtree: all_dims.update(node._node_dims) return all_dims diff --git a/xarray/core/datatree_io.py b/xarray/core/datatree_io.py index c586caaba89..825a7dbb9cb 100644 --- a/xarray/core/datatree_io.py +++ b/xarray/core/datatree_io.py @@ -74,13 +74,13 @@ def _datatree_to_netcdf( at_root = node is dt ds = node.to_dataset(inherit=write_inherited_coords or at_root) group_path = None if at_root else "/" + node.relative_to(dt) - ds.to_netcdf( + ds.to_netcdf( # type: ignore[misc] # Not all union combinations were tried because there are too many unions target, group=group_path, mode=mode, encoding=encoding.get(node.path), unlimited_dims=unlimited_dims.get(node.path), - engine=engine, + engine=engine, # type: ignore[arg-type] format=format, compute=compute, **kwargs, @@ -134,7 +134,7 @@ def _datatree_to_zarr( at_root = node is dt ds = node.to_dataset(inherit=write_inherited_coords or at_root) group_path = None if at_root else "/" + node.relative_to(dt) - ds.to_zarr( + ds.to_zarr( # type: ignore[call-overload] store, group=group_path, mode=mode, diff --git a/xarray/core/datatree_mapping.py b/xarray/core/datatree_mapping.py index f9fd5505b66..dc2fc591f44 100644 --- a/xarray/core/datatree_mapping.py +++ b/xarray/core/datatree_mapping.py @@ -13,15 +13,14 @@ @overload def map_over_datasets( - func: Callable[ - ..., - Dataset | None, - ], + func: Callable[..., Dataset | None], *args: Any, kwargs: Mapping[str, Any] | None = None, ) -> DataTree: ... +# add an explicit overload for the most common case of two return values +# (python typing does not have a way to match tuple lengths in general) @overload def map_over_datasets( func: Callable[..., tuple[Dataset | None, Dataset | None]], @@ -30,8 +29,6 @@ def map_over_datasets( ) -> tuple[DataTree, DataTree]: ... -# add an expect overload for the most common case of two return values -# (python typing does not have a way to match tuple lengths in general) @overload def map_over_datasets( func: Callable[..., tuple[Dataset | None, ...]], @@ -41,7 +38,7 @@ def map_over_datasets( def map_over_datasets( - func: Callable[..., Dataset | tuple[Dataset | None, ...] | None], + func: Callable[..., Dataset | None | tuple[Dataset | None, ...]], *args: Any, kwargs: Mapping[str, Any] | None = None, ) -> DataTree | tuple[DataTree, ...]: diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index a38ec6b2d2b..d15dc51ec33 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -7,7 +7,6 @@ from typing import ( TYPE_CHECKING, Any, - Generic, TypeVar, ) @@ -15,7 +14,7 @@ from xarray.core.utils import Frozen, is_dict_like if TYPE_CHECKING: - from xarray.core.types import T_DataArray + from xarray.core.dataarray import DataArray class InvalidTreeError(Exception): @@ -44,10 +43,7 @@ def __init__(self, *pathsegments): # TODO should we also forbid suffixes to avoid node names with dots in them? -Tree = TypeVar("Tree", bound="TreeNode") - - -class TreeNode(Generic[Tree]): +class TreeNode: """ Base class representing a node of a tree, with methods for traversing and altering the tree. @@ -74,10 +70,10 @@ class TreeNode(Generic[Tree]): """ - _parent: Tree | None - _children: dict[str, Tree] + _parent: Self | None + _children: dict[str, Self] - def __init__(self, children: Mapping[str, Tree] | None = None): + def __init__(self, children: Mapping[str, Self] | None = None): """Create a parentless node.""" self._parent = None self._children = {} @@ -87,18 +83,18 @@ def __init__(self, children: Mapping[str, Tree] | None = None): self.children = {name: child.copy() for name, child in children.items()} @property - def parent(self) -> Tree | None: + def parent(self) -> Self | None: """Parent of this node.""" return self._parent @parent.setter - def parent(self: Tree, new_parent: Tree) -> None: + def parent(self, new_parent: Self) -> None: raise AttributeError( "Cannot set parent attribute directly, you must modify the children of the other node instead using dict-like syntax" ) def _set_parent( - self, new_parent: Tree | None, child_name: str | None = None + self, new_parent: Self | None, child_name: str | None = None ) -> None: # TODO is it possible to refactor in a way that removes this private method? @@ -114,7 +110,7 @@ def _set_parent( self._detach(old_parent) self._attach(new_parent, child_name) - def _check_loop(self, new_parent: Tree | None) -> None: + def _check_loop(self, new_parent: Self | None) -> None: """Checks that assignment of this new parent will not create a cycle.""" if new_parent is not None: if new_parent is self: @@ -127,10 +123,10 @@ def _check_loop(self, new_parent: Tree | None) -> None: "Cannot set parent, as intended parent is already a descendant of this node." ) - def _is_descendant_of(self, node: Tree) -> bool: + def _is_descendant_of(self, node: Self) -> bool: return any(n is self for n in node.parents) - def _detach(self, parent: Tree | None) -> None: + def _detach(self, parent: Self | None) -> None: if parent is not None: self._pre_detach(parent) parents_children = parent.children @@ -142,7 +138,7 @@ def _detach(self, parent: Tree | None) -> None: self._parent = None self._post_detach(parent) - def _attach(self, parent: Tree | None, child_name: str | None = None) -> None: + def _attach(self, parent: Self | None, child_name: str | None = None) -> None: if parent is not None: if child_name is None: raise ValueError( @@ -165,12 +161,12 @@ def orphan(self) -> None: self._set_parent(new_parent=None) @property - def children(self: Tree) -> Mapping[str, Tree]: + def children(self) -> Mapping[str, Self]: """Child nodes of this node, stored under a mapping via their names.""" return Frozen(self._children) @children.setter - def children(self: Tree, children: Mapping[str, Tree]) -> None: + def children(self, children: Mapping[str, Self]) -> None: self._check_children(children) children = {**children} @@ -198,7 +194,7 @@ def children(self) -> None: self._post_detach_children(children) @staticmethod - def _check_children(children: Mapping[str, Tree]) -> None: + def _check_children(children: Mapping[str, TreeNode]) -> None: """Check children for correct types and for any duplicates.""" if not is_dict_like(children): raise TypeError( @@ -224,19 +220,19 @@ def _check_children(children: Mapping[str, Tree]) -> None: def __repr__(self) -> str: return f"TreeNode(children={dict(self._children)})" - def _pre_detach_children(self: Tree, children: Mapping[str, Tree]) -> None: + def _pre_detach_children(self, children: Mapping[str, Self]) -> None: """Method call before detaching `children`.""" pass - def _post_detach_children(self: Tree, children: Mapping[str, Tree]) -> None: + def _post_detach_children(self, children: Mapping[str, Self]) -> None: """Method call after detaching `children`.""" pass - def _pre_attach_children(self: Tree, children: Mapping[str, Tree]) -> None: + def _pre_attach_children(self, children: Mapping[str, Self]) -> None: """Method call before attaching `children`.""" pass - def _post_attach_children(self: Tree, children: Mapping[str, Tree]) -> None: + def _post_attach_children(self, children: Mapping[str, Self]) -> None: """Method call after attaching `children`.""" pass @@ -300,14 +296,14 @@ def __copy__(self) -> Self: def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self: return self._copy_subtree(inherit=True, deep=True, memo=memo) - def _iter_parents(self: Tree) -> Iterator[Tree]: + def _iter_parents(self) -> Iterator[Self]: """Iterate up the tree, starting from the current node's parent.""" - node: Tree | None = self.parent + node: Self | None = self.parent while node is not None: yield node node = node.parent - def iter_lineage(self: Tree) -> tuple[Tree, ...]: + def iter_lineage(self) -> tuple[Self, ...]: """Iterate up the tree, starting from the current node.""" from warnings import warn @@ -320,7 +316,7 @@ def iter_lineage(self: Tree) -> tuple[Tree, ...]: return (self, *self.parents) @property - def lineage(self: Tree) -> tuple[Tree, ...]: + def lineage(self) -> tuple[Self, ...]: """All parent nodes and their parent nodes, starting with the closest.""" from warnings import warn @@ -333,12 +329,12 @@ def lineage(self: Tree) -> tuple[Tree, ...]: return self.iter_lineage() @property - def parents(self: Tree) -> tuple[Tree, ...]: + def parents(self) -> tuple[Self, ...]: """All parent nodes and their parent nodes, starting with the closest.""" return tuple(self._iter_parents()) @property - def ancestors(self: Tree) -> tuple[Tree, ...]: + def ancestors(self) -> tuple[Self, ...]: """All parent nodes and their parent nodes, starting with the most distant.""" from warnings import warn @@ -352,7 +348,7 @@ def ancestors(self: Tree) -> tuple[Tree, ...]: return (*reversed(self.parents), self) @property - def root(self: Tree) -> Tree: + def root(self) -> Self: """Root node of the tree""" node = self while node.parent is not None: @@ -374,7 +370,7 @@ def is_leaf(self) -> bool: return self.children == {} @property - def leaves(self: Tree) -> tuple[Tree, ...]: + def leaves(self) -> tuple[Self, ...]: """ All leaf nodes. @@ -383,7 +379,7 @@ def leaves(self: Tree) -> tuple[Tree, ...]: return tuple(node for node in self.subtree if node.is_leaf) @property - def siblings(self: Tree) -> dict[str, Tree]: + def siblings(self) -> dict[str, Self]: """ Nodes with the same parent as this node. """ @@ -397,7 +393,7 @@ def siblings(self: Tree) -> dict[str, Tree]: return {} @property - def subtree(self: Tree) -> Iterator[Tree]: + def subtree(self) -> Iterator[Self]: """ Iterate over all nodes in this tree, including both self and all descendants. @@ -417,7 +413,7 @@ def subtree(self: Tree) -> Iterator[Tree]: queue.extend(node.children.values()) @property - def subtree_with_keys(self: Tree) -> Iterator[tuple[str, Tree]]: + def subtree_with_keys(self) -> Iterator[tuple[str, Self]]: """ Iterate over relative paths and node pairs for all nodes in this tree. @@ -436,7 +432,7 @@ def subtree_with_keys(self: Tree) -> Iterator[tuple[str, Tree]]: queue.extend((path / name, child) for name, child in node.children.items()) @property - def descendants(self: Tree) -> tuple[Tree, ...]: + def descendants(self) -> tuple[Self, ...]: """ Child nodes and all their child nodes. @@ -451,7 +447,7 @@ def descendants(self: Tree) -> tuple[Tree, ...]: return tuple(descendants) @property - def level(self: Tree) -> int: + def level(self) -> int: """ Level of this node. @@ -470,7 +466,7 @@ def level(self: Tree) -> int: return len(self.parents) @property - def depth(self: Tree) -> int: + def depth(self) -> int: """ Maximum level of this tree. @@ -488,7 +484,7 @@ def depth(self: Tree) -> int: return max(node.level for node in self.root.subtree) @property - def width(self: Tree) -> int: + def width(self) -> int: """ Number of nodes at this level in the tree. @@ -505,23 +501,23 @@ def width(self: Tree) -> int: """ return len([node for node in self.root.subtree if node.level == self.level]) - def _pre_detach(self: Tree, parent: Tree) -> None: + def _pre_detach(self, parent: Self) -> None: """Method call before detaching from `parent`.""" pass - def _post_detach(self: Tree, parent: Tree) -> None: + def _post_detach(self, parent: Self) -> None: """Method call after detaching from `parent`.""" pass - def _pre_attach(self: Tree, parent: Tree, name: str) -> None: + def _pre_attach(self, parent: Self, name: str) -> None: """Method call before attaching to `parent`.""" pass - def _post_attach(self: Tree, parent: Tree, name: str) -> None: + def _post_attach(self, parent: Self, name: str) -> None: """Method call after attaching to `parent`.""" pass - def get(self: Tree, key: str, default: Tree | None = None) -> Tree | None: + def get(self, key: str, default: Self | None = None) -> Self | None: """ Return the child node with the specified key. @@ -535,7 +531,7 @@ def get(self: Tree, key: str, default: Tree | None = None) -> Tree | None: # TODO `._walk` method to be called by both `_get_item` and `_set_item` - def _get_item(self: Tree, path: str | NodePath) -> Tree | T_DataArray: + def _get_item(self, path: str | NodePath) -> Self | DataArray: """ Returns the object lying at the given path. @@ -559,13 +555,14 @@ def _get_item(self: Tree, path: str | NodePath) -> Tree | T_DataArray: current_node = current_node.parent elif part in ("", "."): pass - elif current_node.get(part) is None: - raise KeyError(f"Could not find node at {path}") else: - current_node = current_node.get(part) + child = current_node.get(part) + if child is None: + raise KeyError(f"Could not find node at {path}") + current_node = child return current_node - def _set(self: Tree, key: str, val: Tree) -> None: + def _set(self, key: str, val: Any) -> None: """ Set the child node with the specified key to value. @@ -575,9 +572,9 @@ def _set(self: Tree, key: str, val: Tree) -> None: self.children = new_children def _set_item( - self: Tree, + self, path: str | NodePath, - item: Tree | T_DataArray, + item: Any, new_nodes_along_path: bool = False, allow_overwrite: bool = True, ) -> None: @@ -649,7 +646,7 @@ def _set_item( else: current_node._set(name, item) - def __delitem__(self: Tree, key: str) -> None: + def __delitem__(self, key: str) -> None: """Remove a child node from this tree object.""" if key in self.children: child = self._children[key] @@ -658,7 +655,7 @@ def __delitem__(self: Tree, key: str) -> None: else: raise KeyError(key) - def same_tree(self, other: Tree) -> bool: + def same_tree(self, other: Self) -> bool: """True if other node is in the same tree as this node.""" return self.root is other.root @@ -674,7 +671,7 @@ def _validate_name(name: str | None) -> None: raise ValueError("node names cannot contain forward slashes") -class NamedNode(TreeNode, Generic[Tree]): +class NamedNode(TreeNode): """ A TreeNode which knows its own name. @@ -682,10 +679,12 @@ class NamedNode(TreeNode, Generic[Tree]): """ _name: str | None - _parent: Tree | None - _children: dict[str, Tree] - def __init__(self, name=None, children=None): + def __init__( + self, + name: str | None = None, + children: Mapping[str, Self] | None = None, + ): super().__init__(children=children) _validate_name(name) self._name = name @@ -738,9 +737,9 @@ def path(self) -> str: root, *ancestors = tuple(reversed(self.parents)) # don't include name of root because (a) root might not have a name & (b) we want path relative to root. names = [*(node.name for node in ancestors), self.name] - return "/" + "/".join(names) + return "/" + "/".join(names) # type: ignore[arg-type] - def relative_to(self: NamedNode, other: NamedNode) -> str: + def relative_to(self, other: Self) -> str: """ Compute the relative path from this node to node `other`. @@ -761,7 +760,7 @@ def relative_to(self: NamedNode, other: NamedNode) -> str: path_to_common_ancestor / this_path.relative_to(common_ancestor.path) ) - def find_common_ancestor(self, other: NamedNode) -> NamedNode: + def find_common_ancestor(self, other: Self) -> Self: """ Find the first common ancestor of two nodes in the same tree. @@ -779,7 +778,7 @@ def find_common_ancestor(self, other: NamedNode) -> NamedNode: "Cannot find common ancestor because nodes do not lie within the same tree" ) - def _path_to_ancestor(self, ancestor: NamedNode) -> NodePath: + def _path_to_ancestor(self, ancestor: Self) -> NodePath: """Return the relative path from this node to the given ancestor node""" if not self.same_tree(ancestor): diff --git a/xarray/tests/test_backends_datatree.py b/xarray/tests/test_backends_datatree.py index 2bdae6def27..52909266002 100644 --- a/xarray/tests/test_backends_datatree.py +++ b/xarray/tests/test_backends_datatree.py @@ -724,7 +724,7 @@ def assert_expected_zarr_files_exist( # inherited variables aren't meant to be written to zarr local_node_variables = node.to_dataset(inherit=False).variables for name, var in local_node_variables.items(): - var_dir = storepath / node.path.removeprefix("/") / name + var_dir = storepath / node.path.removeprefix("/") / name # type: ignore[operator] assert_expected_zarr_files_exist( arr_dir=var_dir, diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index efa25386440..c394fee8ef2 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -165,7 +165,9 @@ def test_same_tree(self) -> None: "/Kate": DataTree(), } ) - assert john["/Mary"].same_tree(john["/Kate"]) + mary = john.children["/Mary"] + kate = john.children["/Kate"] + assert mary.same_tree(kate) def test_relative_paths(self) -> None: john = DataTree.from_dict( @@ -1884,7 +1886,7 @@ def test_differently_inherited_coordinates(self): assert not child.identical(new_child) deeper_root = DataTree(children={"root": root}) - grandchild = deeper_root["/root/child"] + grandchild = deeper_root.children["/root/child"] assert child.equals(grandchild) assert child.identical(grandchild) diff --git a/xarray/tests/test_treenode.py b/xarray/tests/test_treenode.py index 5f9c586fce2..7eb715630bb 100644 --- a/xarray/tests/test_treenode.py +++ b/xarray/tests/test_treenode.py @@ -100,7 +100,7 @@ def test_doppelganger_child(self) -> None: john: TreeNode = TreeNode() with pytest.raises(TypeError): - john.children = {"Kate": 666} + john.children = {"Kate": 666} # type: ignore[dict-item] with pytest.raises(InvalidTreeError, match="Cannot add same node"): john.children = {"Kate": kate, "Evil_Kate": kate} From 5e2f9dd442031e012c7f8fbb1b0ad9a0a5a374b0 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 14 Aug 2025 12:09:38 -0700 Subject: [PATCH 2/2] a few more fixes --- xarray/tests/test_datatree.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index c394fee8ef2..7c114d31104 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -165,8 +165,8 @@ def test_same_tree(self) -> None: "/Kate": DataTree(), } ) - mary = john.children["/Mary"] - kate = john.children["/Kate"] + mary = john.children["Mary"] + kate = john.children["Kate"] assert mary.same_tree(kate) def test_relative_paths(self) -> None: @@ -176,14 +176,8 @@ def test_relative_paths(self) -> None: "/Annie": DataTree(), } ) - sue_result = john["Mary/Sue"] - if isinstance(sue_result, DataTree): - sue: DataTree = sue_result - - annie_result = john["Annie"] - if isinstance(annie_result, DataTree): - annie: DataTree = annie_result - + sue = john.children["Mary"].children["Sue"] + annie = john.children["Annie"] assert sue.relative_to(john) == "Mary/Sue" assert john.relative_to(sue) == "../.." assert annie.relative_to(sue) == "../../Annie" @@ -1886,7 +1880,7 @@ def test_differently_inherited_coordinates(self): assert not child.identical(new_child) deeper_root = DataTree(children={"root": root}) - grandchild = deeper_root.children["/root/child"] + grandchild = deeper_root.children["root"].children["child"] assert child.equals(grandchild) assert child.identical(grandchild)