Skip to content

Commit 8aadd15

Browse files
authored
[v3] Elevate codec pipeline (#1932)
* initial work toward pushing codecpipeline higher in the stacK * remove codecpipeline from metadata, add it to AsyncArray and or create it dynamically * revert changes to blosc.py * revert changes to test_codecs.py * consistent expanded function signature for evolve_from_array_spec * restore wider function signature for codec.validate to avoid self-referential function call * remove commented code block * make codec_pipeline a cached property of sharding codec * cached_property -> vanilla property
1 parent 65dc4cc commit 8aadd15

File tree

7 files changed

+191
-136
lines changed

7 files changed

+191
-136
lines changed

src/zarr/abc/codec.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,22 @@
22

33
from abc import abstractmethod
44
from collections.abc import Awaitable, Callable, Iterable
5-
from typing import TYPE_CHECKING, Generic, TypeVar
5+
from typing import TYPE_CHECKING, Any, Generic, TypeVar
6+
7+
import numpy as np
68

79
from zarr.abc.metadata import Metadata
810
from zarr.abc.store import ByteGetter, ByteSetter
911
from zarr.buffer import Buffer, NDBuffer
10-
from zarr.common import concurrent_map
12+
from zarr.chunk_grids import ChunkGrid
13+
from zarr.common import ChunkCoords, concurrent_map
1114
from zarr.config import config
1215

1316
if TYPE_CHECKING:
1417
from typing_extensions import Self
1518

1619
from zarr.array_spec import ArraySpec
1720
from zarr.indexing import SelectorTuple
18-
from zarr.metadata import ArrayMetadata
1921

2022
CodecInput = TypeVar("CodecInput", bound=NDBuffer | Buffer)
2123
CodecOutput = TypeVar("CodecOutput", bound=NDBuffer | Buffer)
@@ -75,13 +77,18 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
7577
"""
7678
return self
7779

78-
def validate(self, array_metadata: ArrayMetadata) -> None:
80+
def validate(self, *, shape: ChunkCoords, dtype: np.dtype[Any], chunk_grid: ChunkGrid) -> None:
7981
"""Validates that the codec configuration is compatible with the array metadata.
8082
Raises errors when the codec configuration is not compatible.
8183
8284
Parameters
8385
----------
84-
array_metadata : ArrayMetadata
86+
shape: ChunkCoords
87+
The array shape
88+
dtype: np.dtype[Any]
89+
The array data type
90+
chunk_grid: ChunkGrid
91+
The array chunk grid
8592
"""
8693
...
8794

@@ -275,13 +282,18 @@ def supports_partial_decode(self) -> bool: ...
275282
def supports_partial_encode(self) -> bool: ...
276283

277284
@abstractmethod
278-
def validate(self, array_metadata: ArrayMetadata) -> None:
285+
def validate(self, *, shape: ChunkCoords, dtype: np.dtype[Any], chunk_grid: ChunkGrid) -> None:
279286
"""Validates that all codec configurations are compatible with the array metadata.
280287
Raises errors when a codec configuration is not compatible.
281288
282289
Parameters
283290
----------
284-
array_metadata : ArrayMetadata
291+
shape: ChunkCoords
292+
The array shape
293+
dtype: np.dtype[Any]
294+
The array data type
295+
chunk_grid: ChunkGrid
296+
The array chunk grid
285297
"""
286298
...
287299

src/zarr/array.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,21 @@
1111
# 1. Was splitting the array into two classes really necessary?
1212
from asyncio import gather
1313
from collections.abc import Iterable
14-
from dataclasses import dataclass, replace
14+
from dataclasses import dataclass, field, replace
1515
from typing import Any, Literal, cast
1616

1717
import numpy as np
1818
import numpy.typing as npt
1919

20-
from zarr.abc.codec import Codec
20+
from zarr.abc.codec import Codec, CodecPipeline
2121
from zarr.abc.store import set_or_delete
2222
from zarr.attributes import Attributes
2323
from zarr.buffer import BufferPrototype, NDArrayLike, NDBuffer, default_buffer_prototype
2424
from zarr.chunk_grids import RegularChunkGrid
2525
from zarr.chunk_key_encodings import ChunkKeyEncoding, DefaultChunkKeyEncoding, V2ChunkKeyEncoding
2626
from zarr.codecs import BytesCodec
27+
from zarr.codecs._v2 import V2Compressor, V2Filters
28+
from zarr.codecs.pipeline import BatchedCodecPipeline
2729
from zarr.common import (
2830
JSON,
2931
ZARR_JSON,
@@ -63,8 +65,8 @@
6365
from zarr.sync import sync
6466

6567

66-
def parse_array_metadata(data: Any) -> ArrayMetadata:
67-
if isinstance(data, ArrayMetadata):
68+
def parse_array_metadata(data: Any) -> ArrayV2Metadata | ArrayV3Metadata:
69+
if isinstance(data, ArrayV2Metadata | ArrayV3Metadata):
6870
return data
6971
elif isinstance(data, dict):
7072
if data["zarr_format"] == 3:
@@ -74,10 +76,22 @@ def parse_array_metadata(data: Any) -> ArrayMetadata:
7476
raise TypeError
7577

7678

79+
def create_codec_pipeline(metadata: ArrayV2Metadata | ArrayV3Metadata) -> BatchedCodecPipeline:
80+
if isinstance(metadata, ArrayV3Metadata):
81+
return BatchedCodecPipeline.from_list(metadata.codecs)
82+
elif isinstance(metadata, ArrayV2Metadata):
83+
return BatchedCodecPipeline.from_list(
84+
[V2Filters(metadata.filters or []), V2Compressor(metadata.compressor)]
85+
)
86+
else:
87+
raise AssertionError
88+
89+
7790
@dataclass(frozen=True)
7891
class AsyncArray:
7992
metadata: ArrayMetadata
8093
store_path: StorePath
94+
codec_pipeline: CodecPipeline = field(init=False)
8195
order: Literal["C", "F"]
8296

8397
def __init__(
@@ -92,6 +106,7 @@ def __init__(
92106
object.__setattr__(self, "metadata", metadata_parsed)
93107
object.__setattr__(self, "store_path", store_path)
94108
object.__setattr__(self, "order", order_parsed)
109+
object.__setattr__(self, "codec_pipeline", create_codec_pipeline(metadata=metadata_parsed))
95110

96111
@classmethod
97112
async def create(
@@ -443,7 +458,7 @@ async def _get_selection(
443458
)
444459
if product(indexer.shape) > 0:
445460
# reading chunks and decoding them
446-
await self.metadata.codec_pipeline.read(
461+
await self.codec_pipeline.read(
447462
[
448463
(
449464
self.store_path / self.metadata.encode_chunk_key(chunk_coords),
@@ -503,7 +518,7 @@ async def _set_selection(
503518
value_buffer = prototype.nd_buffer.from_ndarray_like(value)
504519

505520
# merging with existing data and encoding chunks
506-
await self.metadata.codec_pipeline.write(
521+
await self.codec_pipeline.write(
507522
[
508523
(
509524
self.store_path / self.metadata.encode_chunk_key(chunk_coords),

src/zarr/codecs/blosc.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,17 +125,14 @@ def to_dict(self) -> dict[str, JSON]:
125125
}
126126

127127
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
128+
dtype = array_spec.dtype
128129
new_codec = self
129130
if new_codec.typesize is None:
130-
new_codec = replace(new_codec, typesize=array_spec.dtype.itemsize)
131+
new_codec = replace(new_codec, typesize=dtype.itemsize)
131132
if new_codec.shuffle is None:
132133
new_codec = replace(
133134
new_codec,
134-
shuffle=(
135-
BloscShuffle.bitshuffle
136-
if array_spec.dtype.itemsize == 1
137-
else BloscShuffle.shuffle
138-
),
135+
shuffle=(BloscShuffle.bitshuffle if dtype.itemsize == 1 else BloscShuffle.shuffle),
139136
)
140137

141138
return new_codec

src/zarr/codecs/pipeline.py

Lines changed: 72 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
from collections.abc import Iterable, Iterator
44
from dataclasses import dataclass
5-
from itertools import islice
6-
from typing import TYPE_CHECKING, TypeVar
5+
from itertools import islice, pairwise
6+
from typing import TYPE_CHECKING, Any, TypeVar
77
from warnings import warn
88

9+
import numpy as np
10+
911
from zarr.abc.codec import (
1012
ArrayArrayCodec,
1113
ArrayBytesCodec,
@@ -17,11 +19,11 @@
1719
)
1820
from zarr.abc.store import ByteGetter, ByteSetter
1921
from zarr.buffer import Buffer, BufferPrototype, NDBuffer
22+
from zarr.chunk_grids import ChunkGrid
2023
from zarr.codecs.registry import get_codec_class
21-
from zarr.common import JSON, concurrent_map, parse_named_configuration
24+
from zarr.common import JSON, ChunkCoords, concurrent_map, parse_named_configuration
2225
from zarr.config import config
2326
from zarr.indexing import SelectorTuple, is_scalar, is_total_slice
24-
from zarr.metadata import ArrayMetadata
2527

2628
if TYPE_CHECKING:
2729
from typing_extensions import Self
@@ -87,54 +89,11 @@ def to_dict(self) -> JSON:
8789
return [c.to_dict() for c in self]
8890

8991
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
90-
return type(self).from_list([c.evolve_from_array_spec(array_spec) for c in self])
91-
92-
@staticmethod
93-
def codecs_from_list(
94-
codecs: list[Codec],
95-
) -> tuple[tuple[ArrayArrayCodec, ...], ArrayBytesCodec, tuple[BytesBytesCodec, ...]]:
96-
from zarr.codecs.sharding import ShardingCodec
97-
98-
if not any(isinstance(codec, ArrayBytesCodec) for codec in codecs):
99-
raise ValueError("Exactly one array-to-bytes codec is required.")
100-
101-
prev_codec: Codec | None = None
102-
for codec in codecs:
103-
if prev_codec is not None:
104-
if isinstance(codec, ArrayBytesCodec) and isinstance(prev_codec, ArrayBytesCodec):
105-
raise ValueError(
106-
f"ArrayBytesCodec '{type(codec)}' cannot follow after ArrayBytesCodec '{type(prev_codec)}' because exactly 1 ArrayBytesCodec is allowed."
107-
)
108-
if isinstance(codec, ArrayBytesCodec) and isinstance(prev_codec, BytesBytesCodec):
109-
raise ValueError(
110-
f"ArrayBytesCodec '{type(codec)}' cannot follow after BytesBytesCodec '{type(prev_codec)}'."
111-
)
112-
if isinstance(codec, ArrayArrayCodec) and isinstance(prev_codec, ArrayBytesCodec):
113-
raise ValueError(
114-
f"ArrayArrayCodec '{type(codec)}' cannot follow after ArrayBytesCodec '{type(prev_codec)}'."
115-
)
116-
if isinstance(codec, ArrayArrayCodec) and isinstance(prev_codec, BytesBytesCodec):
117-
raise ValueError(
118-
f"ArrayArrayCodec '{type(codec)}' cannot follow after BytesBytesCodec '{type(prev_codec)}'."
119-
)
120-
prev_codec = codec
121-
122-
if any(isinstance(codec, ShardingCodec) for codec in codecs) and len(codecs) > 1:
123-
warn(
124-
"Combining a `sharding_indexed` codec disables partial reads and "
125-
"writes, which may lead to inefficient performance.",
126-
stacklevel=3,
127-
)
128-
129-
return (
130-
tuple(codec for codec in codecs if isinstance(codec, ArrayArrayCodec)),
131-
next(codec for codec in codecs if isinstance(codec, ArrayBytesCodec)),
132-
tuple(codec for codec in codecs if isinstance(codec, BytesBytesCodec)),
133-
)
92+
return type(self).from_list([c.evolve_from_array_spec(array_spec=array_spec) for c in self])
13493

13594
@classmethod
136-
def from_list(cls, codecs: list[Codec], *, batch_size: int | None = None) -> Self:
137-
array_array_codecs, array_bytes_codec, bytes_bytes_codecs = cls.codecs_from_list(codecs)
95+
def from_list(cls, codecs: Iterable[Codec], *, batch_size: int | None = None) -> Self:
96+
array_array_codecs, array_bytes_codec, bytes_bytes_codecs = codecs_from_list(codecs)
13897

13998
return cls(
14099
array_array_codecs=array_array_codecs,
@@ -180,9 +139,9 @@ def __iter__(self) -> Iterator[Codec]:
180139
yield self.array_bytes_codec
181140
yield from self.bytes_bytes_codecs
182141

183-
def validate(self, array_metadata: ArrayMetadata) -> None:
142+
def validate(self, *, shape: ChunkCoords, dtype: np.dtype[Any], chunk_grid: ChunkGrid) -> None:
184143
for codec in self:
185-
codec.validate(array_metadata)
144+
codec.validate(shape=shape, dtype=dtype, chunk_grid=chunk_grid)
186145

187146
def compute_encoded_size(self, byte_length: int, array_spec: ArraySpec) -> int:
188147
for codec in self:
@@ -509,3 +468,64 @@ async def write(
509468
self.write_batch,
510469
config.get("async.concurrency"),
511470
)
471+
472+
473+
def codecs_from_list(
474+
codecs: Iterable[Codec],
475+
) -> tuple[tuple[ArrayArrayCodec, ...], ArrayBytesCodec, tuple[BytesBytesCodec, ...]]:
476+
from zarr.codecs.sharding import ShardingCodec
477+
478+
array_array: tuple[ArrayArrayCodec, ...] = ()
479+
array_bytes_maybe: ArrayBytesCodec | None = None
480+
bytes_bytes: tuple[BytesBytesCodec, ...] = ()
481+
482+
if any(isinstance(codec, ShardingCodec) for codec in codecs) and len(tuple(codecs)) > 1:
483+
warn(
484+
"Combining a `sharding_indexed` codec disables partial reads and "
485+
"writes, which may lead to inefficient performance.",
486+
stacklevel=3,
487+
)
488+
489+
for prev_codec, cur_codec in pairwise((None, *codecs)):
490+
if isinstance(cur_codec, ArrayArrayCodec):
491+
if isinstance(prev_codec, ArrayBytesCodec | BytesBytesCodec):
492+
msg = (
493+
f"Invalid codec order. ArrayArrayCodec {cur_codec}"
494+
"must be preceded by another ArrayArrayCodec. "
495+
f"Got {type(prev_codec)} instead."
496+
)
497+
raise ValueError(msg)
498+
array_array += (cur_codec,)
499+
500+
elif isinstance(cur_codec, ArrayBytesCodec):
501+
if isinstance(prev_codec, BytesBytesCodec):
502+
msg = (
503+
f"Invalid codec order. ArrayBytes codec {cur_codec}"
504+
f" must be preceded by an ArrayArrayCodec. Got {type(prev_codec)} instead."
505+
)
506+
raise ValueError(msg)
507+
508+
if array_bytes_maybe is not None:
509+
msg = (
510+
f"Got two instances of ArrayBytesCodec: {array_bytes_maybe} and {cur_codec}. "
511+
"Only one array-to-bytes codec is allowed."
512+
)
513+
raise ValueError(msg)
514+
515+
array_bytes_maybe = cur_codec
516+
517+
elif isinstance(cur_codec, BytesBytesCodec):
518+
if isinstance(prev_codec, ArrayArrayCodec):
519+
msg = (
520+
f"Invalid codec order. BytesBytesCodec {cur_codec}"
521+
"must be preceded by either another BytesBytesCodec, or an ArrayBytesCodec. "
522+
f"Got {type(prev_codec)} instead."
523+
)
524+
bytes_bytes += (cur_codec,)
525+
else:
526+
raise AssertionError
527+
528+
if array_bytes_maybe is None:
529+
raise ValueError("Required ArrayBytesCodec was not found.")
530+
else:
531+
return array_array, array_bytes_maybe, bytes_bytes

0 commit comments

Comments
 (0)