Skip to content

Performance improvements for cohorts detection #172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions asv_bench/benchmarks/cohorts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import dask
import numpy as np
import pandas as pd

import flox


class Cohorts:
"""Time the core reduction function."""

def setup(self, *args, **kwargs):
raise NotImplementedError

def time_find_group_cohorts(self):
flox.core.find_group_cohorts(self.by, self.array.chunks)
# The cache clear fails dependably in CI
# Not sure why
try:
flox.cache.cache.clear()
except AttributeError:
pass

def time_graph_construct(self):
flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis, method="cohorts")

def track_num_tasks(self):
result = flox.groupby_reduce(
self.array, self.by, func="sum", axis=self.axis, method="cohorts"
)[0]
return len(result.dask.to_dict())

track_num_tasks.unit = "tasks"


class NWMMidwest(Cohorts):
"""2D labels, ireregular w.r.t chunk size.
Mimics National Weather Model, Midwest county groupby."""

def setup(self, *args, **kwargs):
x = np.repeat(np.arange(30), 150)
y = np.repeat(np.arange(30), 60)
self.by = x[np.newaxis, :] * y[:, np.newaxis]

self.array = dask.array.ones(self.by.shape, chunks=(350, 350))
self.axis = (-2, -1)


class ERA5(Cohorts):
"""ERA5"""

def setup(self, *args, **kwargs):
time = pd.Series(pd.date_range("2016-01-01", "2018-12-31 23:59", freq="H"))

self.by = time.dt.dayofyear.values
self.axis = (-1,)

array = dask.array.random.random((721, 1440, len(time)), chunks=(-1, -1, 48))
self.array = flox.core.rechunk_for_cohorts(
array, -1, self.by, force_new_chunk_at=[1], chunksize=48, ignore_old_chunks=True
)
1 change: 1 addition & 0 deletions flox/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python
# flake8: noqa
"""Top-level module for flox ."""
from . import cache
from .aggregations import Aggregation # noqa
from .core import groupby_reduce, rechunk_for_blockwise, rechunk_for_cohorts # noqa

Expand Down
1 change: 1 addition & 0 deletions flox/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
cache = cachey.Cache(1e6)
memoize = partial(cache.memoize, key=dask.base.tokenize)
except ImportError:
cache = {}
memoize = lambda x: x # type: ignore
21 changes: 12 additions & 9 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _collapse_axis(arr: np.ndarray, naxis: int) -> np.ndarray:
def _get_optimal_chunks_for_groups(chunks, labels):
chunkidx = np.cumsum(chunks) - 1
# what are the groups at chunk boundaries
labels_at_chunk_bounds = np.unique(labels[chunkidx])
labels_at_chunk_bounds = _unique(labels[chunkidx])
# what's the last index of all groups
last_indexes = npg.aggregate_numpy.aggregate(labels, np.arange(len(labels)), func="last")
# what's the last index of groups at the chunk boundaries.
Expand Down Expand Up @@ -136,6 +136,12 @@ def _get_optimal_chunks_for_groups(chunks, labels):
return tuple(newchunks)


def _unique(a):
"""Much faster to use pandas unique and sort the results.
np.unique sorts before uniquifying and is slow."""
return np.sort(pd.unique(a))


@memoize
def find_group_cohorts(labels, chunks, merge: bool = True):
"""
Expand Down Expand Up @@ -180,14 +186,11 @@ def find_group_cohorts(labels, chunks, merge: bool = True):
blocks[idx] = np.full(tuple(block.shape[ax] for ax in axis), idx)
which_chunk = np.block(blocks.reshape(shape).tolist()).reshape(-1)

# We always drop NaN; np.unique also considers every NaN to be different so
# it's really important we get rid of them.
raveled = labels.reshape(-1)
unique_labels = np.unique(raveled[~isnull(raveled)])
# these are chunks where a label is present
label_chunks = {lab: tuple(np.unique(which_chunk[raveled == lab])) for lab in unique_labels}
label_chunks = pd.Series(which_chunk).groupby(raveled).unique()
# These invert the label_chunks mapping so we know which labels occur together.
chunks_cohorts = tlz.groupby(label_chunks.get, label_chunks.keys())
chunks_cohorts = tlz.groupby(lambda x: tuple(label_chunks.get(x)), label_chunks.keys())

if merge:
# First sort by number of chunks occupied by cohort
Expand Down Expand Up @@ -892,7 +895,7 @@ def _grouped_combine(
# when there's only a single axis of reduction, we can just concatenate later,
# reindexing is unnecessary
# I bet we can minimize the amount of reindexing for mD reductions too, but it's complicated
unique_groups = np.unique(tuple(flatten(deepmap(listify_groups, x_chunk))))
unique_groups = _unique(tuple(flatten(deepmap(listify_groups, x_chunk))))
unique_groups = unique_groups[~isnull(unique_groups)]
if len(unique_groups) == 0:
unique_groups = [np.nan]
Expand Down Expand Up @@ -1065,7 +1068,7 @@ def subset_to_blocks(
unraveled = np.unravel_index(flatblocks, blkshape)
normalized: list[Union[int, np.ndarray, slice]] = []
for ax, idx in enumerate(unraveled):
i = np.unique(idx).squeeze()
i = _unique(idx).squeeze()
if i.ndim == 0:
normalized.append(i.item())
else:
Expand Down Expand Up @@ -1310,7 +1313,7 @@ def dask_groupby_agg(
# along the reduced axis
slices = slices_from_chunks(tuple(array.chunks[ax] for ax in axis))
if expected_groups is None:
groups_in_block = tuple(np.unique(by_input[slc]) for slc in slices)
groups_in_block = tuple(_unique(by_input[slc]) for slc in slices)
else:
# For cohorts, we could be indexing a block with groups that
# are not in the cohort (usually for nD `by`)
Expand Down