Skip to content

Commit d7ac79a

Browse files
josephnowakpre-commit-ci[bot]dcherian
authored
Fix the push method when the limit parameter is bigger than the chunk… (#9940)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian <[email protected]>
1 parent 5fdceff commit d7ac79a

File tree

3 files changed

+43
-55
lines changed

3 files changed

+43
-55
lines changed

doc/whats-new.rst

+3
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ Deprecations
9595

9696
Bug fixes
9797
~~~~~~~~~
98+
99+
- Fix :py:meth:`DataArray.ffill`, :py:meth:`DataArray.bfill`, :py:meth:`Dataset.ffill` and :py:meth:`Dataset.bfill` when the limit is bigger than the chunksize (:issue:`9939`).
100+
By `Joseph Nowak <https://github.com/josephnowak>`_.
98101
- Fix issues related to Pandas v3 ("us" vs. "ns" for python datetime, copy on write) and handling of 0d-numpy arrays in datetime/timedelta decoding (:pull:`9953`).
99102
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
100103
- Remove dask-expr from CI runs, add "pyarrow" dask dependency to windows CI runs, fix related tests (:issue:`9962`, :pull:`9971`).

xarray/core/dask_array_ops.py

+13-37
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import math
4-
from functools import partial
54

65
from xarray.core import dtypes, nputils
76

@@ -92,31 +91,6 @@ def _dtype_push(a, axis, dtype=None):
9291
return _push(a, axis=axis)
9392

9493

95-
def _reset_cumsum(a, axis, dtype=None):
96-
import numpy as np
97-
98-
cumsum = np.cumsum(a, axis=axis)
99-
reset_points = np.maximum.accumulate(np.where(a == 0, cumsum, 0), axis=axis)
100-
return cumsum - reset_points
101-
102-
103-
def _last_reset_cumsum(a, axis, keepdims=None):
104-
import numpy as np
105-
106-
# Take the last cumulative sum taking into account the reset
107-
# This is useful for blelloch method
108-
return np.take(_reset_cumsum(a, axis=axis), axis=axis, indices=[-1])
109-
110-
111-
def _combine_reset_cumsum(a, b, axis):
112-
import numpy as np
113-
114-
# It is going to sum the previous result until the first
115-
# non nan value
116-
bitmask = np.cumprod(b != 0, axis=axis)
117-
return np.where(bitmask, b + a, b)
118-
119-
12094
def push(array, n, axis, method="blelloch"):
12195
"""
12296
Dask-aware bottleneck.push
@@ -145,16 +119,18 @@ def push(array, n, axis, method="blelloch"):
145119
)
146120

147121
if n is not None and 0 < n < array.shape[axis] - 1:
148-
valid_positions = da.reductions.cumreduction(
149-
func=_reset_cumsum,
150-
binop=partial(_combine_reset_cumsum, axis=axis),
151-
ident=0,
152-
x=da.isnan(array, dtype=int),
153-
axis=axis,
154-
dtype=int,
155-
method=method,
156-
preop=_last_reset_cumsum,
157-
)
158-
pushed_array = da.where(valid_positions <= n, pushed_array, np.nan)
122+
# The idea is to calculate a cumulative sum of a bitmask
123+
# created from the isnan method, but every time a False is found the sum
124+
# must be restarted, and the final result indicates the amount of contiguous
125+
# nan values found in the original array on every position
126+
nan_bitmask = da.isnan(array, dtype=int)
127+
cumsum_nan = nan_bitmask.cumsum(axis=axis, method=method)
128+
valid_positions = da.where(nan_bitmask == 0, cumsum_nan, np.nan)
129+
valid_positions = push(valid_positions, None, axis, method=method)
130+
# All the NaNs at the beginning are converted to 0
131+
valid_positions = da.nan_to_num(valid_positions)
132+
valid_positions = cumsum_nan - valid_positions
133+
valid_positions = valid_positions <= n
134+
pushed_array = da.where(valid_positions, pushed_array, np.nan)
159135

160136
return pushed_array

xarray/tests/test_duck_array_ops.py

+27-18
Original file line numberDiff line numberDiff line change
@@ -1025,31 +1025,40 @@ def test_least_squares(use_dask, skipna):
10251025
@requires_dask
10261026
@requires_bottleneck
10271027
@pytest.mark.parametrize("method", ["sequential", "blelloch"])
1028-
def test_push_dask(method):
1028+
@pytest.mark.parametrize(
1029+
"arr",
1030+
[
1031+
[np.nan, 1, 2, 3, np.nan, np.nan, np.nan, np.nan, 4, 5, np.nan, 6],
1032+
[
1033+
np.nan,
1034+
np.nan,
1035+
np.nan,
1036+
2,
1037+
np.nan,
1038+
np.nan,
1039+
np.nan,
1040+
9,
1041+
np.nan,
1042+
np.nan,
1043+
np.nan,
1044+
np.nan,
1045+
],
1046+
],
1047+
)
1048+
def test_push_dask(method, arr):
10291049
import bottleneck
1030-
import dask.array
1050+
import dask.array as da
10311051

1032-
array = np.array([np.nan, 1, 2, 3, np.nan, np.nan, np.nan, np.nan, 4, 5, np.nan, 6])
1052+
arr = np.array(arr)
1053+
chunks = list(range(1, 11)) + [(1, 2, 3, 2, 2, 1, 1)]
10331054

10341055
for n in [None, 1, 2, 3, 4, 5, 11]:
1035-
expected = bottleneck.push(array, axis=0, n=n)
1036-
for c in range(1, 11):
1056+
expected = bottleneck.push(arr, axis=0, n=n)
1057+
for c in chunks:
10371058
with raise_if_dask_computes():
1038-
actual = push(
1039-
dask.array.from_array(array, chunks=c), axis=0, n=n, method=method
1040-
)
1059+
actual = push(da.from_array(arr, chunks=c), axis=0, n=n, method=method)
10411060
np.testing.assert_equal(actual, expected)
10421061

1043-
# some chunks of size-1 with NaN
1044-
with raise_if_dask_computes():
1045-
actual = push(
1046-
dask.array.from_array(array, chunks=(1, 2, 3, 2, 2, 1, 1)),
1047-
axis=0,
1048-
n=n,
1049-
method=method,
1050-
)
1051-
np.testing.assert_equal(actual, expected)
1052-
10531062

10541063
def test_extension_array_equality(categorical1, int1):
10551064
int_duck_array = PandasExtensionArray(int1)

0 commit comments

Comments
 (0)