Skip to content

Commit 1efee21

Browse files
mufernandomergify-bot
authored andcommitted
adding time slicing to ts.samples()
1 parent fe6121a commit 1efee21

File tree

3 files changed

+203
-13
lines changed

3 files changed

+203
-13
lines changed

python/CHANGELOG.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
- Add ``__setitem__`` to all tables allowing single rows to be updated. For example
2323
``tables.nodes[0] = tables.nodes[0].replace(flags=tskit.NODE_IS_SAMPLE)``
2424
(:user:`jeromekelleher`, :user:`benjeffery`, :issue:`1545`, :pr:`1600`).
25+
- Added a new parameter ``time`` to ``TreeSequence.samples()`` allowing to select
26+
samples at a specific time point or time interval.
27+
(:user:`mufernando`, :user:`petrelharp`, :issue:`1692`, :pr:`1700`)
2528

2629
- Add ``table.metadata_vector`` to all table classes to allow easy extraction of a single
2730
metadata key into an array

python/tests/test_highlevel.py

Lines changed: 167 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,23 @@ def get_mrca(pi, x, y):
346346
return mrca
347347

348348

349+
def get_samples(ts, time=None, population=None):
350+
samples = []
351+
for node in ts.nodes():
352+
keep = bool(node.is_sample())
353+
if time is not None:
354+
if isinstance(time, (int, float)):
355+
keep &= np.isclose(node.time, time)
356+
if isinstance(time, (tuple, list, np.ndarray)):
357+
keep &= node.time >= time[0]
358+
keep &= node.time < time[1]
359+
if population is not None:
360+
keep &= node.population == population
361+
if keep:
362+
samples.append(node.id)
363+
return np.array(samples)
364+
365+
349366
class TestMRCACalculator:
350367
"""
351368
Class to test the Schieber-Vishkin algorithm.
@@ -509,11 +526,14 @@ class TestNumpySamples:
509526
various methods.
510527
"""
511528

