Skip to content

Commit 9d147b9

Browse files
committed
refactor members cache
1 parent 1eaf3ea commit 9d147b9

File tree

2 files changed

+67
-42
lines changed

2 files changed

+67
-42
lines changed

xarray/backends/zarr.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,7 @@ class ZarrStore(AbstractWritableDataStore):
601601

602602
__slots__ = (
603603
"_append_dim",
604-
"_cache_array_keys",
604+
"_cache_members",
605605
"_close_store_on_close",
606606
"_consolidate_on_close",
607607
"_group",
@@ -634,7 +634,7 @@ def open_store(
634634
zarr_format=None,
635635
use_zarr_fill_value_as_mask=None,
636636
write_empty: bool | None = None,
637-
cache_array_keys: bool = False,
637+
cache_members: bool = False,
638638
):
639639
(
640640
zarr_group,
@@ -666,7 +666,7 @@ def open_store(
666666
write_empty,
667667
close_store_on_close,
668668
use_zarr_fill_value_as_mask,
669-
cache_array_keys=cache_array_keys,
669+
cache_members=cache_members
670670
)
671671
for group in group_paths
672672
}
@@ -748,31 +748,42 @@ def __init__(
748748
self._write_empty = write_empty
749749
self._close_store_on_close = close_store_on_close
750750
self._use_zarr_fill_value_as_mask = use_zarr_fill_value_as_mask
751+
self._cache_members: bool = cache_members
752+
self._members: dict[str, ZarrArray | ZarrGroup] = {}
751753

752-
self._members: tuple[bool, dict[str, ZarrArray | ZarrGroup] | None]
753-
if cache_members:
754-
self._members = (True, None)
755-
else:
756-
self._members = (False, None)
754+
if self._cache_members:
755+
# initialize the cache
756+
self._members = self._fetch_members()
757757

758758
@property
759-
def members(self) -> dict[str, ZarrArray | ZarrGroup]:
759+
def members(self) -> dict[str, ZarrArray]:
760760
"""
761761
Model the arrays and groups contained in self.zarr_group as a dict
762762
"""
763-
do_cache, old_members = self._members
764-
if not do_cache or old_members is None:
765-
# we explicitly only care about the arrays, which saves some IO
766-
# in zarr v2
767-
members = dict(self.zarr_group.arrays())
768-
if do_cache:
769-
self._members = (do_cache, members)
770-
return members
763+
if not self._cache_members:
764+
return self._fetch_members()
771765
else:
772-
return old_members
766+
return self._members
767+
768+
def _fetch_members(self) -> dict[str, ZarrArray]:
769+
"""
770+
Get the arrays and groups defined in the zarr group modelled by this Store
771+
"""
772+
return dict(self.zarr_group.items())
773+
774+
def _update_members(self, data: dict[str, ZarrArray]):
775+
if not self._cache_members:
776+
msg = (
777+
'Updating the members cache is only valid if this object was created '
778+
'with cache_members=True, but this object has `cache_members=False`.'
779+
f'You should update the zarr group directly.'
780+
)
781+
raise ValueError(msg)
782+
else:
783+
self._members = {**self.members, **data}
773784

774785
def array_keys(self) -> tuple[str, ...]:
775-
return tuple(key for (key, _) in self.arrays())
786+
return tuple(key for (key, node) in self.members.items() if isinstance(node, ZarrArray))
776787

777788
def arrays(self) -> tuple[tuple[str, ZarrArray], ...]:
778789
return tuple(
@@ -1047,6 +1058,10 @@ def _open_existing_array(self, *, name) -> ZarrArray:
10471058
else:
10481059
zarr_array = self.zarr_group[name]
10491060

1061+
# update the model of the underlying zarr group
1062+
if self._cache_members:
1063+
self._update_members({name: zarr_array})
1064+
self._update_members({name: zarr_array})
10501065
return zarr_array
10511066

10521067
def _create_new_array(
@@ -1075,6 +1090,9 @@ def _create_new_array(
10751090
**encoding,
10761091
)
10771092
zarr_array = _put_attrs(zarr_array, attrs)
1093+
# update the model of the underlying zarr group
1094+
if self._cache_members:
1095+
self._update_members({name: zarr_array})
10781096
return zarr_array
10791097

10801098
def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=None):
@@ -1143,7 +1161,8 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
11431161
zarr_array.resize(new_shape)
11441162

11451163
zarr_shape = zarr_array.shape
1146-
1164+
# update the model of the members of the zarr group
1165+
self.members[name] = zarr_array
11471166
region = tuple(write_region[dim] for dim in dims)
11481167

11491168
# We need to do this for both new and existing variables to ensure we're not

xarray/tests/test_backends.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2275,10 +2275,10 @@ def create_zarr_target(self):
22752275
raise NotImplementedError
22762276

22772277
@contextlib.contextmanager
2278-
def create_store(self):
2278+
def create_store(self, cache_members: bool = True):
22792279
with self.create_zarr_target() as store_target:
22802280
yield backends.ZarrStore.open_group(
2281-
store_target, mode="w", **self.version_kwargs
2281+
store_target, mode="w", cache_members=cache_members, **self.version_kwargs
22822282
)
22832283

22842284
def save(self, dataset, store_target, **kwargs): # type: ignore[override]
@@ -2572,7 +2572,7 @@ def test_hidden_zarr_keys(self) -> None:
25722572
skip_if_zarr_format_3("This test is unnecessary; no hidden Zarr keys")
25732573

25742574
expected = create_test_data()
2575-
with self.create_store() as store:
2575+
with self.create_store(cache_members=False) as store:
25762576
expected.dump_to_store(store)
25772577
zarr_group = store.ds
25782578

@@ -2594,6 +2594,7 @@ def test_hidden_zarr_keys(self) -> None:
25942594

25952595
# put it back and try removing from a variable
25962596
del zarr_group["var2"].attrs[self.DIMENSION_KEY]
2597+
25972598
with pytest.raises(KeyError):
25982599
with xr.decode_cf(store):
25992600
pass
@@ -3258,40 +3259,45 @@ def test_chunked_cftime_datetime(self) -> None:
32583259
assert original[name].chunks == actual_var.chunks
32593260
assert original.chunks == actual.chunks
32603261

3261-
@pytest.mark.parametrize("cache_array_keys", [True, False])
3262-
def test_get_array_keys(self, cache_array_keys: bool) -> None:
3262+
def test_cache_members(self) -> None:
32633263
"""
3264-
Ensure that if `ZarrStore` is created with `cache_array_keys` set to `True`,
3265-
a `ZarrStore.get_array_keys` only invokes the `array_keys` function on the
3266-
`ZarrStore.zarr_group` instance once, and that the results of that call are cached.
3264+
Ensure that if `ZarrStore` is created with `cache_members` set to `True`,
3265+
a `ZarrStore` only inspects the underlying zarr group once,
3266+
and that the results of that inspection are cached.
32673267
3268-
Otherwise, `ZarrStore.get_array_keys` instance should invoke the `array_keys`
3269-
each time it is called.
3268+
Otherwise, `ZarrStore.members` should inspect the underlying zarr group each time it is
3269+
invoked
32703270
"""
32713271
with self.create_zarr_target() as store_target:
3272-
zstore = backends.ZarrStore.open_group(
3273-
store_target, mode="w", cache_members=cache_array_keys
3272+
zstore_mut = backends.ZarrStore.open_group(
3273+
store_target, mode="w", cache_members=False
32743274
)
32753275

32763276
# ensure that the keys are sorted
32773277
array_keys = sorted(("foo", "bar"))
32783278

32793279
# create some arrays
32803280
for ak in array_keys:
3281-
zstore.zarr_group.create(name=ak, shape=(1,), dtype="uint8")
3281+
zstore_mut.zarr_group.create(name=ak, shape=(1,), dtype="uint8")
3282+
3283+
zstore_stat = backends.ZarrStore.open_group(
3284+
store_target, mode="r", cache_members=True
3285+
)
32823286

3283-
observed_keys_0 = sorted(zstore.array_keys())
3287+
observed_keys_0 = sorted(zstore_stat.array_keys())
32843288
assert observed_keys_0 == array_keys
32853289

32863290
# create a new array
32873291
new_key = "baz"
3288-
zstore.zarr_group.create(name=new_key, shape=(1,), dtype="uint8")
3289-
observed_keys_1 = sorted(zstore.array_keys())
3292+
zstore_mut.zarr_group.create(name=new_key, shape=(1,), dtype="uint8")
3293+
3294+
observed_keys_1 = sorted(zstore_stat.array_keys())
3295+
assert observed_keys_1 == array_keys
3296+
3297+
observed_keys_2 = sorted(zstore_mut.array_keys())
3298+
assert observed_keys_2 == sorted(array_keys + [new_key])
3299+
32903300

3291-
if cache_array_keys:
3292-
assert observed_keys_1 == array_keys
3293-
else:
3294-
assert observed_keys_1 == sorted(array_keys + [new_key])
32953301

32963302

32973303
@requires_zarr
@@ -3556,9 +3562,9 @@ def create_zarr_target(self):
35563562
yield tmp
35573563

35583564
@contextlib.contextmanager
3559-
def create_store(self):
3565+
def create_store(self, cache_members: bool = True):
35603566
with self.create_zarr_target() as store_target:
3561-
group = backends.ZarrStore.open_group(store_target, mode="a")
3567+
group = backends.ZarrStore.open_group(store_target, mode="a", cache_members=cache_members)
35623568
yield group
35633569

35643570

0 commit comments

Comments
 (0)