Skip to content

Fix mypy errors in core.py - groupby_reduce #153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 87 additions & 66 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Callable,
Dict,
Iterable,
Literal,
Mapping,
Sequence,
Union,
Expand All @@ -36,6 +37,17 @@
if TYPE_CHECKING:
import dask.array.Array as DaskArray

T_Func = Union[str, Callable]
T_Funcs = Union[T_Func, Sequence[T_Func]]
T_Axis = int
T_Axiss = tuple[T_Axis, ...] # TODO: var name grammar?
T_AxissOpt = Union[T_Axis, T_Axiss, None]
T_Dtypes = Union[np.typing.DTypeLike, Sequence[np.typing.DTypeLike], None]
T_FillValues = Union[np.typing.ArrayLike, Sequence[np.typing.ArrayLike], None]
T_Engine = Literal["flox", "numpy", "numba"]
T_Method = Literal["map-reduce", "blockwise", "cohorts", "split-reduce"]
T_IsBins = Union[bool | Sequence[bool]]


IntermediateDict = Dict[Union[str, Callable], Any]
FinalResultsDict = Dict[str, Union["DaskArray", np.ndarray]]
Expand Down Expand Up @@ -1042,14 +1054,14 @@ def dask_groupby_agg(
by: DaskArray | np.ndarray,
agg: Aggregation,
expected_groups: pd.Index | None,
axis: Sequence = None,
axis: T_Axiss = (),
split_out: int = 1,
fill_value: Any = None,
method: str = "map-reduce",
method: T_Method = "map-reduce",
reindex: bool = False,
engine: str = "numpy",
engine: T_Engine = "numpy",
sort: bool = True,
) -> tuple[DaskArray, np.ndarray | DaskArray]:
) -> tuple[DaskArray, tuple[np.ndarray | DaskArray]]:

