Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change keep/del intervals to in-place #372

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions docs/data-model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -708,13 +708,20 @@ Schema section (TODO).
Table transformation methods
============================

The following methods operate *in place* on a :class:`.TableCollection`,
transforming them while preserving information.
In general, table methods operate *in place* on a :class:`.TableCollection`,
directly altering the data stored within its constituent tables.

In some applications, tables may most naturally be produced in a way that is
logically consistent, but not meeting all the requirements for validity that
are established for algorithmic and efficiency reasons.
These methods (while having other uses), can be used to make such a set of
tables valid, and thus ready to be loaded into a tree sequence.
are established for algorithmic and efficiency reasons. Several of the methods
below (while also having other uses), can be used to make such a set of tables
valid, and thus ready to be loaded into a tree sequence.

Some of the other methods described in this section also have an equivalant
:class:`.TreeSequence` version: an important distinction is that unlike the
methods here, :class:`.TreeSequence` methods do *not* operate in place, but
rather act in a functional way, returning a new tree sequence while leaving
the original one unchanged.

This section is best skipped unless you are writing a program that records
tables directly.
Expand Down
3 changes: 2 additions & 1 deletion python/tests/test_genotypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,8 @@ def test_jukes_cantor_n20(self):

def test_zero_edge_missing_data(self):
ts = msprime.simulate(10, random_seed=2, mutation_rate=2)
tables = ts.tables.keep_intervals([[0.25, 0.75]])
tables = ts.dump_tables()
tables.keep_intervals([[0.25, 0.75]])
# add some sites in the deleted regions
tables.sites.add_row(0.1, "A")
tables.sites.add_row(0.2, "A")
Expand Down
17 changes: 17 additions & 0 deletions python/tests/test_provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,20 @@ def test_form(self):
s = provenance.get_schema()
self.assertEqual(s["schema"], "http://json-schema.org/draft-07/schema#")
self.assertEqual(s["version"], "1.0.0")


class TestTreeSeqEditMethods(unittest.TestCase):
"""
Ensure that tree sequence 'edit' methods correctly record themselves
"""
def test_keep_delete_different(self):
ts = msprime.simulate(5, random_seed=1)
ts_keep = ts.keep_intervals([[0.25, 0.5]])
ts_del = ts.delete_intervals([[0, 0.25], [0.5, 1.0]])
self.assertEqual(ts_keep.num_provenances, ts_del.num_provenances)
for i, (p1, p2) in enumerate(zip(ts_keep.provenances(), ts_del.provenances())):
if i == ts_keep.num_provenances - 1:
# last one should be different
self.assertNotEqual(p1.record, p2.record)
else:
self.assertEqual(p1.record, p2.record)
100 changes: 60 additions & 40 deletions python/tests/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,30 @@
import tests.test_wright_fisher as wf


def ts_equal(ts_1, ts_2, compare_provenances=True):
"""
Check equality of tree sequences, ignoring provenance timestamps (but not contents)
"""
return tables_equal(ts_1.tables, ts_2.tables, compare_provenances)


def tables_equal(table_collection_1, table_collection_2, compare_provenances=True):
"""
Check equality of tables, ignoring provenance timestamps (but not contents)
"""
for (_, table_1), (_, table_2) in zip(table_collection_1, table_collection_2):
if isinstance(table_1, tskit.ProvenanceTable):
if compare_provenances:
if np.any(table_1.record != table_2.record):
return False
if np.any(table_1.record_offset != table_2.record_offset):
return False
else:
if table_1 != table_2:
return False
return True


