diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 2469a31a3d9..cdc74e06882 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -10,13 +10,20 @@ from .common import AbstractWritableDataStore, BackendArray, _encode_variable_name # need some special secret attributes to tell us the dimensions -_DIMENSION_KEY = "_ARRAY_DIMENSIONS" +DIMENSION_KEY = "_ARRAY_DIMENSIONS" -# zarr attributes have to be serializable as json -# many xarray datasets / variables have numpy arrays and values -# these functions handle encoding / decoding of such items -def _encode_zarr_attr_value(value): +def encode_zarr_attr_value(value): + """ + Encode a attribute value as something that can be serialized as json + + Many xarray datasets / variables have numpy arrays and values. This + function handles encoding / decoding of such items. + + ndarray -> list + scalar array -> scalar + other -> other (no change) + """ if isinstance(value, np.ndarray): encoded = value.tolist() # this checks if it's a scalar number @@ -170,7 +177,20 @@ def _get_zarr_dims_and_attrs(zarr_obj, dimension_key): return dimensions, attributes -def _extract_zarr_variable_encoding(variable, raise_on_invalid=False): +def extract_zarr_variable_encoding(variable, raise_on_invalid=False): + """ + Extract zarr encoding dictionary from xarray Variable + + Parameters + ---------- + variable : xarray.Variable + raise_on_invalid : bool, optional + + Returns + ------- + encoding : dict + Zarr encoding for `variable` + """ encoding = variable.encoding.copy() valid_encodings = {"chunks", "compressor", "filters", "cache_metadata"} @@ -271,7 +291,7 @@ def __init__(self, zarr_group, consolidate_on_close=False): def open_store_variable(self, name, zarr_array): data = indexing.LazilyOuterIndexedArray(ZarrArrayWrapper(name, self)) - dimensions, attributes = _get_zarr_dims_and_attrs(zarr_array, _DIMENSION_KEY) + dimensions, attributes = _get_zarr_dims_and_attrs(zarr_array, DIMENSION_KEY) attributes = dict(attributes) encoding = { "chunks": zarr_array.chunks, @@ -298,7 +318,7 @@ def get_dimensions(self): dimensions = {} for k, v in self.ds.arrays(): try: - for d, s in zip(v.attrs[_DIMENSION_KEY], v.shape): + for d, s in zip(v.attrs[DIMENSION_KEY], v.shape): if d in dimensions and dimensions[d] != s: raise ValueError( "found conflicting lengths for dimension %s " @@ -310,7 +330,7 @@ def get_dimensions(self): raise KeyError( "Zarr object is missing the attribute `%s`, " "which is required for xarray to determine " - "variable dimensions." % (_DIMENSION_KEY) + "variable dimensions." % (DIMENSION_KEY) ) return dimensions @@ -328,7 +348,7 @@ def encode_variable(self, variable): return variable def encode_attribute(self, a): - return _encode_zarr_attr_value(a) + return encode_zarr_attr_value(a) def store( self, @@ -433,10 +453,10 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No writer.add(v.data, zarr_array, region=tuple(new_region)) else: # new variable - encoding = _extract_zarr_variable_encoding(v, raise_on_invalid=check) + encoding = extract_zarr_variable_encoding(v, raise_on_invalid=check) encoded_attrs = {} # the magic for storing the hidden dimension data - encoded_attrs[_DIMENSION_KEY] = dims + encoded_attrs[DIMENSION_KEY] = dims for k2, v2 in attrs.items(): encoded_attrs[k2] = self.encode_attribute(v2) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index a4585985bdc..82fe1b38149 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -4498,3 +4498,50 @@ def test_invalid_netcdf_raises(engine): data = create_test_data() with raises_regex(ValueError, "unrecognized option 'invalid_netcdf'"): data.to_netcdf("foo.nc", engine=engine, invalid_netcdf=True) + + +@requires_zarr +def test_encode_zarr_attr_value(): + # array -> list + arr = np.array([1, 2, 3]) + expected = [1, 2, 3] + actual = backends.zarr.encode_zarr_attr_value(arr) + assert isinstance(actual, list) + assert actual == expected + + # scalar array -> scalar + sarr = np.array(1)[()] + expected = 1 + actual = backends.zarr.encode_zarr_attr_value(sarr) + assert isinstance(actual, int) + assert actual == expected + + # string -> string (no change) + expected = "foo" + actual = backends.zarr.encode_zarr_attr_value(expected) + assert isinstance(actual, str) + assert actual == expected + + +@requires_zarr +def test_extract_zarr_variable_encoding(): + + var = xr.Variable("x", [1, 2]) + actual = backends.zarr.extract_zarr_variable_encoding(var) + assert "chunks" in actual + assert actual["chunks"] is None + + var = xr.Variable("x", [1, 2], encoding={"chunks": (1,)}) + actual = backends.zarr.extract_zarr_variable_encoding(var) + assert actual["chunks"] == (1,) + + # does not raise on invalid + var = xr.Variable("x", [1, 2], encoding={"foo": (1,)}) + actual = backends.zarr.extract_zarr_variable_encoding(var) + + # raises on invalid + var = xr.Variable("x", [1, 2], encoding={"foo": (1,)}) + with raises_regex(ValueError, "unexpected encoding parameters"): + actual = backends.zarr.extract_zarr_variable_encoding( + var, raise_on_invalid=True + )