Skip to content

Commit 27a4e9a

Browse files
authored
Try cleaning up some expected_groups logic (#175)
* Try cleaning up some expected_groups logic * Fix _extract_unknown_groups * Fixes
1 parent a70c5dd commit 27a4e9a

File tree

1 file changed

+8
-15
lines changed

1 file changed

+8
-15
lines changed

flox/core.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,7 +1164,7 @@ def subset_to_blocks(
11641164
return dask.array.Array(graph, name, chunks, meta=array)
11651165

11661166

1167-
def _extract_unknown_groups(reduced, group_chunks, dtype) -> tuple[DaskArray]:
1167+
def _extract_unknown_groups(reduced, dtype) -> tuple[DaskArray]:
11681168
import dask.array
11691169
from dask.highlevelgraph import HighLevelGraph
11701170

@@ -1180,7 +1180,7 @@ def _extract_unknown_groups(reduced, group_chunks, dtype) -> tuple[DaskArray]:
11801180
dask.array.Array(
11811181
HighLevelGraph.from_collections(groups_token, layer, dependencies=[reduced]),
11821182
groups_token,
1183-
chunks=group_chunks,
1183+
chunks=((np.nan,),),
11841184
meta=np.array([], dtype=dtype),
11851185
),
11861186
)
@@ -1293,14 +1293,7 @@ def dask_groupby_agg(
12931293
name=f"{name}-chunk-{token}",
12941294
)
12951295

1296-
if expected_groups is None:
1297-
if is_duck_dask_array(by_input):
1298-
expected_groups = None
1299-
else:
1300-
expected_groups = _get_expected_groups(by_input, sort=sort)
1301-
group_chunks: tuple[tuple[Union[int, float], ...]] = (
1302-
(len(expected_groups),) if expected_groups is not None else (np.nan,),
1303-
)
1296+
group_chunks: tuple[tuple[Union[int, float], ...]]
13041297

13051298
if method in ["map-reduce", "cohorts"]:
13061299
combine: Callable[..., IntermediateDict]
@@ -1333,13 +1326,13 @@ def dask_groupby_agg(
13331326
aggregate=partial(aggregate, expected_groups=expected_groups, reindex=reindex),
13341327
)
13351328
if is_duck_dask_array(by_input) and expected_groups is None:
1336-
groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype)
1329+
groups = _extract_unknown_groups(reduced, dtype=by.dtype)
1330+
group_chunks = ((np.nan,),)
13371331
else:
13381332
if expected_groups is None:
1339-
expected_groups_ = _get_expected_groups(by_input, sort=sort)
1340-
else:
1341-
expected_groups_ = expected_groups
1342-
groups = (expected_groups_.to_numpy(),)
1333+
expected_groups = _get_expected_groups(by_input, sort=sort)
1334+
groups = (expected_groups.to_numpy(),)
1335+
group_chunks = ((len(expected_groups),),)
13431336

13441337
elif method == "cohorts":
13451338
chunks_cohorts = find_group_cohorts(

0 commit comments

Comments
 (0)