Skip to content

Commit fa5ed42

Browse files
committed
Merge branch 'main' into numbagg
* main: Fix benchmarks Fix engine='numba' (#73) Switch to pre-commit.ci (#71)
2 parents 551a719 + 8e715f0 commit fa5ed42

File tree

4 files changed

+39
-37
lines changed

4 files changed

+39
-37
lines changed

.github/workflows/linting.yaml

Lines changed: 0 additions & 16 deletions
This file was deleted.

benchmarks/combine.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def setup(self, *args, **kwargs):
1313

1414
@parameterized("kind", ("cohorts", "mapreduce"))
1515
def time_combine(self, kind):
16-
flox.core._npg_combine(
16+
flox.core._grouped_combine(
1717
getattr(self, f"x_chunk_{kind}"),
1818
**self.kwargs,
1919
keepdims=True,
@@ -22,7 +22,7 @@ def time_combine(self, kind):
2222

2323
@parameterized("kind", ("cohorts", "mapreduce"))
2424
def peakmem_combine(self, kind):
25-
flox.core._npg_combine(
25+
flox.core._grouped_combine(
2626
getattr(self, f"x_chunk_{kind}"),
2727
**self.kwargs,
2828
keepdims=True,
@@ -58,4 +58,4 @@ def construct_member(groups):
5858
]
5959

6060
self.x_chunk_cohorts = [construct_member(groups) for groups in [np.array((1, 2, 3, 4))] * 4]
61-
self.kwargs = {"agg": flox.aggregations.mean, "axis": (3,), "group_ndim": 1}
61+
self.kwargs = {"agg": flox.aggregations.mean, "axis": (3,), "neg_axis": (-1,)}

flox/aggregate_npg.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@
22
import numpy_groupies as npg
33

44

5+
def _get_aggregate(engine):
6+
return npg.aggregate_numpy if engine == "numpy" else npg.aggregate_numba
7+
8+
59
def sum_of_squares(
6-
group_idx, array, *, axis=-1, func="sum", size=None, fill_value=None, dtype=None
10+
group_idx, array, engine, *, axis=-1, func="sum", size=None, fill_value=None, dtype=None
711
):
812

9-
return npg.aggregate_numpy.aggregate(
13+
return _get_aggregate(engine).aggregate(
1014
group_idx,
1115
array**2,
1216
axis=axis,
@@ -17,12 +21,12 @@ def sum_of_squares(
1721
)
1822

1923

20-
def nansum(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None):
24+
def nansum(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None):
2125
# npg takes out NaNs before calling np.bincount
2226
# This means that all NaN groups are equivalent to absent groups
2327
# This behaviour does not work for xarray
2428

25-
return npg.aggregate_numpy.aggregate(
29+
return _get_aggregate(engine).aggregate(
2630
group_idx,
2731
np.where(np.isnan(array), 0, array),
2832
axis=axis,
@@ -33,12 +37,12 @@ def nansum(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None)
3337
)
3438

3539

36-
def nanprod(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None):
40+
def nanprod(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None):
3741
# npg takes out NaNs before calling np.bincount
3842
# This means that all NaN groups are equivalent to absent groups
3943
# This behaviour does not work for xarray
4044

41-
return npg.aggregate_numpy.aggregate(
45+
return _get_aggregate(engine).aggregate(
4246
group_idx,
4347
np.where(np.isnan(array), 1, array),
4448
axis=axis,
@@ -49,7 +53,14 @@ def nanprod(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None
4953
)
5054

5155

52-
def nansum_of_squares(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None):
56+
def nansum_of_squares(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None):
5357
return sum_of_squares(
54-
group_idx, array, func="nansum", size=size, fill_value=fill_value, axis=axis, dtype=dtype
58+
group_idx,
59+
array,
60+
engine=engine,
61+
func="nansum",
62+
size=size,
63+
fill_value=fill_value,
64+
axis=axis,
65+
dtype=dtype,
5566
)

flox/aggregations.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,31 +10,38 @@
1010

1111

1212
def generic_aggregate(
13-
group_idx, array, *, engine, func, axis=-1, size=None, fill_value=None, dtype=None, **kwargs
13+
group_idx,
14+
array,
15+
*,
16+
engine: str,
17+
func: str,
18+
axis=-1,
19+
size=None,
20+
fill_value=None,
21+
dtype=None,
22+
**kwargs,
1423
):
1524
if engine == "flox":
1625
try:
1726
method = getattr(aggregate_flox, func)
1827
except AttributeError:
1928
method = partial(npg.aggregate_numpy.aggregate, func=func)
29+
2030
elif engine == "numbagg":
2131
from . import aggregate_numbagg
2232

2333
try:
2434
method = getattr(aggregate_numbagg, func)
2535
except AttributeError:
2636
method = partial(npg.aggregate_numpy.aggregate, func=func)
27-
elif engine == "numpy":
28-
try:
29-
# TODO: fix numba here
30-
method = getattr(aggregate_npg, func)
31-
except AttributeError:
32-
method = partial(npg.aggregate_np, func=func)
33-
elif engine == "numba":
37+
38+
elif engine in ["numpy", "numba"]:
3439
try:
35-
method = getattr(aggregate_npg, f"{func}")
40+
method_ = getattr(aggregate_npg, func)
41+
method = partial(method_, engine=engine)
3642
except AttributeError:
37-
method = partial(npg.aggregate_nb, func=func)
43+
aggregate = npg.aggregate_np if engine == "numpy" else npg.aggregate_nb
44+
method = partial(aggregate, func=func)
3845
else:
3946
raise ValueError(
4047
f"Expected engine to be one of ['flox', 'numpy', 'numba', 'numbagg']. Received {engine} instead."

0 commit comments

Comments
 (0)