Skip to content

Commit c0ef2f6

Browse files
authored
Fix set_index when an existing dimension becomes a level (pydata#3520)
* Added a test * Fix set_index * lint * black / mypy * Use _replace method * whats new
1 parent 8b24037 commit c0ef2f6

File tree

4 files changed

+27
-7
lines changed

4 files changed

+27
-7
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ New Features
7979

8080
Bug fixes
8181
~~~~~~~~~
82+
- Fix a bug in `set_index` in case that an existing dimension becomes a level variable of MultiIndex. (:pull:`3520`)
83+
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
8284
- Harmonize `_FillValue`, `missing_value` during encoding and decoding steps. (:pull:`3502`)
8385
By `Anderson Banihirwe <https://github.com/andersy005>`_.
8486
- Fix regression introduced in v0.14.0 that would cause a crash if dask is installed

xarray/core/dataarray.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
assert_coordinate_consistent,
4949
remap_label_indexers,
5050
)
51-
from .dataset import Dataset, merge_indexes, split_indexes
51+
from .dataset import Dataset, split_indexes
5252
from .formatting import format_item
5353
from .indexes import Indexes, default_indexes
5454
from .merge import PANDAS_TYPES
@@ -1601,10 +1601,10 @@ def set_index(
16011601
--------
16021602
DataArray.reset_index
16031603
"""
1604-
_check_inplace(inplace)
1605-
indexes = either_dict_or_kwargs(indexes, indexes_kwargs, "set_index")
1606-
coords, _ = merge_indexes(indexes, self._coords, set(), append=append)
1607-
return self._replace(coords=coords)
1604+
ds = self._to_temp_dataset().set_index(
1605+
indexes, append=append, inplace=inplace, **indexes_kwargs
1606+
)
1607+
return self._from_temp_dataset(ds)
16081608

16091609
def reset_index(
16101610
self,

xarray/core/dataset.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ def merge_indexes(
204204
"""
205205
vars_to_replace: Dict[Hashable, Variable] = {}
206206
vars_to_remove: List[Hashable] = []
207+
dims_to_replace: Dict[Hashable, Hashable] = {}
207208
error_msg = "{} is not the name of an existing variable."
208209

209210
for dim, var_names in indexes.items():
@@ -244,7 +245,7 @@ def merge_indexes(
244245
if not len(names) and len(var_names) == 1:
245246
idx = pd.Index(variables[var_names[0]].values)
246247

247-
else:
248+
else: # MultiIndex
248249
for n in var_names:
249250
try:
250251
var = variables[n]
@@ -256,15 +257,22 @@ def merge_indexes(
256257
levels.append(cat.categories)
257258

258259
idx = pd.MultiIndex(levels, codes, names=names)
260+
for n in names:
261+
dims_to_replace[n] = dim
259262

260263
vars_to_replace[dim] = IndexVariable(dim, idx)
261264
vars_to_remove.extend(var_names)
262265

263266
new_variables = {k: v for k, v in variables.items() if k not in vars_to_remove}
264267
new_variables.update(vars_to_replace)
268+
269+
# update dimensions if necessary GH: 3512
270+
for k, v in new_variables.items():
271+
if any(d in dims_to_replace for d in v.dims):
272+
new_dims = [dims_to_replace.get(d, d) for d in v.dims]
273+
new_variables[k] = v._replace(dims=new_dims)
265274
new_coord_names = coord_names | set(vars_to_replace)
266275
new_coord_names -= set(vars_to_remove)
267-
268276
return new_variables, new_coord_names
269277

270278

xarray/tests/test_dataarray.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,6 +1182,16 @@ def test_selection_multiindex_remove_unused(self):
11821182
expected = expected.set_index(xy=["x", "y"]).unstack()
11831183
assert_identical(expected, actual)
11841184

1185+
def test_selection_multiindex_from_level(self):
1186+
# GH: 3512
1187+
da = DataArray([0, 1], dims=["x"], coords={"x": [0, 1], "y": "a"})
1188+
db = DataArray([2, 3], dims=["x"], coords={"x": [0, 1], "y": "b"})
1189+
data = xr.concat([da, db], dim="x").set_index(xy=["x", "y"])
1190+
assert data.dims == ("xy",)
1191+
actual = data.sel(y="a")
1192+
expected = data.isel(xy=[0, 1]).unstack("xy").squeeze("y").drop("y")
1193+
assert_equal(actual, expected)
1194+
11851195
def test_virtual_default_coords(self):
11861196
array = DataArray(np.zeros((5,)), dims="x")
11871197
expected = DataArray(range(5), dims="x", name="x")

0 commit comments

Comments
 (0)