Skip to content

Commit 7046255

Browse files
authored
Support alternative names for the root node in DataTree.from_dict (#9638)
1 parent de3fce8 commit 7046255

File tree

2 files changed

+64
-11
lines changed

2 files changed

+64
-11
lines changed

xarray/core/datatree.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,10 +1104,12 @@ def from_dict(
11041104
d : dict-like
11051105
A mapping from path names to xarray.Dataset or DataTree objects.
11061106
1107-
Path names are to be given as unix-like path. If path names containing more than one
1108-
part are given, new tree nodes will be constructed as necessary.
1107+
Path names are to be given as unix-like path. If path names
1108+
containing more than one part are given, new tree nodes will be
1109+
constructed as necessary.
11091110
1110-
To assign data to the root node of the tree use "/" as the path.
1111+
To assign data to the root node of the tree use "", ".", "/" or "./"
1112+
as the path.
11111113
name : Hashable | None, optional
11121114
Name for the root node of the tree. Default is None.
11131115
@@ -1119,17 +1121,27 @@ def from_dict(
11191121
-----
11201122
If your dictionary is nested you will need to flatten it before using this method.
11211123
"""
1122-
1123-
# First create the root node
1124+
# Find any values corresponding to the root
11241125
d_cast = dict(d)
1125-
root_data = d_cast.pop("/", None)
1126+
root_data = None
1127+
for key in ("", ".", "/", "./"):
1128+
if key in d_cast:
1129+
if root_data is not None:
1130+
raise ValueError(
1131+
"multiple entries found corresponding to the root node"
1132+
)
1133+
root_data = d_cast.pop(key)
1134+
1135+
# Create the root node
11261136
if isinstance(root_data, DataTree):
11271137
obj = root_data.copy()
1138+
obj.name = name
11281139
elif root_data is None or isinstance(root_data, Dataset):
11291140
obj = cls(name=name, dataset=root_data, children=None)
11301141
else:
11311142
raise TypeError(
1132-
f'root node data (at "/") must be a Dataset or DataTree, got {type(root_data)}'
1143+
f'root node data (at "", ".", "/" or "./") must be a Dataset '
1144+
f"or DataTree, got {type(root_data)}"
11331145
)
11341146

11351147
def depth(item) -> int:
@@ -1141,11 +1153,10 @@ def depth(item) -> int:
11411153
# Sort keys by depth so as to insert nodes from root first (see GH issue #9276)
11421154
for path, data in sorted(d_cast.items(), key=depth):
11431155
# Create and set new node
1144-
node_name = NodePath(path).name
11451156
if isinstance(data, DataTree):
11461157
new_node = data.copy()
11471158
elif isinstance(data, Dataset) or data is None:
1148-
new_node = cls(name=node_name, dataset=data)
1159+
new_node = cls(dataset=data)
11491160
else:
11501161
raise TypeError(f"invalid values: {data}")
11511162
obj._set_item(
@@ -1683,7 +1694,7 @@ def reduce(
16831694
numeric_only=numeric_only,
16841695
**kwargs,
16851696
)
1686-
path = "/" if node is self else node.relative_to(self)
1697+
path = node.relative_to(self)
16871698
result[path] = node_result
16881699
return type(self).from_dict(result, name=self.name)
16891700

@@ -1718,7 +1729,7 @@ def _selective_indexing(
17181729
# with a scalar) can also create scalar coordinates, which
17191730
# need to be explicitly removed.
17201731
del node_result.coords[k]
1721-
path = "/" if node is self else node.relative_to(self)
1732+
path = node.relative_to(self)
17221733
result[path] = node_result
17231734
return type(self).from_dict(result, name=self.name)
17241735

xarray/tests/test_datatree.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,48 @@ def test_array_values(self) -> None:
883883
with pytest.raises(TypeError):
884884
DataTree.from_dict(data) # type: ignore[arg-type]
885885

886+
def test_relative_paths(self) -> None:
887+
tree = DataTree.from_dict({".": None, "foo": None, "./bar": None, "x/y": None})
888+
paths = [node.path for node in tree.subtree]
889+
assert paths == [
890+
"/",
891+
"/foo",
892+
"/bar",
893+
"/x",
894+
"/x/y",
895+
]
896+
897+
def test_root_keys(self):
898+
ds = Dataset({"x": 1})
899+
expected = DataTree(dataset=ds)
900+
901+
actual = DataTree.from_dict({"": ds})
902+
assert_identical(actual, expected)
903+
904+
actual = DataTree.from_dict({".": ds})
905+
assert_identical(actual, expected)
906+
907+
actual = DataTree.from_dict({"/": ds})
908+
assert_identical(actual, expected)
909+
910+
actual = DataTree.from_dict({"./": ds})
911+
assert_identical(actual, expected)
912+
913+
with pytest.raises(
914+
ValueError, match="multiple entries found corresponding to the root node"
915+
):
916+
DataTree.from_dict({"": ds, "/": ds})
917+
918+
def test_name(self):
919+
tree = DataTree.from_dict({"/": None}, name="foo")
920+
assert tree.name == "foo"
921+
922+
tree = DataTree.from_dict({"/": DataTree()}, name="foo")
923+
assert tree.name == "foo"
924+
925+
tree = DataTree.from_dict({"/": DataTree(name="bar")}, name="foo")
926+
assert tree.name == "foo"
927+
886928

887929
class TestDatasetView:
888930
def test_view_contents(self) -> None:

0 commit comments

Comments
 (0)