Skip to content

Commit 6b6cc3a

Browse files
committed
fix: replace tests that went missing in zarr-developers#2006
1 parent b1ecdd5 commit 6b6cc3a

File tree

2 files changed

+191
-3
lines changed

2 files changed

+191
-3
lines changed

tests/v3/test_group.py

+153-1
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66
import numpy as np
77
import pytest
88

9+
import zarr
910
from zarr import Array, AsyncArray, AsyncGroup, Group
1011
from zarr.abc.store import Store
1112
from zarr.core.buffer import default_buffer_prototype
1213
from zarr.core.common import JSON, ZarrFormat
1314
from zarr.core.group import GroupMetadata
1415
from zarr.core.sync import sync
1516
from zarr.errors import ContainsArrayError, ContainsGroupError
16-
from zarr.store import LocalStore, StorePath
17+
from zarr.store import LocalStore, MemoryStore, StorePath
1718
from zarr.store.common import make_store_path
1819

1920
from .conftest import parse_store
@@ -699,3 +700,154 @@ def test_serializable_sync_group(store: LocalStore, zarr_format: ZarrFormat) ->
699700
actual = pickle.loads(p)
700701

701702
assert actual == expected
703+
704+
705+
async def test_group_members_async(store: LocalStore | MemoryStore) -> None:
706+
group = AsyncGroup(
707+
GroupMetadata(),
708+
store_path=StorePath(store=store, path="root"),
709+
)
710+
a0 = await group.create_array("a0", shape=(1,))
711+
g0 = await group.create_group("g0")
712+
a1 = await g0.create_array("a1", shape=(1,))
713+
g1 = await g0.create_group("g1")
714+
a2 = await g1.create_array("a2", shape=(1,))
715+
g2 = await g1.create_group("g2")
716+
717+
# immediate children
718+
children = sorted([x async for x in group.members()], key=lambda x: x[0])
719+
assert children == [
720+
("a0", a0),
721+
("g0", g0),
722+
]
723+
724+
nmembers = await group.nmembers()
725+
assert nmembers == 2
726+
727+
# partial
728+
children = sorted([x async for x in group.members(max_depth=1)], key=lambda x: x[0])
729+
expected = [
730+
("a0", a0),
731+
("g0", g0),
732+
("g0/a1", a1),
733+
("g0/g1", g1),
734+
]
735+
assert children == expected
736+
nmembers = await group.nmembers(max_depth=1)
737+
assert nmembers == 4
738+
739+
# all children
740+
all_children = sorted([x async for x in group.members(max_depth=None)], key=lambda x: x[0])
741+
expected = [
742+
("a0", a0),
743+
("g0", g0),
744+
("g0/a1", a1),
745+
("g0/g1", g1),
746+
("g0/g1/a2", a2),
747+
("g0/g1/g2", g2),
748+
]
749+
assert all_children == expected
750+
751+
nmembers = await group.nmembers(max_depth=None)
752+
assert nmembers == 6
753+
754+
with pytest.raises(ValueError, match="max_depth"):
755+
[x async for x in group.members(max_depth=-1)]
756+
757+
758+
async def test_require_group(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
759+
root = await AsyncGroup.create(store=store, zarr_format=zarr_format)
760+
761+
# create foo group
762+
_ = await root.create_group("foo", attributes={"foo": 100})
763+
764+
# test that we can get the group using require_group
765+
foo_group = await root.require_group("foo")
766+
assert foo_group.attrs == {"foo": 100}
767+
768+
# test that we can get the group using require_group and overwrite=True
769+
foo_group = await root.require_group("foo", overwrite=True)
770+
771+
_ = await foo_group.create_array(
772+
"bar", shape=(10,), dtype="uint8", chunk_shape=(2,), attributes={"foo": 100}
773+
)
774+
775+
# test that overwriting a group w/ children fails
776+
# TODO: figure out why ensure_no_existing_node is not catching the foo.bar array
777+
#
778+
# with pytest.raises(ContainsArrayError):
779+
# await root.require_group("foo", overwrite=True)
780+
781+
# test that requiring a group where an array is fails
782+
with pytest.raises(TypeError):
783+
await foo_group.require_group("bar")
784+
785+
786+
async def test_require_groups(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
787+
root = await AsyncGroup.create(store=store, zarr_format=zarr_format)
788+
# create foo group
789+
_ = await root.create_group("foo", attributes={"foo": 100})
790+
# create bar group
791+
_ = await root.create_group("bar", attributes={"bar": 200})
792+
793+
foo_group, bar_group = await root.require_groups("foo", "bar")
794+
assert foo_group.attrs == {"foo": 100}
795+
assert bar_group.attrs == {"bar": 200}
796+
797+
# get a mix of existing and new groups
798+
foo_group, spam_group = await root.require_groups("foo", "spam")
799+
assert foo_group.attrs == {"foo": 100}
800+
assert spam_group.attrs == {}
801+
802+
# no names
803+
no_group = await root.require_groups()
804+
assert no_group == ()
805+
806+
807+
async def test_create_dataset(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
808+
root = await AsyncGroup.create(store=store, zarr_format=zarr_format)
809+
with pytest.warns(DeprecationWarning):
810+
foo = await root.create_dataset("foo", shape=(10,), dtype="uint8")
811+
assert foo.shape == (10,)
812+
813+
with pytest.raises(ContainsArrayError), pytest.warns(DeprecationWarning):
814+
await root.create_dataset("foo", shape=(100,), dtype="int8")
815+
816+
_ = await root.create_group("bar")
817+
with pytest.raises(ContainsGroupError), pytest.warns(DeprecationWarning):
818+
await root.create_dataset("bar", shape=(100,), dtype="int8")
819+
820+
821+
async def test_require_array(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
822+
root = await AsyncGroup.create(store=store, zarr_format=zarr_format)
823+
foo1 = await root.require_array("foo", shape=(10,), dtype="i8", attributes={"foo": 101})
824+
assert foo1.attrs == {"foo": 101}
825+
foo2 = await root.require_array("foo", shape=(10,), dtype="i8")
826+
assert foo2.attrs == {"foo": 101}
827+
828+
# exact = False
829+
_ = await root.require_array("foo", shape=10, dtype="f8")
830+
831+
# errors w/ exact True
832+
with pytest.raises(TypeError, match="Incompatible dtype"):
833+
await root.require_array("foo", shape=(10,), dtype="f8", exact=True)
834+
835+
with pytest.raises(TypeError, match="Incompatible shape"):
836+
await root.require_array("foo", shape=(100, 100), dtype="i8")
837+
838+
with pytest.raises(TypeError, match="Incompatible dtype"):
839+
await root.require_array("foo", shape=(10,), dtype="f4")
840+
841+
_ = await root.create_group("bar")
842+
with pytest.raises(TypeError, match="Incompatible object"):
843+
await root.require_array("bar", shape=(10,), dtype="int8")
844+
845+
846+
async def test_open_mutable_mapping():
847+
group = await zarr.api.asynchronous.open_group(store={}, mode="w")
848+
assert isinstance(group.store_path.store, MemoryStore)
849+
850+
851+
def test_open_mutable_mapping_sync():
852+
group = zarr.open_group(store={}, mode="w")
853+
assert isinstance(group.store_path.store, MemoryStore)

tests/v3/test_store/test_memory.py

+38-2
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44

55
import pytest
66

7-
from zarr.core.buffer import Buffer, cpu
8-
from zarr.store.memory import MemoryStore
7+
from zarr.core.buffer import Buffer, cpu, gpu
8+
from zarr.store.memory import GpuMemoryStore, MemoryStore
99
from zarr.testing.store import StoreTests
10+
from zarr.testing.utils import gpu_test
1011

1112

1213
class TestMemoryStore(StoreTests[MemoryStore, cpu.Buffer]):
@@ -56,3 +57,38 @@ def test_serizalizable_store(self, store: MemoryStore) -> None:
5657

5758
with pytest.raises(NotImplementedError):
5859
pickle.dumps(store)
60+
61+
62+
@gpu_test
63+
class TestGpuMemoryStore(StoreTests[GpuMemoryStore, gpu.Buffer]):
64+
store_cls = GpuMemoryStore
65+
buffer_cls = gpu.Buffer
66+
67+
def set(self, store: GpuMemoryStore, key: str, value: Buffer) -> None:
68+
store._store_dict[key] = value
69+
70+
def get(self, store: MemoryStore, key: str) -> Buffer:
71+
return store._store_dict[key]
72+
73+
@pytest.fixture(scope="function", params=[None, {}])
74+
def store_kwargs(self, request) -> dict[str, str | None | dict[str, Buffer]]:
75+
return {"store_dict": request.param, "mode": "r+"}
76+
77+
@pytest.fixture(scope="function")
78+
def store(self, store_kwargs: str | None | dict[str, gpu.Buffer]) -> GpuMemoryStore:
79+
return self.store_cls(**store_kwargs)
80+
81+
def test_store_repr(self, store: GpuMemoryStore) -> None:
82+
assert str(store) == f"gpumemory://{id(store._store_dict)}"
83+
84+
def test_store_supports_writes(self, store: GpuMemoryStore) -> None:
85+
assert store.supports_writes
86+
87+
def test_store_supports_listing(self, store: GpuMemoryStore) -> None:
88+
assert store.supports_listing
89+
90+
def test_store_supports_partial_writes(self, store: GpuMemoryStore) -> None:
91+
assert store.supports_partial_writes
92+
93+
def test_list_prefix(self, store: GpuMemoryStore) -> None:
94+
assert True

0 commit comments

Comments
 (0)