Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 60 additions & 7 deletions src/zarr/testing/stateful.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import builtins
from typing import Any
import functools
from collections.abc import Callable
from typing import Any, TypeVar, cast

import hypothesis.extra.numpy as npst
import hypothesis.strategies as st
Expand All @@ -24,15 +26,43 @@
from zarr.testing.strategies import (
basic_indices,
chunk_paths,
dimension_names,
key_ranges,
node_names,
np_array_and_chunks,
numpy_arrays,
orthogonal_indices,
)
from zarr.testing.strategies import keys as zarr_keys

MAX_BINARY_SIZE = 100

F = TypeVar("F", bound=Callable[..., Any])


def with_frequency(frequency: float) -> Callable[[F], F]:
"""This needs to be deterministic for hypothesis replaying"""

def decorator(func: F) -> F:
counter_attr = f"__{func.__name__}_counter"

@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
return func(*args, **kwargs)

@precondition
def frequency_check(f: Any) -> Any:
if not hasattr(f, counter_attr):
setattr(f, counter_attr, 0)

current_count = getattr(f, counter_attr) + 1
setattr(f, counter_attr, current_count)

return (current_count * frequency) % 1.0 >= (1.0 - frequency)

return cast(F, frequency_check(wrapper))

return decorator


def split_prefix_name(path: str) -> tuple[str, str]:
split = path.rsplit("/", maxsplit=1)
Expand Down Expand Up @@ -90,11 +120,7 @@ def add_group(self, name: str, data: DataObject) -> None:
zarr.group(store=self.store, path=path)
zarr.group(store=self.model, path=path)

