1
+ from collections import defaultdict
1
2
from functools import reduce
2
- from itertools import product
3
- from typing import Any , Iterable , Iterator , Optional , Tuple
3
+ import math
4
+ from typing import Any , Dict , Iterable , Iterator , List , Tuple , Union
4
5
5
6
import numpy as np
6
7
@@ -16,7 +17,7 @@ def _cum_prod(x: Iterable[int]) -> Iterable[int]:
16
17
yield prod
17
18
18
19
19
- class ShardedStore (Store ):
20
+ class MortonOrderShardedStore (Store ):
20
21
"""This class should not be used directly,
21
22
but is added to an Array as a wrapper when needed automatically."""
22
23
@@ -32,59 +33,97 @@ def __init__(
32
33
) -> None :
33
34
self ._store : BaseStore = BaseStore ._ensure_store (store )
34
35
self ._shards = shards
35
- # This defines C/F-order
36
- self ._shard_strides = tuple (_cum_prod (shards ))
37
36
self ._num_chunks_per_shard = reduce (lambda x , y : x * y , shards , 1 )
38
37
self ._dimension_separator = dimension_separator
39
- # TODO: add jumptable for compressed data
38
+
40
39
chunk_has_constant_size = not are_chunks_compressed and not dtype == object
41
40
assert chunk_has_constant_size , "Currently only uncompressed, fixed-length data can be used."
42
41
self ._chunk_has_constant_size = chunk_has_constant_size
43
42
if chunk_has_constant_size :
44
43
binary_fill_value = np .full (1 , fill_value = fill_value or 0 , dtype = dtype ).tobytes ()
45
44
self ._fill_chunk = binary_fill_value * chunk_size
46
- else :
47
- self ._fill_chunk = None
45
+ self ._emtpy_meta = b"\x00 " * math .ceil (self ._num_chunks_per_shard / 8 )
46
+
47
+ # unused when using Morton order
48
+ self ._shard_strides = tuple (_cum_prod (shards ))
48
49
49
50
# TODO: add warnings for ineffective reads/writes:
50
51
# * warn if partial reads are not available
51
52
# * optionally warn on unaligned writes if no partial writes are available
52
-
53
- def __key_to_sharded__ (self , key : str ) -> Tuple [str , int ]:
53
+
54
+ def __get_meta__ (self , shard_content : Union [bytes , bytearray ]) -> int :
55
+ return int .from_bytes (shard_content [- len (self ._emtpy_meta ):], byteorder = "big" )
56
+
57
+ def __set_meta__ (self , shard_content : bytearray , meta : int ) -> None :
58
+ shard_content [- len (self ._emtpy_meta ):] = meta .to_bytes (len (self ._emtpy_meta ), byteorder = "big" )
59
+
60
+ # The following two methods define the order of the chunks in a shard
61
+ # TODO use morton order
62
+ def __chunk_key_to_shard_key_and_index__ (self , chunk_key : str ) -> Tuple [str , int ]:
54
63
# TODO: allow to be in a group (aka only use last parts for dimensions)
55
- subkeys = map (int , key .split (self ._dimension_separator ))
64
+ chunk_subkeys = map (int , chunk_key .split (self ._dimension_separator ))
56
65
57
- shard_tuple , index_tuple = zip (* ((subkey // shard_i , subkey % shard_i ) for subkey , shard_i in zip (subkeys , self ._shards )))
66
+ shard_tuple , index_tuple = zip (* ((subkey // shard_i , subkey % shard_i ) for subkey , shard_i in zip (chunk_subkeys , self ._shards )))
58
67
shard_key = self ._dimension_separator .join (map (str , shard_tuple ))
59
68
index = sum (i * j for i , j in zip (index_tuple , self ._shard_strides ))
60
69
return shard_key , index
61
70
62
- def __get_chunk_slice__ (self , shard_key : str , shard_index : int ) -> Tuple [int , int ]:
63
- # TODO: here we would use the jumptable for compression, which uses shard_key
71
+ def __shard_key_and_index_to_chunk_key__ (self , shard_key_tuple : Tuple [int , ...], shard_index : int ) -> str :
72
+ offset = tuple (shard_index % s2 // s1 for s1 , s2 in zip (self ._shard_strides , self ._shard_strides [1 :] + (self ._num_chunks_per_shard ,)))
73
+ original_key = (shard_key_i * shards_i + offset_i for shard_key_i , offset_i , shards_i in zip (shard_key_tuple , offset , self ._shards ))
74
+ return self ._dimension_separator .join (map (str , original_key ))
75
+
76
+ def __keys_to_shard_groups__ (self , keys : Iterable [str ]) -> Dict [str , List [Tuple [str , str ]]]:
77
+ shard_indices_per_shard_key = defaultdict (list )
78
+ for chunk_key in keys :
79
+ shard_key , shard_index = self .__chunk_key_to_shard_key_and_index__ (chunk_key )
80
+ shard_indices_per_shard_key [shard_key ].append ((shard_index , chunk_key ))
81
+ return shard_indices_per_shard_key
82
+
83
+ def __get_chunk_slice__ (self , shard_index : int ) -> Tuple [int , int ]:
64
84
start = shard_index * len (self ._fill_chunk )
65
85
return slice (start , start + len (self ._fill_chunk ))
66
86
67
87
def __getitem__ (self , key : str ) -> bytes :
68
- shard_key , shard_index = self .__key_to_sharded__ (key )
69
- chunk_slice = self .__get_chunk_slice__ (shard_key , shard_index )
70
- # TODO use partial reads if available
71
- full_shard_value = self ._store [shard_key ]
72
- return full_shard_value [chunk_slice ]
88
+ return self .getitems ([key ])[key ]
89
+
90
+ def getitems (self , keys : Iterable [str ], ** kwargs ) -> Dict [str , bytes ]:
91
+ result = {}
92
+ for shard_key , chunks_in_shard in self .__keys_to_shard_groups__ (keys ).items ():
93
+ # TODO use partial reads if available
94
+ full_shard_value = self ._store [shard_key ]
95
+ # TODO omit items if they don't exist
96
+ for shard_index , chunk_key in chunks_in_shard :
97
+ result [chunk_key ] = full_shard_value [self .__get_chunk_slice__ (shard_index )]
98
+ return result
73
99
74
100
def __setitem__ (self , key : str , value : bytes ) -> None :
75
- shard_key , shard_index = self .__key_to_sharded__ (key )
76
- if shard_key in self ._store :
77
- full_shard_value = bytearray (self ._store [shard_key ])
78
- else :
79
- full_shard_value = bytearray (self ._fill_chunk * self ._num_chunks_per_shard )
80
- chunk_slice = self .__get_chunk_slice__ (shard_key , shard_index )
81
- # TODO use partial writes if available
82
- full_shard_value [chunk_slice ] = value
83
- self ._store [shard_key ] = full_shard_value
101
+ self .setitems ({key : value })
102
+
103
+ def setitems (self , values : Dict [str , bytes ]) -> None :
104
+ for shard_key , chunks_in_shard in self .__keys_to_shard_groups__ (values .keys ()).items ():
105
+ if len (chunks_in_shard ) == self ._num_chunks_per_shard :
106
+ # TODO shards at a non-dataset-size aligned surface are not captured here yet
107
+ full_shard_value = b"" .join (
108
+ values [chunk_key ] for _ , chunk_key in sorted (chunks_in_shard )
109
+ ) + b"\xff " * len (self ._emtpy_meta )
110
+ self ._store [shard_key ] = full_shard_value
111
+ else :
112
+ # TODO use partial writes if available
113
+ try :
114
+ full_shard_value = bytearray (self ._store [shard_key ])
115
+ except KeyError :
116
+ full_shard_value = bytearray (self ._fill_chunk * self ._num_chunks_per_shard + self ._emtpy_meta )
117
+ chunk_mask = self .__get_meta__ (full_shard_value )
118
+ for shard_index , chunk_key in chunks_in_shard :
119
+ chunk_mask |= 1 << shard_index
120
+ full_shard_value [self .__get_chunk_slice__ (shard_index )] = values [chunk_key ]
121
+ self .__set_meta__ (full_shard_value , chunk_mask )
122
+ self ._store [shard_key ] = full_shard_value
84
123
85
124
def __delitem__ (self , key ) -> None :
86
- # TODO not implemented yet
87
- # For uncompressed chunks, deleting the "last" chunk might need to be detected.
125
+ # TODO not implemented yet, also delitems
126
+ # Deleting the "last" chunk in a shard needs to remove the whole shard
88
127
raise NotImplementedError ("Deletion is not yet implemented" )
89
128
90
129
def __iter__ (self ) -> Iterator [str ]:
@@ -94,16 +133,20 @@ def __iter__(self) -> Iterator[str]:
94
133
yield shard_key
95
134
else :
96
135
# For each shard key in the wrapped store, all corresponding chunks are yielded.
97
- # TODO: For compressed chunks we might yield only the actualy contained chunks by reading the jumptables.
98
136
# TODO: allow to be in a group (aka only use last parts for dimensions)
99
- subkeys = tuple (map (int , shard_key .split (self ._dimension_separator )))
100
- for offset in product (* (range (i ) for i in self ._shards )):
101
- original_key = (subkeys_i * shards_i + offset_i for subkeys_i , offset_i , shards_i in zip (subkeys , offset , self ._shards ))
102
- yield self ._dimension_separator .join (map (str , original_key ))
137
+ shard_key_tuple = tuple (map (int , shard_key .split (self ._dimension_separator )))
138
+ mask = self .__get_meta__ (self ._store [shard_key ])
139
+ for i in range (self ._num_chunks_per_shard ):
140
+ if mask == 0 :
141
+ break
142
+ if mask & 1 :
143
+ yield self .__shard_key_and_index_to_chunk_key__ (shard_key_tuple , i )
144
+ mask >>= 1
103
145
104
146
def __len__ (self ) -> int :
105
147
return sum (1 for _ in self .keys ())
106
148
107
- # TODO: For efficient reads and writes, we need to implement
108
- # getitems, setitems & delitems
109
- # and combine writes/reads/deletions to the same shard.
149
+
150
+ SHARDED_STORES = {
151
+ "morton_order" : MortonOrderShardedStore ,
152
+ }
0 commit comments