Skip to content

Tree array access. #1320

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __getattr__(cls, name):
# -- Project information -----------------------------------------------------

project = "tskit"
copyright = "2018-2020, Tskit developers" # noqa: A001
copyright = "2018-2021, Tskit developers" # noqa: A001
author = "Tskit developers"


Expand Down Expand Up @@ -278,8 +278,8 @@ def handle_item(fieldarg, content):

# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
"https://docs.python.org/": None,
"http://docs.scipy.org/doc/numpy/": None,
"https://docs.python.org/3": None,
"https://numpy.org/doc/stable/": None,
"https://svgwrite.readthedocs.io/en/stable/": None,
}

Expand Down
8 changes: 8 additions & 0 deletions docs/python-api.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
.. |tree_array_warning| replace:: This is a high-performance interface which
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

woah nice

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah!

provides zero-copy access to memory used in the C library.
As a consequence, the values stored in this array will change as
the Tree state is modified as we move along the tree sequence. See the
:class:`.Tree` documentation for more details. Therefore, if you want to
compare arrays representing different trees along the sequence, you must
take **copies** of the arrays.

.. currentmodule:: tskit
.. _sec_python_api:

Expand Down
116 changes: 116 additions & 0 deletions python/_tskitmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -8786,6 +8786,121 @@ Tree_set_root_threshold(Tree *self, PyObject *args)
return ret;
}

/* The x_array properties are the high-performance zero-copy interface to the
* corresponding arrays in the tsk_tree object. We use properties and
* return a new array each time rather than trying to create a single array
* at Tree initialisation time to avoid a circular reference counting loop,
* which (it seems) the even cyclic garbage collection support can't resolve.
*/
static PyObject *
Tree_make_array(Tree *self, int dtype, void *data)
{
PyObject *ret = NULL;
PyArrayObject *array = NULL;
npy_intp dims = self->tree->num_nodes;

array = (PyArrayObject *) PyArray_SimpleNewFromData(1, &dims, dtype, data);
if (array == NULL) {
goto out;
}
PyArray_CLEARFLAGS(array, NPY_ARRAY_WRITEABLE);
if (PyArray_SetBaseObject(array, (PyObject *) self) != 0) {
goto out;
}
/* PyArray_SetBaseObject steals a reference, so we have to incref the tree
* object. This makes sure that the Tree instance will stay alive if there
* are any arrays that refer to its memory. */
Py_INCREF(self);
ret = (PyObject *) array;
array = NULL;
out:
Py_XDECREF(array);
return ret;
}

static PyObject *
Tree_get_parent_array(Tree *self, void *closure)
{
PyObject *ret = NULL;

if (Tree_check_state(self) != 0) {
goto out;
}
ret = Tree_make_array(self, NPY_INT32, self->tree->parent);
out:
return ret;
}

static PyObject *
Tree_get_left_child_array(Tree *self, void *closure)
{
PyObject *ret = NULL;

if (Tree_check_state(self) != 0) {
goto out;
}
ret = Tree_make_array(self, NPY_INT32, self->tree->left_child);
out:
return ret;
}

static PyObject *
Tree_get_right_child_array(Tree *self, void *closure)
{
PyObject *ret = NULL;

if (Tree_check_state(self) != 0) {
goto out;
}
ret = Tree_make_array(self, NPY_INT32, self->tree->right_child);
out:
return ret;
}

static PyObject *
Tree_get_left_sib_array(Tree *self, void *closure)
{
PyObject *ret = NULL;

if (Tree_check_state(self) != 0) {
goto out;
}
ret = Tree_make_array(self, NPY_INT32, self->tree->left_sib);
out:
return ret;
}

static PyObject *
Tree_get_right_sib_array(Tree *self, void *closure)
{
PyObject *ret = NULL;

if (Tree_check_state(self) != 0) {
goto out;
}
ret = Tree_make_array(self, NPY_INT32, self->tree->right_sib);
out:
return ret;
}

