Skip to content

Propagate attrs with unary, binary functions #4195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Oct 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ New Features
- :py:func:`open_dataset` and :py:func:`open_mfdataset`
now works with ``engine="zarr"`` (:issue:`3668`, :pull:`4003`, :pull:`4187`).
By `Miguel Jimenez <https://github.com/Mikejmnez>`_ and `Wei Ji Leong <https://github.com/weiji14>`_.
- Unary & binary operations follow the ``keep_attrs`` flag (:issue:`3490`, :issue:`4065`, :issue:`3433`, :issue:`3595`, :pull:`4195`).
By `Deepak Cherian <https://github.com/dcherian>`_.

Bug fixes
~~~~~~~~~
Expand Down
3 changes: 2 additions & 1 deletion xarray/core/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

from .options import OPTIONS
from .options import OPTIONS, _get_keep_attrs
from .pycompat import dask_array_type
from .utils import not_implemented

Expand Down Expand Up @@ -77,6 +77,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
dataset_fill_value=np.nan,
kwargs=kwargs,
dask="allowed",
keep_attrs=_get_keep_attrs(default=True),
)

# this has no runtime function - these are listed so IDEs know these
Expand Down
38 changes: 30 additions & 8 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@
_JOINS_WITHOUT_FILL_VALUES = frozenset({"inner", "exact"})


def _first_of_type(args, kind):
""" Return either first object of type 'kind' or raise if not found. """
for arg in args:
if isinstance(arg, kind):
return arg
raise ValueError("This should be unreachable.")


class _UFuncSignature:
"""Core dimensions signature for a given function.

Expand Down Expand Up @@ -252,8 +260,9 @@ def apply_dataarray_vfunc(
args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False
)

if keep_attrs and hasattr(args[0], "name"):
name = args[0].name
if keep_attrs:
first_obj = _first_of_type(args, DataArray)
name = first_obj.name
else:
name = result_name(args)
result_coords = build_output_coords(args, signature, exclude_dims)
Expand All @@ -270,6 +279,14 @@ def apply_dataarray_vfunc(
(coords,) = result_coords
out = DataArray(result_var, coords, name=name, fastpath=True)

if keep_attrs:
if isinstance(out, tuple):
for da in out:
# This is adding attrs in place
da._copy_attrs_from(first_obj)
else:
out._copy_attrs_from(first_obj)

return out


Expand Down Expand Up @@ -390,15 +407,16 @@ def apply_dataset_vfunc(
"""
from .dataset import Dataset

first_obj = args[0] # we'll copy attrs from this in case keep_attrs=True

if dataset_join not in _JOINS_WITHOUT_FILL_VALUES and fill_value is _NO_FILL_VALUE:
raise TypeError(
"to apply an operation to datasets with different "
"data variables with apply_ufunc, you must supply the "
"dataset_fill_value argument."
)

if keep_attrs:
first_obj = _first_of_type(args, Dataset)

