Skip to content

Commit 03685fb

Browse files
committed
Propagate attrs with unary, binary functions
Closes #3490 Closes #4065 Closes #3433 Closes #3595
1 parent 54b9450 commit 03685fb

File tree

5 files changed

+58
-10
lines changed

5 files changed

+58
-10
lines changed

xarray/core/arithmetic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import numpy as np
55

6-
from .options import OPTIONS
6+
from .options import OPTIONS, _get_keep_attrs
77
from .pycompat import dask_array_type
88
from .utils import not_implemented
99

@@ -77,6 +77,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
7777
dataset_fill_value=np.nan,
7878
kwargs=kwargs,
7979
dask="allowed",
80+
keep_attrs=_get_keep_attrs(default=False),
8081
)
8182

8283
# this has no runtime function - these are listed so IDEs know these

xarray/core/computation.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,13 @@ def apply_dataarray_vfunc(
223223
args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False
224224
)
225225

226-
if keep_attrs and hasattr(args[0], "name"):
227-
name = args[0].name
226+
for arg in args:
227+
first_obj = arg
228+
if isinstance(arg, DataArray):
229+
break
230+
231+
if keep_attrs and hasattr(first_obj, "name"):
232+
name = first_obj.name
228233
else:
229234
name = result_name(args)
230235
result_coords = build_output_coords(args, signature, exclude_dims)
@@ -241,6 +246,12 @@ def apply_dataarray_vfunc(
241246
(coords,) = result_coords
242247
out = DataArray(result_var, coords, name=name, fastpath=True)
243248

249+
if keep_attrs and hasattr(first_obj, "attrs"):
250+
if isinstance(out, tuple):
251+
out = tuple(da._copy_attrs_from(first_obj) for da in out)
252+
else:
253+
out._copy_attrs_from(first_obj)
254+
244255
return out
245256

246257

@@ -361,7 +372,10 @@ def apply_dataset_vfunc(
361372
"""
362373
from .dataset import Dataset
363374

364-
first_obj = args[0] # we'll copy attrs from this in case keep_attrs=True
375+
for arg in args:
376+
first_obj = args
377+
if isinstance(first_obj, Dataset):
378+
break
365379

366380
if dataset_join not in _JOINS_WITHOUT_FILL_VALUES and fill_value is _NO_FILL_VALUE:
367381
raise TypeError(
@@ -554,6 +568,11 @@ def apply_variable_ufunc(
554568
"""
555569
from .variable import Variable, as_compatible_data
556570

571+
for arg in args:
572+
first_obj = arg
573+
if isinstance(arg, Variable):
574+
break
575+
557576
dim_sizes = unified_dim_sizes(
558577
(a for a in args if hasattr(a, "dims")), exclude_dims=exclude_dims
559578
)
@@ -639,8 +658,8 @@ def func(*arrays):
639658
)
640659
)
641660

642-
if keep_attrs and isinstance(args[0], Variable):
643-
var.attrs.update(args[0].attrs)
661+
if keep_attrs and isinstance(first_obj, Variable):
662+
var.attrs.update(first_obj.attrs)
644663
output.append(var)
645664

646665
if signature.num_outputs == 1:

xarray/core/dataset.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4368,12 +4368,15 @@ def map(
43684368
foo (dim_0, dim_1) float64 0.3751 1.951 1.945 0.2948 0.711 0.3948
43694369
bar (x) float64 1.0 2.0
43704370
"""
4371+
if keep_attrs is None:
4372+
keep_attrs = _get_keep_attrs(default=False)
43714373
variables = {
43724374
k: maybe_wrap_array(v, func(v, *args, **kwargs))
43734375
for k, v in self.data_vars.items()
43744376
}
4375-
if keep_attrs is None:
4376-
keep_attrs = _get_keep_attrs(default=False)
4377+
if keep_attrs:
4378+
for k, v in variables.items():
4379+
v._copy_attrs_from(self.data_vars[k])
43774380
attrs = self.attrs if keep_attrs else None
43784381
return type(self)(variables, attrs=attrs)
43794382

xarray/core/variable.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2024,7 +2024,9 @@ def imag(self):
20242024
return type(self)(self.dims, self.data.imag, self._attrs)
20252025

20262026
def __array_wrap__(self, obj, context=None):
2027-
return Variable(self.dims, obj)
2027+
keep_attrs = _get_keep_attrs(default=False)
2028+
attrs = self._attrs if keep_attrs else {}
2029+
return Variable(self.dims, obj, attrs)
20282030

20292031
@staticmethod
20302032
def _unary_op(f):

xarray/tests/test_dataarray.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,15 @@
99
import pytest
1010

1111
import xarray as xr
12-
from xarray import DataArray, Dataset, IndexVariable, Variable, align, broadcast
12+
from xarray import (
13+
DataArray,
14+
Dataset,
15+
IndexVariable,
16+
Variable,
17+
align,
18+
broadcast,
19+
set_options,
20+
)
1321
from xarray.coding.times import CFDatetimeCoder
1422
from xarray.convert import from_cdms2
1523
from xarray.core import dtypes
@@ -2458,6 +2466,21 @@ def test_assign_attrs(self):
24582466
assert_identical(new_actual, expected)
24592467
assert actual.attrs == {"a": 1, "b": 2}
24602468

2469+
def test_propagate_attrs(self):
2470+
da = DataArray(self.va)
2471+
2472+
# test defaults
2473+
assert da.clip(0, 1).attrs != da.attrs
2474+
assert (np.float64(1.0) * da).attrs != da.attrs
2475+
assert np.abs(da).attrs != da.attrs
2476+
assert abs(da).attrs != da.attrs
2477+
2478+
with set_options(keep_attrs=True):
2479+
assert da.clip(0, 1).attrs == da.attrs
2480+
assert (np.float64(1.0) * da).attrs == da.attrs
2481+
assert np.abs(da).attrs == da.attrs
2482+
assert abs(da).attrs == da.attrs
2483+
24612484
def test_fillna(self):
24622485
a = DataArray([np.nan, 1, np.nan, 3], coords={"x": range(4)}, dims="x")
24632486
actual = a.fillna(-1)

0 commit comments

Comments
 (0)