Skip to content

Commit 96be43e

Browse files
committed
Prevent concurrent coordinate writes
1 parent 85ff3ae commit 96be43e

File tree

2 files changed

+27
-6
lines changed

2 files changed

+27
-6
lines changed

test_xpartition.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,15 @@ def da(request):
100100
def _construct_dataarray(shape, chunks, name):
101101
dims = list(string.ascii_lowercase[: len(shape)])
102102
data = np.random.random(shape)
103-
da = xr.DataArray(data, dims=dims, name=name)
103+
coords = [range(length) for length in shape]
104+
da = xr.DataArray(data, dims=dims, name=name, coords=coords)
104105
if chunks is not None:
105106
chunks = {dim: chunk for dim, chunk in zip(dims, chunks)}
106107
da = da.chunk(chunks)
108+
109+
# Add coverage for chunked coordinates
110+
chunked_coord_name = f"{da.name}_chunked_coord"
111+
da = da.assign_coords({chunked_coord_name: da.chunk(chunks)})
107112
return da
108113

109114

@@ -138,12 +143,23 @@ def ds():
138143
return xr.merge(unchunked_dataarrays + chunked_dataarrays)
139144

140145

146+
def get_unchunked_modification_times(ds, store):
147+
modification_times = {}
148+
for name, variable in ds.variables.items():
149+
if not isinstance(variable.data, dask.array.Array):
150+
blob_name = ".".join(["0" for _ in variable.dims])
151+
blob_path = os.path.join(store, name, blob_name)
152+
modification_times[name] = os.path.getmtime(blob_path)
153+
return modification_times
154+
155+
141156
@pytest.mark.filterwarnings("ignore:Specified Dask chunks")
142157
@pytest.mark.parametrize("ranks", [1, 2, 3, 5, 10, 11])
143158
@pytest.mark.parametrize("collect_variable_writes", [False, True])
144159
def test_dataset_mappable_write(tmpdir, ds, ranks, collect_variable_writes):
145160
store = os.path.join(tmpdir, "test.zarr")
146161
ds.partition.initialize_store(store)
162+
expected_modification_times = get_unchunked_modification_times(ds, store)
147163

148164
with multiprocessing.get_context("spawn").Pool(ranks) as pool:
149165
pool.map(
@@ -154,6 +170,11 @@ def test_dataset_mappable_write(tmpdir, ds, ranks, collect_variable_writes):
154170
)
155171

156172
result = xr.open_zarr(store)
173+
resulting_modification_times = get_unchunked_modification_times(ds, store)
174+
175+
# This checks that all unchunked variables in the dataset were written
176+
# only once, upon initialization of the store.
177+
assert expected_modification_times == resulting_modification_times
157178
xr.testing.assert_identical(result, ds)
158179

159180

@@ -317,7 +338,7 @@ def __call__(self, dsk, keys, **kwargs):
317338

318339

319340
@pytest.mark.parametrize(
320-
("collect_variable_writes", "expected_computes"), [(False, 6), (True, 3)]
341+
("collect_variable_writes", "expected_computes"), [(False, 9), (True, 3)]
321342
)
322343
def test_dataset_mappable_write_minimizes_compute_calls(
323344
tmpdir, collect_variable_writes, expected_computes

xpartition.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def isel(self, **block_indexers) -> xr.DataArray:
184184
def _write_partition_dataarray(
185185
da: xr.DataArray, store: str, ranks: int, dims: Sequence[Hashable], rank: int
186186
):
187-
ds = da.to_dataset()
187+
ds = da.drop_vars(da.coords).to_dataset()
188188
partition = da.partition.indexers(ranks, rank, dims)
189189
if partition is not None:
190190
ds.isel(partition).to_zarr(store, region=partition)
@@ -214,18 +214,18 @@ def _collect_by_partition(
214214
DataArrays that can be written out to those partitions.
215215
"""
216216
dataarrays = collections.defaultdict(list)
217-
for da in ds.data_vars.values():
217+
for da in {**ds.coords, **ds.data_vars}.values():
218218
if isinstance(da.data, dask.array.Array):
219219
partition_dims = [dim for dim in dims if dim in da.dims]
220220
indexers = da.partition.indexers(ranks, rank, partition_dims)
221-
dataarrays[freeze_indexers(indexers)].append(da)
221+
dataarrays[freeze_indexers(indexers)].append(da.drop_vars(da.coords))
222222
return [(unfreeze_indexers(k), xr.merge(v)) for k, v in dataarrays.items()]
223223

224224

225225
def _write_partition_dataset_via_individual_variables(
226226
ds: xr.Dataset, store: str, ranks: int, dims: Sequence[Hashable], rank: int
227227
):
228-
for da in ds.data_vars.values():
228+
for da in {**ds.coords, **ds.data_vars}.values():
229229
if isinstance(da.data, dask.array.Array):
230230
partition_dims = [dim for dim in dims if dim in da.dims]
231231
da.partition.write(store, ranks, partition_dims, rank)

0 commit comments

Comments
 (0)