Skip to content

Commit 692593b

Browse files
Fix fill_value handling for complex dtypes (#2200)
* Fix fill_value handling for complex & datetime dtypes * cleanup * more cleanup * more cleanup * Fix default fill_value * Fixes * Add booleans * Add v2, v3 specific dtypes * Add version.py to gitignore * cleanpu * style: pre-commit fixes --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent fafd0bf commit 692593b

File tree

6 files changed

+77
-65
lines changed

6 files changed

+77
-65
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,5 @@ fixture/
8484
.DS_Store
8585
tests/.hypothesis
8686
.hypothesis/
87+
88+
zarr/version.py

src/zarr/core/array.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -252,12 +252,6 @@ async def _create_v3(
252252
shape = parse_shapelike(shape)
253253
codecs = list(codecs) if codecs is not None else [BytesCodec()]
254254

255-
if fill_value is None:
256-
if dtype == np.dtype("bool"):
257-
fill_value = False
258-
else:
259-
fill_value = 0
260-
261255
if chunk_key_encoding is None:
262256
chunk_key_encoding = ("default", "/")
263257
assert chunk_key_encoding is not None
@@ -281,7 +275,6 @@ async def _create_v3(
281275
)
282276

283277
array = cls(metadata=metadata, store_path=store_path)
284-
285278
await array._save_metadata(metadata)
286279
return array
287280

src/zarr/core/buffer/core.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,9 +464,14 @@ def __repr__(self) -> str:
464464

465465
def all_equal(self, other: Any, equal_nan: bool = True) -> bool:
466466
"""Compare to `other` using np.array_equal."""
467+
if other is None:
468+
# Handle None fill_value for Zarr V2
469+
return False
467470
# use array_equal to obtain equal_nan=True functionality
468471
data, other = np.broadcast_arrays(self._data, other)
469-
result = np.array_equal(self._data, other, equal_nan=equal_nan)
472+
result = np.array_equal(
473+
self._data, other, equal_nan=equal_nan if self._data.dtype.kind not in "US" else False
474+
)
470475
return result
471476

472477
def fill(self, value: Any) -> None:

src/zarr/core/metadata/v3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def parse_fill_value(
360360
if fill_value is None:
361361
return dtype.type(0)
362362
if isinstance(fill_value, Sequence) and not isinstance(fill_value, str):
363-
if dtype in (np.complex64, np.complex128):
363+
if dtype.type in (np.complex64, np.complex128):
364364
dtype = cast(COMPLEX_DTYPE, dtype)
365365
if len(fill_value) == 2:
366366
# complex datatypes serialize to JSON arrays with two elements
@@ -391,7 +391,7 @@ def parse_fill_value(
391391
pass
392392
elif fill_value in ["Infinity", "-Infinity"] and not np.isfinite(casted_value):
393393
pass
394-
elif dtype.kind == "f":
394+
elif dtype.kind in "cf":
395395
# float comparison is not exact, especially when dtype <float64
396396
# so we us np.isclose for this comparison.
397397
# this also allows us to compare nan fill_values

src/zarr/testing/strategies.py

Lines changed: 62 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import re
2-
from typing import Any
1+
from typing import Any, Literal
32

43
import hypothesis.extra.numpy as npst
54
import hypothesis.strategies as st
@@ -19,6 +18,35 @@
1918
max_leaves=3,
2019
)
2120

21+
22+
def v3_dtypes() -> st.SearchStrategy[np.dtype]:
23+
return (
24+
npst.boolean_dtypes()
25+
| npst.integer_dtypes(endianness="=")
26+
| npst.unsigned_integer_dtypes(endianness="=")
27+
| npst.floating_dtypes(endianness="=")
28+
| npst.complex_number_dtypes(endianness="=")
29+
# | npst.byte_string_dtypes(endianness="=")
30+
# | npst.unicode_string_dtypes()
31+
# | npst.datetime64_dtypes()
32+
# | npst.timedelta64_dtypes()
33+
)
34+
35+
36+
def v2_dtypes() -> st.SearchStrategy[np.dtype]:
37+
return (
38+
npst.boolean_dtypes()
39+
| npst.integer_dtypes(endianness="=")
40+
| npst.unsigned_integer_dtypes(endianness="=")
41+
| npst.floating_dtypes(endianness="=")
42+
| npst.complex_number_dtypes(endianness="=")
43+
| npst.byte_string_dtypes(endianness="=")
44+
| npst.unicode_string_dtypes(endianness="=")
45+
| npst.datetime64_dtypes()
46+
# | npst.timedelta64_dtypes()
47+
)
48+
49+
2250
# From https://zarr-specs.readthedocs.io/en/latest/v3/core/v3.0.html#node-names
2351
# 1. must not be the empty string ("")
2452
# 2. must not include the character "/"
@@ -33,21 +61,29 @@
3361
array_names = node_names
3462
attrs = st.none() | st.dictionaries(_attr_keys, _attr_values)
3563
paths = st.lists(node_names, min_size=1).map(lambda x: "/".join(x)) | st.just("/")
36-
np_arrays = npst.arrays(
37-
# TODO: re-enable timedeltas once they are supported
38-
dtype=npst.scalar_dtypes().filter(
39-
lambda x: (x.kind not in ["m", "M"]) and (x.byteorder not in [">"])
40-
),
41-
shape=npst.array_shapes(max_dims=4),
42-
)
4364
stores = st.builds(MemoryStore, st.just({}), mode=st.just("w"))
4465
compressors = st.sampled_from([None, "default"])
45-
format = st.sampled_from([2, 3])
66+
zarr_formats: st.SearchStrategy[Literal[2, 3]] = st.sampled_from([2, 3])
67+
array_shapes = npst.array_shapes(max_dims=4)
68+
69+
70+
@st.composite # type: ignore[misc]
71+
def numpy_arrays(
72+
draw: st.DrawFn,
73+
*,
74+
shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes,
75+
zarr_formats: st.SearchStrategy[Literal[2, 3]] = zarr_formats,
76+
) -> Any:
77+
"""
78+
Generate numpy arrays that can be saved in the provided Zarr format.
79+
"""
80+
zarr_format = draw(zarr_formats)
81+
return draw(npst.arrays(dtype=v3_dtypes() if zarr_format == 3 else v2_dtypes(), shape=shapes))
4682

4783

4884
@st.composite # type: ignore[misc]
4985
def np_array_and_chunks(
50-
draw: st.DrawFn, *, arrays: st.SearchStrategy[np.ndarray] = np_arrays
86+
draw: st.DrawFn, *, arrays: st.SearchStrategy[np.ndarray] = numpy_arrays
5187
) -> tuple[np.ndarray, tuple[int]]: # type: ignore[type-arg]
5288
"""A hypothesis strategy to generate small sized random arrays.
5389
@@ -66,73 +102,49 @@ def np_array_and_chunks(
66102
def arrays(
67103
draw: st.DrawFn,
68104
*,
105+
shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes,
69106
compressors: st.SearchStrategy = compressors,
70107
stores: st.SearchStrategy[StoreLike] = stores,
71-
arrays: st.SearchStrategy[np.ndarray] = np_arrays,
72108
paths: st.SearchStrategy[None | str] = paths,
73109
array_names: st.SearchStrategy = array_names,
110+
arrays: st.SearchStrategy | None = None,
74111
attrs: st.SearchStrategy = attrs,
75-
format: st.SearchStrategy = format,
112+
zarr_formats: st.SearchStrategy = zarr_formats,
76113
) -> Array:
77114
store = draw(stores)
78-
nparray, chunks = draw(np_array_and_chunks(arrays=arrays))
79115
path = draw(paths)
80116
name = draw(array_names)
81117
attributes = draw(attrs)
82-
zarr_format = draw(format)
118+
zarr_format = draw(zarr_formats)
119+
if arrays is None:
120+
arrays = numpy_arrays(shapes=shapes, zarr_formats=st.just(zarr_format))
121+
nparray, chunks = draw(np_array_and_chunks(arrays=arrays))
122+
# test that None works too.
123+
fill_value = draw(st.one_of([st.none(), npst.from_dtype(nparray.dtype)]))
83124
# compressor = draw(compressors)
84125

85-
# TODO: clean this up
86-
# if path is None and name is None:
87-
# array_path = None
88-
# array_name = None
89-
# elif path is None and name is not None:
90-
# array_path = f"{name}"
91-
# array_name = f"/{name}"
92-
# elif path is not None and name is None:
93-
# array_path = path
94-
# array_name = None
95-
# elif path == "/":
96-
# assert name is not None
97-
# array_path = name
98-
# array_name = "/" + name
99-
# else:
100-
# assert name is not None
101-
# array_path = f"{path}/{name}"
102-
# array_name = "/" + array_path
103-
104126
expected_attrs = {} if attributes is None else attributes
105127

106128
array_path = path + ("/" if not path.endswith("/") else "") + name
107129
root = Group.from_store(store, zarr_format=zarr_format)
108-
fill_value_args: tuple[Any, ...] = tuple()
109-
if nparray.dtype.kind == "M":
110-
m = re.search(r"\[(.+)\]", nparray.dtype.str)
111-
if not m:
112-
raise ValueError(f"Couldn't find precision for dtype '{nparray.dtype}.")
113-
114-
fill_value_args = (
115-
# e.g. ns, D
116-
m.groups()[0],
117-
)
118130

119131
a = root.create_array(
120132
array_path,
121133
shape=nparray.shape,
122134
chunks=chunks,
123-
dtype=nparray.dtype.str,
135+
dtype=nparray.dtype,
124136
attributes=attributes,
125-
# compressor=compressor, # TODO: FIXME
126-
fill_value=nparray.dtype.type(0, *fill_value_args),
137+
# compressor=compressor, # FIXME
138+
fill_value=fill_value,
127139
)
128140

129141
assert isinstance(a, Array)
142+
assert a.fill_value is not None
143+
assert isinstance(root[array_path], Array)
130144
assert nparray.shape == a.shape
131145
assert chunks == a.chunks
132146
assert array_path == a.path, (path, name, array_path, a.name, a.path)
133-
# assert array_path == a.name, (path, name, array_path, a.name, a.path)
134-
# assert a.basename is None # TODO
135-
# assert a.store == normalize_store_arg(store)
147+
assert a.basename == name, (a.basename, name)
136148
assert dict(a.attrs) == expected_attrs
137149

138150
a[:] = nparray

tests/v3/test_properties.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
import hypothesis.strategies as st # noqa: E402
99
from hypothesis import given # noqa: E402
1010

11-
from zarr.testing.strategies import arrays, basic_indices, np_arrays # noqa: E402
11+
from zarr.testing.strategies import arrays, basic_indices, numpy_arrays, zarr_formats # noqa: E402
1212

1313

14-
@given(st.data())
15-
def test_roundtrip(data: st.DataObject) -> None:
16-
nparray = data.draw(np_arrays)
17-
zarray = data.draw(arrays(arrays=st.just(nparray)))
14+
@given(data=st.data(), zarr_format=zarr_formats)
15+
def test_roundtrip(data: st.DataObject, zarr_format: int) -> None:
16+
nparray = data.draw(numpy_arrays(zarr_formats=st.just(zarr_format)))
17+
zarray = data.draw(arrays(arrays=st.just(nparray), zarr_formats=st.just(zarr_format)))
1818
assert_array_equal(nparray, zarray[:])
1919

2020

0 commit comments

Comments
 (0)