Skip to content

Commit f7988b0

Browse files
Add tsk_tree_size_bound to enable safe allocations
Closes #1725
1 parent 967bd2b commit f7988b0

File tree

6 files changed

+83
-7
lines changed

6 files changed

+83
-7
lines changed

c/tskit/trees.c

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3626,7 +3626,6 @@ tsk_tree_get_time(const tsk_tree_t *self, tsk_id_t u, double *t)
36263626
if (u == self->virtual_root) {
36273627
*t = INFINITY;
36283628
} else {
3629-
36303629
ret = tsk_treeseq_get_node(self->tree_sequence, u, &node);
36313630
if (ret != 0) {
36323631
goto out;
@@ -4259,16 +4258,48 @@ tsk_tree_clear(tsk_tree_t *self)
42594258
return ret;
42604259
}
42614260

4261+
tsk_size_t
4262+
tsk_tree_get_size_bound(const tsk_tree_t *self)
4263+
{
4264+
tsk_size_t bound = 0;
4265+
4266+
if (self->tree_sequence != NULL) {
4267+
/* This is a safe upper bound which can be computed cheaply.
4268+
* We have at most n roots and each edge adds at most one new
4269+
* node to the tree. We also allow space for the virtual root,
4270+
* to simplify client code.
4271+
*
4272+
* In the common case of a binary tree with a single root, we have
4273+
* 2n - 1 nodes in total, and 2n - 2 edges. Therefore, we return
4274+
* 3n - 1, which is an over-estimate of 1/2 and we allocate
4275+
* 1.5 times as much memory as we need.
4276+
*
4277+
* Since tracking the exact number of nodes in the tree would require
4278+
* storing the number of nodes beneath every node and complicate
4279+
* the tree transition method, this seems like a good compromise
4280+
* and will result in less memory usage overall in nearly all cases.
4281+
*/
4282+
bound = 1 + self->tree_sequence->num_samples + self->num_edges;
4283+
}
4284+
return bound;
4285+
}
4286+
42624287
/* Traversal orders */
42634288

4289+
static tsk_id_t *
4290+
tsk_tree_alloc_node_stack(const tsk_tree_t *self)
4291+
{
4292+
return tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(tsk_id_t));
4293+
}
4294+
42644295
int
42654296
tsk_tree_preorder(
42664297
const tsk_tree_t *self, tsk_id_t root, tsk_id_t *nodes, tsk_size_t *num_nodes_ret)
42674298
{
42684299
int ret = 0;
42694300
const tsk_id_t *restrict right_child = self->right_child;
42704301
const tsk_id_t *restrict left_sib = self->left_sib;
4271-
tsk_id_t *restrict stack = tsk_malloc((self->num_nodes + 1) * sizeof(*stack));
4302+
tsk_id_t *restrict stack = tsk_tree_alloc_node_stack(self);
42724303
tsk_size_t num_nodes = 0;
42734304
tsk_id_t u, v;
42744305
int stack_top;
@@ -4320,7 +4351,7 @@ tsk_tree_postorder(
43204351
const tsk_id_t *restrict right_child = self->right_child;
43214352
const tsk_id_t *restrict left_sib = self->left_sib;
43224353
const tsk_id_t *restrict parent = self->parent;
4323-
tsk_id_t *restrict stack = tsk_malloc((self->num_nodes + 1) * sizeof(*stack));
4354+
tsk_id_t *restrict stack = tsk_tree_alloc_node_stack(self);
43244355
tsk_size_t num_nodes = 0;
43254356
tsk_id_t u, v, postorder_parent;
43264357
int stack_top;

c/tskit/trees.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,31 @@ int tsk_tree_prev(tsk_tree_t *self);
411411
int tsk_tree_clear(tsk_tree_t *self);
412412

413413
void tsk_tree_print_state(const tsk_tree_t *self, FILE *out);
414+
415+
/**
416+
@brief Return an upper bound on the number of nodes reachable
417+
from the roots of this tree.
418+
419+
@rst
420+
This function provides an upper bound on the number of nodes that
421+
can be reached in tree traversals, and is intended to be used
422+
for memory allocation purposes. If ``num_nodes`` is the number
423+
of nodes visited in a tree traversal from the virtual root
424+
(e.g., ``tsk_tree_preorder(tree, tree->virtual_root, nodes,
425+
&num_nodes)``), the bound ``N`` returned here is guaranteed to
426+
be greater than or equal to ``num_nodes``.
427+
428+
.. warning:: The precise value returned is not defined and should
429+
not be depended on, as it may change from version-to-version.
430+
431+
@endrst
432+
433+
@param self A pointer to a tsk_tree_t object.
434+
@return An upper bound on the number nodes reachable from the roots
435+
of this tree, or zero if this tree has not been initialised.
436+
*/
437+
tsk_size_t tsk_tree_get_size_bound(const tsk_tree_t *self);
438+
414439
/** @} */
415440

416441
int tsk_tree_set_root_threshold(tsk_tree_t *self, tsk_size_t root_threshold);

python/_tskitmodule.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10070,7 +10070,7 @@ Tree_get_traversal_array(Tree *self, PyObject *args, tsk_traversal_func *func)
1007010070
if (!PyArg_ParseTuple(args, "i", &root)) {
1007110071
goto out;
1007210072
}
10073-
data = PyDataMem_NEW((self->tree->num_nodes + 1) * sizeof(*data));
10073+
data = PyDataMem_NEW(tsk_tree_get_size_bound(self->tree) * sizeof(*data));
1007410074
if (data == NULL) {
1007510075
ret = PyErr_NoMemory();
1007610076
goto out;

python/tests/test_highlevel.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,24 @@ def get_samples(ts, time=None, population=None):
445445

446446

447447
class TestTreeTraversals:
448+
def test_bad_traversal_order(self):
449+
ts = msprime.sim_ancestry(2, random_seed=234)
450+
tree = ts.first()
451+
for bad_order in ["pre", "post", "preorderorder", ("x",), b"preorder"]:
452+
with pytest.raises(ValueError, match="Traversal order"):
453+
tree.nodes(order=bad_order)
454+
455+
@pytest.mark.parametrize("order", list(traversal_map.keys()))
456+
def test_returned_types(self, order):
457+
ts = msprime.sim_ancestry(2, random_seed=234)
458+
tree = ts.first()
459+
iterator = tree.nodes(order=order)
460+
assert isinstance(iterator, collections.abc.Iterable)
461+
lst = list(iterator)
462+
assert len(lst) > 0
463+
for u in lst:
464+
assert isinstance(u, int)
465+
448466
@pytest.mark.parametrize("ts", get_example_tree_sequences())
449467
@pytest.mark.parametrize("order", list(traversal_map.keys()))
450468
def test_traversals_virtual_root(self, ts, order):

python/tests/test_lowlevel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2714,7 +2714,7 @@ def test_cleared_tree(self):
27142714
def check_tree(tree):
27152715
assert tree.get_index() == -1
27162716
assert tree.get_left_child(tree.get_virtual_root()) == samples[0]
2717-
assert tree.get_num_edge() == 0
2717+
assert tree.get_num_edges() == 0
27182718
assert tree.get_mrca(0, 1) == _tskit.NULL
27192719
for u in range(ts.get_num_nodes()):
27202720
assert tree.get_parent(u) == _tskit.NULL

python/tskit/trees.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2127,9 +2127,11 @@ def timedesc(self, u=NULL):
21272127
return self.timeasc(u)[::-1]
21282128

21292129
def _preorder_traversal(self, root):
2130+
# Return Python integers for compatibility
21302131
return map(int, self.preorder(root))
21312132

21322133
def _postorder_traversal(self, root):
2134+
# Return Python integers for compatibility
21332135
return map(int, self.postorder(root))
21342136

21352137
def _inorder_traversal(self, root):
@@ -2166,13 +2168,13 @@ def _timeasc_traversal(self, root):
21662168
"""
21672169
Sorts by increasing time but falls back to increasing ID for equal times.
21682170
"""
2169-
yield from self.timeasc(root)
2171+
return map(int, self.timeasc(root))
21702172

21712173
def _timedesc_traversal(self, root):
21722174
"""
21732175
The reverse of timeasc.
21742176
"""
2175-
yield from self.timedesc(root)
2177+
return map(int, self.timedesc(root))
21762178

21772179
def _minlex_postorder_traversal(self, root):
21782180
"""

0 commit comments

Comments
 (0)