Skip to content

Use blockwise to extract final result #182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 27, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 50 additions & 20 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,7 +1195,6 @@ def dask_groupby_agg(

import dask.array
from dask.array.core import slices_from_chunks
from dask.highlevelgraph import HighLevelGraph

# I think _tree_reduce expects this
assert isinstance(axis, Sequence)
Expand Down Expand Up @@ -1268,6 +1267,9 @@ def dask_groupby_agg(
engine=engine,
sort=sort,
),
# output indices are the same as input indices
# Unlike xhistogram, we don't always know what the size of the group
# dimension will be unless reindex=True
inds,
array,
inds,
Expand All @@ -1277,7 +1279,7 @@ def dask_groupby_agg(
dtype=array.dtype, # this is purely for show
meta=array._meta,
align_arrays=False,
token=f"{name}-chunk-{token}",
name=f"{name}-chunk-{token}",
)

if expected_groups is None:
Expand Down Expand Up @@ -1364,35 +1366,63 @@ def dask_groupby_agg(
groups = (np.concatenate(groups_in_block),)
ngroups_per_block = tuple(len(grp) for grp in groups_in_block)
group_chunks = (ngroups_per_block,)

else:
raise ValueError(f"Unknown method={method}.")

# extract results from the dict
out_inds = inds[: -len(axis)] + (inds[-1],)
output_chunks = reduced.chunks[: -len(axis)] + group_chunks
if method == "blockwise" and len(axis) > 1:
# The final results are available but the blocks along axes
# need to be reshaped to axis=-1
# I don't know that this is possible with blockwise
# All other code paths benefit from an unmaterialized Blockwise layer
reduced = _collapse_blocks_along_axes(reduced, axis, group_chunks)

# Can't use map_blocks because it forces concatenate=True along drop_axes,
result = dask.array.blockwise(
_extract_result,
out_inds,
reduced,
inds,
adjust_chunks=dict(zip(out_inds, output_chunks)),
dtype=agg.dtype[agg.name],
key=agg.name,
name=f"{name}-{token}",
concatenate=False,
)

return (result, groups)


def _collapse_blocks_along_axes(reduced, axis, group_chunks):
import dask.array
from dask.highlevelgraph import HighLevelGraph

nblocks = tuple(reduced.numblocks[ax] for ax in axis)
output_chunks = reduced.chunks[: -len(axis)] + ((1,) * (len(axis) - 1),) + group_chunks

# extract results from the dict
ochunks = tuple(range(len(chunks_v)) for chunks_v in output_chunks)
layer2: dict[tuple, tuple] = {}
agg_name = f"{name}-{token}"
for ochunk in itertools.product(*ochunks):
if method == "blockwise":
if len(axis) == 1:
inchunk = ochunk
else:
nblocks = tuple(len(array.chunks[ax]) for ax in axis)
inchunk = ochunk[:-1] + np.unravel_index(ochunk[-1], nblocks)
else:
inchunk = ochunk[:-1] + (0,) * (len(axis) - 1) + (ochunk[-1],)
name = f"reshape-{reduced.name}"

layer2[(agg_name, *ochunk)] = (operator.getitem, (reduced.name, *inchunk), agg.name)
for ochunk in itertools.product(*ochunks):
inchunk = ochunk[: -len(axis)] + np.unravel_index(ochunk[-1], nblocks)
layer2[(name, *ochunk)] = (reduced.name, *inchunk)

result = dask.array.Array(
HighLevelGraph.from_collections(agg_name, layer2, dependencies=[reduced]),
agg_name,
return dask.array.Array(
HighLevelGraph.from_collections(name, layer2, dependencies=[reduced]),
name,
chunks=output_chunks,
dtype=agg.dtype[agg.name],
dtype=reduced.dtype,
)

return (result, groups)

def _extract_result(result_dict: FinalResultsDict, key) -> np.ndarray:
from dask.array.core import deepfirst

# deepfirst should be not be needed here but sometimes we receive a list of dict?
return deepfirst(result_dict)[key]


def _validate_reindex(
Expand Down