-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Zarr: Optimize region="auto"
detection
#8997
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,7 +27,6 @@ | |
_normalize_path, | ||
) | ||
from xarray.backends.locks import _get_scheduler | ||
from xarray.backends.zarr import open_zarr | ||
from xarray.core import indexing | ||
from xarray.core.combine import ( | ||
_infer_concat_order_from_positions, | ||
|
@@ -1522,92 +1521,6 @@ def save_mfdataset( | |
) | ||
|
||
|
||
def _auto_detect_region(ds_new, ds_orig, dim): | ||
# Create a mapping array of coordinates to indices on the original array | ||
coord = ds_orig[dim] | ||
da_map = DataArray(np.arange(coord.size), coords={dim: coord}) | ||
|
||
try: | ||
da_idxs = da_map.sel({dim: ds_new[dim]}) | ||
except KeyError as e: | ||
if "not all values found" in str(e): | ||
raise KeyError( | ||
f"Not all values of coordinate '{dim}' in the new array were" | ||
" found in the original store. Writing to a zarr region slice" | ||
" requires that no dimensions or metadata are changed by the write." | ||
) | ||
else: | ||
raise e | ||
|
||
if (da_idxs.diff(dim) != 1).any(): | ||
raise ValueError( | ||
f"The auto-detected region of coordinate '{dim}' for writing new data" | ||
" to the original store had non-contiguous indices. Writing to a zarr" | ||
" region slice requires that the new data constitute a contiguous subset" | ||
" of the original store." | ||
) | ||
|
||
dim_slice = slice(da_idxs.values[0], da_idxs.values[-1] + 1) | ||
|
||
return dim_slice | ||
|
||
|
||
def _auto_detect_regions(ds, region, open_kwargs): | ||
ds_original = open_zarr(**open_kwargs) | ||
for key, val in region.items(): | ||
if val == "auto": | ||
region[key] = _auto_detect_region(ds, ds_original, key) | ||
return region | ||
|
||
|
||
def _validate_and_autodetect_region(ds, region, mode, open_kwargs) -> dict[str, slice]: | ||
if region == "auto": | ||
region = {dim: "auto" for dim in ds.dims} | ||
|
||
if not isinstance(region, dict): | ||
raise TypeError(f"``region`` must be a dict, got {type(region)}") | ||
|
||
if any(v == "auto" for v in region.values()): | ||
if mode != "r+": | ||
raise ValueError( | ||
f"``mode`` must be 'r+' when using ``region='auto'``, got {mode}" | ||
) | ||
region = _auto_detect_regions(ds, region, open_kwargs) | ||
|
||
for k, v in region.items(): | ||
if k not in ds.dims: | ||
raise ValueError( | ||
f"all keys in ``region`` are not in Dataset dimensions, got " | ||
f"{list(region)} and {list(ds.dims)}" | ||
) | ||
if not isinstance(v, slice): | ||
raise TypeError( | ||
"all values in ``region`` must be slice objects, got " | ||
f"region={region}" | ||
) | ||
if v.step not in {1, None}: | ||
raise ValueError( | ||
"step on all slices in ``region`` must be 1 or None, got " | ||
f"region={region}" | ||
) | ||
|
||
non_matching_vars = [ | ||
k for k, v in ds.variables.items() if not set(region).intersection(v.dims) | ||
] | ||
if non_matching_vars: | ||
raise ValueError( | ||
f"when setting `region` explicitly in to_zarr(), all " | ||
f"variables in the dataset to write must have at least " | ||
f"one dimension in common with the region's dimensions " | ||
f"{list(region.keys())}, but that is not " | ||
f"the case for some variables here. To drop these variables " | ||
f"from this dataset before exporting to zarr, write: " | ||
f".drop_vars({non_matching_vars!r})" | ||
) | ||
|
||
return region | ||
|
||
|
||
def _validate_datatypes_for_zarr_append(zstore, dataset): | ||
"""If variable exists in the store, confirm dtype of the data to append is compatible with | ||
existing dtype. | ||
|
@@ -1768,24 +1681,6 @@ def to_zarr( | |
# validate Dataset keys, DataArray names | ||
_validate_dataset_names(dataset) | ||
|
||
if region is not None: | ||
open_kwargs = dict( | ||
store=store, | ||
synchronizer=synchronizer, | ||
group=group, | ||
consolidated=consolidated, | ||
storage_options=storage_options, | ||
zarr_version=zarr_version, | ||
) | ||
region = _validate_and_autodetect_region(dataset, region, mode, open_kwargs) | ||
# can't modify indexed with region writes | ||
dataset = dataset.drop_vars(dataset.indexes) | ||
if append_dim is not None and append_dim in region: | ||
raise ValueError( | ||
f"cannot list the same dimension in both ``append_dim`` and " | ||
f"``region`` with to_zarr(), got {append_dim} in both" | ||
) | ||
|
||
if zarr_version is None: | ||
# default to 2 if store doesn't specify it's version (e.g. a path) | ||
zarr_version = int(getattr(store, "_store_version", 2)) | ||
|
@@ -1815,6 +1710,16 @@ def to_zarr( | |
write_empty=write_empty_chunks, | ||
) | ||
|
||
if region is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved down so we only open the Zarr store once. |
||
zstore._validate_and_autodetect_region(dataset) | ||
# can't modify indexed with region writes | ||
dataset = dataset.drop_vars(dataset.indexes) | ||
if append_dim is not None and append_dim in region: | ||
raise ValueError( | ||
f"cannot list the same dimension in both ``append_dim`` and " | ||
f"``region`` with to_zarr(), got {append_dim} in both" | ||
) | ||
|
||
if mode in ["a", "a-", "r+"]: | ||
_validate_datatypes_for_zarr_append(zstore, dataset) | ||
if append_dim is not None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
from typing import TYPE_CHECKING, Any | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from xarray import coding, conventions | ||
from xarray.backends.common import ( | ||
|
@@ -509,7 +510,9 @@ def ds(self): | |
# TODO: consider deprecating this in favor of zarr_group | ||
return self.zarr_group | ||
|
||
def open_store_variable(self, name, zarr_array): | ||
def open_store_variable(self, name, zarr_array=None): | ||
if zarr_array is None: | ||
zarr_array = self.zarr_group[name] | ||
data = indexing.LazilyIndexedArray(ZarrArrayWrapper(zarr_array)) | ||
try_nczarr = self._mode == "r" | ||
dimensions, attributes = _get_zarr_dims_and_attrs( | ||
|
@@ -623,11 +626,7 @@ def store( | |
# avoid needing to load index variables into memory. | ||
# TODO: consider making loading indexes lazy again? | ||
existing_vars, _, _ = conventions.decode_cf_variables( | ||
{ | ||
k: v | ||
for k, v in self.get_variables().items() | ||
if k in existing_variable_names | ||
}, | ||
{k: self.open_store_variable(name=k) for k in existing_variable_names}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just open the needed variables instead of opening all of them. |
||
self.get_attrs(), | ||
) | ||
# Modified variables must use the same encoding as the store. | ||
|
@@ -796,10 +795,93 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No | |
region = tuple(write_region[dim] for dim in dims) | ||
writer.add(v.data, zarr_array, region) | ||
|
||
def close(self): | ||
def close(self) -> None: | ||
if self._close_store_on_close: | ||
self.zarr_group.store.close() | ||
|
||
def _auto_detect_regions(self, ds, region): | ||
for dim, val in region.items(): | ||
if val != "auto": | ||
continue | ||
|
||
if dim not in ds._variables: | ||
# unindexed dimension | ||
region[dim] = slice(0, ds.sizes[dim]) | ||
continue | ||
|
||
variable = conventions.decode_cf_variable( | ||
dim, self.open_store_variable(dim).compute() | ||
) | ||
assert variable.dims == (dim,) | ||
index = pd.Index(variable.data) | ||
idxs = index.get_indexer(ds[dim].data) | ||
Comment on lines
+812
to
+817
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lines 812-817: This is the main logic change. |
||
if any(idxs == -1): | ||
raise KeyError( | ||
f"Not all values of coordinate '{dim}' in the new array were" | ||
" found in the original store. Writing to a zarr region slice" | ||
" requires that no dimensions or metadata are changed by the write." | ||
) | ||
|
||
if (np.diff(idxs) != 1).any(): | ||
raise ValueError( | ||
f"The auto-detected region of coordinate '{dim}' for writing new data" | ||
" to the original store had non-contiguous indices. Writing to a zarr" | ||
" region slice requires that the new data constitute a contiguous subset" | ||
" of the original store." | ||
) | ||
region[dim] = slice(idxs[0], idxs[-1] + 1) | ||
return region | ||
|
||
def _validate_and_autodetect_region(self, ds) -> None: | ||
region = self._write_region | ||
|
||
if region == "auto": | ||
region = {dim: "auto" for dim in ds.dims} | ||
|
||
if not isinstance(region, dict): | ||
raise TypeError(f"``region`` must be a dict, got {type(region)}") | ||
if any(v == "auto" for v in region.values()): | ||
if self._mode != "r+": | ||
raise ValueError( | ||
f"``mode`` must be 'r+' when using ``region='auto'``, got {self._mode!r}" | ||
) | ||
region = self._auto_detect_regions(ds, region) | ||
|
||
# validate before attempting to auto-detect since the auto-detection | ||
# should always return a valid slice. | ||
for k, v in region.items(): | ||
if k not in ds.dims: | ||
raise ValueError( | ||
f"all keys in ``region`` are not in Dataset dimensions, got " | ||
f"{list(region)} and {list(ds.dims)}" | ||
) | ||
if not isinstance(v, slice): | ||
raise TypeError( | ||
"all values in ``region`` must be slice objects, got " | ||
f"region={region}" | ||
) | ||
if v.step not in {1, None}: | ||
raise ValueError( | ||
"step on all slices in ``region`` must be 1 or None, got " | ||
f"region={region}" | ||
) | ||
|
||
non_matching_vars = [ | ||
k for k, v in ds.variables.items() if not set(region).intersection(v.dims) | ||
] | ||
if non_matching_vars: | ||
raise ValueError( | ||
f"when setting `region` explicitly in to_zarr(), all " | ||
f"variables in the dataset to write must have at least " | ||
f"one dimension in common with the region's dimensions " | ||
f"{list(region.keys())}, but that is not " | ||
f"the case for some variables here. To drop these variables " | ||
f"from this dataset before exporting to zarr, write: " | ||
f".drop_vars({non_matching_vars!r})" | ||
) | ||
|
||
self._write_region = region | ||
|
||
|
||
def open_zarr( | ||
store, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This last line does not do what it looks like it's doing if there are no indexes!