Skip to content

Commit db4f03e

Browse files
dcherianmax-sixty
andauthored
Propagate attrs with unary, binary functions (#4195)
* Propagate attrs with unary, binary functions Closes #3490 Closes #4065 Closes #3433 Closes #3595 * Un xfail test * bugfix * Some progress. Still need keep_attrs in DataArray._unary_op * Fix dataset attrs * whats-new * small fix * Fix imag, real * fix variable tests * fix multiple return variables. * review comments * Update doc/whats-new.rst * Propagate attrs with DataArray unary ops * More tests * Small cleanup * Review comments. * Fix duplication Co-authored-by: Maximilian Roos <[email protected]>
1 parent 92e49f9 commit db4f03e

File tree

10 files changed

+110
-20
lines changed

10 files changed

+110
-20
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ New Features
3030
- :py:func:`open_dataset` and :py:func:`open_mfdataset`
3131
now works with ``engine="zarr"`` (:issue:`3668`, :pull:`4003`, :pull:`4187`).
3232
By `Miguel Jimenez <https://github.com/Mikejmnez>`_ and `Wei Ji Leong <https://github.com/weiji14>`_.
33+
- Unary & binary operations follow the ``keep_attrs`` flag (:issue:`3490`, :issue:`4065`, :issue:`3433`, :issue:`3595`, :pull:`4195`).
34+
By `Deepak Cherian <https://github.com/dcherian>`_.
3335

3436
Bug fixes
3537
~~~~~~~~~

xarray/core/arithmetic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import numpy as np
55

6-
from .options import OPTIONS
6+
from .options import OPTIONS, _get_keep_attrs
77
from .pycompat import dask_array_type
88
from .utils import not_implemented
99

@@ -77,6 +77,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
7777
dataset_fill_value=np.nan,
7878
kwargs=kwargs,
7979
dask="allowed",
80+
keep_attrs=_get_keep_attrs(default=True),
8081
)
8182

8283
# this has no runtime function - these are listed so IDEs know these

xarray/core/computation.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@
4242
_JOINS_WITHOUT_FILL_VALUES = frozenset({"inner", "exact"})
4343

4444

45+
def _first_of_type(args, kind):
46+
""" Return either first object of type 'kind' or raise if not found. """
47+
for arg in args:
48+
if isinstance(arg, kind):
49+
return arg
50+
raise ValueError("This should be unreachable.")
51+
52+
4553
class _UFuncSignature:
4654
"""Core dimensions signature for a given function.
4755
@@ -252,8 +260,9 @@ def apply_dataarray_vfunc(
252260
args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False
253261
)
254262

255-
if keep_attrs and hasattr(args[0], "name"):
256-
name = args[0].name
263+
if keep_attrs:
264+
first_obj = _first_of_type(args, DataArray)
265+
name = first_obj.name
257266
else:
258267
name = result_name(args)
259268
result_coords = build_output_coords(args, signature, exclude_dims)
@@ -270,6 +279,14 @@ def apply_dataarray_vfunc(
270279
(coords,) = result_coords
271280
out = DataArray(result_var, coords, name=name, fastpath=True)
272281

282+
if keep_attrs:
283+
if isinstance(out, tuple):
284+
for da in out:
285+
# This is adding attrs in place
286+
da._copy_attrs_from(first_obj)
287+
else:
288+
out._copy_attrs_from(first_obj)
289+
273290
return out
274291

275292

@@ -390,15 +407,16 @@ def apply_dataset_vfunc(
390407
"""
391408
from .dataset import Dataset
392409

393-
first_obj = args[0] # we'll copy attrs from this in case keep_attrs=True
394-
395410
if dataset_join not in _JOINS_WITHOUT_FILL_VALUES and fill_value is _NO_FILL_VALUE:
396411
raise TypeError(
397412
"to apply an operation to datasets with different "
398413
"data variables with apply_ufunc, you must supply the "
399414
"dataset_fill_value argument."
400415
)
401416

