Skip to content

Commit c057d13

Browse files
authored
Implement DataTree.isel and DataTree.sel (#9588)
* Implement DataTree.isel and DataTree.sel * add api docs * fix CI failures * add docstrings for DataTree.isel and DataTree.sel * Add comments * add another indexing test
1 parent 4c3c22b commit c057d13

File tree

3 files changed

+274
-21
lines changed

3 files changed

+274
-21
lines changed

doc/api.rst

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -761,16 +761,17 @@ Compare one ``DataTree`` object to another.
761761
DataTree.equals
762762
DataTree.identical
763763

764-
.. Indexing
765-
.. --------
764+
Indexing
765+
--------
766766

767-
.. Index into all nodes in the subtree simultaneously.
767+
Index into all nodes in the subtree simultaneously.
768768

769-
.. .. autosummary::
770-
.. :toctree: generated/
769+
.. autosummary::
770+
:toctree: generated/
771+
772+
DataTree.isel
773+
DataTree.sel
771774

772-
.. DataTree.isel
773-
.. DataTree.sel
774775
.. DataTree.drop_sel
775776
.. DataTree.drop_isel
776777
.. DataTree.head

xarray/core/datatree.py

Lines changed: 186 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,13 @@
3232
from xarray.core.merge import dataset_update_method
3333
from xarray.core.options import OPTIONS as XR_OPTS
3434
from xarray.core.treenode import NamedNode, NodePath
35+
from xarray.core.types import Self
3536
from xarray.core.utils import (
3637
Default,
3738
FilteredMapping,
3839
Frozen,
3940
_default,
41+
drop_dims_from_indexers,
4042
either_dict_or_kwargs,
4143
maybe_wrap_array,
4244
)
@@ -54,7 +56,12 @@
5456

5557
from xarray.core.datatree_io import T_DataTreeNetcdfEngine, T_DataTreeNetcdfTypes
5658
from xarray.core.merge import CoercibleMapping, CoercibleValue
57-
from xarray.core.types import ErrorOptions, NetcdfWriteModes, ZarrWriteModes
59+
from xarray.core.types import (
60+
ErrorOptions,
61+
ErrorOptionsWithWarn,
62+
NetcdfWriteModes,
63+
ZarrWriteModes,
64+
)
5865

5966
# """
6067
# DEVELOPERS' NOTE
@@ -1081,7 +1088,7 @@ def from_dict(
10811088
d: Mapping[str, Dataset | DataTree | None],
10821089
/,
10831090
name: str | None = None,
1084-
) -> DataTree:
1091+
) -> Self:
10851092
"""
10861093
Create a datatree from a dictionary of data objects, organised by paths into the tree.
10871094
@@ -1601,3 +1608,180 @@ def to_zarr(
16011608
compute=compute,
16021609
**kwargs,
16031610
)
1611+
1612+
def _selective_indexing(
1613+
self,
1614+
func: Callable[[Dataset, Mapping[Any, Any]], Dataset],
1615+
indexers: Mapping[Any, Any],
1616+
missing_dims: ErrorOptionsWithWarn = "raise",
1617+
) -> Self:
1618+
"""Apply an indexing operation over the subtree, handling missing
1619+
dimensions and inherited coordinates gracefully by only applying
1620+
indexing at each node selectively.
1621+
"""
1622+
all_dims = set()
1623+
for node in self.subtree:
1624+
all_dims.update(node._node_dims)
1625+
indexers = drop_dims_from_indexers(indexers, all_dims, missing_dims)
1626+
1627+
result = {}
1628+
for node in self.subtree:
1629+
node_indexers = {k: v for k, v in indexers.items() if k in node.dims}
1630+
node_result = func(node.dataset, node_indexers)
1631+
# Indexing datasets corresponding to each node results in redundant
1632+
# coordinates when indexes from a parent node are inherited.
1633+
# Ideally, we would avoid creating such coordinates in the first
1634+
# place, but that would require implementing indexing operations at
1635+
# the Variable instead of the Dataset level.
1636+
for k in node_indexers:
1637+
if k not in node._node_coord_variables and k in node_result.coords:
1638+
# We remove all inherited coordinates. Coordinates
1639+
# corresponding to an index would be de-duplicated by
1640+
# _deduplicate_inherited_coordinates(), but indexing (e.g.,
1641+
# with a scalar) can also create scalar coordinates, which
1642+
# need to be explicitly removed.
1643+
del node_result.coords[k]
1644+
result[node.path] = node_result
1645+
return type(self).from_dict(result, name=self.name)
1646+
1647+
def isel(
1648+
self,
1649+
indexers: Mapping[Any, Any] | None = None,
1650+
drop: bool = False,
1651+
missing_dims: ErrorOptionsWithWarn = "raise",
1652+
**indexers_kwargs: Any,
1653+
) -> Self:
1654+
"""Returns a new data tree with each array indexed along the specified
1655+
dimension(s).
1656+
1657+
This method selects values from each array using its `__getitem__`
1658+
method, except this method does not require knowing the order of
1659+
each array's dimensions.
1660+
1661+
Parameters
1662+
----------
1663+
indexers : dict, optional
1664+
A dict with keys matching dimensions and values given
1665+
by integers, slice objects or arrays.
1666+
indexer can be a integer, slice, array-like or DataArray.
1667+
If DataArrays are passed as indexers, xarray-style indexing will be
1668+
carried out. See :ref:`indexing` for the details.
1669+
One of indexers or indexers_kwargs must be provided.
1670+
drop : bool, default: False
1671+
If ``drop=True``, drop coordinates variables indexed by integers
1672+
instead of making them scalar.
1673+
missing_dims : {"raise", "warn", "ignore"}, default: "raise"
1674+
What to do if dimensions that should be selected from are not present in the
1675+
Dataset:
1676+
- "raise": raise an exception
1677+
- "warn": raise a warning, and ignore the missing dimensions
1678+
- "ignore": ignore the missing dimensions
1679+
1680+
**indexers_kwargs : {dim: indexer, ...}, optional
1681+
The keyword arguments form of ``indexers``.
1682+
One of indexers or indexers_kwargs must be provided.
1683+
1684+
Returns
1685+
-------
1686+
obj : DataTree
1687+
A new DataTree with the same contents as this data tree, except each
1688+
array and dimension is indexed by the appropriate indexers.
1689+
If indexer DataArrays have coordinates that do not conflict with
1690+
this object, then these coordinates will be attached.
1691+
In general, each array's data will be a view of the array's data
1692+
in this dataset, unless vectorized indexing was triggered by using
1693+
an array indexer, in which case the data will be a copy.
1694+
1695+
See Also
1696+
--------
1697+
DataTree.sel
1698+
Dataset.isel
1699+
"""
1700+
1701+
def apply_indexers(dataset, node_indexers):
1702+
return dataset.isel(node_indexers, drop=drop)
1703+
1704+
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")
1705+
return self._selective_indexing(
1706+
apply_indexers, indexers, missing_dims=missing_dims
1707+
)
1708+
1709+
def sel(
1710+
self,
1711+
indexers: Mapping[Any, Any] | None = None,
1712+
method: str | None = None,
1713+
tolerance: int | float | Iterable[int | float] | None = None,
1714+
drop: bool = False,
1715+
**indexers_kwargs: Any,
1716+
) -> Self:
1717+
"""Returns a new data tree with each array indexed by tick labels
1718+
along the specified dimension(s).
1719+
1720+
In contrast to `DataTree.isel`, indexers for this method should use
1721+
labels instead of integers.
1722+
1723+
Under the hood, this method is powered by using pandas's powerful Index
1724+
objects. This makes label based indexing essentially just as fast as
1725+
using integer indexing.
1726+
1727+
It also means this method uses pandas's (well documented) logic for
1728+
indexing. This means you can use string shortcuts for datetime indexes
1729+
(e.g., '2000-01' to select all values in January 2000). It also means
1730+
that slices are treated as inclusive of both the start and stop values,
1731+
unlike normal Python indexing.
1732+
1733+
Parameters
1734+
----------
1735+
indexers : dict, optional
1736+
A dict with keys matching dimensions and values given
1737+
by scalars, slices or arrays of tick labels. For dimensions with
1738+
multi-index, the indexer may also be a dict-like object with keys
1739+
matching index level names.
1740+
If DataArrays are passed as indexers, xarray-style indexing will be
1741+
carried out. See :ref:`indexing` for the details.
1742+
One of indexers or indexers_kwargs must be provided.
1743+
method : {None, "nearest", "pad", "ffill", "backfill", "bfill"}, optional
1744+
Method to use for inexact matches:
1745+
1746+
* None (default): only exact matches
1747+
* pad / ffill: propagate last valid index value forward
1748+
* backfill / bfill: propagate next valid index value backward
1749+
* nearest: use nearest valid index value
1750+
tolerance : optional
1751+
Maximum distance between original and new labels for inexact
1752+
matches. The values of the index at the matching locations must
1753+
satisfy the equation ``abs(index[indexer] - target) <= tolerance``.
1754+
drop : bool, optional
1755+
If ``drop=True``, drop coordinates variables in `indexers` instead
1756+
of making them scalar.
1757+
**indexers_kwargs : {dim: indexer, ...}, optional
1758+
The keyword arguments form of ``indexers``.
1759+
One of indexers or indexers_kwargs must be provided.
1760+
1761+
Returns
1762+
-------
1763+
obj : DataTree
1764+
A new DataTree with the same contents as this data tree, except each
1765+
variable and dimension is indexed by the appropriate indexers.
1766+
If indexer DataArrays have coordinates that do not conflict with
1767+
this object, then these coordinates will be attached.
1768+
In general, each array's data will be a view of the array's data
1769+
in this dataset, unless vectorized indexing was triggered by using
1770+
an array indexer, in which case the data will be a copy.
1771+
1772+
See Also
1773+
--------
1774+
DataTree.isel
1775+
Dataset.sel
1776+
"""
1777+
1778+
def apply_indexers(dataset, node_indexers):
1779+
# TODO: reimplement in terms of map_index_queries(), to avoid
1780+
# redundant look-ups of integer positions from labels (via indexes)
1781+
# on child nodes.
1782+
return dataset.sel(
1783+
node_indexers, method=method, tolerance=tolerance, drop=drop
1784+
)
1785+
1786+
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel")
1787+
return self._selective_indexing(apply_indexers, indexers)

xarray/tests/test_datatree.py

Lines changed: 80 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,6 @@ def test_ipython_key_completions(self, create_test_datatree):
971971
var_keys = list(dt.variables.keys())
972972
assert all(var_key in key_completions for var_key in var_keys)
973973

974-
@pytest.mark.xfail(reason="sel not implemented yet")
975974
def test_operation_with_attrs_but_no_data(self):
976975
# tests bug from xarray-datatree GH262
977976
xs = xr.Dataset({"testvar": xr.DataArray(np.ones((2, 3)))})
@@ -1561,26 +1560,95 @@ def test_filter(self):
15611560
assert_identical(elders, expected)
15621561

15631562

1564-
class TestDSMethodInheritance:
1565-
@pytest.mark.xfail(reason="isel not implemented yet")
1566-
def test_dataset_method(self):
1567-
ds = xr.Dataset({"a": ("x", [1, 2, 3])})
1568-
dt = DataTree.from_dict(
1563+
class TestIndexing:
1564+
1565+
def test_isel_siblings(self):
1566+
tree = DataTree.from_dict(
15691567
{
1570-
"/": ds,
1571-
"/results": ds,
1568+
"/first": xr.Dataset({"a": ("x", [1, 2])}),
1569+
"/second": xr.Dataset({"b": ("x", [1, 2, 3])}),
15721570
}
15731571
)
15741572

15751573
expected = DataTree.from_dict(
15761574
{
1577-
"/": ds.isel(x=1),
1578-
"/results": ds.isel(x=1),
1575+
"/first": xr.Dataset({"a": 2}),
1576+
"/second": xr.Dataset({"b": 3}),
15791577
}
15801578
)
1579+
actual = tree.isel(x=-1)
1580+
assert_equal(actual, expected)
15811581

1582-
result = dt.isel(x=1)
1583-
assert_equal(result, expected)
1582+
expected = DataTree.from_dict(
1583+
{
1584+
"/first": xr.Dataset({"a": ("x", [1])}),
1585+
"/second": xr.Dataset({"b": ("x", [1])}),
1586+
}
1587+
)
1588+
actual = tree.isel(x=slice(1))
1589+
assert_equal(actual, expected)
1590+
1591+
actual = tree.isel(x=[0])
1592+
assert_equal(actual, expected)
1593+
1594+
actual = tree.isel(x=slice(None))
1595+
assert_equal(actual, tree)
1596+
1597+
def test_isel_inherited(self):
1598+
tree = DataTree.from_dict(
1599+
{
1600+
"/": xr.Dataset(coords={"x": [1, 2]}),
1601+
"/child": xr.Dataset({"foo": ("x", [3, 4])}),
1602+
}
1603+
)
1604+
1605+
expected = DataTree.from_dict(
1606+
{
1607+
"/": xr.Dataset(coords={"x": 2}),
1608+
"/child": xr.Dataset({"foo": 4}),
1609+
}
1610+
)
1611+
actual = tree.isel(x=-1)
1612+
assert_equal(actual, expected)
1613+
1614+
expected = DataTree.from_dict(
1615+
{
1616+
"/child": xr.Dataset({"foo": 4}),
1617+
}
1618+
)
1619+
actual = tree.isel(x=-1, drop=True)
1620+
assert_equal(actual, expected)
1621+
1622+
expected = DataTree.from_dict(
1623+
{
1624+
"/": xr.Dataset(coords={"x": [1]}),
1625+
"/child": xr.Dataset({"foo": ("x", [3])}),
1626+
}
1627+
)
1628+
actual = tree.isel(x=[0])
1629+
assert_equal(actual, expected)
1630+
1631+
actual = tree.isel(x=slice(None))
1632+
assert_equal(actual, tree)
1633+
1634+
def test_sel(self):
1635+
tree = DataTree.from_dict(
1636+
{
1637+
"/first": xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"x": [1, 2, 3]}),
1638+
"/second": xr.Dataset({"b": ("x", [4, 5])}, coords={"x": [2, 3]}),
1639+
}
1640+
)
1641+
expected = DataTree.from_dict(
1642+
{
1643+
"/first": xr.Dataset({"a": 2}, coords={"x": 2}),
1644+
"/second": xr.Dataset({"b": 4}, coords={"x": 2}),
1645+
}
1646+
)
1647+
actual = tree.sel(x=2)
1648+
assert_equal(actual, expected)
1649+
1650+
1651+
class TestDSMethodInheritance:
15841652

15851653
@pytest.mark.xfail(reason="reduce methods not implemented yet")
15861654
def test_reduce_method(self):

0 commit comments

Comments
 (0)