From 4f028b3a4307c3ec91bf8547ac7ef1ac8e8d32db Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 7 Oct 2022 11:44:56 -0600 Subject: [PATCH 1/6] Refactor --- flox/core.py | 82 ++++++++++++++++++++++++++++++---------------------- 1 file changed, 48 insertions(+), 34 deletions(-) diff --git a/flox/core.py b/flox/core.py index a73a97398..ab8311265 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1068,6 +1068,30 @@ def _reduce_blockwise( return result +def _extract_unknown_groups(reduced, groups_token, group_chunks, dtype): + import dask.array + from dask.highlevelgraph import HighLevelGraph + + layer: dict[tuple, tuple] = {} + # we've used keepdims=True, so _tree_reduce preserves some dummy dimensions + first_block = reduced.ndim * (0,) + layer[(groups_token, *first_block)] = ( + operator.getitem, + (reduced.name, *first_block), + "groups", + ) + groups: tuple[np.ndarray | DaskArray] = ( + dask.array.Array( + HighLevelGraph.from_collections(groups_token, layer, dependencies=[reduced]), + groups_token, + chunks=group_chunks, + meta=np.array([], dtype=dtype), + ), + ) + + return groups + + def dask_groupby_agg( array: DaskArray, by: DaskArray | np.ndarray, @@ -1189,13 +1213,12 @@ def dask_groupby_agg( group_chunks = ((len(expected_groups),) if expected_groups is not None else (np.nan,),) if method == "map-reduce": - # these are negative axis indices useful for concatenating the intermediates - neg_axis = tuple(range(-len(axis), 0)) - combine: Callable[..., IntermediateDict] if do_simple_combine: combine = _simple_combine else: + # these are negative axis indices useful for concatenating the intermediates + neg_axis = tuple(range(-len(axis), 0)) combine = partial(_grouped_combine, engine=engine, neg_axis=neg_axis, sort=sort) # reduced is really a dict mapping reduction name to array @@ -1219,10 +1242,24 @@ def dask_groupby_agg( keepdims=True, concatenate=False, ) - output_chunks = reduced.chunks[: -(len(axis) + int(split_out > 1))] + group_chunks + + if is_duck_dask_array(by_input) and expected_groups is None: + groups = _extract_unknown_groups( + reduced, + groups_token=f"groups-{name}-{token}", + group_chunks=group_chunks, + dtype=by.dtype, + ) + else: + if expected_groups is None: + expected_groups_ = _get_expected_groups(by_input, sort=sort) + else: + expected_groups_ = expected_groups + groups = (expected_groups_.to_numpy(),) + elif method == "blockwise": reduced = intermediate - # Here one input chunk → one output chunka + # Here one input chunk → one output chunks # find number of groups in each chunk, this is needed for output chunks # along the reduced axis slices = slices_from_chunks(tuple(array.chunks[ax] for ax in axis)) @@ -1235,41 +1272,17 @@ def dask_groupby_agg( groups_in_block = tuple( np.intersect1d(by_input[slc], expected_groups) for slc in slices ) + groups = (np.concatenate(groups_in_block),) + ngroups_per_block = tuple(len(grp) for grp in groups_in_block) - output_chunks = reduced.chunks[: -(len(axis))] + (ngroups_per_block,) + group_chunks = (ngroups_per_block,) + else: raise ValueError(f"Unknown method={method}.") # extract results from the dict - layer: dict[tuple, tuple] = {} + output_chunks = reduced.chunks[: -(len(axis) + int(split_out > 1))] + group_chunks ochunks = tuple(range(len(chunks_v)) for chunks_v in output_chunks) - if is_duck_dask_array(by_input) and expected_groups is None: - groups_name = f"groups-{name}-{token}" - # we've used keepdims=True, so _tree_reduce preserves some dummy dimensions - first_block = len(ochunks) * (0,) - layer[(groups_name, *first_block)] = ( - operator.getitem, - (reduced.name, *first_block), - "groups", - ) - groups: tuple[np.ndarray | DaskArray] = ( - dask.array.Array( - HighLevelGraph.from_collections(groups_name, layer, dependencies=[reduced]), - groups_name, - chunks=group_chunks, - dtype=by.dtype, - ), - ) - else: - if method == "map-reduce": - if expected_groups is None: - expected_groups_ = _get_expected_groups(by_input, sort=sort) - else: - expected_groups_ = expected_groups - groups = (expected_groups_.to_numpy(),) - else: - groups = (np.concatenate(groups_in_block),) - layer2: dict[tuple, tuple] = {} agg_name = f"{name}-{token}" for ochunk in itertools.product(*ochunks): @@ -1624,6 +1637,7 @@ def groupby_reduce( f"\n\n Received: {func}" ) + # TODO: just do this in dask_groupby_agg # we always need some fill_value (see above) so choose the default if needed if kwargs["fill_value"] is None: kwargs["fill_value"] = agg.fill_value[agg.name] From d44acfe256ef7e9c95252e992e50509a47d14c99 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 7 Oct 2022 12:13:50 -0600 Subject: [PATCH 2/6] minbor refactor --- flox/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flox/core.py b/flox/core.py index ab8311265..cda5d7c37 100644 --- a/flox/core.py +++ b/flox/core.py @@ -880,7 +880,6 @@ def _grouped_combine( agg: Aggregation, axis: T_Axes, keepdims: bool, - neg_axis: T_Axes, engine: T_Engine, is_aggregate: bool = False, sort: bool = True, @@ -906,6 +905,9 @@ def _grouped_combine( partial(reindex_intermediates, agg=agg, unique_groups=unique_groups), x_chunk ) + # these are negative axis indices useful for concatenating the intermediates + neg_axis = tuple(range(-len(axis), 0)) + groups = _conc2(x_chunk, "groups", axis=neg_axis) if agg.reduction_type == "argreduce": @@ -1217,9 +1219,7 @@ def dask_groupby_agg( if do_simple_combine: combine = _simple_combine else: - # these are negative axis indices useful for concatenating the intermediates - neg_axis = tuple(range(-len(axis), 0)) - combine = partial(_grouped_combine, engine=engine, neg_axis=neg_axis, sort=sort) + combine = partial(_grouped_combine, engine=engine, axis=axis, sort=sort) # reduced is really a dict mapping reduction name to array # and "groups" to an array of group labels From ed59091a96249d5f7ade38410ca8128f8f384af9 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 7 Oct 2022 12:22:43 -0600 Subject: [PATCH 3/6] Small changes. --- flox/core.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/flox/core.py b/flox/core.py index cda5d7c37..6cf5fb77d 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1070,19 +1070,19 @@ def _reduce_blockwise( return result -def _extract_unknown_groups(reduced, groups_token, group_chunks, dtype): +def _extract_unknown_groups(reduced, groups_token, group_chunks, dtype) -> tuple[DaskArray]: import dask.array from dask.highlevelgraph import HighLevelGraph layer: dict[tuple, tuple] = {} - # we've used keepdims=True, so _tree_reduce preserves some dummy dimensions + groups_token = f"group-{reduced.name}" first_block = reduced.ndim * (0,) layer[(groups_token, *first_block)] = ( operator.getitem, (reduced.name, *first_block), "groups", ) - groups: tuple[np.ndarray | DaskArray] = ( + groups: tuple[DaskArray] = ( dask.array.Array( HighLevelGraph.from_collections(groups_token, layer, dependencies=[reduced]), groups_token, @@ -1244,12 +1244,7 @@ def dask_groupby_agg( ) if is_duck_dask_array(by_input) and expected_groups is None: - groups = _extract_unknown_groups( - reduced, - groups_token=f"groups-{name}-{token}", - group_chunks=group_chunks, - dtype=by.dtype, - ) + groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype) else: if expected_groups is None: expected_groups_ = _get_expected_groups(by_input, sort=sort) From 79b3f0ba2068755c4385ddc4bd9ec0ece9c3efa2 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 7 Oct 2022 12:49:59 -0600 Subject: [PATCH 4/6] fix --- flox/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flox/core.py b/flox/core.py index 6cf5fb77d..1f59e999b 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1070,7 +1070,7 @@ def _reduce_blockwise( return result -def _extract_unknown_groups(reduced, groups_token, group_chunks, dtype) -> tuple[DaskArray]: +def _extract_unknown_groups(reduced, group_chunks, dtype) -> tuple[DaskArray]: import dask.array from dask.highlevelgraph import HighLevelGraph From 8601096ba75843c3e46becd7da139844be96ff50 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 7 Oct 2022 12:56:37 -0600 Subject: [PATCH 5/6] Fix benchmarks --- asv_bench/benchmarks/combine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asv_bench/benchmarks/combine.py b/asv_bench/benchmarks/combine.py index 8d86f56db..2da0b1392 100644 --- a/asv_bench/benchmarks/combine.py +++ b/asv_bench/benchmarks/combine.py @@ -58,4 +58,4 @@ def construct_member(groups): ] self.x_chunk_cohorts = [construct_member(groups) for groups in [np.array((1, 2, 3, 4))] * 4] - self.kwargs = {"agg": flox.aggregations.mean, "axis": (3,), "neg_axis": (-1,)} + self.kwargs = {"agg": flox.aggregations.mean, "axis": (3,)} From e647dbf569d2563c8fb6eee353b938babe9db439 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 7 Oct 2022 13:40:50 -0600 Subject: [PATCH 6/6] fix --- flox/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flox/core.py b/flox/core.py index 1f59e999b..10a9197a9 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1219,7 +1219,7 @@ def dask_groupby_agg( if do_simple_combine: combine = _simple_combine else: - combine = partial(_grouped_combine, engine=engine, axis=axis, sort=sort) + combine = partial(_grouped_combine, engine=engine, sort=sort) # reduced is really a dict mapping reduction name to array # and "groups" to an array of group labels