Skip to content

Commit 994af64

Browse files
committed
Add initialize_zarr
Closes pydata#8343
1 parent 22ca9ba commit 994af64

File tree

4 files changed

+148
-4
lines changed

4 files changed

+148
-4
lines changed

doc/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ Top-level functions
2727
combine_nested
2828
where
2929
infer_freq
30+
initialize_zarr
3031
full_like
3132
zeros_like
3233
ones_like

xarray/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
open_mfdataset,
1010
save_mfdataset,
1111
)
12-
from xarray.backends.zarr import open_zarr
12+
from xarray.backends.zarr import initialize_zarr, open_zarr
1313
from xarray.coding.cftime_offsets import cftime_range, date_range, date_range_like
1414
from xarray.coding.cftimeindex import CFTimeIndex
1515
from xarray.coding.frequencies import infer_freq
@@ -75,6 +75,7 @@
7575
"full_like",
7676
"get_options",
7777
"infer_freq",
78+
"initialize_zarr",
7879
"load_dataarray",
7980
"load_dataset",
8081
"map_blocks",

xarray/backends/zarr.py

Lines changed: 90 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import json
44
import os
55
import warnings
6-
from collections.abc import Iterable
7-
from typing import TYPE_CHECKING, Any
6+
from collections.abc import Hashable, Iterable
7+
from typing import TYPE_CHECKING, Any, Literal
88

99
import numpy as np
1010

