Skip to content

Commit

Permalink
Add Python interface for Mutation.edge
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher authored and mergify[bot] committed May 20, 2022
1 parent 856cc6a commit 6736cef
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 22 deletions.
4 changes: 4 additions & 0 deletions python/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
- ``tree.mrca`` now takes 2 or more arguments and gives the common ancestor of them all.
(:user:`savitakartik`, :issue:`1340`, :pr:`2121`)

- Add a ``edge`` attribute to the ``Mutation`` class that gives the ID of the
edge that the mutation falls on.
(:user:`jeromekelleher`, :issue:`685`, :pr:`2279`).

**Breaking Changes**

- The JSON metadata codec now interprets the empty string as an empty object. This means
Expand Down
24 changes: 21 additions & 3 deletions python/_tskitmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ make_metadata(const char *metadata, Py_ssize_t length)
}

static PyObject *
make_mutation(const tsk_mutation_t *mutation)
make_mutation_row(const tsk_mutation_t *mutation)
{
PyObject *ret = NULL;
PyObject *metadata = NULL;
Expand All @@ -339,6 +339,24 @@ make_mutation(const tsk_mutation_t *mutation)
return ret;
}

static PyObject *
make_mutation_object(const tsk_mutation_t *mutation)
{
PyObject *ret = NULL;
PyObject *metadata = NULL;

metadata = make_metadata(mutation->metadata, (Py_ssize_t) mutation->metadata_length);
if (metadata == NULL) {
goto out;
}
ret = Py_BuildValue("iis#iOdi", mutation->site, mutation->node,
mutation->derived_state, (Py_ssize_t) mutation->derived_state_length,
mutation->parent, metadata, mutation->time, mutation->edge);
out:
Py_XDECREF(metadata);
return ret;
}

static PyObject *
make_mutation_id_list(const tsk_mutation_t *mutations, tsk_size_t length)
{
Expand Down Expand Up @@ -4016,7 +4034,7 @@ MutationTable_get_row(MutationTable *self, PyObject *args)
handle_library_error(err);
goto out;
}
ret = make_mutation(&mutation);
ret = make_mutation_row(&mutation);
out:
return ret;
}
Expand Down Expand Up @@ -7984,7 +8002,7 @@ TreeSequence_get_mutation(TreeSequence *self, PyObject *args)
handle_library_error(err);
goto out;
}
ret = make_mutation(&record);
ret = make_mutation_object(&record);
out:
return ret;
}
Expand Down
13 changes: 11 additions & 2 deletions python/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# MIT License
#
# Copyright (c) 2018-2021 Tskit Developers
# Copyright (c) 2018-2022 Tskit Developers
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -138,7 +138,15 @@ def __init__(self, tree_sequence, breakpoints=None):
ll_ts = self._tree_sequence._ll_tree_sequence

def make_mutation(id_):
site, node, derived_state, parent, metadata, time = ll_ts.get_mutation(id_)
(
site,
node,
derived_state,
parent,
metadata,
time,
edge,
) = ll_ts.get_mutation(id_)
return tskit.Mutation(
id=id_,
site=site,
Expand All @@ -147,6 +155,7 @@ def make_mutation(id_):
derived_state=derived_state,
parent=parent,
metadata=metadata,
edge=edge,
metadata_decoder=tskit.metadata.parse_metadata_schema(
ll_ts.get_table_metadata_schemas().mutation
).decode_row,
Expand Down
33 changes: 17 additions & 16 deletions python/tests/test_lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,22 +1215,20 @@ def test_dump_equality(self, tmp_path):
ts2.dump_tables(tc2)
assert tc.equals(tc2)

def verify_mutations(self, ts):
mutations = [ts.get_mutation(j) for j in range(ts.get_num_mutations())]
assert ts.get_num_mutations() > 0
assert len(mutations) == ts.get_num_mutations()
# Check the form of the mutations
for j, (position, nodes, index) in enumerate(mutations):
assert j == index
for node in nodes:
def test_get_mutation_interface(self):
for ts in self.get_example_tree_sequences():
mutations = [ts.get_mutation(j) for j in range(ts.get_num_mutations())]
assert len(mutations) == ts.get_num_mutations()
# Check the form of the mutations
for packed in mutations:
site, node, derived_state, parent, metadata, time, edge = packed
assert isinstance(site, int)
assert isinstance(node, int)
assert node >= 0
assert node <= ts.get_num_nodes()
assert isinstance(position, float)
assert position > 0
assert position < ts.get_sequence_length()
# mutations must be sorted by position order.
assert mutations == sorted(mutations)
assert isinstance(derived_state, str)
assert isinstance(parent, int)
assert isinstance(metadata, bytes)
assert isinstance(time, float)
assert isinstance(edge, int)

def test_get_edge_interface(self):
for ts in self.get_example_tree_sequences():
Expand Down Expand Up @@ -2718,12 +2716,13 @@ def test_sites(self):
all_tree_sites.extend(tree_sites)
for (
position,
_ancestral_state,
ancestral_state,
mutations,
index,
metadata,
) in tree_sites:
assert st.get_left() <= position < st.get_right()
assert isinstance(ancestral_state, str)
assert index == j
assert metadata == b""
for mut_id in mutations:
Expand All @@ -2734,11 +2733,13 @@ def test_sites(self):
parent,
metadata,
time,
edge,
) = ts.get_mutation(mut_id)
assert site == index
assert mutation_id == mut_id
assert st.get_parent(node) != _tskit.NULL
assert metadata == b""
assert edge != _tskit.NULL
mutation_id += 1
j += 1
assert all_tree_sites == all_sites
Expand Down
68 changes: 68 additions & 0 deletions python/tests/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -6201,6 +6201,74 @@ def test_many_multiroot_trees_recurrent_mutations(self):
self.verify_branch_mutations(ts, mutations_per_branch)


