From bd4051a7920df2f75c3c65b48c65dfda88494d8a Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 9 Apr 2025 10:29:34 -0600 Subject: [PATCH] Better handling of RangeIndex for sparse reindexing --- flox/core.py | 16 +++++++++++++++- tests/test_core.py | 31 +++++++++++++++++++++---------- 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/flox/core.py b/flox/core.py index 96555d4f..acb3e0c6 100644 --- a/flox/core.py +++ b/flox/core.py @@ -762,7 +762,19 @@ def reindex_pydata_sparse_coo(array, from_: pd.Index, to: pd.Index, fill_value, assert axis == -1 - needs_reindex = (from_.get_indexer(to) == -1).any() + # Are there any elements in `to` that are not in `from_`. + if isinstance(to, pd.RangeIndex) and len(to) > len(from_): + # 1. pandas optimizes set difference between two RangeIndexes only + # 2. We want to avoid realizing a very large numpy array in to memory. + # This happens in the `else` clause. + # There are potentially other tricks we can play, but this is a simple + # and effective one. If a user is reindexing to sparse, then len(to) is + # almost guaranteed to be > len(from_). If len(to) <= len(from_), then realizing + # another array of the same shape should be fine. + needs_reindex = True + else: + needs_reindex = (from_.get_indexer(to) == -1).any() + if needs_reindex and fill_value is None: raise ValueError("Filling is required. fill_value cannot be None.") @@ -2315,6 +2327,8 @@ def _factorize_multiple( if any_by_dask: import dask.array + from . import dask_array_ops # noqa + # unifying chunks will make sure all arrays in `by` are dask arrays # with compatible chunks, even if there was originally a numpy array inds = tuple(range(by[0].ndim)) diff --git a/tests/test_core.py b/tests/test_core.py index 39367764..773e4fd8 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -2085,12 +2085,14 @@ def test_datetime_timedelta_first_last(engine, func) -> None: @requires_dask @requires_sparse -def test_reindex_sparse(): +@pytest.mark.xdist_group(name="sparse-group") +@pytest.mark.parametrize("size", [2**62 - 1, 11]) +def test_reindex_sparse(size): import sparse array = dask.array.ones((2, 12), chunks=(-1, 3)) func = "sum" - expected_groups = pd.Index(np.arange(11)) + expected_groups = pd.RangeIndex(size) by = dask.array.from_array(np.repeat(np.arange(6) * 2, 2), chunks=(3,)) dense = np.zeros((2, 11)) dense[..., np.arange(6) * 2] = 2 @@ -2110,14 +2112,23 @@ def mocked_reindex(*args, **kwargs): assert isinstance(res, sparse.COO) return res - with patch("flox.core.reindex_") as mocked_func: - mocked_func.side_effect = mocked_reindex - actual, *_ = groupby_reduce( - array, by, func=func, reindex=reindex, expected_groups=expected_groups, fill_value=0 - ) - assert_equal(actual, expected) - # once during graph construction, 10 times afterward - assert mocked_func.call_count > 1 + # Define the error-raising property + def raise_error(self): + raise AttributeError("Access to '_data' is not allowed.") + + with patch("flox.core.reindex_") as mocked_reindex_func: + with patch.object(pd.RangeIndex, "_data", property(raise_error)): + mocked_reindex_func.side_effect = mocked_reindex + actual, *_ = groupby_reduce( + array, by, func=func, reindex=reindex, expected_groups=expected_groups, fill_value=0 + ) + if size == 11: + assert_equal(actual, expected) + else: + actual.compute() # just compute + + # once during graph construction, 10 times afterward + assert mocked_reindex_func.call_count > 1 def test_sparse_errors():