Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions src/zarr/abc/codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,22 @@

from abc import abstractmethod
from collections.abc import Awaitable, Callable, Iterable
from typing import TYPE_CHECKING, Generic, TypeVar
from typing import TYPE_CHECKING, Any, Generic, TypeVar

import numpy as np

from zarr.abc.metadata import Metadata
from zarr.abc.store import ByteGetter, ByteSetter
from zarr.buffer import Buffer, NDBuffer
from zarr.common import concurrent_map
from zarr.chunk_grids import ChunkGrid
from zarr.common import ChunkCoords, concurrent_map
from zarr.config import config

if TYPE_CHECKING:
from typing_extensions import Self

from zarr.array_spec import ArraySpec
from zarr.indexing import SelectorTuple
from zarr.metadata import ArrayMetadata

CodecInput = TypeVar("CodecInput", bound=NDBuffer | Buffer)
CodecOutput = TypeVar("CodecOutput", bound=NDBuffer | Buffer)
Expand Down Expand Up @@ -75,13 +77,18 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
"""
return self

def validate(self, array_metadata: ArrayMetadata) -> None:
def validate(self, *, shape: ChunkCoords, dtype: np.dtype[Any], chunk_grid: ChunkGrid) -> None:
"""Validates that the codec configuration is compatible with the array metadata.
Raises errors when the codec configuration is not compatible.

Parameters
----------
array_metadata : ArrayMetadata
shape: ChunkCoords
The array shape
dtype: np.dtype[Any]
The array data type
chunk_grid: ChunkGrid
The array chunk grid
"""
...

Expand Down Expand Up @@ -275,13 +282,18 @@ def supports_partial_decode(self) -> bool: ...
def supports_partial_encode(self) -> bool: ...

@abstractmethod
def validate(self, array_metadata: ArrayMetadata) -> None:
def validate(self, *, shape: ChunkCoords, dtype: np.dtype[Any], chunk_grid: ChunkGrid) -> None:
"""Validates that all codec configurations are compatible with the array metadata.
Raises errors when a codec configuration is not compatible.