512-
def get_tree_sequence(self, num_demes=4):
513-
n = 40
529+
def get_tree_sequence(self, num_demes=4, times=None, n=40):
530+
if times is None:
531+
times = [0]
514532
return msprime.simulate(
515533
samples=[
516-
msprime.Sample(time=0, population=j % num_demes) for j in range(n)
534+
msprime.Sample(time=t, population=j % num_demes)
535+
for j in range(n)
536+
for t in times
517537
],
518538
population_configurations=[
519539
msprime.PopulationConfiguration() for _ in range(num_demes)
@@ -541,6 +561,150 @@ def test_samples(self):
541561
]
542562
assert total == ts.num_samples
543563

564+
@pytest.mark.parametrize("time", [0, 0.1, 1 / 3, 1 / 4, 5 / 7])
565+
def test_samples_time(self, time):
566+
ts = self.get_tree_sequence(num_demes=2, n=20, times=[time, 0.2, 1, 15])
567+
assert np.array_equal(get_samples(ts, time=time), ts.samples(time=time))
568+
for population in (None, 0):
569+
assert np.array_equal(
570+
get_samples(ts, time=time, population=population),
571+
ts.samples(time=time, population=population),
572+
)
573+
574+
@pytest.mark.parametrize(
575+
"time_interval",
576+
[
577+
[0, 0.1],
578+
(0, 1 / 3),
579+
np.array([1 / 4, 2 / 3]),
580+
(0.345, 5 / 7),
581+
(-1, 1),
582+
],
583+
)
584+
def test_samples_time_interval(self, time_interval):
585+
rng = np.random.default_rng(seed=931)
586+
times = rng.uniform(low=time_interval[0], high=2 * time_interval[1], size=20)
587+
ts = self.get_tree_sequence(num_demes=2, n=1, times=times)
588+
assert np.array_equal(
589+
get_samples(ts, time=time_interval),
590+
ts.samples(time=time_interval),
591+
)
592+
for population in (None, 0):
593+
assert np.array_equal(
594+
get_samples(ts, time=time_interval, population=population),
595+
ts.samples(time=time_interval, population=population),
596+
)
597+
598+
def test_samples_example(self):
599+
tables = tskit.TableCollection(sequence_length=10)
600+
time = [np.array(0), 0, np.array([1]), 1, 1, 3, 3.00001, 3.0 - 0.0001, 1 / 3]
601+
pops = [1, 3, 1, 2, 1, 1, 1, 3, 1]
602+
for _ in range(max(pops) + 1):
603+
tables.populations.add_row()
604+
for t, p in zip(time, pops):
605+
tables.nodes.add_row(
606+
flags=tskit.NODE_IS_SAMPLE,
607+
time=t,
608+
population=p,
609+
)
610+
# add not-samples also
611+
for t, p in zip(time, pops):
612+
tables.nodes.add_row(
613+
flags=0,
614+
time=t,
615+
population=p,
616+
)
617+
ts = tables.tree_sequence()
618+
assert np.array_equal(
619+
ts.samples(),
620+
np.arange(len(time)),
621+
)
622+
assert np.array_equal(
623+
ts.samples(time=[0, np.inf]),
624+
np.arange(len(time)),
625+
)
626+
assert np.array_equal(
627+
ts.samples(time=0),
628+
[0, 1],
629+
)
630+
# default tolerance is 1e-5
631+
assert np.array_equal(
632+
ts.samples(time=0.3333333),
633+
[8],
634+
)
635+
assert np.array_equal(
636+
ts.samples(time=3),
637+
[5, 6],
638+
)
639+
assert np.array_equal(
640+
ts.samples(time=1),
641+
[2, 3, 4],
642+
)
643+
assert np.array_equal(
644+
ts.samples(time=1, population=2),
645+
[3],
646+
)
647+
assert np.array_equal(
648+
ts.samples(population=0),
649+
[],
650+
)
651+
assert np.array_equal(
652+
ts.samples(population=1),
653+
[0, 2, 4, 5, 6, 8],
654+
)
655+
assert np.array_equal(
656+
ts.samples(population=2),
657+
[3],
658+
)
659+
assert np.array_equal(
660+
ts.samples(time=[0, 3]),
661+
[0, 1, 2, 3, 4, 7, 8],
662+
)
663+
# note tuple instead of array
664+
assert np.array_equal(
665+
ts.samples(time=(1, 3)),
666+
[2, 3, 4, 7],
667+
)
668+
assert np.array_equal(
669+
ts.samples(time=[0, 3], population=1),
670+
[0, 2, 4, 8],
671+
)
672+
assert np.array_equal(
673+
ts.samples(time=[0.333333, 3]),
674+
[2, 3, 4, 7, 8],
675+
)
676+
assert np.array_equal(
677+
ts.samples(time=[100, np.inf]),
678+
[],
679+
)
680+
assert np.array_equal(
681+
ts.samples(time=-1),
682+
[],
683+
)
684+
assert np.array_equal(
685+
ts.samples(time=[-100, 100]),
686+
np.arange(len(time)),
687+
)
688+
assert np.array_equal(
689+
ts.samples(time=[-100, -1]),
690+
[],
691+
)
692+
693+
def test_samples_time_errors(self):
694+
ts = self.get_tree_sequence(4)
695+
# error incorrect types
696+
with pytest.raises(ValueError):
697+
ts.samples(time="s")
698+
with pytest.raises(ValueError):
699+
ts.samples(time=[])
700+
with pytest.raises(ValueError):
701+
ts.samples(time=np.array([1, 2, 3]))
702+
with pytest.raises(ValueError):
703+
ts.samples(time=(1, 2, 3))
704+
# error using min and max switched
705+
with pytest.raises(ValueError):
706+
ts.samples(time=(2.4, 1))
707+
544708
def test_genotype_matrix_indexing(self):
545709
num_demes = 4
546710
ts = self.get_tree_sequence(num_demes)

python/tskit/trees.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4800,15 +4800,21 @@ def get_samples(self, population_id=None):
48004800
# Deprecated alias for samples()
48014801
return self.samples(population_id)
48024802

4803-
def samples(self, population=None, population_id=None):
4804-
"""
4805-
Returns an array of the sample node IDs in this tree sequence. If the
4806-
``population`` parameter is specified, only return sample IDs from that
4807-
population.
4808-
4809-
:param int population: The population of interest. If None,
4810-
return all samples.
4803+
def samples(self, population=None, population_id=None, time=None):
4804+
"""
4805+
Returns an array of the sample node IDs in this tree sequence. If
4806+
`population` is specified, only return sample IDs from that population.
4807+
It is also possible to restrict samples by time using the parameter
4808+
`time`. If `time` is a numeric value, only return sample IDs whose node
4809+
time is approximately equal to the specified time. If `time` is a pair
4810+
of values of the form `(min_time, max_time)`, only return sample IDs
4811+
whose node time `t` is in this interval such that `min_time <= t < max_time`.
4812+
4813+
:param int population: The population of interest. If None, do not
4814+
filter samples by population.
48114815
:param int population_id: Deprecated alias for ``population``.
4816+
:param float,tuple time: The time or time interval of interest. If
4817+
None, do not filter samples by time.
48124818
:return: A numpy array of the node IDs for the samples of interest,
48134819
listed in numerical order.
48144820
:rtype: numpy.ndarray (dtype=np.int32)
@@ -4820,10 +4826,27 @@ def samples(self, population=None, population_id=None):
48204826
if population_id is not None:
48214827
population = population_id
48224828
samples = self._ll_tree_sequence.get_samples()
4829+
keep = np.full(shape=samples.shape, fill_value=True)
48234830
if population is not None:
48244831
sample_population = self.tables.nodes.population[samples]
4825-
samples = samples[sample_population == population]
4826-
return samples
4832+
keep = np.logical_and(keep, sample_population == population)
4833+
if time is not None:
4834+
# ndmin is set so that scalars are converted into 1d arrays
4835+
time = np.array(time, ndmin=1, dtype=float)
4836+
sample_times = self.tables.nodes.time[samples]
4837+
if time.shape == (1,):
4838+
keep = np.logical_and(keep, np.isclose(sample_times, time))
4839+
elif time.shape == (2,):
4840+
if time[1] <= time[0]:
4841+
raise ValueError("time_interval max is less than or equal to min.")
4842+
keep = np.logical_and(keep, sample_times >= time[0])
4843+
keep = np.logical_and(keep, sample_times < time[1])
4844+
else:
4845+
raise ValueError(
4846+
"time must be either a single value or a pair of values "
4847+
"(min_time, max_time)."
4848+
)
4849+
return samples[keep]
48274850

48284851
def write_fasta(self, output, sequence_ids=None, wrap_width=60):
48294852
""

0 commit comments

Comments
 (0)