Skip to content

Commit 562bc08

Browse files
committed
adding time slice to ts.samples()
1 parent 936bf86 commit 562bc08

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

python/tests/test_highlevel.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,22 @@ def test_samples(self):
541541
]
542542
assert total == ts.num_samples
543543

544+
def test_samples_time(self):
545+
rng = np.random.default_rng(seed=923)
546+
ts = self.get_tree_sequence(2)
547+
times = rng.exponential(size=10)
548+
for time in times:
549+
min_samples = []
550+
max_samples = []
551+
for node in ts.nodes():
552+
if node.is_sample:
553+
if node.time <= time:
554+
min_samples.append(node.id)
555+
if node.time >= time:
556+
max_samples.append(node.id)
557+
assert np.array.equal(np.array(min_samples), ts.sampels(min_time=time))
558+
assert np.array.equal(np.array(max_samples), ts.sampels(max_time=time))
559+
544560
def test_genotype_matrix_indexing(self):
545561
num_demes = 4
546562
ts = self.get_tree_sequence(num_demes)

python/tskit/trees.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4800,7 +4800,9 @@ def get_samples(self, population_id=None):
48004800
# Deprecated alias for samples()
48014801
return self.samples(population_id)
48024802

4803-
def samples(self, population=None, population_id=None):
4803+
def samples(
4804+
self, population=None, population_id=None, max_time=None, min_time=None
4805+
):
48044806
"""
48054807
Returns an array of the sample node IDs in this tree sequence. If the
48064808
``population`` parameter is specified, only return sample IDs from that
@@ -4820,10 +4822,17 @@ def samples(self, population=None, population_id=None):
48204822
if population_id is not None:
48214823
population = population_id
48224824
samples = self._ll_tree_sequence.get_samples()
4825+
keep = np.full(shape=samples.shape, fill_value=True)
48234826
if population is not None:
48244827
sample_population = self.tables.nodes.population[samples]
4825-
samples = samples[sample_population == population]
4826-
return samples
4828+
keep = np.logical_and(keep, sample_population == population)
4829+
if max_time is not None:
4830+
sample_times = self.tables.nodes.time[samples]
4831+
keep = np.logical_and(keep, sample_times <= max_time)
4832+
if min_time is not None:
4833+
sample_times = self.tables.nodes.time[samples]
4834+
keep = np.logical_and(keep, sample_times >= min_time)
4835+
return samples[keep]
48274836

48284837
def write_fasta(self, output, sequence_ids=None, wrap_width=60):
48294838
""

0 commit comments

Comments
 (0)