1
1
from __future__ import annotations
2
2
3
3
from dataclasses import dataclass
4
- from typing import TYPE_CHECKING , Any
4
+ from functools import cached_property
5
+ from importlib .metadata import version
6
+ from typing import TYPE_CHECKING
5
7
6
- import numpy .typing as npt
7
- from zstandard import ZstdCompressor , ZstdDecompressor
8
+ from numcodecs .zstd import Zstd
8
9
9
10
from zarr .abc .codec import BytesBytesCodec
10
11
from zarr .array_spec import ArraySpec
@@ -38,6 +39,14 @@ class ZstdCodec(BytesBytesCodec):
38
39
checksum : bool = False
39
40
40
41
def __init__ (self , * , level : int = 0 , checksum : bool = False ) -> None :
42
+ # numcodecs 0.13.0 introduces the checksum attribute for the zstd codec
43
+ _numcodecs_version = tuple (map (int , version ("numcodecs" ).split ("." )))
44
+ if _numcodecs_version < (0 , 13 , 0 ): # pragma: no cover
45
+ raise RuntimeError (
46
+ "numcodecs version >= 0.13.0 is required to use the zstd codec. "
47
+ f"Version { _numcodecs_version } is currently installed."
48
+ )
49
+
41
50
level_parsed = parse_zstd_level (level )
42
51
checksum_parsed = parse_checksum (checksum )
43
52
@@ -52,21 +61,18 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
52
61
def to_dict (self ) -> dict [str , JSON ]:
53
62
return {"name" : "zstd" , "configuration" : {"level" : self .level , "checksum" : self .checksum }}
54
63
55
- def _compress (self , data : npt .NDArray [Any ]) -> bytes :
56
- ctx = ZstdCompressor (level = self .level , write_checksum = self .checksum )
57
- return ctx .compress (data .tobytes ())
58
-
59
- def _decompress (self , data : npt .NDArray [Any ]) -> bytes :
60
- ctx = ZstdDecompressor ()
61
- return ctx .decompress (data .tobytes ())
64
+ @cached_property
65
+ def _zstd_codec (self ) -> Zstd :
66
+ config_dict = {"level" : self .level , "checksum" : self .checksum }
67
+ return Zstd .from_config (config_dict )
62
68
63
69
async def _decode_single (
64
70
self ,
65
71
chunk_bytes : Buffer ,
66
72
chunk_spec : ArraySpec ,
67
73
) -> Buffer :
68
74
return await to_thread (
69
- as_numpy_array_wrapper , self ._decompress , chunk_bytes , chunk_spec .prototype
75
+ as_numpy_array_wrapper , self ._zstd_codec . decode , chunk_bytes , chunk_spec .prototype
70
76
)
71
77
72
78
async def _encode_single (
@@ -75,7 +81,7 @@ async def _encode_single(
75
81
chunk_spec : ArraySpec ,
76
82
) -> Buffer | None :
77
83
return await to_thread (
78
- as_numpy_array_wrapper , self ._compress , chunk_bytes , chunk_spec .prototype
84
+ as_numpy_array_wrapper , self ._zstd_codec . encode , chunk_bytes , chunk_spec .prototype
79
85
)
80
86
81
87
def compute_encoded_size (self , _input_byte_length : int , _chunk_spec : ArraySpec ) -> int :
0 commit comments