Skip to content

Commit 4a53e41

Browse files
Fix zarr append dtype checks (#6476)
* fix zarr append dtype check first commit * use zstore in _validate_datatype * remove coding.strings.is_unicode_dtype check * test appending fixed length strings * test string length mismatch raises for U and S * add explanatory comment for zarr append dtype checks Co-authored-by: Maximilian Roos <[email protected]>
1 parent 770e878 commit 4a53e41

File tree

3 files changed

+81
-19
lines changed

3 files changed

+81
-19
lines changed

xarray/backends/api.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import numpy as np
1919

20-
from .. import backends, coding, conventions
20+
from .. import backends, conventions
2121
from ..core import indexing
2222
from ..core.combine import (
2323
_infer_concat_order_from_positions,
@@ -1277,28 +1277,40 @@ def _validate_region(ds, region):
12771277
)
12781278

12791279

1280-
def _validate_datatypes_for_zarr_append(dataset):
1281-
"""DataArray.name and Dataset keys must be a string or None"""
1280+
def _validate_datatypes_for_zarr_append(zstore, dataset):
1281+
"""If variable exists in the store, confirm dtype of the data to append is compatible with
1282+
existing dtype.
1283+
"""
1284+
1285+
existing_vars = zstore.get_variables()
12821286

1283-
def check_dtype(var):
1287+
def check_dtype(vname, var):
12841288
if (
1285-
not np.issubdtype(var.dtype, np.number)
1286-
and not np.issubdtype(var.dtype, np.datetime64)
1287-
and not np.issubdtype(var.dtype, np.bool_)
1288-
and not coding.strings.is_unicode_dtype(var.dtype)
1289-
and not var.dtype == object
1289+
vname not in existing_vars
1290+
or np.issubdtype(var.dtype, np.number)
1291+
or np.issubdtype(var.dtype, np.datetime64)
1292+
or np.issubdtype(var.dtype, np.bool_)
1293+
or var.dtype == object
12901294
):
1291-
# and not re.match('^bytes[1-9]+$', var.dtype.name)):
1295+
# We can skip dtype equality checks under two conditions: (1) if the var to append is
1296+
# new to the dataset, because in this case there is no existing var to compare it to;
1297+
# or (2) if var to append's dtype is known to be easy-to-append, because in this case
1298+
# we can be confident appending won't cause problems. Examples of dtypes which are not
1299+
# easy-to-append include length-specified strings of type `|S*` or `<U*` (where * is a
1300+
# positive integer character length). For these dtypes, appending dissimilar lengths
1301+
# can result in truncation of appended data. Therefore, variables which already exist
1302+
# in the dataset, and with dtypes which are not known to be easy-to-append, necessitate
1303+
# exact dtype equality, as checked below.
1304+
pass
1305+
elif not var.dtype == existing_vars[vname].dtype:
12921306
raise ValueError(
1293-
"Invalid dtype for data variable: {} "
1294-
"dtype must be a subtype of number, "
1295-
"datetime, bool, a fixed sized string, "
1296-
"a fixed size unicode string or an "
1297-
"object".format(var)
1307+
f"Mismatched dtypes for variable {vname} between Zarr store on disk "
1308+
f"and dataset to append. Store has dtype {existing_vars[vname].dtype} but "
1309+
f"dataset to append has dtype {var.dtype}."
12981310
)
12991311

1300-
for k in dataset.data_vars.values():
1301-
check_dtype(k)
1312+
for vname, var in dataset.data_vars.items():
1313+
check_dtype(vname, var)
13021314

13031315

13041316
def to_zarr(
@@ -1403,7 +1415,7 @@ def to_zarr(
14031415
)
14041416

14051417
if mode in ["a", "r+"]:
1406-
_validate_datatypes_for_zarr_append(dataset)
1418+
_validate_datatypes_for_zarr_append(zstore, dataset)
14071419
if append_dim is not None:
14081420
existing_dims = zstore.get_dimensions()
14091421
if append_dim not in existing_dims:

xarray/tests/test_backends.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,11 @@
8282
_NON_STANDARD_CALENDARS,
8383
_STANDARD_CALENDARS,
8484
)
85-
from .test_dataset import create_append_test_data, create_test_data
85+
from .test_dataset import (
86+
create_append_string_length_mismatch_test_data,
87+
create_append_test_data,
88+
create_test_data,
89+
)
8690

8791
try:
8892
import netCDF4 as nc4
@@ -2112,6 +2116,17 @@ def test_append_with_existing_encoding_raises(self):
21122116
encoding={"da": {"compressor": None}},
21132117
)
21142118