def simple_keep_intervals(tables, intervals, simplify=True, record_provenance=True):
"""
Simple Python implementation of keep_intervals.
Expand All @@ -54,7 +78,6 @@ def simple_keep_intervals(tables, intervals, simplify=True, record_provenance=Tr
if start < last_stop:
raise ValueError("Intervals must be disjoint")
last_stop = stop
tables = tables.copy()
tables.edges.clear()
tables.sites.clear()
tables.mutations.clear()
Expand Down Expand Up @@ -83,7 +106,6 @@ def simple_keep_intervals(tables, intervals, simplify=True, record_provenance=Tr
}
tables.provenances.add_row(record=json.dumps(
provenance.get_provenance_dict(parameters)))
return tables


def generate_segments(n, sequence_length=100, seed=None):
Expand Down Expand Up @@ -4566,8 +4588,7 @@ def test_slice_by_tree_positions(self):
breakpoints = list(ts.breakpoints())

# Keep the last 3 trees (from 4th last breakpoint onwards)
ts_sliced = ts.tables.keep_intervals(
[[breakpoints[-4], ts.sequence_length]]).tree_sequence()
ts_sliced = ts.keep_intervals([[breakpoints[-4], ts.sequence_length]])
self.assertEqual(ts_sliced.num_trees, 4)
self.assertLess(ts_sliced.num_edges, ts.num_edges)
self.assertAlmostEqual(ts_sliced.sequence_length, 1.0)
Expand All @@ -4577,8 +4598,7 @@ def test_slice_by_tree_positions(self):
self.assertEqual(ts_sliced.num_mutations, last_3_mutations)

# Keep the first 3 trees
ts_sliced = ts.tables.keep_intervals(
[[0, breakpoints[3]]]).tree_sequence()
ts_sliced = ts.keep_intervals([[0, breakpoints[3]]])
self.assertEqual(ts_sliced.num_trees, 4)
self.assertLess(ts_sliced.num_edges, ts.num_edges)
self.assertAlmostEqual(ts_sliced.sequence_length, 1)
Expand All @@ -4588,8 +4608,7 @@ def test_slice_by_tree_positions(self):
self.assertEqual(ts_sliced.num_mutations, first_3_mutations)

# Slice out the middle
ts_sliced = ts.tables.keep_intervals(
[[breakpoints[3], breakpoints[-4]]]).tree_sequence()
ts_sliced = ts.keep_intervals([[breakpoints[3], breakpoints[-4]]])
self.assertEqual(ts_sliced.num_trees, ts.num_trees - 4)
self.assertLess(ts_sliced.num_edges, ts.num_edges)
self.assertAlmostEqual(ts_sliced.sequence_length, 1.0)
Expand All @@ -4599,24 +4618,23 @@ def test_slice_by_tree_positions(self):

def test_slice_by_position(self):
ts = msprime.simulate(5, random_seed=1, recombination_rate=2, mutation_rate=2)
ts_sliced = ts.tables.keep_intervals([[0.4, 0.6]]).tree_sequence()
ts_sliced = ts.keep_intervals([[0.4, 0.6]])
positions = ts.tables.sites.position
self.assertEqual(
ts_sliced.num_sites, np.sum((positions >= 0.4) & (positions < 0.6)))

def test_slice_unsimplified(self):
ts = msprime.simulate(5, random_seed=1, recombination_rate=2, mutation_rate=2)
ts_sliced = ts.tables.keep_intervals([[0.4, 0.6]], simplify=True).tree_sequence()
ts_sliced = ts.keep_intervals([[0.4, 0.6]], simplify=True)
self.assertNotEqual(ts.num_nodes, ts_sliced.num_nodes)
self.assertAlmostEqual(ts_sliced.sequence_length, 1.0)
ts_sliced = ts.tables.keep_intervals(
[[0.4, 0.6]], simplify=False).tree_sequence()
ts_sliced = ts.keep_intervals([[0.4, 0.6]], simplify=False)
self.assertEqual(ts.num_nodes, ts_sliced.num_nodes)
self.assertAlmostEqual(ts_sliced.sequence_length, 1.0)

def test_slice_coordinates(self):
ts = msprime.simulate(5, random_seed=1, recombination_rate=2, mutation_rate=2)
ts_sliced = ts.tables.keep_intervals([[0.4, 0.6]]).tree_sequence()
ts_sliced = ts.keep_intervals([[0.4, 0.6]])
self.assertAlmostEqual(ts_sliced.sequence_length, 1)
self.assertNotEqual(ts_sliced.num_trees, ts.num_trees)
self.assertEqual(ts_sliced.at_index(0).total_branch_length, 0)
Expand Down Expand Up @@ -4647,16 +4665,12 @@ def example_intervals(self, tables):

def do_keep_intervals(
self, tables, intervals, simplify=True, record_provenance=True):
t1 = simple_keep_intervals(tables, intervals, simplify, record_provenance)
t2 = tables.keep_intervals(intervals, simplify, record_provenance)
t3 = t2.copy()
self.assertEqual(len(t1.provenances), len(t2.provenances))
# Provenances may differ using timestamps, so ignore them
# (this is a hack, as we prob want to compare their contents)
t1.provenances.clear()
t2.provenances.clear()
self.assertEqual(t1, t2)
return t3
t1 = tables.copy()
simple_keep_intervals(t1, intervals, simplify, record_provenance)
t2 = tables.copy()
t2.keep_intervals(intervals, simplify, record_provenance)
self.assertTrue(tables_equal(t1, t2))
return t2

def test_migration_error(self):
tables = tskit.TableCollection(1)
Expand Down Expand Up @@ -4714,15 +4728,6 @@ def test_hundred_intervals(self):
for rec_prov in (True, False):
self.do_keep_intervals(tables, intervals, simplify, rec_prov)

def test_read_only(self):
# tables.keep_intervals should not alter the source tables
ts = msprime.simulate(10, random_seed=4, recombination_rate=2, mutation_rate=2)
source_tables = ts.tables
source_tables.keep_intervals([(0.5, 0.511)])
self.assertEqual(source_tables, ts.dump_tables())
source_tables.keep_intervals([(0.5, 0.511)], simplify=False)
self.assertEqual(source_tables, ts.dump_tables())

def test_regular_intervals(self):
ts = msprime.simulate(
3, random_seed=1234, recombination_rate=2, mutation_rate=2)
Expand Down Expand Up @@ -4829,20 +4834,35 @@ class TestKeepDeleteIntervalsExamples(unittest.TestCase):
Simple examples of keep/delete intervals at work.
"""

