From ad71e6cb21bb5bd2c7da2a6d4c020100d12289b4 Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Wed, 13 Nov 2019 23:45:57 +0900 Subject: [PATCH 1/6] Added a test --- xarray/tests/test_dataarray.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 42fae2c9dd4..1fe80085f35 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1182,6 +1182,15 @@ def test_selection_multiindex_remove_unused(self): expected = expected.set_index(xy=["x", "y"]).unstack() assert_identical(expected, actual) + def test_selection_multiindex_from_level(self): + #GH: issue 3512 + da = DataArray([0, 1], dims=['x'], coords={'x': [0, 1], 'y': 'a'}) + db = DataArray([2, 3], dims=['x'], coords={'x': [0, 1], 'y': 'b'}) + data = xr.concat([da, db], dim='x').set_index(xy=['x', 'y']) + actual = data.sel(y='a') + expected = data.isel(xy=[0, 1]).unstack('xy').drop('y') + assert_identical(actual, expected) + def test_virtual_default_coords(self): array = DataArray(np.zeros((5,)), dims="x") expected = DataArray(range(5), dims="x", name="x") From d942e575744bb2d78f2c15b4263193985a184787 Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Thu, 14 Nov 2019 01:05:24 +0900 Subject: [PATCH 2/6] Fix set_index --- xarray/core/dataarray.py | 7 +++---- xarray/core/dataset.py | 13 +++++++++++-- xarray/tests/test_dataarray.py | 7 ++++--- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index a192fe08cee..2124159a531 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1601,10 +1601,9 @@ def set_index( -------- DataArray.reset_index """ - _check_inplace(inplace) - indexes = either_dict_or_kwargs(indexes, indexes_kwargs, "set_index") - coords, _ = merge_indexes(indexes, self._coords, set(), append=append) - return self._replace(coords=coords) + ds = self._to_temp_dataset().set_index(indexes, append=append, + inplace=inplace, **indexes_kwargs) + return self._from_temp_dataset(ds) def reset_index( self, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index fe8abdc4b95..8f65e91f5b6 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -204,6 +204,7 @@ def merge_indexes( """ vars_to_replace: Dict[Hashable, Variable] = {} vars_to_remove: List[Hashable] = [] + dims_to_replace: Dict[Hashable, Variable] = {} error_msg = "{} is not the name of an existing variable." for dim, var_names in indexes.items(): @@ -244,7 +245,7 @@ def merge_indexes( if not len(names) and len(var_names) == 1: idx = pd.Index(variables[var_names[0]].values) - else: + else: # MultiIndex for n in var_names: try: var = variables[n] @@ -256,15 +257,23 @@ def merge_indexes( levels.append(cat.categories) idx = pd.MultiIndex(levels, codes, names=names) + for n in names: + dims_to_replace[n] = dim vars_to_replace[dim] = IndexVariable(dim, idx) vars_to_remove.extend(var_names) new_variables = {k: v for k, v in variables.items() if k not in vars_to_remove} new_variables.update(vars_to_replace) + + # update dimensions if necessary GH: 3512 + for k, v in new_variables.items(): + if any(d in dims_to_replace for d in v.dims): + new_dims = [dims_to_replace.get(d, d) for d in v.dims] + new_variables[k] = type(v)(new_dims, v.data, attrs=v.attrs, + encoding=v.encoding, fastpath=True) new_coord_names = coord_names | set(vars_to_replace) new_coord_names -= set(vars_to_remove) - return new_variables, new_coord_names diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 1fe80085f35..a4f09774126 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1183,13 +1183,14 @@ def test_selection_multiindex_remove_unused(self): assert_identical(expected, actual) def test_selection_multiindex_from_level(self): - #GH: issue 3512 + #GH: 3512 da = DataArray([0, 1], dims=['x'], coords={'x': [0, 1], 'y': 'a'}) db = DataArray([2, 3], dims=['x'], coords={'x': [0, 1], 'y': 'b'}) data = xr.concat([da, db], dim='x').set_index(xy=['x', 'y']) + assert data.dims == ('xy', ) actual = data.sel(y='a') - expected = data.isel(xy=[0, 1]).unstack('xy').drop('y') - assert_identical(actual, expected) + expected = data.isel(xy=[0, 1]).unstack('xy').squeeze('y').drop('y') + assert_equal(actual, expected) def test_virtual_default_coords(self): array = DataArray(np.zeros((5,)), dims="x") From f9b6e6bfd46368d090559f92a1a610793709246f Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Thu, 14 Nov 2019 01:09:57 +0900 Subject: [PATCH 3/6] lint --- xarray/core/dataarray.py | 2 +- xarray/tests/test_dataarray.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 2124159a531..9b109b0f60a 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -48,7 +48,7 @@ assert_coordinate_consistent, remap_label_indexers, ) -from .dataset import Dataset, merge_indexes, split_indexes +from .dataset import Dataset, split_indexes from .formatting import format_item from .indexes import Indexes, default_indexes from .merge import PANDAS_TYPES diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index a4f09774126..56c3a01c7a0 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1183,7 +1183,7 @@ def test_selection_multiindex_remove_unused(self): assert_identical(expected, actual) def test_selection_multiindex_from_level(self): - #GH: 3512 + # GH: 3512 da = DataArray([0, 1], dims=['x'], coords={'x': [0, 1], 'y': 'a'}) db = DataArray([2, 3], dims=['x'], coords={'x': [0, 1], 'y': 'b'}) data = xr.concat([da, db], dim='x').set_index(xy=['x', 'y']) From 8a034c0f23f82f85546f97fcbe4a98c5ff60e316 Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Thu, 14 Nov 2019 07:36:08 +0900 Subject: [PATCH 4/6] black / mypy --- xarray/core/dataarray.py | 5 +++-- xarray/core/dataset.py | 7 ++++--- xarray/tests/test_dataarray.py | 12 ++++++------ 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 9b109b0f60a..55e73478260 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1601,8 +1601,9 @@ def set_index( -------- DataArray.reset_index """ - ds = self._to_temp_dataset().set_index(indexes, append=append, - inplace=inplace, **indexes_kwargs) + ds = self._to_temp_dataset().set_index( + indexes, append=append, inplace=inplace, **indexes_kwargs + ) return self._from_temp_dataset(ds) def reset_index( diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 8f65e91f5b6..264b76e2315 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -204,7 +204,7 @@ def merge_indexes( """ vars_to_replace: Dict[Hashable, Variable] = {} vars_to_remove: List[Hashable] = [] - dims_to_replace: Dict[Hashable, Variable] = {} + dims_to_replace: Dict[Hashable, Hashable] = {} error_msg = "{} is not the name of an existing variable." for dim, var_names in indexes.items(): @@ -270,8 +270,9 @@ def merge_indexes( for k, v in new_variables.items(): if any(d in dims_to_replace for d in v.dims): new_dims = [dims_to_replace.get(d, d) for d in v.dims] - new_variables[k] = type(v)(new_dims, v.data, attrs=v.attrs, - encoding=v.encoding, fastpath=True) + new_variables[k] = type(v)( + new_dims, v.data, attrs=v.attrs, encoding=v.encoding, fastpath=True + ) new_coord_names = coord_names | set(vars_to_replace) new_coord_names -= set(vars_to_remove) return new_variables, new_coord_names diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 56c3a01c7a0..1288dcd54ba 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1184,12 +1184,12 @@ def test_selection_multiindex_remove_unused(self): def test_selection_multiindex_from_level(self): # GH: 3512 - da = DataArray([0, 1], dims=['x'], coords={'x': [0, 1], 'y': 'a'}) - db = DataArray([2, 3], dims=['x'], coords={'x': [0, 1], 'y': 'b'}) - data = xr.concat([da, db], dim='x').set_index(xy=['x', 'y']) - assert data.dims == ('xy', ) - actual = data.sel(y='a') - expected = data.isel(xy=[0, 1]).unstack('xy').squeeze('y').drop('y') + da = DataArray([0, 1], dims=["x"], coords={"x": [0, 1], "y": "a"}) + db = DataArray([2, 3], dims=["x"], coords={"x": [0, 1], "y": "b"}) + data = xr.concat([da, db], dim="x").set_index(xy=["x", "y"]) + assert data.dims == ("xy",) + actual = data.sel(y="a") + expected = data.isel(xy=[0, 1]).unstack("xy").squeeze("y").drop("y") assert_equal(actual, expected) def test_virtual_default_coords(self): From 758264685c29e6ef9f6724499fa19cc2383ec2c1 Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Thu, 14 Nov 2019 17:12:49 +0900 Subject: [PATCH 5/6] Use _replace method --- xarray/core/dataset.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 99e88061998..de713b830f2 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -270,9 +270,7 @@ def merge_indexes( for k, v in new_variables.items(): if any(d in dims_to_replace for d in v.dims): new_dims = [dims_to_replace.get(d, d) for d in v.dims] - new_variables[k] = type(v)( - new_dims, v.data, attrs=v.attrs, encoding=v.encoding, fastpath=True - ) + new_variables[k] = v._replace(dims=new_dims) new_coord_names = coord_names | set(vars_to_replace) new_coord_names -= set(vars_to_remove) return new_variables, new_coord_names From 18fa5ec46da318d76488ea2994e9654e9683bce9 Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Thu, 14 Nov 2019 17:15:14 +0900 Subject: [PATCH 6/6] whats new --- doc/whats-new.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b8fb1f8f58e..abd94779435 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -79,6 +79,8 @@ New Features Bug fixes ~~~~~~~~~ +- Fix a bug in `set_index` in case that an existing dimension becomes a level variable of MultiIndex. (:pull:`3520`) + By `Keisuke Fujii `_. - Harmonize `_FillValue`, `missing_value` during encoding and decoding steps. (:pull:`3502`) By `Anderson Banihirwe `_. - Fix regression introduced in v0.14.0 that would cause a crash if dask is installed