417+
if keep_attrs:
418+
first_obj = _first_of_type(args, Dataset)
419+
402420
if len(args) > 1:
403421
args = deep_align(
404422
args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False
@@ -417,9 +435,11 @@ def apply_dataset_vfunc(
417435
(coord_vars,) = list_of_coords
418436
out = _fast_dataset(result_vars, coord_vars)
419437

420-
if keep_attrs and isinstance(first_obj, Dataset):
438+
if keep_attrs:
421439
if isinstance(out, tuple):
422-
out = tuple(ds._copy_attrs_from(first_obj) for ds in out)
440+
for ds in out:
441+
# This is adding attrs in place
442+
ds._copy_attrs_from(first_obj)
423443
else:
424444
out._copy_attrs_from(first_obj)
425445
return out
@@ -595,6 +615,8 @@ def apply_variable_ufunc(
595615
"""Apply a ndarray level function over Variable and/or ndarray objects."""
596616
from .variable import Variable, as_compatible_data
597617

618+
first_obj = _first_of_type(args, Variable)
619+
598620
dim_sizes = unified_dim_sizes(
599621
(a for a in args if hasattr(a, "dims")), exclude_dims=exclude_dims
600622
)
@@ -734,8 +756,8 @@ def func(*arrays):
734756
)
735757
)
736758

737-
if keep_attrs and isinstance(args[0], Variable):
738-
var.attrs.update(args[0].attrs)
759+
if keep_attrs:
760+
var.attrs.update(first_obj.attrs)
739761
output.append(var)
740762

741763
if signature.num_outputs == 1:

xarray/core/dataarray.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
from .indexes import Indexes, default_indexes, propagate_indexes
5656
from .indexing import is_fancy_indexer
5757
from .merge import PANDAS_TYPES, MergeError, _extract_indexes_from_coords
58-
from .options import OPTIONS
58+
from .options import OPTIONS, _get_keep_attrs
5959
from .utils import Default, ReprObject, _check_inplace, _default, either_dict_or_kwargs
6060
from .variable import (
6161
IndexVariable,
@@ -2734,13 +2734,19 @@ def __rmatmul__(self, other):
27342734
def _unary_op(f: Callable[..., Any]) -> Callable[..., "DataArray"]:
27352735
@functools.wraps(f)
27362736
def func(self, *args, **kwargs):
2737+
keep_attrs = kwargs.pop("keep_attrs", None)
2738+
if keep_attrs is None:
2739+
keep_attrs = _get_keep_attrs(default=True)
27372740
with warnings.catch_warnings():
27382741
warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered")
27392742
warnings.filterwarnings(
27402743
"ignore", r"Mean of empty slice", category=RuntimeWarning
27412744
)
27422745
with np.errstate(all="ignore"):
2743-
return self.__array_wrap__(f(self.variable.data, *args, **kwargs))
2746+
da = self.__array_wrap__(f(self.variable.data, *args, **kwargs))
2747+
if keep_attrs:
2748+
da.attrs = self.attrs
2749+
return da
27442750

27452751
return func
27462752

xarray/core/dataset.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4403,12 +4403,15 @@ def map(
44034403
foo (dim_0, dim_1) float64 1.764 0.4002 0.9787 2.241 1.868 0.9773
44044404
bar (x) float64 1.0 2.0
44054405
"""
4406+
if keep_attrs is None:
4407+
keep_attrs = _get_keep_attrs(default=False)
44064408
variables = {
44074409
k: maybe_wrap_array(v, func(v, *args, **kwargs))
44084410
for k, v in self.data_vars.items()
44094411
}
4410-
if keep_attrs is None:
4411-
keep_attrs = _get_keep_attrs(default=False)
4412+
if keep_attrs:
4413+
for k, v in variables.items():
4414+
v._copy_attrs_from(self.data_vars[k])
44124415
attrs = self.attrs if keep_attrs else None
44134416
return type(self)(variables, attrs=attrs)
44144417

@@ -4939,15 +4942,20 @@ def from_dict(cls, d):
49394942
return obj
49404943

49414944
@staticmethod
4942-
def _unary_op(f, keep_attrs=False):
4945+
def _unary_op(f):
49434946
@functools.wraps(f)
49444947
def func(self, *args, **kwargs):
49454948
variables = {}
4949+
keep_attrs = kwargs.pop("keep_attrs", None)
4950+
if keep_attrs is None:
4951+
keep_attrs = _get_keep_attrs(default=True)
49464952
for k, v in self._variables.items():
49474953
if k in self._coord_names:
49484954
variables[k] = v
49494955
else:
49504956
variables[k] = f(v, *args, **kwargs)
4957+
if keep_attrs:
4958+
variables[k].attrs = v._attrs
49514959
attrs = self._attrs if keep_attrs else None
49524960
return self._replace_with_new_dims(variables, attrs=attrs)
49534961

@@ -5684,11 +5692,11 @@ def _integrate_one(self, coord, datetime_unit=None):
56845692

56855693
@property
56865694
def real(self):
5687-
return self._unary_op(lambda x: x.real, keep_attrs=True)(self)
5695+
return self.map(lambda x: x.real, keep_attrs=True)
56885696

56895697
@property
56905698
def imag(self):
5691-
return self._unary_op(lambda x: x.imag, keep_attrs=True)(self)
5699+
return self.map(lambda x: x.imag, keep_attrs=True)
56925700

56935701
plot = utils.UncachedAccessor(_Dataset_PlotMethods)
56945702

xarray/core/variable.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2102,8 +2102,14 @@ def __array_wrap__(self, obj, context=None):
21022102
def _unary_op(f):
21032103
@functools.wraps(f)
21042104
def func(self, *args, **kwargs):
2105+
keep_attrs = kwargs.pop("keep_attrs", None)
2106+
if keep_attrs is None:
2107+
keep_attrs = _get_keep_attrs(default=True)
21052108
with np.errstate(all="ignore"):
2106-
return self.__array_wrap__(f(self.data, *args, **kwargs))
2109+
result = self.__array_wrap__(f(self.data, *args, **kwargs))
2110+
if keep_attrs:
2111+
result.attrs = self.attrs
2112+
return result
21072113

21082114
return func
21092115

xarray/tests/test_dataarray.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,15 @@
99
import pytest
1010

1111
import xarray as xr
12-
from xarray import DataArray, Dataset, IndexVariable, Variable, align, broadcast
12+
from xarray import (
13+
DataArray,
14+
Dataset,
15+
IndexVariable,
16+
Variable,
17+
align,
18+
broadcast,
19+
set_options,
20+
)
1321
from xarray.coding.times import CFDatetimeCoder
1422
from xarray.convert import from_cdms2
1523
from xarray.core import dtypes
@@ -2486,6 +2494,21 @@ def test_assign_attrs(self):
24862494
assert_identical(new_actual, expected)
24872495
assert actual.attrs == {"a": 1, "b": 2}
24882496

2497+
@pytest.mark.parametrize(
2498+
"func", [lambda x: x.clip(0, 1), lambda x: np.float64(1.0) * x, np.abs, abs]
2499+
)
2500+
def test_propagate_attrs(self, func):
2501+
da = DataArray(self.va)
2502+
2503+
# test defaults
2504+
assert func(da).attrs == da.attrs
2505+
2506+
with set_options(keep_attrs=False):
2507+
assert func(da).attrs == {}
2508+
2509+
with set_options(keep_attrs=True):
2510+
assert func(da).attrs == da.attrs
2511+
24892512
def test_fillna(self):
24902513
a = DataArray([np.nan, 1, np.nan, 3], coords={"x": range(4)}, dims="x")
24912514
actual = a.fillna(-1)

xarray/tests/test_dataset.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4473,6 +4473,28 @@ def test_fillna(self):
44734473
assert actual.a.name == "a"
44744474
assert actual.a.attrs == ds.a.attrs
44754475

4476+
@pytest.mark.parametrize(
4477+
"func", [lambda x: x.clip(0, 1), lambda x: np.float64(1.0) * x, np.abs, abs]
4478+
)
4479+
def test_propagate_attrs(self, func):
4480+
4481+
da = DataArray(range(5), name="a", attrs={"attr": "da"})
4482+
ds = Dataset({"a": da}, attrs={"attr": "ds"})
4483+
4484+
# test defaults
4485+
assert func(ds).attrs == ds.attrs
4486+
with set_options(keep_attrs=False):
4487+
assert func(ds).attrs != ds.attrs
4488+
assert func(ds).a.attrs != ds.a.attrs
4489+
4490+
with set_options(keep_attrs=False):
4491+
assert func(ds).attrs != ds.attrs
4492+
assert func(ds).a.attrs != ds.a.attrs
4493+
4494+
with set_options(keep_attrs=True):
4495+
assert func(ds).attrs == ds.attrs
4496+
assert func(ds).a.attrs == ds.a.attrs
4497+
44764498
def test_where(self):
44774499
ds = Dataset({"a": ("x", range(5))})
44784500
expected = Dataset({"a": ("x", [np.nan, np.nan, 2, 3, 4])})

xarray/tests/test_variable.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,8 @@ def test_1d_math(self):
342342
assert_array_equal(y - v, 1 - v)
343343
# verify attributes are dropped
344344
v2 = self.cls(["x"], x, {"units": "meters"})
345-
assert_identical(base_v, +v2)
345+
with set_options(keep_attrs=False):
346+
assert_identical(base_v, +v2)
346347
# binary ops with all variables
347348
assert_array_equal(v + v, 2 * v)
348349
w = self.cls(["x"], y, {"foo": "bar"})

xarray/tests/test_weighted.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,6 @@ def test_weighted_operations_keep_attr(operation, as_dataset, keep_attrs):
320320
assert not result.attrs
321321

322322

323-
@pytest.mark.xfail(reason="xr.Dataset.map does not copy attrs of DataArrays GH: 3595")
324323
@pytest.mark.parametrize("operation", ("sum", "mean"))
325324
def test_weighted_operations_keep_attr_da_in_ds(operation):
326325
# GH #3595

0 commit comments

Comments
 (0)