Skip to content

Commit 2a4e1d4

Browse files
committed
Fix generic typing in zarr.codecs
1 parent c1323c4 commit 2a4e1d4

File tree

5 files changed

+33
-32
lines changed

5 files changed

+33
-32
lines changed

pyproject.toml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,6 @@ check_untyped_defs = false
187187
[[tool.mypy.overrides]]
188188
module = [
189189
"zarr.abc.codec",
190-
"zarr.codecs.bytes",
191-
"zarr.codecs.pipeline",
192-
"zarr.codecs.sharding",
193-
"zarr.codecs.transpose",
194190
"zarr.array_v2",
195191
"zarr.array",
196192
"zarr.sync",

src/zarr/codecs/bytes.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
from enum import Enum
44
import sys
55

6-
from typing import TYPE_CHECKING, Dict, Optional, Union
6+
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
77

88
import numpy as np
9+
import numpy.typing as npt
910

1011
from zarr.abc.codec import ArrayBytesCodec
1112
from zarr.codecs.registry import register_codec
@@ -60,7 +61,7 @@ def evolve(self, array_spec: ArraySpec) -> Self:
6061
)
6162
return self
6263

63-
def _get_byteorder(self, array: np.ndarray) -> Endian:
64+
def _get_byteorder(self, array: npt.NDArray[Any]) -> Endian:
6465
if array.dtype.byteorder == "<":
6566
return Endian.little
6667
elif array.dtype.byteorder == ">":
@@ -73,7 +74,7 @@ async def decode(
7374
chunk_bytes: BytesLike,
7475
chunk_spec: ArraySpec,
7576
_runtime_configuration: RuntimeConfiguration,
76-
) -> np.ndarray:
77+
) -> npt.NDArray[Any]:
7778
if chunk_spec.dtype.itemsize > 0:
7879
if self.endian == Endian.little:
7980
prefix = "<"
@@ -94,7 +95,7 @@ async def decode(
9495

9596
async def encode(
9697
self,
97-
chunk_array: np.ndarray,
98+
chunk_array: npt.NDArray[Any],
9899
_chunk_spec: ArraySpec,
99100
_runtime_configuration: RuntimeConfiguration,
100101
) -> Optional[BytesLike]:

src/zarr/codecs/pipeline.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Iterable
4-
import numpy as np
3+
from typing import TYPE_CHECKING, Any, Iterable
4+
import numpy.typing as npt
55
from dataclasses import dataclass
66
from warnings import warn
77

@@ -152,7 +152,7 @@ async def decode(
152152
chunk_bytes: BytesLike,
153153
array_spec: ArraySpec,
154154
runtime_configuration: RuntimeConfiguration,
155-
) -> np.ndarray:
155+
) -> npt.NDArray[Any]:
156156
(
157157
aa_codecs_with_spec,
158158
ab_codec_with_spec,
@@ -176,7 +176,7 @@ async def decode_partial(
176176
selection: SliceSelection,
177177
chunk_spec: ArraySpec,
178178
runtime_configuration: RuntimeConfiguration,
179-
) -> Optional[np.ndarray]:
179+
) -> Optional[npt.NDArray[Any]]:
180180
assert self.supports_partial_decode
181181
assert isinstance(self.array_bytes_codec, ArrayBytesCodecPartialDecodeMixin)
182182
return await self.array_bytes_codec.decode_partial(
@@ -185,7 +185,7 @@ async def decode_partial(
185185

186186
async def encode(
187187
self,
188-
chunk_array: np.ndarray,
188+
chunk_array: npt.NDArray[Any],
189189
array_spec: ArraySpec,
190190
runtime_configuration: RuntimeConfiguration,
191191
) -> Optional[BytesLike]:
@@ -222,7 +222,7 @@ async def encode(
222222
async def encode_partial(
223223
self,
224224
store_path: StorePath,
225-
chunk_array: np.ndarray,
225+
chunk_array: npt.NDArray[Any],
226226
selection: SliceSelection,
227227
chunk_spec: ArraySpec,
228228
runtime_configuration: RuntimeConfiguration,

src/zarr/codecs/sharding.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from __future__ import annotations
22
from enum import Enum
3-
from typing import TYPE_CHECKING, Iterable, Mapping, NamedTuple, Union
3+
from typing import TYPE_CHECKING, Any, Iterable, Mapping, NamedTuple, Union
44
from dataclasses import dataclass, replace
55
from functools import lru_cache
66

77

88
import numpy as np
9+
import numpy.typing as npt
910
from zarr.abc.codec import (
1011
Codec,
1112
ArrayBytesCodec,
@@ -19,6 +20,7 @@
1920
from zarr.common import (
2021
ArraySpec,
2122
ChunkCoordsLike,
23+
ChunkCoords,
2224
concurrent_map,
2325
parse_enum,
2426
parse_named_configuration,
@@ -45,7 +47,6 @@
4547
from zarr.store import StorePath
4648
from zarr.common import (
4749
JSON,
48-
ChunkCoords,
4950
BytesLike,
5051
SliceSelection,
5152
)
@@ -65,7 +66,7 @@ def parse_index_location(data: JSON) -> ShardingCodecIndexLocation:
6566

6667
class _ShardIndex(NamedTuple):
6768
# dtype uint64, shape (chunks_per_shard_0, chunks_per_shard_1, ..., 2)
68-
offsets_and_lengths: np.ndarray
69+
offsets_and_lengths: npt.NDArray[np.uint64]
6970

7071
@property
7172
def chunks_per_shard(self) -> ChunkCoords:
@@ -126,7 +127,10 @@ def create_empty(cls, chunks_per_shard: ChunkCoords) -> _ShardIndex:
126127
return cls(offsets_and_lengths)
127128

128129

129-
class _ShardProxy(Mapping):
130+
_ShardMapping = Mapping[ChunkCoords, Optional[BytesLike]]
131+
132+
133+
class _ShardProxy(_ShardMapping):
130134
index: _ShardIndex
131135
buf: BytesLike
132136

@@ -175,7 +179,7 @@ def merge_with_morton_order(
175179
cls,
176180
chunks_per_shard: ChunkCoords,
177181
tombstones: Set[ChunkCoords],
178-
*shard_dicts: Mapping[ChunkCoords, BytesLike],
182+
*shard_dicts: _ShardMapping,
179183
) -> _ShardBuilder:
180184
obj = cls.create_empty(chunks_per_shard)
181185
for chunk_coords in morton_order_iter(chunks_per_shard):
@@ -303,7 +307,7 @@ async def decode(
303307
shard_bytes: BytesLike,
304308
shard_spec: ArraySpec,
305309
runtime_configuration: RuntimeConfiguration,
306-
) -> np.ndarray:
310+
) -> npt.NDArray[Any]:
307311
# print("decode")
308312
shard_shape = shard_spec.shape
309313
chunk_shape = self.chunk_shape
@@ -353,7 +357,7 @@ async def decode_partial(
353357
selection: SliceSelection,
354358
shard_spec: ArraySpec,
355359
runtime_configuration: RuntimeConfiguration,
356-
) -> Optional[np.ndarray]:
360+
) -> Optional[npt.NDArray[Any]]:
357361
shard_shape = shard_spec.shape
358362
chunk_shape = self.chunk_shape
359363
chunks_per_shard = self._get_chunks_per_shard(shard_spec)
@@ -375,7 +379,7 @@ async def decode_partial(
375379
all_chunk_coords = set(chunk_coords for chunk_coords, _, _ in indexed_chunks)
376380

377381
# reading bytes of all requested chunks
378-
shard_dict: Mapping[ChunkCoords, BytesLike] = {}
382+
shard_dict: _ShardMapping = {}
379383
if self._is_total_shard(all_chunk_coords, chunks_per_shard):
380384
# read entire shard
381385
shard_dict_maybe = await self._load_full_shard_maybe(store_path, chunks_per_shard)
@@ -423,7 +427,7 @@ async def _read_chunk(
423427
out_selection: SliceSelection,
424428
shard_spec: ArraySpec,
425429
runtime_configuration: RuntimeConfiguration,
426-
out: np.ndarray,
430+
out: npt.NDArray[Any],
427431
) -> None:
428432
chunk_spec = self._get_chunk_spec(shard_spec)
429433
chunk_bytes = shard_dict.get(chunk_coords, None)
@@ -436,7 +440,7 @@ async def _read_chunk(
436440

437441
async def encode(
438442
self,
439-
shard_array: np.ndarray,
443+
shard_array: npt.NDArray[Any],
440444
shard_spec: ArraySpec,
441445
runtime_configuration: RuntimeConfiguration,
442446
) -> Optional[BytesLike]:
@@ -453,7 +457,7 @@ async def encode(
453457
)
454458

455459
async def _write_chunk(
456-
shard_array: np.ndarray,
460+
shard_array: npt.NDArray[Any],
457461
chunk_coords: ChunkCoords,
458462
chunk_selection: SliceSelection,
459463
out_selection: SliceSelection,
@@ -498,7 +502,7 @@ async def _write_chunk(
498502
async def encode_partial(
499503
self,
500504
store_path: StorePath,
501-
shard_array: np.ndarray,
505+
shard_array: npt.NDArray[Any],
502506
selection: SliceSelection,
503507
shard_spec: ArraySpec,
504508
runtime_configuration: RuntimeConfiguration,

src/zarr/codecs/transpose.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from __future__ import annotations
2-
from typing import TYPE_CHECKING, Dict, Iterable, Union, cast
2+
from typing import TYPE_CHECKING, Any, Dict, Iterable, Union, cast
33

44
from dataclasses import dataclass, replace
55

@@ -10,7 +10,7 @@
1010
from typing import TYPE_CHECKING, Optional, Tuple
1111
from typing_extensions import Self
1212

13-
import numpy as np
13+
import numpy.typing as npt
1414

1515
from zarr.abc.codec import ArrayArrayCodec
1616
from zarr.codecs.registry import register_codec
@@ -75,10 +75,10 @@ def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
7575

7676
async def decode(
7777
self,
78-
chunk_array: np.ndarray,
78+
chunk_array: npt.NDArray[Any],
7979
chunk_spec: ArraySpec,
8080
_runtime_configuration: RuntimeConfiguration,
81-
) -> np.ndarray:
81+
) -> npt.NDArray[Any]:
8282
inverse_order = [0] * chunk_spec.ndim
8383
for x, i in enumerate(self.order):
8484
inverse_order[x] = i
@@ -87,10 +87,10 @@ async def decode(
8787

8888
async def encode(
8989
self,
90-
chunk_array: np.ndarray,
90+
chunk_array: npt.NDArray[Any],
9191
chunk_spec: ArraySpec,
9292
_runtime_configuration: RuntimeConfiguration,
93-
) -> Optional[np.ndarray]:
93+
) -> Optional[npt.NDArray[Any]]:
9494
chunk_array = chunk_array.transpose(self.order)
9595
return chunk_array
9696

0 commit comments

Comments
 (0)