Skip to content

Commit 4b46a59

Browse files
committed
Make concat algorithm a parameter
1 parent 6c7fd4a commit 4b46a59

File tree

3 files changed

+89
-58
lines changed

3 files changed

+89
-58
lines changed

sgkit/io/vcf/vcf_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def zarrs_to_dataset(
304304

305305
storage_options = storage_options or {}
306306

307-
datasets = [xr.open_zarr(fsspec.get_mapper(path, **storage_options)) for path in urls] # type: ignore[no-untyped-call]
307+
datasets = [xr.open_zarr(fsspec.get_mapper(path, **storage_options), concat_characters=False) for path in urls] # type: ignore[no-untyped-call]
308308

309309
# Combine the datasets into one
310310
ds = xr.concat(datasets, dim="variants", data_vars="minimal")

sgkit/io/vcfzarr_reader.py

Lines changed: 66 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import tempfile
22
from pathlib import Path
3-
from typing import List, Optional
3+
from typing import Hashable, List, Optional
44

55
import dask.array as da
66
import xarray as xr
@@ -11,6 +11,7 @@
1111
from ..model import DIM_VARIANT, create_genotype_call_dataset
1212
from ..typing import ArrayLike, PathType
1313
from ..utils import encode_array, max_str_len
14+
from .vcf.vcf_reader import zarrs_to_dataset
1415

1516

1617
def _ensure_2d(arr: ArrayLike) -> ArrayLike:
@@ -60,6 +61,7 @@ def vcfzarr_to_zarr(
6061
grouped_by_contig: bool = False,
6162
consolidated: bool = False,
6263
tempdir: Optional[PathType] = None,
64+
concat_algorithm: Optional[str] = None,
6365
) -> None:
6466
"""Convert VCF Zarr files created using scikit-allel to a single Zarr on-disk store in sgkit Xarray format.
6567
@@ -78,6 +80,10 @@ def vcfzarr_to_zarr(
7880
tempdir
7981
Temporary directory where intermediate files are stored. The default None means
8082
use the system default temporary directory.
83+
concat_algorithm
84+
The algorithm to use to concatenate and rechunk Zarr files. The default None means
85+
use the optimized version suitable for large files, whereas ``xarray_internal`` will
86+
use built-in Xarray APIs, which can exhibit high memory usage, see https://github.com/dask/dask/issues/6745.
8187
"""
8288

8389
if consolidated:
@@ -118,50 +124,16 @@ def vcfzarr_to_zarr(
118124
contig_zarr_file = Path(tmpdir) / contig
119125
ds.to_zarr(contig_zarr_file)
120126

121-
zarr_files.append(contig_zarr_file)
122-
123-
zarr_groups = [zarr.open_group(str(f)) for f in zarr_files]
124-
125-
first_zarr_group = zarr_groups[0]
126-
127-
with zarr.open_group(str(output)) as output_zarr:
128-
129-
var_to_attrs = {} # attributes to copy
130-
delayed = [] # do all the rechunking operations in one computation
131-
for var in vars_to_rechunk:
132-
var_to_attrs[var] = first_zarr_group[var].attrs.asdict()
133-
dtype = None
134-
if var == "variant_id":
135-
kind = first_zarr_group[var].dtype.kind
136-
max_len = _get_max_len(zarr_groups, "max_variant_id_length")
137-
dtype = f"{kind}{max_len}"
138-
elif var == "variant_allele":
139-
kind = first_zarr_group[var].dtype.kind
140-
max_len = _get_max_len(zarr_groups, "max_variant_allele_length")
141-
dtype = f"{kind}{max_len}"
142-
143-
arr = concatenate_and_rechunk(
144-
[group[var] for group in zarr_groups], dtype=dtype
145-
)
146-
d = arr.to_zarr(
147-
str(output),
148-
component=var,
149-
overwrite=True,
150-
compute=False,
151-
fill_value=None,
152-
)
153-
delayed.append(d)
154-
da.compute(*delayed)
155-
156-
# copy variables that are not rechunked (e.g. sample_id)
157-
for var in vars_to_copy:
158-
output_zarr[var] = first_zarr_group[var]
159-
output_zarr[var].attrs.update(first_zarr_group[var].attrs)
160-
161-
# copy attributes
162-
output_zarr.attrs.update(first_zarr_group.attrs)
163-
for (var, attrs) in var_to_attrs.items():
164-
output_zarr[var].attrs.update(attrs)
127+
zarr_files.append(str(contig_zarr_file))
128+
129+
if concat_algorithm == "xarray_internal":
130+
ds = zarrs_to_dataset(zarr_files)
131+
ds.to_zarr(output, mode="w")
132+
else:
133+
# Use the optimized algorithm in `concatenate_and_rechunk`
134+
_concat_zarrs_optimized(
135+
zarr_files, output, vars_to_rechunk, vars_to_copy
136+
)
165137

166138

167139
def _vcfzarr_to_dataset(
@@ -222,7 +194,7 @@ def _vcfzarr_to_dataset(
222194
kind = arr.dtype.kind
223195
if kind in ["O", "U", "S"]:
224196
# Compute fixed-length string dtype for array
225-
if kind == "O":
197+
if kind == "O" or var in ("variant_id", "variant_allele"):
226198
kind = "S"
227199
max_len = max_str_len(arr).values
228200
dt = f"{kind}{max_len}"
@@ -239,3 +211,51 @@ def _vcfzarr_to_dataset(
239211
def _get_max_len(zarr_groups: List[zarr.Group], attr_name: str) -> int:
240212
max_len: int = max([group.attrs[attr_name] for group in zarr_groups])
241213
return max_len
214+
215+
216+
def _concat_zarrs_optimized(
217+
zarr_files: List[str],
218+
output: PathType,
219+
vars_to_rechunk: List[Hashable],
220+
vars_to_copy: List[Hashable],
221+
) -> None:
222+
zarr_groups = [zarr.open_group(f) for f in zarr_files]
223+
224+
first_zarr_group = zarr_groups[0]
225+
226+
with zarr.open_group(str(output)) as output_zarr:
227+
228+
var_to_attrs = {} # attributes to copy
229+
delayed = [] # do all the rechunking operations in one computation
230+
for var in vars_to_rechunk:
231+
var_to_attrs[var] = first_zarr_group[var].attrs.asdict()
232+
dtype = None
233+
if var == "variant_id":
234+
max_len = _get_max_len(zarr_groups, "max_variant_id_length")
235+
dtype = f"S{max_len}"
236+
elif var == "variant_allele":
237+
max_len = _get_max_len(zarr_groups, "max_variant_allele_length")
238+
dtype = f"S{max_len}"
239+
240+
arr = concatenate_and_rechunk(
241+
[group[var] for group in zarr_groups], dtype=dtype
242+
)
243+
d = arr.to_zarr(
244+
str(output),
245+
component=var,
246+
overwrite=True,
247+
compute=False,
248+
fill_value=None,
249+
)
250+
delayed.append(d)
251+
da.compute(*delayed)
252+
253+
# copy variables that are not rechunked (e.g. sample_id)
254+
for var in vars_to_copy:
255+
output_zarr[var] = first_zarr_group[var]
256+
output_zarr[var].attrs.update(first_zarr_group[var].attrs)
257+
258+
# copy attributes
259+
output_zarr.attrs.update(first_zarr_group.attrs)
260+
for (var, attrs) in var_to_attrs.items():
261+
output_zarr[var].attrs.update(attrs)

sgkit/tests/test_vcfzarr_reader.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,13 @@ def test_read_vcfzarr(shared_datadir):
7676
"vcfzarr_filename, grouped_by_contig",
7777
[("sample.vcf.zarr.zip", False), ("sample-grouped.vcf.zarr.zip", True)],
7878
)
79-
def test_vcfzarr_to_zarr(shared_datadir, tmp_path, vcfzarr_filename, grouped_by_contig):
79+
@pytest.mark.parametrize(
80+
"concat_algorithm",
81+
[None, "xarray_internal"],
82+
)
83+
def test_vcfzarr_to_zarr(
84+
shared_datadir, tmp_path, vcfzarr_filename, grouped_by_contig, concat_algorithm
85+
):
8086
# The file sample-grouped.vcf.zarr.zip was created by running the following
8187
# in a python session with the scikit-allel package installed.
8288
#
@@ -89,7 +95,12 @@ def test_vcfzarr_to_zarr(shared_datadir, tmp_path, vcfzarr_filename, grouped_by_
8995

9096
path = shared_datadir / vcfzarr_filename
9197
output = tmp_path.joinpath("vcf.zarr").as_posix()
92-
vcfzarr_to_zarr(path, output, grouped_by_contig=grouped_by_contig)
98+
vcfzarr_to_zarr(
99+
path,
100+
output,
101+
grouped_by_contig=grouped_by_contig,
102+
concat_algorithm=concat_algorithm,
103+
)
93104

94105
ds = xr.open_zarr(output) # type: ignore[no-untyped-call]
95106

@@ -119,15 +130,15 @@ def test_vcfzarr_to_zarr(shared_datadir, tmp_path, vcfzarr_filename, grouped_by_
119130
assert_array_equal(
120131
ds["variant_id"],
121132
[
122-
".",
123-
".",
124-
"rs6054257",
125-
".",
126-
"rs6040355",
127-
".",
128-
"microsat1",
129-
".",
130-
"rsTest",
133+
b".",
134+
b".",
135+
b"rs6054257",
136+
b".",
137+
b"rs6040355",
138+
b".",
139+
b"microsat1",
140+
b".",
141+
b"rsTest",
131142
],
132143
)
133144
assert_array_equal(

0 commit comments

Comments
 (0)