Skip to content

Commit e3ad220

Browse files
committed
Add shards to array strategy
1 parent 870265a commit e3ad220

File tree

2 files changed

+38
-16
lines changed

2 files changed

+38
-16
lines changed

changes/2822.feature.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add arbitrary `shards` to Hypothesis strategy for generating arrays.

src/zarr/testing/strategies.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,32 @@ def numpy_arrays(
110110
return draw(npst.arrays(dtype=dtype, shape=shapes))
111111

112112

113+
@st.composite # type: ignore[misc]
114+
def chunk_shapes(draw: st.DrawFn, *, shape: tuple[int, ...]) -> tuple[int, ...]:
115+
# We want this strategy to shrink towards arrays with smaller number of chunks
116+
# 1. st.integers() shrinks towards smaller values. So we use that to generate number of chunks
117+
numchunks = draw(
118+
st.tuples(*[st.integers(min_value=0 if size == 0 else 1, max_value=size) for size in shape])
119+
)
120+
# 2. and now generate the chunks tuple
121+
return tuple(
122+
size // nchunks if nchunks > 0 else 0
123+
for size, nchunks in zip(shape, numchunks, strict=True)
124+
)
125+
126+
127+
@st.composite # type: ignore[misc]
128+
def shard_shapes(
129+
draw: st.DrawFn, *, shape: tuple[int, ...], chunk_shape: tuple[int, ...]
130+
) -> tuple[int, ...]:
131+
# We want this strategy to shrink towards arrays with smaller number of shards
132+
# shards must be an integral number of chunks
133+
assert all(c != 0 for c in chunk_shape)
134+
numchunks = tuple(s // c for s, c in zip(shape, chunk_shape, strict=True))
135+
multiples = tuple(draw(st.integers(min_value=1, max_value=nc)) for nc in numchunks)
136+
return tuple(m * c for m, c in zip(multiples, chunk_shape, strict=True))
137+
138+
113139
@st.composite # type: ignore[misc]
114140
def np_array_and_chunks(
115141
draw: st.DrawFn, *, arrays: st.SearchStrategy[np.ndarray] = numpy_arrays
@@ -119,19 +145,7 @@ def np_array_and_chunks(
119145
Returns: a tuple of the array and a suitable random chunking for it.
120146
"""
121147
array = draw(arrays)
122-
# We want this strategy to shrink towards arrays with smaller number of chunks
123-
# 1. st.integers() shrinks towards smaller values. So we use that to generate number of chunks
124-
numchunks = draw(
125-
st.tuples(
126-
*[st.integers(min_value=0 if size == 0 else 1, max_value=size) for size in array.shape]
127-
)
128-
)
129-
# 2. and now generate the chunks tuple
130-
chunks = tuple(
131-
size // nchunks if nchunks > 0 else 0
132-
for size, nchunks in zip(array.shape, numchunks, strict=True)
133-
)
134-
return (array, chunks)
148+
return (array, draw(chunk_shapes(shape=array.shape)))
135149

136150

137151
@st.composite # type: ignore[misc]
@@ -154,7 +168,12 @@ def arrays(
154168
zarr_format = draw(zarr_formats)
155169
if arrays is None:
156170
arrays = numpy_arrays(shapes=shapes, zarr_formats=st.just(zarr_format))
157-
nparray, chunks = draw(np_array_and_chunks(arrays=arrays))
171+
nparray = draw(arrays)
172+
chunk_shape = draw(chunk_shapes(shape=nparray.shape))
173+
if zarr_format == 3 and all(c > 0 for c in chunk_shape):
174+
shard_shape = draw(st.none() | shard_shapes(shape=nparray.shape, chunk_shape=chunk_shape))
175+
else:
176+
shard_shape = None
158177
# test that None works too.
159178
fill_value = draw(st.one_of([st.none(), npst.from_dtype(nparray.dtype)]))
160179
# compressor = draw(compressors)
@@ -167,7 +186,8 @@ def arrays(
167186
a = root.create_array(
168187
array_path,
169188
shape=nparray.shape,
170-
chunks=chunks,
189+
chunks=chunk_shape,
190+
shards=shard_shape,
171191
dtype=nparray.dtype,
172192
attributes=attributes,
173193
# compressor=compressor, # FIXME
@@ -180,7 +200,8 @@ def arrays(
180200
assert a.name is not None
181201
assert isinstance(root[array_path], Array)
182202
assert nparray.shape == a.shape
183-
assert chunks == a.chunks
203+
assert chunk_shape == a.chunks
204+
assert shard_shape == a.shards
184205
assert array_path == a.path, (path, name, array_path, a.name, a.path)
185206
assert a.basename == name, (a.basename, name)
186207
assert dict(a.attrs) == expected_attrs

0 commit comments

Comments
 (0)