Skip to content

Commit

Permalink
Move interval and packing code to util.py
Browse files Browse the repository at this point in the history
Improve testing for keep/delete intervals and document.
  • Loading branch information
jeromekelleher committed Aug 10, 2019
1 parent dab4068 commit 95b7726
Show file tree
Hide file tree
Showing 6 changed files with 401 additions and 259 deletions.
53 changes: 0 additions & 53 deletions python/tests/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -1895,56 +1895,3 @@ def test_asdict_not_implemented(self):
t = tskit.BaseTable(None, None)
with self.assertRaises(NotImplementedError):
t.asdict()


class TestIntervalOps(unittest.TestCase):
"""
Test cases for the interval operations used in masks and slicing operations.
"""
def test_bad_intervals(self):
for bad_type in [{}, Exception]:
with self.assertRaises(TypeError):
tskit.intervals_to_np_array(bad_type, 0, 1)
for bad_depth in [[[[]]]]:
with self.assertRaises(ValueError):
tskit.intervals_to_np_array(bad_depth, 0, 1)
for bad_shape in [[[0], [0]], [[[0, 1, 2], [0, 1]]]]:
with self.assertRaises(ValueError):
tskit.intervals_to_np_array(bad_shape, 0, 1)

# Out of bounds
with self.assertRaises(ValueError):
tskit.intervals_to_np_array([[-1, 0]], 0, 1)
with self.assertRaises(ValueError):
tskit.intervals_to_np_array([[0, 1]], 1, 2)
with self.assertRaises(ValueError):
tskit.intervals_to_np_array([[0, 1]], 0, 0.5)

# Overlapping intervals
with self.assertRaises(ValueError):
tskit.intervals_to_np_array([[0, 1], [0.9, 2.0]], 0, 10)

# Empty intervals
for bad_interval in [[0, 0], [1, 0]]:
with self.assertRaises(ValueError):
tskit.intervals_to_np_array([bad_interval], 0, 10)

def test_empty_interval_list(self):
intervals = tskit.intervals_to_np_array([], 0, 10)
self.assertEqual(len(intervals), 0)

def test_negate_intervals(self):
L = 10
cases = [
([], [[0, L]]),
([[0, 5], [6, L]], [[5, 6]]),
([[0, 5]], [[5, L]]),
([[5, L]], [[0, 5]]),
(
[[0, 1], [2, 3], [3, 4], [5, 6]],
[[1, 2], [4, 5], [6, L]]
),
]
for source, dest in cases:
self.assertTrue(np.array_equal(
tskit.negate_intervals(source, 0, L), dest))
208 changes: 156 additions & 52 deletions python/tests/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def simple_keep_intervals(tables, intervals, simplify=True, record_provenance=Tr
site.position, site.ancestral_state, site.metadata)
for m in site.mutations:
tables.mutations.add_row(
site_id, m.node, m.derived_state, m.parent, m.metadata)
site_id, m.node, m.derived_state, tskit.NULL, m.metadata)
tables.build_index()
tables.compute_mutation_parents()
if simplify:
tables.simplify()
if record_provenance:
Expand Down Expand Up @@ -4357,8 +4359,7 @@ def test_edge_cases(self):
self.assertEqual(search_sorted([1], v), np.searchsorted([1], v))


@unittest.skip("TODO refactor these tests")
class TestSlice(TopologyTestCase):
class TestKeepSingleInterval(unittest.TestCase):
"""
Tests for cutting up tree sequences along the genome.
"""
Expand All @@ -4367,61 +4368,56 @@ def test_slice_by_tree_positions(self):
breakpoints = list(ts.breakpoints())

