Skip to content

Commit dcf2ac4

Browse files
authored
Zarr: Optimize region="auto" detection (#8997)
* Zarr: Optimize region detection * Fix for unindexed dimensions. * Better example * small cleanup
1 parent 2ad98b1 commit dcf2ac4

File tree

3 files changed

+101
-114
lines changed

3 files changed

+101
-114
lines changed

doc/user-guide/io.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -874,7 +874,7 @@ and then calling ``to_zarr`` with ``compute=False`` to write only metadata
874874
# The values of this dask array are entirely irrelevant; only the dtype,
875875
# shape and chunks are used
876876
dummies = dask.array.zeros(30, chunks=10)
877-
ds = xr.Dataset({"foo": ("x", dummies)})
877+
ds = xr.Dataset({"foo": ("x", dummies)}, coords={"x": np.arange(30)})
878878
path = "path/to/directory.zarr"
879879
# Now we write the metadata without computing any array values
880880
ds.to_zarr(path, compute=False)
@@ -890,7 +890,7 @@ where the data should be written (in index space, not label space), e.g.,
890890
891891
# For convenience, we'll slice a single dataset, but in the real use-case
892892
# we would create them separately possibly even from separate processes.
893-
ds = xr.Dataset({"foo": ("x", np.arange(30))})
893+
ds = xr.Dataset({"foo": ("x", np.arange(30))}, coords={"x": np.arange(30)})
894894
# Any of the following region specifications are valid
895895
ds.isel(x=slice(0, 10)).to_zarr(path, region="auto")
896896
ds.isel(x=slice(10, 20)).to_zarr(path, region={"x": "auto"})

xarray/backends/api.py

Lines changed: 10 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
_normalize_path,
2828
)
2929
from xarray.backends.locks import _get_scheduler
30-
from xarray.backends.zarr import open_zarr
3130
from xarray.core import indexing
3231
from xarray.core.combine import (
3332
_infer_concat_order_from_positions,
@@ -1522,92 +1521,6 @@ def save_mfdataset(
15221521
)
15231522

15241523

1525-
def _auto_detect_region(ds_new, ds_orig, dim):
1526-
# Create a mapping array of coordinates to indices on the original array
1527-
coord = ds_orig[dim]
1528-
da_map = DataArray(np.arange(coord.size), coords={dim: coord})
1529-
1530-
try:
1531-
da_idxs = da_map.sel({dim: ds_new[dim]})
1532-
except KeyError as e:
1533-
if "not all values found" in str(e):
1534-
raise KeyError(
1535-
f"Not all values of coordinate '{dim}' in the new array were"
1536-
" found in the original store. Writing to a zarr region slice"
1537-
" requires that no dimensions or metadata are changed by the write."
1538-
)
1539-
else:
1540-
raise e
1541-
1542-
if (da_idxs.diff(dim) != 1).any():
1543-
raise ValueError(
1544-
f"The auto-detected region of coordinate '{dim}' for writing new data"
1545-
" to the original store had non-contiguous indices. Writing to a zarr"
1546-
" region slice requires that the new data constitute a contiguous subset"
1547-
" of the original store."
1548-
)
1549-
1550-
dim_slice = slice(da_idxs.values[0], da_idxs.values[-1] + 1)
1551-
1552-
return dim_slice
1553-
1554-
1555-
def _auto_detect_regions(ds, region, open_kwargs):
1556-
ds_original = open_zarr(**open_kwargs)
1557-
for key, val in region.items():
1558-
if val == "auto":
1559-
region[key] = _auto_detect_region(ds, ds_original, key)
1560-
return region
1561-
1562-
1563-
def _validate_and_autodetect_region(ds, region, mode, open_kwargs) -> dict[str, slice]:
1564-
if region == "auto":
1565-
region = {dim: "auto" for dim in ds.dims}
1566-
1567-
if not isinstance(region, dict):
1568-
raise TypeError(f"``region`` must be a dict, got {type(region)}")
1569-
1570-
if any(v == "auto" for v in region.values()):
1571-
if mode != "r+":
1572-
raise ValueError(
1573-
f"``mode`` must be 'r+' when using ``region='auto'``, got {mode}"
1574-
)
1575-
region = _auto_detect_regions(ds, region, open_kwargs)
1576-
1577-
for k, v in region.items():
1578-
if k not in ds.dims:
1579-
raise ValueError(
1580-
f"all keys in ``region`` are not in Dataset dimensions, got "
1581-
f"{list(region)} and {list(ds.dims)}"
1582-
)
1583-
if not isinstance(v, slice):
1584-
raise TypeError(
1585-
"all values in ``region`` must be slice objects, got "
1586-
f"region={region}"
1587-
)
1588-
if v.step not in {1, None}:
1589-
raise ValueError(
1590-
"step on all slices in ``region`` must be 1 or None, got "
1591-
f"region={region}"
1592-
)
1593-
1594-
non_matching_vars = [
1595-
k for k, v in ds.variables.items() if not set(region).intersection(v.dims)
1596-
]
1597-
if non_matching_vars:
1598-
raise ValueError(
1599-
f"when setting `region` explicitly in to_zarr(), all "
1600-
f"variables in the dataset to write must have at least "
1601-
f"one dimension in common with the region's dimensions "
1602-
f"{list(region.keys())}, but that is not "
1603-
f"the case for some variables here. To drop these variables "
1604-
f"from this dataset before exporting to zarr, write: "
1605-
f".drop_vars({non_matching_vars!r})"
1606-
)
1607-
1608-
return region
1609-
1610-
16111524
def _validate_datatypes_for_zarr_append(zstore, dataset):
16121525
"""If variable exists in the store, confirm dtype of the data to append is compatible with
16131526
existing dtype.
@@ -1768,24 +1681,6 @@ def to_zarr(
17681681
# validate Dataset keys, DataArray names
17691682
_validate_dataset_names(dataset)
17701683

1771-
if region is not None:
1772-
open_kwargs = dict(
1773-
store=store,
1774-
synchronizer=synchronizer,
1775-
group=group,
1776-
consolidated=consolidated,
1777-
storage_options=storage_options,
1778-
zarr_version=zarr_version,
1779-
)
1780-
region = _validate_and_autodetect_region(dataset, region, mode, open_kwargs)
1781-
# can't modify indexed with region writes
1782-
dataset = dataset.drop_vars(dataset.indexes)
1783-
if append_dim is not None and append_dim in region:
1784-
raise ValueError(
1785-
f"cannot list the same dimension in both ``append_dim`` and "
1786-
f"``region`` with to_zarr(), got {append_dim} in both"
1787-
)
1788-
17891684
if zarr_version is None:
17901685
# default to 2 if store doesn't specify it's version (e.g. a path)
17911686
zarr_version = int(getattr(store, "_store_version", 2))
@@ -1815,6 +1710,16 @@ def to_zarr(
18151710
write_empty=write_empty_chunks,
18161711
)
18171712

1713+
if region is not None:
1714+
zstore._validate_and_autodetect_region(dataset)
1715+
# can't modify indexed with region writes
1716+
dataset = dataset.drop_vars(dataset.indexes)
1717+
if append_dim is not None and append_dim in region:
1718+
raise ValueError(
1719+
f"cannot list the same dimension in both ``append_dim`` and "
1720+
f"``region`` with to_zarr(), got {append_dim} in both"
1721+
)
1722+
18181723
if mode in ["a", "a-", "r+"]:
18191724
_validate_datatypes_for_zarr_append(zstore, dataset)
18201725
if append_dim is not None:

xarray/backends/zarr.py

Lines changed: 89 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import TYPE_CHECKING, Any
88

99
import numpy as np
10+
import pandas as pd
1011

1112
from xarray import coding, conventions
1213
from xarray.backends.common import (
@@ -509,7 +510,9 @@ def ds(self):
509510
# TODO: consider deprecating this in favor of zarr_group
510511
return self.zarr_group
511512

512-
def open_store_variable(self, name, zarr_array):
513+
def open_store_variable(self, name, zarr_array=None):
514+
if zarr_array is None:
515+
zarr_array = self.zarr_group[name]
513516
data = indexing.LazilyIndexedArray(ZarrArrayWrapper(zarr_array))
514517
try_nczarr = self._mode == "r"
515518
dimensions, attributes = _get_zarr_dims_and_attrs(
@@ -623,11 +626,7 @@ def store(
623626
# avoid needing to load index variables into memory.
624627
# TODO: consider making loading indexes lazy again?
625628
existing_vars, _, _ = conventions.decode_cf_variables(
626-
{
627-
k: v
628-
for k, v in self.get_variables().items()
629-
if k in existing_variable_names
630-
},
629+
{k: self.open_store_variable(name=k) for k in existing_variable_names},
631630
self.get_attrs(),
632631
)
633632
# Modified variables must use the same encoding as the store.
@@ -796,10 +795,93 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
796795
region = tuple(write_region[dim] for dim in dims)
797796
writer.add(v.data, zarr_array, region)
798797

799-
def close(self):
798+
def close(self) -> None:
800799
if self._close_store_on_close:
801800
self.zarr_group.store.close()
802801

802+
def _auto_detect_regions(self, ds, region):
803+
for dim, val in region.items():
804+
if val != "auto":
805+
continue
806+
807+
if dim not in ds._variables:
808+
# unindexed dimension
809+
region[dim] = slice(0, ds.sizes[dim])
810+
continue
811+
812+
variable = conventions.decode_cf_variable(
813+
dim, self.open_store_variable(dim).compute()
814+
)
815+
assert variable.dims == (dim,)
816+
index = pd.Index(variable.data)
817+
idxs = index.get_indexer(ds[dim].data)
818+
if any(idxs == -1):
819+
raise KeyError(
820+
f"Not all values of coordinate '{dim}' in the new array were"
821+
" found in the original store. Writing to a zarr region slice"
822+
" requires that no dimensions or metadata are changed by the write."
823+
)
824+
825+
if (np.diff(idxs) != 1).any():
826+
raise ValueError(
827+
f"The auto-detected region of coordinate '{dim}' for writing new data"
828+
" to the original store had non-contiguous indices. Writing to a zarr"
829+
" region slice requires that the new data constitute a contiguous subset"
830+
" of the original store."
831+
)
832+
region[dim] = slice(idxs[0], idxs[-1] + 1)
833+
return region
834+
835+
def _validate_and_autodetect_region(self, ds) -> None:
836+
region = self._write_region
837+
838+
if region == "auto":
839+
region = {dim: "auto" for dim in ds.dims}
840+
841+
if not isinstance(region, dict):
842+
raise TypeError(f"``region`` must be a dict, got {type(region)}")
843+
if any(v == "auto" for v in region.values()):
844+
if self._mode != "r+":
845+
raise ValueError(
846+
f"``mode`` must be 'r+' when using ``region='auto'``, got {self._mode!r}"
847+
)
848+
region = self._auto_detect_regions(ds, region)
849+
850+
# validate before attempting to auto-detect since the auto-detection
851+
# should always return a valid slice.
852+
for k, v in region.items():
853+
if k not in ds.dims:
854+
raise ValueError(
855+
f"all keys in ``region`` are not in Dataset dimensions, got "
856+
f"{list(region)} and {list(ds.dims)}"
857+
)
858+
if not isinstance(v, slice):
859+
raise TypeError(
860+
"all values in ``region`` must be slice objects, got "
861+
f"region={region}"
862+
)
863+
if v.step not in {1, None}:
864+
raise ValueError(
865+
"step on all slices in ``region`` must be 1 or None, got "
866+
f"region={region}"
867+
)
868+
869+
non_matching_vars = [
870+
k for k, v in ds.variables.items() if not set(region).intersection(v.dims)
871+
]
872+
if non_matching_vars:
873+
raise ValueError(
874+
f"when setting `region` explicitly in to_zarr(), all "
875+
f"variables in the dataset to write must have at least "
876+
f"one dimension in common with the region's dimensions "
877+
f"{list(region.keys())}, but that is not "
878+
f"the case for some variables here. To drop these variables "
879+
f"from this dataset before exporting to zarr, write: "
880+
f".drop_vars({non_matching_vars!r})"
881+
)
882+
883+
self._write_region = region
884+
803885

804886
def open_zarr(
805887
store,

0 commit comments

Comments
 (0)