Skip to content

Commit 0e035fb

Browse files
committed
feature(store): list_* -> AsyncGenerators
- add zarr.testing.store module to support downstream use cases - set pytest-asyncio mode to auto -
1 parent 19a28df commit 0e035fb

File tree

8 files changed

+199
-64
lines changed

8 files changed

+199
-64
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ disallow_untyped_calls = false
219219

220220

221221
[tool.pytest.ini_options]
222+
asyncio_mode = "auto"
222223
doctest_optionflags = [
223224
"NORMALIZE_WHITESPACE",
224225
"ELLIPSIS",

src/zarr/abc/store.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from abc import abstractmethod, ABC
22

3+
from collections.abc import AsyncGenerator
34
from typing import List, Tuple, Optional
45

56

@@ -24,7 +25,7 @@ async def get(
2425
@abstractmethod
2526
async def get_partial_values(
2627
self, key_ranges: List[Tuple[str, Tuple[int, int]]]
27-
) -> List[bytes]:
28+
) -> List[Optional[bytes]]:
2829
"""Retrieve possibly partial values from given key_ranges.
2930
3031
Parameters
@@ -106,17 +107,17 @@ def supports_listing(self) -> bool:
106107
...
107108

108109
@abstractmethod
109-
async def list(self) -> List[str]:
110+
def list(self) -> AsyncGenerator[str, None]:
110111
"""Retrieve all keys in the store.
111112
112113
Returns
113114
-------
114-
list[str]
115+
AsyncGenerator[str, None]
115116
"""
116117
...
117118

118119
@abstractmethod
119-
async def list_prefix(self, prefix: str) -> List[str]:
120+
def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
120121
"""Retrieve all keys in the store with a given prefix.
121122
122123
Parameters
@@ -125,12 +126,12 @@ async def list_prefix(self, prefix: str) -> List[str]:
125126
126127
Returns
127128
-------
128-
list[str]
129+
AsyncGenerator[str, None]
129130
"""
130131
...
131132

132133
@abstractmethod
133-
async def list_dir(self, prefix: str) -> List[str]:
134+
def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
134135
"""
135136
Retrieve all keys and prefixes with a given prefix and which do not contain the character
136137
“/” after the given prefix.
@@ -141,6 +142,6 @@ async def list_dir(self, prefix: str) -> List[str]:
141142
142143
Returns
143144
-------
144-
list[str]
145+
AsyncGenerator[str, None]
145146
"""
146147
...

src/zarr/group.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -306,20 +306,20 @@ async def members(self) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], N
306306
)
307307

308308
raise ValueError(msg)
309-
subkeys = await self.store_path.store.list_dir(self.store_path.path)
310309
# would be nice to make these special keys accessible programmatically,
311310
# and scoped to specific zarr versions
312-
subkeys_filtered = filter(lambda v: v not in ("zarr.json", ".zgroup", ".zattrs"), subkeys)
313-
# is there a better way to schedule this?
314-
for subkey in subkeys_filtered:
311+
_skip_keys = ("zarr.json", ".zgroup", ".zattrs")
312+
async for key in self.store_path.store.list_dir(self.store_path.path):
313+
if key in _skip_keys:
314+
continue
315315
try:
316-
yield (subkey, await self.getitem(subkey))
316+
yield (key, await self.getitem(key))
317317
except KeyError:
318-
# keyerror is raised when `subkey` names an object (in the object storage sense),
318+
# keyerror is raised when `key` names an object (in the object storage sense),
319319
# as opposed to a prefix, in the store under the prefix associated with this group
320-
# in which case `subkey` cannot be the name of a sub-array or sub-group.
320+
# in which case `key` cannot be the name of a sub-array or sub-group.
321321
logger.warning(
322-
"Object at %s is not recognized as a component of a Zarr hierarchy.", subkey
322+
"Object at %s is not recognized as a component of a Zarr hierarchy.", key
323323
)
324324
pass
325325

src/zarr/store/local.py

Lines changed: 53 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import io
44
import shutil
5+
from collections.abc import AsyncGenerator
56
from pathlib import Path
67
from typing import Union, Optional, List, Tuple
78

@@ -10,8 +11,24 @@
1011

1112

1213
def _get(path: Path, byte_range: Optional[Tuple[int, Optional[int]]] = None) -> bytes:
14+
"""
15+
Fetch a contiguous region of bytes from a file.
16+
Parameters
17+
----------
18+
path: Path
19+
The file to read bytes from.
20+
byte_range: Optional[Tuple[int, Optional[int]]] = None
21+
The range of bytes to read. If `byte_range` is `None`, then the entire file will be read.
22+
If `byte_range` is a tuple, the first value specifies the index of the first byte to read,
23+
and the second value specifies the total number of bytes to read. If the total value is
24+
`None`, then the entire file after the first byte will be read.
25+
"""
1326
if byte_range is not None:
14-
start = byte_range[0]
27+
if byte_range[0] is None:
28+
start = 0
29+
else:
30+
start = byte_range[0]
31+
1532
end = (start + byte_range[1]) if byte_range[1] is not None else None
1633
else:
1734
return path.read_bytes()
@@ -84,21 +101,28 @@ async def get(
84101

85102
async def get_partial_values(
86103
self, key_ranges: List[Tuple[str, Tuple[int, int]]]
87-
) -> List[bytes]:
104+
) -> List[Optional[bytes]]:
105+
"""
106+
Read byte ranges from multiple keys.
107+
Parameters
108+
----------
109+
key_ranges: List[Tuple[str, Tuple[int, int]]]
110+
A list of (key, (start, length)) tuples. The first element of the tuple is the name of
111+
the key in storage to fetch bytes from. The second element the tuple defines the byte
112+
range to retrieve. These values are arguments to `get`, as this method wraps
113+
concurrent invocation of `get`.
114+
"""
88115
args = []
89116
for key, byte_range in key_ranges:
90117
assert isinstance(key, str)
91118
path = self.root / key
92-
if byte_range is not None:
93-
args.append((_get, path, byte_range[0], byte_range[1]))
94-
else:
95-
args.append((_get, path))
119+
args.append((_get, path, byte_range))
96120
return await concurrent_map(args, to_thread, limit=None) # TODO: fix limit
97121

98122
async def set(self, key: str, value: BytesLike) -> None:
99123
assert isinstance(key, str)
100124
path = self.root / key
101-
await to_thread(_put, path, value)
125+
await to_thread(_put, path, value, auto_mkdir=self.auto_mkdir)
102126

103127
async def set_partial_values(self, key_start_values: List[Tuple[str, int, bytes]]) -> None:
104128
args = []
@@ -122,22 +146,19 @@ async def exists(self, key: str) -> bool:
122146
path = self.root / key
123147
return await to_thread(path.is_file)
124148

125-
async def list(self) -> List[str]:
149+
async def list(self) -> AsyncGenerator[str, None]:
126150
"""Retrieve all keys in the store.
127151
128152
Returns
129153
-------
130-
list[str]
154+
AsyncGenerator[str, None]
131155
"""
156+
to_strip = str(self.root) + "/"
157+
for p in list(self.root.rglob("*")):
158+
if p.is_file():
159+
yield str(p).replace(to_strip, "")
132160

133-
# Q: do we want to return strings or Paths?
134-
def _list(root: Path) -> List[str]:
135-
files = [str(p) for p in root.rglob("") if p.is_file()]
136-
return files
137-
138-
return await to_thread(_list, self.root)
139-
140-
async def list_prefix(self, prefix: str) -> List[str]:
161+
async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
141162
"""Retrieve all keys in the store with a given prefix.
142163
143164
Parameters
@@ -146,16 +167,15 @@ async def list_prefix(self, prefix: str) -> List[str]:
146167
147168
Returns
148169
-------
149-
list[str]
170+
AsyncGenerator[str, None]
150171
"""
151172

152-
def _list_prefix(root: Path, prefix: str) -> List[str]:
153-
files = [str(p) for p in (root / prefix).rglob("*") if p.is_file()]
154-
return files
173+
to_strip = str(self.root) + "/"
174+
for p in (self.root / prefix).rglob("*"):
175+
if p.is_file():
176+
yield str(p).replace(to_strip, "")
155177

156-
return await to_thread(_list_prefix, self.root, prefix)
157-
158-
async def list_dir(self, prefix: str) -> List[str]:
178+
async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
159179
"""
160180
Retrieve all keys and prefixes with a given prefix and which do not contain the character
161181
“/” after the given prefix.
@@ -166,15 +186,15 @@ async def list_dir(self, prefix: str) -> List[str]:
166186
167187
Returns
168188
-------
169-
list[str]
189+
AsyncGenerator[str, None]
170190
"""
171191

172-
def _list_dir(root: Path, prefix: str) -> List[str]:
173-
base = root / prefix
174-
to_strip = str(base) + "/"
175-
try:
176-
return [str(key).replace(to_strip, "") for key in base.iterdir()]
177-
except (FileNotFoundError, NotADirectoryError):
178-
return []
192+
base = self.root / prefix
193+
to_strip = str(base) + "/"
179194

180-
return await to_thread(_list_dir, self.root, prefix)
195+
try:
196+
key_iter = base.iterdir()
197+
for key in key_iter:
198+
yield str(key).replace(to_strip, "")
199+
except (FileNotFoundError, NotADirectoryError):
200+
pass

src/zarr/store/memory.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

3+
from collections.abc import AsyncGenerator
34
from typing import Optional, MutableMapping, List, Tuple
45

5-
from zarr.common import BytesLike
6+
from zarr.common import BytesLike, concurrent_map
67
from zarr.abc.store import Store
78

89

@@ -38,8 +39,9 @@ async def get(
3839

3940
async def get_partial_values(
4041
self, key_ranges: List[Tuple[str, Tuple[int, int]]]
41-
) -> List[bytes]:
42-
raise NotImplementedError
42+
) -> List[Optional[BytesLike]]:
43+
vals = await concurrent_map(key_ranges, self.get, limit=None)
44+
return vals
4345

4446
async def exists(self, key: str) -> bool:
4547
return key in self._store_dict
@@ -67,20 +69,23 @@ async def delete(self, key: str) -> None:
6769
async def set_partial_values(self, key_start_values: List[Tuple[str, int, bytes]]) -> None:
6870
raise NotImplementedError
6971

70-
async def list(self) -> List[str]:
71-
return list(self._store_dict.keys())
72+
async def list(self) -> AsyncGenerator[str, None]:
73+
for key in self._store_dict:
74+
yield key
75+
76+
async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
77+
for key in self._store_dict:
78+
if key.startswith(prefix):
79+
yield key
7280

73-
async def list_prefix(self, prefix: str) -> List[str]:
74-
return [key for key in self._store_dict if key.startswith(prefix)]
81+
async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
82+
if prefix.endswith("/"):
83+
prefix = prefix[:-1]
7584

76-
async def list_dir(self, prefix: str) -> List[str]:
7785
if prefix == "":
78-
return list({key.split("/", maxsplit=1)[0] for key in self._store_dict})
86+
for key in self._store_dict:
87+
yield key.split("/", maxsplit=1)[0]
7988
else:
80-
return list(
81-
{
82-
key.strip(prefix + "/").split("/")[0]
83-
for key in self._store_dict
84-
if (key.startswith(prefix + "/") and key != prefix)
85-
}
86-
)
89+
for key in self._store_dict:
90+
if key.startswith(prefix + "/") and key != prefix:
91+
yield key.strip(prefix + "/").split("/")[0]

src/zarr/testing/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import importlib.util
2+
import warnings
3+
4+
if importlib.util.find_spec("pytest") is not None:
5+
from zarr.testing.store import StoreTests
6+
else:
7+
warnings.warn("pytest not installed, skipping test suite")
8+
9+
__all__ = ["StoreTests"]

src/zarr/testing/store.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import pytest
2+
3+
from zarr.abc.store import Store
4+
5+
6+
class StoreTests:
7+
store_cls: type[Store]
8+
9+
@pytest.fixture(scope="function")
10+
def store(self) -> Store:
11+
return self.store_cls()
12+
13+
def test_store_type(self, store: Store) -> None:
14+
assert isinstance(store, Store)
15+
assert isinstance(store, self.store_cls)
16+
17+
def test_store_repr(self, store: Store) -> None:
18+
assert repr(store)
19+
20+
def test_store_capabilities(self, store: Store) -> None:
21+
assert store.supports_writes
22+
assert store.supports_partial_writes
23+
assert store.supports_listing
24+
25+
@pytest.mark.parametrize("key", ["c/0", "foo/c/0.0", "foo/0/0"])
26+
@pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""])
27+
async def test_set_get_bytes_roundtrip(self, store: Store, key: str, data: bytes) -> None:
28+
await store.set(key, data)
29+
assert await store.get(key) == data
30+
31+
@pytest.mark.parametrize("key", ["foo/c/0"])
32+
@pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""])
33+
async def test_get_partial_values(self, store: Store, key: str, data: bytes) -> None:
34+
# put all of the data
35+
await store.set(key, data)
36+
# read back just part of it
37+
vals = await store.get_partial_values([(key, (0, 2))])
38+
assert vals == [data[0:2]]
39+
40+
# read back multiple parts of it at once
41+
vals = await store.get_partial_values([(key, (0, 2)), (key, (2, 4))])
42+
assert vals == [data[0:2], data[2:4]]
43+
44+
async def test_exists(self, store: Store) -> None:
45+
assert not await store.exists("foo")
46+
await store.set("foo/zarr.json", b"bar")
47+
assert await store.exists("foo/zarr.json")
48+
49+
async def test_delete(self, store: Store) -> None:
50+
await store.set("foo/zarr.json", b"bar")
51+
assert await store.exists("foo/zarr.json")
52+
await store.delete("foo/zarr.json")
53+
assert not await store.exists("foo/zarr.json")
54+
55+
async def test_list(self, store: Store) -> None:
56+
assert [k async for k in store.list()] == []
57+
await store.set("foo/zarr.json", b"bar")
58+
keys = [k async for k in store.list()]
59+
assert keys == ["foo/zarr.json"], keys
60+
61+
expected = ["foo/zarr.json"]
62+
for i in range(10):
63+
key = f"foo/c/{i}"
64+
expected.append(key)
65+
await store.set(f"foo/c/{i}", i.to_bytes(length=3, byteorder="little"))
66+
67+
async def test_list_prefix(self, store: Store) -> None:
68+
# TODO: we currently don't use list_prefix anywhere
69+
pass
70+
71+
async def test_list_dir(self, store: Store) -> None:
72+
assert [k async for k in store.list_dir("")] == []
73+
assert [k async for k in store.list_dir("foo")] == []
74+
await store.set("foo/zarr.json", b"bar")
75+
await store.set("foo/c/1", b"\x01")
76+
77+
keys = [k async for k in store.list_dir("foo")]
78+
assert keys == ["zarr.json", "c"], keys
79+
80+
keys = [k async for k in store.list_dir("foo/")]
81+
assert keys == ["zarr.json", "c"], keys

0 commit comments

Comments
 (0)