From 73d53d73dd1fc200cac54b1583828f4da0a354f9 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Sun, 25 Aug 2024 13:32:24 -0500 Subject: [PATCH 01/10] Fixed MemoryStore.list_dir Ensures that nested children are listed properly. --- src/zarr/store/memory.py | 14 +++++++++++--- src/zarr/testing/store.py | 7 +++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/zarr/store/memory.py b/src/zarr/store/memory.py index 999d750755..d474f18b28 100644 --- a/src/zarr/store/memory.py +++ b/src/zarr/store/memory.py @@ -117,6 +117,14 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: for key in keys_unique: yield key else: - for key in self._store_dict: - if key.startswith(prefix + "/") and key != prefix: - yield key.removeprefix(prefix + "/").split("/")[0] + # Our dictionary doesn't contain directory markers, but we want to include + # a pseudo directory when there's a nested item and we're listing an + # intermediate level. + n = prefix.count("/") + 2 + keys_unique = { + "/".join(k.split("/", n)[:n]) + for k in self._store_dict + if k.startswith(prefix + "/") + } + for key in keys_unique: + yield key.removeprefix(prefix + "/").split("/")[0] diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index ebef4824f7..e263cb38fd 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -192,6 +192,13 @@ async def test_list_dir(self, store: S) -> None: assert [k async for k in store.list_dir("foo")] == [] await store.set("foo/zarr.json", Buffer.from_bytes(b"bar")) await store.set("foo/c/1", Buffer.from_bytes(b"\x01")) + await store.set("foo/c/d/1", Buffer.from_bytes(b"\x01")) + await store.set("foo/c/d/2", Buffer.from_bytes(b"\x01")) + await store.set("foo/c/d/3", Buffer.from_bytes(b"\x01")) + + keys_expected = ["foo"] + keys_observed = [k async for k in store.list_dir("")] + assert set(keys_observed) == set(keys_expected), keys_observed keys_expected = ["zarr.json", "c"] keys_observed = [k async for k in store.list_dir("foo")] From 90940a0e01366a7339bf4abae5caca3fdcb73e30 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Sun, 25 Aug 2024 13:39:16 -0500 Subject: [PATCH 02/10] fixup s3 --- src/zarr/store/remote.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zarr/store/remote.py b/src/zarr/store/remote.py index f5ea694b0a..83393e4dac 100644 --- a/src/zarr/store/remote.py +++ b/src/zarr/store/remote.py @@ -202,7 +202,7 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: except FileNotFoundError: return for onefile in (a.replace(prefix + "/", "") for a in allfiles): - yield onefile + yield onefile.removeprefix(self.path).removeprefix("/") async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: for onefile in await self._fs._ls(prefix, detail=False): From 8ee89f4be622ca5450814eca56d6c1025992e970 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Sun, 25 Aug 2024 11:14:46 -0500 Subject: [PATCH 03/10] recursive Group.members This PR adds a recursive=True flag to Group.members, for recursively listing the members of some hierarhcy. This is useful for Consolidated Metadata, which needs to recursively inspect children. IMO, it's useful (and simple) enough to include in the public API. --- src/zarr/core/group.py | 53 ++++++++++++++++++++++++++++++++++-------- tests/v3/test_group.py | 52 +++++++++++++++++++++++++++++++++++++---- 2 files changed, 91 insertions(+), 14 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 86d27e3a97..4becaf940b 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -424,20 +424,43 @@ async def update_attributes(self, new_attributes: dict[str, Any]) -> AsyncGroup: def __repr__(self) -> str: return f"" - async def nmembers(self) -> int: + async def nmembers(self, recursive: bool = False) -> int: + """ + Count the number of members in this group. + + Parameters + ---------- + recursive : bool, default False + Whether to recursively count arrays and groups in child groups of + this Group. By default, just immediate child array and group members + are counted. + + Returns + ------- + count : int + """ # TODO: consider using aioitertools.builtins.sum for this # return await aioitertools.builtins.sum((1 async for _ in self.members()), start=0) n = 0 - async for _ in self.members(): + async for _ in self.members(recursive=recursive): n += 1 return n - async def members(self) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], None]: + async def members( + self, recursive: bool = False + ) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], None]: """ Returns an AsyncGenerator over the arrays and groups contained in this group. This method requires that `store_path.store` supports directory listing. The results are not guaranteed to be ordered. + + Parameters + ---------- + recursive : bool, default False + Whether to recursively include arrays and groups in child groups of + this Group. By default, just immediate child array and group members + are included. """ if not self.store_path.store.supports_listing: msg = ( @@ -456,7 +479,19 @@ async def members(self) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], N if key in _skip_keys: continue try: - yield (key, await self.getitem(key)) + obj = await self.getitem(key) + yield (key, obj) + + if ( + recursive + and hasattr(obj.metadata, "node_type") + and obj.metadata.node_type == "group" + ): + # the assert is just for mypy to know that `obj.metadata.node_type` + # implies an AsyncGroup, not an AsyncArray + assert isinstance(obj, AsyncGroup) + async for child_key, val in obj.members(recursive=recursive): + yield "/".join([key, child_key]), val except KeyError: # keyerror is raised when `key` names an object (in the object storage sense), # as opposed to a prefix, in the store under the prefix associated with this group @@ -628,17 +663,15 @@ def update_attributes(self, new_attributes: dict[str, Any]) -> Group: self._sync(self._async_group.update_attributes(new_attributes)) return self - @property - def nmembers(self) -> int: - return self._sync(self._async_group.nmembers()) + def nmembers(self, recursive: bool = False) -> int: + return self._sync(self._async_group.nmembers(recursive=recursive)) - @property - def members(self) -> tuple[tuple[str, Array | Group], ...]: + def members(self, recursive: bool = False) -> tuple[tuple[str, Array | Group], ...]: """ Return the sub-arrays and sub-groups of this group as a tuple of (name, array | group) pairs """ - _members = self._sync_iter(self._async_group.members()) + _members = self._sync_iter(self._async_group.members(recursive=recursive)) result = tuple(map(lambda kv: (kv[0], _parse_async_node(kv[1])), _members)) return result diff --git a/tests/v3/test_group.py b/tests/v3/test_group.py index 39921c26d8..efc53352fe 100644 --- a/tests/v3/test_group.py +++ b/tests/v3/test_group.py @@ -88,7 +88,7 @@ def test_group_members(store: MemoryStore | LocalStore, zarr_format: ZarrFormat) members_expected["subgroup"] = group.create_group("subgroup") # make a sub-sub-subgroup, to ensure that the children calculation doesn't go # too deep in the hierarchy - _ = members_expected["subgroup"].create_group("subsubgroup") # type: ignore + subsubgroup = members_expected["subgroup"].create_group("subsubgroup") # type: ignore members_expected["subarray"] = group.create_array( "subarray", shape=(100,), dtype="uint8", chunk_shape=(10,), exists_ok=True @@ -101,7 +101,13 @@ def test_group_members(store: MemoryStore | LocalStore, zarr_format: ZarrFormat) # this creates a directory with a random key in it # this should not show up as a member sync(store.set(f"{path}/extra_directory/extra_object-2", Buffer.from_bytes(b"000000"))) - members_observed = group.members + members_observed = group.members() + # members are not guaranteed to be ordered, so sort before comparing + assert sorted(dict(members_observed)) == sorted(members_expected) + + # recursive=True + members_observed = group.members(recursive=True) + members_expected["subgroup/subsubgroup"] = subsubgroup # members are not guaranteed to be ordered, so sort before comparing assert sorted(dict(members_observed)) == sorted(members_expected) @@ -349,7 +355,8 @@ def test_group_create_array( if method == "create_array": array = group.create_array(name="array", shape=shape, dtype=dtype, data=data) elif method == "array": - array = group.array(name="array", shape=shape, dtype=dtype, data=data) + with pytest.warns(DeprecationWarning): + array = group.array(name="array", shape=shape, dtype=dtype, data=data) else: raise AssertionError @@ -358,7 +365,7 @@ def test_group_create_array( with pytest.raises(ContainsArrayError): group.create_array(name="array", shape=shape, dtype=dtype, data=data) elif method == "array": - with pytest.raises(ContainsArrayError): + with pytest.raises(ContainsArrayError), pytest.warns(DeprecationWarning): group.array(name="array", shape=shape, dtype=dtype, data=data) assert array.shape == shape assert array.dtype == np.dtype(dtype) @@ -653,3 +660,40 @@ async def test_asyncgroup_update_attributes( agroup_new_attributes = await agroup.update_attributes(attributes_new) assert agroup_new_attributes.attrs == attributes_new + + +async def test_group_members_async(store: LocalStore | MemoryStore): + group = AsyncGroup( + GroupMetadata(), + store_path=StorePath(store=store, path="root"), + ) + a0 = await group.create_array("a0", (1,)) + g0 = await group.create_group("g0") + a1 = await g0.create_array("a1", (1,)) + g1 = await g0.create_group("g1") + a2 = await g1.create_array("a2", (1,)) + g2 = await g1.create_group("g2") + + # immediate children + children = sorted([x async for x in group.members()], key=lambda x: x[0]) + assert children == [ + ("a0", a0), + ("g0", g0), + ] + + nmembers = await group.nmembers() + assert nmembers == 2 + + all_children = sorted([x async for x in group.members(recursive=True)], key=lambda x: x[0]) + expected = [ + ("a0", a0), + ("g0", g0), + ("g0/a1", a1), + ("g0/g1", g1), + ("g0/g1/a2", a2), + ("g0/g1/g2", g2), + ] + assert all_children == expected + + nmembers = await group.nmembers(recursive=True) + assert nmembers == 6 From e54c755069dbaf7372afe54f2f6432b46aa6d882 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Sun, 25 Aug 2024 14:27:47 -0500 Subject: [PATCH 04/10] trigger ci From 8bd52d09d8667b07acf68fb417969ba9bc15cf6a Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Sun, 25 Aug 2024 15:22:38 -0500 Subject: [PATCH 05/10] fixed datetime serialization --- src/zarr/core/metadata.py | 22 +++++++++++++++------- src/zarr/testing/strategies.py | 10 +++++++++- tests/v3/test_metadata/test_v3.py | 23 +++++++++++++++++++++++ 3 files changed, 47 insertions(+), 8 deletions(-) diff --git a/src/zarr/core/metadata.py b/src/zarr/core/metadata.py index d541e43205..72172a2673 100644 --- a/src/zarr/core/metadata.py +++ b/src/zarr/core/metadata.py @@ -256,13 +256,21 @@ def _json_convert(o: Any) -> Any: if isinstance(o, np.dtype): return str(o) if np.isscalar(o): - # convert numpy scalar to python type, and pass - # python types through - out = getattr(o, "item", lambda: o)() - if isinstance(out, complex): - # python complex types are not JSON serializable, so we use the - # serialization defined in the zarr v3 spec - return [out.real, out.imag] + out: Any + if hasattr(o, "dtype") and o.dtype.kind == "M" and hasattr(o, "view"): + # https://github.com/zarr-developers/zarr-python/issues/2119 + # `.item()` on a datetime type might or might not return an + # integer, depending on the value. + # Explicitly cast to an int first, and then grab .item() + out = o.view("i8").item() + else: + # convert numpy scalar to python type, and pass + # python types through + out = getattr(o, "item", lambda: o)() + if isinstance(out, complex): + # python complex types are not JSON serializable, so we use the + # serialization defined in the zarr v3 spec + return [out.real, out.imag] return out if isinstance(o, Enum): return o.name diff --git a/src/zarr/testing/strategies.py b/src/zarr/testing/strategies.py index d2e41c6290..00d8a37736 100644 --- a/src/zarr/testing/strategies.py +++ b/src/zarr/testing/strategies.py @@ -1,3 +1,4 @@ +import re from typing import Any import hypothesis.extra.numpy as npst @@ -101,7 +102,14 @@ def arrays( root = Group.create(store) fill_value_args: tuple[Any, ...] = tuple() if nparray.dtype.kind == "M": - fill_value_args = ("ns",) + m = re.search("\[(.+)\]", nparray.dtype.str) + if not m: + raise ValueError(f"Couldn't find precision for dtype '{nparray.dtype}.") + + fill_value_args = ( + # e.g. ns, D + m.groups()[0], + ) a = root.create_array( array_path, diff --git a/tests/v3/test_metadata/test_v3.py b/tests/v3/test_metadata/test_v3.py index eedcdf6234..1a0c5b94d7 100644 --- a/tests/v3/test_metadata/test_v3.py +++ b/tests/v3/test_metadata/test_v3.py @@ -1,10 +1,12 @@ from __future__ import annotations +import json import re from typing import TYPE_CHECKING, Literal from zarr.abc.codec import Codec from zarr.codecs.bytes import BytesCodec +from zarr.core.buffer import default_buffer_prototype from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding, V2ChunkKeyEncoding if TYPE_CHECKING: @@ -230,3 +232,24 @@ def test_metadata_to_dict( observed.pop("chunk_key_encoding") expected.pop("chunk_key_encoding") assert observed == expected + + +@pytest.mark.parametrize("fill_value", [-1, 0, 1, 2932897]) +@pytest.mark.parametrize("precision", ["ns", "D"]) +async def test_datetime_metadata(fill_value: int, precision: str): + metadata_dict = { + "zarr_format": 3, + "node_type": "array", + "shape": (1,), + "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}}, + "data_type": f" Date: Mon, 26 Aug 2024 09:13:30 -0500 Subject: [PATCH 06/10] fixup --- tests/v3/test_properties.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/v3/test_properties.py b/tests/v3/test_properties.py index 9204cdc523..4f0bcf9192 100644 --- a/tests/v3/test_properties.py +++ b/tests/v3/test_properties.py @@ -32,6 +32,7 @@ def test_basic_indexing(data): @given(data=st.data()) +@pytest.mark.filterwarnings("ignore::RuntimeWarning") def test_vindex(data): zarray = data.draw(arrays()) nparray = zarray[:] From ad4fa348a7d530d90f50364624c1c99c73056695 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Mon, 26 Aug 2024 09:35:08 -0500 Subject: [PATCH 07/10] fixed invalid escape sequence --- src/zarr/testing/strategies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zarr/testing/strategies.py b/src/zarr/testing/strategies.py index 00d8a37736..3a460d4fff 100644 --- a/src/zarr/testing/strategies.py +++ b/src/zarr/testing/strategies.py @@ -102,7 +102,7 @@ def arrays( root = Group.create(store) fill_value_args: tuple[Any, ...] = tuple() if nparray.dtype.kind == "M": - m = re.search("\[(.+)\]", nparray.dtype.str) + m = re.search(r"\[(.+)\]", nparray.dtype.str) if not m: raise ValueError(f"Couldn't find precision for dtype '{nparray.dtype}.") From c3d3855e1cda52094c03c6723696be3cbca90ad6 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Mon, 26 Aug 2024 09:43:49 -0500 Subject: [PATCH 08/10] fixup --- tests/v3/test_properties.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/v3/test_properties.py b/tests/v3/test_properties.py index 4f0bcf9192..c978187cf1 100644 --- a/tests/v3/test_properties.py +++ b/tests/v3/test_properties.py @@ -18,6 +18,11 @@ def test_roundtrip(data): @given(data=st.data()) +# The filter warning here is to silence an occasional warning in NDBuffer.all_equal +# See https://github.com/zarr-developers/zarr-python/pull/2118#issuecomment-2310280899 +# Uncomment the next line to reproduce the original failure. +# @reproduce_failure('6.111.2', b'AXicY2FgZGRAB/8/ndR2z7nkDZEDADWpBL4=') +@pytest.mark.filterwarnings("ignore::RuntimeWarning") def test_basic_indexing(data): zarray = data.draw(arrays()) nparray = zarray[:] @@ -32,6 +37,10 @@ def test_basic_indexing(data): @given(data=st.data()) +# The filter warning here is to silence an occasional warning in NDBuffer.all_equal +# See https://github.com/zarr-developers/zarr-python/pull/2118#issuecomment-2310280899 +# Uncomment the next line to reproduce the original failure. +# @reproduce_failure('6.111.2', b'AXicY2FgZGRAB/8/eLmF7qr/C5EDADZUBRM=') @pytest.mark.filterwarnings("ignore::RuntimeWarning") def test_vindex(data): zarray = data.draw(arrays()) From b1c6627793b0f0957add9e04b5797292e3b94b57 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Tue, 27 Aug 2024 12:36:04 -0500 Subject: [PATCH 09/10] max_depth --- src/zarr/core/group.py | 53 ++++++++++++++++++++++++++++-------------- tests/v3/test_group.py | 41 +++++++++++++++++++++++++------- 2 files changed, 68 insertions(+), 26 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 4becaf940b..1c4237aeb5 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -424,16 +424,21 @@ async def update_attributes(self, new_attributes: dict[str, Any]) -> AsyncGroup: def __repr__(self) -> str: return f"" - async def nmembers(self, recursive: bool = False) -> int: + async def nmembers( + self, + max_depth: int | None = 0, + ) -> int: """ Count the number of members in this group. Parameters ---------- - recursive : bool, default False - Whether to recursively count arrays and groups in child groups of - this Group. By default, just immediate child array and group members - are counted. + max_depth : int, default 0 + The maximum number of levels of the hierarchy to include. By + default, (``max_depth=0``) only immediate children are included. Set + ``max_depth=None`` or ``max_depth=-1`` to include all nodes, and + some positive integer to consider children within that many levels + of the root Group. Returns ------- @@ -442,12 +447,13 @@ async def nmembers(self, recursive: bool = False) -> int: # TODO: consider using aioitertools.builtins.sum for this # return await aioitertools.builtins.sum((1 async for _ in self.members()), start=0) n = 0 - async for _ in self.members(recursive=recursive): + async for _ in self.members(max_depth=max_depth): n += 1 return n async def members( - self, recursive: bool = False + self, + max_depth: int | None = 0, ) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], None]: """ Returns an AsyncGenerator over the arrays and groups contained in this group. @@ -457,11 +463,22 @@ async def members( Parameters ---------- - recursive : bool, default False - Whether to recursively include arrays and groups in child groups of - this Group. By default, just immediate child array and group members - are included. + max_depth : int, default 0 + The maximum number of levels of the hierarchy to include. By + default, (``max_depth=0``) only immediate children are included. Set + ``max_depth=None`` or ``max_depth=-1`` to include all nodes, and + some positive integer to consider children within that many levels + of the root Group. + """ + if max_depth is None: + max_depth = -1 + async for item in self._members(max_depth=max_depth, current_depth=0): + yield item + + async def _members( + self, max_depth: int, current_depth: int + ) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], None]: if not self.store_path.store.supports_listing: msg = ( f"The store associated with this group ({type(self.store_path.store)}) " @@ -483,14 +500,16 @@ async def members( yield (key, obj) if ( - recursive + ((current_depth < max_depth) or (max_depth < 0)) and hasattr(obj.metadata, "node_type") and obj.metadata.node_type == "group" ): # the assert is just for mypy to know that `obj.metadata.node_type` # implies an AsyncGroup, not an AsyncArray assert isinstance(obj, AsyncGroup) - async for child_key, val in obj.members(recursive=recursive): + async for child_key, val in obj._members( + max_depth=max_depth, current_depth=current_depth + 1 + ): yield "/".join([key, child_key]), val except KeyError: # keyerror is raised when `key` names an object (in the object storage sense), @@ -663,15 +682,15 @@ def update_attributes(self, new_attributes: dict[str, Any]) -> Group: self._sync(self._async_group.update_attributes(new_attributes)) return self - def nmembers(self, recursive: bool = False) -> int: - return self._sync(self._async_group.nmembers(recursive=recursive)) + def nmembers(self, max_depth: int | None = 0) -> int: + return self._sync(self._async_group.nmembers(max_depth=max_depth)) - def members(self, recursive: bool = False) -> tuple[tuple[str, Array | Group], ...]: + def members(self, max_depth: int | None = 0) -> tuple[tuple[str, Array | Group], ...]: """ Return the sub-arrays and sub-groups of this group as a tuple of (name, array | group) pairs """ - _members = self._sync_iter(self._async_group.members(recursive=recursive)) + _members = self._sync_iter(self._async_group.members(max_depth=max_depth)) result = tuple(map(lambda kv: (kv[0], _parse_async_node(kv[1])), _members)) return result diff --git a/tests/v3/test_group.py b/tests/v3/test_group.py index efc53352fe..9767a8b39d 100644 --- a/tests/v3/test_group.py +++ b/tests/v3/test_group.py @@ -89,6 +89,7 @@ def test_group_members(store: MemoryStore | LocalStore, zarr_format: ZarrFormat) # make a sub-sub-subgroup, to ensure that the children calculation doesn't go # too deep in the hierarchy subsubgroup = members_expected["subgroup"].create_group("subsubgroup") # type: ignore + subsubsubgroup = subsubgroup.create_group("subsubsubgroup") # type: ignore members_expected["subarray"] = group.create_array( "subarray", shape=(100,), dtype="uint8", chunk_shape=(10,), exists_ok=True @@ -105,12 +106,18 @@ def test_group_members(store: MemoryStore | LocalStore, zarr_format: ZarrFormat) # members are not guaranteed to be ordered, so sort before comparing assert sorted(dict(members_observed)) == sorted(members_expected) - # recursive=True - members_observed = group.members(recursive=True) + # partial + members_observed = group.members(max_depth=1) members_expected["subgroup/subsubgroup"] = subsubgroup # members are not guaranteed to be ordered, so sort before comparing assert sorted(dict(members_observed)) == sorted(members_expected) + # total + members_observed = group.members(max_depth=-1) + members_expected["subgroup/subsubgroup/subsubsubgroup"] = subsubsubgroup + # members are not guaranteed to be ordered, so sort before comparing + assert sorted(dict(members_observed)) == sorted(members_expected) + def test_group(store: MemoryStore | LocalStore, zarr_format: ZarrFormat) -> None: """ @@ -684,16 +691,32 @@ async def test_group_members_async(store: LocalStore | MemoryStore): nmembers = await group.nmembers() assert nmembers == 2 - all_children = sorted([x async for x in group.members(recursive=True)], key=lambda x: x[0]) + # partial + children = sorted([x async for x in group.members(max_depth=1)], key=lambda x: x[0]) expected = [ ("a0", a0), ("g0", g0), ("g0/a1", a1), ("g0/g1", g1), - ("g0/g1/a2", a2), - ("g0/g1/g2", g2), ] - assert all_children == expected - - nmembers = await group.nmembers(recursive=True) - assert nmembers == 6 + assert children == expected + nmembers = await group.nmembers(max_depth=1) + assert nmembers == 4 + + # all children + for max_depth in [-1, None]: + all_children = sorted( + [x async for x in group.members(max_depth=max_depth)], key=lambda x: x[0] + ) + expected = [ + ("a0", a0), + ("g0", g0), + ("g0/a1", a1), + ("g0/g1", g1), + ("g0/g1/a2", a2), + ("g0/g1/g2", g2), + ] + assert all_children == expected + + nmembers = await group.nmembers(max_depth=max_depth) + assert nmembers == 6 From 82853b1ab5f161969ee9b519f29c05260079921d Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Wed, 28 Aug 2024 12:39:43 -0500 Subject: [PATCH 10/10] max_depth=None --- src/zarr/core/group.py | 18 ++++++++---------- tests/v3/test_group.py | 37 ++++++++++++++++++++----------------- 2 files changed, 28 insertions(+), 27 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 1c4237aeb5..2c26cac3b1 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -436,9 +436,8 @@ async def nmembers( max_depth : int, default 0 The maximum number of levels of the hierarchy to include. By default, (``max_depth=0``) only immediate children are included. Set - ``max_depth=None`` or ``max_depth=-1`` to include all nodes, and - some positive integer to consider children within that many levels - of the root Group. + ``max_depth=None`` to include all nodes, and some positive integer + to consider children within that many levels of the root Group. Returns ------- @@ -466,18 +465,17 @@ async def members( max_depth : int, default 0 The maximum number of levels of the hierarchy to include. By default, (``max_depth=0``) only immediate children are included. Set - ``max_depth=None`` or ``max_depth=-1`` to include all nodes, and - some positive integer to consider children within that many levels - of the root Group. + ``max_depth=None`` to include all nodes, and some positive integer + to consider children within that many levels of the root Group. """ - if max_depth is None: - max_depth = -1 + if max_depth is not None and max_depth < 0: + raise ValueError(f"max_depth must be None or >= 0. Got '{max_depth}' instead") async for item in self._members(max_depth=max_depth, current_depth=0): yield item async def _members( - self, max_depth: int, current_depth: int + self, max_depth: int | None, current_depth: int ) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], None]: if not self.store_path.store.supports_listing: msg = ( @@ -500,7 +498,7 @@ async def _members( yield (key, obj) if ( - ((current_depth < max_depth) or (max_depth < 0)) + ((max_depth is None) or (current_depth < max_depth)) and hasattr(obj.metadata, "node_type") and obj.metadata.node_type == "group" ): diff --git a/tests/v3/test_group.py b/tests/v3/test_group.py index 9767a8b39d..eb7b1f30dd 100644 --- a/tests/v3/test_group.py +++ b/tests/v3/test_group.py @@ -113,11 +113,14 @@ def test_group_members(store: MemoryStore | LocalStore, zarr_format: ZarrFormat) assert sorted(dict(members_observed)) == sorted(members_expected) # total - members_observed = group.members(max_depth=-1) + members_observed = group.members(max_depth=None) members_expected["subgroup/subsubgroup/subsubsubgroup"] = subsubsubgroup # members are not guaranteed to be ordered, so sort before comparing assert sorted(dict(members_observed)) == sorted(members_expected) + with pytest.raises(ValueError, match="max_depth"): + members_observed = group.members(max_depth=-1) + def test_group(store: MemoryStore | LocalStore, zarr_format: ZarrFormat) -> None: """ @@ -704,19 +707,19 @@ async def test_group_members_async(store: LocalStore | MemoryStore): assert nmembers == 4 # all children - for max_depth in [-1, None]: - all_children = sorted( - [x async for x in group.members(max_depth=max_depth)], key=lambda x: x[0] - ) - expected = [ - ("a0", a0), - ("g0", g0), - ("g0/a1", a1), - ("g0/g1", g1), - ("g0/g1/a2", a2), - ("g0/g1/g2", g2), - ] - assert all_children == expected - - nmembers = await group.nmembers(max_depth=max_depth) - assert nmembers == 6 + all_children = sorted([x async for x in group.members(max_depth=None)], key=lambda x: x[0]) + expected = [ + ("a0", a0), + ("g0", g0), + ("g0/a1", a1), + ("g0/g1", g1), + ("g0/g1/a2", a2), + ("g0/g1/g2", g2), + ] + assert all_children == expected + + nmembers = await group.nmembers(max_depth=None) + assert nmembers == 6 + + with pytest.raises(ValueError, match="max_depth"): + [x async for x in group.members(max_depth=-1)]