import dask.array
from dask.array.core import slices_from_chunks
Expand Down Expand Up @@ -1161,11 +1173,11 @@ def dask_groupby_agg(
# these are negative axis indices useful for concatenating the intermediates
neg_axis = tuple(range(-len(axis), 0))

combine = (
_simple_combine
if do_simple_combine
else partial(_grouped_combine, engine=engine, neg_axis=neg_axis, sort=sort)
)
combine: Callable[..., IntermediateDict]
if do_simple_combine:
combine = _simple_combine
else:
combine = partial(_grouped_combine, engine=engine, neg_axis=neg_axis, sort=sort)

# reduced is really a dict mapping reduction name to array
# and "groups" to an array of group labels
Expand Down Expand Up @@ -1204,13 +1216,12 @@ def dask_groupby_agg(
groups_in_block = tuple(
np.intersect1d(by_input[slc], expected_groups) for slc in slices
)
ngroups_per_block = tuple(len(groups) for groups in groups_in_block)
ngroups_per_block = tuple(len(grp) for grp in groups_in_block)
output_chunks = reduced.chunks[: -(len(axis))] + (ngroups_per_block,)
else:
raise ValueError(f"Unknown method={method}.")

# extract results from the dict
result: dict = {}
layer: dict[tuple, tuple] = {}
ochunks = tuple(range(len(chunks_v)) for chunks_v in output_chunks)
if is_duck_dask_array(by_input) and expected_groups is None:
Expand All @@ -1222,7 +1233,7 @@ def dask_groupby_agg(
(reduced.name, *first_block),
"groups",
)
groups = (
groups: tuple[np.ndarray | DaskArray] = (
dask.array.Array(
HighLevelGraph.from_collections(groups_name, layer, dependencies=[reduced]),
groups_name,
Expand All @@ -1233,12 +1244,14 @@ def dask_groupby_agg(
else:
if method == "map-reduce":
if expected_groups is None:
expected_groups = _get_expected_groups(by_input, sort=sort)
groups = (expected_groups.to_numpy(),)
expected_groups_ = _get_expected_groups(by_input, sort=sort)
else:
expected_groups_ = expected_groups
groups = (expected_groups_.to_numpy(),)
else:
groups = (np.concatenate(groups_in_block),)

layer: dict[tuple, tuple] = {} # type: ignore
layer2: dict[tuple, tuple] = {}
agg_name = f"{name}-{token}"
for ochunk in itertools.product(*ochunks):
if method == "blockwise":
Expand All @@ -1249,16 +1262,16 @@ def dask_groupby_agg(
inchunk = ochunk[:-1] + np.unravel_index(ochunk[-1], nblocks)
else:
inchunk = ochunk[:-1] + (0,) * len(axis) + (ochunk[-1],) * int(split_out > 1)
layer[(agg_name, *ochunk)] = (operator.getitem, (reduced.name, *inchunk), agg.name)
layer2[(agg_name, *ochunk)] = (operator.getitem, (reduced.name, *inchunk), agg.name)

result = dask.array.Array(
HighLevelGraph.from_collections(agg_name, layer, dependencies=[reduced]),
HighLevelGraph.from_collections(agg_name, layer2, dependencies=[reduced]),
agg_name,
chunks=output_chunks,
dtype=agg.dtype[agg.name],
)

return (result, *groups)
return (result, groups)


def _validate_reindex(reindex: bool, func, method, expected_groups) -> bool:
Expand Down Expand Up @@ -1358,13 +1371,13 @@ def groupby_reduce(
func: str | Aggregation,
expected_groups: Sequence | np.ndarray | None = None,
sort: bool = True,
isbin: bool = False,
axis=None,
isbin: T_IsBins = False,
axis: T_AxissOpt = None,
fill_value=None,
min_count: int | None = None,
split_out: int = 1,
method: str = "map-reduce",
engine: str = "numpy",
method: T_Method = "map-reduce",
engine: T_Engine = "numpy",
reindex: bool | None = None,
finalize_kwargs: Mapping | None = None,
) -> tuple[DaskArray, np.ndarray | DaskArray]:
Expand Down Expand Up @@ -1469,9 +1482,9 @@ def groupby_reduce(
)
reindex = _validate_reindex(reindex, func, method, expected_groups)

by: tuple = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by)
nby = len(by)
by_is_dask = any(is_duck_dask_array(b) for b in by)
bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by)
nby = len(bys)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be nice to get these "renaming" changes in a separate PR to make it easier to review.

by_is_dask = any(is_duck_dask_array(b) for b in bys)

if method in ["split-reduce", "cohorts"] and by_is_dask:
raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.")
Expand All @@ -1480,54 +1493,58 @@ def groupby_reduce(
array = np.asarray(array)
array = array.astype(int) if np.issubdtype(array.dtype, bool) else array

if isinstance(isbin, bool):
isbin = (isbin,) * len(by)
if isinstance(isbin, Sequence):
isbins = isbin
else:
isbins = (isbin,) * nby
if expected_groups is None:
expected_groups = (None,) * len(by)
expected_groups = (None,) * nby

_assert_by_is_aligned(array.shape, by)
_assert_by_is_aligned(array.shape, bys)

if len(by) == 1 and not isinstance(expected_groups, tuple):
if nby == 1 and not isinstance(expected_groups, tuple):
expected_groups = (np.asarray(expected_groups),)
elif len(expected_groups) != len(by):
elif len(expected_groups) != nby:
raise ValueError(
f"Must have same number of `expected_groups` (received {len(expected_groups)}) "
f" and variables to group by (received {len(by)})."
f" and variables to group by (received {nby})."
)

# We convert to pd.Index since that lets us know if we are binning or not
# (pd.IntervalIndex or not)
expected_groups = _convert_expected_groups_to_index(expected_groups, isbin, sort)
expected_groups = _convert_expected_groups_to_index(expected_groups, isbins, sort)

# TODO: could restrict this to dask-only
factorize_early = (nby > 1) or (
any(isbin) and method in ["split-reduce", "cohorts"] and is_duck_dask_array(array)
any(isbins) and method in ["split-reduce", "cohorts"] and is_duck_dask_array(array)
)
if factorize_early:
by, final_groups, grp_shape = _factorize_multiple(
by, expected_groups, by_is_dask=by_is_dask, reindex=reindex
bys, final_groups, grp_shape = _factorize_multiple(
bys, expected_groups, by_is_dask=by_is_dask, reindex=reindex
)
expected_groups = (pd.RangeIndex(math.prod(grp_shape)),)

assert len(by) == 1
by = by[0]
assert len(bys) == 1
by_ = bys[0]
expected_groups = expected_groups[0]

if axis is None:
axis = tuple(array.ndim + np.arange(-by.ndim, 0))
axis_ = tuple(array.ndim + np.arange(-by_.ndim, 0))
else:
axis = np.core.numeric.normalize_axis_tuple(axis, array.ndim) # type: ignore
# TODO: How come this function doesn't exist according to mypy?
axis_ = np.core.numeric.normalize_axis_tuple(axis, array.ndim) # type: ignore
nax = len(axis_)

if method in ["blockwise", "cohorts", "split-reduce"] and len(axis) != by.ndim:
if method in ["blockwise", "cohorts", "split-reduce"] and nax != by_.ndim:
raise NotImplementedError(
"Must reduce along all dimensions of `by` when method != 'map-reduce'."
f"Received method={method!r}"
)

# TODO: make sure expected_groups is unique
if len(axis) == 1 and by.ndim > 1 and expected_groups is None:
if nax == 1 and by_.ndim > 1 and expected_groups is None:
if not by_is_dask:
expected_groups = _get_expected_groups(by, sort)
expected_groups = _get_expected_groups(by_, sort)
else:
# When we reduce along all axes, we are guaranteed to see all
# groups in the final combine stage, so everything works.
Expand All @@ -1540,13 +1557,14 @@ def groupby_reduce(
"Please provide ``expected_groups`` when not reducing along all axes."
)

assert len(axis) <= by.ndim
if len(axis) < by.ndim:
by = _move_reduce_dims_to_end(by, -array.ndim + np.array(axis) + by.ndim)
array = _move_reduce_dims_to_end(array, axis)
axis = tuple(array.ndim + np.arange(-len(axis), 0))
assert nax <= by_.ndim
if nax < by_.ndim:
by_ = _move_reduce_dims_to_end(by_, tuple(-array.ndim + ax + by_.ndim for ax in axis_))
array = _move_reduce_dims_to_end(array, axis_)
axis_ = tuple(array.ndim + np.arange(-nax, 0))
nax = len(axis_)

has_dask = is_duck_dask_array(array) or is_duck_dask_array(by)
has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)

# When axis is a subset of possible values; then npg will
# apply it to groups that don't exist along a particular axis (for e.g.)
Expand All @@ -1555,7 +1573,7 @@ def groupby_reduce(
# The only way to do this consistently is mask out using min_count
# Consider np.sum([np.nan]) = np.nan, np.nansum([np.nan]) = 0
if min_count is None:
if len(axis) < by.ndim or fill_value is not None:
if nax < by_.ndim or fill_value is not None:
min_count = 1

# TODO: set in xarray?
Expand All @@ -1564,20 +1582,23 @@ def groupby_reduce(
# overwrite than when min_count is set
fill_value = np.nan

kwargs = dict(axis=axis, fill_value=fill_value, engine=engine)
kwargs = dict(axis=axis_, fill_value=fill_value, engine=engine)
agg = _initialize_aggregation(func, array.dtype, fill_value, min_count, finalize_kwargs)

if not has_dask:
results = _reduce_blockwise(
array, by, agg, expected_groups=expected_groups, reindex=reindex, sort=sort, **kwargs
array, by_, agg, expected_groups=expected_groups, reindex=reindex, sort=sort, **kwargs
)
groups = (results["groups"],)
result = results[agg.name]

else:
if TYPE_CHECKING:
assert isinstance(array, DaskArray) # TODO: How else to narrow that .chunk is there?

if agg.chunk[0] is None and method != "blockwise":
raise NotImplementedError(
f"Aggregation {func.name!r} is only implemented for dask arrays when method='blockwise'."
f"Aggregation {agg.name!r} is only implemented for dask arrays when method='blockwise'."
f"\n\n Received: {func}"
)

Expand All @@ -1589,25 +1610,25 @@ def groupby_reduce(

if method in ["split-reduce", "cohorts"]:
cohorts = find_group_cohorts(
by, [array.chunks[ax] for ax in axis], merge=True, method=method
by_, [array.chunks[ax] for ax in axis_], merge=True, method=method
)

results = []
results_ = []
groups_ = []
for cohort in cohorts:
cohort = sorted(cohort)
# equivalent of xarray.DataArray.where(mask, drop=True)
mask = np.isin(by, cohort)
mask = np.isin(by_, cohort)
indexer = [np.unique(v) for v in np.nonzero(mask)]
array_subset = array
for ax, idxr in zip(range(-by.ndim, 0), indexer):
for ax, idxr in zip(range(-by_.ndim, 0), indexer):
array_subset = np.take(array_subset, idxr, axis=ax)
numblocks = math.prod([len(array_subset.chunks[ax]) for ax in axis])
numblocks = math.prod([len(array_subset.chunks[ax]) for ax in axis_])

# get final result for these groups
r, *g = partial_agg(
array_subset,
by[np.ix_(*indexer)],
by_[np.ix_(*indexer)],
expected_groups=pd.Index(cohort),
# First deep copy becasue we might be doping blockwise,
# which sets agg.finalize=None, then map-reduce (GH102)
Expand All @@ -1620,22 +1641,22 @@ def groupby_reduce(
sort=False,
# if only a single block along axis, we can just work blockwise
# inspired by https://github.com/dask/dask/issues/8361
method="blockwise" if numblocks == 1 and len(axis) == by.ndim else "map-reduce",
method="blockwise" if numblocks == 1 and nax == by_.ndim else "map-reduce",
)
results.append(r)
results_.append(r)
groups_.append(cohort)

# concatenate results together,
# sort to make sure we match expected output
groups = (np.hstack(groups_),)
result = np.concatenate(results, axis=-1)
result = np.concatenate(results_, axis=-1)
else:
if method == "blockwise" and by.ndim == 1:
array = rechunk_for_blockwise(array, axis=-1, labels=by)
if method == "blockwise" and by_.ndim == 1:
array = rechunk_for_blockwise(array, axis=-1, labels=by_)

result, *groups = partial_agg(
result, groups = partial_agg(
array,
by,
by_,
expected_groups=None if method == "blockwise" else expected_groups,
agg=agg,
reindex=reindex,
Expand Down