Skip to content

Commit d98baf4

Browse files
authored
Consolidate _get_alpha func (#8465)
* Consolidate `_get_alpha` func Am changing this a bit so starting with consolidating it rather than converting twice
1 parent bb8511e commit d98baf4

File tree

1 file changed

+12
-29
lines changed

1 file changed

+12
-29
lines changed

xarray/core/rolling_exp.py

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,51 +25,34 @@ def _get_alpha(
2525
span: float | None = None,
2626
halflife: float | None = None,
2727
alpha: float | None = None,
28-
) -> float:
29-
# pandas defines in terms of com (converting to alpha in the algo)
30-
# so use its function to get a com and then convert to alpha
31-
32-
com = _get_center_of_mass(com, span, halflife, alpha)
33-
return 1 / (1 + com)
34-
35-
36-
def _get_center_of_mass(
37-
comass: float | None,
38-
span: float | None,
39-
halflife: float | None,
40-
alpha: float | None,
4128
) -> float:
4229
"""
43-
Vendored from pandas.core.window.common._get_center_of_mass
44-
45-
See licenses/PANDAS_LICENSE for the function's license
30+
Convert com, span, halflife to alpha.
4631
"""
47-
valid_count = count_not_none(comass, span, halflife, alpha)
32+
valid_count = count_not_none(com, span, halflife, alpha)
4833
if valid_count > 1:
49-
raise ValueError("comass, span, halflife, and alpha are mutually exclusive")
34+
raise ValueError("com, span, halflife, and alpha are mutually exclusive")
5035

51-
# Convert to center of mass; domain checks ensure 0 < alpha <= 1
52-
if comass is not None:
53-
if comass < 0:
54-
raise ValueError("comass must satisfy: comass >= 0")
36+
# Convert to alpha
37+
if com is not None:
38+
if com < 0:
39+
raise ValueError("commust satisfy: com>= 0")
40+
return 1 / (com + 1)
5541
elif span is not None:
5642
if span < 1:
5743
raise ValueError("span must satisfy: span >= 1")
58-
comass = (span - 1) / 2.0
44+
return 2 / (span + 1)
5945
elif halflife is not None:
6046
if halflife <= 0:
6147
raise ValueError("halflife must satisfy: halflife > 0")
62-
decay = 1 - np.exp(np.log(0.5) / halflife)
63-
comass = 1 / decay - 1
48+
return 1 - np.exp(np.log(0.5) / halflife)
6449
elif alpha is not None:
65-
if alpha <= 0 or alpha > 1:
50+
if not 0 < alpha <= 1:
6651
raise ValueError("alpha must satisfy: 0 < alpha <= 1")
67-
comass = (1.0 - alpha) / alpha
52+
return alpha
6853
else:
6954
raise ValueError("Must pass one of comass, span, halflife, or alpha")
7055

71-
return float(comass)
72-
7356

7457
class RollingExp(Generic[T_DataWithCoords]):
7558
"""

0 commit comments

Comments
 (0)