Skip to content

Cast fill value to array's dtype #2020

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/zarr/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,10 @@ def order(self) -> Literal["C", "F"]:
def read_only(self) -> bool:
return self._async_array.read_only

@property
def fill_value(self) -> Any:
return self.metadata.fill_value

def __array__(
self, dtype: npt.DTypeLike | None = None, copy: bool | None = None
) -> NDArrayLike:
Expand Down
130 changes: 122 additions & 8 deletions src/zarr/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import json
from abc import ABC, abstractmethod
from collections.abc import Iterable
from collections.abc import Iterable, Sequence
from dataclasses import dataclass, field, replace
from enum import Enum
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Literal, cast, overload

import numpy as np
import numpy.typing as npt
Expand All @@ -32,7 +32,6 @@
ChunkCoords,
ZarrFormat,
parse_dtype,
parse_fill_value,
parse_named_configuration,
parse_shapelike,
)
Expand Down Expand Up @@ -189,7 +188,7 @@ def __init__(
chunk_grid_parsed = ChunkGrid.from_dict(chunk_grid)
chunk_key_encoding_parsed = ChunkKeyEncoding.from_dict(chunk_key_encoding)
dimension_names_parsed = parse_dimension_names(dimension_names)
fill_value_parsed = parse_fill_value(fill_value)
fill_value_parsed = parse_fill_value_v3(fill_value, dtype=data_type_parsed)
attributes_parsed = parse_attributes(attributes)
codecs_parsed_partial = parse_codecs(codecs)

Expand Down Expand Up @@ -255,9 +254,18 @@ def encode_chunk_key(self, chunk_coords: ChunkCoords) -> str:
return self.chunk_key_encoding.encode_chunk_key(chunk_coords)

def to_buffer_dict(self) -> dict[str, Buffer]:
def _json_convert(o: np.dtype[Any] | Enum | Codec) -> str | dict[str, Any]:
def _json_convert(o: Any) -> Any:
if isinstance(o, np.dtype):
return str(o)
if np.isscalar(o):
# convert numpy scalar to python type, and pass
# python types through
out = getattr(o, "item", lambda: o)()
if isinstance(out, complex):
# python complex types are not JSON serializable, so we use the
# serialization defined in the zarr v3 spec
return [out.real, out.imag]
return out
if isinstance(o, Enum):
return o.name
# this serializes numcodecs compressors
Expand Down Expand Up @@ -341,7 +349,7 @@ def __init__(
order_parsed = parse_indexing_order(order)
dimension_separator_parsed = parse_separator(dimension_separator)
filters_parsed = parse_filters(filters)
fill_value_parsed = parse_fill_value(fill_value)
fill_value_parsed = parse_fill_value_v2(fill_value, dtype=data_type_parsed)
attributes_parsed = parse_attributes(attributes)

object.__setattr__(self, "shape", shape_parsed)
Expand Down Expand Up @@ -371,13 +379,17 @@ def chunks(self) -> ChunkCoords:

def to_buffer_dict(self) -> dict[str, Buffer]:
def _json_convert(
o: np.dtype[Any],
) -> str | list[tuple[str, str] | tuple[str, str, tuple[int, ...]]]:
o: Any,
) -> Any:
if isinstance(o, np.dtype):
if o.fields is None:
return o.str
else:
return o.descr
if np.isscalar(o):
# convert numpy scalar to python type, and pass
# python types through
return getattr(o, "item", lambda: o)()
raise TypeError

zarray_dict = self.to_dict()
Expand Down Expand Up @@ -517,3 +529,105 @@ def parse_codecs(data: Iterable[Codec | dict[str, JSON]]) -> tuple[Codec, ...]:
out += (get_codec_class(name_parsed).from_dict(c),)

return out


def parse_fill_value_v2(fill_value: Any, dtype: np.dtype[Any]) -> Any:
"""
Parse a potential fill value into a value that is compatible with the provided dtype.

This is a light wrapper around zarr.v2.util.normalize_fill_value.

Parameters
----------
fill_value: Any
A potential fill value.
dtype: np.dtype[Any]
A numpy dtype.

Returns
An instance of `dtype`, or `None`, or any python object (in the case of an object dtype)
"""
from zarr.v2.util import normalize_fill_value

return normalize_fill_value(fill_value=fill_value, dtype=dtype)


BOOL = np.bool_
BOOL_DTYPE = np.dtypes.BoolDType

INTEGER_DTYPE = (
np.dtypes.Int8DType
| np.dtypes.Int16DType
| np.dtypes.Int32DType
| np.dtypes.Int64DType
| np.dtypes.UByteDType
| np.dtypes.UInt16DType
| np.dtypes.UInt32DType
| np.dtypes.UInt64DType
)

INTEGER = np.int8 | np.int16 | np.int32 | np.int64 | np.uint8 | np.uint16 | np.uint32 | np.uint64
FLOAT_DTYPE = np.dtypes.Float16DType | np.dtypes.Float32DType | np.dtypes.Float64DType
FLOAT = np.float16 | np.float32 | np.float64
COMPLEX_DTYPE = np.dtypes.Complex64DType | np.dtypes.Complex128DType
COMPLEX = np.complex64 | np.complex128
# todo: r* dtypes


@overload
def parse_fill_value_v3(fill_value: Any, dtype: BOOL_DTYPE) -> BOOL: ...


@overload
def parse_fill_value_v3(fill_value: Any, dtype: INTEGER_DTYPE) -> INTEGER: ...


@overload
def parse_fill_value_v3(fill_value: Any, dtype: FLOAT_DTYPE) -> FLOAT: ...


@overload
def parse_fill_value_v3(fill_value: Any, dtype: COMPLEX_DTYPE) -> COMPLEX: ...


def parse_fill_value_v3(
fill_value: Any, dtype: BOOL_DTYPE | INTEGER_DTYPE | FLOAT_DTYPE | COMPLEX_DTYPE
) -> BOOL | INTEGER | FLOAT | COMPLEX:
"""
Parse `fill_value`, a potential fill value, into an instance of `dtype`, a data type.
If `fill_value` is `None`, then this function will return the result of casting the value 0
to the provided data type. Otherwise, `fill_value` will be cast to the provided data type.

Note that some numpy dtypes use very permissive casting rules. For example,
`np.bool_({'not remotely a bool'})` returns `True`. Thus this function should not be used for
validating that the provided fill value is a valid instance of the data type.

Parameters
----------
fill_value: Any
A potential fill value.
dtype: BOOL_DTYPE | INTEGER_DTYPE | FLOAT_DTYPE | COMPLEX_DTYPE
A numpy data type that models a data type defined in the Zarr V3 specification.

Returns
-------
A scalar instance of `dtype`
"""
if fill_value is None:
return dtype.type(0)
if isinstance(fill_value, Sequence) and not isinstance(fill_value, str):
if dtype in (np.complex64, np.complex128):
dtype = cast(COMPLEX_DTYPE, dtype)
if len(fill_value) == 2:
# complex datatypes serialize to JSON arrays with two elements
return dtype.type(complex(*fill_value))
else:
msg = (
f"Got an invalid fill value for complex data type {dtype}."
f"Expected a sequence with 2 elements, but {fill_value} has "
f"length {len(fill_value)}."
)
raise ValueError(msg)
msg = f"Cannot parse non-string sequence {fill_value} as a scalar with type {dtype}."
raise TypeError(msg)
return dtype.type(fill_value)
49 changes: 49 additions & 0 deletions tests/v3/test_array.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pytest

from zarr.array import Array
Expand Down Expand Up @@ -34,3 +35,51 @@ def test_array_name_properties_with_group(
assert spam.path == "bar/spam"
assert spam.name == "/bar/spam"
assert spam.basename == "spam"


@pytest.mark.parametrize("store", ["memory"], indirect=True)
@pytest.mark.parametrize("specifiy_fill_value", [True, False])
@pytest.mark.parametrize("dtype_str", ["bool", "uint8", "complex64"])
def test_array_v3_fill_value_default(
store: MemoryStore, specifiy_fill_value: bool, dtype_str: str
) -> None:
"""
Test that creating an array with the fill_value parameter set to None, or unspecified,
results in the expected fill_value attribute of the array, i.e. 0 cast to the array's dtype.
"""
shape = (10,)
default_fill_value = 0
if specifiy_fill_value:
arr = Array.create(
store=store,
shape=shape,
dtype=dtype_str,
zarr_format=3,
chunk_shape=shape,
fill_value=None,
)
else:
arr = Array.create(
store=store, shape=shape, dtype=dtype_str, zarr_format=3, chunk_shape=shape
)

assert arr.fill_value == np.dtype(dtype_str).type(default_fill_value)
assert arr.fill_value.dtype == arr.dtype


@pytest.mark.parametrize("store", ["memory"], indirect=True)
@pytest.mark.parametrize("fill_value", [False, 0.0, 1, 2.3])
@pytest.mark.parametrize("dtype_str", ["bool", "uint8", "float32", "complex64"])
def test_array_v3_fill_value(store: MemoryStore, fill_value: int, dtype_str: str) -> None:
shape = (10,)
arr = Array.create(
store=store,
shape=shape,
dtype=dtype_str,
zarr_format=3,
chunk_shape=shape,
fill_value=fill_value,
)

assert arr.fill_value == np.dtype(dtype_str).type(fill_value)
assert arr.fill_value.dtype == arr.dtype
60 changes: 0 additions & 60 deletions tests/v3/test_metadata.py
Original file line number Diff line number Diff line change
@@ -1,60 +0,0 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

if TYPE_CHECKING:
from collections.abc import Sequence
from typing import Any

from zarr.metadata import parse_dimension_names, parse_zarr_format_v2, parse_zarr_format_v3


# todo: test
def test_datatype_enum(): ...


# todo: test
# this will almost certainly be a collection of tests
def test_array_metadata_v3(): ...


# todo: test
# this will almost certainly be a collection of tests
def test_array_metadata_v2(): ...


@pytest.mark.parametrize("data", [None, ("a", "b", "c"), ["a", "a", "a"]])
def parse_dimension_names_valid(data: Sequence[str] | None) -> None:
assert parse_dimension_names(data) == data


@pytest.mark.parametrize("data", [(), [1, 2, "a"], {"foo": 10}])
def parse_dimension_names_invalid(data: Any) -> None:
with pytest.raises(TypeError, match="Expected either None or iterable of str,"):
parse_dimension_names(data)


# todo: test
def test_parse_attributes() -> None: ...


def test_parse_zarr_format_v3_valid() -> None:
assert parse_zarr_format_v3(3) == 3


@pytest.mark.parametrize("data", [None, 1, 2, 4, 5, "3"])
def test_parse_zarr_foramt_v3_invalid(data: Any) -> None:
with pytest.raises(ValueError, match=f"Invalid value. Expected 3. Got {data}"):
parse_zarr_format_v3(data)


def test_parse_zarr_format_v2_valid() -> None:
assert parse_zarr_format_v2(2) == 2


@pytest.mark.parametrize("data", [None, 1, 3, 4, 5, "3"])
def test_parse_zarr_foramt_v2_invalid(data: Any) -> None:
with pytest.raises(ValueError, match=f"Invalid value. Expected 2. Got {data}"):
parse_zarr_format_v2(data)
20 changes: 20 additions & 0 deletions tests/v3/test_metadata/test_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Any

import pytest

from zarr.metadata import parse_zarr_format_v2


def test_parse_zarr_format_valid() -> None:
assert parse_zarr_format_v2(2) == 2


@pytest.mark.parametrize("data", [None, 1, 3, 4, 5, "3"])
def test_parse_zarr_format_invalid(data: Any) -> None:
with pytest.raises(ValueError, match=f"Invalid value. Expected 2. Got {data}"):
parse_zarr_format_v2(data)
Loading