Skip to content

Commit

Permalink
Merge pull request #1799 from jeromekelleher/use-virtual-root
Browse files Browse the repository at this point in the history
Use virtual root
  • Loading branch information
mergify[bot] authored Oct 20, 2021
2 parents e622480 + 02d4771 commit e5dc0f6
Show file tree
Hide file tree
Showing 8 changed files with 211 additions and 130 deletions.
39 changes: 39 additions & 0 deletions c/tests/test_trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -4329,6 +4329,44 @@ test_single_tree_is_descendant(void)
tsk_treeseq_free(&ts);
}

static void
test_single_tree_total_branch_length(void)
{
int ret;
tsk_treeseq_t ts;
tsk_tree_t tree;
double length;

tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, NULL,
NULL, NULL, NULL, 0);
ret = tsk_tree_init(&tree, &ts, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
ret = tsk_tree_first(&tree);
CU_ASSERT_EQUAL_FATAL(ret, 1);

CU_ASSERT_EQUAL_FATAL(tsk_tree_get_total_branch_length(&tree, TSK_NULL, &length), 0);
CU_ASSERT_EQUAL_FATAL(length, 9);
CU_ASSERT_EQUAL_FATAL(tsk_tree_get_total_branch_length(&tree, 7, &length), 0);
CU_ASSERT_EQUAL_FATAL(length, 9);
CU_ASSERT_EQUAL_FATAL(
tsk_tree_get_total_branch_length(&tree, tree.virtual_root, &length), 0);
CU_ASSERT_EQUAL_FATAL(length, 9);
CU_ASSERT_EQUAL_FATAL(tsk_tree_get_total_branch_length(&tree, 4, &length), 0);
CU_ASSERT_EQUAL_FATAL(length, 2);
CU_ASSERT_EQUAL_FATAL(tsk_tree_get_total_branch_length(&tree, 0, &length), 0);
CU_ASSERT_EQUAL_FATAL(length, 0);
CU_ASSERT_EQUAL_FATAL(tsk_tree_get_total_branch_length(&tree, 5, &length), 0);
CU_ASSERT_EQUAL_FATAL(length, 4);

CU_ASSERT_EQUAL_FATAL(tsk_tree_get_total_branch_length(&tree, -2, &length),
TSK_ERR_NODE_OUT_OF_BOUNDS);
CU_ASSERT_EQUAL_FATAL(
tsk_tree_get_total_branch_length(&tree, 8, &length), TSK_ERR_NODE_OUT_OF_BOUNDS);

tsk_tree_free(&tree);
tsk_treeseq_free(&ts);
}

