Skip to content

Commit 8c16be3

Browse files
committed
BUG: groupby sum, mean, var should always be floats
1 parent 714f7d7 commit 8c16be3

20 files changed

+219
-81
lines changed

doc/source/whatsnew/v1.3.0.rst

+31
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,36 @@ Preserve dtypes in :meth:`~pandas.DataFrame.combine_first`
298298
299299
combined.dtypes
300300
301+
Group by methods agg and transform no longer changes return dtype for callables
302+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
303+
304+
Previously the methods :meth:`.DataFrameGroupBy.aggregate`,
305+
:meth:`.SeriesGroupBy.aggregate`, :meth:`.DataFrameGroupBy.transform`, and
306+
:meth:`.SeriesGroupBy.transform` might cast the result dtype when the argument ``func``
307+
is callable, possibly leading to undesirable results (:issue:`21240`). The cast would
308+
occur if the result is numeric and casting back to the input dtype does not change any
309+
values as measured by ``np.allclose``. Now no such casting occurs.
310+
311+
.. ipython:: python
312+
313+
df = pd.DataFrame({'key': [1, 1], 'a': [True, False], 'b': [True, True]})
314+
df
315+
316+
*pandas 1.2.x*
317+
318+
.. code-block:: ipython
319+
320+
In [5]: df.groupby('key').agg(lambda x: x.sum())
321+
Out[5]:
322+
a b
323+
key
324+
1 True 2
325+
326+
*pandas 1.3.0*
327+
328+
.. ipython:: python
329+
330+
In [5]: df.groupby('key').agg(lambda x: x.sum())
301331
302332
Try operating inplace when setting values with ``loc`` and ``iloc``
303333
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -847,6 +877,7 @@ Groupby/resample/rolling
847877
- Bug in :meth:`GroupBy.cummin` and :meth:`GroupBy.cummax` incorrectly rounding integer values near the ``int64`` implementations bounds (:issue:`40767`)
848878
- Bug in :meth:`.GroupBy.rank` with nullable dtypes incorrectly raising ``TypeError`` (:issue:`41010`)
849879
- Bug in :meth:`.GroupBy.cummin` and :meth:`.GroupBy.cummax` computing wrong result with nullable data types too large to roundtrip when casting to float (:issue:`37493`)
880+
- Bug in :meth:`.GroupBy.mean`, :meth:`.GroupBy.median`, and :meth:`.GroupBy.var` would return integer dtype if the result happened to be an integer; now these methods will always return floats (:issue:`41137`)
850881

851882
Reshaping
852883
^^^^^^^^^

pandas/_libs/lib.pyx

