From 394925eed494718a23da639303fa1ca18fed2c80 Mon Sep 17 00:00:00 2001 From: dcherian Date: Wed, 19 Oct 2022 12:06:23 -0600 Subject: [PATCH 1/2] Support reindexing in simple_combine MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit For 1D combine, great improvement for cohorts-type reductions More memory but similar time for map-reduce. Note that the map-reduce intermediates are a worst case where there are no shared groups between the chunks being combined. This case is actually optimized in _group_combine where reindexing is skipped for reducing along a single axis. [ 68.75%] ··· =========== ========= ========= -- combine ----------- ------------------- kind grouped combine =========== ========= ========= cohorts 760M 631M mapreduce 981M 1.81G =========== ========= ========= [ 75.00%] ··· =========== ========== =========== -- combine ----------- ---------------------- kind grouped combine =========== ========== =========== cohorts 393±10ms 137±10ms mapreduce 652±10ms 611±400ms =========== ========== =========== Fix bug in unique --- asv_bench/benchmarks/combine.py | 29 +++++---- flox/core.py | 103 ++++++++++++++++++-------------- tests/test_core.py | 71 ++++++++++++---------- 3 files changed, 118 insertions(+), 85 deletions(-) diff --git a/asv_bench/benchmarks/combine.py b/asv_bench/benchmarks/combine.py index 2da0b1392..085c27468 100644 --- a/asv_bench/benchmarks/combine.py +++ b/asv_bench/benchmarks/combine.py @@ -1,3 +1,5 @@ +from functools import partial + import numpy as np import flox @@ -7,26 +9,31 @@ N = 1000 +def _get_combine(combine): + if combine == "grouped": + return partial(flox.core._grouped_combine, engine="numpy") + else: + return partial(flox.core._simple_combine, reindex=False) + + class Combine: def setup(self, *args, **kwargs): raise NotImplementedError - @parameterized("kind", ("cohorts", "mapreduce")) - def time_combine(self, kind): - flox.core._grouped_combine( + @parameterized(("kind", "combine"), (("reindexed", "not_reindexed"), ("grouped", "simple"))) + def time_combine(self, kind, combine): + _get_combine(combine)( getattr(self, f"x_chunk_{kind}"), **self.kwargs, keepdims=True, - engine="numpy", ) - @parameterized("kind", ("cohorts", "mapreduce")) - def peakmem_combine(self, kind): - flox.core._grouped_combine( + @parameterized(("kind", "combine"), (("reindexed", "not_reindexed"), ("grouped", "simple"))) + def peakmem_combine(self, kind, combine): + _get_combine(combine)( getattr(self, f"x_chunk_{kind}"), **self.kwargs, keepdims=True, - engine="numpy", ) @@ -47,7 +54,7 @@ def construct_member(groups): } # motivated by - self.x_chunk_mapreduce = [ + self.x_chunk_not_reindexed = [ construct_member(groups) for groups in [ np.array((1, 2, 3, 4)), @@ -57,5 +64,7 @@ def construct_member(groups): * 2 ] - self.x_chunk_cohorts = [construct_member(groups) for groups in [np.array((1, 2, 3, 4))] * 4] + self.x_chunk_reindexed = [ + construct_member(groups) for groups in [np.array((1, 2, 3, 4))] * 4 + ] self.kwargs = {"agg": flox.aggregations.mean, "axis": (3,)} diff --git a/flox/core.py b/flox/core.py index a4efa6153..68afd5c6a 100644 --- a/flox/core.py +++ b/flox/core.py @@ -816,8 +816,25 @@ def _expand_dims(results: IntermediateDict) -> IntermediateDict: return results +def _find_unique_groups(x_chunk): + from dask.base import flatten + from dask.utils import deepmap + + unique_groups = _unique(np.asarray(tuple(flatten(deepmap(listify_groups, x_chunk))))) + unique_groups = unique_groups[~isnull(unique_groups)] + + if len(unique_groups) == 0: + unique_groups = [np.nan] + return unique_groups + + def _simple_combine( - x_chunk, agg: Aggregation, axis: T_Axes, keepdims: bool, is_aggregate: bool = False + x_chunk, + agg: Aggregation, + axis: T_Axes, + keepdims: bool, + reindex: bool, + is_aggregate: bool = False, ) -> IntermediateDict: """ 'Simple' combination of blockwise results. @@ -830,8 +847,19 @@ def _simple_combine( 4. At the final agggregate step, we squeeze out DUMMY_AXIS """ from dask.array.core import deepfirst + from dask.utils import deepmap + + if not reindex: + # We didn't reindex at the blockwise step + # So now reindex before combining by reducing along DUMMY_AXIS + unique_groups = _find_unique_groups(x_chunk) + x_chunk = deepmap( + partial(reindex_intermediates, agg=agg, unique_groups=unique_groups), x_chunk + ) + else: + unique_groups = deepfirst(x_chunk)["groups"] - results: IntermediateDict = {"groups": deepfirst(x_chunk)["groups"]} + results: IntermediateDict = {"groups": unique_groups} results["intermediates"] = [] axis_ = axis[:-1] + (DUMMY_AXIS,) for idx, combine in enumerate(agg.combine): @@ -886,7 +914,6 @@ def _grouped_combine( sort: bool = True, ) -> IntermediateDict: """Combine intermediates step of tree reduction.""" - from dask.base import flatten from dask.utils import deepmap if isinstance(x_chunk, dict): @@ -897,11 +924,7 @@ def _grouped_combine( # when there's only a single axis of reduction, we can just concatenate later, # reindexing is unnecessary # I bet we can minimize the amount of reindexing for mD reductions too, but it's complicated - unique_groups = _unique(np.array(tuple(flatten(deepmap(listify_groups, x_chunk))))) - unique_groups = unique_groups[~isnull(unique_groups)] - if len(unique_groups) == 0: - unique_groups = [np.nan] - + unique_groups = _find_unique_groups(x_chunk) x_chunk = deepmap( partial(reindex_intermediates, agg=agg, unique_groups=unique_groups), x_chunk ) @@ -1216,7 +1239,8 @@ def dask_groupby_agg( # This allows us to discover groups at compute time, support argreductions, lower intermediate # memory usage (but method="cohorts" would also work to reduce memory in some cases) - do_simple_combine = method != "blockwise" and reindex and not _is_arg_reduction(agg) + do_simple_combine = not _is_arg_reduction(agg) + if method == "blockwise": # use the "non dask" code path, but applied blockwise blockwise_method = partial( @@ -1268,31 +1292,32 @@ def dask_groupby_agg( if method in ["map-reduce", "cohorts"]: combine: Callable[..., IntermediateDict] if do_simple_combine: - combine = _simple_combine + combine = partial(_simple_combine, reindex=reindex) + combine_name = "simple-combine" else: combine = partial(_grouped_combine, engine=engine, sort=sort) + combine_name = "grouped-combine" - # Each chunk of `reduced`` is really a dict mapping - # 1. reduction name to array - # 2. "groups" to an array of group labels - # Note: it does not make sense to interpret axis relative to - # shape of intermediate results after the blockwise call tree_reduce = partial( dask.array.reductions._tree_reduce, - combine=partial(combine, agg=agg), - name=f"{name}-reduce-{method}", + name=f"{name}-reduce-{method}-{combine_name}", dtype=array.dtype, axis=axis, keepdims=True, concatenate=False, ) - aggregate = partial( - _aggregate, combine=combine, agg=agg, fill_value=fill_value, reindex=reindex - ) + aggregate = partial(_aggregate, combine=combine, agg=agg, fill_value=fill_value) + + # Each chunk of `reduced`` is really a dict mapping + # 1. reduction name to array + # 2. "groups" to an array of group labels + # Note: it does not make sense to interpret axis relative to + # shape of intermediate results after the blockwise call if method == "map-reduce": reduced = tree_reduce( intermediate, - aggregate=partial(aggregate, expected_groups=expected_groups), + combine=partial(combine, agg=agg), + aggregate=partial(aggregate, expected_groups=expected_groups, reindex=reindex), ) if is_duck_dask_array(by_input) and expected_groups is None: groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype) @@ -1310,23 +1335,17 @@ def dask_groupby_agg( reduced_ = [] groups_ = [] for blks, cohort in chunks_cohorts.items(): + index = pd.Index(cohort) subset = subset_to_blocks(intermediate, blks, array.blocks.shape[-len(axis) :]) - if do_simple_combine: - # reindex so that reindex can be set to True later - reindexed = dask.array.map_blocks( - reindex_intermediates, - subset, - agg=agg, - unique_groups=cohort, - meta=subset._meta, - ) - else: - reindexed = subset - + reindexed = dask.array.map_blocks( + reindex_intermediates, subset, agg=agg, unique_groups=index, meta=subset._meta + ) + # now that we have reindexed, we can set reindex=True explicitlly reduced_.append( tree_reduce( reindexed, - aggregate=partial(aggregate, expected_groups=cohort, reindex=reindex), + combine=partial(combine, agg=agg, reindex=True), + aggregate=partial(aggregate, expected_groups=index, reindex=True), ) ) groups_.append(cohort) @@ -1382,18 +1401,17 @@ def _validate_reindex( if reindex is True: if _is_arg_reduction(func): raise NotImplementedError - if method == "blockwise": - raise NotImplementedError + if method in ["blockwise", "cohorts"]: + raise ValueError( + "reindex=True is not a valid choice for method='blockwise' or method='cohorts'." + ) if reindex is None: if method == "blockwise" or _is_arg_reduction(func): reindex = False - elif expected_groups is not None: - reindex = True - - elif method in ["split-reduce", "cohorts"]: - reindex = True + elif method == "cohorts": + reindex = False elif method == "map-reduce": if expected_groups is None and by_is_dask: @@ -1401,9 +1419,6 @@ def _validate_reindex( else: reindex = True - if method in ["split-reduce", "cohorts"] and reindex is False: - raise NotImplementedError - assert isinstance(reindex, bool) return reindex diff --git a/tests/test_core.py b/tests/test_core.py index ee929633a..4254b84a5 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,5 +1,6 @@ from __future__ import annotations +import itertools from functools import partial, reduce from typing import TYPE_CHECKING @@ -219,29 +220,31 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): assert actual.dtype.kind == "i" assert_equal(actual, expected, tolerance) - if not has_dask: + if not has_dask or chunks is None: continue - for method in ["map-reduce", "cohorts", "split-reduce"]: - if method == "map-reduce": - reindexes = [True, False, None] + + params = list(itertools.product(["map-reduce"], [True, False, None])) + params.extend(itertools.product(["cohorts"], [False, None])) + for method, reindex in params: + call = partial( + groupby_reduce, array, *by, method=method, reindex=reindex, **flox_kwargs + ) + if "arg" in func and reindex is True: + # simple_combine with argreductions not supported right now + with pytest.raises(NotImplementedError): + call() + continue + actual, *groups = call() + if "arg" not in func: + # make sure we use simple combine + assert any("simple-combine" in key for key in actual.dask.layers.keys()) else: - reindexes = [None] - for reindex in reindexes: - call = partial( - groupby_reduce, array, *by, method=method, reindex=reindex, **flox_kwargs - ) - if "arg" in func: - if method != "map-reduce" or reindex is True: - with pytest.raises(NotImplementedError): - call() - continue - - actual, *groups = call() - for actual_group, expect in zip(groups, expected_groups): - assert_equal(actual_group, expect, tolerance) - if "arg" in func: - assert actual.dtype.kind == "i" - assert_equal(actual, expected, tolerance) + assert any("grouped-combine" in key for key in actual.dask.layers.keys()) + for actual_group, expect in zip(groups, expected_groups): + assert_equal(actual_group, expect, tolerance) + if "arg" in func: + assert actual.dtype.kind == "i" + assert_equal(actual, expected, tolerance) @requires_dask @@ -1140,7 +1143,6 @@ def test_subset_block_2d(flatblocks, expectidx): assert_equal(subset, array.compute()[expectidx]) -@pytest.mark.parametrize("method", ["map-reduce", "cohorts"]) @pytest.mark.parametrize( "expected, reindex, func, expected_groups, by_is_dask", [ @@ -1158,13 +1160,20 @@ def test_subset_block_2d(flatblocks, expectidx): [True, None, "sum", ([1], None), True], ], ) -def test_validate_reindex(expected, reindex, func, method, expected_groups, by_is_dask): - if by_is_dask and method == "cohorts": - # This should error elsewhere - pytest.skip() - call = partial(_validate_reindex, reindex, func, method, expected_groups, by_is_dask) - if "arg" in func and method == "cohorts": +def test_validate_reindex_map_reduce(expected, reindex, func, expected_groups, by_is_dask): + actual = _validate_reindex(reindex, func, "map-reduce", expected_groups, by_is_dask) + assert actual == expected + + +def test_validate_reindex(): + for method in ["map-reduce", "cohorts"]: with pytest.raises(NotImplementedError): - call() - else: - assert call() == expected + _validate_reindex(True, "argmax", method, expected_groups=None, by_is_dask=False) + + for method in ["blockwise", "cohorts"]: + with pytest.raises(ValueError): + _validate_reindex(True, "sum", method, expected_groups=None, by_is_dask=False) + + for func in ["sum", "argmax"]: + actual = _validate_reindex(None, func, method, expected_groups=None, by_is_dask=False) + assert actual is False From 9af06927db06af0b40dd0f3640488ba8a170e061 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 25 Oct 2022 12:43:38 -0600 Subject: [PATCH 2/2] Fix bug with all NaN blocks --- flox/core.py | 6 +++--- tests/test_core.py | 14 ++++++++++---- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/flox/core.py b/flox/core.py index 68afd5c6a..933ef8c22 100644 --- a/flox/core.py +++ b/flox/core.py @@ -136,7 +136,7 @@ def _get_optimal_chunks_for_groups(chunks, labels): return tuple(newchunks) -def _unique(a: np.ndarray): +def _unique(a: np.ndarray) -> np.ndarray: """Much faster to use pandas unique and sort the results. np.unique sorts before uniquifying and is slow.""" return np.sort(pd.unique(a.reshape(-1))) @@ -816,7 +816,7 @@ def _expand_dims(results: IntermediateDict) -> IntermediateDict: return results -def _find_unique_groups(x_chunk): +def _find_unique_groups(x_chunk) -> np.ndarray: from dask.base import flatten from dask.utils import deepmap @@ -824,7 +824,7 @@ def _find_unique_groups(x_chunk): unique_groups = unique_groups[~isnull(unique_groups)] if len(unique_groups) == 0: - unique_groups = [np.nan] + unique_groups = np.array([np.nan]) return unique_groups diff --git a/tests/test_core.py b/tests/test_core.py index 4254b84a5..d1db6fab9 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -517,7 +517,7 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine): assert_equal(actual, expected, tolerance) -@pytest.mark.parametrize("chunks", [None, (2, 2, 3)]) +@pytest.mark.parametrize("reindex,chunks", [(None, None), (False, (2, 2, 3)), (True, (2, 2, 3))]) @pytest.mark.parametrize( "axis, groups, expected_shape", [ @@ -526,7 +526,7 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine): (None, [0], (1,)), # global reduction; 0 shaped group axis; 1 group ], ) -def test_groupby_reduce_nans(chunks, axis, groups, expected_shape, engine): +def test_groupby_reduce_nans(reindex, chunks, axis, groups, expected_shape, engine): def _maybe_chunk(arr): if chunks: if not has_dask: @@ -549,6 +549,7 @@ def _maybe_chunk(arr): axis=axis, fill_value=0, engine=engine, + reindex=reindex, ) assert_equal(result, np.zeros(expected_shape, dtype=np.intp)) @@ -561,7 +562,10 @@ def _maybe_chunk(arr): @requires_dask -def test_groupby_all_nan_blocks(engine): +@pytest.mark.parametrize( + "expected_groups, reindex", [(None, None), ([0, 1, 2], True), ([0, 1, 2], False)] +) +def test_groupby_all_nan_blocks(expected_groups, reindex, engine): labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0]) nan_labels = labels.astype(float) # copy nan_labels[:5] = np.nan @@ -576,8 +580,10 @@ def test_groupby_all_nan_blocks(engine): da.from_array(array, chunks=(1, 3)), da.from_array(by, chunks=(1, 3)), func="sum", - expected_groups=None, + expected_groups=expected_groups, engine=engine, + reindex=reindex, + method="map-reduce", ) assert_equal(actual, expected)