Skip to content

Support reindexing in simple_combine #177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions asv_bench/benchmarks/combine.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import numpy as np

import flox
Expand All @@ -7,26 +9,31 @@
N = 1000


def _get_combine(combine):
if combine == "grouped":
return partial(flox.core._grouped_combine, engine="numpy")
else:
return partial(flox.core._simple_combine, reindex=False)


class Combine:
def setup(self, *args, **kwargs):
raise NotImplementedError

@parameterized("kind", ("cohorts", "mapreduce"))
def time_combine(self, kind):
flox.core._grouped_combine(
@parameterized(("kind", "combine"), (("reindexed", "not_reindexed"), ("grouped", "simple")))
def time_combine(self, kind, combine):
_get_combine(combine)(
getattr(self, f"x_chunk_{kind}"),
**self.kwargs,
keepdims=True,
engine="numpy",
)

@parameterized("kind", ("cohorts", "mapreduce"))
def peakmem_combine(self, kind):
flox.core._grouped_combine(
@parameterized(("kind", "combine"), (("reindexed", "not_reindexed"), ("grouped", "simple")))
def peakmem_combine(self, kind, combine):
_get_combine(combine)(
getattr(self, f"x_chunk_{kind}"),
**self.kwargs,
keepdims=True,
engine="numpy",
)


Expand All @@ -47,7 +54,7 @@ def construct_member(groups):
}

# motivated by
self.x_chunk_mapreduce = [
self.x_chunk_not_reindexed = [
construct_member(groups)
for groups in [
np.array((1, 2, 3, 4)),
Expand All @@ -57,5 +64,7 @@ def construct_member(groups):
* 2
]

self.x_chunk_cohorts = [construct_member(groups) for groups in [np.array((1, 2, 3, 4))] * 4]
self.x_chunk_reindexed = [
construct_member(groups) for groups in [np.array((1, 2, 3, 4))] * 4
]
self.kwargs = {"agg": flox.aggregations.mean, "axis": (3,)}
105 changes: 60 additions & 45 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _get_optimal_chunks_for_groups(chunks, labels):
return tuple(newchunks)


def _unique(a: np.ndarray):
def _unique(a: np.ndarray) -> np.ndarray:
"""Much faster to use pandas unique and sort the results.
np.unique sorts before uniquifying and is slow."""
return np.sort(pd.unique(a.reshape(-1)))
Expand Down Expand Up @@ -816,8 +816,25 @@ def _expand_dims(results: IntermediateDict) -> IntermediateDict:
return results


def _find_unique_groups(x_chunk) -> np.ndarray:
from dask.base import flatten
from dask.utils import deepmap

unique_groups = _unique(np.asarray(tuple(flatten(deepmap(listify_groups, x_chunk)))))
unique_groups = unique_groups[~isnull(unique_groups)]

if len(unique_groups) == 0:
unique_groups = np.array([np.nan])
return unique_groups


def _simple_combine(
x_chunk, agg: Aggregation, axis: T_Axes, keepdims: bool, is_aggregate: bool = False
x_chunk,
agg: Aggregation,
axis: T_Axes,
keepdims: bool,
reindex: bool,
is_aggregate: bool = False,
) -> IntermediateDict:
"""
'Simple' combination of blockwise results.
Expand All @@ -830,8 +847,19 @@ def _simple_combine(
4. At the final agggregate step, we squeeze out DUMMY_AXIS
"""
from dask.array.core import deepfirst
from dask.utils import deepmap

if not reindex:
# We didn't reindex at the blockwise step
# So now reindex before combining by reducing along DUMMY_AXIS
unique_groups = _find_unique_groups(x_chunk)
x_chunk = deepmap(
partial(reindex_intermediates, agg=agg, unique_groups=unique_groups), x_chunk
)
else:
unique_groups = deepfirst(x_chunk)["groups"]

results: IntermediateDict = {"groups": deepfirst(x_chunk)["groups"]}
results: IntermediateDict = {"groups": unique_groups}
results["intermediates"] = []
axis_ = axis[:-1] + (DUMMY_AXIS,)
for idx, combine in enumerate(agg.combine):
Expand Down Expand Up @@ -886,7 +914,6 @@ def _grouped_combine(
sort: bool = True,
) -> IntermediateDict:
"""Combine intermediates step of tree reduction."""
from dask.base import flatten
from dask.utils import deepmap

if isinstance(x_chunk, dict):
Expand All @@ -897,11 +924,7 @@ def _grouped_combine(
# when there's only a single axis of reduction, we can just concatenate later,
# reindexing is unnecessary
# I bet we can minimize the amount of reindexing for mD reductions too, but it's complicated
unique_groups = _unique(np.array(tuple(flatten(deepmap(listify_groups, x_chunk)))))
unique_groups = unique_groups[~isnull(unique_groups)]
if len(unique_groups) == 0:
unique_groups = [np.nan]

unique_groups = _find_unique_groups(x_chunk)
x_chunk = deepmap(
partial(reindex_intermediates, agg=agg, unique_groups=unique_groups), x_chunk
)
Expand Down Expand Up @@ -1216,7 +1239,8 @@ def dask_groupby_agg(
# This allows us to discover groups at compute time, support argreductions, lower intermediate
# memory usage (but method="cohorts" would also work to reduce memory in some cases)

do_simple_combine = method != "blockwise" and reindex and not _is_arg_reduction(agg)
do_simple_combine = not _is_arg_reduction(agg)

if method == "blockwise":
# use the "non dask" code path, but applied blockwise
blockwise_method = partial(
Expand Down Expand Up @@ -1268,31 +1292,32 @@ def dask_groupby_agg(
if method in ["map-reduce", "cohorts"]:
combine: Callable[..., IntermediateDict]
if do_simple_combine:
combine = _simple_combine
combine = partial(_simple_combine, reindex=reindex)
combine_name = "simple-combine"
else:
combine = partial(_grouped_combine, engine=engine, sort=sort)
combine_name = "grouped-combine"

# Each chunk of `reduced`` is really a dict mapping
# 1. reduction name to array
# 2. "groups" to an array of group labels
# Note: it does not make sense to interpret axis relative to
# shape of intermediate results after the blockwise call
tree_reduce = partial(
dask.array.reductions._tree_reduce,
combine=partial(combine, agg=agg),
name=f"{name}-reduce-{method}",
name=f"{name}-reduce-{method}-{combine_name}",
dtype=array.dtype,
axis=axis,
keepdims=True,
concatenate=False,
)
aggregate = partial(
_aggregate, combine=combine, agg=agg, fill_value=fill_value, reindex=reindex
)
aggregate = partial(_aggregate, combine=combine, agg=agg, fill_value=fill_value)

# Each chunk of `reduced`` is really a dict mapping
# 1. reduction name to array
# 2. "groups" to an array of group labels
# Note: it does not make sense to interpret axis relative to
# shape of intermediate results after the blockwise call
if method == "map-reduce":
reduced = tree_reduce(
intermediate,
aggregate=partial(aggregate, expected_groups=expected_groups),
combine=partial(combine, agg=agg),
aggregate=partial(aggregate, expected_groups=expected_groups, reindex=reindex),
)
if is_duck_dask_array(by_input) and expected_groups is None:
groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype)
Expand All @@ -1310,23 +1335,17 @@ def dask_groupby_agg(
reduced_ = []
groups_ = []
for blks, cohort in chunks_cohorts.items():
index = pd.Index(cohort)
subset = subset_to_blocks(intermediate, blks, array.blocks.shape[-len(axis) :])
if do_simple_combine:
# reindex so that reindex can be set to True later
reindexed = dask.array.map_blocks(
reindex_intermediates,
subset,
agg=agg,
unique_groups=cohort,
meta=subset._meta,
)
else:
reindexed = subset

reindexed = dask.array.map_blocks(
reindex_intermediates, subset, agg=agg, unique_groups=index, meta=subset._meta
)
# now that we have reindexed, we can set reindex=True explicitlly
reduced_.append(
tree_reduce(
reindexed,
aggregate=partial(aggregate, expected_groups=cohort, reindex=reindex),
combine=partial(combine, agg=agg, reindex=True),
aggregate=partial(aggregate, expected_groups=index, reindex=True),
)
)
groups_.append(cohort)
Expand Down Expand Up @@ -1382,28 +1401,24 @@ def _validate_reindex(
if reindex is True:
if _is_arg_reduction(func):
raise NotImplementedError
if method == "blockwise":
raise NotImplementedError
if method in ["blockwise", "cohorts"]:
raise ValueError(
"reindex=True is not a valid choice for method='blockwise' or method='cohorts'."
)

if reindex is None:
if method == "blockwise" or _is_arg_reduction(func):
reindex = False

elif expected_groups is not None:
reindex = True

elif method in ["split-reduce", "cohorts"]:
reindex = True
elif method == "cohorts":
reindex = False

elif method == "map-reduce":
if expected_groups is None and by_is_dask:
reindex = False
else:
reindex = True

if method in ["split-reduce", "cohorts"] and reindex is False:
raise NotImplementedError

assert isinstance(reindex, bool)
return reindex

Expand Down
Loading