diff --git a/flox/core.py b/flox/core.py index fe6bbe475..24480341d 100644 --- a/flox/core.py +++ b/flox/core.py @@ -12,6 +12,7 @@ Callable, Dict, Iterable, + Literal, Mapping, Sequence, Union, @@ -36,6 +37,17 @@ if TYPE_CHECKING: import dask.array.Array as DaskArray + T_Func = Union[str, Callable] + T_Funcs = Union[T_Func, Sequence[T_Func]] + T_Axis = int + T_Axiss = tuple[T_Axis, ...] # TODO: var name grammar? + T_AxissOpt = Union[T_Axis, T_Axiss, None] + T_Dtypes = Union[np.typing.DTypeLike, Sequence[np.typing.DTypeLike], None] + T_FillValues = Union[np.typing.ArrayLike, Sequence[np.typing.ArrayLike], None] + T_Engine = Literal["flox", "numpy", "numba"] + T_Method = Literal["map-reduce", "blockwise", "cohorts", "split-reduce"] + T_IsBins = Union[bool | Sequence[bool]] + IntermediateDict = Dict[Union[str, Callable], Any] FinalResultsDict = Dict[str, Union["DaskArray", np.ndarray]] @@ -1042,14 +1054,14 @@ def dask_groupby_agg( by: DaskArray | np.ndarray, agg: Aggregation, expected_groups: pd.Index | None, - axis: Sequence = None, + axis: T_Axiss = (), split_out: int = 1, fill_value: Any = None, - method: str = "map-reduce", + method: T_Method = "map-reduce", reindex: bool = False, - engine: str = "numpy", + engine: T_Engine = "numpy", sort: bool = True, -) -> tuple[DaskArray, np.ndarray | DaskArray]: +) -> tuple[DaskArray, tuple[np.ndarray | DaskArray]]: import dask.array from dask.array.core import slices_from_chunks @@ -1161,11 +1173,11 @@ def dask_groupby_agg( # these are negative axis indices useful for concatenating the intermediates neg_axis = tuple(range(-len(axis), 0)) - combine = ( - _simple_combine - if do_simple_combine - else partial(_grouped_combine, engine=engine, neg_axis=neg_axis, sort=sort) - ) + combine: Callable[..., IntermediateDict] + if do_simple_combine: + combine = _simple_combine + else: + combine = partial(_grouped_combine, engine=engine, neg_axis=neg_axis, sort=sort) # reduced is really a dict mapping reduction name to array # and "groups" to an array of group labels @@ -1204,13 +1216,12 @@ def dask_groupby_agg( groups_in_block = tuple( np.intersect1d(by_input[slc], expected_groups) for slc in slices ) - ngroups_per_block = tuple(len(groups) for groups in groups_in_block) + ngroups_per_block = tuple(len(grp) for grp in groups_in_block) output_chunks = reduced.chunks[: -(len(axis))] + (ngroups_per_block,) else: raise ValueError(f"Unknown method={method}.") # extract results from the dict - result: dict = {} layer: dict[tuple, tuple] = {} ochunks = tuple(range(len(chunks_v)) for chunks_v in output_chunks) if is_duck_dask_array(by_input) and expected_groups is None: @@ -1222,7 +1233,7 @@ def dask_groupby_agg( (reduced.name, *first_block), "groups", ) - groups = ( + groups: tuple[np.ndarray | DaskArray] = ( dask.array.Array( HighLevelGraph.from_collections(groups_name, layer, dependencies=[reduced]), groups_name, @@ -1233,12 +1244,14 @@ def dask_groupby_agg( else: if method == "map-reduce": if expected_groups is None: - expected_groups = _get_expected_groups(by_input, sort=sort) - groups = (expected_groups.to_numpy(),) + expected_groups_ = _get_expected_groups(by_input, sort=sort) + else: + expected_groups_ = expected_groups + groups = (expected_groups_.to_numpy(),) else: groups = (np.concatenate(groups_in_block),) - layer: dict[tuple, tuple] = {} # type: ignore + layer2: dict[tuple, tuple] = {} agg_name = f"{name}-{token}" for ochunk in itertools.product(*ochunks): if method == "blockwise": @@ -1249,16 +1262,16 @@ def dask_groupby_agg( inchunk = ochunk[:-1] + np.unravel_index(ochunk[-1], nblocks) else: inchunk = ochunk[:-1] + (0,) * len(axis) + (ochunk[-1],) * int(split_out > 1) - layer[(agg_name, *ochunk)] = (operator.getitem, (reduced.name, *inchunk), agg.name) + layer2[(agg_name, *ochunk)] = (operator.getitem, (reduced.name, *inchunk), agg.name) result = dask.array.Array( - HighLevelGraph.from_collections(agg_name, layer, dependencies=[reduced]), + HighLevelGraph.from_collections(agg_name, layer2, dependencies=[reduced]), agg_name, chunks=output_chunks, dtype=agg.dtype[agg.name], ) - return (result, *groups) + return (result, groups) def _validate_reindex(reindex: bool, func, method, expected_groups) -> bool: @@ -1358,13 +1371,13 @@ def groupby_reduce( func: str | Aggregation, expected_groups: Sequence | np.ndarray | None = None, sort: bool = True, - isbin: bool = False, - axis=None, + isbin: T_IsBins = False, + axis: T_AxissOpt = None, fill_value=None, min_count: int | None = None, split_out: int = 1, - method: str = "map-reduce", - engine: str = "numpy", + method: T_Method = "map-reduce", + engine: T_Engine = "numpy", reindex: bool | None = None, finalize_kwargs: Mapping | None = None, ) -> tuple[DaskArray, np.ndarray | DaskArray]: @@ -1469,9 +1482,9 @@ def groupby_reduce( ) reindex = _validate_reindex(reindex, func, method, expected_groups) - by: tuple = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by) - nby = len(by) - by_is_dask = any(is_duck_dask_array(b) for b in by) + bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by) + nby = len(bys) + by_is_dask = any(is_duck_dask_array(b) for b in bys) if method in ["split-reduce", "cohorts"] and by_is_dask: raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.") @@ -1480,54 +1493,58 @@ def groupby_reduce( array = np.asarray(array) array = array.astype(int) if np.issubdtype(array.dtype, bool) else array - if isinstance(isbin, bool): - isbin = (isbin,) * len(by) + if isinstance(isbin, Sequence): + isbins = isbin + else: + isbins = (isbin,) * nby if expected_groups is None: - expected_groups = (None,) * len(by) + expected_groups = (None,) * nby - _assert_by_is_aligned(array.shape, by) + _assert_by_is_aligned(array.shape, bys) - if len(by) == 1 and not isinstance(expected_groups, tuple): + if nby == 1 and not isinstance(expected_groups, tuple): expected_groups = (np.asarray(expected_groups),) - elif len(expected_groups) != len(by): + elif len(expected_groups) != nby: raise ValueError( f"Must have same number of `expected_groups` (received {len(expected_groups)}) " - f" and variables to group by (received {len(by)})." + f" and variables to group by (received {nby})." ) # We convert to pd.Index since that lets us know if we are binning or not # (pd.IntervalIndex or not) - expected_groups = _convert_expected_groups_to_index(expected_groups, isbin, sort) + expected_groups = _convert_expected_groups_to_index(expected_groups, isbins, sort) # TODO: could restrict this to dask-only factorize_early = (nby > 1) or ( - any(isbin) and method in ["split-reduce", "cohorts"] and is_duck_dask_array(array) + any(isbins) and method in ["split-reduce", "cohorts"] and is_duck_dask_array(array) ) if factorize_early: - by, final_groups, grp_shape = _factorize_multiple( - by, expected_groups, by_is_dask=by_is_dask, reindex=reindex + bys, final_groups, grp_shape = _factorize_multiple( + bys, expected_groups, by_is_dask=by_is_dask, reindex=reindex ) expected_groups = (pd.RangeIndex(math.prod(grp_shape)),) - assert len(by) == 1 - by = by[0] + assert len(bys) == 1 + by_ = bys[0] expected_groups = expected_groups[0] if axis is None: - axis = tuple(array.ndim + np.arange(-by.ndim, 0)) + axis_ = tuple(array.ndim + np.arange(-by_.ndim, 0)) else: - axis = np.core.numeric.normalize_axis_tuple(axis, array.ndim) # type: ignore + # TODO: How come this function doesn't exist according to mypy? + axis_ = np.core.numeric.normalize_axis_tuple(axis, array.ndim) # type: ignore + nax = len(axis_) - if method in ["blockwise", "cohorts", "split-reduce"] and len(axis) != by.ndim: + if method in ["blockwise", "cohorts", "split-reduce"] and nax != by_.ndim: raise NotImplementedError( "Must reduce along all dimensions of `by` when method != 'map-reduce'." f"Received method={method!r}" ) # TODO: make sure expected_groups is unique - if len(axis) == 1 and by.ndim > 1 and expected_groups is None: + if nax == 1 and by_.ndim > 1 and expected_groups is None: if not by_is_dask: - expected_groups = _get_expected_groups(by, sort) + expected_groups = _get_expected_groups(by_, sort) else: # When we reduce along all axes, we are guaranteed to see all # groups in the final combine stage, so everything works. @@ -1540,13 +1557,14 @@ def groupby_reduce( "Please provide ``expected_groups`` when not reducing along all axes." ) - assert len(axis) <= by.ndim - if len(axis) < by.ndim: - by = _move_reduce_dims_to_end(by, -array.ndim + np.array(axis) + by.ndim) - array = _move_reduce_dims_to_end(array, axis) - axis = tuple(array.ndim + np.arange(-len(axis), 0)) + assert nax <= by_.ndim + if nax < by_.ndim: + by_ = _move_reduce_dims_to_end(by_, tuple(-array.ndim + ax + by_.ndim for ax in axis_)) + array = _move_reduce_dims_to_end(array, axis_) + axis_ = tuple(array.ndim + np.arange(-nax, 0)) + nax = len(axis_) - has_dask = is_duck_dask_array(array) or is_duck_dask_array(by) + has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_) # When axis is a subset of possible values; then npg will # apply it to groups that don't exist along a particular axis (for e.g.) @@ -1555,7 +1573,7 @@ def groupby_reduce( # The only way to do this consistently is mask out using min_count # Consider np.sum([np.nan]) = np.nan, np.nansum([np.nan]) = 0 if min_count is None: - if len(axis) < by.ndim or fill_value is not None: + if nax < by_.ndim or fill_value is not None: min_count = 1 # TODO: set in xarray? @@ -1564,20 +1582,23 @@ def groupby_reduce( # overwrite than when min_count is set fill_value = np.nan - kwargs = dict(axis=axis, fill_value=fill_value, engine=engine) + kwargs = dict(axis=axis_, fill_value=fill_value, engine=engine) agg = _initialize_aggregation(func, array.dtype, fill_value, min_count, finalize_kwargs) if not has_dask: results = _reduce_blockwise( - array, by, agg, expected_groups=expected_groups, reindex=reindex, sort=sort, **kwargs + array, by_, agg, expected_groups=expected_groups, reindex=reindex, sort=sort, **kwargs ) groups = (results["groups"],) result = results[agg.name] else: + if TYPE_CHECKING: + assert isinstance(array, DaskArray) # TODO: How else to narrow that .chunk is there? + if agg.chunk[0] is None and method != "blockwise": raise NotImplementedError( - f"Aggregation {func.name!r} is only implemented for dask arrays when method='blockwise'." + f"Aggregation {agg.name!r} is only implemented for dask arrays when method='blockwise'." f"\n\n Received: {func}" ) @@ -1589,25 +1610,25 @@ def groupby_reduce( if method in ["split-reduce", "cohorts"]: cohorts = find_group_cohorts( - by, [array.chunks[ax] for ax in axis], merge=True, method=method + by_, [array.chunks[ax] for ax in axis_], merge=True, method=method ) - results = [] + results_ = [] groups_ = [] for cohort in cohorts: cohort = sorted(cohort) # equivalent of xarray.DataArray.where(mask, drop=True) - mask = np.isin(by, cohort) + mask = np.isin(by_, cohort) indexer = [np.unique(v) for v in np.nonzero(mask)] array_subset = array - for ax, idxr in zip(range(-by.ndim, 0), indexer): + for ax, idxr in zip(range(-by_.ndim, 0), indexer): array_subset = np.take(array_subset, idxr, axis=ax) - numblocks = math.prod([len(array_subset.chunks[ax]) for ax in axis]) + numblocks = math.prod([len(array_subset.chunks[ax]) for ax in axis_]) # get final result for these groups r, *g = partial_agg( array_subset, - by[np.ix_(*indexer)], + by_[np.ix_(*indexer)], expected_groups=pd.Index(cohort), # First deep copy becasue we might be doping blockwise, # which sets agg.finalize=None, then map-reduce (GH102) @@ -1620,22 +1641,22 @@ def groupby_reduce( sort=False, # if only a single block along axis, we can just work blockwise # inspired by https://github.com/dask/dask/issues/8361 - method="blockwise" if numblocks == 1 and len(axis) == by.ndim else "map-reduce", + method="blockwise" if numblocks == 1 and nax == by_.ndim else "map-reduce", ) - results.append(r) + results_.append(r) groups_.append(cohort) # concatenate results together, # sort to make sure we match expected output groups = (np.hstack(groups_),) - result = np.concatenate(results, axis=-1) + result = np.concatenate(results_, axis=-1) else: - if method == "blockwise" and by.ndim == 1: - array = rechunk_for_blockwise(array, axis=-1, labels=by) + if method == "blockwise" and by_.ndim == 1: + array = rechunk_for_blockwise(array, axis=-1, labels=by_) - result, *groups = partial_agg( + result, groups = partial_agg( array, - by, + by_, expected_groups=None if method == "blockwise" else expected_groups, agg=agg, reindex=reindex,