Skip to content

Commit 913cfab

Browse files
authored
BUG: DataFrameGroupBy.__getitem__ with non-unique columns (#41427)
1 parent 7a90824 commit 913cfab

File tree

4 files changed

+43
-8
lines changed

4 files changed

+43
-8
lines changed

doc/source/whatsnew/v1.3.0.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,8 @@ Groupby/resample/rolling
895895
- Bug in :meth:`SeriesGroupBy.agg` failing to retain ordered :class:`CategoricalDtype` on order-preserving aggregations (:issue:`41147`)
896896
- Bug in :meth:`DataFrameGroupBy.min` and :meth:`DataFrameGroupBy.max` with multiple object-dtype columns and ``numeric_only=False`` incorrectly raising ``ValueError`` (:issue:41111`)
897897
- Bug in :meth:`DataFrameGroupBy.rank` with the GroupBy object's ``axis=0`` and the ``rank`` method's keyword ``axis=1`` (:issue:`41320`)
898+
- Bug in :meth:`DataFrameGroupBy.__getitem__` with non-unique columns incorrectly returning a malformed :class:`SeriesGroupBy` instead of :class:`DataFrameGroupBy` (:issue:`41427`)
899+
- Bug in :meth:`DataFrameGroupBy.transform` with non-unique columns incorrectly raising ``AttributeError`` (:issue:`41427`)
898900

899901
Reshaping
900902
^^^^^^^^^

pandas/core/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def ndim(self) -> int:
214214
@cache_readonly
215215
def _obj_with_exclusions(self):
216216
if self._selection is not None and isinstance(self.obj, ABCDataFrame):
217-
return self.obj.reindex(columns=self._selection_list)
217+
return self.obj[self._selection_list]
218218

219219
if len(self.exclusions) > 0:
220220
return self.obj.drop(self.exclusions, axis=1)
@@ -239,7 +239,9 @@ def __getitem__(self, key):
239239
else:
240240
if key not in self.obj:
241241
raise KeyError(f"Column not found: {key}")
242-
return self._gotitem(key, ndim=1)
242+
subset = self.obj[key]
243+
ndim = subset.ndim
244+
return self._gotitem(key, ndim=ndim, subset=subset)
243245

244246
def _gotitem(self, key, ndim: int, subset=None):
245247
"""

pandas/core/groupby/generic.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,12 +1417,19 @@ def _choose_path(self, fast_path: Callable, slow_path: Callable, group: DataFram
14171417
return path, res
14181418

14191419
def _transform_item_by_item(self, obj: DataFrame, wrapper) -> DataFrame:
1420-
# iterate through columns
1420+
# iterate through columns, see test_transform_exclude_nuisance
14211421
output = {}
14221422
inds = []
14231423
for i, col in enumerate(obj):
1424+
subset = obj.iloc[:, i]
1425+
sgb = SeriesGroupBy(
1426+
subset,
1427+
selection=col,
1428+
grouper=self.grouper,
1429+
exclusions=self.exclusions,
1430+
)
14241431
try:
1425-
output[col] = self[col].transform(wrapper)
1432+
output[i] = sgb.transform(wrapper)
14261433
except TypeError:
14271434
# e.g. trying to call nanmean with string values
14281435
pass
@@ -1434,7 +1441,9 @@ def _transform_item_by_item(self, obj: DataFrame, wrapper) -> DataFrame:
14341441

14351442
columns = obj.columns.take(inds)
14361443

1437-
return self.obj._constructor(output, index=obj.index, columns=columns)
1444+
result = self.obj._constructor(output, index=obj.index)
1445+
result.columns = columns
1446+
return result
14381447

14391448
def filter(self, func, dropna=True, *args, **kwargs):
14401449
"""
@@ -1504,7 +1513,7 @@ def filter(self, func, dropna=True, *args, **kwargs):
15041513

15051514
return self._apply_filter(indices, dropna)
15061515

1507-
def __getitem__(self, key):
1516+
def __getitem__(self, key) -> DataFrameGroupBy | SeriesGroupBy:
15081517
if self.axis == 1:
15091518
# GH 37725
15101519
raise ValueError("Cannot subset columns when using axis=1")

pandas/tests/groupby/transform/test_transform.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
date_range,
2121
)
2222
import pandas._testing as tm
23+
from pandas.core.groupby.generic import (
24+
DataFrameGroupBy,
25+
SeriesGroupBy,
26+
)
2327
from pandas.core.groupby.groupby import DataError
2428

2529

@@ -391,13 +395,31 @@ def test_transform_select_columns(df):
391395
tm.assert_frame_equal(result, expected)
392396

393397

394-
def test_transform_exclude_nuisance(df):
398+
@pytest.mark.parametrize("duplicates", [True, False])
399+
def test_transform_exclude_nuisance(df, duplicates):
400+
# case that goes through _transform_item_by_item
401+
402+
if duplicates:
403+
# make sure we work with duplicate columns GH#41427
404+
df.columns = ["A", "C", "C", "D"]
395405

396406
# this also tests orderings in transform between
397407
# series/frame to make sure it's consistent
398408
expected = {}
399409
grouped = df.groupby("A")
400-
expected["C"] = grouped["C"].transform(np.mean)
410+
411+
gbc = grouped["C"]
412+
expected["C"] = gbc.transform(np.mean)
413+
if duplicates:
414+
# squeeze 1-column DataFrame down to Series
415+
expected["C"] = expected["C"]["C"]
416+
417+
assert isinstance(gbc.obj, DataFrame)
418+
assert isinstance(gbc, DataFrameGroupBy)
419+
else:
420+
assert isinstance(gbc, SeriesGroupBy)
421+
assert isinstance(gbc.obj, Series)
422+
401423
expected["D"] = grouped["D"].transform(np.mean)
402424
expected = DataFrame(expected)
403425
result = df.groupby("A").transform(np.mean)

0 commit comments

Comments
 (0)