# Keep the last 3 trees (from 4th last breakpoint onwards)
ts_sliced = ts.tables.slice(start=breakpoints[-4]).tree_sequence()
self.assertEqual(ts_sliced.num_trees, 3)
ts_sliced = ts.tables.keep_intervals(
[[breakpoints[-4], ts.sequence_length]]).tree_sequence()
self.assertEqual(ts_sliced.num_trees, 4)
self.assertLess(ts_sliced.num_edges, ts.num_edges)
self.assertAlmostEqual(ts_sliced.sequence_length, 1.0 - breakpoints[-4])
self.assertAlmostEqual(ts_sliced.sequence_length, 1.0)
last_3_mutations = 0
for tree_index in range(-3, 0):
last_3_mutations += ts.at_index(tree_index).num_mutations
self.assertEqual(ts_sliced.num_mutations, last_3_mutations)

# Keep the first 3 trees
ts_sliced = ts.tables.slice(stop=breakpoints[3]).tree_sequence()
self.assertEqual(ts_sliced.num_trees, 3)
ts_sliced = ts.tables.keep_intervals(
[[0, breakpoints[3]]]).tree_sequence()
self.assertEqual(ts_sliced.num_trees, 4)
self.assertLess(ts_sliced.num_edges, ts.num_edges)
self.assertAlmostEqual(ts_sliced.sequence_length, breakpoints[3])
self.assertAlmostEqual(ts_sliced.sequence_length, 1)
first_3_mutations = 0
for tree_index in range(0, 3):
first_3_mutations += ts.at_index(tree_index).num_mutations
self.assertEqual(ts_sliced.num_mutations, first_3_mutations)

# Slice out the middle
ts_sliced = ts.tables.slice(breakpoints[3], breakpoints[-4]).tree_sequence()
self.assertEqual(ts_sliced.num_trees, ts.num_trees - 6)
ts_sliced = ts.tables.keep_intervals(
[[breakpoints[3], breakpoints[-4]]]).tree_sequence()
self.assertEqual(ts_sliced.num_trees, ts.num_trees - 4)
self.assertLess(ts_sliced.num_edges, ts.num_edges)
self.assertAlmostEqual(
ts_sliced.sequence_length, breakpoints[-4] - breakpoints[3])
self.assertAlmostEqual(ts_sliced.sequence_length, 1.0)
self.assertEqual(
ts_sliced.num_mutations,
ts.num_mutations - first_3_mutations - last_3_mutations)

def test_slice_by_position(self):
ts = msprime.simulate(5, random_seed=1, recombination_rate=2, mutation_rate=2)
ts_sliced = ts.tables.slice(0.4, 0.6).tree_sequence()
ts_sliced = ts.tables.keep_intervals([[0.4, 0.6]]).tree_sequence()
positions = ts.tables.sites.position
self.assertEqual(
ts_sliced.num_sites, np.sum((positions >= 0.4) & (positions < 0.6)))

def test_slice_bounds(self):
ts = msprime.simulate(5, random_seed=1, recombination_rate=2, mutation_rate=2)
tables = ts.tables
self.assertRaises(ValueError, tables.slice, -1)
self.assertRaises(ValueError, tables.slice, stop=2)
self.assertRaises(ValueError, tables.slice, 0.8, 0.2)

def test_slice_unsimplified(self):
ts = msprime.simulate(5, random_seed=1, recombination_rate=2, mutation_rate=2)
ts_sliced = ts.tables.slice(0.4, 0.6, simplify=True).tree_sequence()
ts_sliced = ts.tables.keep_intervals([[0.4, 0.6]], simplify=True).tree_sequence()
self.assertNotEqual(ts.num_nodes, ts_sliced.num_nodes)
self.assertAlmostEqual(ts_sliced.sequence_length, 0.2)
ts_sliced = ts.tables.slice(0.4, 0.6, simplify=False).tree_sequence()
self.assertAlmostEqual(ts_sliced.sequence_length, 1.0)
ts_sliced = ts.tables.keep_intervals([[0.4, 0.6]], simplify=False).tree_sequence()
self.assertEqual(ts.num_nodes, ts_sliced.num_nodes)
self.assertAlmostEqual(ts_sliced.sequence_length, 0.2)
self.assertAlmostEqual(ts_sliced.sequence_length, 1.0)

