diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 95d55a2ce0..f95ba34efd 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator +from asyncio import gather +from collections.abc import AsyncGenerator, Iterable from typing import Any, NamedTuple, Protocol, runtime_checkable from typing_extensions import Self @@ -158,6 +159,13 @@ async def set(self, key: str, value: Buffer) -> None: """ ... + async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: + """ + Insert multiple (key, value) pairs into storage. + """ + await gather(*(self.set(key, value) for key, value in values)) + return None + @property @abstractmethod def supports_deletes(self) -> bool: @@ -211,7 +219,9 @@ def list(self) -> AsyncGenerator[str, None]: @abstractmethod def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: - """Retrieve all keys in the store with a given prefix. + """ + Retrieve all keys in the store that begin with a given prefix. Keys are returned with the + common leading prefix removed. Parameters ---------- diff --git a/src/zarr/core/common.py b/src/zarr/core/common.py index 906467005f..8ebe5160bd 100644 --- a/src/zarr/core/common.py +++ b/src/zarr/core/common.py @@ -27,6 +27,7 @@ ZGROUP_JSON = ".zgroup" ZATTRS_JSON = ".zattrs" +ByteRangeRequest = tuple[int | None, int | None] BytesLike = bytes | bytearray | memoryview ShapeLike = tuple[int, ...] | int ChunkCoords = tuple[int, ...] diff --git a/src/zarr/core/sync.py b/src/zarr/core/sync.py index ff7f9a43af..db3dce79b2 100644 --- a/src/zarr/core/sync.py +++ b/src/zarr/core/sync.py @@ -113,6 +113,23 @@ def _get_loop() -> asyncio.AbstractEventLoop: return loop[0] +async def _collect_aiterator(data: AsyncIterator[T]) -> tuple[T, ...]: + """ + Collect an entire async iterator into a tuple + """ + result = [] + async for x in data: + result.append(x) + return tuple(result) + + +def collect_aiterator(data: AsyncIterator[T]) -> tuple[T, ...]: + """ + Synchronously collect an entire async iterator into a tuple. + """ + return sync(_collect_aiterator(data)) + + class SyncMixin: def _sync(self, coroutine: Coroutine[Any, Any, T]) -> T: # TODO: refactor this to to take *args and **kwargs and pass those to the method diff --git a/src/zarr/store/local.py b/src/zarr/store/local.py index 5fd48c2db0..c78837586f 100644 --- a/src/zarr/store/local.py +++ b/src/zarr/store/local.py @@ -191,7 +191,9 @@ async def list(self) -> AsyncGenerator[str, None]: yield str(p).replace(to_strip, "") async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: - """Retrieve all keys in the store with a given prefix. + """ + Retrieve all keys in the store that begin with a given prefix. Keys are returned with the + common leading prefix removed. Parameters ---------- @@ -201,14 +203,10 @@ async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: ------- AsyncGenerator[str, None] """ + to_strip = os.path.join(str(self.root / prefix)) for p in (self.root / prefix).rglob("*"): if p.is_file(): - yield str(p) - - to_strip = str(self.root) + "/" - for p in (self.root / prefix).rglob("*"): - if p.is_file(): - yield str(p).replace(to_strip, "") + yield str(p.relative_to(to_strip)) async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: """ diff --git a/src/zarr/store/memory.py b/src/zarr/store/memory.py index 13e289f374..e304419768 100644 --- a/src/zarr/store/memory.py +++ b/src/zarr/store/memory.py @@ -124,9 +124,21 @@ async def list(self) -> AsyncGenerator[str, None]: async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: for key in self._store_dict: if key.startswith(prefix): - yield key + yield key.removeprefix(prefix) async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: + """ + Retrieve all keys in the store that begin with a given prefix. Keys are returned with the + common leading prefix removed. + + Parameters + ---------- + prefix : str + + Returns + ------- + AsyncGenerator[str, None] + """ if prefix.endswith("/"): prefix = prefix[:-1] diff --git a/src/zarr/store/remote.py b/src/zarr/store/remote.py index e3e2ba3447..084ef986b1 100644 --- a/src/zarr/store/remote.py +++ b/src/zarr/store/remote.py @@ -216,5 +216,19 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: yield onefile.removeprefix(self.path).removeprefix("/") async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: - for onefile in await self._fs._ls(prefix, detail=False): - yield onefile + """ + Retrieve all keys in the store that begin with a given prefix. Keys are returned with the + common leading prefix removed. + + Parameters + ---------- + prefix : str + + Returns + ------- + AsyncGenerator[str, None] + """ + + find_str = "/".join([self.path, prefix]) + for onefile in await self._fs._find(find_str, detail=False, maxdepth=None, withdirs=False): + yield onefile.removeprefix(find_str) diff --git a/src/zarr/store/zip.py b/src/zarr/store/zip.py index ea31ad934a..2e4927aced 100644 --- a/src/zarr/store/zip.py +++ b/src/zarr/store/zip.py @@ -209,9 +209,21 @@ async def list(self) -> AsyncGenerator[str, None]: yield key async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: + """ + Retrieve all keys in the store that begin with a given prefix. Keys are returned with the + common leading prefix removed. + + Parameters + ---------- + prefix : str + + Returns + ------- + AsyncGenerator[str, None] + """ async for key in self.list(): if key.startswith(prefix): - yield key + yield key.removeprefix(prefix) async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: if prefix.endswith("/"): diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index a08b6960db..8a9f27e4b8 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -3,9 +3,9 @@ import pytest -import zarr.api.asynchronous from zarr.abc.store import AccessMode, Store from zarr.core.buffer import Buffer, default_buffer_prototype +from zarr.core.sync import _collect_aiterator from zarr.store._utils import _normalize_interval_index from zarr.testing.utils import assert_bytes_equal @@ -123,6 +123,18 @@ async def test_set(self, store: S, key: str, data: bytes) -> None: observed = self.get(store, key) assert_bytes_equal(observed, data_buf) + async def test_set_many(self, store: S) -> None: + """ + Test that a dict of key : value pairs can be inserted into the store via the + `_set_many` method. + """ + keys = ["zarr.json", "c/0", "foo/c/0.0", "foo/0/0"] + data_buf = [self.buffer_cls.from_bytes(k.encode()) for k in keys] + store_dict = dict(zip(keys, data_buf, strict=True)) + await store._set_many(store_dict.items()) + for k, v in store_dict.items(): + assert self.get(store, k).to_bytes() == v.to_bytes() + @pytest.mark.parametrize( "key_ranges", ( @@ -185,76 +197,57 @@ async def test_clear(self, store: S) -> None: assert await store.empty() async def test_list(self, store: S) -> None: - assert [k async for k in store.list()] == [] - await store.set("foo/zarr.json", self.buffer_cls.from_bytes(b"bar")) - keys = [k async for k in store.list()] - assert keys == ["foo/zarr.json"], keys - - expected = ["foo/zarr.json"] - for i in range(10): - key = f"foo/c/{i}" - expected.append(key) - await store.set( - f"foo/c/{i}", self.buffer_cls.from_bytes(i.to_bytes(length=3, byteorder="little")) - ) + assert await _collect_aiterator(store.list()) == () + prefix = "foo" + data = self.buffer_cls.from_bytes(b"") + store_dict = { + prefix + "/zarr.json": data, + **{prefix + f"/c/{idx}": data for idx in range(10)}, + } + await store._set_many(store_dict.items()) + expected_sorted = sorted(store_dict.keys()) + observed = await _collect_aiterator(store.list()) + observed_sorted = sorted(observed) + assert observed_sorted == expected_sorted - @pytest.mark.xfail async def test_list_prefix(self, store: S) -> None: - # TODO: we currently don't use list_prefix anywhere - raise NotImplementedError + """ + Test that the `list_prefix` method works as intended. Given a prefix, it should return + all the keys in storage that start with this prefix. Keys should be returned with the shared + prefix removed. + """ + prefixes = ("", "a/", "a/b/", "a/b/c/") + data = self.buffer_cls.from_bytes(b"") + fname = "zarr.json" + store_dict = {p + fname: data for p in prefixes} + + await store._set_many(store_dict.items()) + + for prefix in prefixes: + observed = tuple(sorted(await _collect_aiterator(store.list_prefix(prefix)))) + expected: tuple[str, ...] = () + for key in store_dict.keys(): + if key.startswith(prefix): + expected += (key.removeprefix(prefix),) + expected = tuple(sorted(expected)) + assert observed == expected async def test_list_dir(self, store: S) -> None: - out = [k async for k in store.list_dir("")] - assert out == [] - assert [k async for k in store.list_dir("foo")] == [] - await store.set("foo/zarr.json", self.buffer_cls.from_bytes(b"bar")) - await store.set("group-0/zarr.json", self.buffer_cls.from_bytes(b"\x01")) # group - await store.set("group-0/group-1/zarr.json", self.buffer_cls.from_bytes(b"\x01")) # group - await store.set("group-0/group-1/a1/zarr.json", self.buffer_cls.from_bytes(b"\x01")) - await store.set("group-0/group-1/a2/zarr.json", self.buffer_cls.from_bytes(b"\x01")) - await store.set("group-0/group-1/a3/zarr.json", self.buffer_cls.from_bytes(b"\x01")) - - keys_expected = ["foo", "group-0"] - keys_observed = [k async for k in store.list_dir("")] - assert set(keys_observed) == set(keys_expected) - - keys_expected = ["zarr.json"] - keys_observed = [k async for k in store.list_dir("foo")] - - assert len(keys_observed) == len(keys_expected), keys_observed - assert set(keys_observed) == set(keys_expected), keys_observed - - keys_observed = [k async for k in store.list_dir("foo/")] - assert len(keys_expected) == len(keys_observed), keys_observed - assert set(keys_observed) == set(keys_expected), keys_observed - - keys_observed = [k async for k in store.list_dir("group-0")] - keys_expected = ["zarr.json", "group-1"] - - assert len(keys_observed) == len(keys_expected), keys_observed - assert set(keys_observed) == set(keys_expected), keys_observed - - keys_observed = [k async for k in store.list_dir("group-0/")] - assert len(keys_expected) == len(keys_observed), keys_observed - assert set(keys_observed) == set(keys_expected), keys_observed + root = "foo" + store_dict = { + root + "/zarr.json": self.buffer_cls.from_bytes(b"bar"), + root + "/c/1": self.buffer_cls.from_bytes(b"\x01"), + } - keys_observed = [k async for k in store.list_dir("group-0/group-1")] - keys_expected = ["zarr.json", "a1", "a2", "a3"] + assert await _collect_aiterator(store.list_dir("")) == () + assert await _collect_aiterator(store.list_dir(root)) == () - assert len(keys_observed) == len(keys_expected), keys_observed - assert set(keys_observed) == set(keys_expected), keys_observed + await store._set_many(store_dict.items()) - keys_observed = [k async for k in store.list_dir("group-0/group-1")] - assert len(keys_expected) == len(keys_observed), keys_observed - assert set(keys_observed) == set(keys_expected), keys_observed + keys_observed = await _collect_aiterator(store.list_dir(root)) + keys_expected = {k.removeprefix(root + "/").split("/")[0] for k in store_dict.keys()} - async def test_set_get(self, store_kwargs: dict[str, Any]) -> None: - kwargs = {**store_kwargs, **{"mode": "w"}} - store = self.store_cls(**kwargs) - await zarr.api.asynchronous.open_array(store=store, path="a", mode="w", shape=(4,)) - keys = [x async for x in store.list()] - assert keys == ["a/zarr.json"] + assert sorted(keys_observed) == sorted(keys_expected) - # no errors - await zarr.api.asynchronous.open_array(store=store, path="a", mode="r") - await zarr.api.asynchronous.open_array(store=store, path="a", mode="a") + keys_observed = await _collect_aiterator(store.list_dir(root + "/")) + assert sorted(keys_expected) == sorted(keys_observed) diff --git a/tests/v3/test_store/test_local.py b/tests/v3/test_store/test_local.py index 59cae22de3..5f1dde3fcc 100644 --- a/tests/v3/test_store/test_local.py +++ b/tests/v3/test_store/test_local.py @@ -35,6 +35,3 @@ def test_store_supports_partial_writes(self, store: LocalStore) -> None: def test_store_supports_listing(self, store: LocalStore) -> None: assert store.supports_listing - - def test_list_prefix(self, store: LocalStore) -> None: - assert True diff --git a/tests/v3/test_store/test_remote.py b/tests/v3/test_store/test_remote.py index afa991209f..e2c3070198 100644 --- a/tests/v3/test_store/test_remote.py +++ b/tests/v3/test_store/test_remote.py @@ -1,13 +1,21 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Generator + + import botocore.client + import os -from collections.abc import Generator -import botocore.client import fsspec import pytest +from botocore.session import Session from upath import UPath from zarr.core.buffer import Buffer, cpu, default_buffer_prototype -from zarr.core.sync import sync +from zarr.core.sync import _collect_aiterator, sync from zarr.store import RemoteStore from zarr.testing.store import StoreTests @@ -40,8 +48,6 @@ def s3_base() -> Generator[None, None, None]: def get_boto3_client() -> botocore.client.BaseClient: - from botocore.session import Session - # NB: we use the sync botocore client for setup session = Session() return session.create_client("s3", endpoint_url=endpoint_url) @@ -87,7 +93,7 @@ async def test_basic() -> None: store = await RemoteStore.open( f"s3://{test_bucket_name}", mode="w", endpoint_url=endpoint_url, anon=False ) - assert not await alist(store.list()) + assert await _collect_aiterator(store.list()) == () assert not await store.exists("foo") data = b"hello" await store.set("foo", cpu.Buffer.from_bytes(data)) @@ -104,7 +110,7 @@ class TestRemoteStoreS3(StoreTests[RemoteStore, cpu.Buffer]): buffer_cls = cpu.Buffer @pytest.fixture(scope="function", params=("use_upath", "use_str")) - def store_kwargs(self, request) -> dict[str, str | bool]: + def store_kwargs(self, request: pytest.FixtureRequest) -> dict[str, str | bool | UPath]: # type: ignore url = f"s3://{test_bucket_name}" anon = False mode = "r+" @@ -116,8 +122,8 @@ def store_kwargs(self, request) -> dict[str, str | bool]: raise AssertionError @pytest.fixture(scope="function") - def store(self, store_kwargs: dict[str, str | bool]) -> RemoteStore: - url = store_kwargs["url"] + async def store(self, store_kwargs: dict[str, str | bool | UPath]) -> RemoteStore: + url: str | UPath = store_kwargs["url"] mode = store_kwargs["mode"] if isinstance(url, UPath): out = self.store_cls(url=url, mode=mode)