diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index 451a827699..ffeff4b2c5 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -4329,6 +4329,44 @@ test_single_tree_is_descendant(void) tsk_treeseq_free(&ts); } +static void +test_single_tree_total_branch_length(void) +{ + int ret; + tsk_treeseq_t ts; + tsk_tree_t tree; + double length; + + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, NULL, + NULL, NULL, NULL, 0); + ret = tsk_tree_init(&tree, &ts, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tree_first(&tree); + CU_ASSERT_EQUAL_FATAL(ret, 1); + + CU_ASSERT_EQUAL_FATAL(tsk_tree_get_total_branch_length(&tree, TSK_NULL, &length), 0); + CU_ASSERT_EQUAL_FATAL(length, 9); + CU_ASSERT_EQUAL_FATAL(tsk_tree_get_total_branch_length(&tree, 7, &length), 0); + CU_ASSERT_EQUAL_FATAL(length, 9); + CU_ASSERT_EQUAL_FATAL( + tsk_tree_get_total_branch_length(&tree, tree.virtual_root, &length), 0); + CU_ASSERT_EQUAL_FATAL(length, 9); + CU_ASSERT_EQUAL_FATAL(tsk_tree_get_total_branch_length(&tree, 4, &length), 0); + CU_ASSERT_EQUAL_FATAL(length, 2); + CU_ASSERT_EQUAL_FATAL(tsk_tree_get_total_branch_length(&tree, 0, &length), 0); + CU_ASSERT_EQUAL_FATAL(length, 0); + CU_ASSERT_EQUAL_FATAL(tsk_tree_get_total_branch_length(&tree, 5, &length), 0); + CU_ASSERT_EQUAL_FATAL(length, 4); + + CU_ASSERT_EQUAL_FATAL(tsk_tree_get_total_branch_length(&tree, -2, &length), + TSK_ERR_NODE_OUT_OF_BOUNDS); + CU_ASSERT_EQUAL_FATAL( + tsk_tree_get_total_branch_length(&tree, 8, &length), TSK_ERR_NODE_OUT_OF_BOUNDS); + + tsk_tree_free(&tree); + tsk_treeseq_free(&ts); +} + static void test_single_tree_map_mutations(void) { @@ -6605,6 +6643,7 @@ main(int argc, char **argv) { "test_single_tree_compute_mutation_times", test_single_tree_compute_mutation_times }, { "test_single_tree_is_descendant", test_single_tree_is_descendant }, + { "test_single_tree_total_branch_length", test_single_tree_total_branch_length }, { "test_single_tree_map_mutations", test_single_tree_map_mutations }, { "test_single_tree_map_mutations_internal_samples", test_single_tree_map_mutations_internal_samples }, diff --git a/c/tskit/convert.c b/c/tskit/convert.c index ed25c1ab84..77e0f71669 100644 --- a/c/tskit/convert.c +++ b/c/tskit/convert.c @@ -154,8 +154,7 @@ tsk_newick_converter_init(tsk_newick_converter_t *self, const tsk_tree_t *tree, self->options = options; self->tree = tree; self->traversal_stack - = tsk_malloc(tsk_treeseq_get_num_nodes(self->tree->tree_sequence) - * sizeof(*self->traversal_stack)); + = tsk_malloc(tsk_tree_get_size_bound(tree) * sizeof(*self->traversal_stack)); if (self->traversal_stack == NULL) { ret = TSK_ERR_NO_MEMORY; goto out; diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 9036c897fe..d407fb740f 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -3509,34 +3509,29 @@ tsk_tree_get_num_samples_by_traversal( const tsk_tree_t *self, tsk_id_t u, tsk_size_t *num_samples) { int ret = 0; - tsk_id_t *stack = NULL; - tsk_id_t v; + tsk_size_t num_nodes, j; tsk_size_t count = 0; - int stack_top = 0; + const tsk_flags_t *restrict flags = self->tree_sequence->tables->nodes.flags; + tsk_id_t *nodes = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*nodes)); + tsk_id_t v; - stack = tsk_malloc(self->num_nodes * sizeof(*stack)); - if (stack == NULL) { + if (nodes == NULL) { ret = TSK_ERR_NO_MEMORY; goto out; } - - stack[0] = u; - while (stack_top >= 0) { - v = stack[stack_top]; - stack_top--; - if (tsk_treeseq_is_sample(self->tree_sequence, v)) { + ret = tsk_tree_preorder(self, u, nodes, &num_nodes); + if (ret != 0) { + goto out; + } + for (j = 0; j < num_nodes; j++) { + v = nodes[j]; + if (flags[v] & TSK_NODE_IS_SAMPLE) { count++; } - v = self->left_child[v]; - while (v != TSK_NULL) { - stack_top++; - stack[stack_top] = v; - v = self->right_sib[v]; - } } *num_samples = count; out: - tsk_safe_free(stack); + tsk_safe_free(nodes); return ret; } @@ -3636,6 +3631,40 @@ tsk_tree_get_time(const tsk_tree_t *self, tsk_id_t u, double *t) return ret; } +int +tsk_tree_get_total_branch_length(const tsk_tree_t *self, tsk_id_t node, double *ret_tbl) +{ + int ret = 0; + tsk_size_t j, num_nodes; + tsk_id_t u, v; + const tsk_id_t *restrict parent = self->parent; + const double *restrict time = self->tree_sequence->tables->nodes.time; + tsk_id_t *nodes = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*nodes)); + double sum = 0; + + if (nodes == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + ret = tsk_tree_preorder(self, node, nodes, &num_nodes); + if (ret != 0) { + goto out; + } + /* We always skip the first node because we don't return the branch length + * over the input node. */ + for (j = 1; j < num_nodes; j++) { + u = nodes[j]; + v = parent[u]; + if (v != TSK_NULL) { + sum += time[v] - time[u]; + } + } + *ret_tbl = sum; +out: + tsk_safe_free(nodes); + return ret; +} + int TSK_WARN_UNUSED tsk_tree_get_sites( const tsk_tree_t *self, const tsk_site_t **sites, tsk_size_t *sites_length) @@ -3649,14 +3678,14 @@ tsk_tree_get_sites( static int tsk_tree_get_depth_unsafe(const tsk_tree_t *self, tsk_id_t u) { - tsk_id_t v; + const tsk_id_t *restrict parent = self->parent; int depth = 0; if (u == self->virtual_root) { return -1; } - for (v = self->parent[u]; v != TSK_NULL; v = self->parent[v]) { + for (v = parent[u]; v != TSK_NULL; v = parent[v]) { depth++; } return depth; @@ -4443,6 +4472,9 @@ get_smallest_set_bit(uint64_t v) * use a general cost matrix, in which case we'll use the Sankoff algorithm. For * now this is unused. * + * We should also vectorise the function so that several sites can be processed + * at once. + * * The algorithm used here is Hartigan parsimony, "Minimum Mutation Fits to a * Given Tree", Biometrics 1973. */ @@ -4458,30 +4490,34 @@ tsk_tree_map_mutations(tsk_tree_t *self, int8_t *genotypes, int8_t state; }; const tsk_size_t num_samples = self->tree_sequence->num_samples; - const tsk_size_t num_nodes = self->num_nodes; const tsk_id_t *restrict left_child = self->left_child; const tsk_id_t *restrict right_sib = self->right_sib; - const tsk_id_t *restrict parent = self->parent; + const tsk_size_t N = tsk_treeseq_get_num_nodes(self->tree_sequence); const tsk_flags_t *restrict node_flags = self->tree_sequence->tables->nodes.flags; - uint64_t optimal_root_set; - uint64_t *restrict optimal_set = tsk_calloc(num_nodes, sizeof(*optimal_set)); - tsk_id_t *restrict postorder_stack - = tsk_malloc(num_nodes * sizeof(*postorder_stack)); + tsk_id_t *nodes = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*nodes)); + /* Note: to use less memory here and to improve cache performance we should + * probably change to allocating exactly the number of nodes returned by + * a preorder traversal, and then lay the memory out in this order. So, we'd + * need a map from node ID to its index in the preorder traversal, but this + * is trivial to compute. Probably doesn't matter so much at the moment + * when we're doing a single site, but it would make a big difference if + * we were vectorising over lots of sites. */ + uint64_t *restrict optimal_set = tsk_calloc(N + 1, sizeof(*optimal_set)); struct stack_elem *restrict preorder_stack - = tsk_malloc(num_nodes * sizeof(*preorder_stack)); - tsk_id_t postorder_parent, root, u, v; + = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*preorder_stack)); + tsk_id_t root, u, v; /* The largest possible number of transitions is one over every sample */ tsk_state_transition_t *transitions = tsk_malloc(num_samples * sizeof(*transitions)); int8_t allele, ancestral_state; int stack_top; struct stack_elem s; - tsk_size_t j, num_transitions, max_allele_count; + tsk_size_t j, num_transitions, max_allele_count, num_nodes; tsk_size_t allele_count[HARTIGAN_MAX_ALLELES]; tsk_size_t non_missing = 0; int8_t num_alleles = 0; - if (optimal_set == NULL || preorder_stack == NULL || postorder_stack == NULL - || transitions == NULL) { + if (optimal_set == NULL || preorder_stack == NULL || transitions == NULL + || nodes == NULL) { ret = TSK_ERR_NO_MEMORY; goto out; } @@ -4518,68 +4554,33 @@ tsk_tree_map_mutations(tsk_tree_t *self, int8_t *genotypes, } } - for (root = self->left_root; root != TSK_NULL; root = self->right_sib[root]) { - /* Do a post order traversal */ - postorder_stack[0] = root; - stack_top = 0; - postorder_parent = TSK_NULL; - while (stack_top >= 0) { - u = postorder_stack[stack_top]; - if (left_child[u] != TSK_NULL && u != postorder_parent) { - for (v = left_child[u]; v != TSK_NULL; v = right_sib[v]) { - stack_top++; - postorder_stack[stack_top] = v; - } - } else { - stack_top--; - postorder_parent = parent[u]; - - /* Visit u */ - tsk_memset( - allele_count, 0, ((size_t) num_alleles) * sizeof(*allele_count)); - for (v = left_child[u]; v != TSK_NULL; v = right_sib[v]) { - for (allele = 0; allele < num_alleles; allele++) { - allele_count[allele] += bit_is_set(optimal_set[v], allele); - } - } - if (!(node_flags[u] & TSK_NODE_IS_SAMPLE)) { - max_allele_count = 0; - for (allele = 0; allele < num_alleles; allele++) { - max_allele_count - = TSK_MAX(max_allele_count, allele_count[allele]); - } - for (allele = 0; allele < num_alleles; allele++) { - if (allele_count[allele] == max_allele_count) { - optimal_set[u] = set_bit(optimal_set[u], allele); - } - } - } - } - } + ret = tsk_tree_postorder(self, self->virtual_root, nodes, &num_nodes); + if (ret != 0) { + goto out; } - - if (!(options & TSK_MM_FIXED_ANCESTRAL_STATE)) { - optimal_root_set = 0; - /* TODO it's annoying that this is essentially the same as the - * visit function above. It would be nice if we had an extra - * node that was the parent of all roots, then the algorithm - * would work as-is */ + for (j = 0; j < num_nodes; j++) { + u = nodes[j]; tsk_memset(allele_count, 0, ((size_t) num_alleles) * sizeof(*allele_count)); - for (root = self->left_root; root != TSK_NULL; root = right_sib[root]) { + for (v = left_child[u]; v != TSK_NULL; v = right_sib[v]) { for (allele = 0; allele < num_alleles; allele++) { - allele_count[allele] += bit_is_set(optimal_set[root], allele); + allele_count[allele] += bit_is_set(optimal_set[v], allele); } } - max_allele_count = 0; - for (allele = 0; allele < num_alleles; allele++) { - max_allele_count = TSK_MAX(max_allele_count, allele_count[allele]); - } - for (allele = 0; allele < num_alleles; allele++) { - if (allele_count[allele] == max_allele_count) { - optimal_root_set = set_bit(optimal_root_set, allele); + /* the virtual root has no flags defined */ + if (u == (tsk_id_t) N || !(node_flags[u] & TSK_NODE_IS_SAMPLE)) { + max_allele_count = 0; + for (allele = 0; allele < num_alleles; allele++) { + max_allele_count = TSK_MAX(max_allele_count, allele_count[allele]); + } + for (allele = 0; allele < num_alleles; allele++) { + if (allele_count[allele] == max_allele_count) { + optimal_set[u] = set_bit(optimal_set[u], allele); + } } } - ancestral_state = get_smallest_set_bit(optimal_root_set); + } + if (!(options & TSK_MM_FIXED_ANCESTRAL_STATE)) { + ancestral_state = get_smallest_set_bit(optimal_set[N]); } num_transitions = 0; @@ -4622,8 +4623,8 @@ tsk_tree_map_mutations(tsk_tree_t *self, int8_t *genotypes, if (preorder_stack != NULL) { free(preorder_stack); } - if (postorder_stack != NULL) { - free(postorder_stack); + if (nodes != NULL) { + free(nodes); } return ret; } @@ -4888,7 +4889,7 @@ fill_kc_vectors(const tsk_tree_t *t, kc_vectors *kc_vecs) int ret = 0; const tsk_treeseq_t *ts = t->tree_sequence; - stack = tsk_malloc(t->num_nodes * sizeof(*stack)); + stack = tsk_malloc(tsk_tree_get_size_bound(t) * sizeof(*stack)); if (stack == NULL) { ret = TSK_ERR_NO_MEMORY; goto out; @@ -5094,7 +5095,7 @@ update_kc_subtree_state( tsk_id_t *stack = NULL; int ret = 0; - stack = tsk_malloc(t->num_nodes * sizeof(*stack)); + stack = tsk_malloc(tsk_tree_get_size_bound(t) * sizeof(*stack)); if (stack == NULL) { ret = TSK_ERR_NO_MEMORY; goto out; diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 5851b1e2ce..e17ff07a8c 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -436,6 +436,31 @@ be greater than or equal to ``num_nodes``. */ tsk_size_t tsk_tree_get_size_bound(const tsk_tree_t *self); +/** +@brief Returns the sum of the lengths of all branches reachable from + the specified node, or from all roots if node=TSK_NULL. + +@rst +Return the total branch length in a particular subtree or of the +entire tree. If the specified node is TSK_NULL (or the virtual +root) the sum of the lengths of all branches reachable from roots +is returned. Branch length is defined as difference between the time +of a node and its parent. The branch length of a root is zero. + +Note that if the specified node is internal its branch length is +*not* included, so that, e.g., the total branch length of a +leaf node is zero. +@endrst + +@param self A pointer to a tsk_tree_t object. +@param node The tree node to compute branch length or TSK_NULL to return the + total branch length of the tree. +@param ret_tbl A double pointer to store the returned total branch length. +@return 0 on success or a negative value on failure. +*/ +int tsk_tree_get_total_branch_length( + const tsk_tree_t *self, tsk_id_t node, double *ret_tbl); + /** @} */ int tsk_tree_set_root_threshold(tsk_tree_t *self, tsk_size_t root_threshold); diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index ccf94ebd73..4c3ff7a0c6 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -60,6 +60,8 @@ Roughly a 10X performance increase for "preorder", "postorder", "timeasc" and "timedesc" (:user:`jeromekelleher`, :pr:`1704`). +- Substantial performance improvement for ``Tree.total_branch_length`` + (:user:`jeromekelleher`, :issue:`1794` :pr:`1799`) -------------------- [0.3.7] - 2021-07-08 diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 96f340cf0c..71155c23ed 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -9433,6 +9433,26 @@ Tree_get_num_edges(Tree *self) return ret; } +static PyObject * +Tree_get_total_branch_length(Tree *self) +{ + PyObject *ret = NULL; + double length; + int err; + + if (Tree_check_state(self) != 0) { + goto out; + } + err = tsk_tree_get_total_branch_length(self->tree, TSK_NULL, &length); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue("d", length); +out: + return ret; +} + static PyObject * Tree_get_index(Tree *self) { @@ -10363,6 +10383,10 @@ static PyMethodDef Tree_methods[] = { .ml_meth = (PyCFunction) Tree_get_num_edges, .ml_flags = METH_NOARGS, .ml_doc = "Returns the number of branches in this tree." }, + { .ml_name = "get_total_branch_length", + .ml_meth = (PyCFunction) Tree_get_total_branch_length, + .ml_flags = METH_NOARGS, + .ml_doc = "Returns the sum of the branch lengths reachable from roots" }, { .ml_name = "get_left", .ml_meth = (PyCFunction) Tree_get_left, .ml_flags = METH_NOARGS, diff --git a/python/tests/test_parsimony.py b/python/tests/test_parsimony.py index cf8afbb6bc..dc31d1d1da 100644 --- a/python/tests/test_parsimony.py +++ b/python/tests/test_parsimony.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2019-2020 Tskit Developers +# Copyright (c) 2019-2021 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -22,10 +22,10 @@ """ Tests for the tree parsimony methods. """ +import dataclasses import io import itertools -import attr import Bio.Phylo.TreeConstruction import msprime import numpy as np @@ -219,48 +219,39 @@ def hartigan_map_mutations(tree, genotypes, alleles, ancestral_state=None): optimal_set[u] = 1 allele_count = np.zeros(num_alleles, dtype=int) - for root in tree.roots: - for u in tree.nodes(root, order="postorder"): - allele_count[:] = 0 - for v in tree.children(u): - for j in range(num_alleles): - allele_count[j] += optimal_set[v, j] - if not tree.is_sample(u): - max_allele_count = np.max(allele_count) - optimal_set[u, allele_count == max_allele_count] = 1 - - if ancestral_state is None: + for u in tree.nodes(tree.virtual_root, order="postorder"): allele_count[:] = 0 - for v in tree.roots: + for v in tree.children(u): for j in range(num_alleles): allele_count[j] += optimal_set[v, j] - max_allele_count = np.max(allele_count) - optimal_root_set = np.zeros(num_alleles, dtype=int) - optimal_root_set[allele_count == max_allele_count] = 1 - ancestral_state = np.argmax(optimal_root_set) + if not tree.is_sample(u): + max_allele_count = np.max(allele_count) + optimal_set[u, allele_count == max_allele_count] = 1 - @attr.s + if ancestral_state is None: + ancestral_state = np.argmax(optimal_set[tree.virtual_root]) + + @dataclasses.dataclass class StackElement: - node = attr.ib() - state = attr.ib() - mutation_parent = attr.ib() + node: int + state: int + mutation_parent: int mutations = [] - for root in tree.roots: - stack = [StackElement(root, ancestral_state, -1)] - while len(stack) > 0: - s = stack.pop() - if optimal_set[s.node, s.state] == 0: - s.state = np.argmax(optimal_set[s.node]) - mutation = tskit.Mutation( - node=s.node, - derived_state=alleles[s.state], - parent=s.mutation_parent, - ) - s.mutation_parent = len(mutations) - mutations.append(mutation) - for v in tree.children(s.node): - stack.append(StackElement(v, s.state, s.mutation_parent)) + stack = [StackElement(root, ancestral_state, -1) for root in reversed(tree.roots)] + while len(stack) > 0: + s = stack.pop() + if optimal_set[s.node, s.state] == 0: + s.state = np.argmax(optimal_set[s.node]) + mutation = tskit.Mutation( + node=s.node, + derived_state=alleles[s.state], + parent=s.mutation_parent, + ) + s.mutation_parent = len(mutations) + mutations.append(mutation) + for v in tree.children(s.node): + stack.append(StackElement(v, s.state, s.mutation_parent)) return alleles[ancestral_state], mutations diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 450f3bb991..6373dcfe5e 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -994,7 +994,7 @@ def total_branch_length(self): :return: The sum of lengths of branches in this tree. :rtype: float """ - return sum(self.branch_length(u) for u in self.nodes()) + return self._ll_tree.get_total_branch_length() def get_mrca(self, u, v): # Deprecated alias for mrca