Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -761,16 +761,17 @@ Compare one ``DataTree`` object to another.
DataTree.equals
DataTree.identical

.. Indexing
.. --------
Indexing
--------

.. Index into all nodes in the subtree simultaneously.
Index into all nodes in the subtree simultaneously.

.. .. autosummary::
.. :toctree: generated/
.. autosummary::
:toctree: generated/

DataTree.isel
DataTree.sel

.. DataTree.isel
.. DataTree.sel
.. DataTree.drop_sel
.. DataTree.drop_isel
.. DataTree.head
Expand Down
69 changes: 67 additions & 2 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
from xarray.core.merge import dataset_update_method
from xarray.core.options import OPTIONS as XR_OPTS
from xarray.core.treenode import NamedNode, NodePath
from xarray.core.types import Self
from xarray.core.utils import (
Default,
FilteredMapping,
Frozen,
_default,
drop_dims_from_indexers,
either_dict_or_kwargs,
maybe_wrap_array,
)
Expand All @@ -54,7 +56,12 @@

from xarray.core.datatree_io import T_DataTreeNetcdfEngine, T_DataTreeNetcdfTypes
from xarray.core.merge import CoercibleMapping, CoercibleValue
from xarray.core.types import ErrorOptions, NetcdfWriteModes, ZarrWriteModes
from xarray.core.types import (
ErrorOptions,
ErrorOptionsWithWarn,
NetcdfWriteModes,
ZarrWriteModes,
)

