Skip to content

Commit 783f916

Browse files
authored
Merge branch 'v3' into generalize-stateful-store
2 parents 7e8c1c2 + 06e3215 commit 783f916

File tree

11 files changed

+151
-91
lines changed

11 files changed

+151
-91
lines changed

src/zarr/abc/store.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from abc import ABC, abstractmethod
2-
from collections.abc import AsyncGenerator
2+
from asyncio import gather
3+
from collections.abc import AsyncGenerator, Iterable
34
from typing import Any, NamedTuple, Protocol, runtime_checkable
45

56
from typing_extensions import Self
@@ -158,6 +159,13 @@ async def set(self, key: str, value: Buffer) -> None:
158159
"""
159160
...
160161

162+
async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None:
163+
"""
164+
Insert multiple (key, value) pairs into storage.
165+
"""
166+
await gather(*(self.set(key, value) for key, value in values))
167+
return None
168+
161169
@property
162170
@abstractmethod
163171
def supports_deletes(self) -> bool:
@@ -211,7 +219,9 @@ def list(self) -> AsyncGenerator[str, None]:
211219

212220
@abstractmethod
213221
def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
214-
"""Retrieve all keys in the store with a given prefix.
222+
"""
223+
Retrieve all keys in the store that begin with a given prefix. Keys are returned with the
224+
common leading prefix removed.
215225
216226
Parameters
217227
----------

src/zarr/core/array.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ async def _create_v2(
313313
chunks=chunks,
314314
order=order,
315315
dimension_separator=dimension_separator,
316-
fill_value=fill_value,
316+
fill_value=0 if fill_value is None else fill_value,
317317
compressor=compressor,
318318
filters=filters,
319319
attributes=attributes,

src/zarr/core/common.py

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
ZGROUP_JSON = ".zgroup"
2828
ZATTRS_JSON = ".zattrs"
2929

30+
ByteRangeRequest = tuple[int | None, int | None]
3031
BytesLike = bytes | bytearray | memoryview
3132
ShapeLike = tuple[int, ...] | int
3233
ChunkCoords = tuple[int, ...]

src/zarr/core/sync.py

+17
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,23 @@ def _get_loop() -> asyncio.AbstractEventLoop:
113113
return loop[0]
114114

115115

116+
async def _collect_aiterator(data: AsyncIterator[T]) -> tuple[T, ...]:
117+
"""
118+
Collect an entire async iterator into a tuple
119+
"""
120+
result = []
121+
async for x in data:
122+
result.append(x)
123+
return tuple(result)
124+
125+
126+
def collect_aiterator(data: AsyncIterator[T]) -> tuple[T, ...]:
127+
"""
128+
Synchronously collect an entire async iterator into a tuple.
129+
"""
130+
return sync(_collect_aiterator(data))
131+
132+
116133
class SyncMixin:
117134
def _sync(self, coroutine: Coroutine[Any, Any, T]) -> T:
118135
# TODO: refactor this to to take *args and **kwargs and pass those to the method

src/zarr/store/local.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,9 @@ async def list(self) -> AsyncGenerator[str, None]:
191191
yield str(p).replace(to_strip, "")
192192

193193
async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
194-
"""Retrieve all keys in the store with a given prefix.
194+
"""
195+
Retrieve all keys in the store that begin with a given prefix. Keys are returned with the
196+
common leading prefix removed.
195197
196198
Parameters
197199
----------
@@ -201,14 +203,10 @@ async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
201203
-------
202204
AsyncGenerator[str, None]
203205
"""
206+
to_strip = os.path.join(str(self.root / prefix))
204207
for p in (self.root / prefix).rglob("*"):
205208
if p.is_file():
206-
yield str(p)
207-
208-
to_strip = str(self.root) + "/"
209-
for p in (self.root / prefix).rglob("*"):
210-
if p.is_file():
211-
yield str(p).replace(to_strip, "")
209+
yield str(p.relative_to(to_strip))
212210

