Skip to content

Commit 398a6fb

Browse files
committed
Merge branch 'main' into multiple-groupers-3
* main: More consistent fill_value handling. isnull instead of isnan Add funding badge (#77)
2 parents af86cea + 35dd38d commit 398a6fb

File tree

5 files changed

+123
-68
lines changed

5 files changed

+123
-68
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
[![GitHub Workflow CI Status](https://img.shields.io/github/workflow/status/dcherian/flox/CI?logo=github&style=flat)](https://github.com/dcherian/flox/actions)[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/dcherian/flox/main.svg)](https://results.pre-commit.ci/latest/github/dcherian/flox/main)[![image](https://img.shields.io/codecov/c/github/dcherian/flox.svg?style=flat)](https://codecov.io/gh/dcherian/flox)[![PyPI](https://img.shields.io/pypi/v/flox.svg?style=flat)](https://pypi.org/project/flox/)[![Conda-forge](https://img.shields.io/conda/vn/conda-forge/flox.svg?style=flat)](https://anaconda.org/conda-forge/flox)[![Documentation Status](https://readthedocs.org/projects/flox/badge/?version=latest)](https://flox.readthedocs.io/en/latest/?badge=latest)
1+
[![GitHub Workflow CI Status](https://img.shields.io/github/workflow/status/dcherian/flox/CI?logo=github&style=flat)](https://github.com/dcherian/flox/actions)[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/dcherian/flox/main.svg)](https://results.pre-commit.ci/latest/github/dcherian/flox/main)[![image](https://img.shields.io/codecov/c/github/dcherian/flox.svg?style=flat)](https://codecov.io/gh/dcherian/flox)[![PyPI](https://img.shields.io/pypi/v/flox.svg?style=flat)](https://pypi.org/project/flox/)[![Conda-forge](https://img.shields.io/conda/vn/conda-forge/flox.svg?style=flat)](https://anaconda.org/conda-forge/flox)[![Documentation Status](https://readthedocs.org/projects/flox/badge/?version=latest)](https://flox.readthedocs.io/en/latest/?badge=latest)[![NASA-80NSSC18M0156](https://img.shields.io/badge/NASA-80NSSC18M0156-blue)](https://earthdata.nasa.gov/esds/competitive-programs/access/pangeo-ml)
22

33
# flox
44

flox/aggregate_npg.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import partial
2+
13
import numpy as np
24
import numpy_groupies as npg
35

@@ -64,3 +66,23 @@ def nansum_of_squares(group_idx, array, engine, *, axis=-1, size=None, fill_valu
6466
axis=axis,
6567
dtype=dtype,
6668
)
69+
70+
71+
def _len(group_idx, array, engine, *, func, axis=-1, size=None, fill_value=None, dtype=None):
72+
result = _get_aggregate(engine).aggregate(
73+
group_idx,
74+
array,
75+
axis=axis,
76+
func=func,
77+
size=size,
78+
fill_value=0,
79+
dtype=np.int64,
80+
)
81+
if fill_value is not None:
82+
result = result.astype(np.array([fill_value]).dtype)
83+
result[result == 0] = fill_value
84+
return result
85+
86+
87+
len = partial(_len, func="len")
88+
nanlen = partial(_len, func="nanlen")

flox/aggregations.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def __init__(
107107
self.preprocess = preprocess
108108
# Use "chunk_reduce" or "chunk_argreduce"
109109
self.reduction_type = reduction_type
110-
self.numpy = numpy if numpy else self.name
110+
self.numpy = (numpy,) if numpy else (self.name,)
111111
# initialize blockwise reduction
112112
self.chunk = _atleast_1d(chunk)
113113
# how to aggregate results after first round of reduction
@@ -163,6 +163,7 @@ def __repr__(self):
163163
f"combine: {self.combine}",
164164
f"aggregate: {self.aggregate}",
165165
f"finalize: {self.finalize}",
166+
f"min_count: {self.min_count}",
166167
)
167168
)
168169

@@ -265,9 +266,9 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
265266

266267

267268
min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF)
268-
nanmin = Aggregation("nanmin", chunk="nanmin", combine="min", fill_value=dtypes.INF)
269+
nanmin = Aggregation("nanmin", chunk="nanmin", combine="nanmin", fill_value=np.nan)
269270
max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF)
270-
nanmax = Aggregation("nanmax", chunk="nanmax", combine="max", fill_value=dtypes.NINF)
271+
nanmax = Aggregation("nanmax", chunk="nanmax", combine="nanmax", fill_value=np.nan)
271272

272273

273274
def argreduce_preprocess(array, axis):
@@ -409,7 +410,13 @@ def _zip_index(array_, idx_):
409410
}
410411

411412

412-
def _initialize_aggregation(func: str | Aggregation, array_dtype, fill_value) -> Aggregation:
413+
def _initialize_aggregation(
414+
func: str | Aggregation,
415+
array_dtype,
416+
fill_value,
417+
min_count: int,
418+
finalize_kwargs,
419+
) -> Aggregation:
413420
if not isinstance(func, Aggregation):
414421
try:
415422
# TODO: need better interface
@@ -425,6 +432,7 @@ def _initialize_aggregation(func: str | Aggregation, array_dtype, fill_value) ->
425432
raise ValueError("Bad type for func. Expected str or Aggregation")
426433

427434
agg.dtype[func] = _normalize_dtype(agg.dtype[func], array_dtype, fill_value)
435+
agg.dtype["numpy"] = (agg.dtype[func],)
428436
agg.dtype["intermediate"] = [
429437
_normalize_dtype(dtype, array_dtype) for dtype in agg.dtype["intermediate"]
430438
]
@@ -435,4 +443,27 @@ def _initialize_aggregation(func: str | Aggregation, array_dtype, fill_value) ->
435443
for dt, fv in zip(agg.dtype["intermediate"], agg.fill_value["intermediate"])
436444
)
437445
agg.fill_value[func] = _get_fill_value(agg.dtype[func], agg.fill_value[func])
446+
447+
fv = fill_value if fill_value is not None else agg.fill_value[agg.name]
448+
agg.fill_value["numpy"] = (fv,)
449+
450+
if finalize_kwargs is not None:
451+
assert isinstance(finalize_kwargs, dict)
452+
agg.finalize_kwargs = finalize_kwargs
453+
454+
# This is needed for the dask pathway.
455+
# Because we use intermediate fill_value since a group could be
456+
# absent in one block, but present in another block
457+
# We set it for numpy to get nansum, nanprod tests to pass
458+
# where the identity element is 0, 1
459+
if min_count is not None:
460+
agg.min_count = min_count
461+
agg.chunk += ("nanlen",)
462+
agg.numpy += ("nanlen",)
463+
agg.combine += ("sum",)
464+
agg.fill_value["intermediate"] += (0,)
465+
agg.fill_value["numpy"] += (0,)
466+
agg.dtype["intermediate"] += (np.intp,)
467+
agg.dtype["numpy"] += (np.intp,)
468+
438469
return agg

