Skip to content

Fix sparse reindexing some more. #437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/user-stories/large-zonal-stats.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@
" blockwise=False,\n",
" array_type=ReindexArrayType.SPARSE_COO,\n",
" ),\n",
" fill_value=0,\n",
")\n",
"result"
]
Expand Down
60 changes: 49 additions & 11 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef]

HAS_NUMBAGG = module_available("numbagg", minversion="0.3.0")
HAS_SPARSE = module_available("sparse")

if TYPE_CHECKING:
try:
Expand Down Expand Up @@ -255,6 +256,12 @@ def _is_bool_supported_reduction(func: T_Agg) -> bool:
)


def _is_sparse_supported_reduction(func: T_Agg) -> bool:
if isinstance(func, Aggregation):
func = func.name
return HAS_SPARSE and all(f not in func for f in ["first", "last", "prod", "var", "std"])


def _get_expected_groups(by: T_By, sort: bool) -> T_ExpectIndex:
if is_duck_dask_array(by):
raise ValueError("Please provide expected_groups if not grouping by a numpy array.")
Expand Down Expand Up @@ -736,12 +743,12 @@ def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) ->
return array.rechunk({axis: newchunks})


def reindex_numpy(array, from_, to, fill_value, dtype, axis):
def reindex_numpy(array, from_: pd.Index, to: pd.Index, fill_value, dtype, axis: int):
idx = from_.get_indexer(to)
indexer = [slice(None, None)] * array.ndim
indexer[axis] = idx
reindexed = array[tuple(indexer)]
if any(idx == -1):
if (idx == -1).any():
if fill_value is None:
raise ValueError("Filling is required. fill_value cannot be None.")
indexer[axis] = idx == -1
Expand All @@ -750,25 +757,43 @@ def reindex_numpy(array, from_, to, fill_value, dtype, axis):
return reindexed


def reindex_pydata_sparse_coo(array, from_, to, fill_value, dtype, axis):
def reindex_pydata_sparse_coo(array, from_: pd.Index, to: pd.Index, fill_value, dtype, axis: int):
import sparse

assert axis == -1

if fill_value is None:
needs_reindex = (from_.get_indexer(to) == -1).any()
if needs_reindex and fill_value is None:
raise ValueError("Filling is required. fill_value cannot be None.")

idx = to.get_indexer(from_)
assert (idx != -1).all() # FIXME
mask = idx != -1 # indices along last axis to keep
if mask.all():
mask = slice(None)
shape = array.shape
ranges = np.broadcast_arrays(*np.ix_(*(tuple(np.arange(size) for size in shape[:axis]) + (idx,))))
coords = np.stack(ranges, axis=0).reshape(array.ndim, -1)

data = array.data if isinstance(array, sparse.COO) else array.reshape(-1)
if isinstance(array, sparse.COO):
subset = array[..., mask]
data = subset.data
coords = subset.coords
if subset.nnz > 0:
coords[-1, :] = idx[mask][coords[-1, :]]
if fill_value is None:
# no reindexing is actually needed (dense case)
# preserve the fill_value
fill_value = array.fill_value
else:
ranges = np.broadcast_arrays(
*np.ix_(*(tuple(np.arange(size) for size in shape[:axis]) + (idx[mask],)))
)
coords = np.stack(ranges, axis=0).reshape(array.ndim, -1)
data = array[..., mask].reshape(-1)

reindexed = sparse.COO(
coords=coords,
data=data.astype(dtype, copy=False),
shape=(*array.shape[:axis], to.size),
fill_value=fill_value,
)

