1
- import re
2
- from typing import Any
1
+ from typing import Any , Literal
3
2
4
3
import hypothesis .extra .numpy as npst
5
4
import hypothesis .strategies as st
19
18
max_leaves = 3 ,
20
19
)
21
20
21
+
22
+ def v3_dtypes () -> st .SearchStrategy [np .dtype ]:
23
+ return (
24
+ npst .boolean_dtypes ()
25
+ | npst .integer_dtypes (endianness = "=" )
26
+ | npst .unsigned_integer_dtypes (endianness = "=" )
27
+ | npst .floating_dtypes (endianness = "=" )
28
+ | npst .complex_number_dtypes (endianness = "=" )
29
+ # | npst.byte_string_dtypes(endianness="=")
30
+ # | npst.unicode_string_dtypes()
31
+ # | npst.datetime64_dtypes()
32
+ # | npst.timedelta64_dtypes()
33
+ )
34
+
35
+
36
+ def v2_dtypes () -> st .SearchStrategy [np .dtype ]:
37
+ return (
38
+ npst .boolean_dtypes ()
39
+ | npst .integer_dtypes (endianness = "=" )
40
+ | npst .unsigned_integer_dtypes (endianness = "=" )
41
+ | npst .floating_dtypes (endianness = "=" )
42
+ | npst .complex_number_dtypes (endianness = "=" )
43
+ | npst .byte_string_dtypes (endianness = "=" )
44
+ | npst .unicode_string_dtypes (endianness = "=" )
45
+ | npst .datetime64_dtypes ()
46
+ # | npst.timedelta64_dtypes()
47
+ )
48
+
49
+
22
50
# From https://zarr-specs.readthedocs.io/en/latest/v3/core/v3.0.html#node-names
23
51
# 1. must not be the empty string ("")
24
52
# 2. must not include the character "/"
33
61
array_names = node_names
34
62
attrs = st .none () | st .dictionaries (_attr_keys , _attr_values )
35
63
paths = st .lists (node_names , min_size = 1 ).map (lambda x : "/" .join (x )) | st .just ("/" )
36
- np_arrays = npst .arrays (
37
- # TODO: re-enable timedeltas once they are supported
38
- dtype = npst .scalar_dtypes ().filter (
39
- lambda x : (x .kind not in ["m" , "M" ]) and (x .byteorder not in [">" ])
40
- ),
41
- shape = npst .array_shapes (max_dims = 4 ),
42
- )
43
64
stores = st .builds (MemoryStore , st .just ({}), mode = st .just ("w" ))
44
65
compressors = st .sampled_from ([None , "default" ])
45
- format = st .sampled_from ([2 , 3 ])
66
+ zarr_formats : st .SearchStrategy [Literal [2 , 3 ]] = st .sampled_from ([2 , 3 ])
67
+ array_shapes = npst .array_shapes (max_dims = 4 )
68
+
69
+
70
+ @st .composite # type: ignore[misc]
71
+ def numpy_arrays (
72
+ draw : st .DrawFn ,
73
+ * ,
74
+ shapes : st .SearchStrategy [tuple [int , ...]] = array_shapes ,
75
+ zarr_formats : st .SearchStrategy [Literal [2 , 3 ]] = zarr_formats ,
76
+ ) -> Any :
77
+ """
78
+ Generate numpy arrays that can be saved in the provided Zarr format.
79
+ """
80
+ zarr_format = draw (zarr_formats )
81
+ return draw (npst .arrays (dtype = v3_dtypes () if zarr_format == 3 else v2_dtypes (), shape = shapes ))
46
82
47
83
48
84
@st .composite # type: ignore[misc]
49
85
def np_array_and_chunks (
50
- draw : st .DrawFn , * , arrays : st .SearchStrategy [np .ndarray ] = np_arrays
86
+ draw : st .DrawFn , * , arrays : st .SearchStrategy [np .ndarray ] = numpy_arrays
51
87
) -> tuple [np .ndarray , tuple [int ]]: # type: ignore[type-arg]
52
88
"""A hypothesis strategy to generate small sized random arrays.
53
89
@@ -66,73 +102,49 @@ def np_array_and_chunks(
66
102
def arrays (
67
103
draw : st .DrawFn ,
68
104
* ,
105
+ shapes : st .SearchStrategy [tuple [int , ...]] = array_shapes ,
69
106
compressors : st .SearchStrategy = compressors ,
70
107
stores : st .SearchStrategy [StoreLike ] = stores ,
71
- arrays : st .SearchStrategy [np .ndarray ] = np_arrays ,
72
108
paths : st .SearchStrategy [None | str ] = paths ,
73
109
array_names : st .SearchStrategy = array_names ,
110
+ arrays : st .SearchStrategy | None = None ,
74
111
attrs : st .SearchStrategy = attrs ,
75
- format : st .SearchStrategy = format ,
112
+ zarr_formats : st .SearchStrategy = zarr_formats ,
76
113
) -> Array :
77
114
store = draw (stores )
78
- nparray , chunks = draw (np_array_and_chunks (arrays = arrays ))
79
115
path = draw (paths )
80
116
name = draw (array_names )
81
117
attributes = draw (attrs )
82
- zarr_format = draw (format )
118
+ zarr_format = draw (zarr_formats )
119
+ if arrays is None :
120
+ arrays = numpy_arrays (shapes = shapes , zarr_formats = st .just (zarr_format ))
121
+ nparray , chunks = draw (np_array_and_chunks (arrays = arrays ))
122
+ # test that None works too.
123
+ fill_value = draw (st .one_of ([st .none (), npst .from_dtype (nparray .dtype )]))
83
124
# compressor = draw(compressors)
84
125
85
- # TODO: clean this up
86
- # if path is None and name is None:
87
- # array_path = None
88
- # array_name = None
89
- # elif path is None and name is not None:
90
- # array_path = f"{name}"
91
- # array_name = f"/{name}"
92
- # elif path is not None and name is None:
93
- # array_path = path
94
- # array_name = None
95
- # elif path == "/":
96
- # assert name is not None
97
- # array_path = name
98
- # array_name = "/" + name
99
- # else:
100
- # assert name is not None
101
- # array_path = f"{path}/{name}"
102
- # array_name = "/" + array_path
103
-
104
126
expected_attrs = {} if attributes is None else attributes
105
127
106
128
array_path = path + ("/" if not path .endswith ("/" ) else "" ) + name
107
129
root = Group .from_store (store , zarr_format = zarr_format )
108
- fill_value_args : tuple [Any , ...] = tuple ()
109
- if nparray .dtype .kind == "M" :
110
- m = re .search (r"\[(.+)\]" , nparray .dtype .str )
111
- if not m :
112
- raise ValueError (f"Couldn't find precision for dtype '{ nparray .dtype } ." )
113
-
114
- fill_value_args = (
115
- # e.g. ns, D
116
- m .groups ()[0 ],
117
- )
118
130
119
131
a = root .create_array (
120
132
array_path ,
121
133
shape = nparray .shape ,
122
134
chunks = chunks ,
123
- dtype = nparray .dtype . str ,
135
+ dtype = nparray .dtype ,
124
136
attributes = attributes ,
125
- # compressor=compressor, # TODO: FIXME
126
- fill_value = nparray . dtype . type ( 0 , * fill_value_args ) ,
137
+ # compressor=compressor, # FIXME
138
+ fill_value = fill_value ,
127
139
)
128
140
129
141
assert isinstance (a , Array )
142
+ assert a .fill_value is not None
143
+ assert isinstance (root [array_path ], Array )
130
144
assert nparray .shape == a .shape
131
145
assert chunks == a .chunks
132
146
assert array_path == a .path , (path , name , array_path , a .name , a .path )
133
- # assert array_path == a.name, (path, name, array_path, a.name, a.path)
134
- # assert a.basename is None # TODO
135
- # assert a.store == normalize_store_arg(store)
147
+ assert a .basename == name , (a .basename , name )
136
148
assert dict (a .attrs ) == expected_attrs
137
149
138
150
a [:] = nparray
0 commit comments