diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 60364a87fd0..331c3baf89e 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -38,6 +38,8 @@ New Features By `Stephan Hoyer `_. - Added typehints in :py:func:`align` to reflect that the same type received in ``objects`` arg will be returned (:pull:`4522`). By `Michal Baumgartner `_. +- :py:meth:`Dataset.weighted` and :py:meth:`DataArray.weighted` are now executing value checks lazily if weights are provided as dask arrays (:issue:`4541`, :pull:`4559`). + By `Julius Busecke `_. Bug fixes ~~~~~~~~~ diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 96b4c79f245..ab4a0958866 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -1,7 +1,9 @@ from typing import TYPE_CHECKING, Hashable, Iterable, Optional, Union, overload +from . import duck_array_ops from .computation import dot from .options import _get_keep_attrs +from .pycompat import is_duck_dask_array if TYPE_CHECKING: from .dataarray import DataArray, Dataset @@ -100,12 +102,24 @@ def __init__(self, obj, weights): if not isinstance(weights, DataArray): raise ValueError("`weights` must be a DataArray") - if weights.isnull().any(): - raise ValueError( - "`weights` cannot contain missing values. " - "Missing values can be replaced by `weights.fillna(0)`." + def _weight_check(w): + # Ref https://github.com/pydata/xarray/pull/4559/files#r515968670 + if duck_array_ops.isnull(w).any(): + raise ValueError( + "`weights` cannot contain missing values. " + "Missing values can be replaced by `weights.fillna(0)`." + ) + return w + + if is_duck_dask_array(weights.data): + # assign to copy - else the check is not triggered + weights = weights.copy( + data=weights.data.map_blocks(_weight_check, dtype=weights.dtype) ) + else: + _weight_check(weights.data) + self.obj = obj self.weights = weights diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py index 48fad296664..2366b982cec 100644 --- a/xarray/tests/test_weighted.py +++ b/xarray/tests/test_weighted.py @@ -5,6 +5,8 @@ from xarray import DataArray from xarray.tests import assert_allclose, assert_equal, raises_regex +from . import raise_if_dask_computes, requires_dask + @pytest.mark.parametrize("as_dataset", (True, False)) def test_weighted_non_DataArray_weights(as_dataset): @@ -29,6 +31,24 @@ def test_weighted_weights_nan_raises(as_dataset, weights): data.weighted(DataArray(weights)) +@requires_dask +@pytest.mark.parametrize("as_dataset", (True, False)) +@pytest.mark.parametrize("weights", ([np.nan, 2], [np.nan, np.nan])) +def test_weighted_weights_nan_raises_dask(as_dataset, weights): + + data = DataArray([1, 2]).chunk({"dim_0": -1}) + if as_dataset: + data = data.to_dataset(name="data") + + weights = DataArray(weights).chunk({"dim_0": -1}) + + with raise_if_dask_computes(): + weighted = data.weighted(weights) + + with pytest.raises(ValueError, match="`weights` cannot contain missing values."): + weighted.sum().load() + + @pytest.mark.parametrize( ("weights", "expected"), (([1, 2], 3), ([2, 0], 2), ([0, 0], np.nan), ([-1, 1], np.nan)),