2119+
@pytest.mark.parametrize("dtype", ["U", "S"])
2120+
def test_append_string_length_mismatch_raises(self, dtype):
2121+
ds, ds_to_append = create_append_string_length_mismatch_test_data(dtype)
2122+
with self.create_zarr_target() as store_target:
2123+
ds.to_zarr(store_target, mode="w")
2124+
with pytest.raises(ValueError, match="Mismatched dtypes for variable"):
2125+
ds_to_append.to_zarr(
2126+
store_target,
2127+
append_dim="time",
2128+
)
2129+
21152130
def test_check_encoding_is_consistent_after_append(self):
21162131

21172132
ds, ds_to_append, _ = create_append_test_data()

xarray/tests/test_dataset.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ def create_append_test_data(seed=None):
7676
time2 = pd.date_range("2000-02-01", periods=nt2)
7777
string_var = np.array(["ae", "bc", "df"], dtype=object)
7878
string_var_to_append = np.array(["asdf", "asdfg"], dtype=object)
79+
string_var_fixed_length = np.array(["aa", "bb", "cc"], dtype="|S2")
80+
string_var_fixed_length_to_append = np.array(["dd", "ee"], dtype="|S2")
7981
unicode_var = ["áó", "áó", "áó"]
8082
datetime_var = np.array(
8183
["2019-01-01", "2019-01-02", "2019-01-03"], dtype="datetime64[s]"
@@ -94,6 +96,9 @@ def create_append_test_data(seed=None):
9496
dims=["lat", "lon", "time"],
9597
),
9698
"string_var": xr.DataArray(string_var, coords=[time1], dims=["time"]),
99+
"string_var_fixed_length": xr.DataArray(
100+
string_var_fixed_length, coords=[time1], dims=["time"]
101+
),
97102
"unicode_var": xr.DataArray(
98103
unicode_var, coords=[time1], dims=["time"]
99104
).astype(np.unicode_),
@@ -112,6 +117,9 @@ def create_append_test_data(seed=None):
112117
"string_var": xr.DataArray(
113118
string_var_to_append, coords=[time2], dims=["time"]
114119
),
120+
"string_var_fixed_length": xr.DataArray(
121+
string_var_fixed_length_to_append, coords=[time2], dims=["time"]
122+
),
115123
"unicode_var": xr.DataArray(
116124
unicode_var[:nt2], coords=[time2], dims=["time"]
117125
).astype(np.unicode_),
@@ -137,6 +145,33 @@ def create_append_test_data(seed=None):
137145
return ds, ds_to_append, ds_with_new_var
138146

139147

148+
def create_append_string_length_mismatch_test_data(dtype):
149+
def make_datasets(data, data_to_append):
150+
ds = xr.Dataset(
151+
{"temperature": (["time"], data)},
152+
coords={"time": [0, 1, 2]},
153+
)
154+
ds_to_append = xr.Dataset(
155+
{"temperature": (["time"], data_to_append)}, coords={"time": [0, 1, 2]}
156+
)
157+
assert all(objp.data.flags.writeable for objp in ds.variables.values())
158+
assert all(
159+
objp.data.flags.writeable for objp in ds_to_append.variables.values()
160+
)
161+
return ds, ds_to_append
162+
163+
u2_strings = ["ab", "cd", "ef"]
164+
u5_strings = ["abc", "def", "ghijk"]
165+
166+
s2_strings = np.array(["aa", "bb", "cc"], dtype="|S2")
167+
s3_strings = np.array(["aaa", "bbb", "ccc"], dtype="|S3")
168+
169+
if dtype == "U":
170+
return make_datasets(u2_strings, u5_strings)
171+
elif dtype == "S":
172+
return make_datasets(s2_strings, s3_strings)
173+
174+
140175
def create_test_multiindex():
141176
mindex = pd.MultiIndex.from_product(
142177
[["a", "b"], [1, 2]], names=("level_1", "level_2")

0 commit comments

Comments
 (0)