Skip to content

Commit

Permalink
Add keep_intervals and delete_intervals.
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Aug 9, 2019
1 parent 7993e4b commit 335d16f
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 241 deletions.
53 changes: 53 additions & 0 deletions python/tests/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -1895,3 +1895,56 @@ def test_asdict_not_implemented(self):
t = tskit.BaseTable(None, None)
with self.assertRaises(NotImplementedError):
t.asdict()


class TestIntervalOps(unittest.TestCase):
"""
Test cases for the interval operations used in masks and slicing operations.
"""
def test_bad_intervals(self):
for bad_type in [{}, Exception]:
with self.assertRaises(TypeError):
tskit.intervals_to_np_array(bad_type, 0, 1)
for bad_depth in [[[[]]]]:
with self.assertRaises(ValueError):
tskit.intervals_to_np_array(bad_depth, 0, 1)
for bad_shape in [[[0], [0]], [[[0, 1, 2], [0, 1]]]]:
with self.assertRaises(ValueError):
tskit.intervals_to_np_array(bad_shape, 0, 1)

# Out of bounds
with self.assertRaises(ValueError):
tskit.intervals_to_np_array([[-1, 0]], 0, 1)
with self.assertRaises(ValueError):
tskit.intervals_to_np_array([[0, 1]], 1, 2)
with self.assertRaises(ValueError):
tskit.intervals_to_np_array([[0, 1]], 0, 0.5)

# Overlapping intervals
with self.assertRaises(ValueError):
tskit.intervals_to_np_array([[0, 1], [0.9, 2.0]], 0, 10)

# Empty intervals
for bad_interval in [[0, 0], [1, 0]]:
with self.assertRaises(ValueError):
tskit.intervals_to_np_array([bad_interval], 0, 10)

def test_empty_interval_list(self):
intervals = tskit.intervals_to_np_array([], 0, 10)
self.assertEqual(len(intervals), 0)

def test_negate_intervals(self):
L = 10
cases = [
([], [[0, L]]),
([[0, 5], [6, L]], [[5, 6]]),
([[0, 5]], [[5, L]]),
([[5, L]], [[0, 5]]),
(
[[0, 1], [2, 3], [3, 4], [5, 6]],
[[1, 2], [4, 5], [6, L]]
),
]
for source, dest in cases:
self.assertTrue(np.array_equal(
tskit.negate_intervals(source, 0, L), dest))
190 changes: 30 additions & 160 deletions python/tests/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,67 +40,9 @@
import tests.test_wright_fisher as wf


def slice_tables(
tables, start=None, stop=None, reset_coordinates=True, simplify=True,
record_provenance=True):
def simple_keep_intervals(tables, intervals, simplify=True, record_provenance=True):
"""
Simple Python implementation of slice (remove a single interval from a ts).
"""
if start is None:
start = 0
if stop is None:
stop = tables.sequence_length

if start < 0 or stop <= start or stop > tables.sequence_length:
raise ValueError("Slice bounds must be within the existing tree sequence")
ts = tables.tree_sequence()
ret = tables.copy()
ret.edges.clear()
ret.sites.clear()
ret.mutations.clear()
for edge in ts.edges():
if edge.right <= start or edge.left >= stop:
# This edge is outside the sliced area - do not include it
continue
if reset_coordinates:
ret.edges.add_row(
max(start, edge.left) - start, min(stop, edge.right) - start,
edge.parent, edge.child)
else:
ret.edges.add_row(
max(start, edge.left), min(stop, edge.right),
edge.parent, edge.child)
for site in ts.sites():
if start <= site.position < stop:
if reset_coordinates:
site_id = ret.sites.add_row(
site.position - start, site.ancestral_state, site.metadata)
else:
site_id = ret.sites.add_row(
site.position, site.ancestral_state, site.metadata)
for m in site.mutations:
ret.mutations.add_row(
site_id, m.node, m.derived_state, m.parent, m.metadata)
if reset_coordinates:
ret.sequence_length = stop - start
if simplify:
ret.simplify()
if record_provenance:
# TODO add slice arguments here
parameters = {
"command": "slice",
"TODO": "add slice parameters"
}
ret.provenances.add_row(record=json.dumps(
provenance.get_provenance_dict(parameters)))
return ret