static PyGetSetDef Tree_getsetters[]
= { { .name = "parent_array",
.get = (getter) Tree_get_parent_array,
.doc = "The parent array in the quintuply linked tree." },
{ .name = "left_child_array",
.get = (getter) Tree_get_left_child_array,
.doc = "The left_child array in the quintuply linked tree." },
{ .name = "right_child_array",
.get = (getter) Tree_get_right_child_array,
.doc = "The right_child array in the quintuply linked tree." },
{ .name = "left_sib_array",
.get = (getter) Tree_get_left_sib_array,
.doc = "The left_sib array in the quintuply linked tree." },
{ .name = "right_sib_array",
.get = (getter) Tree_get_right_sib_array,
.doc = "The right_sib array in the quintuply linked tree." },
{ NULL } };

static PyMethodDef Tree_methods[] = {
{ .ml_name = "first",
.ml_meth = (PyCFunction) Tree_first,
Expand Down Expand Up @@ -8961,6 +9076,7 @@ static PyTypeObject TreeType = {
.tp_flags = Py_TPFLAGS_DEFAULT,
.tp_doc = "Tree objects",
.tp_methods = Tree_methods,
.tp_getset = Tree_getsetters,
.tp_init = (initproc) Tree_init,
.tp_new = PyType_GenericNew,
// clang-format on
Expand Down
58 changes: 43 additions & 15 deletions python/tests/test_highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2726,6 +2726,41 @@ def test_interval(self):
breakpoints[i + 1] - breakpoints[i]
)

def verify_tree_arrays(self, tree):
ts = tree.tree_sequence
assert tree.parent_array.shape == (ts.num_nodes,)
assert tree.left_child_array.shape == (ts.num_nodes,)
assert tree.right_child_array.shape == (ts.num_nodes,)
assert tree.left_sib_array.shape == (ts.num_nodes,)
assert tree.right_sib_array.shape == (ts.num_nodes,)
for u in range(ts.num_nodes):
assert tree.parent(u) == tree.parent_array[u]
assert tree.left_child(u) == tree.left_child_array[u]
assert tree.right_child(u) == tree.right_child_array[u]
assert tree.left_sib(u) == tree.left_sib_array[u]
assert tree.right_sib(u) == tree.right_sib_array[u]

def test_tree_arrays(self):
ts = msprime.simulate(10, recombination_rate=1, random_seed=1)
assert ts.num_trees > 1
for tree in ts.trees():
self.verify_tree_arrays(tree)

@pytest.mark.parametrize(
"array", ["parent", "left_child", "right_child", "left_sib", "right_sib"]
)
def test_tree_array_properties(self, array):
name = array + "_array"
ts = msprime.simulate(10, random_seed=1)
tree = ts.first()
a = getattr(tree, name)
assert getattr(tree, name) is a
assert a.base is tree._ll_tree
with pytest.raises(AttributeError):
setattr(tree, name, None)
with pytest.raises(AttributeError):
delattr(tree, name)

def verify_empty_tree(self, tree):
ts = tree.tree_sequence
assert tree.index == -1
Expand All @@ -2745,6 +2780,7 @@ def verify_empty_tree(self, tree):
assert tree.left_sib(samples[j]) == samples[j - 1]
if j < ts.num_samples - 1:
assert tree.right_sib(samples[j]) == samples[j + 1]
self.verify_tree_arrays(tree)

def test_empty_tree(self):
ts = msprime.simulate(10, recombination_rate=3, length=3, random_seed=42)
Expand Down Expand Up @@ -2775,21 +2811,11 @@ def test_clear(self):
def verify_trees_identical(self, t1, t2):
assert t1.tree_sequence is t2.tree_sequence
assert t1.num_nodes is t2.num_nodes
assert [t1.parent(u) for u in range(t1.num_nodes)] == [
t2.parent(u) for u in range(t2.num_nodes)
]
assert [t1.left_child(u) for u in range(t1.num_nodes)] == [
t2.left_child(u) for u in range(t2.num_nodes)
]
assert [t1.right_child(u) for u in range(t1.num_nodes)] == [
t2.right_child(u) for u in range(t2.num_nodes)
]
assert [t1.left_sib(u) for u in range(t1.num_nodes)] == [
t2.left_sib(u) for u in range(t2.num_nodes)
]
assert [t1.right_sib(u) for u in range(t1.num_nodes)] == [
t2.right_sib(u) for u in range(t2.num_nodes)
]
assert np.all(t1.parent_array == t2.parent_array)
assert np.all(t1.left_child_array == t2.left_child_array)
assert np.all(t1.right_child_array == t2.right_child_array)
assert np.all(t1.left_sib_array == t2.left_sib_array)
assert np.all(t1.right_sib_array == t2.right_sib_array)
assert list(t1.sites()) == list(t2.sites())

def test_copy_seek(self):
Expand All @@ -2807,6 +2833,8 @@ def test_copy_seek(self):
tree.clear()
copy = tree.copy()
tree.first()
# Make sure the underlying arrays are different
assert np.any(tree.parent_array != copy.parent_array)
copy.first()
while tree.index != -1:
self.verify_trees_identical(tree, copy)
Expand Down
70 changes: 70 additions & 0 deletions python/tests/test_lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1929,6 +1929,8 @@ class TestTree(LowLevelTestCase):
Tests on the low-level tree interface.
"""

ARRAY_NAMES = ["parent", "left_child", "right_child", "left_sib", "right_sib"]

def test_options(self):
ts = self.get_example_tree_sequence()
st = _tskit.Tree(ts)
Expand Down Expand Up @@ -2550,6 +2552,74 @@ def test_map_mutations_errors(self):
with pytest.raises(_tskit.LibraryError):
tree.map_mutations(genotypes)

@pytest.mark.parametrize("array", ARRAY_NAMES)
def test_array_read_only(self, array):
name = array + "_array"
ts1 = self.get_example_tree_sequence(10)
t1 = _tskit.Tree(ts1)
t1.first()
with pytest.raises(AttributeError, match="not writable"):
setattr(t1, name, None)
with pytest.raises(AttributeError, match="not writable"):
delattr(t1, name)

a = getattr(t1, name)
with pytest.raises(ValueError, match="assignment destination"):
a[:] = 0
with pytest.raises(ValueError, match="assignment destination"):
a[0] = 0
with pytest.raises(ValueError, match="cannot set WRITEABLE"):
a.setflags(write=True)

@pytest.mark.parametrize("array", ARRAY_NAMES)
def test_array_properties(self, array):
ts1 = self.get_example_tree_sequence(10)
t1 = _tskit.Tree(ts1)
a = getattr(t1, array + "_array")
t1.first()
a = getattr(t1, array + "_array")
assert a.dtype == np.int32
assert a.shape == (ts1.get_num_nodes(),)
assert a.base == t1
assert not a.flags.writeable
assert a.flags.aligned
assert a.flags.c_contiguous
assert not a.flags.owndata
b = getattr(t1, array + "_array")
assert a is not b
assert np.all(a == b)
a_copy = a.copy()
# This checks that the underlying pointer to memory is the same in
# both arrays.
assert a.__array_interface__ == b.__array_interface__
t1.next()
# NB! Because we are pointing to the underlying memory, the arrays
# will change as we iterate along the trees! This is a gotcha, but
# it's just something we have to document as it's a consequence of the
# zero copy semantics.
b = getattr(t1, array + "_array")
assert np.all(a == b)
assert np.any(a_copy != b)

@pytest.mark.parametrize("array", ARRAY_NAMES)
def test_array_lifetime(self, array):
ts1 = self.get_example_tree_sequence(10)
t1 = _tskit.Tree(ts1)
t1.first()
a1 = getattr(t1, array + "_array")
a2 = a1.copy()
assert a1 is not a2
del t1
# Do some memory operations
a3 = np.ones(10 ** 6)
assert np.all(a1 == a2)
del ts1
assert np.all(a1 == a2)
del a1
# Just do something to touch memory
a2[:] = 0
assert a3 is not a2


class TestTableMetadataSchema(MetadataTestMixin):
def test_metadata_schema_attribute(self):
Expand Down
Loading