-
-
Notifications
You must be signed in to change notification settings - Fork 330
test: check that store, array, and group classes are serializable #2006
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
17ed022
275d6fa
4e2fdbc
938a0bb
0408af6
1aab1fa
4fa132e
c834853
767df05
339ce1c
1e0c4a2
c80ee2b
43b482f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,19 @@ | ||
from __future__ import annotations | ||
|
||
import pickle | ||
from typing import TYPE_CHECKING, Any, Literal, cast | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
import zarr.api.asynchronous | ||
from zarr import Array, AsyncArray, AsyncGroup, Group | ||
from zarr.abc.store import Store | ||
from zarr.api.synchronous import open_group | ||
from zarr.core.buffer import default_buffer_prototype | ||
from zarr.core.common import JSON, ZarrFormat | ||
from zarr.core.group import GroupMetadata | ||
from zarr.core.sync import sync | ||
from zarr.errors import ContainsArrayError, ContainsGroupError | ||
from zarr.store import LocalStore, MemoryStore, StorePath | ||
from zarr.store import LocalStore, StorePath | ||
from zarr.store.common import make_store_path | ||
|
||
from .conftest import parse_store | ||
|
@@ -681,152 +680,22 @@ async def test_asyncgroup_update_attributes(store: Store, zarr_format: ZarrForma | |
assert agroup_new_attributes.attrs == attributes_new | ||
|
||
|
||
async def test_group_members_async(store: LocalStore | MemoryStore) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jhamman was this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not deliberately removed. Must have been a bad merge conflict resolution. I'll bring it back today. Sorry! |
||
group = AsyncGroup( | ||
GroupMetadata(), | ||
store_path=StorePath(store=store, path="root"), | ||
) | ||
a0 = await group.create_array("a0", shape=(1,)) | ||
g0 = await group.create_group("g0") | ||
a1 = await g0.create_array("a1", shape=(1,)) | ||
g1 = await g0.create_group("g1") | ||
a2 = await g1.create_array("a2", shape=(1,)) | ||
g2 = await g1.create_group("g2") | ||
|
||
# immediate children | ||
children = sorted([x async for x in group.members()], key=lambda x: x[0]) | ||
assert children == [ | ||
("a0", a0), | ||
("g0", g0), | ||
] | ||
|
||
nmembers = await group.nmembers() | ||
assert nmembers == 2 | ||
|
||
# partial | ||
children = sorted([x async for x in group.members(max_depth=1)], key=lambda x: x[0]) | ||
expected = [ | ||
("a0", a0), | ||
("g0", g0), | ||
("g0/a1", a1), | ||
("g0/g1", g1), | ||
] | ||
assert children == expected | ||
nmembers = await group.nmembers(max_depth=1) | ||
assert nmembers == 4 | ||
|
||
# all children | ||
all_children = sorted([x async for x in group.members(max_depth=None)], key=lambda x: x[0]) | ||
expected = [ | ||
("a0", a0), | ||
("g0", g0), | ||
("g0/a1", a1), | ||
("g0/g1", g1), | ||
("g0/g1/a2", a2), | ||
("g0/g1/g2", g2), | ||
] | ||
assert all_children == expected | ||
|
||
nmembers = await group.nmembers(max_depth=None) | ||
assert nmembers == 6 | ||
|
||
with pytest.raises(ValueError, match="max_depth"): | ||
[x async for x in group.members(max_depth=-1)] | ||
|
||
|
||
async def test_require_group(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: | ||
root = await AsyncGroup.create(store=store, zarr_format=zarr_format) | ||
|
||
# create foo group | ||
_ = await root.create_group("foo", attributes={"foo": 100}) | ||
|
||
# test that we can get the group using require_group | ||
foo_group = await root.require_group("foo") | ||
assert foo_group.attrs == {"foo": 100} | ||
|
||
# test that we can get the group using require_group and overwrite=True | ||
foo_group = await root.require_group("foo", overwrite=True) | ||
|
||
_ = await foo_group.create_array( | ||
"bar", shape=(10,), dtype="uint8", chunk_shape=(2,), attributes={"foo": 100} | ||
@pytest.mark.parametrize("store", ("local",), indirect=["store"]) | ||
@pytest.mark.parametrize("zarr_format", (2, 3)) | ||
async def test_serializable_async_group(store: LocalStore, zarr_format: ZarrFormat) -> None: | ||
expected = await AsyncGroup.create( | ||
store=store, attributes={"foo": 999}, zarr_format=zarr_format | ||
) | ||
p = pickle.dumps(expected) | ||
actual = pickle.loads(p) | ||
assert actual == expected | ||
|
||
# test that overwriting a group w/ children fails | ||
# TODO: figure out why ensure_no_existing_node is not catching the foo.bar array | ||
# | ||
# with pytest.raises(ContainsArrayError): | ||
# await root.require_group("foo", overwrite=True) | ||
|
||
# test that requiring a group where an array is fails | ||
with pytest.raises(TypeError): | ||
await foo_group.require_group("bar") | ||
|
||
|
||
async def test_require_groups(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: | ||
root = await AsyncGroup.create(store=store, zarr_format=zarr_format) | ||
# create foo group | ||
_ = await root.create_group("foo", attributes={"foo": 100}) | ||
# create bar group | ||
_ = await root.create_group("bar", attributes={"bar": 200}) | ||
|
||
foo_group, bar_group = await root.require_groups("foo", "bar") | ||
assert foo_group.attrs == {"foo": 100} | ||
assert bar_group.attrs == {"bar": 200} | ||
|
||
# get a mix of existing and new groups | ||
foo_group, spam_group = await root.require_groups("foo", "spam") | ||
assert foo_group.attrs == {"foo": 100} | ||
assert spam_group.attrs == {} | ||
|
||
# no names | ||
no_group = await root.require_groups() | ||
assert no_group == () | ||
|
||
|
||
async def test_create_dataset(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: | ||
root = await AsyncGroup.create(store=store, zarr_format=zarr_format) | ||
with pytest.warns(DeprecationWarning): | ||
foo = await root.create_dataset("foo", shape=(10,), dtype="uint8") | ||
assert foo.shape == (10,) | ||
|
||
with pytest.raises(ContainsArrayError), pytest.warns(DeprecationWarning): | ||
await root.create_dataset("foo", shape=(100,), dtype="int8") | ||
|
||
_ = await root.create_group("bar") | ||
with pytest.raises(ContainsGroupError), pytest.warns(DeprecationWarning): | ||
await root.create_dataset("bar", shape=(100,), dtype="int8") | ||
|
||
|
||
async def test_require_array(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: | ||
root = await AsyncGroup.create(store=store, zarr_format=zarr_format) | ||
foo1 = await root.require_array("foo", shape=(10,), dtype="i8", attributes={"foo": 101}) | ||
assert foo1.attrs == {"foo": 101} | ||
foo2 = await root.require_array("foo", shape=(10,), dtype="i8") | ||
assert foo2.attrs == {"foo": 101} | ||
|
||
# exact = False | ||
_ = await root.require_array("foo", shape=10, dtype="f8") | ||
|
||
# errors w/ exact True | ||
with pytest.raises(TypeError, match="Incompatible dtype"): | ||
await root.require_array("foo", shape=(10,), dtype="f8", exact=True) | ||
|
||
with pytest.raises(TypeError, match="Incompatible shape"): | ||
await root.require_array("foo", shape=(100, 100), dtype="i8") | ||
|
||
with pytest.raises(TypeError, match="Incompatible dtype"): | ||
await root.require_array("foo", shape=(10,), dtype="f4") | ||
|
||
_ = await root.create_group("bar") | ||
with pytest.raises(TypeError, match="Incompatible object"): | ||
await root.require_array("bar", shape=(10,), dtype="int8") | ||
|
||
|
||
async def test_open_mutable_mapping(): | ||
group = await zarr.api.asynchronous.open_group(store={}, mode="w") | ||
assert isinstance(group.store_path.store, MemoryStore) | ||
|
||
@pytest.mark.parametrize("store", ("local",), indirect=["store"]) | ||
@pytest.mark.parametrize("zarr_format", (2, 3)) | ||
def test_serializable_sync_group(store: LocalStore, zarr_format: ZarrFormat) -> None: | ||
expected = Group.create(store=store, attributes={"foo": 999}, zarr_format=zarr_format) | ||
p = pickle.dumps(expected) | ||
actual = pickle.loads(p) | ||
|
||
def test_open_mutable_mapping_sync(): | ||
group = open_group(store={}, mode="w") | ||
assert isinstance(group.store_path.store, MemoryStore) | ||
assert actual == expected |
Uh oh!
There was an error while loading. Please reload this page.