diff --git a/asv_bench/benchmarks/rolling.py b/asv_bench/benchmarks/rolling.py index 52814ad3481..3f2a38104de 100644 --- a/asv_bench/benchmarks/rolling.py +++ b/asv_bench/benchmarks/rolling.py @@ -8,27 +8,44 @@ from . import parameterized, randn, requires_dask nx = 3000 +long_nx = 30000000 ny = 2000 nt = 1000 window = 20 +randn_xy = randn((nx, ny), frac_nan=0.1) +randn_xt = randn((nx, nt)) +randn_t = randn((nt, )) +randn_long = randn((long_nx, ), frac_nan=0.1) + class Rolling(object): def setup(self, *args, **kwargs): self.ds = xr.Dataset( - {'var1': (('x', 'y'), randn((nx, ny), frac_nan=0.1)), - 'var2': (('x', 't'), randn((nx, nt))), - 'var3': (('t', ), randn(nt))}, + {'var1': (('x', 'y'), randn_xy), + 'var2': (('x', 't'), randn_xt), + 'var3': (('t', ), randn_t)}, coords={'x': np.arange(nx), 'y': np.linspace(0, 1, ny), 't': pd.date_range('1970-01-01', periods=nt, freq='D'), 'x_coords': ('x', np.linspace(1.1, 2.1, nx))}) + self.da_long = xr.DataArray(randn_long, dims='x', + coords={'x': np.arange(long_nx) * 0.1}) @parameterized(['func', 'center'], (['mean', 'count'], [True, False])) def time_rolling(self, func, center): getattr(self.ds.rolling(x=window, center=center), func)() + @parameterized(['func', 'pandas'], + (['mean', 'count'], [True, False])) + def time_rolling_long(self, func, pandas): + if pandas: + se = self.da_long.to_series() + getattr(se.rolling(window=window), func)() + else: + getattr(self.da_long.rolling(x=window), func)() + @parameterized(['window_', 'min_periods'], ([20, 40], [5, None])) def time_rolling_np(self, window_, min_periods): @@ -47,3 +64,4 @@ def setup(self, *args, **kwargs): requires_dask() super(RollingDask, self).setup(**kwargs) self.ds = self.ds.chunk({'x': 100, 'y': 50, 't': 50}) + self.da_long = self.da_long.chunk({'x': 10000}) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 40e206aaa86..1a252b15fd0 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -37,6 +37,10 @@ Documentation Enhancements ~~~~~~~~~~~~ + - Some speed improvement to construct :py:class:`~xarray.DataArrayRolling` + object (:issue:`1993`) + By `Keisuke Fujii `_. + Bug fixes ~~~~~~~~~ diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index fb09c9e0df3..079c60f35a7 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -151,34 +151,21 @@ def __init__(self, obj, min_periods=None, center=False, **windows): """ super(DataArrayRolling, self).__init__(obj, min_periods=min_periods, center=center, **windows) - self.window_indices = None - self.window_labels = None - self._setup_windows() + self.window_labels = self.obj[self.dim] def __iter__(self): - for (label, indices) in zip(self.window_labels, self.window_indices): - window = self.obj.isel(**{self.dim: indices}) + stops = np.arange(1, len(self.window_labels) + 1) + starts = stops - int(self.window) + starts[:int(self.window)] = 0 + for (label, start, stop) in zip(self.window_labels, starts, stops): + window = self.obj.isel(**{self.dim: slice(start, stop)}) counts = window.count(dim=self.dim) window = window.where(counts >= self._min_periods) yield (label, window) - def _setup_windows(self): - """ - Find the indices and labels for each window - """ - self.window_labels = self.obj[self.dim] - window = int(self.window) - dim_size = self.obj[self.dim].size - - stops = np.arange(dim_size) + 1 - starts = np.maximum(stops - window, 0) - - self.window_indices = [slice(start, stop) - for start, stop in zip(starts, stops)] - def construct(self, window_dim, stride=1, fill_value=dtypes.NA): """ Convert this rolling object to xr.DataArray,