diff --git a/asv_bench/benchmarks/combine.py b/asv_bench/benchmarks/combine.py index 8d86f56db..2da0b1392 100644 --- a/asv_bench/benchmarks/combine.py +++ b/asv_bench/benchmarks/combine.py @@ -58,4 +58,4 @@ def construct_member(groups): ] self.x_chunk_cohorts = [construct_member(groups) for groups in [np.array((1, 2, 3, 4))] * 4] - self.kwargs = {"agg": flox.aggregations.mean, "axis": (3,), "neg_axis": (-1,)} + self.kwargs = {"agg": flox.aggregations.mean, "axis": (3,)} diff --git a/flox/core.py b/flox/core.py index a73a97398..10a9197a9 100644 --- a/flox/core.py +++ b/flox/core.py @@ -880,7 +880,6 @@ def _grouped_combine( agg: Aggregation, axis: T_Axes, keepdims: bool, - neg_axis: T_Axes, engine: T_Engine, is_aggregate: bool = False, sort: bool = True, @@ -906,6 +905,9 @@ def _grouped_combine( partial(reindex_intermediates, agg=agg, unique_groups=unique_groups), x_chunk ) + # these are negative axis indices useful for concatenating the intermediates + neg_axis = tuple(range(-len(axis), 0)) + groups = _conc2(x_chunk, "groups", axis=neg_axis) if agg.reduction_type == "argreduce": @@ -1068,6 +1070,30 @@ def _reduce_blockwise( return result +def _extract_unknown_groups(reduced, group_chunks, dtype) -> tuple[DaskArray]: + import dask.array + from dask.highlevelgraph import HighLevelGraph + + layer: dict[tuple, tuple] = {} + groups_token = f"group-{reduced.name}" + first_block = reduced.ndim * (0,) + layer[(groups_token, *first_block)] = ( + operator.getitem, + (reduced.name, *first_block), + "groups", + ) + groups: tuple[DaskArray] = ( + dask.array.Array( + HighLevelGraph.from_collections(groups_token, layer, dependencies=[reduced]), + groups_token, + chunks=group_chunks, + meta=np.array([], dtype=dtype), + ), + ) + + return groups + + def dask_groupby_agg( array: DaskArray, by: DaskArray | np.ndarray, @@ -1189,14 +1215,11 @@ def dask_groupby_agg( group_chunks = ((len(expected_groups),) if expected_groups is not None else (np.nan,),) if method == "map-reduce": - # these are negative axis indices useful for concatenating the intermediates - neg_axis = tuple(range(-len(axis), 0)) - combine: Callable[..., IntermediateDict] if do_simple_combine: combine = _simple_combine else: - combine = partial(_grouped_combine, engine=engine, neg_axis=neg_axis, sort=sort) + combine = partial(_grouped_combine, engine=engine, sort=sort) # reduced is really a dict mapping reduction name to array # and "groups" to an array of group labels @@ -1219,10 +1242,19 @@ def dask_groupby_agg( keepdims=True, concatenate=False, ) - output_chunks = reduced.chunks[: -(len(axis) + int(split_out > 1))] + group_chunks + + if is_duck_dask_array(by_input) and expected_groups is None: + groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype) + else: + if expected_groups is None: + expected_groups_ = _get_expected_groups(by_input, sort=sort) + else: + expected_groups_ = expected_groups + groups = (expected_groups_.to_numpy(),) + elif method == "blockwise": reduced = intermediate - # Here one input chunk → one output chunka + # Here one input chunk → one output chunks # find number of groups in each chunk, this is needed for output chunks # along the reduced axis slices = slices_from_chunks(tuple(array.chunks[ax] for ax in axis)) @@ -1235,41 +1267,17 @@ def dask_groupby_agg( groups_in_block = tuple( np.intersect1d(by_input[slc], expected_groups) for slc in slices ) + groups = (np.concatenate(groups_in_block),) + ngroups_per_block = tuple(len(grp) for grp in groups_in_block) - output_chunks = reduced.chunks[: -(len(axis))] + (ngroups_per_block,) + group_chunks = (ngroups_per_block,) + else: raise ValueError(f"Unknown method={method}.") # extract results from the dict - layer: dict[tuple, tuple] = {} + output_chunks = reduced.chunks[: -(len(axis) + int(split_out > 1))] + group_chunks ochunks = tuple(range(len(chunks_v)) for chunks_v in output_chunks) - if is_duck_dask_array(by_input) and expected_groups is None: - groups_name = f"groups-{name}-{token}" - # we've used keepdims=True, so _tree_reduce preserves some dummy dimensions - first_block = len(ochunks) * (0,) - layer[(groups_name, *first_block)] = ( - operator.getitem, - (reduced.name, *first_block), - "groups", - ) - groups: tuple[np.ndarray | DaskArray] = ( - dask.array.Array( - HighLevelGraph.from_collections(groups_name, layer, dependencies=[reduced]), - groups_name, - chunks=group_chunks, - dtype=by.dtype, - ), - ) - else: - if method == "map-reduce": - if expected_groups is None: - expected_groups_ = _get_expected_groups(by_input, sort=sort) - else: - expected_groups_ = expected_groups - groups = (expected_groups_.to_numpy(),) - else: - groups = (np.concatenate(groups_in_block),) - layer2: dict[tuple, tuple] = {} agg_name = f"{name}-{token}" for ochunk in itertools.product(*ochunks): @@ -1624,6 +1632,7 @@ def groupby_reduce( f"\n\n Received: {func}" ) + # TODO: just do this in dask_groupby_agg # we always need some fill_value (see above) so choose the default if needed if kwargs["fill_value"] is None: kwargs["fill_value"] = agg.fill_value[agg.name]