Skip to content

Commit 52d6849

Browse files
authored
refactor: split metadata into v2 and v3 modules (#2163)
* refactor: split metadata into v2 and v3 modules * add more explicit typeguards * port fill value normalization from v2 * remove v2 suffix from zarr format parsing * remove v2 suffix from zarr format parsing
1 parent c62294e commit 52d6849

File tree

9 files changed

+448
-385
lines changed

9 files changed

+448
-385
lines changed

src/zarr/api/asynchronous.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from zarr.core.array import Array, AsyncArray
1111
from zarr.core.common import JSON, AccessModeLiteral, ChunkCoords, MemoryOrder, ZarrFormat
1212
from zarr.core.group import AsyncGroup
13-
from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata
13+
from zarr.core.metadata.v2 import ArrayV2Metadata
14+
from zarr.core.metadata.v3 import ArrayV3Metadata
1415
from zarr.store import (
1516
StoreLike,
1617
make_store_path,

src/zarr/codecs/sharding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
get_indexer,
4545
morton_order_iter,
4646
)
47-
from zarr.core.metadata import parse_codecs
47+
from zarr.core.metadata.v3 import parse_codecs
4848
from zarr.registry import get_ndbuffer_class, get_pipeline_class, register_codec
4949

5050
if TYPE_CHECKING:

src/zarr/core/array.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@
5555
is_scalar,
5656
pop_fields,
5757
)
58-
from zarr.core.metadata import ArrayMetadata, ArrayV2Metadata, ArrayV3Metadata
58+
from zarr.core.metadata.v2 import ArrayV2Metadata
59+
from zarr.core.metadata.v3 import ArrayV3Metadata
5960
from zarr.core.sync import sync
6061
from zarr.registry import get_pipeline_class
6162
from zarr.store import StoreLike, StorePath, make_store_path
@@ -67,6 +68,7 @@
6768
from collections.abc import Iterable
6869

6970
from zarr.abc.codec import Codec, CodecPipeline
71+
from zarr.core.metadata.common import ArrayMetadata
7072

7173
# Array and AsyncArray are defined in the base ``zarr`` namespace
7274
__all__ = ["parse_array_metadata", "create_codec_pipeline"]

src/zarr/core/metadata/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .v2 import ArrayV2Metadata
2+
from .v3 import ArrayV3Metadata
3+
4+
__all__ = ["ArrayV2Metadata", "ArrayV3Metadata"]

src/zarr/core/metadata/common.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
if TYPE_CHECKING:
6+
from typing import Any, Literal
7+
8+
import numpy as np
9+
from typing_extensions import Self
10+
11+
from zarr.core.array_spec import ArraySpec
12+
from zarr.core.buffer import Buffer, BufferPrototype
13+
from zarr.core.chunk_grids import ChunkGrid
14+
from zarr.core.common import JSON, ChunkCoords, ZarrFormat
15+
16+
from abc import ABC, abstractmethod
17+
from dataclasses import dataclass
18+
19+
from zarr.abc.metadata import Metadata
20+
21+
22+
@dataclass(frozen=True, kw_only=True)
23+
class ArrayMetadata(Metadata, ABC):
24+
shape: ChunkCoords
25+
fill_value: Any
26+
chunk_grid: ChunkGrid
27+
attributes: dict[str, JSON]
28+
zarr_format: ZarrFormat
29+
30+
@property
31+
@abstractmethod
32+
def dtype(self) -> np.dtype[Any]:
33+
pass
34+
35+
@property
36+
@abstractmethod
37+
def ndim(self) -> int:
38+
pass
39+
40+
@abstractmethod
41+
def get_chunk_spec(
42+
self, _chunk_coords: ChunkCoords, order: Literal["C", "F"], prototype: BufferPrototype
43+
) -> ArraySpec:
44+
pass
45+
46+
@abstractmethod
47+
def encode_chunk_key(self, chunk_coords: ChunkCoords) -> str:
48+
pass
49+
50+
@abstractmethod
51+
def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]:
52+
pass
53+
54+
@abstractmethod
55+
def update_shape(self, shape: ChunkCoords) -> Self:
56+
pass
57+
58+
@abstractmethod
59+
def update_attributes(self, attributes: dict[str, JSON]) -> Self:
60+
pass
61+
62+
63+
def parse_attributes(data: None | dict[str, JSON]) -> dict[str, JSON]:
64+
if data is None:
65+
return {}
66+
67+
return data

