diff --git a/ci/environment.yml b/ci/environment.yml index 93f0e891f..9d5aa6d01 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -22,3 +22,4 @@ dependencies: - pooch - toolz - numba + - scipy diff --git a/docs/source/aggregations.md b/docs/source/aggregations.md index e6c10e4ba..d3591d2dc 100644 --- a/docs/source/aggregations.md +++ b/docs/source/aggregations.md @@ -11,8 +11,11 @@ the `func` kwarg: - `"std"`, `"nanstd"` - `"argmin"` - `"argmax"` -- `"first"` -- `"last"` +- `"first"`, `"nanfirst"` +- `"last"`, `"nanlast"` +- `"median"`, `"nanmedian"` +- `"mode"`, `"nanmode"` +- `"quantile"`, `"nanquantile"` ```{tip} We would like to add support for `cumsum`, `cumprod` ([issue](https://github.com/xarray-contrib/flox/issues/91)). Contributions are welcome! diff --git a/docs/source/user-stories/custom-aggregations.ipynb b/docs/source/user-stories/custom-aggregations.ipynb index f191c77e0..8b9be09e9 100644 --- a/docs/source/user-stories/custom-aggregations.ipynb +++ b/docs/source/user-stories/custom-aggregations.ipynb @@ -15,8 +15,13 @@ ">\n", "> A = da.groupby(['lon_bins', 'lat_bins']).mode()\n", "\n", - "This notebook will describe how to accomplish this using a custom `Aggregation`\n", - "since `mode` and `median` aren't supported by flox yet.\n" + "This notebook will describe how to accomplish this using a custom `Aggregation`.\n", + "\n", + "\n", + "```{tip}\n", + "flox now supports `mode`, `nanmode`, `quantile`, `nanquantile`, `median`, `nanmedian` using exactly the same \n", + "approach as shown below\n", + "```\n" ] }, { diff --git a/flox/aggregate_npg.py b/flox/aggregate_npg.py index 30e0eb257..966bd43b8 100644 --- a/flox/aggregate_npg.py +++ b/flox/aggregate_npg.py @@ -100,3 +100,84 @@ def _len(group_idx, array, engine, *, func, axis=-1, size=None, fill_value=None, len = partial(_len, func="len") nanlen = partial(_len, func="nanlen") + + +def median(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None): + return npg.aggregate_numpy.aggregate( + group_idx, + array, + func=np.median, + axis=axis, + size=size, + fill_value=fill_value, + dtype=dtype, + ) + + +def nanmedian(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None): + return npg.aggregate_numpy.aggregate( + group_idx, + array, + func=np.nanmedian, + axis=axis, + size=size, + fill_value=fill_value, + dtype=dtype, + ) + + +def quantile(group_idx, array, engine, *, q, axis=-1, size=None, fill_value=None, dtype=None): + return npg.aggregate_numpy.aggregate( + group_idx, + array, + func=partial(np.quantile, q=q), + axis=axis, + size=size, + fill_value=fill_value, + dtype=dtype, + ) + + +def nanquantile(group_idx, array, engine, *, q, axis=-1, size=None, fill_value=None, dtype=None): + return npg.aggregate_numpy.aggregate( + group_idx, + array, + func=partial(np.nanquantile, q=q), + axis=axis, + size=size, + fill_value=fill_value, + dtype=dtype, + ) + + +def mode_(array, nan_policy, dtype): + from scipy.stats import mode + + # npg splits `array` into object arrays for each group + # scipy.stats.mode does not like that + # here we cast back + return mode(array.astype(dtype, copy=False), nan_policy=nan_policy, axis=-1).mode + + +def mode(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None): + return npg.aggregate_numpy.aggregate( + group_idx, + array, + func=partial(mode_, nan_policy="propagate", dtype=array.dtype), + axis=axis, + size=size, + fill_value=fill_value, + dtype=dtype, + ) + + +def nanmode(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None): + return npg.aggregate_numpy.aggregate( + group_idx, + array, + func=partial(mode_, nan_policy="omit", dtype=array.dtype), + axis=axis, + size=size, + fill_value=fill_value, + dtype=dtype, + ) diff --git a/flox/aggregations.py b/flox/aggregations.py index e5013032a..c06ef3509 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, Callable, TypedDict import numpy as np -import numpy_groupies as npg from numpy.typing import DTypeLike from . import aggregate_flox, aggregate_npg, xrutils @@ -35,6 +34,16 @@ class AggDtype(TypedDict): intermediate: tuple[np.dtype | type[np.intp], ...] +def get_npg_aggregation(func, *, engine): + try: + method_ = getattr(aggregate_npg, func) + method = partial(method_, engine=engine) + except AttributeError: + aggregate = aggregate_npg._get_aggregate(engine).aggregate + method = partial(aggregate, func=func) + return method + + def generic_aggregate( group_idx, array, @@ -51,14 +60,11 @@ def generic_aggregate( try: method = getattr(aggregate_flox, func) except AttributeError: - method = partial(npg.aggregate_numpy.aggregate, func=func) + method = get_npg_aggregation(func, engine="numpy") + elif engine in ["numpy", "numba"]: - try: - method_ = getattr(aggregate_npg, func) - method = partial(method_, engine=engine) - except AttributeError: - aggregate = aggregate_npg._get_aggregate(engine).aggregate - method = partial(aggregate, func=func) + method = get_npg_aggregation(func, engine=engine) + else: raise ValueError( f"Expected engine to be one of ['flox', 'numpy', 'numba']. Received {engine} instead." @@ -465,10 +471,22 @@ def _pick_second(*x): final_dtype=bool, ) -# numpy_groupies does not support median -# And the dask version is really hard! -# median = Aggregation("median", chunk=None, combine=None, fill_value=None) -# nanmedian = Aggregation("nanmedian", chunk=None, combine=None, fill_value=None) +# Support statistical quantities only blockwise +# The parallel versions will be approximate and are hard to implement! +median = Aggregation( + name="median", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64 +) +nanmedian = Aggregation( + name="nanmedian", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64 +) +quantile = Aggregation( + name="quantile", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64 +) +nanquantile = Aggregation( + name="nanquantile", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64 +) +mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None) +nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None) aggregations = { "any": any_, @@ -496,6 +514,12 @@ def _pick_second(*x): "nanfirst": nanfirst, "last": last, "nanlast": nanlast, + "median": median, + "nanmedian": nanmedian, + "quantile": quantile, + "nanquantile": nanquantile, + "mode": mode, + "nanmode": nanmode, } diff --git a/flox/core.py b/flox/core.py index 98c37bc6f..484303295 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1307,15 +1307,14 @@ def dask_groupby_agg( assert isinstance(axis, Sequence) assert all(ax >= 0 for ax in axis) - if method == "blockwise" and not isinstance(by, np.ndarray): - raise NotImplementedError - inds = tuple(range(array.ndim)) name = f"groupby_{agg.name}" token = dask.base.tokenize(array, by, agg, expected_groups, axis) if expected_groups is None and reindex: expected_groups = _get_expected_groups(by, sort=sort) + if method == "cohorts": + assert reindex is False by_input = by @@ -1349,7 +1348,6 @@ def dask_groupby_agg( # b. "_grouped_combine": A more general solution where we tree-reduce the groupby reduction. # This allows us to discover groups at compute time, support argreductions, lower intermediate # memory usage (but method="cohorts" would also work to reduce memory in some cases) - do_simple_combine = not _is_arg_reduction(agg) if method == "blockwise": @@ -1375,7 +1373,7 @@ def dask_groupby_agg( partial( blockwise_method, axis=axis, - expected_groups=None if method == "cohorts" else expected_groups, + expected_groups=expected_groups if reindex else None, engine=engine, sort=sort, ), @@ -1468,14 +1466,24 @@ def dask_groupby_agg( elif method == "blockwise": reduced = intermediate - # Here one input chunk → one output chunks - # find number of groups in each chunk, this is needed for output chunks - # along the reduced axis - slices = slices_from_chunks(tuple(array.chunks[ax] for ax in axis)) - groups_in_block = tuple(_unique(by_input[slc]) for slc in slices) - groups = (np.concatenate(groups_in_block),) - ngroups_per_block = tuple(len(grp) for grp in groups_in_block) - group_chunks = (ngroups_per_block,) + if reindex: + if TYPE_CHECKING: + assert expected_groups is not None + # TODO: we could have `expected_groups` be a dask array with appropriate chunks + # for now, we have a numpy array that is interpreted as listing all group labels + # that are present in every chunk + groups = (expected_groups,) + group_chunks = ((len(expected_groups),),) + else: + # Here one input chunk → one output chunks + # find number of groups in each chunk, this is needed for output chunks + # along the reduced axis + # TODO: this logic is very specialized for the resampling case + slices = slices_from_chunks(tuple(array.chunks[ax] for ax in axis)) + groups_in_block = tuple(_unique(by_input[slc]) for slc in slices) + groups = (np.concatenate(groups_in_block),) + ngroups_per_block = tuple(len(grp) for grp in groups_in_block) + group_chunks = (ngroups_per_block,) else: raise ValueError(f"Unknown method={method}.") @@ -1547,7 +1555,7 @@ def _validate_reindex( if reindex is True and not all_numpy: if _is_arg_reduction(func): raise NotImplementedError - if method in ["blockwise", "cohorts"]: + if method == "cohorts" or (method == "blockwise" and not any_by_dask): raise ValueError( "reindex=True is not a valid choice for method='blockwise' or method='cohorts'." ) @@ -1562,7 +1570,11 @@ def _validate_reindex( # have to do the grouped_combine since there's no good fill_value reindex = False - if method == "blockwise" or _is_arg_reduction(func): + if method == "blockwise": + # for grouping by dask arrays, we set reindex=True + reindex = any_by_dask + + elif _is_arg_reduction(func): reindex = False elif method == "cohorts": @@ -1767,7 +1779,10 @@ def groupby_reduce( *by : ndarray or DaskArray Array of labels to group over. Must be aligned with ``array`` so that ``array.shape[-by.ndim :] == by.shape`` - func : str or Aggregation + func : {"all", "any", "count", "sum", "nansum", "mean", "nanmean", \ + "max", "nanmax", "min", "nanmin", "argmax", "nanargmax", "argmin", "nanargmin", \ + "quantile", "nanquantile", "median", "nanmedian", "mode", "nanmode", \ + "first", "nanfirst", "last", "nanlast"} or Aggregation Single function name or an Aggregation instance expected_groups : (optional) Sequence Expected unique labels. @@ -1835,7 +1850,7 @@ def groupby_reduce( boost in computation speed. For cases like time grouping, this may result in large intermediates relative to the original block size. Avoid that by using ``method="cohorts"``. By default, it is turned off for argreductions. finalize_kwargs : dict, optional - Kwargs passed to finalize the reduction such as ``ddof`` for var, std. + Kwargs passed to finalize the reduction such as ``ddof`` for var, std or ``q`` for quantile. Returns ------- @@ -1855,6 +1870,9 @@ def groupby_reduce( "Try engine='numpy' or engine='numba' instead." ) + if func == "quantile" and (finalize_kwargs is None or "q" not in finalize_kwargs): + raise ValueError("Please pass `q` for quantile calculations.") + bys: T_Bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by) nby = len(bys) by_is_dask = tuple(is_duck_dask_array(b) for b in bys) @@ -2023,7 +2041,7 @@ def groupby_reduce( result, groups = partial_agg( array, by_, - expected_groups=None if method == "blockwise" else expected_groups, + expected_groups=expected_groups, agg=agg, reindex=reindex, method=method, diff --git a/flox/xarray.py b/flox/xarray.py index c85ad7113..eb35da387 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -88,8 +88,11 @@ def xarray_reduce( Xarray object to reduce *by : DataArray or iterable of str or iterable of DataArray Variables with which to group by ``obj`` - func : str or Aggregation - Reduction method + func : {"all", "any", "count", "sum", "nansum", "mean", "nanmean", \ + "max", "nanmax", "min", "nanmin", "argmax", "nanargmax", "argmin", "nanargmin", \ + "quantile", "nanquantile", "median", "nanmedian", "mode", "nanmode", \ + "first", "nanfirst", "last", "nanlast"} or Aggregation + Single function name or an Aggregation instance expected_groups : str or sequence expected group labels corresponding to each `by` variable isbin : iterable of bool @@ -164,7 +167,7 @@ def xarray_reduce( boost in computation speed. For cases like time grouping, this may result in large intermediates relative to the original block size. Avoid that by using method="cohorts". By default, it is turned off for arg reductions. **finalize_kwargs - kwargs passed to the finalize function, like ``ddof`` for var, std. + kwargs passed to the finalize function, like ``ddof`` for var, std or ``q`` for quantile. Returns ------- diff --git a/pyproject.toml b/pyproject.toml index e41fb17eb..e387657d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,7 +118,8 @@ module=[ "matplotlib.*", "pandas", "setuptools", - "toolz" + "scipy.*", + "toolz", ] ignore_missing_imports = true diff --git a/tests/__init__.py b/tests/__init__.py index 3e43e94d6..f1c8ec6bf 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -47,6 +47,7 @@ def LooseVersion(vstring): has_dask, requires_dask = _importorskip("dask") has_numba, requires_numba = _importorskip("numba") +has_scipy, requires_scipy = _importorskip("scipy") has_xarray, requires_xarray = _importorskip("xarray") diff --git a/tests/test_core.py b/tests/test_core.py index 453958b2d..440431304 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -31,6 +31,7 @@ has_dask, raise_if_dask_computes, requires_dask, + requires_scipy, ) labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0]) @@ -50,6 +51,9 @@ def dask_array_ones(*args): return None +DEFAULT_QUANTILE = 0.9 +SCIPY_STATS_FUNCS = ("mode", "nanmode") +BLOCKWISE_FUNCS = ("median", "nanmedian", "quantile", "nanquantile") + SCIPY_STATS_FUNCS ALL_FUNCS = ( "sum", "nansum", @@ -73,9 +77,11 @@ def dask_array_ones(*args): "any", "all", "nanlast", - pytest.param("median", marks=(pytest.mark.skip,)), - pytest.param("nanmedian", marks=(pytest.mark.skip,)), -) + "median", + "nanmedian", + "quantile", + "nanquantile", +) + tuple(pytest.param(func, marks=requires_scipy) for func in SCIPY_STATS_FUNCS) if TYPE_CHECKING: from flox.core import T_Agg, T_Engine, T_ExpectedGroupsOpt, T_Method @@ -84,12 +90,26 @@ def dask_array_ones(*args): def _get_array_func(func: str) -> Callable: if func == "count": - def npfunc(x): + def npfunc(x, **kwargs): x = np.asarray(x) return (~np.isnan(x)).sum() elif func in ["nanfirst", "nanlast"]: npfunc = getattr(xrutils, func) + + elif func in SCIPY_STATS_FUNCS: + import scipy.stats + + if "nan" in func: + func = func[3:] + nan_policy = "omit" + else: + nan_policy = "propagate" + + def npfunc(x, **kwargs): + spfunc = partial(getattr(scipy.stats, func), nan_policy=nan_policy) + return getattr(spfunc(x, **kwargs), func) + else: npfunc = getattr(np, func) @@ -205,7 +225,7 @@ def gen_array_by(size, func): @pytest.mark.parametrize("add_nan_by", [True, False]) @pytest.mark.parametrize("func", ALL_FUNCS) def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): - if "arg" in func and engine == "flox": + if ("arg" in func and engine == "flox") or (func in BLOCKWISE_FUNCS and chunks != -1): pytest.skip() array, by = gen_array_by(size, func) @@ -224,6 +244,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): finalize_kwargs = finalize_kwargs + [{"ddof": 1}, {"ddof": 0}] fill_value = np.nan tolerance = {"rtol": 1e-14, "atol": 1e-16} + elif "quantile" in func: + finalize_kwargs = [{"q": DEFAULT_QUANTILE}] + fill_value = None + tolerance = None else: fill_value = None tolerance = None @@ -246,15 +270,16 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): func_ = f"nan{func}" if "nan" not in func else func array_[..., nanmask] = np.nan expected = getattr(np, func_)(array_, axis=-1, **kwargs) - # elif func in ["first", "last"]: - # expected = getattr(xrutils, f"nan{func}")(array_[..., ~nanmask], axis=-1, **kwargs) - elif func in ["nanfirst", "nanlast"]: - expected = getattr(xrutils, func)(array_[..., ~nanmask], axis=-1, **kwargs) else: - expected = getattr(np, func)(array_[..., ~nanmask], axis=-1, **kwargs) + array_func = _get_array_func(func) + expected = array_func(array_[..., ~nanmask], axis=-1, **kwargs) for _ in range(nby): expected = np.expand_dims(expected, -1) + if func in BLOCKWISE_FUNCS: + assert chunks == -1 + flox_kwargs["method"] = "blockwise" + actual, *groups = groupby_reduce(array, *by, **flox_kwargs) assert actual.ndim == (array.ndim + nby - 1) assert expected.ndim == (array.ndim + nby - 1) @@ -265,7 +290,7 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): assert actual.dtype.kind == "i" assert_equal(actual, expected, tolerance) - if not has_dask or chunks is None: + if not has_dask or chunks is None or func in BLOCKWISE_FUNCS: continue params = list(itertools.product(["map-reduce"], [True, False, None])) @@ -396,7 +421,7 @@ def test_numpy_reduce_nd_md(): def test_groupby_agg_dask(func, shape, array_chunks, group_chunks, add_nan, dtype, engine, reindex): """Tests groupby_reduce with dask arrays against groupby_reduce with numpy arrays""" - if func in ["first", "last"]: + if func in ["first", "last"] or func in BLOCKWISE_FUNCS: pytest.skip() if "arg" in func and (engine == "flox" or reindex): @@ -551,7 +576,7 @@ def test_first_last_disallowed_dask(func): "axis", [None, (0, 1, 2), (0, 1), (0, 2), (1, 2), 0, 1, 2, (0,), (1,), (2,)] ) def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine): - if "arg" in func and engine == "flox": + if ("arg" in func and engine == "flox") or func in BLOCKWISE_FUNCS: pytest.skip() if not isinstance(axis, int): @@ -847,7 +872,7 @@ def test_rechunk_for_cohorts(chunk_at, expected): def test_fill_value_behaviour(func, chunks, fill_value, engine): # fill_value = np.nan tests promotion of int counts to float # This is used by xarray - if func in ["all", "any"] or "arg" in func: + if (func in ["all", "any"] or "arg" in func) or func in BLOCKWISE_FUNCS: pytest.skip() npfunc = _get_array_func(func) @@ -903,8 +928,17 @@ def test_cohorts_map_reduce_consistent_dtypes(method, dtype, labels_dtype): @requires_dask @pytest.mark.parametrize("func", ALL_FUNCS) @pytest.mark.parametrize("axis", (-1, None)) -@pytest.mark.parametrize("method", ["blockwise", "cohorts", "map-reduce", "split-reduce"]) +@pytest.mark.parametrize("method", ["blockwise", "cohorts", "map-reduce"]) def test_cohorts_nd_by(func, method, axis, engine): + if ( + ("arg" in func and (axis is None or engine == "flox")) + or (method != "blockwise" and func in BLOCKWISE_FUNCS) + or (axis is None and ("first" in func or "last" in func)) + ): + pytest.skip() + if axis is not None and method != "map-reduce": + pytest.xfail() + o = dask.array.ones((3,), chunks=-1) o2 = dask.array.ones((2, 3), chunks=-1) @@ -915,20 +949,14 @@ def test_cohorts_nd_by(func, method, axis, engine): by[0, 4] = 31 array = np.broadcast_to(array, (2, 3) + array.shape) - if "arg" in func and (axis is None or engine == "flox"): - pytest.skip() - if func in ["any", "all"]: fill_value = False else: fill_value = -123 - if axis is not None and method != "map-reduce": - pytest.xfail() - if axis is None and ("first" in func or "last" in func): - pytest.skip() - kwargs = dict(func=func, engine=engine, method=method, axis=axis, fill_value=fill_value) + if "quantile" in func: + kwargs["finalize_kwargs"] = {"q": DEFAULT_QUANTILE} actual, groups = groupby_reduce(array, by, **kwargs) expected, sorted_groups = groupby_reduce(array.compute(), by, **kwargs) assert_equal(groups, sorted_groups) @@ -990,6 +1018,8 @@ def test_datetime_binning(): def test_bool_reductions(func, engine): if "arg" in func and engine == "flox": pytest.skip() + if "quantile" in func or "mode" in func: + pytest.skip() groups = np.array([1, 1, 1]) data = np.array([True, True, False]) npfunc = _get_array_func(func) @@ -1242,9 +1272,14 @@ def grouped_median(group_idx, array, *, axis=-1, size=None, fill_value=None, dty def test_dtype(func, dtype, engine): if "arg" in func or func in ["any", "all"]: pytest.skip() + + finalize_kwargs = {"q": DEFAULT_QUANTILE} if "quantile" in func else {} + arr = np.ones((4, 12), dtype=dtype) labels = np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"]) - actual, _ = groupby_reduce(arr, labels, func=func, dtype=np.float64, engine=engine) + actual, _ = groupby_reduce( + arr, labels, func=func, dtype=np.float64, engine=engine, finalize_kwargs=finalize_kwargs + ) assert actual.dtype == np.dtype("float64") @@ -1387,6 +1422,33 @@ def test_validate_reindex() -> None: ) assert actual is False + with pytest.raises(ValueError): + _validate_reindex( + True, + "sum", + method="blockwise", + expected_groups=np.array([1, 2, 3]), + any_by_dask=False, + is_dask_array=True, + ) + + assert _validate_reindex( + True, + "sum", + method="blockwise", + expected_groups=np.array([1, 2, 3]), + any_by_dask=True, + is_dask_array=True, + ) + assert _validate_reindex( + None, + "sum", + method="blockwise", + expected_groups=np.array([1, 2, 3]), + any_by_dask=True, + is_dask_array=True, + ) + @requires_dask def test_1d_blockwise_sort_optimization():