def dice_tables(
tables, intervals, reset_coordinates=True, simplify=True,
record_provenance=True):
"""
Simple Python implementation of dice (keep a set of disjoint intervals from ts).
Simple Python implementation of keep_intervals.
"""
ts = tables.tree_sequence()
last_stop = 0
Expand All @@ -116,42 +58,26 @@ def dice_tables(
tables.edges.clear()
tables.sites.clear()
tables.mutations.clear()
max_right = 0
for edge in ts.edges():
offset = 0
for interval_left, interval_right in intervals:
if not (edge.right <= interval_left or edge.left >= interval_right):
left = max(interval_left, edge.left)
right = min(interval_right, edge.right)
if reset_coordinates:
left = offset + left - interval_left
right = offset + right - interval_left
max_right = max(right, max_right)
tables.edges.add_row(left, right, edge.parent, edge.child)
offset += interval_right - interval_left
if reset_coordinates:
interval_sum = sum(right - left for left, right in intervals)
tables.sequence_length = max(interval_sum, max_right)
for site in ts.sites():
offset = 0
for interval_left, interval_right in intervals:
if interval_left <= site.position < interval_right:
position = site.position
if reset_coordinates:
position = offset + position - interval_left
site_id = tables.sites.add_row(
position, site.ancestral_state, site.metadata)
site.position, site.ancestral_state, site.metadata)
for m in site.mutations:
tables.mutations.add_row(
site_id, m.node, m.derived_state, m.parent, m.metadata)
offset += interval_right - interval_left
if simplify:
tables.simplify()
if record_provenance:
# TODO add dice arguments here
parameters = {
"command": "dice",
"TODO": "add dice parameters"
"command": "keep_intervals",
"TODO": "add parameters"
}
tables.provenances.add_row(record=json.dumps(
provenance.get_provenance_dict(parameters)))
Expand Down Expand Up @@ -4431,31 +4357,11 @@ def test_edge_cases(self):
self.assertEqual(search_sorted([1], v), np.searchsorted([1], v))


@unittest.skip("TODO refactor these tests")
class TestSlice(TopologyTestCase):
"""
Tests for cutting up tree sequences along the genome.
"""
def test_lib_vs_simple_slice(self):
ts = msprime.simulate(
10, random_seed=self.random_seed, recombination_rate=2, mutation_rate=2)
for a, b in zip(np.random.uniform(0, 1, 10), np.random.uniform(0, 1, 10)):
if a != b:
for reset_coords in (True, False):
for simplify in (True, False):
for rec_prov in (True, False):
start = min(a, b)
stop = max(a, b)
tables = ts.dump_tables()
t1 = slice_tables(
tables, start, stop, reset_coords, simplify, rec_prov)
t2 = tables.slice(
start, stop, reset_coords, simplify, rec_prov)
# Provenances may differ using timestamps, so ignore them
# (this is a hack, as we prob want to compare their contents)
t2.provenances.clear()
t1.provenances.clear()
self.assertEqual(t1, t2)

def test_slice_by_tree_positions(self):
ts = msprime.simulate(5, random_seed=1, recombination_rate=2, mutation_rate=2)
breakpoints = list(ts.breakpoints())
Expand Down Expand Up @@ -4529,21 +4435,19 @@ def test_slice_keep_coordinates(self):
self.assertEqual(ts_sliced.at_index(-1).total_branch_length, 0)


class TestDice(TopologyTestCase):
class TestKeepIntervals(TopologyTestCase):
"""
Tests for the tree sequence dice operation, where we slice out multiple disjoint
Tests for keep_intervals operation, where we slice out multiple disjoint
intervals concurrently.
"""
def example_intervals(self, tables):
L = tables.sequence_length
yield [(0, L)]