@rule(
data=st.data(),
name=node_names,
array_and_chunks=np_array_and_chunks(arrays=numpy_arrays(zarr_formats=st.just(3))),
)
@rule(data=st.data(), name=node_names, array_and_chunks=np_array_and_chunks())
def add_array(
self,
data: DataObject,
Expand Down Expand Up @@ -122,12 +148,17 @@ def add_array(
path=path,
store=store,
fill_value=fill_value,
zarr_format=3,
dimension_names=data.draw(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixes an oversight where we were never setting dimension names on these arrays

dimension_names(ndim=array.ndim), label="dimension names"
),
# Chose bytes codec to avoid wasting time compressing the data being written
codecs=[BytesCodec()],
)
self.all_arrays.add(path)

@rule()
@with_frequency(0.25)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from logs it seems like hypothesis was clearing the store quite frequently, so I'm reducing frequency here.

def clear(self) -> None:
note("clearing")
import zarr
Expand Down Expand Up @@ -192,6 +223,14 @@ def delete_chunk(self, data: DataObject) -> None:
self._sync(self.model.delete(path))
self._sync(self.store.delete(path))

@precondition(lambda self: bool(self.all_arrays))
@rule(data=st.data())
def check_array(self, data: DataObject) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new check that asserts the model and the tested store have the same array data

path = data.draw(st.sampled_from(sorted(self.all_arrays)))
actual = zarr.open_array(self.store, path=path)[:]
expected = zarr.open_array(self.model, path=path)[:]
np.testing.assert_equal(actual, expected)

@precondition(lambda self: bool(self.all_arrays))
@rule(data=st.data())
def overwrite_array_basic_indexing(self, data: DataObject) -> None:
Expand All @@ -206,6 +245,20 @@ def overwrite_array_basic_indexing(self, data: DataObject) -> None:
model_array[slicer] = new_data
store_array[slicer] = new_data

@precondition(lambda self: bool(self.all_arrays))
@rule(data=st.data())
def overwrite_array_orthogonal_indexing(self, data: DataObject) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adds a new step modeling a user overwriting existing array data with .oindex

array = data.draw(st.sampled_from(sorted(self.all_arrays)))
model_array = zarr.open_array(path=array, store=self.model)
store_array = zarr.open_array(path=array, store=self.store)
indexer, _ = data.draw(orthogonal_indices(shape=model_array.shape))
note(f"overwriting array orthogonal {indexer=}")
new_data = data.draw(
npst.arrays(shape=model_array.oindex[indexer].shape, dtype=model_array.dtype) # type: ignore[union-attr]
)
model_array.oindex[indexer] = new_data
store_array.oindex[indexer] = new_data

@precondition(lambda self: bool(self.all_arrays))
@rule(data=st.data())
def resize_array(self, data: DataObject) -> None:
Expand Down
68 changes: 37 additions & 31 deletions src/zarr/testing/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def paths(draw: st.DrawFn, *, max_num_nodes: int | None = None) -> str:
return draw(st.just("/") | keys(max_num_nodes=max_num_nodes))


def v3_dtypes() -> st.SearchStrategy[np.dtype[Any]]:
def dtypes() -> st.SearchStrategy[np.dtype[Any]]:
return (
npst.boolean_dtypes()
| npst.integer_dtypes(endianness="=")
Expand All @@ -57,18 +57,12 @@ def v3_dtypes() -> st.SearchStrategy[np.dtype[Any]]:
)


def v3_dtypes() -> st.SearchStrategy[np.dtype[Any]]:
Copy link
Contributor Author

@dcherian dcherian Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are unnecessary now. I guess we could add DeprecationWarning asking the user to use zarr.testing.strategies.dtypes() instead. I can do this next week

return dtypes()


def v2_dtypes() -> st.SearchStrategy[np.dtype[Any]]:
return (
npst.boolean_dtypes()
| npst.integer_dtypes(endianness="=")
| npst.unsigned_integer_dtypes(endianness="=")
| npst.floating_dtypes(endianness="=")
| npst.complex_number_dtypes(endianness="=")
| npst.byte_string_dtypes(endianness="=")
| npst.unicode_string_dtypes(endianness="=")
| npst.datetime64_dtypes(endianness="=")
| npst.timedelta64_dtypes(endianness="=")
)
return dtypes()


def safe_unicode_for_dtype(dtype: np.dtype[np.str_]) -> st.SearchStrategy[str]:
Expand Down Expand Up @@ -144,7 +138,7 @@ def array_metadata(
shape = draw(array_shapes())
ndim = len(shape)
chunk_shape = draw(array_shapes(min_dims=ndim, max_dims=ndim))
np_dtype = draw(v3_dtypes())
np_dtype = draw(dtypes())
dtype = get_data_type_from_native_dtype(np_dtype)
fill_value = draw(npst.from_dtype(np_dtype))
if zarr_format == 2:
Expand Down Expand Up @@ -179,14 +173,12 @@ def numpy_arrays(
*,
shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes,
dtype: np.dtype[Any] | None = None,
zarr_formats: st.SearchStrategy[ZarrFormat] = zarr_formats,
) -> npt.NDArray[Any]:
"""
Generate numpy arrays that can be saved in the provided Zarr format.
"""
zarr_format = draw(zarr_formats)
if dtype is None:
dtype = draw(v3_dtypes() if zarr_format == 3 else v2_dtypes())
dtype = draw(dtypes())
if np.issubdtype(dtype, np.str_):
safe_unicode_strings = safe_unicode_for_dtype(dtype)
return draw(npst.arrays(dtype=dtype, shape=shapes, elements=safe_unicode_strings))
Expand Down Expand Up @@ -255,17 +247,24 @@ def arrays(
attrs: st.SearchStrategy = attrs,
zarr_formats: st.SearchStrategy = zarr_formats,
) -> Array:
store = draw(stores)
path = draw(paths)
name = draw(array_names)
attributes = draw(attrs)
zarr_format = draw(zarr_formats)
store = draw(stores, label="store")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just nicer logs.

path = draw(paths, label="array parent")
name = draw(array_names, label="array name")
attributes = draw(attrs, label="attributes")
zarr_format = draw(zarr_formats, label="zarr format")
if arrays is None:
arrays = numpy_arrays(shapes=shapes, zarr_formats=st.just(zarr_format))
nparray = draw(arrays)
chunk_shape = draw(chunk_shapes(shape=nparray.shape))
arrays = numpy_arrays(shapes=shapes)
nparray = draw(arrays, label="array data")
chunk_shape = draw(chunk_shapes(shape=nparray.shape), label="chunk shape")
extra_kwargs = {}
if zarr_format == 3 and all(c > 0 for c in chunk_shape):
shard_shape = draw(st.none() | shard_shapes(shape=nparray.shape, chunk_shape=chunk_shape))
shard_shape = draw(
st.none() | shard_shapes(shape=nparray.shape, chunk_shape=chunk_shape),
label="shard shape",
)
extra_kwargs["dimension_names"] = draw(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, i forgot to set dimension_names earlier

dimension_names(ndim=nparray.ndim), label="dimension names"
)
else:
shard_shape = None
# test that None works too.
Expand All @@ -286,6 +285,7 @@ def arrays(
attributes=attributes,
# compressor=compressor, # FIXME
fill_value=fill_value,
**extra_kwargs,
)

assert isinstance(a, Array)
Expand Down Expand Up @@ -385,13 +385,19 @@ def orthogonal_indices(
npindexer = []
ndim = len(shape)
for axis, size in enumerate(shape):
val = draw(
Copy link
Contributor Author

@dcherian dcherian Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strategy now works for 0-D

npst.integer_array_indices(
if size != 0:
strategy = npst.integer_array_indices(
shape=(size,), result_shape=npst.array_shapes(min_side=1, max_side=size, max_dims=1)
)
| basic_indices(min_dims=1, shape=(size,), allow_ellipsis=False)
.map(lambda x: (x,) if not isinstance(x, tuple) else x) # bare ints, slices
.filter(bool) # skip empty tuple
) | basic_indices(min_dims=1, shape=(size,), allow_ellipsis=False)
else:
strategy = basic_indices(min_dims=1, shape=(size,), allow_ellipsis=False)

val = draw(
strategy
# bare ints, slices
.map(lambda x: (x,) if not isinstance(x, tuple) else x)
# skip empty tuple
.filter(bool)
)
(idxr,) = val
if isinstance(idxr, int):
Expand Down
8 changes: 4 additions & 4 deletions tests/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ def deep_equal(a: Any, b: Any) -> bool:


@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
@given(data=st.data(), zarr_format=zarr_formats)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed now that all dtypes are supported for both versions

def test_array_roundtrip(data: st.DataObject, zarr_format: int) -> None:
nparray = data.draw(numpy_arrays(zarr_formats=st.just(zarr_format)))
zarray = data.draw(arrays(arrays=st.just(nparray), zarr_formats=st.just(zarr_format)))
@given(data=st.data())
def test_array_roundtrip(data: st.DataObject) -> None:
nparray = data.draw(numpy_arrays())
zarray = data.draw(arrays(arrays=st.just(nparray)))
assert_array_equal(nparray, zarray[:])


Expand Down