Skip to content

Commit 2a7341c

Browse files
jeromekellehermergify[bot]
authored andcommitted
VCF masking
1 parent c0c32a5 commit 2a7341c

File tree

4 files changed

+280
-2
lines changed

4 files changed

+280
-2
lines changed

python/CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@
3434
methods to return arrays of per-individual times and populations, respectively.
3535
(:user:`petrelharp`, :issue:`1481`, :pr:`2298`).
3636

37+
- Add the ``sample_mask`` and ``site_mask`` to ``write_vcf`` to allow parts
38+
of an output VCF to be omitted or marked as missing data. Also add the
39+
``as_vcf`` convenience function, to return VCF as a string.
40+
(:user:`jeromekelleher`, :pr:`2300`).
41+
3742
**Breaking Changes**
3843

3944
- The JSON metadata codec now interprets the empty string as an empty object. This means

python/tests/test_vcf.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,14 @@
2929
import math
3030
import os
3131
import tempfile
32+
import textwrap
3233

3334
import msprime
3435
import numpy as np
3536
import pytest
3637
import vcf
3738

39+
import tests
3840
import tests.test_wright_fisher as wf
3941
import tskit
4042
from tests import tsutil
@@ -638,3 +640,192 @@ def test_defaults(self):
638640
assert ts.num_sites > 0
639641
with ts_to_pyvcf(ts) as vcf_reader:
640642
assert vcf_reader.samples == ["tsk_0", "tsk_1"]
643+
644+
645+
def drop_header(s):
646+
return "\n".join(line for line in s.splitlines() if not line.startswith("##"))
647+
648+
649+
class TestMasking:
650+
@tests.cached_example
651+
def ts(self):
652+
ts = tskit.Tree.generate_balanced(3, span=10).tree_sequence
653+
ts = tsutil.insert_branch_sites(ts)
654+
return ts
655+
656+
@pytest.mark.parametrize("mask", [[True], np.zeros(5, dtype=bool), []])
657+
def test_site_mask_wrong_size(self, mask):
658+
with pytest.raises(ValueError, match="Site mask must be"):
659+
self.ts().as_vcf(site_mask=mask)
660+
661+
@pytest.mark.parametrize("mask", [[[0, 1], [1, 0]], "abcd"])
662+
def test_site_mask_bad_type(self, mask):
663+
# converting to a bool array is pretty lax in what's allows.
664+
with pytest.raises(ValueError, match="Site mask must be"):
665+
self.ts().as_vcf(site_mask=mask)
666+
667+
@pytest.mark.parametrize("mask", [[[0, 1], [1, 0]], "abcd"])
668+
def test_sample_mask_bad_type(self, mask):
669+
# converting to a bool array is pretty lax in what's allows.
670+
with pytest.raises(ValueError, match="Sample mask must be"):
671+
self.ts().as_vcf(sample_mask=mask)
672+
673+
def test_no_masks(self):
674+
s = """\
675+
#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2
676+
1\t0\t0\t0\t1\t.\tPASS\t.\tGT\t1\t0\t0
677+
1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t0\t1\t1
678+
1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t0\t1\t0
679+
1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t0\t0\t1"""
680+
expected = textwrap.dedent(s)
681+
assert drop_header(self.ts().as_vcf()) == expected
682+
683+
def test_no_masks_triploid(self):
684+
s = """\
685+
#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0
686+
1\t0\t0\t0\t1\t.\tPASS\t.\tGT\t1|0|0
687+
1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t0|1|1
688+
1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t0|1|0
689+
1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t0|0|1"""
690+
expected = textwrap.dedent(s)
691+
assert drop_header(self.ts().as_vcf(ploidy=3)) == expected
692+
693+
def test_site_0_masked(self):
694+
s = """\
695+
#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2
696+
1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t0\t1\t1
697+
1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t0\t1\t0
698+
1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t0\t0\t1"""
699+
expected = textwrap.dedent(s)
700+
actual = self.ts().as_vcf(site_mask=[True, False, False, False])
701+
assert drop_header(actual) == expected
702+
703+
def test_site_0_masked_triploid(self):
704+
s = """\
705+
#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0
706+
1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t0|1|1
707+
1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t0|1|0
708+
1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t0|0|1"""
709+
expected = textwrap.dedent(s)
710+
actual = self.ts().as_vcf(ploidy=3, site_mask=[True, False, False, False])
711+
assert drop_header(actual) == expected
712+
713+
def test_site_1_masked(self):
714+
s = """\
715+
#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2
716+
1\t0\t0\t0\t1\t.\tPASS\t.\tGT\t1\t0\t0
717+
1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t0\t1\t0
718+
1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t0\t0\t1"""
719+
expected = textwrap.dedent(s)
720+
actual = self.ts().as_vcf(site_mask=[False, True, False, False])
721+
assert drop_header(actual) == expected
722+
723+
def test_all_sites_masked(self):
724+
s = """\
725+
#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2"""
726+
expected = textwrap.dedent(s)
727+
actual = self.ts().as_vcf(site_mask=[True, True, True, True])
728+
assert drop_header(actual) == expected
729+
730+
def test_all_sites_not_masked(self):
731+
s = """\
732+
#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2
733+
1\t0\t0\t0\t1\t.\tPASS\t.\tGT\t1\t0\t0
734+
1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t0\t1\t1
735+
1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t0\t1\t0
736+
1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t0\t0\t1"""
737+
expected = textwrap.dedent(s)
738+
actual = self.ts().as_vcf(site_mask=[False, False, False, False])
739+
assert drop_header(actual) == expected
740+
741+
@pytest.mark.parametrize(
742+
"mask",
743+
[[False, False, False], [0, 0, 0], lambda _: [False, False, False]],
744+
)
745+
def test_all_samples_not_masked(self, mask):
746+
s = """\
747+
#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2
748+
1\t0\t0\t0\t1\t.\tPASS\t.\tGT\t1\t0\t0
749+
1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t0\t1\t1
750+
1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t0\t1\t0
751+
1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t0\t0\t1"""
752+
expected = textwrap.dedent(s)
753+
actual = self.ts().as_vcf(sample_mask=mask)
754+
assert drop_header(actual) == expected
755+
756+
@pytest.mark.parametrize(
757+
"mask", [[True, False, False], [1, 0, 0], lambda _: [True, False, False]]
758+
)
759+
def test_sample_0_masked(self, mask):
760+
s = """\
761+
#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2
762+
1\t0\t0\t0\t1\t.\tPASS\t.\tGT\t.\t0\t0
763+
1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t.\t1\t1
764+
1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t.\t1\t0
765+
1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t.\t0\t1"""
766+
expected = textwrap.dedent(s)
767+
actual = self.ts().as_vcf(sample_mask=mask)
768+
assert drop_header(actual) == expected
769+
770+
@pytest.mark.parametrize(
771+
"mask", [[False, True, False], [0, 1, 0], lambda _: [False, True, False]]
772+
)
773+
def test_sample_1_masked(self, mask):
774+
s = """\
775+
#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2
776+
1\t0\t0\t0\t1\t.\tPASS\t.\tGT\t1\t.\t0
777+
1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t0\t.\t1
778+
1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t0\t.\t0
779+
1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t0\t.\t1"""
780+
expected = textwrap.dedent(s)
781+
actual = self.ts().as_vcf(sample_mask=mask)
782+
assert drop_header(actual) == expected
783+
784+
@pytest.mark.parametrize(
785+
"mask", [[True, True, True], [1, 1, 1], lambda _: [True, True, True]]
786+
)
787+
def test_all_samples_masked(self, mask):
788+
s = """\
789+
#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2
790+
1\t0\t0\t0\t1\t.\tPASS\t.\tGT\t.\t.\t.
791+
1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t.\t.\t.
792+
1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t.\t.\t.
793+
1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t.\t.\t."""
794+
expected = textwrap.dedent(s)
795+
actual = self.ts().as_vcf(sample_mask=mask)
796+
assert drop_header(actual) == expected
797+
798+
def test_all_functional_sample_mask(self):
799+
s = """\
800+
#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\ttsk_0\ttsk_1\ttsk_2
801+
1\t0\t0\t0\t1\t.\tPASS\t.\tGT\t.\t0\t0
802+
1\t2\t1\t0\t1\t.\tPASS\t.\tGT\t0\t.\t1
803+
1\t4\t2\t0\t1\t.\tPASS\t.\tGT\t0\t1\t.
804+
1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t.\t0\t1"""
805+
806+
def mask(variant):
807+
a = [0, 0, 0]
808+
a[variant.site.id % 3] = 1
809+
return a
810+
811+
expected = textwrap.dedent(s)
812+
actual = self.ts().as_vcf(sample_mask=mask)
813+
assert drop_header(actual) == expected
814+
815+
@pytest.mark.skipif(not _pysam_imported, reason="pysam not available")
816+
def test_mask_ok_with_pysam(self):
817+
with ts_to_pysam(self.ts(), sample_mask=[0, 0, 1]) as records:
818+
variants = list(records)
819+
assert len(variants) == 4
820+
samples = ["tsk_0", "tsk_1", "tsk_2"]
821+
gts = [variants[0].samples[key]["GT"] for key in samples]
822+
assert gts == [(1,), (0,), (None,)]
823+
824+
gts = [variants[1].samples[key]["GT"] for key in samples]
825+
assert gts == [(0,), (1,), (None,)]
826+
827+
gts = [variants[2].samples[key]["GT"] for key in samples]
828+
assert gts == [(0,), (1,), (None,)]
829+
830+
gts = [variants[3].samples[key]["GT"] for key in samples]
831+
assert gts == [(0,), (0,), (None,)]

