|
3 | 3 | import json
|
4 | 4 | import os
|
5 | 5 | import warnings
|
6 |
| -from collections.abc import Iterable |
7 |
| -from typing import TYPE_CHECKING, Any |
| 6 | +from collections.abc import Hashable, Iterable |
| 7 | +from typing import TYPE_CHECKING, Any, Literal |
8 | 8 |
|
9 | 9 | import numpy as np
|
10 | 10 |
|
|
19 | 19 | )
|
20 | 20 | from xarray.backends.store import StoreBackendEntrypoint
|
21 | 21 | from xarray.core import indexing
|
| 22 | +from xarray.core.common import zeros_like |
22 | 23 | from xarray.core.parallelcompat import guess_chunkmanager
|
23 | 24 | from xarray.core.pycompat import integer_types
|
24 | 25 | from xarray.core.utils import (
|
|
34 | 35 | from xarray.backends.common import AbstractDataStore
|
35 | 36 | from xarray.core.dataset import Dataset
|
36 | 37 |
|
37 |
| - |
38 | 38 | # need some special secret attributes to tell us the dimensions
|
39 | 39 | DIMENSION_KEY = "_ARRAY_DIMENSIONS"
|
40 | 40 |
|
41 | 41 |
|
| 42 | +def initialize_zarr( |
| 43 | + store, |
| 44 | + ds: Dataset, |
| 45 | + *, |
| 46 | + region_dims: Iterable[Hashable] | None = None, |
| 47 | + mode: Literal["w", "w-"] = "w-", |
| 48 | + **kwargs, |
| 49 | +) -> Dataset: |
| 50 | + """ |
| 51 | + Initialize a Zarr store with metadata. |
| 52 | +
|
| 53 | + This function initializes a Zarr store with metadata that describes the entire datasets. |
| 54 | + If ``region_dims`` is specified, it will also |
| 55 | + 1. Write variables that don't contain any of ``region_dims``, and |
| 56 | + 2. Return a dataset with variables that do contain one or more of ``region_dims``. |
| 57 | + This dataset can be used for region writes in parallel. |
| 58 | +
|
| 59 | + Parameters |
| 60 | + ---------- |
| 61 | + store : MutableMapping or str |
| 62 | + Zarr store to write to. |
| 63 | + ds : Dataset |
| 64 | + Dataset to write. |
| 65 | + region_dims : Iterable[Hashable], optional |
| 66 | + An iterable of dimension names that will be passed to the ``region`` |
| 67 | + kwarg of ``to_zarr`` later. |
| 68 | + mode : {'w', 'w-'} |
| 69 | + Write mode for initializing the store. |
| 70 | +
|
| 71 | + Returns |
| 72 | + ------- |
| 73 | + Dataset |
| 74 | + Dataset containing variables with one or more ``region_dims`` |
| 75 | + dimensions. Use this for writing to the store in parallel later. |
| 76 | +
|
| 77 | + Raises |
| 78 | + ------ |
| 79 | + ValueError |
| 80 | +
|
| 81 | + """ |
| 82 | + |
| 83 | + if "compute" in kwargs: |
| 84 | + raise ValueError("The ``compute`` kwarg is not supported in `initialize_zarr`.") |
| 85 | + |
| 86 | + if not ds.chunks: |
| 87 | + raise ValueError("This function should be used with chunked Datasets.") |
| 88 | + |
| 89 | + if mode not in ["w", "w-"]: |
| 90 | + raise ValueError( |
| 91 | + f"Only mode='w' or mode='w-' is allowed for initialize_zarr. Received mode={mode!r}" |
| 92 | + ) |
| 93 | + |
| 94 | + # TODO: what should we do here. |
| 95 | + # compute=False only skips dask variables. |
| 96 | + # - We could reaplce all dask variables with zeros_like |
| 97 | + # - and then write all other variables eagerly. |
| 98 | + # Right now we do two writes for eager variables |
| 99 | + template = zeros_like(ds) |
| 100 | + template.to_zarr(store, mode=mode, **kwargs, compute=False) |
| 101 | + |
| 102 | + if region_dims: |
| 103 | + after_drop = ds.drop_dims(region_dims) |
| 104 | + |
| 105 | + # we have to remove the dropped variables from the encoding dictionary :/ |
| 106 | + new_encoding = kwargs.pop("encoding", None) |
| 107 | + if new_encoding: |
| 108 | + new_encoding = {k: v for k, v in new_encoding.items() if k in after_drop} |
| 109 | + |
| 110 | + after_drop.to_zarr( |
| 111 | + store, **kwargs, encoding=new_encoding, compute=True, mode="a" |
| 112 | + ) |
| 113 | + |
| 114 | + # can't use drop_dims since that will also remove any variable |
| 115 | + # with any of the dims to be dropped |
| 116 | + # even if they also have one or more of region_dims |
| 117 | + dims_to_drop = set(ds.dims) - set(region_dims) |
| 118 | + vars_to_drop = [ |
| 119 | + name |
| 120 | + for name, var in ds._variables.items() |
| 121 | + if set(var.dims).issubset(dims_to_drop) |
| 122 | + ] |
| 123 | + return ds.drop_vars(vars_to_drop) |
| 124 | + |
| 125 | + else: |
| 126 | + return ds |
| 127 | + |
| 128 | + |
42 | 129 | def encode_zarr_attr_value(value):
|
43 | 130 | """
|
44 | 131 | Encode a attribute value as something that can be serialized as json
|
|
0 commit comments