From bdf58a65a03647c13bce7ac670dc9a3adadddbbf Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Mon, 17 Jun 2024 16:46:19 +0200 Subject: [PATCH 01/36] make codec pipeline implementation configurable --- src/zarr/abc/codec.py | 10 ++++++++++ src/zarr/codecs/sharding.py | 6 +++--- src/zarr/config.py | 37 ++++++++++++++++++++++++++++++++++--- src/zarr/metadata.py | 8 ++++---- tests/v3/test_config.py | 30 ++++++++++++++++++++++++++++-- 5 files changed, 79 insertions(+), 12 deletions(-) diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index 1f452159ed..282ee5dcf7 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -14,6 +14,7 @@ from typing_extensions import Self from zarr.array_spec import ArraySpec + from zarr.common import JSON from zarr.indexing import SelectorTuple from zarr.metadata import ArrayMetadata @@ -384,6 +385,15 @@ async def write( """ ... + @classmethod + def from_dict(cls, data: Iterable[JSON | Codec]) -> Self: + """ + Create an instance of the model from a dictionary + """ + ... + + return cls(**data) + async def batching_helper( func: Callable[[CodecInput, ArraySpec], Awaitable[CodecOutput | None]], diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 74ad5ac44f..2f309a4eea 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -23,7 +23,6 @@ from zarr.chunk_grids import RegularChunkGrid from zarr.codecs.bytes import BytesCodec from zarr.codecs.crc32c_ import Crc32cCodec -from zarr.codecs.pipeline import BatchedCodecPipeline from zarr.codecs.registry import register_codec from zarr.common import ( ChunkCoords, @@ -33,6 +32,7 @@ parse_shapelike, product, ) +from zarr.config import config from zarr.indexing import BasicIndexer, SelectorTuple, c_order_iter, get_indexer, morton_order_iter from zarr.metadata import ArrayMetadata, parse_codecs @@ -314,12 +314,12 @@ def __init__( codecs_parsed = ( parse_codecs(codecs) if codecs is not None - else BatchedCodecPipeline.from_list([BytesCodec()]) + else config.codec_pipeline_class.from_list([BytesCodec()]) ) index_codecs_parsed = ( parse_codecs(index_codecs) if index_codecs is not None - else BatchedCodecPipeline.from_list([BytesCodec(), Crc32cCodec()]) + else config.codec_pipeline_class.from_list([BytesCodec(), Crc32cCodec()]) ) index_location_parsed = ( parse_index_location(index_location) diff --git a/src/zarr/config.py b/src/zarr/config.py index 7c5b48a16c..09860acdd1 100644 --- a/src/zarr/config.py +++ b/src/zarr/config.py @@ -1,8 +1,39 @@ from __future__ import annotations -from typing import Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, cast + +from donfig import Config as DConfig + +if TYPE_CHECKING: + from zarr.abc.codec import CodecPipeline + + +class BadConfigError(ValueError): + _msg = "bad Config: %r" + + +class Config(DConfig): # type: ignore[misc] + @property + def codec_pipeline_class(self) -> type[CodecPipeline]: + from zarr.abc.codec import CodecPipeline + + name = self.get("codec_pipeline.name") + name_camel_case = name.replace("_", " ").title().replace(" ", "") + selected_pipelines = [ + p for p in CodecPipeline.__subclasses__() if p.__name__ in (name, name_camel_case) + ] + + if not selected_pipelines: + raise BadConfigError( + f'No subclass of CodecPipeline with name "{name}" or "{name_camel_case}" found.' + ) + if len(selected_pipelines) > 1: + raise BadConfigError( + f'Multiple subclasses of CodecPipeline with name "{name}" or ' + f'"{name_camel_case}" found: {selected_pipelines}.' + ) + return selected_pipelines[0] -from donfig import Config config = Config( "zarr", @@ -10,7 +41,7 @@ { "array": {"order": "C"}, "async": {"concurrency": None, "timeout": None}, - "codec_pipeline": {"batch_size": 1}, + "codec_pipeline": {"name": "batched_codec_pipeline", "batch_size": 1}, } ], ) diff --git a/src/zarr/metadata.py b/src/zarr/metadata.py index 8329bd9200..ec3a89b360 100644 --- a/src/zarr/metadata.py +++ b/src/zarr/metadata.py @@ -373,9 +373,9 @@ def chunks(self) -> ChunkCoords: @property def codec_pipeline(self) -> CodecPipeline: - from zarr.codecs import BatchedCodecPipeline + from zarr.config import config - return BatchedCodecPipeline.from_list( + return config.codec_pipeline_class.from_list( [V2Filters(self.filters or []), V2Compressor(self.compressor)] ) @@ -501,8 +501,8 @@ def parse_v2_metadata(data: ArrayV2Metadata) -> ArrayV2Metadata: def parse_codecs(data: Iterable[Codec | JSON]) -> CodecPipeline: - from zarr.codecs import BatchedCodecPipeline + from zarr.config import config if not isinstance(data, Iterable): raise TypeError(f"Expected iterable, got {type(data)}") - return BatchedCodecPipeline.from_dict(data) + return config.codec_pipeline_class.from_dict(data) diff --git a/tests/v3/test_config.py b/tests/v3/test_config.py index aed9775d17..8b0b870c7b 100644 --- a/tests/v3/test_config.py +++ b/tests/v3/test_config.py @@ -1,4 +1,8 @@ -from zarr.config import config +import pytest + +from zarr.abc.codec import CodecPipeline +from zarr.codecs import BatchedCodecPipeline +from zarr.config import BadConfigError, config def test_config_defaults_set(): @@ -7,7 +11,7 @@ def test_config_defaults_set(): { "array": {"order": "C"}, "async": {"concurrency": None, "timeout": None}, - "codec_pipeline": {"batch_size": 1}, + "codec_pipeline": {"name": "batched_codec_pipeline", "batch_size": 1}, } ] assert config.get("array.order") == "C" @@ -17,3 +21,25 @@ def test_config_defaults_can_be_overridden(): assert config.get("array.order") == "C" with config.set({"array.order": "F"}): assert config.get("array.order") == "F" + + +def test_config_codec_pipeline_class(): + # has default value + assert config.codec_pipeline_class.__name__ != "" + + config.set({"codec_pipeline.name": "batched_codec_pipeline"}) + assert config.codec_pipeline_class == BatchedCodecPipeline + + class MockCodecPipeline(CodecPipeline): + pass + + config.set({"codec_pipeline.name": "mock_codec_pipeline"}) + assert config.codec_pipeline_class == MockCodecPipeline + + with pytest.raises(BadConfigError): + config.set({"codec_pipeline.name": "wrong_name"}) + config.codec_pipeline_class + + # Camel case works, too + config.set({"codec_pipeline.name": "MockCodecPipeline"}) + assert config.codec_pipeline_class == MockCodecPipeline From e2a5e11f071bb83831abb93ea824d6aa530abfd8 Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Mon, 17 Jun 2024 17:25:38 +0200 Subject: [PATCH 02/36] add test_config_codec_pipeline_class_in_env --- src/zarr/config.py | 14 +++++++++++++- tests/v3/test_config.py | 11 +++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/zarr/config.py b/src/zarr/config.py index 09860acdd1..70e9ef97d5 100644 --- a/src/zarr/config.py +++ b/src/zarr/config.py @@ -13,6 +13,18 @@ class BadConfigError(ValueError): class Config(DConfig): # type: ignore[misc] + """ Will collect configuration from config files and environment variables + + Example environment variables: + Grabs environment variables of the form "DASK_FOO__BAR_BAZ=123" and + turns these into config variables of the form ``{"foo": {"bar-baz": 123}}`` + It transforms the key and value in the following way: + + - Lower-cases the key text + - Treats ``__`` (double-underscore) as nested access + - Calls ``ast.literal_eval`` on the value + + """ @property def codec_pipeline_class(self) -> type[CodecPipeline]: from zarr.abc.codec import CodecPipeline @@ -36,7 +48,7 @@ def codec_pipeline_class(self) -> type[CodecPipeline]: config = Config( - "zarr", + "zarr_python", defaults=[ { "array": {"order": "C"}, diff --git a/tests/v3/test_config.py b/tests/v3/test_config.py index 8b0b870c7b..6cf1899629 100644 --- a/tests/v3/test_config.py +++ b/tests/v3/test_config.py @@ -1,3 +1,5 @@ +import os + import pytest from zarr.abc.codec import CodecPipeline @@ -43,3 +45,12 @@ class MockCodecPipeline(CodecPipeline): # Camel case works, too config.set({"codec_pipeline.name": "MockCodecPipeline"}) assert config.codec_pipeline_class == MockCodecPipeline + + +def test_config_codec_pipeline_class_in_env(): + class MockEnvCodecPipeline(CodecPipeline): + pass + + os.environ[("ZARR_PYTHON_CODEC_PIPELINE__NAME")] = "mock_env_codec_pipeline" + config.refresh() + assert config.codec_pipeline_class == MockEnvCodecPipeline From 311cfcce09ca53d1a120db575f9b26001ebd41a1 Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Tue, 18 Jun 2024 13:19:30 +0200 Subject: [PATCH 03/36] make codec implementation configurable --- src/zarr/abc/codec.py | 12 +++++++---- src/zarr/codecs/registry.py | 24 +++++++++++++++++++-- src/zarr/config.py | 35 +++++++++++++++++++++---------- tests/v3/test_config.py | 42 +++++++++++++++++++++++++++---------- 4 files changed, 85 insertions(+), 28 deletions(-) diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index 282ee5dcf7..8e28fa1201 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -8,7 +8,6 @@ from zarr.abc.store import ByteGetter, ByteSetter from zarr.buffer import Buffer, NDBuffer from zarr.common import concurrent_map -from zarr.config import config if TYPE_CHECKING: from typing_extensions import Self @@ -22,6 +21,11 @@ CodecOutput = TypeVar("CodecOutput", bound=NDBuffer | Buffer) +def get_config(): + from zarr.config import config + + return config + class _Codec(Generic[CodecInput, CodecOutput], Metadata): """Generic base class for codecs. Please use ArrayArrayCodec, ArrayBytesCodec or BytesBytesCodec for subclassing. @@ -186,7 +190,7 @@ async def decode_partial( for byte_getter, selection, chunk_spec in batch_info ], self._decode_partial_single, - config.get("async.concurrency"), + get_config().get("async.concurrency"), ) @@ -226,7 +230,7 @@ async def encode_partial( for byte_setter, chunk_array, selection, chunk_spec in batch_info ], self._encode_partial_single, - config.get("async.concurrency"), + get_config().get("async.concurrency"), ) @@ -402,7 +406,7 @@ async def batching_helper( return await concurrent_map( [(chunk_array, chunk_spec) for chunk_array, chunk_spec in batch_info], noop_for_none(func), - config.get("async.concurrency"), + get_config().get("async.concurrency"), ) diff --git a/src/zarr/codecs/registry.py b/src/zarr/codecs/registry.py index 2f2b09499f..fcd9642802 100644 --- a/src/zarr/codecs/registry.py +++ b/src/zarr/codecs/registry.py @@ -8,9 +8,13 @@ from importlib.metadata import EntryPoint from importlib.metadata import entry_points as get_entry_points +from zarr.config import camel_case, config + __codec_registry: dict[str, type[Codec]] = {} __lazy_load_codecs: dict[str, EntryPoint] = {} +all_codecs = {} + def _collect_entrypoints() -> dict[str, EntryPoint]: entry_points = get_entry_points() @@ -19,11 +23,27 @@ def _collect_entrypoints() -> dict[str, EntryPoint]: return __lazy_load_codecs +def _reload_config() -> None: + config.refresh() + for codec_cls, key in all_codecs.items(): + register_codec(key, codec_cls) + + def register_codec(key: str, codec_cls: type[Codec]) -> None: - __codec_registry[key] = codec_cls + all_codecs[codec_cls] = key + + selected_codec = config.get("codecs", {}).get(key) + if selected_codec is None: + raise ValueError(f"Codec '{key}' not found in config.") + name = selected_codec.get("name") + name_camel_case = camel_case(name) + if codec_cls.__name__ in (name, name_camel_case): + __codec_registry[key] = codec_cls -def get_codec_class(key: str) -> type[Codec]: +def get_codec_class(key: str, reload_config=False) -> type[Codec]: + if reload_config: + _reload_config() item = __codec_registry.get(key) if item is None: if key in __lazy_load_codecs: diff --git a/src/zarr/config.py b/src/zarr/config.py index 70e9ef97d5..5c10ae6087 100644 --- a/src/zarr/config.py +++ b/src/zarr/config.py @@ -13,24 +13,24 @@ class BadConfigError(ValueError): class Config(DConfig): # type: ignore[misc] - """ Will collect configuration from config files and environment variables + """Will collect configuration from config files and environment variables - Example environment variables: - Grabs environment variables of the form "DASK_FOO__BAR_BAZ=123" and - turns these into config variables of the form ``{"foo": {"bar-baz": 123}}`` - It transforms the key and value in the following way: + Example environment variables: + Grabs environment variables of the form "DASK_FOO__BAR_BAZ=123" and + turns these into config variables of the form ``{"foo": {"bar-baz": 123}}`` + It transforms the key and value in the following way: - - Lower-cases the key text - - Treats ``__`` (double-underscore) as nested access - - Calls ``ast.literal_eval`` on the value + - Lower-cases the key text + - Treats ``__`` (double-underscore) as nested access + - Calls ``ast.literal_eval`` on the value - """ + """ @property def codec_pipeline_class(self) -> type[CodecPipeline]: from zarr.abc.codec import CodecPipeline name = self.get("codec_pipeline.name") - name_camel_case = name.replace("_", " ").title().replace(" ", "") + name_camel_case = camel_case(name) selected_pipelines = [ p for p in CodecPipeline.__subclasses__() if p.__name__ in (name, name_camel_case) ] @@ -53,7 +53,17 @@ def codec_pipeline_class(self) -> type[CodecPipeline]: { "array": {"order": "C"}, "async": {"concurrency": None, "timeout": None}, - "codec_pipeline": {"name": "batched_codec_pipeline", "batch_size": 1}, + "codec_pipeline": {"name": "BatchedCodecPipeline", "batch_size": 1}, + "codecs": { + "blosc": {"name": "BloscCodec"}, + "gzip": {"name": "GzipCodec"}, + "zstd": {"name": "ZstdCodec"}, + "bytes": {"name": "BytesCodec"}, + "endian": {"name": "BytesCodec"}, # compatibility with earlier versions of ZEP1 + "crc32c": {"name": "Crc32cCodec"}, + "sharding_indexed": {"name": "ShardingCodec"}, + "transpose": {"name": "TransposeCodec"}, + }, } ], ) @@ -64,3 +74,6 @@ def parse_indexing_order(data: Any) -> Literal["C", "F"]: return cast(Literal["C", "F"], data) msg = f"Expected one of ('C', 'F'), got {data} instead." raise ValueError(msg) + +def camel_case(string: str) -> str: + return string.replace("_", " ").title().replace(" ", "") \ No newline at end of file diff --git a/tests/v3/test_config.py b/tests/v3/test_config.py index 6cf1899629..f6c7047bb0 100644 --- a/tests/v3/test_config.py +++ b/tests/v3/test_config.py @@ -3,7 +3,8 @@ import pytest from zarr.abc.codec import CodecPipeline -from zarr.codecs import BatchedCodecPipeline +from zarr.codecs import BatchedCodecPipeline, BloscCodec +from zarr.codecs.registry import get_codec_class, register_codec from zarr.config import BadConfigError, config @@ -13,7 +14,17 @@ def test_config_defaults_set(): { "array": {"order": "C"}, "async": {"concurrency": None, "timeout": None}, - "codec_pipeline": {"name": "batched_codec_pipeline", "batch_size": 1}, + "codec_pipeline": {"name": "BatchedCodecPipeline", "batch_size": 1}, + "codecs": { + "blosc": {"name": "BloscCodec"}, + "gzip": {"name": "GzipCodec"}, + "zstd": {"name": "ZstdCodec"}, + "bytes": {"name": "BytesCodec"}, + "endian": {"name": "BytesCodec"}, # compatibility with earlier versions of ZEP1 + "crc32c": {"name": "Crc32cCodec"}, + "sharding_indexed": {"name": "ShardingCodec"}, + "transpose": {"name": "TransposeCodec"}, + }, } ] assert config.get("array.order") == "C" @@ -29,28 +40,37 @@ def test_config_codec_pipeline_class(): # has default value assert config.codec_pipeline_class.__name__ != "" - config.set({"codec_pipeline.name": "batched_codec_pipeline"}) + config.set({"codec_pipeline.name": "BatchedCodecPipeline"}) assert config.codec_pipeline_class == BatchedCodecPipeline class MockCodecPipeline(CodecPipeline): pass - config.set({"codec_pipeline.name": "mock_codec_pipeline"}) + config.set({"codec_pipeline.name": "MockCodecPipeline"}) assert config.codec_pipeline_class == MockCodecPipeline with pytest.raises(BadConfigError): config.set({"codec_pipeline.name": "wrong_name"}) config.codec_pipeline_class - # Camel case works, too - config.set({"codec_pipeline.name": "MockCodecPipeline"}) - assert config.codec_pipeline_class == MockCodecPipeline - - -def test_config_codec_pipeline_class_in_env(): class MockEnvCodecPipeline(CodecPipeline): pass - os.environ[("ZARR_PYTHON_CODEC_PIPELINE__NAME")] = "mock_env_codec_pipeline" + os.environ[("ZARR_PYTHON_CODEC_PIPELINE__NAME")] = "MockEnvCodecPipeline" config.refresh() assert config.codec_pipeline_class == MockEnvCodecPipeline + + +def test_config_codec_implementation(): + assert get_codec_class("blosc").__name__ == config.defaults[0]["codecs"]["blosc"]["name"] + + class MockBloscCodec(BloscCodec): + pass + + config.set({"codecs.blosc.name": "MockBloscCodec"}) + register_codec("blosc", MockBloscCodec) + assert get_codec_class("blosc") == MockBloscCodec + + os.environ[("ZARR_PYTHON_CODECS__BLOSC__NAME")] = "BloscCodec" + assert get_codec_class("blosc", reload_config=True) == BloscCodec + From 364c4034c20c6a861ea00d617e97d2b9f661ae4e Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Tue, 18 Jun 2024 13:21:25 +0200 Subject: [PATCH 04/36] remove snake case support for class names in config --- src/zarr/codecs/registry.py | 5 ++--- src/zarr/config.py | 16 ++++------------ tests/v3/test_config.py | 4 ++-- 3 files changed, 8 insertions(+), 17 deletions(-) diff --git a/src/zarr/codecs/registry.py b/src/zarr/codecs/registry.py index fcd9642802..1a5eff34d0 100644 --- a/src/zarr/codecs/registry.py +++ b/src/zarr/codecs/registry.py @@ -8,7 +8,7 @@ from importlib.metadata import EntryPoint from importlib.metadata import entry_points as get_entry_points -from zarr.config import camel_case, config +from zarr.config import config __codec_registry: dict[str, type[Codec]] = {} __lazy_load_codecs: dict[str, EntryPoint] = {} @@ -36,8 +36,7 @@ def register_codec(key: str, codec_cls: type[Codec]) -> None: if selected_codec is None: raise ValueError(f"Codec '{key}' not found in config.") name = selected_codec.get("name") - name_camel_case = camel_case(name) - if codec_cls.__name__ in (name, name_camel_case): + if codec_cls.__name__ == name: __codec_registry[key] = codec_cls diff --git a/src/zarr/config.py b/src/zarr/config.py index 5c10ae6087..e305efa046 100644 --- a/src/zarr/config.py +++ b/src/zarr/config.py @@ -25,24 +25,19 @@ class Config(DConfig): # type: ignore[misc] - Calls ``ast.literal_eval`` on the value """ + @property def codec_pipeline_class(self) -> type[CodecPipeline]: from zarr.abc.codec import CodecPipeline name = self.get("codec_pipeline.name") - name_camel_case = camel_case(name) - selected_pipelines = [ - p for p in CodecPipeline.__subclasses__() if p.__name__ in (name, name_camel_case) - ] + selected_pipelines = [p for p in CodecPipeline.__subclasses__() if p.__name__ == name] if not selected_pipelines: - raise BadConfigError( - f'No subclass of CodecPipeline with name "{name}" or "{name_camel_case}" found.' - ) + raise BadConfigError(f'No subclass of CodecPipeline with name "{name}" found.') if len(selected_pipelines) > 1: raise BadConfigError( - f'Multiple subclasses of CodecPipeline with name "{name}" or ' - f'"{name_camel_case}" found: {selected_pipelines}.' + f'Multiple subclasses of CodecPipeline with name "{name}" found: {selected_pipelines}.' ) return selected_pipelines[0] @@ -74,6 +69,3 @@ def parse_indexing_order(data: Any) -> Literal["C", "F"]: return cast(Literal["C", "F"], data) msg = f"Expected one of ('C', 'F'), got {data} instead." raise ValueError(msg) - -def camel_case(string: str) -> str: - return string.replace("_", " ").title().replace(" ", "") \ No newline at end of file diff --git a/tests/v3/test_config.py b/tests/v3/test_config.py index f6c7047bb0..4c6b5e5410 100644 --- a/tests/v3/test_config.py +++ b/tests/v3/test_config.py @@ -56,7 +56,7 @@ class MockCodecPipeline(CodecPipeline): class MockEnvCodecPipeline(CodecPipeline): pass - os.environ[("ZARR_PYTHON_CODEC_PIPELINE__NAME")] = "MockEnvCodecPipeline" + os.environ["ZARR_PYTHON_CODEC_PIPELINE__NAME"] = "MockEnvCodecPipeline" config.refresh() assert config.codec_pipeline_class == MockEnvCodecPipeline @@ -71,6 +71,6 @@ class MockBloscCodec(BloscCodec): register_codec("blosc", MockBloscCodec) assert get_codec_class("blosc") == MockBloscCodec - os.environ[("ZARR_PYTHON_CODECS__BLOSC__NAME")] = "BloscCodec" + os.environ["ZARR_PYTHON_CODECS__BLOSC__NAME"] = "BloscCodec" assert get_codec_class("blosc", reload_config=True) == BloscCodec From 9e94e368fe4c7431d1c38ef43a99fd92ae00d100 Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Tue, 18 Jun 2024 14:51:29 +0200 Subject: [PATCH 05/36] use registry for codec pipeline config --- src/zarr/codecs/pipeline.py | 5 ++- src/zarr/codecs/registry.py | 64 +++++++++++++++++++----------- src/zarr/codecs/sharding.py | 7 ++-- src/zarr/config.py | 15 ------- src/zarr/metadata.py | 10 ++--- tests/v3/test_codec_entrypoints.py | 9 ++++- tests/v3/test_config.py | 37 ++++++++++------- 7 files changed, 82 insertions(+), 65 deletions(-) diff --git a/src/zarr/codecs/pipeline.py b/src/zarr/codecs/pipeline.py index acef311a8c..0eb2a8455f 100644 --- a/src/zarr/codecs/pipeline.py +++ b/src/zarr/codecs/pipeline.py @@ -17,7 +17,7 @@ ) from zarr.abc.store import ByteGetter, ByteSetter from zarr.buffer import Buffer, BufferPrototype, NDBuffer -from zarr.codecs.registry import get_codec_class +from zarr.codecs.registry import get_codec_class, register_pipeline from zarr.common import JSON, concurrent_map, parse_named_configuration from zarr.config import config from zarr.indexing import SelectorTuple, is_scalar, is_total_slice @@ -509,3 +509,6 @@ async def write( self.write_batch, config.get("async.concurrency"), ) + + +register_pipeline(BatchedCodecPipeline) diff --git a/src/zarr/codecs/registry.py b/src/zarr/codecs/registry.py index 1a5eff34d0..f44b35f6fd 100644 --- a/src/zarr/codecs/registry.py +++ b/src/zarr/codecs/registry.py @@ -1,19 +1,19 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING if TYPE_CHECKING: - from zarr.abc.codec import Codec + from zarr.abc.codec import Codec, CodecPipeline from importlib.metadata import EntryPoint from importlib.metadata import entry_points as get_entry_points -from zarr.config import config +from zarr.config import BadConfigError, config -__codec_registry: dict[str, type[Codec]] = {} +__codec_registry: dict[str, dict[str, type[Codec]]] = {} __lazy_load_codecs: dict[str, EntryPoint] = {} - -all_codecs = {} +__pipeline_registry: dict[str, type[CodecPipeline]] = {} def _collect_entrypoints() -> dict[str, EntryPoint]: @@ -25,33 +25,51 @@ def _collect_entrypoints() -> dict[str, EntryPoint]: def _reload_config() -> None: config.refresh() - for codec_cls, key in all_codecs.items(): - register_codec(key, codec_cls) def register_codec(key: str, codec_cls: type[Codec]) -> None: - all_codecs[codec_cls] = key + registered_codecs = __codec_registry.get(key, {}) + registered_codecs[codec_cls.__name__] = codec_cls + __codec_registry[key] = registered_codecs + + +def register_pipeline(pipe_cls: type[CodecPipeline]) -> None: + __pipeline_registry[pipe_cls.__name__] = pipe_cls - selected_codec = config.get("codecs", {}).get(key) - if selected_codec is None: - raise ValueError(f"Codec '{key}' not found in config.") - name = selected_codec.get("name") - if codec_cls.__name__ == name: - __codec_registry[key] = codec_cls + +def get_pipeline_class(reload_config=False) -> type[CodecPipeline]: + if reload_config: + _reload_config() + name = config.get("codec_pipeline.name") + pipeline_class = __pipeline_registry.get(name) + if pipeline_class: + return pipeline_class + raise BadConfigError( + f"Pipeline class '{name}' not found in registered pipelines: {list(__pipeline_registry.keys())}." + ) def get_codec_class(key: str, reload_config=False) -> type[Codec]: if reload_config: _reload_config() - item = __codec_registry.get(key) - if item is None: - if key in __lazy_load_codecs: - # logger.debug("Auto loading codec '%s' from entrypoint", codec_id) - cls = __lazy_load_codecs[key].load() - register_codec(key, cls) - item = __codec_registry.get(key) - if item: - return item + + if key in __lazy_load_codecs: + # logger.debug("Auto loading codec '%s' from entrypoint", codec_id) + cls = __lazy_load_codecs[key].load() + register_codec(key, cls) + + codec_classes = __codec_registry.get(key) + + config_entry = config.get("codecs", {}).get(key) + if config_entry is None: + warnings.warn(f"Codec '{key}' not configured in config. Selecting any implementation.") + return codec_classes.values()[-1] + + name = config_entry.get("name") + selected_codec_cls = codec_classes[name] + + if selected_codec_cls: + return selected_codec_cls raise KeyError(key) diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 2f309a4eea..c9e767bc52 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -23,7 +23,7 @@ from zarr.chunk_grids import RegularChunkGrid from zarr.codecs.bytes import BytesCodec from zarr.codecs.crc32c_ import Crc32cCodec -from zarr.codecs.registry import register_codec +from zarr.codecs.registry import get_pipeline_class, register_codec from zarr.common import ( ChunkCoords, ChunkCoordsLike, @@ -32,7 +32,6 @@ parse_shapelike, product, ) -from zarr.config import config from zarr.indexing import BasicIndexer, SelectorTuple, c_order_iter, get_indexer, morton_order_iter from zarr.metadata import ArrayMetadata, parse_codecs @@ -314,12 +313,12 @@ def __init__( codecs_parsed = ( parse_codecs(codecs) if codecs is not None - else config.codec_pipeline_class.from_list([BytesCodec()]) + else get_pipeline_class().from_list([BytesCodec()]) ) index_codecs_parsed = ( parse_codecs(index_codecs) if index_codecs is not None - else config.codec_pipeline_class.from_list([BytesCodec(), Crc32cCodec()]) + else get_pipeline_class().from_list([BytesCodec(), Crc32cCodec()]) ) index_location_parsed = ( parse_index_location(index_location) diff --git a/src/zarr/config.py b/src/zarr/config.py index e305efa046..b9e6bb89a6 100644 --- a/src/zarr/config.py +++ b/src/zarr/config.py @@ -26,21 +26,6 @@ class Config(DConfig): # type: ignore[misc] """ - @property - def codec_pipeline_class(self) -> type[CodecPipeline]: - from zarr.abc.codec import CodecPipeline - - name = self.get("codec_pipeline.name") - selected_pipelines = [p for p in CodecPipeline.__subclasses__() if p.__name__ == name] - - if not selected_pipelines: - raise BadConfigError(f'No subclass of CodecPipeline with name "{name}" found.') - if len(selected_pipelines) > 1: - raise BadConfigError( - f'Multiple subclasses of CodecPipeline with name "{name}" found: {selected_pipelines}.' - ) - return selected_pipelines[0] - config = Config( "zarr_python", diff --git a/src/zarr/metadata.py b/src/zarr/metadata.py index ec3a89b360..a39797a106 100644 --- a/src/zarr/metadata.py +++ b/src/zarr/metadata.py @@ -16,6 +16,7 @@ from zarr.chunk_grids import ChunkGrid, RegularChunkGrid from zarr.chunk_key_encodings import ChunkKeyEncoding, parse_separator from zarr.codecs._v2 import V2Compressor, V2Filters +from zarr.codecs.registry import get_pipeline_class if TYPE_CHECKING: from typing_extensions import Self @@ -39,7 +40,6 @@ # For type checking _bool = bool - __all__ = ["ArrayMetadata"] @@ -373,9 +373,7 @@ def chunks(self) -> ChunkCoords: @property def codec_pipeline(self) -> CodecPipeline: - from zarr.config import config - - return config.codec_pipeline_class.from_list( + return get_pipeline_class().from_list( [V2Filters(self.filters or []), V2Compressor(self.compressor)] ) @@ -501,8 +499,6 @@ def parse_v2_metadata(data: ArrayV2Metadata) -> ArrayV2Metadata: def parse_codecs(data: Iterable[Codec | JSON]) -> CodecPipeline: - from zarr.config import config - if not isinstance(data, Iterable): raise TypeError(f"Expected iterable, got {type(data)}") - return config.codec_pipeline_class.from_dict(data) + return get_pipeline_class().from_dict(data) diff --git a/tests/v3/test_codec_entrypoints.py b/tests/v3/test_codec_entrypoints.py index 6b5c221f4d..3b8f0ded92 100644 --- a/tests/v3/test_codec_entrypoints.py +++ b/tests/v3/test_codec_entrypoints.py @@ -1,9 +1,11 @@ import os.path import sys +import warnings import pytest import zarr.codecs.registry +from zarr import config here = os.path.abspath(os.path.dirname(__file__)) @@ -20,5 +22,10 @@ def set_path(): @pytest.mark.usefixtures("set_path") def test_entrypoint_codec(): + with pytest.raises(UserWarning): + cls = zarr.codecs.registry.get_codec_class("test") + assert cls.__name__ == "TestCodec" + + config.set({"codecs.test.name": "TestCodec"}) cls = zarr.codecs.registry.get_codec_class("test") - assert cls.__name__ == "TestCodec" + assert cls.__name__ == "TestCodec" \ No newline at end of file diff --git a/tests/v3/test_config.py b/tests/v3/test_config.py index 4c6b5e5410..becb1ab590 100644 --- a/tests/v3/test_config.py +++ b/tests/v3/test_config.py @@ -1,10 +1,16 @@ import os +from unittest import mock import pytest from zarr.abc.codec import CodecPipeline from zarr.codecs import BatchedCodecPipeline, BloscCodec -from zarr.codecs.registry import get_codec_class, register_codec +from zarr.codecs.registry import ( + get_codec_class, + get_pipeline_class, + register_codec, + register_pipeline, +) from zarr.config import BadConfigError, config @@ -36,41 +42,44 @@ def test_config_defaults_can_be_overridden(): assert config.get("array.order") == "F" -def test_config_codec_pipeline_class(): +def test_config_codec_pipeline_class(reset_config): # has default value - assert config.codec_pipeline_class.__name__ != "" + assert get_pipeline_class().__name__ != "" config.set({"codec_pipeline.name": "BatchedCodecPipeline"}) - assert config.codec_pipeline_class == BatchedCodecPipeline + assert get_pipeline_class() == BatchedCodecPipeline class MockCodecPipeline(CodecPipeline): pass + register_pipeline(MockCodecPipeline) + config.set({"codec_pipeline.name": "MockCodecPipeline"}) - assert config.codec_pipeline_class == MockCodecPipeline + assert get_pipeline_class() == MockCodecPipeline with pytest.raises(BadConfigError): config.set({"codec_pipeline.name": "wrong_name"}) - config.codec_pipeline_class + get_pipeline_class() class MockEnvCodecPipeline(CodecPipeline): pass - os.environ["ZARR_PYTHON_CODEC_PIPELINE__NAME"] = "MockEnvCodecPipeline" - config.refresh() - assert config.codec_pipeline_class == MockEnvCodecPipeline + register_pipeline(MockEnvCodecPipeline) + + with mock.patch.dict(os.environ, {"ZARR_PYTHON_CODEC_PIPELINE__NAME": "MockEnvCodecPipeline"}): + assert get_pipeline_class(reload_config=True) == MockEnvCodecPipeline -def test_config_codec_implementation(): +def test_config_codec_implementation(reset_config): assert get_codec_class("blosc").__name__ == config.defaults[0]["codecs"]["blosc"]["name"] class MockBloscCodec(BloscCodec): pass - config.set({"codecs.blosc.name": "MockBloscCodec"}) register_codec("blosc", MockBloscCodec) - assert get_codec_class("blosc") == MockBloscCodec - os.environ["ZARR_PYTHON_CODECS__BLOSC__NAME"] = "BloscCodec" - assert get_codec_class("blosc", reload_config=True) == BloscCodec + config.set({"codecs.blosc.name": "MockBloscCodec"}) + assert get_codec_class("blosc") == MockBloscCodec + with mock.patch.dict(os.environ, {"ZARR_PYTHON_CODECS__BLOSC__NAME": "BloscCodec"}): + assert get_codec_class("blosc", reload_config=True) == BloscCodec From 11f184df8b7ee48ba4be617c5d4a1ad0defd56fa Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Tue, 18 Jun 2024 15:04:19 +0200 Subject: [PATCH 06/36] typing --- src/zarr/abc/codec.py | 3 ++- src/zarr/codecs/registry.py | 10 ++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index 8e28fa1201..4dc3e3bb4d 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -8,6 +8,7 @@ from zarr.abc.store import ByteGetter, ByteSetter from zarr.buffer import Buffer, NDBuffer from zarr.common import concurrent_map +from zarr.config import Config if TYPE_CHECKING: from typing_extensions import Self @@ -21,7 +22,7 @@ CodecOutput = TypeVar("CodecOutput", bound=NDBuffer | Buffer) -def get_config(): +def get_config() -> Config: from zarr.config import config return config diff --git a/src/zarr/codecs/registry.py b/src/zarr/codecs/registry.py index f44b35f6fd..06b826994b 100644 --- a/src/zarr/codecs/registry.py +++ b/src/zarr/codecs/registry.py @@ -37,7 +37,7 @@ def register_pipeline(pipe_cls: type[CodecPipeline]) -> None: __pipeline_registry[pipe_cls.__name__] = pipe_cls -def get_pipeline_class(reload_config=False) -> type[CodecPipeline]: +def get_pipeline_class(reload_config: bool = False) -> type[CodecPipeline]: if reload_config: _reload_config() name = config.get("codec_pipeline.name") @@ -49,7 +49,7 @@ def get_pipeline_class(reload_config=False) -> type[CodecPipeline]: ) -def get_codec_class(key: str, reload_config=False) -> type[Codec]: +def get_codec_class(key: str, reload_config: bool = False) -> type[Codec]: if reload_config: _reload_config() @@ -58,12 +58,14 @@ def get_codec_class(key: str, reload_config=False) -> type[Codec]: cls = __lazy_load_codecs[key].load() register_codec(key, cls) - codec_classes = __codec_registry.get(key) + codec_classes = __codec_registry[key] + if not codec_classes: + raise KeyError(key) config_entry = config.get("codecs", {}).get(key) if config_entry is None: warnings.warn(f"Codec '{key}' not configured in config. Selecting any implementation.") - return codec_classes.values()[-1] + return list(codec_classes.values())[-1] name = config_entry.get("name") selected_codec_cls = codec_classes[name] From 216a5d441d04e319f42518e4eb8439f9f4a9e275 Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Tue, 18 Jun 2024 16:10:43 +0200 Subject: [PATCH 07/36] load codec pipeline from entrypoints --- src/zarr/abc/codec.py | 1 + src/zarr/codecs/registry.py | 16 ++++++++++++--- src/zarr/config.py | 9 +++++---- .../entry_points.txt | 2 ++ tests/v3/package_with_entrypoint/__init__.py | 13 +++++++++++- tests/v3/test_codec_entrypoints.py | 20 +++++++++++-------- tests/v3/test_config.py | 4 ++-- 7 files changed, 47 insertions(+), 18 deletions(-) diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index 4dc3e3bb4d..8bc75010ea 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -27,6 +27,7 @@ def get_config() -> Config: return config + class _Codec(Generic[CodecInput, CodecOutput], Metadata): """Generic base class for codecs. Please use ArrayArrayCodec, ArrayBytesCodec or BytesBytesCodec for subclassing. diff --git a/src/zarr/codecs/registry.py b/src/zarr/codecs/registry.py index 06b826994b..7c123d63c4 100644 --- a/src/zarr/codecs/registry.py +++ b/src/zarr/codecs/registry.py @@ -14,13 +14,17 @@ __codec_registry: dict[str, dict[str, type[Codec]]] = {} __lazy_load_codecs: dict[str, EntryPoint] = {} __pipeline_registry: dict[str, type[CodecPipeline]] = {} +__lazy_load_pipelines: list[EntryPoint] = [] -def _collect_entrypoints() -> dict[str, EntryPoint]: +def _collect_entrypoints() -> tuple[dict[str, EntryPoint], list[EntryPoint]]: entry_points = get_entry_points() for e in entry_points.select(group="zarr.codecs"): __lazy_load_codecs[e.name] = e - return __lazy_load_codecs + for e in entry_points.select(group="zarr"): + if e.name == "codec_pipeline": + __lazy_load_pipelines.append(e) + return __lazy_load_codecs, __lazy_load_pipelines def _reload_config() -> None: @@ -40,6 +44,10 @@ def register_pipeline(pipe_cls: type[CodecPipeline]) -> None: def get_pipeline_class(reload_config: bool = False) -> type[CodecPipeline]: if reload_config: _reload_config() + for e in __lazy_load_pipelines: + __lazy_load_pipelines.remove(e) + register_pipeline(e.load()) + name = config.get("codec_pipeline.name") pipeline_class = __pipeline_registry.get(name) if pipeline_class: @@ -64,7 +72,9 @@ def get_codec_class(key: str, reload_config: bool = False) -> type[Codec]: config_entry = config.get("codecs", {}).get(key) if config_entry is None: - warnings.warn(f"Codec '{key}' not configured in config. Selecting any implementation.") + warnings.warn( + f"Codec '{key}' not configured in config. Selecting any implementation.", stacklevel=2 + ) return list(codec_classes.values())[-1] name = config_entry.get("name") diff --git a/src/zarr/config.py b/src/zarr/config.py index b9e6bb89a6..2209b5efc3 100644 --- a/src/zarr/config.py +++ b/src/zarr/config.py @@ -1,12 +1,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import Any, Literal, cast from donfig import Config as DConfig -if TYPE_CHECKING: - from zarr.abc.codec import CodecPipeline - class BadConfigError(ValueError): _msg = "bad Config: %r" @@ -26,6 +23,10 @@ class Config(DConfig): # type: ignore[misc] """ + def reset(self) -> None: + self.clear() + self.refresh() + config = Config( "zarr_python", diff --git a/tests/v3/package_with_entrypoint-0.1.dist-info/entry_points.txt b/tests/v3/package_with_entrypoint-0.1.dist-info/entry_points.txt index 2c9dc375de..93eaebd0f4 100644 --- a/tests/v3/package_with_entrypoint-0.1.dist-info/entry_points.txt +++ b/tests/v3/package_with_entrypoint-0.1.dist-info/entry_points.txt @@ -1,2 +1,4 @@ [zarr.codecs] test = package_with_entrypoint:TestCodec +[zarr] +codec_pipeline = package_with_entrypoint:TestCodecPipeline diff --git a/tests/v3/package_with_entrypoint/__init__.py b/tests/v3/package_with_entrypoint/__init__.py index 6368e5b236..63e51a52db 100644 --- a/tests/v3/package_with_entrypoint/__init__.py +++ b/tests/v3/package_with_entrypoint/__init__.py @@ -1,6 +1,6 @@ from numpy import ndarray -from zarr.abc.codec import ArrayBytesCodec +from zarr.abc.codec import ArrayBytesCodec, CodecPipeline from zarr.array_spec import ArraySpec from zarr.common import BytesLike @@ -24,3 +24,14 @@ async def decode( def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) -> int: return input_byte_length + + +class TestCodecPipeline(CodecPipeline): + def __init__(self, batch_size: int = 1): + pass + + async def encode(self, chunk_array: ndarray, chunk_spec: ArraySpec) -> BytesLike: + pass + + async def decode(self, chunk_bytes: BytesLike, chunk_spec: ArraySpec) -> ndarray: + pass diff --git a/tests/v3/test_codec_entrypoints.py b/tests/v3/test_codec_entrypoints.py index 3b8f0ded92..584d15e4c7 100644 --- a/tests/v3/test_codec_entrypoints.py +++ b/tests/v3/test_codec_entrypoints.py @@ -1,6 +1,5 @@ import os.path import sys -import warnings import pytest @@ -16,16 +15,21 @@ def set_path(): zarr.codecs.registry._collect_entrypoints() yield sys.path.remove(here) - entry_points = zarr.codecs.registry._collect_entrypoints() - entry_points.pop("test") + lazy_load_codecs, lazy_load_pipelines = zarr.codecs.registry._collect_entrypoints() + lazy_load_codecs.pop("test") + lazy_load_pipelines.clear() + config.reset() @pytest.mark.usefixtures("set_path") def test_entrypoint_codec(): - with pytest.raises(UserWarning): - cls = zarr.codecs.registry.get_codec_class("test") - assert cls.__name__ == "TestCodec" - config.set({"codecs.test.name": "TestCodec"}) cls = zarr.codecs.registry.get_codec_class("test") - assert cls.__name__ == "TestCodec" \ No newline at end of file + assert cls.__name__ == "TestCodec" + + +@pytest.mark.usefixtures("set_path") +def test_entrypoint_pipeline(): + config.set({"codec_pipeline.name": "TestCodecPipeline"}) + cls = zarr.codecs.registry.get_pipeline_class() + assert cls.__name__ == "TestCodecPipeline" diff --git a/tests/v3/test_config.py b/tests/v3/test_config.py index becb1ab590..b67716e10a 100644 --- a/tests/v3/test_config.py +++ b/tests/v3/test_config.py @@ -42,7 +42,7 @@ def test_config_defaults_can_be_overridden(): assert config.get("array.order") == "F" -def test_config_codec_pipeline_class(reset_config): +def test_config_codec_pipeline_class(): # has default value assert get_pipeline_class().__name__ != "" @@ -70,7 +70,7 @@ class MockEnvCodecPipeline(CodecPipeline): assert get_pipeline_class(reload_config=True) == MockEnvCodecPipeline -def test_config_codec_implementation(reset_config): +def test_config_codec_implementation(): assert get_codec_class("blosc").__name__ == config.defaults[0]["codecs"]["blosc"]["name"] class MockBloscCodec(BloscCodec): From 2a3b7ea9b083792a9b1ef917e91f703d5bdea354 Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Tue, 18 Jun 2024 17:18:45 +0200 Subject: [PATCH 08/36] test if configured codec implementation and codec pipeline is used --- tests/v3/conftest.py | 8 +++++ tests/v3/test_config.py | 72 +++++++++++++++++++++++++++++++++++------ 2 files changed, 70 insertions(+), 10 deletions(-) diff --git a/tests/v3/conftest.py b/tests/v3/conftest.py index 6b58cce412..22753f0a7d 100644 --- a/tests/v3/conftest.py +++ b/tests/v3/conftest.py @@ -4,6 +4,7 @@ from types import ModuleType from typing import TYPE_CHECKING +from zarr import config from zarr.common import ZarrFormat from zarr.group import AsyncGroup @@ -90,3 +91,10 @@ def xp(request: pytest.FixtureRequest) -> Iterator[ModuleType]: """Fixture to parametrize over numpy-like libraries""" yield pytest.importorskip(request.param) + + +@pytest.fixture(autouse=True) +def reset_config(): + config.reset() + yield + config.reset() diff --git a/tests/v3/test_config.py b/tests/v3/test_config.py index b67716e10a..e0d1bee1c8 100644 --- a/tests/v3/test_config.py +++ b/tests/v3/test_config.py @@ -1,10 +1,16 @@ import os +from collections.abc import Iterable from unittest import mock +from unittest.mock import Mock import pytest -from zarr.abc.codec import CodecPipeline -from zarr.codecs import BatchedCodecPipeline, BloscCodec +from zarr import Array +from zarr.abc.codec import CodecInput, CodecOutput, CodecPipeline +from zarr.abc.store import ByteSetter +from zarr.array_spec import ArraySpec +from zarr.buffer import NDBuffer +from zarr.codecs import BatchedCodecPipeline, BloscCodec, BytesCodec from zarr.codecs.registry import ( get_codec_class, get_pipeline_class, @@ -12,6 +18,14 @@ register_pipeline, ) from zarr.config import BadConfigError, config +from zarr.indexing import SelectorTuple + + +@pytest.fixture() +def reset_config(): + config.reset() + yield + config.reset() def test_config_defaults_set(): @@ -42,21 +56,41 @@ def test_config_defaults_can_be_overridden(): assert config.get("array.order") == "F" -def test_config_codec_pipeline_class(): +@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) +def test_config_codec_pipeline_class(store): # has default value assert get_pipeline_class().__name__ != "" config.set({"codec_pipeline.name": "BatchedCodecPipeline"}) assert get_pipeline_class() == BatchedCodecPipeline - class MockCodecPipeline(CodecPipeline): - pass + _mock = Mock() - register_pipeline(MockCodecPipeline) + class MockCodecPipeline(BatchedCodecPipeline): + async def write( + self, + batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple]], + value: NDBuffer, + drop_axes: tuple[int, ...] = (), + ) -> None: + _mock.call() + register_pipeline(MockCodecPipeline) config.set({"codec_pipeline.name": "MockCodecPipeline"}) assert get_pipeline_class() == MockCodecPipeline + # test if codec is used + arr = Array.create( + store=store, + shape=(100,), + chunks=(10,), + zarr_format=3, + dtype="i4", + ) + arr[:] = range(100) + + _mock.call.assert_called() + with pytest.raises(BadConfigError): config.set({"codec_pipeline.name": "wrong_name"}) get_pipeline_class() @@ -70,16 +104,34 @@ class MockEnvCodecPipeline(CodecPipeline): assert get_pipeline_class(reload_config=True) == MockEnvCodecPipeline -def test_config_codec_implementation(): +@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) +def test_config_codec_implementation(store): + # has default value assert get_codec_class("blosc").__name__ == config.defaults[0]["codecs"]["blosc"]["name"] - class MockBloscCodec(BloscCodec): - pass + _mock = Mock() - register_codec("blosc", MockBloscCodec) + class MockBloscCodec(BloscCodec): + async def _encode_single( + self, chunk_data: CodecInput, chunk_spec: ArraySpec + ) -> CodecOutput | None: + _mock.call() config.set({"codecs.blosc.name": "MockBloscCodec"}) + register_codec("blosc", MockBloscCodec) assert get_codec_class("blosc") == MockBloscCodec + # test if codec is used + arr = Array.create( + store=store, + shape=(100,), + chunks=(10,), + zarr_format=3, + dtype="i4", + codecs=[BytesCodec(), {"name": "blosc", "configuration": {}}], + ) + arr[:] = range(100) + _mock.call.assert_called() + with mock.patch.dict(os.environ, {"ZARR_PYTHON_CODECS__BLOSC__NAME": "BloscCodec"}): assert get_codec_class("blosc", reload_config=True) == BloscCodec From 02d1f6ef813edc94cac9e806bc54b3305736eb0a Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Thu, 20 Jun 2024 14:14:41 +0200 Subject: [PATCH 09/36] make ndbuffer implementation configurable --- src/zarr/array.py | 51 ++++++++++++++++----- src/zarr/buffer.py | 15 ++++++- src/zarr/codecs/registry.py | 71 ++++++++++++++++++++++++------ src/zarr/codecs/sharding.py | 8 ++-- src/zarr/config.py | 2 + src/zarr/metadata.py | 2 +- src/zarr/store/core.py | 4 +- src/zarr/store/remote.py | 5 ++- src/zarr/testing/buffer.py | 62 ++++++++++++++++++++++++++ src/zarr/testing/store.py | 8 ++-- tests/v3/test_buffer.py | 60 +------------------------ tests/v3/test_config.py | 29 ++++++++++++ tests/v3/test_store/test_remote.py | 2 +- 13 files changed, 222 insertions(+), 97 deletions(-) create mode 100644 src/zarr/testing/buffer.py diff --git a/src/zarr/array.py b/src/zarr/array.py index 9ac1ce41ec..2d26b4eaf5 100644 --- a/src/zarr/array.py +++ b/src/zarr/array.py @@ -460,8 +460,13 @@ async def _get_selection( return out_buffer.as_ndarray_like() async def getitem( - self, selection: Selection, *, prototype: BufferPrototype = default_buffer_prototype + self, + selection: Selection, + *, + prototype: BufferPrototype | None = None, ) -> NDArrayLike: + if prototype is None: + prototype = default_buffer_prototype() indexer = BasicIndexer( selection, shape=self.metadata.shape, @@ -522,8 +527,10 @@ async def setitem( self, selection: Selection, value: NDArrayLike, - prototype: BufferPrototype = default_buffer_prototype, + prototype: BufferPrototype | None = None, ) -> None: + if prototype is None: + prototype = default_buffer_prototype() indexer = BasicIndexer( selection, shape=self.metadata.shape, @@ -724,9 +731,11 @@ def get_basic_selection( selection: BasicSelection = Ellipsis, *, out: NDBuffer | None = None, - prototype: BufferPrototype = default_buffer_prototype, + prototype: BufferPrototype | None = None, fields: Fields | None = None, ) -> NDArrayLike: + if prototype is None: + prototype = default_buffer_prototype() if self.shape == (): raise NotImplementedError else: @@ -745,8 +754,10 @@ def set_basic_selection( value: NDArrayLike, *, fields: Fields | None = None, - prototype: BufferPrototype = default_buffer_prototype, + prototype: BufferPrototype | None = None, ) -> None: + if prototype is None: + prototype = default_buffer_prototype() indexer = BasicIndexer(selection, self.shape, self.metadata.chunk_grid) sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) @@ -756,8 +767,10 @@ def get_orthogonal_selection( *, out: NDBuffer | None = None, fields: Fields | None = None, - prototype: BufferPrototype = default_buffer_prototype, + prototype: BufferPrototype | None = None, ) -> NDArrayLike: + if prototype is None: + prototype = default_buffer_prototype() indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid) return sync( self._async_array._get_selection( @@ -771,8 +784,10 @@ def set_orthogonal_selection( value: NDArrayLike, *, fields: Fields | None = None, - prototype: BufferPrototype = default_buffer_prototype, + prototype: BufferPrototype | None = None, ) -> None: + if prototype is None: + prototype = default_buffer_prototype() indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid) return sync( self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype) @@ -784,8 +799,10 @@ def get_mask_selection( *, out: NDBuffer | None = None, fields: Fields | None = None, - prototype: BufferPrototype = default_buffer_prototype, + prototype: BufferPrototype | None = None, ) -> NDArrayLike: + if prototype is None: + prototype = default_buffer_prototype() indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid) return sync( self._async_array._get_selection( @@ -799,8 +816,10 @@ def set_mask_selection( value: NDArrayLike, *, fields: Fields | None = None, - prototype: BufferPrototype = default_buffer_prototype, + prototype: BufferPrototype | None = None, ) -> None: + if prototype is None: + prototype = default_buffer_prototype() indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid) sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) @@ -810,8 +829,10 @@ def get_coordinate_selection( *, out: NDBuffer | None = None, fields: Fields | None = None, - prototype: BufferPrototype = default_buffer_prototype, + prototype: BufferPrototype | None = None, ) -> NDArrayLike: + if prototype is None: + prototype = default_buffer_prototype() indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid) out_array = sync( self._async_array._get_selection( @@ -829,8 +850,10 @@ def set_coordinate_selection( value: NDArrayLike, *, fields: Fields | None = None, - prototype: BufferPrototype = default_buffer_prototype, + prototype: BufferPrototype | None = None, ) -> None: + if prototype is None: + prototype = default_buffer_prototype() # setup indexer indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid) @@ -854,8 +877,10 @@ def get_block_selection( *, out: NDBuffer | None = None, fields: Fields | None = None, - prototype: BufferPrototype = default_buffer_prototype, + prototype: BufferPrototype | None = None, ) -> NDArrayLike: + if prototype is None: + prototype = default_buffer_prototype() indexer = BlockIndexer(selection, self.shape, self.metadata.chunk_grid) return sync( self._async_array._get_selection( @@ -869,8 +894,10 @@ def set_block_selection( value: NDArrayLike, *, fields: Fields | None = None, - prototype: BufferPrototype = default_buffer_prototype, + prototype: BufferPrototype | None = None, ) -> None: + if prototype is None: + prototype = default_buffer_prototype() indexer = BlockIndexer(selection, self.shape, self.metadata.chunk_grid) sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) diff --git a/src/zarr/buffer.py b/src/zarr/buffer.py index 1a34d9f290..cc79d45f65 100644 --- a/src/zarr/buffer.py +++ b/src/zarr/buffer.py @@ -15,6 +15,12 @@ import numpy as np import numpy.typing as npt +from zarr.codecs.registry import ( + get_buffer_class, + get_ndbuffer_class, + register_buffer, + register_ndbuffer, +) from zarr.common import ChunkCoords if TYPE_CHECKING: @@ -326,7 +332,7 @@ def from_numpy_array(cls, array_like: npt.ArrayLike) -> Self: ------- New buffer representing `array_like` """ - return cls.from_ndarray_like(np.asanyarray(array_like)) + return cls.from_ndarray_like(np.asanyarray(array_like)) # TODO def as_ndarray_like(self) -> NDArrayLike: """Returns the underlying array (host or device memory) of this buffer @@ -454,4 +460,9 @@ class BufferPrototype(NamedTuple): # The default buffer prototype used throughout the Zarr codebase. -default_buffer_prototype = BufferPrototype(buffer=Buffer, nd_buffer=NDBuffer) +def default_buffer_prototype() -> BufferPrototype: + return BufferPrototype(buffer=get_buffer_class(), nd_buffer=get_ndbuffer_class()) + + +register_buffer(Buffer) +register_ndbuffer(NDBuffer) diff --git a/src/zarr/codecs/registry.py b/src/zarr/codecs/registry.py index 7c123d63c4..1722a4c5e2 100644 --- a/src/zarr/codecs/registry.py +++ b/src/zarr/codecs/registry.py @@ -5,6 +5,7 @@ if TYPE_CHECKING: from zarr.abc.codec import Codec, CodecPipeline + from zarr.buffer import Buffer, NDBuffer from importlib.metadata import EntryPoint from importlib.metadata import entry_points as get_entry_points @@ -15,6 +16,10 @@ __lazy_load_codecs: dict[str, EntryPoint] = {} __pipeline_registry: dict[str, type[CodecPipeline]] = {} __lazy_load_pipelines: list[EntryPoint] = [] +__buffer_registry: dict[str, type[Buffer]] = {} +__lazy_load_buffer: list[EntryPoint] = [] +__ndbuffer_registry: dict[str, type[NDBuffer]] = {} +__lazy_load_ndbuffer: list[EntryPoint] = [] def _collect_entrypoints() -> tuple[dict[str, EntryPoint], list[EntryPoint]]: @@ -41,20 +46,12 @@ def register_pipeline(pipe_cls: type[CodecPipeline]) -> None: __pipeline_registry[pipe_cls.__name__] = pipe_cls -def get_pipeline_class(reload_config: bool = False) -> type[CodecPipeline]: - if reload_config: - _reload_config() - for e in __lazy_load_pipelines: - __lazy_load_pipelines.remove(e) - register_pipeline(e.load()) +def register_ndbuffer(cls: type[NDBuffer]) -> None: + __ndbuffer_registry[cls.__name__] = cls - name = config.get("codec_pipeline.name") - pipeline_class = __pipeline_registry.get(name) - if pipeline_class: - return pipeline_class - raise BadConfigError( - f"Pipeline class '{name}' not found in registered pipelines: {list(__pipeline_registry.keys())}." - ) + +def register_buffer(cls: type[Buffer]) -> None: + __buffer_registry[cls.__name__] = cls def get_codec_class(key: str, reload_config: bool = False) -> type[Codec]: @@ -85,4 +82,52 @@ def get_codec_class(key: str, reload_config: bool = False) -> type[Codec]: raise KeyError(key) +def get_pipeline_class(reload_config: bool = False) -> type[CodecPipeline]: + if reload_config: + _reload_config() + for e in __lazy_load_pipelines: + __lazy_load_pipelines.remove(e) + register_pipeline(e.load()) + + name = config.get("codec_pipeline.name") + pipeline_class = __pipeline_registry.get(name) + if pipeline_class: + return pipeline_class + raise BadConfigError( + f"Pipeline class '{name}' not found in registered pipelines: {list(__pipeline_registry.keys())}." + ) + + +def get_buffer_class(reload_config: bool = False) -> type[Buffer]: + if reload_config: + _reload_config() + for e in __lazy_load_buffer: + __lazy_load_buffer.remove(e) + register_buffer(e.load()) + + name = config.get("buffer.name") + buffer_class = __buffer_registry.get(name) + if buffer_class: + return buffer_class + raise BadConfigError( + f"Buffer class '{name}' not found in registered buffers: {list(__buffer_registry.keys())}." + ) + + +def get_ndbuffer_class(reload_config: bool = False) -> type[NDBuffer]: + if reload_config: + _reload_config() + for e in __lazy_load_ndbuffer: + __lazy_load_ndbuffer.remove(e) + register_ndbuffer(e.load()) + + name = config.get("ndbuffer.name") + ndbuffer_class = __ndbuffer_registry.get(name) + if ndbuffer_class: + return ndbuffer_class + raise BadConfigError( + f"NDBuffer class '{name}' not found in registered buffers: {list(__ndbuffer_registry.keys())}." + ) + + _collect_entrypoints() diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index c9e767bc52..78a18c98d2 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -66,7 +66,7 @@ async def get( ) -> Buffer | None: assert byte_range is None, "byte_range is not supported within shards" assert ( - prototype is default_buffer_prototype + prototype is default_buffer_prototype() ), "prototype is not supported within shards currently" return self.shard_dict.get(self.chunk_coords) @@ -621,7 +621,7 @@ def _get_index_chunk_spec(self, chunks_per_shard: ChunkCoords) -> ArraySpec: dtype=np.dtype(" ArraySpec: @@ -649,11 +649,11 @@ async def _load_shard_index_maybe( shard_index_size = self._shard_index_size(chunks_per_shard) if self.index_location == ShardingCodecIndexLocation.start: index_bytes = await byte_getter.get( - prototype=default_buffer_prototype, byte_range=(0, shard_index_size) + prototype=default_buffer_prototype(), byte_range=(0, shard_index_size) ) else: index_bytes = await byte_getter.get( - prototype=default_buffer_prototype, byte_range=(-shard_index_size, None) + prototype=default_buffer_prototype(), byte_range=(-shard_index_size, None) ) if index_bytes is not None: return await self._decode_shard_index(index_bytes, chunks_per_shard) diff --git a/src/zarr/config.py b/src/zarr/config.py index 2209b5efc3..4ab38c7280 100644 --- a/src/zarr/config.py +++ b/src/zarr/config.py @@ -45,6 +45,8 @@ def reset(self) -> None: "sharding_indexed": {"name": "ShardingCodec"}, "transpose": {"name": "TransposeCodec"}, }, + "buffer": {"name": "Buffer"}, + "ndbuffer": {"name": "NDBuffer"}, } ], ) diff --git a/src/zarr/metadata.py b/src/zarr/metadata.py index a39797a106..8cef88a558 100644 --- a/src/zarr/metadata.py +++ b/src/zarr/metadata.py @@ -200,7 +200,7 @@ def __init__( dtype=data_type_parsed, fill_value=fill_value_parsed, order="C", # TODO: order is not needed here. - prototype=default_buffer_prototype, # TODO: prototype is not needed here. + prototype=default_buffer_prototype(), # TODO: prototype is not needed here. ) codecs_parsed = parse_codecs(codecs).evolve_from_array_spec(array_spec) diff --git a/src/zarr/store/core.py b/src/zarr/store/core.py index 512c8383eb..8afc87424d 100644 --- a/src/zarr/store/core.py +++ b/src/zarr/store/core.py @@ -29,9 +29,11 @@ def __init__(self, store: Store, path: str | None = None): async def get( self, - prototype: BufferPrototype = default_buffer_prototype, + prototype: BufferPrototype | None = None, byte_range: tuple[int, int | None] | None = None, ) -> Buffer | None: + if prototype is None: + prototype = default_buffer_prototype() return await self.store.get(self.path, prototype=prototype, byte_range=byte_range) async def set(self, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: diff --git a/src/zarr/store/remote.py b/src/zarr/store/remote.py index db826f456d..7cfbad5d2f 100644 --- a/src/zarr/store/remote.py +++ b/src/zarr/store/remote.py @@ -79,11 +79,12 @@ def __repr__(self) -> str: async def get( self, key: str, - prototype: BufferPrototype = default_buffer_prototype, + prototype: BufferPrototype | None = None, byte_range: tuple[int | None, int | None] | None = None, ) -> Buffer | None: path = _dereference_path(self.path, key) - + if prototype is None: + prototype = default_buffer_prototype() try: if byte_range: # fsspec uses start/end, not start/length diff --git a/src/zarr/testing/buffer.py b/src/zarr/testing/buffer.py new file mode 100644 index 0000000000..a56a0ced64 --- /dev/null +++ b/src/zarr/testing/buffer.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, Literal + +import numpy as np +import numpy.typing as npt + +from zarr.buffer import Buffer, BufferPrototype, NDBuffer +from zarr.store import MemoryStore + +if TYPE_CHECKING: + from typing_extensions import Self + + +class MyNDArrayLike(np.ndarray): + """An example of a ndarray-like class""" + + +class MyBuffer(Buffer): + """Example of a custom Buffer that handles ArrayLike""" + + +class MyNDBuffer(NDBuffer): + """Example of a custom NDBuffer that handles MyNDArrayLike""" + + @classmethod + def create( + cls, + *, + shape: Iterable[int], + dtype: npt.DTypeLike, + order: Literal["C", "F"] = "C", + fill_value: Any | None = None, + ) -> Self: + """Overwrite `NDBuffer.create` to create an MyNDArrayLike instance""" + ret = cls(MyNDArrayLike(shape=shape, dtype=dtype, order=order)) + if fill_value is not None: + ret.fill(fill_value) + return ret + + +class MyStore(MemoryStore): + """Example of a custom Store that expect MyBuffer for all its non-metadata + + We assume that keys containing "json" is metadata + """ + + async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: + if "json" not in key: + assert isinstance(value, MyBuffer) + await super().set(key, value, byte_range) + + async def get( + self, + key: str, + prototype: BufferPrototype, + byte_range: tuple[int, int | None] | None = None, + ) -> Buffer | None: + if "json" not in key: + assert prototype.buffer is MyBuffer + return await super().get(key, byte_range) diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 9c37ce0434..69786838e2 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -91,7 +91,7 @@ async def test_get( """ data_buf = Buffer.from_bytes(data) self.set(store, key, data_buf) - observed = await store.get(key, prototype=default_buffer_prototype, byte_range=byte_range) + observed = await store.get(key, prototype=default_buffer_prototype(), byte_range=byte_range) start, length = _normalize_interval_index(data_buf, interval=byte_range) expected = data_buf[start : start + length] assert_bytes_equal(observed, expected) @@ -126,7 +126,7 @@ async def test_get_partial_values( # read back just part of it observed_maybe = await store.get_partial_values( - prototype=default_buffer_prototype, key_ranges=key_ranges + prototype=default_buffer_prototype(), key_ranges=key_ranges ) observed: list[Buffer] = [] @@ -138,7 +138,9 @@ async def test_get_partial_values( for idx in range(len(observed)): key, byte_range = key_ranges[idx] - result = await store.get(key, prototype=default_buffer_prototype, byte_range=byte_range) + result = await store.get( + key, prototype=default_buffer_prototype(), byte_range=byte_range + ) assert result is not None expected.append(result) diff --git a/tests/v3/test_buffer.py b/tests/v3/test_buffer.py index e814afef15..ebc2c332ec 100644 --- a/tests/v3/test_buffer.py +++ b/tests/v3/test_buffer.py @@ -1,14 +1,10 @@ from __future__ import annotations -from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Literal - import numpy as np -import numpy.typing as npt import pytest from zarr.array import AsyncArray -from zarr.buffer import ArrayLike, Buffer, BufferPrototype, NDArrayLike, NDBuffer +from zarr.buffer import ArrayLike, BufferPrototype, NDArrayLike from zarr.codecs.blosc import BloscCodec from zarr.codecs.bytes import BytesCodec from zarr.codecs.crc32c_ import Crc32cCodec @@ -16,59 +12,7 @@ from zarr.codecs.transpose import TransposeCodec from zarr.codecs.zstd import ZstdCodec from zarr.store.core import StorePath -from zarr.store.memory import MemoryStore - -if TYPE_CHECKING: - from typing_extensions import Self - - -class MyNDArrayLike(np.ndarray): - """An example of a ndarray-like class""" - - -class MyBuffer(Buffer): - """Example of a custom Buffer that handles ArrayLike""" - - -class MyNDBuffer(NDBuffer): - """Example of a custom NDBuffer that handles MyNDArrayLike""" - - @classmethod - def create( - cls, - *, - shape: Iterable[int], - dtype: npt.DTypeLike, - order: Literal["C", "F"] = "C", - fill_value: Any | None = None, - ) -> Self: - """Overwrite `NDBuffer.create` to create an MyNDArrayLike instance""" - ret = cls(MyNDArrayLike(shape=shape, dtype=dtype, order=order)) - if fill_value is not None: - ret.fill(fill_value) - return ret - - -class MyStore(MemoryStore): - """Example of a custom Store that expect MyBuffer for all its non-metadata - - We assume that keys containing "json" is metadata - """ - - async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: - if "json" not in key: - assert isinstance(value, MyBuffer) - await super().set(key, value, byte_range) - - async def get( - self, - key: str, - prototype: BufferPrototype, - byte_range: tuple[int, int | None] | None = None, - ) -> Buffer | None: - if "json" not in key: - assert prototype.buffer is MyBuffer - return await super().get(key, byte_range) +from zarr.testing.buffer import MyBuffer, MyNDArrayLike, MyNDBuffer, MyStore def test_nd_array_like(xp): diff --git a/tests/v3/test_config.py b/tests/v3/test_config.py index e0d1bee1c8..9f19dffe32 100644 --- a/tests/v3/test_config.py +++ b/tests/v3/test_config.py @@ -13,12 +13,15 @@ from zarr.codecs import BatchedCodecPipeline, BloscCodec, BytesCodec from zarr.codecs.registry import ( get_codec_class, + get_ndbuffer_class, get_pipeline_class, register_codec, + register_ndbuffer, register_pipeline, ) from zarr.config import BadConfigError, config from zarr.indexing import SelectorTuple +from zarr.testing.buffer import MyNDArrayLike, MyNDBuffer @pytest.fixture() @@ -135,3 +138,29 @@ async def _encode_single( with mock.patch.dict(os.environ, {"ZARR_PYTHON_CODECS__BLOSC__NAME": "BloscCodec"}): assert get_codec_class("blosc", reload_config=True) == BloscCodec + + +@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) +def test_config_ndbuffer_implementation(store): + # has default value + assert get_ndbuffer_class().__name__ == config.defaults[0]["ndbuffer"]["name"] + + # set custom ndbuffer with MyNDArrayLike implementation + register_ndbuffer(MyNDBuffer) + config.set({"ndbuffer.name": "MyNDBuffer"}) + assert get_ndbuffer_class() == MyNDBuffer + arr = Array.create( + store=store, + shape=(100,), + chunks=(10,), + zarr_format=3, + dtype="i4", + ) + got = arr[:] + print(type(got)) + assert isinstance(got, MyNDArrayLike) + + +@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) +def test_config_buffer_implementation(store): + pass diff --git a/tests/v3/test_store/test_remote.py b/tests/v3/test_store/test_remote.py index 936cf206d9..161df214d9 100644 --- a/tests/v3/test_store/test_remote.py +++ b/tests/v3/test_store/test_remote.py @@ -73,7 +73,7 @@ async def test_basic(): assert await store.exists("foo") assert (await store.get("foo")).to_bytes() == data out = await store.get_partial_values( - prototype=default_buffer_prototype, key_ranges=[("foo", (1, None))] + prototype=default_buffer_prototype(), key_ranges=[("foo", (1, None))] ) assert out[0].to_bytes() == data[1:] From 6467149781bc269bb72ba45a602bc583ab46b1ab Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Thu, 20 Jun 2024 14:16:19 +0200 Subject: [PATCH 10/36] fix circular import --- src/zarr/buffer.py | 4 ++-- src/zarr/codecs/blosc.py | 2 +- src/zarr/codecs/bytes.py | 2 +- src/zarr/codecs/crc32c_.py | 2 +- src/zarr/codecs/gzip.py | 2 +- src/zarr/codecs/pipeline.py | 2 +- src/zarr/codecs/sharding.py | 2 +- src/zarr/codecs/transpose.py | 2 +- src/zarr/codecs/zstd.py | 2 +- src/zarr/metadata.py | 2 +- src/zarr/{codecs => }/registry.py | 0 tests/v3/test_codec_entrypoints.py | 10 +++++----- tests/v3/test_config.py | 6 +++--- 13 files changed, 19 insertions(+), 19 deletions(-) rename src/zarr/{codecs => }/registry.py (100%) diff --git a/src/zarr/buffer.py b/src/zarr/buffer.py index cc79d45f65..444381aad8 100644 --- a/src/zarr/buffer.py +++ b/src/zarr/buffer.py @@ -15,13 +15,13 @@ import numpy as np import numpy.typing as npt -from zarr.codecs.registry import ( +from zarr.common import ChunkCoords +from zarr.registry import ( get_buffer_class, get_ndbuffer_class, register_buffer, register_ndbuffer, ) -from zarr.common import ChunkCoords if TYPE_CHECKING: from typing_extensions import Self diff --git a/src/zarr/codecs/blosc.py b/src/zarr/codecs/blosc.py index e577d18fb2..c061e7fea6 100644 --- a/src/zarr/codecs/blosc.py +++ b/src/zarr/codecs/blosc.py @@ -11,8 +11,8 @@ from zarr.abc.codec import BytesBytesCodec from zarr.array_spec import ArraySpec from zarr.buffer import Buffer, as_numpy_array_wrapper -from zarr.codecs.registry import register_codec from zarr.common import JSON, parse_enum, parse_named_configuration, to_thread +from zarr.registry import register_codec if TYPE_CHECKING: from typing_extensions import Self diff --git a/src/zarr/codecs/bytes.py b/src/zarr/codecs/bytes.py index 0b9a5c089e..2be4a896f5 100644 --- a/src/zarr/codecs/bytes.py +++ b/src/zarr/codecs/bytes.py @@ -10,8 +10,8 @@ from zarr.abc.codec import ArrayBytesCodec from zarr.array_spec import ArraySpec from zarr.buffer import Buffer, NDArrayLike, NDBuffer -from zarr.codecs.registry import register_codec from zarr.common import JSON, parse_enum, parse_named_configuration +from zarr.registry import register_codec if TYPE_CHECKING: from typing_extensions import Self diff --git a/src/zarr/codecs/crc32c_.py b/src/zarr/codecs/crc32c_.py index b670b25429..5c94558b00 100644 --- a/src/zarr/codecs/crc32c_.py +++ b/src/zarr/codecs/crc32c_.py @@ -9,8 +9,8 @@ from zarr.abc.codec import BytesBytesCodec from zarr.array_spec import ArraySpec from zarr.buffer import Buffer -from zarr.codecs.registry import register_codec from zarr.common import JSON, parse_named_configuration +from zarr.registry import register_codec if TYPE_CHECKING: from typing_extensions import Self diff --git a/src/zarr/codecs/gzip.py b/src/zarr/codecs/gzip.py index 0ad97c1207..915ae79832 100644 --- a/src/zarr/codecs/gzip.py +++ b/src/zarr/codecs/gzip.py @@ -8,8 +8,8 @@ from zarr.abc.codec import BytesBytesCodec from zarr.array_spec import ArraySpec from zarr.buffer import Buffer, as_numpy_array_wrapper -from zarr.codecs.registry import register_codec from zarr.common import JSON, parse_named_configuration, to_thread +from zarr.registry import register_codec if TYPE_CHECKING: from typing_extensions import Self diff --git a/src/zarr/codecs/pipeline.py b/src/zarr/codecs/pipeline.py index 0eb2a8455f..7a9fd47f84 100644 --- a/src/zarr/codecs/pipeline.py +++ b/src/zarr/codecs/pipeline.py @@ -17,11 +17,11 @@ ) from zarr.abc.store import ByteGetter, ByteSetter from zarr.buffer import Buffer, BufferPrototype, NDBuffer -from zarr.codecs.registry import get_codec_class, register_pipeline from zarr.common import JSON, concurrent_map, parse_named_configuration from zarr.config import config from zarr.indexing import SelectorTuple, is_scalar, is_total_slice from zarr.metadata import ArrayMetadata +from zarr.registry import get_codec_class, register_pipeline if TYPE_CHECKING: from typing_extensions import Self diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 78a18c98d2..de92917b37 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -23,7 +23,6 @@ from zarr.chunk_grids import RegularChunkGrid from zarr.codecs.bytes import BytesCodec from zarr.codecs.crc32c_ import Crc32cCodec -from zarr.codecs.registry import get_pipeline_class, register_codec from zarr.common import ( ChunkCoords, ChunkCoordsLike, @@ -34,6 +33,7 @@ ) from zarr.indexing import BasicIndexer, SelectorTuple, c_order_iter, get_indexer, morton_order_iter from zarr.metadata import ArrayMetadata, parse_codecs +from zarr.registry import get_pipeline_class, register_codec if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Iterator diff --git a/src/zarr/codecs/transpose.py b/src/zarr/codecs/transpose.py index 33dab21fb6..6953c2c129 100644 --- a/src/zarr/codecs/transpose.py +++ b/src/zarr/codecs/transpose.py @@ -9,8 +9,8 @@ from zarr.abc.codec import ArrayArrayCodec from zarr.array_spec import ArraySpec from zarr.buffer import NDBuffer -from zarr.codecs.registry import register_codec from zarr.common import JSON, ChunkCoordsLike, parse_named_configuration +from zarr.registry import register_codec if TYPE_CHECKING: from typing import TYPE_CHECKING diff --git a/src/zarr/codecs/zstd.py b/src/zarr/codecs/zstd.py index 4c5afba00b..b244ee703a 100644 --- a/src/zarr/codecs/zstd.py +++ b/src/zarr/codecs/zstd.py @@ -9,8 +9,8 @@ from zarr.abc.codec import BytesBytesCodec from zarr.array_spec import ArraySpec from zarr.buffer import Buffer, as_numpy_array_wrapper -from zarr.codecs.registry import register_codec from zarr.common import JSON, parse_named_configuration, to_thread +from zarr.registry import register_codec if TYPE_CHECKING: from typing_extensions import Self diff --git a/src/zarr/metadata.py b/src/zarr/metadata.py index 8cef88a558..f62d024dff 100644 --- a/src/zarr/metadata.py +++ b/src/zarr/metadata.py @@ -16,7 +16,7 @@ from zarr.chunk_grids import ChunkGrid, RegularChunkGrid from zarr.chunk_key_encodings import ChunkKeyEncoding, parse_separator from zarr.codecs._v2 import V2Compressor, V2Filters -from zarr.codecs.registry import get_pipeline_class +from zarr.registry import get_pipeline_class if TYPE_CHECKING: from typing_extensions import Self diff --git a/src/zarr/codecs/registry.py b/src/zarr/registry.py similarity index 100% rename from src/zarr/codecs/registry.py rename to src/zarr/registry.py diff --git a/tests/v3/test_codec_entrypoints.py b/tests/v3/test_codec_entrypoints.py index 584d15e4c7..f9e0480951 100644 --- a/tests/v3/test_codec_entrypoints.py +++ b/tests/v3/test_codec_entrypoints.py @@ -3,7 +3,7 @@ import pytest -import zarr.codecs.registry +import zarr.registry from zarr import config here = os.path.abspath(os.path.dirname(__file__)) @@ -12,10 +12,10 @@ @pytest.fixture() def set_path(): sys.path.append(here) - zarr.codecs.registry._collect_entrypoints() + zarr.registry._collect_entrypoints() yield sys.path.remove(here) - lazy_load_codecs, lazy_load_pipelines = zarr.codecs.registry._collect_entrypoints() + lazy_load_codecs, lazy_load_pipelines = zarr.registry._collect_entrypoints() lazy_load_codecs.pop("test") lazy_load_pipelines.clear() config.reset() @@ -24,12 +24,12 @@ def set_path(): @pytest.mark.usefixtures("set_path") def test_entrypoint_codec(): config.set({"codecs.test.name": "TestCodec"}) - cls = zarr.codecs.registry.get_codec_class("test") + cls = zarr.registry.get_codec_class("test") assert cls.__name__ == "TestCodec" @pytest.mark.usefixtures("set_path") def test_entrypoint_pipeline(): config.set({"codec_pipeline.name": "TestCodecPipeline"}) - cls = zarr.codecs.registry.get_pipeline_class() + cls = zarr.registry.get_pipeline_class() assert cls.__name__ == "TestCodecPipeline" diff --git a/tests/v3/test_config.py b/tests/v3/test_config.py index 9f19dffe32..6596bd7270 100644 --- a/tests/v3/test_config.py +++ b/tests/v3/test_config.py @@ -11,7 +11,9 @@ from zarr.array_spec import ArraySpec from zarr.buffer import NDBuffer from zarr.codecs import BatchedCodecPipeline, BloscCodec, BytesCodec -from zarr.codecs.registry import ( +from zarr.config import BadConfigError, config +from zarr.indexing import SelectorTuple +from zarr.registry import ( get_codec_class, get_ndbuffer_class, get_pipeline_class, @@ -19,8 +21,6 @@ register_ndbuffer, register_pipeline, ) -from zarr.config import BadConfigError, config -from zarr.indexing import SelectorTuple from zarr.testing.buffer import MyNDArrayLike, MyNDBuffer From 460f853c2a4f5371d72b1f9f532bce43f005a9b6 Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Thu, 20 Jun 2024 14:38:38 +0200 Subject: [PATCH 11/36] change class method calls on NDBuffer to use get_ndbuffer_class() --- src/zarr/buffer.py | 2 +- src/zarr/codecs/_v2.py | 7 ++++--- src/zarr/codecs/sharding.py | 4 ++-- tests/v3/test_indexing.py | 17 +++++++++-------- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/zarr/buffer.py b/src/zarr/buffer.py index 444381aad8..ef5c4aa7a2 100644 --- a/src/zarr/buffer.py +++ b/src/zarr/buffer.py @@ -332,7 +332,7 @@ def from_numpy_array(cls, array_like: npt.ArrayLike) -> Self: ------- New buffer representing `array_like` """ - return cls.from_ndarray_like(np.asanyarray(array_like)) # TODO + return cls.from_ndarray_like(np.asanyarray(array_like)) def as_ndarray_like(self) -> NDArrayLike: """Returns the underlying array (host or device memory) of this buffer diff --git a/src/zarr/codecs/_v2.py b/src/zarr/codecs/_v2.py index c43a087a94..22fc486360 100644 --- a/src/zarr/codecs/_v2.py +++ b/src/zarr/codecs/_v2.py @@ -9,6 +9,7 @@ from zarr.array_spec import ArraySpec from zarr.buffer import Buffer, NDBuffer from zarr.common import JSON, to_thread +from zarr.registry import get_ndbuffer_class @dataclass(frozen=True) @@ -34,7 +35,7 @@ async def _decode_single( if str(chunk_numpy_array.dtype) != chunk_spec.dtype: chunk_numpy_array = chunk_numpy_array.view(chunk_spec.dtype) - return NDBuffer.from_numpy_array(chunk_numpy_array) + return get_ndbuffer_class().from_numpy_array(chunk_numpy_array) async def _encode_single( self, @@ -86,7 +87,7 @@ async def _decode_single( order=chunk_spec.order, ) - return NDBuffer.from_ndarray_like(chunk_ndarray) + return get_ndbuffer_class().from_ndarray_like(chunk_ndarray) async def _encode_single( self, @@ -99,7 +100,7 @@ async def _encode_single( filter = numcodecs.get_codec(filter_metadata) chunk_ndarray = await to_thread(filter.encode, chunk_ndarray) - return NDBuffer.from_ndarray_like(chunk_ndarray) + return get_ndbuffer_class().from_ndarray_like(chunk_ndarray) def compute_encoded_size(self, _input_byte_length: int, _chunk_spec: ArraySpec) -> int: raise NotImplementedError diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index de92917b37..e6d228abc7 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -33,7 +33,7 @@ ) from zarr.indexing import BasicIndexer, SelectorTuple, c_order_iter, get_indexer, morton_order_iter from zarr.metadata import ArrayMetadata, parse_codecs -from zarr.registry import get_pipeline_class, register_codec +from zarr.registry import get_ndbuffer_class, get_pipeline_class, register_codec if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Iterator @@ -599,7 +599,7 @@ async def _encode_shard_index(self, index: _ShardIndex) -> Buffer: await self.index_codecs.encode( [ ( - NDBuffer.from_numpy_array(index.offsets_and_lengths), + get_ndbuffer_class().from_numpy_array(index.offsets_and_lengths), self._get_index_chunk_spec(index.chunks_per_shard), ) ], diff --git a/tests/v3/test_indexing.py b/tests/v3/test_indexing.py index 00ea947b49..731dac8bc4 100644 --- a/tests/v3/test_indexing.py +++ b/tests/v3/test_indexing.py @@ -12,7 +12,7 @@ import zarr from zarr.abc.store import Store -from zarr.buffer import BufferPrototype, NDBuffer +from zarr.buffer import BufferPrototype from zarr.common import ChunkCoords from zarr.indexing import ( make_slice_selection, @@ -21,6 +21,7 @@ oindex_set, replace_ellipsis, ) +from zarr.registry import get_ndbuffer_class from zarr.store.core import StorePath from zarr.store.memory import MemoryStore @@ -123,7 +124,7 @@ def test_get_basic_selection_0d(store: StorePath): assert 42 == z[()] # test out param - b = NDBuffer.from_numpy_array(np.zeros_like(a)) + b = get_ndbuffer_class().from_numpy_array(np.zeros_like(a)) z.get_basic_selection(Ellipsis, out=b) assert_array_equal(a, b) @@ -141,10 +142,10 @@ def test_get_basic_selection_0d(store: StorePath): assert a[["foo", "bar"]] == z.get_basic_selection((), fields=["foo", "bar"]) assert a[["foo", "bar"]] == z["foo", "bar"] # test out param - b = NDBuffer.from_numpy_array(np.zeros_like(a)) + b = get_ndbuffer_class().from_numpy_array(np.zeros_like(a)) z.get_basic_selection(Ellipsis, out=b) assert_array_equal(a, b) - c = NDBuffer.from_numpy_array(np.zeros_like(a[["foo", "bar"]])) + c = get_ndbuffer_class().from_numpy_array(np.zeros_like(a[["foo", "bar"]])) z.get_basic_selection(Ellipsis, out=c, fields=["foo", "bar"]) assert_array_equal(a[["foo", "bar"]], c) @@ -232,7 +233,7 @@ def _test_get_basic_selection(a, z, selection): assert_array_equal(expect, actual) # test out param - b = NDBuffer.from_numpy_array(np.empty(shape=expect.shape, dtype=expect.dtype)) + b = get_ndbuffer_class().from_numpy_array(np.empty(shape=expect.shape, dtype=expect.dtype)) z.get_basic_selection(selection, out=b) assert_array_equal(expect, b.as_numpy_array()) @@ -1382,7 +1383,7 @@ def test_get_selection_out(store: StorePath): ] for selection in selections: expect = a[selection] - out = NDBuffer.from_numpy_array(np.empty(expect.shape)) + out = get_ndbuffer_class().from_numpy_array(np.empty(expect.shape)) z.get_basic_selection(selection, out=out) assert_array_equal(expect, out.as_numpy_array()[:]) @@ -1412,7 +1413,7 @@ def test_get_selection_out(store: StorePath): ] for selection in selections: expect = oindex(a, selection) - out = NDBuffer.from_numpy_array(np.zeros(expect.shape, dtype=expect.dtype)) + out = get_ndbuffer_class().from_numpy_array(np.zeros(expect.shape, dtype=expect.dtype)) z.get_orthogonal_selection(selection, out=out) assert_array_equal(expect, out.as_numpy_array()[:]) @@ -1434,7 +1435,7 @@ def test_get_selection_out(store: StorePath): ] for selection in selections: expect = a[selection] - out = NDBuffer.from_numpy_array(np.zeros(expect.shape, dtype=expect.dtype)) + out = get_ndbuffer_class().from_numpy_array(np.zeros(expect.shape, dtype=expect.dtype)) z.get_coordinate_selection(selection, out=out) assert_array_equal(expect, out.as_numpy_array()[:]) From acc7f1703de73d6c90601d138a90b1da3b2e94ce Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Thu, 20 Jun 2024 15:50:39 +0200 Subject: [PATCH 12/36] make buffer implementation configurable --- src/zarr/array.py | 2 +- src/zarr/codecs/_v2.py | 4 ++-- src/zarr/codecs/sharding.py | 16 +++++++++---- src/zarr/group.py | 14 ++++++------ src/zarr/metadata.py | 16 ++++++++----- src/zarr/testing/buffer.py | 2 +- src/zarr/testing/store.py | 4 +--- tests/v3/test_buffer.py | 6 ++--- tests/v3/test_config.py | 45 ++++++++++++++++++++++++++++++++----- 9 files changed, 76 insertions(+), 33 deletions(-) diff --git a/src/zarr/array.py b/src/zarr/array.py index 2d26b4eaf5..4fcd6dd432 100644 --- a/src/zarr/array.py +++ b/src/zarr/array.py @@ -475,7 +475,7 @@ async def getitem( return await self._get_selection(indexer, prototype=prototype) async def _save_metadata(self, metadata: ArrayMetadata) -> None: - to_save = metadata.to_buffer_dict() + to_save = metadata.to_buffer_dict(default_buffer_prototype()) awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()] await gather(*awaitables) diff --git a/src/zarr/codecs/_v2.py b/src/zarr/codecs/_v2.py index 22fc486360..60854bee34 100644 --- a/src/zarr/codecs/_v2.py +++ b/src/zarr/codecs/_v2.py @@ -7,7 +7,7 @@ from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec from zarr.array_spec import ArraySpec -from zarr.buffer import Buffer, NDBuffer +from zarr.buffer import Buffer, NDBuffer, default_buffer_prototype from zarr.common import JSON, to_thread from zarr.registry import get_ndbuffer_class @@ -56,7 +56,7 @@ async def _encode_single( else: encoded_chunk_bytes = ensure_bytes(chunk_numpy_array) - return Buffer.from_bytes(encoded_chunk_bytes) + return default_buffer_prototype().buffer.from_bytes(encoded_chunk_bytes) def compute_encoded_size(self, _input_byte_length: int, _chunk_spec: ArraySpec) -> int: raise NotImplementedError diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index e6d228abc7..ef81c66de4 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -169,10 +169,14 @@ async def from_bytes( return obj @classmethod - def create_empty(cls, chunks_per_shard: ChunkCoords) -> _ShardReader: + def create_empty( + cls, chunks_per_shard: ChunkCoords, buffer_prototype: BufferPrototype | None = None + ) -> _ShardReader: + if buffer_prototype is None: + buffer_prototype = default_buffer_prototype() index = _ShardIndex.create_empty(chunks_per_shard) obj = cls() - obj.buf = Buffer.create_zero_length() + obj.buf = buffer_prototype.buffer.create_zero_length() obj.index = index return obj @@ -215,9 +219,13 @@ def merge_with_morton_order( return obj @classmethod - def create_empty(cls, chunks_per_shard: ChunkCoords) -> _ShardBuilder: + def create_empty( + cls, chunks_per_shard: ChunkCoords, buffer_prototype: BufferPrototype | None = None + ) -> _ShardBuilder: + if buffer_prototype is None: + buffer_prototype = default_buffer_prototype() obj = cls() - obj.buf = Buffer.create_zero_length() + obj.buf = buffer_prototype.buffer.create_zero_length() obj.index = _ShardIndex.create_empty(chunks_per_shard) return obj diff --git a/src/zarr/group.py b/src/zarr/group.py index 4bb4b6b4dd..4f6a1dd721 100644 --- a/src/zarr/group.py +++ b/src/zarr/group.py @@ -14,7 +14,7 @@ from zarr.abc.store import set_or_delete from zarr.array import Array, AsyncArray from zarr.attributes import Attributes -from zarr.buffer import Buffer +from zarr.buffer import Buffer, BufferPrototype, default_buffer_prototype from zarr.chunk_key_encodings import ChunkKeyEncoding from zarr.common import ( JSON, @@ -78,15 +78,15 @@ class GroupMetadata(Metadata): zarr_format: ZarrFormat = 3 node_type: Literal["group"] = field(default="group", init=False) - def to_buffer_dict(self) -> dict[str, Buffer]: + def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: if self.zarr_format == 3: - return {ZARR_JSON: Buffer.from_bytes(json.dumps(self.to_dict()).encode())} + return {ZARR_JSON: prototype.buffer.from_bytes(json.dumps(self.to_dict()).encode())} else: return { - ZGROUP_JSON: Buffer.from_bytes( + ZGROUP_JSON: prototype.buffer.from_bytes( json.dumps({"zarr_format": self.zarr_format}).encode() ), - ZATTRS_JSON: Buffer.from_bytes(json.dumps(self.attributes).encode()), + ZATTRS_JSON: prototype.buffer.from_bytes(json.dumps(self.attributes).encode()), } def __init__(self, attributes: dict[str, Any] | None = None, zarr_format: ZarrFormat = 3): @@ -266,7 +266,7 @@ async def delitem(self, key: str) -> None: raise ValueError(f"unexpected zarr_format: {self.metadata.zarr_format}") async def _save_metadata(self) -> None: - to_save = self.metadata.to_buffer_dict() + to_save = self.metadata.to_buffer_dict(default_buffer_prototype()) awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()] await asyncio.gather(*awaitables) @@ -529,7 +529,7 @@ async def update_attributes_async(self, new_attributes: dict[str, Any]) -> Group new_metadata = replace(self.metadata, attributes=new_attributes) # Write new metadata - to_save = new_metadata.to_buffer_dict() + to_save = new_metadata.to_buffer_dict(default_buffer_prototype()) awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()] await asyncio.gather(*awaitables) diff --git a/src/zarr/metadata.py b/src/zarr/metadata.py index f62d024dff..2719a01e53 100644 --- a/src/zarr/metadata.py +++ b/src/zarr/metadata.py @@ -147,7 +147,7 @@ def encode_chunk_key(self, chunk_coords: ChunkCoords) -> str: pass @abstractmethod - def to_buffer_dict(self) -> dict[str, Buffer]: + def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: pass @abstractmethod @@ -259,7 +259,7 @@ def get_chunk_spec( def encode_chunk_key(self, chunk_coords: ChunkCoords) -> str: return self.chunk_key_encoding.encode_chunk_key(chunk_coords) - def to_buffer_dict(self) -> dict[str, Buffer]: + def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: def _json_convert(o: np.dtype[Any] | Enum | Codec) -> str | dict[str, Any]: if isinstance(o, np.dtype): return str(o) @@ -273,7 +273,9 @@ def _json_convert(o: np.dtype[Any] | Enum | Codec) -> str | dict[str, Any]: raise TypeError return { - ZARR_JSON: Buffer.from_bytes(json.dumps(self.to_dict(), default=_json_convert).encode()) + ZARR_JSON: prototype.buffer.from_bytes( + json.dumps(self.to_dict(), default=_json_convert).encode() + ) } @classmethod @@ -377,7 +379,7 @@ def codec_pipeline(self) -> CodecPipeline: [V2Filters(self.filters or []), V2Compressor(self.compressor)] ) - def to_buffer_dict(self) -> dict[str, Buffer]: + def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: def _json_convert( o: np.dtype[Any], ) -> str | list[tuple[str, str] | tuple[str, str, tuple[int, ...]]]: @@ -393,8 +395,10 @@ def _json_convert( zattrs_dict = zarray_dict.pop("attributes", {}) assert isinstance(zattrs_dict, dict) return { - ZARRAY_JSON: Buffer.from_bytes(json.dumps(zarray_dict, default=_json_convert).encode()), - ZATTRS_JSON: Buffer.from_bytes(json.dumps(zattrs_dict).encode()), + ZARRAY_JSON: prototype.buffer.from_bytes( + json.dumps(zarray_dict, default=_json_convert).encode() + ), + ZATTRS_JSON: prototype.buffer.from_bytes(json.dumps(zattrs_dict).encode()), } @classmethod diff --git a/src/zarr/testing/buffer.py b/src/zarr/testing/buffer.py index a56a0ced64..0a7b0565ad 100644 --- a/src/zarr/testing/buffer.py +++ b/src/zarr/testing/buffer.py @@ -40,7 +40,7 @@ def create( return ret -class MyStore(MemoryStore): +class StoreExpectingMyBuffer(MemoryStore): """Example of a custom Store that expect MyBuffer for all its non-metadata We assume that keys containing "json" is metadata diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 69786838e2..03925ab1d6 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -125,9 +125,7 @@ async def test_get_partial_values( self.set(store, key, Buffer.from_bytes(bytes(key, encoding="utf-8"))) # read back just part of it - observed_maybe = await store.get_partial_values( - prototype=default_buffer_prototype(), key_ranges=key_ranges - ) + observed_maybe = await store.get_partial_values(prototype=Buffer, key_ranges=key_ranges) observed: list[Buffer] = [] expected: list[Buffer] = [] diff --git a/tests/v3/test_buffer.py b/tests/v3/test_buffer.py index ebc2c332ec..5f6a948b17 100644 --- a/tests/v3/test_buffer.py +++ b/tests/v3/test_buffer.py @@ -12,7 +12,7 @@ from zarr.codecs.transpose import TransposeCodec from zarr.codecs.zstd import ZstdCodec from zarr.store.core import StorePath -from zarr.testing.buffer import MyBuffer, MyNDArrayLike, MyNDBuffer, MyStore +from zarr.testing.buffer import MyBuffer, MyNDArrayLike, MyNDBuffer, StoreExpectingMyBuffer def test_nd_array_like(xp): @@ -27,7 +27,7 @@ async def test_async_array_prototype(): expect = np.zeros((9, 9), dtype="uint16", order="F") a = await AsyncArray.create( - StorePath(MyStore(mode="w")) / "test_async_array_prototype", + StorePath(StoreExpectingMyBuffer(mode="w")) / "test_async_array_prototype", shape=expect.shape, chunk_shape=(5, 5), dtype=expect.dtype, @@ -51,7 +51,7 @@ async def test_async_array_prototype(): async def test_codecs_use_of_prototype(): expect = np.zeros((10, 10), dtype="uint16", order="F") a = await AsyncArray.create( - StorePath(MyStore(mode="w")) / "test_codecs_use_of_prototype", + StorePath(StoreExpectingMyBuffer(mode="w")) / "test_codecs_use_of_prototype", shape=expect.shape, chunk_shape=(5, 5), dtype=expect.dtype, diff --git a/tests/v3/test_config.py b/tests/v3/test_config.py index 6596bd7270..a733bc2e22 100644 --- a/tests/v3/test_config.py +++ b/tests/v3/test_config.py @@ -3,25 +3,28 @@ from unittest import mock from unittest.mock import Mock +import numpy as np import pytest -from zarr import Array +from zarr import Array, zeros from zarr.abc.codec import CodecInput, CodecOutput, CodecPipeline from zarr.abc.store import ByteSetter from zarr.array_spec import ArraySpec from zarr.buffer import NDBuffer -from zarr.codecs import BatchedCodecPipeline, BloscCodec, BytesCodec +from zarr.codecs import BatchedCodecPipeline, BloscCodec, BytesCodec, Crc32cCodec, ShardingCodec from zarr.config import BadConfigError, config from zarr.indexing import SelectorTuple from zarr.registry import ( + get_buffer_class, get_codec_class, get_ndbuffer_class, get_pipeline_class, + register_buffer, register_codec, register_ndbuffer, register_pipeline, ) -from zarr.testing.buffer import MyNDArrayLike, MyNDBuffer +from zarr.testing.buffer import MyBuffer, MyNDArrayLike, MyNDBuffer, StoreExpectingMyBuffer @pytest.fixture() @@ -161,6 +164,36 @@ def test_config_ndbuffer_implementation(store): assert isinstance(got, MyNDArrayLike) -@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) -def test_config_buffer_implementation(store): - pass +def test_config_buffer_implementation(): + # has default value + assert get_buffer_class().__name__ == config.defaults[0]["buffer"]["name"] + + arr = zeros(shape=(100), store=StoreExpectingMyBuffer(mode="w")) + + # AssertionError of StoreExpectingMyBuffer when not using my buffer + with pytest.raises(AssertionError): + arr[:] = np.arange(100) + + register_buffer(MyBuffer) + config.set({"buffer.name": "MyBuffer"}) + assert get_buffer_class() == MyBuffer + + # no error using MyBuffer + arr[:] = np.arange(100) + + arr_sharding = zeros( + shape=(100, 10), + store=StoreExpectingMyBuffer(mode="w"), + codecs=[ShardingCodec(chunk_shape=(10, 10))], + ) + arr_sharding[:] = np.arange(1000).reshape(100, 10) + + arr_Crc32c = zeros( + shape=(100, 10), + store=StoreExpectingMyBuffer(mode="w"), + codecs=[BytesCodec(), Crc32cCodec()], + ) + arr_Crc32c[:] = np.arange(1000).reshape(100, 10) + + +pass From b060722cc4c0991a13cc56978c9da21550a99b40 Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Thu, 20 Jun 2024 16:38:13 +0200 Subject: [PATCH 13/36] format --- src/zarr/config.py | 1 - src/zarr/metadata.py | 6 ++++-- tests/v3/test_config.py | 6 +----- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/zarr/config.py b/src/zarr/config.py index c4e1292e3a..fc7afaa4b3 100644 --- a/src/zarr/config.py +++ b/src/zarr/config.py @@ -48,7 +48,6 @@ def reset(self) -> None: }, "buffer": {"name": "Buffer"}, "ndbuffer": {"name": "NDBuffer"}, - "codec_pipeline": {"batch_size": 1}, } ], ) diff --git a/src/zarr/metadata.py b/src/zarr/metadata.py index ced962479b..5b0c6f5c9d 100644 --- a/src/zarr/metadata.py +++ b/src/zarr/metadata.py @@ -16,8 +16,8 @@ from zarr.chunk_grids import ChunkGrid, RegularChunkGrid from zarr.chunk_key_encodings import ChunkKeyEncoding, parse_separator from zarr.codecs._v2 import V2Compressor, V2Filters -from zarr.registry import get_pipeline_class from zarr.config import config +from zarr.registry import get_pipeline_class if TYPE_CHECKING: from typing_extensions import Self @@ -401,7 +401,9 @@ def _json_convert( ZARRAY_JSON: prototype.buffer.from_bytes( json.dumps(zarray_dict, default=_json_convert, indent=json_indent).encode() ), - ZATTRS_JSON: prototype.buffer.from_bytes(json.dumps(zattrs_dict, indent=json_indent).encode()), + ZATTRS_JSON: prototype.buffer.from_bytes( + json.dumps(zattrs_dict, indent=json_indent).encode() + ), } @classmethod diff --git a/tests/v3/test_config.py b/tests/v3/test_config.py index eb1fd23728..6386ddfd2e 100644 --- a/tests/v3/test_config.py +++ b/tests/v3/test_config.py @@ -1,5 +1,6 @@ import os from collections.abc import Iterable +from typing import Any from unittest import mock from unittest.mock import Mock @@ -26,11 +27,6 @@ ) from zarr.testing.buffer import MyBuffer, MyNDArrayLike, MyNDBuffer, StoreExpectingMyBuffer -from typing import Any - -import pytest - -from zarr.config import config @pytest.fixture() def reset_config(): From 1ca197db5d93f5fa6e93f1bd9a97e865745eddca Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Thu, 20 Jun 2024 16:49:57 +0200 Subject: [PATCH 14/36] fix tests --- src/zarr/codecs/sharding.py | 2 +- src/zarr/testing/store.py | 4 +++- tests/v3/test_config.py | 2 ++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index ef81c66de4..06e0ff48e4 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -66,7 +66,7 @@ async def get( ) -> Buffer | None: assert byte_range is None, "byte_range is not supported within shards" assert ( - prototype is default_buffer_prototype() + prototype == default_buffer_prototype() ), "prototype is not supported within shards currently" return self.shard_dict.get(self.chunk_coords) diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 03925ab1d6..69786838e2 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -125,7 +125,9 @@ async def test_get_partial_values( self.set(store, key, Buffer.from_bytes(bytes(key, encoding="utf-8"))) # read back just part of it - observed_maybe = await store.get_partial_values(prototype=Buffer, key_ranges=key_ranges) + observed_maybe = await store.get_partial_values( + prototype=default_buffer_prototype(), key_ranges=key_ranges + ) observed: list[Buffer] = [] expected: list[Buffer] = [] diff --git a/tests/v3/test_config.py b/tests/v3/test_config.py index 6386ddfd2e..9dc714fdd0 100644 --- a/tests/v3/test_config.py +++ b/tests/v3/test_config.py @@ -43,6 +43,8 @@ def test_config_defaults_set() -> None: "async": {"concurrency": None, "timeout": None}, "json_indent": 2, "codec_pipeline": {"name": "BatchedCodecPipeline", "batch_size": 1}, + "buffer": {"name": "Buffer"}, + "ndbuffer": {"name": "NDBuffer"}, "codecs": { "blosc": {"name": "BloscCodec"}, "gzip": {"name": "GzipCodec"}, From 26329f6c605c15d732da90fb2f249fde220a0c29 Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Thu, 20 Jun 2024 17:02:12 +0200 Subject: [PATCH 15/36] ignore mypy in tests --- src/zarr/testing/buffer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/zarr/testing/buffer.py b/src/zarr/testing/buffer.py index 0a7b0565ad..0f38b6975d 100644 --- a/src/zarr/testing/buffer.py +++ b/src/zarr/testing/buffer.py @@ -1,3 +1,4 @@ +# mypy: ignore-errors from __future__ import annotations from collections.abc import Iterable From 9601a6f84089e44f6d4561c22159c348a0a60c62 Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Fri, 21 Jun 2024 11:06:14 +0200 Subject: [PATCH 16/36] add test to lazy load (nd)buffer from entrypoint --- src/zarr/registry.py | 14 ++++++---- .../entry_points.txt | 2 ++ tests/v3/package_with_entrypoint/__init__.py | 27 ++++++++++++++----- tests/v3/test_codec_entrypoints.py | 13 ++++++--- 4 files changed, 41 insertions(+), 15 deletions(-) diff --git a/src/zarr/registry.py b/src/zarr/registry.py index 1722a4c5e2..7906ffb699 100644 --- a/src/zarr/registry.py +++ b/src/zarr/registry.py @@ -22,14 +22,21 @@ __lazy_load_ndbuffer: list[EntryPoint] = [] -def _collect_entrypoints() -> tuple[dict[str, EntryPoint], list[EntryPoint]]: +def _collect_entrypoints() -> ( + tuple[dict[str, EntryPoint], list[EntryPoint], list[EntryPoint], list[EntryPoint]] +): entry_points = get_entry_points() for e in entry_points.select(group="zarr.codecs"): __lazy_load_codecs[e.name] = e for e in entry_points.select(group="zarr"): if e.name == "codec_pipeline": __lazy_load_pipelines.append(e) - return __lazy_load_codecs, __lazy_load_pipelines + if e.name == "buffer": + __lazy_load_buffer.append(e) + if e.name == "ndbuffer": + __lazy_load_ndbuffer.append(e) + + return __lazy_load_codecs, __lazy_load_pipelines, __lazy_load_buffer, __lazy_load_ndbuffer def _reload_config() -> None: @@ -88,7 +95,6 @@ def get_pipeline_class(reload_config: bool = False) -> type[CodecPipeline]: for e in __lazy_load_pipelines: __lazy_load_pipelines.remove(e) register_pipeline(e.load()) - name = config.get("codec_pipeline.name") pipeline_class = __pipeline_registry.get(name) if pipeline_class: @@ -104,7 +110,6 @@ def get_buffer_class(reload_config: bool = False) -> type[Buffer]: for e in __lazy_load_buffer: __lazy_load_buffer.remove(e) register_buffer(e.load()) - name = config.get("buffer.name") buffer_class = __buffer_registry.get(name) if buffer_class: @@ -120,7 +125,6 @@ def get_ndbuffer_class(reload_config: bool = False) -> type[NDBuffer]: for e in __lazy_load_ndbuffer: __lazy_load_ndbuffer.remove(e) register_ndbuffer(e.load()) - name = config.get("ndbuffer.name") ndbuffer_class = __ndbuffer_registry.get(name) if ndbuffer_class: diff --git a/tests/v3/package_with_entrypoint-0.1.dist-info/entry_points.txt b/tests/v3/package_with_entrypoint-0.1.dist-info/entry_points.txt index 93eaebd0f4..8457b61d9d 100644 --- a/tests/v3/package_with_entrypoint-0.1.dist-info/entry_points.txt +++ b/tests/v3/package_with_entrypoint-0.1.dist-info/entry_points.txt @@ -2,3 +2,5 @@ test = package_with_entrypoint:TestCodec [zarr] codec_pipeline = package_with_entrypoint:TestCodecPipeline +buffer = package_with_entrypoint:TestBuffer +ndbuffer = package_with_entrypoint:TestNDBuffer \ No newline at end of file diff --git a/tests/v3/package_with_entrypoint/__init__.py b/tests/v3/package_with_entrypoint/__init__.py index 63e51a52db..cbd70470f1 100644 --- a/tests/v3/package_with_entrypoint/__init__.py +++ b/tests/v3/package_with_entrypoint/__init__.py @@ -1,7 +1,10 @@ +from collections.abc import Iterable + from numpy import ndarray -from zarr.abc.codec import ArrayBytesCodec, CodecPipeline +from zarr.abc.codec import ArrayBytesCodec, CodecInput, CodecPipeline from zarr.array_spec import ArraySpec +from zarr.buffer import Buffer, NDBuffer from zarr.common import BytesLike @@ -10,15 +13,13 @@ class TestCodec(ArrayBytesCodec): async def encode( self, - chunk_array: ndarray, - chunk_spec: ArraySpec, + chunks_and_specs: Iterable[tuple[CodecInput | None, ArraySpec]], ) -> BytesLike | None: pass async def decode( self, - chunk_bytes: BytesLike, - chunk_spec: ArraySpec, + chunks_and_specs: Iterable[tuple[CodecInput | None, ArraySpec]], ) -> ndarray: pass @@ -30,8 +31,20 @@ class TestCodecPipeline(CodecPipeline): def __init__(self, batch_size: int = 1): pass - async def encode(self, chunk_array: ndarray, chunk_spec: ArraySpec) -> BytesLike: + async def encode( + self, chunks_and_specs: Iterable[tuple[CodecInput | None, ArraySpec]] + ) -> BytesLike: pass - async def decode(self, chunk_bytes: BytesLike, chunk_spec: ArraySpec) -> ndarray: + async def decode( + self, chunks_and_specs: Iterable[tuple[CodecInput | None, ArraySpec]] + ) -> ndarray: pass + + +class TestBuffer(Buffer): + pass + + +class TestNDBuffer(NDBuffer): + pass diff --git a/tests/v3/test_codec_entrypoints.py b/tests/v3/test_codec_entrypoints.py index f9e0480951..280dfae2b4 100644 --- a/tests/v3/test_codec_entrypoints.py +++ b/tests/v3/test_codec_entrypoints.py @@ -15,9 +15,9 @@ def set_path(): zarr.registry._collect_entrypoints() yield sys.path.remove(here) - lazy_load_codecs, lazy_load_pipelines = zarr.registry._collect_entrypoints() - lazy_load_codecs.pop("test") - lazy_load_pipelines.clear() + lazy_load_lists = zarr.registry._collect_entrypoints() + for lazy_load_list in lazy_load_lists: + lazy_load_list.clear() config.reset() @@ -33,3 +33,10 @@ def test_entrypoint_pipeline(): config.set({"codec_pipeline.name": "TestCodecPipeline"}) cls = zarr.registry.get_pipeline_class() assert cls.__name__ == "TestCodecPipeline" + + +@pytest.mark.usefixtures("set_path") +def test_entrypoint_buffer(): + config.set({"buffer.name": "TestBuffer", "ndbuffer.name": "TestNDBuffer"}) + assert zarr.registry.get_buffer_class().__name__ == "TestBuffer" + assert zarr.registry.get_ndbuffer_class().__name__ == "TestNDBuffer" From ffe583285ccac94ec8c2ba5a315d31bec98d6c2e Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Mon, 24 Jun 2024 17:31:36 +0200 Subject: [PATCH 17/36] better assertion message --- src/zarr/codecs/sharding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 06e0ff48e4..0ff3151197 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -67,7 +67,7 @@ async def get( assert byte_range is None, "byte_range is not supported within shards" assert ( prototype == default_buffer_prototype() - ), "prototype is not supported within shards currently" + ), f"prototype is not supported within shards currently. diff: {prototype} != {default_buffer_prototype()}" return self.shard_dict.get(self.chunk_coords) From d07a127aca4e3ac1e4784d44f8f41cf60a807fa7 Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Mon, 24 Jun 2024 19:03:41 +0200 Subject: [PATCH 18/36] fix merge --- src/zarr/array.py | 8 ++++---- src/zarr/codecs/pipeline.py | 1 - src/zarr/codecs/sharding.py | 22 ++++++++++++++-------- src/zarr/codecs/transpose.py | 1 - src/zarr/metadata.py | 3 +-- tests/v3/test_indexing.py | 2 +- 6 files changed, 20 insertions(+), 17 deletions(-) diff --git a/src/zarr/array.py b/src/zarr/array.py index 42ae3e908b..aea0fbd756 100644 --- a/src/zarr/array.py +++ b/src/zarr/array.py @@ -25,7 +25,7 @@ from zarr.chunk_key_encodings import ChunkKeyEncoding, DefaultChunkKeyEncoding, V2ChunkKeyEncoding from zarr.codecs import BytesCodec from zarr.codecs._v2 import V2Compressor, V2Filters -from zarr.codecs.pipeline import BatchedCodecPipeline +from zarr.registry import get_pipeline_class from zarr.common import ( JSON, ZARR_JSON, @@ -76,11 +76,11 @@ def parse_array_metadata(data: Any) -> ArrayV2Metadata | ArrayV3Metadata: raise TypeError -def create_codec_pipeline(metadata: ArrayV2Metadata | ArrayV3Metadata) -> BatchedCodecPipeline: +def create_codec_pipeline(metadata: ArrayV2Metadata | ArrayV3Metadata) -> CodecPipeline: if isinstance(metadata, ArrayV3Metadata): - return BatchedCodecPipeline.from_list(metadata.codecs) + return get_pipeline_class().from_list(metadata.codecs) elif isinstance(metadata, ArrayV2Metadata): - return BatchedCodecPipeline.from_list( + return get_pipeline_class().from_list( [V2Filters(metadata.filters or []), V2Compressor(metadata.compressor)] ) else: diff --git a/src/zarr/codecs/pipeline.py b/src/zarr/codecs/pipeline.py index 51f7c16106..a0e392da86 100644 --- a/src/zarr/codecs/pipeline.py +++ b/src/zarr/codecs/pipeline.py @@ -20,7 +20,6 @@ from zarr.abc.store import ByteGetter, ByteSetter from zarr.buffer import Buffer, BufferPrototype, NDBuffer from zarr.chunk_grids import ChunkGrid -from zarr.codecs.registry import get_codec_class from zarr.common import JSON, ChunkCoords, concurrent_map, parse_named_configuration from zarr.config import config from zarr.indexing import SelectorTuple, is_scalar, is_total_slice diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 7167cc14b2..4553efc521 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -14,7 +14,7 @@ ArrayBytesCodec, ArrayBytesCodecPartialDecodeMixin, ArrayBytesCodecPartialEncodeMixin, - Codec, + Codec, CodecPipeline, ) from zarr.abc.store import ByteGetter, ByteSetter from zarr.array_spec import ArraySpec @@ -32,7 +32,7 @@ ) from zarr.indexing import BasicIndexer, SelectorTuple, c_order_iter, get_indexer, morton_order_iter from zarr.metadata import parse_codecs -from zarr.registry import get_ndbuffer_class, register_codec +from zarr.registry import get_ndbuffer_class, get_pipeline_class, register_codec if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Iterator @@ -337,8 +337,8 @@ def from_dict(cls, data: dict[str, JSON]) -> Self: return cls(**configuration_parsed) # type: ignore[arg-type] @property - def codec_pipeline(self) -> BatchedCodecPipeline: - return BatchedCodecPipeline.from_list(self.codecs) + def codec_pipeline(self) -> CodecPipeline: + return get_pipeline_class().from_list(self.codecs) def to_dict(self) -> dict[str, JSON]: return { @@ -584,7 +584,7 @@ async def _decode_shard_index( ) -> _ShardIndex: index_array = next( iter( - await BatchedCodecPipeline.from_list(self.index_codecs).decode( + await get_pipeline_class().from_list(self.index_codecs).decode( [(index_bytes, self._get_index_chunk_spec(chunks_per_shard))], ) ) @@ -595,7 +595,9 @@ async def _decode_shard_index( async def _encode_shard_index(self, index: _ShardIndex) -> Buffer: index_bytes = next( iter( - await BatchedCodecPipeline.from_list(self.index_codecs).encode( + await get_pipeline_class() + .from_list(self.index_codecs) + .encode( [ ( get_ndbuffer_class().from_numpy_array(index.offsets_and_lengths), @@ -610,8 +612,12 @@ async def _encode_shard_index(self, index: _ShardIndex) -> Buffer: return index_bytes def _shard_index_size(self, chunks_per_shard: ChunkCoords) -> int: - return BatchedCodecPipeline.from_list(self.index_codecs).compute_encoded_size( - 16 * product(chunks_per_shard), self._get_index_chunk_spec(chunks_per_shard) + return ( + get_pipeline_class() + .from_list(self.index_codecs) + .compute_encoded_size( + 16 * product(chunks_per_shard), self._get_index_chunk_spec(chunks_per_shard) + ) ) def _get_index_chunk_spec(self, chunks_per_shard: ChunkCoords) -> ArraySpec: diff --git a/src/zarr/codecs/transpose.py b/src/zarr/codecs/transpose.py index 9fc522ee26..2501a2dd31 100644 --- a/src/zarr/codecs/transpose.py +++ b/src/zarr/codecs/transpose.py @@ -10,7 +10,6 @@ from zarr.array_spec import ArraySpec from zarr.buffer import NDBuffer from zarr.chunk_grids import ChunkGrid -from zarr.codecs.registry import register_codec from zarr.common import JSON, ChunkCoordsLike, parse_named_configuration from zarr.registry import register_codec diff --git a/src/zarr/metadata.py b/src/zarr/metadata.py index a09a97757c..72f2bb991f 100644 --- a/src/zarr/metadata.py +++ b/src/zarr/metadata.py @@ -15,9 +15,8 @@ from zarr.buffer import Buffer, BufferPrototype, default_buffer_prototype from zarr.chunk_grids import ChunkGrid, RegularChunkGrid from zarr.chunk_key_encodings import ChunkKeyEncoding, parse_separator -from zarr.codecs.registry import get_codec_class from zarr.config import config -from zarr.registry import get_pipeline_class +from zarr.registry import get_pipeline_class, get_codec_class if TYPE_CHECKING: from typing_extensions import Self diff --git a/tests/v3/test_indexing.py b/tests/v3/test_indexing.py index 0eab610f91..51a2375300 100644 --- a/tests/v3/test_indexing.py +++ b/tests/v3/test_indexing.py @@ -12,7 +12,7 @@ import zarr from zarr.abc.store import Store -from zarr.buffer import BufferPrototype +from zarr.buffer import BufferPrototype, NDBuffer from zarr.common import ChunkCoords from zarr.indexing import ( make_slice_selection, From 7448f36ad01c62aca0da25956da27d381b2752c9 Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Mon, 24 Jun 2024 19:03:41 +0200 Subject: [PATCH 19/36] fix merge --- src/zarr/abc/codec.py | 2 +- src/zarr/array.py | 8 ++++---- src/zarr/codecs/pipeline.py | 1 - src/zarr/codecs/sharding.py | 22 ++++++++++++++-------- src/zarr/codecs/transpose.py | 1 - src/zarr/metadata.py | 3 +-- tests/v3/test_indexing.py | 2 +- 7 files changed, 21 insertions(+), 18 deletions(-) diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index cc5e47f3b6..727c6c23c1 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -265,7 +265,7 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: @classmethod @abstractmethod - def from_list(cls, codecs: list[Codec]) -> Self: + def from_list(cls, codecs: Iterable[Codec]) -> Self: """Creates a codec pipeline from a list of codecs. Parameters diff --git a/src/zarr/array.py b/src/zarr/array.py index 42ae3e908b..aea0fbd756 100644 --- a/src/zarr/array.py +++ b/src/zarr/array.py @@ -25,7 +25,7 @@ from zarr.chunk_key_encodings import ChunkKeyEncoding, DefaultChunkKeyEncoding, V2ChunkKeyEncoding from zarr.codecs import BytesCodec from zarr.codecs._v2 import V2Compressor, V2Filters -from zarr.codecs.pipeline import BatchedCodecPipeline +from zarr.registry import get_pipeline_class from zarr.common import ( JSON, ZARR_JSON, @@ -76,11 +76,11 @@ def parse_array_metadata(data: Any) -> ArrayV2Metadata | ArrayV3Metadata: raise TypeError -def create_codec_pipeline(metadata: ArrayV2Metadata | ArrayV3Metadata) -> BatchedCodecPipeline: +def create_codec_pipeline(metadata: ArrayV2Metadata | ArrayV3Metadata) -> CodecPipeline: if isinstance(metadata, ArrayV3Metadata): - return BatchedCodecPipeline.from_list(metadata.codecs) + return get_pipeline_class().from_list(metadata.codecs) elif isinstance(metadata, ArrayV2Metadata): - return BatchedCodecPipeline.from_list( + return get_pipeline_class().from_list( [V2Filters(metadata.filters or []), V2Compressor(metadata.compressor)] ) else: diff --git a/src/zarr/codecs/pipeline.py b/src/zarr/codecs/pipeline.py index 51f7c16106..a0e392da86 100644 --- a/src/zarr/codecs/pipeline.py +++ b/src/zarr/codecs/pipeline.py @@ -20,7 +20,6 @@ from zarr.abc.store import ByteGetter, ByteSetter from zarr.buffer import Buffer, BufferPrototype, NDBuffer from zarr.chunk_grids import ChunkGrid -from zarr.codecs.registry import get_codec_class from zarr.common import JSON, ChunkCoords, concurrent_map, parse_named_configuration from zarr.config import config from zarr.indexing import SelectorTuple, is_scalar, is_total_slice diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 7167cc14b2..4553efc521 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -14,7 +14,7 @@ ArrayBytesCodec, ArrayBytesCodecPartialDecodeMixin, ArrayBytesCodecPartialEncodeMixin, - Codec, + Codec, CodecPipeline, ) from zarr.abc.store import ByteGetter, ByteSetter from zarr.array_spec import ArraySpec @@ -32,7 +32,7 @@ ) from zarr.indexing import BasicIndexer, SelectorTuple, c_order_iter, get_indexer, morton_order_iter from zarr.metadata import parse_codecs -from zarr.registry import get_ndbuffer_class, register_codec +from zarr.registry import get_ndbuffer_class, get_pipeline_class, register_codec if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Iterator @@ -337,8 +337,8 @@ def from_dict(cls, data: dict[str, JSON]) -> Self: return cls(**configuration_parsed) # type: ignore[arg-type] @property - def codec_pipeline(self) -> BatchedCodecPipeline: - return BatchedCodecPipeline.from_list(self.codecs) + def codec_pipeline(self) -> CodecPipeline: + return get_pipeline_class().from_list(self.codecs) def to_dict(self) -> dict[str, JSON]: return { @@ -584,7 +584,7 @@ async def _decode_shard_index( ) -> _ShardIndex: index_array = next( iter( - await BatchedCodecPipeline.from_list(self.index_codecs).decode( + await get_pipeline_class().from_list(self.index_codecs).decode( [(index_bytes, self._get_index_chunk_spec(chunks_per_shard))], ) ) @@ -595,7 +595,9 @@ async def _decode_shard_index( async def _encode_shard_index(self, index: _ShardIndex) -> Buffer: index_bytes = next( iter( - await BatchedCodecPipeline.from_list(self.index_codecs).encode( + await get_pipeline_class() + .from_list(self.index_codecs) + .encode( [ ( get_ndbuffer_class().from_numpy_array(index.offsets_and_lengths), @@ -610,8 +612,12 @@ async def _encode_shard_index(self, index: _ShardIndex) -> Buffer: return index_bytes def _shard_index_size(self, chunks_per_shard: ChunkCoords) -> int: - return BatchedCodecPipeline.from_list(self.index_codecs).compute_encoded_size( - 16 * product(chunks_per_shard), self._get_index_chunk_spec(chunks_per_shard) + return ( + get_pipeline_class() + .from_list(self.index_codecs) + .compute_encoded_size( + 16 * product(chunks_per_shard), self._get_index_chunk_spec(chunks_per_shard) + ) ) def _get_index_chunk_spec(self, chunks_per_shard: ChunkCoords) -> ArraySpec: diff --git a/src/zarr/codecs/transpose.py b/src/zarr/codecs/transpose.py index 9fc522ee26..2501a2dd31 100644 --- a/src/zarr/codecs/transpose.py +++ b/src/zarr/codecs/transpose.py @@ -10,7 +10,6 @@ from zarr.array_spec import ArraySpec from zarr.buffer import NDBuffer from zarr.chunk_grids import ChunkGrid -from zarr.codecs.registry import register_codec from zarr.common import JSON, ChunkCoordsLike, parse_named_configuration from zarr.registry import register_codec diff --git a/src/zarr/metadata.py b/src/zarr/metadata.py index a09a97757c..72f2bb991f 100644 --- a/src/zarr/metadata.py +++ b/src/zarr/metadata.py @@ -15,9 +15,8 @@ from zarr.buffer import Buffer, BufferPrototype, default_buffer_prototype from zarr.chunk_grids import ChunkGrid, RegularChunkGrid from zarr.chunk_key_encodings import ChunkKeyEncoding, parse_separator -from zarr.codecs.registry import get_codec_class from zarr.config import config -from zarr.registry import get_pipeline_class +from zarr.registry import get_pipeline_class, get_codec_class if TYPE_CHECKING: from typing_extensions import Self diff --git a/tests/v3/test_indexing.py b/tests/v3/test_indexing.py index 0eab610f91..51a2375300 100644 --- a/tests/v3/test_indexing.py +++ b/tests/v3/test_indexing.py @@ -12,7 +12,7 @@ import zarr from zarr.abc.store import Store -from zarr.buffer import BufferPrototype +from zarr.buffer import BufferPrototype, NDBuffer from zarr.common import ChunkCoords from zarr.indexing import ( make_slice_selection, From 57ad3b4536676c46b6dc42ab11dfc0e1832d778e Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Mon, 24 Jun 2024 19:10:09 +0200 Subject: [PATCH 20/36] formatting --- src/zarr/array.py | 3 +-- src/zarr/codecs/sharding.py | 7 +++++-- src/zarr/metadata.py | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/zarr/array.py b/src/zarr/array.py index aea0fbd756..4f9f9193ab 100644 --- a/src/zarr/array.py +++ b/src/zarr/array.py @@ -1,7 +1,6 @@ from __future__ import annotations import json - # Notes on what I've changed here: # 1. Split Array into AsyncArray and Array # 3. Added .size and .attrs methods @@ -25,7 +24,6 @@ from zarr.chunk_key_encodings import ChunkKeyEncoding, DefaultChunkKeyEncoding, V2ChunkKeyEncoding from zarr.codecs import BytesCodec from zarr.codecs._v2 import V2Compressor, V2Filters -from zarr.registry import get_pipeline_class from zarr.common import ( JSON, ZARR_JSON, @@ -61,6 +59,7 @@ pop_fields, ) from zarr.metadata import ArrayMetadata, ArrayV2Metadata, ArrayV3Metadata +from zarr.registry import get_pipeline_class from zarr.store import StoreLike, StorePath, make_store_path from zarr.sync import sync diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 4553efc521..af61317c3e 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -14,7 +14,8 @@ ArrayBytesCodec, ArrayBytesCodecPartialDecodeMixin, ArrayBytesCodecPartialEncodeMixin, - Codec, CodecPipeline, + Codec, + CodecPipeline, ) from zarr.abc.store import ByteGetter, ByteSetter from zarr.array_spec import ArraySpec @@ -584,7 +585,9 @@ async def _decode_shard_index( ) -> _ShardIndex: index_array = next( iter( - await get_pipeline_class().from_list(self.index_codecs).decode( + await get_pipeline_class() + .from_list(self.index_codecs) + .decode( [(index_bytes, self._get_index_chunk_spec(chunks_per_shard))], ) ) diff --git a/src/zarr/metadata.py b/src/zarr/metadata.py index 72f2bb991f..23b5d860df 100644 --- a/src/zarr/metadata.py +++ b/src/zarr/metadata.py @@ -16,7 +16,7 @@ from zarr.chunk_grids import ChunkGrid, RegularChunkGrid from zarr.chunk_key_encodings import ChunkKeyEncoding, parse_separator from zarr.config import config -from zarr.registry import get_pipeline_class, get_codec_class +from zarr.registry import get_codec_class, get_pipeline_class if TYPE_CHECKING: from typing_extensions import Self From 0b2cf9ab6658a32d310b000bc839e7bca7284a06 Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Mon, 24 Jun 2024 19:20:12 +0200 Subject: [PATCH 21/36] fix mypy --- src/zarr/indexing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/zarr/indexing.py b/src/zarr/indexing.py index 74cbbe8c6b..d3d791759c 100644 --- a/src/zarr/indexing.py +++ b/src/zarr/indexing.py @@ -1016,6 +1016,7 @@ def __init__(self, selection: CoordinateSelection, shape: ChunkCoords, chunk_gri # broadcast selection - this will raise error if array dimensions don't match selection_broadcast = tuple(np.broadcast_arrays(*selection_normalized)) chunks_multi_index_broadcast = np.broadcast_arrays(*chunks_multi_index) + cast(list[np.ndarray[Any, np.dtype[Any]]], chunks_multi_index_broadcast) # remember shape of selection, because we will flatten indices for processing sel_shape = selection_broadcast[0].shape if selection_broadcast[0].shape else (1,) From 4f6d690422f40d999b15affa4669e4352207537c Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Mon, 24 Jun 2024 19:29:56 +0200 Subject: [PATCH 22/36] fix ruff formatting --- src/zarr/array.py | 363 +++++++++++++++++++++++----------------------- 1 file changed, 181 insertions(+), 182 deletions(-) diff --git a/src/zarr/array.py b/src/zarr/array.py index 4f9f9193ab..37593f9362 100644 --- a/src/zarr/array.py +++ b/src/zarr/array.py @@ -94,10 +94,10 @@ class AsyncArray: order: Literal["C", "F"] def __init__( - self, - metadata: ArrayMetadata, - store_path: StorePath, - order: Literal["C", "F"] | None = None, + self, + metadata: ArrayMetadata, + store_path: StorePath, + order: Literal["C", "F"] | None = None, ): metadata_parsed = parse_array_metadata(metadata) order_parsed = parse_indexing_order(order or config.get("array.order")) @@ -109,33 +109,33 @@ def __init__( @classmethod async def create( - cls, - store: StoreLike, - *, - # v2 and v3 - shape: ChunkCoords, - dtype: npt.DTypeLike, - zarr_format: ZarrFormat = 3, - fill_value: Any | None = None, - attributes: dict[str, JSON] | None = None, - # v3 only - chunk_shape: ChunkCoords | None = None, - chunk_key_encoding: ( - ChunkKeyEncoding - | tuple[Literal["default"], Literal[".", "/"]] - | tuple[Literal["v2"], Literal[".", "/"]] - | None - ) = None, - codecs: Iterable[Codec | dict[str, JSON]] | None = None, - dimension_names: Iterable[str] | None = None, - # v2 only - chunks: ChunkCoords | None = None, - dimension_separator: Literal[".", "/"] | None = None, - order: Literal["C", "F"] | None = None, - filters: list[dict[str, JSON]] | None = None, - compressor: dict[str, JSON] | None = None, - # runtime - exists_ok: bool = False, + cls, + store: StoreLike, + *, + # v2 and v3 + shape: ChunkCoords, + dtype: npt.DTypeLike, + zarr_format: ZarrFormat = 3, + fill_value: Any | None = None, + attributes: dict[str, JSON] | None = None, + # v3 only + chunk_shape: ChunkCoords | None = None, + chunk_key_encoding: ( + ChunkKeyEncoding + | tuple[Literal["default"], Literal[".", "/"]] + | tuple[Literal["v2"], Literal[".", "/"]] + | None + ) = None, + codecs: Iterable[Codec | dict[str, JSON]] | None = None, + dimension_names: Iterable[str] | None = None, + # v2 only + chunks: ChunkCoords | None = None, + dimension_separator: Literal[".", "/"] | None = None, + order: Literal["C", "F"] | None = None, + filters: list[dict[str, JSON]] | None = None, + compressor: dict[str, JSON] | None = None, + # runtime + exists_ok: bool = False, ) -> AsyncArray: store_path = make_store_path(store) @@ -204,23 +204,23 @@ async def create( @classmethod async def _create_v3( - cls, - store_path: StorePath, - *, - shape: ChunkCoords, - dtype: npt.DTypeLike, - chunk_shape: ChunkCoords, - fill_value: Any | None = None, - chunk_key_encoding: ( - ChunkKeyEncoding - | tuple[Literal["default"], Literal[".", "/"]] - | tuple[Literal["v2"], Literal[".", "/"]] - | None - ) = None, - codecs: Iterable[Codec | dict[str, JSON]] | None = None, - dimension_names: Iterable[str] | None = None, - attributes: dict[str, JSON] | None = None, - exists_ok: bool = False, + cls, + store_path: StorePath, + *, + shape: ChunkCoords, + dtype: npt.DTypeLike, + chunk_shape: ChunkCoords, + fill_value: Any | None = None, + chunk_key_encoding: ( + ChunkKeyEncoding + | tuple[Literal["default"], Literal[".", "/"]] + | tuple[Literal["v2"], Literal[".", "/"]] + | None + ) = None, + codecs: Iterable[Codec | dict[str, JSON]] | None = None, + dimension_names: Iterable[str] | None = None, + attributes: dict[str, JSON] | None = None, + exists_ok: bool = False, ) -> AsyncArray: if not exists_ok: assert not await (store_path / ZARR_JSON).exists() @@ -262,19 +262,19 @@ async def _create_v3( @classmethod async def _create_v2( - cls, - store_path: StorePath, - *, - shape: ChunkCoords, - dtype: npt.DTypeLike, - chunks: ChunkCoords, - dimension_separator: Literal[".", "/"] | None = None, - fill_value: None | int | float = None, - order: Literal["C", "F"] | None = None, - filters: list[dict[str, JSON]] | None = None, - compressor: dict[str, JSON] | None = None, - attributes: dict[str, JSON] | None = None, - exists_ok: bool = False, + cls, + store_path: StorePath, + *, + shape: ChunkCoords, + dtype: npt.DTypeLike, + chunks: ChunkCoords, + dimension_separator: Literal[".", "/"] | None = None, + fill_value: None | int | float = None, + order: Literal["C", "F"] | None = None, + filters: list[dict[str, JSON]] | None = None, + compressor: dict[str, JSON] | None = None, + attributes: dict[str, JSON] | None = None, + exists_ok: bool = False, ) -> AsyncArray: import numcodecs @@ -310,9 +310,9 @@ async def _create_v2( @classmethod def from_dict( - cls, - store_path: StorePath, - data: dict[str, JSON], + cls, + store_path: StorePath, + data: dict[str, JSON], ) -> AsyncArray: metadata = parse_array_metadata(data) async_array = cls(metadata=metadata, store_path=store_path) @@ -320,9 +320,9 @@ def from_dict( @classmethod async def open( - cls, - store: StoreLike, - zarr_format: ZarrFormat | None = 3, + cls, + store: StoreLike, + zarr_format: ZarrFormat | None = 3, ) -> AsyncArray: store_path = make_store_path(store) @@ -428,12 +428,12 @@ def basename(self) -> str | None: return None async def _get_selection( - self, - indexer: Indexer, - *, - prototype: BufferPrototype, - out: NDBuffer | None = None, - fields: Fields | None = None, + self, + indexer: Indexer, + *, + prototype: BufferPrototype, + out: NDBuffer | None = None, + fields: Fields | None = None, ) -> NDArrayLike: # check fields are sensible out_dtype = check_fields(fields, self.dtype) @@ -473,10 +473,10 @@ async def _get_selection( return out_buffer.as_ndarray_like() async def getitem( - self, - selection: BasicSelection, - *, - prototype: BufferPrototype | None = None, + self, + selection: BasicSelection, + *, + prototype: BufferPrototype | None = None, ) -> NDArrayLike: if prototype is None: prototype = default_buffer_prototype() @@ -493,12 +493,12 @@ async def _save_metadata(self, metadata: ArrayMetadata) -> None: await gather(*awaitables) async def _set_selection( - self, - indexer: Indexer, - value: npt.ArrayLike, - *, - prototype: BufferPrototype, - fields: Fields | None = None, + self, + indexer: Indexer, + value: npt.ArrayLike, + *, + prototype: BufferPrototype, + fields: Fields | None = None, ) -> None: # check fields are sensible check_fields(fields, self.dtype) @@ -537,10 +537,10 @@ async def _set_selection( ) async def setitem( - self, - selection: BasicSelection, - value: npt.ArrayLike, - prototype: BufferPrototype | None = None, + self, + selection: BasicSelection, + value: npt.ArrayLike, + prototype: BufferPrototype | None = None, ) -> None: if prototype is None: prototype = default_buffer_prototype() @@ -552,7 +552,7 @@ async def setitem( return await self._set_selection(indexer, value, prototype=prototype) async def resize( - self, new_shape: ChunkCoords, delete_outside_chunks: bool = True + self, new_shape: ChunkCoords, delete_outside_chunks: bool = True ) -> AsyncArray: assert len(new_shape) == len(self.metadata.shape) new_metadata = self.metadata.update_shape(new_shape) @@ -562,7 +562,6 @@ async def resize( new_chunk_coords = set(self.metadata.chunk_grid.all_chunk_coords(new_shape)) if delete_outside_chunks: - async def _delete_key(key: str) -> None: await (self.store_path / key).delete() @@ -599,33 +598,33 @@ class Array: @classmethod def create( - cls, - store: StoreLike, - *, - # v2 and v3 - shape: ChunkCoords, - dtype: npt.DTypeLike, - zarr_format: ZarrFormat = 3, - fill_value: Any | None = None, - attributes: dict[str, JSON] | None = None, - # v3 only - chunk_shape: ChunkCoords | None = None, - chunk_key_encoding: ( - ChunkKeyEncoding - | tuple[Literal["default"], Literal[".", "/"]] - | tuple[Literal["v2"], Literal[".", "/"]] - | None - ) = None, - codecs: Iterable[Codec | dict[str, JSON]] | None = None, - dimension_names: Iterable[str] | None = None, - # v2 only - chunks: ChunkCoords | None = None, - dimension_separator: Literal[".", "/"] | None = None, - order: Literal["C", "F"] | None = None, - filters: list[dict[str, JSON]] | None = None, - compressor: dict[str, JSON] | None = None, - # runtime - exists_ok: bool = False, + cls, + store: StoreLike, + *, + # v2 and v3 + shape: ChunkCoords, + dtype: npt.DTypeLike, + zarr_format: ZarrFormat = 3, + fill_value: Any | None = None, + attributes: dict[str, JSON] | None = None, + # v3 only + chunk_shape: ChunkCoords | None = None, + chunk_key_encoding: ( + ChunkKeyEncoding + | tuple[Literal["default"], Literal[".", "/"]] + | tuple[Literal["v2"], Literal[".", "/"]] + | None + ) = None, + codecs: Iterable[Codec | dict[str, JSON]] | None = None, + dimension_names: Iterable[str] | None = None, + # v2 only + chunks: ChunkCoords | None = None, + dimension_separator: Literal[".", "/"] | None = None, + order: Literal["C", "F"] | None = None, + filters: list[dict[str, JSON]] | None = None, + compressor: dict[str, JSON] | None = None, + # runtime + exists_ok: bool = False, ) -> Array: async_array = sync( AsyncArray.create( @@ -651,17 +650,17 @@ def create( @classmethod def from_dict( - cls, - store_path: StorePath, - data: dict[str, JSON], + cls, + store_path: StorePath, + data: dict[str, JSON], ) -> Array: async_array = AsyncArray.from_dict(store_path=store_path, data=data) return cls(async_array) @classmethod def open( - cls, - store: StoreLike, + cls, + store: StoreLike, ) -> Array: async_array = sync(AsyncArray.open(store)) return cls(async_array) @@ -722,7 +721,7 @@ def read_only(self) -> bool: return self._async_array.read_only def __array__( - self, dtype: npt.DTypeLike | None = None, copy: bool | None = None + self, dtype: npt.DTypeLike | None = None, copy: bool | None = None ) -> NDArrayLike: """ This method is used by numpy when converting zarr.Array into a numpy array. @@ -990,12 +989,12 @@ def __setitem__(self, selection: Selection, value: npt.ArrayLike) -> None: self.set_basic_selection(cast(BasicSelection, pure_selection), value, fields=fields) def get_basic_selection( - self, - selection: BasicSelection = Ellipsis, - *, - out: NDBuffer | None = None, - prototype: BufferPrototype | None = None, - fields: Fields | None = None, + self, + selection: BasicSelection = Ellipsis, + *, + out: NDBuffer | None = None, + prototype: BufferPrototype | None = None, + fields: Fields | None = None, ) -> NDArrayLike: """Retrieve data for an item or region of the array. @@ -1113,12 +1112,12 @@ def get_basic_selection( ) def set_basic_selection( - self, - selection: BasicSelection, - value: npt.ArrayLike, - *, - fields: Fields | None = None, - prototype: BufferPrototype | None = None, + self, + selection: BasicSelection, + value: npt.ArrayLike, + *, + fields: Fields | None = None, + prototype: BufferPrototype | None = None, ) -> None: """Modify data for an item or region of the array. @@ -1208,12 +1207,12 @@ def set_basic_selection( sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) def get_orthogonal_selection( - self, - selection: OrthogonalSelection, - *, - out: NDBuffer | None = None, - fields: Fields | None = None, - prototype: BufferPrototype | None = None, + self, + selection: OrthogonalSelection, + *, + out: NDBuffer | None = None, + fields: Fields | None = None, + prototype: BufferPrototype | None = None, ) -> NDArrayLike: """Retrieve data by making a selection for each dimension of the array. For example, if an array has 2 dimensions, allows selecting specific rows and/or @@ -1332,12 +1331,12 @@ def get_orthogonal_selection( ) def set_orthogonal_selection( - self, - selection: OrthogonalSelection, - value: npt.ArrayLike, - *, - fields: Fields | None = None, - prototype: BufferPrototype | None = None, + self, + selection: OrthogonalSelection, + value: npt.ArrayLike, + *, + fields: Fields | None = None, + prototype: BufferPrototype | None = None, ) -> None: """Modify data via a selection for each dimension of the array. @@ -1442,12 +1441,12 @@ def set_orthogonal_selection( ) def get_mask_selection( - self, - mask: MaskSelection, - *, - out: NDBuffer | None = None, - fields: Fields | None = None, - prototype: BufferPrototype | None = None, + self, + mask: MaskSelection, + *, + out: NDBuffer | None = None, + fields: Fields | None = None, + prototype: BufferPrototype | None = None, ) -> NDArrayLike: """Retrieve a selection of individual items, by providing a Boolean array of the same shape as the array against which the selection is being made, where True @@ -1524,12 +1523,12 @@ def get_mask_selection( ) def set_mask_selection( - self, - mask: MaskSelection, - value: npt.ArrayLike, - *, - fields: Fields | None = None, - prototype: BufferPrototype | None = None, + self, + mask: MaskSelection, + value: npt.ArrayLike, + *, + fields: Fields | None = None, + prototype: BufferPrototype | None = None, ) -> None: """Modify a selection of individual items, by providing a Boolean array of the same shape as the array against which the selection is being made, where True @@ -1602,12 +1601,12 @@ def set_mask_selection( sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) def get_coordinate_selection( - self, - selection: CoordinateSelection, - *, - out: NDBuffer | None = None, - fields: Fields | None = None, - prototype: BufferPrototype | None = None, + self, + selection: CoordinateSelection, + *, + out: NDBuffer | None = None, + fields: Fields | None = None, + prototype: BufferPrototype | None = None, ) -> NDArrayLike: """Retrieve a selection of individual items, by providing the indices (coordinates) for each selected item. @@ -1691,12 +1690,12 @@ def get_coordinate_selection( return out_array def set_coordinate_selection( - self, - selection: CoordinateSelection, - value: npt.ArrayLike, - *, - fields: Fields | None = None, - prototype: BufferPrototype | None = None, + self, + selection: CoordinateSelection, + value: npt.ArrayLike, + *, + fields: Fields | None = None, + prototype: BufferPrototype | None = None, ) -> None: """Modify a selection of individual items, by providing the indices (coordinates) for each item to be modified. @@ -1780,12 +1779,12 @@ def set_coordinate_selection( sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) def get_block_selection( - self, - selection: BasicSelection, - *, - out: NDBuffer | None = None, - fields: Fields | None = None, - prototype: BufferPrototype | None = None, + self, + selection: BasicSelection, + *, + out: NDBuffer | None = None, + fields: Fields | None = None, + prototype: BufferPrototype | None = None, ) -> NDArrayLike: """Retrieve a selection of individual items, by providing the indices (coordinates) for each selected item. @@ -1878,12 +1877,12 @@ def get_block_selection( ) def set_block_selection( - self, - selection: BasicSelection, - value: npt.ArrayLike, - *, - fields: Fields | None = None, - prototype: BufferPrototype | None = None, + self, + selection: BasicSelection, + value: npt.ArrayLike, + *, + fields: Fields | None = None, + prototype: BufferPrototype | None = None, ) -> None: """Modify a selection of individual blocks, by providing the chunk indices (coordinates) for each block to be modified. From 96676b70114e938f9f43518dde2fa51cd8cd249c Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Wed, 26 Jun 2024 13:35:13 +0200 Subject: [PATCH 23/36] fix merge --- src/zarr/array.py | 364 +++++++++++++++++++------------------ src/zarr/indexing.py | 3 +- src/zarr/testing/buffer.py | 5 +- tests/v3/test_buffer.py | 56 ------ 4 files changed, 188 insertions(+), 240 deletions(-) diff --git a/src/zarr/array.py b/src/zarr/array.py index 37593f9362..3662ba1658 100644 --- a/src/zarr/array.py +++ b/src/zarr/array.py @@ -1,6 +1,7 @@ from __future__ import annotations import json + # Notes on what I've changed here: # 1. Split Array into AsyncArray and Array # 3. Added .size and .attrs methods @@ -94,10 +95,10 @@ class AsyncArray: order: Literal["C", "F"] def __init__( - self, - metadata: ArrayMetadata, - store_path: StorePath, - order: Literal["C", "F"] | None = None, + self, + metadata: ArrayMetadata, + store_path: StorePath, + order: Literal["C", "F"] | None = None, ): metadata_parsed = parse_array_metadata(metadata) order_parsed = parse_indexing_order(order or config.get("array.order")) @@ -109,33 +110,33 @@ def __init__( @classmethod async def create( - cls, - store: StoreLike, - *, - # v2 and v3 - shape: ChunkCoords, - dtype: npt.DTypeLike, - zarr_format: ZarrFormat = 3, - fill_value: Any | None = None, - attributes: dict[str, JSON] | None = None, - # v3 only - chunk_shape: ChunkCoords | None = None, - chunk_key_encoding: ( - ChunkKeyEncoding - | tuple[Literal["default"], Literal[".", "/"]] - | tuple[Literal["v2"], Literal[".", "/"]] - | None - ) = None, - codecs: Iterable[Codec | dict[str, JSON]] | None = None, - dimension_names: Iterable[str] | None = None, - # v2 only - chunks: ChunkCoords | None = None, - dimension_separator: Literal[".", "/"] | None = None, - order: Literal["C", "F"] | None = None, - filters: list[dict[str, JSON]] | None = None, - compressor: dict[str, JSON] | None = None, - # runtime - exists_ok: bool = False, + cls, + store: StoreLike, + *, + # v2 and v3 + shape: ChunkCoords, + dtype: npt.DTypeLike, + zarr_format: ZarrFormat = 3, + fill_value: Any | None = None, + attributes: dict[str, JSON] | None = None, + # v3 only + chunk_shape: ChunkCoords | None = None, + chunk_key_encoding: ( + ChunkKeyEncoding + | tuple[Literal["default"], Literal[".", "/"]] + | tuple[Literal["v2"], Literal[".", "/"]] + | None + ) = None, + codecs: Iterable[Codec | dict[str, JSON]] | None = None, + dimension_names: Iterable[str] | None = None, + # v2 only + chunks: ChunkCoords | None = None, + dimension_separator: Literal[".", "/"] | None = None, + order: Literal["C", "F"] | None = None, + filters: list[dict[str, JSON]] | None = None, + compressor: dict[str, JSON] | None = None, + # runtime + exists_ok: bool = False, ) -> AsyncArray: store_path = make_store_path(store) @@ -204,23 +205,23 @@ async def create( @classmethod async def _create_v3( - cls, - store_path: StorePath, - *, - shape: ChunkCoords, - dtype: npt.DTypeLike, - chunk_shape: ChunkCoords, - fill_value: Any | None = None, - chunk_key_encoding: ( - ChunkKeyEncoding - | tuple[Literal["default"], Literal[".", "/"]] - | tuple[Literal["v2"], Literal[".", "/"]] - | None - ) = None, - codecs: Iterable[Codec | dict[str, JSON]] | None = None, - dimension_names: Iterable[str] | None = None, - attributes: dict[str, JSON] | None = None, - exists_ok: bool = False, + cls, + store_path: StorePath, + *, + shape: ChunkCoords, + dtype: npt.DTypeLike, + chunk_shape: ChunkCoords, + fill_value: Any | None = None, + chunk_key_encoding: ( + ChunkKeyEncoding + | tuple[Literal["default"], Literal[".", "/"]] + | tuple[Literal["v2"], Literal[".", "/"]] + | None + ) = None, + codecs: Iterable[Codec | dict[str, JSON]] | None = None, + dimension_names: Iterable[str] | None = None, + attributes: dict[str, JSON] | None = None, + exists_ok: bool = False, ) -> AsyncArray: if not exists_ok: assert not await (store_path / ZARR_JSON).exists() @@ -262,19 +263,19 @@ async def _create_v3( @classmethod async def _create_v2( - cls, - store_path: StorePath, - *, - shape: ChunkCoords, - dtype: npt.DTypeLike, - chunks: ChunkCoords, - dimension_separator: Literal[".", "/"] | None = None, - fill_value: None | int | float = None, - order: Literal["C", "F"] | None = None, - filters: list[dict[str, JSON]] | None = None, - compressor: dict[str, JSON] | None = None, - attributes: dict[str, JSON] | None = None, - exists_ok: bool = False, + cls, + store_path: StorePath, + *, + shape: ChunkCoords, + dtype: npt.DTypeLike, + chunks: ChunkCoords, + dimension_separator: Literal[".", "/"] | None = None, + fill_value: None | int | float = None, + order: Literal["C", "F"] | None = None, + filters: list[dict[str, JSON]] | None = None, + compressor: dict[str, JSON] | None = None, + attributes: dict[str, JSON] | None = None, + exists_ok: bool = False, ) -> AsyncArray: import numcodecs @@ -310,9 +311,9 @@ async def _create_v2( @classmethod def from_dict( - cls, - store_path: StorePath, - data: dict[str, JSON], + cls, + store_path: StorePath, + data: dict[str, JSON], ) -> AsyncArray: metadata = parse_array_metadata(data) async_array = cls(metadata=metadata, store_path=store_path) @@ -320,9 +321,9 @@ def from_dict( @classmethod async def open( - cls, - store: StoreLike, - zarr_format: ZarrFormat | None = 3, + cls, + store: StoreLike, + zarr_format: ZarrFormat | None = 3, ) -> AsyncArray: store_path = make_store_path(store) @@ -428,12 +429,12 @@ def basename(self) -> str | None: return None async def _get_selection( - self, - indexer: Indexer, - *, - prototype: BufferPrototype, - out: NDBuffer | None = None, - fields: Fields | None = None, + self, + indexer: Indexer, + *, + prototype: BufferPrototype, + out: NDBuffer | None = None, + fields: Fields | None = None, ) -> NDArrayLike: # check fields are sensible out_dtype = check_fields(fields, self.dtype) @@ -473,10 +474,10 @@ async def _get_selection( return out_buffer.as_ndarray_like() async def getitem( - self, - selection: BasicSelection, - *, - prototype: BufferPrototype | None = None, + self, + selection: BasicSelection, + *, + prototype: BufferPrototype | None = None, ) -> NDArrayLike: if prototype is None: prototype = default_buffer_prototype() @@ -493,12 +494,12 @@ async def _save_metadata(self, metadata: ArrayMetadata) -> None: await gather(*awaitables) async def _set_selection( - self, - indexer: Indexer, - value: npt.ArrayLike, - *, - prototype: BufferPrototype, - fields: Fields | None = None, + self, + indexer: Indexer, + value: npt.ArrayLike, + *, + prototype: BufferPrototype, + fields: Fields | None = None, ) -> None: # check fields are sensible check_fields(fields, self.dtype) @@ -537,10 +538,10 @@ async def _set_selection( ) async def setitem( - self, - selection: BasicSelection, - value: npt.ArrayLike, - prototype: BufferPrototype | None = None, + self, + selection: BasicSelection, + value: npt.ArrayLike, + prototype: BufferPrototype | None = None, ) -> None: if prototype is None: prototype = default_buffer_prototype() @@ -552,7 +553,7 @@ async def setitem( return await self._set_selection(indexer, value, prototype=prototype) async def resize( - self, new_shape: ChunkCoords, delete_outside_chunks: bool = True + self, new_shape: ChunkCoords, delete_outside_chunks: bool = True ) -> AsyncArray: assert len(new_shape) == len(self.metadata.shape) new_metadata = self.metadata.update_shape(new_shape) @@ -562,6 +563,7 @@ async def resize( new_chunk_coords = set(self.metadata.chunk_grid.all_chunk_coords(new_shape)) if delete_outside_chunks: + async def _delete_key(key: str) -> None: await (self.store_path / key).delete() @@ -598,33 +600,33 @@ class Array: @classmethod def create( - cls, - store: StoreLike, - *, - # v2 and v3 - shape: ChunkCoords, - dtype: npt.DTypeLike, - zarr_format: ZarrFormat = 3, - fill_value: Any | None = None, - attributes: dict[str, JSON] | None = None, - # v3 only - chunk_shape: ChunkCoords | None = None, - chunk_key_encoding: ( - ChunkKeyEncoding - | tuple[Literal["default"], Literal[".", "/"]] - | tuple[Literal["v2"], Literal[".", "/"]] - | None - ) = None, - codecs: Iterable[Codec | dict[str, JSON]] | None = None, - dimension_names: Iterable[str] | None = None, - # v2 only - chunks: ChunkCoords | None = None, - dimension_separator: Literal[".", "/"] | None = None, - order: Literal["C", "F"] | None = None, - filters: list[dict[str, JSON]] | None = None, - compressor: dict[str, JSON] | None = None, - # runtime - exists_ok: bool = False, + cls, + store: StoreLike, + *, + # v2 and v3 + shape: ChunkCoords, + dtype: npt.DTypeLike, + zarr_format: ZarrFormat = 3, + fill_value: Any | None = None, + attributes: dict[str, JSON] | None = None, + # v3 only + chunk_shape: ChunkCoords | None = None, + chunk_key_encoding: ( + ChunkKeyEncoding + | tuple[Literal["default"], Literal[".", "/"]] + | tuple[Literal["v2"], Literal[".", "/"]] + | None + ) = None, + codecs: Iterable[Codec | dict[str, JSON]] | None = None, + dimension_names: Iterable[str] | None = None, + # v2 only + chunks: ChunkCoords | None = None, + dimension_separator: Literal[".", "/"] | None = None, + order: Literal["C", "F"] | None = None, + filters: list[dict[str, JSON]] | None = None, + compressor: dict[str, JSON] | None = None, + # runtime + exists_ok: bool = False, ) -> Array: async_array = sync( AsyncArray.create( @@ -650,17 +652,17 @@ def create( @classmethod def from_dict( - cls, - store_path: StorePath, - data: dict[str, JSON], + cls, + store_path: StorePath, + data: dict[str, JSON], ) -> Array: async_array = AsyncArray.from_dict(store_path=store_path, data=data) return cls(async_array) @classmethod def open( - cls, - store: StoreLike, + cls, + store: StoreLike, ) -> Array: async_array = sync(AsyncArray.open(store)) return cls(async_array) @@ -721,7 +723,7 @@ def read_only(self) -> bool: return self._async_array.read_only def __array__( - self, dtype: npt.DTypeLike | None = None, copy: bool | None = None + self, dtype: npt.DTypeLike | None = None, copy: bool | None = None ) -> NDArrayLike: """ This method is used by numpy when converting zarr.Array into a numpy array. @@ -989,12 +991,12 @@ def __setitem__(self, selection: Selection, value: npt.ArrayLike) -> None: self.set_basic_selection(cast(BasicSelection, pure_selection), value, fields=fields) def get_basic_selection( - self, - selection: BasicSelection = Ellipsis, - *, - out: NDBuffer | None = None, - prototype: BufferPrototype | None = None, - fields: Fields | None = None, + self, + selection: BasicSelection = Ellipsis, + *, + out: NDBuffer | None = None, + prototype: BufferPrototype | None = None, + fields: Fields | None = None, ) -> NDArrayLike: """Retrieve data for an item or region of the array. @@ -1112,12 +1114,12 @@ def get_basic_selection( ) def set_basic_selection( - self, - selection: BasicSelection, - value: npt.ArrayLike, - *, - fields: Fields | None = None, - prototype: BufferPrototype | None = None, + self, + selection: BasicSelection, + value: npt.ArrayLike, + *, + fields: Fields | None = None, + prototype: BufferPrototype | None = None, ) -> None: """Modify data for an item or region of the array. @@ -1207,12 +1209,12 @@ def set_basic_selection( sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) def get_orthogonal_selection( - self, - selection: OrthogonalSelection, - *, - out: NDBuffer | None = None, - fields: Fields | None = None, - prototype: BufferPrototype | None = None, + self, + selection: OrthogonalSelection, + *, + out: NDBuffer | None = None, + fields: Fields | None = None, + prototype: BufferPrototype | None = None, ) -> NDArrayLike: """Retrieve data by making a selection for each dimension of the array. For example, if an array has 2 dimensions, allows selecting specific rows and/or @@ -1331,12 +1333,12 @@ def get_orthogonal_selection( ) def set_orthogonal_selection( - self, - selection: OrthogonalSelection, - value: npt.ArrayLike, - *, - fields: Fields | None = None, - prototype: BufferPrototype | None = None, + self, + selection: OrthogonalSelection, + value: npt.ArrayLike, + *, + fields: Fields | None = None, + prototype: BufferPrototype | None = None, ) -> None: """Modify data via a selection for each dimension of the array. @@ -1441,12 +1443,12 @@ def set_orthogonal_selection( ) def get_mask_selection( - self, - mask: MaskSelection, - *, - out: NDBuffer | None = None, - fields: Fields | None = None, - prototype: BufferPrototype | None = None, + self, + mask: MaskSelection, + *, + out: NDBuffer | None = None, + fields: Fields | None = None, + prototype: BufferPrototype | None = None, ) -> NDArrayLike: """Retrieve a selection of individual items, by providing a Boolean array of the same shape as the array against which the selection is being made, where True @@ -1523,12 +1525,12 @@ def get_mask_selection( ) def set_mask_selection( - self, - mask: MaskSelection, - value: npt.ArrayLike, - *, - fields: Fields | None = None, - prototype: BufferPrototype | None = None, + self, + mask: MaskSelection, + value: npt.ArrayLike, + *, + fields: Fields | None = None, + prototype: BufferPrototype | None = None, ) -> None: """Modify a selection of individual items, by providing a Boolean array of the same shape as the array against which the selection is being made, where True @@ -1601,12 +1603,12 @@ def set_mask_selection( sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) def get_coordinate_selection( - self, - selection: CoordinateSelection, - *, - out: NDBuffer | None = None, - fields: Fields | None = None, - prototype: BufferPrototype | None = None, + self, + selection: CoordinateSelection, + *, + out: NDBuffer | None = None, + fields: Fields | None = None, + prototype: BufferPrototype | None = None, ) -> NDArrayLike: """Retrieve a selection of individual items, by providing the indices (coordinates) for each selected item. @@ -1690,12 +1692,12 @@ def get_coordinate_selection( return out_array def set_coordinate_selection( - self, - selection: CoordinateSelection, - value: npt.ArrayLike, - *, - fields: Fields | None = None, - prototype: BufferPrototype | None = None, + self, + selection: CoordinateSelection, + value: npt.ArrayLike, + *, + fields: Fields | None = None, + prototype: BufferPrototype | None = None, ) -> None: """Modify a selection of individual items, by providing the indices (coordinates) for each item to be modified. @@ -1779,12 +1781,12 @@ def set_coordinate_selection( sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) def get_block_selection( - self, - selection: BasicSelection, - *, - out: NDBuffer | None = None, - fields: Fields | None = None, - prototype: BufferPrototype | None = None, + self, + selection: BasicSelection, + *, + out: NDBuffer | None = None, + fields: Fields | None = None, + prototype: BufferPrototype | None = None, ) -> NDArrayLike: """Retrieve a selection of individual items, by providing the indices (coordinates) for each selected item. @@ -1877,12 +1879,12 @@ def get_block_selection( ) def set_block_selection( - self, - selection: BasicSelection, - value: npt.ArrayLike, - *, - fields: Fields | None = None, - prototype: BufferPrototype | None = None, + self, + selection: BasicSelection, + value: npt.ArrayLike, + *, + fields: Fields | None = None, + prototype: BufferPrototype | None = None, ) -> None: """Modify a selection of individual blocks, by providing the chunk indices (coordinates) for each block to be modified. diff --git a/src/zarr/indexing.py b/src/zarr/indexing.py index bf019e8f1b..a3253a68c3 100644 --- a/src/zarr/indexing.py +++ b/src/zarr/indexing.py @@ -1016,14 +1016,13 @@ def __init__(self, selection: CoordinateSelection, shape: ChunkCoords, chunk_gri # broadcast selection - this will raise error if array dimensions don't match selection_broadcast = tuple(np.broadcast_arrays(*selection_normalized)) chunks_multi_index_broadcast = np.broadcast_arrays(*chunks_multi_index) - cast(list[np.ndarray[Any, np.dtype[Any]]], chunks_multi_index_broadcast) # remember shape of selection, because we will flatten indices for processing sel_shape = selection_broadcast[0].shape if selection_broadcast[0].shape else (1,) # flatten selection selection_broadcast = tuple(dim_sel.reshape(-1) for dim_sel in selection_broadcast) - chunks_multi_index_broadcast = tuple( + chunks_multi_index_broadcast = list( [dim_chunks.reshape(-1) for dim_chunks in chunks_multi_index_broadcast] ) diff --git a/src/zarr/testing/buffer.py b/src/zarr/testing/buffer.py index 0f38b6975d..73dda61714 100644 --- a/src/zarr/testing/buffer.py +++ b/src/zarr/testing/buffer.py @@ -60,4 +60,7 @@ async def get( ) -> Buffer | None: if "json" not in key: assert prototype.buffer is MyBuffer - return await super().get(key, byte_range) + ret = await super().get(key=key, prototype=prototype, byte_range=byte_range) + if ret is not None: + assert isinstance(ret, prototype.buffer) + return ret diff --git a/tests/v3/test_buffer.py b/tests/v3/test_buffer.py index 0af4835673..5f6a948b17 100644 --- a/tests/v3/test_buffer.py +++ b/tests/v3/test_buffer.py @@ -12,64 +12,8 @@ from zarr.codecs.transpose import TransposeCodec from zarr.codecs.zstd import ZstdCodec from zarr.store.core import StorePath -from zarr.store.memory import MemoryStore from zarr.testing.buffer import MyBuffer, MyNDArrayLike, MyNDBuffer, StoreExpectingMyBuffer -if TYPE_CHECKING: - from typing_extensions import Self - - -class MyNDArrayLike(np.ndarray): - """An example of a ndarray-like class""" - - -class MyBuffer(Buffer): - """Example of a custom Buffer that handles ArrayLike""" - - -class MyNDBuffer(NDBuffer): - """Example of a custom NDBuffer that handles MyNDArrayLike""" - - @classmethod - def create( - cls, - *, - shape: Iterable[int], - dtype: npt.DTypeLike, - order: Literal[C, F] = "C", - fill_value: Any | None = None, - ) -> Self: - """Overwrite `NDBuffer.create` to create an MyNDArrayLike instance""" - ret = cls(MyNDArrayLike(shape=shape, dtype=dtype, order=order)) - if fill_value is not None: - ret.fill(fill_value) - return ret - - -class MyStore(MemoryStore): - """Example of a custom Store that expect MyBuffer for all its non-metadata - - We assume that keys containing "json" is metadata - """ - - async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: - if "json" not in key: - assert isinstance(value, MyBuffer) - await super().set(key, value, byte_range) - - async def get( - self, - key: str, - prototype: BufferPrototype, - byte_range: tuple[int, int | None] | None = None, - ) -> Buffer | None: - if "json" not in key: - assert prototype.buffer is MyBuffer - ret = await super().get(key=key, prototype=prototype, byte_range=byte_range) - if ret is not None: - assert isinstance(ret, prototype.buffer) - return ret - def test_nd_array_like(xp): ary = xp.arange(10) From c098d9a43c3624794cbff82c5143976618d4ad82 Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Thu, 27 Jun 2024 15:41:28 +0200 Subject: [PATCH 24/36] fix mypy --- src/zarr/indexing.py | 4 ++-- tests/v3/test_store/test_remote.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/zarr/indexing.py b/src/zarr/indexing.py index dcf1b34cd1..bb5ed660cf 100644 --- a/src/zarr/indexing.py +++ b/src/zarr/indexing.py @@ -1022,8 +1022,8 @@ def __init__(self, selection: CoordinateSelection, shape: ChunkCoords, chunk_gri # flatten selection selection_broadcast = tuple(dim_sel.reshape(-1) for dim_sel in selection_broadcast) - chunks_multi_index_broadcast = list( - dim_chunks.reshape(-1) for dim_chunks in chunks_multi_index_broadcast + chunks_multi_index_broadcast = tuple( + [dim_chunks.reshape(-1) for dim_chunks in chunks_multi_index_broadcast] ) # ravel chunk indices diff --git a/tests/v3/test_store/test_remote.py b/tests/v3/test_store/test_remote.py index 7f6cced05d..fa6fb3a5b7 100644 --- a/tests/v3/test_store/test_remote.py +++ b/tests/v3/test_store/test_remote.py @@ -88,7 +88,7 @@ async def test_basic(): data = b"hello" await store.set("foo", Buffer.from_bytes(data)) assert await store.exists("foo") - assert (await store.get("foo", prototype=default_buffer_prototype)).to_bytes() == data + assert (await store.get("foo", prototype=default_buffer_prototype())).to_bytes() == data out = await store.get_partial_values( prototype=default_buffer_prototype(), key_ranges=[("foo", (1, None))] ) From 97e004bfc2a8306b412b854dd2e8db8ba6f2fa8d Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Mon, 1 Jul 2024 13:28:22 +0200 Subject: [PATCH 25/36] use numpy_buffer_prototype for reading shard index --- src/zarr/buffer.py | 5 +++++ src/zarr/codecs/sharding.py | 14 ++++++++++---- tests/v3/test_buffer.py | 9 ++++++++- tests/v3/test_config.py | 12 +++++++----- 4 files changed, 30 insertions(+), 10 deletions(-) diff --git a/src/zarr/buffer.py b/src/zarr/buffer.py index 774a38900f..9c75ba1410 100644 --- a/src/zarr/buffer.py +++ b/src/zarr/buffer.py @@ -489,5 +489,10 @@ def default_buffer_prototype() -> BufferPrototype: return BufferPrototype(buffer=get_buffer_class(), nd_buffer=get_ndbuffer_class()) +# The numpy prototype used for E.g. when reading the shard index +def numpy_buffer_prototype() -> BufferPrototype: + return BufferPrototype(buffer=Buffer, nd_buffer=NDBuffer) + + register_buffer(Buffer) register_ndbuffer(NDBuffer) diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index af61317c3e..efe71cd630 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -19,7 +19,13 @@ ) from zarr.abc.store import ByteGetter, ByteSetter from zarr.array_spec import ArraySpec -from zarr.buffer import Buffer, BufferPrototype, NDBuffer, default_buffer_prototype +from zarr.buffer import ( + Buffer, + BufferPrototype, + NDBuffer, + default_buffer_prototype, + numpy_buffer_prototype, +) from zarr.chunk_grids import ChunkGrid, RegularChunkGrid from zarr.codecs.bytes import BytesCodec from zarr.codecs.crc32c_ import Crc32cCodec @@ -629,7 +635,7 @@ def _get_index_chunk_spec(self, chunks_per_shard: ChunkCoords) -> ArraySpec: dtype=np.dtype(" ArraySpec: @@ -657,11 +663,11 @@ async def _load_shard_index_maybe( shard_index_size = self._shard_index_size(chunks_per_shard) if self.index_location == ShardingCodecIndexLocation.start: index_bytes = await byte_getter.get( - prototype=default_buffer_prototype(), byte_range=(0, shard_index_size) + prototype=numpy_buffer_prototype(), byte_range=(0, shard_index_size) ) else: index_bytes = await byte_getter.get( - prototype=default_buffer_prototype(), byte_range=(-shard_index_size, None) + prototype=numpy_buffer_prototype(), byte_range=(-shard_index_size, None) ) if index_bytes is not None: return await self._decode_shard_index(index_bytes, chunks_per_shard) diff --git a/tests/v3/test_buffer.py b/tests/v3/test_buffer.py index 5f6a948b17..c1a4904525 100644 --- a/tests/v3/test_buffer.py +++ b/tests/v3/test_buffer.py @@ -4,7 +4,7 @@ import pytest from zarr.array import AsyncArray -from zarr.buffer import ArrayLike, BufferPrototype, NDArrayLike +from zarr.buffer import ArrayLike, BufferPrototype, NDArrayLike, numpy_buffer_prototype from zarr.codecs.blosc import BloscCodec from zarr.codecs.bytes import BytesCodec from zarr.codecs.crc32c_ import Crc32cCodec @@ -77,3 +77,10 @@ async def test_codecs_use_of_prototype(): got = await a.getitem(selection=(slice(0, 10), slice(0, 10)), prototype=my_prototype) assert isinstance(got, MyNDArrayLike) assert np.array_equal(expect, got) + + +def test_numpy_buffer_prototype(): + buffer = numpy_buffer_prototype().buffer.create_zero_length() + ndbuffer = numpy_buffer_prototype().nd_buffer.create(shape=(1, 2), dtype=np.dtype("int64")) + assert isinstance(buffer.as_array_like(), np.ndarray) + assert isinstance(ndbuffer.as_ndarray_like(), np.ndarray) diff --git a/tests/v3/test_config.py b/tests/v3/test_config.py index 9dc714fdd0..bc762578ea 100644 --- a/tests/v3/test_config.py +++ b/tests/v3/test_config.py @@ -191,21 +191,23 @@ def test_config_buffer_implementation(): assert get_buffer_class() == MyBuffer # no error using MyBuffer + data = np.arange(100) arr[:] = np.arange(100) + assert np.array_equal(arr[:], data) + data2d = np.arange(1000).reshape(100, 10) arr_sharding = zeros( shape=(100, 10), store=StoreExpectingMyBuffer(mode="w"), codecs=[ShardingCodec(chunk_shape=(10, 10))], ) - arr_sharding[:] = np.arange(1000).reshape(100, 10) + arr_sharding[:] = data2d + assert np.array_equal(arr_sharding[:], data2d) arr_Crc32c = zeros( shape=(100, 10), store=StoreExpectingMyBuffer(mode="w"), codecs=[BytesCodec(), Crc32cCodec()], ) - arr_Crc32c[:] = np.arange(1000).reshape(100, 10) - - -pass + arr_Crc32c[:] = data2d + assert np.array_equal(arr_Crc32c[:], data2d) From 01ab484f0d777ff24fffd81c0d8267fba897c54b Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Thu, 4 Jul 2024 16:33:18 +0200 Subject: [PATCH 26/36] rename buffer and entrypoint test-classes --- src/zarr/abc/codec.py | 12 +++---- src/zarr/config.py | 2 +- src/zarr/testing/buffer.py | 16 ++++----- .../entry_points.txt | 8 ++--- tests/v3/package_with_entrypoint/__init__.py | 8 ++--- tests/v3/test_buffer.py | 19 +++++++---- tests/v3/test_codec_entrypoints.py | 14 ++++---- tests/v3/test_config.py | 34 +++++++++---------- 8 files changed, 56 insertions(+), 57 deletions(-) diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index db4deb0807..9223019fab 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -11,7 +11,7 @@ from zarr.buffer import Buffer, NDBuffer from zarr.chunk_grids import ChunkGrid from zarr.common import ChunkCoords, concurrent_map -from zarr.config import Config, config +from zarr.config import config if TYPE_CHECKING: from typing_extensions import Self @@ -24,10 +24,6 @@ CodecOutput = TypeVar("CodecOutput", bound=NDBuffer | Buffer) -def get_config() -> Config: - return config - - class _Codec(Generic[CodecInput, CodecOutput], Metadata): """Generic base class for codecs. Please use ArrayArrayCodec, ArrayBytesCodec or BytesBytesCodec for subclassing. @@ -194,7 +190,7 @@ async def decode_partial( return await concurrent_map( list(batch_info), self._decode_partial_single, - get_config().get("async.concurrency"), + config.get("async.concurrency"), ) @@ -231,7 +227,7 @@ async def encode_partial( await concurrent_map( list(batch_info), self._encode_partial_single, - get_config().get("async.concurrency"), + config.get("async.concurrency"), ) @@ -412,7 +408,7 @@ async def batching_helper( return await concurrent_map( list(batch_info), noop_for_none(func), - get_config().get("async.concurrency"), + config.get("async.concurrency"), ) diff --git a/src/zarr/config.py b/src/zarr/config.py index fc7afaa4b3..214dfc794c 100644 --- a/src/zarr/config.py +++ b/src/zarr/config.py @@ -13,7 +13,7 @@ class Config(DConfig): # type: ignore[misc] """Will collect configuration from config files and environment variables Example environment variables: - Grabs environment variables of the form "DASK_FOO__BAR_BAZ=123" and + Grabs environment variables of the form "ZARR_PYTHON_FOO__BAR_BAZ=123" and turns these into config variables of the form ``{"foo": {"bar-baz": 123}}`` It transforms the key and value in the following way: diff --git a/src/zarr/testing/buffer.py b/src/zarr/testing/buffer.py index 73dda61714..d2da1c5a6e 100644 --- a/src/zarr/testing/buffer.py +++ b/src/zarr/testing/buffer.py @@ -14,15 +14,15 @@ from typing_extensions import Self -class MyNDArrayLike(np.ndarray): +class TestNDArrayLike(np.ndarray): """An example of a ndarray-like class""" -class MyBuffer(Buffer): +class TestBuffer(Buffer): """Example of a custom Buffer that handles ArrayLike""" -class MyNDBuffer(NDBuffer): +class NDBufferUsingTestNDArrayLike(NDBuffer): """Example of a custom NDBuffer that handles MyNDArrayLike""" @classmethod @@ -34,14 +34,14 @@ def create( order: Literal["C", "F"] = "C", fill_value: Any | None = None, ) -> Self: - """Overwrite `NDBuffer.create` to create an MyNDArrayLike instance""" - ret = cls(MyNDArrayLike(shape=shape, dtype=dtype, order=order)) + """Overwrite `NDBuffer.create` to create an TestNDArrayLike instance""" + ret = cls(TestNDArrayLike(shape=shape, dtype=dtype, order=order)) if fill_value is not None: ret.fill(fill_value) return ret -class StoreExpectingMyBuffer(MemoryStore): +class StoreExpectingTestBuffer(MemoryStore): """Example of a custom Store that expect MyBuffer for all its non-metadata We assume that keys containing "json" is metadata @@ -49,7 +49,7 @@ class StoreExpectingMyBuffer(MemoryStore): async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: if "json" not in key: - assert isinstance(value, MyBuffer) + assert isinstance(value, TestBuffer) await super().set(key, value, byte_range) async def get( @@ -59,7 +59,7 @@ async def get( byte_range: tuple[int, int | None] | None = None, ) -> Buffer | None: if "json" not in key: - assert prototype.buffer is MyBuffer + assert prototype.buffer is TestBuffer ret = await super().get(key=key, prototype=prototype, byte_range=byte_range) if ret is not None: assert isinstance(ret, prototype.buffer) diff --git a/tests/v3/package_with_entrypoint-0.1.dist-info/entry_points.txt b/tests/v3/package_with_entrypoint-0.1.dist-info/entry_points.txt index 8457b61d9d..cab9aedd7a 100644 --- a/tests/v3/package_with_entrypoint-0.1.dist-info/entry_points.txt +++ b/tests/v3/package_with_entrypoint-0.1.dist-info/entry_points.txt @@ -1,6 +1,6 @@ [zarr.codecs] -test = package_with_entrypoint:TestCodec +test = package_with_entrypoint:TestEntrypointCodec [zarr] -codec_pipeline = package_with_entrypoint:TestCodecPipeline -buffer = package_with_entrypoint:TestBuffer -ndbuffer = package_with_entrypoint:TestNDBuffer \ No newline at end of file +codec_pipeline = package_with_entrypoint:TestEntrypointCodecPipeline +buffer = package_with_entrypoint:TestEntrypointBuffer +ndbuffer = package_with_entrypoint:TestEntrypointNDBuffer \ No newline at end of file diff --git a/tests/v3/package_with_entrypoint/__init__.py b/tests/v3/package_with_entrypoint/__init__.py index cbd70470f1..0b3adf508c 100644 --- a/tests/v3/package_with_entrypoint/__init__.py +++ b/tests/v3/package_with_entrypoint/__init__.py @@ -8,7 +8,7 @@ from zarr.common import BytesLike -class TestCodec(ArrayBytesCodec): +class TestEntrypointCodec(ArrayBytesCodec): is_fixed_size = True async def encode( @@ -27,7 +27,7 @@ def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) -> return input_byte_length -class TestCodecPipeline(CodecPipeline): +class TestEntrypointCodecPipeline(CodecPipeline): def __init__(self, batch_size: int = 1): pass @@ -42,9 +42,9 @@ async def decode( pass -class TestBuffer(Buffer): +class TestEntrypointBuffer(Buffer): pass -class TestNDBuffer(NDBuffer): +class TestEntrypointNDBuffer(NDBuffer): pass diff --git a/tests/v3/test_buffer.py b/tests/v3/test_buffer.py index c1a4904525..d53e98d42d 100644 --- a/tests/v3/test_buffer.py +++ b/tests/v3/test_buffer.py @@ -12,7 +12,12 @@ from zarr.codecs.transpose import TransposeCodec from zarr.codecs.zstd import ZstdCodec from zarr.store.core import StorePath -from zarr.testing.buffer import MyBuffer, MyNDArrayLike, MyNDBuffer, StoreExpectingMyBuffer +from zarr.testing.buffer import ( + NDBufferUsingTestNDArrayLike, + StoreExpectingTestBuffer, + TestBuffer, + TestNDArrayLike, +) def test_nd_array_like(xp): @@ -27,7 +32,7 @@ async def test_async_array_prototype(): expect = np.zeros((9, 9), dtype="uint16", order="F") a = await AsyncArray.create( - StorePath(StoreExpectingMyBuffer(mode="w")) / "test_async_array_prototype", + StorePath(StoreExpectingTestBuffer(mode="w")) / "test_async_array_prototype", shape=expect.shape, chunk_shape=(5, 5), dtype=expect.dtype, @@ -35,7 +40,7 @@ async def test_async_array_prototype(): ) expect[1:4, 3:6] = np.ones((3, 3)) - my_prototype = BufferPrototype(buffer=MyBuffer, nd_buffer=MyNDBuffer) + my_prototype = BufferPrototype(buffer=TestBuffer, nd_buffer=NDBufferUsingTestNDArrayLike) await a.setitem( selection=(slice(1, 4), slice(3, 6)), @@ -43,7 +48,7 @@ async def test_async_array_prototype(): prototype=my_prototype, ) got = await a.getitem(selection=(slice(0, 9), slice(0, 9)), prototype=my_prototype) - assert isinstance(got, MyNDArrayLike) + assert isinstance(got, TestNDArrayLike) assert np.array_equal(expect, got) @@ -51,7 +56,7 @@ async def test_async_array_prototype(): async def test_codecs_use_of_prototype(): expect = np.zeros((10, 10), dtype="uint16", order="F") a = await AsyncArray.create( - StorePath(StoreExpectingMyBuffer(mode="w")) / "test_codecs_use_of_prototype", + StorePath(StoreExpectingTestBuffer(mode="w")) / "test_codecs_use_of_prototype", shape=expect.shape, chunk_shape=(5, 5), dtype=expect.dtype, @@ -67,7 +72,7 @@ async def test_codecs_use_of_prototype(): ) expect[:] = np.arange(100).reshape(10, 10) - my_prototype = BufferPrototype(buffer=MyBuffer, nd_buffer=MyNDBuffer) + my_prototype = BufferPrototype(buffer=TestBuffer, nd_buffer=NDBufferUsingTestNDArrayLike) await a.setitem( selection=(slice(0, 10), slice(0, 10)), @@ -75,7 +80,7 @@ async def test_codecs_use_of_prototype(): prototype=my_prototype, ) got = await a.getitem(selection=(slice(0, 10), slice(0, 10)), prototype=my_prototype) - assert isinstance(got, MyNDArrayLike) + assert isinstance(got, TestNDArrayLike) assert np.array_equal(expect, got) diff --git a/tests/v3/test_codec_entrypoints.py b/tests/v3/test_codec_entrypoints.py index 280dfae2b4..539a188706 100644 --- a/tests/v3/test_codec_entrypoints.py +++ b/tests/v3/test_codec_entrypoints.py @@ -23,20 +23,20 @@ def set_path(): @pytest.mark.usefixtures("set_path") def test_entrypoint_codec(): - config.set({"codecs.test.name": "TestCodec"}) + config.set({"codecs.test.name": "TestEntrypointCodec"}) cls = zarr.registry.get_codec_class("test") - assert cls.__name__ == "TestCodec" + assert cls.__name__ == "TestEntrypointCodec" @pytest.mark.usefixtures("set_path") def test_entrypoint_pipeline(): - config.set({"codec_pipeline.name": "TestCodecPipeline"}) + config.set({"codec_pipeline.name": "TestEntrypointCodecPipeline"}) cls = zarr.registry.get_pipeline_class() - assert cls.__name__ == "TestCodecPipeline" + assert cls.__name__ == "TestEntrypointCodecPipeline" @pytest.mark.usefixtures("set_path") def test_entrypoint_buffer(): - config.set({"buffer.name": "TestBuffer", "ndbuffer.name": "TestNDBuffer"}) - assert zarr.registry.get_buffer_class().__name__ == "TestBuffer" - assert zarr.registry.get_ndbuffer_class().__name__ == "TestNDBuffer" + config.set({"buffer.name": "TestEntrypointBuffer", "ndbuffer.name": "TestEntrypointNDBuffer"}) + assert zarr.registry.get_buffer_class().__name__ == "TestEntrypointBuffer" + assert zarr.registry.get_ndbuffer_class().__name__ == "TestEntrypointNDBuffer" diff --git a/tests/v3/test_config.py b/tests/v3/test_config.py index bc762578ea..ffd6bdaa23 100644 --- a/tests/v3/test_config.py +++ b/tests/v3/test_config.py @@ -25,14 +25,12 @@ register_ndbuffer, register_pipeline, ) -from zarr.testing.buffer import MyBuffer, MyNDArrayLike, MyNDBuffer, StoreExpectingMyBuffer - - -@pytest.fixture() -def reset_config(): - config.reset() - yield - config.reset() +from zarr.testing.buffer import ( + NDBufferUsingTestNDArrayLike, + StoreExpectingTestBuffer, + TestBuffer, + TestNDArrayLike, +) def test_config_defaults_set() -> None: @@ -161,9 +159,9 @@ def test_config_ndbuffer_implementation(store): assert get_ndbuffer_class().__name__ == config.defaults[0]["ndbuffer"]["name"] # set custom ndbuffer with MyNDArrayLike implementation - register_ndbuffer(MyNDBuffer) - config.set({"ndbuffer.name": "MyNDBuffer"}) - assert get_ndbuffer_class() == MyNDBuffer + register_ndbuffer(NDBufferUsingTestNDArrayLike) + config.set({"ndbuffer.name": "NDBufferUsingTestNDArrayLike"}) + assert get_ndbuffer_class() == NDBufferUsingTestNDArrayLike arr = Array.create( store=store, shape=(100,), @@ -173,22 +171,22 @@ def test_config_ndbuffer_implementation(store): ) got = arr[:] print(type(got)) - assert isinstance(got, MyNDArrayLike) + assert isinstance(got, TestNDArrayLike) def test_config_buffer_implementation(): # has default value assert get_buffer_class().__name__ == config.defaults[0]["buffer"]["name"] - arr = zeros(shape=(100), store=StoreExpectingMyBuffer(mode="w")) + arr = zeros(shape=(100), store=StoreExpectingTestBuffer(mode="w")) # AssertionError of StoreExpectingMyBuffer when not using my buffer with pytest.raises(AssertionError): arr[:] = np.arange(100) - register_buffer(MyBuffer) - config.set({"buffer.name": "MyBuffer"}) - assert get_buffer_class() == MyBuffer + register_buffer(TestBuffer) + config.set({"buffer.name": "TestBuffer"}) + assert get_buffer_class() == TestBuffer # no error using MyBuffer data = np.arange(100) @@ -198,7 +196,7 @@ def test_config_buffer_implementation(): data2d = np.arange(1000).reshape(100, 10) arr_sharding = zeros( shape=(100, 10), - store=StoreExpectingMyBuffer(mode="w"), + store=StoreExpectingTestBuffer(mode="w"), codecs=[ShardingCodec(chunk_shape=(10, 10))], ) arr_sharding[:] = data2d @@ -206,7 +204,7 @@ def test_config_buffer_implementation(): arr_Crc32c = zeros( shape=(100, 10), - store=StoreExpectingMyBuffer(mode="w"), + store=StoreExpectingTestBuffer(mode="w"), codecs=[BytesCodec(), Crc32cCodec()], ) arr_Crc32c[:] = data2d From 96271571c82cc87c09907318dbfc33215a282da0 Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Thu, 4 Jul 2024 17:37:39 +0200 Subject: [PATCH 27/36] document interaction registry and config --- src/zarr/config.py | 11 +++++++++++ src/zarr/registry.py | 7 +++++++ 2 files changed, 18 insertions(+) diff --git a/src/zarr/config.py b/src/zarr/config.py index 214dfc794c..fdc91991f9 100644 --- a/src/zarr/config.py +++ b/src/zarr/config.py @@ -28,6 +28,17 @@ def reset(self) -> None: self.refresh() +""" +The config module is responsible for managing the configuration of zarr and is based on the Donfig python library. +For selecting custom implementations of codecs, pipelines, buffers and ndbuffers, first register the implementations +in the registry and then select them in the config. +e.g. an implementation of the bytes codec in a class "NewBytesCodec", requires the value of codecs.bytes.name to be +"NewBytesCodec". +Donfig can be configured programmatically, by environment variables, or from YAML files in standard locations +e.g. export ZARR_PYTHON_CODECS__BYTES__NAME="NewBytesCodec" +(for more information see github.com/pytroll/donfig) +Default values below point to the standard implementations of zarr-python +""" config = Config( "zarr_python", defaults=[ diff --git a/src/zarr/registry.py b/src/zarr/registry.py index 7906ffb699..56fd22e17e 100644 --- a/src/zarr/registry.py +++ b/src/zarr/registry.py @@ -21,10 +21,17 @@ __ndbuffer_registry: dict[str, type[NDBuffer]] = {} __lazy_load_ndbuffer: list[EntryPoint] = [] +""" +The registry module is responsible for managing implementations of codecs, pipelines, buffers and ndbuffers and +collecting them from entrypoints. +The implementation used is determined by the config +""" + def _collect_entrypoints() -> ( tuple[dict[str, EntryPoint], list[EntryPoint], list[EntryPoint], list[EntryPoint]] ): + """Collects codecs, pipelines, buffers and ndbuffers from entrypoints""" entry_points = get_entry_points() for e in entry_points.select(group="zarr.codecs"): __lazy_load_codecs[e.name] = e From 5e8300206389317120967e817288df0241c3223f Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Mon, 8 Jul 2024 13:46:16 +0200 Subject: [PATCH 28/36] change config prefix from zarr_python to zarr --- src/zarr/config.py | 6 +++--- tests/v3/test_config.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/zarr/config.py b/src/zarr/config.py index fdc91991f9..44c62ec47d 100644 --- a/src/zarr/config.py +++ b/src/zarr/config.py @@ -13,7 +13,7 @@ class Config(DConfig): # type: ignore[misc] """Will collect configuration from config files and environment variables Example environment variables: - Grabs environment variables of the form "ZARR_PYTHON_FOO__BAR_BAZ=123" and + Grabs environment variables of the form "ZARR_FOO__BAR_BAZ=123" and turns these into config variables of the form ``{"foo": {"bar-baz": 123}}`` It transforms the key and value in the following way: @@ -35,12 +35,12 @@ def reset(self) -> None: e.g. an implementation of the bytes codec in a class "NewBytesCodec", requires the value of codecs.bytes.name to be "NewBytesCodec". Donfig can be configured programmatically, by environment variables, or from YAML files in standard locations -e.g. export ZARR_PYTHON_CODECS__BYTES__NAME="NewBytesCodec" +e.g. export ZARR_CODECS__BYTES__NAME="NewBytesCodec" (for more information see github.com/pytroll/donfig) Default values below point to the standard implementations of zarr-python """ config = Config( - "zarr_python", + "zarr", defaults=[ { "array": {"order": "C"}, diff --git a/tests/v3/test_config.py b/tests/v3/test_config.py index ffd6bdaa23..ebe5a1ebd0 100644 --- a/tests/v3/test_config.py +++ b/tests/v3/test_config.py @@ -116,7 +116,7 @@ class MockEnvCodecPipeline(CodecPipeline): register_pipeline(MockEnvCodecPipeline) - with mock.patch.dict(os.environ, {"ZARR_PYTHON_CODEC_PIPELINE__NAME": "MockEnvCodecPipeline"}): + with mock.patch.dict(os.environ, {"ZARR_CODEC_PIPELINE__NAME": "MockEnvCodecPipeline"}): assert get_pipeline_class(reload_config=True) == MockEnvCodecPipeline @@ -149,7 +149,7 @@ async def _encode_single( arr[:] = range(100) _mock.call.assert_called() - with mock.patch.dict(os.environ, {"ZARR_PYTHON_CODECS__BLOSC__NAME": "BloscCodec"}): + with mock.patch.dict(os.environ, {"ZARR_CODECS__BLOSC__NAME": "BloscCodec"}): assert get_codec_class("blosc", reload_config=True) == BloscCodec From cc5f93c2e75a46df7e6f6f94999fa92c99cfa228 Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Mon, 8 Jul 2024 16:19:27 +0200 Subject: [PATCH 29/36] use fully_qualified_name for implementation config --- src/zarr/config.py | 25 ++++++----- src/zarr/registry.py | 37 +++++++++------- tests/v3/test_codec_entrypoints.py | 11 +++-- tests/v3/test_config.py | 69 +++++++++++++++++++----------- 4 files changed, 86 insertions(+), 56 deletions(-) diff --git a/src/zarr/config.py b/src/zarr/config.py index 44c62ec47d..ec78747a6b 100644 --- a/src/zarr/config.py +++ b/src/zarr/config.py @@ -46,19 +46,22 @@ def reset(self) -> None: "array": {"order": "C"}, "async": {"concurrency": None, "timeout": None}, "json_indent": 2, - "codec_pipeline": {"name": "BatchedCodecPipeline", "batch_size": 1}, + "codec_pipeline": { + "path": "zarr.codecs.pipeline.BatchedCodecPipeline", + "batch_size": 1, + }, "codecs": { - "blosc": {"name": "BloscCodec"}, - "gzip": {"name": "GzipCodec"}, - "zstd": {"name": "ZstdCodec"}, - "bytes": {"name": "BytesCodec"}, - "endian": {"name": "BytesCodec"}, # compatibility with earlier versions of ZEP1 - "crc32c": {"name": "Crc32cCodec"}, - "sharding_indexed": {"name": "ShardingCodec"}, - "transpose": {"name": "TransposeCodec"}, + "blosc": "zarr.codecs.blosc.BloscCodec", + "gzip": "zarr.codecs.gzip.GzipCodec", + "zstd": "zarr.codecs.zstd.ZstdCodec", + "bytes": "zarr.codecs.bytes.BytesCodec", + "endian": "zarr.codecs.bytes.BytesCodec", # compatibility with earlier versions of ZEP1 + "crc32c": "zarr.codecs.crc32c_.Crc32cCodec", + "sharding_indexed": "zarr.codecs.sharding.ShardingCodec", + "transpose": "zarr.codecs.transpose.TransposeCodec", }, - "buffer": {"name": "Buffer"}, - "ndbuffer": {"name": "NDBuffer"}, + "buffer": "zarr.buffer.Buffer", + "ndbuffer": "zarr.buffer.NDBuffer", } ], ) diff --git a/src/zarr/registry.py b/src/zarr/registry.py index 56fd22e17e..4ffb18d07b 100644 --- a/src/zarr/registry.py +++ b/src/zarr/registry.py @@ -50,22 +50,27 @@ def _reload_config() -> None: config.refresh() +def fully_qualified_name(cls: type) -> str: + module = cls.__module__ + return module + "." + cls.__qualname__ + + def register_codec(key: str, codec_cls: type[Codec]) -> None: registered_codecs = __codec_registry.get(key, {}) - registered_codecs[codec_cls.__name__] = codec_cls + registered_codecs[fully_qualified_name(codec_cls)] = codec_cls __codec_registry[key] = registered_codecs def register_pipeline(pipe_cls: type[CodecPipeline]) -> None: - __pipeline_registry[pipe_cls.__name__] = pipe_cls + __pipeline_registry[fully_qualified_name(pipe_cls)] = pipe_cls def register_ndbuffer(cls: type[NDBuffer]) -> None: - __ndbuffer_registry[cls.__name__] = cls + __ndbuffer_registry[fully_qualified_name(cls)] = cls def register_buffer(cls: type[Buffer]) -> None: - __buffer_registry[cls.__name__] = cls + __buffer_registry[fully_qualified_name(cls)] = cls def get_codec_class(key: str, reload_config: bool = False) -> type[Codec]: @@ -87,9 +92,9 @@ def get_codec_class(key: str, reload_config: bool = False) -> type[Codec]: f"Codec '{key}' not configured in config. Selecting any implementation.", stacklevel=2 ) return list(codec_classes.values())[-1] - - name = config_entry.get("name") - selected_codec_cls = codec_classes[name] + print(f"{codec_classes=}") + print(f"{config_entry=}") + selected_codec_cls = codec_classes[config_entry] if selected_codec_cls: return selected_codec_cls @@ -102,12 +107,12 @@ def get_pipeline_class(reload_config: bool = False) -> type[CodecPipeline]: for e in __lazy_load_pipelines: __lazy_load_pipelines.remove(e) register_pipeline(e.load()) - name = config.get("codec_pipeline.name") - pipeline_class = __pipeline_registry.get(name) + path = config.get("codec_pipeline.path") + pipeline_class = __pipeline_registry.get(path) if pipeline_class: return pipeline_class raise BadConfigError( - f"Pipeline class '{name}' not found in registered pipelines: {list(__pipeline_registry.keys())}." + f"Pipeline class '{path}' not found in registered pipelines: {list(__pipeline_registry.keys())}." ) @@ -117,12 +122,12 @@ def get_buffer_class(reload_config: bool = False) -> type[Buffer]: for e in __lazy_load_buffer: __lazy_load_buffer.remove(e) register_buffer(e.load()) - name = config.get("buffer.name") - buffer_class = __buffer_registry.get(name) + path = config.get("buffer") + buffer_class = __buffer_registry.get(path) if buffer_class: return buffer_class raise BadConfigError( - f"Buffer class '{name}' not found in registered buffers: {list(__buffer_registry.keys())}." + f"Buffer class '{path}' not found in registered buffers: {list(__buffer_registry.keys())}." ) @@ -132,12 +137,12 @@ def get_ndbuffer_class(reload_config: bool = False) -> type[NDBuffer]: for e in __lazy_load_ndbuffer: __lazy_load_ndbuffer.remove(e) register_ndbuffer(e.load()) - name = config.get("ndbuffer.name") - ndbuffer_class = __ndbuffer_registry.get(name) + path = config.get("ndbuffer") + ndbuffer_class = __ndbuffer_registry.get(path) if ndbuffer_class: return ndbuffer_class raise BadConfigError( - f"NDBuffer class '{name}' not found in registered buffers: {list(__ndbuffer_registry.keys())}." + f"NDBuffer class '{path}' not found in registered buffers: {list(__ndbuffer_registry.keys())}." ) diff --git a/tests/v3/test_codec_entrypoints.py b/tests/v3/test_codec_entrypoints.py index 539a188706..af4193db17 100644 --- a/tests/v3/test_codec_entrypoints.py +++ b/tests/v3/test_codec_entrypoints.py @@ -23,20 +23,25 @@ def set_path(): @pytest.mark.usefixtures("set_path") def test_entrypoint_codec(): - config.set({"codecs.test.name": "TestEntrypointCodec"}) + config.set({"codecs.test": "package_with_entrypoint.TestEntrypointCodec"}) cls = zarr.registry.get_codec_class("test") assert cls.__name__ == "TestEntrypointCodec" @pytest.mark.usefixtures("set_path") def test_entrypoint_pipeline(): - config.set({"codec_pipeline.name": "TestEntrypointCodecPipeline"}) + config.set({"codec_pipeline.path": "package_with_entrypoint.TestEntrypointCodecPipeline"}) cls = zarr.registry.get_pipeline_class() assert cls.__name__ == "TestEntrypointCodecPipeline" @pytest.mark.usefixtures("set_path") def test_entrypoint_buffer(): - config.set({"buffer.name": "TestEntrypointBuffer", "ndbuffer.name": "TestEntrypointNDBuffer"}) + config.set( + { + "buffer": "package_with_entrypoint.TestEntrypointBuffer", + "ndbuffer": "package_with_entrypoint.TestEntrypointNDBuffer", + } + ) assert zarr.registry.get_buffer_class().__name__ == "TestEntrypointBuffer" assert zarr.registry.get_ndbuffer_class().__name__ == "TestEntrypointNDBuffer" diff --git a/tests/v3/test_config.py b/tests/v3/test_config.py index ebe5a1ebd0..8e7b868520 100644 --- a/tests/v3/test_config.py +++ b/tests/v3/test_config.py @@ -7,6 +7,7 @@ import numpy as np import pytest +import zarr from zarr import Array, zeros from zarr.abc.codec import CodecInput, CodecOutput, CodecPipeline from zarr.abc.store import ByteSetter @@ -16,6 +17,7 @@ from zarr.config import BadConfigError, config from zarr.indexing import SelectorTuple from zarr.registry import ( + fully_qualified_name, get_buffer_class, get_codec_class, get_ndbuffer_class, @@ -40,18 +42,21 @@ def test_config_defaults_set() -> None: "array": {"order": "C"}, "async": {"concurrency": None, "timeout": None}, "json_indent": 2, - "codec_pipeline": {"name": "BatchedCodecPipeline", "batch_size": 1}, - "buffer": {"name": "Buffer"}, - "ndbuffer": {"name": "NDBuffer"}, + "codec_pipeline": { + "path": "zarr.codecs.pipeline.BatchedCodecPipeline", + "batch_size": 1, + }, + "buffer": "zarr.buffer.Buffer", + "ndbuffer": "zarr.buffer.NDBuffer", "codecs": { - "blosc": {"name": "BloscCodec"}, - "gzip": {"name": "GzipCodec"}, - "zstd": {"name": "ZstdCodec"}, - "bytes": {"name": "BytesCodec"}, - "endian": {"name": "BytesCodec"}, # compatibility with earlier versions of ZEP1 - "crc32c": {"name": "Crc32cCodec"}, - "sharding_indexed": {"name": "ShardingCodec"}, - "transpose": {"name": "TransposeCodec"}, + "blosc": "zarr.codecs.blosc.BloscCodec", + "gzip": "zarr.codecs.gzip.GzipCodec", + "zstd": "zarr.codecs.zstd.ZstdCodec", + "bytes": "zarr.codecs.bytes.BytesCodec", + "endian": "zarr.codecs.bytes.BytesCodec", + "crc32c": "zarr.codecs.crc32c_.Crc32cCodec", + "sharding_indexed": "zarr.codecs.sharding.ShardingCodec", + "transpose": "zarr.codecs.transpose.TransposeCodec", }, } ] @@ -72,13 +77,22 @@ def test_config_defaults_can_be_overridden(key: str, old_val: Any, new_val: Any) assert config.get(key) == new_val +def test_fully_qualified_name(): + class MockClass: + pass + + assert "v3.test_config.test_fully_qualified_name..MockClass" == fully_qualified_name( + MockClass + ) + + @pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) def test_config_codec_pipeline_class(store): # has default value assert get_pipeline_class().__name__ != "" - config.set({"codec_pipeline.name": "BatchedCodecPipeline"}) - assert get_pipeline_class() == BatchedCodecPipeline + config.set({"codec_pipeline.name": "zarr.codecs.pipeline.BatchedCodecPipeline"}) + assert get_pipeline_class() == zarr.codecs.pipeline.BatchedCodecPipeline _mock = Mock() @@ -92,7 +106,8 @@ async def write( _mock.call() register_pipeline(MockCodecPipeline) - config.set({"codec_pipeline.name": "MockCodecPipeline"}) + config.set({"codec_pipeline.path": fully_qualified_name(MockCodecPipeline)}) + assert get_pipeline_class() == MockCodecPipeline # test if codec is used @@ -108,7 +123,7 @@ async def write( _mock.call.assert_called() with pytest.raises(BadConfigError): - config.set({"codec_pipeline.name": "wrong_name"}) + config.set({"codec_pipeline.path": "wrong_name"}) get_pipeline_class() class MockEnvCodecPipeline(CodecPipeline): @@ -116,14 +131,16 @@ class MockEnvCodecPipeline(CodecPipeline): register_pipeline(MockEnvCodecPipeline) - with mock.patch.dict(os.environ, {"ZARR_CODEC_PIPELINE__NAME": "MockEnvCodecPipeline"}): + with mock.patch.dict( + os.environ, {"ZARR_CODEC_PIPELINE__PATH": fully_qualified_name(MockEnvCodecPipeline)} + ): assert get_pipeline_class(reload_config=True) == MockEnvCodecPipeline @pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) def test_config_codec_implementation(store): # has default value - assert get_codec_class("blosc").__name__ == config.defaults[0]["codecs"]["blosc"]["name"] + assert fully_qualified_name(get_codec_class("blosc")) == config.defaults[0]["codecs"]["blosc"] _mock = Mock() @@ -133,7 +150,7 @@ async def _encode_single( ) -> CodecOutput | None: _mock.call() - config.set({"codecs.blosc.name": "MockBloscCodec"}) + config.set({"codecs.blosc": fully_qualified_name(MockBloscCodec)}) register_codec("blosc", MockBloscCodec) assert get_codec_class("blosc") == MockBloscCodec @@ -149,18 +166,18 @@ async def _encode_single( arr[:] = range(100) _mock.call.assert_called() - with mock.patch.dict(os.environ, {"ZARR_CODECS__BLOSC__NAME": "BloscCodec"}): + with mock.patch.dict(os.environ, {"ZARR_CODECS__BLOSC": fully_qualified_name(BloscCodec)}): assert get_codec_class("blosc", reload_config=True) == BloscCodec @pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) def test_config_ndbuffer_implementation(store): # has default value - assert get_ndbuffer_class().__name__ == config.defaults[0]["ndbuffer"]["name"] + assert fully_qualified_name(get_ndbuffer_class()) == config.defaults[0]["ndbuffer"] - # set custom ndbuffer with MyNDArrayLike implementation + # set custom ndbuffer with TestNDArrayLike implementation register_ndbuffer(NDBufferUsingTestNDArrayLike) - config.set({"ndbuffer.name": "NDBufferUsingTestNDArrayLike"}) + config.set({"ndbuffer": fully_qualified_name(NDBufferUsingTestNDArrayLike)}) assert get_ndbuffer_class() == NDBufferUsingTestNDArrayLike arr = Array.create( store=store, @@ -176,19 +193,19 @@ def test_config_ndbuffer_implementation(store): def test_config_buffer_implementation(): # has default value - assert get_buffer_class().__name__ == config.defaults[0]["buffer"]["name"] + assert fully_qualified_name(get_buffer_class()) == config.defaults[0]["buffer"] arr = zeros(shape=(100), store=StoreExpectingTestBuffer(mode="w")) - # AssertionError of StoreExpectingMyBuffer when not using my buffer + # AssertionError of StoreExpectingTestBuffer when not using my buffer with pytest.raises(AssertionError): arr[:] = np.arange(100) register_buffer(TestBuffer) - config.set({"buffer.name": "TestBuffer"}) + config.set({"buffer": fully_qualified_name(TestBuffer)}) assert get_buffer_class() == TestBuffer - # no error using MyBuffer + # no error using TestBuffer data = np.arange(100) arr[:] = np.arange(100) assert np.array_equal(arr[:], data) From ae1023ce5112903b1a863a97e0a36c17efbeead3 Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Mon, 8 Jul 2024 18:23:35 +0200 Subject: [PATCH 30/36] refactor registry dicts --- src/zarr/registry.py | 72 +++++++++++++++++------------- tests/v3/test_codec_entrypoints.py | 6 +-- 2 files changed, 43 insertions(+), 35 deletions(-) diff --git a/src/zarr/registry.py b/src/zarr/registry.py index 4ffb18d07b..001e9424c4 100644 --- a/src/zarr/registry.py +++ b/src/zarr/registry.py @@ -1,7 +1,8 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING +from collections import defaultdict +from typing import TYPE_CHECKING, Any, Generic, TypeVar if TYPE_CHECKING: from zarr.abc.codec import Codec, CodecPipeline @@ -12,14 +13,19 @@ from zarr.config import BadConfigError, config -__codec_registry: dict[str, dict[str, type[Codec]]] = {} -__lazy_load_codecs: dict[str, EntryPoint] = {} -__pipeline_registry: dict[str, type[CodecPipeline]] = {} -__lazy_load_pipelines: list[EntryPoint] = [] -__buffer_registry: dict[str, type[Buffer]] = {} -__lazy_load_buffer: list[EntryPoint] = [] -__ndbuffer_registry: dict[str, type[NDBuffer]] = {} -__lazy_load_ndbuffer: list[EntryPoint] = [] +T = TypeVar("T") + + +class Registry(Generic[T], dict[str, type[T]]): + def __init__(self): + super().__init__() + self.lazy_load_list: list[EntryPoint] = [] + + +__codec_registries: dict[str, Registry[Codec]] = defaultdict(Registry) +__pipeline_registry: Registry[CodecPipeline] = Registry() +__buffer_registry: Registry[Buffer] = Registry() +__ndbuffer_registry: Registry[NDBuffer] = Registry() """ The registry module is responsible for managing implementations of codecs, pipelines, buffers and ndbuffers and @@ -28,22 +34,25 @@ """ -def _collect_entrypoints() -> ( - tuple[dict[str, EntryPoint], list[EntryPoint], list[EntryPoint], list[EntryPoint]] -): +def _collect_entrypoints() -> list[Registry[Any]]: """Collects codecs, pipelines, buffers and ndbuffers from entrypoints""" entry_points = get_entry_points() for e in entry_points.select(group="zarr.codecs"): - __lazy_load_codecs[e.name] = e + __codec_registries[e.name].lazy_load_list.append(e) for e in entry_points.select(group="zarr"): if e.name == "codec_pipeline": - __lazy_load_pipelines.append(e) + __pipeline_registry.lazy_load_list.append(e) if e.name == "buffer": - __lazy_load_buffer.append(e) + __buffer_registry.lazy_load_list.append(e) if e.name == "ndbuffer": - __lazy_load_ndbuffer.append(e) + __ndbuffer_registry.lazy_load_list.append(e) - return __lazy_load_codecs, __lazy_load_pipelines, __lazy_load_buffer, __lazy_load_ndbuffer + return [ + *__codec_registries.values(), + __pipeline_registry, + __buffer_registry, + __ndbuffer_registry, + ] def _reload_config() -> None: @@ -56,9 +65,9 @@ def fully_qualified_name(cls: type) -> str: def register_codec(key: str, codec_cls: type[Codec]) -> None: - registered_codecs = __codec_registry.get(key, {}) - registered_codecs[fully_qualified_name(codec_cls)] = codec_cls - __codec_registry[key] = registered_codecs + registry = __codec_registries.get(key, Registry()) + registry[fully_qualified_name(codec_cls)] = codec_cls + __codec_registries[key] = registry def register_pipeline(pipe_cls: type[CodecPipeline]) -> None: @@ -77,12 +86,13 @@ def get_codec_class(key: str, reload_config: bool = False) -> type[Codec]: if reload_config: _reload_config() - if key in __lazy_load_codecs: + if key in __codec_registries: # logger.debug("Auto loading codec '%s' from entrypoint", codec_id) - cls = __lazy_load_codecs[key].load() - register_codec(key, cls) + for lazy_load_item in __codec_registries[key].lazy_load_list: + cls = lazy_load_item.load() + register_codec(key, cls) - codec_classes = __codec_registry[key] + codec_classes = __codec_registries[key] if not codec_classes: raise KeyError(key) @@ -92,8 +102,6 @@ def get_codec_class(key: str, reload_config: bool = False) -> type[Codec]: f"Codec '{key}' not configured in config. Selecting any implementation.", stacklevel=2 ) return list(codec_classes.values())[-1] - print(f"{codec_classes=}") - print(f"{config_entry=}") selected_codec_cls = codec_classes[config_entry] if selected_codec_cls: @@ -104,8 +112,8 @@ def get_codec_class(key: str, reload_config: bool = False) -> type[Codec]: def get_pipeline_class(reload_config: bool = False) -> type[CodecPipeline]: if reload_config: _reload_config() - for e in __lazy_load_pipelines: - __lazy_load_pipelines.remove(e) + for e in __pipeline_registry.lazy_load_list: + __pipeline_registry.lazy_load_list.remove(e) register_pipeline(e.load()) path = config.get("codec_pipeline.path") pipeline_class = __pipeline_registry.get(path) @@ -119,8 +127,8 @@ def get_pipeline_class(reload_config: bool = False) -> type[CodecPipeline]: def get_buffer_class(reload_config: bool = False) -> type[Buffer]: if reload_config: _reload_config() - for e in __lazy_load_buffer: - __lazy_load_buffer.remove(e) + for e in __buffer_registry.lazy_load_list: + __buffer_registry.lazy_load_list.remove(e) register_buffer(e.load()) path = config.get("buffer") buffer_class = __buffer_registry.get(path) @@ -134,8 +142,8 @@ def get_buffer_class(reload_config: bool = False) -> type[Buffer]: def get_ndbuffer_class(reload_config: bool = False) -> type[NDBuffer]: if reload_config: _reload_config() - for e in __lazy_load_ndbuffer: - __lazy_load_ndbuffer.remove(e) + for e in __ndbuffer_registry.lazy_load_list: + __ndbuffer_registry.lazy_load_list.remove(e) register_ndbuffer(e.load()) path = config.get("ndbuffer") ndbuffer_class = __ndbuffer_registry.get(path) diff --git a/tests/v3/test_codec_entrypoints.py b/tests/v3/test_codec_entrypoints.py index af4193db17..ff81019199 100644 --- a/tests/v3/test_codec_entrypoints.py +++ b/tests/v3/test_codec_entrypoints.py @@ -15,9 +15,9 @@ def set_path(): zarr.registry._collect_entrypoints() yield sys.path.remove(here) - lazy_load_lists = zarr.registry._collect_entrypoints() - for lazy_load_list in lazy_load_lists: - lazy_load_list.clear() + registries = zarr.registry._collect_entrypoints() + for registry in registries: + registry.lazy_load_list.clear() config.reset() From 2d89931904bf5eb3cb4d458d832046312d19af76 Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Mon, 8 Jul 2024 18:23:56 +0200 Subject: [PATCH 31/36] fix default_buffer_prototype access in tests --- src/zarr/registry.py | 2 +- tests/v3/test_codecs/test_blosc.py | 4 +-- tests/v3/test_codecs/test_codecs.py | 38 +++++++++++++------------- tests/v3/test_codecs/test_endian.py | 4 +-- tests/v3/test_codecs/test_sharding.py | 4 +-- tests/v3/test_codecs/test_transpose.py | 4 +-- 6 files changed, 28 insertions(+), 28 deletions(-) diff --git a/src/zarr/registry.py b/src/zarr/registry.py index 001e9424c4..4e04aa607a 100644 --- a/src/zarr/registry.py +++ b/src/zarr/registry.py @@ -17,7 +17,7 @@ class Registry(Generic[T], dict[str, type[T]]): - def __init__(self): + def __init__(self) -> None: super().__init__() self.lazy_load_list: list[EntryPoint] = [] diff --git a/tests/v3/test_codecs/test_blosc.py b/tests/v3/test_codecs/test_blosc.py index 04c4c671c8..33ca9eba77 100644 --- a/tests/v3/test_codecs/test_blosc.py +++ b/tests/v3/test_codecs/test_blosc.py @@ -26,7 +26,7 @@ async def test_blosc_evolve(store: Store, dtype: str) -> None: ) zarr_json = json.loads( - (await store.get(f"{path}/zarr.json", prototype=default_buffer_prototype)).to_bytes() + (await store.get(f"{path}/zarr.json", prototype=default_buffer_prototype())).to_bytes() ) blosc_configuration_json = zarr_json["codecs"][1]["configuration"] assert blosc_configuration_json["typesize"] == typesize @@ -47,7 +47,7 @@ async def test_blosc_evolve(store: Store, dtype: str) -> None: ) zarr_json = json.loads( - (await store.get(f"{path2}/zarr.json", prototype=default_buffer_prototype)).to_bytes() + (await store.get(f"{path2}/zarr.json", prototype=default_buffer_prototype())).to_bytes() ) blosc_configuration_json = zarr_json["codecs"][0]["configuration"]["codecs"][1]["configuration"] assert blosc_configuration_json["typesize"] == typesize diff --git a/tests/v3/test_codecs/test_codecs.py b/tests/v3/test_codecs/test_codecs.py index 1104805d4b..a2b459f60d 100644 --- a/tests/v3/test_codecs/test_codecs.py +++ b/tests/v3/test_codecs/test_codecs.py @@ -127,7 +127,7 @@ async def test_order( ) z[:, :] = data assert_bytes_equal( - await store.get(f"{path}/0.0", prototype=default_buffer_prototype), z._store["0.0"] + await store.get(f"{path}/0.0", prototype=default_buffer_prototype()), z._store["0.0"] ) @@ -249,7 +249,7 @@ async def test_delete_empty_chunks(store: Store) -> None: await _AsyncArrayProxy(a)[:16, :16].set(np.zeros((16, 16))) await _AsyncArrayProxy(a)[:16, :16].set(data) assert np.array_equal(await _AsyncArrayProxy(a)[:16, :16].get(), data) - assert await store.get(f"{path}/c0/0", prototype=default_buffer_prototype) is None + assert await store.get(f"{path}/c0/0", prototype=default_buffer_prototype()) is None @pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) @@ -280,16 +280,16 @@ async def test_zarr_compat(store: Store) -> None: assert np.array_equal(data, z2[:16, :18]) assert_bytes_equal( - z2._store["0.0"], await store.get(f"{path}/0.0", prototype=default_buffer_prototype) + z2._store["0.0"], await store.get(f"{path}/0.0", prototype=default_buffer_prototype()) ) assert_bytes_equal( - z2._store["0.1"], await store.get(f"{path}/0.1", prototype=default_buffer_prototype) + z2._store["0.1"], await store.get(f"{path}/0.1", prototype=default_buffer_prototype()) ) assert_bytes_equal( - z2._store["1.0"], await store.get(f"{path}/1.0", prototype=default_buffer_prototype) + z2._store["1.0"], await store.get(f"{path}/1.0", prototype=default_buffer_prototype()) ) assert_bytes_equal( - z2._store["1.1"], await store.get(f"{path}/1.1", prototype=default_buffer_prototype) + z2._store["1.1"], await store.get(f"{path}/1.1", prototype=default_buffer_prototype()) ) @@ -323,16 +323,16 @@ async def test_zarr_compat_F(store: Store) -> None: assert np.array_equal(data, z2[:16, :18]) assert_bytes_equal( - z2._store["0.0"], await store.get(f"{path}/0.0", prototype=default_buffer_prototype) + z2._store["0.0"], await store.get(f"{path}/0.0", prototype=default_buffer_prototype()) ) assert_bytes_equal( - z2._store["0.1"], await store.get(f"{path}/0.1", prototype=default_buffer_prototype) + z2._store["0.1"], await store.get(f"{path}/0.1", prototype=default_buffer_prototype()) ) assert_bytes_equal( - z2._store["1.0"], await store.get(f"{path}/1.0", prototype=default_buffer_prototype) + z2._store["1.0"], await store.get(f"{path}/1.0", prototype=default_buffer_prototype()) ) assert_bytes_equal( - z2._store["1.1"], await store.get(f"{path}/1.1", prototype=default_buffer_prototype) + z2._store["1.1"], await store.get(f"{path}/1.1", prototype=default_buffer_prototype()) ) @@ -365,7 +365,7 @@ async def test_dimension_names(store: Store) -> None: ) assert (await AsyncArray.open(spath2)).metadata.dimension_names is None - zarr_json_buffer = await store.get(f"{path2}/zarr.json", prototype=default_buffer_prototype) + zarr_json_buffer = await store.get(f"{path2}/zarr.json", prototype=default_buffer_prototype()) assert zarr_json_buffer is not None assert "dimension_names" not in json.loads(zarr_json_buffer.to_bytes()) @@ -473,14 +473,14 @@ async def test_resize(store: Store) -> None: ) await _AsyncArrayProxy(a)[:16, :18].set(data) - assert await store.get(f"{path}/1.1", prototype=default_buffer_prototype) is not None - assert await store.get(f"{path}/0.0", prototype=default_buffer_prototype) is not None - assert await store.get(f"{path}/0.1", prototype=default_buffer_prototype) is not None - assert await store.get(f"{path}/1.0", prototype=default_buffer_prototype) is not None + assert await store.get(f"{path}/1.1", prototype=default_buffer_prototype()) is not None + assert await store.get(f"{path}/0.0", prototype=default_buffer_prototype()) is not None + assert await store.get(f"{path}/0.1", prototype=default_buffer_prototype()) is not None + assert await store.get(f"{path}/1.0", prototype=default_buffer_prototype()) is not None a = await a.resize((10, 12)) assert a.metadata.shape == (10, 12) - assert await store.get(f"{path}/0.0", prototype=default_buffer_prototype) is not None - assert await store.get(f"{path}/0.1", prototype=default_buffer_prototype) is not None - assert await store.get(f"{path}/1.0", prototype=default_buffer_prototype) is None - assert await store.get(f"{path}/1.1", prototype=default_buffer_prototype) is None + assert await store.get(f"{path}/0.0", prototype=default_buffer_prototype()) is not None + assert await store.get(f"{path}/0.1", prototype=default_buffer_prototype()) is not None + assert await store.get(f"{path}/1.0", prototype=default_buffer_prototype()) is None + assert await store.get(f"{path}/1.1", prototype=default_buffer_prototype()) is None diff --git a/tests/v3/test_codecs/test_endian.py b/tests/v3/test_codecs/test_endian.py index 8301a424b9..6f3e1c9482 100644 --- a/tests/v3/test_codecs/test_endian.py +++ b/tests/v3/test_codecs/test_endian.py @@ -44,7 +44,7 @@ async def test_endian(store: Store, endian: Literal["big", "little"]) -> None: ) z[:, :] = data assert_bytes_equal( - await store.get(f"{path}/0.0", prototype=default_buffer_prototype), z._store["0.0"] + await store.get(f"{path}/0.0", prototype=default_buffer_prototype()), z._store["0.0"] ) @@ -83,5 +83,5 @@ async def test_endian_write( ) z[:, :] = data assert_bytes_equal( - await store.get(f"{path}/0.0", prototype=default_buffer_prototype), z._store["0.0"] + await store.get(f"{path}/0.0", prototype=default_buffer_prototype()), z._store["0.0"] ) diff --git a/tests/v3/test_codecs/test_sharding.py b/tests/v3/test_codecs/test_sharding.py index f0031349cb..27667ca9dd 100644 --- a/tests/v3/test_codecs/test_sharding.py +++ b/tests/v3/test_codecs/test_sharding.py @@ -314,8 +314,8 @@ async def test_delete_empty_shards(store: Store) -> None: data = np.ones((16, 16), dtype="uint16") data[:8, :8] = 0 assert np.array_equal(data, await _AsyncArrayProxy(a)[:, :].get()) - assert await store.get(f"{path}/c/1/0", prototype=default_buffer_prototype) is None - chunk_bytes = await store.get(f"{path}/c/0/0", prototype=default_buffer_prototype) + assert await store.get(f"{path}/c/1/0", prototype=default_buffer_prototype()) is None + chunk_bytes = await store.get(f"{path}/c/0/0", prototype=default_buffer_prototype()) assert chunk_bytes is not None and len(chunk_bytes) == 16 * 2 + 8 * 8 * 2 + 4 diff --git a/tests/v3/test_codecs/test_transpose.py b/tests/v3/test_codecs/test_transpose.py index 3fd4350299..bea7435122 100644 --- a/tests/v3/test_codecs/test_transpose.py +++ b/tests/v3/test_codecs/test_transpose.py @@ -79,8 +79,8 @@ async def test_transpose( ) z[:, :] = data assert await store.get( - "transpose/0.0", prototype=default_buffer_prototype - ) == await store.get("transpose_zarr/0.0", default_buffer_prototype) + "transpose/0.0", prototype=default_buffer_prototype() + ) == await store.get("transpose_zarr/0.0", default_buffer_prototype()) @pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) From 168efff698b677198e521d12d3fdc3dffc56a5d1 Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Tue, 9 Jul 2024 13:35:17 +0200 Subject: [PATCH 32/36] allow multiple implementations per entry_point --- src/zarr/registry.py | 16 +++++++++------- .../entry_points.txt | 6 ++++-- tests/v3/package_with_entrypoint/__init__.py | 4 ++++ 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/src/zarr/registry.py b/src/zarr/registry.py index 4e04aa607a..2f629956ed 100644 --- a/src/zarr/registry.py +++ b/src/zarr/registry.py @@ -37,16 +37,18 @@ def __init__(self) -> None: def _collect_entrypoints() -> list[Registry[Any]]: """Collects codecs, pipelines, buffers and ndbuffers from entrypoints""" entry_points = get_entry_points() - for e in entry_points.select(group="zarr.codecs"): - __codec_registries[e.name].lazy_load_list.append(e) - for e in entry_points.select(group="zarr"): - if e.name == "codec_pipeline": + for e in entry_points: + if e.matches(group="zarr", name="codec_pipeline") or e.matches(group="zarr.codec_pipeline"): __pipeline_registry.lazy_load_list.append(e) - if e.name == "buffer": + if e.matches(group="zarr", name="buffer") or e.matches(group="zarr.buffer"): __buffer_registry.lazy_load_list.append(e) - if e.name == "ndbuffer": + if e.matches(group="zarr", name="ndbuffer") or e.matches(group="zarr.ndbuffer"): __ndbuffer_registry.lazy_load_list.append(e) - + if e.matches(group="zarr.codecs"): + __codec_registries[e.name].lazy_load_list.append(e) + if e.group.startswith("zarr.codecs."): + codec_name = e.group.split(".")[2] + __codec_registries[codec_name].lazy_load_list.append(e) return [ *__codec_registries.values(), __pipeline_registry, diff --git a/tests/v3/package_with_entrypoint-0.1.dist-info/entry_points.txt b/tests/v3/package_with_entrypoint-0.1.dist-info/entry_points.txt index cab9aedd7a..76b98f2283 100644 --- a/tests/v3/package_with_entrypoint-0.1.dist-info/entry_points.txt +++ b/tests/v3/package_with_entrypoint-0.1.dist-info/entry_points.txt @@ -2,5 +2,7 @@ test = package_with_entrypoint:TestEntrypointCodec [zarr] codec_pipeline = package_with_entrypoint:TestEntrypointCodecPipeline -buffer = package_with_entrypoint:TestEntrypointBuffer -ndbuffer = package_with_entrypoint:TestEntrypointNDBuffer \ No newline at end of file +ndbuffer = package_with_entrypoint:TestEntrypointNDBuffer +[zarr.buffer] +test_buffer = package_with_entrypoint:TestEntrypointBuffer +another_buffer = package_with_entrypoint:AnotherTestEntrypointBuffer \ No newline at end of file diff --git a/tests/v3/package_with_entrypoint/__init__.py b/tests/v3/package_with_entrypoint/__init__.py index 0b3adf508c..e12cb948b7 100644 --- a/tests/v3/package_with_entrypoint/__init__.py +++ b/tests/v3/package_with_entrypoint/__init__.py @@ -46,5 +46,9 @@ class TestEntrypointBuffer(Buffer): pass +class AnotherTestEntrypointBuffer(Buffer): + pass + + class TestEntrypointNDBuffer(NDBuffer): pass From a13e7deb71ebc4f99144ca31f7e8d3e02e5de695 Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Tue, 9 Jul 2024 15:15:21 +0200 Subject: [PATCH 33/36] add tests for multiple implementations per entry_point --- src/zarr/registry.py | 44 +++++++++---------- .../entry_points.txt | 10 ++++- tests/v3/package_with_entrypoint/__init__.py | 17 +++++-- tests/v3/test_codec_entrypoints.py | 16 ++++--- 4 files changed, 53 insertions(+), 34 deletions(-) diff --git a/src/zarr/registry.py b/src/zarr/registry.py index 2f629956ed..75ea01024f 100644 --- a/src/zarr/registry.py +++ b/src/zarr/registry.py @@ -2,13 +2,13 @@ import warnings from collections import defaultdict -from typing import TYPE_CHECKING, Any, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast if TYPE_CHECKING: from zarr.abc.codec import Codec, CodecPipeline from zarr.buffer import Buffer, NDBuffer -from importlib.metadata import EntryPoint +from importlib.metadata import EntryPoint, EntryPoints from importlib.metadata import entry_points as get_entry_points from zarr.config import BadConfigError, config @@ -21,6 +21,14 @@ def __init__(self) -> None: super().__init__() self.lazy_load_list: list[EntryPoint] = [] + def lazy_load(self) -> None: + for e in self.lazy_load_list: + self.register(e.load()) + self.lazy_load_list.clear() + + def register(self, cls: type[T]) -> None: + self[fully_qualified_name(cls)] = cls + __codec_registries: dict[str, Registry[Codec]] = defaultdict(Registry) __pipeline_registry: Registry[CodecPipeline] = Registry() @@ -36,8 +44,7 @@ def __init__(self) -> None: def _collect_entrypoints() -> list[Registry[Any]]: """Collects codecs, pipelines, buffers and ndbuffers from entrypoints""" - entry_points = get_entry_points() - for e in entry_points: + for e in cast(EntryPoints, get_entry_points()): if e.matches(group="zarr", name="codec_pipeline") or e.matches(group="zarr.codec_pipeline"): __pipeline_registry.lazy_load_list.append(e) if e.matches(group="zarr", name="buffer") or e.matches(group="zarr.buffer"): @@ -67,21 +74,21 @@ def fully_qualified_name(cls: type) -> str: def register_codec(key: str, codec_cls: type[Codec]) -> None: - registry = __codec_registries.get(key, Registry()) - registry[fully_qualified_name(codec_cls)] = codec_cls - __codec_registries[key] = registry + if key not in __codec_registries.keys(): + __codec_registries[key] = Registry() + __codec_registries[key].register(codec_cls) def register_pipeline(pipe_cls: type[CodecPipeline]) -> None: - __pipeline_registry[fully_qualified_name(pipe_cls)] = pipe_cls + __pipeline_registry.register(pipe_cls) def register_ndbuffer(cls: type[NDBuffer]) -> None: - __ndbuffer_registry[fully_qualified_name(cls)] = cls + __ndbuffer_registry.register(cls) def register_buffer(cls: type[Buffer]) -> None: - __buffer_registry[fully_qualified_name(cls)] = cls + __buffer_registry.register(cls) def get_codec_class(key: str, reload_config: bool = False) -> type[Codec]: @@ -90,9 +97,7 @@ def get_codec_class(key: str, reload_config: bool = False) -> type[Codec]: if key in __codec_registries: # logger.debug("Auto loading codec '%s' from entrypoint", codec_id) - for lazy_load_item in __codec_registries[key].lazy_load_list: - cls = lazy_load_item.load() - register_codec(key, cls) + __codec_registries[key].lazy_load() codec_classes = __codec_registries[key] if not codec_classes: @@ -114,9 +119,7 @@ def get_codec_class(key: str, reload_config: bool = False) -> type[Codec]: def get_pipeline_class(reload_config: bool = False) -> type[CodecPipeline]: if reload_config: _reload_config() - for e in __pipeline_registry.lazy_load_list: - __pipeline_registry.lazy_load_list.remove(e) - register_pipeline(e.load()) + __pipeline_registry.lazy_load() path = config.get("codec_pipeline.path") pipeline_class = __pipeline_registry.get(path) if pipeline_class: @@ -129,9 +132,8 @@ def get_pipeline_class(reload_config: bool = False) -> type[CodecPipeline]: def get_buffer_class(reload_config: bool = False) -> type[Buffer]: if reload_config: _reload_config() - for e in __buffer_registry.lazy_load_list: - __buffer_registry.lazy_load_list.remove(e) - register_buffer(e.load()) + __buffer_registry.lazy_load() + path = config.get("buffer") buffer_class = __buffer_registry.get(path) if buffer_class: @@ -144,9 +146,7 @@ def get_buffer_class(reload_config: bool = False) -> type[Buffer]: def get_ndbuffer_class(reload_config: bool = False) -> type[NDBuffer]: if reload_config: _reload_config() - for e in __ndbuffer_registry.lazy_load_list: - __ndbuffer_registry.lazy_load_list.remove(e) - register_ndbuffer(e.load()) + __ndbuffer_registry.lazy_load() path = config.get("ndbuffer") ndbuffer_class = __ndbuffer_registry.get(path) if ndbuffer_class: diff --git a/tests/v3/package_with_entrypoint-0.1.dist-info/entry_points.txt b/tests/v3/package_with_entrypoint-0.1.dist-info/entry_points.txt index 76b98f2283..16abbca0b8 100644 --- a/tests/v3/package_with_entrypoint-0.1.dist-info/entry_points.txt +++ b/tests/v3/package_with_entrypoint-0.1.dist-info/entry_points.txt @@ -3,6 +3,12 @@ test = package_with_entrypoint:TestEntrypointCodec [zarr] codec_pipeline = package_with_entrypoint:TestEntrypointCodecPipeline ndbuffer = package_with_entrypoint:TestEntrypointNDBuffer +buffer = package_with_entrypoint:TestEntrypointBuffer +[zarr.codecs.test] +another_codec = package_with_entrypoint:TestEntrypointGroup.Codec [zarr.buffer] -test_buffer = package_with_entrypoint:TestEntrypointBuffer -another_buffer = package_with_entrypoint:AnotherTestEntrypointBuffer \ No newline at end of file +another_buffer = package_with_entrypoint:TestEntrypointGroup.Buffer +[zarr.ndbuffer] +another_ndbuffer = package_with_entrypoint:TestEntrypointGroup.NDBuffer +[zarr.codec_pipeline] +another_pipeline = package_with_entrypoint:TestEntrypointGroup.Pipeline diff --git a/tests/v3/package_with_entrypoint/__init__.py b/tests/v3/package_with_entrypoint/__init__.py index e12cb948b7..4d626808d8 100644 --- a/tests/v3/package_with_entrypoint/__init__.py +++ b/tests/v3/package_with_entrypoint/__init__.py @@ -5,6 +5,7 @@ from zarr.abc.codec import ArrayBytesCodec, CodecInput, CodecPipeline from zarr.array_spec import ArraySpec from zarr.buffer import Buffer, NDBuffer +from zarr.codecs import BytesCodec from zarr.common import BytesLike @@ -46,9 +47,19 @@ class TestEntrypointBuffer(Buffer): pass -class AnotherTestEntrypointBuffer(Buffer): +class TestEntrypointNDBuffer(NDBuffer): pass -class TestEntrypointNDBuffer(NDBuffer): - pass +class TestEntrypointGroup: + class Codec(BytesCodec): + pass + + class Buffer(Buffer): + pass + + class NDBuffer(NDBuffer): + pass + + class Pipeline(CodecPipeline): + pass diff --git a/tests/v3/test_codec_entrypoints.py b/tests/v3/test_codec_entrypoints.py index ff81019199..9e2932fdd5 100644 --- a/tests/v3/test_codec_entrypoints.py +++ b/tests/v3/test_codec_entrypoints.py @@ -22,10 +22,11 @@ def set_path(): @pytest.mark.usefixtures("set_path") -def test_entrypoint_codec(): - config.set({"codecs.test": "package_with_entrypoint.TestEntrypointCodec"}) - cls = zarr.registry.get_codec_class("test") - assert cls.__name__ == "TestEntrypointCodec" +@pytest.mark.parametrize("codec_name", ["TestEntrypointCodec", "TestEntrypointGroup.Codec"]) +def test_entrypoint_codec(codec_name): + config.set({"codecs.test": "package_with_entrypoint." + codec_name}) + cls_test = zarr.registry.get_codec_class("test") + assert cls_test.__qualname__ == codec_name @pytest.mark.usefixtures("set_path") @@ -36,12 +37,13 @@ def test_entrypoint_pipeline(): @pytest.mark.usefixtures("set_path") -def test_entrypoint_buffer(): +@pytest.mark.parametrize("buffer_name", ["TestEntrypointBuffer", "TestEntrypointGroup.Buffer"]) +def test_entrypoint_buffer(buffer_name): config.set( { - "buffer": "package_with_entrypoint.TestEntrypointBuffer", + "buffer": "package_with_entrypoint." + buffer_name, "ndbuffer": "package_with_entrypoint.TestEntrypointNDBuffer", } ) - assert zarr.registry.get_buffer_class().__name__ == "TestEntrypointBuffer" + assert zarr.registry.get_buffer_class().__qualname__ == buffer_name assert zarr.registry.get_ndbuffer_class().__name__ == "TestEntrypointNDBuffer" From 56335e41313508aa98daf503c96757b6cdec91a8 Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Wed, 10 Jul 2024 11:21:48 +0200 Subject: [PATCH 34/36] fix DeprecationWarning: SelectableGroups in registry.py --- src/zarr/registry.py | 52 +++++++++++++------ .../entry_points.txt | 4 +- 2 files changed, 39 insertions(+), 17 deletions(-) diff --git a/src/zarr/registry.py b/src/zarr/registry.py index 75ea01024f..36003720bf 100644 --- a/src/zarr/registry.py +++ b/src/zarr/registry.py @@ -2,13 +2,13 @@ import warnings from collections import defaultdict -from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, TypeVar if TYPE_CHECKING: from zarr.abc.codec import Codec, CodecPipeline from zarr.buffer import Buffer, NDBuffer -from importlib.metadata import EntryPoint, EntryPoints +from importlib.metadata import EntryPoint from importlib.metadata import entry_points as get_entry_points from zarr.config import BadConfigError, config @@ -43,19 +43,41 @@ def register(self, cls: type[T]) -> None: def _collect_entrypoints() -> list[Registry[Any]]: - """Collects codecs, pipelines, buffers and ndbuffers from entrypoints""" - for e in cast(EntryPoints, get_entry_points()): - if e.matches(group="zarr", name="codec_pipeline") or e.matches(group="zarr.codec_pipeline"): - __pipeline_registry.lazy_load_list.append(e) - if e.matches(group="zarr", name="buffer") or e.matches(group="zarr.buffer"): - __buffer_registry.lazy_load_list.append(e) - if e.matches(group="zarr", name="ndbuffer") or e.matches(group="zarr.ndbuffer"): - __ndbuffer_registry.lazy_load_list.append(e) - if e.matches(group="zarr.codecs"): - __codec_registries[e.name].lazy_load_list.append(e) - if e.group.startswith("zarr.codecs."): - codec_name = e.group.split(".")[2] - __codec_registries[codec_name].lazy_load_list.append(e) + """ + Collects codecs, pipelines, buffers and ndbuffers from entrypoints. + Allowed syntax for entry_points.txt is e.g. + + [zarr.codecs] + gzip = package:EntrypointGzipCodec1 + [zarr.codecs.gzip] + some_name = package:EntrypointGzipCodec2 + another = package:EntrypointGzipCodec3 + + [zarr] + buffer = package:TestBuffer1 + [zarr.buffer] + xyz = package:TestBuffer2 + abc = package:TestBuffer3 + ... + """ + entry_points = get_entry_points() + __pipeline_registry.lazy_load_list.extend( + entry_points.select(group="zarr", name="codec_pipeline") + + entry_points.select(group="zarr.codec_pipeline") + ) + __buffer_registry.lazy_load_list.extend( + entry_points.select(group="zarr", name="buffer") + entry_points.select(group="zarr.buffer") + ) + __ndbuffer_registry.lazy_load_list.extend( + entry_points.select(group="zarr", name="ndbuffer") + + entry_points.select(group="zarr.ndbuffer") + ) + for e in entry_points.select(group="zarr.codecs"): + __codec_registries[e.name].lazy_load_list.append(e) + for group in entry_points.groups: + if group.startswith("zarr.codecs."): + codec_name = group.split(".")[2] + __codec_registries[codec_name].lazy_load_list.extend(entry_points.select(group=group)) return [ *__codec_registries.values(), __pipeline_registry, diff --git a/tests/v3/package_with_entrypoint-0.1.dist-info/entry_points.txt b/tests/v3/package_with_entrypoint-0.1.dist-info/entry_points.txt index 16abbca0b8..eee724c912 100644 --- a/tests/v3/package_with_entrypoint-0.1.dist-info/entry_points.txt +++ b/tests/v3/package_with_entrypoint-0.1.dist-info/entry_points.txt @@ -1,11 +1,11 @@ [zarr.codecs] test = package_with_entrypoint:TestEntrypointCodec +[zarr.codecs.test] +another_codec = package_with_entrypoint:TestEntrypointGroup.Codec [zarr] codec_pipeline = package_with_entrypoint:TestEntrypointCodecPipeline ndbuffer = package_with_entrypoint:TestEntrypointNDBuffer buffer = package_with_entrypoint:TestEntrypointBuffer -[zarr.codecs.test] -another_codec = package_with_entrypoint:TestEntrypointGroup.Codec [zarr.buffer] another_buffer = package_with_entrypoint:TestEntrypointGroup.Buffer [zarr.ndbuffer] From ca27b1da38638c877999c435917cf42bb7254efe Mon Sep 17 00:00:00 2001 From: brokkoli71 Date: Wed, 10 Jul 2024 11:29:39 +0200 Subject: [PATCH 35/36] fix DeprecationWarning: EntryPoints list interface in registry.py --- src/zarr/registry.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/zarr/registry.py b/src/zarr/registry.py index 36003720bf..605b4ee085 100644 --- a/src/zarr/registry.py +++ b/src/zarr/registry.py @@ -61,16 +61,14 @@ def _collect_entrypoints() -> list[Registry[Any]]: ... """ entry_points = get_entry_points() + + __buffer_registry.lazy_load_list.extend(entry_points.select(group="zarr.buffer")) + __buffer_registry.lazy_load_list.extend(entry_points.select(group="zarr", name="buffer")) + __ndbuffer_registry.lazy_load_list.extend(entry_points.select(group="zarr.ndbuffer")) + __ndbuffer_registry.lazy_load_list.extend(entry_points.select(group="zarr", name="ndbuffer")) + __pipeline_registry.lazy_load_list.extend(entry_points.select(group="zarr.codec_pipeline")) __pipeline_registry.lazy_load_list.extend( entry_points.select(group="zarr", name="codec_pipeline") - + entry_points.select(group="zarr.codec_pipeline") - ) - __buffer_registry.lazy_load_list.extend( - entry_points.select(group="zarr", name="buffer") + entry_points.select(group="zarr.buffer") - ) - __ndbuffer_registry.lazy_load_list.extend( - entry_points.select(group="zarr", name="ndbuffer") - + entry_points.select(group="zarr.ndbuffer") ) for e in entry_points.select(group="zarr.codecs"): __codec_registries[e.name].lazy_load_list.append(e) From d470ec69647a75095e23500d633e01e039e8187b Mon Sep 17 00:00:00 2001 From: Hannes Spitz <44113112+brokkoli71@users.noreply.github.com> Date: Wed, 10 Jul 2024 17:38:23 +0200 Subject: [PATCH 36/36] clarify _collect_entrypoints docstring Co-authored-by: Norman Rzepka --- src/zarr/registry.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/zarr/registry.py b/src/zarr/registry.py index 605b4ee085..ac373f401d 100644 --- a/src/zarr/registry.py +++ b/src/zarr/registry.py @@ -45,6 +45,7 @@ def register(self, cls: type[T]) -> None: def _collect_entrypoints() -> list[Registry[Any]]: """ Collects codecs, pipelines, buffers and ndbuffers from entrypoints. + Entry points can either be single items or groups of items. Allowed syntax for entry_points.txt is e.g. [zarr.codecs]