# """
# DEVELOPERS' NOTE
Expand Down Expand Up @@ -1087,7 +1094,7 @@ def from_dict(
d: Mapping[str, Dataset | DataTree | None],
/,
name: str | None = None,
) -> DataTree:
) -> Self:
"""
Create a datatree from a dictionary of data objects, organised by paths into the tree.

Expand Down Expand Up @@ -1607,3 +1614,61 @@ def to_zarr(
compute=compute,
**kwargs,
)

def _selective_indexing(
self,
func: Callable[[Dataset, Mapping[Any, Any]], Dataset],
indexers: Mapping[Any, Any],
missing_dims: ErrorOptionsWithWarn = "raise",
) -> Self:
all_dims = set()
for node in self.subtree:
all_dims.update(node._node_dims)
indexers = drop_dims_from_indexers(indexers, all_dims, missing_dims)

result = {}
for node in self.subtree:
node_indexers = {k: v for k, v in indexers.items() if k in node.dims}
node_result = func(node.dataset, node_indexers)
for k in node_indexers:
if k not in node._node_coord_variables and k in node_result.coords:
del node_result.coords[k]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is basically saying "if the result of indexing would cause an inherited coordinate to be duplicated, then remove that duplicated coordinate from the result"?

Doesn't this imply an inefficiency, in that we are potentially doing the work to index coordinates that we're only going to drop afterwards?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a comment to explain. It is indeed a small source of inefficiency.

result[node.path] = node_result
return type(self).from_dict(result, name=self.name)

def isel(
self,
indexers: Mapping[Any, Any] | None = None,
drop: bool = False,
missing_dims: ErrorOptionsWithWarn = "raise",
**indexers_kwargs: Any,
) -> Self:
"""Positional indexing."""

def apply_indexers(dataset, node_indexers):
return dataset.isel(node_indexers, drop=drop)

indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")
return self._selective_indexing(
apply_indexers, indexers, missing_dims=missing_dims
)

def sel(
self,
indexers: Mapping[Any, Any] | None = None,
method: str | None = None,
tolerance: int | float | Iterable[int | float] | None = None,
drop: bool = False,
**indexers_kwargs: Any,
) -> Self:
"""Label-based indexing."""

def apply_indexers(dataset, node_indexers):
# TODO: reimplement in terms of map_index_queries(), to avoid
# redundant index look-ups on child nodes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this related to my comment above or a separate issue?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a separate issue.

return dataset.sel(
node_indexers, method=method, tolerance=tolerance, drop=drop
)

indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel")
return self._selective_indexing(apply_indexers, indexers)
83 changes: 71 additions & 12 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,6 @@ def test_ipython_key_completions(self, create_test_datatree):
var_keys = list(dt.variables.keys())
assert all(var_key in key_completions for var_key in var_keys)

@pytest.mark.xfail(reason="sel not implemented yet")
def test_operation_with_attrs_but_no_data(self):
# tests bug from xarray-datatree GH262
xs = xr.Dataset({"testvar": xr.DataArray(np.ones((2, 3)))})
Expand Down Expand Up @@ -1561,26 +1560,86 @@ def test_filter(self):
assert_identical(elders, expected)


class TestDSMethodInheritance:
@pytest.mark.xfail(reason="isel not implemented yet")
def test_dataset_method(self):
ds = xr.Dataset({"a": ("x", [1, 2, 3])})
dt = DataTree.from_dict(
class TestIndexing:

def test_isel_siblings(self):
tree = DataTree.from_dict(
{
"/": ds,
"/results": ds,
"/first": xr.Dataset({"a": ("x", [1, 2])}),
"/second": xr.Dataset({"b": ("x", [1, 2, 3])}),
}
)

expected = DataTree.from_dict(
{
"/": ds.isel(x=1),
"/results": ds.isel(x=1),
"/first": xr.Dataset({"a": 2}),
"/second": xr.Dataset({"b": 3}),
}
)
actual = tree.isel(x=-1)
assert_equal(actual, expected)

result = dt.isel(x=1)
assert_equal(result, expected)
expected = DataTree.from_dict(
{
"/first": xr.Dataset({"a": ("x", [1])}),
"/second": xr.Dataset({"b": ("x", [1])}),
}
)
actual = tree.isel(x=slice(1))
assert_equal(actual, expected)

actual = tree.isel(x=[0])
assert_equal(actual, expected)

actual = tree.isel(x=slice(None))
assert_equal(actual, tree)

def test_isel_inherited(self):
tree = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": [1, 2]}),
"/child": xr.Dataset({"foo": ("x", [3, 4])}),
}
)

expected = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": 2}),
"/child": xr.Dataset({"foo": 4}),
}
)
actual = tree.isel(x=-1)
assert_equal(actual, expected)
Comment on lines +1605 to +1612
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When this passes the x coordinate gets duplicated doesn't it? As the result is a non-indexed scalar. Does that mean this would have failed if you had used the definition of assert_identical from #9473 to check for duplication?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC this should also pass

actual = tree.isel(x=[0])
assert_identical(actual, expected)

because the result will have inherited instead of duplicated coordinates instead.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote special logic for removing duplicated scalar coordinates to handle just this!

There is no longer a need to implement a different version of inheritance for assert_identical than for assert_equal, because we no longer allow any cases in the DataTree data model where inheritance is ambiguous:

  • Deduplication of coordinates with an index removes them from child nodes.
  • Other coordinates are no longer inherited.


expected = DataTree.from_dict(
{
"/child": xr.Dataset({"foo": 4}),
}
)
actual = tree.isel(x=-1, drop=True)
assert_equal(actual, expected)

actual = tree.isel(x=slice(None))
assert_equal(actual, tree)

def test_sel(self):
tree = DataTree.from_dict(
{
"/first": xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"x": [1, 2, 3]}),
"/second": xr.Dataset({"b": ("x", [4, 5])}, coords={"x": [2, 3]}),
}
)
expected = DataTree.from_dict(
{
"/first": xr.Dataset({"a": 2}, coords={"x": 2}),
"/second": xr.Dataset({"b": 4}, coords={"x": 2}),
}
)
actual = tree.sel(x=2)
assert_equal(actual, expected)


class TestDSMethodInheritance:

@pytest.mark.xfail(reason="reduce methods not implemented yet")
def test_reduce_method(self):
Expand Down
Loading