Skip to content

Commit fecaa85

Browse files
authored
Remove duplicate coordinates with indexes on sub-trees (#9531)
1 parent 22e93e7 commit fecaa85

File tree

3 files changed

+71
-3
lines changed

3 files changed

+71
-3
lines changed

xarray/core/datatree.py

+31-2
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,34 @@ def check_alignment(
157157
check_alignment(child_path, child_ds, base_ds, child.children)
158158

159159

160+
def _deduplicate_inherited_coordinates(child: DataTree, parent: DataTree) -> None:
161+
# This method removes repeated indexes (and corresponding coordinates)
162+
# that are repeated between a DataTree and its parents.
163+
#
164+
# TODO(shoyer): Decide how to handle repeated coordinates *without* an
165+
# index. Should these be allowed, in which case we probably want to
166+
# exclude them from inheritance, or should they be automatically
167+
# dropped?
168+
# https://github.com/pydata/xarray/issues/9475#issuecomment-2357004264
169+
removed_something = False
170+
for name in parent._indexes:
171+
if name in child._node_indexes:
172+
# Indexes on a Dataset always have a corresponding coordinate.
173+
# We already verified that these coordinates match in the
174+
# check_alignment() call from _pre_attach().
175+
del child._node_indexes[name]
176+
del child._node_coord_variables[name]
177+
removed_something = True
178+
179+
if removed_something:
180+
child._node_dims = calculate_dimensions(
181+
child._data_variables | child._node_coord_variables
182+
)
183+
184+
for grandchild in child._children.values():
185+
_deduplicate_inherited_coordinates(grandchild, child)
186+
187+
160188
def _check_for_slashes_in_names(variables: Iterable[Hashable]) -> None:
161189
offending_variable_names = [
162190
name for name in variables if isinstance(name, str) and "/" in name
@@ -375,7 +403,7 @@ def map( # type: ignore[override]
375403

376404

377405
class DataTree(
378-
NamedNode,
406+
NamedNode["DataTree"],
379407
MappedDatasetMethodsMixin,
380408
MappedDataWithCoords,
381409
DataTreeArithmeticMixin,
@@ -486,6 +514,7 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None:
486514
node_ds = self.to_dataset(inherited=False)
487515
parent_ds = parent._to_dataset_view(rebuild_dims=False, inherited=True)
488516
check_alignment(path, node_ds, parent_ds, self.children)
517+
_deduplicate_inherited_coordinates(self, parent)
489518

490519
@property
491520
def _coord_variables(self) -> ChainMap[Hashable, Variable]:
@@ -1353,7 +1382,7 @@ def map_over_subtree(
13531382
func: Callable,
13541383
*args: Iterable[Any],
13551384
**kwargs: Any,
1356-
) -> DataTree | tuple[DataTree]:
1385+
) -> DataTree | tuple[DataTree, ...]:
13571386
"""
13581387
Apply a function to every dataset in this subtree, returning a new tree which stores the results.
13591388

xarray/tests/test_datatree.py

+28
Original file line numberDiff line numberDiff line change
@@ -1295,6 +1295,22 @@ def test_inherited_coords_override(self):
12951295
xr.testing.assert_equal(dt["/b/y"], xr.DataArray(2, coords=sub_coords))
12961296
xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords=sub_coords))
12971297

1298+
def test_inherited_coords_with_index_are_deduplicated(self):
1299+
dt = DataTree.from_dict(
1300+
{
1301+
"/": xr.Dataset(coords={"x": [1, 2]}),
1302+
"/b": xr.Dataset(coords={"x": [1, 2]}),
1303+
}
1304+
)
1305+
child_dataset = dt.children["b"].to_dataset(inherited=False)
1306+
expected = xr.Dataset()
1307+
assert_identical(child_dataset, expected)
1308+
1309+
dt["/c"] = xr.Dataset({"foo": ("x", [4, 5])}, coords={"x": [1, 2]})
1310+
child_dataset = dt.children["c"].to_dataset(inherited=False)
1311+
expected = xr.Dataset({"foo": ("x", [4, 5])})
1312+
assert_identical(child_dataset, expected)
1313+
12981314
def test_inconsistent_dims(self):
12991315
expected_msg = _exact_match(
13001316
"""
@@ -1639,6 +1655,18 @@ def test_binary_op_on_datatree(self):
16391655
result = dt * dt # type: ignore[operator]
16401656
assert_equal(result, expected)
16411657

1658+
def test_arithmetic_inherited_coords(self):
1659+
tree = DataTree(xr.Dataset(coords={"x": [1, 2, 3]}))
1660+
tree["/foo"] = DataTree(xr.Dataset({"bar": ("x", [4, 5, 6])}))
1661+
actual: DataTree = 2 * tree # type: ignore[assignment,operator]
1662+
1663+
actual_dataset = actual.children["foo"].to_dataset(inherited=False)
1664+
assert "x" not in actual_dataset.coords
1665+
1666+
expected = tree.copy()
1667+
expected["/foo/bar"].data = np.array([8, 10, 12])
1668+
assert_identical(actual, expected)
1669+
16421670

16431671
class TestUFuncs:
16441672
def test_tree(self, create_test_datatree):

xarray/tests/test_datatree_mapping.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
check_isomorphic,
88
map_over_subtree,
99
)
10-
from xarray.testing import assert_equal
10+
from xarray.testing import assert_equal, assert_identical
1111

1212
empty = xr.Dataset()
1313

@@ -306,6 +306,17 @@ def fail_on_specific_node(ds):
306306
):
307307
dt.map_over_subtree(fail_on_specific_node)
308308

309+
def test_inherited_coordinates_with_index(self):
310+
root = xr.Dataset(coords={"x": [1, 2]})
311+
child = xr.Dataset({"foo": ("x", [0, 1])}) # no coordinates
312+
tree = xr.DataTree.from_dict({"/": root, "/child": child})
313+
actual = tree.map_over_subtree(lambda ds: ds) # identity
314+
assert isinstance(actual, xr.DataTree)
315+
assert_identical(tree, actual)
316+
317+
actual_child = actual.children["child"].to_dataset(inherited=False)
318+
assert_identical(actual_child, child)
319+
309320

310321
class TestMutableOperations:
311322
def test_construct_using_type(self):

0 commit comments

Comments
 (0)