diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index 8fdab20ab6..4be2fb85ad 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -1895,3 +1895,56 @@ def test_asdict_not_implemented(self): t = tskit.BaseTable(None, None) with self.assertRaises(NotImplementedError): t.asdict() + + +class TestIntervalOps(unittest.TestCase): + """ + Test cases for the interval operations used in masks and slicing operations. + """ + def test_bad_intervals(self): + for bad_type in [{}, Exception]: + with self.assertRaises(TypeError): + tskit.intervals_to_np_array(bad_type, 0, 1) + for bad_depth in [[[[]]]]: + with self.assertRaises(ValueError): + tskit.intervals_to_np_array(bad_depth, 0, 1) + for bad_shape in [[[0], [0]], [[[0, 1, 2], [0, 1]]]]: + with self.assertRaises(ValueError): + tskit.intervals_to_np_array(bad_shape, 0, 1) + + # Out of bounds + with self.assertRaises(ValueError): + tskit.intervals_to_np_array([[-1, 0]], 0, 1) + with self.assertRaises(ValueError): + tskit.intervals_to_np_array([[0, 1]], 1, 2) + with self.assertRaises(ValueError): + tskit.intervals_to_np_array([[0, 1]], 0, 0.5) + + # Overlapping intervals + with self.assertRaises(ValueError): + tskit.intervals_to_np_array([[0, 1], [0.9, 2.0]], 0, 10) + + # Empty intervals + for bad_interval in [[0, 0], [1, 0]]: + with self.assertRaises(ValueError): + tskit.intervals_to_np_array([bad_interval], 0, 10) + + def test_empty_interval_list(self): + intervals = tskit.intervals_to_np_array([], 0, 10) + self.assertEqual(len(intervals), 0) + + def test_negate_intervals(self): + L = 10 + cases = [ + ([], [[0, L]]), + ([[0, 5], [6, L]], [[5, 6]]), + ([[0, 5]], [[5, L]]), + ([[5, L]], [[0, 5]]), + ( + [[0, 1], [2, 3], [3, 4], [5, 6]], + [[1, 2], [4, 5], [6, L]] + ), + ] + for source, dest in cases: + self.assertTrue(np.array_equal( + tskit.negate_intervals(source, 0, L), dest)) diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index 19f13b8bb7..690c8fe79b 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -40,67 +40,9 @@ import tests.test_wright_fisher as wf -def slice_tables( - tables, start=None, stop=None, reset_coordinates=True, simplify=True, - record_provenance=True): +def simple_keep_intervals(tables, intervals, simplify=True, record_provenance=True): """ - Simple Python implementation of slice (remove a single interval from a ts). - """ - if start is None: - start = 0 - if stop is None: - stop = tables.sequence_length - - if start < 0 or stop <= start or stop > tables.sequence_length: - raise ValueError("Slice bounds must be within the existing tree sequence") - ts = tables.tree_sequence() - ret = tables.copy() - ret.edges.clear() - ret.sites.clear() - ret.mutations.clear() - for edge in ts.edges(): - if edge.right <= start or edge.left >= stop: - # This edge is outside the sliced area - do not include it - continue - if reset_coordinates: - ret.edges.add_row( - max(start, edge.left) - start, min(stop, edge.right) - start, - edge.parent, edge.child) - else: - ret.edges.add_row( - max(start, edge.left), min(stop, edge.right), - edge.parent, edge.child) - for site in ts.sites(): - if start <= site.position < stop: - if reset_coordinates: - site_id = ret.sites.add_row( - site.position - start, site.ancestral_state, site.metadata) - else: - site_id = ret.sites.add_row( - site.position, site.ancestral_state, site.metadata) - for m in site.mutations: - ret.mutations.add_row( - site_id, m.node, m.derived_state, m.parent, m.metadata) - if reset_coordinates: - ret.sequence_length = stop - start - if simplify: - ret.simplify() - if record_provenance: - # TODO add slice arguments here - parameters = { - "command": "slice", - "TODO": "add slice parameters" - } - ret.provenances.add_row(record=json.dumps( - provenance.get_provenance_dict(parameters))) - return ret - - -def dice_tables( - tables, intervals, reset_coordinates=True, simplify=True, - record_provenance=True): - """ - Simple Python implementation of dice (keep a set of disjoint intervals from ts). + Simple Python implementation of keep_intervals. """ ts = tables.tree_sequence() last_stop = 0 @@ -116,42 +58,26 @@ def dice_tables( tables.edges.clear() tables.sites.clear() tables.mutations.clear() - max_right = 0 for edge in ts.edges(): - offset = 0 for interval_left, interval_right in intervals: if not (edge.right <= interval_left or edge.left >= interval_right): left = max(interval_left, edge.left) right = min(interval_right, edge.right) - if reset_coordinates: - left = offset + left - interval_left - right = offset + right - interval_left - max_right = max(right, max_right) tables.edges.add_row(left, right, edge.parent, edge.child) - offset += interval_right - interval_left - if reset_coordinates: - interval_sum = sum(right - left for left, right in intervals) - tables.sequence_length = max(interval_sum, max_right) for site in ts.sites(): - offset = 0 for interval_left, interval_right in intervals: if interval_left <= site.position < interval_right: - position = site.position - if reset_coordinates: - position = offset + position - interval_left site_id = tables.sites.add_row( - position, site.ancestral_state, site.metadata) + site.position, site.ancestral_state, site.metadata) for m in site.mutations: tables.mutations.add_row( site_id, m.node, m.derived_state, m.parent, m.metadata) - offset += interval_right - interval_left if simplify: tables.simplify() if record_provenance: - # TODO add dice arguments here parameters = { - "command": "dice", - "TODO": "add dice parameters" + "command": "keep_intervals", + "TODO": "add parameters" } tables.provenances.add_row(record=json.dumps( provenance.get_provenance_dict(parameters))) @@ -4431,31 +4357,11 @@ def test_edge_cases(self): self.assertEqual(search_sorted([1], v), np.searchsorted([1], v)) +@unittest.skip("TODO refactor these tests") class TestSlice(TopologyTestCase): """ Tests for cutting up tree sequences along the genome. """ - def test_lib_vs_simple_slice(self): - ts = msprime.simulate( - 10, random_seed=self.random_seed, recombination_rate=2, mutation_rate=2) - for a, b in zip(np.random.uniform(0, 1, 10), np.random.uniform(0, 1, 10)): - if a != b: - for reset_coords in (True, False): - for simplify in (True, False): - for rec_prov in (True, False): - start = min(a, b) - stop = max(a, b) - tables = ts.dump_tables() - t1 = slice_tables( - tables, start, stop, reset_coords, simplify, rec_prov) - t2 = tables.slice( - start, stop, reset_coords, simplify, rec_prov) - # Provenances may differ using timestamps, so ignore them - # (this is a hack, as we prob want to compare their contents) - t2.provenances.clear() - t1.provenances.clear() - self.assertEqual(t1, t2) - def test_slice_by_tree_positions(self): ts = msprime.simulate(5, random_seed=1, recombination_rate=2, mutation_rate=2) breakpoints = list(ts.breakpoints()) @@ -4529,21 +4435,19 @@ def test_slice_keep_coordinates(self): self.assertEqual(ts_sliced.at_index(-1).total_branch_length, 0) -class TestDice(TopologyTestCase): +class TestKeepIntervals(TopologyTestCase): """ - Tests for the tree sequence dice operation, where we slice out multiple disjoint + Tests for keep_intervals operation, where we slice out multiple disjoint intervals concurrently. """ def example_intervals(self, tables): L = tables.sequence_length yield [(0, L)] - def do_dice( - self, tables, intervals, reset_coordinates=True, simplify=True, - record_provenance=True): - t1 = dice_tables( - tables, intervals, reset_coordinates, simplify, record_provenance) - t2 = tables.dice(intervals, reset_coordinates, simplify, record_provenance) + def do_keep_intervals( + self, tables, intervals, simplify=True, record_provenance=True): + t1 = simple_keep_intervals(tables, intervals, simplify, record_provenance) + t2 = tables.keep_intervals(intervals, simplify, record_provenance) t3 = t2.copy() self.assertEqual(len(t1.provenances), len(t2.provenances)) # Provenances may differ using timestamps, so ignore them @@ -4557,83 +4461,49 @@ def do_dice( self.assertEqual(t1, t2) return t3 - def verify_multi_slice(self, tables, intervals): - # Dice should be equal to repeated applications of slice - results = [slice_tables( - tables, interval[0], interval[1], reset_coordinates=False, - simplify=False) for interval in intervals] - result = tables.copy() - result.edges.clear() - result.sites.clear() - result.mutations.clear() - for partial in results: - for edge in partial.edges: - result.edges.add_row(edge.left, edge.right, edge.parent, edge.child) - offset = len(result.sites) - for site in partial.sites: - result.sites.add_row(site.position, site.ancestral_state, site.metadata) - for mutation in partial.mutations: - result.mutations.add_row( - site=offset + mutation.site, node=mutation.node, - parent=mutation.parent, derived_state=mutation.derived_state, - metadata=mutation.metadata) - result.sort() - diced = tables.dice(intervals, reset_coordinates=False, simplify=False) - result.provenances.clear() - diced.provenances.clear() - self.assertEqual(result, diced) - def test_one_interval(self): ts = msprime.simulate( 10, random_seed=self.random_seed, recombination_rate=2, mutation_rate=2) tables = ts.tables intervals = [(0.3, 0.7)] - self.verify_multi_slice(tables, intervals) - for reset_coords in (True, False): - for simplify in (True, False): - for rec_prov in (True, False): - self.do_dice(tables, intervals, reset_coords, simplify, rec_prov) + for simplify in (True, False): + for rec_prov in (True, False): + self.do_keep_intervals(tables, intervals, simplify, rec_prov) def test_two_intervals(self): ts = msprime.simulate( 10, random_seed=self.random_seed, recombination_rate=2, mutation_rate=2) tables = ts.tables intervals = [(0.1, 0.2), (0.8, 0.9)] - self.verify_multi_slice(tables, intervals) - for reset_coords in (True, False): - for simplify in (True, False): - for rec_prov in (True, False): - self.do_dice(tables, intervals, reset_coords, simplify, rec_prov) + for simplify in (True, False): + for rec_prov in (True, False): + self.do_keep_intervals(tables, intervals, simplify, rec_prov) def test_ten_intervals(self): ts = msprime.simulate( 10, random_seed=self.random_seed, recombination_rate=2, mutation_rate=2) tables = ts.tables intervals = [(x, x + 0.05) for x in np.arange(0.0, 1.0, 0.1)] - self.verify_multi_slice(tables, intervals) - for reset_coords in (True, False): - for simplify in (True, False): - for rec_prov in (True, False): - self.do_dice(tables, intervals, reset_coords, simplify, rec_prov) + for simplify in (True, False): + for rec_prov in (True, False): + self.do_keep_intervals(tables, intervals, simplify, rec_prov) def test_hundred_intervals(self): ts = msprime.simulate( 10, random_seed=self.random_seed, recombination_rate=2, mutation_rate=2) tables = ts.tables intervals = [(x, x + 0.005) for x in np.arange(0.0, 1.0, 0.01)] - self.verify_multi_slice(tables, intervals) - for reset_coords in (True, False): - for simplify in (True, False): - for rec_prov in (True, False): - self.do_dice(tables, intervals, reset_coords, simplify, rec_prov) + for simplify in (True, False): + for rec_prov in (True, False): + self.do_keep_intervals(tables, intervals, simplify, rec_prov) def test_read_only(self): - # tables.dice should not alter the source tables + # tables.keep_intervals should not alter the source tables ts = msprime.simulate(10, random_seed=4, recombination_rate=2, mutation_rate=2) source_tables = ts.tables - source_tables.dice([(0.5, 0.511)]) + source_tables.keep_intervals([(0.5, 0.511)]) self.assertEqual(source_tables, ts.dump_tables()) - source_tables.dice([(0.5, 0.511)], simplify=False) + source_tables.keep_intervals([(0.5, 0.511)], simplify=False) self.assertEqual(source_tables, ts.dump_tables()) def test_regular_intervals(self): @@ -4644,15 +4514,15 @@ def test_regular_intervals(self): for num_intervals in range(2, 10): breaks = np.linspace(0, ts.sequence_length, num=num_intervals) intervals = [(x, x + eps) for x in breaks[:-1]] - self.do_dice(tables, intervals) + self.do_keep_intervals(tables, intervals) def test_empty_tables(self): tables = tskit.TableCollection(1.0) - diced = self.do_dice(tables, [(0, 1)]) + diced = self.do_keep_intervals(tables, [(0, 1)]) self.assertEqual(diced.sequence_length, 1) self.assertEqual(len(diced.edges), 0) for intervals in self.example_intervals(tables): - diced = self.do_dice(tables, [(0, 1)]) + diced = self.do_keep_intervals(tables, [(0, 1)]) self.assertEqual(len(diced.edges), 0) self.assertAlmostEqual( diced.sequence_length, sum(r - l for l, r in intervals)) diff --git a/python/tskit/tables.py b/python/tskit/tables.py index b59ada247d..51e4cf1f83 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -1802,27 +1802,20 @@ def deduplicate_sites(self): self.ll_tables.deduplicate_sites() # TODO add provenance - def slice( - self, start=None, stop=None, reset_coordinates=True, simplify=True, - record_provenance=True): + def delete_intervals(self, intervals, simplify=True, record_provenance=True): + # TODO document + return self.keep_intervals( + negate_intervals(intervals, 0, self.sequence_length), + simplify=simplify, record_provenance=record_provenance) + + def keep_intervals(self, intervals, simplify=True, record_provenance=True): """ Returns a copy of this set of tables which include only information in - the genomic interval between ``start`` and ``stop``. Edges are truncated - to this interval, and sites not covered by the sliced region are discarded. - - :param float start: The leftmost genomic position, giving the start point - of the kept region. Tree sequence information along the genome prior to (but - not including) this point will be discarded. If None, set equal to zero. - :param float stop: The rightmost genomic position, giving the end point of the - kept region. Tree sequence information at this point and further along the - genomic sequence will be discarded. If None, is set equal to the - TableCollection's ``sequence_length``. - :param bool reset_coordinates: Reset the genomic coordinates such that position - 0 in the tables after this operation corresponds to the current position - ``start`` and the returned tree sequence has sequence length - ``stop``-``start``. Sites and tree intervals will all have their positions - shifted to reflect the new coordinate system. If ``False``, coordinates are - before and after this operation are directly comparable (Default: True). + the specified list of genomic intervals. Edges are truncated to within + these intervals, and sites not falling within these intervals are + discarded. + + :param array_like intervals: TODO description of genomic intervals. :param bool simplify: If True, run simplify on the tables so that nodes no longer used are discarded. (Default: True). :param bool record_provenance: If True, record details of this operation @@ -1830,19 +1823,6 @@ def slice( (Default: True). :rtype: tskit.TableCollection """ - if start is None: - start = 0 - if stop is None: - stop = self.sequence_length - if start is None and stop is None: - raise ValueError("At least one of start or stop coordinates required") - return self.dice( - [(start, stop)], reset_coordinates=reset_coordinates, simplify=simplify, - record_provenance=record_provenance) - - def dice( - self, intervals, reset_coordinates=True, simplify=True, - record_provenance=True): def keep_with_offset(keep, data, offset): lens = np.diff(offset) @@ -1851,22 +1831,8 @@ def keep_with_offset(keep, data, offset): np.array([0], dtype=offset.dtype), np.cumsum(lens[keep], dtype=offset.dtype)])) - # Not strictly necessary for the current implementation, but if we want to push - # this down into C, it'll be much more convenient to interpret the input as a - # 2D numpy array. So, we do this to ensure forward compatability. - intervals = np.array(intervals) - if len(intervals.shape) != 2 and intervals.shape[1] != 2: - raise ValueError("Intervals must be a list of 2-tuples or 2D numpy array") - - last_right = 0 - for left, right in intervals: - if left < 0 or right > self.sequence_length: - raise ValueError( - "Slice bounds must be within the existing tree sequence") - if right <= left: - raise ValueError("Bad dice interval") - if left < last_right: - raise ValueError("Intervals must be disjoint.") + intervals = intervals_to_np_array(intervals, 0, self.sequence_length) + tables = self.copy() sites = self.sites edges = self.edges @@ -1876,7 +1842,6 @@ def keep_with_offset(keep, data, offset): tables.mutations.clear() keep_sites = np.repeat(False, sites.num_rows) keep_mutations = np.repeat(False, mutations.num_rows) - offset = 0 for s, e in intervals: curr_keep_sites = np.logical_and(sites.position >= s, sites.position < e) keep_sites = np.logical_or(keep_sites, curr_keep_sites) @@ -1887,40 +1852,21 @@ def keep_with_offset(keep, data, offset): keep_mutations = np.logical_or( keep_mutations, curr_keep_sites[mutations.site]) keep_edges = np.logical_not(np.logical_or(edges.right <= s, edges.left >= e)) - if reset_coordinates: - tables.edges.append_columns( - left=offset + np.fmax(s, edges.left[keep_edges]) - s, - right=offset + np.fmin(e, edges.right[keep_edges]) - s, - parent=edges.parent[keep_edges], - child=edges.child[keep_edges]) - tables.sites.append_columns( - position=offset + sites.position[curr_keep_sites] - s, - ancestral_state=new_as, - ancestral_state_offset=new_as_offset, - metadata=new_md, - metadata_offset=new_md_offset) - offset += e - s - else: - tables.edges.append_columns( - left=np.fmax(s, edges.left[keep_edges]), - right=np.fmin(e, edges.right[keep_edges]), - parent=edges.parent[keep_edges], - child=edges.child[keep_edges]) - tables.sites.append_columns( - position=sites.position[curr_keep_sites], - ancestral_state=new_as, - ancestral_state_offset=new_as_offset, - metadata=new_md, - metadata_offset=new_md_offset) + tables.edges.append_columns( + left=np.fmax(s, edges.left[keep_edges]), + right=np.fmin(e, edges.right[keep_edges]), + parent=edges.parent[keep_edges], + child=edges.child[keep_edges]) + tables.sites.append_columns( + position=sites.position[curr_keep_sites], + ancestral_state=new_as, + ancestral_state_offset=new_as_offset, + metadata=new_md, + metadata_offset=new_md_offset) new_ds, new_ds_offset = keep_with_offset( keep_mutations, mutations.derived_state, mutations.derived_state_offset) new_md, new_md_offset = keep_with_offset( keep_mutations, mutations.metadata, mutations.metadata_offset) - if reset_coordinates: - if len(tables.edges) > 0: - tables.sequence_length = max(offset, np.max(tables.edges.right)) - else: - tables.sequence_length = offset site_map = np.cumsum(keep_sites, dtype=mutations.site.dtype) - 1 tables.mutations.set_columns( site=site_map[mutations.site[keep_mutations]], @@ -1936,8 +1882,8 @@ def keep_with_offset(keep, data, offset): if record_provenance: # TODO replace with a version of https://github.com/tskit-dev/tskit/pull/243 parameters = { - "command": "dice", - "TODO": "add dice parameters" + "command": "keep_slices", + "TODO": "add parameters" } tables.provenances.add_row(record=json.dumps( provenance.get_provenance_dict(parameters))) @@ -2042,3 +1988,52 @@ def unpack_strings(packed, offset, encoding="utf8"): :rtype: list[str] """ return [b.decode(encoding) for b in unpack_bytes(packed, offset)] + + +#################### +# Interval utilities +#################### + + +def intervals_to_np_array(intervals, start, end): + """ + Converts the specified intervals to a numpy array and checks for + errors. + """ + intervals = np.array(intervals, dtype=np.float64) + # Special case the empty list of intervals + if len(intervals) == 0: + intervals = np.zeros((0, 2), dtype=np.float64) + if len(intervals.shape) != 2: + raise ValueError("Intervals must be a 2D numpy array") + if intervals.shape[1] != 2: + raise ValueError("Intervals array shape must be (N, 2)") + # TODO do this with numpy operations. + last_right = start + for left, right in intervals: + if left < start or right > end: + raise ValueError( + "Intervals must be within {} and {}".format(start, end)) + if right <= left: + raise ValueError("Bad interval: right <= left") + if left < last_right: + raise ValueError("Intervals must be disjoint.") + last_right = right + return intervals + + +def negate_intervals(intervals, start, end): + """ + Returns the set of intervals *not* covered by the specified set of + disjoint intervals in the specfied range. + """ + intervals = intervals_to_np_array(intervals, start, end) + other_intervals = [] + last_right = start + for left, right in intervals: + if left != last_right: + other_intervals.append((last_right, left)) + last_right = right + if last_right != end: + other_intervals.append((last_right, end)) + return np.array(other_intervals)