diff --git a/doc/whats-new.rst b/doc/whats-new.rst index d0e2ef3bd59..a7ce49320df 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -91,6 +91,8 @@ Deprecations Bug fixes ~~~~~~~~~ +- Properly support :py:meth:`DataArray.ffill`, :py:meth:`DataArray.bfill`, :py:meth:`Dataset.ffill`, :py:meth:`Dataset.bfill` along chunked dimensions. + (:issue:`2699`).By `Deepak Cherian `_. - Fix 2d plot failure for certain combinations of dimensions when `x` is 1d and `y` is 2d (:issue:`5097`, :pull:`5099`). By `John Omotani `_. - Ensure standard calendar times encoded with large values (i.e. greater than approximately 292 years), can be decoded correctly without silently overflowing (:pull:`5050`). This was a regression in xarray 0.17.0. By `Zeb Nicholls `_. diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index 25f082ec3c5..87f67028862 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -51,3 +51,24 @@ def least_squares(lhs, rhs, rcond=None, skipna=False): # See issue dask/dask#6516 coeffs, residuals, _, _ = da.linalg.lstsq(lhs_da, rhs) return coeffs, residuals + + +def push(array, n, axis): + """ + Dask-aware bottleneck.push + """ + from bottleneck import push + + if len(array.chunks[axis]) > 1 and n is not None and n < array.shape[axis]: + raise NotImplementedError( + "Cannot fill along a chunked axis when limit is not None." + "Either rechunk to a single chunk along this axis or call .compute() or .load() first." + ) + if all(c == 1 for c in array.chunks[axis]): + array = array.rechunk({axis: 2}) + pushed = array.map_blocks(push, axis=axis, n=n) + if len(array.chunks[axis]) > 1: + pushed = pushed.map_overlap( + push, axis=axis, n=n, depth={axis: (1, 0)}, boundary="none" + ) + return pushed diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 7a65c2b6494..635d4fe9e80 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2515,7 +2515,8 @@ def ffill(self, dim: Hashable, limit: int = None) -> "DataArray": The maximum number of consecutive NaN values to forward fill. In other words, if there is a gap with more than this number of consecutive NaNs, it will only be partially filled. Must be greater - than 0 or None for no limit. + than 0 or None for no limit. Must be None or greater than or equal + to axis length if filling along chunked axes (dimensions). Returns ------- @@ -2539,7 +2540,8 @@ def bfill(self, dim: Hashable, limit: int = None) -> "DataArray": The maximum number of consecutive NaN values to backward fill. In other words, if there is a gap with more than this number of consecutive NaNs, it will only be partially filled. Must be greater - than 0 or None for no limit. + than 0 or None for no limit. Must be None or greater than or equal + to axis length if filling along chunked axes (dimensions). Returns ------- diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 223e21a82e6..6638834f9d8 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4654,7 +4654,8 @@ def ffill(self, dim: Hashable, limit: int = None) -> "Dataset": The maximum number of consecutive NaN values to forward fill. In other words, if there is a gap with more than this number of consecutive NaNs, it will only be partially filled. Must be greater - than 0 or None for no limit. + than 0 or None for no limit. Must be None or greater than or equal + to axis length if filling along chunked axes (dimensions). Returns ------- @@ -4679,7 +4680,8 @@ def bfill(self, dim: Hashable, limit: int = None) -> "Dataset": The maximum number of consecutive NaN values to backward fill. In other words, if there is a gap with more than this number of consecutive NaNs, it will only be partially filled. Must be greater - than 0 or None for no limit. + than 0 or None for no limit. Must be None or greater than or equal + to axis length if filling along chunked axes (dimensions). Returns ------- diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 9dcd7906ef7..491f0925d73 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -631,3 +631,12 @@ def least_squares(lhs, rhs, rcond=None, skipna=False): return dask_array_ops.least_squares(lhs, rhs, rcond=rcond, skipna=skipna) else: return nputils.least_squares(lhs, rhs, rcond=rcond, skipna=skipna) + + +def push(array, n, axis): + from bottleneck import push + + if is_duck_dask_array(array): + return dask_array_ops.push(array, n, axis) + else: + return push(array, n, axis) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index e6dd8b537a0..1407107a7be 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -11,7 +11,7 @@ from . import utils from .common import _contains_datetime_like_objects, ones_like from .computation import apply_ufunc -from .duck_array_ops import datetime_to_numeric, timedelta_to_numeric +from .duck_array_ops import datetime_to_numeric, push, timedelta_to_numeric from .options import _get_keep_attrs from .pycompat import is_duck_dask_array from .utils import OrderedSet, is_scalar @@ -390,12 +390,10 @@ def func_interpolate_na(interpolator, y, x, **kwargs): def _bfill(arr, n=None, axis=-1): """inverse of ffill""" - import bottleneck as bn - arr = np.flip(arr, axis=axis) # fill - arr = bn.push(arr, axis=axis, n=n) + arr = push(arr, axis=axis, n=n) # reverse back to original return np.flip(arr, axis=axis) @@ -403,17 +401,15 @@ def _bfill(arr, n=None, axis=-1): def ffill(arr, dim=None, limit=None): """forward fill missing values""" - import bottleneck as bn - axis = arr.get_axis_num(dim) # work around for bottleneck 178 _limit = limit if limit is not None else arr.shape[axis] return apply_ufunc( - bn.push, + push, arr, - dask="parallelized", + dask="allowed", keep_attrs=True, output_dtypes=[arr.dtype], kwargs=dict(n=_limit, axis=axis), @@ -430,7 +426,7 @@ def bfill(arr, dim=None, limit=None): return apply_ufunc( _bfill, arr, - dask="parallelized", + dask="allowed", keep_attrs=True, output_dtypes=[arr.dtype], kwargs=dict(n=_limit, axis=axis), diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index 1dd26bab6b6..d1ee1c14052 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -20,6 +20,7 @@ mean, np_timedelta64_to_float, pd_timedelta_to_float, + push, py_timedelta_to_float, stack, timedelta_to_numeric, @@ -34,6 +35,7 @@ has_dask, has_scipy, raise_if_dask_computes, + requires_bottleneck, requires_cftime, requires_dask, ) @@ -858,3 +860,26 @@ def test_least_squares(use_dask, skipna): np.testing.assert_allclose(coeffs, [1.5, 1.25]) np.testing.assert_allclose(residuals, [2.0]) + + +@requires_dask +@requires_bottleneck +def test_push_dask(): + import bottleneck + import dask.array + + array = np.array([np.nan, np.nan, np.nan, 1, 2, 3, np.nan, np.nan, 4, 5, np.nan, 6]) + expected = bottleneck.push(array, axis=0) + for c in range(1, 11): + with raise_if_dask_computes(): + actual = push(dask.array.from_array(array, chunks=c), axis=0, n=None) + np.testing.assert_equal(actual, expected) + + # some chunks of size-1 with NaN + with raise_if_dask_computes(): + actual = push( + dask.array.from_array(array, chunks=(1, 2, 3, 2, 2, 1, 1)), + axis=0, + n=None, + ) + np.testing.assert_equal(actual, expected) diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index 202cf99715d..e4c74b40ec0 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -17,6 +17,7 @@ assert_allclose, assert_array_equal, assert_equal, + raise_if_dask_computes, requires_bottleneck, requires_cftime, requires_dask, @@ -393,37 +394,39 @@ def test_ffill(): @requires_bottleneck @requires_dask -def test_ffill_dask(): +@pytest.mark.parametrize("method", ["ffill", "bfill"]) +def test_ffill_bfill_dask(method): da, _ = make_interpolate_example_data((40, 40), 0.5) da = da.chunk({"x": 5}) - actual = da.ffill("time") - expected = da.load().ffill("time") - assert isinstance(actual.data, dask_array_type) - assert_equal(actual, expected) - # with limit - da = da.chunk({"x": 5}) - actual = da.ffill("time", limit=3) - expected = da.load().ffill("time", limit=3) - assert isinstance(actual.data, dask_array_type) + dask_method = getattr(da, method) + numpy_method = getattr(da.compute(), method) + # unchunked axis + with raise_if_dask_computes(): + actual = dask_method("time") + expected = numpy_method("time") assert_equal(actual, expected) - -@requires_bottleneck -@requires_dask -def test_bfill_dask(): - da, _ = make_interpolate_example_data((40, 40), 0.5) - da = da.chunk({"x": 5}) - actual = da.bfill("time") - expected = da.load().bfill("time") - assert isinstance(actual.data, dask_array_type) + # chunked axis + with raise_if_dask_computes(): + actual = dask_method("x") + expected = numpy_method("x") assert_equal(actual, expected) # with limit - da = da.chunk({"x": 5}) - actual = da.bfill("time", limit=3) - expected = da.load().bfill("time", limit=3) - assert isinstance(actual.data, dask_array_type) + with raise_if_dask_computes(): + actual = dask_method("time", limit=3) + expected = numpy_method("time", limit=3) + assert_equal(actual, expected) + + # limit < axis size + with pytest.raises(NotImplementedError): + actual = dask_method("x", limit=2) + + # limit > axis size + with raise_if_dask_computes(): + actual = dask_method("x", limit=41) + expected = numpy_method("x", limit=41) assert_equal(actual, expected)