|
5 | 5 | import pytest
|
6 | 6 | from numpy_groupies.aggregate_numpy import aggregate
|
7 | 7 |
|
| 8 | +from flox.aggregations import Aggregation |
8 | 9 | from flox.core import (
|
9 | 10 | _convert_expected_groups_to_index,
|
10 | 11 | _get_optimal_chunks_for_groups,
|
@@ -964,3 +965,47 @@ def test_factorize_reindex_sorting_ints():
|
964 | 965 |
|
965 | 966 | expected = factorize_(**kwargs, reindex=True, sort=False)[0]
|
966 | 967 | assert_equal(expected, [6, 4, 6, 3, 2, 0])
|
| 968 | + |
| 969 | + |
| 970 | +@requires_dask |
| 971 | +def test_custom_aggregation_blockwise(): |
| 972 | + def grouped_median(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None): |
| 973 | + return aggregate( |
| 974 | + group_idx, |
| 975 | + array, |
| 976 | + func=np.median, |
| 977 | + axis=axis, |
| 978 | + size=size, |
| 979 | + fill_value=fill_value, |
| 980 | + dtype=dtype, |
| 981 | + ) |
| 982 | + |
| 983 | + agg_median = Aggregation( |
| 984 | + name="median", numpy=grouped_median, fill_value=-1, chunk=None, combine=None |
| 985 | + ) |
| 986 | + |
| 987 | + array = np.arange(100, dtype=np.float32).reshape(5, 20) |
| 988 | + by = np.ones((20,)) |
| 989 | + |
| 990 | + actual, _ = groupby_reduce(array, by, func=agg_median, axis=-1) |
| 991 | + expected = np.median(array, axis=-1, keepdims=True) |
| 992 | + assert_equal(expected, actual) |
| 993 | + |
| 994 | + for method in ["map-reduce", "cohorts", "split-reduce"]: |
| 995 | + with pytest.raises(NotImplementedError): |
| 996 | + groupby_reduce( |
| 997 | + dask.array.from_array(array, chunks=(1, -1)), |
| 998 | + by, |
| 999 | + func=agg_median, |
| 1000 | + axis=-1, |
| 1001 | + method=method, |
| 1002 | + ) |
| 1003 | + |
| 1004 | + actual, _ = groupby_reduce( |
| 1005 | + dask.array.from_array(array, chunks=(1, -1)), |
| 1006 | + by, |
| 1007 | + func=agg_median, |
| 1008 | + axis=-1, |
| 1009 | + method="blockwise", |
| 1010 | + ) |
| 1011 | + assert_equal(expected, actual) |
0 commit comments