diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 78e7ac2cbc..1178169b28 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -17,6 +17,10 @@ - ``tree.mrca`` now takes 2 or more arguments and gives the common ancestor of them all. (:user:`savitakartik`, :issue:`1340`, :pr:`2121`) +- Add a ``edge`` attribute to the ``Mutation`` class that gives the ID of the + edge that the mutation falls on. + (:user:`jeromekelleher`, :issue:`685`, :pr:`2279`). + **Breaking Changes** - The JSON metadata codec now interprets the empty string as an empty object. This means diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 2c9c737fba..feb39b75ab 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -322,7 +322,7 @@ make_metadata(const char *metadata, Py_ssize_t length) } static PyObject * -make_mutation(const tsk_mutation_t *mutation) +make_mutation_row(const tsk_mutation_t *mutation) { PyObject *ret = NULL; PyObject *metadata = NULL; @@ -339,6 +339,24 @@ make_mutation(const tsk_mutation_t *mutation) return ret; } +static PyObject * +make_mutation_object(const tsk_mutation_t *mutation) +{ + PyObject *ret = NULL; + PyObject *metadata = NULL; + + metadata = make_metadata(mutation->metadata, (Py_ssize_t) mutation->metadata_length); + if (metadata == NULL) { + goto out; + } + ret = Py_BuildValue("iis#iOdi", mutation->site, mutation->node, + mutation->derived_state, (Py_ssize_t) mutation->derived_state_length, + mutation->parent, metadata, mutation->time, mutation->edge); +out: + Py_XDECREF(metadata); + return ret; +} + static PyObject * make_mutation_id_list(const tsk_mutation_t *mutations, tsk_size_t length) { @@ -4016,7 +4034,7 @@ MutationTable_get_row(MutationTable *self, PyObject *args) handle_library_error(err); goto out; } - ret = make_mutation(&mutation); + ret = make_mutation_row(&mutation); out: return ret; } @@ -7984,7 +8002,7 @@ TreeSequence_get_mutation(TreeSequence *self, PyObject *args) handle_library_error(err); goto out; } - ret = make_mutation(&record); + ret = make_mutation_object(&record); out: return ret; } diff --git a/python/tests/__init__.py b/python/tests/__init__.py index f5559979f7..bb4151d978 100644 --- a/python/tests/__init__.py +++ b/python/tests/__init__.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2021 Tskit Developers +# Copyright (c) 2018-2022 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -138,7 +138,15 @@ def __init__(self, tree_sequence, breakpoints=None): ll_ts = self._tree_sequence._ll_tree_sequence def make_mutation(id_): - site, node, derived_state, parent, metadata, time = ll_ts.get_mutation(id_) + ( + site, + node, + derived_state, + parent, + metadata, + time, + edge, + ) = ll_ts.get_mutation(id_) return tskit.Mutation( id=id_, site=site, @@ -147,6 +155,7 @@ def make_mutation(id_): derived_state=derived_state, parent=parent, metadata=metadata, + edge=edge, metadata_decoder=tskit.metadata.parse_metadata_schema( ll_ts.get_table_metadata_schemas().mutation ).decode_row, diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 2a62a1ff67..02caa178d4 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -1215,22 +1215,20 @@ def test_dump_equality(self, tmp_path): ts2.dump_tables(tc2) assert tc.equals(tc2) - def verify_mutations(self, ts): - mutations = [ts.get_mutation(j) for j in range(ts.get_num_mutations())] - assert ts.get_num_mutations() > 0 - assert len(mutations) == ts.get_num_mutations() - # Check the form of the mutations - for j, (position, nodes, index) in enumerate(mutations): - assert j == index - for node in nodes: + def test_get_mutation_interface(self): + for ts in self.get_example_tree_sequences(): + mutations = [ts.get_mutation(j) for j in range(ts.get_num_mutations())] + assert len(mutations) == ts.get_num_mutations() + # Check the form of the mutations + for packed in mutations: + site, node, derived_state, parent, metadata, time, edge = packed + assert isinstance(site, int) assert isinstance(node, int) - assert node >= 0 - assert node <= ts.get_num_nodes() - assert isinstance(position, float) - assert position > 0 - assert position < ts.get_sequence_length() - # mutations must be sorted by position order. - assert mutations == sorted(mutations) + assert isinstance(derived_state, str) + assert isinstance(parent, int) + assert isinstance(metadata, bytes) + assert isinstance(time, float) + assert isinstance(edge, int) def test_get_edge_interface(self): for ts in self.get_example_tree_sequences(): @@ -2718,12 +2716,13 @@ def test_sites(self): all_tree_sites.extend(tree_sites) for ( position, - _ancestral_state, + ancestral_state, mutations, index, metadata, ) in tree_sites: assert st.get_left() <= position < st.get_right() + assert isinstance(ancestral_state, str) assert index == j assert metadata == b"" for mut_id in mutations: @@ -2734,11 +2733,13 @@ def test_sites(self): parent, metadata, time, + edge, ) = ts.get_mutation(mut_id) assert site == index assert mutation_id == mut_id assert st.get_parent(node) != _tskit.NULL assert metadata == b"" + assert edge != _tskit.NULL mutation_id += 1 j += 1 assert all_tree_sites == all_sites diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index fc15494d68..af0f315215 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -6201,6 +6201,74 @@ def test_many_multiroot_trees_recurrent_mutations(self): self.verify_branch_mutations(ts, mutations_per_branch) +class TestMutationEdge: + def verify_mutation_edge(self, ts): + # print(ts.tables) + for mutation in ts.mutations(): + site = ts.site(mutation.site) + if mutation.edge == tskit.NULL: + edges = [ + edge + for edge in ts.edges() + if edge.left <= site.position < edge.right + and mutation.node == edge.child + ] + assert len(edges) == 0 + else: + edge = ts.edge(mutation.edge) + assert edge.left <= site.position < edge.right + assert edge.child == mutation.node + + for tree in ts.trees(): + for site in tree.sites(): + for mutation in site.mutations: + assert mutation.edge == ts.mutation(mutation.id).edge + if mutation.edge == tskit.NULL: + assert tree.parent(mutation.node) == tskit.NULL + + def verify_branch_mutations(self, ts, mutations_per_branch): + ts = tsutil.insert_branch_mutations(ts, mutations_per_branch) + assert ts.num_mutations > 1 + self.verify_mutation_edge(ts) + + def test_single_tree_one_mutation_per_branch(self): + ts = msprime.simulate(6, random_seed=10) + self.verify_branch_mutations(ts, 1) + + def test_single_tree_two_mutations_per_branch(self): + ts = msprime.simulate(10, random_seed=9) + self.verify_branch_mutations(ts, 2) + + def test_single_tree_three_mutations_per_branch(self): + ts = msprime.simulate(8, random_seed=9) + self.verify_branch_mutations(ts, 3) + + def test_single_multiroot_tree_recurrent_mutations(self): + ts = msprime.simulate(6, random_seed=10) + ts = tsutil.decapitate(ts, ts.num_edges // 2) + for mutations_per_branch in [1, 2, 3]: + self.verify_branch_mutations(ts, mutations_per_branch) + + def test_many_multiroot_trees_recurrent_mutations(self): + ts = msprime.simulate(7, recombination_rate=1, random_seed=10) + assert ts.num_trees > 3 + ts = tsutil.decapitate(ts, ts.num_edges // 2) + for mutations_per_branch in [1, 2, 3]: + self.verify_branch_mutations(ts, mutations_per_branch) + + @pytest.mark.parametrize("n", range(2, 5)) + @pytest.mark.parametrize("mutations_per_branch", range(3)) + def test_balanced_binary_tree(self, n, mutations_per_branch): + ts = tskit.Tree.generate_balanced(4).tree_sequence + # These trees have a handy property + assert all(edge.id == edge.child for edge in ts.edges()) + for mutation in ts.mutations(): + assert mutation.edge == mutation.node + for site in ts.first().sites(): + for mutation in site.mutations: + assert mutation.edge == mutation.node + + class TestMutationTime: """ Tests that mutation time is correctly specified, and that we correctly diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 2282b8f541..11b83d6bdb 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -319,7 +319,16 @@ class Mutation(util.Dataclass): underlying tree sequence data. """ - __slots__ = ["id", "site", "node", "derived_state", "parent", "metadata", "time"] + __slots__ = [ + "id", + "site", + "node", + "derived_state", + "parent", + "metadata", + "time", + "edge", + ] id: int # noqa A003 """ The integer ID of this mutation. Varies from 0 to @@ -363,6 +372,10 @@ class Mutation(util.Dataclass): """ The occurrence time of this mutation. """ + edge: int + """ + The ID of the edge that this mutation is on. + """ # To get default values on slots we define a custom init def __init__( @@ -374,6 +387,7 @@ def __init__( derived_state=None, parent=NULL, metadata=b"", + edge=NULL, ): self.id = id self.site = site @@ -382,6 +396,7 @@ def __init__( self.derived_state = derived_state self.parent = parent self.metadata = metadata + self.edge = edge # We need a custom eq to compare unknown times. def __eq__(self, other): @@ -392,6 +407,7 @@ def __eq__(self, other): and self.node == other.node and self.derived_state == other.derived_state and self.parent == other.parent + and self.edge == other.edge and self.metadata == other.metadata and ( self.time == other.time @@ -5179,6 +5195,7 @@ def mutation(self, id_): parent, metadata, time, + edge, ) = self._ll_tree_sequence.get_mutation(id_) return Mutation( id=id_, @@ -5188,6 +5205,7 @@ def mutation(self, id_): parent=parent, metadata=metadata, time=time, + edge=edge, metadata_decoder=self.table_metadata_schemas.mutation.decode_row, )