def test_single_tree_keep_middle(self):
def test_tables_single_tree_keep_middle(self):
ts = msprime.simulate(10, random_seed=2)
tables = ts.tables
t_keep = tables.keep_intervals([[0.25, 0.5]])
t_delete = tables.delete_intervals([[0, 0.25], [0.5, 1.0]])
t_keep = ts.dump_tables()
t_keep.keep_intervals([[0.25, 0.5]], record_provenance=False)
t_delete = ts.dump_tables()
t_delete.delete_intervals([[0, 0.25], [0.5, 1.0]], record_provenance=False)
t_keep.provenances.clear()
t_delete.provenances.clear()
self.assertEqual(t_keep, t_delete)

def test_single_tree_delete_middle(self):
def test_tables_single_tree_delete_middle(self):
ts = msprime.simulate(10, random_seed=2)
tables = ts.tables
t_keep = tables.delete_intervals([[0.25, 0.5]])
t_delete = tables.keep_intervals([[0, 0.25], [0.5, 1.0]])
t_keep = ts.dump_tables()
t_keep.delete_intervals([[0.25, 0.5]], record_provenance=False)
t_delete = ts.dump_tables()
t_delete.keep_intervals([[0, 0.25], [0.5, 1.0]], record_provenance=False)
t_keep.provenances.clear()
t_delete.provenances.clear()
self.assertEqual(t_keep, t_delete)

def test_ts_single_tree_keep_middle(self):
ts = msprime.simulate(10, random_seed=2)
ts_keep = ts.keep_intervals([[0.25, 0.5]], record_provenance=False)
ts_delete = ts.delete_intervals([[0, 0.25], [0.5, 1.0]], record_provenance=False)
self.assertTrue(ts_equal(ts_keep, ts_delete))

def test_ts_single_tree_delete_middle(self):
ts = msprime.simulate(10, random_seed=2)
ts_keep = ts.delete_intervals([[0.25, 0.5]], record_provenance=False)
ts_delete = ts.keep_intervals([[0, 0.25], [0.5, 1.0]], record_provenance=False)
# One provenance should have "delete_intervals", the other "keep_intervals")
self.assertTrue(ts_equal(ts_keep, ts_delete))
Loading