+10-12
Original file line numberDiff line numberDiff line change
@@ -2233,7 +2233,7 @@ def maybe_convert_objects(ndarray[object] objects, bint try_float=False,
22332233
Array of converted object values to more specific dtypes if applicable.
22342234
"""
22352235
cdef:
2236-
Py_ssize_t i, n, itemsize_max = 0
2236+
Py_ssize_t i, n, itemsize = 0
22372237
ndarray[float64_t] floats
22382238
ndarray[complex128_t] complexes
22392239
ndarray[int64_t] ints
@@ -2266,10 +2266,12 @@ def maybe_convert_objects(ndarray[object] objects, bint try_float=False,
22662266

22672267
for i in range(n):
22682268
val = objects[i]
2269-
if itemsize_max != -1:
2270-
itemsize = get_itemsize(val)
2271-
if itemsize > itemsize_max or itemsize == -1:
2272-
itemsize_max = itemsize
2269+
if (
2270+
hasattr(val, "dtype")
2271+
and hasattr(val.dtype, "itemsize")
2272+
and val.dtype.itemsize > itemsize
2273+
):
2274+
itemsize = val.dtype.itemsize
22732275

22742276
if val is None:
22752277
seen.null_ = True
@@ -2458,13 +2460,9 @@ def maybe_convert_objects(ndarray[object] objects, bint try_float=False,
24582460
result = ints
24592461
elif seen.is_bool and not seen.nan_:
24602462
result = bools.view(np.bool_)
2461-
2462-
if result is uints or result is ints or result is floats or result is complexes:
2463-
# cast to the largest itemsize when all values are NumPy scalars
2464-
if itemsize_max > 0 and itemsize_max != result.dtype.itemsize:
2465-
result = result.astype(result.dtype.kind + str(itemsize_max))
2466-
return result
2467-
elif result is not None:
2463+
if result is not None:
2464+
if itemsize > 0 and itemsize != result.dtype.itemsize:
2465+
result = result.astype(result.dtype.kind + str(itemsize))
24682466
return result
24692467

24702468
return objects

pandas/core/dtypes/cast.py

+39
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,45 @@ def maybe_cast_pointwise_result(
406406
return result
407407

408408

409+
def maybe_cast_result_dtype(dtype: DtypeObj, how: str) -> DtypeObj:
410+
"""
411+
Get the desired dtype of a result based on the
412+
input dtype and how it was computed.
413+
414+
Parameters
415+
----------
416+
dtype : DtypeObj
417+
Input dtype.
418+
how : str
419+
How the result was computed.
420+
421+
Returns
422+
-------
423+
DtypeObj
424+
The desired dtype of the result.
425+
"""
426+
from pandas.core.arrays.boolean import BooleanDtype
427+
from pandas.core.arrays.floating import Float64Dtype
428+
from pandas.core.arrays.integer import (
429+
Int64Dtype,
430+
_IntegerDtype,
431+
)
432+
433+
if how in ["add", "cumsum", "sum", "prod"]:
434+
if dtype == np.dtype(bool):
435+
return np.dtype(np.int64)
436+
elif isinstance(dtype, (BooleanDtype, _IntegerDtype)):
437+
return Int64Dtype()
438+
elif how in ["mean", "median", "var"]:
439+
if isinstance(dtype, (BooleanDtype, _IntegerDtype)):
440+
return Float64Dtype()
441+
elif is_float_dtype(dtype):
442+
return dtype
443+
elif is_numeric_dtype(dtype):
444+
return np.dtype(np.float64)
445+
return dtype
446+
447+
409448
def maybe_cast_to_extension_array(
410449
cls: type[ExtensionArray], obj: ArrayLike, dtype: ExtensionDtype | None = None
411450
) -> ArrayLike:

pandas/core/groupby/generic.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,6 @@
4444
doc,
4545
)
4646

47-
from pandas.core.dtypes.cast import (
48-
find_common_type,
49-
maybe_downcast_numeric,
50-
)
5147
from pandas.core.dtypes.common import (
5248
ensure_int64,
5349
is_bool,
@@ -588,8 +584,9 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
588584

589585
def _transform_general(self, func, *args, **kwargs):
590586
"""
591-
Transform with a non-str `func`.
587+
Transform with a callable func`.
592588
"""
589+
assert callable(func)
593590
klass = type(self._selected_obj)
594591

595592
results = []
@@ -613,10 +610,6 @@ def _transform_general(self, func, *args, **kwargs):
613610
# we will only try to coerce the result type if
614611
# we have a numeric dtype, as these are *always* user-defined funcs
615612
# the cython take a different path (and casting)
616-
if is_numeric_dtype(result.dtype):
617-
common_dtype = find_common_type([self._selected_obj.dtype, result.dtype])
618-
if common_dtype is result.dtype:
619-
result = maybe_downcast_numeric(result, self._selected_obj.dtype)
620613

621614
result.name = self._selected_obj.name
622615
return result

pandas/core/groupby/groupby.py

-3
Original file line numberDiff line numberDiff line change
@@ -1241,9 +1241,6 @@ def _python_agg_general(self, func, *args, **kwargs):
12411241
assert result is not None
12421242
key = base.OutputKey(label=name, position=idx)
12431243

1244-
if is_numeric_dtype(obj.dtype):
1245-
result = maybe_downcast_numeric(result, obj.dtype)
1246-
12471244
if self.grouper._filter_empty_groups:
12481245
mask = counts.ravel() > 0
12491246

pandas/core/groupby/ops.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -290,10 +290,13 @@ def get_result_dtype(self, dtype: DtypeObj) -> DtypeObj:
290290
return np.dtype(np.int64)
291291
elif isinstance(dtype, (BooleanDtype, _IntegerDtype)):
292292
return Int64Dtype()
293-
elif how in ["mean", "median", "var"] and isinstance(
294-
dtype, (BooleanDtype, _IntegerDtype)
295-
):
296-
return Float64Dtype()
293+
elif how in ["mean", "median", "var"]:
294+
if isinstance(dtype, (BooleanDtype, _IntegerDtype)):
295+
return Float64Dtype()
296+
elif is_float_dtype(dtype):
297+
return dtype
298+
elif is_numeric_dtype(dtype):
299+
return np.dtype(np.float64)
297300
return dtype
298301

299302
def uses_mask(self) -> bool:

pandas/tests/extension/base/groupby.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_groupby_extension_agg(self, as_index, data_for_grouping):
2525
_, index = pd.factorize(data_for_grouping, sort=True)
2626

2727
index = pd.Index(index, name="B")
28-
expected = pd.Series([3, 1, 4], index=index, name="A")
28+
expected = pd.Series([3.0, 1.0, 4.0], index=index, name="A")
2929
if as_index:
3030
self.assert_series_equal(result, expected)
3131
else:
@@ -54,7 +54,7 @@ def test_groupby_extension_no_sort(self, data_for_grouping):
5454
_, index = pd.factorize(data_for_grouping, sort=False)
5555

5656
index = pd.Index(index, name="B")
57-
expected = pd.Series([1, 3, 4], index=index, name="A")
57+
expected = pd.Series([1.0, 3.0, 4.0], index=index, name="A")
5858
self.assert_series_equal(result, expected)
5959

6060
def test_groupby_extension_transform(self, data_for_grouping):

pandas/tests/extension/test_boolean.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def test_groupby_extension_agg(self, as_index, data_for_grouping):
272272
_, index = pd.factorize(data_for_grouping, sort=True)
273273

274274
index = pd.Index(index, name="B")
275-
expected = pd.Series([3, 1], index=index, name="A")
275+
expected = pd.Series([3.0, 1.0], index=index, name="A")
276276
if as_index:
277277
self.assert_series_equal(result, expected)
278278
else:
@@ -301,7 +301,7 @@ def test_groupby_extension_no_sort(self, data_for_grouping):
301301
_, index = pd.factorize(data_for_grouping, sort=False)
302302

303303
index = pd.Index(index, name="B")
304-
expected = pd.Series([1, 3], index=index, name="A")
304+
expected = pd.Series([1.0, 3.0], index=index, name="A")
305305
self.assert_series_equal(result, expected)
306306

307307
def test_groupby_extension_transform(self, data_for_grouping):

pandas/tests/groupby/aggregate/test_aggregate.py

+61-5
Original file line numberDiff line numberDiff line change
@@ -234,11 +234,10 @@ def test_aggregate_item_by_item(df):
234234
K = len(result.columns)
235235

236236
# GH5782
237-
# odd comparisons can result here, so cast to make easy
238-
exp = Series(np.array([foo] * K), index=list("BCD"), dtype=np.float64, name="foo")
237+
exp = Series(np.array([foo] * K), index=list("BCD"), name="foo")
239238
tm.assert_series_equal(result.xs("foo"), exp)
240239

241-
exp = Series(np.array([bar] * K), index=list("BCD"), dtype=np.float64, name="bar")
240+
exp = Series(np.array([bar] * K), index=list("BCD"), name="bar")
242241
tm.assert_almost_equal(result.xs("bar"), exp)
243242

244243
def aggfun(ser):
@@ -442,6 +441,57 @@ def test_bool_agg_dtype(op):
442441
assert is_integer_dtype(result)
443442

444443

444+
@pytest.mark.parametrize(
445+
"keys, agg_index",
446+
[
447+
(["a"], Index([1], name="a")),
448+
(["a", "b"], MultiIndex([[1], [2]], [[0], [0]], names=["a", "b"])),
449+
],
450+
)
451+
@pytest.mark.parametrize(
452+
"input_dtype", ["bool", "int32", "int64", "float32", "float64"]
453+
)
454+
@pytest.mark.parametrize(
455+
"result_dtype", ["bool", "int32", "int64", "float32", "float64"]
456+
)
457+
@pytest.mark.parametrize("method", ["apply", "aggregate", "transform"])
458+
def test_callable_result_dtype_frame(
459+
keys, agg_index, input_dtype, result_dtype, method
460+
):
461+
# GH 21240
462+
df = DataFrame({"a": [1], "b": [2], "c": [True]})
463+
df["c"] = df["c"].astype(input_dtype)
464+
op = getattr(df.groupby(keys)[["c"]], method)
465+
result = op(lambda x: x.astype(result_dtype).iloc[0])
466+
expected_index = pd.RangeIndex(0, 1) if method == "transform" else agg_index
467+
expected = DataFrame({"c": [df["c"].iloc[0]]}, index=expected_index).astype(
468+
result_dtype
469+
)
470+
if method == "apply":
471+
expected.columns.names = [0]
472+
tm.assert_frame_equal(result, expected)
473+
474+
475+
@pytest.mark.parametrize(
476+
"keys, agg_index",
477+
[
478+
(["a"], Index([1], name="a")),
479+
(["a", "b"], MultiIndex([[1], [2]], [[0], [0]], names=["a", "b"])),
480+
],
481+
)
482+
@pytest.mark.parametrize("input", [True, 1, 1.0])
483+
@pytest.mark.parametrize("dtype", [bool, int, float])
484+
@pytest.mark.parametrize("method", ["apply", "aggregate", "transform"])
485+
def test_callable_result_dtype_series(keys, agg_index, input, dtype, method):
486+
# GH 21240
487+
df = DataFrame({"a": [1], "b": [2], "c": [input]})
488+
op = getattr(df.groupby(keys)["c"], method)
489+
result = op(lambda x: x.astype(dtype).iloc[0])
490+
expected_index = pd.RangeIndex(0, 1) if method == "transform" else agg_index
491+
expected = Series([df["c"].iloc[0]], index=expected_index, name="c").astype(dtype)
492+
tm.assert_series_equal(result, expected)
493+
494+
445495
def test_order_aggregate_multiple_funcs():
446496
# GH 25692
447497
df = DataFrame({"A": [1, 1, 2, 2], "B": [1, 2, 3, 4]})
@@ -462,7 +512,9 @@ def test_uint64_type_handling(dtype, how):
462512
expected = df.groupby("y").agg({"x": how})
463513
df.x = df.x.astype(dtype)
464514
result = df.groupby("y").agg({"x": how})
465-
result.x = result.x.astype(np.int64)
515+
if how not in ("mean", "median"):
516+
# mean and median always result in floats
517+
result.x = result.x.astype(np.int64)
466518
tm.assert_frame_equal(result, expected, check_exact=True)
467519

468520

@@ -849,7 +901,11 @@ def test_multiindex_custom_func(func):
849901
data = [[1, 4, 2], [5, 7, 1]]
850902
df = DataFrame(data, columns=MultiIndex.from_arrays([[1, 1, 2], [3, 4, 3]]))
851903
result = df.groupby(np.array([0, 1])).agg(func)
852-
expected_dict = {(1, 3): {0: 1, 1: 5}, (1, 4): {0: 4, 1: 7}, (2, 3): {0: 2, 1: 1}}
904+
expected_dict = {
905+
(1, 3): {0: 1.0, 1: 5.0},
906+
(1, 4): {0: 4.0, 1: 7.0},
907+
(2, 3): {0: 2.0, 1: 1.0},
908+
}
853909
expected = DataFrame(expected_dict)
854910
tm.assert_frame_equal(result, expected)
855911

pandas/tests/groupby/test_categorical.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -285,8 +285,6 @@ def test_apply(ordered):
285285
result = grouped.apply(lambda x: np.mean(x))
286286
tm.assert_frame_equal(result, expected)
287287

288-
# we coerce back to ints
289-
expected = expected.astype("int")
290288
result = grouped.mean()
291289
tm.assert_frame_equal(result, expected)
292290

@@ -371,7 +369,7 @@ def test_observed(observed, using_array_manager):
371369
result = groups_double_key.agg("mean")
372370
expected = DataFrame(
373371
{
374-
"val": [10, 30, 20, 40],
372+
"val": [10.0, 30, 20, 40],
375373
"cat": Categorical(
376374
["a", "a", "b", "b"], categories=["a", "b", "c"], ordered=True
377375
),
@@ -418,7 +416,9 @@ def test_observed_codes_remap(observed):
418416
groups_double_key = df.groupby([values, "C2"], observed=observed)
419417

420418
idx = MultiIndex.from_arrays([values, [1, 2, 3, 4]], names=["cat", "C2"])
421-
expected = DataFrame({"C1": [3, 3, 4, 5], "C3": [10, 100, 200, 34]}, index=idx)
419+
expected = DataFrame(
420+
{"C1": [3.0, 3.0, 4.0, 5.0], "C3": [10.0, 100.0, 200.0, 34.0]}, index=idx
421+
)
422422
if not observed:
423423
expected = cartesian_product_for_groupers(
424424
expected, [values.values, [1, 2, 3, 4]], ["cat", "C2"]
@@ -1505,7 +1505,9 @@ def test_read_only_category_no_sort():
15051505
df = DataFrame(
15061506
{"a": [1, 3, 5, 7], "b": Categorical([1, 1, 2, 2], categories=Index(cats))}
15071507
)
1508-
expected = DataFrame(data={"a": [2, 6]}, index=CategoricalIndex([1, 2], name="b"))
1508+
expected = DataFrame(
1509+
data={"a": [2.0, 6.0]}, index=CategoricalIndex([1, 2], name="b")
1510+
)
15091511
result = df.groupby("b", sort=False).mean()
15101512
tm.assert_frame_equal(result, expected)
15111513

@@ -1597,7 +1599,7 @@ def test_aggregate_categorical_with_isnan():
15971599
index = MultiIndex.from_arrays([[1, 1], [1, 2]], names=("A", "B"))
15981600
expected = DataFrame(
15991601
data={
1600-
"numerical_col": [1.0, 0.0],
1602+
"numerical_col": [1, 0],
16011603
"object_col": [0, 0],
16021604
"categorical_col": [0, 0],
16031605
},

pandas/tests/groupby/test_function.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,8 @@ def test_median_empty_bins(observed):
408408

409409
result = df.groupby(bins, observed=observed).median()
410410
expected = df.groupby(bins, observed=observed).agg(lambda x: x.median())
411-
tm.assert_frame_equal(result, expected)
411+
# TODO: GH 41137
412+
tm.assert_frame_equal(result, expected, check_dtype=False)
412413

413414

414415
@pytest.mark.parametrize(
@@ -588,7 +589,7 @@ def test_ops_general(op, targop):
588589
df = DataFrame(np.random.randn(1000))
589590
labels = np.random.randint(0, 50, size=1000).astype(float)
590591

591-
result = getattr(df.groupby(labels), op)().astype(float)
592+
result = getattr(df.groupby(labels), op)()
592593
expected = df.groupby(labels).agg(targop)
593594
tm.assert_frame_equal(result, expected)
594595

0 commit comments

Comments
 (0)