Skip to content

Commit 7e2768a

Browse files
committed
make shard_format configurable, add bitmask for uncompressed chunks
1 parent 97a9368 commit 7e2768a

File tree

6 files changed

+107
-52
lines changed

6 files changed

+107
-52
lines changed

chunking_test.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,18 @@
44
import zarr
55

66
store = zarr.DirectoryStore("data/chunking_test.zarr")
7-
z = zarr.zeros((20, 3), chunks=(3, 3), shards=(2, 2), store=store, overwrite=True, compressor=None)
8-
z[...] = 42
7+
z = zarr.zeros((20, 3), chunks=(3, 2), shards=(2, 2), store=store, overwrite=True, compressor=None)
8+
z[:10, :] = 42
99
z[15, 1] = 389
1010
z[19, 2] = 1
1111
z[0, 1] = -4.2
1212

13+
print(store[".zarray"].decode())
1314
print("ONDISK", sorted(os.listdir("data/chunking_test.zarr")))
1415
assert json.loads(store[".zarray"].decode()) ["shards"] == [2, 2]
1516

16-
print("STORE", list(store))
17-
print("CHUNKSTORE (SHARDED)", list(z.chunk_store))
17+
print("STORE", sorted(store))
18+
print("CHUNKSTORE (SHARDED)", sorted(z.chunk_store))
1819

1920
z_reopened = zarr.open("data/chunking_test.zarr")
2021
assert z_reopened.shards == (2, 2)

zarr/_storage/sharded_store.py

Lines changed: 81 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from collections import defaultdict
12
from functools import reduce
2-
from itertools import product
3-
from typing import Any, Iterable, Iterator, Optional, Tuple
3+
import math
4+
from typing import Any, Dict, Iterable, Iterator, List, Tuple, Union
45

56
import numpy as np
67

@@ -16,7 +17,7 @@ def _cum_prod(x: Iterable[int]) -> Iterable[int]:
1617
yield prod
1718

1819

19-
class ShardedStore(Store):
20+
class MortonOrderShardedStore(Store):
2021
"""This class should not be used directly,
2122
but is added to an Array as a wrapper when needed automatically."""
2223

@@ -32,59 +33,97 @@ def __init__(
3233
) -> None:
3334
self._store: BaseStore = BaseStore._ensure_store(store)
3435
self._shards = shards
35-
# This defines C/F-order
36-
self._shard_strides = tuple(_cum_prod(shards))
3736
self._num_chunks_per_shard = reduce(lambda x, y: x*y, shards, 1)
3837
self._dimension_separator = dimension_separator
39-
# TODO: add jumptable for compressed data
38+
4039
chunk_has_constant_size = not are_chunks_compressed and not dtype == object
4140
assert chunk_has_constant_size, "Currently only uncompressed, fixed-length data can be used."
4241
self._chunk_has_constant_size = chunk_has_constant_size
4342
if chunk_has_constant_size:
4443
binary_fill_value = np.full(1, fill_value=fill_value or 0, dtype=dtype).tobytes()
4544
self._fill_chunk = binary_fill_value * chunk_size
46-
else:
47-
self._fill_chunk = None
45+
self._emtpy_meta = b"\x00" * math.ceil(self._num_chunks_per_shard / 8)
46+
47+
# unused when using Morton order
48+
self._shard_strides = tuple(_cum_prod(shards))
4849

4950
# TODO: add warnings for ineffective reads/writes:
5051
# * warn if partial reads are not available
5152
# * optionally warn on unaligned writes if no partial writes are available
52-
53-
def __key_to_sharded__(self, key: str) -> Tuple[str, int]:
53+
54+
def __get_meta__(self, shard_content: Union[bytes, bytearray]) -> int:
55+
return int.from_bytes(shard_content[-len(self._emtpy_meta):], byteorder="big")
56+
57+
def __set_meta__(self, shard_content: bytearray, meta: int) -> None:
58+
shard_content[-len(self._emtpy_meta):] = meta.to_bytes(len(self._emtpy_meta), byteorder="big")
59+
60+
# The following two methods define the order of the chunks in a shard
61+
# TODO use morton order
62+
def __chunk_key_to_shard_key_and_index__(self, chunk_key: str) -> Tuple[str, int]:
5463
# TODO: allow to be in a group (aka only use last parts for dimensions)
55-
subkeys = map(int, key.split(self._dimension_separator))
64+
chunk_subkeys = map(int, chunk_key.split(self._dimension_separator))
5665

