@@ -100,10 +100,15 @@ def da(request):
100
100
def _construct_dataarray (shape , chunks , name ):
101
101
dims = list (string .ascii_lowercase [: len (shape )])
102
102
data = np .random .random (shape )
103
- da = xr .DataArray (data , dims = dims , name = name )
103
+ coords = [range (length ) for length in shape ]
104
+ da = xr .DataArray (data , dims = dims , name = name , coords = coords )
104
105
if chunks is not None :
105
106
chunks = {dim : chunk for dim , chunk in zip (dims , chunks )}
106
107
da = da .chunk (chunks )
108
+
109
+ # Add coverage for chunked coordinates
110
+ chunked_coord_name = f"{ da .name } _chunked_coord"
111
+ da = da .assign_coords ({chunked_coord_name : da .chunk (chunks )})
107
112
return da
108
113
109
114
@@ -138,12 +143,23 @@ def ds():
138
143
return xr .merge (unchunked_dataarrays + chunked_dataarrays )
139
144
140
145
146
+ def get_unchunked_modification_times (ds , store ):
147
+ modification_times = {}
148
+ for name , variable in ds .variables .items ():
149
+ if not isinstance (variable .data , dask .array .Array ):
150
+ blob_name = "." .join (["0" for _ in variable .dims ])
151
+ blob_path = os .path .join (store , name , blob_name )
152
+ modification_times [name ] = os .path .getmtime (blob_path )
153
+ return modification_times
154
+
155
+
141
156
@pytest .mark .filterwarnings ("ignore:Specified Dask chunks" )
142
157
@pytest .mark .parametrize ("ranks" , [1 , 2 , 3 , 5 , 10 , 11 ])
143
158
@pytest .mark .parametrize ("collect_variable_writes" , [False , True ])
144
159
def test_dataset_mappable_write (tmpdir , ds , ranks , collect_variable_writes ):
145
160
store = os .path .join (tmpdir , "test.zarr" )
146
161
ds .partition .initialize_store (store )
162
+ expected_modification_times = get_unchunked_modification_times (ds , store )
147
163
148
164
with multiprocessing .get_context ("spawn" ).Pool (ranks ) as pool :
149
165
pool .map (
@@ -154,6 +170,11 @@ def test_dataset_mappable_write(tmpdir, ds, ranks, collect_variable_writes):
154
170
)
155
171
156
172
result = xr .open_zarr (store )
173
+ resulting_modification_times = get_unchunked_modification_times (ds , store )
174
+
175
+ # This checks that all unchunked variables in the dataset were written
176
+ # only once, upon initialization of the store.
177
+ assert expected_modification_times == resulting_modification_times
157
178
xr .testing .assert_identical (result , ds )
158
179
159
180
@@ -317,7 +338,7 @@ def __call__(self, dsk, keys, **kwargs):
317
338
318
339
319
340
@pytest .mark .parametrize (
320
- ("collect_variable_writes" , "expected_computes" ), [(False , 6 ), (True , 3 )]
341
+ ("collect_variable_writes" , "expected_computes" ), [(False , 9 ), (True , 3 )]
321
342
)
322
343
def test_dataset_mappable_write_minimizes_compute_calls (
323
344
tmpdir , collect_variable_writes , expected_computes
0 commit comments