Skip to content

Fix generic typing in zarr.codecs #1845

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,6 @@ check_untyped_defs = false
module = [
"zarr.v2.*",
"zarr.abc.codec",
"zarr.codecs.bytes",
"zarr.codecs.pipeline",
"zarr.codecs.sharding",
"zarr.codecs.transpose",
"zarr.array_v2",
"zarr.array",
"zarr.sync",
Expand Down
9 changes: 5 additions & 4 deletions src/zarr/codecs/bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from enum import Enum
import sys

from typing import TYPE_CHECKING, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

import numpy as np
import numpy.typing as npt

from zarr.abc.codec import ArrayBytesCodec
from zarr.codecs.registry import register_codec
Expand Down Expand Up @@ -60,7 +61,7 @@ def evolve(self, array_spec: ArraySpec) -> Self:
)
return self

def _get_byteorder(self, array: np.ndarray) -> Endian:
def _get_byteorder(self, array: npt.NDArray[Any]) -> Endian:
if array.dtype.byteorder == "<":
return Endian.little
elif array.dtype.byteorder == ">":
Expand All @@ -73,7 +74,7 @@ async def decode(
chunk_bytes: BytesLike,
chunk_spec: ArraySpec,
_runtime_configuration: RuntimeConfiguration,
) -> np.ndarray:
) -> npt.NDArray[Any]:
if chunk_spec.dtype.itemsize > 0:
if self.endian == Endian.little:
prefix = "<"
Expand All @@ -93,7 +94,7 @@ async def decode(

async def encode(
self,
chunk_array: np.ndarray,
chunk_array: npt.NDArray[Any],
_chunk_spec: ArraySpec,
_runtime_configuration: RuntimeConfiguration,
) -> Optional[BytesLike]:
Expand Down
12 changes: 6 additions & 6 deletions src/zarr/codecs/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Iterable
import numpy as np
from typing import TYPE_CHECKING, Any, Iterable
import numpy.typing as npt
from dataclasses import dataclass
from warnings import warn

Expand Down Expand Up @@ -152,7 +152,7 @@ async def decode(
chunk_bytes: BytesLike,
array_spec: ArraySpec,
runtime_configuration: RuntimeConfiguration,
) -> np.ndarray:
) -> npt.NDArray[Any]:
(
aa_codecs_with_spec,
ab_codec_with_spec,
Expand All @@ -176,7 +176,7 @@ async def decode_partial(
selection: SliceSelection,
chunk_spec: ArraySpec,
runtime_configuration: RuntimeConfiguration,
) -> Optional[np.ndarray]:
) -> Optional[npt.NDArray[Any]]:
assert self.supports_partial_decode
assert isinstance(self.array_bytes_codec, ArrayBytesCodecPartialDecodeMixin)
return await self.array_bytes_codec.decode_partial(
Expand All @@ -185,7 +185,7 @@ async def decode_partial(

async def encode(
self,
chunk_array: np.ndarray,
chunk_array: npt.NDArray[Any],
array_spec: ArraySpec,
runtime_configuration: RuntimeConfiguration,
) -> Optional[BytesLike]:
Expand Down Expand Up @@ -222,7 +222,7 @@ async def encode(
async def encode_partial(
self,
store_path: StorePath,
chunk_array: np.ndarray,
chunk_array: npt.NDArray[Any],
selection: SliceSelection,
chunk_spec: ArraySpec,
runtime_configuration: RuntimeConfiguration,
Expand Down
32 changes: 18 additions & 14 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, Iterable, Mapping, NamedTuple, Union
from typing import TYPE_CHECKING, Any, Iterable, Mapping, NamedTuple, Union, Optional
from dataclasses import dataclass, replace
from functools import lru_cache


import numpy as np
import numpy.typing as npt
from zarr.abc.codec import (
Codec,
ArrayBytesCodec,
Expand All @@ -18,7 +19,9 @@
from zarr.codecs.registry import register_codec
from zarr.common import (
ArraySpec,
BytesLike,
ChunkCoordsLike,
ChunkCoords,
concurrent_map,
parse_enum,
parse_named_configuration,
Expand All @@ -39,14 +42,12 @@
)

if TYPE_CHECKING:
from typing import Awaitable, Callable, Dict, Iterator, List, Optional, Set, Tuple
from typing import Awaitable, Callable, Dict, Iterator, List, Set, Tuple
from typing_extensions import Self

from zarr.store import StorePath
from zarr.common import (
JSON,
ChunkCoords,
BytesLike,
SliceSelection,
)
from zarr.config import RuntimeConfiguration
Expand All @@ -65,7 +66,7 @@ def parse_index_location(data: JSON) -> ShardingCodecIndexLocation:

class _ShardIndex(NamedTuple):
# dtype uint64, shape (chunks_per_shard_0, chunks_per_shard_1, ..., 2)
offsets_and_lengths: np.ndarray
offsets_and_lengths: npt.NDArray[np.uint64]

@property
def chunks_per_shard(self) -> ChunkCoords:
Expand Down Expand Up @@ -126,7 +127,10 @@ def create_empty(cls, chunks_per_shard: ChunkCoords) -> _ShardIndex:
return cls(offsets_and_lengths)


class _ShardProxy(Mapping):
_ShardMapping = Mapping[ChunkCoords, Optional[BytesLike]]


class _ShardProxy(_ShardMapping):
index: _ShardIndex
buf: BytesLike

Expand Down Expand Up @@ -175,7 +179,7 @@ def merge_with_morton_order(
cls,
chunks_per_shard: ChunkCoords,
tombstones: Set[ChunkCoords],
*shard_dicts: Mapping[ChunkCoords, BytesLike],
*shard_dicts: _ShardMapping,
) -> _ShardBuilder:
obj = cls.create_empty(chunks_per_shard)
for chunk_coords in morton_order_iter(chunks_per_shard):
Expand Down Expand Up @@ -303,7 +307,7 @@ async def decode(
shard_bytes: BytesLike,
shard_spec: ArraySpec,
runtime_configuration: RuntimeConfiguration,
) -> np.ndarray:
) -> npt.NDArray[Any]:
# print("decode")
shard_shape = shard_spec.shape
chunk_shape = self.chunk_shape
Expand Down Expand Up @@ -353,7 +357,7 @@ async def decode_partial(
selection: SliceSelection,
shard_spec: ArraySpec,
runtime_configuration: RuntimeConfiguration,
) -> Optional[np.ndarray]:
) -> Optional[npt.NDArray[Any]]:
shard_shape = shard_spec.shape
chunk_shape = self.chunk_shape
chunks_per_shard = self._get_chunks_per_shard(shard_spec)
Expand All @@ -375,7 +379,7 @@ async def decode_partial(
all_chunk_coords = set(chunk_coords for chunk_coords, _, _ in indexed_chunks)

# reading bytes of all requested chunks
shard_dict: Mapping[ChunkCoords, BytesLike] = {}
shard_dict: _ShardMapping = {}
if self._is_total_shard(all_chunk_coords, chunks_per_shard):
# read entire shard
shard_dict_maybe = await self._load_full_shard_maybe(store_path, chunks_per_shard)
Expand Down Expand Up @@ -423,7 +427,7 @@ async def _read_chunk(
out_selection: SliceSelection,
shard_spec: ArraySpec,
runtime_configuration: RuntimeConfiguration,
out: np.ndarray,
out: npt.NDArray[Any],
) -> None:
chunk_spec = self._get_chunk_spec(shard_spec)
chunk_bytes = shard_dict.get(chunk_coords, None)
Expand All @@ -436,7 +440,7 @@ async def _read_chunk(

async def encode(
self,
shard_array: np.ndarray,
shard_array: npt.NDArray[Any],
shard_spec: ArraySpec,
runtime_configuration: RuntimeConfiguration,
) -> Optional[BytesLike]:
Expand All @@ -453,7 +457,7 @@ async def encode(
)

async def _write_chunk(
shard_array: np.ndarray,
shard_array: npt.NDArray[Any],
chunk_coords: ChunkCoords,
chunk_selection: SliceSelection,
out_selection: SliceSelection,
Expand Down Expand Up @@ -498,7 +502,7 @@ async def _write_chunk(
async def encode_partial(
self,
store_path: StorePath,
shard_array: np.ndarray,
shard_array: npt.NDArray[Any],
selection: SliceSelection,
shard_spec: ArraySpec,
runtime_configuration: RuntimeConfiguration,
Expand Down
12 changes: 6 additions & 6 deletions src/zarr/codecs/transpose.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Dict, Iterable, Union, cast
from typing import TYPE_CHECKING, Any, Dict, Iterable, Union, cast

from dataclasses import dataclass, replace

Expand All @@ -10,7 +10,7 @@
from typing import TYPE_CHECKING, Optional, Tuple
from typing_extensions import Self

import numpy as np
import numpy.typing as npt

from zarr.abc.codec import ArrayArrayCodec
from zarr.codecs.registry import register_codec
Expand Down Expand Up @@ -75,10 +75,10 @@ def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:

async def decode(
self,
chunk_array: np.ndarray,
chunk_array: npt.NDArray[Any],
chunk_spec: ArraySpec,
_runtime_configuration: RuntimeConfiguration,
) -> np.ndarray:
) -> npt.NDArray[Any]:
inverse_order = [0] * chunk_spec.ndim
for x, i in enumerate(self.order):
inverse_order[x] = i
Expand All @@ -87,10 +87,10 @@ async def decode(

async def encode(
self,
chunk_array: np.ndarray,
chunk_array: npt.NDArray[Any],
chunk_spec: ArraySpec,
_runtime_configuration: RuntimeConfiguration,
) -> Optional[np.ndarray]:
) -> Optional[npt.NDArray[Any]]:
chunk_array = chunk_array.transpose(self.order)
return chunk_array

Expand Down
Loading