|
6 | 6 | import numpy as np
|
7 | 7 | import pytest
|
8 | 8 |
|
| 9 | +import zarr |
9 | 10 | from zarr import Array, AsyncArray, AsyncGroup, Group
|
10 | 11 | from zarr.abc.store import Store
|
11 | 12 | from zarr.core.buffer import default_buffer_prototype
|
12 | 13 | from zarr.core.common import JSON, ZarrFormat
|
13 | 14 | from zarr.core.group import GroupMetadata
|
14 | 15 | from zarr.core.sync import sync
|
15 | 16 | from zarr.errors import ContainsArrayError, ContainsGroupError
|
16 |
| -from zarr.store import LocalStore, StorePath |
| 17 | +from zarr.store import LocalStore, MemoryStore, StorePath |
17 | 18 | from zarr.store.common import make_store_path
|
18 | 19 |
|
19 | 20 | from .conftest import parse_store
|
@@ -699,3 +700,154 @@ def test_serializable_sync_group(store: LocalStore, zarr_format: ZarrFormat) ->
|
699 | 700 | actual = pickle.loads(p)
|
700 | 701 |
|
701 | 702 | assert actual == expected
|
| 703 | + |
| 704 | + |
| 705 | +async def test_group_members_async(store: LocalStore | MemoryStore) -> None: |
| 706 | + group = AsyncGroup( |
| 707 | + GroupMetadata(), |
| 708 | + store_path=StorePath(store=store, path="root"), |
| 709 | + ) |
| 710 | + a0 = await group.create_array("a0", shape=(1,)) |
| 711 | + g0 = await group.create_group("g0") |
| 712 | + a1 = await g0.create_array("a1", shape=(1,)) |
| 713 | + g1 = await g0.create_group("g1") |
| 714 | + a2 = await g1.create_array("a2", shape=(1,)) |
| 715 | + g2 = await g1.create_group("g2") |
| 716 | + |
| 717 | + # immediate children |
| 718 | + children = sorted([x async for x in group.members()], key=lambda x: x[0]) |
| 719 | + assert children == [ |
| 720 | + ("a0", a0), |
| 721 | + ("g0", g0), |
| 722 | + ] |
| 723 | + |
| 724 | + nmembers = await group.nmembers() |
| 725 | + assert nmembers == 2 |
| 726 | + |
| 727 | + # partial |
| 728 | + children = sorted([x async for x in group.members(max_depth=1)], key=lambda x: x[0]) |
| 729 | + expected = [ |
| 730 | + ("a0", a0), |
| 731 | + ("g0", g0), |
| 732 | + ("g0/a1", a1), |
| 733 | + ("g0/g1", g1), |
| 734 | + ] |
| 735 | + assert children == expected |
| 736 | + nmembers = await group.nmembers(max_depth=1) |
| 737 | + assert nmembers == 4 |
| 738 | + |
| 739 | + # all children |
| 740 | + all_children = sorted([x async for x in group.members(max_depth=None)], key=lambda x: x[0]) |
| 741 | + expected = [ |
| 742 | + ("a0", a0), |
| 743 | + ("g0", g0), |
| 744 | + ("g0/a1", a1), |
| 745 | + ("g0/g1", g1), |
| 746 | + ("g0/g1/a2", a2), |
| 747 | + ("g0/g1/g2", g2), |
| 748 | + ] |
| 749 | + assert all_children == expected |
| 750 | + |
| 751 | + nmembers = await group.nmembers(max_depth=None) |
| 752 | + assert nmembers == 6 |
| 753 | + |
| 754 | + with pytest.raises(ValueError, match="max_depth"): |
| 755 | + [x async for x in group.members(max_depth=-1)] |
| 756 | + |
| 757 | + |
| 758 | +async def test_require_group(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: |
| 759 | + root = await AsyncGroup.create(store=store, zarr_format=zarr_format) |
| 760 | + |
| 761 | + # create foo group |
| 762 | + _ = await root.create_group("foo", attributes={"foo": 100}) |
| 763 | + |
| 764 | + # test that we can get the group using require_group |
| 765 | + foo_group = await root.require_group("foo") |
| 766 | + assert foo_group.attrs == {"foo": 100} |
| 767 | + |
| 768 | + # test that we can get the group using require_group and overwrite=True |
| 769 | + foo_group = await root.require_group("foo", overwrite=True) |
| 770 | + |
| 771 | + _ = await foo_group.create_array( |
| 772 | + "bar", shape=(10,), dtype="uint8", chunk_shape=(2,), attributes={"foo": 100} |
| 773 | + ) |
| 774 | + |
| 775 | + # test that overwriting a group w/ children fails |
| 776 | + # TODO: figure out why ensure_no_existing_node is not catching the foo.bar array |
| 777 | + # |
| 778 | + # with pytest.raises(ContainsArrayError): |
| 779 | + # await root.require_group("foo", overwrite=True) |
| 780 | + |
| 781 | + # test that requiring a group where an array is fails |
| 782 | + with pytest.raises(TypeError): |
| 783 | + await foo_group.require_group("bar") |
| 784 | + |
| 785 | + |
| 786 | +async def test_require_groups(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: |
| 787 | + root = await AsyncGroup.create(store=store, zarr_format=zarr_format) |
| 788 | + # create foo group |
| 789 | + _ = await root.create_group("foo", attributes={"foo": 100}) |
| 790 | + # create bar group |
| 791 | + _ = await root.create_group("bar", attributes={"bar": 200}) |
| 792 | + |
| 793 | + foo_group, bar_group = await root.require_groups("foo", "bar") |
| 794 | + assert foo_group.attrs == {"foo": 100} |
| 795 | + assert bar_group.attrs == {"bar": 200} |
| 796 | + |
| 797 | + # get a mix of existing and new groups |
| 798 | + foo_group, spam_group = await root.require_groups("foo", "spam") |
| 799 | + assert foo_group.attrs == {"foo": 100} |
| 800 | + assert spam_group.attrs == {} |
| 801 | + |
| 802 | + # no names |
| 803 | + no_group = await root.require_groups() |
| 804 | + assert no_group == () |
| 805 | + |
| 806 | + |
| 807 | +async def test_create_dataset(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: |
| 808 | + root = await AsyncGroup.create(store=store, zarr_format=zarr_format) |
| 809 | + with pytest.warns(DeprecationWarning): |
| 810 | + foo = await root.create_dataset("foo", shape=(10,), dtype="uint8") |
| 811 | + assert foo.shape == (10,) |
| 812 | + |
| 813 | + with pytest.raises(ContainsArrayError), pytest.warns(DeprecationWarning): |
| 814 | + await root.create_dataset("foo", shape=(100,), dtype="int8") |
| 815 | + |
| 816 | + _ = await root.create_group("bar") |
| 817 | + with pytest.raises(ContainsGroupError), pytest.warns(DeprecationWarning): |
| 818 | + await root.create_dataset("bar", shape=(100,), dtype="int8") |
| 819 | + |
| 820 | + |
| 821 | +async def test_require_array(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: |
| 822 | + root = await AsyncGroup.create(store=store, zarr_format=zarr_format) |
| 823 | + foo1 = await root.require_array("foo", shape=(10,), dtype="i8", attributes={"foo": 101}) |
| 824 | + assert foo1.attrs == {"foo": 101} |
| 825 | + foo2 = await root.require_array("foo", shape=(10,), dtype="i8") |
| 826 | + assert foo2.attrs == {"foo": 101} |
| 827 | + |
| 828 | + # exact = False |
| 829 | + _ = await root.require_array("foo", shape=10, dtype="f8") |
| 830 | + |
| 831 | + # errors w/ exact True |
| 832 | + with pytest.raises(TypeError, match="Incompatible dtype"): |
| 833 | + await root.require_array("foo", shape=(10,), dtype="f8", exact=True) |
| 834 | + |
| 835 | + with pytest.raises(TypeError, match="Incompatible shape"): |
| 836 | + await root.require_array("foo", shape=(100, 100), dtype="i8") |
| 837 | + |
| 838 | + with pytest.raises(TypeError, match="Incompatible dtype"): |
| 839 | + await root.require_array("foo", shape=(10,), dtype="f4") |
| 840 | + |
| 841 | + _ = await root.create_group("bar") |
| 842 | + with pytest.raises(TypeError, match="Incompatible object"): |
| 843 | + await root.require_array("bar", shape=(10,), dtype="int8") |
| 844 | + |
| 845 | + |
| 846 | +async def test_open_mutable_mapping(): |
| 847 | + group = await zarr.api.asynchronous.open_group(store={}, mode="w") |
| 848 | + assert isinstance(group.store_path.store, MemoryStore) |
| 849 | + |
| 850 | + |
| 851 | +def test_open_mutable_mapping_sync(): |
| 852 | + group = zarr.open_group(store={}, mode="w") |
| 853 | + assert isinstance(group.store_path.store, MemoryStore) |
0 commit comments