diff --git a/datatree/datatree.py b/datatree/datatree.py index a3df42d1..ae516fd3 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -874,10 +874,44 @@ def groups(self): """Return all netCDF4 groups in the tree, given as a tuple of path-like strings.""" return tuple(node.pathstr for node in self.subtree_nodes) - def to_netcdf(self, filename: str): + def to_netcdf( + self, filepath, mode: str = "w", encoding=None, unlimited_dims=None, **kwargs + ): + """ + Write datatree contents to a netCDF file. + + Paramters + --------- + filepath : str or Path + Path to which to save this datatree. + mode : {"w", "a"}, default: "w" + Write ('w') or append ('a') mode. If mode='w', any existing file at + this location will be overwritten. If mode='a', existing variables + will be overwritten. Only appies to the root group. + encoding : dict, optional + Nested dictionary with variable names as keys and dictionaries of + variable specific encodings as values, e.g., + ``{"root/set1": {"my_variable": {"dtype": "int16", "scale_factor": 0.1, + "zlib": True}, ...}, ...}``. See ``xarray.Dataset.to_netcdf`` for available + options. + unlimited_dims : dict, optional + Mapping of unlimited dimensions per group that that should be serialized as unlimited dimensions. + By default, no dimensions are treated as unlimited dimensions. + Note that unlimited_dims may also be set via + ``dataset.encoding["unlimited_dims"]``. + kwargs : + Addional keyword arguments to be passed to ``xarray.Dataset.to_netcdf`` + """ from .io import _datatree_to_netcdf - _datatree_to_netcdf(self, filename) + _datatree_to_netcdf( + self, + filepath, + mode=mode, + encoding=encoding, + unlimited_dims=unlimited_dims, + **kwargs, + ) def plot(self): raise NotImplementedError diff --git a/datatree/io.py b/datatree/io.py index 727b4485..c717203a 100644 --- a/datatree/io.py +++ b/datatree/io.py @@ -7,12 +7,20 @@ from .datatree import DataNode, DataTree, PathType +def _ds_or_none(ds): + """return none if ds is empty""" + if any(ds.coords) or any(ds.variables) or any(ds.attrs): + return ds + return None + + def _open_group_children_recursively(filename, node, ncgroup, chunks, **kwargs): for g in ncgroup.groups.values(): # Open and add this node's dataset to the tree name = os.path.basename(g.path) ds = open_dataset(filename, group=g.path, chunks=chunks, **kwargs) + ds = _ds_or_none(ds) child_node = DataNode(name, ds) node.add_child(child_node) @@ -65,5 +73,73 @@ def open_mfdatatree( return full_tree -def _datatree_to_netcdf(dt: DataTree, filepath: str): - raise NotImplementedError +def _maybe_extract_group_kwargs(enc, group): + try: + return enc[group] + except KeyError: + return None + + +def _create_empty_group(filename, group, mode): + with netCDF4.Dataset(filename, mode=mode) as rootgrp: + rootgrp.createGroup(group) + + +def _datatree_to_netcdf( + dt: DataTree, + filepath, + mode: str = "w", + encoding=None, + unlimited_dims=None, + **kwargs +): + + if kwargs.get("format", None) not in [None, "NETCDF4"]: + raise ValueError("to_netcdf only supports the NETCDF4 format") + + if kwargs.get("engine", None) not in [None, "netcdf4", "h5netcdf"]: + raise ValueError("to_netcdf only supports the netcdf4 and h5netcdf engines") + + if kwargs.get("group", None) is not None: + raise NotImplementedError( + "specifying a root group for the tree has not been implemented" + ) + + if not kwargs.get("compute", True): + raise NotImplementedError("compute=False has not been implemented yet") + + if encoding is None: + encoding = {} + + if unlimited_dims is None: + unlimited_dims = {} + + ds = dt.ds + group_path = dt.pathstr.replace(dt.root.pathstr, "") + if ds is None: + _create_empty_group(filepath, group_path, mode) + else: + ds.to_netcdf( + filepath, + group=group_path, + mode=mode, + encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr), + unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr), + **kwargs + ) + mode = "a" + + for node in dt.descendants: + ds = node.ds + group_path = node.pathstr.replace(dt.root.pathstr, "") + if ds is None: + _create_empty_group(filepath, group_path, mode) + else: + ds.to_netcdf( + filepath, + group=group_path, + mode=mode, + encoding=_maybe_extract_group_kwargs(encoding, dt.pathstr), + unlimited_dims=_maybe_extract_group_kwargs(unlimited_dims, dt.pathstr), + **kwargs + ) diff --git a/datatree/tests/test_datatree.py b/datatree/tests/test_datatree.py index a1266af0..91412e6b 100644 --- a/datatree/tests/test_datatree.py +++ b/datatree/tests/test_datatree.py @@ -4,6 +4,7 @@ from xarray.testing import assert_identical from datatree import DataNode, DataTree +from datatree.io import open_datatree def create_test_datatree(): @@ -31,14 +32,14 @@ def create_test_datatree(): | Dimensions: (x: 2, y: 3) | Data variables: | a (y) int64 6, 7, 8 - | set1 (x) int64 9, 10 + | set0 (x) int64 9, 10 The structure has deliberately repeated names of tags, variables, and dimensions in order to better check for bugs caused by name conflicts. """ set1_data = xr.Dataset({"a": 0, "b": 1}) set2_data = xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])}) - root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set1": ("x", [9, 10])}) + root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}) # Avoid using __init__ so we can independently test it # TODO change so it has a DataTree at the bottom @@ -297,4 +298,21 @@ def test_repr_of_node_with_data(self): class TestIO: - ... + def test_to_netcdf(self, tmpdir): + filepath = str( + tmpdir / "test.nc" + ) # casting to str avoids a pathlib bug in xarray + original_dt = create_test_datatree() + original_dt.to_netcdf(filepath, engine="netcdf4") + + roundtrip_dt = open_datatree(filepath) + + original_dt.name == roundtrip_dt.name + assert original_dt.ds.identical(roundtrip_dt.ds) + for a, b in zip(original_dt.descendants, roundtrip_dt.descendants): + assert a.name == b.name + assert a.pathstr == b.pathstr + if a.has_data: + assert a.ds.identical(b.ds) + else: + assert a.ds is b.ds