diff --git a/pandas/core/window/indexers.py b/pandas/core/window/indexers.py index f9a2eead9cb87..a81d712190f79 100644 --- a/pandas/core/window/indexers.py +++ b/pandas/core/window/indexers.py @@ -3,7 +3,10 @@ import numpy as np -from pandas.tseries.offsets import DateOffset +from pandas.core.indexes.datetimes import DatetimeIndex + +from pandas.tseries.frequencies import to_offset +from pandas.tseries.offsets import BusinessMixin, DateOffset BeginEnd = Tuple[np.ndarray, np.ndarray] @@ -231,3 +234,30 @@ def get_window_bounds( end[i] -= 1 return start, end + + +class BusinessWindowIndexer(VariableWindowIndexer): + """Calculate window bounds based on business frequencies.""" + + def __init__(self, index=None, offset=None, keys=None): + super().__init__(index, offset, keys) + self.offset = to_offset(self.offset) + if not isinstance(self.offset, BusinessMixin): + raise ValueError( + "BusinessWindowIndexer only supports Business frequencies." + ) + if not isinstance(self.index, DatetimeIndex): + raise ValueError("BusinessWindowIndexer only supports DatetimeIndexes.") + + def get_window_bounds( + self, + num_values: int = 0, + window_size: int = 0, + min_periods: Optional[int] = None, + center: Optional[bool] = None, + closed: Optional[str] = None, + win_type: Optional[str] = None, + ) -> BeginEnd: + return super().get_window_bounds( + num_values, self.offset, min_periods, center, closed, win_type + ) diff --git a/pandas/core/window/rolling.py b/pandas/core/window/rolling.py index ec300d5578881..0609bc8e1910a 100644 --- a/pandas/core/window/rolling.py +++ b/pandas/core/window/rolling.py @@ -395,6 +395,8 @@ def _get_window_indexer(self, index_as_array): ------- VariableWindowIndexer or FixedWindowIndexer """ + if isinstance(self.window, BaseIndexer): + return self.window if self.is_freq_type: return VariableWindowIndexer(index=index_as_array) return FixedWindowIndexer(index=index_as_array) @@ -436,7 +438,14 @@ def _apply( check_minp = _use_window if window is None: - apply_window = self._get_window(**kwargs) # type: int + apply_window = self._get_window(**kwargs) + else: + apply_window = window + + if isinstance(apply_window, BaseIndexer): + # This value isn't significant for subclasses of BaseIndexer + # but is passed along to other validation checks. + apply_window = 0 blocks, obj = self._create_blocks() block_list = list(blocks) @@ -898,12 +907,8 @@ def _pop_args(win_type, arg_names, kwargs): # GH #15662. `False` makes symmetric window, rather than periodic. return sig.get_window(win_type, window, False).astype(float) elif isinstance(window, BaseIndexer): - return window.get_window_span( - win_type=self.win_type, - min_periods=self.min_periods, - center=self.center, - closed=self.closed, - ) + # Defer calling `.get_window_bounds` until later? + return window def _get_roll_func( self, cfunc: Callable, check_minp: Callable, index: np.ndarray, **kwargs diff --git a/pandas/tests/window/test_custom_indexer.py b/pandas/tests/window/test_custom_indexer.py index 1d32d6619422a..07b78fb8fc395 100644 --- a/pandas/tests/window/test_custom_indexer.py +++ b/pandas/tests/window/test_custom_indexer.py @@ -1,4 +1,6 @@ -from pandas import Series +from pandas import Series, date_range +from pandas.core.window.indexers import BusinessWindowIndexer +import pandas.util.testing as tm def test_custom_indexer_validates( @@ -13,3 +15,11 @@ def test_custom_indexer_validates( min_periods=min_periods, closed=closed, ) + + +def test_rolling_business_day(): + index = date_range("2019-09-06", "2019-09-13", freq="D") + s = Series(range(len(index)), index=index) + result = s.rolling(BusinessWindowIndexer(index=s.index, offset="B")).mean() + expected = Series([0, 1, 1.5, 2, 4, 5, 6, 7], index=index) + tm.assert_series_equal(result, expected)