flox/core.py

Lines changed: 34 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def find_group_cohorts(labels, chunks, merge=True, method="cohorts"):
162162
labels = np.asarray(labels)
163163

164164
if method == "split-reduce":
165-
return pd.unique(labels.ravel()).reshape(-1, 1).tolist()
165+
return _get_expected_groups(labels, sort=False).values.reshape(-1, 1).tolist()
166166

167167
# Build an array with the shape of labels, but where every element is the "chunk number"
168168
# 1. First subset the array appropriately
@@ -180,7 +180,7 @@ def find_group_cohorts(labels, chunks, merge=True, method="cohorts"):
180180
# We always drop NaN; np.unique also considers every NaN to be different so
181181
# it's really important we get rid of them.
182182
raveled = labels.ravel()
183-
unique_labels = np.unique(raveled[~np.isnan(raveled)])
183+
unique_labels = np.unique(raveled[~isnull(raveled)])
184184
# these are chunks where a label is present
185185
label_chunks = {lab: tuple(np.unique(which_chunk[raveled == lab])) for lab in unique_labels}
186186
# These invert the label_chunks mapping so we know which labels occur together.
@@ -361,7 +361,7 @@ def reindex_(
361361
raise ValueError("Filling is required. fill_value cannot be None.")
362362
indexer[axis] = idx == -1
363363
# This allows us to match xarray's type promotion rules
364-
if fill_value is xrdtypes.NA or np.isnan(fill_value):
364+
if fill_value is xrdtypes.NA or isnull(fill_value):
365365
new_dtype, fill_value = xrdtypes.maybe_promote(reindexed.dtype)
366366
reindexed = reindexed.astype(new_dtype, copy=False)
367367
reindexed[tuple(indexer)] = fill_value
@@ -429,7 +429,7 @@ def factorize_(
429429
else:
430430
sorter = None
431431
idx = np.searchsorted(expect, groupvar.ravel(), sorter=sorter)
432-
mask = np.isnan(groupvar.ravel())
432+
mask = isnull(groupvar.ravel())
433433
# TODO: optimize?
434434
idx[mask] = -1
435435
if not sort:
@@ -510,7 +510,7 @@ def chunk_argreduce(
510510
engine=engine,
511511
sort=sort,
512512
)
513-
if not np.isnan(results["groups"]).all():
513+
if not isnull(results["groups"]).all():
514514
# will not work for empty groups...
515515
# glorious
516516
idx = np.broadcast_to(idx, array.shape)
@@ -639,6 +639,8 @@ def chunk_reduce(
639639
# counts are needed for the final result as well as for masking
640640
# optimize that out.
641641
previous_reduction = None
642+
for param in (fill_value, kwargs, dtype):
643+
assert len(param) >= len(func)
642644
for reduction, fv, kw, dt in zip(func, fill_value, kwargs, dtype):
643645
if empty:
644646
result = np.full(shape=final_array_shape, fill_value=fv)
@@ -842,7 +844,7 @@ def _grouped_combine(
842844
# reindexing is unnecessary
843845
# I bet we can minimize the amount of reindexing for mD reductions too, but it's complicated
844846
unique_groups = np.unique(tuple(flatten(deepmap(listify_groups, x_chunk))))
845-
unique_groups = unique_groups[~np.isnan(unique_groups)]
847+
unique_groups = unique_groups[~isnull(unique_groups)]
846848
if len(unique_groups) == 0:
847849
unique_groups = [np.nan]
848850

@@ -962,13 +964,10 @@ def _reduce_blockwise(array, by, agg, *, axis, expected_groups, fill_value, engi
962964
Blockwise groupby reduction that produces the final result. This code path is
963965
also used for non-dask array aggregations.
964966
"""
965-
966967
# for pure numpy grouping, we just use npg directly and avoid "finalizing"
967968
# (agg.finalize = None). We still need to do the reindexing step in finalize
968969
# so that everything matches the dask version.
969970
agg.finalize = None
970-
# xarray's count is npg's nanlen
971-
func: tuple[str] = (agg.numpy, "nanlen")
972971

973972
assert agg.finalize_kwargs is not None
974973
finalize_kwargs = agg.finalize_kwargs
@@ -979,14 +978,14 @@ def _reduce_blockwise(array, by, agg, *, axis, expected_groups, fill_value, engi
979978
results = chunk_reduce(
980979
array,
981980
by,
982-
func=func,
981+
func=agg.numpy,
983982
axis=axis,
984983
expected_groups=expected_groups,
985984
# This fill_value should only apply to groups that only contain NaN observations
986985
# BUT there is funkiness when axis is a subset of all possible values
987986
# (see below)
988-
fill_value=(agg.fill_value[agg.name], 0),
989-
dtype=(agg.dtype[agg.name], np.intp),
987+
fill_value=agg.fill_value["numpy"],
988+
dtype=agg.dtype["numpy"],
990989
kwargs=finalize_kwargs,
991990
engine=engine,
992991
sort=sort,
@@ -998,36 +997,20 @@ def _reduce_blockwise(array, by, agg, *, axis, expected_groups, fill_value, engi
998997
# so replace -1 with 0; unravel; then replace 0 with -1
999998
# UGH!
1000999
idx = results["intermediates"][0]
1001-
mask = idx == -1
1000+
mask = idx == agg.fill_value["numpy"][0]
10021001
idx[mask] = 0
10031002
# Fix npg bug where argmax with nD array, 1D group_idx, axis=-1
10041003
# will return wrong indices
10051004
idx = np.unravel_index(idx, array.shape)[-1]
1006-
idx[mask] = -1
1005+
idx[mask] = agg.fill_value["numpy"][0]
10071006
results["intermediates"][0] = idx
10081007
elif agg.name in ["nanvar", "nanstd"]:
1009-
# Fix npg bug where all-NaN rows are 0 instead of NaN
1008+
# TODO: Fix npg bug where all-NaN rows are 0 instead of NaN
10101009
value, counts = results["intermediates"]
10111010
mask = counts <= 0
10121011
value[mask] = np.nan
10131012
results["intermediates"][0] = value
10141013

1015-
# When axis is a subset of possible values; then npg will
1016-
# apply it to groups that don't exist along a particular axis (for e.g.)
1017-
# since these count as a group that is absent. thoo!
1018-
# TODO: the "count" bit is a hack to make tests pass.
1019-
if len(axis) < by.ndim and agg.min_count is None and agg.name != "count":
1020-
agg.min_count = 1
1021-
1022-
# This fill_value applies to members of expected_groups not seen in groups
1023-
# or when the min_count threshold is not satisfied
1024-
# Use xarray's dtypes.NA to match type promotion rules
1025-
if fill_value is None:
1026-
if agg.name in ["any", "all"]:
1027-
fill_value = False
1028-
elif not _is_arg_reduction(agg):
1029-
fill_value = xrdtypes.NA
1030-
10311014
result = _finalize_results(results, agg, axis, expected_groups, fill_value=fill_value)
10321015
return result
10331016

@@ -1519,20 +1502,33 @@ def groupby_reduce(
15191502
array = _move_reduce_dims_to_end(array, axis)
15201503
axis = tuple(array.ndim + np.arange(-len(axis), 0))
15211504

1505+
has_dask = is_duck_dask_array(array) or is_duck_dask_array(by)
1506+
1507+
# When axis is a subset of possible values; then npg will
1508+
# apply it to groups that don't exist along a particular axis (for e.g.)
1509+
# since these count as a group that is absent. thoo!
1510+
# fill_value applies to all-NaN groups as well as labels in expected_groups that are not found.
1511+
# The only way to do this consistently is mask out using min_count
1512+
# Consider np.sum([np.nan]) = np.nan, np.nansum([np.nan]) = 0
1513+
if min_count is None:
1514+
if (
1515+
len(axis) < by.ndim
1516+
or fill_value is not None
1517+
# TODO: Fix npg bug where all-NaN rows are 0 instead of NaN
1518+
or (not has_dask and isinstance(func, str) and func in ["nanvar", "nanstd"])
1519+
):
1520+
min_count = 1
1521+
1522+
# TODO: set in xarray?
15221523
if min_count is not None and func in ["nansum", "nanprod"] and fill_value is None:
15231524
# nansum, nanprod have fill_value=0, 1
15241525
# overwrite than when min_count is set
15251526
fill_value = np.nan
15261527

1527-
agg = _initialize_aggregation(func, array.dtype, fill_value)
1528-
agg.min_count = min_count
1529-
if finalize_kwargs is not None:
1530-
assert isinstance(finalize_kwargs, dict)
1531-
agg.finalize_kwargs = finalize_kwargs
1532-
15331528
kwargs = dict(axis=axis, fill_value=fill_value, engine=engine, sort=sort)
1529+
agg = _initialize_aggregation(func, array.dtype, fill_value, min_count, finalize_kwargs)
15341530

1535-
if not is_duck_dask_array(array) and not is_duck_dask_array(by):
1531+
if not has_dask:
15361532
results = _reduce_blockwise(array, by, agg, expected_groups=expected_groups, **kwargs)
15371533
groups = (results["groups"],)
15381534
result = results[agg.name]
@@ -1541,21 +1537,10 @@ def groupby_reduce(
15411537
if agg.chunk is None:
15421538
raise NotImplementedError(f"{func} not implemented for dask arrays")
15431539

1544-
if agg.min_count is None:
1545-
# This is needed for the dask pathway.
1546-
# Because we use intermediate fill_value since a group could be
1547-
# absent in one block, but present in another block
1548-
agg.min_count = 1
1549-
15501540
# we always need some fill_value (see above) so choose the default if needed
15511541
if kwargs["fill_value"] is None:
15521542
kwargs["fill_value"] = agg.fill_value[agg.name]
15531543

1554-
agg.chunk += ("nanlen",)
1555-
agg.combine += ("sum",)
1556-
agg.fill_value["intermediate"] += (0,)
1557-
agg.dtype["intermediate"] += (np.intp,)
1558-
15591544
partial_agg = partial(dask_groupby_agg, agg=agg, split_out=split_out, **kwargs)
15601545

15611546
if method in ["split-reduce", "cohorts"]:

0 commit comments

Comments
 (0)