diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 59984c5afa3..3a3fb19daa4 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -423,7 +423,6 @@ class DataTree( def __init__( self, data: Dataset | DataArray | None = None, - parent: DataTree | None = None, children: Mapping[str, DataTree] | None = None, name: str | None = None, ): @@ -439,8 +438,6 @@ def __init__( data : Dataset, DataArray, or None, optional Data to store under the .ds attribute of this node. DataArrays will be promoted to Datasets. Default is None. - parent : DataTree, optional - Parent node to this node. Default is None. children : Mapping[str, DataTree], optional Any child nodes of this node. Default is None. name : str, optional @@ -459,8 +456,9 @@ def __init__( super().__init__(name=name) self._set_node_data(_coerce_to_dataset(data)) - self.parent = parent - self.children = children + + # shallow copy to avoid modifying arguments in-place (see GH issue #9196) + self.children = {name: child.copy() for name, child in children.items()} def _set_node_data(self, ds: Dataset): data_vars, coord_vars = _collect_data_and_coord_variables(ds) @@ -497,17 +495,6 @@ def _dims(self) -> ChainMap[Hashable, int]: def _indexes(self) -> ChainMap[Hashable, Index]: return ChainMap(self._node_indexes, *(p._node_indexes for p in self.parents)) - @property - def parent(self: DataTree) -> DataTree | None: - """Parent of this node.""" - return self._parent - - @parent.setter - def parent(self: DataTree, new_parent: DataTree) -> None: - if new_parent and self.name is None: - raise ValueError("Cannot set an unnamed node as a child of another node") - self._set_parent(new_parent, self.name) - def _to_dataset_view(self, rebuild_dims: bool) -> DatasetView: variables = dict(self._data_variables) variables |= self._coord_variables @@ -896,7 +883,7 @@ def _set(self, key: str, val: DataTree | CoercibleValue) -> None: # create and assign a shallow copy here so as not to alter original name of node in grafted tree new_node = val.copy(deep=False) new_node.name = key - new_node.parent = self + new_node._set_parent(new_parent=self, child_name=key) else: if not isinstance(val, DataArray | Variable): # accommodate other types that can be coerced into Variables @@ -1097,7 +1084,7 @@ def from_dict( obj = root_data.copy() obj.orphan() else: - obj = cls(name=name, data=root_data, parent=None, children=None) + obj = cls(name=name, data=root_data, children=None) def depth(item) -> int: pathstr, _ = item diff --git a/xarray/core/datatree_render.py b/xarray/core/datatree_render.py index 98cb4f91495..47e0358588d 100644 --- a/xarray/core/datatree_render.py +++ b/xarray/core/datatree_render.py @@ -51,11 +51,16 @@ def __init__(self): >>> from xarray.core.datatree import DataTree >>> from xarray.core.datatree_render import RenderDataTree - >>> root = DataTree(name="root") - >>> s0 = DataTree(name="sub0", parent=root) - >>> s0b = DataTree(name="sub0B", parent=s0) - >>> s0a = DataTree(name="sub0A", parent=s0) - >>> s1 = DataTree(name="sub1", parent=root) + >>> root = DataTree.from_dict( + ... { + ... "/": None, + ... "/sub0": None, + ... "/sub0/sub0B": None, + ... "/sub0/sub0A": None, + ... "/sub1": None, + ... }, + ... name="root", + ... ) >>> print(RenderDataTree(root)) Group: / @@ -98,11 +103,16 @@ def __init__( >>> from xarray import Dataset >>> from xarray.core.datatree import DataTree >>> from xarray.core.datatree_render import RenderDataTree - >>> root = DataTree(name="root", data=Dataset({"a": 0, "b": 1})) - >>> s0 = DataTree(name="sub0", parent=root, data=Dataset({"c": 2, "d": 3})) - >>> s0b = DataTree(name="sub0B", parent=s0, data=Dataset({"e": 4})) - >>> s0a = DataTree(name="sub0A", parent=s0, data=Dataset({"f": 5, "g": 6})) - >>> s1 = DataTree(name="sub1", parent=root, data=Dataset({"h": 7})) + >>> root = DataTree.from_dict( + ... { + ... "/": Dataset({"a": 0, "b": 1}), + ... "/sub0": Dataset({"c": 2, "d": 3}), + ... "/sub0/sub0B": Dataset({"e": 4}), + ... "/sub0/sub0A": Dataset({"f": 5, "g": 6}), + ... "/sub1": Dataset({"h": 7}), + ... }, + ... name="root", + ... ) # Simple one line: @@ -208,17 +218,16 @@ def by_attr(self, attrname: str = "name") -> str: >>> from xarray import Dataset >>> from xarray.core.datatree import DataTree >>> from xarray.core.datatree_render import RenderDataTree - >>> root = DataTree(name="root") - >>> s0 = DataTree(name="sub0", parent=root) - >>> s0b = DataTree( - ... name="sub0B", parent=s0, data=Dataset({"foo": 4, "bar": 109}) + >>> root = DataTree.from_dict( + ... { + ... "/sub0/sub0B": Dataset({"foo": 4, "bar": 109}), + ... "/sub0/sub0A": None, + ... "/sub1/sub1A": None, + ... "/sub1/sub1B": Dataset({"bar": 8}), + ... "/sub1/sub1C/sub1Ca": None, + ... }, + ... name="root", ... ) - >>> s0a = DataTree(name="sub0A", parent=s0) - >>> s1 = DataTree(name="sub1", parent=root) - >>> s1a = DataTree(name="sub1A", parent=s1) - >>> s1b = DataTree(name="sub1B", parent=s1, data=Dataset({"bar": 8})) - >>> s1c = DataTree(name="sub1C", parent=s1) - >>> s1ca = DataTree(name="sub1Ca", parent=s1c) >>> print(RenderDataTree(root).by_attr("name")) root ├── sub0 diff --git a/xarray/core/iterators.py b/xarray/core/iterators.py index 1ba3c6f1675..eeaeb35aa9c 100644 --- a/xarray/core/iterators.py +++ b/xarray/core/iterators.py @@ -28,15 +28,9 @@ class LevelOrderIter(Iterator): -------- >>> from xarray.core.datatree import DataTree >>> from xarray.core.iterators import LevelOrderIter - >>> f = DataTree(name="f") - >>> b = DataTree(name="b", parent=f) - >>> a = DataTree(name="a", parent=b) - >>> d = DataTree(name="d", parent=b) - >>> c = DataTree(name="c", parent=d) - >>> e = DataTree(name="e", parent=d) - >>> g = DataTree(name="g", parent=f) - >>> i = DataTree(name="i", parent=g) - >>> h = DataTree(name="h", parent=i) + >>> f = DataTree.from_dict( + ... {"/b/a": None, "/b/d/c": None, "/b/d/e": None, "/g/h/i": None}, name="f" + ... ) >>> print(f) Group: / @@ -46,19 +40,19 @@ class LevelOrderIter(Iterator): │ ├── Group: /b/d/c │ └── Group: /b/d/e └── Group: /g - └── Group: /g/i - └── Group: /g/i/h + └── Group: /g/h + └── Group: /g/h/i >>> [node.name for node in LevelOrderIter(f)] - ['f', 'b', 'g', 'a', 'd', 'i', 'c', 'e', 'h'] + ['f', 'b', 'g', 'a', 'd', 'h', 'c', 'e', 'i'] >>> [node.name for node in LevelOrderIter(f, maxlevel=3)] - ['f', 'b', 'g', 'a', 'd', 'i'] + ['f', 'b', 'g', 'a', 'd', 'h'] >>> [ ... node.name ... for node in LevelOrderIter(f, filter_=lambda n: n.name not in ("e", "g")) ... ] - ['f', 'b', 'a', 'd', 'i', 'c', 'h'] + ['f', 'b', 'a', 'd', 'h', 'c', 'i'] >>> [node.name for node in LevelOrderIter(f, stop=lambda n: n.name == "d")] - ['f', 'b', 'g', 'a', 'i', 'h'] + ['f', 'b', 'g', 'a', 'h', 'i'] """ def __init__( diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index 77e7ed23a51..9dfd346508a 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -86,6 +86,12 @@ def parent(self) -> Tree | None: """Parent of this node.""" return self._parent + @parent.setter + def parent(self: Tree, new_parent: Tree) -> None: + raise AttributeError( + "Cannot set parent attribute directly, you must modify the children of the other node instead using dict-like syntax" + ) + def _set_parent( self, new_parent: Tree | None, child_name: str | None = None ) -> None: diff --git a/xarray/tests/conftest.py b/xarray/tests/conftest.py index 85f968b19a6..065e49a372a 100644 --- a/xarray/tests/conftest.py +++ b/xarray/tests/conftest.py @@ -196,14 +196,17 @@ def _create_test_datatree(modify=lambda ds: ds): set2_data = modify(xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])})) root_data = modify(xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})) - # Avoid using __init__ so we can independently test it - root: DataTree = DataTree(data=root_data) - set1: DataTree = DataTree(name="set1", parent=root, data=set1_data) - DataTree(name="set1", parent=set1) - DataTree(name="set2", parent=set1) - set2: DataTree = DataTree(name="set2", parent=root, data=set2_data) - DataTree(name="set1", parent=set2) - DataTree(name="set3", parent=root) + root = DataTree.from_dict( + { + "/": root_data, + "/set1": set1_data, + "/set1/set1": None, + "/set1/set2": None, + "/set2": set2_data, + "/set2/set1": None, + "/set3": None, + } + ) return root diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 9a15376a1f8..1a840dbb9e4 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -35,72 +35,88 @@ def test_bad_names(self): class TestFamilyTree: - def test_setparent_unnamed_child_node_fails(self): - john: DataTree = DataTree(name="john") - with pytest.raises(ValueError, match="unnamed"): - DataTree(parent=john) + def test_dont_modify_children_inplace(self): + # GH issue 9196 + child: DataTree = DataTree() + DataTree(children={"child": child}) + assert child.parent is None def test_create_two_children(self): root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}) set1_data = xr.Dataset({"a": 0, "b": 1}) - - root: DataTree = DataTree(data=root_data) - set1: DataTree = DataTree(name="set1", parent=root, data=set1_data) - DataTree(name="set1", parent=root) - DataTree(name="set2", parent=set1) + root = DataTree.from_dict( + {"/": root_data, "/set1": set1_data, "/set1/set2": None} + ) + assert root["/set1"].name == "set1" + assert root["/set1/set2"].name == "set2" def test_create_full_tree(self, simple_datatree): - root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}) - set1_data = xr.Dataset({"a": 0, "b": 1}) - set2_data = xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])}) + d = simple_datatree.to_dict() + d_keys = list(d.keys()) - root: DataTree = DataTree(data=root_data) - set1: DataTree = DataTree(name="set1", parent=root, data=set1_data) - DataTree(name="set1", parent=set1) - DataTree(name="set2", parent=set1) - set2: DataTree = DataTree(name="set2", parent=root, data=set2_data) - DataTree(name="set1", parent=set2) - DataTree(name="set3", parent=root) + expected_keys = [ + "/", + "/set1", + "/set2", + "/set3", + "/set1/set1", + "/set1/set2", + "/set2/set1", + ] - expected = simple_datatree - assert root.identical(expected) + assert d_keys == expected_keys class TestNames: def test_child_gets_named_on_attach(self): sue: DataTree = DataTree() mary: DataTree = DataTree(children={"Sue": sue}) # noqa - assert sue.name == "Sue" + assert mary.children["Sue"].name == "Sue" class TestPaths: def test_path_property(self): - sue: DataTree = DataTree() - mary: DataTree = DataTree(children={"Sue": sue}) - john: DataTree = DataTree(children={"Mary": mary}) - assert sue.path == "/Mary/Sue" + john = DataTree.from_dict( + { + "/Mary/Sue": DataTree(), + } + ) + assert john["/Mary/Sue"].path == "/Mary/Sue" assert john.path == "/" def test_path_roundtrip(self): - sue: DataTree = DataTree() - mary: DataTree = DataTree(children={"Sue": sue}) - john: DataTree = DataTree(children={"Mary": mary}) - assert john[sue.path] is sue + john = DataTree.from_dict( + { + "/Mary/Sue": DataTree(), + } + ) + assert john["/Mary/Sue"].name == "Sue" def test_same_tree(self): - mary: DataTree = DataTree() - kate: DataTree = DataTree() - john: DataTree = DataTree(children={"Mary": mary, "Kate": kate}) # noqa - assert mary.same_tree(kate) + john = DataTree.from_dict( + { + "/Mary": DataTree(), + "/Kate": DataTree(), + } + ) + assert john["/Mary"].same_tree(john["/Kate"]) def test_relative_paths(self): - sue: DataTree = DataTree() - mary: DataTree = DataTree(children={"Sue": sue}) - annie: DataTree = DataTree() - john: DataTree = DataTree(children={"Mary": mary, "Annie": annie}) + john: DataTree = DataTree.from_dict( + { + "/Mary/Sue": DataTree(), + "/Annie": DataTree(), + } + ) + sue_result = john["Mary/Sue"] + if isinstance(sue_result, DataTree): + sue: DataTree = sue_result + + annie_result = john["Annie"] + if isinstance(annie_result, DataTree): + annie: DataTree = annie_result - result = sue.relative_to(john) - assert result == "Mary/Sue" + assert sue.relative_to(john) == "Mary/Sue" assert john.relative_to(sue) == "../.." assert annie.relative_to(sue) == "../../Annie" assert sue.relative_to(annie) == "../Mary/Sue" @@ -121,7 +137,7 @@ def test_create_with_data(self): assert_identical(john.to_dataset(), dat) with pytest.raises(TypeError): - DataTree(name="mary", parent=john, data="junk") # type: ignore[arg-type] + DataTree(name="mary", data="junk") # type: ignore[arg-type] def test_set_data(self): john: DataTree = DataTree(name="john") @@ -168,13 +184,21 @@ def test_to_dataset(self): class TestVariablesChildrenNameCollisions: def test_parent_already_has_variable_with_childs_name(self): - dt: DataTree = DataTree(data=xr.Dataset({"a": [0], "b": 1})) with pytest.raises(KeyError, match="already contains a variable named a"): - DataTree(name="a", data=None, parent=dt) + DataTree.from_dict({"/": xr.Dataset({"a": [0], "b": 1}), "/a": None}) + + def test_parent_already_has_variable_with_childs_name_update(self): + dt: DataTree = DataTree(data=xr.Dataset({"a": [0], "b": 1})) + with pytest.raises(ValueError, match="already contains a variable named a"): + dt.update({"a": DataTree()}) def test_assign_when_already_child_with_variables_name(self): - dt: DataTree = DataTree(data=None) - DataTree(name="a", data=None, parent=dt) + dt = DataTree.from_dict( + { + "/a": DataTree(), + } + ) + with pytest.raises(ValueError, match="node already contains a variable"): dt.ds = xr.Dataset({"a": 0}) # type: ignore[assignment] @@ -190,11 +214,14 @@ class TestGet: ... class TestGetItem: def test_getitem_node(self): - folder1: DataTree = DataTree(name="folder1") - results: DataTree = DataTree(name="results", parent=folder1) - highres: DataTree = DataTree(name="highres", parent=results) - assert folder1["results"] is results - assert folder1["results/highres"] is highres + folder1 = DataTree.from_dict( + { + "/results/highres": DataTree(), + } + ) + + assert folder1["results"].name == "results" + assert folder1["results/highres"].name == "highres" def test_getitem_self(self): dt: DataTree = DataTree() @@ -207,14 +234,15 @@ def test_getitem_single_data_variable(self): def test_getitem_single_data_variable_from_node(self): data = xr.Dataset({"temp": [0, 50]}) - folder1: DataTree = DataTree(name="folder1") - results: DataTree = DataTree(name="results", parent=folder1) - DataTree(name="highres", parent=results, data=data) + folder1 = DataTree.from_dict( + { + "/results/highres": data, + } + ) assert_identical(folder1["results/highres/temp"], data["temp"]) def test_getitem_nonexistent_node(self): - folder1: DataTree = DataTree(name="folder1") - DataTree(name="results", parent=folder1) + folder1: DataTree = DataTree.from_dict({"/results": DataTree()}, name="folder1") with pytest.raises(KeyError): folder1["results/highres"] @@ -390,14 +418,13 @@ def test_setitem_unnamed_child_node_becomes_named(self): assert john2["sonny"].name == "sonny" def test_setitem_new_grandchild_node(self): - john: DataTree = DataTree(name="john") - mary: DataTree = DataTree(name="mary", parent=john) - rose: DataTree = DataTree(name="rose") - john["mary/rose"] = rose + john = DataTree.from_dict({"/Mary/Rose": DataTree()}) + new_rose: DataTree = DataTree(data=xr.Dataset({"x": 0})) + john["Mary/Rose"] = new_rose - grafted_rose = john["mary/rose"] - assert grafted_rose.parent is mary - assert grafted_rose.name == "rose" + grafted_rose = john["Mary/Rose"] + assert grafted_rose.parent is john["/Mary"] + assert grafted_rose.name == "Rose" def test_grafted_subtree_retains_name(self): subtree: DataTree = DataTree(name="original_subtree_name") @@ -413,10 +440,10 @@ def test_setitem_new_empty_node(self): assert_identical(mary.to_dataset(), xr.Dataset()) def test_setitem_overwrite_data_in_node_with_none(self): - john: DataTree = DataTree(name="john") - mary: DataTree = DataTree(name="mary", parent=john, data=xr.Dataset()) + john: DataTree = DataTree.from_dict({"/mary": xr.Dataset()}, name="john") + john["mary"] = DataTree() - assert_identical(mary.to_dataset(), xr.Dataset()) + assert_identical(john["mary"].to_dataset(), xr.Dataset()) john.ds = xr.Dataset() # type: ignore[assignment] with pytest.raises(ValueError, match="has no name"): @@ -598,8 +625,13 @@ def test_view_contents(self): def test_immutability(self): # See issue https://github.com/xarray-contrib/datatree/issues/38 - dt: DataTree = DataTree(name="root", data=None) - DataTree(name="a", data=None, parent=dt) + dt = DataTree.from_dict( + { + "/": None, + "/a": None, + }, + name="root", + ) with pytest.raises( AttributeError, match="Mutation of the DatasetView is not allowed" @@ -1052,44 +1084,51 @@ def test_filter(self): class TestDSMethodInheritance: def test_dataset_method(self): ds = xr.Dataset({"a": ("x", [1, 2, 3])}) - dt: DataTree = DataTree(data=ds) - DataTree(name="results", parent=dt, data=ds) + dt = DataTree.from_dict( + { + "/": ds, + "/results": ds, + } + ) - expected: DataTree = DataTree(data=ds.isel(x=1)) - DataTree(name="results", parent=expected, data=ds.isel(x=1)) + expected = DataTree.from_dict( + { + "/": ds.isel(x=1), + "/results": ds.isel(x=1), + } + ) result = dt.isel(x=1) assert_equal(result, expected) def test_reduce_method(self): ds = xr.Dataset({"a": ("x", [False, True, False])}) - dt: DataTree = DataTree(data=ds) - DataTree(name="results", parent=dt, data=ds) + dt = DataTree.from_dict({"/": ds, "/results": ds}) - expected: DataTree = DataTree(data=ds.any()) - DataTree(name="results", parent=expected, data=ds.any()) + expected = DataTree.from_dict({"/": ds.any(), "/results": ds.any()}) result = dt.any() assert_equal(result, expected) def test_nan_reduce_method(self): ds = xr.Dataset({"a": ("x", [1, 2, 3])}) - dt: DataTree = DataTree(data=ds) - DataTree(name="results", parent=dt, data=ds) + dt = DataTree.from_dict({"/": ds, "/results": ds}) - expected: DataTree = DataTree(data=ds.mean()) - DataTree(name="results", parent=expected, data=ds.mean()) + expected = DataTree.from_dict({"/": ds.mean(), "/results": ds.mean()}) result = dt.mean() assert_equal(result, expected) def test_cum_method(self): ds = xr.Dataset({"a": ("x", [1, 2, 3])}) - dt: DataTree = DataTree(data=ds) - DataTree(name="results", parent=dt, data=ds) + dt = DataTree.from_dict({"/": ds, "/results": ds}) - expected: DataTree = DataTree(data=ds.cumsum()) - DataTree(name="results", parent=expected, data=ds.cumsum()) + expected = DataTree.from_dict( + { + "/": ds.cumsum(), + "/results": ds.cumsum(), + } + ) result = dt.cumsum() assert_equal(result, expected) @@ -1099,11 +1138,9 @@ class TestOps: def test_binary_op_on_int(self): ds1 = xr.Dataset({"a": [5], "b": [3]}) ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) - dt: DataTree = DataTree(data=ds1) - DataTree(name="subnode", data=ds2, parent=dt) + dt = DataTree.from_dict({"/": ds1, "/subnode": ds2}) - expected: DataTree = DataTree(data=ds1 * 5) - DataTree(name="subnode", data=ds2 * 5, parent=expected) + expected = DataTree.from_dict({"/": ds1 * 5, "/subnode": ds2 * 5}) # TODO: Remove ignore when ops.py is migrated? result: DataTree = dt * 5 # type: ignore[assignment,operator] @@ -1112,12 +1149,21 @@ def test_binary_op_on_int(self): def test_binary_op_on_dataset(self): ds1 = xr.Dataset({"a": [5], "b": [3]}) ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) - dt: DataTree = DataTree(data=ds1) - DataTree(name="subnode", data=ds2, parent=dt) + dt = DataTree.from_dict( + { + "/": ds1, + "/subnode": ds2, + } + ) + other_ds = xr.Dataset({"z": ("z", [0.1, 0.2])}) - expected: DataTree = DataTree(data=ds1 * other_ds) - DataTree(name="subnode", data=ds2 * other_ds, parent=expected) + expected = DataTree.from_dict( + { + "/": ds1 * other_ds, + "/subnode": ds2 * other_ds, + } + ) result = dt * other_ds assert_equal(result, expected) @@ -1125,11 +1171,10 @@ def test_binary_op_on_dataset(self): def test_binary_op_on_datatree(self): ds1 = xr.Dataset({"a": [5], "b": [3]}) ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) - dt: DataTree = DataTree(data=ds1) - DataTree(name="subnode", data=ds2, parent=dt) - expected: DataTree = DataTree(data=ds1 * ds1) - DataTree(name="subnode", data=ds2 * ds2, parent=expected) + dt = DataTree.from_dict({"/": ds1, "/subnode": ds2}) + + expected = DataTree.from_dict({"/": ds1 * ds1, "/subnode": ds2 * ds2}) # TODO: Remove ignore when ops.py is migrated? result: DataTree = dt * dt # type: ignore[operator] diff --git a/xarray/tests/test_datatree_mapping.py b/xarray/tests/test_datatree_mapping.py index b8b55613c4a..62fada4ab4f 100644 --- a/xarray/tests/test_datatree_mapping.py +++ b/xarray/tests/test_datatree_mapping.py @@ -74,10 +74,9 @@ def test_checking_from_root(self, create_test_datatree): dt1 = create_test_datatree() dt2 = create_test_datatree() real_root: DataTree = DataTree(name="real root") - dt2.name = "not_real_root" - dt2.parent = real_root + real_root["not_real_root"] = dt2 with pytest.raises(TreeIsomorphismError): - check_isomorphic(dt1, dt2, check_from_root=True) + check_isomorphic(dt1, real_root, check_from_root=True) class TestMapOverSubTree: diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index 34aee38a1e3..e7076151314 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -657,8 +657,11 @@ def test_datatree_print_node_with_data(self): def test_datatree_printout_nested_node(self): dat = xr.Dataset({"a": [0, 2]}) - root: DataTree = DataTree(name="root") - DataTree(name="results", data=dat, parent=root) + root = DataTree.from_dict( + { + "/results": dat, + } + ) printout = str(root) assert printout.splitlines()[3].startswith(" ") diff --git a/xarray/tests/test_treenode.py b/xarray/tests/test_treenode.py index a7de2e39af6..3996edba659 100644 --- a/xarray/tests/test_treenode.py +++ b/xarray/tests/test_treenode.py @@ -55,6 +55,15 @@ def test_parent_swap(self): assert steve.children["Mary"] is mary assert "Mary" not in john.children + def test_forbid_setting_parent_directly(self): + john: TreeNode = TreeNode() + mary: TreeNode = TreeNode() + + with pytest.raises( + AttributeError, match="Cannot set parent attribute directly" + ): + mary.parent = john + def test_multi_child_family(self): mary: TreeNode = TreeNode() kate: TreeNode = TreeNode()