Skip to content

Commit a891e22

Browse files
committed
wip refactor set_index
- Refactor reset_index - Improve creating new xarray indexes from pandas indexes with proper propagation of variable metadata (dtype, attrs, encoding)
1 parent 0086e32 commit a891e22

File tree

3 files changed

+142
-104
lines changed

3 files changed

+142
-104
lines changed

xarray/core/dataarray.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from .common import AbstractArray, DataWithCoords
4747
from .computation import unify_chunks
4848
from .coordinates import DataArrayCoordinates, assert_coordinate_consistent
49-
from .dataset import Dataset, split_indexes
49+
from .dataset import Dataset
5050
from .formatting import format_item
5151
from .indexes import Index, Indexes, default_indexes, propagate_indexes
5252
from .indexing import is_fancy_indexer, map_index_queries
@@ -2016,10 +2016,8 @@ def reset_index(
20162016
--------
20172017
DataArray.set_index
20182018
"""
2019-
coords, _ = split_indexes(
2020-
dims_or_levels, self._coords, set(), self._level_coords, drop=drop
2021-
)
2022-
return self._replace(coords=coords)
2019+
ds = self._to_temp_dataset().reset_index(dims_or_levels, drop=drop)
2020+
return self._from_temp_dataset(ds)
20232021

20242022
def reorder_levels(
20252023
self,

xarray/core/dataset.py

Lines changed: 55 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
Any,
1515
Callable,
1616
Collection,
17-
DefaultDict,
1817
Dict,
1918
Hashable,
2019
Iterable,
@@ -194,64 +193,6 @@ def calculate_dimensions(variables: Mapping[Any, Variable]) -> Dict[Hashable, in
194193
return dims
195194

196195

197-
def split_indexes(
198-
dims_or_levels: Union[Hashable, Sequence[Hashable]],
199-
variables: Mapping[Any, Variable],
200-
coord_names: Set[Hashable],
201-
level_coords: Mapping[Any, Hashable],
202-
drop: bool = False,
203-
) -> Tuple[Dict[Hashable, Variable], Set[Hashable]]:
204-
"""Extract (multi-)indexes (levels) as variables.
205-
206-
Not public API. Used in Dataset and DataArray reset_index
207-
methods.
208-
"""
209-
if isinstance(dims_or_levels, str) or not isinstance(dims_or_levels, Sequence):
210-
dims_or_levels = [dims_or_levels]
211-
212-
dim_levels: DefaultDict[Any, List[Hashable]] = defaultdict(list)
213-
dims = []
214-
for k in dims_or_levels:
215-
if k in level_coords:
216-
dim_levels[level_coords[k]].append(k)
217-
else:
218-
dims.append(k)
219-
220-
vars_to_replace = {}
221-
vars_to_create: Dict[Hashable, Variable] = {}
222-
vars_to_remove = []
223-
224-
for d in dims:
225-
index = variables[d].to_index()
226-
if isinstance(index, pd.MultiIndex):
227-
dim_levels[d] = index.names
228-
else:
229-
vars_to_remove.append(d)
230-
if not drop:
231-
vars_to_create[str(d) + "_"] = Variable(d, index, variables[d].attrs)
232-
233-
for d, levs in dim_levels.items():
234-
index = variables[d].to_index()
235-
if len(levs) == index.nlevels:
236-
vars_to_remove.append(d)
237-
else:
238-
vars_to_replace[d] = IndexVariable(d, index.droplevel(levs))
239-
240-
if not drop:
241-
for lev in levs:
242-
idx = index.get_level_values(lev)
243-
vars_to_create[idx.name] = Variable(d, idx, variables[d].attrs)
244-
245-
new_variables = dict(variables)
246-
for v in set(vars_to_remove):
247-
del new_variables[v]
248-
new_variables.update(vars_to_replace)
249-
new_variables.update(vars_to_create)
250-
new_coord_names = (coord_names | set(vars_to_create)) - set(vars_to_remove)
251-
252-
return new_variables, new_coord_names
253-
254-
255196
def _assert_empty(args: tuple, msg: str = "%s") -> None:
256197
if args:
257198
raise ValueError(msg % args)
@@ -3777,14 +3718,61 @@ def reset_index(
37773718
--------
37783719
Dataset.set_index
37793720
"""
3780-
variables, coord_names = split_indexes(
3781-
dims_or_levels,
3782-
self._variables,
3783-
self._coord_names,
3784-
cast(Mapping[Hashable, Hashable], self._level_coords),
3785-
drop=drop,
3786-
)
3787-
return self._replace_vars_and_dims(variables, coord_names=coord_names)
3721+
if isinstance(dims_or_levels, str) or not isinstance(dims_or_levels, Sequence):
3722+
dims_or_levels = [dims_or_levels]
3723+
3724+
invalid_coords = set(dims_or_levels) - set(self.xindexes)
3725+
if invalid_coords:
3726+
raise ValueError(
3727+
f"{tuple(invalid_coords)} are not coordinates with an index"
3728+
)
3729+
3730+
drop_indexes: List[Hashable] = []
3731+
drop_variables: List[Hashable] = []
3732+
replaced_indexes: List[PandasMultiIndex] = []
3733+
new_indexes: Dict[Hashable, Index] = {}
3734+
new_variables: Dict[Hashable, IndexVariable] = {}
3735+
3736+
index_coord_names = {
3737+
k: coord_names
3738+
for _, coord_names in group_coords_by_index(self.xindexes)
3739+
for k in coord_names
3740+
}
3741+
3742+
for name in dims_or_levels:
3743+
index = self.xindexes[name]
3744+
drop_indexes += [k for k in index_coord_names[name]]
3745+
3746+
if isinstance(index, PandasMultiIndex) and name not in self.dims:
3747+
# special case for pd.MultiIndex (name is an index level):
3748+
# replace by a new index with dropped level(s) instead of just drop the index
3749+
# TODO: eventually extend Index API to allow this for custom multi-indexes?
3750+
if index not in replaced_indexes:
3751+
level_names = index.index.names
3752+
level_vars = {
3753+
k: self._variables[k]
3754+
for k in level_names
3755+
if k not in dims_or_levels
3756+
}
3757+
idx, idx_vars = index.keep_levels(level_vars)
3758+
new_indexes.update({k: idx for k in idx_vars})
3759+
new_variables.update(idx_vars)
3760+
replaced_indexes.append(index)
3761+
3762+
if drop:
3763+
drop_variables.append(name)
3764+
3765+
indexes = {k: v for k, v in self.xindexes.items() if k not in drop_indexes}
3766+
indexes.update(new_indexes)
3767+
3768+
variables = {
3769+
k: v for k, v in self._variables.items() if k not in drop_variables
3770+
}
3771+
variables.update(new_variables)
3772+
3773+
coord_names = set(new_variables) | self._coord_names
3774+
3775+
return self._replace(variables, coord_names=coord_names, indexes=indexes)
37883776

37893777
def reorder_levels(
37903778
self,

xarray/core/indexes.py

Lines changed: 84 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Optional,
1212
Tuple,
1313
Union,
14+
cast,
1415
)
1516

1617
import numpy as np
@@ -212,7 +213,10 @@ def from_variables(
212213

213214
@classmethod
214215
def from_pandas_index(
215-
cls, index: pd.Index, dim: Hashable
216+
cls,
217+
index: pd.Index,
218+
dim: Hashable,
219+
var_meta: Optional[Dict[Any, Dict]] = None,
216220
) -> Tuple["PandasIndex", IndexVars]:
217221
from .variable import IndexVariable
218222

@@ -223,10 +227,21 @@ def from_pandas_index(
223227
else:
224228
name = index.name
225229

226-
data = PandasIndexingAdapter(index)
227-
index_var = IndexVariable(dim, data, fastpath=True)
230+
if var_meta is None:
231+
var_meta = {name: {}}
232+
233+
data = PandasIndexingAdapter(index, dtype=var_meta[name].get("dtype"))
234+
index_var = IndexVariable(
235+
dim,
236+
data,
237+
fastpath=True,
238+
attrs=var_meta[name].get("attrs"),
239+
encoding=var_meta[name].get("encoding"),
240+
)
228241

229-
return cls(index, dim), {name: index_var}
242+
return cls(index, dim, coord_dtype=var_meta[name].get("dtype")), {
243+
name: index_var
244+
}
230245

231246
def to_pandas_index(self) -> pd.Index:
232247
return self.index
@@ -297,13 +312,11 @@ def rename(self, name_dict, dims_dict):
297312
return self, {}
298313

299314
new_name = name_dict.get(self.index.name, self.index.name)
300-
pd_idx = self.index.rename(new_name)
315+
index = self.index.rename(new_name)
301316
new_dim = dims_dict.get(self.dim, self.dim)
317+
var_meta = {new_name: {"dtype": self.coord_dtype}}
302318

303-
index, index_vars = self.from_pandas_index(pd_idx, dim=new_dim)
304-
index.coord_dtype = self.coord_dtype
305-
306-
return index, index_vars
319+
return self.from_pandas_index(index, dim=new_dim, var_meta=var_meta)
307320

308321
def copy(self, deep=True):
309322
return self._replace(self.index.copy(deep=deep))
@@ -411,13 +424,12 @@ def from_variables_maybe_expand(
411424
"""Create a new multi-index maybe by expanding an existing one with
412425
new variables as index levels.
413426
414-
the index might be created along a new dimension.
427+
The index and its corresponding coordinates may be created along a new dimension.
415428
"""
416429
names: List[Hashable] = []
417430
codes: List[List[int]] = []
418431
levels: List[List[int]] = []
419432
var_meta: Dict[str, Dict] = {}
420-
level_coords_dtype: Dict[Hashable, Any] = {}
421433

422434
_check_dim_compat({**current_variables, **variables})
423435

@@ -427,20 +439,21 @@ def add_level_var(name, var):
427439
"attrs": var.attrs,
428440
"encoding": var.encoding,
429441
}
430-
level_coords_dtype[name] = var.dtype
431442

432443
if len(current_variables) > 1:
433-
current_index: pd.MultiIndex = next(
434-
iter(current_variables.values())
435-
)._data.array
444+
# expand from an existing multi-index
445+
data = cast(
446+
PandasMultiIndexingAdapter, next(iter(current_variables.values()))._data
447+
)
448+
current_index = data.array
436449
names.extend(current_index.names)
437450
codes.extend(current_index.codes)
438451
levels.extend(current_index.levels)
439452
for name in current_index.names:
440453
add_level_var(name, current_variables[name])
441454

442455
elif len(current_variables) == 1:
443-
# one 1D variable (no multi-index): convert it to an index level
456+
# expand from one 1D variable (no multi-index): convert it to an index level
444457
var = next(iter(current_variables.values()))
445458
new_var_name = f"{dim}_level_0"
446459
names.append(new_var_name)
@@ -457,27 +470,63 @@ def add_level_var(name, var):
457470
add_level_var(name, var)
458471

459472
index = pd.MultiIndex(levels, codes, names=names)
460-
obj = cls(index, dim, level_coords_dtype=level_coords_dtype)
461-
index_vars = _create_variables_from_multiindex(index, dim, var_meta=var_meta)
462473

463-
return obj, index_vars
474+
return cls.from_pandas_index(index, dim, var_meta=var_meta)
475+
476+
def keep_levels(
477+
self, level_variables: Mapping[Any, "Variable"]
478+
) -> Tuple[Union["PandasMultiIndex", PandasIndex], IndexVars]:
479+
"""Keep only the provided levels and return a new multi-index with its
480+
corresponding coordinates.
481+
482+
"""
483+
var_meta: Dict[str, Dict] = {}
484+
485+
for name, var in level_variables.items():
486+
var_meta[name] = {
487+
"dtype": var.dtype,
488+
"attrs": var.attrs,
489+
"encoding": var.encoding,
490+
}
491+
492+
index = self.index.droplevel(
493+
[k for k in self.index.names if k not in level_variables]
494+
)
495+
496+
if isinstance(index, pd.MultiIndex):
497+
return self.from_pandas_index(index, self.dim, var_meta=var_meta)
498+
else:
499+
return PandasIndex.from_pandas_index(index, self.dim, var_meta=var_meta)
464500

465501
@classmethod
466502
def from_pandas_index(
467-
cls, index: pd.MultiIndex, dim: Hashable
503+
cls,
504+
index: pd.MultiIndex,
505+
dim: Hashable,
506+
var_meta: Optional[Dict[Any, Dict]] = None,
468507
) -> Tuple["PandasMultiIndex", IndexVars]:
469-
var_meta = {}
508+
509+
names = []
510+
idx_dtypes = {}
470511
for i, idx in enumerate(index.levels):
471512
name = idx.name or f"{dim}_level_{i}"
472513
if name == dim:
473514
raise ValueError(
474515
f"conflicting multi-index level name {name!r} with dimension {dim!r}"
475516
)
476-
var_meta[name] = {"dtype": idx.dtype}
517+
names.append(name)
518+
idx_dtypes[name] = idx.dtype
519+
520+
if var_meta is None:
521+
var_meta = {k: {} for k in names}
522+
for name, dtype in idx_dtypes.items():
523+
var_meta[name]["dtype"] = var_meta[name].get("dtype", dtype)
524+
525+
level_coords_dtype = {k: var_meta[k]["dtype"] for k in names}
477526

478-
index = index.rename(var_meta.keys())
527+
index = index.rename(names)
479528
index_vars = _create_variables_from_multiindex(index, dim, var_meta=var_meta)
480-
return cls(index, dim), index_vars
529+
return cls(index, dim, level_coords_dtype=level_coords_dtype), index_vars
481530

482531
def query(self, labels, method=None, tolerance=None) -> QueryResult:
483532
if method is not None or tolerance is not None:
@@ -570,15 +619,19 @@ def query(self, labels, method=None, tolerance=None) -> QueryResult:
570619
raise KeyError(f"not all values found in index {coord_name!r}")
571620

572621
if new_index is not None:
622+
# variable(s) attrs and encoding metadata are propagated
623+
# when replacing the indexes in the resulting xarray object
624+
var_meta = {k: {"dtype": v} for k, v in self.level_coords_dtype.items()}
625+
573626
if isinstance(new_index, pd.MultiIndex):
574627
new_index, new_vars = PandasMultiIndex.from_pandas_index(
575-
new_index, self.dim
628+
new_index, self.dim, var_meta=var_meta
576629
)
577630
dims_dict = {}
578631
drop_coords = set(self.index.names) - set(new_index.index.names)
579632
else:
580633
new_index, new_vars = PandasIndex.from_pandas_index(
581-
new_index, new_index.name
634+
new_index, new_index.name, var_meta=var_meta
582635
)
583636
dims_dict = {self.dim: new_index.index.name}
584637
drop_coords = set(self.index.names) - {new_index.index.name} | {
@@ -602,15 +655,14 @@ def rename(self, name_dict, dims_dict):
602655

603656
# pandas 1.3.0: could simply do `self.index.rename(names_dict)`
604657
new_names = [name_dict.get(k, k) for k in self.index.names]
605-
pd_idx = self.index.rename(new_names)
606-
new_dim = dims_dict.get(self.dim, self.dim)
658+
index = self.index.rename(new_names)
607659

608-
index, index_vars = self.from_pandas_index(pd_idx, new_dim)
609-
index.level_coords_dtype = {
610-
k: v for k, v in zip(new_names, self.level_coords_dtype.values())
660+
new_dim = dims_dict.get(self.dim, self.dim)
661+
var_meta = {
662+
k: {"dtype": v} for k, v in zip(new_names, self.level_coords_dtype.values())
611663
}
612664

613-
return index, index_vars
665+
return self.from_pandas_index(index, new_dim, var_meta=var_meta)
614666

615667

616668
def remove_unused_levels_categories(index: pd.Index) -> pd.Index:

0 commit comments

Comments
 (0)