@@ -66,7 +66,7 @@ def _get_expected_groups(by, raise_if_dask=True) -> pd.Index | None:
66
66
return None
67
67
flatby = by .ravel ()
68
68
expected = np .unique (flatby [~ isnull (flatby )])
69
- return _convert_expected_groups_to_index (expected , isbin = False )
69
+ return _convert_expected_groups_to_index (( expected ,), isbin = ( False ,))[ 0 ]
70
70
71
71
72
72
def _get_chunk_reduction (reduction_type : str ) -> Callable :
@@ -385,7 +385,7 @@ def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]:
385
385
return offset , size
386
386
387
387
388
- def factorize_ (by : tuple , axis , expected_groups : tuple [pd .Index , ...] = None , reindex = False ):
388
+ def factorize_ (by : tuple , axis , expected_groups : tuple [pd .Index , ...] = None , reindex = False , fastpath = False ):
389
389
if not isinstance (by , tuple ):
390
390
raise ValueError (f"Expected `by` to be a tuple. Received { type (by )} instead" )
391
391
@@ -420,10 +420,13 @@ def factorize_(by: tuple, axis, expected_groups: tuple[pd.Index, ...] = None, re
420
420
grp_shape = tuple (len (grp ) for grp in found_groups )
421
421
ngroups = np .prod (grp_shape )
422
422
if len (by ) > 1 :
423
- group_idx = np .ravel_multi_index (factorized , grp_shape ). reshape ( by [ 0 ]. shape )
423
+ group_idx = np .ravel_multi_index (factorized , grp_shape )
424
424
else :
425
425
group_idx = factorized [0 ]
426
426
427
+ if fastpath :
428
+ return group_idx , found_groups , grp_shape
429
+
427
430
if np .isscalar (axis ) and groupvar .ndim > 1 :
428
431
# Not reducing along all dimensions of by
429
432
# this is OK because for 3D by and axis=(1,2),
@@ -1243,23 +1246,60 @@ def _assert_by_is_aligned(shape, by):
1243
1246
)
1244
1247
1245
1248
1246
- def _convert_expected_groups_to_index (expected_groups , isbin : bool ) -> pd .Index | None :
1247
- if isinstance (expected_groups , pd .IntervalIndex ) or (
1248
- isinstance (expected_groups , pd .Index ) and not isbin
1249
- ):
1250
- return expected_groups
1251
- if isbin :
1252
- return pd .IntervalIndex .from_arrays (expected_groups [:- 1 ], expected_groups [1 :])
1253
- elif expected_groups is not None :
1254
- return pd .Index (expected_groups )
1255
- return expected_groups
1249
+ def _convert_expected_groups_to_index (expected_groups : tuple , isbin : bool ) -> pd .Index | None :
1250
+ out = []
1251
+ for ex , isbin_ in zip (expected_groups , isbin ):
1252
+ if isinstance (ex , pd .IntervalIndex ) or (isinstance (ex , pd .Index ) and not isbin ):
1253
+ out .append (expected_groups )
1254
+ elif ex is not None :
1255
+ if isbin_ :
1256
+ out .append (pd .IntervalIndex .from_arrays (ex [:- 1 ], ex [1 :]))
1257
+ else :
1258
+ out .append (pd .Index (ex ))
1259
+ else :
1260
+ assert ex is None
1261
+ out .append (None )
1262
+ return tuple (out )
1263
+
1264
+
1265
+ def _lazy_factorize_wrapper (* by , ** kwargs ):
1266
+ group_idx , _ = factorize_ (by , ** kwargs )
1267
+ return group_idx
1268
+
1269
+
1270
+ def _factorize_multiple (by , expected_groups , by_is_dask ):
1271
+ kwargs = dict (
1272
+ expected_groups = expected_groups ,
1273
+ axis = None , # always None, we offset later if necessary.
1274
+ fastpath = True ,
1275
+ )
1276
+ if by_is_dask :
1277
+ import dask .array
1278
+
1279
+ group_idx = dask .array .map_blocks (
1280
+ _lazy_factorize_wrapper ,
1281
+ * np .broadcast_arrays (* by ),
1282
+ meta = np .array ((), dtype = np .int64 ),
1283
+ ** kwargs ,
1284
+ )
1285
+ found_groups = tuple (None if is_duck_dask_array (b ) else np .unique (b ) for b in by )
1286
+ else :
1287
+ group_idx , found_groups , grp_shape = factorize_ (by , ** kwargs )
1288
+
1289
+ final_groups = tuple (
1290
+ pd .Index (found ) if expect is None else expect
1291
+ for found , expect in zip (found_groups , expected_groups )
1292
+ )
1293
+
1294
+ if any (grp is None for grp in final_groups ):
1295
+ raise
1296
+ return (group_idx ,), final_groups , grp_shape
1256
1297
1257
1298
1258
1299
def groupby_reduce (
1259
1300
array : np .ndarray | DaskArray ,
1260
- by : np .ndarray | DaskArray ,
1301
+ * by : np .ndarray | DaskArray ,
1261
1302
func : str | Aggregation ,
1262
- * ,
1263
1303
expected_groups : Sequence | np .ndarray | None = None ,
1264
1304
sort : bool = True ,
1265
1305
isbin : bool = False ,
@@ -1373,6 +1413,7 @@ def groupby_reduce(
1373
1413
reindex = _validate_reindex (reindex , func , method , expected_groups )
1374
1414
1375
1415
by : tuple = tuple (np .asarray (b ) if not is_duck_array (b ) else b for b in by )
1416
+ nby = len (by )
1376
1417
by_is_dask = any (is_duck_dask_array (b ) for b in by )
1377
1418
if not is_duck_array (array ):
1378
1419
array = np .asarray (array )
@@ -1392,6 +1433,20 @@ def groupby_reduce(
1392
1433
# (pd.IntervalIndex or not)
1393
1434
expected_groups = _convert_expected_groups_to_index (expected_groups , isbin )
1394
1435
1436
+ # when grouping by multiple variables, we factorize early.
1437
+ # TODO: could restrict this to dask-only
1438
+ if nby > 1 :
1439
+ by , final_groups , grp_shape = _factorize_multiple (
1440
+ by , expected_groups , by_is_dask = by_is_dask
1441
+ )
1442
+ expected_groups = (pd .RangeIndex (np .prod (grp_shape )),)
1443
+ else :
1444
+ final_groups = expected_groups
1445
+
1446
+ assert len (by ) == 1
1447
+ by = by [0 ]
1448
+ expected_groups = expected_groups [0 ]
1449
+
1395
1450
if axis is None :
1396
1451
axis = tuple (array .ndim + np .arange (- by .ndim , 0 ))
1397
1452
else :
@@ -1404,7 +1459,7 @@ def groupby_reduce(
1404
1459
1405
1460
# TODO: make sure expected_groups is unique
1406
1461
# TODO: hack
1407
- if len (axis ) == 1 and by_ndim > 1 and expected_groups [ 0 ] is None :
1462
+ if len (axis ) == 1 and by . ndim > 1 and expected_groups is None :
1408
1463
# TODO: hack
1409
1464
if not by_is_dask :
1410
1465
expected_groups = _get_expected_groups (by )
@@ -1509,7 +1564,11 @@ def groupby_reduce(
1509
1564
array = rechunk_for_blockwise (array , axis = - 1 , labels = by )
1510
1565
1511
1566
result , * groups = partial_agg (
1512
- array , by , expected_groups = expected_groups , reindex = reindex , method = method
1567
+ array ,
1568
+ by ,
1569
+ expected_groups = None if method == "blockwise" else expected_groups ,
1570
+ reindex = reindex ,
1571
+ method = method ,
1513
1572
)
1514
1573
1515
1574
if sort and method != "map-reduce" :
@@ -1518,4 +1577,7 @@ def groupby_reduce(
1518
1577
result = result [..., sorted_idx ]
1519
1578
groups = (groups [0 ][sorted_idx ],)
1520
1579
1580
+ if nby > 1 :
1581
+ groups = final_groups
1582
+ result = result .reshape (result .shape [:- 1 ] + grp_shape )
1521
1583
return (result , * groups )
0 commit comments