Skip to content

Commit 381b073

Browse files
committed
Preserve EA dtype in DataFrame.stack
1 parent 4cac923 commit 381b073

File tree

4 files changed

+29
-1
lines changed

4 files changed

+29
-1
lines changed

doc/source/whatsnew/v0.24.0.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,8 @@ update the ``ExtensionDtype._metadata`` tuple to match the signature of your
724724
- Updated the ``.type`` attribute for ``PeriodDtype``, ``DatetimeTZDtype``, and ``IntervalDtype`` to be instances of the dtype (``Period``, ``Timestamp``, and ``Interval`` respectively) (:issue:`22938`)
725725
- :func:`ExtensionArray.isna` is allowed to return an ``ExtensionArray`` (:issue:`22325`).
726726
- Support for reduction operations such as ``sum``, ``mean`` via opt-in base class method override (:issue:`22762`)
727+
- :meth:`DataFrame.stack` no longer converts to object dtype for DataFrames where each column has the same extension dtype. The output Series will have the same dtype as the columns (:issue:`23077`).
728+
727729

728730
.. _whatsnew_0240.api.incompatibilities:
729731

pandas/core/reshape/reshape.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,8 +470,15 @@ def factorize(index):
470470
if is_extension_array_dtype(dtype):
471471
arr = dtype.construct_array_type()
472472
new_values = arr._concat_same_type([
473-
col for _, col in frame.iteritems()
473+
col._values for _, col in frame.iteritems()
474474
])
475+
# final take to get the order correct.
476+
# idx is an indexer like
477+
# [c0r0, c1r0, c2r0, ...,
478+
# c0r1, c1r1, c241, ...]
479+
idx = np.arange(N * K).reshape(K, N).T.ravel()
480+
new_values = new_values.take(idx)
481+
475482
else:
476483
# homogeneous, non-EA
477484
new_values = frame.values.ravel()

pandas/tests/extension/base/reshaping.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,11 @@ def test_merge(self, data, na_value):
170170
[data[0], data[0], data[1], data[2], na_value],
171171
dtype=data.dtype)})
172172
self.assert_frame_equal(res, exp[['ext', 'int1', 'key', 'int2']])
173+
174+
def test_stack(self, data):
175+
df = pd.DataFrame({"A": data[:5], "B": data[:5]})
176+
result = df.stack()
177+
assert result.dtype == df.A.dtype
178+
result = result.astype(object)
179+
expected = df.astype(object).stack()
180+
self.assert_series_equal(result, expected)

pandas/tests/frame/test_reshape.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,17 @@ def test_stack_preserve_categorical_dtype(self, ordered, labels):
872872

873873
tm.assert_series_equal(result, expected)
874874

875+
def test_stack_preserve_categorical_dtype_values(self):
876+
# GH-23077
877+
cat = pd.Categorical(['a', 'a', 'b', 'c'])
878+
df = pd.DataFrame({"A": cat, "B": cat})
879+
result = df.stack()
880+
index = pd.MultiIndex.from_product([[0, 1, 2, 3], ['A', 'B']])
881+
expected = pd.Series(pd.Categorical(['a', 'a', 'a', 'a',
882+
'b', 'b', 'c', 'c']),
883+
index=index)
884+
tm.assert_series_equal(result, expected)
885+
875886
@pytest.mark.parametrize("level", [0, 'baz'])
876887
def test_unstack_swaplevel_sortlevel(self, level):
877888
# GH 20994

0 commit comments

Comments
 (0)