diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index fd938238ae4..52d2175621f 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -20,6 +20,7 @@ ) from xarray.backends.store import StoreBackendEntrypoint from xarray.core import indexing +from xarray.core.treenode import NodePath from xarray.core.types import ZarrWriteModes from xarray.core.utils import ( FrozenDict, @@ -33,6 +34,8 @@ if TYPE_CHECKING: from io import BufferedIOBase + from zarr import Group as ZarrGroup + from xarray.backends.common import AbstractDataStore from xarray.core.dataset import Dataset from xarray.core.datatree import DataTree @@ -1218,66 +1221,86 @@ def open_datatree( zarr_version=None, **kwargs, ) -> DataTree: - from xarray.backends.api import open_dataset from xarray.core.datatree import DataTree + + filename_or_obj = _normalize_path(filename_or_obj) + groups_dict = self.open_groups_as_dict(filename_or_obj, **kwargs) + + return DataTree.from_dict(groups_dict) + + def open_groups_as_dict( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + *, + mask_and_scale=True, + decode_times=True, + concat_characters=True, + decode_coords=True, + drop_variables: str | Iterable[str] | None = None, + use_cftime=None, + decode_timedelta=None, + group: str | Iterable[str] | Callable | None = None, + mode="r", + synchronizer=None, + consolidated=None, + chunk_store=None, + storage_options=None, + stacklevel=3, + zarr_version=None, + **kwargs, + ) -> dict[str, Dataset]: + from xarray.core.treenode import NodePath filename_or_obj = _normalize_path(filename_or_obj) + + # Check for a group and make it a parent if it exists if group: - parent = NodePath("/") / NodePath(group) - stores = ZarrStore.open_store( - filename_or_obj, - group=parent, - mode=mode, - synchronizer=synchronizer, - consolidated=consolidated, - consolidate_on_close=False, - chunk_store=chunk_store, - storage_options=storage_options, - stacklevel=stacklevel + 1, - zarr_version=zarr_version, - ) - if not stores: - ds = open_dataset( - filename_or_obj, group=parent, engine="zarr", **kwargs - ) - return DataTree.from_dict({str(parent): ds}) + parent = str(NodePath("/") / NodePath(group)) else: - parent = NodePath("/") - stores = ZarrStore.open_store( - filename_or_obj, - group=parent, - mode=mode, - synchronizer=synchronizer, - consolidated=consolidated, - consolidate_on_close=False, - chunk_store=chunk_store, - storage_options=storage_options, - stacklevel=stacklevel + 1, - zarr_version=zarr_version, - ) - ds = open_dataset(filename_or_obj, group=parent, engine="zarr", **kwargs) - tree_root = DataTree.from_dict({str(parent): ds}) + parent = str(NodePath("/")) + + stores = ZarrStore.open_store( + filename_or_obj, + group=parent, + mode=mode, + synchronizer=synchronizer, + consolidated=consolidated, + consolidate_on_close=False, + chunk_store=chunk_store, + storage_options=storage_options, + stacklevel=stacklevel + 1, + zarr_version=zarr_version, + ) + + groups_dict = {} + for path_group, store in stores.items(): - ds = open_dataset( - filename_or_obj, store=store, group=path_group, engine="zarr", **kwargs - ) - new_node = DataTree(name=NodePath(path_group).name, dataset=ds) - tree_root._set_item( - path_group, - new_node, - allow_overwrite=False, - new_nodes_along_path=True, - ) - return tree_root + store_entrypoint = StoreBackendEntrypoint() + + with close_on_error(store): + group_ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + group_name = str(NodePath(path_group)) + groups_dict[group_name] = group_ds + + return groups_dict -def _iter_zarr_groups(root, parent="/"): - from xarray.core.treenode import NodePath +def _iter_zarr_groups(root: ZarrGroup, parent: str = "/") -> Iterable[str]: - parent = NodePath(parent) + parent_nodepath = NodePath(parent) + yield str(parent_nodepath) for path, group in root.groups(): - gpath = parent / path + gpath = parent_nodepath / path yield str(gpath) yield from _iter_zarr_groups(group, parent=gpath) diff --git a/xarray/tests/test_backends_datatree.py b/xarray/tests/test_backends_datatree.py index d9490385b7d..8ca4711acad 100644 --- a/xarray/tests/test_backends_datatree.py +++ b/xarray/tests/test_backends_datatree.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re from typing import TYPE_CHECKING, cast import numpy as np @@ -24,6 +25,88 @@ pass +@pytest.fixture(scope="module") +def unaligned_datatree_nc(tmp_path_factory): + """Creates a test netCDF4 file with the following unaligned structure, writes it to a /tmp directory + and returns the file path of the netCDF4 file. + + Group: / + │ Dimensions: (lat: 1, lon: 2) + │ Dimensions without coordinates: lat, lon + │ Data variables: + │ root_variable (lat, lon) float64 16B ... + └── Group: /Group1 + │ Dimensions: (lat: 1, lon: 2) + │ Dimensions without coordinates: lat, lon + │ Data variables: + │ group_1_var (lat, lon) float64 16B ... + └── Group: /Group1/subgroup1 + Dimensions: (lat: 2, lon: 2) + Dimensions without coordinates: lat, lon + Data variables: + subgroup1_var (lat, lon) float64 32B ... + """ + filepath = tmp_path_factory.mktemp("data") / "unaligned_subgroups.nc" + with nc4.Dataset(filepath, "w", format="NETCDF4") as root_group: + group_1 = root_group.createGroup("/Group1") + subgroup_1 = group_1.createGroup("/subgroup1") + + root_group.createDimension("lat", 1) + root_group.createDimension("lon", 2) + root_group.createVariable("root_variable", np.float64, ("lat", "lon")) + + group_1_var = group_1.createVariable("group_1_var", np.float64, ("lat", "lon")) + group_1_var[:] = np.array([[0.1, 0.2]]) + group_1_var.units = "K" + group_1_var.long_name = "air_temperature" + + subgroup_1.createDimension("lat", 2) + + subgroup1_var = subgroup_1.createVariable( + "subgroup1_var", np.float64, ("lat", "lon") + ) + subgroup1_var[:] = np.array([[0.1, 0.2]]) + + yield filepath + + +@pytest.fixture(scope="module") +def unaligned_datatree_zarr(tmp_path_factory): + """Creates a zarr store with the following unaligned group hierarchy: + Group: / + │ Dimensions: (y: 3, x: 2) + │ Dimensions without coordinates: y, x + │ Data variables: + │ a (y) int64 24B ... + │ set0 (x) int64 16B ... + └── Group: /Group1 + │ │ Dimensions: () + │ │ Data variables: + │ │ a int64 8B ... + │ │ b int64 8B ... + │ └── /Group1/subgroup1 + │ Dimensions: () + │ Data variables: + │ a int64 8B ... + │ b int64 8B ... + └── Group: /Group2 + Dimensions: (y: 2, x: 2) + Dimensions without coordinates: y, x + Data variables: + a (y) int64 16B ... + b (x) float64 16B ... + """ + filepath = tmp_path_factory.mktemp("data") / "unaligned_simple_datatree.zarr" + root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}) + set1_data = xr.Dataset({"a": 0, "b": 1}) + set2_data = xr.Dataset({"a": ("y", [2, 3]), "b": ("x", [0.1, 0.2])}) + root_data.to_zarr(filepath) + set1_data.to_zarr(filepath, group="/Group1", mode="a") + set2_data.to_zarr(filepath, group="/Group2", mode="a") + set1_data.to_zarr(filepath, group="/Group1/subgroup1", mode="a") + yield filepath + + class DatatreeIOBase: engine: T_DataTreeNetcdfEngine | None = None @@ -73,91 +156,23 @@ def test_netcdf_encoding(self, tmpdir, simple_datatree): class TestNetCDF4DatatreeIO(DatatreeIOBase): engine: T_DataTreeNetcdfEngine | None = "netcdf4" - def test_open_datatree(self, tmpdir) -> None: - """Create a test netCDF4 file with this unaligned structure: - Group: / - │ Dimensions: (lat: 1, lon: 2) - │ Dimensions without coordinates: lat, lon - │ Data variables: - │ root_variable (lat, lon) float64 16B ... - └── Group: /Group1 - │ Dimensions: (lat: 1, lon: 2) - │ Dimensions without coordinates: lat, lon - │ Data variables: - │ group_1_var (lat, lon) float64 16B ... - └── Group: /Group1/subgroup1 - Dimensions: (lat: 2, lon: 2) - Dimensions without coordinates: lat, lon - Data variables: - subgroup1_var (lat, lon) float64 32B ... - """ - filepath = tmpdir + "/unaligned_subgroups.nc" - with nc4.Dataset(filepath, "w", format="NETCDF4") as root_group: - group_1 = root_group.createGroup("/Group1") - subgroup_1 = group_1.createGroup("/subgroup1") - - root_group.createDimension("lat", 1) - root_group.createDimension("lon", 2) - root_group.createVariable("root_variable", np.float64, ("lat", "lon")) - - group_1_var = group_1.createVariable( - "group_1_var", np.float64, ("lat", "lon") - ) - group_1_var[:] = np.array([[0.1, 0.2]]) - group_1_var.units = "K" - group_1_var.long_name = "air_temperature" - - subgroup_1.createDimension("lat", 2) - - subgroup1_var = subgroup_1.createVariable( - "subgroup1_var", np.float64, ("lat", "lon") - ) - subgroup1_var[:] = np.array([[0.1, 0.2]]) - with pytest.raises(ValueError): - open_datatree(filepath) - - def test_open_groups(self, tmpdir) -> None: - """Test `open_groups` with netCDF4 file with the same unaligned structure: - Group: / - │ Dimensions: (lat: 1, lon: 2) - │ Dimensions without coordinates: lat, lon - │ Data variables: - │ root_variable (lat, lon) float64 16B ... - └── Group: /Group1 - │ Dimensions: (lat: 1, lon: 2) - │ Dimensions without coordinates: lat, lon - │ Data variables: - │ group_1_var (lat, lon) float64 16B ... - └── Group: /Group1/subgroup1 - Dimensions: (lat: 2, lon: 2) - Dimensions without coordinates: lat, lon - Data variables: - subgroup1_var (lat, lon) float64 32B ... - """ - filepath = tmpdir + "/unaligned_subgroups.nc" - with nc4.Dataset(filepath, "w", format="NETCDF4") as root_group: - group_1 = root_group.createGroup("/Group1") - subgroup_1 = group_1.createGroup("/subgroup1") - - root_group.createDimension("lat", 1) - root_group.createDimension("lon", 2) - root_group.createVariable("root_variable", np.float64, ("lat", "lon")) + def test_open_datatree(self, unaligned_datatree_nc) -> None: + """Test if `open_datatree` fails to open a netCDF4 with an unaligned group hierarchy.""" - group_1_var = group_1.createVariable( - "group_1_var", np.float64, ("lat", "lon") - ) - group_1_var[:] = np.array([[0.1, 0.2]]) - group_1_var.units = "K" - group_1_var.long_name = "air_temperature" - - subgroup_1.createDimension("lat", 2) - - subgroup1_var = subgroup_1.createVariable( - "subgroup1_var", np.float64, ("lat", "lon") - ) - subgroup1_var[:] = np.array([[0.1, 0.2]]) + with pytest.raises( + ValueError, + match=( + re.escape( + "group '/Group1/subgroup1' is not aligned with its parents:\nGroup:\n" + ) + + ".*" + ), + ): + open_datatree(unaligned_datatree_nc) - unaligned_dict_of_datasets = open_groups(filepath) + def test_open_groups(self, unaligned_datatree_nc) -> None: + """Test `open_groups` with a netCDF4 file with an unaligned group hierarchy.""" + unaligned_dict_of_datasets = open_groups(unaligned_datatree_nc) # Check that group names are keys in the dictionary of `xr.Datasets` assert "/" in unaligned_dict_of_datasets.keys() @@ -165,19 +180,20 @@ def test_open_groups(self, tmpdir) -> None: assert "/Group1/subgroup1" in unaligned_dict_of_datasets.keys() # Check that group name returns the correct datasets assert_identical( - unaligned_dict_of_datasets["/"], xr.open_dataset(filepath, group="/") + unaligned_dict_of_datasets["/"], + xr.open_dataset(unaligned_datatree_nc, group="/"), ) assert_identical( unaligned_dict_of_datasets["/Group1"], - xr.open_dataset(filepath, group="Group1"), + xr.open_dataset(unaligned_datatree_nc, group="Group1"), ) assert_identical( unaligned_dict_of_datasets["/Group1/subgroup1"], - xr.open_dataset(filepath, group="/Group1/subgroup1"), + xr.open_dataset(unaligned_datatree_nc, group="/Group1/subgroup1"), ) def test_open_groups_to_dict(self, tmpdir) -> None: - """Create a an aligned netCDF4 with the following structure to test `open_groups` + """Create an aligned netCDF4 with the following structure to test `open_groups` and `DataTree.from_dict`. Group: / │ Dimensions: (lat: 1, lon: 2) @@ -305,3 +321,53 @@ def test_to_zarr_inherited_coords(self, tmpdir): assert_equal(original_dt, roundtrip_dt) subtree = cast(DataTree, roundtrip_dt["/sub"]) assert "x" not in subtree.to_dataset(inherited=False).coords + + def test_open_groups_round_trip(self, tmpdir, simple_datatree) -> None: + """Test `open_groups` opens a zarr store with the `simple_datatree` structure.""" + filepath = tmpdir / "test.zarr" + original_dt = simple_datatree + original_dt.to_zarr(filepath) + + roundtrip_dict = open_groups(filepath, engine="zarr") + roundtrip_dt = DataTree.from_dict(roundtrip_dict) + + assert open_datatree(filepath, engine="zarr").identical(roundtrip_dt) + + def test_open_datatree(self, unaligned_datatree_zarr) -> None: + """Test if `open_datatree` fails to open a zarr store with an unaligned group hierarchy.""" + with pytest.raises( + ValueError, + match=( + re.escape("group '/Group2' is not aligned with its parents:") + ".*" + ), + ): + open_datatree(unaligned_datatree_zarr, engine="zarr") + + def test_open_groups(self, unaligned_datatree_zarr) -> None: + """Test `open_groups` with a zarr store of an unaligned group hierarchy.""" + + unaligned_dict_of_datasets = open_groups(unaligned_datatree_zarr, engine="zarr") + + assert "/" in unaligned_dict_of_datasets.keys() + assert "/Group1" in unaligned_dict_of_datasets.keys() + assert "/Group1/subgroup1" in unaligned_dict_of_datasets.keys() + assert "/Group2" in unaligned_dict_of_datasets.keys() + # Check that group name returns the correct datasets + assert_identical( + unaligned_dict_of_datasets["/"], + xr.open_dataset(unaligned_datatree_zarr, group="/", engine="zarr"), + ) + assert_identical( + unaligned_dict_of_datasets["/Group1"], + xr.open_dataset(unaligned_datatree_zarr, group="Group1", engine="zarr"), + ) + assert_identical( + unaligned_dict_of_datasets["/Group1/subgroup1"], + xr.open_dataset( + unaligned_datatree_zarr, group="/Group1/subgroup1", engine="zarr" + ), + ) + assert_identical( + unaligned_dict_of_datasets["/Group2"], + xr.open_dataset(unaligned_datatree_zarr, group="/Group2", engine="zarr"), + )