From 2567f5393443cc446d8e55dfce52628a27969911 Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Wed, 25 Mar 2020 22:21:56 -0700 Subject: [PATCH 1/3] expose a few zarr backend functions as semi-public api --- xarray/backends/zarr.py | 44 ++++++++++++++++++++++++---------- xarray/tests/test_backends.py | 45 +++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 12 deletions(-) diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 2469a31a3d9..cdc74e06882 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -10,13 +10,20 @@ from .common import AbstractWritableDataStore, BackendArray, _encode_variable_name # need some special secret attributes to tell us the dimensions -_DIMENSION_KEY = "_ARRAY_DIMENSIONS" +DIMENSION_KEY = "_ARRAY_DIMENSIONS" -# zarr attributes have to be serializable as json -# many xarray datasets / variables have numpy arrays and values -# these functions handle encoding / decoding of such items -def _encode_zarr_attr_value(value): +def encode_zarr_attr_value(value): + """ + Encode a attribute value as something that can be serialized as json + + Many xarray datasets / variables have numpy arrays and values. This + function handles encoding / decoding of such items. + + ndarray -> list + scalar array -> scalar + other -> other (no change) + """ if isinstance(value, np.ndarray): encoded = value.tolist() # this checks if it's a scalar number @@ -170,7 +177,20 @@ def _get_zarr_dims_and_attrs(zarr_obj, dimension_key): return dimensions, attributes -def _extract_zarr_variable_encoding(variable, raise_on_invalid=False): +def extract_zarr_variable_encoding(variable, raise_on_invalid=False): + """ + Extract zarr encoding dictionary from xarray Variable + + Parameters + ---------- + variable : xarray.Variable + raise_on_invalid : bool, optional + + Returns + ------- + encoding : dict + Zarr encoding for `variable` + """ encoding = variable.encoding.copy() valid_encodings = {"chunks", "compressor", "filters", "cache_metadata"} @@ -271,7 +291,7 @@ def __init__(self, zarr_group, consolidate_on_close=False): def open_store_variable(self, name, zarr_array): data = indexing.LazilyOuterIndexedArray(ZarrArrayWrapper(name, self)) - dimensions, attributes = _get_zarr_dims_and_attrs(zarr_array, _DIMENSION_KEY) + dimensions, attributes = _get_zarr_dims_and_attrs(zarr_array, DIMENSION_KEY) attributes = dict(attributes) encoding = { "chunks": zarr_array.chunks, @@ -298,7 +318,7 @@ def get_dimensions(self): dimensions = {} for k, v in self.ds.arrays(): try: - for d, s in zip(v.attrs[_DIMENSION_KEY], v.shape): + for d, s in zip(v.attrs[DIMENSION_KEY], v.shape): if d in dimensions and dimensions[d] != s: raise ValueError( "found conflicting lengths for dimension %s " @@ -310,7 +330,7 @@ def get_dimensions(self): raise KeyError( "Zarr object is missing the attribute `%s`, " "which is required for xarray to determine " - "variable dimensions." % (_DIMENSION_KEY) + "variable dimensions." % (DIMENSION_KEY) ) return dimensions @@ -328,7 +348,7 @@ def encode_variable(self, variable): return variable def encode_attribute(self, a): - return _encode_zarr_attr_value(a) + return encode_zarr_attr_value(a) def store( self, @@ -433,10 +453,10 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No writer.add(v.data, zarr_array, region=tuple(new_region)) else: # new variable - encoding = _extract_zarr_variable_encoding(v, raise_on_invalid=check) + encoding = extract_zarr_variable_encoding(v, raise_on_invalid=check) encoded_attrs = {} # the magic for storing the hidden dimension data - encoded_attrs[_DIMENSION_KEY] = dims + encoded_attrs[DIMENSION_KEY] = dims for k2, v2 in attrs.items(): encoded_attrs[k2] = self.encode_attribute(v2) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index a4585985bdc..a2956f8ad99 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -4498,3 +4498,48 @@ def test_invalid_netcdf_raises(engine): data = create_test_data() with raises_regex(ValueError, "unrecognized option 'invalid_netcdf'"): data.to_netcdf("foo.nc", engine=engine, invalid_netcdf=True) + + +@requires_zarr +def test_encode_zarr_attr_value(): + # array -> list + arr = np.array([1, 2, 3]) + expected = [1, 2, 3] + actual = backends.zarr.encode_zarr_attr_value(arr) + assert isinstance(actual, list) + assert actual == expected + + # scalar array -> scalar + sarr = np.array(1)[()] + expected = 1 + actual = backends.zarr.encode_zarr_attr_value(sarr) + assert isinstance(actual, int) + assert actual == expected + + # string -> string (no change) + expected = 'foo' + actual = backends.zarr.encode_zarr_attr_value(expected) + assert isinstance(actual, str) + assert actual == expected + + +@requires_zarr +def test_extract_zarr_variable_encoding(): + + var = xr.Variable('x', [1, 2]) + actual = backends.zarr.extract_zarr_variable_encoding(var) + assert 'chunks' in actual + assert actual['chunks'] is None + + var = xr.Variable('x', [1, 2], encoding={'chunks': (1, )}) + actual = backends.zarr.extract_zarr_variable_encoding(var) + assert actual['chunks'] is (1, ) + + # does not raise on invalid + var = xr.Variable('x', [1, 2], encoding={'foo': (1, )}) + actual = backends.zarr.extract_zarr_variable_encoding(var) + + # raises on invalid + var = xr.Variable('x', [1, 2], encoding={'foo': (1, )}) + with raises_regex(ValueError, "unexpected encoding parameters"): + actual = backends.zarr.extract_zarr_variable_encoding(var, raise_on_invalid=True) From 4657673bc9a749afbbad8d9ecb49000433df5c11 Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Wed, 25 Mar 2020 22:24:05 -0700 Subject: [PATCH 2/3] black --- xarray/tests/test_backends.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index a2956f8ad99..63b2cb2ff00 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -4517,7 +4517,7 @@ def test_encode_zarr_attr_value(): assert actual == expected # string -> string (no change) - expected = 'foo' + expected = "foo" actual = backends.zarr.encode_zarr_attr_value(expected) assert isinstance(actual, str) assert actual == expected @@ -4526,20 +4526,22 @@ def test_encode_zarr_attr_value(): @requires_zarr def test_extract_zarr_variable_encoding(): - var = xr.Variable('x', [1, 2]) + var = xr.Variable("x", [1, 2]) actual = backends.zarr.extract_zarr_variable_encoding(var) - assert 'chunks' in actual - assert actual['chunks'] is None + assert "chunks" in actual + assert actual["chunks"] is None - var = xr.Variable('x', [1, 2], encoding={'chunks': (1, )}) + var = xr.Variable("x", [1, 2], encoding={"chunks": (1,)}) actual = backends.zarr.extract_zarr_variable_encoding(var) - assert actual['chunks'] is (1, ) + assert actual["chunks"] is (1,) # does not raise on invalid - var = xr.Variable('x', [1, 2], encoding={'foo': (1, )}) + var = xr.Variable("x", [1, 2], encoding={"foo": (1,)}) actual = backends.zarr.extract_zarr_variable_encoding(var) # raises on invalid - var = xr.Variable('x', [1, 2], encoding={'foo': (1, )}) + var = xr.Variable("x", [1, 2], encoding={"foo": (1,)}) with raises_regex(ValueError, "unexpected encoding parameters"): - actual = backends.zarr.extract_zarr_variable_encoding(var, raise_on_invalid=True) + actual = backends.zarr.extract_zarr_variable_encoding( + var, raise_on_invalid=True + ) From 3d36382753d8d9c106a50db6e9d1fe0ff4d6832d Mon Sep 17 00:00:00 2001 From: Joseph Hamman Date: Wed, 25 Mar 2020 22:56:40 -0700 Subject: [PATCH 3/3] update equality check for chunks --- xarray/tests/test_backends.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 63b2cb2ff00..82fe1b38149 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -4533,7 +4533,7 @@ def test_extract_zarr_variable_encoding(): var = xr.Variable("x", [1, 2], encoding={"chunks": (1,)}) actual = backends.zarr.extract_zarr_variable_encoding(var) - assert actual["chunks"] is (1,) + assert actual["chunks"] == (1,) # does not raise on invalid var = xr.Variable("x", [1, 2], encoding={"foo": (1,)})