python/tskit/trees.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5325,6 +5325,18 @@ def samples(self, population=None, *, population_id=None, time=None):
53255325
)
53265326
return samples[keep]
53275327

5328+
def as_vcf(self, **kwargs):
5329+
"""
5330+
Return the result of :meth:`.write_vcf` as a string.
5331+
Keyword parameters are as defined in :meth:`.write_vcf`.
5332+
5333+
:return: A VCF encoding of the variants in this tree sequence as a string.
5334+
:rtype: str
5335+
"""
5336+
buff = io.StringIO()
5337+
self.write_vcf(buff, **kwargs)
5338+
return buff.getvalue()
5339+
53285340
def write_vcf(
53295341
self,
53305342
output,
@@ -5333,6 +5345,8 @@ def write_vcf(
53335345
individuals=None,
53345346
individual_names=None,
53355347
position_transform=None,
5348+
site_mask=None,
5349+
sample_mask=None,
53365350
):
53375351
"""
53385352
Writes a VCF formatted file to the specified file-like object.
@@ -5463,6 +5477,23 @@ def write_vcf(
54635477
54645478
$ tskit vcf example.trees | bcftools view -O b > example.bcf
54655479
5480+
5481+
The ``sample_mask`` argument provides a general way to mask out
5482+
parts of the output, which can be helpful when simulating missing
5483+
data. In this (contrived) example, we create a sample mask function
5484+
that marks one genotype missing in each variant in a regular
5485+
pattern:
5486+
5487+
.. code-block:: python
5488+
5489+
def sample_mask(variant):
5490+
sample_mask = np.zeros(ts.num_samples, dtype=bool)
5491+
sample_mask[variant.site.id % ts.num_samples] = 1
5492+
return sample_mask
5493+
5494+
5495+
ts.write_vcf(sys.stdout, sample_mask=sample_mask)
5496+
54665497
:param io.IOBase output: The file-like object to write the VCF output.
54675498
:param int ploidy: The ploidy of the individuals to be written to
54685499
VCF. This sample size must be evenly divisible by ploidy. Cannot be
@@ -5488,6 +5519,20 @@ def write_vcf(
54885519
pre 0.2.0 legacy behaviour of rounding values to the nearest integer
54895520
(starting from 1) and avoiding the output of identical positions
54905521
by incrementing is used.
5522+
:param site_mask: A numpy boolean array (or something convertable to
5523+
a numpy boolean array) with num_sites elements, used to mask out
5524+
sites in the output. If ``site_mask[j]`` is True, then this
5525+
site (i.e., the line in the VCF file) will be omitted.
5526+
:param sample_mask: A numpy boolean array (or something convertable to
5527+
a numpy boolean array) with num_samples elements, or a callable
5528+
that returns such an array, such that if
5529+
``sample_mask[j]`` is True, then the genotype for sample ``j``
5530+
will be marked as missing using a ".". If ``sample_mask`` is a
5531+
callable, it must take a single argument and return a boolean
5532+
numpy array. This function will be called for each (unmasked) site
5533+
with the corresponding :class:`.Variant` object, allowing
5534+
for dynamic masks to be generated. See above for example
5535+
usage.
54915536
"""
54925537
writer = vcf.VcfWriter(
54935538
self,
@@ -5496,6 +5541,8 @@ def write_vcf(
54965541
individuals=individuals,
54975542
individual_names=individual_names,
54985543
position_transform=position_transform,
5544+
site_mask=site_mask,
5545+
sample_mask=sample_mask,
54995546
)
55005547
writer.write(output)
55015548

python/tskit/vcf.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def __init__(
5858
individuals=None,
5959
individual_names=None,
6060
position_transform=None,
61+
site_mask=None,
62+
sample_mask=None,
6163
):
6264
self.tree_sequence = tree_sequence
6365
self.contig_id = contig_id
@@ -98,6 +100,18 @@ def __init__(
98100
# from the legacy VCF output code.
99101
self.contig_length = max(self.transformed_positions[-1], self.contig_length)
100102

103+
if site_mask is None:
104+
site_mask = np.zeros(tree_sequence.num_sites, dtype=bool)
105+
self.site_mask = np.array(site_mask, dtype=bool)
106+
if self.site_mask.shape != (tree_sequence.num_sites,):
107+
raise ValueError("Site mask must be 1D a boolean array of length num_sites")
108+
109+
self.sample_mask = sample_mask
110+
if sample_mask is not None:
111+
if not callable(sample_mask):
112+
sample_mask = np.array(sample_mask, dtype=bool)
113+
self.sample_mask = lambda _: sample_mask
114+
101115
def __make_sample_mapping(self, ploidy):
102116
"""
103117
Compute the sample IDs for each VCF individual and the template for
@@ -176,6 +190,12 @@ def write(self, output):
176190
indexes = np.array(indexes, dtype=int)
177191

178192
for variant in self.tree_sequence.variants(samples=self.samples):
193+
site_id = variant.site.id
194+
# We check the mask before we do any checks so we can use this as a
195+
# way of skipping problematic sites.
196+
if self.site_mask[site_id]:
197+
continue
198+
179199
if variant.num_alleles > 9:
180200
raise ValueError(
181201
"More than 9 alleles not currently supported. Please open an issue "
@@ -187,7 +207,6 @@ def write(self, output):
187207
"on GitHub if this limitation affects you."
188208
)
189209
pos = self.transformed_positions[variant.index]
190-
site_id = variant.site.id
191210
ref = variant.alleles[0]
192211
alt = ",".join(variant.alleles[1:]) if len(variant.alleles) > 1 else "."
193212
print(
@@ -204,7 +223,23 @@ def write(self, output):
204223
end="\t",
205224
file=output,
206225
)
207-
gt_array[indexes] = variant.genotypes + ord("0")
226+
# NOTE: when we support missing data we should be able to
227+
# simply add ``and not variant.has_missing_data`` here.
228+
# Probably OK to take the perf hit in making the missing
229+
# data case go in with the more general sample masking case.
230+
if self.sample_mask is None:
231+
gt_array[indexes] = variant.genotypes + ord("0")
232+
else:
233+
genotypes = variant.genotypes.copy()
234+
sample_mask = np.array(self.sample_mask(variant), dtype=bool)
235+
if sample_mask.shape != genotypes.shape:
236+
raise ValueError(
237+
"Sample mask must be a numpy array of size num_samples"
238+
)
239+
gt_array[indexes] = genotypes + ord("0")
240+
genotypes[sample_mask] = -1
241+
missing = genotypes == -1
242+
gt_array[indexes[missing]] = ord(".")
208243
g_bytes = memoryview(gt_array).tobytes()
209244
g_str = g_bytes.decode()
210245
print(g_str, end="", file=output)

0 commit comments

Comments
 (0)