Skip to content

Commit 566e524

Browse files
committed
Add codecs, shards to array strategy
1 parent feeb08f commit 566e524

File tree

1 file changed

+61
-22
lines changed

1 file changed

+61
-22
lines changed

src/zarr/testing/strategies.py

Lines changed: 61 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import sys
2-
from typing import Any
2+
from typing import Any, Literal
33

44
import hypothesis.extra.numpy as npst
55
import hypothesis.strategies as st
@@ -86,11 +86,30 @@ def safe_unicode_for_dtype(dtype: np.dtype[np.str_]) -> st.SearchStrategy[str]:
8686
# i.e. stores.examples() will always return the same object per Store class.
8787
# So we map a clear to reset the store.
8888
stores = st.builds(MemoryStore, st.just({})).map(lambda x: sync(x.clear()))
89-
compressors = st.sampled_from([None, "default"])
9089
zarr_formats: st.SearchStrategy[ZarrFormat] = st.sampled_from([2, 3])
9190
array_shapes = npst.array_shapes(max_dims=4, min_side=0)
9291

9392

93+
@st.composite # type: ignore[misc]
94+
def codecs(
95+
draw: st.DrawFn,
96+
*,
97+
zarr_formats: st.SearchStrategy[Literal[2, 3]] = zarr_formats,
98+
dtypes: st.SearchStrategy[np.dtype] | None = None,
99+
) -> Any:
100+
zarr_format = draw(zarr_formats)
101+
codec_kwargs = {"filters": draw(st.none() | st.just(()))}
102+
zarr_codecs = st.one_of(
103+
st.builds(zarr.codecs.ZstdCodec, level=st.integers(min_value=0, max_value=9)),
104+
# TODO: other codecs
105+
)
106+
if zarr_format == 2:
107+
codec_kwargs["compressors"] = draw(st.none() | st.just(()))
108+
else:
109+
codec_kwargs["compressors"] = draw(st.none() | st.just(()) | zarr_codecs)
110+
return codec_kwargs
111+
112+
94113
@st.composite # type: ignore[misc]
95114
def numpy_arrays(
96115
draw: st.DrawFn,
@@ -110,6 +129,32 @@ def numpy_arrays(
110129
return draw(npst.arrays(dtype=dtype, shape=shapes))
111130

112131

132+
@st.composite # type: ignore[misc]
133+
def chunk_shapes(draw: st.DrawFn, *, shape: tuple[int, ...]) -> tuple[int, ...]:
134+
# We want this strategy to shrink towards arrays with smaller number of chunks
135+
# 1. st.integers() shrinks towards smaller values. So we use that to generate number of chunks
136+
numchunks = draw(
137+
st.tuples(*[st.integers(min_value=0 if size == 0 else 1, max_value=size) for size in shape])
138+
)
139+
# 2. and now generate the chunks tuple
140+
return tuple(
141+
size // nchunks if nchunks > 0 else 0
142+
for size, nchunks in zip(shape, numchunks, strict=True)
143+
)
144+
145+
146+
@st.composite # type: ignore[misc]
147+
def shard_shapes(
148+
draw: st.DrawFn, *, shape: tuple[int, ...], chunk_shape: tuple[int, ...]
149+
) -> tuple[int, ...]:
150+
# We want this strategy to shrink towards arrays with smaller number of shards
151+
# shards must be an integral number of chunks
152+
assert all(c != 0 for c in chunk_shape)
153+
numchunks = tuple(s // c for s, c in zip(shape, chunk_shape, strict=True))
154+
multiples = tuple(draw(st.integers(min_value=1, max_value=nc)) for nc in numchunks)
155+
return tuple(m * c for m, c in zip(multiples, chunk_shape, strict=True))
156+
157+
113158
@st.composite # type: ignore[misc]
114159
def np_array_and_chunks(
115160
draw: st.DrawFn, *, arrays: st.SearchStrategy[np.ndarray] = numpy_arrays
@@ -119,32 +164,20 @@ def np_array_and_chunks(
119164
Returns: a tuple of the array and a suitable random chunking for it.
120165
"""
121166
array = draw(arrays)
122-
# We want this strategy to shrink towards arrays with smaller number of chunks
123-
# 1. st.integers() shrinks towards smaller values. So we use that to generate number of chunks
124-
numchunks = draw(
125-
st.tuples(
126-
*[st.integers(min_value=0 if size == 0 else 1, max_value=size) for size in array.shape]
127-
)
128-
)
129-
# 2. and now generate the chunks tuple
130-
chunks = tuple(
131-
size // nchunks if nchunks > 0 else 0
132-
for size, nchunks in zip(array.shape, numchunks, strict=True)
133-
)
134-
return (array, chunks)
167+
return (array, draw(chunk_shapes, array.shape))
135168

136169

137170
@st.composite # type: ignore[misc]
138171
def arrays(
139172
draw: st.DrawFn,
140173
*,
141174
shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes,
142-
compressors: st.SearchStrategy = compressors,
143175
stores: st.SearchStrategy[StoreLike] = stores,
144176
paths: st.SearchStrategy[str | None] = paths,
145177
array_names: st.SearchStrategy = array_names,
146178
arrays: st.SearchStrategy | None = None,
147179
attrs: st.SearchStrategy = attrs,
180+
codecs: st.SearchStrategy = codecs,
148181
zarr_formats: st.SearchStrategy = zarr_formats,
149182
) -> Array:
150183
store = draw(stores)
@@ -154,24 +187,29 @@ def arrays(
154187
zarr_format = draw(zarr_formats)
155188
if arrays is None:
156189
arrays = numpy_arrays(shapes=shapes, zarr_formats=st.just(zarr_format))
157-
nparray, chunks = draw(np_array_and_chunks(arrays=arrays))
190+
nparray = draw(arrays)
191+
chunk_shape = draw(chunk_shapes(shape=nparray.shape))
192+
if zarr_format == 3 and all(c > 0 for c in chunk_shape):
193+
shard_shape = draw(st.none() | shard_shapes(shape=nparray.shape, chunk_shape=chunk_shape))
194+
else:
195+
shard_shape = None
158196
# test that None works too.
159197
fill_value = draw(st.one_of([st.none(), npst.from_dtype(nparray.dtype)]))
160-
# compressor = draw(compressors)
161198

162199
expected_attrs = {} if attributes is None else attributes
163200

164201
array_path = _dereference_path(path, name)
165202
root = zarr.open_group(store, mode="w", zarr_format=zarr_format)
166-
203+
codec_kwargs = draw(codecs(zarr_formats=st.just(zarr_format), dtypes=st.just(nparray.dtype)))
167204
a = root.create_array(
168205
array_path,
169206
shape=nparray.shape,
170-
chunks=chunks,
207+
chunks=chunk_shape,
208+
shards=shard_shape,
171209
dtype=nparray.dtype,
172210
attributes=attributes,
173-
# compressor=compressor, # FIXME
174211
fill_value=fill_value,
212+
**codec_kwargs,
175213
)
176214

177215
assert isinstance(a, Array)
@@ -180,7 +218,8 @@ def arrays(
180218
assert a.name is not None
181219
assert isinstance(root[array_path], Array)
182220
assert nparray.shape == a.shape
183-
assert chunks == a.chunks
221+
assert chunk_shape == a.chunks
222+
assert shard_shape == a.shards
184223
assert array_path == a.path, (path, name, array_path, a.name, a.path)
185224
assert a.basename == name, (a.basename, name)
186225
assert dict(a.attrs) == expected_attrs

0 commit comments

Comments
 (0)