Skip to content

Allow setting (or skipping) new indexes in open_dataset #8051

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
60 changes: 57 additions & 3 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from xarray.backends.locks import _get_scheduler
from xarray.coders import CFDatetimeCoder, CFTimedeltaCoder
from xarray.core import indexing
from xarray.core.coordinates import Coordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree
Expand Down Expand Up @@ -379,6 +380,18 @@ def _chunk_ds(
return backend_ds._replace(variables)


def _create_default_indexes(ds, create_default_indexes):
if not create_default_indexes:
return ds

to_index = {
name: coord.variable
for name, coord in ds.coords.items()
if coord.dims == (name,) and name not in ds.xindexes
}
return ds.assign_coords(Coordinates(to_index))


def _dataset_from_backend_dataset(
backend_ds,
filename_or_obj,
Expand All @@ -389,6 +402,7 @@ def _dataset_from_backend_dataset(
inline_array,
chunked_array_type,
from_array_kwargs,
create_default_indexes,
**extra_tokens,
):
if not isinstance(chunks, int | dict) and chunks not in {None, "auto"}:
Expand All @@ -397,11 +411,14 @@ def _dataset_from_backend_dataset(
)

_protect_dataset_variables_inplace(backend_ds, cache)

indexed = _create_default_indexes(backend_ds, create_default_indexes)

if chunks is None:
ds = backend_ds
ds = indexed
else:
ds = _chunk_ds(
backend_ds,
indexed,
filename_or_obj,
engine,
chunks,
Expand Down Expand Up @@ -434,6 +451,7 @@ def _datatree_from_backend_datatree(
inline_array,
chunked_array_type,
from_array_kwargs,
create_default_indexes,
**extra_tokens,
):
if not isinstance(chunks, int | dict) and chunks not in {None, "auto"}:
Expand All @@ -448,7 +466,7 @@ def _datatree_from_backend_datatree(
tree = DataTree.from_dict(
{
path: _chunk_ds(
node.dataset,
node.dataset.pipe(_create_default_indexes, create_default_indexes),
filename_or_obj,
engine,
chunks,
Expand Down Expand Up @@ -497,6 +515,7 @@ def open_dataset(
concat_characters: bool | Mapping[str, bool] | None = None,
decode_coords: Literal["coordinates", "all"] | bool | None = None,
drop_variables: str | Iterable[str] | None = None,
create_default_indexes: bool = True,
inline_array: bool = False,
chunked_array_type: str | None = None,
from_array_kwargs: dict[str, Any] | None = None,
Expand Down Expand Up @@ -610,6 +629,13 @@ def open_dataset(
A variable or list of variables to exclude from being parsed from the
dataset. This may be useful to drop variables with problems or
inconsistent values.
create_default_indexes : bool, default: True
If True, create pandas indexes for :term:`dimension coordinates <dimension coordinate>`,
which loads the coordinate data into memory. Set it to False if you want to avoid loading
data into memory.

Note that backends can still choose to create other indexes. If you want to control that,
please refer to the backend's documentation.
inline_array: bool, default: False
How to include the array in the dask task graph.
By default(``inline_array=False``) the array is included in a task by
Expand Down Expand Up @@ -702,6 +728,7 @@ def open_dataset(
chunked_array_type,
from_array_kwargs,
drop_variables=drop_variables,
create_default_indexes=create_default_indexes,
**decoders,
**kwargs,
)
Expand All @@ -725,6 +752,7 @@ def open_dataarray(
concat_characters: bool | None = None,
decode_coords: Literal["coordinates", "all"] | bool | None = None,
drop_variables: str | Iterable[str] | None = None,
create_default_indexes: bool = True,
inline_array: bool = False,
chunked_array_type: str | None = None,
from_array_kwargs: dict[str, Any] | None = None,
Expand Down Expand Up @@ -833,6 +861,13 @@ def open_dataarray(
A variable or list of variables to exclude from being parsed from the
dataset. This may be useful to drop variables with problems or
inconsistent values.
create_default_indexes : bool, default: True
If True, create pandas indexes for :term:`dimension coordinates <dimension coordinate>`,
which loads the coordinate data into memory. Set it to False if you want to avoid loading
data into memory.

Note that backends can still choose to create other indexes. If you want to control that,
please refer to the backend's documentation.
inline_array: bool, default: False
How to include the array in the dask task graph.
By default(``inline_array=False``) the array is included in a task by
Expand Down Expand Up @@ -890,6 +925,7 @@ def open_dataarray(
chunks=chunks,
cache=cache,
drop_variables=drop_variables,
create_default_indexes=create_default_indexes,
inline_array=inline_array,
chunked_array_type=chunked_array_type,
from_array_kwargs=from_array_kwargs,
Expand Down Expand Up @@ -946,6 +982,7 @@ def open_datatree(
concat_characters: bool | Mapping[str, bool] | None = None,
decode_coords: Literal["coordinates", "all"] | bool | None = None,
drop_variables: str | Iterable[str] | None = None,
create_default_indexes: bool = True,
inline_array: bool = False,
chunked_array_type: str | None = None,
from_array_kwargs: dict[str, Any] | None = None,
Expand Down Expand Up @@ -1055,6 +1092,13 @@ def open_datatree(
A variable or list of variables to exclude from being parsed from the
dataset. This may be useful to drop variables with problems or
inconsistent values.
create_default_indexes : bool, default: True
If True, create pandas indexes for :term:`dimension coordinates <dimension coordinate>`,
which loads the coordinate data into memory. Set it to False if you want to avoid loading
data into memory.

Note that backends can still choose to create other indexes. If you want to control that,
please refer to the backend's documentation.
inline_array: bool, default: False
How to include the array in the dask task graph.
By default(``inline_array=False``) the array is included in a task by
Expand Down Expand Up @@ -1148,6 +1192,7 @@ def open_datatree(
chunked_array_type,
from_array_kwargs,
drop_variables=drop_variables,
create_default_indexes=create_default_indexes,
**decoders,
**kwargs,
)
Expand Down Expand Up @@ -1175,6 +1220,7 @@ def open_groups(
concat_characters: bool | Mapping[str, bool] | None = None,
decode_coords: Literal["coordinates", "all"] | bool | None = None,
drop_variables: str | Iterable[str] | None = None,
create_default_indexes: bool = True,
inline_array: bool = False,
chunked_array_type: str | None = None,
from_array_kwargs: dict[str, Any] | None = None,
Expand Down Expand Up @@ -1286,6 +1332,13 @@ def open_groups(
A variable or list of variables to exclude from being parsed from the
dataset. This may be useful to drop variables with problems or
inconsistent values.
create_default_indexes : bool, default: True
If True, create pandas indexes for :term:`dimension coordinates <dimension coordinate>`,
which loads the coordinate data into memory. Set it to False if you want to avoid loading
data into memory.

Note that backends can still choose to create other indexes. If you want to control that,
please refer to the backend's documentation.
inline_array: bool, default: False
How to include the array in the dask task graph.
By default(``inline_array=False``) the array is included in a task by
Expand Down Expand Up @@ -1381,6 +1434,7 @@ def open_groups(
chunked_array_type,
from_array_kwargs,
drop_variables=drop_variables,
create_default_indexes=create_default_indexes,
**decoders,
**kwargs,
)
Expand Down
17 changes: 15 additions & 2 deletions xarray/backends/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AbstractDataStore,
BackendEntrypoint,
)
from xarray.core.coordinates import Coordinates
from xarray.core.dataset import Dataset

if TYPE_CHECKING:
Expand Down Expand Up @@ -36,6 +37,7 @@ def open_dataset(
concat_characters=True,
decode_coords=True,
drop_variables: str | Iterable[str] | None = None,
set_indexes: bool = True,
use_cftime=None,
decode_timedelta=None,
) -> Dataset:
Expand All @@ -56,8 +58,19 @@ def open_dataset(
decode_timedelta=decode_timedelta,
)

ds = Dataset(vars, attrs=attrs)
ds = ds.set_coords(coord_names.intersection(vars))
# split data and coordinate variables (promote dimension coordinates)
data_vars = {}
coord_vars = {}
for name, var in vars.items():
if name in coord_names or var.dims == (name,):
coord_vars[name] = var
else:
data_vars[name] = var

# explicit Coordinates object with no index passed
coords = Coordinates(coord_vars)

ds = Dataset(data_vars, coords=coords, attrs=attrs)
ds.set_close(filename_or_obj.close)
ds.encoding = encoding

Expand Down
9 changes: 9 additions & 0 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,6 +1347,7 @@ def open_zarr(
use_zarr_fill_value_as_mask=None,
chunked_array_type: str | None = None,
from_array_kwargs: dict[str, Any] | None = None,
create_default_indexes=True,
**kwargs,
):
"""Load and decode a dataset from a Zarr store.
Expand Down Expand Up @@ -1457,6 +1458,13 @@ def open_zarr(
chunked arrays, via whichever chunk manager is specified through the ``chunked_array_type`` kwarg.
Defaults to ``{'manager': 'dask'}``, meaning additional kwargs will be passed eventually to
:py:func:`dask.array.from_array`. Experimental API that should not be relied upon.
create_default_indexes : bool, default: True
If True, create pandas indexes for :term:`dimension coordinates <dimension coordinate>`,
which loads the coordinate data into memory. Set it to False if you want to avoid loading
data into memory.

Note that backends can still choose to create other indexes. If you want to control that,
please refer to the backend's documentation.

Returns
-------
Expand Down Expand Up @@ -1513,6 +1521,7 @@ def open_zarr(
engine="zarr",
chunks=chunks,
drop_variables=drop_variables,
create_default_indexes=create_default_indexes,
chunked_array_type=chunked_array_type,
from_array_kwargs=from_array_kwargs,
backend_kwargs=backend_kwargs,
Expand Down
36 changes: 36 additions & 0 deletions xarray/tests/test_backends_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,39 @@ def test_join_chunks(self, shape, pref_chunks, req_chunks):
chunks=dict(zip(initial[self.var_name].dims, req_chunks, strict=True)),
)
self.check_dataset(initial, final, explicit_chunks(req_chunks, shape))

@pytest.mark.parametrize("create_default_indexes", [True, False])
def test_default_indexes(self, create_default_indexes):
"""Create default indexes if the backend does not create them."""
coords = xr.Coordinates({"x": ("x", [0, 1]), "y": list("abc")}, indexes={})
initial = xr.Dataset({"a": ("x", [1, 2])}, coords=coords)

with assert_no_warnings():
final = xr.open_dataset(
initial,
engine=PassThroughBackendEntrypoint,
create_default_indexes=create_default_indexes,
)

if create_default_indexes:
assert all(name in final.xindexes for name in ["x", "y"])
else:
assert not final.xindexes

@pytest.mark.parametrize("create_default_indexes", [True, False])
def test_default_indexes_passthrough(self, create_default_indexes):
"""Allow creating indexes in the backend."""

initial = xr.Dataset(
{"a": (["x", "y"], [[1, 2, 3], [4, 5, 6]])},
coords={"x": ("x", [0, 1]), "y": ("y", list("abc"))},
).stack(z=["x", "y"])

with assert_no_warnings():
final = xr.open_dataset(
initial,
engine=PassThroughBackendEntrypoint,
create_default_indexes=create_default_indexes,
)

assert initial.coords.equals(final.coords)
Loading