Skip to content

Commit d9a142a

Browse files
authored
Remove zstandard dependency in favor of numcodecs (#1838)
* removes zstandard dependency * fix dependencies * change default level to 3 * revert to 0 default level
1 parent f6de884 commit d9a142a

File tree

3 files changed

+20
-16
lines changed

3 files changed

+20
-16
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ repos:
3535
- numcodecs
3636
- numpy
3737
- typing_extensions
38-
- zstandard
3938
# Tests
4039
- pytest
4140
# Zarr v2

pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ dependencies = [
2929
'numcodecs>=0.10.0',
3030
'fsspec>2024',
3131
'crc32c',
32-
'zstandard',
3332
'typing_extensions',
3433
'donfig',
3534
]
@@ -85,8 +84,8 @@ docs = [
8584
'pydata-sphinx-theme',
8685
'numpydoc',
8786
'numcodecs[msgpack]',
88-
"msgpack",
89-
"lmdb",
87+
'msgpack',
88+
'lmdb',
9089
]
9190
extra = [
9291
'msgpack',

src/zarr/codecs/zstd.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass
4-
from typing import TYPE_CHECKING, Any
4+
from functools import cached_property
5+
from importlib.metadata import version
6+
from typing import TYPE_CHECKING
57

6-
import numpy.typing as npt
7-
from zstandard import ZstdCompressor, ZstdDecompressor
8+
from numcodecs.zstd import Zstd
89

910
from zarr.abc.codec import BytesBytesCodec
1011
from zarr.array_spec import ArraySpec
@@ -38,6 +39,14 @@ class ZstdCodec(BytesBytesCodec):
3839
checksum: bool = False
3940

4041
def __init__(self, *, level: int = 0, checksum: bool = False) -> None:
42+
# numcodecs 0.13.0 introduces the checksum attribute for the zstd codec
43+
_numcodecs_version = tuple(map(int, version("numcodecs").split(".")))
44+
if _numcodecs_version < (0, 13, 0): # pragma: no cover
45+
raise RuntimeError(
46+
"numcodecs version >= 0.13.0 is required to use the zstd codec. "
47+
f"Version {_numcodecs_version} is currently installed."
48+
)
49+
4150
level_parsed = parse_zstd_level(level)
4251
checksum_parsed = parse_checksum(checksum)
4352

@@ -52,21 +61,18 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
5261
def to_dict(self) -> dict[str, JSON]:
5362
return {"name": "zstd", "configuration": {"level": self.level, "checksum": self.checksum}}
5463

55-
def _compress(self, data: npt.NDArray[Any]) -> bytes:
56-
ctx = ZstdCompressor(level=self.level, write_checksum=self.checksum)
57-
return ctx.compress(data.tobytes())
58-
59-
def _decompress(self, data: npt.NDArray[Any]) -> bytes:
60-
ctx = ZstdDecompressor()
61-
return ctx.decompress(data.tobytes())
64+
@cached_property
65+
def _zstd_codec(self) -> Zstd:
66+
config_dict = {"level": self.level, "checksum": self.checksum}
67+
return Zstd.from_config(config_dict)
6268

6369
async def _decode_single(
6470
self,
6571
chunk_bytes: Buffer,
6672
chunk_spec: ArraySpec,
6773
) -> Buffer:
6874
return await to_thread(
69-
as_numpy_array_wrapper, self._decompress, chunk_bytes, chunk_spec.prototype
75+
as_numpy_array_wrapper, self._zstd_codec.decode, chunk_bytes, chunk_spec.prototype
7076
)
7177

7278
async def _encode_single(
@@ -75,7 +81,7 @@ async def _encode_single(
7581
chunk_spec: ArraySpec,
7682
) -> Buffer | None:
7783
return await to_thread(
78-
as_numpy_array_wrapper, self._compress, chunk_bytes, chunk_spec.prototype
84+
as_numpy_array_wrapper, self._zstd_codec.encode, chunk_bytes, chunk_spec.prototype
7985
)
8086

8187
def compute_encoded_size(self, _input_byte_length: int, _chunk_spec: ArraySpec) -> int:

0 commit comments

Comments
 (0)