From 657633b78483b1d34518cb9ada25af958f846b12 Mon Sep 17 00:00:00 2001 From: Charles Stern <62192187+cisaacstern@users.noreply.github.com> Date: Mon, 11 Apr 2022 17:26:23 -0700 Subject: [PATCH 1/6] fix zarr append dtype check first commit --- xarray/backends/api.py | 38 +++++++++++++++++++---------------- xarray/tests/test_backends.py | 16 ++++++++++++++- xarray/tests/test_dataset.py | 16 +++++++++++++++ 3 files changed, 52 insertions(+), 18 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 548b98048ba..1c4c6d74a0f 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1277,28 +1277,32 @@ def _validate_region(ds, region): ) -def _validate_datatypes_for_zarr_append(dataset): - """DataArray.name and Dataset keys must be a string or None""" +def _validate_datatypes_for_zarr_append(store, dataset): + """If variable exists in the store, confirm dtype of the data to append is compatible with + existing dtype. + """ + + existing_ds = backends.zarr.open_zarr(store) - def check_dtype(var): + def check_dtype(vname, var): if ( - not np.issubdtype(var.dtype, np.number) - and not np.issubdtype(var.dtype, np.datetime64) - and not np.issubdtype(var.dtype, np.bool_) - and not coding.strings.is_unicode_dtype(var.dtype) - and not var.dtype == object + vname not in existing_ds.data_vars + or np.issubdtype(var.dtype, np.number) + or np.issubdtype(var.dtype, np.datetime64) + or np.issubdtype(var.dtype, np.bool_) + or coding.strings.is_unicode_dtype(var.dtype) + or var.dtype == object ): - # and not re.match('^bytes[1-9]+$', var.dtype.name)): + pass + elif not var.dtype == existing_ds[vname].dtype: raise ValueError( - "Invalid dtype for data variable: {} " - "dtype must be a subtype of number, " - "datetime, bool, a fixed sized string, " - "a fixed size unicode string or an " - "object".format(var) + f"Mismatched dtypes for variable {vname} between Zarr store on disk " + f"and dataset to append. Store has dtype {existing_ds[vname].dtype} but " + f"dataset to append has dtype {var.dtype}." ) - for k in dataset.data_vars.values(): - check_dtype(k) + for vname, var in dataset.data_vars.items(): + check_dtype(vname, var) def to_zarr( @@ -1403,7 +1407,7 @@ def to_zarr( ) if mode in ["a", "r+"]: - _validate_datatypes_for_zarr_append(dataset) + _validate_datatypes_for_zarr_append(store, dataset) if append_dim is not None: existing_dims = zstore.get_dimensions() if append_dim not in existing_dims: diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 825c6f7130f..f19ed3afa19 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -81,7 +81,11 @@ _NON_STANDARD_CALENDARS, _STANDARD_CALENDARS, ) -from .test_dataset import create_append_test_data, create_test_data +from .test_dataset import ( + create_append_mismatch_test_data, + create_append_test_data, + create_test_data, +) try: import netCDF4 as nc4 @@ -2111,6 +2115,16 @@ def test_append_with_existing_encoding_raises(self): encoding={"da": {"compressor": None}}, ) + def test_append_dtype_mismatch_raises(self): + ds, ds_to_append = create_append_mismatch_test_data() + with self.create_zarr_target() as store_target: + ds.to_zarr(store_target, mode="w") + with pytest.raises(ValueError, match="Mismatched dtypes for variable"): + ds_to_append.to_zarr( + store_target, + append_dim="time", + ) + def test_check_encoding_is_consistent_after_append(self): ds, ds_to_append, _ = create_append_test_data() diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 5f368375fc0..61e7179e057 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -137,6 +137,22 @@ def create_append_test_data(seed=None): return ds, ds_to_append, ds_with_new_var +def create_append_mismatch_test_data(): + fixed_length_strings = ["ab", "cd", "ef"] + variable_length_strings = ["abc", "def", "ghijk"] + + ds = xr.Dataset( + {"temperature": (["time"], fixed_length_strings)}, + coords={"time": [0, 1, 2]}, + ) + ds_to_append = xr.Dataset( + {"temperature": (["time"], variable_length_strings)}, coords={"time": [0, 1, 2]} + ) + assert all(objp.data.flags.writeable for objp in ds.variables.values()) + assert all(objp.data.flags.writeable for objp in ds_to_append.variables.values()) + return ds, ds_to_append + + def create_test_multiindex(): mindex = pd.MultiIndex.from_product( [["a", "b"], [1, 2]], names=("level_1", "level_2") From 09cb09d8dca52171c48d233d71d85a6a08a45c93 Mon Sep 17 00:00:00 2001 From: Charles Stern <62192187+cisaacstern@users.noreply.github.com> Date: Tue, 12 Apr 2022 09:50:40 -0700 Subject: [PATCH 2/6] use zstore in _validate_datatype --- xarray/backends/api.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 70c073158a9..00106fa4d98 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1277,16 +1277,16 @@ def _validate_region(ds, region): ) -def _validate_datatypes_for_zarr_append(store, dataset): +def _validate_datatypes_for_zarr_append(zstore, dataset): """If variable exists in the store, confirm dtype of the data to append is compatible with existing dtype. """ - existing_ds = backends.zarr.open_zarr(store) + existing_vars = zstore.get_variables() def check_dtype(vname, var): if ( - vname not in existing_ds.data_vars + vname not in existing_vars or np.issubdtype(var.dtype, np.number) or np.issubdtype(var.dtype, np.datetime64) or np.issubdtype(var.dtype, np.bool_) @@ -1294,10 +1294,10 @@ def check_dtype(vname, var): or var.dtype == object ): pass - elif not var.dtype == existing_ds[vname].dtype: + elif not var.dtype == existing_vars[vname].dtype: raise ValueError( f"Mismatched dtypes for variable {vname} between Zarr store on disk " - f"and dataset to append. Store has dtype {existing_ds[vname].dtype} but " + f"and dataset to append. Store has dtype {existing_vars[vname].dtype} but " f"dataset to append has dtype {var.dtype}." ) @@ -1407,7 +1407,7 @@ def to_zarr( ) if mode in ["a", "r+"]: - _validate_datatypes_for_zarr_append(store, dataset) + _validate_datatypes_for_zarr_append(zstore, dataset) if append_dim is not None: existing_dims = zstore.get_dimensions() if append_dim not in existing_dims: From 3e945ef062873c47c53b759d2d4dd57dc07ecd46 Mon Sep 17 00:00:00 2001 From: Charles Stern <62192187+cisaacstern@users.noreply.github.com> Date: Tue, 12 Apr 2022 10:19:45 -0700 Subject: [PATCH 3/6] remove coding.strings.is_unicode_dtype check --- xarray/backends/api.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 00106fa4d98..2029f787555 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -17,7 +17,7 @@ import numpy as np -from .. import backends, coding, conventions +from .. import backends, conventions from ..core import indexing from ..core.combine import ( _infer_concat_order_from_positions, @@ -1290,7 +1290,6 @@ def check_dtype(vname, var): or np.issubdtype(var.dtype, np.number) or np.issubdtype(var.dtype, np.datetime64) or np.issubdtype(var.dtype, np.bool_) - or coding.strings.is_unicode_dtype(var.dtype) or var.dtype == object ): pass From 9b1186102e9a49878cebae4f85277defe3b6c777 Mon Sep 17 00:00:00 2001 From: Charles Stern <62192187+cisaacstern@users.noreply.github.com> Date: Tue, 12 Apr 2022 13:17:52 -0700 Subject: [PATCH 4/6] test appending fixed length strings --- xarray/tests/test_backends.py | 6 +++--- xarray/tests/test_dataset.py | 18 +++++++++++++----- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 12139fdb308..c4d85a28750 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -82,7 +82,7 @@ _STANDARD_CALENDARS, ) from .test_dataset import ( - create_append_mismatch_test_data, + create_append_string_length_mismatch_test_data, create_append_test_data, create_test_data, ) @@ -2115,8 +2115,8 @@ def test_append_with_existing_encoding_raises(self): encoding={"da": {"compressor": None}}, ) - def test_append_dtype_mismatch_raises(self): - ds, ds_to_append = create_append_mismatch_test_data() + def test_append_string_length_mismatch_raises(self): + ds, ds_to_append = create_append_string_length_mismatch_test_data() with self.create_zarr_target() as store_target: ds.to_zarr(store_target, mode="w") with pytest.raises(ValueError, match="Mismatched dtypes for variable"): diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index b6bc91c2720..3e9c4052631 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -76,6 +76,8 @@ def create_append_test_data(seed=None): time2 = pd.date_range("2000-02-01", periods=nt2) string_var = np.array(["ae", "bc", "df"], dtype=object) string_var_to_append = np.array(["asdf", "asdfg"], dtype=object) + string_var_fixed_length = np.array(["aa", "bb", "cc"], dtype="|S2") + string_var_fixed_length_to_append = np.array(["dd", "ee"], dtype="|S2") unicode_var = ["áó", "áó", "áó"] datetime_var = np.array( ["2019-01-01", "2019-01-02", "2019-01-03"], dtype="datetime64[s]" @@ -94,6 +96,9 @@ def create_append_test_data(seed=None): dims=["lat", "lon", "time"], ), "string_var": xr.DataArray(string_var, coords=[time1], dims=["time"]), + "string_var_fixed_length": xr.DataArray( + string_var_fixed_length, coords=[time1], dims=["time"] + ), "unicode_var": xr.DataArray( unicode_var, coords=[time1], dims=["time"] ).astype(np.unicode_), @@ -112,6 +117,9 @@ def create_append_test_data(seed=None): "string_var": xr.DataArray( string_var_to_append, coords=[time2], dims=["time"] ), + "string_var_fixed_length": xr.DataArray( + string_var_fixed_length_to_append, coords=[time2], dims=["time"] + ), "unicode_var": xr.DataArray( unicode_var[:nt2], coords=[time2], dims=["time"] ).astype(np.unicode_), @@ -137,16 +145,16 @@ def create_append_test_data(seed=None): return ds, ds_to_append, ds_with_new_var -def create_append_mismatch_test_data(): - fixed_length_strings = ["ab", "cd", "ef"] - variable_length_strings = ["abc", "def", "ghijk"] +def create_append_string_length_mismatch_test_data(): + u2_strings = ["ab", "cd", "ef"] + u5_strings = ["abc", "def", "ghijk"] ds = xr.Dataset( - {"temperature": (["time"], fixed_length_strings)}, + {"temperature": (["time"], u2_strings)}, coords={"time": [0, 1, 2]}, ) ds_to_append = xr.Dataset( - {"temperature": (["time"], variable_length_strings)}, coords={"time": [0, 1, 2]} + {"temperature": (["time"], u5_strings)}, coords={"time": [0, 1, 2]} ) assert all(objp.data.flags.writeable for objp in ds.variables.values()) assert all(objp.data.flags.writeable for objp in ds_to_append.variables.values()) From 4abcaa8ed83350c26c52aa3b8982779e106aaf8e Mon Sep 17 00:00:00 2001 From: Charles Stern <62192187+cisaacstern@users.noreply.github.com> Date: Tue, 12 Apr 2022 14:18:43 -0700 Subject: [PATCH 5/6] test string length mismatch raises for U and S --- xarray/tests/test_backends.py | 5 +++-- xarray/tests/test_dataset.py | 33 ++++++++++++++++++++++----------- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index c4d85a28750..fc1978ddb98 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2115,8 +2115,9 @@ def test_append_with_existing_encoding_raises(self): encoding={"da": {"compressor": None}}, ) - def test_append_string_length_mismatch_raises(self): - ds, ds_to_append = create_append_string_length_mismatch_test_data() + @pytest.mark.parametrize("dtype", ["U", "S"]) + def test_append_string_length_mismatch_raises(self, dtype): + ds, ds_to_append = create_append_string_length_mismatch_test_data(dtype) with self.create_zarr_target() as store_target: ds.to_zarr(store_target, mode="w") with pytest.raises(ValueError, match="Mismatched dtypes for variable"): diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 3e9c4052631..ea732a4b851 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -145,20 +145,31 @@ def create_append_test_data(seed=None): return ds, ds_to_append, ds_with_new_var -def create_append_string_length_mismatch_test_data(): +def create_append_string_length_mismatch_test_data(dtype): + def make_datasets(data, data_to_append): + ds = xr.Dataset( + {"temperature": (["time"], data)}, + coords={"time": [0, 1, 2]}, + ) + ds_to_append = xr.Dataset( + {"temperature": (["time"], data_to_append)}, coords={"time": [0, 1, 2]} + ) + assert all(objp.data.flags.writeable for objp in ds.variables.values()) + assert all( + objp.data.flags.writeable for objp in ds_to_append.variables.values() + ) + return ds, ds_to_append + u2_strings = ["ab", "cd", "ef"] u5_strings = ["abc", "def", "ghijk"] - ds = xr.Dataset( - {"temperature": (["time"], u2_strings)}, - coords={"time": [0, 1, 2]}, - ) - ds_to_append = xr.Dataset( - {"temperature": (["time"], u5_strings)}, coords={"time": [0, 1, 2]} - ) - assert all(objp.data.flags.writeable for objp in ds.variables.values()) - assert all(objp.data.flags.writeable for objp in ds_to_append.variables.values()) - return ds, ds_to_append + s2_strings = np.array(["aa", "bb", "cc"], dtype="|S2") + s3_strings = np.array(["aaa", "bbb", "ccc"], dtype="|S3") + + if dtype == "U": + return make_datasets(u2_strings, u5_strings) + elif dtype == "S": + return make_datasets(s2_strings, s3_strings) def create_test_multiindex(): From 9f97f00636d09454d6b7d2e13fa0775ed7770769 Mon Sep 17 00:00:00 2001 From: Charles Stern <62192187+cisaacstern@users.noreply.github.com> Date: Tue, 10 May 2022 15:39:01 -0700 Subject: [PATCH 6/6] add explanatory comment for zarr append dtype checks --- xarray/backends/api.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 02d23cf15fe..05aa5d04deb 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1292,6 +1292,15 @@ def check_dtype(vname, var): or np.issubdtype(var.dtype, np.bool_) or var.dtype == object ): + # We can skip dtype equality checks under two conditions: (1) if the var to append is + # new to the dataset, because in this case there is no existing var to compare it to; + # or (2) if var to append's dtype is known to be easy-to-append, because in this case + # we can be confident appending won't cause problems. Examples of dtypes which are not + # easy-to-append include length-specified strings of type `|S*` or `