57-
shard_tuple, index_tuple = zip(*((subkey // shard_i, subkey % shard_i) for subkey, shard_i in zip(subkeys, self._shards)))
66+
shard_tuple, index_tuple = zip(*((subkey // shard_i, subkey % shard_i) for subkey, shard_i in zip(chunk_subkeys, self._shards)))
5867
shard_key = self._dimension_separator.join(map(str, shard_tuple))
5968
index = sum(i * j for i, j in zip(index_tuple, self._shard_strides))
6069
return shard_key, index
6170

62-
def __get_chunk_slice__(self, shard_key: str, shard_index: int) -> Tuple[int, int]:
63-
# TODO: here we would use the jumptable for compression, which uses shard_key
71+
def __shard_key_and_index_to_chunk_key__(self, shard_key_tuple: Tuple[int, ...], shard_index: int) -> str:
72+
offset = tuple(shard_index % s2 // s1 for s1, s2 in zip(self._shard_strides, self._shard_strides[1:] + (self._num_chunks_per_shard,)))
73+
original_key = (shard_key_i * shards_i + offset_i for shard_key_i, offset_i, shards_i in zip(shard_key_tuple, offset, self._shards))
74+
return self._dimension_separator.join(map(str, original_key))
75+
76+
def __keys_to_shard_groups__(self, keys: Iterable[str]) -> Dict[str, List[Tuple[str, str]]]:
77+
shard_indices_per_shard_key = defaultdict(list)
78+
for chunk_key in keys:
79+
shard_key, shard_index = self.__chunk_key_to_shard_key_and_index__(chunk_key)
80+
shard_indices_per_shard_key[shard_key].append((shard_index, chunk_key))
81+
return shard_indices_per_shard_key
82+
83+
def __get_chunk_slice__(self, shard_index: int) -> Tuple[int, int]:
6484
start = shard_index * len(self._fill_chunk)
6585
return slice(start, start + len(self._fill_chunk))
6686

6787
def __getitem__(self, key: str) -> bytes:
68-
shard_key, shard_index = self.__key_to_sharded__(key)
69-
chunk_slice = self.__get_chunk_slice__(shard_key, shard_index)
70-
# TODO use partial reads if available
71-
full_shard_value = self._store[shard_key]
72-
return full_shard_value[chunk_slice]
88+
return self.getitems([key])[key]
89+
90+
def getitems(self, keys: Iterable[str], **kwargs) -> Dict[str, bytes]:
91+
result = {}
92+
for shard_key, chunks_in_shard in self.__keys_to_shard_groups__(keys).items():
93+
# TODO use partial reads if available
94+
full_shard_value = self._store[shard_key]
95+
# TODO omit items if they don't exist
96+
for shard_index, chunk_key in chunks_in_shard:
97+
result[chunk_key] = full_shard_value[self.__get_chunk_slice__(shard_index)]
98+
return result
7399

74100
def __setitem__(self, key: str, value: bytes) -> None:
75-
shard_key, shard_index = self.__key_to_sharded__(key)
76-
if shard_key in self._store:
77-
full_shard_value = bytearray(self._store[shard_key])
78-
else:
79-
full_shard_value = bytearray(self._fill_chunk * self._num_chunks_per_shard)
80-
chunk_slice = self.__get_chunk_slice__(shard_key, shard_index)
81-
# TODO use partial writes if available
82-
full_shard_value[chunk_slice] = value
83-
self._store[shard_key] = full_shard_value
101+
self.setitems({key: value})
102+
103+
def setitems(self, values: Dict[str, bytes]) -> None:
104+
for shard_key, chunks_in_shard in self.__keys_to_shard_groups__(values.keys()).items():
105+
if len(chunks_in_shard) == self._num_chunks_per_shard:
106+
# TODO shards at a non-dataset-size aligned surface are not captured here yet
107+
full_shard_value = b"".join(
108+
values[chunk_key] for _, chunk_key in sorted(chunks_in_shard)
109+
) + b"\xff" * len(self._emtpy_meta)
110+
self._store[shard_key] = full_shard_value
111+
else:
112+
# TODO use partial writes if available
113+
try:
114+
full_shard_value = bytearray(self._store[shard_key])
115+
except KeyError:
116+
full_shard_value = bytearray(self._fill_chunk * self._num_chunks_per_shard + self._emtpy_meta)
117+
chunk_mask = self.__get_meta__(full_shard_value)
118+
for shard_index, chunk_key in chunks_in_shard:
119+
chunk_mask |= 1 << shard_index
120+
full_shard_value[self.__get_chunk_slice__(shard_index)] = values[chunk_key]
121+
self.__set_meta__(full_shard_value, chunk_mask)
122+
self._store[shard_key] = full_shard_value
84123

85124
def __delitem__(self, key) -> None:
86-
# TODO not implemented yet
87-
# For uncompressed chunks, deleting the "last" chunk might need to be detected.
125+
# TODO not implemented yet, also delitems
126+
# Deleting the "last" chunk in a shard needs to remove the whole shard
88127
raise NotImplementedError("Deletion is not yet implemented")
89128

90129
def __iter__(self) -> Iterator[str]:
@@ -94,16 +133,20 @@ def __iter__(self) -> Iterator[str]:
94133
yield shard_key
95134
else:
96135
# For each shard key in the wrapped store, all corresponding chunks are yielded.
97-
# TODO: For compressed chunks we might yield only the actualy contained chunks by reading the jumptables.
98136
# TODO: allow to be in a group (aka only use last parts for dimensions)
99-
subkeys = tuple(map(int, shard_key.split(self._dimension_separator)))
100-
for offset in product(*(range(i) for i in self._shards)):
101-
original_key = (subkeys_i * shards_i + offset_i for subkeys_i, offset_i, shards_i in zip(subkeys, offset, self._shards))
102-
yield self._dimension_separator.join(map(str, original_key))
137+
shard_key_tuple = tuple(map(int, shard_key.split(self._dimension_separator)))
138+
mask = self.__get_meta__(self._store[shard_key])
139+
for i in range(self._num_chunks_per_shard):
140+
if mask == 0:
141+
break
142+
if mask & 1:
143+
yield self.__shard_key_and_index_to_chunk_key__(shard_key_tuple, i)
144+
mask >>= 1
103145

104146
def __len__(self) -> int:
105147
return sum(1 for _ in self.keys())
106148

107-
# TODO: For efficient reads and writes, we need to implement
108-
# getitems, setitems & delitems
109-
# and combine writes/reads/deletions to the same shard.
149+
150+
SHARDED_STORES = {
151+
"morton_order": MortonOrderShardedStore,
152+
}

zarr/core.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from numcodecs.compat import ensure_bytes, ensure_ndarray
1212

1313
from collections.abc import MutableMapping
14-
from zarr._storage.sharded_store import ShardedStore
14+
from zarr._storage.sharded_store import SHARDED_STORES
1515

1616
from zarr.attrs import Attributes
1717
from zarr.codecs import AsType, get_codec
@@ -219,6 +219,7 @@ def _load_metadata_nosync(self):
219219
self._shape = meta['shape']
220220
self._chunks = meta['chunks']
221221
self._shards = meta.get('shards')
222+
self._shard_format = meta.get('shard_format')
222223
self._dtype = meta['dtype']
223224
self._fill_value = meta['fill_value']
224225
self._order = meta['order']
@@ -272,7 +273,8 @@ def _flush_metadata_nosync(self):
272273
# should the dimension_separator also be included in this dict?
273274
meta = dict(shape=self._shape, chunks=self._chunks, dtype=self._dtype,
274275
compressor=compressor_config, fill_value=self._fill_value,
275-
order=self._order, filters=filters_config, shards=self._shards)
276+
order=self._order, filters=filters_config,
277+
shards=self._shards, shard_format=self._shard_format)
276278
mkey = self._key_prefix + array_meta_key
277279
self._store[mkey] = self._store._metadata_class.encode_array_metadata(meta)
278280

@@ -324,7 +326,7 @@ def chunk_store(self):
324326
return chunk_store
325327
else:
326328
if self._cached_sharded_store is None:
327-
self._cached_sharded_store = ShardedStore(
329+
self._cached_sharded_store = SHARDED_STORES[self._shard_format](
328330
chunk_store,
329331
shards=self._shards,
330332
dimension_separator=self._dimension_separator,
@@ -1731,7 +1733,7 @@ def _set_selection(self, indexer, value, fields=None):
17311733
check_array_shape('value', value, sel_shape)
17321734

17331735
# iterate over chunks in range
1734-
if not hasattr(self.store, "setitems") or self._synchronizer is not None \
1736+
if not hasattr(self.chunk_store, "setitems") or self._synchronizer is not None \
17351737
or any(map(lambda x: x == 0, self.shape)):
17361738
# iterative approach
17371739
for chunk_coords, chunk_selection, out_selection in indexer:
@@ -1974,8 +1976,8 @@ def _chunk_setitems(self, lchunk_coords, lchunk_selection, values, fields=None):
19741976
self.chunk_store.setitems(to_store)
19751977

19761978
def _chunk_delitems(self, ckeys):
1977-
if hasattr(self.store, "delitems"):
1978-
self.store.delitems(ckeys)
1979+
if hasattr(self.chunk_store, "delitems"):
1980+
self.chunk_store.delitems(ckeys)
19791981
else: # pragma: no cover
19801982
# exempting this branch from coverage as there are no extant stores
19811983
# that will trigger this condition, but it's possible that they

zarr/creation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def create(shape, chunks=True, dtype=None, compressor='default',
2121
overwrite=False, path=None, chunk_store=None, filters=None,
2222
cache_metadata=True, cache_attrs=True, read_only=False,
2323
object_codec=None, dimension_separator=None, write_empty_chunks=True,
24-
shards: Union[int, Tuple[int, ...], None]=None, **kwargs):
24+
shards: Union[int, Tuple[int, ...], None]=None, shard_format: str="morton_order", **kwargs):
2525
"""Create an array.
2626
2727
Parameters
@@ -147,7 +147,7 @@ def create(shape, chunks=True, dtype=None, compressor='default',
147147
init_array(store, shape=shape, chunks=chunks, dtype=dtype, compressor=compressor,
148148
fill_value=fill_value, order=order, overwrite=overwrite, path=path,
149149
chunk_store=chunk_store, filters=filters, object_codec=object_codec,
150-
dimension_separator=dimension_separator, shards=shards)
150+
dimension_separator=dimension_separator, shards=shards, shard_format=shard_format)
151151

152152
# instantiate array
153153
z = Array(store, path=path, chunk_store=chunk_store, synchronizer=synchronizer,

zarr/meta.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def decode_array_metadata(cls, s: Union[MappingType, str]) -> MappingType[str, A
5252

5353
dimension_separator = meta.get("dimension_separator", None)
5454
shards = meta.get("shards", None)
55+
shard_format = meta.get("shard_format", None)
5556
fill_value = cls.decode_fill_value(meta['fill_value'], dtype, object_codec)
5657
meta = dict(
5758
zarr_format=meta["zarr_format"],
@@ -67,6 +68,8 @@ def decode_array_metadata(cls, s: Union[MappingType, str]) -> MappingType[str, A
6768
meta['dimension_separator'] = dimension_separator
6869
if shards:
6970
meta['shards'] = tuple(shards)
71+
assert shard_format is not None
72+
meta['shard_format'] = shard_format
7073
except Exception as e:
7174
raise MetadataError("error decoding metadata") from e
7275
else:
@@ -81,6 +84,7 @@ def encode_array_metadata(cls, meta: MappingType[str, Any]) -> bytes:
8184

8285
dimension_separator = meta.get("dimension_separator")
8386
shards = meta.get("shards")
87+
shard_format = meta.get("shard_format")
8488
if dtype.hasobject:
8589
import numcodecs
8690
object_codec = numcodecs.get_codec(meta['filters'][0])
@@ -99,9 +103,10 @@ def encode_array_metadata(cls, meta: MappingType[str, Any]) -> bytes:
99103
)
100104
if dimension_separator:
101105
meta['dimension_separator'] = dimension_separator
102-
103106
if shards:
104107
meta['shards'] = shards
108+
assert shard_format is not None
109+
meta['shard_format'] = shard_format
105110

106111
return json_dumps(meta)
107112

zarr/storage.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def init_array(
237237
object_codec=None,
238238
dimension_separator=None,
239239
shards: Union[int, Tuple[int, ...], None]=None,
240+
shard_format: Optional[str]=None,
240241
):
241242
"""Initialize an array store with the given configuration. Note that this is a low-level
242243
function and there should be no need to call this directly from user code.
@@ -355,7 +356,7 @@ def init_array(
355356
chunk_store=chunk_store, filters=filters,
356357
object_codec=object_codec,
357358
dimension_separator=dimension_separator,
358-
shards=shards)
359+
shards=shards, shard_format=shard_format)
359360

360361

361362
def _init_array_metadata(
@@ -373,6 +374,7 @@ def _init_array_metadata(
373374
object_codec=None,
374375
dimension_separator=None,
375376
shards:Union[int, Tuple[int, ...], None] = None,
377+
shard_format: Optional[str]=None,
376378
):
377379

378380
# guard conditions
@@ -392,6 +394,7 @@ def _init_array_metadata(
392394
dtype = dtype.base
393395
chunks = normalize_chunks(chunks, shape, dtype.itemsize)
394396
shards = normalize_shards(shards, shape)
397+
shard_format = shard_format or "morton_order"
395398
order = normalize_order(order)
396399
fill_value = normalize_fill_value(fill_value, dtype)
397400

@@ -451,6 +454,7 @@ def _init_array_metadata(
451454
dimension_separator=dimension_separator)
452455
if shards is not None:
453456
meta["shards"] = shards
457+
meta["shard_format"] = shard_format
454458
key = _path_to_prefix(path) + array_meta_key
455459
if hasattr(store, '_metadata_class'):
456460
store[key] = store._metadata_class.encode_array_metadata(meta) # type: ignore

0 commit comments

Comments
 (0)