-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Implement DataTree.isel and DataTree.sel #9588
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
100712d
e312c10
9a4c895
c7fe8fe
93b7813
f3eecf4
06c5844
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,11 +32,13 @@ | |
| from xarray.core.merge import dataset_update_method | ||
| from xarray.core.options import OPTIONS as XR_OPTS | ||
| from xarray.core.treenode import NamedNode, NodePath | ||
| from xarray.core.types import Self | ||
| from xarray.core.utils import ( | ||
| Default, | ||
| FilteredMapping, | ||
| Frozen, | ||
| _default, | ||
| drop_dims_from_indexers, | ||
| either_dict_or_kwargs, | ||
| maybe_wrap_array, | ||
| ) | ||
|
|
@@ -54,7 +56,12 @@ | |
|
|
||
| from xarray.core.datatree_io import T_DataTreeNetcdfEngine, T_DataTreeNetcdfTypes | ||
| from xarray.core.merge import CoercibleMapping, CoercibleValue | ||
| from xarray.core.types import ErrorOptions, NetcdfWriteModes, ZarrWriteModes | ||
| from xarray.core.types import ( | ||
| ErrorOptions, | ||
| ErrorOptionsWithWarn, | ||
| NetcdfWriteModes, | ||
| ZarrWriteModes, | ||
| ) | ||
|
|
||
| # """ | ||
| # DEVELOPERS' NOTE | ||
|
|
@@ -1087,7 +1094,7 @@ def from_dict( | |
| d: Mapping[str, Dataset | DataTree | None], | ||
| /, | ||
| name: str | None = None, | ||
| ) -> DataTree: | ||
| ) -> Self: | ||
| """ | ||
| Create a datatree from a dictionary of data objects, organised by paths into the tree. | ||
|
|
||
|
|
@@ -1607,3 +1614,61 @@ def to_zarr( | |
| compute=compute, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| def _selective_indexing( | ||
| self, | ||
| func: Callable[[Dataset, Mapping[Any, Any]], Dataset], | ||
| indexers: Mapping[Any, Any], | ||
| missing_dims: ErrorOptionsWithWarn = "raise", | ||
| ) -> Self: | ||
| all_dims = set() | ||
| for node in self.subtree: | ||
| all_dims.update(node._node_dims) | ||
| indexers = drop_dims_from_indexers(indexers, all_dims, missing_dims) | ||
shoyer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| result = {} | ||
| for node in self.subtree: | ||
| node_indexers = {k: v for k, v in indexers.items() if k in node.dims} | ||
| node_result = func(node.dataset, node_indexers) | ||
| for k in node_indexers: | ||
| if k not in node._node_coord_variables and k in node_result.coords: | ||
| del node_result.coords[k] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So this is basically saying "if the result of indexing would cause an inherited coordinate to be duplicated, then remove that duplicated coordinate from the result"? Doesn't this imply an inefficiency, in that we are potentially doing the work to index coordinates that we're only going to drop afterwards? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a comment to explain. It is indeed a small source of inefficiency. |
||
| result[node.path] = node_result | ||
| return type(self).from_dict(result, name=self.name) | ||
|
|
||
| def isel( | ||
| self, | ||
| indexers: Mapping[Any, Any] | None = None, | ||
| drop: bool = False, | ||
| missing_dims: ErrorOptionsWithWarn = "raise", | ||
| **indexers_kwargs: Any, | ||
| ) -> Self: | ||
| """Positional indexing.""" | ||
shoyer marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def apply_indexers(dataset, node_indexers): | ||
| return dataset.isel(node_indexers, drop=drop) | ||
|
|
||
| indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel") | ||
| return self._selective_indexing( | ||
| apply_indexers, indexers, missing_dims=missing_dims | ||
| ) | ||
|
|
||
| def sel( | ||
| self, | ||
| indexers: Mapping[Any, Any] | None = None, | ||
| method: str | None = None, | ||
| tolerance: int | float | Iterable[int | float] | None = None, | ||
| drop: bool = False, | ||
| **indexers_kwargs: Any, | ||
| ) -> Self: | ||
| """Label-based indexing.""" | ||
|
|
||
| def apply_indexers(dataset, node_indexers): | ||
| # TODO: reimplement in terms of map_index_queries(), to avoid | ||
| # redundant index look-ups on child nodes | ||
|
||
| return dataset.sel( | ||
| node_indexers, method=method, tolerance=tolerance, drop=drop | ||
| ) | ||
|
|
||
| indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel") | ||
| return self._selective_indexing(apply_indexers, indexers) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -971,7 +971,6 @@ def test_ipython_key_completions(self, create_test_datatree): | |
| var_keys = list(dt.variables.keys()) | ||
| assert all(var_key in key_completions for var_key in var_keys) | ||
|
|
||
| @pytest.mark.xfail(reason="sel not implemented yet") | ||
| def test_operation_with_attrs_but_no_data(self): | ||
| # tests bug from xarray-datatree GH262 | ||
| xs = xr.Dataset({"testvar": xr.DataArray(np.ones((2, 3)))}) | ||
|
|
@@ -1561,26 +1560,86 @@ def test_filter(self): | |
| assert_identical(elders, expected) | ||
|
|
||
|
|
||
| class TestDSMethodInheritance: | ||
| @pytest.mark.xfail(reason="isel not implemented yet") | ||
| def test_dataset_method(self): | ||
| ds = xr.Dataset({"a": ("x", [1, 2, 3])}) | ||
| dt = DataTree.from_dict( | ||
| class TestIndexing: | ||
|
|
||
| def test_isel_siblings(self): | ||
| tree = DataTree.from_dict( | ||
| { | ||
| "/": ds, | ||
| "/results": ds, | ||
| "/first": xr.Dataset({"a": ("x", [1, 2])}), | ||
| "/second": xr.Dataset({"b": ("x", [1, 2, 3])}), | ||
| } | ||
| ) | ||
|
|
||
| expected = DataTree.from_dict( | ||
| { | ||
| "/": ds.isel(x=1), | ||
| "/results": ds.isel(x=1), | ||
| "/first": xr.Dataset({"a": 2}), | ||
| "/second": xr.Dataset({"b": 3}), | ||
| } | ||
| ) | ||
| actual = tree.isel(x=-1) | ||
| assert_equal(actual, expected) | ||
|
|
||
| result = dt.isel(x=1) | ||
| assert_equal(result, expected) | ||
| expected = DataTree.from_dict( | ||
| { | ||
| "/first": xr.Dataset({"a": ("x", [1])}), | ||
| "/second": xr.Dataset({"b": ("x", [1])}), | ||
| } | ||
| ) | ||
| actual = tree.isel(x=slice(1)) | ||
| assert_equal(actual, expected) | ||
|
|
||
| actual = tree.isel(x=[0]) | ||
| assert_equal(actual, expected) | ||
|
|
||
| actual = tree.isel(x=slice(None)) | ||
| assert_equal(actual, tree) | ||
|
|
||
| def test_isel_inherited(self): | ||
| tree = DataTree.from_dict( | ||
| { | ||
| "/": xr.Dataset(coords={"x": [1, 2]}), | ||
| "/child": xr.Dataset({"foo": ("x", [3, 4])}), | ||
| } | ||
| ) | ||
|
|
||
| expected = DataTree.from_dict( | ||
| { | ||
| "/": xr.Dataset(coords={"x": 2}), | ||
| "/child": xr.Dataset({"foo": 4}), | ||
| } | ||
| ) | ||
| actual = tree.isel(x=-1) | ||
| assert_equal(actual, expected) | ||
|
Comment on lines
+1605
to
+1612
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When this passes the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC this should also pass actual = tree.isel(x=[0])
assert_identical(actual, expected)because the result will have inherited instead of duplicated coordinates instead. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wrote special logic for removing duplicated scalar coordinates to handle just this! There is no longer a need to implement a different version of inheritance for
|
||
|
|
||
| expected = DataTree.from_dict( | ||
| { | ||
| "/child": xr.Dataset({"foo": 4}), | ||
| } | ||
| ) | ||
| actual = tree.isel(x=-1, drop=True) | ||
| assert_equal(actual, expected) | ||
|
|
||
| actual = tree.isel(x=slice(None)) | ||
| assert_equal(actual, tree) | ||
|
|
||
| def test_sel(self): | ||
| tree = DataTree.from_dict( | ||
| { | ||
| "/first": xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"x": [1, 2, 3]}), | ||
| "/second": xr.Dataset({"b": ("x", [4, 5])}, coords={"x": [2, 3]}), | ||
| } | ||
| ) | ||
| expected = DataTree.from_dict( | ||
| { | ||
| "/first": xr.Dataset({"a": 2}, coords={"x": 2}), | ||
| "/second": xr.Dataset({"b": 4}, coords={"x": 2}), | ||
| } | ||
| ) | ||
| actual = tree.sel(x=2) | ||
| assert_equal(actual, expected) | ||
|
|
||
|
|
||
| class TestDSMethodInheritance: | ||
|
|
||
| @pytest.mark.xfail(reason="reduce methods not implemented yet") | ||
| def test_reduce_method(self): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.