diff --git a/flox/aggregations.py b/flox/aggregations.py index ca6ff3996..e85c0699d 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -184,6 +184,8 @@ def __init__( self.chunk: FuncTuple = _atleast_1d(chunk) # how to aggregate results after first round of reduction self.combine: FuncTuple = _atleast_1d(combine) + # simpler reductions used with the "simple combine" algorithm + self.simple_combine = None # final aggregation self.aggregate: Callable | str = aggregate if aggregate else self.combine[0] # finalize results (see mean) @@ -577,4 +579,16 @@ def _initialize_aggregation( else: agg.min_count = 0 + simple_combine = [] + for combine in agg.combine: + if isinstance(combine, str): + if combine in ["nanfirst", "nanlast"]: + simple_combine.append(getattr(xrutils, combine)) + else: + simple_combine.append(getattr(np, combine)) + else: + simple_combine.append(combine) + + agg.simple_combine = tuple(simple_combine) + return agg diff --git a/flox/core.py b/flox/core.py index c908f6de5..2444df8e3 100644 --- a/flox/core.py +++ b/flox/core.py @@ -94,6 +94,10 @@ def _is_minmax_reduction(func: T_Agg) -> bool: ) +def _is_first_last_reduction(func: T_Agg) -> bool: + return isinstance(func, str) and func in ["nanfirst", "nanlast", "first", "last"] + + def _get_expected_groups(by: T_By, sort: bool) -> pd.Index: if is_duck_dask_array(by): raise ValueError("Please provide expected_groups if not grouping by a numpy array.") @@ -954,13 +958,13 @@ def _simple_combine( results: IntermediateDict = {"groups": unique_groups} results["intermediates"] = [] axis_ = axis[:-1] + (DUMMY_AXIS,) - for idx, combine in enumerate(agg.combine): + for idx, combine in enumerate(agg.simple_combine): array = _conc2(x_chunk, key1="intermediates", key2=idx, axis=axis_) assert array.ndim >= 2 with warnings.catch_warnings(): warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") - assert isinstance(combine, str) - result = getattr(np, combine)(array, axis=axis_, keepdims=True) + assert callable(combine) + result = combine(array, axis=axis_, keepdims=True) if is_aggregate: # squeeze out DUMMY_AXIS if this is the last step i.e. called from _aggregate result = result.squeeze(axis=DUMMY_AXIS) @@ -1534,11 +1538,17 @@ def _validate_reindex( raise ValueError( "reindex=True is not a valid choice for method='blockwise' or method='cohorts'." ) + if func in ["first", "last"]: + raise ValueError("reindex must be None or False when func is 'first' or 'last.") if reindex is None: if all_numpy: return True + if func in ["first", "last"]: + # have to do the grouped_combine since there's no good fill_value + reindex = False + if method == "blockwise" or _is_arg_reduction(func): reindex = False @@ -1552,6 +1562,7 @@ def _validate_reindex( reindex = True assert isinstance(reindex, bool) + return reindex @@ -1875,6 +1886,21 @@ def groupby_reduce( axis_ = np.core.numeric.normalize_axis_tuple(axis, array.ndim) # type: ignore nax = len(axis_) + has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_) + + if _is_first_last_reduction(func): + if has_dask and nax != 1: + raise ValueError( + "For dask arrays: first, last, nanfirst, nanlast reductions are " + "only supported along a single axis. Please reshape appropriately." + ) + + elif nax not in [1, by_.ndim]: + raise ValueError( + "first, last, nanfirst, nanlast reductions are only supported " + "along a single axis or when reducing across all dimensions of `by`." + ) + # TODO: make sure expected_groups is unique if nax == 1 and by_.ndim > 1 and expected_groups is None: if not any_by_dask: @@ -1898,8 +1924,6 @@ def groupby_reduce( axis_ = tuple(array.ndim + np.arange(-nax, 0)) nax = len(axis_) - has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_) - # When axis is a subset of possible values; then npg will # apply it to groups that don't exist along a particular axis (for e.g.) # since these count as a group that is absent. thoo! @@ -1986,6 +2010,6 @@ def groupby_reduce( ).reshape(result.shape[:-1] + grp_shape) groups = final_groups - if _is_minmax_reduction(func) and is_bool_array: + if is_bool_array and (_is_minmax_reduction(func) or _is_first_last_reduction(func)): result = result.astype(bool) return (result, *groups) # type: ignore[return-value] # Unpack not in mypy yet diff --git a/flox/xrutils.py b/flox/xrutils.py index 73b2023aa..45cf45eec 100644 --- a/flox/xrutils.py +++ b/flox/xrutils.py @@ -1,12 +1,12 @@ # The functions defined here were copied based on the source code # defined in xarray - import datetime from typing import Any, Iterable import numpy as np import pandas as pd +from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined] try: import cftime @@ -283,3 +283,36 @@ def _contains_cftime_datetimes(array) -> bool: return isinstance(sample, cftime.datetime) else: return False + + +def _select_along_axis(values, idx, axis): + other_ind = np.ix_(*[np.arange(s) for s in idx.shape]) + sl = other_ind[:axis] + (idx,) + other_ind[axis:] + return values[sl] + + +def nanfirst(values, axis, keepdims=False): + if isinstance(axis, tuple): + (axis,) = axis + values = np.asarray(values) + axis = normalize_axis_index(axis, values.ndim) + idx_first = np.argmax(~pd.isnull(values), axis=axis) + result = _select_along_axis(values, idx_first, axis) + if keepdims: + return np.expand_dims(result, axis=axis) + else: + return result + + +def nanlast(values, axis, keepdims=False): + if isinstance(axis, tuple): + (axis,) = axis + values = np.asarray(values) + axis = normalize_axis_index(axis, values.ndim) + rev = (slice(None),) * axis + (slice(None, None, -1),) + idx_last = -1 - np.argmax(~pd.isnull(values)[rev], axis=axis) + result = _select_along_axis(values, idx_last, axis) + if keepdims: + return np.expand_dims(result, axis=axis) + else: + return result diff --git a/tests/test_core.py b/tests/test_core.py index 783e3a8bd..7c152fd10 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -3,13 +3,14 @@ import itertools import warnings from functools import partial, reduce -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable import numpy as np import pandas as pd import pytest from numpy_groupies.aggregate_numpy import aggregate +from flox import xrutils from flox.aggregations import Aggregation from flox.core import ( _convert_expected_groups_to_index, @@ -53,6 +54,7 @@ def dask_array_ones(*args): "sum", "nansum", "argmax", + "nanfirst", pytest.param("nanargmax", marks=(pytest.mark.skip,)), "prod", "nanprod", @@ -70,6 +72,7 @@ def dask_array_ones(*args): pytest.param("nanargmin", marks=(pytest.mark.skip,)), "any", "all", + "nanlast", pytest.param("median", marks=(pytest.mark.skip,)), pytest.param("nanmedian", marks=(pytest.mark.skip,)), ) @@ -78,6 +81,21 @@ def dask_array_ones(*args): from flox.core import T_Engine, T_ExpectedGroupsOpt, T_Func2 +def _get_array_func(func: str) -> Callable: + if func == "count": + + def npfunc(x): + x = np.asarray(x) + return (~np.isnan(x)).sum() + + elif func in ["nanfirst", "nanlast"]: + npfunc = getattr(xrutils, func) + else: + npfunc = getattr(np, func) + + return npfunc + + def test_alignment_error(): da = np.ones((12,)) labels = np.ones((5,)) @@ -217,6 +235,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): if "arg" in func and add_nan_by: array_[..., nanmask] = np.nan expected = getattr(np, "nan" + func)(array_, axis=-1, **kwargs) + # elif func in ["first", "last"]: + # expected = getattr(xrutils, f"nan{func}")(array_[..., ~nanmask], axis=-1, **kwargs) + elif func in ["nanfirst", "nanlast"]: + expected = getattr(xrutils, func)(array_[..., ~nanmask], axis=-1, **kwargs) else: expected = getattr(np, func)(array_[..., ~nanmask], axis=-1, **kwargs) for _ in range(nby): @@ -241,7 +263,7 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): call = partial( groupby_reduce, array, *by, method=method, reindex=reindex, **flox_kwargs ) - if "arg" in func and reindex is True: + if ("arg" in func or func in ["first", "last"]) and reindex is True: # simple_combine with argreductions not supported right now with pytest.raises(NotImplementedError): call() @@ -486,6 +508,28 @@ def test_dask_reduce_axis_subset(): ) +@pytest.mark.parametrize("func", ["first", "last", "nanfirst", "nanlast"]) +@pytest.mark.parametrize("axis", [(0, 1)]) +def test_first_last_disallowed(axis, func): + with pytest.raises(ValueError): + groupby_reduce(np.empty((2, 3, 2)), np.ones((2, 3, 2)), func=func, axis=axis) + + +@requires_dask +@pytest.mark.parametrize("func", ["nanfirst", "nanlast"]) +@pytest.mark.parametrize("axis", [None, (0, 1, 2)]) +def test_nanfirst_nanlast_disallowed_dask(axis, func): + with pytest.raises(ValueError): + groupby_reduce(dask.array.empty((2, 3, 2)), np.ones((2, 3, 2)), func=func, axis=axis) + + +@requires_dask +@pytest.mark.parametrize("func", ["first", "last"]) +def test_first_last_disallowed_dask(func): + with pytest.raises(NotImplementedError): + groupby_reduce(dask.array.empty((2, 3, 2)), np.ones((2, 3, 2)), func=func, axis=-1) + + @requires_dask @pytest.mark.parametrize("func", ALL_FUNCS) @pytest.mark.parametrize( @@ -495,8 +539,34 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine): if "arg" in func and engine == "flox": pytest.skip() - if not isinstance(axis, int) and "arg" in func and (axis is None or len(axis) > 1): - pytest.skip() + if not isinstance(axis, int): + if "arg" in func and (axis is None or len(axis) > 1): + pytest.skip() + if ("first" in func or "last" in func) and (axis is not None and len(axis) not in [1, 3]): + pytest.skip() + + if func in ["all", "any"]: + fill_value = False + else: + fill_value = 123 + + if "var" in func or "std" in func: + tolerance = {"rtol": 1e-14, "atol": 1e-16} + else: + tolerance = None + # tests against the numpy output to make sure dask compute matches + by = np.broadcast_to(labels2d, (3, *labels2d.shape)) + rng = np.random.default_rng(12345) + array = rng.random(by.shape) + kwargs = dict( + func=func, axis=axis, expected_groups=[0, 2], fill_value=fill_value, engine=engine + ) + expected, _ = groupby_reduce(array, by, **kwargs) + if engine == "flox": + kwargs.pop("engine") + expected_npg, _ = groupby_reduce(array, by, **kwargs, engine="numpy") + assert_equal(expected_npg, expected) + if func in ["all", "any"]: fill_value = False else: @@ -513,17 +583,23 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine): kwargs = dict( func=func, axis=axis, expected_groups=[0, 2], fill_value=fill_value, engine=engine ) + expected, _ = groupby_reduce(array, by, **kwargs) + if engine == "flox": + kwargs.pop("engine") + expected_npg, _ = groupby_reduce(array, by, **kwargs, engine="numpy") + assert_equal(expected_npg, expected) + + if ("first" in func or "last" in func) and ( + axis is None or (not isinstance(axis, int) and len(axis) != 1) + ): + return + with raise_if_dask_computes(): actual, _ = groupby_reduce( da.from_array(array, chunks=(-1, 2, 3)), da.from_array(by, chunks=(-1, 2, 2)), **kwargs, ) - expected, _ = groupby_reduce(array, by, **kwargs) - if engine == "flox": - kwargs.pop("engine") - expected_npg, _ = groupby_reduce(array, by, **kwargs, engine="numpy") - assert_equal(expected_npg, expected) assert_equal(actual, expected, tolerance) @@ -751,15 +827,7 @@ def test_fill_value_behaviour(func, chunks, fill_value, engine): if chunks is not None and not has_dask: pytest.skip() - if func == "count": - - def npfunc(x): - x = np.asarray(x) - return (~np.isnan(x)).sum() - - else: - npfunc = getattr(np, func) - + npfunc = _get_array_func(func) by = np.array([1, 2, 3, 1, 2, 3]) array = np.array([np.nan, 1, 1, np.nan, 1, 1]) if chunks: @@ -767,7 +835,9 @@ def npfunc(x): actual, _ = groupby_reduce( array, by, func=func, engine=engine, fill_value=fill_value, expected_groups=[0, 1, 2, 3] ) - expected = np.array([fill_value, fill_value, npfunc([1.0, 1.0]), npfunc([1.0, 1.0])]) + expected = np.array( + [fill_value, fill_value, npfunc([1.0, 1.0], axis=0), npfunc([1.0, 1.0], axis=0)] + ) assert_equal(actual, expected) @@ -832,6 +902,8 @@ def test_cohorts_nd_by(func, method, axis, engine): if axis is not None and method != "map-reduce": pytest.xfail() + if axis is None and ("first" in func or "last" in func): + pytest.skip() kwargs = dict(func=func, engine=engine, method=method, axis=axis, fill_value=fill_value) actual, groups = groupby_reduce(array, by, **kwargs) @@ -897,7 +969,8 @@ def test_bool_reductions(func, engine): pytest.skip() groups = np.array([1, 1, 1]) data = np.array([True, True, False]) - expected = np.expand_dims(getattr(np, func)(data), -1) + npfunc = _get_array_func(func) + expected = np.expand_dims(npfunc(data, axis=0), -1) actual, _ = groupby_reduce(data, groups, func=func, engine=engine) assert_equal(expected, actual)