Skip to content

More flexible index variables #8124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion xarray/core/alignment.py
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@
Indexes,
PandasIndex,
PandasMultiIndex,
create_index_variables,
indexes_all_equal,
safe_cast_to_index,
)
@@ -425,7 +426,7 @@ def align_indexes(self) -> None:
elif self.join == "right":
joined_index_vars = matching_index_vars[-1]
else:
joined_index_vars = joined_index.create_variables()
joined_index_vars = create_index_variables(joined_index)
else:
joined_index = matching_indexes[0]
joined_index_vars = matching_index_vars[0]
4 changes: 2 additions & 2 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
from xarray.core import dtypes, utils
from xarray.core.alignment import align, reindex_variables
from xarray.core.duck_array_ops import lazy_array_equiv
from xarray.core.indexes import Index, PandasIndex
from xarray.core.indexes import Index, PandasIndex, create_index_variables
from xarray.core.merge import (
_VALID_COMPAT,
collect_variables_and_indexes,
@@ -619,7 +619,7 @@ def get_indexes(name):
# index created from a scalar coordinate
idx_vars = {name: datasets[0][name].variable}
result_indexes.update({k: combined_idx for k in idx_vars})
combined_idx_vars = combined_idx.create_variables(idx_vars)
combined_idx_vars = create_index_variables(combined_idx, idx_vars)
for k, v in combined_idx_vars.items():
v.attrs = merge_attrs(
[ds.variables[k].attrs for ds in datasets],
7 changes: 5 additions & 2 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
@@ -293,7 +293,7 @@ def __init__(
else:
variables = {}
for name, data in coords.items():
var = as_variable(data, name=name)
var = as_variable(data, name=name, auto_convert=False)
if var.dims == (name,) and indexes is None:
index, index_vars = create_default_index_implicit(var, list(coords))
default_indexes.update({k: index for k in index_vars})
@@ -930,9 +930,12 @@ def create_coords_with_default_indexes(
if isinstance(obj, DataArray):
dataarray_coords.append(obj.coords)

variable = as_variable(obj, name=name)
variable = as_variable(obj, name=name, auto_convert=False)

if variable.dims == (name,):
# still needed to convert to IndexVariable first due to some
# pandas multi-index edge cases.
variable = variable.to_index_variable()
idx, idx_vars = create_default_index_implicit(variable, all_variables)
indexes.update({k: idx for k in idx_vars})
variables.update(idx_vars)
10 changes: 7 additions & 3 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
@@ -162,7 +162,9 @@ def _infer_coords_and_dims(
dims = list(coords.keys())
else:
for n, (dim, coord) in enumerate(zip(dims, coords)):
coord = as_variable(coord, name=dims[n]).to_index_variable()
coord = as_variable(
coord, name=dims[n], auto_convert=False
).to_index_variable()
dims[n] = coord.name
dims = tuple(dims)
elif len(dims) != len(shape):
@@ -183,10 +185,12 @@ def _infer_coords_and_dims(
new_coords = {}
if utils.is_dict_like(coords):
for k, v in coords.items():
new_coords[k] = as_variable(v, name=k)
new_coords[k] = as_variable(v, name=k, auto_convert=False)
if new_coords[k].dims == (k,):
new_coords[k] = new_coords[k].to_index_variable()
elif coords is not None:
for dim, coord in zip(dims, coords):
var = as_variable(coord, name=dim)
var = as_variable(coord, name=dim, auto_convert=False)
var.dims = (dim,)
new_coords[dim] = var.to_index_variable()

14 changes: 8 additions & 6 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
@@ -65,6 +65,7 @@
PandasMultiIndex,
assert_no_index_corrupted,
create_default_index_implicit,
create_index_variables,
filter_indexes_from_coords,
isel_indexes,
remove_unused_levels_categories,
@@ -4106,11 +4107,12 @@ def _rename_indexes(
new_index = index.rename(name_dict, dims_dict)
new_coord_names = [name_dict.get(k, k) for k in coord_names]
indexes.update({k: new_index for k in new_coord_names})
new_index_vars = new_index.create_variables(
new_index_vars = create_index_variables(
new_index,
{
new: self._variables[old]
for old, new in zip(coord_names, new_coord_names)
}
},
)
variables.update(new_index_vars)

@@ -4940,7 +4942,7 @@ def set_xindex(

index = index_cls.from_variables(coord_vars, options=options)

new_coord_vars = index.create_variables(coord_vars)
new_coord_vars = create_index_variables(index, coord_vars)

# special case for setting a pandas multi-index from level coordinates
# TODO: remove it once we depreciate pandas multi-index dimension (tuple
@@ -5134,7 +5136,7 @@ def _stack_once(
idx = index_cls.stack(product_vars, new_dim)
new_indexes[new_dim] = idx
new_indexes.update({k: idx for k in product_vars})
idx_vars = idx.create_variables(product_vars)
idx_vars = create_index_variables(idx, product_vars)
# keep consistent multi-index coordinate order
for k in idx_vars:
new_variables.pop(k, None)
@@ -5326,7 +5328,7 @@ def _unstack_once(
indexes.update(new_indexes)

for name, idx in new_indexes.items():
variables.update(idx.create_variables(index_vars))
variables.update(create_index_variables(idx, index_vars))

for name, var in self.variables.items():
if name not in index_vars:
@@ -5367,7 +5369,7 @@ def _unstack_full_reindex(

new_index_variables = {}
for name, idx in new_indexes.items():
new_index_variables.update(idx.create_variables(index_vars))
new_index_variables.update(create_index_variables(idx, index_vars))

new_dim_sizes = {k: v.size for k, v in new_index_variables.items()}
variables.update(new_index_variables)
36 changes: 33 additions & 3 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@

from xarray.core import formatting, nputils, utils
from xarray.core.indexing import (
IndexedCoordinateArray,
IndexSelResult,
PandasIndexingAdapter,
PandasMultiIndexingAdapter,
@@ -1332,6 +1333,29 @@ def rename(self, name_dict, dims_dict):
)


def create_index_variables(
index: Index, variables: Mapping[Any, Variable] | None = None
) -> IndexVars:
"""Create index coordinate variables and wrap their data in order to prevent
modifying their values in-place.

"""
# - IndexVariable already has safety guards that prevent updating its values
# (it is a special case for PandasIndex that will likely be removed, eventually)
# - For Variable objects: wrap their data.
from xarray.core.variable import IndexVariable

index_vars = index.create_variables(variables)

for var in index_vars.values():
if not isinstance(var, IndexVariable) and not isinstance(
var._data, IndexedCoordinateArray
):
var._data = IndexedCoordinateArray(var._data)

return index_vars


def create_default_index_implicit(
dim_variable: Variable,
all_variables: Mapping | Iterable[Hashable] | None = None,
@@ -1349,7 +1373,13 @@ def create_default_index_implicit(
all_variables = {k: None for k in all_variables}

name = dim_variable.dims[0]
array = getattr(dim_variable._data, "array", None)
data = dim_variable._data
if isinstance(data, PandasIndexingAdapter):
array = data.array
elif isinstance(data, IndexedCoordinateArray):
array = getattr(data.array, "array", None)
else:
array = None
index: PandasIndex

if isinstance(array, pd.MultiIndex):
@@ -1631,7 +1661,7 @@ def copy_indexes(
convert_new_idx = False

new_idx = idx._copy(deep=deep, memo=memo)
idx_vars = idx.create_variables(coords)
idx_vars = create_index_variables(idx, coords)

if convert_new_idx:
new_idx = cast(PandasIndex, new_idx).index
@@ -1779,7 +1809,7 @@ def _apply_indexes(
new_index = getattr(index, func)(index_args)
if new_index is not None:
new_indexes.update({k: new_index for k in index_vars})
new_index_vars = new_index.create_variables(index_vars)
new_index_vars = create_index_variables(new_index, index_vars)
new_index_variables.update(new_index_vars)
else:
for k in index_vars:
29 changes: 29 additions & 0 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
@@ -1467,6 +1467,35 @@ def transpose(self, order):
return self.array.transpose(order)


class IndexedCoordinateArray(ExplicitlyIndexedNDArrayMixin):
"""Wrap an Xarray indexed coordinate array to make sure it keeps
synced with its index.

"""

__slots__ = ("array",)

def __init__(self, array):
self.array = as_indexable(array)

def get_duck_array(self):
return self.array.get_duck_array()

def __getitem__(self, key):
return type(self)(_wrap_numpy_scalars(self.array[key]))

def transpose(self, order):
return self.array.transpose(order)

def __setitem__(self, key, value):
raise TypeError(
"cannot modify the values of an indexed coordinate in-place "
"as it may corrupt its index. "
"Please use DataArray.assign_coords, Dataset.assign_coords "
"or Dataset.assign as appropriate."
)


class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
"""Wrap a pandas.Index to preserve dtypes and handle explicit indexing."""

2 changes: 1 addition & 1 deletion xarray/core/merge.py
Original file line number Diff line number Diff line change
@@ -355,7 +355,7 @@ def append_all(variables, indexes):
indexes_.pop(name, None)
append_all(coords_, indexes_)

variable = as_variable(variable, name=name)
variable = as_variable(variable, name=name, auto_convert=False)
if name in indexes:
append(name, variable, indexes[name])
elif variable.dims == (name,):
30 changes: 25 additions & 5 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@
from xarray.core.common import AbstractArray
from xarray.core.indexing import (
BasicIndexer,
IndexedCoordinateArray,
OuterIndexer,
PandasIndexingAdapter,
VectorizedIndexer,
@@ -45,6 +46,7 @@
decode_numpy_dict_values,
drop_dims_from_indexers,
either_dict_or_kwargs,
emit_user_level_warning,
ensure_us_time_resolution,
infix_dims,
is_duck_array,
@@ -86,7 +88,7 @@ class MissingDimensionsError(ValueError):
# TODO: move this to an xarray.exceptions module?


def as_variable(obj, name=None) -> Variable | IndexVariable:
def as_variable(obj, name=None, auto_convert=True) -> Variable | IndexVariable:
"""Convert an object into a Variable.

Parameters
@@ -106,6 +108,9 @@ def as_variable(obj, name=None) -> Variable | IndexVariable:
along a dimension of this given name.
- Variables with name matching one of their dimensions are converted
into `IndexVariable` objects.
auto_convert : bool, optional
For internal use only! If True, convert a "dimension" variable into
an IndexVariable object (deprecated).

Returns
-------
@@ -156,9 +161,15 @@ def as_variable(obj, name=None) -> Variable | IndexVariable:
f"explicit list of dimensions: {obj!r}"
)

if name is not None and name in obj.dims and obj.ndim == 1:
# automatically convert the Variable into an Index
obj = obj.to_index_variable()
if auto_convert:
if name is not None and name in obj.dims and obj.ndim == 1:
# automatically convert the Variable into an Index
emit_user_level_warning(
f"variable {name!r} with name matching its dimension will not be "
"automatically converted into an `IndexVariable` object in the future.",
FutureWarning,
)
obj = obj.to_index_variable()

return obj

@@ -430,6 +441,13 @@ def data(self) -> Any:

@data.setter
def data(self, data):
if isinstance(self._data, IndexedCoordinateArray):
raise ValueError(
"Cannot assign to the .data attribute of an indexed coordinate "
"as it may corrupt its index. "
"Please use DataArray.assign_coords, Dataset.assign_coords or Dataset.assign as appropriate."
)

data = as_compatible_data(data)
if data.shape != self.shape:
raise ValueError(
@@ -829,8 +847,10 @@ def _broadcast_indexes_vectorized(self, key):
variable = (
value
if isinstance(value, Variable)
else as_variable(value, name=dim)
else as_variable(value, name=dim, auto_convert=False)
)
if variable.dims == (dim,):
variable = variable.to_index_variable()
if variable.dtype.kind == "b": # boolean indexing case
(variable,) = variable._nonzero()

18 changes: 13 additions & 5 deletions xarray/testing.py
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex, default_indexes
from xarray.core.indexing import IndexedCoordinateArray
from xarray.core.variable import IndexVariable, Variable

__all__ = (
@@ -272,9 +273,11 @@ def _assert_indexes_invariants_checks(
}

index_vars = {
k for k, v in possible_coord_variables.items() if isinstance(v, IndexVariable)
k
for k, v in possible_coord_variables.items()
if isinstance(v, IndexVariable) or isinstance(v._data, IndexedCoordinateArray)
}
assert indexes.keys() <= index_vars, (set(indexes), index_vars)
assert indexes.keys() == index_vars, (set(indexes), index_vars)

# check pandas index wrappers vs. coordinate data adapters
for k, index in indexes.items():
@@ -340,9 +343,14 @@ def _assert_dataarray_invariants(da: DataArray, check_default_indexes: bool):
da.dims,
{k: v.dims for k, v in da._coords.items()},
)
assert all(
isinstance(v, IndexVariable) for (k, v) in da._coords.items() if v.dims == (k,)
), {k: type(v) for k, v in da._coords.items()}

if check_default_indexes:
assert all(
isinstance(v, IndexVariable)
for (k, v) in da._coords.items()
if v.dims == (k,)
), {k: type(v) for k, v in da._coords.items()}

for k, v in da._coords.items():
_assert_variable_invariants(v, k)