Skip to content

Commit c01c6a6

Browse files
Count edges in the tree.
Use for a crude upper bound on number of nodes in a traversal.
1 parent b6fdb44 commit c01c6a6

File tree

3 files changed

+18
-1
lines changed

3 files changed

+18
-1
lines changed

c/tests/test_trees.c

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5039,7 +5039,10 @@ test_virtual_root_properties(void)
50395039
CU_ASSERT_EQUAL_FATAL(depth, -1);
50405040

50415041
CU_ASSERT_EQUAL_FATAL(tsk_tree_get_time(&t, t.virtual_root, &time), 0)
5042-
CU_ASSERT_TRUE(isinf(time));
5042+
/* Workaround problems in IEEE floating point macros. We may want to
5043+
* add tsk_isinf (like tsk_isnan) at some point, but not worth it just
5044+
* for this test case */
5045+
CU_ASSERT_TRUE(isinf((float) time));
50435046

50445047
CU_ASSERT_EQUAL_FATAL(tsk_tree_get_mrca(&t, t.virtual_root, 0, &node), 0)
50455048
CU_ASSERT_EQUAL(node, t.virtual_root);

python/tests/test_topology.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6556,6 +6556,8 @@ def verify(self, ts):
65566556
tree2 = tskit.Tree(ts)
65576557
tree2.first()
65586558
for interval, tree1 in tsutil.algorithm_R(ts, root_threshold=1):
6559+
root_reachable_nodes = len(tree2.preorder())
6560+
assert tree1.max_nodes >= root_reachable_nodes
65596561
assert interval == tree2.interval
65606562
assert tree1.roots() == tree2.roots
65616563
# Definition here is the set unique path ends from samples
@@ -6592,6 +6594,9 @@ def verify(self, ts):
65926594
assert interval_py == tree_lib.interval
65936595
assert interval_leg == tree_lib.interval
65946596

6597+
root_reachable_nodes = len(tree_lib.preorder())
6598+
assert tree_py.max_nodes >= root_reachable_nodes
6599+
65956600
# Definition here is the set unique path ends from samples
65966601
# that subtend at least k samples
65976602
roots = set()

python/tests/tsutil.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,6 +1184,7 @@ def __init__(self, n, root_threshold=1):
11841184
self.left_sib = np.zeros(n + 1, dtype=np.int32) - 1
11851185
self.right_sib = np.zeros(n + 1, dtype=np.int32) - 1
11861186
self.num_samples = np.zeros(n + 1, dtype=np.int32)
1187+
self.num_edges = 0
11871188

11881189
def __str__(self):
11891190
s = "id\tparent\tlchild\trchild\tlsib\trsib\tnsamp\n"
@@ -1196,6 +1197,12 @@ def __str__(self):
11961197
)
11971198
return s
11981199

1200+
@property
1201+
def max_nodes(self):
1202+
# This gives a crude upper bound on the number of possible nodes in a
1203+
# traversal.
1204+
return self.num_edges + sum(self.num_samples[root] for root in self.roots())
1205+
11991206
def roots(self):
12001207
roots = []
12011208
u = self.left_child[-1]
@@ -1245,6 +1252,7 @@ def remove_root(self, root):
12451252
self.remove_branch(-1, root)
12461253

12471254
def remove_edge(self, edge):
1255+
self.num_edges -= 1
12481256
self.remove_branch(edge.parent, edge.child)
12491257

12501258
u = edge.parent
@@ -1260,6 +1268,7 @@ def remove_edge(self, edge):
12601268
self.insert_root(edge.child)
12611269

12621270
def insert_edge(self, edge):
1271+
self.num_edges += 1
12631272
u = edge.parent
12641273
while u != -1:
12651274
path_end = u

0 commit comments

Comments
 (0)