Skip to content

Makes stats cache optional #2135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions python/tests/test_tree_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,13 @@ def naive_branch_general_stat(


def branch_general_stat(
ts, sample_weights, summary_func, windows=None, polarised=False, span_normalise=True
ts,
sample_weights,
summary_func,
windows=None,
polarised=False,
span_normalise=True,
cache_summary=True,
):
"""
Efficient implementation of the algorithm used as the basis for the
Expand All @@ -200,8 +206,6 @@ def branch_general_stat(
time = ts.tables.nodes.time
parent = np.zeros(ts.num_nodes, dtype=np.int32) - 1
branch_length = np.zeros(ts.num_nodes)
# The value of summary_func(u) for every node.
summary = np.zeros((ts.num_nodes, result_dim))
# The result for the current tree *not* weighted by span.
running_sum = np.zeros(result_dim)

Expand All @@ -211,20 +215,26 @@ def polarised_summary(u):
s += summary_func(total_weight - state[u])
return s

for u in ts.samples():
summary[u] = polarised_summary(u)
# TODO Make this optional and figure out how to cache inline
# The value of summary_func(u) for every node.
# summary = np.zeros((ts.num_nodes, result_dim))
# for u in ts.samples():
# summary[u] = polarised_summary(u)

def summary(u):
return polarised_summary(u)

window_index = 0
for (t_left, t_right), edges_out, edges_in in ts.edge_diffs():
for edge in edges_out:
u = edge.child
running_sum -= branch_length[u] * summary[u]
running_sum -= branch_length[u] * summary(u)
u = edge.parent
while u != -1:
running_sum -= branch_length[u] * summary[u]
running_sum -= branch_length[u] * summary(u)
state[u] -= state[edge.child]
summary[u] = polarised_summary(u)
running_sum += branch_length[u] * summary[u]
# summary(u) = polarised_summary(u)
running_sum += branch_length[u] * summary(u)
u = parent[u]
parent[edge.child] = -1
branch_length[edge.child] = 0
Expand All @@ -233,13 +243,13 @@ def polarised_summary(u):
parent[edge.child] = edge.parent
branch_length[edge.child] = time[edge.parent] - time[edge.child]
u = edge.child
running_sum += branch_length[u] * summary[u]
running_sum += branch_length[u] * summary(u)
u = edge.parent
while u != -1:
running_sum -= branch_length[u] * summary[u]
running_sum -= branch_length[u] * summary(u)
state[u] += state[edge.child]
summary[u] = polarised_summary(u)
running_sum += branch_length[u] * summary[u]
# summary(u) = polarised_summary(u)
running_sum += branch_length[u] * summary(u)
u = parent[u]

# Update the windows
Expand Down