Skip to content

Commit ceb3b36

Browse files
jhammand-v-b
andauthored
test: check that store, array, and group classes are serializable (#2006)
* test: check that store, array, and group classes are serializable w/ pickle and can be dependably roundtripped * raise if MemoryStore is pickled * Apply suggestions from code review Co-authored-by: Davis Bennett <[email protected]> * fix typos * new buffer __eq__ * pickle support for zip store --------- Co-authored-by: Davis Bennett <[email protected]>
1 parent f1bd703 commit ceb3b36

File tree

10 files changed

+127
-194
lines changed

10 files changed

+127
-194
lines changed

src/zarr/abc/store.py

+5
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ def _check_writable(self) -> None:
8383
if self.mode.readonly:
8484
raise ValueError("store mode does not support writing")
8585

86+
@abstractmethod
87+
def __eq__(self, value: object) -> bool:
88+
"""Equality comparison."""
89+
...
90+
8691
@abstractmethod
8792
async def get(
8893
self,

src/zarr/core/buffer/core.py

+6
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,12 @@ def __add__(self, other: Buffer) -> Self:
281281
"""Concatenate two buffers"""
282282
...
283283

284+
def __eq__(self, other: object) -> bool:
285+
# Another Buffer class can override this to choose a more efficient path
286+
return isinstance(other, Buffer) and np.array_equal(
287+
self.as_numpy_array(), other.as_numpy_array()
288+
)
289+
284290

285291
class NDBuffer:
286292
"""An n-dimensional memory block

src/zarr/store/memory.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from collections.abc import AsyncGenerator, MutableMapping
4+
from typing import TYPE_CHECKING, Any
45

56
from zarr.abc.store import Store
67
from zarr.core.buffer import Buffer, gpu
@@ -47,6 +48,19 @@ def __str__(self) -> str:
4748
def __repr__(self) -> str:
4849
return f"MemoryStore({str(self)!r})"
4950

51+
def __eq__(self, other: object) -> bool:
52+
return (
53+
isinstance(other, type(self))
54+
and self._store_dict == other._store_dict
55+
and self.mode == other.mode
56+
)
57+
58+
def __setstate__(self, state: Any) -> None:
59+
raise NotImplementedError(f"{type(self)} cannot be pickled")
60+
61+
def __getstate__(self) -> None:
62+
raise NotImplementedError(f"{type(self)} cannot be pickled")
63+
5064
async def get(
5165
self,
5266
key: str,

src/zarr/store/remote.py

+10
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(
5151
this must not be used.
5252
"""
5353
super().__init__(mode=mode)
54+
self._storage_options = storage_options
5455
if isinstance(url, str):
5556
self._url = url.rstrip("/")
5657
self._fs, _path = fsspec.url_to_fs(url, **storage_options)
@@ -91,6 +92,15 @@ def __str__(self) -> str:
9192
def __repr__(self) -> str:
9293
return f"<RemoteStore({type(self._fs).__name__}, {self.path})>"
9394

95+
def __eq__(self, other: object) -> bool:
96+
return (
97+
isinstance(other, type(self))
98+
and self.path == other.path
99+
and self.mode == other.mode
100+
and self._url == other._url
101+
# and self._storage_options == other._storage_options # FIXME: this isn't working for some reason
102+
)
103+
94104
async def get(
95105
self,
96106
key: str,

src/zarr/store/zip.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
import zipfile
77
from pathlib import Path
8-
from typing import TYPE_CHECKING, Literal
8+
from typing import TYPE_CHECKING, Any, Literal
99

1010
from zarr.abc.store import Store
1111
from zarr.core.buffer import Buffer, BufferPrototype
@@ -68,7 +68,7 @@ def __init__(
6868
self.compression = compression
6969
self.allowZip64 = allowZip64
7070

71-
async def _open(self) -> None:
71+
def _sync_open(self) -> None:
7272
if self._is_open:
7373
raise ValueError("store is already open")
7474

@@ -83,6 +83,17 @@ async def _open(self) -> None:
8383

8484
self._is_open = True
8585

86+
async def _open(self) -> None:
87+
self._sync_open()
88+
89+
def __getstate__(self) -> tuple[Path, ZipStoreAccessModeLiteral, int, bool]:
90+
return self.path, self._zmode, self.compression, self.allowZip64
91+
92+
def __setstate__(self, state: Any) -> None:
93+
self.path, self._zmode, self.compression, self.allowZip64 = state
94+
self._is_open = False
95+
self._sync_open()
96+
8697
def close(self) -> None:
8798
super().close()
8899
with self._lock:

src/zarr/testing/store.py

+14
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pickle
12
from typing import Any, Generic, TypeVar
23

34
import pytest
@@ -48,6 +49,19 @@ def test_store_type(self, store: S) -> None:
4849
assert isinstance(store, Store)
4950
assert isinstance(store, self.store_cls)
5051

52+
def test_store_eq(self, store: S, store_kwargs: dict[str, Any]) -> None:
53+
# check self equality
54+
assert store == store
55+
56+
# check store equality with same inputs
57+
# asserting this is important for being able to compare (de)serialized stores
58+
store2 = self.store_cls(**store_kwargs)
59+
assert store == store2
60+
61+
def test_serizalizable_store(self, store: S) -> None:
62+
foo = pickle.dumps(store)
63+
assert pickle.loads(foo) == store
64+
5165
def test_store_mode(self, store: S, store_kwargs: dict[str, Any]) -> None:
5266
assert store.mode == AccessMode.from_literal("r+")
5367
assert not store.mode.readonly

tests/v3/test_array.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import pickle
12
from typing import Literal
23

34
import numpy as np
45
import pytest
56

6-
from zarr import Array, Group
7+
from zarr import Array, AsyncArray, Group
78
from zarr.core.common import ZarrFormat
89
from zarr.errors import ContainsArrayError, ContainsGroupError
910
from zarr.store import LocalStore, MemoryStore
@@ -135,3 +136,36 @@ def test_array_v3_fill_value(store: MemoryStore, fill_value: int, dtype_str: str
135136

136137
assert arr.fill_value == np.dtype(dtype_str).type(fill_value)
137138
assert arr.fill_value.dtype == arr.dtype
139+
140+
141+
@pytest.mark.parametrize("store", ("local",), indirect=["store"])
142+
@pytest.mark.parametrize("zarr_format", (2, 3))
143+
async def test_serializable_async_array(
144+
store: LocalStore | MemoryStore, zarr_format: ZarrFormat
145+
) -> None:
146+
expected = await AsyncArray.create(
147+
store=store, shape=(100,), chunks=(10,), zarr_format=zarr_format, dtype="i4"
148+
)
149+
# await expected.setitems(list(range(100)))
150+
151+
p = pickle.dumps(expected)
152+
actual = pickle.loads(p)
153+
154+
assert actual == expected
155+
# np.testing.assert_array_equal(await actual.getitem(slice(None)), await expected.getitem(slice(None)))
156+
# TODO: uncomment the parts of this test that will be impacted by the config/prototype changes in flight
157+
158+
159+
@pytest.mark.parametrize("store", ("local",), indirect=["store"])
160+
@pytest.mark.parametrize("zarr_format", (2, 3))
161+
def test_serializable_sync_array(store: LocalStore, zarr_format: ZarrFormat) -> None:
162+
expected = Array.create(
163+
store=store, shape=(100,), chunks=(10,), zarr_format=zarr_format, dtype="i4"
164+
)
165+
expected[:] = list(range(100))
166+
167+
p = pickle.dumps(expected)
168+
actual = pickle.loads(p)
169+
170+
assert actual == expected
171+
np.testing.assert_array_equal(actual[:], expected[:])

tests/v3/test_group.py

+17-148
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
11
from __future__ import annotations
22

3+
import pickle
34
from typing import TYPE_CHECKING, Any, Literal, cast
45

56
import numpy as np
67
import pytest
78

8-
import zarr.api.asynchronous
99
from zarr import Array, AsyncArray, AsyncGroup, Group
1010
from zarr.abc.store import Store
11-
from zarr.api.synchronous import open_group
1211
from zarr.core.buffer import default_buffer_prototype
1312
from zarr.core.common import JSON, ZarrFormat
1413
from zarr.core.group import GroupMetadata
1514
from zarr.core.sync import sync
1615
from zarr.errors import ContainsArrayError, ContainsGroupError
17-
from zarr.store import LocalStore, MemoryStore, StorePath
16+
from zarr.store import LocalStore, StorePath
1817
from zarr.store.common import make_store_path
1918

2019
from .conftest import parse_store
@@ -681,152 +680,22 @@ async def test_asyncgroup_update_attributes(store: Store, zarr_format: ZarrForma
681680
assert agroup_new_attributes.attrs == attributes_new
682681

683682

684-
async def test_group_members_async(store: LocalStore | MemoryStore) -> None:
685-
group = AsyncGroup(
686-
GroupMetadata(),
687-
store_path=StorePath(store=store, path="root"),
688-
)
689-
a0 = await group.create_array("a0", shape=(1,))
690-
g0 = await group.create_group("g0")
691-
a1 = await g0.create_array("a1", shape=(1,))
692-
g1 = await g0.create_group("g1")
693-
a2 = await g1.create_array("a2", shape=(1,))
694-
g2 = await g1.create_group("g2")
695-
696-
# immediate children
697-
children = sorted([x async for x in group.members()], key=lambda x: x[0])
698-
assert children == [
699-
("a0", a0),
700-
("g0", g0),
701-
]
702-
703-
nmembers = await group.nmembers()
704-
assert nmembers == 2
705-
706-
# partial
707-
children = sorted([x async for x in group.members(max_depth=1)], key=lambda x: x[0])
708-
expected = [
709-
("a0", a0),
710-
("g0", g0),
711-
("g0/a1", a1),
712-
("g0/g1", g1),
713-
]
714-
assert children == expected
715-
nmembers = await group.nmembers(max_depth=1)
716-
assert nmembers == 4
717-
718-
# all children
719-
all_children = sorted([x async for x in group.members(max_depth=None)], key=lambda x: x[0])
720-
expected = [
721-
("a0", a0),
722-
("g0", g0),
723-
("g0/a1", a1),
724-
("g0/g1", g1),
725-
("g0/g1/a2", a2),
726-
("g0/g1/g2", g2),
727-
]
728-
assert all_children == expected
729-
730-
nmembers = await group.nmembers(max_depth=None)
731-
assert nmembers == 6
732-
733-
with pytest.raises(ValueError, match="max_depth"):
734-
[x async for x in group.members(max_depth=-1)]
735-
736-
737-
async def test_require_group(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
738-
root = await AsyncGroup.create(store=store, zarr_format=zarr_format)
739-
740-
# create foo group
741-
_ = await root.create_group("foo", attributes={"foo": 100})
742-
743-
# test that we can get the group using require_group
744-
foo_group = await root.require_group("foo")
745-
assert foo_group.attrs == {"foo": 100}
746-
747-
# test that we can get the group using require_group and overwrite=True
748-
foo_group = await root.require_group("foo", overwrite=True)
749-
750-
_ = await foo_group.create_array(
751-
"bar", shape=(10,), dtype="uint8", chunk_shape=(2,), attributes={"foo": 100}
683+
@pytest.mark.parametrize("store", ("local",), indirect=["store"])
684+
@pytest.mark.parametrize("zarr_format", (2, 3))
685+
async def test_serializable_async_group(store: LocalStore, zarr_format: ZarrFormat) -> None:
686+
expected = await AsyncGroup.create(
687+
store=store, attributes={"foo": 999}, zarr_format=zarr_format
752688
)
689+
p = pickle.dumps(expected)
690+
actual = pickle.loads(p)
691+
assert actual == expected
753692

754-
# test that overwriting a group w/ children fails
755-
# TODO: figure out why ensure_no_existing_node is not catching the foo.bar array
756-
#
757-
# with pytest.raises(ContainsArrayError):
758-
# await root.require_group("foo", overwrite=True)
759-
760-
# test that requiring a group where an array is fails
761-
with pytest.raises(TypeError):
762-
await foo_group.require_group("bar")
763-
764-
765-
async def test_require_groups(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
766-
root = await AsyncGroup.create(store=store, zarr_format=zarr_format)
767-
# create foo group
768-
_ = await root.create_group("foo", attributes={"foo": 100})
769-
# create bar group
770-
_ = await root.create_group("bar", attributes={"bar": 200})
771-
772-
foo_group, bar_group = await root.require_groups("foo", "bar")
773-
assert foo_group.attrs == {"foo": 100}
774-
assert bar_group.attrs == {"bar": 200}
775-
776-
# get a mix of existing and new groups
777-
foo_group, spam_group = await root.require_groups("foo", "spam")
778-
assert foo_group.attrs == {"foo": 100}
779-
assert spam_group.attrs == {}
780-
781-
# no names
782-
no_group = await root.require_groups()
783-
assert no_group == ()
784-
785-
786-
async def test_create_dataset(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
787-
root = await AsyncGroup.create(store=store, zarr_format=zarr_format)
788-
with pytest.warns(DeprecationWarning):
789-
foo = await root.create_dataset("foo", shape=(10,), dtype="uint8")
790-
assert foo.shape == (10,)
791-
792-
with pytest.raises(ContainsArrayError), pytest.warns(DeprecationWarning):
793-
await root.create_dataset("foo", shape=(100,), dtype="int8")
794-
795-
_ = await root.create_group("bar")
796-
with pytest.raises(ContainsGroupError), pytest.warns(DeprecationWarning):
797-
await root.create_dataset("bar", shape=(100,), dtype="int8")
798-
799-
800-
async def test_require_array(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
801-
root = await AsyncGroup.create(store=store, zarr_format=zarr_format)
802-
foo1 = await root.require_array("foo", shape=(10,), dtype="i8", attributes={"foo": 101})
803-
assert foo1.attrs == {"foo": 101}
804-
foo2 = await root.require_array("foo", shape=(10,), dtype="i8")
805-
assert foo2.attrs == {"foo": 101}
806-
807-
# exact = False
808-
_ = await root.require_array("foo", shape=10, dtype="f8")
809-
810-
# errors w/ exact True
811-
with pytest.raises(TypeError, match="Incompatible dtype"):
812-
await root.require_array("foo", shape=(10,), dtype="f8", exact=True)
813-
814-
with pytest.raises(TypeError, match="Incompatible shape"):
815-
await root.require_array("foo", shape=(100, 100), dtype="i8")
816-
817-
with pytest.raises(TypeError, match="Incompatible dtype"):
818-
await root.require_array("foo", shape=(10,), dtype="f4")
819-
820-
_ = await root.create_group("bar")
821-
with pytest.raises(TypeError, match="Incompatible object"):
822-
await root.require_array("bar", shape=(10,), dtype="int8")
823-
824-
825-
async def test_open_mutable_mapping():
826-
group = await zarr.api.asynchronous.open_group(store={}, mode="w")
827-
assert isinstance(group.store_path.store, MemoryStore)
828693

694+
@pytest.mark.parametrize("store", ("local",), indirect=["store"])
695+
@pytest.mark.parametrize("zarr_format", (2, 3))
696+
def test_serializable_sync_group(store: LocalStore, zarr_format: ZarrFormat) -> None:
697+
expected = Group.create(store=store, attributes={"foo": 999}, zarr_format=zarr_format)
698+
p = pickle.dumps(expected)
699+
actual = pickle.loads(p)
829700

830-
def test_open_mutable_mapping_sync():
831-
group = open_group(store={}, mode="w")
832-
assert isinstance(group.store_path.store, MemoryStore)
701+
assert actual == expected

0 commit comments

Comments
 (0)