class TestMutationEdge:
def verify_mutation_edge(self, ts):
# print(ts.tables)
for mutation in ts.mutations():
site = ts.site(mutation.site)
if mutation.edge == tskit.NULL:
edges = [
edge
for edge in ts.edges()
if edge.left <= site.position < edge.right
and mutation.node == edge.child
]
assert len(edges) == 0
else:
edge = ts.edge(mutation.edge)
assert edge.left <= site.position < edge.right
assert edge.child == mutation.node

for tree in ts.trees():
for site in tree.sites():
for mutation in site.mutations:
assert mutation.edge == ts.mutation(mutation.id).edge
if mutation.edge == tskit.NULL:
assert tree.parent(mutation.node) == tskit.NULL

def verify_branch_mutations(self, ts, mutations_per_branch):
ts = tsutil.insert_branch_mutations(ts, mutations_per_branch)
assert ts.num_mutations > 1
self.verify_mutation_edge(ts)

def test_single_tree_one_mutation_per_branch(self):
ts = msprime.simulate(6, random_seed=10)
self.verify_branch_mutations(ts, 1)

def test_single_tree_two_mutations_per_branch(self):
ts = msprime.simulate(10, random_seed=9)
self.verify_branch_mutations(ts, 2)

def test_single_tree_three_mutations_per_branch(self):
ts = msprime.simulate(8, random_seed=9)
self.verify_branch_mutations(ts, 3)

def test_single_multiroot_tree_recurrent_mutations(self):
ts = msprime.simulate(6, random_seed=10)
ts = tsutil.decapitate(ts, ts.num_edges // 2)
for mutations_per_branch in [1, 2, 3]:
self.verify_branch_mutations(ts, mutations_per_branch)

def test_many_multiroot_trees_recurrent_mutations(self):
ts = msprime.simulate(7, recombination_rate=1, random_seed=10)
assert ts.num_trees > 3
ts = tsutil.decapitate(ts, ts.num_edges // 2)
for mutations_per_branch in [1, 2, 3]:
self.verify_branch_mutations(ts, mutations_per_branch)

@pytest.mark.parametrize("n", range(2, 5))
@pytest.mark.parametrize("mutations_per_branch", range(3))
def test_balanced_binary_tree(self, n, mutations_per_branch):
ts = tskit.Tree.generate_balanced(4).tree_sequence
# These trees have a handy property
assert all(edge.id == edge.child for edge in ts.edges())
for mutation in ts.mutations():
assert mutation.edge == mutation.node
for site in ts.first().sites():
for mutation in site.mutations:
assert mutation.edge == mutation.node


class TestMutationTime:
"""
Tests that mutation time is correctly specified, and that we correctly
Expand Down
20 changes: 19 additions & 1 deletion python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,16 @@ class Mutation(util.Dataclass):
underlying tree sequence data.
"""

__slots__ = ["id", "site", "node", "derived_state", "parent", "metadata", "time"]
__slots__ = [
"id",
"site",
"node",
"derived_state",
"parent",
"metadata",
"time",
"edge",
]
id: int # noqa A003
"""
The integer ID of this mutation. Varies from 0 to
Expand Down Expand Up @@ -363,6 +372,10 @@ class Mutation(util.Dataclass):
"""
The occurrence time of this mutation.
"""
edge: int
"""
The ID of the edge that this mutation is on.
"""

# To get default values on slots we define a custom init
def __init__(
Expand All @@ -374,6 +387,7 @@ def __init__(
derived_state=None,
parent=NULL,
metadata=b"",
edge=NULL,
):
self.id = id
self.site = site
Expand All @@ -382,6 +396,7 @@ def __init__(
self.derived_state = derived_state
self.parent = parent
self.metadata = metadata
self.edge = edge

# We need a custom eq to compare unknown times.
def __eq__(self, other):
Expand All @@ -392,6 +407,7 @@ def __eq__(self, other):
and self.node == other.node
and self.derived_state == other.derived_state
and self.parent == other.parent
and self.edge == other.edge
and self.metadata == other.metadata
and (
self.time == other.time
Expand Down Expand Up @@ -5179,6 +5195,7 @@ def mutation(self, id_):
parent,
metadata,
time,
edge,
) = self._ll_tree_sequence.get_mutation(id_)
return Mutation(
id=id_,
Expand All @@ -5188,6 +5205,7 @@ def mutation(self, id_):
parent=parent,
metadata=metadata,
time=time,
edge=edge,
metadata_decoder=self.table_metadata_schemas.mutation.decode_row,
)

Expand Down

0 comments on commit 6736cef

Please sign in to comment.