def test_slice_keep_coordinates(self):
def test_slice_coordinates(self):
ts = msprime.simulate(5, random_seed=1, recombination_rate=2, mutation_rate=2)
ts_sliced = ts.tables.slice(0.4, 0.6, reset_coordinates=False).tree_sequence()
ts_sliced = ts.tables.keep_intervals([[0.4, 0.6]]).tree_sequence()
self.assertAlmostEqual(ts_sliced.sequence_length, 1)
self.assertNotEqual(ts_sliced.num_trees, ts.num_trees)
self.assertEqual(ts_sliced.at_index(0).total_branch_length, 0)
Expand All @@ -4442,7 +4438,13 @@ class TestKeepIntervals(TopologyTestCase):
"""
def example_intervals(self, tables):
L = tables.sequence_length
yield []
yield [(0, L)]
yield [(0, L / 2), (L / 2, L)]
yield [(0, 0.25 * L), (0.75 * L, L)]
yield [(0.25 * L, L)]
yield [(0.25 * L, 0.5 * L)]
yield [(0.25 * L, 0.5 * L), (0.75 * L, 0.8 * L)]

def do_keep_intervals(
self, tables, intervals, simplify=True, record_provenance=True):
Expand All @@ -4454,13 +4456,29 @@ def do_keep_intervals(
# (this is a hack, as we prob want to compare their contents)
t1.provenances.clear()
t2.provenances.clear()
# print(t1.edges)
# print(t2.edges)
# print(t1.edges == t2.edges)
# print(t1.sequence_length == t2.sequence_length)
self.assertEqual(t1, t2)
return t3

def test_migration_error(self):
tables = tskit.TableCollection(1)
tables.migrations.add_row(0, 1, 0, 0, 0, 0)
with self.assertRaises(ValueError):
tables.keep_intervals([[0, 1]])

def test_bad_intervals(self):
tables = tskit.TableCollection(10)
bad_intervals = [
[[1, 1]],
[[-1, 0]],
[[0, 11]],
[[0, 5], [4, 6]]
]
for intervals in bad_intervals:
with self.assertRaises(ValueError):
tables.keep_intervals(intervals)
with self.assertRaises(ValueError):
tables.delete_intervals(intervals)

def test_one_interval(self):
ts = msprime.simulate(
10, random_seed=self.random_seed, recombination_rate=2, mutation_rate=2)
Expand Down Expand Up @@ -4516,30 +4534,116 @@ def test_regular_intervals(self):
intervals = [(x, x + eps) for x in breaks[:-1]]
self.do_keep_intervals(tables, intervals)

def test_empty_tables(self):
tables = tskit.TableCollection(1.0)
diced = self.do_keep_intervals(tables, [(0, 1)])
self.assertEqual(diced.sequence_length, 1)
self.assertEqual(len(diced.edges), 0)
for intervals in self.example_intervals(tables):
diced = self.do_keep_intervals(tables, [(0, 1)])
self.assertEqual(len(diced.edges), 0)
self.assertAlmostEqual(
diced.sequence_length, sum(r - l for l, r in intervals))

@unittest.skip("Sites not working - not sure why not")
def test_no_edges_sites(self):
tables = tskit.TableCollection(1.0)
tables.sites.add_row(0.1, "A")
tables.sites.add_row(0.2, "T")
diced = self.do_dice(tables, [(0, 1)])
self.assertEqual(diced.sequence_length, 1)
self.assertEqual(len(diced.edges), 0)
print(tables.sites)
print(diced.sites)
self.assertEqual(tables.sites, diced.sites)
for intervals in self.example_intervals(tables):
diced = self.do_dice(tables, [(0, 1)])
self.assertEqual(len(tables.sites), 2)
diced = self.do_keep_intervals(tables, intervals)
self.assertEqual(diced.sequence_length, 1)
self.assertEqual(len(diced.edges), 0)
self.assertAlmostEqual(
diced.sequence_length, sum(r - l for l, r in intervals))
self.assertEqual(len(diced.sites), 0)

def verify(self, tables):
for intervals in self.example_intervals(tables):
for simplify in [True, False]:
self.do_keep_intervals(tables, intervals, simplify=simplify)

def test_empty_tables(self):
tables = tskit.TableCollection(1.0)
self.verify(tables)

def test_single_tree_jukes_cantor(self):
ts = msprime.simulate(6, random_seed=1, mutation_rate=1)
ts = tsutil.jukes_cantor(ts, 20, 1, seed=10)
self.verify(ts.tables)

def test_single_tree_multichar_mutations(self):
ts = msprime.simulate(6, random_seed=1, mutation_rate=1)
ts = tsutil.insert_multichar_mutations(ts)
self.verify(ts.tables)

def test_many_trees_infinite_sites(self):
ts = msprime.simulate(6, recombination_rate=2, mutation_rate=2, random_seed=1)
self.assertGreater(ts.num_sites, 0)
self.assertGreater(ts.num_trees, 2)
self.verify(ts.tables)

def test_many_trees_sequence_length_infinite_sites(self):
for L in [0.5, 1.5, 3.3333]:
ts = msprime.simulate(
6, length=L, recombination_rate=2, mutation_rate=1, random_seed=1)
self.verify(ts.tables)

def test_wright_fisher_unsimplified(self):
tables = wf.wf_sim(
4, 5, seed=1, deep_history=True, initial_generation_samples=False,
num_loci=10)
tables.sort()
ts = msprime.mutate(tables.tree_sequence(), rate=0.05, random_seed=234)
self.assertGreater(ts.num_sites, 0)
self.verify(ts.tables)

def test_wright_fisher_initial_generation(self):
tables = wf.wf_sim(
6, 5, seed=3, deep_history=True, initial_generation_samples=True,
num_loci=2)
tables.sort()
tables.simplify()
ts = msprime.mutate(tables.tree_sequence(), rate=0.08, random_seed=2)
self.assertGreater(ts.num_sites, 0)
self.verify(ts.tables)

def test_wright_fisher_initial_generation_no_deep_history(self):
tables = wf.wf_sim(
7, 15, seed=202, deep_history=False, initial_generation_samples=True,
num_loci=5)
tables.sort()
tables.simplify()
ts = msprime.mutate(tables.tree_sequence(), rate=0.01, random_seed=2)
self.assertGreater(ts.num_sites, 0)
self.verify(ts.tables)

def test_wright_fisher_unsimplified_multiple_roots(self):
tables = wf.wf_sim(
8, 15, seed=1, deep_history=False, initial_generation_samples=False,
num_loci=20)
tables.sort()
ts = msprime.mutate(tables.tree_sequence(), rate=0.006, random_seed=2)
self.assertGreater(ts.num_sites, 0)
self.verify(ts.tables)

def test_wright_fisher_simplified(self):
tables = wf.wf_sim(
9, 10, seed=1, deep_history=True, initial_generation_samples=False,
num_loci=5)
tables.sort()
ts = tables.tree_sequence().simplify()
ts = msprime.mutate(ts, rate=0.01, random_seed=1234)
self.assertGreater(ts.num_sites, 0)
self.verify(ts.tables)


class TestKeepDeleteIntervalsExamples(unittest.TestCase):
"""
Simple examples of keep/delete intervals at work.
"""

def test_single_tree_keep_middle(self):
ts = msprime.simulate(10, random_seed=2)
tables = ts.tables
t_keep = tables.keep_intervals([[0.25, 0.5]])
t_delete = tables.delete_intervals([[0, 0.25], [0.5, 1.0]])
t_keep.provenances.clear()
t_delete.provenances.clear()
self.assertEqual(t_keep, t_delete)

def test_single_tree_delete_middle(self):
ts = msprime.simulate(10, random_seed=2)
tables = ts.tables
t_keep = tables.delete_intervals([[0.25, 0.5]])
t_delete = tables.keep_intervals([[0, 0.25], [0.5, 1.0]])
t_keep.provenances.clear()
t_delete.provenances.clear()
self.assertEqual(t_keep, t_delete)
Loading

0 comments on commit 95b7726

Please sign in to comment.