static void
test_single_tree_map_mutations(void)
{
Expand Down Expand Up @@ -6605,6 +6643,7 @@ main(int argc, char **argv)
{ "test_single_tree_compute_mutation_times",
test_single_tree_compute_mutation_times },
{ "test_single_tree_is_descendant", test_single_tree_is_descendant },
{ "test_single_tree_total_branch_length", test_single_tree_total_branch_length },
{ "test_single_tree_map_mutations", test_single_tree_map_mutations },
{ "test_single_tree_map_mutations_internal_samples",
test_single_tree_map_mutations_internal_samples },
Expand Down
3 changes: 1 addition & 2 deletions c/tskit/convert.c
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,7 @@ tsk_newick_converter_init(tsk_newick_converter_t *self, const tsk_tree_t *tree,
self->options = options;
self->tree = tree;
self->traversal_stack
= tsk_malloc(tsk_treeseq_get_num_nodes(self->tree->tree_sequence)
* sizeof(*self->traversal_stack));
= tsk_malloc(tsk_tree_get_size_bound(tree) * sizeof(*self->traversal_stack));
if (self->traversal_stack == NULL) {
ret = TSK_ERR_NO_MEMORY;
goto out;
Expand Down
181 changes: 91 additions & 90 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -3509,34 +3509,29 @@ tsk_tree_get_num_samples_by_traversal(
const tsk_tree_t *self, tsk_id_t u, tsk_size_t *num_samples)
{
int ret = 0;
tsk_id_t *stack = NULL;
tsk_id_t v;
tsk_size_t num_nodes, j;
tsk_size_t count = 0;
int stack_top = 0;
const tsk_flags_t *restrict flags = self->tree_sequence->tables->nodes.flags;
tsk_id_t *nodes = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*nodes));
tsk_id_t v;

stack = tsk_malloc(self->num_nodes * sizeof(*stack));
if (stack == NULL) {
if (nodes == NULL) {
ret = TSK_ERR_NO_MEMORY;
goto out;
}

stack[0] = u;
while (stack_top >= 0) {
v = stack[stack_top];
stack_top--;
if (tsk_treeseq_is_sample(self->tree_sequence, v)) {
ret = tsk_tree_preorder(self, u, nodes, &num_nodes);
if (ret != 0) {
goto out;
}
for (j = 0; j < num_nodes; j++) {
v = nodes[j];
if (flags[v] & TSK_NODE_IS_SAMPLE) {
count++;
}
v = self->left_child[v];
while (v != TSK_NULL) {
stack_top++;
stack[stack_top] = v;
v = self->right_sib[v];
}
}
*num_samples = count;
out:
tsk_safe_free(stack);
tsk_safe_free(nodes);
return ret;
}

Expand Down Expand Up @@ -3636,6 +3631,40 @@ tsk_tree_get_time(const tsk_tree_t *self, tsk_id_t u, double *t)
return ret;
}

int
tsk_tree_get_total_branch_length(const tsk_tree_t *self, tsk_id_t node, double *ret_tbl)
{
int ret = 0;
tsk_size_t j, num_nodes;
tsk_id_t u, v;
const tsk_id_t *restrict parent = self->parent;
const double *restrict time = self->tree_sequence->tables->nodes.time;
tsk_id_t *nodes = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*nodes));
double sum = 0;

if (nodes == NULL) {
ret = TSK_ERR_NO_MEMORY;
goto out;
}
ret = tsk_tree_preorder(self, node, nodes, &num_nodes);
if (ret != 0) {
goto out;
}
/* We always skip the first node because we don't return the branch length
* over the input node. */
for (j = 1; j < num_nodes; j++) {
u = nodes[j];
v = parent[u];
if (v != TSK_NULL) {
sum += time[v] - time[u];
}
}
*ret_tbl = sum;
out:
tsk_safe_free(nodes);
return ret;
}

int TSK_WARN_UNUSED
tsk_tree_get_sites(
const tsk_tree_t *self, const tsk_site_t **sites, tsk_size_t *sites_length)
Expand All @@ -3649,14 +3678,14 @@ tsk_tree_get_sites(
static int
tsk_tree_get_depth_unsafe(const tsk_tree_t *self, tsk_id_t u)
{

tsk_id_t v;
const tsk_id_t *restrict parent = self->parent;
int depth = 0;

if (u == self->virtual_root) {
return -1;
}
for (v = self->parent[u]; v != TSK_NULL; v = self->parent[v]) {
for (v = parent[u]; v != TSK_NULL; v = parent[v]) {
depth++;
}
return depth;
Expand Down Expand Up @@ -4443,6 +4472,9 @@ get_smallest_set_bit(uint64_t v)
* use a general cost matrix, in which case we'll use the Sankoff algorithm. For
* now this is unused.
*
* We should also vectorise the function so that several sites can be processed
* at once.
*
* The algorithm used here is Hartigan parsimony, "Minimum Mutation Fits to a
* Given Tree", Biometrics 1973.
*/
Expand All @@ -4458,30 +4490,34 @@ tsk_tree_map_mutations(tsk_tree_t *self, int8_t *genotypes,
int8_t state;
};
const tsk_size_t num_samples = self->tree_sequence->num_samples;
const tsk_size_t num_nodes = self->num_nodes;
const tsk_id_t *restrict left_child = self->left_child;
const tsk_id_t *restrict right_sib = self->right_sib;
const tsk_id_t *restrict parent = self->parent;
const tsk_size_t N = tsk_treeseq_get_num_nodes(self->tree_sequence);
const tsk_flags_t *restrict node_flags = self->tree_sequence->tables->nodes.flags;
uint64_t optimal_root_set;
uint64_t *restrict optimal_set = tsk_calloc(num_nodes, sizeof(*optimal_set));
tsk_id_t *restrict postorder_stack
= tsk_malloc(num_nodes * sizeof(*postorder_stack));
tsk_id_t *nodes = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*nodes));
/* Note: to use less memory here and to improve cache performance we should
* probably change to allocating exactly the number of nodes returned by
* a preorder traversal, and then lay the memory out in this order. So, we'd
* need a map from node ID to its index in the preorder traversal, but this
* is trivial to compute. Probably doesn't matter so much at the moment
* when we're doing a single site, but it would make a big difference if
* we were vectorising over lots of sites. */
uint64_t *restrict optimal_set = tsk_calloc(N + 1, sizeof(*optimal_set));
struct stack_elem *restrict preorder_stack
= tsk_malloc(num_nodes * sizeof(*preorder_stack));
tsk_id_t postorder_parent, root, u, v;
= tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*preorder_stack));
tsk_id_t root, u, v;
/* The largest possible number of transitions is one over every sample */
tsk_state_transition_t *transitions = tsk_malloc(num_samples * sizeof(*transitions));
int8_t allele, ancestral_state;
int stack_top;
struct stack_elem s;
tsk_size_t j, num_transitions, max_allele_count;
tsk_size_t j, num_transitions, max_allele_count, num_nodes;
tsk_size_t allele_count[HARTIGAN_MAX_ALLELES];
tsk_size_t non_missing = 0;
int8_t num_alleles = 0;

