|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +import pickle |
3 | 4 | from typing import TYPE_CHECKING, Any, Literal, cast
|
4 | 5 |
|
5 | 6 | import numpy as np
|
6 | 7 | import pytest
|
7 | 8 |
|
8 |
| -import zarr.api.asynchronous |
9 | 9 | from zarr import Array, AsyncArray, AsyncGroup, Group
|
10 | 10 | from zarr.abc.store import Store
|
11 |
| -from zarr.api.synchronous import open_group |
12 | 11 | from zarr.core.buffer import default_buffer_prototype
|
13 | 12 | from zarr.core.common import JSON, ZarrFormat
|
14 | 13 | from zarr.core.group import GroupMetadata
|
15 | 14 | from zarr.core.sync import sync
|
16 | 15 | from zarr.errors import ContainsArrayError, ContainsGroupError
|
17 |
| -from zarr.store import LocalStore, MemoryStore, StorePath |
| 16 | +from zarr.store import LocalStore, StorePath |
18 | 17 | from zarr.store.common import make_store_path
|
19 | 18 |
|
20 | 19 | from .conftest import parse_store
|
@@ -681,152 +680,22 @@ async def test_asyncgroup_update_attributes(store: Store, zarr_format: ZarrForma
|
681 | 680 | assert agroup_new_attributes.attrs == attributes_new
|
682 | 681 |
|
683 | 682 |
|
684 |
| -async def test_group_members_async(store: LocalStore | MemoryStore) -> None: |
685 |
| - group = AsyncGroup( |
686 |
| - GroupMetadata(), |
687 |
| - store_path=StorePath(store=store, path="root"), |
688 |
| - ) |
689 |
| - a0 = await group.create_array("a0", shape=(1,)) |
690 |
| - g0 = await group.create_group("g0") |
691 |
| - a1 = await g0.create_array("a1", shape=(1,)) |
692 |
| - g1 = await g0.create_group("g1") |
693 |
| - a2 = await g1.create_array("a2", shape=(1,)) |
694 |
| - g2 = await g1.create_group("g2") |
695 |
| - |
696 |
| - # immediate children |
697 |
| - children = sorted([x async for x in group.members()], key=lambda x: x[0]) |
698 |
| - assert children == [ |
699 |
| - ("a0", a0), |
700 |
| - ("g0", g0), |
701 |
| - ] |
702 |
| - |
703 |
| - nmembers = await group.nmembers() |
704 |
| - assert nmembers == 2 |
705 |
| - |
706 |
| - # partial |
707 |
| - children = sorted([x async for x in group.members(max_depth=1)], key=lambda x: x[0]) |
708 |
| - expected = [ |
709 |
| - ("a0", a0), |
710 |
| - ("g0", g0), |
711 |
| - ("g0/a1", a1), |
712 |
| - ("g0/g1", g1), |
713 |
| - ] |
714 |
| - assert children == expected |
715 |
| - nmembers = await group.nmembers(max_depth=1) |
716 |
| - assert nmembers == 4 |
717 |
| - |
718 |
| - # all children |
719 |
| - all_children = sorted([x async for x in group.members(max_depth=None)], key=lambda x: x[0]) |
720 |
| - expected = [ |
721 |
| - ("a0", a0), |
722 |
| - ("g0", g0), |
723 |
| - ("g0/a1", a1), |
724 |
| - ("g0/g1", g1), |
725 |
| - ("g0/g1/a2", a2), |
726 |
| - ("g0/g1/g2", g2), |
727 |
| - ] |
728 |
| - assert all_children == expected |
729 |
| - |
730 |
| - nmembers = await group.nmembers(max_depth=None) |
731 |
| - assert nmembers == 6 |
732 |
| - |
733 |
| - with pytest.raises(ValueError, match="max_depth"): |
734 |
| - [x async for x in group.members(max_depth=-1)] |
735 |
| - |
736 |
| - |
737 |
| -async def test_require_group(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: |
738 |
| - root = await AsyncGroup.create(store=store, zarr_format=zarr_format) |
739 |
| - |
740 |
| - # create foo group |
741 |
| - _ = await root.create_group("foo", attributes={"foo": 100}) |
742 |
| - |
743 |
| - # test that we can get the group using require_group |
744 |
| - foo_group = await root.require_group("foo") |
745 |
| - assert foo_group.attrs == {"foo": 100} |
746 |
| - |
747 |
| - # test that we can get the group using require_group and overwrite=True |
748 |
| - foo_group = await root.require_group("foo", overwrite=True) |
749 |
| - |
750 |
| - _ = await foo_group.create_array( |
751 |
| - "bar", shape=(10,), dtype="uint8", chunk_shape=(2,), attributes={"foo": 100} |
| 683 | +@pytest.mark.parametrize("store", ("local",), indirect=["store"]) |
| 684 | +@pytest.mark.parametrize("zarr_format", (2, 3)) |
| 685 | +async def test_serializable_async_group(store: LocalStore, zarr_format: ZarrFormat) -> None: |
| 686 | + expected = await AsyncGroup.create( |
| 687 | + store=store, attributes={"foo": 999}, zarr_format=zarr_format |
752 | 688 | )
|
| 689 | + p = pickle.dumps(expected) |
| 690 | + actual = pickle.loads(p) |
| 691 | + assert actual == expected |
753 | 692 |
|
754 |
| - # test that overwriting a group w/ children fails |
755 |
| - # TODO: figure out why ensure_no_existing_node is not catching the foo.bar array |
756 |
| - # |
757 |
| - # with pytest.raises(ContainsArrayError): |
758 |
| - # await root.require_group("foo", overwrite=True) |
759 |
| - |
760 |
| - # test that requiring a group where an array is fails |
761 |
| - with pytest.raises(TypeError): |
762 |
| - await foo_group.require_group("bar") |
763 |
| - |
764 |
| - |
765 |
| -async def test_require_groups(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: |
766 |
| - root = await AsyncGroup.create(store=store, zarr_format=zarr_format) |
767 |
| - # create foo group |
768 |
| - _ = await root.create_group("foo", attributes={"foo": 100}) |
769 |
| - # create bar group |
770 |
| - _ = await root.create_group("bar", attributes={"bar": 200}) |
771 |
| - |
772 |
| - foo_group, bar_group = await root.require_groups("foo", "bar") |
773 |
| - assert foo_group.attrs == {"foo": 100} |
774 |
| - assert bar_group.attrs == {"bar": 200} |
775 |
| - |
776 |
| - # get a mix of existing and new groups |
777 |
| - foo_group, spam_group = await root.require_groups("foo", "spam") |
778 |
| - assert foo_group.attrs == {"foo": 100} |
779 |
| - assert spam_group.attrs == {} |
780 |
| - |
781 |
| - # no names |
782 |
| - no_group = await root.require_groups() |
783 |
| - assert no_group == () |
784 |
| - |
785 |
| - |
786 |
| -async def test_create_dataset(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: |
787 |
| - root = await AsyncGroup.create(store=store, zarr_format=zarr_format) |
788 |
| - with pytest.warns(DeprecationWarning): |
789 |
| - foo = await root.create_dataset("foo", shape=(10,), dtype="uint8") |
790 |
| - assert foo.shape == (10,) |
791 |
| - |
792 |
| - with pytest.raises(ContainsArrayError), pytest.warns(DeprecationWarning): |
793 |
| - await root.create_dataset("foo", shape=(100,), dtype="int8") |
794 |
| - |
795 |
| - _ = await root.create_group("bar") |
796 |
| - with pytest.raises(ContainsGroupError), pytest.warns(DeprecationWarning): |
797 |
| - await root.create_dataset("bar", shape=(100,), dtype="int8") |
798 |
| - |
799 |
| - |
800 |
| -async def test_require_array(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: |
801 |
| - root = await AsyncGroup.create(store=store, zarr_format=zarr_format) |
802 |
| - foo1 = await root.require_array("foo", shape=(10,), dtype="i8", attributes={"foo": 101}) |
803 |
| - assert foo1.attrs == {"foo": 101} |
804 |
| - foo2 = await root.require_array("foo", shape=(10,), dtype="i8") |
805 |
| - assert foo2.attrs == {"foo": 101} |
806 |
| - |
807 |
| - # exact = False |
808 |
| - _ = await root.require_array("foo", shape=10, dtype="f8") |
809 |
| - |
810 |
| - # errors w/ exact True |
811 |
| - with pytest.raises(TypeError, match="Incompatible dtype"): |
812 |
| - await root.require_array("foo", shape=(10,), dtype="f8", exact=True) |
813 |
| - |
814 |
| - with pytest.raises(TypeError, match="Incompatible shape"): |
815 |
| - await root.require_array("foo", shape=(100, 100), dtype="i8") |
816 |
| - |
817 |
| - with pytest.raises(TypeError, match="Incompatible dtype"): |
818 |
| - await root.require_array("foo", shape=(10,), dtype="f4") |
819 |
| - |
820 |
| - _ = await root.create_group("bar") |
821 |
| - with pytest.raises(TypeError, match="Incompatible object"): |
822 |
| - await root.require_array("bar", shape=(10,), dtype="int8") |
823 |
| - |
824 |
| - |
825 |
| -async def test_open_mutable_mapping(): |
826 |
| - group = await zarr.api.asynchronous.open_group(store={}, mode="w") |
827 |
| - assert isinstance(group.store_path.store, MemoryStore) |
828 | 693 |
|
| 694 | +@pytest.mark.parametrize("store", ("local",), indirect=["store"]) |
| 695 | +@pytest.mark.parametrize("zarr_format", (2, 3)) |
| 696 | +def test_serializable_sync_group(store: LocalStore, zarr_format: ZarrFormat) -> None: |
| 697 | + expected = Group.create(store=store, attributes={"foo": 999}, zarr_format=zarr_format) |
| 698 | + p = pickle.dumps(expected) |
| 699 | + actual = pickle.loads(p) |
829 | 700 |
|
830 |
| -def test_open_mutable_mapping_sync(): |
831 |
| - group = open_group(store={}, mode="w") |
832 |
| - assert isinstance(group.store_path.store, MemoryStore) |
| 701 | + assert actual == expected |
0 commit comments