Skip to content

Commit 754f08d

Browse files
committed
Merge remote-tracking branch 'upstream/v3' into user/tom/fix/v2-compat
2 parents 209591c + 61683be commit 754f08d

File tree

3 files changed

+42
-8
lines changed

3 files changed

+42
-8
lines changed

src/zarr/store/memory.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,15 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
114114

115115
if prefix == "":
116116
keys_unique = set(k.split("/")[0] for k in self._store_dict.keys())
117-
for key in keys_unique:
118-
yield key
119117
else:
120-
for key in self._store_dict:
121-
if key.startswith(prefix + "/") and key != prefix:
122-
yield key.removeprefix(prefix + "/").split("/")[0]
118+
# Our dictionary doesn't contain directory markers, but we want to include
119+
# a pseudo directory when there's a nested item and we're listing an
120+
# intermediate level.
121+
keys_unique = {
122+
key.removeprefix(prefix + "/").split("/")[0]
123+
for key in self._store_dict
124+
if key.startswith(prefix + "/") and key != prefix
125+
}
126+
127+
for key in keys_unique:
128+
yield key

src/zarr/store/remote.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
202202
except FileNotFoundError:
203203
return
204204
for onefile in (a.replace(prefix + "/", "") for a in allfiles):
205-
yield onefile
205+
yield onefile.removeprefix(self.path).removeprefix("/")
206206

207207
async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
208208
for onefile in await self._fs._ls(prefix, detail=False):

src/zarr/testing/store.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,17 @@ async def test_list_dir(self, store: S) -> None:
191191
assert out == []
192192
assert [k async for k in store.list_dir("foo")] == []
193193
await store.set("foo/zarr.json", Buffer.from_bytes(b"bar"))
194-
await store.set("foo/c/1", Buffer.from_bytes(b"\x01"))
194+
await store.set("group-0/zarr.json", Buffer.from_bytes(b"\x01")) # group
195+
await store.set("group-0/group-1/zarr.json", Buffer.from_bytes(b"\x01")) # group
196+
await store.set("group-0/group-1/a1/zarr.json", Buffer.from_bytes(b"\x01"))
197+
await store.set("group-0/group-1/a2/zarr.json", Buffer.from_bytes(b"\x01"))
198+
await store.set("group-0/group-1/a3/zarr.json", Buffer.from_bytes(b"\x01"))
195199

196-
keys_expected = ["zarr.json", "c"]
200+
keys_expected = ["foo", "group-0"]
201+
keys_observed = [k async for k in store.list_dir("")]
202+
assert set(keys_observed) == set(keys_expected)
203+
204+
keys_expected = ["zarr.json"]
197205
keys_observed = [k async for k in store.list_dir("foo")]
198206

199207
assert len(keys_observed) == len(keys_expected), keys_observed
@@ -202,3 +210,23 @@ async def test_list_dir(self, store: S) -> None:
202210
keys_observed = [k async for k in store.list_dir("foo/")]
203211
assert len(keys_expected) == len(keys_observed), keys_observed
204212
assert set(keys_observed) == set(keys_expected), keys_observed
213+
214+
keys_observed = [k async for k in store.list_dir("group-0")]
215+
keys_expected = ["zarr.json", "group-1"]
216+
217+
assert len(keys_observed) == len(keys_expected), keys_observed
218+
assert set(keys_observed) == set(keys_expected), keys_observed
219+
220+
keys_observed = [k async for k in store.list_dir("group-0/")]
221+
assert len(keys_expected) == len(keys_observed), keys_observed
222+
assert set(keys_observed) == set(keys_expected), keys_observed
223+
224+
keys_observed = [k async for k in store.list_dir("group-0/group-1")]
225+
keys_expected = ["zarr.json", "a1", "a2", "a3"]
226+
227+
assert len(keys_observed) == len(keys_expected), keys_observed
228+
assert set(keys_observed) == set(keys_expected), keys_observed
229+
230+
keys_observed = [k async for k in store.list_dir("group-0/group-1")]
231+
assert len(keys_expected) == len(keys_observed), keys_observed
232+
assert set(keys_observed) == set(keys_expected), keys_observed

0 commit comments

Comments
 (0)