Skip to content

Commit dafd726

Browse files
authored
Add min_weight param to rolling_exp functions (#8285)
* Add `min_weight` param to `rolling_exp` functions * whatsnew
1 parent 8f7e8b5 commit dafd726

File tree

2 files changed

+35
-24
lines changed

2 files changed

+35
-24
lines changed

doc/whats-new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ New Features
2626
the ``other`` parameter, passing the object as the only argument. Previously,
2727
this was only valid for the ``cond`` parameter. (:issue:`8255`)
2828
By `Maximilian Roos <https://github.com/max-sixty>`_.
29+
- ``.rolling_exp`` functions can now take a ``min_weight`` parameter, to only
30+
output values when there are sufficient recent non-nan values.
31+
``numbagg>=0.3.1`` is required. (:pull:`8285`)
32+
By `Maximilian Roos <https://github.com/max-sixty>`_.
2933
- :py:meth:`DataArray.sortby` & :py:meth:`Dataset.sortby` accept a callable for
3034
the ``variables`` parameter, passing the object as the only argument.
3135
By `Maximilian Roos <https://github.com/max-sixty>`_.

xarray/core/rolling_exp.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,15 @@
88
from xarray.core.computation import apply_ufunc
99
from xarray.core.options import _get_keep_attrs
1010
from xarray.core.pdcompat import count_not_none
11-
from xarray.core.pycompat import is_duck_dask_array
12-
from xarray.core.types import T_DataWithCoords, T_DuckArray
11+
from xarray.core.types import T_DataWithCoords
12+
13+
try:
14+
import numbagg
15+
from numbagg import move_exp_nanmean, move_exp_nansum
16+
17+
has_numbagg = numbagg.__version__
18+
except ImportError:
19+
has_numbagg = False
1320

1421

1522
def _get_alpha(
@@ -25,26 +32,6 @@ def _get_alpha(
2532
return 1 / (1 + com)
2633

2734

28-
def move_exp_nanmean(array: T_DuckArray, *, axis: int, alpha: float) -> np.ndarray:
29-
if is_duck_dask_array(array):
30-
raise TypeError("rolling_exp is not currently support for dask-like arrays")
31-
import numbagg
32-
33-
# No longer needed in numbag > 0.2.0; remove in time
34-
if axis == ():
35-
return array.astype(np.float64)
36-
else:
37-
return numbagg.move_exp_nanmean(array, axis=axis, alpha=alpha)
38-
39-
40-
def move_exp_nansum(array: T_DuckArray, *, axis: int, alpha: float) -> np.ndarray:
41-
if is_duck_dask_array(array):
42-
raise TypeError("rolling_exp is not currently supported for dask-like arrays")
43-
import numbagg
44-
45-
return numbagg.move_exp_nansum(array, axis=axis, alpha=alpha)
46-
47-
4835
def _get_center_of_mass(
4936
comass: float | None,
5037
span: float | None,
@@ -110,11 +97,31 @@ def __init__(
11097
obj: T_DataWithCoords,
11198
windows: Mapping[Any, int | float],
11299
window_type: str = "span",
100+
min_weight: float = 0.0,
113101
):
102+
if has_numbagg is False:
103+
raise ImportError(
104+
"numbagg >= 0.2.1 is required for rolling_exp but currently numbagg is not installed"
105+
)
106+
elif has_numbagg < "0.2.1":
107+
raise ImportError(
108+
f"numbagg >= 0.2.1 is required for rolling_exp but currently version {has_numbagg} is installed"
109+
)
110+
elif has_numbagg < "0.3.1" and min_weight > 0:
111+
raise ImportError(
112+
f"numbagg >= 0.3.1 is required for `min_weight > 0` but currently version {has_numbagg} is installed"
113+
)
114+
114115
self.obj: T_DataWithCoords = obj
115116
dim, window = next(iter(windows.items()))
116117
self.dim = dim
117118
self.alpha = _get_alpha(**{window_type: window})
119+
self.min_weight = min_weight
120+
# Don't pass min_weight=0 so we can support older versions of numbagg
121+
kwargs = dict(alpha=self.alpha, axis=-1)
122+
if min_weight > 0:
123+
kwargs["min_weight"] = min_weight
124+
self.kwargs = kwargs
118125

119126
def mean(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
120127
"""
@@ -145,7 +152,7 @@ def mean(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
145152
move_exp_nanmean,
146153
self.obj,
147154
input_core_dims=[[self.dim]],
148-
kwargs=dict(alpha=self.alpha, axis=-1),
155+
kwargs=self.kwargs,
149156
output_core_dims=[[self.dim]],
150157
keep_attrs=keep_attrs,
151158
on_missing_core_dim="copy",
@@ -181,7 +188,7 @@ def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
181188
move_exp_nansum,
182189
self.obj,
183190
input_core_dims=[[self.dim]],
184-
kwargs=dict(alpha=self.alpha, axis=-1),
191+
kwargs=self.kwargs,
185192
output_core_dims=[[self.dim]],
186193
keep_attrs=keep_attrs,
187194
on_missing_core_dim="copy",

0 commit comments

Comments
 (0)