Skip to content

Commit 6fe1234

Browse files
authored
Zarr: Optimize appending (pydata#8998)
* Zarr: optimize appending * Update xarray/backends/zarr.py * Don't run `encoding` check if it wasn't provided. * Add regression test * fix types * fix test * Use mock instead
1 parent cd25bfa commit 6fe1234

File tree

4 files changed

+237
-68
lines changed

4 files changed

+237
-68
lines changed

.gitignore

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ nosetests.xml
5050
dask-worker-space/
5151

5252
# asv environments
53-
.asv
53+
asv_bench/.asv
54+
asv_bench/pkgs
5455

5556
# Translations
5657
*.mo
@@ -68,7 +69,7 @@ dask-worker-space/
6869

6970
# xarray specific
7071
doc/_build
71-
generated/
72+
doc/generated/
7273
xarray/tests/data/*.grib.*.idx
7374

7475
# Sync tools

xarray/backends/api.py

Lines changed: 3 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,42 +1521,6 @@ def save_mfdataset(
15211521
)
15221522

15231523

1524-
def _validate_datatypes_for_zarr_append(zstore, dataset):
1525-
"""If variable exists in the store, confirm dtype of the data to append is compatible with
1526-
existing dtype.
1527-
"""
1528-
1529-
existing_vars = zstore.get_variables()
1530-
1531-
def check_dtype(vname, var):
1532-
if (
1533-
vname not in existing_vars
1534-
or np.issubdtype(var.dtype, np.number)
1535-
or np.issubdtype(var.dtype, np.datetime64)
1536-
or np.issubdtype(var.dtype, np.bool_)
1537-
or var.dtype == object
1538-
):
1539-
# We can skip dtype equality checks under two conditions: (1) if the var to append is
1540-
# new to the dataset, because in this case there is no existing var to compare it to;
1541-
# or (2) if var to append's dtype is known to be easy-to-append, because in this case
1542-
# we can be confident appending won't cause problems. Examples of dtypes which are not
1543-
# easy-to-append include length-specified strings of type `|S*` or `<U*` (where * is a
1544-
# positive integer character length). For these dtypes, appending dissimilar lengths
1545-
# can result in truncation of appended data. Therefore, variables which already exist
1546-
# in the dataset, and with dtypes which are not known to be easy-to-append, necessitate
1547-
# exact dtype equality, as checked below.
1548-
pass
1549-
elif not var.dtype == existing_vars[vname].dtype:
1550-
raise ValueError(
1551-
f"Mismatched dtypes for variable {vname} between Zarr store on disk "
1552-
f"and dataset to append. Store has dtype {existing_vars[vname].dtype} but "
1553-
f"dataset to append has dtype {var.dtype}."
1554-
)
1555-
1556-
for vname, var in dataset.data_vars.items():
1557-
check_dtype(vname, var)
1558-
1559-
15601524
# compute=True returns ZarrStore
15611525
@overload
15621526
def to_zarr(
@@ -1712,37 +1676,21 @@ def to_zarr(
17121676

17131677
if region is not None:
17141678
zstore._validate_and_autodetect_region(dataset)
1715-
# can't modify indexed with region writes
1679+
# can't modify indexes with region writes
17161680
dataset = dataset.drop_vars(dataset.indexes)
17171681
if append_dim is not None and append_dim in region:
17181682
raise ValueError(
17191683
f"cannot list the same dimension in both ``append_dim`` and "
17201684
f"``region`` with to_zarr(), got {append_dim} in both"
17211685
)
17221686

1723-
if mode in ["a", "a-", "r+"]:
1724-
_validate_datatypes_for_zarr_append(zstore, dataset)
1725-
if append_dim is not None:
1726-
existing_dims = zstore.get_dimensions()
1727-
if append_dim not in existing_dims:
1728-
raise ValueError(
1729-
f"append_dim={append_dim!r} does not match any existing "
1730-
f"dataset dimensions {existing_dims}"
1731-
)
1687+
if encoding and mode in ["a", "a-", "r+"]:
17321688
existing_var_names = set(zstore.zarr_group.array_keys())
17331689
for var_name in existing_var_names:
1734-
if var_name in encoding.keys():
1690+
if var_name in encoding:
17351691
raise ValueError(
17361692
f"variable {var_name!r} already exists, but encoding was provided"
17371693
)
1738-
if mode == "r+":
1739-
new_names = [k for k in dataset.variables if k not in existing_var_names]
1740-
if new_names:
1741-
raise ValueError(
1742-
f"dataset contains non-pre-existing variables {new_names}, "
1743-
"which is not allowed in ``xarray.Dataset.to_zarr()`` with "
1744-
"mode='r+'. To allow writing new variables, set mode='a'."
1745-
)
17461694

17471695
writer = ArrayWriter()
17481696
# TODO: figure out how to properly handle unlimited_dims

xarray/backends/zarr.py

Lines changed: 70 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,34 @@ def encode_zarr_variable(var, needs_copy=True, name=None):
324324
return var
325325

326326

327+
def _validate_datatypes_for_zarr_append(vname, existing_var, new_var):
328+
"""If variable exists in the store, confirm dtype of the data to append is compatible with
329+
existing dtype.
330+
"""
331+
if (
332+
np.issubdtype(new_var.dtype, np.number)
333+
or np.issubdtype(new_var.dtype, np.datetime64)
334+
or np.issubdtype(new_var.dtype, np.bool_)
335+
or new_var.dtype == object
336+
):
337+
# We can skip dtype equality checks under two conditions: (1) if the var to append is
338+
# new to the dataset, because in this case there is no existing var to compare it to;
339+
# or (2) if var to append's dtype is known to be easy-to-append, because in this case
340+
# we can be confident appending won't cause problems. Examples of dtypes which are not
341+
# easy-to-append include length-specified strings of type `|S*` or `<U*` (where * is a
342+
# positive integer character length). For these dtypes, appending dissimilar lengths
343+
# can result in truncation of appended data. Therefore, variables which already exist
344+
# in the dataset, and with dtypes which are not known to be easy-to-append, necessitate
345+
# exact dtype equality, as checked below.
346+
pass
347+
elif not new_var.dtype == existing_var.dtype:
348+
raise ValueError(
349+
f"Mismatched dtypes for variable {vname} between Zarr store on disk "
350+
f"and dataset to append. Store has dtype {existing_var.dtype} but "
351+
f"dataset to append has dtype {new_var.dtype}."
352+
)
353+
354+
327355
def _validate_and_transpose_existing_dims(
328356
var_name, new_var, existing_var, region, append_dim
329357
):
@@ -612,26 +640,58 @@ def store(
612640
import zarr
613641

614642
existing_keys = tuple(self.zarr_group.array_keys())
643+
644+
if self._mode == "r+":
645+
new_names = [k for k in variables if k not in existing_keys]
646+
if new_names:
647+
raise ValueError(
648+
f"dataset contains non-pre-existing variables {new_names}, "
649+
"which is not allowed in ``xarray.Dataset.to_zarr()`` with "
650+
"``mode='r+'``. To allow writing new variables, set ``mode='a'``."
651+
)
652+
653+
if self._append_dim is not None and self._append_dim not in existing_keys:
654+
# For dimensions without coordinate values, we must parse
655+
# the _ARRAY_DIMENSIONS attribute on *all* arrays to check if it
656+
# is a valid existing dimension name.
657+
# TODO: This `get_dimensions` method also does shape checking
658+
# which isn't strictly necessary for our check.
659+
existing_dims = self.get_dimensions()
660+
if self._append_dim not in existing_dims:
661+
raise ValueError(
662+
f"append_dim={self._append_dim!r} does not match any existing "
663+
f"dataset dimensions {existing_dims}"
664+
)
665+
615666
existing_variable_names = {
616667
vn for vn in variables if _encode_variable_name(vn) in existing_keys
617668
}
618-
new_variables = set(variables) - existing_variable_names
619-
variables_without_encoding = {vn: variables[vn] for vn in new_variables}
669+
new_variable_names = set(variables) - existing_variable_names
620670
variables_encoded, attributes = self.encode(
621-
variables_without_encoding, attributes
671+
{vn: variables[vn] for vn in new_variable_names}, attributes
622672
)
623673

624674
if existing_variable_names:
625-
# Decode variables directly, without going via xarray.Dataset to
626-
# avoid needing to load index variables into memory.
627-
# TODO: consider making loading indexes lazy again?
675+
# We make sure that values to be appended are encoded *exactly*
676+
# as the current values in the store.
677+
# To do so, we decode variables directly to access the proper encoding,
678+
# without going via xarray.Dataset to avoid needing to load
679+
# index variables into memory.
628680
existing_vars, _, _ = conventions.decode_cf_variables(
629-
{k: self.open_store_variable(name=k) for k in existing_variable_names},
630-
self.get_attrs(),
681+
variables={
682+
k: self.open_store_variable(name=k) for k in existing_variable_names
683+
},
684+
# attributes = {} since we don't care about parsing the global
685+
# "coordinates" attribute
686+
attributes={},
631687
)
632688
# Modified variables must use the same encoding as the store.
633689
vars_with_encoding = {}
634690
for vn in existing_variable_names:
691+
if self._mode in ["a", "a-", "r+"]:
692+
_validate_datatypes_for_zarr_append(
693+
vn, existing_vars[vn], variables[vn]
694+
)
635695
vars_with_encoding[vn] = variables[vn].copy(deep=False)
636696
vars_with_encoding[vn].encoding = existing_vars[vn].encoding
637697
vars_with_encoding, _ = self.encode(vars_with_encoding, {})
@@ -696,7 +756,6 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
696756

697757
for vn, v in variables.items():
698758
name = _encode_variable_name(vn)
699-
check = vn in check_encoding_set
700759
attrs = v.attrs.copy()
701760
dims = v.dims
702761
dtype = v.dtype
@@ -712,7 +771,7 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
712771
# https://github.com/pydata/xarray/issues/8371 for details.
713772
encoding = extract_zarr_variable_encoding(
714773
v,
715-
raise_on_invalid=check,
774+
raise_on_invalid=vn in check_encoding_set,
716775
name=vn,
717776
safe_chunks=self._safe_chunks,
718777
)
@@ -815,7 +874,7 @@ def _auto_detect_regions(self, ds, region):
815874
assert variable.dims == (dim,)
816875
index = pd.Index(variable.data)
817876
idxs = index.get_indexer(ds[dim].data)
818-
if any(idxs == -1):
877+
if (idxs == -1).any():
819878
raise KeyError(
820879
f"Not all values of coordinate '{dim}' in the new array were"
821880
" found in the original store. Writing to a zarr region slice"

xarray/tests/test_backends.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2990,6 +2990,167 @@ def test_chunked_cftime_datetime(self) -> None:
29902990
assert original.chunks == actual.chunks
29912991

29922992

2993+
@requires_zarr
2994+
@pytest.mark.skipif(not have_zarr_v3, reason="requires zarr version 3")
2995+
class TestInstrumentedZarrStore:
2996+
methods = [
2997+
"__iter__",
2998+
"__contains__",
2999+
"__setitem__",
3000+
"__getitem__",
3001+
"listdir",
3002+
"list_prefix",
3003+
]
3004+
3005+
@contextlib.contextmanager
3006+
def create_zarr_target(self):
3007+
import zarr
3008+
3009+
if Version(zarr.__version__) < Version("2.18.0"):
3010+
pytest.skip("Instrumented tests only work on latest Zarr.")
3011+
3012+
store = KVStoreV3({})
3013+
yield store
3014+
3015+
def make_patches(self, store):
3016+
from unittest.mock import MagicMock
3017+
3018+
return {
3019+
method: MagicMock(
3020+
f"KVStoreV3.{method}",
3021+
side_effect=getattr(store, method),
3022+
autospec=True,
3023+
)
3024+
for method in self.methods
3025+
}
3026+
3027+
def summarize(self, patches):
3028+
summary = {}
3029+
for name, patch_ in patches.items():
3030+
count = 0
3031+
for call in patch_.mock_calls:
3032+
if "zarr.json" not in call.args:
3033+
count += 1
3034+
summary[name.strip("__")] = count
3035+
return summary
3036+
3037+
def check_requests(self, expected, patches):
3038+
summary = self.summarize(patches)
3039+
for k in summary:
3040+
assert summary[k] <= expected[k], (k, summary)
3041+
3042+
def test_append(self) -> None:
3043+
original = Dataset({"foo": ("x", [1])}, coords={"x": [0]})
3044+
modified = Dataset({"foo": ("x", [2])}, coords={"x": [1]})
3045+
with self.create_zarr_target() as store:
3046+
expected = {
3047+
"iter": 2,
3048+
"contains": 9,
3049+
"setitem": 9,
3050+
"getitem": 6,
3051+
"listdir": 2,
3052+
"list_prefix": 2,
3053+
}
3054+
patches = self.make_patches(store)
3055+
with patch.multiple(KVStoreV3, **patches):
3056+
original.to_zarr(store)
3057+
self.check_requests(expected, patches)
3058+
3059+
patches = self.make_patches(store)
3060+
# v2024.03.0: {'iter': 6, 'contains': 2, 'setitem': 5, 'getitem': 10, 'listdir': 6, 'list_prefix': 0}
3061+
# 6057128b: {'iter': 5, 'contains': 2, 'setitem': 5, 'getitem': 10, "listdir": 5, "list_prefix": 0}
3062+
expected = {
3063+
"iter": 2,
3064+
"contains": 2,
3065+
"setitem": 5,
3066+
"getitem": 6,
3067+
"listdir": 2,
3068+
"list_prefix": 0,
3069+
}
3070+
with patch.multiple(KVStoreV3, **patches):
3071+
modified.to_zarr(store, mode="a", append_dim="x")
3072+
self.check_requests(expected, patches)
3073+
3074+
patches = self.make_patches(store)
3075+
expected = {
3076+
"iter": 2,
3077+
"contains": 2,
3078+
"setitem": 5,
3079+
"getitem": 6,
3080+
"listdir": 2,
3081+
"list_prefix": 0,
3082+
}
3083+
with patch.multiple(KVStoreV3, **patches):
3084+
modified.to_zarr(store, mode="a-", append_dim="x")
3085+
self.check_requests(expected, patches)
3086+
3087+
with open_dataset(store, engine="zarr") as actual:
3088+
assert_identical(
3089+
actual, xr.concat([original, modified, modified], dim="x")
3090+
)
3091+
3092+
@requires_dask
3093+
def test_region_write(self) -> None:
3094+
ds = Dataset({"foo": ("x", [1, 2, 3])}, coords={"x": [0, 1, 2]}).chunk()
3095+
with self.create_zarr_target() as store:
3096+
expected = {
3097+
"iter": 2,
3098+
"contains": 7,
3099+
"setitem": 8,
3100+
"getitem": 6,
3101+
"listdir": 2,
3102+
"list_prefix": 4,
3103+
}
3104+
patches = self.make_patches(store)
3105+
with patch.multiple(KVStoreV3, **patches):
3106+
ds.to_zarr(store, mode="w", compute=False)
3107+
self.check_requests(expected, patches)
3108+
3109+
# v2024.03.0: {'iter': 5, 'contains': 2, 'setitem': 1, 'getitem': 6, 'listdir': 5, 'list_prefix': 0}
3110+
# 6057128b: {'iter': 4, 'contains': 2, 'setitem': 1, 'getitem': 5, 'listdir': 4, 'list_prefix': 0}
3111+
expected = {
3112+
"iter": 2,
3113+
"contains": 2,
3114+
"setitem": 1,
3115+
"getitem": 3,
3116+
"listdir": 2,
3117+
"list_prefix": 0,
3118+
}
3119+
patches = self.make_patches(store)
3120+
with patch.multiple(KVStoreV3, **patches):
3121+
ds.to_zarr(store, region={"x": slice(None)})
3122+
self.check_requests(expected, patches)
3123+
3124+
# v2024.03.0: {'iter': 6, 'contains': 4, 'setitem': 1, 'getitem': 11, 'listdir': 6, 'list_prefix': 0}
3125+
# 6057128b: {'iter': 4, 'contains': 2, 'setitem': 1, 'getitem': 7, 'listdir': 4, 'list_prefix': 0}
3126+
expected = {
3127+
"iter": 2,
3128+
"contains": 2,
3129+
"setitem": 1,
3130+
"getitem": 5,
3131+
"listdir": 2,
3132+
"list_prefix": 0,
3133+
}
3134+
patches = self.make_patches(store)
3135+
with patch.multiple(KVStoreV3, **patches):
3136+
ds.to_zarr(store, region="auto")
3137+
self.check_requests(expected, patches)
3138+
3139+
expected = {
3140+
"iter": 1,
3141+
"contains": 2,
3142+
"setitem": 0,
3143+
"getitem": 5,
3144+
"listdir": 1,
3145+
"list_prefix": 0,
3146+
}
3147+
patches = self.make_patches(store)
3148+
with patch.multiple(KVStoreV3, **patches):
3149+
with open_dataset(store, engine="zarr") as actual:
3150+
assert_identical(actual, ds)
3151+
self.check_requests(expected, patches)
3152+
3153+
29933154
@requires_zarr
29943155
class TestZarrDictStore(ZarrBase):
29953156
@contextlib.contextmanager

0 commit comments

Comments
 (0)