Skip to content

Commit 00debbe

Browse files
Python version of virtual roots in tree arrays.
Includes tests to document differences in semantics to legacy version.
1 parent d9523c7 commit 00debbe

File tree

3 files changed

+248
-45
lines changed

3 files changed

+248
-45
lines changed

python/tests/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def edge_diffs(self):
297297
left = right
298298

299299
def trees(self):
300-
rtt = tsutil.RootThresholdTree(self._tree_sequence)
300+
rtt = tsutil.LegacyRootThresholdTree(self._tree_sequence)
301301
pt = PythonTree(self._tree_sequence.get_num_nodes())
302302
pt.index = 0
303303
for left, right in rtt.iterate():

python/tests/test_topology.py

Lines changed: 100 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# MIT License
22
#
3-
# Copyright (c) 2018-2020 Tskit Developers
3+
# Copyright (c) 2018-2021 Tskit Developers
44
# Copyright (c) 2016-2017 University of Oxford
55
#
66
# Permission is hereby granted, free of charge, to any person obtaining a copy
@@ -323,6 +323,12 @@ def test_multiroot_tree(self):
323323
ts = tsutil.decapitate(ts, ts.num_edges // 2)
324324
self.verify(ts)
325325

326+
def test_all_missing_data(self):
327+
tables = tskit.TableCollection(1)
328+
for _ in range(10):
329+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
330+
self.verify(tables.tree_sequence())
331+
326332

327333
class TestKCMetric(unittest.TestCase):
328334
"""
@@ -6605,6 +6611,11 @@ def verify(self, ts):
66056611
break
66066612
index = tree1.next_sample[index]
66076613
assert samples1 == samples2
6614+
np.testing.assert_array_equal(tree1.parent, tree2.parent_array)
6615+
np.testing.assert_array_equal(tree1.left_child, tree2.left_child_array)
6616+
np.testing.assert_array_equal(tree1.right_child, tree2.right_child_array)
6617+
# We don't compare the sib arrays because these aren't maintained properly
6618+
# for roots in the Python implementation
66086619
assert right == ts.sequence_length
66096620

66106621

@@ -6615,7 +6626,7 @@ class TestOneSampleRoot(ExampleTopologyMixin):
66156626
"""
66166627

66176628
def verify(self, ts):
6618-
tree1 = tsutil.RootThresholdTree(ts, root_threshold=1)
6629+
tree1 = tsutil.LegacyRootThresholdTree(ts, root_threshold=1)
66196630
tree2 = tskit.Tree(ts)
66206631
tree2.first()
66216632
for interval in tree1.iterate():
@@ -6629,35 +6640,102 @@ def verify(self, ts):
66296640
u = tree2.parent(u)
66306641
roots.add(path_end)
66316642
assert set(tree1.roots()) == roots
6643+
np.testing.assert_array_equal(tree1.parent, tree2.parent_array)
6644+
np.testing.assert_array_equal(tree1.left_child, tree2.left_child_array)
6645+
np.testing.assert_array_equal(tree1.right_child, tree2.right_child_array)
6646+
np.testing.assert_array_equal(tree1.left_sib, tree2.left_sib_array)
6647+
np.testing.assert_array_equal(tree1.right_sib, tree2.right_sib_array)
66326648
tree2.next()
66336649
assert tree2.index == -1
66346650

66356651

6636-
class TestKSamplesRoot(ExampleTopologyMixin):
6652+
class RootThreshold(ExampleTopologyMixin):
66376653
"""
66386654
Tests for the root criteria of subtending at least k samples.
66396655
"""
66406656

6657+
@pytest.mark.skip("all missing broken in lib: #1706")
6658+
def test_all_missing_data(self):
6659+
pass
6660+
66416661
def verify(self, ts):
6642-
for k in range(1, 5):
6643-
tree1 = tsutil.RootThresholdTree(ts, root_threshold=k)
6644-
tree2 = tskit.Tree(ts, root_threshold=k)
6645-
tree2.first()
6646-
for interval in tree1.iterate():
6647-
assert interval == tree2.interval
6648-
# Definition here is the set unique path ends from samples
6649-
# that subtend at least k samples
6650-
roots = set()
6651-
for u in ts.samples():
6652-
while u != tskit.NULL:
6653-
path_end = u
6654-
u = tree2.parent(u)
6655-
if tree2.num_samples(path_end) >= k:
6656-
roots.add(path_end)
6657-
assert set(tree1.roots()) == roots
6658-
assert tree1.roots() == tree2.roots
6659-
tree2.next()
6660-
assert tree2.index == -1
6662+
k = self.root_threshold
6663+
trees_py = tsutil.algorithm_R(ts, root_threshold=k)
6664+
tree_lib = tskit.Tree(ts, root_threshold=k)
6665+
tree_lib.first()
6666+
tree_leg = tsutil.LegacyRootThresholdTree(ts, root_threshold=k)
6667+
for (interval_py, tree_py), interval_leg in itertools.zip_longest(
6668+
trees_py, tree_leg.iterate()
6669+
):
6670+
assert interval_py == tree_lib.interval
6671+
assert interval_leg == tree_lib.interval
6672+
6673+
# Definition here is the set unique path ends from samples
6674+
# that subtend at least k samples
6675+
roots = set()
6676+
for u in ts.samples():
6677+
while u != tskit.NULL:
6678+
path_end = u
6679+
u = tree_lib.parent(u)
6680+
if tree_lib.num_samples(path_end) >= k:
6681+
roots.add(path_end)
6682+
assert set(tree_py.roots()) == roots
6683+
assert set(tree_lib.roots) == roots
6684+
assert tree_leg.roots() == tree_lib.roots
6685+
assert len(tree_py.roots()) == tree_lib.num_roots
6686+
assert len(tree_leg.roots()) == tree_lib.num_roots
6687+
6688+
# The legacy class has identical behaviour to the lib version
6689+
assert tree_leg.left_root == tree_lib.left_root
6690+
np.testing.assert_array_equal(tree_leg.parent, tree_lib.parent_array)
6691+
np.testing.assert_array_equal(
6692+
tree_leg.left_child, tree_lib.left_child_array
6693+
)
6694+
np.testing.assert_array_equal(
6695+
tree_leg.right_child, tree_lib.right_child_array
6696+
)
6697+
np.testing.assert_array_equal(tree_leg.left_sib, tree_lib.left_sib_array)
6698+
np.testing.assert_array_equal(tree_leg.right_sib, tree_lib.right_sib_array)
6699+
6700+
# NOTE: the legacy left_root value is *not* necessarily the same as the
6701+
# new left_root.
6702+
# assert tree_leg.left_root == tree_py.left_child[-1]
6703+
6704+
# The virtual root version is identical except for the extra node and
6705+
# the details of the sib arrays.
6706+
np.testing.assert_array_equal(tree_py.parent[:-1], tree_leg.parent)
6707+
np.testing.assert_array_equal(tree_py.left_child[:-1], tree_leg.left_child)
6708+
np.testing.assert_array_equal(
6709+
tree_py.right_child[:-1], tree_leg.right_child
6710+
)
6711+
# The sib arrays are identical except for root nodes.
6712+
for u in range(ts.num_nodes):
6713+
if u not in roots:
6714+
assert tree_py.left_sib[u] == tree_leg.left_sib[u]
6715+
assert tree_py.right_sib[u] == tree_leg.right_sib[u]
6716+
6717+
tree_lib.next()
6718+
assert tree_lib.index == -1
6719+
6720+
6721+
class TestRootThreshold1(RootThreshold):
6722+
root_threshold = 1
6723+
6724+
6725+
class TestRootThreshold2(RootThreshold):
6726+
root_threshold = 2
6727+
6728+
6729+
class TestRootThreshold3(RootThreshold):
6730+
root_threshold = 3
6731+
6732+
6733+
class TestRootThreshold4(RootThreshold):
6734+
root_threshold = 4
6735+
6736+
6737+
class TestRootThreshold10(RootThreshold):
6738+
root_threshold = 10
66616739

66626740

66636741
class TestSquashEdges:

python/tests/tsutil.py

Lines changed: 147 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# MIT License
22
#
3-
# Copyright (c) 2018-2019 Tskit Developers
3+
# Copyright (c) 2018-2021 Tskit Developers
44
# Copyright (C) 2017 University of Oxford
55
#
66
# Permission is hereby granted, free of charge, to any person obtaining a copy
@@ -1175,6 +1175,143 @@ def algorithm_T(ts):
11751175
left = right
11761176

11771177

1178+
class QuintuplyLinkedTree:
1179+
def __init__(self, n, root_threshold=1):
1180+
self.root_threshold = root_threshold
1181+
self.parent = np.zeros(n + 1, dtype=np.int32) - 1
1182+
self.left_child = np.zeros(n + 1, dtype=np.int32) - 1
1183+
self.right_child = np.zeros(n + 1, dtype=np.int32) - 1
1184+
self.left_sib = np.zeros(n + 1, dtype=np.int32) - 1
1185+
self.right_sib = np.zeros(n + 1, dtype=np.int32) - 1
1186+
self.num_samples = np.zeros(n + 1, dtype=np.int32)
1187+
1188+
def __str__(self):
1189+
s = "id\tparent\tlchild\trchild\tlsib\trsib\tnsamp\n"
1190+
for j in range(len(self.parent)):
1191+
s += (
1192+
f"{j}\t{self.parent[j]}\t"
1193+
f"{self.left_child[j]}\t{self.right_child[j]}\t"
1194+
f"{self.left_sib[j]}\t{self.right_sib[j]}\t"
1195+
f"{self.num_samples[j]}\n"
1196+
)
1197+
return s
1198+
1199+
def roots(self):
1200+
roots = []
1201+
u = self.left_child[-1]
1202+
while u != -1:
1203+
roots.append(u)
1204+
u = self.right_sib[u]
1205+
return roots
1206+
1207+
def remove_branch(self, p, c):
1208+
lsib = self.left_sib[c]
1209+
rsib = self.right_sib[c]
1210+
if lsib == -1:
1211+
self.left_child[p] = rsib
1212+
else:
1213+
self.right_sib[lsib] = rsib
1214+
if rsib == -1:
1215+
self.right_child[p] = lsib
1216+
else:
1217+
self.left_sib[rsib] = lsib
1218+
self.parent[c] = -1
1219+
self.left_sib[c] = -1
1220+
self.right_sib[c] = -1
1221+
1222+
def insert_branch(self, p, c):
1223+
assert self.parent[c] == -1, "contradictory edges"
1224+
self.parent[c] = p
1225+
u = self.right_child[p]
1226+
if u == -1:
1227+
self.left_child[p] = c
1228+
self.left_sib[c] = -1
1229+
self.right_sib[c] = -1
1230+
else:
1231+
self.right_sib[u] = c
1232+
self.left_sib[c] = u
1233+
self.right_sib[c] = -1
1234+
self.right_child[p] = c
1235+
1236+
def is_root(self, u):
1237+
return self.num_samples[u] >= self.root_threshold
1238+
1239+
# Note we cheat a bit here and use the -1 == last element semantics from Python.
1240+
# We could use self.insert_branch(N, root) and then set self.parent[root] = -1.
1241+
def insert_root(self, root):
1242+
self.insert_branch(-1, root)
1243+
1244+
def remove_root(self, root):
1245+
self.remove_branch(-1, root)
1246+
1247+
def remove_edge(self, edge):
1248+
self.remove_branch(edge.parent, edge.child)
1249+
1250+
u = edge.parent
1251+
while u != -1:
1252+
path_end = u
1253+
path_end_was_root = self.is_root(u)
1254+
self.num_samples[u] -= self.num_samples[edge.child]
1255+
u = self.parent[u]
1256+
1257+
if path_end_was_root and not self.is_root(path_end):
1258+
self.remove_root(path_end)
1259+
if self.is_root(edge.child):
1260+
self.insert_root(edge.child)
1261+
1262+
def insert_edge(self, edge):
1263+
u = edge.parent
1264+
while u != -1:
1265+
path_end = u
1266+
path_end_was_root = self.is_root(u)
1267+
self.num_samples[u] += self.num_samples[edge.child]
1268+
u = self.parent[u]
1269+
1270+
if self.is_root(edge.child):
1271+
self.remove_root(edge.child)
1272+
if self.is_root(path_end) and not path_end_was_root:
1273+
self.insert_root(path_end)
1274+
1275+
self.insert_branch(edge.parent, edge.child)
1276+
1277+
1278+
def algorithm_R(ts, root_threshold=1):
1279+
"""
1280+
Quintuply linked tree with root tracking.
1281+
"""
1282+
sequence_length = ts.sequence_length
1283+
N = ts.num_nodes
1284+
M = ts.num_edges
1285+
tree = QuintuplyLinkedTree(N, root_threshold=root_threshold)
1286+
edges = list(ts.edges())
1287+
in_order = ts.tables.indexes.edge_insertion_order
1288+
out_order = ts.tables.indexes.edge_removal_order
1289+
1290+
# Initialise the tree
1291+
for u in ts.samples():
1292+
tree.num_samples[u] = 1
1293+
if tree.is_root(u):
1294+
tree.insert_root(u)
1295+
1296+
j = 0
1297+
k = 0
1298+
left = 0
1299+
while j < M or left < sequence_length:
1300+
while k < M and edges[out_order[k]].right == left:
1301+
tree.remove_edge(edges[out_order[k]])
1302+
k += 1
1303+
while j < M and edges[in_order[j]].left == left:
1304+
tree.insert_edge(edges[in_order[j]])
1305+
j += 1
1306+
right = sequence_length
1307+
if j < M:
1308+
right = min(right, edges[in_order[j]].left)
1309+
if k < M:
1310+
right = min(right, edges[out_order[k]].right)
1311+
yield (left, right), tree
1312+
left = right
1313+
1314+
11781315
class SampleListTree:
11791316
"""
11801317
Straightforward implementation of the quintuply linked tree for developing
@@ -1315,15 +1452,8 @@ def sample_lists(self):
13151452
sequence_length = ts.sequence_length
13161453
edges = list(ts.edges())
13171454
M = len(edges)
1318-
time = [ts.node(edge.parent).time for edge in edges]
1319-
in_order = sorted(
1320-
range(M),
1321-
key=lambda j: (edges[j].left, time[j], edges[j].parent, edges[j].child),
1322-
)
1323-
out_order = sorted(
1324-
range(M),
1325-
key=lambda j: (edges[j].right, -time[j], -edges[j].parent, -edges[j].child),
1326-
)
1455+
in_order = ts.tables.indexes.edge_insertion_order
1456+
out_order = ts.tables.indexes.edge_removal_order
13271457
j = 0
13281458
k = 0
13291459
left = 0
@@ -1348,10 +1478,12 @@ def sample_lists(self):
13481478
left = right
13491479

13501480

1351-
class RootThresholdTree:
1481+
class LegacyRootThresholdTree:
13521482
"""
1353-
Straightforward implementation of the quintuply linked tree for developing
1354-
and testing the root_threshold feature.
1483+
Implementation of the quintuply linked tree with root tracking using the
1484+
pre C 1.0/Python 0.4.0 algorithm. We keep this version around to make sure
1485+
that we can be clear what the differences in the semantics of the new
1486+
and old versions are.
13551487
13561488
NOTE: The interface is pretty awkward; it's not intended for anything other
13571489
than testing.
@@ -1512,15 +1644,8 @@ def iterate(self):
15121644
sequence_length = ts.sequence_length
15131645
edges = list(ts.edges())
15141646
M = len(edges)
1515-
time = [ts.node(edge.parent).time for edge in edges]
1516-
in_order = sorted(
1517-
range(M),
1518-
key=lambda j: (edges[j].left, time[j], edges[j].parent, edges[j].child),
1519-
)
1520-
out_order = sorted(
1521-
range(M),
1522-
key=lambda j: (edges[j].right, -time[j], -edges[j].parent, -edges[j].child),
1523-
)
1647+
in_order = ts.tables.indexes.edge_insertion_order
1648+
out_order = ts.tables.indexes.edge_removal_order
15241649
j = 0
15251650
k = 0
15261651
left = 0

0 commit comments

Comments
 (0)