Skip to content

Commit 8309599

Browse files
josephnowakpre-commit-ci[bot]dcherian
authored
New algorithm for forward filling (#6118)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian <[email protected]>
1 parent be4b980 commit 8309599

File tree

4 files changed

+52
-31
lines changed

4 files changed

+52
-31
lines changed

doc/whats-new.rst

+5
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ New Features
2424
- New top-level function :py:func:`cross`. (:issue:`3279`, :pull:`5365`).
2525
By `Jimmy Westling <https://github.com/illviljan>`_.
2626

27+
- Enable the limit option for dask array in the following methods :py:meth:`DataArray.ffill`, :py:meth:`DataArray.bfill`, :py:meth:`Dataset.ffill` and :py:meth:`Dataset.bfill` (:issue:`6112`)
28+
By `Joseph Nowak <https://github.com/josephnowak>`_.
2729

2830
Breaking changes
2931
~~~~~~~~~~~~~~~~
@@ -45,6 +47,9 @@ Deprecations
4547

4648
Bug fixes
4749
~~~~~~~~~
50+
- Properly support :py:meth:`DataArray.ffill`, :py:meth:`DataArray.bfill`, :py:meth:`Dataset.ffill` and :py:meth:`Dataset.bfill` along chunked dimensions (:issue:`6112`).
51+
By `Joseph Nowak <https://github.com/josephnowak>`_.
52+
4853
- Subclasses of ``byte`` and ``str`` (e.g. ``np.str_`` and ``np.bytes_``) will now serialise to disk rather than raising a ``ValueError: unsupported dtype for netCDF4 variable: object`` as they did previously (:pull:`5264`).
4954
By `Zeb Nicholls <https://github.com/znicholls>`_.
5055

xarray/core/dask_array_ops.py

+31-19
Original file line numberDiff line numberDiff line change
@@ -57,24 +57,36 @@ def push(array, n, axis):
5757
"""
5858
Dask-aware bottleneck.push
5959
"""
60-
from bottleneck import push
60+
import bottleneck
61+
import dask.array as da
62+
import numpy as np
6163

62-
if len(array.chunks[axis]) > 1 and n is not None and n < array.shape[axis]:
63-
raise NotImplementedError(
64-
"Cannot fill along a chunked axis when limit is not None."
65-
"Either rechunk to a single chunk along this axis or call .compute() or .load() first."
66-
)
67-
if all(c == 1 for c in array.chunks[axis]):
68-
array = array.rechunk({axis: 2})
69-
pushed = array.map_blocks(push, axis=axis, n=n, dtype=array.dtype, meta=array._meta)
70-
if len(array.chunks[axis]) > 1:
71-
pushed = pushed.map_overlap(
72-
push,
73-
axis=axis,
74-
n=n,
75-
depth={axis: (1, 0)},
76-
boundary="none",
77-
dtype=array.dtype,
78-
meta=array._meta,
64+
def _fill_with_last_one(a, b):
65+
# cumreduction apply the push func over all the blocks first so, the only missing part is filling
66+
# the missing values using the last data of the previous chunk
67+
return np.where(~np.isnan(b), b, a)
68+
69+
if n is not None and 0 < n < array.shape[axis] - 1:
70+
arange = da.broadcast_to(
71+
da.arange(
72+
array.shape[axis], chunks=array.chunks[axis], dtype=array.dtype
73+
).reshape(
74+
tuple(size if i == axis else 1 for i, size in enumerate(array.shape))
75+
),
76+
array.shape,
77+
array.chunks,
7978
)
80-
return pushed
79+
valid_arange = da.where(da.notnull(array), arange, np.nan)
80+
valid_limits = (arange - push(valid_arange, None, axis)) <= n
81+
# omit the forward fill that violate the limit
82+
return da.where(valid_limits, push(array, None, axis), np.nan)
83+
84+
# The method parameter makes that the tests for python 3.7 fails.
85+
return da.reductions.cumreduction(
86+
func=bottleneck.push,
87+
binop=_fill_with_last_one,
88+
ident=np.nan,
89+
x=array,
90+
axis=axis,
91+
dtype=array.dtype,
92+
)

xarray/tests/test_duck_array_ops.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -884,16 +884,18 @@ def test_push_dask():
884884
import bottleneck
885885
import dask.array
886886

887-
array = np.array([np.nan, np.nan, np.nan, 1, 2, 3, np.nan, np.nan, 4, 5, np.nan, 6])
888-
expected = bottleneck.push(array, axis=0)
889-
for c in range(1, 11):
887+
array = np.array([np.nan, 1, 2, 3, np.nan, np.nan, np.nan, np.nan, 4, 5, np.nan, 6])
888+
889+
for n in [None, 1, 2, 3, 4, 5, 11]:
890+
expected = bottleneck.push(array, axis=0, n=n)
891+
for c in range(1, 11):
892+
with raise_if_dask_computes():
893+
actual = push(dask.array.from_array(array, chunks=c), axis=0, n=n)
894+
np.testing.assert_equal(actual, expected)
895+
896+
# some chunks of size-1 with NaN
890897
with raise_if_dask_computes():
891-
actual = push(dask.array.from_array(array, chunks=c), axis=0, n=None)
898+
actual = push(
899+
dask.array.from_array(array, chunks=(1, 2, 3, 2, 2, 1, 1)), axis=0, n=n
900+
)
892901
np.testing.assert_equal(actual, expected)
893-
894-
# some chunks of size-1 with NaN
895-
with raise_if_dask_computes():
896-
actual = push(
897-
dask.array.from_array(array, chunks=(1, 2, 3, 2, 2, 1, 1)), axis=0, n=None
898-
)
899-
np.testing.assert_equal(actual, expected)

xarray/tests/test_missing.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -452,8 +452,10 @@ def test_ffill_bfill_dask(method):
452452
assert_equal(actual, expected)
453453

454454
# limit < axis size
455-
with pytest.raises(NotImplementedError):
455+
with raise_if_dask_computes():
456456
actual = dask_method("x", limit=2)
457+
expected = numpy_method("x", limit=2)
458+
assert_equal(actual, expected)
457459

458460
# limit > axis size
459461
with raise_if_dask_computes():

0 commit comments

Comments
 (0)