Skip to content

Commit

Permalink
Add order argument to ts.nodes()
Browse files Browse the repository at this point in the history
  • Loading branch information
hyanwong committed Aug 26, 2022
1 parent 1ad7ac6 commit da8d73c
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 4 deletions.
5 changes: 5 additions & 0 deletions python/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
[0.5.3] - 2022-XX-XX
--------------------

**Features**

- The ``ts.nodes`` method now takes an ``order`` parameter so that nodes
can be visited in time order (:user:`hyanwong`, :pr:`2471`, :issue:`2370`)

**Changes**

- Single statistics computed with ``TreeSequence.general_stat`` are now
Expand Down
27 changes: 27 additions & 0 deletions python/tests/test_highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1441,6 +1441,33 @@ def verify_pairwise_diversity(self, ts):
def test_pairwise_diversity(self, ts):
self.verify_pairwise_diversity(ts)

def test_bad_node_order(self):
ts = tskit.TableCollection(1).tree_sequence()
for order in ["abc", 0, 1, False]:
with pytest.raises(ValueError, match="order"):
ts.nodes(order=order)

def test_node_order(self):
ts = tskit.Tree.generate_balanced(8).tree_sequence
tables = ts.dump_tables()
tables.subset(np.arange(ts.num_nodes - 1, -1, -1)) # reverse the node order
tables.sort()
ts = tables.tree_sequence()
order = [n.id for n in ts.nodes()]
assert order == list(range(ts.num_nodes))
order = [n.id for n in ts.nodes(order="id")]
assert order == list(range(ts.num_nodes))
order = [n.id for n in ts.nodes(order="timeasc")]
assert order != list(range(ts.num_nodes))
order = np.array(order)
assert np.all(ts.nodes_time[order] == np.sort(ts.nodes_time))
# Check it conforms to the order of parents in the edge table
parent_only_order = order[np.isin(order, ts.edges_parent)]
edge_parents = np.concatenate(
(ts.edges_parent[:-1][np.diff(ts.edges_parent) != 0], ts.edges_parent[-1:])
)
assert np.all(parent_only_order == edge_parents)

def verify_edgesets(self, ts):
"""
Verifies that the edgesets we return are equivalent to the original edges.
Expand Down
30 changes: 26 additions & 4 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -3800,10 +3800,16 @@ class SimpleContainerSequence:
Simple wrapper to allow arrays of SimpleContainers (e.g. edges, nodes) that have a
function allowing access by index (e.g. ts.edge(i), ts.node(i)) to be treated as a
python sequence, allowing forward and reverse iteration.
To generate a sequence of items in a different order, the ``order`` parameter allows
an array of indexes to be passed in, such as returned from np.argsort or np.lexsort.
"""

def __init__(self, getter, length):
self.getter = getter
def __init__(self, getter, length, order=None):
if order is None:
self.getter = getter
else:
self.getter = lambda index: getter(order[index])
self.length = length

def __len__(self):
Expand Down Expand Up @@ -4463,15 +4469,31 @@ def individuals(self):
"""
return SimpleContainerSequence(self.individual, self.num_individuals)

def nodes(self):
def nodes(self, *, order=None):
"""
Returns an iterable sequence of all the :ref:`nodes <sec_node_table_definition>`
in this tree sequence.
.. note:: Although node ids are commonly ordered by node time, this is not a
formal tree sequence requirement. If you wish to iterate over nodes in
time order, you should therefore use ``order="timeasc"`` (and wrap the
resulting sequence in the standard Python :func:`python:reversed` function
if you wish to iterate over older nodes before younger ones)
:param str order: The order in which the nodes should be returned: must be
one of "id" (default) or "timeasc" (ascending order of time, then by
ascending node id, matching the first two ordering requirements of
parent nodes in a :meth:`sorted <TableCollection.sort>` edge table).
:return: An iterable sequence of all nodes.
:rtype: Sequence(:class:`.Node`)
"""
return SimpleContainerSequence(self.node, self.num_nodes)
order = "id" if order is None else order
if order not in ["id", "timeasc"]:
raise ValueError('order must be "id" or "timeasc"')
odr = None
if order == "timeasc":
odr = np.lexsort((np.arange(self.num_nodes), self.nodes_time))
return SimpleContainerSequence(self.node, self.num_nodes, order=odr)

def edges(self):
"""
Expand Down

0 comments on commit da8d73c

Please sign in to comment.