if len(args) > 1:
args = deep_align(
args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False
Expand All @@ -417,9 +435,11 @@ def apply_dataset_vfunc(
(coord_vars,) = list_of_coords
out = _fast_dataset(result_vars, coord_vars)

if keep_attrs and isinstance(first_obj, Dataset):
if keep_attrs:
if isinstance(out, tuple):
out = tuple(ds._copy_attrs_from(first_obj) for ds in out)
for ds in out:
# This is adding attrs in place
ds._copy_attrs_from(first_obj)
else:
out._copy_attrs_from(first_obj)
return out
Expand Down Expand Up @@ -595,6 +615,8 @@ def apply_variable_ufunc(
"""Apply a ndarray level function over Variable and/or ndarray objects."""
from .variable import Variable, as_compatible_data

first_obj = _first_of_type(args, Variable)

dim_sizes = unified_dim_sizes(
(a for a in args if hasattr(a, "dims")), exclude_dims=exclude_dims
)
Expand Down Expand Up @@ -734,8 +756,8 @@ def func(*arrays):
)
)

if keep_attrs and isinstance(args[0], Variable):
var.attrs.update(args[0].attrs)
if keep_attrs:
var.attrs.update(first_obj.attrs)
output.append(var)

if signature.num_outputs == 1:
Expand Down
10 changes: 8 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from .indexes import Indexes, default_indexes, propagate_indexes
from .indexing import is_fancy_indexer
from .merge import PANDAS_TYPES, MergeError, _extract_indexes_from_coords
from .options import OPTIONS
from .options import OPTIONS, _get_keep_attrs
from .utils import Default, ReprObject, _check_inplace, _default, either_dict_or_kwargs
from .variable import (
IndexVariable,
Expand Down Expand Up @@ -2734,13 +2734,19 @@ def __rmatmul__(self, other):
def _unary_op(f: Callable[..., Any]) -> Callable[..., "DataArray"]:
@functools.wraps(f)
def func(self, *args, **kwargs):
keep_attrs = kwargs.pop("keep_attrs", None)
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=True)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered")
warnings.filterwarnings(
"ignore", r"Mean of empty slice", category=RuntimeWarning
)
with np.errstate(all="ignore"):
return self.__array_wrap__(f(self.variable.data, *args, **kwargs))
da = self.__array_wrap__(f(self.variable.data, *args, **kwargs))
if keep_attrs:
da.attrs = self.attrs
return da

return func

Expand Down
18 changes: 13 additions & 5 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4394,12 +4394,15 @@ def map(
foo (dim_0, dim_1) float64 1.764 0.4002 0.9787 2.241 1.868 0.9773
bar (x) float64 1.0 2.0
"""
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)
variables = {
k: maybe_wrap_array(v, func(v, *args, **kwargs))
for k, v in self.data_vars.items()
}
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)
if keep_attrs:
for k, v in variables.items():
v._copy_attrs_from(self.data_vars[k])
attrs = self.attrs if keep_attrs else None
return type(self)(variables, attrs=attrs)

Expand Down Expand Up @@ -4930,15 +4933,20 @@ def from_dict(cls, d):
return obj

@staticmethod
def _unary_op(f, keep_attrs=False):
def _unary_op(f):
@functools.wraps(f)
def func(self, *args, **kwargs):
variables = {}
keep_attrs = kwargs.pop("keep_attrs", None)
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=True)
for k, v in self._variables.items():
if k in self._coord_names:
variables[k] = v
else:
variables[k] = f(v, *args, **kwargs)
if keep_attrs:
variables[k].attrs = v._attrs
attrs = self._attrs if keep_attrs else None
return self._replace_with_new_dims(variables, attrs=attrs)

Expand Down Expand Up @@ -5677,11 +5685,11 @@ def _integrate_one(self, coord, datetime_unit=None):

@property
def real(self):
return self._unary_op(lambda x: x.real, keep_attrs=True)(self)
return self.map(lambda x: x.real, keep_attrs=True)

@property
def imag(self):
return self._unary_op(lambda x: x.imag, keep_attrs=True)(self)
return self.map(lambda x: x.imag, keep_attrs=True)

plot = utils.UncachedAccessor(_Dataset_PlotMethods)

Expand Down
2 changes: 1 addition & 1 deletion xarray/core/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _get_keep_attrs(default):
return global_choice
else:
raise ValueError(
"The global option keep_attrs must be one of" " True, False or 'default'."
"The global option keep_attrs must be one of True, False or 'default'."
)


Expand Down
8 changes: 7 additions & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2111,8 +2111,14 @@ def __array_wrap__(self, obj, context=None):
def _unary_op(f):
@functools.wraps(f)
def func(self, *args, **kwargs):
keep_attrs = kwargs.pop("keep_attrs", None)
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=True)
with np.errstate(all="ignore"):
return self.__array_wrap__(f(self.data, *args, **kwargs))
result = self.__array_wrap__(f(self.data, *args, **kwargs))
if keep_attrs:
result.attrs = self.attrs
return result

return func

Expand Down
25 changes: 24 additions & 1 deletion xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,15 @@
import pytest

import xarray as xr
from xarray import DataArray, Dataset, IndexVariable, Variable, align, broadcast
from xarray import (
DataArray,
Dataset,
IndexVariable,
Variable,
align,
broadcast,
set_options,
)
from xarray.coding.times import CFDatetimeCoder
from xarray.convert import from_cdms2
from xarray.core import dtypes
Expand Down Expand Up @@ -2486,6 +2494,21 @@ def test_assign_attrs(self):
assert_identical(new_actual, expected)
assert actual.attrs == {"a": 1, "b": 2}

@pytest.mark.parametrize(
"func", [lambda x: x.clip(0, 1), lambda x: np.float64(1.0) * x, np.abs, abs]
)
def test_propagate_attrs(self, func):
da = DataArray(self.va)

# test defaults
assert func(da).attrs == da.attrs

with set_options(keep_attrs=False):
assert func(da).attrs == {}

with set_options(keep_attrs=True):
assert func(da).attrs == da.attrs

def test_fillna(self):
a = DataArray([np.nan, 1, np.nan, 3], coords={"x": range(4)}, dims="x")
actual = a.fillna(-1)
Expand Down
22 changes: 22 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4473,6 +4473,28 @@ def test_fillna(self):
assert actual.a.name == "a"
assert actual.a.attrs == ds.a.attrs

@pytest.mark.parametrize(
"func", [lambda x: x.clip(0, 1), lambda x: np.float64(1.0) * x, np.abs, abs]
)
def test_propagate_attrs(self, func):

da = DataArray(range(5), name="a", attrs={"attr": "da"})
ds = Dataset({"a": da}, attrs={"attr": "ds"})

# test defaults
assert func(ds).attrs == ds.attrs
with set_options(keep_attrs=False):
assert func(ds).attrs != ds.attrs
assert func(ds).a.attrs != ds.a.attrs

with set_options(keep_attrs=False):
assert func(ds).attrs != ds.attrs
assert func(ds).a.attrs != ds.a.attrs

with set_options(keep_attrs=True):
assert func(ds).attrs == ds.attrs
assert func(ds).a.attrs == ds.a.attrs

def test_where(self):
ds = Dataset({"a": ("x", range(5))})
expected = Dataset({"a": ("x", [np.nan, np.nan, 2, 3, 4])})
Expand Down
3 changes: 2 additions & 1 deletion xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,8 @@ def test_1d_math(self):
assert_array_equal(y - v, 1 - v)
# verify attributes are dropped
v2 = self.cls(["x"], x, {"units": "meters"})
assert_identical(base_v, +v2)
with set_options(keep_attrs=False):
assert_identical(base_v, +v2)
# binary ops with all variables
assert_array_equal(v + v, 2 * v)
w = self.cls(["x"], y, {"foo": "bar"})
Expand Down
1 change: 0 additions & 1 deletion xarray/tests/test_weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,6 @@ def test_weighted_operations_keep_attr(operation, as_dataset, keep_attrs):
assert not result.attrs


@pytest.mark.xfail(reason="xr.Dataset.map does not copy attrs of DataArrays GH: 3595")
@pytest.mark.parametrize("operation", ("sum", "mean"))
def test_weighted_operations_keep_attr_da_in_ds(operation):
# GH #3595
Expand Down