return reindexed
Expand All @@ -795,7 +820,11 @@ def reindex_(

if array.shape[axis] == 0:
# all groups were NaN
reindexed = np.full(array.shape[:-1] + (len(to),), fill_value, dtype=array.dtype)
shape = array.shape[:-1] + (len(to),)
if array_type in (ReindexArrayType.AUTO, ReindexArrayType.NUMPY):
reindexed = np.full(shape, fill_value, dtype=array.dtype)
else:
raise NotImplementedError
return reindexed

from_ = pd.Index(from_)
Expand Down Expand Up @@ -1044,7 +1073,7 @@ def chunk_argreduce(
sort=sort,
user_dtype=user_dtype,
)
if not isnull(results["groups"]).all():
if not all(isnull(results["groups"])):
idx = np.broadcast_to(idx, array.shape)

# array, by get flattened to 1D before passing to npg
Expand Down Expand Up @@ -1288,7 +1317,7 @@ def _finalize_results(
fill_value = agg.fill_value["user"]
if min_count > 0:
count_mask = counts < min_count
if count_mask.any():
if count_mask.any() or reindex.array_type is ReindexArrayType.SPARSE_COO:
# For one count_mask.any() prevents promoting bool to dtype(fill_value) unless
# necessary
if fill_value is None:
Expand Down Expand Up @@ -2815,6 +2844,15 @@ def groupby_reduce(
array.dtype,
)

if reindex.array_type is ReindexArrayType.SPARSE_COO:
if not HAS_SPARSE:
raise ImportError("Package 'sparse' must be installed to reindex to a sparse.COO array.")
if not _is_sparse_supported_reduction(func):
raise NotImplementedError(
f"Aggregation {func=!r} is not supported when reindexing to a sparse array. "
"Please raise an issue"
)

if TYPE_CHECKING:
assert isinstance(reindex, ReindexStrategy)
assert method is not None
Expand Down
16 changes: 10 additions & 6 deletions flox/xrutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ def notnull(data):
return out


def isnull(data):
def isnull(data: Any):
if data is None:
return False
if not is_duck_array(data):
data = np.asarray(data)
scalar_type = data.dtype.type
Expand All @@ -177,7 +179,7 @@ def isnull(data):
else:
# at this point, array should have dtype=object
if isinstance(data, (np.ndarray, dask_array_type)): # noqa
return pd.isnull(data)
return pd.isnull(data) # type: ignore[arg-type]
else:
# Not reachable yet, but intended for use with other duck array
# types. For full consistency with pandas, we should accept None as
Expand Down Expand Up @@ -374,9 +376,10 @@ def _select_along_axis(values, idx, axis):
def nanfirst(values, axis, keepdims=False):
if isinstance(axis, tuple):
(axis,) = axis
values = np.asarray(values)
if not is_duck_array(values):
values = np.asarray(values)
axis = normalize_axis_index(axis, values.ndim)
idx_first = np.argmax(~pd.isnull(values), axis=axis)
idx_first = np.argmax(~isnull(values), axis=axis)
result = _select_along_axis(values, idx_first, axis)
if keepdims:
return np.expand_dims(result, axis=axis)
Expand All @@ -387,10 +390,11 @@ def nanfirst(values, axis, keepdims=False):
def nanlast(values, axis, keepdims=False):
if isinstance(axis, tuple):
(axis,) = axis
values = np.asarray(values)
if not is_duck_array(values):
values = np.asarray(values)
axis = normalize_axis_index(axis, values.ndim)
rev = (slice(None),) * axis + (slice(None, None, -1),)
idx_last = -1 - np.argmax(~pd.isnull(values)[rev], axis=axis)
idx_last = -1 - np.argmax(~isnull(values)[rev], axis=axis)
result = _select_along_axis(values, idx_last, axis)
if keepdims:
return np.expand_dims(result, axis=axis)
Expand Down
56 changes: 52 additions & 4 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
_choose_engine,
_convert_expected_groups_to_index,
_get_optimal_chunks_for_groups,
_is_sparse_supported_reduction,
_normalize_indexes,
_validate_reindex,
factorize_,
Expand All @@ -43,6 +44,7 @@
assert_equal_tuple,
has_cubed,
has_dask,
has_sparse,
raise_if_dask_computes,
requires_cubed,
requires_dask,
Expand Down Expand Up @@ -74,6 +76,10 @@ def dask_array_ones(*args):


DEFAULT_QUANTILE = 0.9
REINDEX_SPARSE_STRAT = ReindexStrategy(blockwise=False, array_type=ReindexArrayType.SPARSE_COO)
REINDEX_SPARSE_PARAM = pytest.param(
REINDEX_SPARSE_STRAT, marks=(requires_dask, pytest.mark.skipif(not has_sparse, reason="no sparse"))
)

if TYPE_CHECKING:
from flox.core import T_Agg, T_Engine, T_ExpectedGroupsOpt, T_Method
Expand Down Expand Up @@ -320,13 +326,20 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
if not has_dask or chunks is None or func in BLOCKWISE_FUNCS:
continue

params = list(itertools.product(["map-reduce"], [True, False, None]))
params = list(
itertools.product(
["map-reduce"],
[True, False, None, REINDEX_SPARSE_STRAT],
)
)
params.extend(itertools.product(["cohorts"], [False, None]))
if chunks == -1:
params.extend([("blockwise", None)])

combine_error = RuntimeError("This combine should not have been called.")
for method, reindex in params:
if isinstance(reindex, ReindexStrategy) and not _is_sparse_supported_reduction(func):
continue
call = partial(
groupby_reduce,
array,
Expand Down Expand Up @@ -360,6 +373,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
assert_equal(actual_group, expect, tolerance)
if "arg" in func:
assert actual.dtype.kind == "i"
if isinstance(reindex, ReindexStrategy):
import sparse

expected = sparse.COO.from_numpy(expected)
assert_equal(actual, expected, tolerance)


Expand Down Expand Up @@ -447,7 +464,7 @@ def test_numpy_reduce_nd_md():


@requires_dask
@pytest.mark.parametrize("reindex", [None, False, True])
@pytest.mark.parametrize("reindex", [None, False, True, REINDEX_SPARSE_PARAM])
@pytest.mark.parametrize("func", ALL_FUNCS)
@pytest.mark.parametrize("add_nan", [False, True])
@pytest.mark.parametrize("dtype", (float,))
Expand All @@ -470,6 +487,9 @@ def test_groupby_agg_dask(func, shape, array_chunks, group_chunks, add_nan, dtyp
if "arg" in func and (engine in ["flox", "numbagg"] or reindex):
pytest.skip()

if isinstance(reindex, ReindexStrategy) and not _is_sparse_supported_reduction(func):
pytest.skip()

rng = np.random.default_rng(12345)
array = dask.array.from_array(rng.random(shape), chunks=array_chunks).astype(dtype)
array = dask.array.ones(shape, chunks=array_chunks)
Expand Down Expand Up @@ -775,6 +795,7 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
(None, None),
pytest.param(False, (2, 2, 3), marks=requires_dask),
pytest.param(True, (2, 2, 3), marks=requires_dask),
pytest.param(REINDEX_SPARSE_PARAM, (2, 2, 3), marks=requires_dask),
],
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -821,7 +842,13 @@ def _maybe_chunk(arr):
@requires_dask
@pytest.mark.parametrize(
"expected_groups, reindex",
[(None, None), (None, False), ([0, 1, 2], True), ([0, 1, 2], False)],
[
(None, None),
(None, False),
([0, 1, 2], True),
([0, 1, 2], False),
pytest.param([0, 1, 2], REINDEX_SPARSE_PARAM),
],
)
def test_groupby_all_nan_blocks_dask(expected_groups, reindex, engine):
labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0])
Expand Down Expand Up @@ -2085,7 +2112,28 @@ def mocked_reindex(*args, **kwargs):

with patch("flox.core.reindex_") as mocked_func:
mocked_func.side_effect = mocked_reindex
actual, *_ = groupby_reduce(array, by, func=func, reindex=reindex, expected_groups=expected_groups)
actual, *_ = groupby_reduce(
array, by, func=func, reindex=reindex, expected_groups=expected_groups, fill_value=0
)
assert_equal(actual, expected)
# once during graph construction, 10 times afterward
assert mocked_func.call_count > 1


def test_sparse_errors():
call = partial(
groupby_reduce,
[1, 2, 3],
[0, 1, 1],
reindex=REINDEX_SPARSE_STRAT,
fill_value=0,
expected_groups=[0, 1, 2],
)

if not has_sparse:
with pytest.raises(ImportError):
call(func="sum")

else:
with pytest.raises(ValueError):
call(func="first")
Loading