213211
async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
214212
"""

src/zarr/store/memory.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,21 @@ async def list(self) -> AsyncGenerator[str, None]:
124124
async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
125125
for key in self._store_dict:
126126
if key.startswith(prefix):
127-
yield key
127+
yield key.removeprefix(prefix)
128128

129129
async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
130+
"""
131+
Retrieve all keys in the store that begin with a given prefix. Keys are returned with the
132+
common leading prefix removed.
133+
134+
Parameters
135+
----------
136+
prefix : str
137+
138+
Returns
139+
-------
140+
AsyncGenerator[str, None]
141+
"""
130142
if prefix.endswith("/"):
131143
prefix = prefix[:-1]
132144

src/zarr/store/remote.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -216,5 +216,19 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
216216
yield onefile.removeprefix(self.path).removeprefix("/")
217217

218218
async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
219-
for onefile in await self._fs._ls(prefix, detail=False):
220-
yield onefile
219+
"""
220+
Retrieve all keys in the store that begin with a given prefix. Keys are returned with the
221+
common leading prefix removed.
222+
223+
Parameters
224+
----------
225+
prefix : str
226+
227+
Returns
228+
-------
229+
AsyncGenerator[str, None]
230+
"""
231+
232+
find_str = "/".join([self.path, prefix])
233+
for onefile in await self._fs._find(find_str, detail=False, maxdepth=None, withdirs=False):
234+
yield onefile.removeprefix(find_str)

src/zarr/store/zip.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,21 @@ async def list(self) -> AsyncGenerator[str, None]:
209209
yield key
210210

211211
async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
212+
"""
213+
Retrieve all keys in the store that begin with a given prefix. Keys are returned with the
214+
common leading prefix removed.
215+
216+
Parameters
217+
----------
218+
prefix : str
219+
220+
Returns
221+
-------
222+
AsyncGenerator[str, None]
223+
"""
212224
async for key in self.list():
213225
if key.startswith(prefix):
214-
yield key
226+
yield key.removeprefix(prefix)
215227

216228
async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
217229
if prefix.endswith("/"):

src/zarr/testing/store.py

+58-65
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
import pytest
55

6-
import zarr.api.asynchronous
76
from zarr.abc.store import AccessMode, Store
87
from zarr.core.buffer import Buffer, default_buffer_prototype
8+
from zarr.core.sync import _collect_aiterator
99
from zarr.store._utils import _normalize_interval_index
1010
from zarr.testing.utils import assert_bytes_equal
1111

@@ -123,6 +123,18 @@ async def test_set(self, store: S, key: str, data: bytes) -> None:
123123
observed = self.get(store, key)
124124
assert_bytes_equal(observed, data_buf)
125125

126+
async def test_set_many(self, store: S) -> None:
127+
"""
128+
Test that a dict of key : value pairs can be inserted into the store via the
129+
`_set_many` method.
130+
"""
131+
keys = ["zarr.json", "c/0", "foo/c/0.0", "foo/0/0"]
132+
data_buf = [self.buffer_cls.from_bytes(k.encode()) for k in keys]
133+
store_dict = dict(zip(keys, data_buf, strict=True))
134+
await store._set_many(store_dict.items())
135+
for k, v in store_dict.items():
136+
assert self.get(store, k).to_bytes() == v.to_bytes()
137+
126138
@pytest.mark.parametrize(
127139
"key_ranges",
128140
(
@@ -185,76 +197,57 @@ async def test_clear(self, store: S) -> None:
185197
assert await store.empty()
186198

187199
async def test_list(self, store: S) -> None:
188-
assert [k async for k in store.list()] == []
189-
await store.set("foo/zarr.json", self.buffer_cls.from_bytes(b"bar"))
190-
keys = [k async for k in store.list()]
191-
assert keys == ["foo/zarr.json"], keys
192-
193-
expected = ["foo/zarr.json"]
194-
for i in range(10):
195-
key = f"foo/c/{i}"
196-
expected.append(key)
197-
await store.set(
198-
f"foo/c/{i}", self.buffer_cls.from_bytes(i.to_bytes(length=3, byteorder="little"))
199-
)
200+
assert await _collect_aiterator(store.list()) == ()
201+
prefix = "foo"
202+
data = self.buffer_cls.from_bytes(b"")
203+
store_dict = {
204+
prefix + "/zarr.json": data,
205+
**{prefix + f"/c/{idx}": data for idx in range(10)},
206+
}
207+
await store._set_many(store_dict.items())
208+
expected_sorted = sorted(store_dict.keys())
209+
observed = await _collect_aiterator(store.list())
210+
observed_sorted = sorted(observed)
211+
assert observed_sorted == expected_sorted
200212

201-
@pytest.mark.xfail
202213
async def test_list_prefix(self, store: S) -> None:
203-
# TODO: we currently don't use list_prefix anywhere
204-
raise NotImplementedError
214+
"""
215+
Test that the `list_prefix` method works as intended. Given a prefix, it should return
216+
all the keys in storage that start with this prefix. Keys should be returned with the shared
217+
prefix removed.
218+
"""
219+
prefixes = ("", "a/", "a/b/", "a/b/c/")
220+
data = self.buffer_cls.from_bytes(b"")
221+
fname = "zarr.json"
222+
store_dict = {p + fname: data for p in prefixes}
223+
224+
await store._set_many(store_dict.items())
225+
226+
for prefix in prefixes:
227+
observed = tuple(sorted(await _collect_aiterator(store.list_prefix(prefix))))
228+
expected: tuple[str, ...] = ()
229+
for key in store_dict.keys():
230+
if key.startswith(prefix):
231+
expected += (key.removeprefix(prefix),)
232+
expected = tuple(sorted(expected))
233+
assert observed == expected
205234

206235
async def test_list_dir(self, store: S) -> None:
207-
out = [k async for k in store.list_dir("")]
208-
assert out == []
209-
assert [k async for k in store.list_dir("foo")] == []
210-
await store.set("foo/zarr.json", self.buffer_cls.from_bytes(b"bar"))
211-
await store.set("group-0/zarr.json", self.buffer_cls.from_bytes(b"\x01")) # group
212-
await store.set("group-0/group-1/zarr.json", self.buffer_cls.from_bytes(b"\x01")) # group
213-
await store.set("group-0/group-1/a1/zarr.json", self.buffer_cls.from_bytes(b"\x01"))
214-
await store.set("group-0/group-1/a2/zarr.json", self.buffer_cls.from_bytes(b"\x01"))
215-
await store.set("group-0/group-1/a3/zarr.json", self.buffer_cls.from_bytes(b"\x01"))
216-
217-
keys_expected = ["foo", "group-0"]
218-
keys_observed = [k async for k in store.list_dir("")]
219-
assert set(keys_observed) == set(keys_expected)
220-
221-
keys_expected = ["zarr.json"]
222-
keys_observed = [k async for k in store.list_dir("foo")]
223-
224-
assert len(keys_observed) == len(keys_expected), keys_observed
225-
assert set(keys_observed) == set(keys_expected), keys_observed
226-
227-
keys_observed = [k async for k in store.list_dir("foo/")]
228-
assert len(keys_expected) == len(keys_observed), keys_observed
229-
assert set(keys_observed) == set(keys_expected), keys_observed
230-
231-
keys_observed = [k async for k in store.list_dir("group-0")]
232-
keys_expected = ["zarr.json", "group-1"]
233-
234-
assert len(keys_observed) == len(keys_expected), keys_observed
235-
assert set(keys_observed) == set(keys_expected), keys_observed
236-
237-
keys_observed = [k async for k in store.list_dir("group-0/")]
238-
assert len(keys_expected) == len(keys_observed), keys_observed
239-
assert set(keys_observed) == set(keys_expected), keys_observed
236+
root = "foo"
237+
store_dict = {
238+
root + "/zarr.json": self.buffer_cls.from_bytes(b"bar"),
239+
root + "/c/1": self.buffer_cls.from_bytes(b"\x01"),
240+
}
240241

241-
keys_observed = [k async for k in store.list_dir("group-0/group-1")]
242-
keys_expected = ["zarr.json", "a1", "a2", "a3"]
242+
assert await _collect_aiterator(store.list_dir("")) == ()
243+
assert await _collect_aiterator(store.list_dir(root)) == ()
243244

244-
assert len(keys_observed) == len(keys_expected), keys_observed
245-
assert set(keys_observed) == set(keys_expected), keys_observed
245+
await store._set_many(store_dict.items())
246246

247-
keys_observed = [k async for k in store.list_dir("group-0/group-1")]
248-
assert len(keys_expected) == len(keys_observed), keys_observed
249-
assert set(keys_observed) == set(keys_expected), keys_observed
247+
keys_observed = await _collect_aiterator(store.list_dir(root))
248+
keys_expected = {k.removeprefix(root + "/").split("/")[0] for k in store_dict.keys()}
250249

251-
async def test_set_get(self, store_kwargs: dict[str, Any]) -> None:
252-
kwargs = {**store_kwargs, **{"mode": "w"}}
253-
store = self.store_cls(**kwargs)
254-
await zarr.api.asynchronous.open_array(store=store, path="a", mode="w", shape=(4,))
255-
keys = [x async for x in store.list()]
256-
assert keys == ["a/zarr.json"]
250+
assert sorted(keys_observed) == sorted(keys_expected)
257251

258-
# no errors
259-
await zarr.api.asynchronous.open_array(store=store, path="a", mode="r")
260-
await zarr.api.asynchronous.open_array(store=store, path="a", mode="a")
252+
keys_observed = await _collect_aiterator(store.list_dir(root + "/"))
253+
assert sorted(keys_expected) == sorted(keys_observed)

tests/v3/test_store/test_local.py

-3
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,3 @@ def test_store_supports_partial_writes(self, store: LocalStore) -> None:
3535

3636
def test_store_supports_listing(self, store: LocalStore) -> None:
3737
assert store.supports_listing
38-
39-
def test_list_prefix(self, store: LocalStore) -> None:
40-
assert True

tests/v3/test_store/test_remote.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
if TYPE_CHECKING:
6+
from collections.abc import Generator
7+
8+
import botocore.client
9+
110
import os
2-
from collections.abc import Generator
311

4-
import botocore.client
512
import fsspec
613
import pytest
14+
from botocore.session import Session
715
from upath import UPath
816

917
from zarr.core.buffer import Buffer, cpu, default_buffer_prototype
10-
from zarr.core.sync import sync
18+
from zarr.core.sync import _collect_aiterator, sync
1119
from zarr.store import RemoteStore
1220
from zarr.testing.store import StoreTests
1321

@@ -40,8 +48,6 @@ def s3_base() -> Generator[None, None, None]:
4048

4149

4250
def get_boto3_client() -> botocore.client.BaseClient:
43-
from botocore.session import Session
44-
4551
# NB: we use the sync botocore client for setup
4652
session = Session()
4753
return session.create_client("s3", endpoint_url=endpoint_url)
@@ -87,7 +93,7 @@ async def test_basic() -> None:
8793
store = await RemoteStore.open(
8894
f"s3://{test_bucket_name}", mode="w", endpoint_url=endpoint_url, anon=False
8995
)
90-
assert not await alist(store.list())
96+
assert await _collect_aiterator(store.list()) == ()
9197
assert not await store.exists("foo")
9298
data = b"hello"
9399
await store.set("foo", cpu.Buffer.from_bytes(data))
@@ -104,7 +110,7 @@ class TestRemoteStoreS3(StoreTests[RemoteStore, cpu.Buffer]):
104110
buffer_cls = cpu.Buffer
105111

106112
@pytest.fixture(scope="function", params=("use_upath", "use_str"))
107-
def store_kwargs(self, request) -> dict[str, str | bool]:
113+
def store_kwargs(self, request: pytest.FixtureRequest) -> dict[str, str | bool | UPath]: # type: ignore
108114
url = f"s3://{test_bucket_name}"
109115
anon = False
110116
mode = "r+"
@@ -116,8 +122,8 @@ def store_kwargs(self, request) -> dict[str, str | bool]:
116122
raise AssertionError
117123

118124
@pytest.fixture(scope="function")
119-
def store(self, store_kwargs: dict[str, str | bool]) -> RemoteStore:
120-
url = store_kwargs["url"]
125+
async def store(self, store_kwargs: dict[str, str | bool | UPath]) -> RemoteStore:
126+
url: str | UPath = store_kwargs["url"]
121127
mode = store_kwargs["mode"]
122128
if isinstance(url, UPath):
123129
out = self.store_cls(url=url, mode=mode)

0 commit comments

Comments
 (0)