src/zarr/core/metadata/v2.py

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
if TYPE_CHECKING:
6+
from typing import Any, Literal
7+
8+
import numpy.typing as npt
9+
from typing_extensions import Self
10+
11+
from zarr.core.buffer import Buffer, BufferPrototype
12+
from zarr.core.common import JSON, ChunkCoords
13+
14+
import json
15+
from dataclasses import dataclass, field, replace
16+
17+
import numpy as np
18+
19+
from zarr.core.array_spec import ArraySpec
20+
from zarr.core.chunk_grids import RegularChunkGrid
21+
from zarr.core.chunk_key_encodings import parse_separator
22+
from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, parse_dtype, parse_shapelike
23+
from zarr.core.config import config, parse_indexing_order
24+
from zarr.core.metadata.common import ArrayMetadata, parse_attributes
25+
26+
27+
@dataclass(frozen=True, kw_only=True)
28+
class ArrayV2Metadata(ArrayMetadata):
29+
shape: ChunkCoords
30+
chunk_grid: RegularChunkGrid
31+
data_type: np.dtype[Any]
32+
fill_value: None | int | float = 0
33+
order: Literal["C", "F"] = "C"
34+
filters: list[dict[str, JSON]] | None = None
35+
dimension_separator: Literal[".", "/"] = "."
36+
compressor: dict[str, JSON] | None = None
37+
attributes: dict[str, JSON] = field(default_factory=dict)
38+
zarr_format: Literal[2] = field(init=False, default=2)
39+
40+
def __init__(
41+
self,
42+
*,
43+
shape: ChunkCoords,
44+
dtype: npt.DTypeLike,
45+
chunks: ChunkCoords,
46+
fill_value: Any,
47+
order: Literal["C", "F"],
48+
dimension_separator: Literal[".", "/"] = ".",
49+
compressor: dict[str, JSON] | None = None,
50+
filters: list[dict[str, JSON]] | None = None,
51+
attributes: dict[str, JSON] | None = None,
52+
):
53+
"""
54+
Metadata for a Zarr version 2 array.
55+
"""
56+
shape_parsed = parse_shapelike(shape)
57+
data_type_parsed = parse_dtype(dtype)
58+
chunks_parsed = parse_shapelike(chunks)
59+
compressor_parsed = parse_compressor(compressor)
60+
order_parsed = parse_indexing_order(order)
61+
dimension_separator_parsed = parse_separator(dimension_separator)
62+
filters_parsed = parse_filters(filters)
63+
fill_value_parsed = parse_fill_value(fill_value, dtype=data_type_parsed)
64+
attributes_parsed = parse_attributes(attributes)
65+
66+
object.__setattr__(self, "shape", shape_parsed)
67+
object.__setattr__(self, "data_type", data_type_parsed)
68+
object.__setattr__(self, "chunk_grid", RegularChunkGrid(chunk_shape=chunks_parsed))
69+
object.__setattr__(self, "compressor", compressor_parsed)
70+
object.__setattr__(self, "order", order_parsed)
71+
object.__setattr__(self, "dimension_separator", dimension_separator_parsed)
72+
object.__setattr__(self, "filters", filters_parsed)
73+
object.__setattr__(self, "fill_value", fill_value_parsed)
74+
object.__setattr__(self, "attributes", attributes_parsed)
75+
76+
# ensure that the metadata document is consistent
77+
_ = parse_metadata(self)
78+
79+
@property
80+
def ndim(self) -> int:
81+
return len(self.shape)
82+
83+
@property
84+
def dtype(self) -> np.dtype[Any]:
85+
return self.data_type
86+
87+
@property
88+
def chunks(self) -> ChunkCoords:
89+
return self.chunk_grid.chunk_shape
90+
91+
def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]:
92+
def _json_convert(
93+
o: Any,
94+
) -> Any:
95+
if isinstance(o, np.dtype):
96+
if o.fields is None:
97+
return o.str
98+
else:
99+
return o.descr
100+
if np.isscalar(o):
101+
# convert numpy scalar to python type, and pass
102+
# python types through
103+
return getattr(o, "item", lambda: o)()
104+
raise TypeError
105+
106+
zarray_dict = self.to_dict()
107+
108+
# todo: remove this check when we can ensure that to_dict always returns dicts.
109+
if not isinstance(zarray_dict, dict):
110+
raise TypeError(f"Invalid type: got {type(zarray_dict)}, expected dict.")
111+
112+
zattrs_dict = zarray_dict.pop("attributes", {})
113+
json_indent = config.get("json_indent")
114+
return {
115+
ZARRAY_JSON: prototype.buffer.from_bytes(
116+
json.dumps(zarray_dict, default=_json_convert, indent=json_indent).encode()
117+
),
118+
ZATTRS_JSON: prototype.buffer.from_bytes(
119+
json.dumps(zattrs_dict, indent=json_indent).encode()
120+
),
121+
}
122+
123+
@classmethod
124+
def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata:
125+
# make a copy to protect the original from modification
126+
_data = data.copy()
127+
# check that the zarr_format attribute is correct
128+
_ = parse_zarr_format(_data.pop("zarr_format"))
129+
return cls(**_data)
130+
131+
def to_dict(self) -> JSON:
132+
zarray_dict = super().to_dict()
133+
134+
# todo: remove this check when we can ensure that to_dict always returns dicts.
135+
if not isinstance(zarray_dict, dict):
136+
raise TypeError(f"Invalid type: got {type(zarray_dict)}, expected dict.")
137+
138+
_ = zarray_dict.pop("chunk_grid")
139+
zarray_dict["chunks"] = self.chunk_grid.chunk_shape
140+
141+
_ = zarray_dict.pop("data_type")
142+
zarray_dict["dtype"] = self.data_type.str
143+
144+
return zarray_dict
145+
146+
def get_chunk_spec(
147+
self, _chunk_coords: ChunkCoords, order: Literal["C", "F"], prototype: BufferPrototype
148+
) -> ArraySpec:
149+
return ArraySpec(
150+
shape=self.chunk_grid.chunk_shape,
151+
dtype=self.dtype,
152+
fill_value=self.fill_value,
153+
order=order,
154+
prototype=prototype,
155+
)
156+
157+
def encode_chunk_key(self, chunk_coords: ChunkCoords) -> str:
158+
chunk_identifier = self.dimension_separator.join(map(str, chunk_coords))
159+
return "0" if chunk_identifier == "" else chunk_identifier
160+
161+
def update_shape(self, shape: ChunkCoords) -> Self:
162+
return replace(self, shape=shape)
163+
164+
def update_attributes(self, attributes: dict[str, JSON]) -> Self:
165+
return replace(self, attributes=attributes)
166+
167+
168+
def parse_zarr_format(data: Literal[2]) -> Literal[2]:
169+
if data == 2:
170+
return data
171+
raise ValueError(f"Invalid value. Expected 2. Got {data}.")
172+
173+
174+
def parse_filters(data: list[dict[str, JSON]] | None) -> list[dict[str, JSON]] | None:
175+
return data
176+
177+
178+
def parse_compressor(data: dict[str, JSON] | None) -> dict[str, JSON] | None:
179+
return data
180+
181+
182+
def parse_metadata(data: ArrayV2Metadata) -> ArrayV2Metadata:
183+
if (l_chunks := len(data.chunks)) != (l_shape := len(data.shape)):
184+
msg = (
185+
f"The `shape` and `chunks` attributes must have the same length. "
186+
f"`chunks` has length {l_chunks}, but `shape` has length {l_shape}."
187+
)
188+
raise ValueError(msg)
189+
return data
190+
191+
192+
def parse_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any:
193+
"""
194+
Parse a potential fill value into a value that is compatible with the provided dtype.
195+
196+
Parameters
197+
----------
198+
fill_value: Any
199+
A potential fill value.
200+
dtype: np.dtype[Any]
201+
A numpy dtype.
202+
203+
Returns
204+
An instance of `dtype`, or `None`, or any python object (in the case of an object dtype)
205+
"""
206+
207+
if fill_value is None or dtype.hasobject:
208+
# no fill value
209+
pass
210+
elif not isinstance(fill_value, np.void) and fill_value == 0:
211+
# this should be compatible across numpy versions for any array type, including
212+
# structured arrays
213+
fill_value = np.zeros((), dtype=dtype)[()]
214+
215+
elif dtype.kind == "U":
216+
# special case unicode because of encoding issues on Windows if passed through numpy
217+
# https://github.com/alimanfoo/zarr/pull/172#issuecomment-343782713
218+
219+
if not isinstance(fill_value, str):
220+
raise ValueError(
221+
f"fill_value {fill_value!r} is not valid for dtype {dtype}; must be a unicode string"
222+
)
223+
else:
224+
try:
225+
if isinstance(fill_value, bytes) and dtype.kind == "V":
226+
# special case for numpy 1.14 compatibility
227+
fill_value = np.array(fill_value, dtype=dtype.str).view(dtype)[()]
228+
else:
229+
fill_value = np.array(fill_value, dtype=dtype)[()]
230+
231+
except Exception as e:
232+
msg = f"Fill_value {fill_value} is not valid for dtype {dtype}."
233+
raise ValueError(msg) from e
234+
235+
return fill_value

0 commit comments

Comments
 (0)