Skip to content

Commit e5cc1e5

Browse files
committed
allow for top and mid-level assignment to DataFrames with MultIndex columns
1 parent e659627 commit e5cc1e5

File tree

4 files changed

+130
-5
lines changed

4 files changed

+130
-5
lines changed

pandas/core/frame.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3047,10 +3047,15 @@ def __setitem__(self, key, value):
30473047
# to a slice for partial-string date indexing
30483048
return self._setitem_slice(indexer, value)
30493049

3050+
# mimic getitem behavior
3051+
is_single_key = isinstance(key, tuple) or not is_list_like(key)
3052+
30503053
if isinstance(key, DataFrame) or getattr(key, "ndim", None) == 2:
30513054
self._setitem_frame(key, value)
30523055
elif isinstance(key, (Series, np.ndarray, list, Index)):
30533056
self._setitem_array(key, value)
3057+
elif is_single_key and self.columns.nlevels > 1:
3058+
return self._setitem_multilevel(key, value)
30543059
else:
30553060
# set column
30563061
self._set_item(key, value)
@@ -3075,10 +3080,18 @@ def _setitem_array(self, key, value):
30753080
self.iloc._setitem_with_indexer(indexer, value)
30763081
else:
30773082
if isinstance(value, DataFrame):
3078-
if len(value.columns) != len(key):
3079-
raise ValueError("Columns must be same length as key")
3080-
for k1, k2 in zip(key, value.columns):
3081-
self[k1] = value[k2]
3083+
columns = value.columns
3084+
if len(columns) == len(key):
3085+
for k1, k2 in zip(key, columns):
3086+
self[k1] = value[k2]
3087+
elif columns.nlevels > 1 and len(columns.levels[0]) == len(key):
3088+
for k1, k2 in zip(key, columns.levels[0]):
3089+
self[k1] = value[k2]
3090+
else:
3091+
raise ValueError(
3092+
"Key must be same length as columns or top level of "
3093+
"MultiIndex"
3094+
)
30823095
else:
30833096
self.loc._ensure_listlike_indexer(key, axis=1)
30843097
indexer = self.loc._get_listlike_indexer(
@@ -3104,6 +3117,30 @@ def _setitem_frame(self, key, value):
31043117
self._check_setitem_copy()
31053118
self._where(-key, value, inplace=True)
31063119

3120+
def _setitem_multilevel(self, key, value):
3121+
# self.columns is a MultiIndex
3122+
if key in self.columns:
3123+
self._set_item(key, value)
3124+
else:
3125+
if not isinstance(key, tuple):
3126+
key = (key,)
3127+
if isinstance(value, DataFrame):
3128+
if len(key) + value.columns.nlevels != self.columns.nlevels:
3129+
raise TypeError(
3130+
"Must pass key/value pair that conforms with number of column "
3131+
"levels"
3132+
)
3133+
if value.columns.nlevels > 1:
3134+
items = MultiIndex.from_tuples([key + i for i in value.columns])
3135+
else:
3136+
items = MultiIndex.from_tuples([key + (i,) for i in value.columns])
3137+
else:
3138+
if len(key) < self.columns.nlevels:
3139+
key = key + ("",) * (self.columns.nlevels - len(key))
3140+
items = MultiIndex.from_tuples([key])
3141+
value = self._sanitize_column(key, value)
3142+
self._mgr.append_block(items, value)
3143+
31073144
def _iset_item(self, loc: int, value):
31083145
self._ensure_valid_index(value)
31093146

pandas/core/internals/managers.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,6 +1175,26 @@ def insert(self, loc: int, item: Label, value, allow_duplicates: bool = False):
11751175
if len(self.blocks) > 100:
11761176
self._consolidate_inplace()
11771177

1178+
def append_block(self, items, values):
1179+
base, size = len(self.items), len(items)
1180+
1181+
new_axis = self.items.append(items)
1182+
block = make_block(
1183+
values=values, ndim=self.ndim, placement=slice(base, base + size)
1184+
)
1185+
1186+
blk_no = len(self.blocks)
1187+
self._blklocs = np.append(self.blklocs, range(size))
1188+
self._blknos = np.append(self.blknos, size * (blk_no,))
1189+
1190+
self.axes[0] = new_axis
1191+
self.blocks += (block,)
1192+
1193+
self._known_consolidated = False
1194+
1195+
if len(self.blocks) > 100:
1196+
self._consolidate_inplace()
1197+
11781198
def reindex_axis(
11791199
self,
11801200
new_index,

pandas/tests/frame/indexing/test_indexing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def test_setitem_list(self, float_frame):
157157
tm.assert_series_equal(float_frame["B"], data["A"], check_names=False)
158158
tm.assert_series_equal(float_frame["A"], data["B"], check_names=False)
159159

160-
msg = "Columns must be same length as key"
160+
msg = "Key must be same length as columns or top level of MultiIndex"
161161
with pytest.raises(ValueError, match=msg):
162162
data[["A"]] = float_frame[["A", "B"]]
163163
newcolumndata = range(len(data.index) - 1)

pandas/tests/indexing/multiindex/test_multiindex.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,71 @@ def test_multiindex_get_loc_list_raises(self):
9191
msg = "unhashable type"
9292
with pytest.raises(TypeError, match=msg):
9393
idx.get_loc([])
94+
95+
def test_multiindex_frame_assign(self):
96+
df0 = pd.DataFrame({"a": [0, 1, 2, 3], "b": [3, 4, 5, 6]})
97+
df1 = pd.concat({"x": df0, "y": df0}, axis=1)
98+
df2 = pd.concat({"q": df1, "r": df1}, axis=1)
99+
100+
# level one assign
101+
result = df2.copy()
102+
result["m"] = result["q"] + result["r"]
103+
expected = pd.concat({"q": df1, "r": df1, "m": 2 * df1}, axis=1)
104+
tm.assert_frame_equal(result, expected)
105+
106+
# level one assign - multiple
107+
result = df2.copy()
108+
result[["m", "n"]] = 2 * result[["q", "r"]]
109+
expected = pd.concat({"q": df1, "r": df1, "m": 2 * df1, "n": 2 * df1}, axis=1)
110+
tm.assert_frame_equal(result, expected)
111+
112+
# level two assign
113+
result = df2.copy()
114+
result["m", "x"] = df2["q", "x"] + df2["q", "y"]
115+
expected = pd.concat(
116+
{"q": df1, "r": df1, "m": pd.concat({"x": 2 * df0}, axis=1)}, axis=1
117+
)
118+
tm.assert_frame_equal(result, expected)
119+
120+
# level two assign - multiple (seems like getitem is not caught up here)
121+
result = df2.copy()
122+
result[[("m", "x"), ("n", "y")]] = 2 * df2["q"]
123+
expected = pd.concat(
124+
{
125+
"q": df1,
126+
"r": df1,
127+
"m": pd.concat({"x": 2 * df0}, axis=1),
128+
"n": pd.concat({"y": 2 * df0}, axis=1),
129+
},
130+
axis=1,
131+
)
132+
tm.assert_frame_equal(result, expected)
133+
134+
# level three assign
135+
result = df2.copy()
136+
result["m", "x", "a"] = df2["q", "x", "a"] + df2["q", "x", "b"]
137+
expected = pd.concat(
138+
{
139+
"q": df1,
140+
"r": df1,
141+
"m": pd.concat(
142+
{"x": pd.concat({"a": df0["a"] + df0["b"]}, axis=1)}, axis=1
143+
),
144+
},
145+
axis=1,
146+
)
147+
tm.assert_frame_equal(result, expected)
148+
149+
# level three assign - multiple
150+
result = df2.copy()
151+
result[[("m", "x", "a"), ("n", "y", "b")]] = 2 * df2["q", "x"]
152+
expected = pd.concat(
153+
{
154+
"q": df1,
155+
"r": df1,
156+
"m": pd.concat({"x": pd.concat({"a": 2 * df0["a"]}, axis=1)}, axis=1),
157+
"n": pd.concat({"y": pd.concat({"b": 2 * df0["b"]}, axis=1)}, axis=1),
158+
},
159+
axis=1,
160+
)
161+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)