Skip to content

Commit 993300b

Browse files
rolling.construct: Add sliding_window_kwargs to pipe arguments down to sliding_window_view (#9720)
* sliding_window_view: add new `automatic_rechunk` kwarg Closes #9550 xref #4325 * Switch to ``sliding_window_kwargs`` * Add one more * better docstring * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Rename to sliding_window_view_kwargs --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 3f0ddc1 commit 993300b

File tree

6 files changed

+230
-27
lines changed

6 files changed

+230
-27
lines changed

doc/whats-new.rst

+3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ New Features
2929
- Support lazy grouping by dask arrays, and allow specifying ordered groups with ``UniqueGrouper(labels=["a", "b", "c"])``
3030
(:issue:`2852`, :issue:`757`).
3131
By `Deepak Cherian <https://github.com/dcherian>`_.
32+
- Add new ``automatic_rechunk`` kwarg to :py:meth:`DataArrayRolling.construct` and
33+
:py:meth:`DatasetRolling.construct`. This is only useful on ``dask>=2024.11.0``
34+
(:issue:`9550`). By `Deepak Cherian <https://github.com/dcherian>`_.
3235
- Optimize ffill, bfill with dask when limit is specified
3336
(:pull:`9771`).
3437
By `Joseph Nowak <https://github.com/josephnowak>`_, and

xarray/core/dask_array_compat.py

+16
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,19 @@ def reshape_blockwise(
1414
return reshape_blockwise(x, shape=shape, chunks=chunks)
1515
else:
1616
return x.reshape(shape)
17+
18+
19+
def sliding_window_view(
20+
x, window_shape, axis=None, *, automatic_rechunk=True, **kwargs
21+
):
22+
# Backcompat for handling `automatic_rechunk`, delete when dask>=2024.11.0
23+
# Note that subok, writeable are unsupported by dask, so we ignore those in kwargs
24+
from dask.array.lib.stride_tricks import sliding_window_view
25+
26+
if module_available("dask", "2024.11.0"):
27+
return sliding_window_view(
28+
x, window_shape=window_shape, axis=axis, automatic_rechunk=automatic_rechunk
29+
)
30+
else:
31+
# automatic_rechunk is not supported
32+
return sliding_window_view(x, window_shape=window_shape, axis=axis)

xarray/core/duck_array_ops.py

+24-9
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,10 @@
3030
transpose,
3131
unravel_index,
3232
)
33-
from numpy.lib.stride_tricks import sliding_window_view # noqa: F401
3433
from packaging.version import Version
3534
from pandas.api.types import is_extension_array_dtype
3635

37-
from xarray.core import dask_array_ops, dtypes, nputils
36+
from xarray.core import dask_array_compat, dask_array_ops, dtypes, nputils
3837
from xarray.core.options import OPTIONS
3938
from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available
4039
from xarray.namedarray import pycompat
@@ -92,19 +91,25 @@ def _dask_or_eager_func(
9291
name,
9392
eager_module=np,
9493
dask_module="dask.array",
94+
dask_only_kwargs=tuple(),
95+
numpy_only_kwargs=tuple(),
9596
):
9697
"""Create a function that dispatches to dask for dask array inputs."""
9798

9899
def f(*args, **kwargs):
99-
if any(is_duck_dask_array(a) for a in args):
100+
if dask_available and any(is_duck_dask_array(a) for a in args):
100101
mod = (
101102
import_module(dask_module)
102103
if isinstance(dask_module, str)
103104
else dask_module
104105
)
105106
wrapped = getattr(mod, name)
107+
for kwarg in numpy_only_kwargs:
108+
kwargs.pop(kwarg, None)
106109
else:
107110
wrapped = getattr(eager_module, name)
111+
for kwarg in dask_only_kwargs:
112+
kwargs.pop(kwarg, None)
108113
return wrapped(*args, **kwargs)
109114

110115
return f
@@ -122,6 +127,22 @@ def fail_on_dask_array_input(values, msg=None, func_name=None):
122127
# Requires special-casing because pandas won't automatically dispatch to dask.isnull via NEP-18
123128
pandas_isnull = _dask_or_eager_func("isnull", eager_module=pd, dask_module="dask.array")
124129

130+
# TODO replace with simply np.ma.masked_invalid once numpy/numpy#16022 is fixed
131+
# TODO: replacing breaks iris + dask tests
132+
masked_invalid = _dask_or_eager_func(
133+
"masked_invalid", eager_module=np.ma, dask_module="dask.array.ma"
134+
)
135+
136+
# sliding_window_view will not dispatch arbitrary kwargs (automatic_rechunk),
137+
# so we need to hand-code this.
138+
sliding_window_view = _dask_or_eager_func(
139+
"sliding_window_view",
140+
eager_module=np.lib.stride_tricks,
141+
dask_module=dask_array_compat,
142+
dask_only_kwargs=("automatic_rechunk",),
143+
numpy_only_kwargs=("subok", "writeable"),
144+
)
145+
125146

126147
def round(array):
127148
xp = get_array_namespace(array)
@@ -170,12 +191,6 @@ def notnull(data):
170191
return ~isnull(data)
171192

172193

173-
# TODO replace with simply np.ma.masked_invalid once numpy/numpy#16022 is fixed
174-
masked_invalid = _dask_or_eager_func(
175-
"masked_invalid", eager_module=np.ma, dask_module="dask.array.ma"
176-
)
177-
178-
179194
def trapz(y, x, axis):
180195
if axis < 0:
181196
axis = y.ndim + axis

0 commit comments

Comments
 (0)