@@ -137,10 +137,10 @@ def _get_optimal_chunks_for_groups(chunks, labels):
137
137
return tuple (newchunks )
138
138
139
139
140
- def _unique (a ):
140
+ def _unique (a : np . ndarray ):
141
141
"""Much faster to use pandas unique and sort the results.
142
142
np.unique sorts before uniquifying and is slow."""
143
- return np .sort (pd .unique (a ))
143
+ return np .sort (pd .unique (a . reshape ( - 1 ) ))
144
144
145
145
146
146
@memoize
@@ -819,8 +819,25 @@ def _expand_dims(results: IntermediateDict) -> IntermediateDict:
819
819
return results
820
820
821
821
822
+ def _find_unique_groups (x_chunk ):
823
+ from dask .base import flatten
824
+ from dask .utils import deepmap
825
+
826
+ unique_groups = _unique (np .asarray (tuple (flatten (deepmap (listify_groups , x_chunk )))))
827
+ unique_groups = unique_groups [~ isnull (unique_groups )]
828
+
829
+ if len (unique_groups ) == 0 :
830
+ unique_groups = [np .nan ]
831
+ return unique_groups
832
+
833
+
822
834
def _simple_combine (
823
- x_chunk , agg : Aggregation , axis : T_Axes , keepdims : bool , is_aggregate : bool = False
835
+ x_chunk ,
836
+ agg : Aggregation ,
837
+ axis : T_Axes ,
838
+ keepdims : bool ,
839
+ reindex : bool ,
840
+ is_aggregate : bool = False ,
824
841
) -> IntermediateDict :
825
842
"""
826
843
'Simple' combination of blockwise results.
@@ -833,8 +850,19 @@ def _simple_combine(
833
850
4. At the final agggregate step, we squeeze out DUMMY_AXIS
834
851
"""
835
852
from dask .array .core import deepfirst
853
+ from dask .utils import deepmap
854
+
855
+ if not reindex :
856
+ # We didn't reindex at the blockwise step
857
+ # So now reindex before combining by reducing along DUMMY_AXIS
858
+ unique_groups = _find_unique_groups (x_chunk )
859
+ x_chunk = deepmap (
860
+ partial (reindex_intermediates , agg = agg , unique_groups = unique_groups ), x_chunk
861
+ )
862
+ else :
863
+ unique_groups = deepfirst (x_chunk )["groups" ]
836
864
837
- results : IntermediateDict = {"groups" : deepfirst ( x_chunk )[ "groups" ] }
865
+ results : IntermediateDict = {"groups" : unique_groups }
838
866
results ["intermediates" ] = []
839
867
axis_ = axis [:- 1 ] + (DUMMY_AXIS ,)
840
868
for idx , combine in enumerate (agg .combine ):
@@ -889,7 +917,6 @@ def _grouped_combine(
889
917
sort : bool = True ,
890
918
) -> IntermediateDict :
891
919
"""Combine intermediates step of tree reduction."""
892
- from dask .base import flatten
893
920
from dask .utils import deepmap
894
921
895
922
if isinstance (x_chunk , dict ):
@@ -900,11 +927,7 @@ def _grouped_combine(
900
927
# when there's only a single axis of reduction, we can just concatenate later,
901
928
# reindexing is unnecessary
902
929
# I bet we can minimize the amount of reindexing for mD reductions too, but it's complicated
903
- unique_groups = _unique (tuple (flatten (deepmap (listify_groups , x_chunk ))))
904
- unique_groups = unique_groups [~ isnull (unique_groups )]
905
- if len (unique_groups ) == 0 :
906
- unique_groups = [np .nan ]
907
-
930
+ unique_groups = _find_unique_groups (x_chunk )
908
931
x_chunk = deepmap (
909
932
partial (reindex_intermediates , agg = agg , unique_groups = unique_groups ), x_chunk
910
933
)
@@ -1219,7 +1242,8 @@ def dask_groupby_agg(
1219
1242
# This allows us to discover groups at compute time, support argreductions, lower intermediate
1220
1243
# memory usage (but method="cohorts" would also work to reduce memory in some cases)
1221
1244
1222
- do_simple_combine = method != "blockwise" and reindex and not _is_arg_reduction (agg )
1245
+ do_simple_combine = not _is_arg_reduction (agg )
1246
+
1223
1247
if method == "blockwise" :
1224
1248
# use the "non dask" code path, but applied blockwise
1225
1249
blockwise_method = partial (
@@ -1271,31 +1295,32 @@ def dask_groupby_agg(
1271
1295
if method in ["map-reduce" , "cohorts" , "split-reduce" ]:
1272
1296
combine : Callable [..., IntermediateDict ]
1273
1297
if do_simple_combine :
1274
- combine = _simple_combine
1298
+ combine = partial (_simple_combine , reindex = reindex )
1299
+ combine_name = "simple-combine"
1275
1300
else :
1276
1301
combine = partial (_grouped_combine , engine = engine , sort = sort )
1302
+ combine_name = "grouped-combine"
1277
1303
1278
- # Each chunk of `reduced`` is really a dict mapping
1279
- # 1. reduction name to array
1280
- # 2. "groups" to an array of group labels
1281
- # Note: it does not make sense to interpret axis relative to
1282
- # shape of intermediate results after the blockwise call
1283
1304
tree_reduce = partial (
1284
1305
dask .array .reductions ._tree_reduce ,
1285
- combine = partial (combine , agg = agg ),
1286
- name = f"{ name } -reduce-{ method } " ,
1306
+ name = f"{ name } -reduce-{ method } -{ combine_name } " ,
1287
1307
dtype = array .dtype ,
1288
1308
axis = axis ,
1289
1309
keepdims = True ,
1290
1310
concatenate = False ,
1291
1311
)
1292
- aggregate = partial (
1293
- _aggregate , combine = combine , agg = agg , fill_value = fill_value , reindex = reindex
1294
- )
1312
+ aggregate = partial (_aggregate , combine = combine , agg = agg , fill_value = fill_value )
1313
+
1314
+ # Each chunk of `reduced`` is really a dict mapping
1315
+ # 1. reduction name to array
1316
+ # 2. "groups" to an array of group labels
1317
+ # Note: it does not make sense to interpret axis relative to
1318
+ # shape of intermediate results after the blockwise call
1295
1319
if method == "map-reduce" :
1296
1320
reduced = tree_reduce (
1297
1321
intermediate ,
1298
- aggregate = partial (aggregate , expected_groups = expected_groups ),
1322
+ combine = partial (combine , agg = agg ),
1323
+ aggregate = partial (aggregate , expected_groups = expected_groups , reindex = reindex ),
1299
1324
)
1300
1325
if is_duck_dask_array (by_input ) and expected_groups is None :
1301
1326
groups = _extract_unknown_groups (reduced , group_chunks = group_chunks , dtype = by .dtype )
@@ -1313,23 +1338,17 @@ def dask_groupby_agg(
1313
1338
reduced_ = []
1314
1339
groups_ = []
1315
1340
for blks , cohort in chunks_cohorts .items ():
1341
+ index = pd .Index (cohort )
1316
1342
subset = subset_to_blocks (intermediate , blks , array .blocks .shape [- len (axis ) :])
1317
- if do_simple_combine :
1318
- # reindex so that reindex can be set to True later
1319
- reindexed = dask .array .map_blocks (
1320
- reindex_intermediates ,
1321
- subset ,
1322
- agg = agg ,
1323
- unique_groups = cohort ,
1324
- meta = subset ._meta ,
1325
- )
1326
- else :
1327
- reindexed = subset
1328
-
1343
+ reindexed = dask .array .map_blocks (
1344
+ reindex_intermediates , subset , agg = agg , unique_groups = index , meta = subset ._meta
1345
+ )
1346
+ # now that we have reindexed, we can set reindex=True explicitlly
1329
1347
reduced_ .append (
1330
1348
tree_reduce (
1331
1349
reindexed ,
1332
- aggregate = partial (aggregate , expected_groups = cohort , reindex = reindex ),
1350
+ combine = partial (combine , agg = agg , reindex = True ),
1351
+ aggregate = partial (aggregate , expected_groups = index , reindex = True ),
1333
1352
)
1334
1353
)
1335
1354
groups_ .append (cohort )
@@ -1394,28 +1413,24 @@ def _validate_reindex(
1394
1413
if reindex is True :
1395
1414
if _is_arg_reduction (func ):
1396
1415
raise NotImplementedError
1397
- if method == "blockwise" :
1398
- raise NotImplementedError
1416
+ if method in ["blockwise" , "cohorts" ]:
1417
+ raise ValueError (
1418
+ "reindex=True is not a valid choice for method='blockwise' or method='cohorts'."
1419
+ )
1399
1420
1400
1421
if reindex is None :
1401
1422
if method == "blockwise" or _is_arg_reduction (func ):
1402
1423
reindex = False
1403
1424
1404
- elif expected_groups is not None :
1405
- reindex = True
1406
-
1407
- elif method in ["split-reduce" , "cohorts" ]:
1408
- reindex = True
1425
+ elif method == "cohorts" :
1426
+ reindex = False
1409
1427
1410
1428
elif method == "map-reduce" :
1411
1429
if expected_groups is None and by_is_dask :
1412
1430
reindex = False
1413
1431
else :
1414
1432
reindex = True
1415
1433
1416
- if method in ["split-reduce" , "cohorts" ] and reindex is False :
1417
- raise NotImplementedError
1418
-
1419
1434
assert isinstance (reindex , bool )
1420
1435
return reindex
1421
1436
0 commit comments