Skip to content

Commit 950570f

Browse files
committed
BUG: groupby.agg with numba and as_index=False
1 parent fcb8b80 commit 950570f

File tree

4 files changed

+25
-24
lines changed

4 files changed

+25
-24
lines changed

pandas/core/groupby/generic.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -219,16 +219,9 @@ def apply(self, func, *args, **kwargs) -> Series:
219219
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):
220220

221221
if maybe_use_numba(engine):
222-
data = self._obj_with_exclusions
223-
result = self._aggregate_with_numba(
224-
data.to_frame(), func, *args, engine_kwargs=engine_kwargs, **kwargs
222+
return self._aggregate_with_numba(
223+
func, *args, engine_kwargs=engine_kwargs, **kwargs
225224
)
226-
index = self.grouper.result_index
227-
result = self.obj._constructor(result.ravel(), index=index, name=data.name)
228-
if not self.as_index:
229-
result = self._insert_inaxis_grouper(result)
230-
result.index = default_index(len(result))
231-
return result
232225

233226
relabeling = func is None
234227
columns = None
@@ -1264,12 +1257,9 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
12641257
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):
12651258

12661259
if maybe_use_numba(engine):
1267-
data = self._obj_with_exclusions
1268-
result = self._aggregate_with_numba(
1269-
data, func, *args, engine_kwargs=engine_kwargs, **kwargs
1260+
return self._aggregate_with_numba(
1261+
func, *args, engine_kwargs=engine_kwargs, **kwargs
12701262
)
1271-
index = self.grouper.result_index
1272-
return self.obj._constructor(result, index=index, columns=data.columns)
12731263

12741264
relabeling, func, columns, order = reconstruct_func(func, **kwargs)
12751265
func = maybe_mangle_lambdas(func)

pandas/core/groupby/groupby.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,17 +1270,18 @@ def _transform_with_numba(
12701270
return result.take(np.argsort(sorted_index), axis=0)
12711271

12721272
@final
1273-
def _aggregate_with_numba(
1274-
self, data: DataFrame, func, *args, engine_kwargs=None, **kwargs
1275-
):
1273+
def _aggregate_with_numba(self, func, *args, engine_kwargs=None, **kwargs):
12761274
"""
12771275
Perform groupby aggregation routine with the numba engine.
12781276
12791277
This routine mimics the data splitting routine of the DataSplitter class
12801278
to generate the indices of each group in the sorted data and then passes the
12811279
data and indices into a Numba jitted function.
12821280
"""
1283-
starts, ends, sorted_index, sorted_data = self._numba_prep(data)
1281+
data = self._obj_with_exclusions
1282+
df = data if data.ndim == 2 else data.to_frame()
1283+
1284+
starts, ends, sorted_index, sorted_data = self._numba_prep(df)
12841285
numba_.validate_udf(func)
12851286
numba_agg_func = numba_.generate_numba_agg_func(
12861287
func, **get_jit_arguments(engine_kwargs, kwargs)
@@ -1290,10 +1291,18 @@ def _aggregate_with_numba(
12901291
sorted_index,
12911292
starts,
12921293
ends,
1293-
len(data.columns),
1294+
len(df.columns),
12941295
*args,
12951296
)
1296-
return result
1297+
1298+
index = self.grouper.result_index
1299+
if data.ndim == 1:
1300+
result_kwargs = {"name": data.name}
1301+
result = result.ravel()
1302+
else:
1303+
result_kwargs = {"columns": data.columns}
1304+
result = data._constructor(result, index=index, **result_kwargs)
1305+
return self._wrap_aggregated_output(result)
12971306

12981307
# -----------------------------------------------------------------
12991308
# apply/agg/transform

pandas/tests/groupby/aggregate/test_numba.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ def incorrect_function(values, index):
5151
# Filter warnings when parallel=True and the function can't be parallelized by Numba
5252
@pytest.mark.parametrize("jit", [True, False])
5353
@pytest.mark.parametrize("pandas_obj", ["Series", "DataFrame"])
54-
def test_numba_vs_cython(jit, pandas_obj, nogil, parallel, nopython):
54+
@pytest.mark.parametrize("as_index", [True, False])
55+
def test_numba_vs_cython(jit, pandas_obj, nogil, parallel, nopython, as_index):
5556
def func_numba(values, index):
5657
return np.mean(values) * 2.7
5758

@@ -65,7 +66,7 @@ def func_numba(values, index):
6566
{0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1]
6667
)
6768
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
68-
grouped = data.groupby(0)
69+
grouped = data.groupby(0, as_index=as_index)
6970
if pandas_obj == "Series":
7071
grouped = grouped[1]
7172

pandas/tests/groupby/transform/test_numba.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ def incorrect_function(values, index):
4848
# Filter warnings when parallel=True and the function can't be parallelized by Numba
4949
@pytest.mark.parametrize("jit", [True, False])
5050
@pytest.mark.parametrize("pandas_obj", ["Series", "DataFrame"])
51-
def test_numba_vs_cython(jit, pandas_obj, nogil, parallel, nopython):
51+
@pytest.mark.parametrize("as_index", [True, False])
52+
def test_numba_vs_cython(jit, pandas_obj, nogil, parallel, nopython, as_index):
5253
def func(values, index):
5354
return values + 1
5455

@@ -62,7 +63,7 @@ def func(values, index):
6263
{0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1]
6364
)
6465
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
65-
grouped = data.groupby(0)
66+
grouped = data.groupby(0, as_index=as_index)
6667
if pandas_obj == "Series":
6768
grouped = grouped[1]
6869

0 commit comments

Comments
 (0)