diff --git a/docs/source/user-stories/large-zonal-stats.ipynb b/docs/source/user-stories/large-zonal-stats.ipynb index e1203d2a0..8c0c47d3a 100644 --- a/docs/source/user-stories/large-zonal-stats.ipynb +++ b/docs/source/user-stories/large-zonal-stats.ipynb @@ -161,6 +161,7 @@ " blockwise=False,\n", " array_type=ReindexArrayType.SPARSE_COO,\n", " ),\n", + " fill_value=0,\n", ")\n", "result" ] diff --git a/flox/core.py b/flox/core.py index ba96181ca..96555d4f9 100644 --- a/flox/core.py +++ b/flox/core.py @@ -68,6 +68,7 @@ from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef] HAS_NUMBAGG = module_available("numbagg", minversion="0.3.0") +HAS_SPARSE = module_available("sparse") if TYPE_CHECKING: try: @@ -255,6 +256,12 @@ def _is_bool_supported_reduction(func: T_Agg) -> bool: ) +def _is_sparse_supported_reduction(func: T_Agg) -> bool: + if isinstance(func, Aggregation): + func = func.name + return HAS_SPARSE and all(f not in func for f in ["first", "last", "prod", "var", "std"]) + + def _get_expected_groups(by: T_By, sort: bool) -> T_ExpectIndex: if is_duck_dask_array(by): raise ValueError("Please provide expected_groups if not grouping by a numpy array.") @@ -736,12 +743,12 @@ def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) -> return array.rechunk({axis: newchunks}) -def reindex_numpy(array, from_, to, fill_value, dtype, axis): +def reindex_numpy(array, from_: pd.Index, to: pd.Index, fill_value, dtype, axis: int): idx = from_.get_indexer(to) indexer = [slice(None, None)] * array.ndim indexer[axis] = idx reindexed = array[tuple(indexer)] - if any(idx == -1): + if (idx == -1).any(): if fill_value is None: raise ValueError("Filling is required. fill_value cannot be None.") indexer[axis] = idx == -1 @@ -750,25 +757,43 @@ def reindex_numpy(array, from_, to, fill_value, dtype, axis): return reindexed -def reindex_pydata_sparse_coo(array, from_, to, fill_value, dtype, axis): +def reindex_pydata_sparse_coo(array, from_: pd.Index, to: pd.Index, fill_value, dtype, axis: int): import sparse assert axis == -1 - if fill_value is None: + needs_reindex = (from_.get_indexer(to) == -1).any() + if needs_reindex and fill_value is None: raise ValueError("Filling is required. fill_value cannot be None.") + idx = to.get_indexer(from_) - assert (idx != -1).all() # FIXME + mask = idx != -1 # indices along last axis to keep + if mask.all(): + mask = slice(None) shape = array.shape - ranges = np.broadcast_arrays(*np.ix_(*(tuple(np.arange(size) for size in shape[:axis]) + (idx,)))) - coords = np.stack(ranges, axis=0).reshape(array.ndim, -1) - data = array.data if isinstance(array, sparse.COO) else array.reshape(-1) + if isinstance(array, sparse.COO): + subset = array[..., mask] + data = subset.data + coords = subset.coords + if subset.nnz > 0: + coords[-1, :] = idx[mask][coords[-1, :]] + if fill_value is None: + # no reindexing is actually needed (dense case) + # preserve the fill_value + fill_value = array.fill_value + else: + ranges = np.broadcast_arrays( + *np.ix_(*(tuple(np.arange(size) for size in shape[:axis]) + (idx[mask],))) + ) + coords = np.stack(ranges, axis=0).reshape(array.ndim, -1) + data = array[..., mask].reshape(-1) reindexed = sparse.COO( coords=coords, data=data.astype(dtype, copy=False), shape=(*array.shape[:axis], to.size), + fill_value=fill_value, ) return reindexed @@ -795,7 +820,11 @@ def reindex_( if array.shape[axis] == 0: # all groups were NaN - reindexed = np.full(array.shape[:-1] + (len(to),), fill_value, dtype=array.dtype) + shape = array.shape[:-1] + (len(to),) + if array_type in (ReindexArrayType.AUTO, ReindexArrayType.NUMPY): + reindexed = np.full(shape, fill_value, dtype=array.dtype) + else: + raise NotImplementedError return reindexed from_ = pd.Index(from_) @@ -1044,7 +1073,7 @@ def chunk_argreduce( sort=sort, user_dtype=user_dtype, ) - if not isnull(results["groups"]).all(): + if not all(isnull(results["groups"])): idx = np.broadcast_to(idx, array.shape) # array, by get flattened to 1D before passing to npg @@ -1288,7 +1317,7 @@ def _finalize_results( fill_value = agg.fill_value["user"] if min_count > 0: count_mask = counts < min_count - if count_mask.any(): + if count_mask.any() or reindex.array_type is ReindexArrayType.SPARSE_COO: # For one count_mask.any() prevents promoting bool to dtype(fill_value) unless # necessary if fill_value is None: @@ -2815,6 +2844,15 @@ def groupby_reduce( array.dtype, ) + if reindex.array_type is ReindexArrayType.SPARSE_COO: + if not HAS_SPARSE: + raise ImportError("Package 'sparse' must be installed to reindex to a sparse.COO array.") + if not _is_sparse_supported_reduction(func): + raise NotImplementedError( + f"Aggregation {func=!r} is not supported when reindexing to a sparse array. " + "Please raise an issue" + ) + if TYPE_CHECKING: assert isinstance(reindex, ReindexStrategy) assert method is not None diff --git a/flox/xrutils.py b/flox/xrutils.py index f201ba60b..7da2f9909 100644 --- a/flox/xrutils.py +++ b/flox/xrutils.py @@ -159,7 +159,9 @@ def notnull(data): return out -def isnull(data): +def isnull(data: Any): + if data is None: + return False if not is_duck_array(data): data = np.asarray(data) scalar_type = data.dtype.type @@ -177,7 +179,7 @@ def isnull(data): else: # at this point, array should have dtype=object if isinstance(data, (np.ndarray, dask_array_type)): # noqa - return pd.isnull(data) + return pd.isnull(data) # type: ignore[arg-type] else: # Not reachable yet, but intended for use with other duck array # types. For full consistency with pandas, we should accept None as @@ -374,9 +376,10 @@ def _select_along_axis(values, idx, axis): def nanfirst(values, axis, keepdims=False): if isinstance(axis, tuple): (axis,) = axis - values = np.asarray(values) + if not is_duck_array(values): + values = np.asarray(values) axis = normalize_axis_index(axis, values.ndim) - idx_first = np.argmax(~pd.isnull(values), axis=axis) + idx_first = np.argmax(~isnull(values), axis=axis) result = _select_along_axis(values, idx_first, axis) if keepdims: return np.expand_dims(result, axis=axis) @@ -387,10 +390,11 @@ def nanfirst(values, axis, keepdims=False): def nanlast(values, axis, keepdims=False): if isinstance(axis, tuple): (axis,) = axis - values = np.asarray(values) + if not is_duck_array(values): + values = np.asarray(values) axis = normalize_axis_index(axis, values.ndim) rev = (slice(None),) * axis + (slice(None, None, -1),) - idx_last = -1 - np.argmax(~pd.isnull(values)[rev], axis=axis) + idx_last = -1 - np.argmax(~isnull(values)[rev], axis=axis) result = _select_along_axis(values, idx_last, axis) if keepdims: return np.expand_dims(result, axis=axis) diff --git a/tests/test_core.py b/tests/test_core.py index a7398a859..393677643 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -24,6 +24,7 @@ _choose_engine, _convert_expected_groups_to_index, _get_optimal_chunks_for_groups, + _is_sparse_supported_reduction, _normalize_indexes, _validate_reindex, factorize_, @@ -43,6 +44,7 @@ assert_equal_tuple, has_cubed, has_dask, + has_sparse, raise_if_dask_computes, requires_cubed, requires_dask, @@ -74,6 +76,10 @@ def dask_array_ones(*args): DEFAULT_QUANTILE = 0.9 +REINDEX_SPARSE_STRAT = ReindexStrategy(blockwise=False, array_type=ReindexArrayType.SPARSE_COO) +REINDEX_SPARSE_PARAM = pytest.param( + REINDEX_SPARSE_STRAT, marks=(requires_dask, pytest.mark.skipif(not has_sparse, reason="no sparse")) +) if TYPE_CHECKING: from flox.core import T_Agg, T_Engine, T_ExpectedGroupsOpt, T_Method @@ -320,13 +326,20 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): if not has_dask or chunks is None or func in BLOCKWISE_FUNCS: continue - params = list(itertools.product(["map-reduce"], [True, False, None])) + params = list( + itertools.product( + ["map-reduce"], + [True, False, None, REINDEX_SPARSE_STRAT], + ) + ) params.extend(itertools.product(["cohorts"], [False, None])) if chunks == -1: params.extend([("blockwise", None)]) combine_error = RuntimeError("This combine should not have been called.") for method, reindex in params: + if isinstance(reindex, ReindexStrategy) and not _is_sparse_supported_reduction(func): + continue call = partial( groupby_reduce, array, @@ -360,6 +373,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): assert_equal(actual_group, expect, tolerance) if "arg" in func: assert actual.dtype.kind == "i" + if isinstance(reindex, ReindexStrategy): + import sparse + + expected = sparse.COO.from_numpy(expected) assert_equal(actual, expected, tolerance) @@ -447,7 +464,7 @@ def test_numpy_reduce_nd_md(): @requires_dask -@pytest.mark.parametrize("reindex", [None, False, True]) +@pytest.mark.parametrize("reindex", [None, False, True, REINDEX_SPARSE_PARAM]) @pytest.mark.parametrize("func", ALL_FUNCS) @pytest.mark.parametrize("add_nan", [False, True]) @pytest.mark.parametrize("dtype", (float,)) @@ -470,6 +487,9 @@ def test_groupby_agg_dask(func, shape, array_chunks, group_chunks, add_nan, dtyp if "arg" in func and (engine in ["flox", "numbagg"] or reindex): pytest.skip() + if isinstance(reindex, ReindexStrategy) and not _is_sparse_supported_reduction(func): + pytest.skip() + rng = np.random.default_rng(12345) array = dask.array.from_array(rng.random(shape), chunks=array_chunks).astype(dtype) array = dask.array.ones(shape, chunks=array_chunks) @@ -775,6 +795,7 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine): (None, None), pytest.param(False, (2, 2, 3), marks=requires_dask), pytest.param(True, (2, 2, 3), marks=requires_dask), + pytest.param(REINDEX_SPARSE_PARAM, (2, 2, 3), marks=requires_dask), ], ) @pytest.mark.parametrize( @@ -821,7 +842,13 @@ def _maybe_chunk(arr): @requires_dask @pytest.mark.parametrize( "expected_groups, reindex", - [(None, None), (None, False), ([0, 1, 2], True), ([0, 1, 2], False)], + [ + (None, None), + (None, False), + ([0, 1, 2], True), + ([0, 1, 2], False), + pytest.param([0, 1, 2], REINDEX_SPARSE_PARAM), + ], ) def test_groupby_all_nan_blocks_dask(expected_groups, reindex, engine): labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0]) @@ -2085,7 +2112,28 @@ def mocked_reindex(*args, **kwargs): with patch("flox.core.reindex_") as mocked_func: mocked_func.side_effect = mocked_reindex - actual, *_ = groupby_reduce(array, by, func=func, reindex=reindex, expected_groups=expected_groups) + actual, *_ = groupby_reduce( + array, by, func=func, reindex=reindex, expected_groups=expected_groups, fill_value=0 + ) assert_equal(actual, expected) # once during graph construction, 10 times afterward assert mocked_func.call_count > 1 + + +def test_sparse_errors(): + call = partial( + groupby_reduce, + [1, 2, 3], + [0, 1, 1], + reindex=REINDEX_SPARSE_STRAT, + fill_value=0, + expected_groups=[0, 1, 2], + ) + + if not has_sparse: + with pytest.raises(ImportError): + call(func="sum") + + else: + with pytest.raises(ValueError): + call(func="first")