Skip to content

Commit

Permalink
Deprecate Tree.num_nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
hyanwong committed Nov 30, 2021
1 parent 1e15ea0 commit 545924b
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 49 deletions.
5 changes: 5 additions & 0 deletions python/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

**Breaking changes**

- The ``Tree.num_nodes`` method is now deprecated with a warning, because it confusingly
returns the number of nodes in the entire tree sequence, rather than in the tree. Text
summaries of trees (e.g. ``str(tree)``) now return the number of nodes in the tree,
not in the entire tree sequence (:user:`hyanwong`, :issue:`1966` :pr:`1968`)

- The CLI ``info`` command now gives more detailed information on the tree sequence
(:user:`benjeffery`, :pr:`1611`)

Expand Down
17 changes: 0 additions & 17 deletions python/_tskitmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -9486,19 +9486,6 @@ Tree_get_sample_size(Tree *self)
return ret;
}

static PyObject *
Tree_get_num_nodes(Tree *self)
{
PyObject *ret = NULL;

if (Tree_check_state(self) != 0) {
goto out;
}
ret = Py_BuildValue("n", (Py_ssize_t) self->tree->num_nodes);
out:
return ret;
}

static PyObject *
Tree_get_num_roots(Tree *self)
{
Expand Down Expand Up @@ -10478,10 +10465,6 @@ static PyMethodDef Tree_methods[] = {
.ml_meth = (PyCFunction) Tree_get_sample_size,
.ml_flags = METH_NOARGS,
.ml_doc = "Returns the number of samples in this tree." },
{ .ml_name = "get_num_nodes",
.ml_meth = (PyCFunction) Tree_get_num_nodes,
.ml_flags = METH_NOARGS,
.ml_doc = "Returns the number of nodes in this tree." },
{ .ml_name = "get_num_roots",
.ml_meth = (PyCFunction) Tree_get_num_roots,
.ml_flags = METH_NOARGS,
Expand Down
2 changes: 1 addition & 1 deletion python/tests/test_combinatorics.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def test_span(self):
span = 8
# Create a start tree, with a single root
tsk_tree = tskit.Tree.unrank(n, (0, 0), span=span)
assert tsk_tree.num_nodes == n + 1
assert tsk_tree.tree_sequence.num_nodes == n + 1
assert tsk_tree.interval.left == 0
assert tsk_tree.interval.right == span
assert tsk_tree.tree_sequence.sequence_length == span
Expand Down
6 changes: 3 additions & 3 deletions python/tests/test_drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def get_left_neighbour(tree, traversal_order):
parent = tree.parent(u)
children[parent].append(u)

left_neighbour = np.full(tree.num_nodes, tskit.NULL, dtype=int)
left_neighbour = np.full(tree.tree_sequence.num_nodes, tskit.NULL, dtype=int)
for u in tree.nodes():
next_left = tskit.NULL
child = u
Expand Down Expand Up @@ -1613,10 +1613,10 @@ def test_node_labels(self):
labels = {u: "XXX" for u in t.nodes()}
svg = t.draw(format="svg", node_labels=labels)
self.verify_basic_svg(svg)
assert svg.count("XXX") == t.num_nodes
assert svg.count("XXX") == t.tree_sequence.num_nodes
svg = t.draw_svg(node_label_attrs={u: {"text": labels[u]} for u in t.nodes()})
self.verify_basic_svg(svg)
assert svg.count("XXX") == t.num_nodes
assert svg.count("XXX") == t.tree_sequence.num_nodes

def test_one_node_label(self):
t = self.get_binary_tree()
Expand Down
22 changes: 15 additions & 7 deletions python/tests/test_highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,11 +880,11 @@ class HighLevelTestCase:

def verify_tree_mrcas(self, st):
# Check the mrcas
oriented_forest = [st.get_parent(j) for j in range(st.num_nodes)]
oriented_forest = [st.get_parent(j) for j in range(st.tree_sequence.num_nodes)]
mrca_calc = tests.MRCACalculator(oriented_forest)
# We've done exhaustive tests elsewhere, no need to go
# through the combinations.
for j in range(st.num_nodes):
for j in range(st.tree_sequence.num_nodes):
mrca = st.get_mrca(0, j)
assert mrca == mrca_calc.get_mrca(0, j)
if mrca != tskit.NULL:
Expand Down Expand Up @@ -3050,7 +3050,7 @@ def verify_nx_algorithm_equivalence(self, tree, g):
) | {root}

# test MRCA
if tree.num_nodes < 20:
if tree.tree_sequence.num_nodes < 20:
for u, v in itertools.combinations(tree.nodes(), 2):
mrca = nx.lowest_common_ancestor(g, u, v)
if mrca is None:
Expand Down Expand Up @@ -3165,12 +3165,12 @@ def is_descendant(tree, u, v):
return v in path

tree = self.get_tree()
for u, v in itertools.product(range(tree.num_nodes), repeat=2):
for u, v in itertools.product(range(tree.tree_sequence.num_nodes), repeat=2):
assert is_descendant(tree, u, v) == tree.is_descendant(u, v)
# All nodes are descendents of themselves
for u in range(tree.num_nodes + 1):
for u in range(tree.tree_sequence.num_nodes + 1):
assert tree.is_descendant(u, u)
for bad_node in [-1, -2, tree.num_nodes + 1]:
for bad_node in [-1, -2, tree.tree_sequence.num_nodes + 1]:
with pytest.raises(ValueError):
tree.is_descendant(0, bad_node)
with pytest.raises(ValueError):
Expand Down Expand Up @@ -3206,10 +3206,19 @@ def test_apis(self):
assert t1.get_mrca(*pair) == t1.mrca(*pair)
assert t1.get_tmrca(*pair) == t1.tmrca(*pair)

@pytest.mark.filterwarnings("error::FutureWarning")
def test_deprecated_apis(self):
t1 = self.get_tree()
assert t1.get_length() == t1.span
assert t1.length == t1.span
assert t1.num_nodes == t1.tree_sequence.num_nodes

@pytest.mark.filterwarnings("ignore::FutureWarning")
def test_deprecated_api_warnings(self):
# Deprecated and will be removed
t1 = self.get_tree()
with pytest.raises(FutureWarning, match="Tree.tree_sequence.num_nodes"):
t1.num_nodes

def test_seek_index(self):
ts = msprime.simulate(10, recombination_rate=3, length=5, random_seed=42)
Expand Down Expand Up @@ -3389,7 +3398,6 @@ def test_clear(self):

def verify_trees_identical(self, t1, t2):
assert t1.tree_sequence is t2.tree_sequence
assert t1.num_nodes is t2.num_nodes
assert np.all(t1.parent_array == t2.parent_array)
assert np.all(t1.left_child_array == t2.left_child_array)
assert np.all(t1.right_child_array == t2.right_child_array)
Expand Down
19 changes: 9 additions & 10 deletions python/tests/test_lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@
import tskit


def get_tracked_sample_counts(st, tracked_samples):
def get_tracked_sample_counts(ts, st, tracked_samples):
"""
Returns a list giving the number of samples in the specified list
that are in the subtree rooted at each node.
"""
nu = [0 for j in range(st.get_num_nodes())]
nu = [0 for j in range(ts.get_num_nodes())]
for j in tracked_samples:
# Duplicates not permitted.
assert nu[j] == 0
Expand All @@ -59,7 +59,7 @@ def get_sample_counts(tree_sequence, st):
"""
Returns a list of the sample node counts for the specified tree.
"""
nu = [0 for j in range(st.get_num_nodes())]
nu = [0 for j in range(tree_sequence.get_num_nodes())]
for j in range(tree_sequence.get_num_samples()):
u = j
while u != _tskit.NULL:
Expand Down Expand Up @@ -2669,7 +2669,6 @@ def test_constructor(self):
_tskit.Tree(ts, options=bad_type)
for ts in self.get_example_tree_sequences():
st = _tskit.Tree(ts)
assert st.get_num_nodes() == ts.get_num_nodes()
# An uninitialised tree should always be zero.
samples = ts.get_samples()
assert st.get_left_child(st.get_virtual_root()) == samples[0]
Expand Down Expand Up @@ -2741,16 +2740,16 @@ def test_count_all_samples(self):
st = _tskit.Tree(ts)
# Without initialisation we should be 0 samples for every node
# that is not a sample.
for j in range(st.get_num_nodes()):
for j in range(ts.get_num_nodes()):
count = 1 if j < ts.get_num_samples() else 0
assert st.get_num_samples(j) == count
assert st.get_num_tracked_samples(j) == 0
while st.next():
nu = get_sample_counts(ts, st)
nu_prime = [st.get_num_samples(j) for j in range(st.get_num_nodes())]
nu_prime = [st.get_num_samples(j) for j in range(ts.get_num_nodes())]
assert nu == nu_prime
# For tracked samples, this should be all zeros.
nu = [st.get_num_tracked_samples(j) for j in range(st.get_num_nodes())]
nu = [st.get_num_tracked_samples(j) for j in range(ts.get_num_nodes())]
assert nu == list([0 for _ in nu])

def test_count_tracked_samples(self):
Expand All @@ -2772,9 +2771,9 @@ def test_count_tracked_samples(self):
random.shuffle(subset)
st = _tskit.Tree(ts, tracked_samples=subset)
while st.next():
nu = get_tracked_sample_counts(st, subset)
nu = get_tracked_sample_counts(ts, st, subset)
nu_prime = [
st.get_num_tracked_samples(j) for j in range(st.get_num_nodes())
st.get_num_tracked_samples(j) for j in range(ts.get_num_nodes())
]
assert nu == nu_prime
# Passing duplicated values should raise an error
Expand Down Expand Up @@ -2977,7 +2976,7 @@ def test_sample_list(self):
assert t.get_right_sample(j) == j

# All non-tree nodes should have 0
for j in range(t.get_num_nodes()):
for j in range(ts.get_num_nodes()):
if (
t.get_parent(j) == _tskit.NULL
and t.get_left_child(j) == _tskit.NULL
Expand Down
4 changes: 2 additions & 2 deletions python/tests/test_parsimony.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def sankoff_score(tree, genotypes, cost_matrix):
the cost of transitioning from each allele to every other allele.
"""
num_alleles = cost_matrix.shape[0]
S = np.zeros((tree.num_nodes, num_alleles))
S = np.zeros((tree.tree_sequence.num_nodes, num_alleles))
for allele, u in zip(genotypes, tree.tree_sequence.samples()):
S[u, :] = INF
S[u, allele] = 0
Expand Down Expand Up @@ -126,7 +126,7 @@ def fitch_map_mutations(tree, genotypes, alleles):
if np.sum(not_missing) == 0:
raise ValueError("Must have at least one non-missing genotype")
num_alleles = np.max(genotypes[not_missing]) + 1
A = np.zeros((tree.num_nodes, num_alleles), dtype=np.int8)
A = np.zeros((tree.tree_sequence.num_nodes, num_alleles), dtype=np.int8)
for allele, u in zip(genotypes, tree.tree_sequence.samples()):
if allele != -1:
A[u, allele] = 1
Expand Down
4 changes: 2 additions & 2 deletions python/tskit/drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1656,7 +1656,7 @@ def get_left_neighbour(tree, traversal_order):
for u in tree.nodes(order=traversal_order):
children[tree.parent(u)].append(u)

left_neighbour = np.full(tree.num_nodes + 1, NULL, dtype=int)
left_neighbour = np.full(tree.tree_sequence.num_nodes + 1, NULL, dtype=int)

def find_neighbours(u, neighbour):
left_neighbour[u] = neighbour
Expand All @@ -1677,7 +1677,7 @@ def get_left_child(tree, traversal_order):
specified traversal order. If a node has no children or NULL is passed
in, return NULL.
"""
left_child = np.full(tree.num_nodes + 1, NULL, dtype=int)
left_child = np.full(tree.tree_sequence.num_nodes + 1, NULL, dtype=int)
for u in tree.nodes(order=traversal_order):
parent = tree.parent(u)
if parent != NULL and left_child[parent] == NULL:
Expand Down
25 changes: 19 additions & 6 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,12 +1403,23 @@ def is_descendant(self, u, v):
def num_nodes(self):
"""
Returns the number of nodes in the :class:`TreeSequence` this tree is in.
Equivalent to ``tree.tree_sequence.num_nodes``. To find the number of
nodes that are reachable from all roots use ``len(list(tree.nodes()))``.
Equivalent to ``tree.tree_sequence.num_nodes``.
.. deprecated:: 0.4
Use :attr:`Tree.tree_sequence.num_nodes<TreeSequence.num_nodes>` if you want
the number of nodes in the entire tree sequence, or
``len(tree.preorder())`` to find the number of nodes that are
reachable from all roots in this tree.
:rtype: int
"""
return self._ll_tree.get_num_nodes()
warnings.warn(
"This property is a deprecated alias for Tree.tree_sequence.num_nodes "
"and will be removed in the future",
FutureWarning,
)
return self.tree_sequence.num_nodes

@property
def num_roots(self):
Expand Down Expand Up @@ -2344,7 +2355,7 @@ def _as_newick_fast(self, *, root, precision, legacy_ms_labels):
single_node_size = (
5 + max_label_size + math.ceil(math.log10(root_time)) + precision
)
buffer_size = 1 + single_node_size * self.num_nodes
buffer_size = 1 + single_node_size * self.tree_sequence.num_nodes
return self._ll_tree.get_newick(
precision=precision,
root=root,
Expand Down Expand Up @@ -2538,7 +2549,9 @@ def parent_dict(self):

def get_parent_dict(self):
pi = {
u: self.parent(u) for u in range(self.num_nodes) if self.parent(u) != NULL
u: self.parent(u)
for u in range(self.tree_sequence.num_nodes)
if self.parent(u) != NULL
}
return pi

Expand All @@ -2550,7 +2563,7 @@ def __str__(self):
f"{self.interval.left:.8g}-{self.interval.right:.8g} ({self.span:.8g})",
],
["Roots", str(self.num_roots)],
["Nodes", str(self.num_nodes)],
["Nodes", str(len(self.preorder()))],
["Sites", str(self.num_sites)],
["Mutations", str(self.num_mutations)],
["Total Branch Length", f"{self.total_branch_length:.8g}"],
Expand Down
2 changes: 1 addition & 1 deletion python/tskit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def tree_html(tree):
<tr><td>Index</td><td>{tree.index}</td></tr>
<tr><td>Interval</td><td>{tree.interval.left:.8g}-{tree.interval.right:.8g} ({tree.span:.8g})</td></tr>
<tr><td>Roots</td><td>{tree.num_roots}</td></tr>
<tr><td>Nodes</td><td>{tree.num_nodes}</td></tr>
<tr><td>Nodes</td><td>{len(tree.preorder())}</td></tr>
<tr><td>Sites</td><td>{tree.num_sites}</td></tr>
<tr><td>Mutations</td><td>{tree.num_mutations}</td></tr>
<tr><td>Total Branch Length</td><td>{tree.total_branch_length:.8g}</td></tr>
Expand Down

0 comments on commit 545924b

Please sign in to comment.