Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def _protect_dataset_variables_inplace(dataset: Dataset, cache: bool) -> None:

def _protect_datatree_variables_inplace(tree: DataTree, cache: bool) -> None:
for node in tree.subtree:
_protect_dataset_variables_inplace(node, cache)
_protect_dataset_variables_inplace(node.dataset, cache)


def _finalize_store(write, store):
Expand Down
42 changes: 34 additions & 8 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def map( # type: ignore[override]


class DataTree(
NamedNode["DataTree"],
NamedNode,
DataTreeAggregations,
DataTreeOpsMixin,
TreeAttrAccessMixin,
Expand Down Expand Up @@ -559,9 +559,12 @@ def _node_coord_variables_with_index(self) -> Mapping[Hashable, Variable]:

@property
def _coord_variables(self) -> ChainMap[Hashable, Variable]:
# ChainMap is incorrected typed in typeshed (only the first argument
# needs to be mutable)
# https://github.com/python/typeshed/issues/8430
return ChainMap(
self._node_coord_variables,
*(p._node_coord_variables_with_index for p in self.parents),
*(p._node_coord_variables_with_index for p in self.parents), # type: ignore[arg-type]
)

@property
Expand Down Expand Up @@ -1340,7 +1343,7 @@ def equals(self, other: DataTree) -> bool:
)

def _inherited_coords_set(self) -> set[str]:
return set(self.parent.coords if self.parent else [])
return set(self.parent.coords if self.parent else []) # type: ignore[arg-type]

def identical(self, other: DataTree) -> bool:
"""
Expand Down Expand Up @@ -1563,9 +1566,33 @@ def match(self, pattern: str) -> DataTree:
}
return DataTree.from_dict(matching_nodes, name=self.name)

@overload
def map_over_datasets(
self,
func: Callable,
func: Callable[..., Dataset | None],
*args: Any,
kwargs: Mapping[str, Any] | None = None,
) -> DataTree: ...

@overload
def map_over_datasets(
self,
func: Callable[..., tuple[Dataset | None, Dataset | None]],
*args: Any,
kwargs: Mapping[str, Any] | None = None,
) -> tuple[DataTree, DataTree]: ...

@overload
def map_over_datasets(
self,
func: Callable[..., tuple[Dataset | None, ...]],
*args: Any,
kwargs: Mapping[str, Any] | None = None,
) -> tuple[DataTree, ...]: ...

def map_over_datasets(
self,
func: Callable[..., Dataset | None | tuple[Dataset | None, ...]],
*args: Any,
kwargs: Mapping[str, Any] | None = None,
) -> DataTree | tuple[DataTree, ...]:
Expand Down Expand Up @@ -1600,8 +1627,7 @@ def map_over_datasets(
map_over_datasets
"""
# TODO this signature means that func has no way to know which node it is being called upon - change?
# TODO fix this typing error
return map_over_datasets(func, self, *args, kwargs=kwargs)
return map_over_datasets(func, self, *args, kwargs=kwargs) # type: ignore[arg-type]

@overload
def pipe(
Expand Down Expand Up @@ -1695,7 +1721,7 @@ def groups(self):

def _unary_op(self, f, *args, **kwargs) -> DataTree:
# TODO do we need to any additional work to avoid duplication etc.? (Similar to aggregations)
return self.map_over_datasets(functools.partial(f, **kwargs), *args) # type: ignore[return-value]
return self.map_over_datasets(functools.partial(f, **kwargs), *args)

def _binary_op(self, other, f, reflexive=False, join=None) -> DataTree:
from xarray.core.groupby import GroupBy
Expand Down Expand Up @@ -1911,7 +1937,7 @@ def to_zarr(
)

def _get_all_dims(self) -> set:
all_dims = set()
all_dims: set[Any] = set()
for node in self.subtree:
all_dims.update(node._node_dims)
return all_dims
Expand Down
6 changes: 3 additions & 3 deletions xarray/core/datatree_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,13 @@ def _datatree_to_netcdf(
at_root = node is dt
ds = node.to_dataset(inherit=write_inherited_coords or at_root)
group_path = None if at_root else "/" + node.relative_to(dt)
ds.to_netcdf(
ds.to_netcdf( # type: ignore[misc] # Not all union combinations were tried because there are too many unions
target,
group=group_path,
mode=mode,
encoding=encoding.get(node.path),
unlimited_dims=unlimited_dims.get(node.path),
engine=engine,
engine=engine, # type: ignore[arg-type]
format=format,
compute=compute,
**kwargs,
Expand Down Expand Up @@ -134,7 +134,7 @@ def _datatree_to_zarr(
at_root = node is dt
ds = node.to_dataset(inherit=write_inherited_coords or at_root)
group_path = None if at_root else "/" + node.relative_to(dt)
ds.to_zarr(
ds.to_zarr( # type: ignore[call-overload]
store,
group=group_path,
mode=mode,
Expand Down
11 changes: 4 additions & 7 deletions xarray/core/datatree_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@

@overload
def map_over_datasets(
func: Callable[
...,
Dataset | None,
],
func: Callable[..., Dataset | None],
*args: Any,
kwargs: Mapping[str, Any] | None = None,
) -> DataTree: ...


# add an explicit overload for the most common case of two return values
# (python typing does not have a way to match tuple lengths in general)
@overload
def map_over_datasets(
func: Callable[..., tuple[Dataset | None, Dataset | None]],
Expand All @@ -30,8 +29,6 @@ def map_over_datasets(
) -> tuple[DataTree, DataTree]: ...


# add an expect overload for the most common case of two return values
# (python typing does not have a way to match tuple lengths in general)
@overload
def map_over_datasets(
func: Callable[..., tuple[Dataset | None, ...]],
Expand All @@ -41,7 +38,7 @@ def map_over_datasets(


def map_over_datasets(
func: Callable[..., Dataset | tuple[Dataset | None, ...] | None],
func: Callable[..., Dataset | None | tuple[Dataset | None, ...]],
*args: Any,
kwargs: Mapping[str, Any] | None = None,
) -> DataTree | tuple[DataTree, ...]:
Expand Down
Loading
Loading