Skip to content
Merged
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ Deprecations
Bug fixes
~~~~~~~~~

- Partial writes to existing chunks with ``region`` will now raise an error
(unless ``safe_chunks=False``); previously an error would only be raised on
new variables. (:pull:`8459`, :issue:`8371`)
By `Maximilian Roos <https://github.com/max-sixty>`_.
- Port `bug fix from pandas <https://github.com/pandas-dev/pandas/pull/55283>`_
to eliminate the adjustment of resample bin edges in the case that the
resampling frequency has units of days and is greater than one day
Expand Down
16 changes: 12 additions & 4 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):
f"Writing this array in parallel with dask could lead to corrupted data."
)
if safe_chunks:
raise NotImplementedError(
raise ValueError(
base_error
+ " Consider either rechunking using `chunk()`, deleting "
"or modifying `encoding['chunks']`, or specify `safe_chunks=False`."
Expand Down Expand Up @@ -679,6 +679,17 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
if v.encoding == {"_FillValue": None} and fill_value is None:
v.encoding = {}

# We need to do this for both new and existing variables to ensure we're not
# writing to a partial chunk, even though we don't use the `encoding` value
# when writing to an existing variable. See
# https://github.com/pydata/xarray/issues/8371 for details.
encoding = extract_zarr_variable_encoding(
v,
raise_on_invalid=check,
name=vn,
safe_chunks=self._safe_chunks,
)

if name in self.zarr_group:
# existing variable
# TODO: if mode="a", consider overriding the existing variable
Expand Down Expand Up @@ -709,9 +720,6 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
zarr_array = self.zarr_group[name]
else:
# new variable
encoding = extract_zarr_variable_encoding(
v, raise_on_invalid=check, name=vn, safe_chunks=self._safe_chunks
)
encoded_attrs = {}
# the magic for storing the hidden dimension data
encoded_attrs[DIMENSION_KEY] = dims
Expand Down
28 changes: 23 additions & 5 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2100,7 +2100,7 @@ def test_chunk_encoding_with_dask(self) -> None:
# should fail if encoding["chunks"] clashes with dask_chunks
badenc = ds.chunk({"x": 4})
badenc.var1.encoding["chunks"] = (6,)
with pytest.raises(NotImplementedError, match=r"named 'var1' would overlap"):
with pytest.raises(ValueError, match=r"named 'var1' would overlap"):
with self.roundtrip(badenc) as actual:
pass

Expand Down Expand Up @@ -2138,9 +2138,7 @@ def test_chunk_encoding_with_dask(self) -> None:
# but itermediate unaligned chunks are bad
badenc = ds.chunk({"x": (3, 5, 3, 1)})
badenc.var1.encoding["chunks"] = (3,)
with pytest.raises(
NotImplementedError, match=r"would overlap multiple dask chunks"
):
with pytest.raises(ValueError, match=r"would overlap multiple dask chunks"):
with self.roundtrip(badenc) as actual:
pass

Expand All @@ -2154,7 +2152,7 @@ def test_chunk_encoding_with_dask(self) -> None:
# TODO: remove this failure once synchronized overlapping writes are
# supported by xarray
ds_chunk4["var1"].encoding.update({"chunks": 5})
with pytest.raises(NotImplementedError, match=r"named 'var1' would overlap"):
with pytest.raises(ValueError, match=r"named 'var1' would overlap"):
with self.roundtrip(ds_chunk4) as actual:
pass
# override option
Expand Down Expand Up @@ -5442,3 +5440,23 @@ def test_zarr_region_transpose(tmp_path):
ds_region.to_zarr(
tmp_path / "test.zarr", region={"x": slice(0, 1), "y": slice(0, 1)}
)


@requires_zarr
@requires_dask
def test_zarr_region_chunk_partial(tmp_path):
"""
Check that writing to partial chunks with `region` fails, assuming `safe_chunks=False`.
"""
ds = (
xr.DataArray(np.arange(120).reshape(4, 3, -1), dims=list("abc"))
.rename("var1")
.to_dataset()
)

ds.chunk(5).to_zarr(tmp_path / "foo.zarr", compute=False, mode="w")
with pytest.raises(ValueError):
for r in range(ds.sizes["a"]):
ds.chunk(3).isel(a=[r]).to_zarr(
tmp_path / "foo.zarr", region=dict(a=slice(r, r + 1))
)