@@ -162,7 +162,7 @@ def find_group_cohorts(labels, chunks, merge=True, method="cohorts"):
162
162
labels = np .asarray (labels )
163
163
164
164
if method == "split-reduce" :
165
- return pd . unique (labels . ravel ()) .reshape (- 1 , 1 ).tolist ()
165
+ return _get_expected_groups (labels , sort = False ). values .reshape (- 1 , 1 ).tolist ()
166
166
167
167
# Build an array with the shape of labels, but where every element is the "chunk number"
168
168
# 1. First subset the array appropriately
@@ -180,7 +180,7 @@ def find_group_cohorts(labels, chunks, merge=True, method="cohorts"):
180
180
# We always drop NaN; np.unique also considers every NaN to be different so
181
181
# it's really important we get rid of them.
182
182
raveled = labels .ravel ()
183
- unique_labels = np .unique (raveled [~ np . isnan (raveled )])
183
+ unique_labels = np .unique (raveled [~ isnull (raveled )])
184
184
# these are chunks where a label is present
185
185
label_chunks = {lab : tuple (np .unique (which_chunk [raveled == lab ])) for lab in unique_labels }
186
186
# These invert the label_chunks mapping so we know which labels occur together.
@@ -361,7 +361,7 @@ def reindex_(
361
361
raise ValueError ("Filling is required. fill_value cannot be None." )
362
362
indexer [axis ] = idx == - 1
363
363
# This allows us to match xarray's type promotion rules
364
- if fill_value is xrdtypes .NA or np . isnan (fill_value ):
364
+ if fill_value is xrdtypes .NA or isnull (fill_value ):
365
365
new_dtype , fill_value = xrdtypes .maybe_promote (reindexed .dtype )
366
366
reindexed = reindexed .astype (new_dtype , copy = False )
367
367
reindexed [tuple (indexer )] = fill_value
@@ -429,7 +429,7 @@ def factorize_(
429
429
else :
430
430
sorter = None
431
431
idx = np .searchsorted (expect , groupvar .ravel (), sorter = sorter )
432
- mask = np . isnan (groupvar .ravel ())
432
+ mask = isnull (groupvar .ravel ())
433
433
# TODO: optimize?
434
434
idx [mask ] = - 1
435
435
if not sort :
@@ -510,7 +510,7 @@ def chunk_argreduce(
510
510
engine = engine ,
511
511
sort = sort ,
512
512
)
513
- if not np . isnan (results ["groups" ]).all ():
513
+ if not isnull (results ["groups" ]).all ():
514
514
# will not work for empty groups...
515
515
# glorious
516
516
idx = np .broadcast_to (idx , array .shape )
@@ -639,6 +639,8 @@ def chunk_reduce(
639
639
# counts are needed for the final result as well as for masking
640
640
# optimize that out.
641
641
previous_reduction = None
642
+ for param in (fill_value , kwargs , dtype ):
643
+ assert len (param ) >= len (func )
642
644
for reduction , fv , kw , dt in zip (func , fill_value , kwargs , dtype ):
643
645
if empty :
644
646
result = np .full (shape = final_array_shape , fill_value = fv )
@@ -842,7 +844,7 @@ def _grouped_combine(
842
844
# reindexing is unnecessary
843
845
# I bet we can minimize the amount of reindexing for mD reductions too, but it's complicated
844
846
unique_groups = np .unique (tuple (flatten (deepmap (listify_groups , x_chunk ))))
845
- unique_groups = unique_groups [~ np . isnan (unique_groups )]
847
+ unique_groups = unique_groups [~ isnull (unique_groups )]
846
848
if len (unique_groups ) == 0 :
847
849
unique_groups = [np .nan ]
848
850
@@ -962,13 +964,10 @@ def _reduce_blockwise(array, by, agg, *, axis, expected_groups, fill_value, engi
962
964
Blockwise groupby reduction that produces the final result. This code path is
963
965
also used for non-dask array aggregations.
964
966
"""
965
-
966
967
# for pure numpy grouping, we just use npg directly and avoid "finalizing"
967
968
# (agg.finalize = None). We still need to do the reindexing step in finalize
968
969
# so that everything matches the dask version.
969
970
agg .finalize = None
970
- # xarray's count is npg's nanlen
971
- func : tuple [str ] = (agg .numpy , "nanlen" )
972
971
973
972
assert agg .finalize_kwargs is not None
974
973
finalize_kwargs = agg .finalize_kwargs
@@ -979,14 +978,14 @@ def _reduce_blockwise(array, by, agg, *, axis, expected_groups, fill_value, engi
979
978
results = chunk_reduce (
980
979
array ,
981
980
by ,
982
- func = func ,
981
+ func = agg . numpy ,
983
982
axis = axis ,
984
983
expected_groups = expected_groups ,
985
984
# This fill_value should only apply to groups that only contain NaN observations
986
985
# BUT there is funkiness when axis is a subset of all possible values
987
986
# (see below)
988
- fill_value = ( agg .fill_value [agg . name ], 0 ) ,
989
- dtype = ( agg .dtype [agg . name ], np . intp ) ,
987
+ fill_value = agg .fill_value ["numpy" ] ,
988
+ dtype = agg .dtype ["numpy" ] ,
990
989
kwargs = finalize_kwargs ,
991
990
engine = engine ,
992
991
sort = sort ,
@@ -998,36 +997,20 @@ def _reduce_blockwise(array, by, agg, *, axis, expected_groups, fill_value, engi
998
997
# so replace -1 with 0; unravel; then replace 0 with -1
999
998
# UGH!
1000
999
idx = results ["intermediates" ][0 ]
1001
- mask = idx == - 1
1000
+ mask = idx == agg . fill_value [ "numpy" ][ 0 ]
1002
1001
idx [mask ] = 0
1003
1002
# Fix npg bug where argmax with nD array, 1D group_idx, axis=-1
1004
1003
# will return wrong indices
1005
1004
idx = np .unravel_index (idx , array .shape )[- 1 ]
1006
- idx [mask ] = - 1
1005
+ idx [mask ] = agg . fill_value [ "numpy" ][ 0 ]
1007
1006
results ["intermediates" ][0 ] = idx
1008
1007
elif agg .name in ["nanvar" , "nanstd" ]:
1009
- # Fix npg bug where all-NaN rows are 0 instead of NaN
1008
+ # TODO: Fix npg bug where all-NaN rows are 0 instead of NaN
1010
1009
value , counts = results ["intermediates" ]
1011
1010
mask = counts <= 0
1012
1011
value [mask ] = np .nan
1013
1012
results ["intermediates" ][0 ] = value
1014
1013
1015
- # When axis is a subset of possible values; then npg will
1016
- # apply it to groups that don't exist along a particular axis (for e.g.)
1017
- # since these count as a group that is absent. thoo!
1018
- # TODO: the "count" bit is a hack to make tests pass.
1019
- if len (axis ) < by .ndim and agg .min_count is None and agg .name != "count" :
1020
- agg .min_count = 1
1021
-
1022
- # This fill_value applies to members of expected_groups not seen in groups
1023
- # or when the min_count threshold is not satisfied
1024
- # Use xarray's dtypes.NA to match type promotion rules
1025
- if fill_value is None :
1026
- if agg .name in ["any" , "all" ]:
1027
- fill_value = False
1028
- elif not _is_arg_reduction (agg ):
1029
- fill_value = xrdtypes .NA
1030
-
1031
1014
result = _finalize_results (results , agg , axis , expected_groups , fill_value = fill_value )
1032
1015
return result
1033
1016
@@ -1519,20 +1502,33 @@ def groupby_reduce(
1519
1502
array = _move_reduce_dims_to_end (array , axis )
1520
1503
axis = tuple (array .ndim + np .arange (- len (axis ), 0 ))
1521
1504
1505
+ has_dask = is_duck_dask_array (array ) or is_duck_dask_array (by )
1506
+
1507
+ # When axis is a subset of possible values; then npg will
1508
+ # apply it to groups that don't exist along a particular axis (for e.g.)
1509
+ # since these count as a group that is absent. thoo!
1510
+ # fill_value applies to all-NaN groups as well as labels in expected_groups that are not found.
1511
+ # The only way to do this consistently is mask out using min_count
1512
+ # Consider np.sum([np.nan]) = np.nan, np.nansum([np.nan]) = 0
1513
+ if min_count is None :
1514
+ if (
1515
+ len (axis ) < by .ndim
1516
+ or fill_value is not None
1517
+ # TODO: Fix npg bug where all-NaN rows are 0 instead of NaN
1518
+ or (not has_dask and isinstance (func , str ) and func in ["nanvar" , "nanstd" ])
1519
+ ):
1520
+ min_count = 1
1521
+
1522
+ # TODO: set in xarray?
1522
1523
if min_count is not None and func in ["nansum" , "nanprod" ] and fill_value is None :
1523
1524
# nansum, nanprod have fill_value=0, 1
1524
1525
# overwrite than when min_count is set
1525
1526
fill_value = np .nan
1526
1527
1527
- agg = _initialize_aggregation (func , array .dtype , fill_value )
1528
- agg .min_count = min_count
1529
- if finalize_kwargs is not None :
1530
- assert isinstance (finalize_kwargs , dict )
1531
- agg .finalize_kwargs = finalize_kwargs
1532
-
1533
1528
kwargs = dict (axis = axis , fill_value = fill_value , engine = engine , sort = sort )
1529
+ agg = _initialize_aggregation (func , array .dtype , fill_value , min_count , finalize_kwargs )
1534
1530
1535
- if not is_duck_dask_array ( array ) and not is_duck_dask_array ( by ) :
1531
+ if not has_dask :
1536
1532
results = _reduce_blockwise (array , by , agg , expected_groups = expected_groups , ** kwargs )
1537
1533
groups = (results ["groups" ],)
1538
1534
result = results [agg .name ]
@@ -1541,21 +1537,10 @@ def groupby_reduce(
1541
1537
if agg .chunk is None :
1542
1538
raise NotImplementedError (f"{ func } not implemented for dask arrays" )
1543
1539
1544
- if agg .min_count is None :
1545
- # This is needed for the dask pathway.
1546
- # Because we use intermediate fill_value since a group could be
1547
- # absent in one block, but present in another block
1548
- agg .min_count = 1
1549
-
1550
1540
# we always need some fill_value (see above) so choose the default if needed
1551
1541
if kwargs ["fill_value" ] is None :
1552
1542
kwargs ["fill_value" ] = agg .fill_value [agg .name ]
1553
1543
1554
- agg .chunk += ("nanlen" ,)
1555
- agg .combine += ("sum" ,)
1556
- agg .fill_value ["intermediate" ] += (0 ,)
1557
- agg .dtype ["intermediate" ] += (np .intp ,)
1558
-
1559
1544
partial_agg = partial (dask_groupby_agg , agg = agg , split_out = split_out , ** kwargs )
1560
1545
1561
1546
if method in ["split-reduce" , "cohorts" ]:
0 commit comments