if (optimal_set == NULL || preorder_stack == NULL || postorder_stack == NULL
|| transitions == NULL) {
if (optimal_set == NULL || preorder_stack == NULL || transitions == NULL
|| nodes == NULL) {
ret = TSK_ERR_NO_MEMORY;
goto out;
}
Expand Down Expand Up @@ -4518,68 +4554,33 @@ tsk_tree_map_mutations(tsk_tree_t *self, int8_t *genotypes,
}
}

for (root = self->left_root; root != TSK_NULL; root = self->right_sib[root]) {
/* Do a post order traversal */
postorder_stack[0] = root;
stack_top = 0;
postorder_parent = TSK_NULL;
while (stack_top >= 0) {
u = postorder_stack[stack_top];
if (left_child[u] != TSK_NULL && u != postorder_parent) {
for (v = left_child[u]; v != TSK_NULL; v = right_sib[v]) {
stack_top++;
postorder_stack[stack_top] = v;
}
} else {
stack_top--;
postorder_parent = parent[u];

/* Visit u */
tsk_memset(
allele_count, 0, ((size_t) num_alleles) * sizeof(*allele_count));
for (v = left_child[u]; v != TSK_NULL; v = right_sib[v]) {
for (allele = 0; allele < num_alleles; allele++) {
allele_count[allele] += bit_is_set(optimal_set[v], allele);
}
}
if (!(node_flags[u] & TSK_NODE_IS_SAMPLE)) {
max_allele_count = 0;
for (allele = 0; allele < num_alleles; allele++) {
max_allele_count
= TSK_MAX(max_allele_count, allele_count[allele]);
}
for (allele = 0; allele < num_alleles; allele++) {
if (allele_count[allele] == max_allele_count) {
optimal_set[u] = set_bit(optimal_set[u], allele);
}
}
}
}
}
ret = tsk_tree_postorder(self, self->virtual_root, nodes, &num_nodes);
if (ret != 0) {
goto out;
}

if (!(options & TSK_MM_FIXED_ANCESTRAL_STATE)) {
optimal_root_set = 0;
/* TODO it's annoying that this is essentially the same as the
* visit function above. It would be nice if we had an extra
* node that was the parent of all roots, then the algorithm
* would work as-is */
for (j = 0; j < num_nodes; j++) {
u = nodes[j];
tsk_memset(allele_count, 0, ((size_t) num_alleles) * sizeof(*allele_count));
for (root = self->left_root; root != TSK_NULL; root = right_sib[root]) {
for (v = left_child[u]; v != TSK_NULL; v = right_sib[v]) {
for (allele = 0; allele < num_alleles; allele++) {
allele_count[allele] += bit_is_set(optimal_set[root], allele);
allele_count[allele] += bit_is_set(optimal_set[v], allele);
}
}
max_allele_count = 0;
for (allele = 0; allele < num_alleles; allele++) {
max_allele_count = TSK_MAX(max_allele_count, allele_count[allele]);
}
for (allele = 0; allele < num_alleles; allele++) {
if (allele_count[allele] == max_allele_count) {
optimal_root_set = set_bit(optimal_root_set, allele);
/* the virtual root has no flags defined */
if (u == (tsk_id_t) N || !(node_flags[u] & TSK_NODE_IS_SAMPLE)) {
max_allele_count = 0;
for (allele = 0; allele < num_alleles; allele++) {
max_allele_count = TSK_MAX(max_allele_count, allele_count[allele]);
}
for (allele = 0; allele < num_alleles; allele++) {
if (allele_count[allele] == max_allele_count) {
optimal_set[u] = set_bit(optimal_set[u], allele);
}
}
}
ancestral_state = get_smallest_set_bit(optimal_root_set);
}
if (!(options & TSK_MM_FIXED_ANCESTRAL_STATE)) {
ancestral_state = get_smallest_set_bit(optimal_set[N]);
}

num_transitions = 0;
Expand Down Expand Up @@ -4622,8 +4623,8 @@ tsk_tree_map_mutations(tsk_tree_t *self, int8_t *genotypes,
if (preorder_stack != NULL) {
free(preorder_stack);
}
if (postorder_stack != NULL) {
free(postorder_stack);
if (nodes != NULL) {
free(nodes);
}
return ret;
}
Expand Down Expand Up @@ -4888,7 +4889,7 @@ fill_kc_vectors(const tsk_tree_t *t, kc_vectors *kc_vecs)
int ret = 0;
const tsk_treeseq_t *ts = t->tree_sequence;

stack = tsk_malloc(t->num_nodes * sizeof(*stack));
stack = tsk_malloc(tsk_tree_get_size_bound(t) * sizeof(*stack));
if (stack == NULL) {
ret = TSK_ERR_NO_MEMORY;
goto out;
Expand Down Expand Up @@ -5094,7 +5095,7 @@ update_kc_subtree_state(
tsk_id_t *stack = NULL;
int ret = 0;

stack = tsk_malloc(t->num_nodes * sizeof(*stack));
stack = tsk_malloc(tsk_tree_get_size_bound(t) * sizeof(*stack));
if (stack == NULL) {
ret = TSK_ERR_NO_MEMORY;
goto out;
Expand Down
25 changes: 25 additions & 0 deletions c/tskit/trees.h
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,31 @@ be greater than or equal to ``num_nodes``.
*/
tsk_size_t tsk_tree_get_size_bound(const tsk_tree_t *self);

/**
@brief Returns the sum of the lengths of all branches reachable from
the specified node, or from all roots if node=TSK_NULL.
@rst
Return the total branch length in a particular subtree or of the
entire tree. If the specified node is TSK_NULL (or the virtual
root) the sum of the lengths of all branches reachable from roots
is returned. Branch length is defined as difference between the time
of a node and its parent. The branch length of a root is zero.
Note that if the specified node is internal its branch length is
*not* included, so that, e.g., the total branch length of a
leaf node is zero.
@endrst
@param self A pointer to a tsk_tree_t object.
@param node The tree node to compute branch length or TSK_NULL to return the
total branch length of the tree.
@param ret_tbl A double pointer to store the returned total branch length.
@return 0 on success or a negative value on failure.
*/
int tsk_tree_get_total_branch_length(
const tsk_tree_t *self, tsk_id_t node, double *ret_tbl);

/** @} */

int tsk_tree_set_root_threshold(tsk_tree_t *self, tsk_size_t root_threshold);
Expand Down
2 changes: 2 additions & 0 deletions python/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
Roughly a 10X performance increase for "preorder", "postorder", "timeasc"
and "timedesc" (:user:`jeromekelleher`, :pr:`1704`).

- Substantial performance improvement for ``Tree.total_branch_length``
(:user:`jeromekelleher`, :issue:`1794` :pr:`1799`)

--------------------
[0.3.7] - 2021-07-08
Expand Down
Loading

0 comments on commit e5dc0f6

Please sign in to comment.