Skip to content

Commit ff5a0a0

Browse files
authored
Merge output variables with input dataset (#217)
* Count allele functions should return datasets not arrays. * Add merge=True to count allele functions. * Issue a MergeWarning in the case when input variables are overwritten.
1 parent 5e04447 commit ff5a0a0

File tree

5 files changed

+137
-35
lines changed

5 files changed

+137
-35
lines changed

setup.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ fail_under = 100
4242
[tool:pytest]
4343
addopts = --doctest-modules --ignore=validation
4444
norecursedirs = .eggs docs
45+
filterwarnings =
46+
error
4547

4648
[flake8]
4749
ignore =

sgkit/stats/aggregation.py

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import dask.array as da
22
import numpy as np
3-
import xarray as xr
43
from numba import guvectorize
5-
from xarray import DataArray, Dataset
4+
from xarray import Dataset
65

7-
from ..typing import ArrayLike
6+
from sgkit.typing import ArrayLike
7+
from sgkit.utils import merge_datasets
88

99

1010
@guvectorize( # type: ignore
@@ -45,21 +45,27 @@ def count_alleles(g: ArrayLike, _: ArrayLike, out: ArrayLike) -> None:
4545
out[a] += 1
4646

4747

48-
def count_call_alleles(ds: Dataset) -> DataArray:
48+
def count_call_alleles(ds: Dataset, merge: bool = True) -> Dataset:
4949
"""Compute per sample allele counts from genotype calls.
5050
5151
Parameters
5252
----------
5353
ds : Dataset
5454
Genotype call dataset such as from
5555
`sgkit.create_genotype_call_dataset`.
56+
merge : bool, optional
57+
If True (the default), merge the input dataset and the computed
58+
output variables into a single dataset. Output variables will
59+
overwrite any input variables with the same name, and a warning
60+
will be issued in this case.
61+
If False, return only the computed output variables.
5662
5763
Returns
5864
-------
59-
call_allele_count : DataArray
60-
Allele counts with shape (variants, samples, alleles) and values
61-
corresponding to the number of non-missing occurrences
62-
of each allele.
65+
Dataset
66+
Array `call_allele_count` of allele counts with
67+
shape (variants, samples, alleles) and values corresponding to
68+
the number of non-missing occurrences of each allele.
6369
6470
Examples
6571
--------
@@ -75,7 +81,7 @@ def count_call_alleles(ds: Dataset) -> DataArray:
7581
2 0/1 1/0
7682
3 0/0 0/0
7783
78-
>>> sg.count_call_alleles(ds).values # doctest: +NORMALIZE_WHITESPACE
84+
>>> sg.count_call_alleles(ds)["call_allele_count"].values # doctest: +NORMALIZE_WHITESPACE
7985
array([[[1, 1],
8086
[1, 1]],
8187
<BLANKLINE>
@@ -92,28 +98,40 @@ def count_call_alleles(ds: Dataset) -> DataArray:
9298
G = da.asarray(ds["call_genotype"])
9399
shape = (G.chunks[0], G.chunks[1], n_alleles)
94100
N = da.empty(n_alleles, dtype=np.uint8)
95-
return xr.DataArray(
96-
da.map_blocks(count_alleles, G, N, chunks=shape, drop_axis=2, new_axis=2),
97-
dims=("variants", "samples", "alleles"),
98-
name="call_allele_count",
101+
new_ds = Dataset(
102+
{
103+
"call_allele_count": (
104+
("variants", "samples", "alleles"),
105+
da.map_blocks(
106+
count_alleles, G, N, chunks=shape, drop_axis=2, new_axis=2
107+
),
108+
)
109+
}
99110
)
111+
return merge_datasets(ds, new_ds) if merge else new_ds
100112

101113

102-
def count_variant_alleles(ds: Dataset) -> DataArray:
114+
def count_variant_alleles(ds: Dataset, merge: bool = True) -> Dataset:
103115
"""Compute allele count from genotype calls.
104116
105117
Parameters
106118
----------
107119
ds : Dataset
108120
Genotype call dataset such as from
109121
`sgkit.create_genotype_call_dataset`.
122+
merge : bool, optional
123+
If True (the default), merge the input dataset and the computed
124+
output variables into a single dataset. Output variables will
125+
overwrite any input variables with the same name, and a warning
126+
will be issued in this case.
127+
If False, return only the computed output variables.
110128
111129
Returns
112130
-------
113-
variant_allele_count : DataArray
114-
Allele counts with shape (variants, alleles) and values
115-
corresponding to the number of non-missing occurrences
116-
of each allele.
131+
Dataset
132+
Array `variant_allele_count` of allele counts with
133+
shape (variants, alleles) and values corresponding to
134+
the number of non-missing occurrences of each allele.
117135
118136
Examples
119137
--------
@@ -129,13 +147,18 @@ def count_variant_alleles(ds: Dataset) -> DataArray:
129147
2 0/1 1/0
130148
3 0/0 0/0
131149
132-
>>> sg.count_variant_alleles(ds).values # doctest: +NORMALIZE_WHITESPACE
150+
>>> sg.count_variant_alleles(ds)["variant_allele_count"].values # doctest: +NORMALIZE_WHITESPACE
133151
array([[2, 2],
134152
[1, 3],
135153
[2, 2],
136154
[4, 0]], dtype=uint64)
137155
"""
138-
return xr.DataArray(
139-
count_call_alleles(ds).sum(dim="samples").rename("variant_allele_count"),
140-
dims=("variants", "alleles"),
156+
new_ds = Dataset(
157+
{
158+
"variant_allele_count": (
159+
("variants", "alleles"),
160+
count_call_alleles(ds)["call_allele_count"].sum(dim="samples"),
161+
)
162+
}
141163
)
164+
return merge_datasets(ds, new_ds) if merge else new_ds

sgkit/tests/test_aggregation.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,26 @@ def get_dataset(calls: ArrayLike, **kwargs: Any) -> Dataset:
2121

2222

2323
def test_count_variant_alleles__single_variant_single_sample():
24-
ac = count_variant_alleles(get_dataset([[[1, 0]]]))
24+
ds = count_variant_alleles(get_dataset([[[1, 0]]]))
25+
assert "call_genotype" in ds
26+
ac = ds["variant_allele_count"]
2527
np.testing.assert_equal(ac, np.array([[1, 1]]))
2628

2729

2830
def test_count_variant_alleles__multi_variant_single_sample():
29-
ac = count_variant_alleles(get_dataset([[[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]]))
31+
ds = count_variant_alleles(get_dataset([[[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]]))
32+
ac = ds["variant_allele_count"]
3033
np.testing.assert_equal(ac, np.array([[2, 0], [1, 1], [1, 1], [0, 2]]))
3134

3235

3336
def test_count_variant_alleles__single_variant_multi_sample():
34-
ac = count_variant_alleles(get_dataset([[[0, 0], [1, 0], [0, 1], [1, 1]]]))
37+
ds = count_variant_alleles(get_dataset([[[0, 0], [1, 0], [0, 1], [1, 1]]]))
38+
ac = ds["variant_allele_count"]
3539
np.testing.assert_equal(ac, np.array([[4, 4]]))
3640

3741

3842
def test_count_variant_alleles__multi_variant_multi_sample():
39-
ac = count_variant_alleles(
43+
ds = count_variant_alleles(
4044
get_dataset(
4145
[
4246
[[0, 0], [0, 0], [0, 0]],
@@ -46,11 +50,12 @@ def test_count_variant_alleles__multi_variant_multi_sample():
4650
]
4751
)
4852
)
53+
ac = ds["variant_allele_count"]
4954
np.testing.assert_equal(ac, np.array([[6, 0], [5, 1], [2, 4], [0, 6]]))
5055

5156

5257
def test_count_variant_alleles__missing_data():
53-
ac = count_variant_alleles(
58+
ds = count_variant_alleles(
5459
get_dataset(
5560
[
5661
[[-1, -1], [-1, -1], [-1, -1]],
@@ -60,11 +65,12 @@ def test_count_variant_alleles__missing_data():
6065
]
6166
)
6267
)
68+
ac = ds["variant_allele_count"]
6369
np.testing.assert_equal(ac, np.array([[0, 0], [2, 1], [1, 2], [0, 6]]))
6470

6571

6672
def test_count_variant_alleles__higher_ploidy():
67-
ac = count_variant_alleles(
73+
ds = count_variant_alleles(
6874
get_dataset(
6975
[
7076
[[-1, -1, 0], [-1, -1, 1], [-1, -1, 2]],
@@ -74,6 +80,7 @@ def test_count_variant_alleles__higher_ploidy():
7480
n_ploidy=3,
7581
)
7682
)
83+
ac = ds["variant_allele_count"]
7784
np.testing.assert_equal(ac, np.array([[1, 1, 1, 0], [1, 2, 2, 1]]))
7885

7986

@@ -88,23 +95,33 @@ def test_count_variant_alleles__chunked():
8895
xr.testing.assert_equal(ac1, ac2) # type: ignore[no-untyped-call]
8996

9097

98+
def test_count_variant_alleles__no_merge():
99+
ds = count_variant_alleles(get_dataset([[[1, 0]]]), merge=False)
100+
assert "call_genotype" not in ds
101+
ac = ds["variant_allele_count"]
102+
np.testing.assert_equal(ac, np.array([[1, 1]]))
103+
104+
91105
def test_count_call_alleles__single_variant_single_sample():
92-
ac = count_call_alleles(get_dataset([[[1, 0]]]))
106+
ds = count_call_alleles(get_dataset([[[1, 0]]]))
107+
ac = ds["call_allele_count"]
93108
np.testing.assert_equal(ac, np.array([[[1, 1]]]))
94109

95110

96111
def test_count_call_alleles__multi_variant_single_sample():
97-
ac = count_call_alleles(get_dataset([[[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]]))
112+
ds = count_call_alleles(get_dataset([[[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]]))
113+
ac = ds["call_allele_count"]
98114
np.testing.assert_equal(ac, np.array([[[2, 0]], [[1, 1]], [[1, 1]], [[0, 2]]]))
99115

100116

101117
def test_count_call_alleles__single_variant_multi_sample():
102-
ac = count_call_alleles(get_dataset([[[0, 0], [1, 0], [0, 1], [1, 1]]]))
118+
ds = count_call_alleles(get_dataset([[[0, 0], [1, 0], [0, 1], [1, 1]]]))
119+
ac = ds["call_allele_count"]
103120
np.testing.assert_equal(ac, np.array([[[2, 0], [1, 1], [1, 1], [0, 2]]]))
104121

105122

106123
def test_count_call_alleles__multi_variant_multi_sample():
107-
ac = count_call_alleles(
124+
ds = count_call_alleles(
108125
get_dataset(
109126
[
110127
[[0, 0], [0, 0], [0, 0]],
@@ -114,6 +131,7 @@ def test_count_call_alleles__multi_variant_multi_sample():
114131
]
115132
)
116133
)
134+
ac = ds["call_allele_count"]
117135
np.testing.assert_equal(
118136
ac,
119137
np.array(
@@ -128,7 +146,7 @@ def test_count_call_alleles__multi_variant_multi_sample():
128146

129147

130148
def test_count_call_alleles__missing_data():
131-
ac = count_call_alleles(
149+
ds = count_call_alleles(
132150
get_dataset(
133151
[
134152
[[-1, -1], [-1, -1], [-1, -1]],
@@ -138,6 +156,7 @@ def test_count_call_alleles__missing_data():
138156
]
139157
)
140158
)
159+
ac = ds["call_allele_count"]
141160
np.testing.assert_equal(
142161
ac,
143162
np.array(
@@ -152,7 +171,7 @@ def test_count_call_alleles__missing_data():
152171

153172

154173
def test_count_call_alleles__higher_ploidy():
155-
ac = count_call_alleles(
174+
ds = count_call_alleles(
156175
get_dataset(
157176
[
158177
[[-1, -1, 0], [-1, -1, 1], [-1, -1, 2]],
@@ -162,6 +181,7 @@ def test_count_call_alleles__higher_ploidy():
162181
n_ploidy=3,
163182
)
164183
)
184+
ac = ds["call_allele_count"]
165185
np.testing.assert_equal(
166186
ac,
167187
np.array(

sgkit/tests/test_utils.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
11
from typing import Any, List
22

3+
import dask.array as da
34
import numpy as np
45
import pytest
6+
import xarray as xr
57
from hypothesis import given, settings
68
from hypothesis import strategies as st
79

8-
from sgkit.utils import check_array_like, encode_array, split_array_chunks
10+
from sgkit.utils import (
11+
MergeWarning,
12+
check_array_like,
13+
encode_array,
14+
merge_datasets,
15+
split_array_chunks,
16+
)
917

1018

1119
def test_check_array_like():
@@ -66,6 +74,20 @@ def test_encode_array(
6674
np.testing.assert_equal(n, expected_names)
6775

6876

77+
def test_merge_datasets():
78+
ds = xr.Dataset(dict(x=xr.DataArray(da.zeros(100))))
79+
80+
new_ds1 = xr.Dataset(dict(y=xr.DataArray(da.zeros(100))))
81+
new_ds2 = xr.Dataset(dict(y=xr.DataArray(da.ones(100))))
82+
83+
ds = merge_datasets(ds, new_ds1)
84+
assert "y" in ds
85+
86+
with pytest.warns(MergeWarning):
87+
ds = merge_datasets(ds, new_ds2)
88+
np.testing.assert_equal(ds["y"].values, np.ones(100))
89+
90+
6991
@pytest.mark.parametrize(
7092
"n,blocks,expected_chunks",
7193
[

sgkit/utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import warnings
12
from typing import Any, List, Set, Tuple, Union
23

34
import numpy as np
5+
from xarray import Dataset
46

57
from .typing import ArrayLike, DType
68

@@ -100,6 +102,39 @@ def encode_array(x: ArrayLike) -> Tuple[ArrayLike, List[Any]]:
100102
return rank[inverse], names[index]
101103

102104

105+
class MergeWarning(UserWarning):
106+
"""Warnings about merging datasets."""
107+
108+
pass
109+
110+
111+
def merge_datasets(input: Dataset, output: Dataset) -> Dataset:
112+
"""Merge the input and output datasets into a new dataset, giving precedence to variables in the output.
113+
114+
Parameters
115+
----------
116+
input : Dataset
117+
The input dataset.
118+
output : Dataset
119+
The output dataset.
120+
121+
Returns
122+
-------
123+
Dataset
124+
The merged dataset. If `input` and `output` have variables with the same name, a `MergeWarning`
125+
is issued, and the variables from the `output` dataset are used.
126+
"""
127+
input_vars = {str(v) for v in input.data_vars.keys()}
128+
output_vars = {str(v) for v in output.data_vars.keys()}
129+
clobber_vars = sorted(list(input_vars & output_vars))
130+
if len(clobber_vars) > 0:
131+
warnings.warn(
132+
f"The following variables in the input dataset will be replaced in the output: {', '.join(clobber_vars)}",
133+
MergeWarning,
134+
)
135+
return output.merge(input, compat="override")
136+
137+
103138
def split_array_chunks(n: int, blocks: int) -> Tuple[int, ...]:
104139
"""Compute chunk sizes for an array split into blocks.
105140

0 commit comments

Comments
 (0)