|
3 | 3 |
|
4 | 4 | import pytest
|
5 | 5 |
|
6 |
| -import zarr.api.asynchronous |
7 | 6 | from zarr.abc.store import AccessMode, Store
|
8 | 7 | from zarr.core.buffer import Buffer, default_buffer_prototype
|
| 8 | +from zarr.core.sync import _collect_aiterator |
9 | 9 | from zarr.store._utils import _normalize_interval_index
|
10 | 10 | from zarr.testing.utils import assert_bytes_equal
|
11 | 11 |
|
@@ -123,6 +123,18 @@ async def test_set(self, store: S, key: str, data: bytes) -> None:
|
123 | 123 | observed = self.get(store, key)
|
124 | 124 | assert_bytes_equal(observed, data_buf)
|
125 | 125 |
|
| 126 | + async def test_set_many(self, store: S) -> None: |
| 127 | + """ |
| 128 | + Test that a dict of key : value pairs can be inserted into the store via the |
| 129 | + `_set_many` method. |
| 130 | + """ |
| 131 | + keys = ["zarr.json", "c/0", "foo/c/0.0", "foo/0/0"] |
| 132 | + data_buf = [self.buffer_cls.from_bytes(k.encode()) for k in keys] |
| 133 | + store_dict = dict(zip(keys, data_buf, strict=True)) |
| 134 | + await store._set_many(store_dict.items()) |
| 135 | + for k, v in store_dict.items(): |
| 136 | + assert self.get(store, k).to_bytes() == v.to_bytes() |
| 137 | + |
126 | 138 | @pytest.mark.parametrize(
|
127 | 139 | "key_ranges",
|
128 | 140 | (
|
@@ -185,76 +197,57 @@ async def test_clear(self, store: S) -> None:
|
185 | 197 | assert await store.empty()
|
186 | 198 |
|
187 | 199 | async def test_list(self, store: S) -> None:
|
188 |
| - assert [k async for k in store.list()] == [] |
189 |
| - await store.set("foo/zarr.json", self.buffer_cls.from_bytes(b"bar")) |
190 |
| - keys = [k async for k in store.list()] |
191 |
| - assert keys == ["foo/zarr.json"], keys |
192 |
| - |
193 |
| - expected = ["foo/zarr.json"] |
194 |
| - for i in range(10): |
195 |
| - key = f"foo/c/{i}" |
196 |
| - expected.append(key) |
197 |
| - await store.set( |
198 |
| - f"foo/c/{i}", self.buffer_cls.from_bytes(i.to_bytes(length=3, byteorder="little")) |
199 |
| - ) |
| 200 | + assert await _collect_aiterator(store.list()) == () |
| 201 | + prefix = "foo" |
| 202 | + data = self.buffer_cls.from_bytes(b"") |
| 203 | + store_dict = { |
| 204 | + prefix + "/zarr.json": data, |
| 205 | + **{prefix + f"/c/{idx}": data for idx in range(10)}, |
| 206 | + } |
| 207 | + await store._set_many(store_dict.items()) |
| 208 | + expected_sorted = sorted(store_dict.keys()) |
| 209 | + observed = await _collect_aiterator(store.list()) |
| 210 | + observed_sorted = sorted(observed) |
| 211 | + assert observed_sorted == expected_sorted |
200 | 212 |
|
201 |
| - @pytest.mark.xfail |
202 | 213 | async def test_list_prefix(self, store: S) -> None:
|
203 |
| - # TODO: we currently don't use list_prefix anywhere |
204 |
| - raise NotImplementedError |
| 214 | + """ |
| 215 | + Test that the `list_prefix` method works as intended. Given a prefix, it should return |
| 216 | + all the keys in storage that start with this prefix. Keys should be returned with the shared |
| 217 | + prefix removed. |
| 218 | + """ |
| 219 | + prefixes = ("", "a/", "a/b/", "a/b/c/") |
| 220 | + data = self.buffer_cls.from_bytes(b"") |
| 221 | + fname = "zarr.json" |
| 222 | + store_dict = {p + fname: data for p in prefixes} |
| 223 | + |
| 224 | + await store._set_many(store_dict.items()) |
| 225 | + |
| 226 | + for prefix in prefixes: |
| 227 | + observed = tuple(sorted(await _collect_aiterator(store.list_prefix(prefix)))) |
| 228 | + expected: tuple[str, ...] = () |
| 229 | + for key in store_dict.keys(): |
| 230 | + if key.startswith(prefix): |
| 231 | + expected += (key.removeprefix(prefix),) |
| 232 | + expected = tuple(sorted(expected)) |
| 233 | + assert observed == expected |
205 | 234 |
|
206 | 235 | async def test_list_dir(self, store: S) -> None:
|
207 |
| - out = [k async for k in store.list_dir("")] |
208 |
| - assert out == [] |
209 |
| - assert [k async for k in store.list_dir("foo")] == [] |
210 |
| - await store.set("foo/zarr.json", self.buffer_cls.from_bytes(b"bar")) |
211 |
| - await store.set("group-0/zarr.json", self.buffer_cls.from_bytes(b"\x01")) # group |
212 |
| - await store.set("group-0/group-1/zarr.json", self.buffer_cls.from_bytes(b"\x01")) # group |
213 |
| - await store.set("group-0/group-1/a1/zarr.json", self.buffer_cls.from_bytes(b"\x01")) |
214 |
| - await store.set("group-0/group-1/a2/zarr.json", self.buffer_cls.from_bytes(b"\x01")) |
215 |
| - await store.set("group-0/group-1/a3/zarr.json", self.buffer_cls.from_bytes(b"\x01")) |
216 |
| - |
217 |
| - keys_expected = ["foo", "group-0"] |
218 |
| - keys_observed = [k async for k in store.list_dir("")] |
219 |
| - assert set(keys_observed) == set(keys_expected) |
220 |
| - |
221 |
| - keys_expected = ["zarr.json"] |
222 |
| - keys_observed = [k async for k in store.list_dir("foo")] |
223 |
| - |
224 |
| - assert len(keys_observed) == len(keys_expected), keys_observed |
225 |
| - assert set(keys_observed) == set(keys_expected), keys_observed |
226 |
| - |
227 |
| - keys_observed = [k async for k in store.list_dir("foo/")] |
228 |
| - assert len(keys_expected) == len(keys_observed), keys_observed |
229 |
| - assert set(keys_observed) == set(keys_expected), keys_observed |
230 |
| - |
231 |
| - keys_observed = [k async for k in store.list_dir("group-0")] |
232 |
| - keys_expected = ["zarr.json", "group-1"] |
233 |
| - |
234 |
| - assert len(keys_observed) == len(keys_expected), keys_observed |
235 |
| - assert set(keys_observed) == set(keys_expected), keys_observed |
236 |
| - |
237 |
| - keys_observed = [k async for k in store.list_dir("group-0/")] |
238 |
| - assert len(keys_expected) == len(keys_observed), keys_observed |
239 |
| - assert set(keys_observed) == set(keys_expected), keys_observed |
| 236 | + root = "foo" |
| 237 | + store_dict = { |
| 238 | + root + "/zarr.json": self.buffer_cls.from_bytes(b"bar"), |
| 239 | + root + "/c/1": self.buffer_cls.from_bytes(b"\x01"), |
| 240 | + } |
240 | 241 |
|
241 |
| - keys_observed = [k async for k in store.list_dir("group-0/group-1")] |
242 |
| - keys_expected = ["zarr.json", "a1", "a2", "a3"] |
| 242 | + assert await _collect_aiterator(store.list_dir("")) == () |
| 243 | + assert await _collect_aiterator(store.list_dir(root)) == () |
243 | 244 |
|
244 |
| - assert len(keys_observed) == len(keys_expected), keys_observed |
245 |
| - assert set(keys_observed) == set(keys_expected), keys_observed |
| 245 | + await store._set_many(store_dict.items()) |
246 | 246 |
|
247 |
| - keys_observed = [k async for k in store.list_dir("group-0/group-1")] |
248 |
| - assert len(keys_expected) == len(keys_observed), keys_observed |
249 |
| - assert set(keys_observed) == set(keys_expected), keys_observed |
| 247 | + keys_observed = await _collect_aiterator(store.list_dir(root)) |
| 248 | + keys_expected = {k.removeprefix(root + "/").split("/")[0] for k in store_dict.keys()} |
250 | 249 |
|
251 |
| - async def test_set_get(self, store_kwargs: dict[str, Any]) -> None: |
252 |
| - kwargs = {**store_kwargs, **{"mode": "w"}} |
253 |
| - store = self.store_cls(**kwargs) |
254 |
| - await zarr.api.asynchronous.open_array(store=store, path="a", mode="w", shape=(4,)) |
255 |
| - keys = [x async for x in store.list()] |
256 |
| - assert keys == ["a/zarr.json"] |
| 250 | + assert sorted(keys_observed) == sorted(keys_expected) |
257 | 251 |
|
258 |
| - # no errors |
259 |
| - await zarr.api.asynchronous.open_array(store=store, path="a", mode="r") |
260 |
| - await zarr.api.asynchronous.open_array(store=store, path="a", mode="a") |
| 252 | + keys_observed = await _collect_aiterator(store.list_dir(root + "/")) |
| 253 | + assert sorted(keys_expected) == sorted(keys_observed) |
0 commit comments