@@ -346,6 +346,23 @@ def get_mrca(pi, x, y):
346
346
return mrca
347
347
348
348
349
+ def get_samples (ts , time = None , population = None ):
350
+ samples = []
351
+ for node in ts .nodes ():
352
+ keep = bool (node .is_sample ())
353
+ if time is not None :
354
+ if isinstance (time , (int , float )):
355
+ keep &= np .isclose (node .time , time )
356
+ if isinstance (time , (tuple , list )):
357
+ keep &= node .time >= time [0 ]
358
+ keep &= node .time < time [1 ]
359
+ if population is not None :
360
+ keep &= node .population == population
361
+ if keep :
362
+ samples .append (node .id )
363
+ return np .array (samples )
364
+
365
+
349
366
class TestMRCACalculator :
350
367
"""
351
368
Class to test the Schieber-Vishkin algorithm.
@@ -509,11 +526,14 @@ class TestNumpySamples:
509
526
various methods.
510
527
"""
511
528
512
- def get_tree_sequence (self , num_demes = 4 ):
513
- n = 40
529
+ def get_tree_sequence (self , num_demes = 4 , times = None , n = 40 ):
530
+ if times is None :
531
+ times = [0 ]
514
532
return msprime .simulate (
515
533
samples = [
516
- msprime .Sample (time = 0 , population = j % num_demes ) for j in range (n )
534
+ msprime .Sample (time = t , population = j % num_demes )
535
+ for j in range (n )
536
+ for t in times
517
537
],
518
538
population_configurations = [
519
539
msprime .PopulationConfiguration () for _ in range (num_demes )
@@ -541,6 +561,150 @@ def test_samples(self):
541
561
]
542
562
assert total == ts .num_samples
543
563
564
+ @pytest .mark .parametrize ("time" , [0 , 0.1 , 1 / 3 , 1 / 4 , 5 / 7 ])
565
+ def test_samples_time (self , time ):
566
+ ts = self .get_tree_sequence (num_demes = 2 , n = 20 , times = [time ])
567
+ assert np .array_equal (get_samples (ts , time = time ), ts .samples (time = time ))
568
+ for population in (None , 0 ):
569
+ assert np .array_equal (
570
+ get_samples (ts , time = time , population = population ),
571
+ ts .samples (time = time , population = population ),
572
+ )
573
+
574
+ @pytest .mark .parametrize (
575
+ "time_interval" ,
576
+ [
577
+ [0 , 0.1 ],
578
+ (0 , 1 / 3 ),
579
+ np .array ([1 / 4 , 2 / 3 ]),
580
+ (0.345 , 5 / 7 ),
581
+ (- 1 , 1 ),
582
+ ],
583
+ )
584
+ def test_samples_time_interval (self , time_interval ):
585
+ rng = np .random .default_rng (seed = 931 )
586
+ times = rng .uniform (low = time_interval [0 ], high = time_interval [1 ], size = 20 )
587
+ ts = self .get_tree_sequence (num_demes = 2 , n = 1 , times = times )
588
+ assert np .array_equal (
589
+ get_samples (ts , time = time_interval ),
590
+ ts .samples (time = time_interval ),
591
+ )
592
+ for population in (None , 0 ):
593
+ assert np .array_equal (
594
+ get_samples (ts , time = time_interval , population = population ),
595
+ ts .samples (time = time_interval , population = population ),
596
+ )
597
+
598
+ def test_samples_example (self ):
599
+ tables = tskit .TableCollection (sequence_length = 10 )
600
+ time = [np .array (0 ), 0 , np .array ([1 ]), 1 , 1 , 3 , 3.00001 , 3.0 - 0.0001 , 1 / 3 ]
601
+ pops = [1 , 3 , 1 , 2 , 1 , 1 , 1 , 3 , 1 ]
602
+ for _ in range (max (pops ) + 1 ):
603
+ tables .populations .add_row ()
604
+ for t , p in zip (time , pops ):
605
+ tables .nodes .add_row (
606
+ flags = tskit .NODE_IS_SAMPLE ,
607
+ time = t ,
608
+ population = p ,
609
+ )
610
+ # add not-samples also
611
+ for t , p in zip (time , pops ):
612
+ tables .nodes .add_row (
613
+ flags = 0 ,
614
+ time = t ,
615
+ population = p ,
616
+ )
617
+ ts = tables .tree_sequence ()
618
+ assert np .array_equal (
619
+ ts .samples (),
620
+ np .arange (len (time )),
621
+ )
622
+ assert np .array_equal (
623
+ ts .samples (time = [0 , np .inf ]),
624
+ np .arange (len (time )),
625
+ )
626
+ assert np .array_equal (
627
+ ts .samples (time = 0 ),
628
+ [0 , 1 ],
629
+ )
630
+ # default tolerance is 1e-5
631
+ assert np .array_equal (
632
+ ts .samples (time = 0.3333333 ),
633
+ [8 ],
634
+ )
635
+ assert np .array_equal (
636
+ ts .samples (time = 3 ),
637
+ [5 , 6 ],
638
+ )
639
+ assert np .array_equal (
640
+ ts .samples (time = 1 ),
641
+ [2 , 3 , 4 ],
642
+ )
643
+ assert np .array_equal (
644
+ ts .samples (time = 1 , population = 2 ),
645
+ [3 ],
646
+ )
647
+ assert np .array_equal (
648
+ ts .samples (population = 0 ),
649
+ [],
650
+ )
651
+ assert np .array_equal (
652
+ ts .samples (population = 1 ),
653
+ [0 , 2 , 4 , 5 , 6 , 8 ],
654
+ )
655
+ assert np .array_equal (
656
+ ts .samples (population = 2 ),
657
+ [3 ],
658
+ )
659
+ assert np .array_equal (
660
+ ts .samples (time = [0 , 3 ]),
661
+ [0 , 1 , 2 , 3 , 4 , 7 , 8 ],
662
+ )
663
+ # note tuple instead of array
664
+ assert np .array_equal (
665
+ ts .samples (time = (1 , 3 )),
666
+ [2 , 3 , 4 , 7 ],
667
+ )
668
+ assert np .array_equal (
669
+ ts .samples (time = [0 , 3 ], population = 1 ),
670
+ [0 , 2 , 4 , 8 ],
671
+ )
672
+ assert np .array_equal (
673
+ ts .samples (time = [0.333333 , 3 ]),
674
+ [2 , 3 , 4 , 7 , 8 ],
675
+ )
676
+ assert np .array_equal (
677
+ ts .samples (time = [100 , np .inf ]),
678
+ [],
679
+ )
680
+ assert np .array_equal (
681
+ ts .samples (time = - 1 ),
682
+ [],
683
+ )
684
+ assert np .array_equal (
685
+ ts .samples (time = [- 100 , 100 ]),
686
+ np .arange (len (time )),
687
+ )
688
+ assert np .array_equal (
689
+ ts .samples (time = [- 100 , - 1 ]),
690
+ [],
691
+ )
692
+
693
+ def test_samples_time_errors (self ):
694
+ ts = self .get_tree_sequence (4 )
695
+ # error incorrect types
696
+ with pytest .raises (ValueError ):
697
+ ts .samples (time = "s" )
698
+ with pytest .raises (ValueError ):
699
+ ts .samples (time = [])
700
+ with pytest .raises (ValueError ):
701
+ ts .samples (time = np .array ([1 , 2 , 3 ]))
702
+ with pytest .raises (ValueError ):
703
+ ts .samples (time = (1 , 2 , 3 ))
704
+ # error using min and max switched
705
+ with pytest .raises (ValueError ):
706
+ ts .samples (time = (2.4 , 1 ))
707
+
544
708
def test_genotype_matrix_indexing (self ):
545
709
num_demes = 4
546
710
ts = self .get_tree_sequence (num_demes )
0 commit comments