Skip to content

Commit c878da2

Browse files
Make Group.arrays, groups compatible with v2 (#2213)
Defines a set of array / group iterators. - .groups / .arrays: over (name, value) pairs - .group_keys / .array_keys: over keys - .group_values / .array_values: over values Co-authored-by: Joe Hamman <[email protected]>
1 parent 32540b4 commit c878da2

File tree

2 files changed

+59
-52
lines changed

2 files changed

+59
-52
lines changed

src/zarr/core/group.py

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from zarr.store.common import ensure_no_existing_node
3434

3535
if TYPE_CHECKING:
36-
from collections.abc import AsyncGenerator, Iterable, Iterator
36+
from collections.abc import AsyncGenerator, Generator, Iterable, Iterator
3737
from typing import Any
3838

3939
from zarr.abc.codec import Codec
@@ -678,29 +678,31 @@ async def contains(self, member: str) -> bool:
678678
else:
679679
return True
680680

681-
# todo: decide if this method should be separate from `groups`
682-
async def group_keys(self) -> AsyncGenerator[str, None]:
683-
async for key, value in self.members():
681+
async def groups(self) -> AsyncGenerator[tuple[str, AsyncGroup], None]:
682+
async for name, value in self.members():
684683
if isinstance(value, AsyncGroup):
685-
yield key
684+
yield name, value
686685

687-
# todo: decide if this method should be separate from `group_keys`
688-
async def groups(self) -> AsyncGenerator[AsyncGroup, None]:
689-
async for _, value in self.members():
690-
if isinstance(value, AsyncGroup):
691-
yield value
686+
async def group_keys(self) -> AsyncGenerator[str, None]:
687+
async for key, _ in self.groups():
688+
yield key
692689

693-
# todo: decide if this method should be separate from `arrays`
694-
async def array_keys(self) -> AsyncGenerator[str, None]:
690+
async def group_values(self) -> AsyncGenerator[AsyncGroup, None]:
691+
async for _, group in self.groups():
692+
yield group
693+
694+
async def arrays(self) -> AsyncGenerator[tuple[str, AsyncArray], None]:
695695
async for key, value in self.members():
696696
if isinstance(value, AsyncArray):
697-
yield key
697+
yield key, value
698698

699-
# todo: decide if this method should be separate from `array_keys`
700-
async def arrays(self) -> AsyncGenerator[AsyncArray, None]:
701-
async for _, value in self.members():
702-
if isinstance(value, AsyncArray):
703-
yield value
699+
async def array_keys(self) -> AsyncGenerator[str, None]:
700+
async for key, _ in self.arrays():
701+
yield key
702+
703+
async def array_values(self) -> AsyncGenerator[AsyncArray, None]:
704+
async for _, array in self.arrays():
705+
yield array
704706

705707
async def tree(self, expand: bool = False, level: int | None = None) -> Any:
706708
raise NotImplementedError
@@ -861,18 +863,29 @@ def members(self, max_depth: int | None = 0) -> tuple[tuple[str, Array | Group],
861863
def __contains__(self, member: str) -> bool:
862864
return self._sync(self._async_group.contains(member))
863865

864-
def group_keys(self) -> tuple[str, ...]:
865-
return tuple(self._sync_iter(self._async_group.group_keys()))
866+
def groups(self) -> Generator[tuple[str, Group], None]:
867+
for name, async_group in self._sync_iter(self._async_group.groups()):
868+
yield name, Group(async_group)
869+
870+
def group_keys(self) -> Generator[str, None]:
871+
for name, _ in self.groups():
872+
yield name
873+
874+
def group_values(self) -> Generator[Group, None]:
875+
for _, group in self.groups():
876+
yield group
866877

867-
def groups(self) -> tuple[Group, ...]:
868-
# TODO: in v2 this was a generator that return key: Group
869-
return tuple(Group(obj) for obj in self._sync_iter(self._async_group.groups()))
878+
def arrays(self) -> Generator[tuple[str, Array], None]:
879+
for name, async_array in self._sync_iter(self._async_group.arrays()):
880+
yield name, Array(async_array)
870881

871-
def array_keys(self) -> tuple[str, ...]:
872-
return tuple(self._sync_iter(self._async_group.array_keys()))
882+
def array_keys(self) -> Generator[str, None]:
883+
for name, _ in self.arrays():
884+
yield name
873885

874-
def arrays(self) -> tuple[Array, ...]:
875-
return tuple(Array(obj) for obj in self._sync_iter(self._async_group.arrays()))
886+
def array_values(self) -> Generator[Array, None]:
887+
for _, array in self.arrays():
888+
yield array
876889

877890
def tree(self, expand: bool = False, level: int | None = None) -> Any:
878891
return self._sync(self._async_group.tree(expand=expand, level=level))

tests/v3/test_group.py

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -301,34 +301,28 @@ def test_group_contains(store: Store, zarr_format: ZarrFormat) -> None:
301301
assert "foo" in group
302302

303303

304-
def test_group_subgroups(store: Store, zarr_format: ZarrFormat) -> None:
305-
"""
306-
Test the behavior of `Group` methods for accessing subgroups, namely `Group.group_keys` and `Group.groups`
307-
"""
304+
def test_group_child_iterators(store: Store, zarr_format: ZarrFormat):
308305
group = Group.create(store, zarr_format=zarr_format)
309-
keys = ("foo", "bar")
310-
subgroups_expected = tuple(group.create_group(k) for k in keys)
311-
# create a sub-array as well
312-
_ = group.create_array("array", shape=(10,))
313-
subgroups_observed = group.groups()
314-
assert set(group.group_keys()) == set(keys)
315-
assert len(subgroups_observed) == len(subgroups_expected)
316-
assert all(a in subgroups_observed for a in subgroups_expected)
306+
expected_group_keys = ["g0", "g1"]
307+
expected_group_values = [group.create_group(name=name) for name in expected_group_keys]
308+
expected_groups = list(zip(expected_group_keys, expected_group_values, strict=False))
317309

310+
expected_group_values[0].create_group("subgroup")
311+
expected_group_values[0].create_array("subarray", shape=(1,))
318312

319-
def test_group_subarrays(store: Store, zarr_format: ZarrFormat) -> None:
320-
"""
321-
Test the behavior of `Group` methods for accessing subgroups, namely `Group.group_keys` and `Group.groups`
322-
"""
323-
group = Group.create(store, zarr_format=zarr_format)
324-
keys = ("foo", "bar")
325-
subarrays_expected = tuple(group.create_array(k, shape=(10,)) for k in keys)
326-
# create a sub-group as well
327-
_ = group.create_group("group")
328-
subarrays_observed = group.arrays()
329-
assert set(group.array_keys()) == set(keys)
330-
assert len(subarrays_observed) == len(subarrays_expected)
331-
assert all(a in subarrays_observed for a in subarrays_expected)
313+
expected_array_keys = ["a0", "a1"]
314+
expected_array_values = [
315+
group.create_array(name=name, shape=(1,)) for name in expected_array_keys
316+
]
317+
expected_arrays = list(zip(expected_array_keys, expected_array_values, strict=False))
318+
319+
assert sorted(group.groups(), key=lambda x: x[0]) == expected_groups
320+
assert sorted(group.group_keys()) == expected_group_keys
321+
assert sorted(group.group_values(), key=lambda x: x.name) == expected_group_values
322+
323+
assert sorted(group.arrays(), key=lambda x: x[0]) == expected_arrays
324+
assert sorted(group.array_keys()) == expected_array_keys
325+
assert sorted(group.array_values(), key=lambda x: x.name) == expected_array_values
332326

333327

334328
def test_group_update_attributes(store: Store, zarr_format: ZarrFormat) -> None:

0 commit comments

Comments
 (0)