Skip to content

Commit c7098ce

Browse files
Tree array access.
Closes #1299
1 parent 976bde5 commit c7098ce

File tree

4 files changed

+253
-15
lines changed

4 files changed

+253
-15
lines changed

python/_tskitmodule.c

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8786,6 +8786,121 @@ Tree_set_root_threshold(Tree *self, PyObject *args)
87868786
return ret;
87878787
}
87888788

8789+
/* The x_array properties are the high-performance zero-copy interface to the
8790+
* corresponding arrays in the tsk_tree object. We use properties and
8791+
* return a new array each time rather than trying to create a single array
8792+
* at Tree initialisation time to avoid a circular reference counting loop,
8793+
* which (it seems) the even cyclic garbage collection support can't resolve.
8794+
*/
8795+
static PyObject *
8796+
Tree_make_array(Tree *self, int dtype, void *data)
8797+
{
8798+
PyObject *ret = NULL;
8799+
PyArrayObject *array = NULL;
8800+
npy_intp dims = self->tree->num_nodes;
8801+
8802+
array = (PyArrayObject *) PyArray_SimpleNewFromData(1, &dims, dtype, data);
8803+
if (array == NULL) {
8804+
goto out;
8805+
}
8806+
PyArray_CLEARFLAGS(array, NPY_ARRAY_WRITEABLE);
8807+
if (PyArray_SetBaseObject(array, (PyObject *) self) != 0) {
8808+
goto out;
8809+
}
8810+
/* PyArray_SetBaseObject steals a reference, so we have to incref the tree
8811+
* object. This makes sure that the Tree instance will stay alive if there
8812+
* are any arrays that refer to its memory. */
8813+
Py_INCREF(self);
8814+
ret = (PyObject *) array;
8815+
array = NULL;
8816+
out:
8817+
Py_XDECREF(array);
8818+
return ret;
8819+
}
8820+
8821+
static PyObject *
8822+
Tree_get_parent_array(Tree *self, void *closure)
8823+
{
8824+
PyObject *ret = NULL;
8825+
8826+
if (Tree_check_state(self) != 0) {
8827+
goto out;
8828+
}
8829+
ret = Tree_make_array(self, NPY_INT32, self->tree->parent);
8830+
out:
8831+
return ret;
8832+
}
8833+
8834+
static PyObject *
8835+
Tree_get_left_child_array(Tree *self, void *closure)
8836+
{
8837+
PyObject *ret = NULL;
8838+
8839+
if (Tree_check_state(self) != 0) {
8840+
goto out;
8841+
}
8842+
ret = Tree_make_array(self, NPY_INT32, self->tree->left_child);
8843+
out:
8844+
return ret;
8845+
}
8846+
8847+
static PyObject *
8848+
Tree_get_right_child_array(Tree *self, void *closure)
8849+
{
8850+
PyObject *ret = NULL;
8851+
8852+
if (Tree_check_state(self) != 0) {
8853+
goto out;
8854+
}
8855+
ret = Tree_make_array(self, NPY_INT32, self->tree->right_child);
8856+
out:
8857+
return ret;
8858+
}
8859+
8860+
static PyObject *
8861+
Tree_get_left_sib_array(Tree *self, void *closure)
8862+
{
8863+
PyObject *ret = NULL;
8864+
8865+
if (Tree_check_state(self) != 0) {
8866+
goto out;
8867+
}
8868+
ret = Tree_make_array(self, NPY_INT32, self->tree->left_sib);
8869+
out:
8870+
return ret;
8871+
}
8872+
8873+
static PyObject *
8874+
Tree_get_right_sib_array(Tree *self, void *closure)
8875+
{
8876+
PyObject *ret = NULL;
8877+
8878+
if (Tree_check_state(self) != 0) {
8879+
goto out;
8880+
}
8881+
ret = Tree_make_array(self, NPY_INT32, self->tree->right_sib);
8882+
out:
8883+
return ret;
8884+
}
8885+
8886+
static PyGetSetDef Tree_getsetters[]
8887+
= { { .name = "parent_array",
8888+
.get = (getter) Tree_get_parent_array,
8889+
.doc = "The parent array in the quintuply linked tree." },
8890+
{ .name = "left_child_array",
8891+
.get = (getter) Tree_get_left_child_array,
8892+
.doc = "The left_child array in the quintuply linked tree." },
8893+
{ .name = "right_child_array",
8894+
.get = (getter) Tree_get_right_child_array,
8895+
.doc = "The right_child array in the quintuply linked tree." },
8896+
{ .name = "left_sib_array",
8897+
.get = (getter) Tree_get_left_sib_array,
8898+
.doc = "The left_sib array in the quintuply linked tree." },
8899+
{ .name = "right_sib_array",
8900+
.get = (getter) Tree_get_right_sib_array,
8901+
.doc = "The right_sib array in the quintuply linked tree." },
8902+
{ NULL } };
8903+
87898904
static PyMethodDef Tree_methods[] = {
87908905
{ .ml_name = "first",
87918906
.ml_meth = (PyCFunction) Tree_first,
@@ -8961,6 +9076,7 @@ static PyTypeObject TreeType = {
89619076
.tp_flags = Py_TPFLAGS_DEFAULT,
89629077
.tp_doc = "Tree objects",
89639078
.tp_methods = Tree_methods,
9079+
.tp_getset = Tree_getsetters,
89649080
.tp_init = (initproc) Tree_init,
89659081
.tp_new = PyType_GenericNew,
89669082
// clang-format on

python/tests/test_highlevel.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2692,6 +2692,39 @@ def test_interval(self):
26922692
breakpoints[i + 1] - breakpoints[i]
26932693
)
26942694

2695+
def verify_tree_arrays(self, tree):
2696+
ts = tree.tree_sequence
2697+
assert tree.parent_array.shape == (ts.num_nodes,)
2698+
assert tree.left_child_array.shape == (ts.num_nodes,)
2699+
assert tree.right_child_array.shape == (ts.num_nodes,)
2700+
assert tree.left_sib_array.shape == (ts.num_nodes,)
2701+
assert tree.right_sib_array.shape == (ts.num_nodes,)
2702+
for u in range(ts.num_nodes):
2703+
assert tree.parent(u) == tree.parent_array[u]
2704+
assert tree.left_child(u) == tree.left_child_array[u]
2705+
assert tree.right_child(u) == tree.right_child_array[u]
2706+
assert tree.left_sib(u) == tree.left_sib_array[u]
2707+
assert tree.right_sib(u) == tree.right_sib_array[u]
2708+
2709+
def test_tree_arrays(self):
2710+
ts = msprime.simulate(10, recombination_rate=1, random_seed=1)
2711+
assert ts.num_trees > 1
2712+
for tree in ts.trees():
2713+
self.verify_tree_arrays(tree)
2714+
2715+
@pytest.mark.parametrize(
2716+
"array", ["parent", "left_child", "right_child", "left_sib", "right_sib"]
2717+
)
2718+
def test_tree_array_properties(self, array):
2719+
name = array + "_array"
2720+
ts = msprime.simulate(10, random_seed=1)
2721+
tree = ts.first()
2722+
a = getattr(tree, name)
2723+
assert getattr(tree, name) is a
2724+
assert a.base is tree._ll_tree
2725+
with pytest.raises(AttributeError):
2726+
setattr(tree, name, None)
2727+
26952728
def verify_empty_tree(self, tree):
26962729
ts = tree.tree_sequence
26972730
assert tree.index == -1
@@ -2711,6 +2744,7 @@ def verify_empty_tree(self, tree):
27112744
assert tree.left_sib(samples[j]) == samples[j - 1]
27122745
if j < ts.num_samples - 1:
27132746
assert tree.right_sib(samples[j]) == samples[j + 1]
2747+
self.verify_tree_arrays(tree)
27142748

27152749
def test_empty_tree(self):
27162750
ts = msprime.simulate(10, recombination_rate=3, length=3, random_seed=42)
@@ -2741,21 +2775,11 @@ def test_clear(self):
27412775
def verify_trees_identical(self, t1, t2):
27422776
assert t1.tree_sequence is t2.tree_sequence
27432777
assert t1.num_nodes is t2.num_nodes
2744-
assert [t1.parent(u) for u in range(t1.num_nodes)] == [
2745-
t2.parent(u) for u in range(t2.num_nodes)
2746-
]
2747-
assert [t1.left_child(u) for u in range(t1.num_nodes)] == [
2748-
t2.left_child(u) for u in range(t2.num_nodes)
2749-
]
2750-
assert [t1.right_child(u) for u in range(t1.num_nodes)] == [
2751-
t2.right_child(u) for u in range(t2.num_nodes)
2752-
]
2753-
assert [t1.left_sib(u) for u in range(t1.num_nodes)] == [
2754-
t2.left_sib(u) for u in range(t2.num_nodes)
2755-
]
2756-
assert [t1.right_sib(u) for u in range(t1.num_nodes)] == [
2757-
t2.right_sib(u) for u in range(t2.num_nodes)
2758-
]
2778+
assert np.all(t1.parent_array == t2.parent_array)
2779+
assert np.all(t1.left_child_array == t2.left_child_array)
2780+
assert np.all(t1.right_child_array == t2.right_child_array)
2781+
assert np.all(t1.left_sib_array == t2.left_sib_array)
2782+
assert np.all(t1.right_sib_array == t2.right_sib_array)
27592783
assert list(t1.sites()) == list(t2.sites())
27602784

27612785
def test_copy_seek(self):

python/tests/test_lowlevel.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2550,6 +2550,76 @@ def test_map_mutations_errors(self):
25502550
with pytest.raises(_tskit.LibraryError):
25512551
tree.map_mutations(genotypes)
25522552

2553+
@pytest.mark.parametrize(
2554+
"array", ["parent", "left_child", "right_child", "left_sib", "right_sib"]
2555+
)
2556+
def test_array_read_only(self, array):
2557+
name = array + "_array"
2558+
ts1 = self.get_example_tree_sequence(10)
2559+
t1 = _tskit.Tree(ts1)
2560+
t1.first()
2561+
with pytest.raises(AttributeError, match="not writable"):
2562+
setattr(t1, name, None)
2563+
a = getattr(t1, name)
2564+
with pytest.raises(ValueError, match="assignment destination"):
2565+
a[:] = 0
2566+
with pytest.raises(ValueError, match="assignment destination"):
2567+
a[0] = 0
2568+
with pytest.raises(ValueError, match="cannot set WRITEABLE"):
2569+
a.setflags(write=True)
2570+
2571+
@pytest.mark.parametrize(
2572+
"array", ["parent", "left_child", "right_child", "left_sib", "right_sib"]
2573+
)
2574+
def test_array_properties(self, array):
2575+
ts1 = self.get_example_tree_sequence(10)
2576+
t1 = _tskit.Tree(ts1)
2577+
a = getattr(t1, array + "_array")
2578+
t1.first()
2579+
a = getattr(t1, array + "_array")
2580+
assert a.dtype == np.int32
2581+
assert a.shape == (ts1.get_num_nodes(),)
2582+
assert a.base == t1
2583+
assert not a.flags.writeable
2584+
assert a.flags.aligned
2585+
assert a.flags.c_contiguous
2586+
b = getattr(t1, array + "_array")
2587+
assert a is not b
2588+
assert np.all(a == b)
2589+
a_copy = a.copy()
2590+
# This checks that the underlying pointer to memory is the same in
2591+
# both arrays.
2592+
assert a.__array_interface__ == b.__array_interface__
2593+
t1.next()
2594+
# NB! Because we are pointing to the underlying memory, the arrays
2595+
# will change as we iterate along the trees! This is a gotcha, but
2596+
# it's just something we have to document as it's a consequence of the
2597+
# zero copy semantics.
2598+
b = getattr(t1, array + "_array")
2599+
assert np.all(a == b)
2600+
assert np.any(a_copy != b)
2601+
2602+
@pytest.mark.parametrize(
2603+
"array", ["parent", "left_child", "right_child", "left_sib", "right_sib"]
2604+
)
2605+
def test_array_lifetime(self, array):
2606+
ts1 = self.get_example_tree_sequence(10)
2607+
t1 = _tskit.Tree(ts1)
2608+
t1.first()
2609+
a1 = getattr(t1, array + "_array")
2610+
a2 = a1.copy()
2611+
assert a1 is not a2
2612+
del t1
2613+
# Do some memory operations
2614+
a3 = np.ones(10 ** 6)
2615+
assert np.all(a1 == a2)
2616+
del ts1
2617+
assert np.all(a1 == a2)
2618+
del a1
2619+
# Just do something to touch memory
2620+
a2[:] = 0
2621+
assert a3 is not a2
2622+
25532623

25542624
class TestTableMetadataSchema(MetadataTestMixin):
25552625
def test_metadata_schema_attribute(self):

python/tskit/trees.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,14 @@ def __init__(
669669
self._tree_sequence = tree_sequence
670670
self._ll_tree = _tskit.Tree(tree_sequence.ll_tree_sequence, **kwargs)
671671
self._ll_tree.set_root_threshold(root_threshold)
672+
# Store the low-level arrays for efficiency. There's no real overhead
673+
# in this, because the refer to the same underlying memory as the
674+
# tree object.
675+
self._parent_array = self._ll_tree.parent_array
676+
self._left_child_array = self._ll_tree.left_child_array
677+
self._right_child_array = self._ll_tree.right_child_array
678+
self._left_sib_array = self._ll_tree.left_sib_array
679+
self._right_sib_array = self._ll_tree.right_sib_array
672680

673681
def copy(self):
674682
"""
@@ -1014,6 +1022,10 @@ def parent(self, u):
10141022
"""
10151023
return self._ll_tree.get_parent(u)
10161024

1025+
@property
1026+
def parent_array(self):
1027+
return self._parent_array
1028+
10171029
# Quintuply linked tree structure.
10181030

10191031
def left_child(self, u):
@@ -1035,6 +1047,10 @@ def left_child(self, u):
10351047
"""
10361048
return self._ll_tree.get_left_child(u)
10371049

1050+
@property
1051+
def left_child_array(self):
1052+
return self._left_child_array
1053+
10381054
def right_child(self, u):
10391055
"""
10401056
Returns the rightmost child of the specified node. Returns
@@ -1054,6 +1070,10 @@ def right_child(self, u):
10541070
"""
10551071
return self._ll_tree.get_right_child(u)
10561072

1073+
@property
1074+
def right_child_array(self):
1075+
return self._right_child_array
1076+
10571077
def left_sib(self, u):
10581078
"""
10591079
Returns the sibling node to the left of u, or :data:`tskit.NULL`
@@ -1069,6 +1089,10 @@ def left_sib(self, u):
10691089
"""
10701090
return self._ll_tree.get_left_sib(u)
10711091

1092+
@property
1093+
def left_sib_array(self):
1094+
return self._left_sib_array
1095+
10721096
def right_sib(self, u):
10731097
"""
10741098
Returns the sibling node to the right of u, or :data:`tskit.NULL`
@@ -1084,6 +1108,10 @@ def right_sib(self, u):
10841108
"""
10851109
return self._ll_tree.get_right_sib(u)
10861110

1111+
@property
1112+
def right_sib_array(self):
1113+
return self._right_sib_array
1114+
10871115
# Sample list.
10881116

10891117
def left_sample(self, u):

0 commit comments

Comments
 (0)