def do_dice(
self, tables, intervals, reset_coordinates=True, simplify=True,
record_provenance=True):
t1 = dice_tables(
tables, intervals, reset_coordinates, simplify, record_provenance)
t2 = tables.dice(intervals, reset_coordinates, simplify, record_provenance)
def do_keep_intervals(
self, tables, intervals, simplify=True, record_provenance=True):
t1 = simple_keep_intervals(tables, intervals, simplify, record_provenance)
t2 = tables.keep_intervals(intervals, simplify, record_provenance)
t3 = t2.copy()
self.assertEqual(len(t1.provenances), len(t2.provenances))
# Provenances may differ using timestamps, so ignore them
Expand All @@ -4557,83 +4461,49 @@ def do_dice(
self.assertEqual(t1, t2)
return t3

def verify_multi_slice(self, tables, intervals):
# Dice should be equal to repeated applications of slice
results = [slice_tables(
tables, interval[0], interval[1], reset_coordinates=False,
simplify=False) for interval in intervals]
result = tables.copy()
result.edges.clear()
result.sites.clear()
result.mutations.clear()
for partial in results:
for edge in partial.edges:
result.edges.add_row(edge.left, edge.right, edge.parent, edge.child)
offset = len(result.sites)
for site in partial.sites:
result.sites.add_row(site.position, site.ancestral_state, site.metadata)
for mutation in partial.mutations:
result.mutations.add_row(
site=offset + mutation.site, node=mutation.node,
parent=mutation.parent, derived_state=mutation.derived_state,
metadata=mutation.metadata)
result.sort()
diced = tables.dice(intervals, reset_coordinates=False, simplify=False)
result.provenances.clear()
diced.provenances.clear()
self.assertEqual(result, diced)

def test_one_interval(self):
ts = msprime.simulate(
10, random_seed=self.random_seed, recombination_rate=2, mutation_rate=2)
tables = ts.tables
intervals = [(0.3, 0.7)]
self.verify_multi_slice(tables, intervals)
for reset_coords in (True, False):
for simplify in (True, False):
for rec_prov in (True, False):
self.do_dice(tables, intervals, reset_coords, simplify, rec_prov)
for simplify in (True, False):
for rec_prov in (True, False):
self.do_keep_intervals(tables, intervals, simplify, rec_prov)

def test_two_intervals(self):
ts = msprime.simulate(
10, random_seed=self.random_seed, recombination_rate=2, mutation_rate=2)
tables = ts.tables
intervals = [(0.1, 0.2), (0.8, 0.9)]
self.verify_multi_slice(tables, intervals)
for reset_coords in (True, False):
for simplify in (True, False):
for rec_prov in (True, False):
self.do_dice(tables, intervals, reset_coords, simplify, rec_prov)
for simplify in (True, False):
for rec_prov in (True, False):
self.do_keep_intervals(tables, intervals, simplify, rec_prov)

def test_ten_intervals(self):
ts = msprime.simulate(
10, random_seed=self.random_seed, recombination_rate=2, mutation_rate=2)
tables = ts.tables
intervals = [(x, x + 0.05) for x in np.arange(0.0, 1.0, 0.1)]
self.verify_multi_slice(tables, intervals)
for reset_coords in (True, False):
for simplify in (True, False):
for rec_prov in (True, False):
self.do_dice(tables, intervals, reset_coords, simplify, rec_prov)
for simplify in (True, False):
for rec_prov in (True, False):
self.do_keep_intervals(tables, intervals, simplify, rec_prov)

def test_hundred_intervals(self):
ts = msprime.simulate(
10, random_seed=self.random_seed, recombination_rate=2, mutation_rate=2)
tables = ts.tables
intervals = [(x, x + 0.005) for x in np.arange(0.0, 1.0, 0.01)]
self.verify_multi_slice(tables, intervals)
for reset_coords in (True, False):
for simplify in (True, False):
for rec_prov in (True, False):
self.do_dice(tables, intervals, reset_coords, simplify, rec_prov)
for simplify in (True, False):
for rec_prov in (True, False):
self.do_keep_intervals(tables, intervals, simplify, rec_prov)

def test_read_only(self):
# tables.dice should not alter the source tables
# tables.keep_intervals should not alter the source tables
ts = msprime.simulate(10, random_seed=4, recombination_rate=2, mutation_rate=2)
source_tables = ts.tables
source_tables.dice([(0.5, 0.511)])
source_tables.keep_intervals([(0.5, 0.511)])
self.assertEqual(source_tables, ts.dump_tables())
source_tables.dice([(0.5, 0.511)], simplify=False)
source_tables.keep_intervals([(0.5, 0.511)], simplify=False)
self.assertEqual(source_tables, ts.dump_tables())

def test_regular_intervals(self):
Expand All @@ -4644,15 +4514,15 @@ def test_regular_intervals(self):
for num_intervals in range(2, 10):
breaks = np.linspace(0, ts.sequence_length, num=num_intervals)
intervals = [(x, x + eps) for x in breaks[:-1]]
self.do_dice(tables, intervals)
self.do_keep_intervals(tables, intervals)

def test_empty_tables(self):
tables = tskit.TableCollection(1.0)
diced = self.do_dice(tables, [(0, 1)])
diced = self.do_keep_intervals(tables, [(0, 1)])
self.assertEqual(diced.sequence_length, 1)
self.assertEqual(len(diced.edges), 0)
for intervals in self.example_intervals(tables):
diced = self.do_dice(tables, [(0, 1)])
diced = self.do_keep_intervals(tables, [(0, 1)])
self.assertEqual(len(diced.edges), 0)
self.assertAlmostEqual(
diced.sequence_length, sum(r - l for l, r in intervals))
Expand Down
Loading

0 comments on commit 335d16f

Please sign in to comment.