|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from typing import TYPE_CHECKING |
| 4 | + |
| 5 | +if TYPE_CHECKING: |
| 6 | + from typing import Any, Literal |
| 7 | + |
| 8 | + import numpy.typing as npt |
| 9 | + from typing_extensions import Self |
| 10 | + |
| 11 | + from zarr.core.buffer import Buffer, BufferPrototype |
| 12 | + from zarr.core.common import JSON, ChunkCoords |
| 13 | + |
| 14 | +import json |
| 15 | +from dataclasses import dataclass, field, replace |
| 16 | + |
| 17 | +import numpy as np |
| 18 | + |
| 19 | +from zarr.core.array_spec import ArraySpec |
| 20 | +from zarr.core.chunk_grids import RegularChunkGrid |
| 21 | +from zarr.core.chunk_key_encodings import parse_separator |
| 22 | +from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, parse_dtype, parse_shapelike |
| 23 | +from zarr.core.config import config, parse_indexing_order |
| 24 | +from zarr.core.metadata.common import ArrayMetadata, parse_attributes |
| 25 | + |
| 26 | + |
| 27 | +@dataclass(frozen=True, kw_only=True) |
| 28 | +class ArrayV2Metadata(ArrayMetadata): |
| 29 | + shape: ChunkCoords |
| 30 | + chunk_grid: RegularChunkGrid |
| 31 | + data_type: np.dtype[Any] |
| 32 | + fill_value: None | int | float = 0 |
| 33 | + order: Literal["C", "F"] = "C" |
| 34 | + filters: list[dict[str, JSON]] | None = None |
| 35 | + dimension_separator: Literal[".", "/"] = "." |
| 36 | + compressor: dict[str, JSON] | None = None |
| 37 | + attributes: dict[str, JSON] = field(default_factory=dict) |
| 38 | + zarr_format: Literal[2] = field(init=False, default=2) |
| 39 | + |
| 40 | + def __init__( |
| 41 | + self, |
| 42 | + *, |
| 43 | + shape: ChunkCoords, |
| 44 | + dtype: npt.DTypeLike, |
| 45 | + chunks: ChunkCoords, |
| 46 | + fill_value: Any, |
| 47 | + order: Literal["C", "F"], |
| 48 | + dimension_separator: Literal[".", "/"] = ".", |
| 49 | + compressor: dict[str, JSON] | None = None, |
| 50 | + filters: list[dict[str, JSON]] | None = None, |
| 51 | + attributes: dict[str, JSON] | None = None, |
| 52 | + ): |
| 53 | + """ |
| 54 | + Metadata for a Zarr version 2 array. |
| 55 | + """ |
| 56 | + shape_parsed = parse_shapelike(shape) |
| 57 | + data_type_parsed = parse_dtype(dtype) |
| 58 | + chunks_parsed = parse_shapelike(chunks) |
| 59 | + compressor_parsed = parse_compressor(compressor) |
| 60 | + order_parsed = parse_indexing_order(order) |
| 61 | + dimension_separator_parsed = parse_separator(dimension_separator) |
| 62 | + filters_parsed = parse_filters(filters) |
| 63 | + fill_value_parsed = parse_fill_value(fill_value, dtype=data_type_parsed) |
| 64 | + attributes_parsed = parse_attributes(attributes) |
| 65 | + |
| 66 | + object.__setattr__(self, "shape", shape_parsed) |
| 67 | + object.__setattr__(self, "data_type", data_type_parsed) |
| 68 | + object.__setattr__(self, "chunk_grid", RegularChunkGrid(chunk_shape=chunks_parsed)) |
| 69 | + object.__setattr__(self, "compressor", compressor_parsed) |
| 70 | + object.__setattr__(self, "order", order_parsed) |
| 71 | + object.__setattr__(self, "dimension_separator", dimension_separator_parsed) |
| 72 | + object.__setattr__(self, "filters", filters_parsed) |
| 73 | + object.__setattr__(self, "fill_value", fill_value_parsed) |
| 74 | + object.__setattr__(self, "attributes", attributes_parsed) |
| 75 | + |
| 76 | + # ensure that the metadata document is consistent |
| 77 | + _ = parse_metadata(self) |
| 78 | + |
| 79 | + @property |
| 80 | + def ndim(self) -> int: |
| 81 | + return len(self.shape) |
| 82 | + |
| 83 | + @property |
| 84 | + def dtype(self) -> np.dtype[Any]: |
| 85 | + return self.data_type |
| 86 | + |
| 87 | + @property |
| 88 | + def chunks(self) -> ChunkCoords: |
| 89 | + return self.chunk_grid.chunk_shape |
| 90 | + |
| 91 | + def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: |
| 92 | + def _json_convert( |
| 93 | + o: Any, |
| 94 | + ) -> Any: |
| 95 | + if isinstance(o, np.dtype): |
| 96 | + if o.fields is None: |
| 97 | + return o.str |
| 98 | + else: |
| 99 | + return o.descr |
| 100 | + if np.isscalar(o): |
| 101 | + # convert numpy scalar to python type, and pass |
| 102 | + # python types through |
| 103 | + return getattr(o, "item", lambda: o)() |
| 104 | + raise TypeError |
| 105 | + |
| 106 | + zarray_dict = self.to_dict() |
| 107 | + |
| 108 | + # todo: remove this check when we can ensure that to_dict always returns dicts. |
| 109 | + if not isinstance(zarray_dict, dict): |
| 110 | + raise TypeError(f"Invalid type: got {type(zarray_dict)}, expected dict.") |
| 111 | + |
| 112 | + zattrs_dict = zarray_dict.pop("attributes", {}) |
| 113 | + json_indent = config.get("json_indent") |
| 114 | + return { |
| 115 | + ZARRAY_JSON: prototype.buffer.from_bytes( |
| 116 | + json.dumps(zarray_dict, default=_json_convert, indent=json_indent).encode() |
| 117 | + ), |
| 118 | + ZATTRS_JSON: prototype.buffer.from_bytes( |
| 119 | + json.dumps(zattrs_dict, indent=json_indent).encode() |
| 120 | + ), |
| 121 | + } |
| 122 | + |
| 123 | + @classmethod |
| 124 | + def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata: |
| 125 | + # make a copy to protect the original from modification |
| 126 | + _data = data.copy() |
| 127 | + # check that the zarr_format attribute is correct |
| 128 | + _ = parse_zarr_format(_data.pop("zarr_format")) |
| 129 | + return cls(**_data) |
| 130 | + |
| 131 | + def to_dict(self) -> JSON: |
| 132 | + zarray_dict = super().to_dict() |
| 133 | + |
| 134 | + # todo: remove this check when we can ensure that to_dict always returns dicts. |
| 135 | + if not isinstance(zarray_dict, dict): |
| 136 | + raise TypeError(f"Invalid type: got {type(zarray_dict)}, expected dict.") |
| 137 | + |
| 138 | + _ = zarray_dict.pop("chunk_grid") |
| 139 | + zarray_dict["chunks"] = self.chunk_grid.chunk_shape |
| 140 | + |
| 141 | + _ = zarray_dict.pop("data_type") |
| 142 | + zarray_dict["dtype"] = self.data_type.str |
| 143 | + |
| 144 | + return zarray_dict |
| 145 | + |
| 146 | + def get_chunk_spec( |
| 147 | + self, _chunk_coords: ChunkCoords, order: Literal["C", "F"], prototype: BufferPrototype |
| 148 | + ) -> ArraySpec: |
| 149 | + return ArraySpec( |
| 150 | + shape=self.chunk_grid.chunk_shape, |
| 151 | + dtype=self.dtype, |
| 152 | + fill_value=self.fill_value, |
| 153 | + order=order, |
| 154 | + prototype=prototype, |
| 155 | + ) |
| 156 | + |
| 157 | + def encode_chunk_key(self, chunk_coords: ChunkCoords) -> str: |
| 158 | + chunk_identifier = self.dimension_separator.join(map(str, chunk_coords)) |
| 159 | + return "0" if chunk_identifier == "" else chunk_identifier |
| 160 | + |
| 161 | + def update_shape(self, shape: ChunkCoords) -> Self: |
| 162 | + return replace(self, shape=shape) |
| 163 | + |
| 164 | + def update_attributes(self, attributes: dict[str, JSON]) -> Self: |
| 165 | + return replace(self, attributes=attributes) |
| 166 | + |
| 167 | + |
| 168 | +def parse_zarr_format(data: Literal[2]) -> Literal[2]: |
| 169 | + if data == 2: |
| 170 | + return data |
| 171 | + raise ValueError(f"Invalid value. Expected 2. Got {data}.") |
| 172 | + |
| 173 | + |
| 174 | +def parse_filters(data: list[dict[str, JSON]] | None) -> list[dict[str, JSON]] | None: |
| 175 | + return data |
| 176 | + |
| 177 | + |
| 178 | +def parse_compressor(data: dict[str, JSON] | None) -> dict[str, JSON] | None: |
| 179 | + return data |
| 180 | + |
| 181 | + |
| 182 | +def parse_metadata(data: ArrayV2Metadata) -> ArrayV2Metadata: |
| 183 | + if (l_chunks := len(data.chunks)) != (l_shape := len(data.shape)): |
| 184 | + msg = ( |
| 185 | + f"The `shape` and `chunks` attributes must have the same length. " |
| 186 | + f"`chunks` has length {l_chunks}, but `shape` has length {l_shape}." |
| 187 | + ) |
| 188 | + raise ValueError(msg) |
| 189 | + return data |
| 190 | + |
| 191 | + |
| 192 | +def parse_fill_value(fill_value: Any, dtype: np.dtype[Any]) -> Any: |
| 193 | + """ |
| 194 | + Parse a potential fill value into a value that is compatible with the provided dtype. |
| 195 | +
|
| 196 | + Parameters |
| 197 | + ---------- |
| 198 | + fill_value: Any |
| 199 | + A potential fill value. |
| 200 | + dtype: np.dtype[Any] |
| 201 | + A numpy dtype. |
| 202 | +
|
| 203 | + Returns |
| 204 | + An instance of `dtype`, or `None`, or any python object (in the case of an object dtype) |
| 205 | + """ |
| 206 | + |
| 207 | + if fill_value is None or dtype.hasobject: |
| 208 | + # no fill value |
| 209 | + pass |
| 210 | + elif not isinstance(fill_value, np.void) and fill_value == 0: |
| 211 | + # this should be compatible across numpy versions for any array type, including |
| 212 | + # structured arrays |
| 213 | + fill_value = np.zeros((), dtype=dtype)[()] |
| 214 | + |
| 215 | + elif dtype.kind == "U": |
| 216 | + # special case unicode because of encoding issues on Windows if passed through numpy |
| 217 | + # https://github.com/alimanfoo/zarr/pull/172#issuecomment-343782713 |
| 218 | + |
| 219 | + if not isinstance(fill_value, str): |
| 220 | + raise ValueError( |
| 221 | + f"fill_value {fill_value!r} is not valid for dtype {dtype}; must be a unicode string" |
| 222 | + ) |
| 223 | + else: |
| 224 | + try: |
| 225 | + if isinstance(fill_value, bytes) and dtype.kind == "V": |
| 226 | + # special case for numpy 1.14 compatibility |
| 227 | + fill_value = np.array(fill_value, dtype=dtype.str).view(dtype)[()] |
| 228 | + else: |
| 229 | + fill_value = np.array(fill_value, dtype=dtype)[()] |
| 230 | + |
| 231 | + except Exception as e: |
| 232 | + msg = f"Fill_value {fill_value} is not valid for dtype {dtype}." |
| 233 | + raise ValueError(msg) from e |
| 234 | + |
| 235 | + return fill_value |
0 commit comments