8
8
from xarray .core .computation import apply_ufunc
9
9
from xarray .core .options import _get_keep_attrs
10
10
from xarray .core .pdcompat import count_not_none
11
- from xarray .core .pycompat import is_duck_dask_array
12
- from xarray .core .types import T_DataWithCoords , T_DuckArray
11
+ from xarray .core .types import T_DataWithCoords
12
+
13
+ try :
14
+ import numbagg
15
+ from numbagg import move_exp_nanmean , move_exp_nansum
16
+
17
+ has_numbagg = numbagg .__version__
18
+ except ImportError :
19
+ has_numbagg = False
13
20
14
21
15
22
def _get_alpha (
@@ -25,26 +32,6 @@ def _get_alpha(
25
32
return 1 / (1 + com )
26
33
27
34
28
- def move_exp_nanmean (array : T_DuckArray , * , axis : int , alpha : float ) -> np .ndarray :
29
- if is_duck_dask_array (array ):
30
- raise TypeError ("rolling_exp is not currently support for dask-like arrays" )
31
- import numbagg
32
-
33
- # No longer needed in numbag > 0.2.0; remove in time
34
- if axis == ():
35
- return array .astype (np .float64 )
36
- else :
37
- return numbagg .move_exp_nanmean (array , axis = axis , alpha = alpha )
38
-
39
-
40
- def move_exp_nansum (array : T_DuckArray , * , axis : int , alpha : float ) -> np .ndarray :
41
- if is_duck_dask_array (array ):
42
- raise TypeError ("rolling_exp is not currently supported for dask-like arrays" )
43
- import numbagg
44
-
45
- return numbagg .move_exp_nansum (array , axis = axis , alpha = alpha )
46
-
47
-
48
35
def _get_center_of_mass (
49
36
comass : float | None ,
50
37
span : float | None ,
@@ -110,11 +97,31 @@ def __init__(
110
97
obj : T_DataWithCoords ,
111
98
windows : Mapping [Any , int | float ],
112
99
window_type : str = "span" ,
100
+ min_weight : float = 0.0 ,
113
101
):
102
+ if has_numbagg is False :
103
+ raise ImportError (
104
+ "numbagg >= 0.2.1 is required for rolling_exp but currently numbagg is not installed"
105
+ )
106
+ elif has_numbagg < "0.2.1" :
107
+ raise ImportError (
108
+ f"numbagg >= 0.2.1 is required for rolling_exp but currently version { has_numbagg } is installed"
109
+ )
110
+ elif has_numbagg < "0.3.1" and min_weight > 0 :
111
+ raise ImportError (
112
+ f"numbagg >= 0.3.1 is required for `min_weight > 0` but currently version { has_numbagg } is installed"
113
+ )
114
+
114
115
self .obj : T_DataWithCoords = obj
115
116
dim , window = next (iter (windows .items ()))
116
117
self .dim = dim
117
118
self .alpha = _get_alpha (** {window_type : window })
119
+ self .min_weight = min_weight
120
+ # Don't pass min_weight=0 so we can support older versions of numbagg
121
+ kwargs = dict (alpha = self .alpha , axis = - 1 )
122
+ if min_weight > 0 :
123
+ kwargs ["min_weight" ] = min_weight
124
+ self .kwargs = kwargs
118
125
119
126
def mean (self , keep_attrs : bool | None = None ) -> T_DataWithCoords :
120
127
"""
@@ -145,7 +152,7 @@ def mean(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
145
152
move_exp_nanmean ,
146
153
self .obj ,
147
154
input_core_dims = [[self .dim ]],
148
- kwargs = dict ( alpha = self .alpha , axis = - 1 ) ,
155
+ kwargs = self .kwargs ,
149
156
output_core_dims = [[self .dim ]],
150
157
keep_attrs = keep_attrs ,
151
158
on_missing_core_dim = "copy" ,
@@ -181,7 +188,7 @@ def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
181
188
move_exp_nansum ,
182
189
self .obj ,
183
190
input_core_dims = [[self .dim ]],
184
- kwargs = dict ( alpha = self .alpha , axis = - 1 ) ,
191
+ kwargs = self .kwargs ,
185
192
output_core_dims = [[self .dim ]],
186
193
keep_attrs = keep_attrs ,
187
194
on_missing_core_dim = "copy" ,
0 commit comments