diff --git a/flox/core.py b/flox/core.py index 07f6d0e69..7cbb5659e 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1195,7 +1195,6 @@ def dask_groupby_agg( import dask.array from dask.array.core import slices_from_chunks - from dask.highlevelgraph import HighLevelGraph # I think _tree_reduce expects this assert isinstance(axis, Sequence) @@ -1268,6 +1267,9 @@ def dask_groupby_agg( engine=engine, sort=sort, ), + # output indices are the same as input indices + # Unlike xhistogram, we don't always know what the size of the group + # dimension will be unless reindex=True inds, array, inds, @@ -1277,7 +1279,7 @@ def dask_groupby_agg( dtype=array.dtype, # this is purely for show meta=array._meta, align_arrays=False, - token=f"{name}-chunk-{token}", + name=f"{name}-chunk-{token}", ) if expected_groups is None: @@ -1364,35 +1366,63 @@ def dask_groupby_agg( groups = (np.concatenate(groups_in_block),) ngroups_per_block = tuple(len(grp) for grp in groups_in_block) group_chunks = (ngroups_per_block,) - else: raise ValueError(f"Unknown method={method}.") - # extract results from the dict + out_inds = inds[: -len(axis)] + (inds[-1],) output_chunks = reduced.chunks[: -len(axis)] + group_chunks + if method == "blockwise" and len(axis) > 1: + # The final results are available but the blocks along axes + # need to be reshaped to axis=-1 + # I don't know that this is possible with blockwise + # All other code paths benefit from an unmaterialized Blockwise layer + reduced = _collapse_blocks_along_axes(reduced, axis, group_chunks) + + # Can't use map_blocks because it forces concatenate=True along drop_axes, + result = dask.array.blockwise( + _extract_result, + out_inds, + reduced, + inds, + adjust_chunks=dict(zip(out_inds, output_chunks)), + dtype=agg.dtype[agg.name], + key=agg.name, + name=f"{name}-{token}", + concatenate=False, + ) + + return (result, groups) + + +def _collapse_blocks_along_axes(reduced, axis, group_chunks): + import dask.array + from dask.highlevelgraph import HighLevelGraph + + nblocks = tuple(reduced.numblocks[ax] for ax in axis) + output_chunks = reduced.chunks[: -len(axis)] + ((1,) * (len(axis) - 1),) + group_chunks + + # extract results from the dict ochunks = tuple(range(len(chunks_v)) for chunks_v in output_chunks) layer2: dict[tuple, tuple] = {} - agg_name = f"{name}-{token}" - for ochunk in itertools.product(*ochunks): - if method == "blockwise": - if len(axis) == 1: - inchunk = ochunk - else: - nblocks = tuple(len(array.chunks[ax]) for ax in axis) - inchunk = ochunk[:-1] + np.unravel_index(ochunk[-1], nblocks) - else: - inchunk = ochunk[:-1] + (0,) * (len(axis) - 1) + (ochunk[-1],) + name = f"reshape-{reduced.name}" - layer2[(agg_name, *ochunk)] = (operator.getitem, (reduced.name, *inchunk), agg.name) + for ochunk in itertools.product(*ochunks): + inchunk = ochunk[: -len(axis)] + np.unravel_index(ochunk[-1], nblocks) + layer2[(name, *ochunk)] = (reduced.name, *inchunk) - result = dask.array.Array( - HighLevelGraph.from_collections(agg_name, layer2, dependencies=[reduced]), - agg_name, + return dask.array.Array( + HighLevelGraph.from_collections(name, layer2, dependencies=[reduced]), + name, chunks=output_chunks, - dtype=agg.dtype[agg.name], + dtype=reduced.dtype, ) - return (result, groups) + +def _extract_result(result_dict: FinalResultsDict, key) -> np.ndarray: + from dask.array.core import deepfirst + + # deepfirst should be not be needed here but sometimes we receive a list of dict? + return deepfirst(result_dict)[key] def _validate_reindex(