diff --git a/flox/core.py b/flox/core.py index 58b89bf17..0ddf608ed 100644 --- a/flox/core.py +++ b/flox/core.py @@ -54,11 +54,9 @@ def _is_arg_reduction(func: str | Aggregation) -> bool: return False -def _get_expected_groups(by, sort, *, raise_if_dask=True) -> pd.Index | None: +def _get_expected_groups(by, sort: bool) -> pd.Index: if is_duck_dask_array(by): - if raise_if_dask: - raise ValueError("Please provide expected_groups if not grouping by a numpy array.") - return None + raise ValueError("Please provide expected_groups if not grouping by a numpy array.") flatby = by.reshape(-1) expected = pd.unique(flatby[~isnull(flatby)]) return _convert_expected_groups_to_index((expected,), isbin=(False,), sort=sort)[0] @@ -1152,7 +1150,10 @@ def dask_groupby_agg( else: intermediate = applied if expected_groups is None: - expected_groups = _get_expected_groups(by_input, sort=sort, raise_if_dask=False) + if is_duck_dask_array(by_input): + expected_groups = None + else: + expected_groups = _get_expected_groups(by_input, sort=sort) group_chunks = ((len(expected_groups),) if expected_groups is not None else (np.nan,),) if method == "map-reduce": diff --git a/flox/xarray.py b/flox/xarray.py index c02959485..5f87bafe6 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -313,7 +313,7 @@ def xarray_reduce( f"Please provided bin edges for group variable {idx} " f"named {group_name} in expected_groups." ) - expect_ = _get_expected_groups(b_.data, sort=sort, raise_if_dask=True) + expect_ = _get_expected_groups(b_.data, sort=sort) else: expect_ = expect expect_index = _convert_expected_groups_to_index((expect_,), (isbin_,), sort=sort)[0]