Parameters
----------
array_metadata : ArrayMetadata
shape: ChunkCoords
The array shape
dtype: np.dtype[Any]
The array data type
chunk_grid: ChunkGrid
The array chunk grid
"""
...

Expand Down
27 changes: 21 additions & 6 deletions src/zarr/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,21 @@
# 1. Was splitting the array into two classes really necessary?
from asyncio import gather
from collections.abc import Iterable
from dataclasses import dataclass, replace
from dataclasses import dataclass, field, replace
from typing import Any, Literal, cast

import numpy as np
import numpy.typing as npt

from zarr.abc.codec import Codec
from zarr.abc.codec import Codec, CodecPipeline
from zarr.abc.store import set_or_delete
from zarr.attributes import Attributes
from zarr.buffer import BufferPrototype, NDArrayLike, NDBuffer, default_buffer_prototype
from zarr.chunk_grids import RegularChunkGrid
from zarr.chunk_key_encodings import ChunkKeyEncoding, DefaultChunkKeyEncoding, V2ChunkKeyEncoding
from zarr.codecs import BytesCodec
from zarr.codecs._v2 import V2Compressor, V2Filters
from zarr.codecs.pipeline import BatchedCodecPipeline
from zarr.common import (
JSON,
ZARR_JSON,
Expand Down Expand Up @@ -63,8 +65,8 @@
from zarr.sync import sync


def parse_array_metadata(data: Any) -> ArrayMetadata:
if isinstance(data, ArrayMetadata):
def parse_array_metadata(data: Any) -> ArrayV2Metadata | ArrayV3Metadata:
if isinstance(data, ArrayV2Metadata | ArrayV3Metadata):
return data
elif isinstance(data, dict):
if data["zarr_format"] == 3:
Expand All @@ -74,10 +76,22 @@ def parse_array_metadata(data: Any) -> ArrayMetadata:
raise TypeError


def create_codec_pipeline(metadata: ArrayV2Metadata | ArrayV3Metadata) -> BatchedCodecPipeline:
if isinstance(metadata, ArrayV3Metadata):
return BatchedCodecPipeline.from_list(metadata.codecs)
elif isinstance(metadata, ArrayV2Metadata):
return BatchedCodecPipeline.from_list(
[V2Filters(metadata.filters or []), V2Compressor(metadata.compressor)]
)
else:
raise AssertionError


@dataclass(frozen=True)
class AsyncArray:
metadata: ArrayMetadata
store_path: StorePath
codec_pipeline: CodecPipeline = field(init=False)
order: Literal["C", "F"]

def __init__(
Expand All @@ -92,6 +106,7 @@ def __init__(
object.__setattr__(self, "metadata", metadata_parsed)
object.__setattr__(self, "store_path", store_path)
object.__setattr__(self, "order", order_parsed)
object.__setattr__(self, "codec_pipeline", create_codec_pipeline(metadata=metadata_parsed))

@classmethod
async def create(
Expand Down Expand Up @@ -443,7 +458,7 @@ async def _get_selection(
)
if product(indexer.shape) > 0:
# reading chunks and decoding them
await self.metadata.codec_pipeline.read(
await self.codec_pipeline.read(
[
(
self.store_path / self.metadata.encode_chunk_key(chunk_coords),
Expand Down Expand Up @@ -503,7 +518,7 @@ async def _set_selection(
value_buffer = prototype.nd_buffer.from_ndarray_like(value)

# merging with existing data and encoding chunks
await self.metadata.codec_pipeline.write(
await self.codec_pipeline.write(
[
(
self.store_path / self.metadata.encode_chunk_key(chunk_coords),
Expand Down
9 changes: 3 additions & 6 deletions src/zarr/codecs/blosc.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,17 +125,14 @@ def to_dict(self) -> dict[str, JSON]:
}

def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
dtype = array_spec.dtype
new_codec = self
if new_codec.typesize is None:
new_codec = replace(new_codec, typesize=array_spec.dtype.itemsize)
new_codec = replace(new_codec, typesize=dtype.itemsize)
if new_codec.shuffle is None:
new_codec = replace(
new_codec,
shuffle=(
BloscShuffle.bitshuffle
if array_spec.dtype.itemsize == 1
else BloscShuffle.shuffle
),
shuffle=(BloscShuffle.bitshuffle if dtype.itemsize == 1 else BloscShuffle.shuffle),
)

return new_codec
Expand Down
124 changes: 72 additions & 52 deletions src/zarr/codecs/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from itertools import islice
from typing import TYPE_CHECKING, TypeVar
from itertools import islice, pairwise
from typing import TYPE_CHECKING, Any, TypeVar
from warnings import warn

import numpy as np

from zarr.abc.codec import (
ArrayArrayCodec,
ArrayBytesCodec,
Expand All @@ -17,11 +19,11 @@
)
from zarr.abc.store import ByteGetter, ByteSetter
from zarr.buffer import Buffer, BufferPrototype, NDBuffer
from zarr.chunk_grids import ChunkGrid
from zarr.codecs.registry import get_codec_class
from zarr.common import JSON, concurrent_map, parse_named_configuration
from zarr.common import JSON, ChunkCoords, concurrent_map, parse_named_configuration
from zarr.config import config
from zarr.indexing import SelectorTuple, is_scalar, is_total_slice
from zarr.metadata import ArrayMetadata

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down Expand Up @@ -87,54 +89,11 @@ def to_dict(self) -> JSON:
return [c.to_dict() for c in self]

def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
return type(self).from_list([c.evolve_from_array_spec(array_spec) for c in self])

@staticmethod
def codecs_from_list(
codecs: list[Codec],
) -> tuple[tuple[ArrayArrayCodec, ...], ArrayBytesCodec, tuple[BytesBytesCodec, ...]]:
from zarr.codecs.sharding import ShardingCodec

if not any(isinstance(codec, ArrayBytesCodec) for codec in codecs):
raise ValueError("Exactly one array-to-bytes codec is required.")

prev_codec: Codec | None = None
for codec in codecs:
if prev_codec is not None:
if isinstance(codec, ArrayBytesCodec) and isinstance(prev_codec, ArrayBytesCodec):
raise ValueError(
f"ArrayBytesCodec '{type(codec)}' cannot follow after ArrayBytesCodec '{type(prev_codec)}' because exactly 1 ArrayBytesCodec is allowed."
)
if isinstance(codec, ArrayBytesCodec) and isinstance(prev_codec, BytesBytesCodec):
raise ValueError(
f"ArrayBytesCodec '{type(codec)}' cannot follow after BytesBytesCodec '{type(prev_codec)}'."
)
if isinstance(codec, ArrayArrayCodec) and isinstance(prev_codec, ArrayBytesCodec):
raise ValueError(
f"ArrayArrayCodec '{type(codec)}' cannot follow after ArrayBytesCodec '{type(prev_codec)}'."
)
if isinstance(codec, ArrayArrayCodec) and isinstance(prev_codec, BytesBytesCodec):
raise ValueError(
f"ArrayArrayCodec '{type(codec)}' cannot follow after BytesBytesCodec '{type(prev_codec)}'."
)
prev_codec = codec

if any(isinstance(codec, ShardingCodec) for codec in codecs) and len(codecs) > 1:
warn(
"Combining a `sharding_indexed` codec disables partial reads and "
"writes, which may lead to inefficient performance.",
stacklevel=3,
)

return (
tuple(codec for codec in codecs if isinstance(codec, ArrayArrayCodec)),
next(codec for codec in codecs if isinstance(codec, ArrayBytesCodec)),
tuple(codec for codec in codecs if isinstance(codec, BytesBytesCodec)),
)
return type(self).from_list([c.evolve_from_array_spec(array_spec=array_spec) for c in self])

@classmethod
def from_list(cls, codecs: list[Codec], *, batch_size: int | None = None) -> Self:
array_array_codecs, array_bytes_codec, bytes_bytes_codecs = cls.codecs_from_list(codecs)
def from_list(cls, codecs: Iterable[Codec], *, batch_size: int | None = None) -> Self:
array_array_codecs, array_bytes_codec, bytes_bytes_codecs = codecs_from_list(codecs)

return cls(
array_array_codecs=array_array_codecs,
Expand Down Expand Up @@ -180,9 +139,9 @@ def __iter__(self) -> Iterator[Codec]:
yield self.array_bytes_codec
yield from self.bytes_bytes_codecs

def validate(self, array_metadata: ArrayMetadata) -> None:
def validate(self, *, shape: ChunkCoords, dtype: np.dtype[Any], chunk_grid: ChunkGrid) -> None:
for codec in self:
codec.validate(array_metadata)
codec.validate(shape=shape, dtype=dtype, chunk_grid=chunk_grid)

def compute_encoded_size(self, byte_length: int, array_spec: ArraySpec) -> int:
for codec in self:
Expand Down Expand Up @@ -509,3 +468,64 @@ async def write(
self.write_batch,
config.get("async.concurrency"),
)


def codecs_from_list(
codecs: Iterable[Codec],
) -> tuple[tuple[ArrayArrayCodec, ...], ArrayBytesCodec, tuple[BytesBytesCodec, ...]]:
from zarr.codecs.sharding import ShardingCodec

array_array: tuple[ArrayArrayCodec, ...] = ()
array_bytes_maybe: ArrayBytesCodec | None = None
bytes_bytes: tuple[BytesBytesCodec, ...] = ()

if any(isinstance(codec, ShardingCodec) for codec in codecs) and len(tuple(codecs)) > 1:
warn(
"Combining a `sharding_indexed` codec disables partial reads and "
"writes, which may lead to inefficient performance.",
stacklevel=3,
)

for prev_codec, cur_codec in pairwise((None, *codecs)):
if isinstance(cur_codec, ArrayArrayCodec):
if isinstance(prev_codec, ArrayBytesCodec | BytesBytesCodec):
msg = (
f"Invalid codec order. ArrayArrayCodec {cur_codec}"
"must be preceded by another ArrayArrayCodec. "
f"Got {type(prev_codec)} instead."
)
raise ValueError(msg)
array_array += (cur_codec,)

elif isinstance(cur_codec, ArrayBytesCodec):
if isinstance(prev_codec, BytesBytesCodec):
msg = (
f"Invalid codec order. ArrayBytes codec {cur_codec}"
f" must be preceded by an ArrayArrayCodec. Got {type(prev_codec)} instead."
)
raise ValueError(msg)

if array_bytes_maybe is not None:
msg = (
f"Got two instances of ArrayBytesCodec: {array_bytes_maybe} and {cur_codec}. "
"Only one array-to-bytes codec is allowed."
)
raise ValueError(msg)

array_bytes_maybe = cur_codec

elif isinstance(cur_codec, BytesBytesCodec):
if isinstance(prev_codec, ArrayArrayCodec):
msg = (
f"Invalid codec order. BytesBytesCodec {cur_codec}"
"must be preceded by either another BytesBytesCodec, or an ArrayBytesCodec. "
f"Got {type(prev_codec)} instead."
)
bytes_bytes += (cur_codec,)
else:
raise AssertionError

if array_bytes_maybe is None:
raise ValueError("Required ArrayBytesCodec was not found.")
else:
return array_array, array_bytes_maybe, bytes_bytes
Loading