@@ -1164,7 +1164,7 @@ def subset_to_blocks(
1164
1164
return dask .array .Array (graph , name , chunks , meta = array )
1165
1165
1166
1166
1167
- def _extract_unknown_groups (reduced , group_chunks , dtype ) -> tuple [DaskArray ]:
1167
+ def _extract_unknown_groups (reduced , dtype ) -> tuple [DaskArray ]:
1168
1168
import dask .array
1169
1169
from dask .highlevelgraph import HighLevelGraph
1170
1170
@@ -1180,7 +1180,7 @@ def _extract_unknown_groups(reduced, group_chunks, dtype) -> tuple[DaskArray]:
1180
1180
dask .array .Array (
1181
1181
HighLevelGraph .from_collections (groups_token , layer , dependencies = [reduced ]),
1182
1182
groups_token ,
1183
- chunks = group_chunks ,
1183
+ chunks = (( np . nan ,),) ,
1184
1184
meta = np .array ([], dtype = dtype ),
1185
1185
),
1186
1186
)
@@ -1293,14 +1293,7 @@ def dask_groupby_agg(
1293
1293
name = f"{ name } -chunk-{ token } " ,
1294
1294
)
1295
1295
1296
- if expected_groups is None :
1297
- if is_duck_dask_array (by_input ):
1298
- expected_groups = None
1299
- else :
1300
- expected_groups = _get_expected_groups (by_input , sort = sort )
1301
- group_chunks : tuple [tuple [Union [int , float ], ...]] = (
1302
- (len (expected_groups ),) if expected_groups is not None else (np .nan ,),
1303
- )
1296
+ group_chunks : tuple [tuple [Union [int , float ], ...]]
1304
1297
1305
1298
if method in ["map-reduce" , "cohorts" ]:
1306
1299
combine : Callable [..., IntermediateDict ]
@@ -1333,13 +1326,13 @@ def dask_groupby_agg(
1333
1326
aggregate = partial (aggregate , expected_groups = expected_groups , reindex = reindex ),
1334
1327
)
1335
1328
if is_duck_dask_array (by_input ) and expected_groups is None :
1336
- groups = _extract_unknown_groups (reduced , group_chunks = group_chunks , dtype = by .dtype )
1329
+ groups = _extract_unknown_groups (reduced , dtype = by .dtype )
1330
+ group_chunks = ((np .nan ,),)
1337
1331
else :
1338
1332
if expected_groups is None :
1339
- expected_groups_ = _get_expected_groups (by_input , sort = sort )
1340
- else :
1341
- expected_groups_ = expected_groups
1342
- groups = (expected_groups_ .to_numpy (),)
1333
+ expected_groups = _get_expected_groups (by_input , sort = sort )
1334
+ groups = (expected_groups .to_numpy (),)
1335
+ group_chunks = ((len (expected_groups ),),)
1343
1336
1344
1337
elif method == "cohorts" :
1345
1338
chunks_cohorts = find_group_cohorts (
0 commit comments