Skip to content

Commit fdd7b62

Browse files
committed
Add merge=True to count allele functions.
1 parent ae8171c commit fdd7b62

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

sgkit/stats/aggregation.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,17 @@ def count_alleles(g: ArrayLike, _: ArrayLike, out: ArrayLike) -> None:
4444
out[a] += 1
4545

4646

47-
def count_call_alleles(ds: Dataset) -> Dataset:
47+
def count_call_alleles(ds: Dataset, merge: bool = True) -> Dataset:
4848
"""Compute per sample allele counts from genotype calls.
4949
5050
Parameters
5151
----------
5252
ds : Dataset
5353
Genotype call dataset such as from
5454
`sgkit.create_genotype_call_dataset`.
55+
merge : bool
56+
If True, merge the input dataset and the computed variables into
57+
a single dataset, otherwise return only the computed variables.
5558
5659
Returns
5760
-------
@@ -91,7 +94,7 @@ def count_call_alleles(ds: Dataset) -> Dataset:
9194
G = da.asarray(ds["call_genotype"])
9295
shape = (G.chunks[0], G.chunks[1], n_alleles)
9396
N = da.empty(n_alleles, dtype=np.uint8)
94-
return Dataset(
97+
new_ds = Dataset(
9598
{
9699
"call_allele_count": (
97100
("variants", "samples", "alleles"),
@@ -101,16 +104,20 @@ def count_call_alleles(ds: Dataset) -> Dataset:
101104
)
102105
}
103106
)
107+
return ds.merge(new_ds) if merge else new_ds
104108

105109

106-
def count_variant_alleles(ds: Dataset) -> Dataset:
110+
def count_variant_alleles(ds: Dataset, merge: bool = True) -> Dataset:
107111
"""Compute allele count from genotype calls.
108112
109113
Parameters
110114
----------
111115
ds : Dataset
112116
Genotype call dataset such as from
113117
`sgkit.create_genotype_call_dataset`.
118+
merge : bool
119+
If True, merge the input dataset and the computed variables into
120+
a single dataset, otherwise return only the computed variables.
114121
115122
Returns
116123
-------
@@ -139,11 +146,12 @@ def count_variant_alleles(ds: Dataset) -> Dataset:
139146
[2, 2],
140147
[4, 0]], dtype=uint64)
141148
"""
142-
return Dataset(
149+
new_ds = Dataset(
143150
{
144151
"variant_allele_count": (
145152
("variants", "alleles"),
146153
count_call_alleles(ds)["call_allele_count"].sum(dim="samples"),
147154
)
148155
}
149156
)
157+
return ds.merge(new_ds) if merge else new_ds

sgkit/tests/test_aggregation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def get_dataset(calls: ArrayLike, **kwargs: Any) -> Dataset:
2222

2323
def test_count_variant_alleles__single_variant_single_sample():
2424
ds = count_variant_alleles(get_dataset([[[1, 0]]]))
25+
assert "call_genotype" in ds
2526
ac = ds["variant_allele_count"]
2627
np.testing.assert_equal(ac, np.array([[1, 1]]))
2728

@@ -94,6 +95,13 @@ def test_count_variant_alleles__chunked():
9495
xr.testing.assert_equal(ac1, ac2) # type: ignore[no-untyped-call]
9596

9697

98+
def test_count_variant_alleles__no_merge():
99+
ds = count_variant_alleles(get_dataset([[[1, 0]]]), merge=False)
100+
assert "call_genotype" not in ds
101+
ac = ds["variant_allele_count"]
102+
np.testing.assert_equal(ac, np.array([[1, 1]]))
103+
104+
97105
def test_count_call_alleles__single_variant_single_sample():
98106
ds = count_call_alleles(get_dataset([[[1, 0]]]))
99107
ac = ds["call_allele_count"]

0 commit comments

Comments
 (0)