From da8d73cc25ebee499dfa99324b8c61d0f3a2f235 Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Tue, 23 Aug 2022 21:26:01 +0100 Subject: [PATCH] Add order argument to ts.nodes() Fixes #2370 --- python/CHANGELOG.rst | 5 +++++ python/tests/test_highlevel.py | 27 +++++++++++++++++++++++++++ python/tskit/trees.py | 30 ++++++++++++++++++++++++++---- 3 files changed, 58 insertions(+), 4 deletions(-) diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index bfd3e8d776..51379113d3 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -2,6 +2,11 @@ [0.5.3] - 2022-XX-XX -------------------- +**Features** + + - The ``ts.nodes`` method now takes an ``order`` parameter so that nodes + can be visited in time order (:user:`hyanwong`, :pr:`2471`, :issue:`2370`) + **Changes** - Single statistics computed with ``TreeSequence.general_stat`` are now diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index 31399199a7..ddc7e54f9d 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -1441,6 +1441,33 @@ def verify_pairwise_diversity(self, ts): def test_pairwise_diversity(self, ts): self.verify_pairwise_diversity(ts) + def test_bad_node_order(self): + ts = tskit.TableCollection(1).tree_sequence() + for order in ["abc", 0, 1, False]: + with pytest.raises(ValueError, match="order"): + ts.nodes(order=order) + + def test_node_order(self): + ts = tskit.Tree.generate_balanced(8).tree_sequence + tables = ts.dump_tables() + tables.subset(np.arange(ts.num_nodes - 1, -1, -1)) # reverse the node order + tables.sort() + ts = tables.tree_sequence() + order = [n.id for n in ts.nodes()] + assert order == list(range(ts.num_nodes)) + order = [n.id for n in ts.nodes(order="id")] + assert order == list(range(ts.num_nodes)) + order = [n.id for n in ts.nodes(order="timeasc")] + assert order != list(range(ts.num_nodes)) + order = np.array(order) + assert np.all(ts.nodes_time[order] == np.sort(ts.nodes_time)) + # Check it conforms to the order of parents in the edge table + parent_only_order = order[np.isin(order, ts.edges_parent)] + edge_parents = np.concatenate( + (ts.edges_parent[:-1][np.diff(ts.edges_parent) != 0], ts.edges_parent[-1:]) + ) + assert np.all(parent_only_order == edge_parents) + def verify_edgesets(self, ts): """ Verifies that the edgesets we return are equivalent to the original edges. diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 39f05e7217..bc0be9fc23 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -3800,10 +3800,16 @@ class SimpleContainerSequence: Simple wrapper to allow arrays of SimpleContainers (e.g. edges, nodes) that have a function allowing access by index (e.g. ts.edge(i), ts.node(i)) to be treated as a python sequence, allowing forward and reverse iteration. + + To generate a sequence of items in a different order, the ``order`` parameter allows + an array of indexes to be passed in, such as returned from np.argsort or np.lexsort. """ - def __init__(self, getter, length): - self.getter = getter + def __init__(self, getter, length, order=None): + if order is None: + self.getter = getter + else: + self.getter = lambda index: getter(order[index]) self.length = length def __len__(self): @@ -4463,15 +4469,31 @@ def individuals(self): """ return SimpleContainerSequence(self.individual, self.num_individuals) - def nodes(self): + def nodes(self, *, order=None): """ Returns an iterable sequence of all the :ref:`nodes ` in this tree sequence. + .. note:: Although node ids are commonly ordered by node time, this is not a + formal tree sequence requirement. If you wish to iterate over nodes in + time order, you should therefore use ``order="timeasc"`` (and wrap the + resulting sequence in the standard Python :func:`python:reversed` function + if you wish to iterate over older nodes before younger ones) + + :param str order: The order in which the nodes should be returned: must be + one of "id" (default) or "timeasc" (ascending order of time, then by + ascending node id, matching the first two ordering requirements of + parent nodes in a :meth:`sorted ` edge table). :return: An iterable sequence of all nodes. :rtype: Sequence(:class:`.Node`) """ - return SimpleContainerSequence(self.node, self.num_nodes) + order = "id" if order is None else order + if order not in ["id", "timeasc"]: + raise ValueError('order must be "id" or "timeasc"') + odr = None + if order == "timeasc": + odr = np.lexsort((np.arange(self.num_nodes), self.nodes_time)) + return SimpleContainerSequence(self.node, self.num_nodes, order=odr) def edges(self): """