@@ -19,6 +19,7 @@
1919
)
2020
from xarray.backends.store import StoreBackendEntrypoint
2121
from xarray.core import indexing
22+
from xarray.core.common import zeros_like
2223
from xarray.core.parallelcompat import guess_chunkmanager
2324
from xarray.core.pycompat import integer_types
2425
from xarray.core.utils import (
@@ -34,11 +35,97 @@
3435
from xarray.backends.common import AbstractDataStore
3536
from xarray.core.dataset import Dataset
3637

37-
3838
# need some special secret attributes to tell us the dimensions
3939
DIMENSION_KEY = "_ARRAY_DIMENSIONS"
4040

4141

42+
def initialize_zarr(
43+
store,
44+
ds: Dataset,
45+
*,
46+
region_dims: Iterable[Hashable] | None = None,
47+
mode: Literal["w", "w-"] = "w-",
48+
**kwargs,
49+
) -> Dataset:
50+
"""
51+
Initialize a Zarr store with metadata.
52+
53+
This function initializes a Zarr store with metadata that describes the entire datasets.
54+
If ``region_dims`` is specified, it will also
55+
1. Write variables that don't contain any of ``region_dims``, and
56+
2. Return a dataset with variables that do contain one or more of ``region_dims``.
57+
This dataset can be used for region writes in parallel.
58+
59+
Parameters
60+
----------
61+
store : MutableMapping or str
62+
Zarr store to write to.
63+
ds : Dataset
64+
Dataset to write.
65+
region_dims : Iterable[Hashable], optional
66+
An iterable of dimension names that will be passed to the ``region``
67+
kwarg of ``to_zarr`` later.
68+
mode : {'w', 'w-'}
69+
Write mode for initializing the store.
70+
71+
Returns
72+
-------
73+
Dataset
74+
Dataset containing variables with one or more ``region_dims``
75+
dimensions. Use this for writing to the store in parallel later.
76+
77+
Raises
78+
------
79+
ValueError
80+
81+
"""
82+
83+
if "compute" in kwargs:
84+
raise ValueError("The ``compute`` kwarg is not supported in `initialize_zarr`.")
85+
86+
if not ds.chunks:
87+
raise ValueError("This function should be used with chunked Datasets.")
88+
89+
if mode not in ["w", "w-"]:
90+
raise ValueError(
91+
f"Only mode='w' or mode='w-' is allowed for initialize_zarr. Received mode={mode!r}"
92+
)
93+
94+
# TODO: what should we do here.
95+
# compute=False only skips dask variables.
96+
# - We could reaplce all dask variables with zeros_like
97+
# - and then write all other variables eagerly.
98+
# Right now we do two writes for eager variables
99+
template = zeros_like(ds)
100+
template.to_zarr(store, mode=mode, **kwargs, compute=False)
101+
102+
if region_dims:
103+
after_drop = ds.drop_dims(region_dims)
104+
105+
# we have to remove the dropped variables from the encoding dictionary :/
106+
new_encoding = kwargs.pop("encoding", None)
107+
if new_encoding:
108+
new_encoding = {k: v for k, v in new_encoding.items() if k in after_drop}
109+
110+
after_drop.to_zarr(
111+
store, **kwargs, encoding=new_encoding, compute=True, mode="a"
112+
)
113+
114+
# can't use drop_dims since that will also remove any variable
115+
# with any of the dims to be dropped
116+
# even if they also have one or more of region_dims
117+
dims_to_drop = set(ds.dims) - set(region_dims)
118+
vars_to_drop = [
119+
name
120+
for name, var in ds._variables.items()
121+
if set(var.dims).issubset(dims_to_drop)
122+
]
123+
return ds.drop_vars(vars_to_drop)
124+
125+
else:
126+
return ds
127+
128+
42129
def encode_zarr_attr_value(value):
43130
"""
44131
Encode a attribute value as something that can be serialized as json

xarray/tests/test_backends.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
)
4949
from xarray.backends.pydap_ import PydapDataStore
5050
from xarray.backends.scipy_ import ScipyBackendEntrypoint
51+
from xarray.backends.zarr import initialize_zarr
5152
from xarray.coding.strings import check_vlen_dtype, create_vlen_dtype
5253
from xarray.coding.variables import SerializationWarning
5354
from xarray.conventions import encode_dataset_coordinates
@@ -5434,3 +5435,57 @@ def test_zarr_region_transpose(tmp_path):
54345435
ds_region.to_zarr(
54355436
tmp_path / "test.zarr", region={"x": slice(0, 1), "y": slice(0, 1)}
54365437
)
5438+
5439+
5440+
@requires_dask
5441+
@requires_zarr
5442+
def test_initialize_zarr(tmp_path) -> None:
5443+
# TODO:
5444+
# 1. with encoding
5445+
# 2. with regions
5446+
# 3. w-
5447+
# 4. mode = r?
5448+
# 5. mode=w+?
5449+
# 5. no region_dims
5450+
x = np.arange(0, 50, 10)
5451+
y = np.arange(0, 20, 2)
5452+
data = dask.array.ones((5, 10), chunks=(1, -1))
5453+
ds = xr.Dataset(
5454+
{
5455+
"xy": (("x", "y"), data),
5456+
"xonly": ("x", data[:, 0]),
5457+
"yonly": ("y", data[0, :]),
5458+
"eager_xonly": ("x", data[:, 0].compute()),
5459+
"eager_yonly": ("y", data[0, :].compute().astype(int)),
5460+
"scalar": 2,
5461+
},
5462+
coords={"x": x, "y": y},
5463+
)
5464+
store = tmp_path / "foo.zarr"
5465+
5466+
with pytest.raises(ValueError, match="Only mode"):
5467+
initialize_zarr(store, ds, mode="r")
5468+
5469+
expected_on_disk = ds.copy(deep=True).assign(
5470+
{
5471+
# chunked variables are all NaNs (really fill_value?)
5472+
"xy": xr.full_like(ds.xy, fill_value=np.nan),
5473+
"xonly": xr.full_like(ds.xonly, fill_value=np.nan),
5474+
# eager variables with region_dim are all zeros (since we do zeros_like)
5475+
"eager_xonly": xr.full_like(ds.xonly, fill_value=0),
5476+
# eager variables without region_dim are identical
5477+
# but are subject to two writes, first zeros then actual values
5478+
"eager_yonly": ds.yonly,
5479+
}
5480+
)
5481+
expected_after_init = ds.drop_vars(["yonly", "eager_yonly", "y", "scalar"])
5482+
after_init = initialize_zarr(store, ds, region_dims=("x",))
5483+
assert_identical(expected_after_init, after_init)
5484+
5485+
with xr.open_zarr(store) as actual:
5486+
assert_identical(expected_on_disk, actual)
5487+
5488+
for i in range(ds.sizes["x"]):
5489+
after_init.isel(x=[i]).to_zarr(store, region={"x": slice(i, i + 1)})
5490+
with xr.open_zarr(store) as actual:
5491+
assert_identical(ds, actual)

0 commit comments

Comments
 (0)