diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py index c8160cefef3..04d7dd41966 100644 --- a/xarray/core/rolling_exp.py +++ b/xarray/core/rolling_exp.py @@ -25,51 +25,34 @@ def _get_alpha( span: float | None = None, halflife: float | None = None, alpha: float | None = None, -) -> float: - # pandas defines in terms of com (converting to alpha in the algo) - # so use its function to get a com and then convert to alpha - - com = _get_center_of_mass(com, span, halflife, alpha) - return 1 / (1 + com) - - -def _get_center_of_mass( - comass: float | None, - span: float | None, - halflife: float | None, - alpha: float | None, ) -> float: """ - Vendored from pandas.core.window.common._get_center_of_mass - - See licenses/PANDAS_LICENSE for the function's license + Convert com, span, halflife to alpha. """ - valid_count = count_not_none(comass, span, halflife, alpha) + valid_count = count_not_none(com, span, halflife, alpha) if valid_count > 1: - raise ValueError("comass, span, halflife, and alpha are mutually exclusive") + raise ValueError("com, span, halflife, and alpha are mutually exclusive") - # Convert to center of mass; domain checks ensure 0 < alpha <= 1 - if comass is not None: - if comass < 0: - raise ValueError("comass must satisfy: comass >= 0") + # Convert to alpha + if com is not None: + if com < 0: + raise ValueError("commust satisfy: com>= 0") + return 1 / (com + 1) elif span is not None: if span < 1: raise ValueError("span must satisfy: span >= 1") - comass = (span - 1) / 2.0 + return 2 / (span + 1) elif halflife is not None: if halflife <= 0: raise ValueError("halflife must satisfy: halflife > 0") - decay = 1 - np.exp(np.log(0.5) / halflife) - comass = 1 / decay - 1 + return 1 - np.exp(np.log(0.5) / halflife) elif alpha is not None: - if alpha <= 0 or alpha > 1: + if not 0 < alpha <= 1: raise ValueError("alpha must satisfy: 0 < alpha <= 1") - comass = (1.0 - alpha) / alpha + return alpha else: raise ValueError("Must pass one of comass, span, halflife, or alpha") - return float(comass) - class RollingExp(Generic[T_DataWithCoords]): """