diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 031b8cf23e1..71b34a1445a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -39,6 +39,8 @@ Bug fixes - :py:meth:`~xarray.Dataset.to_stacked_array` now uses dimensions in order of appearance. This fixes the issue where using :py:meth:`~xarray.Dataset.transpose` before :py:meth:`~xarray.Dataset.to_stacked_array` had no effect. (Mentioned in :issue:`9921`) +- Enable ``keep_attrs`` in ``DatasetView.map`` relevant for :py:func:`map_over_datasets` (:pull:`10219`) + By `Mathias Hauser `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index f963b021eed..734927fd3d1 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -45,6 +45,7 @@ ) from xarray.core.indexes import Index, Indexes from xarray.core.options import OPTIONS as XR_OPTS +from xarray.core.options import _get_keep_attrs from xarray.core.treenode import NamedNode, NodePath, zip_subtrees from xarray.core.types import Self from xarray.core.utils import ( @@ -421,13 +422,18 @@ def map( # type: ignore[override] # Copied from xarray.Dataset so as not to call type(self), which causes problems (see https://github.com/xarray-contrib/datatree/issues/188). # TODO Refactor xarray upstream to avoid needing to overwrite this. - # TODO This copied version will drop all attrs - the keep_attrs stuff should be re-instated + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) variables = { k: maybe_wrap_array(v, func(v, *args, **kwargs)) for k, v in self.data_vars.items() } + if keep_attrs: + for k, v in variables.items(): + v._copy_attrs_from(self.data_vars[k]) + attrs = self.attrs if keep_attrs else None # return type(self)(variables, attrs=attrs) - return Dataset(variables) + return Dataset(variables, attrs=attrs) class DataTree( diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 1592d9b0898..bf60bd43bd4 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -1019,6 +1019,29 @@ def weighted_mean(ds): weighted_mean(dt.dataset) + def test_map_keep_attrs(self) -> None: + # test DatasetView.map(..., keep_attrs=...) + data = xr.DataArray([1, 2, 3], dims="x", attrs={"da": "attrs"}) + ds = xr.Dataset({"data": data}, attrs={"ds": "attrs"}) + dt = DataTree(ds) + + def func(ds): + # x.mean() removes the attrs of the data_vars + return ds.map(lambda x: x.mean(), keep_attrs=True) + + result = xr.map_over_datasets(func, dt) + expected = dt.mean(keep_attrs=True) + xr.testing.assert_identical(result, expected) + + # per default DatasetView.map does not keep attrs + def func(ds): + # x.mean() removes the attrs of the data_vars + return ds.map(lambda x: x.mean()) + + result = xr.map_over_datasets(func, dt) + expected = dt.mean() + xr.testing.assert_identical(result, expected.mean()) + class TestAccess: def test_attribute_access(self, create_test_datatree) -> None: