diff --git a/src/zarr/v3/abc/codec.py b/src/zarr/v3/abc/codec.py index c81f2c976f..0a7c68784f 100644 --- a/src/zarr/v3/abc/codec.py +++ b/src/zarr/v3/abc/codec.py @@ -1,13 +1,3 @@ -# Notes: -# 1. These are missing methods described in the spec. I expected to see these method definitions: -# def compute_encoded_representation_type(self, decoded_representation_type): -# def encode(self, decoded_value): -# def decode(self, encoded_value, decoded_representation_type): -# def partial_decode(self, input_handle, decoded_representation_type, decoded_regions): -# def compute_encoded_size(self, input_size): -# 2. Understand why array metadata is included on all codecs - - from __future__ import annotations from abc import abstractmethod, ABC @@ -20,30 +10,39 @@ if TYPE_CHECKING: - from zarr.v3.metadata import CoreArrayMetadata, CodecMetadata + from zarr.v3.metadata import ( + ArraySpec, + ArrayMetadata, + DataType, + CodecMetadata, + RuntimeConfiguration, + ) class Codec(ABC): is_fixed_size: bool - array_metadata: CoreArrayMetadata + @classmethod @abstractmethod - def compute_encoded_size(self, input_byte_length: int) -> int: + def get_metadata_class(cls) -> Type[CodecMetadata]: pass - def resolve_metadata(self) -> CoreArrayMetadata: - return self.array_metadata - @classmethod @abstractmethod - def from_metadata( - cls, codec_metadata: "CodecMetadata", array_metadata: CoreArrayMetadata - ) -> Codec: + def from_metadata(cls, codec_metadata: CodecMetadata) -> Codec: pass - @classmethod @abstractmethod - def get_metadata_class(cls) -> "Type[CodecMetadata]": + def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) -> int: + pass + + def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec: + return chunk_spec + + def evolve(self, *, ndim: int, data_type: DataType) -> Codec: + return self + + def validate(self, array_metadata: ArrayMetadata) -> None: pass @@ -52,6 +51,8 @@ class ArrayArrayCodec(Codec): async def decode( self, chunk_array: np.ndarray, + chunk_spec: ArraySpec, + runtime_configuration: RuntimeConfiguration, ) -> np.ndarray: pass @@ -59,6 +60,8 @@ async def decode( async def encode( self, chunk_array: np.ndarray, + chunk_spec: ArraySpec, + runtime_configuration: RuntimeConfiguration, ) -> Optional[np.ndarray]: pass @@ -68,6 +71,8 @@ class ArrayBytesCodec(Codec): async def decode( self, chunk_array: BytesLike, + chunk_spec: ArraySpec, + runtime_configuration: RuntimeConfiguration, ) -> np.ndarray: pass @@ -75,6 +80,8 @@ async def decode( async def encode( self, chunk_array: np.ndarray, + chunk_spec: ArraySpec, + runtime_configuration: RuntimeConfiguration, ) -> Optional[BytesLike]: pass @@ -85,6 +92,8 @@ async def decode_partial( self, store_path: StorePath, selection: SliceSelection, + chunk_spec: ArraySpec, + runtime_configuration: RuntimeConfiguration, ) -> Optional[np.ndarray]: pass @@ -96,6 +105,8 @@ async def encode_partial( store_path: StorePath, chunk_array: np.ndarray, selection: SliceSelection, + chunk_spec: ArraySpec, + runtime_configuration: RuntimeConfiguration, ) -> None: pass @@ -105,6 +116,8 @@ class BytesBytesCodec(Codec): async def decode( self, chunk_array: BytesLike, + chunk_spec: ArraySpec, + runtime_configuration: RuntimeConfiguration, ) -> BytesLike: pass @@ -112,5 +125,7 @@ async def decode( async def encode( self, chunk_array: BytesLike, + chunk_spec: ArraySpec, + runtime_configuration: RuntimeConfiguration, ) -> Optional[BytesLike]: pass diff --git a/src/zarr/v3/array.py b/src/zarr/v3/array.py index d55a5aee43..dadde1658a 100644 --- a/src/zarr/v3/array.py +++ b/src/zarr/v3/array.py @@ -16,11 +16,10 @@ import numpy as np from attr import evolve, frozen -from zarr.v3.abc.codec import ArrayBytesCodecPartialDecodeMixin - # from zarr.v3.array_v2 import ArrayV2 from zarr.v3.codecs import CodecMetadata, CodecPipeline, bytes_codec +from zarr.v3.codecs.registry import get_codec_from_metadata from zarr.v3.common import ( ZARR_JSON, ChunkCoords, @@ -31,6 +30,7 @@ from zarr.v3.indexing import BasicIndexer, all_chunk_coords, is_total_slice from zarr.v3.metadata import ( ArrayMetadata, + ArraySpec, DataType, DefaultChunkKeyEncodingConfigurationMetadata, DefaultChunkKeyEncodingMetadata, @@ -41,7 +41,6 @@ V2ChunkKeyEncodingMetadata, dtype_to_data_type, ) -from zarr.v3.codecs.sharding import ShardingCodec from zarr.v3.store import StoreLike, StorePath, make_store_path from zarr.v3.sync import sync @@ -118,8 +117,11 @@ async def create( metadata=metadata, store_path=store_path, runtime_configuration=runtime_configuration, - codec_pipeline=CodecPipeline.from_metadata( - metadata.codecs, metadata.get_core_metadata(runtime_configuration) + codec_pipeline=CodecPipeline.create( + [ + get_codec_from_metadata(codec).evolve(ndim=len(shape), data_type=data_type) + for codec in codecs + ] ), ) @@ -134,13 +136,17 @@ def from_json( runtime_configuration: RuntimeConfiguration, ) -> AsyncArray: metadata = ArrayMetadata.from_json(zarr_json) + codecs = [ + get_codec_from_metadata(codec).evolve( + ndim=len(metadata.shape), data_type=metadata.data_type + ) + for codec in metadata.codecs + ] async_array = cls( metadata=metadata, store_path=store_path, runtime_configuration=runtime_configuration, - codec_pipeline=CodecPipeline.from_metadata( - metadata.codecs, metadata.get_core_metadata(runtime_configuration) - ), + codec_pipeline=CodecPipeline.create(codecs), ) async_array._validate_metadata() return async_array @@ -240,6 +246,7 @@ def _validate_metadata(self) -> None: self.metadata.dimension_names ), "`dimension_names` and `shape` need to have the same number of dimensions." assert self.metadata.fill_value is not None, "`fill_value` is required." + self.codec_pipeline.validate(self.metadata) async def _read_chunk( self, @@ -248,15 +255,14 @@ async def _read_chunk( out_selection: SliceSelection, out: np.ndarray, ): + chunk_spec = self.metadata.get_chunk_spec(chunk_coords) chunk_key_encoding = self.metadata.chunk_key_encoding chunk_key = chunk_key_encoding.encode_chunk_key(chunk_coords) store_path = self.store_path / chunk_key - if len(self.codec_pipeline.codecs) == 1 and isinstance( - self.codec_pipeline.codecs[0], ArrayBytesCodecPartialDecodeMixin - ): - chunk_array = await self.codec_pipeline.codecs[0].decode_partial( - store_path, chunk_selection + if self.codec_pipeline.supports_partial_decode: + chunk_array = await self.codec_pipeline.decode_partial( + store_path, chunk_selection, chunk_spec, self.runtime_configuration ) if chunk_array is not None: out[out_selection] = chunk_array @@ -265,7 +271,9 @@ async def _read_chunk( else: chunk_bytes = await store_path.get() if chunk_bytes is not None: - chunk_array = await self.codec_pipeline.decode(chunk_bytes) + chunk_array = await self.codec_pipeline.decode( + chunk_bytes, chunk_spec, self.runtime_configuration + ) tmp = chunk_array[chunk_selection] out[out_selection] = tmp else: @@ -316,6 +324,7 @@ async def _write_chunk( chunk_selection: SliceSelection, out_selection: SliceSelection, ): + chunk_spec = self.metadata.get_chunk_spec(chunk_coords) chunk_key_encoding = self.metadata.chunk_key_encoding chunk_key = chunk_key_encoding.encode_chunk_key(chunk_coords) store_path = self.store_path / chunk_key @@ -330,17 +339,16 @@ async def _write_chunk( chunk_array.fill(value) else: chunk_array = value[out_selection] - await self._write_chunk_to_store(store_path, chunk_array) + await self._write_chunk_to_store(store_path, chunk_array, chunk_spec) - elif len(self.codec_pipeline.codecs) == 1 and isinstance( - self.codec_pipeline.codecs[0], ShardingCodec - ): - sharding_codec = self.codec_pipeline.codecs[0] + elif self.codec_pipeline.supports_partial_encode: # print("encode_partial", chunk_coords, chunk_selection, repr(self)) - await sharding_codec.encode_partial( + await self.codec_pipeline.encode_partial( store_path, value[out_selection], chunk_selection, + chunk_spec, + self.runtime_configuration, ) else: # writing partial chunks @@ -356,18 +364,24 @@ async def _write_chunk( chunk_array.fill(self.metadata.fill_value) else: chunk_array = ( - await self.codec_pipeline.decode(chunk_bytes) + await self.codec_pipeline.decode( + chunk_bytes, chunk_spec, self.runtime_configuration + ) ).copy() # make a writable copy chunk_array[chunk_selection] = value[out_selection] - await self._write_chunk_to_store(store_path, chunk_array) + await self._write_chunk_to_store(store_path, chunk_array, chunk_spec) - async def _write_chunk_to_store(self, store_path: StorePath, chunk_array: np.ndarray): + async def _write_chunk_to_store( + self, store_path: StorePath, chunk_array: np.ndarray, chunk_spec: ArraySpec + ): if np.all(chunk_array == self.metadata.fill_value): # chunks that only contain fill_value will be removed await store_path.delete() else: - chunk_bytes = await self.codec_pipeline.encode(chunk_array) + chunk_bytes = await self.codec_pipeline.encode( + chunk_array, chunk_spec, self.runtime_configuration + ) if chunk_bytes is None: await store_path.delete() else: diff --git a/src/zarr/v3/array_v2.py b/src/zarr/v3/array_v2.py index a2f26f01b0..dc4cbebd5e 100644 --- a/src/zarr/v3/array_v2.py +++ b/src/zarr/v3/array_v2.py @@ -20,7 +20,7 @@ to_thread, ) from zarr.v3.indexing import BasicIndexer, all_chunk_coords, is_total_slice -from zarr.v3.metadata import ArrayV2Metadata, RuntimeConfiguration +from zarr.v3.metadata import ArrayV2Metadata, CodecMetadata, RuntimeConfiguration from zarr.v3.store import StoreLike, StorePath, make_store_path from zarr.v3.sync import sync @@ -83,12 +83,14 @@ async def create_async( order=order, dimension_separator=dimension_separator, fill_value=0 if fill_value is None else fill_value, - compressor=numcodecs.get_codec(compressor).get_config() - if compressor is not None - else None, - filters=[numcodecs.get_codec(filter).get_config() for filter in filters] - if filters is not None - else None, + compressor=( + numcodecs.get_codec(compressor).get_config() if compressor is not None else None + ), + filters=( + [numcodecs.get_codec(filter).get_config() for filter in filters] + if filters is not None + else None + ), ) array = cls( metadata=metadata, @@ -441,22 +443,29 @@ async def convert_to_v3_async(self) -> Array: from zarr.v3.common import ZARR_JSON from zarr.v3.metadata import ( ArrayMetadata, + DataType, + RegularChunkGridConfigurationMetadata, + RegularChunkGridMetadata, + V2ChunkKeyEncodingConfigurationMetadata, + V2ChunkKeyEncodingMetadata, + dtype_to_data_type, + ) + from zarr.v3.codecs.blosc import ( BloscCodecConfigurationMetadata, BloscCodecMetadata, + blosc_shuffle_int_to_str, + ) + from zarr.v3.codecs.bytes import ( BytesCodecConfigurationMetadata, BytesCodecMetadata, - CodecMetadata, - DataType, + ) + from zarr.v3.codecs.gzip import ( GzipCodecConfigurationMetadata, GzipCodecMetadata, - RegularChunkGridConfigurationMetadata, - RegularChunkGridMetadata, + ) + from zarr.v3.codecs.transpose import ( TransposeCodecConfigurationMetadata, TransposeCodecMetadata, - V2ChunkKeyEncodingConfigurationMetadata, - V2ChunkKeyEncodingMetadata, - blosc_shuffle_int_to_str, - dtype_to_data_type, ) data_type = DataType[dtype_to_data_type[self.metadata.dtype.str]] @@ -476,7 +485,11 @@ async def convert_to_v3_async(self) -> Array: if self.metadata.order == "F": codecs.append( - TransposeCodecMetadata(configuration=TransposeCodecConfigurationMetadata(order="F")) + TransposeCodecMetadata( + configuration=TransposeCodecConfigurationMetadata( + order=tuple(reversed(range(self.metadata.ndim))) + ) + ) ) codecs.append( BytesCodecMetadata(configuration=BytesCodecConfigurationMetadata(endian=endian)) diff --git a/src/zarr/v3/codecs/__init__.py b/src/zarr/v3/codecs/__init__.py index 30a42c8ad5..40c71f6807 100644 --- a/src/zarr/v3/codecs/__init__.py +++ b/src/zarr/v3/codecs/__init__.py @@ -1,9 +1,9 @@ from __future__ import annotations -from functools import reduce from typing import ( TYPE_CHECKING, Iterable, + Iterator, List, Literal, Optional, @@ -11,17 +11,24 @@ Union, ) from warnings import warn +from attr import frozen import numpy as np -from attr import frozen -from zarr.v3.abc.codec import Codec, ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec -from zarr.v3.common import BytesLike -from zarr.v3.metadata import CodecMetadata, ShardingCodecIndexLocation -from zarr.v3.codecs.registry import get_codec_class +from zarr.v3.abc.codec import ( + ArrayBytesCodecPartialDecodeMixin, + ArrayBytesCodecPartialEncodeMixin, + Codec, + ArrayArrayCodec, + ArrayBytesCodec, + BytesBytesCodec, +) +from zarr.v3.common import BytesLike, SliceSelection +from zarr.v3.metadata import CodecMetadata, ShardingCodecIndexLocation, RuntimeConfiguration +from zarr.v3.store import StorePath if TYPE_CHECKING: - from zarr.v3.metadata import CoreArrayMetadata + from zarr.v3.metadata import ArrayMetadata, ArraySpec from zarr.v3.codecs.sharding import ShardingCodecMetadata from zarr.v3.codecs.blosc import BloscCodecMetadata from zarr.v3.codecs.bytes import BytesCodecMetadata @@ -31,27 +38,23 @@ from zarr.v3.codecs.crc32c_ import Crc32cCodecMetadata +def _find_array_bytes_codec( + codecs: Iterable[Tuple[Codec, ArraySpec]] +) -> Tuple[ArrayBytesCodec, ArraySpec]: + for codec, array_spec in codecs: + if isinstance(codec, ArrayBytesCodec): + return (codec, array_spec) + raise KeyError + + @frozen class CodecPipeline: - codecs: List[Codec] + array_array_codecs: List[ArrayArrayCodec] + array_bytes_codec: ArrayBytesCodec + bytes_bytes_codecs: List[BytesBytesCodec] @classmethod - def from_metadata( - cls, - codecs_metadata: Iterable[CodecMetadata], - array_metadata: CoreArrayMetadata, - ) -> CodecPipeline: - out: List[Codec] = [] - for codec_metadata in codecs_metadata or []: - codec_cls = get_codec_class(codec_metadata.name) - codec = codec_cls.from_metadata(codec_metadata, array_metadata) - out.append(codec) - array_metadata = codec.resolve_metadata() - CodecPipeline._validate_codecs(out, array_metadata) - return cls(out) - - @staticmethod - def _validate_codecs(codecs: List[Codec], array_metadata: CoreArrayMetadata) -> None: + def create(cls, codecs: List[Codec]) -> CodecPipeline: from zarr.v3.codecs.sharding import ShardingCodec assert any( @@ -86,22 +89,6 @@ def _validate_codecs(codecs: List[Codec], array_metadata: CoreArrayMetadata) -> f"ArrayArrayCodec '{type(codec)}' cannot follow after " + f"BytesBytesCodec '{type(prev_codec)}'." ) - - if isinstance(codec, ShardingCodec): - assert len(codec.configuration.chunk_shape) == len(array_metadata.shape), ( - "The shard's `chunk_shape` and array's `shape` need to have the " - + "same number of dimensions." - ) - assert all( - s % c == 0 - for s, c in zip( - array_metadata.chunk_shape, - codec.configuration.chunk_shape, - ) - ), ( - "The array's `chunk_shape` needs to be divisible by the " - + "shard's inner `chunk_shape`." - ) prev_codec = codec if any(isinstance(codec, ShardingCodec) for codec in codecs) and len(codecs) > 1: @@ -110,48 +97,150 @@ def _validate_codecs(codecs: List[Codec], array_metadata: CoreArrayMetadata) -> + "writes, which may lead to inefficient performance." ) - def _array_array_codecs(self) -> List[ArrayArrayCodec]: - return [codec for codec in self.codecs if isinstance(codec, ArrayArrayCodec)] - - def _array_bytes_codec(self) -> ArrayBytesCodec: - return next(codec for codec in self.codecs if isinstance(codec, ArrayBytesCodec)) - - def _bytes_bytes_codecs(self) -> List[BytesBytesCodec]: - return [codec for codec in self.codecs if isinstance(codec, BytesBytesCodec)] + return CodecPipeline( + array_array_codecs=[codec for codec in codecs if isinstance(codec, ArrayArrayCodec)], + array_bytes_codec=[codec for codec in codecs if isinstance(codec, ArrayBytesCodec)][0], + bytes_bytes_codecs=[codec for codec in codecs if isinstance(codec, BytesBytesCodec)], + ) - async def decode(self, chunk_bytes: BytesLike) -> np.ndarray: - for bb_codec in self._bytes_bytes_codecs()[::-1]: - chunk_bytes = await bb_codec.decode(chunk_bytes) + @property + def supports_partial_decode(self) -> bool: + return (len(self.array_array_codecs) + len(self.bytes_bytes_codecs)) == 0 and isinstance( + self.array_bytes_codec, ArrayBytesCodecPartialDecodeMixin + ) - chunk_array = await self._array_bytes_codec().decode(chunk_bytes) + @property + def supports_partial_encode(self) -> bool: + return (len(self.array_array_codecs) + len(self.bytes_bytes_codecs)) == 0 and isinstance( + self.array_bytes_codec, ArrayBytesCodecPartialEncodeMixin + ) - for aa_codec in self._array_array_codecs()[::-1]: - chunk_array = await aa_codec.decode(chunk_array) + def __iter__(self) -> Iterator[Codec]: + for aa_codec in self.array_array_codecs: + yield aa_codec + + yield self.array_bytes_codec + + for bb_codec in self.bytes_bytes_codecs: + yield bb_codec + + def validate(self, array_metadata: ArrayMetadata) -> None: + for codec in self: + codec.validate(array_metadata) + + def _codecs_with_resolved_metadata( + self, array_spec: ArraySpec + ) -> Tuple[ + List[Tuple[ArrayArrayCodec, ArraySpec]], + Tuple[ArrayBytesCodec, ArraySpec], + List[Tuple[BytesBytesCodec, ArraySpec]], + ]: + aa_codecs_with_spec: List[Tuple[ArrayArrayCodec, ArraySpec]] = [] + for aa_codec in self.array_array_codecs: + aa_codecs_with_spec.append((aa_codec, array_spec)) + array_spec = aa_codec.resolve_metadata(array_spec) + + ab_codec_with_spec = (self.array_bytes_codec, array_spec) + array_spec = self.array_bytes_codec.resolve_metadata(array_spec) + + bb_codecs_with_spec: List[Tuple[BytesBytesCodec, ArraySpec]] = [] + for bb_codec in self.bytes_bytes_codecs: + bb_codecs_with_spec.append((bb_codec, array_spec)) + array_spec = bb_codec.resolve_metadata(array_spec) + + return (aa_codecs_with_spec, ab_codec_with_spec, bb_codecs_with_spec) + + async def decode( + self, + chunk_bytes: BytesLike, + array_spec: ArraySpec, + runtime_configuration: RuntimeConfiguration, + ) -> np.ndarray: + ( + aa_codecs_with_spec, + ab_codec_with_spec, + bb_codecs_with_spec, + ) = self._codecs_with_resolved_metadata(array_spec) + + for bb_codec, array_spec in bb_codecs_with_spec[::-1]: + chunk_bytes = await bb_codec.decode(chunk_bytes, array_spec, runtime_configuration) + + ab_codec, array_spec = ab_codec_with_spec + chunk_array = await ab_codec.decode(chunk_bytes, array_spec, runtime_configuration) + + for aa_codec, array_spec in aa_codecs_with_spec[::-1]: + chunk_array = await aa_codec.decode(chunk_array, array_spec, runtime_configuration) return chunk_array - async def encode(self, chunk_array: np.ndarray) -> Optional[BytesLike]: - for aa_codec in self._array_array_codecs(): - chunk_array_maybe = await aa_codec.encode(chunk_array) + async def decode_partial( + self, + store_path: StorePath, + selection: SliceSelection, + chunk_spec: ArraySpec, + runtime_configuration: RuntimeConfiguration, + ) -> Optional[np.ndarray]: + assert self.supports_partial_decode + assert isinstance(self.array_bytes_codec, ArrayBytesCodecPartialDecodeMixin) + return await self.array_bytes_codec.decode_partial( + store_path, selection, chunk_spec, runtime_configuration + ) + + async def encode( + self, + chunk_array: np.ndarray, + array_spec: ArraySpec, + runtime_configuration: RuntimeConfiguration, + ) -> Optional[BytesLike]: + ( + aa_codecs_with_spec, + ab_codec_with_spec, + bb_codecs_with_spec, + ) = self._codecs_with_resolved_metadata(array_spec) + + for aa_codec, array_spec in aa_codecs_with_spec: + chunk_array_maybe = await aa_codec.encode( + chunk_array, array_spec, runtime_configuration + ) if chunk_array_maybe is None: return None chunk_array = chunk_array_maybe - chunk_bytes_maybe = await self._array_bytes_codec().encode(chunk_array) + ab_codec, array_spec = ab_codec_with_spec + chunk_bytes_maybe = await ab_codec.encode(chunk_array, array_spec, runtime_configuration) if chunk_bytes_maybe is None: return None chunk_bytes = chunk_bytes_maybe - for bb_codec in self._bytes_bytes_codecs(): - chunk_bytes_maybe = await bb_codec.encode(chunk_bytes) + for bb_codec, array_spec in bb_codecs_with_spec: + chunk_bytes_maybe = await bb_codec.encode( + chunk_bytes, array_spec, runtime_configuration + ) if chunk_bytes_maybe is None: return None chunk_bytes = chunk_bytes_maybe return chunk_bytes - def compute_encoded_size(self, byte_length: int) -> int: - return reduce(lambda acc, codec: codec.compute_encoded_size(acc), self.codecs, byte_length) + async def encode_partial( + self, + store_path: StorePath, + chunk_array: np.ndarray, + selection: SliceSelection, + chunk_spec: ArraySpec, + runtime_configuration: RuntimeConfiguration, + ) -> None: + assert self.supports_partial_encode + assert isinstance(self.array_bytes_codec, ArrayBytesCodecPartialEncodeMixin) + await self.array_bytes_codec.encode_partial( + store_path, chunk_array, selection, chunk_spec, runtime_configuration + ) + + def compute_encoded_size(self, byte_length: int, array_spec: ArraySpec) -> int: + for codec in self: + byte_length = codec.compute_encoded_size(byte_length, array_spec) + array_spec = codec.resolve_metadata(array_spec) + return byte_length def blosc_codec( @@ -217,14 +306,16 @@ def crc32c_codec() -> "Crc32cCodecMetadata": def sharding_codec( chunk_shape: Tuple[int, ...], - codecs: Optional[List[CodecMetadata]] = None, - index_codecs: Optional[List[CodecMetadata]] = None, + codecs: Optional[Iterable[CodecMetadata]] = None, + index_codecs: Optional[Iterable[CodecMetadata]] = None, index_location: ShardingCodecIndexLocation = ShardingCodecIndexLocation.end, ) -> "ShardingCodecMetadata": from zarr.v3.codecs.sharding import ShardingCodecMetadata, ShardingCodecConfigurationMetadata - codecs = codecs or [bytes_codec()] - index_codecs = index_codecs or [bytes_codec(), crc32c_codec()] + codecs = tuple(codecs) if codecs is not None else (bytes_codec(),) + index_codecs = ( + tuple(index_codecs) if index_codecs is not None else (bytes_codec(), crc32c_codec()) + ) return ShardingCodecMetadata( configuration=ShardingCodecConfigurationMetadata( chunk_shape, codecs, index_codecs, index_location diff --git a/src/zarr/v3/codecs/blosc.py b/src/zarr/v3/codecs/blosc.py index 8fb32faaa7..efc862e636 100644 --- a/src/zarr/v3/codecs/blosc.py +++ b/src/zarr/v3/codecs/blosc.py @@ -1,4 +1,5 @@ from __future__ import annotations +from functools import lru_cache from typing import ( TYPE_CHECKING, @@ -10,16 +11,15 @@ import numcodecs import numpy as np -from attr import asdict, evolve, frozen, field +from attr import evolve, frozen, field from numcodecs.blosc import Blosc from zarr.v3.abc.codec import BytesBytesCodec from zarr.v3.codecs.registry import register_codec from zarr.v3.common import BytesLike, to_thread -from zarr.v3.metadata import CodecMetadata if TYPE_CHECKING: - from zarr.v3.metadata import CoreArrayMetadata + from zarr.v3.metadata import ArraySpec, CodecMetadata, DataType, RuntimeConfiguration BloscShuffle = Literal["noshuffle", "shuffle", "bitshuffle"] @@ -52,47 +52,55 @@ class BloscCodecMetadata: @frozen class BloscCodec(BytesBytesCodec): - array_metadata: CoreArrayMetadata configuration: BloscCodecConfigurationMetadata - blosc_codec: Blosc is_fixed_size = False @classmethod - def from_metadata( - cls, codec_metadata: CodecMetadata, array_metadata: CoreArrayMetadata - ) -> BloscCodec: + def from_metadata(cls, codec_metadata: CodecMetadata) -> BloscCodec: assert isinstance(codec_metadata, BloscCodecMetadata) - configuration = codec_metadata.configuration - if configuration.typesize == 0: - configuration = evolve(configuration, typesize=array_metadata.data_type.byte_count) - config_dict = asdict(codec_metadata.configuration) - config_dict.pop("typesize", None) - map_shuffle_str_to_int = {"noshuffle": 0, "shuffle": 1, "bitshuffle": 2} - config_dict["shuffle"] = map_shuffle_str_to_int[config_dict["shuffle"]] - return cls( - array_metadata=array_metadata, - configuration=configuration, - blosc_codec=Blosc.from_config(config_dict), - ) + return cls(configuration=codec_metadata.configuration) @classmethod def get_metadata_class(cls) -> Type[BloscCodecMetadata]: return BloscCodecMetadata + def evolve(self, *, data_type: DataType, **_kwargs) -> BloscCodec: + new_codec = self + if new_codec.configuration.typesize == 0: + new_configuration = evolve(new_codec.configuration, typesize=data_type.byte_count) + new_codec = evolve(new_codec, configuration=new_configuration) + + return new_codec + + @lru_cache + def get_blosc_codec(self) -> Blosc: + map_shuffle_str_to_int = {"noshuffle": 0, "shuffle": 1, "bitshuffle": 2} + config_dict = { + "cname": self.configuration.cname, + "clevel": self.configuration.clevel, + "shuffle": map_shuffle_str_to_int[self.configuration.shuffle], + "blocksize": self.configuration.blocksize, + } + return Blosc.from_config(config_dict) + async def decode( self, chunk_bytes: bytes, + _chunk_spec: ArraySpec, + _runtime_configuration: RuntimeConfiguration, ) -> BytesLike: - return await to_thread(self.blosc_codec.decode, chunk_bytes) + return await to_thread(self.get_blosc_codec().decode, chunk_bytes) async def encode( self, chunk_bytes: bytes, + chunk_spec: ArraySpec, + _runtime_configuration: RuntimeConfiguration, ) -> Optional[BytesLike]: - chunk_array = np.frombuffer(chunk_bytes, dtype=self.array_metadata.dtype) - return await to_thread(self.blosc_codec.encode, chunk_array) + chunk_array = np.frombuffer(chunk_bytes, dtype=chunk_spec.dtype) + return await to_thread(self.get_blosc_codec().encode, chunk_array) - def compute_encoded_size(self, _input_byte_length: int) -> int: + def compute_encoded_size(self, _input_byte_length: int, _chunk_spec: ArraySpec) -> int: raise NotImplementedError diff --git a/src/zarr/v3/codecs/bytes.py b/src/zarr/v3/codecs/bytes.py index 80a3f155d0..e05ccb6abc 100644 --- a/src/zarr/v3/codecs/bytes.py +++ b/src/zarr/v3/codecs/bytes.py @@ -13,15 +13,17 @@ from zarr.v3.abc.codec import ArrayBytesCodec from zarr.v3.codecs.registry import register_codec from zarr.v3.common import BytesLike -from zarr.v3.metadata import CodecMetadata if TYPE_CHECKING: - from zarr.v3.metadata import CoreArrayMetadata + from zarr.v3.metadata import CodecMetadata, ArraySpec, ArrayMetadata, RuntimeConfiguration + + +Endian = Literal["big", "little"] @frozen class BytesCodecConfigurationMetadata: - endian: Optional[Literal["big", "little"]] = "little" + endian: Optional[Endian] = "little" @frozen @@ -32,28 +34,24 @@ class BytesCodecMetadata: @frozen class BytesCodec(ArrayBytesCodec): - array_metadata: CoreArrayMetadata configuration: BytesCodecConfigurationMetadata is_fixed_size = True @classmethod - def from_metadata( - cls, codec_metadata: CodecMetadata, array_metadata: CoreArrayMetadata - ) -> BytesCodec: + def from_metadata(cls, codec_metadata: CodecMetadata) -> BytesCodec: assert isinstance(codec_metadata, BytesCodecMetadata) - assert ( - array_metadata.dtype.itemsize == 1 or codec_metadata.configuration.endian is not None - ), "The `endian` configuration needs to be specified for multi-byte data types." - return cls( - array_metadata=array_metadata, - configuration=codec_metadata.configuration, - ) + return cls(configuration=codec_metadata.configuration) @classmethod def get_metadata_class(cls) -> Type[BytesCodecMetadata]: return BytesCodecMetadata - def _get_byteorder(self, array: np.ndarray) -> Literal["big", "little"]: + def validate(self, array_metadata: ArrayMetadata) -> None: + assert ( + not array_metadata.data_type.has_endianness or self.configuration.endian is not None + ), "The `endian` configuration needs to be specified for multi-byte data types." + + def _get_byteorder(self, array: np.ndarray) -> Endian: if array.dtype.byteorder == "<": return "little" elif array.dtype.byteorder == ">": @@ -66,27 +64,31 @@ def _get_byteorder(self, array: np.ndarray) -> Literal["big", "little"]: async def decode( self, chunk_bytes: BytesLike, + chunk_spec: ArraySpec, + _runtime_configuration: RuntimeConfiguration, ) -> np.ndarray: - if self.array_metadata.dtype.itemsize > 0: + if chunk_spec.dtype.itemsize > 0: if self.configuration.endian == "little": prefix = "<" else: prefix = ">" - dtype = np.dtype(f"{prefix}{self.array_metadata.data_type.to_numpy_shortname()}") + dtype = np.dtype(f"{prefix}{chunk_spec.data_type.to_numpy_shortname()}") else: - dtype = np.dtype(f"|{self.array_metadata.data_type.to_numpy_shortname()}") + dtype = np.dtype(f"|{chunk_spec.data_type.to_numpy_shortname()}") chunk_array = np.frombuffer(chunk_bytes, dtype) # ensure correct chunk shape - if chunk_array.shape != self.array_metadata.chunk_shape: + if chunk_array.shape != chunk_spec.shape: chunk_array = chunk_array.reshape( - self.array_metadata.chunk_shape, + chunk_spec.shape, ) return chunk_array async def encode( self, chunk_array: np.ndarray, + _chunk_spec: ArraySpec, + _runtime_configuration: RuntimeConfiguration, ) -> Optional[BytesLike]: if chunk_array.dtype.itemsize > 1: byteorder = self._get_byteorder(chunk_array) @@ -95,7 +97,7 @@ async def encode( chunk_array = chunk_array.astype(new_dtype) return chunk_array.tobytes() - def compute_encoded_size(self, input_byte_length: int) -> int: + def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: return input_byte_length diff --git a/src/zarr/v3/codecs/crc32c_.py b/src/zarr/v3/codecs/crc32c_.py index c4fab3c9b9..4f8b9c7b0b 100644 --- a/src/zarr/v3/codecs/crc32c_.py +++ b/src/zarr/v3/codecs/crc32c_.py @@ -14,10 +14,9 @@ from zarr.v3.abc.codec import BytesBytesCodec from zarr.v3.codecs.registry import register_codec from zarr.v3.common import BytesLike -from zarr.v3.metadata import CodecMetadata if TYPE_CHECKING: - from zarr.v3.metadata import CoreArrayMetadata + from zarr.v3.metadata import ArraySpec, CodecMetadata, RuntimeConfiguration @frozen @@ -27,15 +26,12 @@ class Crc32cCodecMetadata: @frozen class Crc32cCodec(BytesBytesCodec): - array_metadata: CoreArrayMetadata is_fixed_size = True @classmethod - def from_metadata( - cls, codec_metadata: CodecMetadata, array_metadata: CoreArrayMetadata - ) -> Crc32cCodec: + def from_metadata(cls, codec_metadata: CodecMetadata) -> Crc32cCodec: assert isinstance(codec_metadata, Crc32cCodecMetadata) - return cls(array_metadata=array_metadata) + return cls() @classmethod def get_metadata_class(cls) -> Type[Crc32cCodecMetadata]: @@ -44,6 +40,8 @@ def get_metadata_class(cls) -> Type[Crc32cCodecMetadata]: async def decode( self, chunk_bytes: bytes, + _chunk_spec: ArraySpec, + _runtime_configuration: RuntimeConfiguration, ) -> BytesLike: crc32_bytes = chunk_bytes[-4:] inner_bytes = chunk_bytes[:-4] @@ -54,10 +52,12 @@ async def decode( async def encode( self, chunk_bytes: bytes, + _chunk_spec: ArraySpec, + _runtime_configuration: RuntimeConfiguration, ) -> Optional[BytesLike]: return chunk_bytes + np.uint32(crc32c(chunk_bytes)).tobytes() - def compute_encoded_size(self, input_byte_length: int) -> int: + def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: return input_byte_length + 4 diff --git a/src/zarr/v3/codecs/gzip.py b/src/zarr/v3/codecs/gzip.py index be1ebcdc9f..a3fafc1382 100644 --- a/src/zarr/v3/codecs/gzip.py +++ b/src/zarr/v3/codecs/gzip.py @@ -13,10 +13,9 @@ from zarr.v3.abc.codec import BytesBytesCodec from zarr.v3.codecs.registry import register_codec from zarr.v3.common import BytesLike, to_thread -from zarr.v3.metadata import CodecMetadata if TYPE_CHECKING: - from zarr.v3.metadata import CoreArrayMetadata + from zarr.v3.metadata import ArraySpec, CodecMetadata, RuntimeConfiguration @frozen @@ -32,20 +31,14 @@ class GzipCodecMetadata: @frozen class GzipCodec(BytesBytesCodec): - array_metadata: CoreArrayMetadata configuration: GzipCodecConfigurationMetadata is_fixed_size = True @classmethod - def from_metadata( - cls, codec_metadata: CodecMetadata, array_metadata: CoreArrayMetadata - ) -> GzipCodec: + def from_metadata(cls, codec_metadata: CodecMetadata) -> GzipCodec: assert isinstance(codec_metadata, GzipCodecMetadata) - return cls( - array_metadata=array_metadata, - configuration=codec_metadata.configuration, - ) + return cls(configuration=codec_metadata.configuration) @classmethod def get_metadata_class(cls) -> Type[GzipCodecMetadata]: @@ -54,16 +47,24 @@ def get_metadata_class(cls) -> Type[GzipCodecMetadata]: async def decode( self, chunk_bytes: bytes, + _chunk_spec: ArraySpec, + _runtime_configuration: RuntimeConfiguration, ) -> BytesLike: return await to_thread(GZip(self.configuration.level).decode, chunk_bytes) async def encode( self, chunk_bytes: bytes, + _chunk_spec: ArraySpec, + _runtime_configuration: RuntimeConfiguration, ) -> Optional[BytesLike]: return await to_thread(GZip(self.configuration.level).encode, chunk_bytes) - def compute_encoded_size(self, _input_byte_length: int) -> int: + def compute_encoded_size( + self, + _input_byte_length: int, + _chunk_spec: ArraySpec, + ) -> int: raise NotImplementedError diff --git a/src/zarr/v3/codecs/registry.py b/src/zarr/v3/codecs/registry.py index 642c0feebb..bdd9a5765d 100644 --- a/src/zarr/v3/codecs/registry.py +++ b/src/zarr/v3/codecs/registry.py @@ -45,6 +45,11 @@ def _get_codec_item(key: str) -> CodecRegistryItem: raise KeyError(key) +def get_codec_from_metadata(val: CodecMetadata) -> Codec: + key = val.name + return _get_codec_item(key).codec_cls.from_metadata(val) + + def get_codec_metadata_class(key: str) -> Type[CodecMetadata]: return _get_codec_item(key).codec_metadata_cls diff --git a/src/zarr/v3/codecs/sharding.py b/src/zarr/v3/codecs/sharding.py index 12c84ade29..26020f160f 100644 --- a/src/zarr/v3/codecs/sharding.py +++ b/src/zarr/v3/codecs/sharding.py @@ -1,4 +1,5 @@ from __future__ import annotations +from functools import cached_property, lru_cache from typing import ( Awaitable, @@ -23,7 +24,7 @@ ) from zarr.v3.codecs import CodecPipeline -from zarr.v3.codecs.registry import register_codec +from zarr.v3.codecs.registry import get_codec_from_metadata, register_codec from zarr.v3.common import ( BytesLike, ChunkCoords, @@ -38,10 +39,14 @@ morton_order_iter, ) from zarr.v3.metadata import ( - CoreArrayMetadata, + ArrayMetadata, + ArraySpec, DataType, CodecMetadata, + RegularChunkGridMetadata, ShardingCodecIndexLocation, + RuntimeConfiguration, + runtime_configuration as make_runtime_configuration, ) from zarr.v3.store import StorePath @@ -51,8 +56,8 @@ @frozen class ShardingCodecConfigurationMetadata: chunk_shape: ChunkCoords - codecs: List["CodecMetadata"] - index_codecs: List["CodecMetadata"] + codecs: Tuple[CodecMetadata, ...] + index_codecs: Tuple[CodecMetadata, ...] index_location: ShardingCodecIndexLocation = ShardingCodecIndexLocation.end @@ -66,6 +71,10 @@ class _ShardIndex(NamedTuple): # dtype uint64, shape (chunks_per_shard_0, chunks_per_shard_1, ..., 2) offsets_and_lengths: np.ndarray + @property + def chunks_per_shard(self) -> ChunkCoords: + return self.offsets_and_lengths.shape[0:-1] + def _localize_chunk(self, chunk_coords: ChunkCoords) -> ChunkCoords: return tuple( chunk_i % shard_i @@ -126,8 +135,10 @@ class _ShardProxy(Mapping): buf: BytesLike @classmethod - async def from_bytes(cls, buf: BytesLike, codec: ShardingCodec) -> _ShardProxy: - shard_index_size = codec._shard_index_size() + async def from_bytes( + cls, buf: BytesLike, codec: ShardingCodec, chunks_per_shard: ChunkCoords + ) -> _ShardProxy: + shard_index_size = codec._shard_index_size(chunks_per_shard) obj = cls() obj.buf = memoryview(buf) if codec.configuration.index_location == ShardingCodecIndexLocation.start: @@ -135,7 +146,7 @@ async def from_bytes(cls, buf: BytesLike, codec: ShardingCodec) -> _ShardProxy: else: shard_index_bytes = obj.buf[-shard_index_size:] - obj.index = await codec._decode_shard_index(shard_index_bytes) + obj.index = await codec._decode_shard_index(shard_index_bytes, chunks_per_shard) return obj @classmethod @@ -215,67 +226,49 @@ async def finalize( class ShardingCodec( ArrayBytesCodec, ArrayBytesCodecPartialDecodeMixin, ArrayBytesCodecPartialEncodeMixin ): - array_metadata: CoreArrayMetadata configuration: ShardingCodecConfigurationMetadata - codec_pipeline: CodecPipeline - index_codec_pipeline: CodecPipeline - chunks_per_shard: Tuple[int, ...] @classmethod def from_metadata( cls, codec_metadata: CodecMetadata, - array_metadata: CoreArrayMetadata, ) -> ShardingCodec: assert isinstance(codec_metadata, ShardingCodecMetadata) - - chunks_per_shard = tuple( - s // c - for s, c in zip( - array_metadata.chunk_shape, - codec_metadata.configuration.chunk_shape, - ) - ) - # rewriting the metadata to scope it to the shard - shard_metadata = CoreArrayMetadata( - shape=array_metadata.chunk_shape, - chunk_shape=codec_metadata.configuration.chunk_shape, - data_type=array_metadata.data_type, - fill_value=array_metadata.fill_value, - runtime_configuration=array_metadata.runtime_configuration, - ) - codec_pipeline = CodecPipeline.from_metadata( - codec_metadata.configuration.codecs, shard_metadata - ) - index_codec_pipeline = CodecPipeline.from_metadata( - codec_metadata.configuration.index_codecs, - CoreArrayMetadata( - shape=chunks_per_shard + (2,), - chunk_shape=chunks_per_shard + (2,), - data_type=DataType.uint64, - fill_value=MAX_UINT_64, - runtime_configuration=array_metadata.runtime_configuration, - ), - ) - return cls( - array_metadata=array_metadata, - configuration=codec_metadata.configuration, - codec_pipeline=codec_pipeline, - index_codec_pipeline=index_codec_pipeline, - chunks_per_shard=chunks_per_shard, - ) + return cls(configuration=codec_metadata.configuration) @classmethod def get_metadata_class(cls) -> Type[ShardingCodecMetadata]: return ShardingCodecMetadata + def validate(self, array_metadata: ArrayMetadata) -> None: + assert len(self.configuration.chunk_shape) == array_metadata.ndim, ( + "The shard's `chunk_shape` and array's `shape` need to have the " + + "same number of dimensions." + ) + assert isinstance( + array_metadata.chunk_grid, RegularChunkGridMetadata + ), "Sharding is only compatible with regular chunk grids." + assert all( + s % c == 0 + for s, c in zip( + array_metadata.chunk_grid.configuration.chunk_shape, + self.configuration.chunk_shape, + ) + ), ( + "The array's `chunk_shape` needs to be divisible by the " + + "shard's inner `chunk_shape`." + ) + async def decode( self, shard_bytes: BytesLike, + shard_spec: ArraySpec, + runtime_configuration: RuntimeConfiguration, ) -> np.ndarray: # print("decode") - shard_shape = self.array_metadata.chunk_shape + shard_shape = shard_spec.shape chunk_shape = self.configuration.chunk_shape + chunks_per_shard = self._get_chunks_per_shard(shard_spec) indexer = BasicIndexer( tuple(slice(0, s) for s in shard_shape), @@ -286,13 +279,13 @@ async def decode( # setup output array out = np.zeros( shard_shape, - dtype=self.array_metadata.dtype, - order=self.array_metadata.runtime_configuration.order, + dtype=shard_spec.dtype, + order=runtime_configuration.order, ) - shard_dict = await _ShardProxy.from_bytes(shard_bytes, self) + shard_dict = await _ShardProxy.from_bytes(shard_bytes, self, chunks_per_shard) if shard_dict.index.is_all_empty(): - out.fill(self.array_metadata.fill_value) + out.fill(shard_spec.fill_value) return out # decoding chunks and writing them into the output buffer @@ -303,12 +296,14 @@ async def decode( chunk_coords, chunk_selection, out_selection, + shard_spec, + runtime_configuration, out, ) for chunk_coords, chunk_selection, out_selection in indexer ], self._read_chunk, - self.array_metadata.runtime_configuration.concurrency, + runtime_configuration.concurrency, ) return out @@ -317,9 +312,12 @@ async def decode_partial( self, store_path: StorePath, selection: SliceSelection, + shard_spec: ArraySpec, + runtime_configuration: RuntimeConfiguration, ) -> Optional[np.ndarray]: - shard_shape = self.array_metadata.chunk_shape + shard_shape = shard_spec.shape chunk_shape = self.configuration.chunk_shape + chunks_per_shard = self._get_chunks_per_shard(shard_spec) indexer = BasicIndexer( selection, @@ -330,8 +328,8 @@ async def decode_partial( # setup output array out = np.zeros( indexer.shape, - dtype=self.array_metadata.dtype, - order=self.array_metadata.runtime_configuration.order, + dtype=shard_spec.dtype, + order=runtime_configuration.order, ) indexed_chunks = list(indexer) @@ -339,15 +337,15 @@ async def decode_partial( # reading bytes of all requested chunks shard_dict: Mapping[ChunkCoords, BytesLike] = {} - if self._is_total_shard(all_chunk_coords): + if self._is_total_shard(all_chunk_coords, chunks_per_shard): # read entire shard - shard_dict_maybe = await self._load_full_shard_maybe(store_path) + shard_dict_maybe = await self._load_full_shard_maybe(store_path, chunks_per_shard) if shard_dict_maybe is None: return None shard_dict = shard_dict_maybe else: # read some chunks within the shard - shard_index = await self._load_shard_index_maybe(store_path) + shard_index = await self._load_shard_index_maybe(store_path, chunks_per_shard) if shard_index is None: return None shard_dict = {} @@ -366,12 +364,14 @@ async def decode_partial( chunk_coords, chunk_selection, out_selection, + shard_spec, + runtime_configuration, out, ) for chunk_coords, chunk_selection, out_selection in indexed_chunks ], self._read_chunk, - self.array_metadata.runtime_configuration.concurrency, + runtime_configuration.concurrency, ) return out @@ -382,22 +382,30 @@ async def _read_chunk( chunk_coords: ChunkCoords, chunk_selection: SliceSelection, out_selection: SliceSelection, + shard_spec: ArraySpec, + runtime_configuration: RuntimeConfiguration, out: np.ndarray, ): + chunk_spec = self._get_chunk_spec(shard_spec) chunk_bytes = shard_dict.get(chunk_coords, None) if chunk_bytes is not None: - chunk_array = await self.codec_pipeline.decode(chunk_bytes) + chunk_array = await self._codec_pipeline.decode( + chunk_bytes, chunk_spec, runtime_configuration + ) tmp = chunk_array[chunk_selection] out[out_selection] = tmp else: - out[out_selection] = self.array_metadata.fill_value + out[out_selection] = chunk_spec.fill_value async def encode( self, shard_array: np.ndarray, + shard_spec: ArraySpec, + runtime_configuration: RuntimeConfiguration, ) -> Optional[BytesLike]: - shard_shape = self.array_metadata.chunk_shape + shard_shape = shard_spec.shape chunk_shape = self.configuration.chunk_shape + chunks_per_shard = self._get_chunks_per_shard(shard_spec) indexer = list( BasicIndexer( @@ -419,14 +427,17 @@ async def _write_chunk( # handling writing partial chunks chunk_array = np.empty( chunk_shape, - dtype=self.array_metadata.dtype, + dtype=shard_spec.dtype, ) - chunk_array.fill(self.array_metadata.fill_value) + chunk_array.fill(shard_spec.fill_value) chunk_array[chunk_selection] = shard_array[out_selection] - if not np.array_equiv(chunk_array, self.array_metadata.fill_value): + if not np.array_equiv(chunk_array, shard_spec.fill_value): + chunk_spec = self._get_chunk_spec(shard_spec) return ( chunk_coords, - await self.codec_pipeline.encode(chunk_array), + await self._codec_pipeline.encode( + chunk_array, chunk_spec, runtime_configuration + ), ) return (chunk_coords, None) @@ -437,12 +448,12 @@ async def _write_chunk( for chunk_coords, chunk_selection, out_selection in indexer ], _write_chunk, - self.array_metadata.runtime_configuration.concurrency, + runtime_configuration.concurrency, ) if len(encoded_chunks) == 0: return None - shard_builder = _ShardBuilder.create_empty(self.chunks_per_shard) + shard_builder = _ShardBuilder.create_empty(chunks_per_shard) for chunk_coords, chunk_bytes in encoded_chunks: if chunk_bytes is not None: shard_builder.append(chunk_coords, chunk_bytes) @@ -456,15 +467,19 @@ async def encode_partial( store_path: StorePath, shard_array: np.ndarray, selection: SliceSelection, + shard_spec: ArraySpec, + runtime_configuration: RuntimeConfiguration, ) -> None: # print("encode_partial") - shard_shape = self.array_metadata.chunk_shape + shard_shape = shard_spec.shape chunk_shape = self.configuration.chunk_shape + chunks_per_shard = self._get_chunks_per_shard(shard_spec) + chunk_spec = self._get_chunk_spec(shard_spec) old_shard_dict = ( - await self._load_full_shard_maybe(store_path) - ) or _ShardProxy.create_empty(self.chunks_per_shard) - new_shard_builder = _ShardBuilder.create_empty(self.chunks_per_shard) + await self._load_full_shard_maybe(store_path, chunks_per_shard) + ) or _ShardProxy.create_empty(chunks_per_shard) + new_shard_builder = _ShardBuilder.create_empty(chunks_per_shard) tombstones: Set[ChunkCoords] = set() indexer = list( @@ -492,19 +507,23 @@ async def _write_chunk( if chunk_bytes is None: chunk_array = np.empty( self.configuration.chunk_shape, - dtype=self.array_metadata.dtype, + dtype=shard_spec.dtype, ) - chunk_array.fill(self.array_metadata.fill_value) + chunk_array.fill(shard_spec.fill_value) else: chunk_array = ( - await self.codec_pipeline.decode(chunk_bytes) + await self._codec_pipeline.decode( + chunk_bytes, chunk_spec, runtime_configuration + ) ).copy() # make a writable copy chunk_array[chunk_selection] = shard_array[out_selection] - if not np.array_equiv(chunk_array, self.array_metadata.fill_value): + if not np.array_equiv(chunk_array, shard_spec.fill_value): return ( chunk_coords, - await self.codec_pipeline.encode(chunk_array), + await self._codec_pipeline.encode( + chunk_array, chunk_spec, runtime_configuration + ), ) else: return (chunk_coords, None) @@ -519,7 +538,7 @@ async def _write_chunk( for chunk_coords, chunk_selection, out_selection in indexer ], _write_chunk, - self.array_metadata.runtime_configuration.concurrency, + runtime_configuration.concurrency, ) for chunk_coords, chunk_bytes in encoded_chunks: @@ -529,7 +548,10 @@ async def _write_chunk( tombstones.add(chunk_coords) shard_builder = _ShardBuilder.merge_with_morton_order( - self.chunks_per_shard, tombstones, new_shard_builder, old_shard_dict + chunks_per_shard, + tombstones, + new_shard_builder, + old_shard_dict, ) if shard_builder.index.is_all_empty(): @@ -542,44 +564,107 @@ async def _write_chunk( ) ) - def _is_total_shard(self, all_chunk_coords: Set[ChunkCoords]) -> bool: - return len(all_chunk_coords) == product(self.chunks_per_shard) and all( - chunk_coords in all_chunk_coords for chunk_coords in c_order_iter(self.chunks_per_shard) + def _is_total_shard( + self, all_chunk_coords: Set[ChunkCoords], chunks_per_shard: ChunkCoords + ) -> bool: + return len(all_chunk_coords) == product(chunks_per_shard) and all( + chunk_coords in all_chunk_coords for chunk_coords in c_order_iter(chunks_per_shard) ) - async def _decode_shard_index(self, index_bytes: BytesLike) -> _ShardIndex: - return _ShardIndex(await self.index_codec_pipeline.decode(index_bytes)) + async def _decode_shard_index( + self, index_bytes: BytesLike, chunks_per_shard: ChunkCoords + ) -> _ShardIndex: + return _ShardIndex( + await self._index_codec_pipeline.decode( + index_bytes, + self._get_index_chunk_spec(chunks_per_shard), + make_runtime_configuration("C"), + ) + ) async def _encode_shard_index(self, index: _ShardIndex) -> BytesLike: - index_bytes = await self.index_codec_pipeline.encode(index.offsets_and_lengths) + index_bytes = await self._index_codec_pipeline.encode( + index.offsets_and_lengths, + self._get_index_chunk_spec(index.chunks_per_shard), + make_runtime_configuration("C"), + ) assert index_bytes is not None return index_bytes - def _shard_index_size(self) -> int: - return self.index_codec_pipeline.compute_encoded_size(16 * product(self.chunks_per_shard)) + def _shard_index_size(self, chunks_per_shard: ChunkCoords) -> int: + return self._index_codec_pipeline.compute_encoded_size( + 16 * product(chunks_per_shard), self._get_index_chunk_spec(chunks_per_shard) + ) - async def _load_shard_index_maybe(self, store_path: StorePath) -> Optional[_ShardIndex]: - shard_index_size = self._shard_index_size() + @lru_cache + def _get_index_chunk_spec(self, chunks_per_shard: ChunkCoords) -> ArraySpec: + return ArraySpec( + shape=chunks_per_shard + (2,), + data_type=DataType.uint64, + fill_value=MAX_UINT_64, + ) + + @lru_cache + def _get_chunk_spec(self, shard_spec: ArraySpec) -> ArraySpec: + return ArraySpec( + shape=self.configuration.chunk_shape, + data_type=shard_spec.data_type, + fill_value=shard_spec.fill_value, + ) + + @lru_cache + def _get_chunks_per_shard(self, shard_spec: ArraySpec) -> ChunkCoords: + return tuple( + s // c + for s, c in zip( + shard_spec.shape, + self.configuration.chunk_shape, + ) + ) + + @cached_property + def _index_codec_pipeline(self) -> CodecPipeline: + return CodecPipeline.create( + [get_codec_from_metadata(c) for c in self.configuration.index_codecs] + ) + + @cached_property + def _codec_pipeline(self) -> CodecPipeline: + return CodecPipeline.create([get_codec_from_metadata(c) for c in self.configuration.codecs]) + + async def _load_shard_index_maybe( + self, store_path: StorePath, chunks_per_shard: ChunkCoords + ) -> Optional[_ShardIndex]: + shard_index_size = self._shard_index_size(chunks_per_shard) if self.configuration.index_location == ShardingCodecIndexLocation.start: index_bytes = await store_path.get((0, shard_index_size)) else: index_bytes = await store_path.get((-shard_index_size, None)) if index_bytes is not None: - return await self._decode_shard_index(index_bytes) + return await self._decode_shard_index(index_bytes, chunks_per_shard) return None - async def _load_shard_index(self, store_path: StorePath) -> _ShardIndex: - return (await self._load_shard_index_maybe(store_path)) or _ShardIndex.create_empty( - self.chunks_per_shard - ) + async def _load_shard_index( + self, store_path: StorePath, chunks_per_shard: ChunkCoords + ) -> _ShardIndex: + return ( + await self._load_shard_index_maybe(store_path, chunks_per_shard) + ) or _ShardIndex.create_empty(chunks_per_shard) - async def _load_full_shard_maybe(self, store_path: StorePath) -> Optional[_ShardProxy]: + async def _load_full_shard_maybe( + self, store_path: StorePath, chunks_per_shard: ChunkCoords + ) -> Optional[_ShardProxy]: shard_bytes = await store_path.get() - return await _ShardProxy.from_bytes(shard_bytes, self) if shard_bytes else None + return ( + await _ShardProxy.from_bytes(shard_bytes, self, chunks_per_shard) + if shard_bytes + else None + ) - def compute_encoded_size(self, input_byte_length: int) -> int: - return input_byte_length + self._shard_index_size() + def compute_encoded_size(self, input_byte_length: int, shard_spec: ArraySpec) -> int: + chunks_per_shard = self._get_chunks_per_shard(shard_spec) + return input_byte_length + self._shard_index_size(chunks_per_shard) register_codec("sharding_indexed", ShardingCodec) diff --git a/src/zarr/v3/codecs/transpose.py b/src/zarr/v3/codecs/transpose.py index d160f2a88d..de6eb0a480 100644 --- a/src/zarr/v3/codecs/transpose.py +++ b/src/zarr/v3/codecs/transpose.py @@ -9,14 +9,13 @@ ) import numpy as np -from attr import frozen, field +from attr import evolve, frozen, field from zarr.v3.abc.codec import ArrayArrayCodec from zarr.v3.codecs.registry import register_codec -from zarr.v3.metadata import CodecMetadata if TYPE_CHECKING: - from zarr.v3.metadata import CoreArrayMetadata + from zarr.v3.metadata import ArraySpec, CodecMetadata, RuntimeConfiguration @frozen @@ -32,69 +31,60 @@ class TransposeCodecMetadata: @frozen class TransposeCodec(ArrayArrayCodec): - array_metadata: CoreArrayMetadata order: Tuple[int, ...] is_fixed_size = True @classmethod - def from_metadata( - cls, codec_metadata: CodecMetadata, array_metadata: CoreArrayMetadata - ) -> TransposeCodec: + def from_metadata(cls, codec_metadata: CodecMetadata) -> TransposeCodec: assert isinstance(codec_metadata, TransposeCodecMetadata) + return cls(order=codec_metadata.configuration.order) - configuration = codec_metadata.configuration + def evolve(self, *, ndim: int, **_kwargs) -> TransposeCodec: # Compatibility with older version of ZEP1 - if configuration.order == "F": # type: ignore - order = tuple(array_metadata.ndim - x - 1 for x in range(array_metadata.ndim)) + if self.order == "F": # type: ignore + order = tuple(ndim - x - 1 for x in range(ndim)) - elif configuration.order == "C": # type: ignore - order = tuple(range(array_metadata.ndim)) + elif self.order == "C": # type: ignore + order = tuple(range(ndim)) else: - assert len(configuration.order) == array_metadata.ndim, ( + assert len(self.order) == ndim, ( "The `order` tuple needs have as many entries as " - + f"there are dimensions in the array. Got: {configuration.order}" + + f"there are dimensions in the array. Got: {self.order}" ) - assert len(configuration.order) == len(set(configuration.order)), ( - "There must not be duplicates in the `order` tuple. " - + f"Got: {configuration.order}" + assert len(self.order) == len(set(self.order)), ( + "There must not be duplicates in the `order` tuple. " + f"Got: {self.order}" ) - assert all(0 <= x < array_metadata.ndim for x in configuration.order), ( + assert all(0 <= x < ndim for x in self.order), ( "All entries in the `order` tuple must be between 0 and " - + f"the number of dimensions in the array. Got: {configuration.order}" + + f"the number of dimensions in the array. Got: {self.order}" ) - order = tuple(configuration.order) + order = tuple(self.order) - return cls( - array_metadata=array_metadata, - order=order, - ) + if order != self.order: + return evolve(self, order=order) + return self @classmethod def get_metadata_class(cls) -> Type[TransposeCodecMetadata]: return TransposeCodecMetadata - def resolve_metadata(self) -> CoreArrayMetadata: - from zarr.v3.metadata import CoreArrayMetadata - - return CoreArrayMetadata( - shape=tuple( - self.array_metadata.shape[self.order[i]] for i in range(self.array_metadata.ndim) - ), - chunk_shape=tuple( - self.array_metadata.chunk_shape[self.order[i]] - for i in range(self.array_metadata.ndim) - ), - data_type=self.array_metadata.data_type, - fill_value=self.array_metadata.fill_value, - runtime_configuration=self.array_metadata.runtime_configuration, + def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec: + from zarr.v3.metadata import ArraySpec + + return ArraySpec( + shape=tuple(chunk_spec.shape[self.order[i]] for i in range(chunk_spec.ndim)), + data_type=chunk_spec.data_type, + fill_value=chunk_spec.fill_value, ) async def decode( self, chunk_array: np.ndarray, + chunk_spec: ArraySpec, + _runtime_configuration: RuntimeConfiguration, ) -> np.ndarray: - inverse_order = [0 for _ in range(self.array_metadata.ndim)] + inverse_order = [0] * chunk_spec.ndim for x, i in enumerate(self.order): inverse_order[x] = i chunk_array = chunk_array.transpose(inverse_order) @@ -103,11 +93,13 @@ async def decode( async def encode( self, chunk_array: np.ndarray, + chunk_spec: ArraySpec, + _runtime_configuration: RuntimeConfiguration, ) -> Optional[np.ndarray]: chunk_array = chunk_array.transpose(self.order) return chunk_array - def compute_encoded_size(self, input_byte_length: int) -> int: + def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: return input_byte_length diff --git a/src/zarr/v3/codecs/zstd.py b/src/zarr/v3/codecs/zstd.py index e66d9e0700..59ce1cf088 100644 --- a/src/zarr/v3/codecs/zstd.py +++ b/src/zarr/v3/codecs/zstd.py @@ -13,10 +13,9 @@ from zarr.v3.abc.codec import BytesBytesCodec from zarr.v3.codecs.registry import register_codec from zarr.v3.common import BytesLike, to_thread -from zarr.v3.metadata import CodecMetadata if TYPE_CHECKING: - from zarr.v3.metadata import CoreArrayMetadata + from zarr.v3.metadata import ArraySpec, CodecMetadata, RuntimeConfiguration @frozen @@ -33,19 +32,13 @@ class ZstdCodecMetadata: @frozen class ZstdCodec(BytesBytesCodec): - array_metadata: CoreArrayMetadata configuration: ZstdCodecConfigurationMetadata is_fixed_size = True @classmethod - def from_metadata( - cls, codec_metadata: CodecMetadata, array_metadata: CoreArrayMetadata - ) -> ZstdCodec: + def from_metadata(cls, codec_metadata: CodecMetadata) -> ZstdCodec: assert isinstance(codec_metadata, ZstdCodecMetadata) - return cls( - array_metadata=array_metadata, - configuration=codec_metadata.configuration, - ) + return cls(configuration=codec_metadata.configuration) @classmethod def get_metadata_class(cls) -> Type[ZstdCodecMetadata]: @@ -64,16 +57,20 @@ def _decompress(self, data: bytes) -> bytes: async def decode( self, chunk_bytes: bytes, + _chunk_spec: ArraySpec, + _runtime_configuration: RuntimeConfiguration, ) -> BytesLike: return await to_thread(self._decompress, chunk_bytes) async def encode( self, chunk_bytes: bytes, + _chunk_spec: ArraySpec, + _runtime_configuration: RuntimeConfiguration, ) -> Optional[BytesLike]: return await to_thread(self._compress, chunk_bytes) - def compute_encoded_size(self, _input_byte_length: int) -> int: + def compute_encoded_size(self, _input_byte_length: int, _chunk_spec: ArraySpec) -> int: raise NotImplementedError diff --git a/src/zarr/v3/metadata.py b/src/zarr/v3/metadata.py index 53b300d3f8..c6dd9f1f46 100644 --- a/src/zarr/v3/metadata.py +++ b/src/zarr/v3/metadata.py @@ -24,6 +24,10 @@ def runtime_configuration( return RuntimeConfiguration(order=order, concurrency=concurrency) +# For type checking +_bool = bool + + class DataType(Enum): bool = "bool" int8 = "int8" @@ -54,6 +58,11 @@ def byte_count(self) -> int: } return data_type_byte_counts[self] + @property + def has_endianness(self) -> _bool: + # This might change in the future, e.g. for a complex with 2 8-bit floats + return self.byte_count != 1 + def to_numpy_shortname(self) -> str: data_type_to_numpy = { DataType.bool: "bool", @@ -154,12 +163,10 @@ class ShardingCodecIndexLocation(Enum): @frozen -class CoreArrayMetadata: +class ArraySpec: shape: ChunkCoords - chunk_shape: ChunkCoords data_type: DataType fill_value: Any - runtime_configuration: RuntimeConfiguration @property def dtype(self) -> np.dtype: @@ -191,13 +198,14 @@ def dtype(self) -> np.dtype: def ndim(self) -> int: return len(self.shape) - def get_core_metadata(self, runtime_configuration: RuntimeConfiguration) -> CoreArrayMetadata: - return CoreArrayMetadata( - shape=self.shape, - chunk_shape=self.chunk_grid.configuration.chunk_shape, + def get_chunk_spec(self, _chunk_coords: ChunkCoords) -> ArraySpec: + assert isinstance( + self.chunk_grid, RegularChunkGridMetadata + ), "Currently, only regular chunk grid is supported" + return ArraySpec( + shape=self.chunk_grid.configuration.chunk_shape, data_type=self.data_type, fill_value=self.fill_value, - runtime_configuration=runtime_configuration, ) def to_bytes(self) -> bytes: