diff --git a/python/test_dice.py b/python/test_dice.py deleted file mode 100644 index 419baaeed1..0000000000 --- a/python/test_dice.py +++ /dev/null @@ -1,157 +0,0 @@ -import msprime -import numpy as np -import tskit - -def find_interval_left(x, breaks): - """ - Given a location x and an increasing vector of breakpoints breaks, - return the index k such that breaks[k] <= x < breaks[k+1], - returning -1 if x < breaks[0] and len(breaks)-1 if x >= breaks[-1]. - """ - if x < breaks[0]: - return -1 - if x >= breaks[-1]: - return len(breaks) - 1 - i = 0 - j = len(breaks) - 1 - while i + 1 < j: - k = int((i + j)/2) - if breaks[k] <= x: - i = k - else: - j = k - return i - -breaks = np.array([0.0, 1.0, 3.5, 5.0]) -for x, y in [(0.0, 0), - (0.5, 0), - (1.0, 1), - (4.0, 2), - (5.0, 3), - (100, 3)]: - assert find_interval_left(x, breaks) == y - - -def find_interval_right(x, breaks): - """ - Given a location x and an increasing vector of breakpoints breaks, - return the index k such that breaks[k] < x <= breaks[k+1]. - returning -1 if x <= breaks[0] and len(breaks)-1 if x > breaks[-1]. - """ - if x <= breaks[0]: - return -1 - if x > breaks[-1]: - return len(breaks) - 1 - i = 0 - j = len(breaks) - 1 - while i + 1 < j: - k = int((i + j)/2) - if breaks[k] < x: - i = k - else: - j = k - return i - -breaks = np.array([0.0, 1.0, 3.5, 5.0]) -for x, y in [(0.0, -1), - (0.5, 0), - (1.0, 0), - (1.2, 1), - (4.0, 2)]: - assert find_interval_right(x, breaks) == y - -def interval_index(x, starts, ends): - """ - Returns the index of the interval that the position x lies in, - or -1 if it does not lie in an interval. - """ - i = find_interval_left(x, starts) - j = find_interval_left(x, ends) - if j >= i: - out = -1 - else: - out = i - return out - -starts = np.array([0.0, 1.0, 3.5, 5.0]) -ends = np.array([0.5, 2.0, 4.0, 6.0]) -for x, y in [(-1.0, -1), - (0.0, 0), - (0.25, 0), - (0.5, -1), - (1.0, 1), - (1.2, 1), - (5.2, 3), - (8.0, -1)]: - assert interval_index(x, starts, ends) == y - -def do_overlap(segment, starts, ends): - """ - Given a segment = [left, right), yield the segments - found by intersecting it with the intervals described by starts, ends, - which should be sorted and nonoverlapping. - """ - assert(len(starts) == len(ends)) - assert(np.all(starts < ends)) - assert(np.all(ends[:-1] < starts[1:])) - left, right = segment - # the index of the first interval that ends at or before `left` - a = find_interval_left(left, ends) - # the index of the first interval that starts before `right` - b = find_interval_right(right, starts) - for k in range(a+1, b+1): - yield (max(left, starts[k]), min(right, ends[k])) - -starts = np.array([0.0, 1.0, 3.5, 5.0]) -ends = np.array([0.5, 2.0, 4.0, 6.0]) -for x, y in [((0.0, 0.5), [(0.0, 0.5)]), - ((0.0, 0.7), [(0.0, 0.5)]), - ((0.0, 1.2), [(0.0, 0.5), (1.0, 1.2)]), - ((0.4, 1.2), [(0.4, 0.5), (1.0, 1.2)]), - ((0.5, 1.0), []), - ((0.5, 1.2), [(1.0, 1.2)]), - ((0.6, 4.2), [(1.0, 2.0), (3.5, 4.0)]), - ((-1.0, 6.2), [(0.0, 0.5), (1.0, 2.0), (3.5, 4.0), (5.0, 6.0)])]: - out = list(do_overlap(x, starts, ends)) - assert len(y) == len(out) - for a, b in zip(y, out): - assert a == b - - -def dice(ts, starts, ends): - """ - Remove edges and sites of the tree sequence that do *not* lie in the collection - of half-open intervals [s, e) given by starts, ends. - """ - assert(len(starts) == len(ends)) - assert(np.all(starts < ends)) - assert(np.all(ends[:-1] < starts[1:])) - tables = ts.tables - tables.edges.clear() - for e in ts.edges(): - for l,r in do_overlap((e.left, e.right), starts, ends): - tables.edges.add_row(left=l, right=r, parent=e.parent, child=e.child) - tables.sites.clear() - site_map = np.repeat(-1, ts.num_sites) - for i, s in enumerate(ts.sites()): - if interval_index(s.position, starts, ends) >= 0: - j = tables.sites.add_row(s.position, s.ancestral_state, s.metadata) - site_map[i] = j - tables.mutations.clear() - for m in ts.mutations(): - s = site_map[m.site] - if s >= 0: - tables.mutations.add_row( - site=s, node=m.node, - derived_state=m.derived_state, - parent=-1, - metadata=m.metadata) - tables.build_index() - tables.compute_mutation_parents() - return tables - - -ts = msprime.simulate(10, recombination_rate=2, mutation_rate=1, length=10, random_seed=23) - -sub_tables = dice(ts, starts=[1.0, 5.0], ends=[2.0, 10.0]) -sub_ts = sub_tables.tree_sequence() diff --git a/python/tests/test_genotypes.py b/python/tests/test_genotypes.py index 4193646045..fd1fecd6c3 100644 --- a/python/tests/test_genotypes.py +++ b/python/tests/test_genotypes.py @@ -410,9 +410,8 @@ def test_jukes_cantor_n20(self): def test_zero_edge_missing_data(self): ts = msprime.simulate(10, random_seed=2, mutation_rate=2) - ts = ts.slice(0.25, 0.75, reset_coordinates=False) + tables = ts.tables.keep_intervals([[0.25, 0.75]]) # add some sites in the deleted regions - tables = ts.dump_tables() tables.sites.add_row(0.1, "A") tables.sites.add_row(0.2, "A") tables.sites.add_row(0.8, "A") diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index 0003aa46f6..c379d35143 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -1949,56 +1949,3 @@ def test_asdict_not_implemented(self): t = tskit.BaseTable(None, None) with self.assertRaises(NotImplementedError): t.asdict() - - -class TestIntervalOps(unittest.TestCase): - """ - Test cases for the interval operations used in masks and slicing operations. - """ - def test_bad_intervals(self): - for bad_type in [{}, Exception]: - with self.assertRaises(TypeError): - tskit.intervals_to_np_array(bad_type, 0, 1) - for bad_depth in [[[[]]]]: - with self.assertRaises(ValueError): - tskit.intervals_to_np_array(bad_depth, 0, 1) - for bad_shape in [[[0], [0]], [[[0, 1, 2], [0, 1]]]]: - with self.assertRaises(ValueError): - tskit.intervals_to_np_array(bad_shape, 0, 1) - - # Out of bounds - with self.assertRaises(ValueError): - tskit.intervals_to_np_array([[-1, 0]], 0, 1) - with self.assertRaises(ValueError): - tskit.intervals_to_np_array([[0, 1]], 1, 2) - with self.assertRaises(ValueError): - tskit.intervals_to_np_array([[0, 1]], 0, 0.5) - - # Overlapping intervals - with self.assertRaises(ValueError): - tskit.intervals_to_np_array([[0, 1], [0.9, 2.0]], 0, 10) - - # Empty intervals - for bad_interval in [[0, 0], [1, 0]]: - with self.assertRaises(ValueError): - tskit.intervals_to_np_array([bad_interval], 0, 10) - - def test_empty_interval_list(self): - intervals = tskit.intervals_to_np_array([], 0, 10) - self.assertEqual(len(intervals), 0) - - def test_negate_intervals(self): - L = 10 - cases = [ - ([], [[0, L]]), - ([[0, 5], [6, L]], [[5, 6]]), - ([[0, 5]], [[5, L]]), - ([[5, L]], [[0, 5]]), - ( - [[0, 1], [2, 3], [3, 4], [5, 6]], - [[1, 2], [4, 5], [6, L]] - ), - ] - for source, dest in cases: - self.assertTrue(np.array_equal( - tskit.negate_intervals(source, 0, L), dest)) diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index 690c8fe79b..f5d6bfc0a2 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -71,7 +71,9 @@ def simple_keep_intervals(tables, intervals, simplify=True, record_provenance=Tr site.position, site.ancestral_state, site.metadata) for m in site.mutations: tables.mutations.add_row( - site_id, m.node, m.derived_state, m.parent, m.metadata) + site_id, m.node, m.derived_state, tskit.NULL, m.metadata) + tables.build_index() + tables.compute_mutation_parents() if simplify: tables.simplify() if record_provenance: @@ -4357,8 +4359,7 @@ def test_edge_cases(self): self.assertEqual(search_sorted([1], v), np.searchsorted([1], v)) -@unittest.skip("TODO refactor these tests") -class TestSlice(TopologyTestCase): +class TestKeepSingleInterval(unittest.TestCase): """ Tests for cutting up tree sequences along the genome. """ @@ -4367,61 +4368,57 @@ def test_slice_by_tree_positions(self): breakpoints = list(ts.breakpoints()) # Keep the last 3 trees (from 4th last breakpoint onwards) - ts_sliced = ts.tables.slice(start=breakpoints[-4]).tree_sequence() - self.assertEqual(ts_sliced.num_trees, 3) + ts_sliced = ts.tables.keep_intervals( + [[breakpoints[-4], ts.sequence_length]]).tree_sequence() + self.assertEqual(ts_sliced.num_trees, 4) self.assertLess(ts_sliced.num_edges, ts.num_edges) - self.assertAlmostEqual(ts_sliced.sequence_length, 1.0 - breakpoints[-4]) + self.assertAlmostEqual(ts_sliced.sequence_length, 1.0) last_3_mutations = 0 for tree_index in range(-3, 0): last_3_mutations += ts.at_index(tree_index).num_mutations self.assertEqual(ts_sliced.num_mutations, last_3_mutations) # Keep the first 3 trees - ts_sliced = ts.tables.slice(stop=breakpoints[3]).tree_sequence() - self.assertEqual(ts_sliced.num_trees, 3) + ts_sliced = ts.tables.keep_intervals( + [[0, breakpoints[3]]]).tree_sequence() + self.assertEqual(ts_sliced.num_trees, 4) self.assertLess(ts_sliced.num_edges, ts.num_edges) - self.assertAlmostEqual(ts_sliced.sequence_length, breakpoints[3]) + self.assertAlmostEqual(ts_sliced.sequence_length, 1) first_3_mutations = 0 for tree_index in range(0, 3): first_3_mutations += ts.at_index(tree_index).num_mutations self.assertEqual(ts_sliced.num_mutations, first_3_mutations) # Slice out the middle - ts_sliced = ts.tables.slice(breakpoints[3], breakpoints[-4]).tree_sequence() - self.assertEqual(ts_sliced.num_trees, ts.num_trees - 6) + ts_sliced = ts.tables.keep_intervals( + [[breakpoints[3], breakpoints[-4]]]).tree_sequence() + self.assertEqual(ts_sliced.num_trees, ts.num_trees - 4) self.assertLess(ts_sliced.num_edges, ts.num_edges) - self.assertAlmostEqual( - ts_sliced.sequence_length, breakpoints[-4] - breakpoints[3]) + self.assertAlmostEqual(ts_sliced.sequence_length, 1.0) self.assertEqual( ts_sliced.num_mutations, ts.num_mutations - first_3_mutations - last_3_mutations) def test_slice_by_position(self): ts = msprime.simulate(5, random_seed=1, recombination_rate=2, mutation_rate=2) - ts_sliced = ts.tables.slice(0.4, 0.6).tree_sequence() + ts_sliced = ts.tables.keep_intervals([[0.4, 0.6]]).tree_sequence() positions = ts.tables.sites.position self.assertEqual( ts_sliced.num_sites, np.sum((positions >= 0.4) & (positions < 0.6))) - def test_slice_bounds(self): - ts = msprime.simulate(5, random_seed=1, recombination_rate=2, mutation_rate=2) - tables = ts.tables - self.assertRaises(ValueError, tables.slice, -1) - self.assertRaises(ValueError, tables.slice, stop=2) - self.assertRaises(ValueError, tables.slice, 0.8, 0.2) - def test_slice_unsimplified(self): ts = msprime.simulate(5, random_seed=1, recombination_rate=2, mutation_rate=2) - ts_sliced = ts.tables.slice(0.4, 0.6, simplify=True).tree_sequence() + ts_sliced = ts.tables.keep_intervals([[0.4, 0.6]], simplify=True).tree_sequence() self.assertNotEqual(ts.num_nodes, ts_sliced.num_nodes) - self.assertAlmostEqual(ts_sliced.sequence_length, 0.2) - ts_sliced = ts.tables.slice(0.4, 0.6, simplify=False).tree_sequence() + self.assertAlmostEqual(ts_sliced.sequence_length, 1.0) + ts_sliced = ts.tables.keep_intervals( + [[0.4, 0.6]], simplify=False).tree_sequence() self.assertEqual(ts.num_nodes, ts_sliced.num_nodes) - self.assertAlmostEqual(ts_sliced.sequence_length, 0.2) + self.assertAlmostEqual(ts_sliced.sequence_length, 1.0) - def test_slice_keep_coordinates(self): + def test_slice_coordinates(self): ts = msprime.simulate(5, random_seed=1, recombination_rate=2, mutation_rate=2) - ts_sliced = ts.tables.slice(0.4, 0.6, reset_coordinates=False).tree_sequence() + ts_sliced = ts.tables.keep_intervals([[0.4, 0.6]]).tree_sequence() self.assertAlmostEqual(ts_sliced.sequence_length, 1) self.assertNotEqual(ts_sliced.num_trees, ts.num_trees) self.assertEqual(ts_sliced.at_index(0).total_branch_length, 0) @@ -4442,7 +4439,13 @@ class TestKeepIntervals(TopologyTestCase): """ def example_intervals(self, tables): L = tables.sequence_length + yield [] yield [(0, L)] + yield [(0, L / 2), (L / 2, L)] + yield [(0, 0.25 * L), (0.75 * L, L)] + yield [(0.25 * L, L)] + yield [(0.25 * L, 0.5 * L)] + yield [(0.25 * L, 0.5 * L), (0.75 * L, 0.8 * L)] def do_keep_intervals( self, tables, intervals, simplify=True, record_provenance=True): @@ -4454,13 +4457,29 @@ def do_keep_intervals( # (this is a hack, as we prob want to compare their contents) t1.provenances.clear() t2.provenances.clear() - # print(t1.edges) - # print(t2.edges) - # print(t1.edges == t2.edges) - # print(t1.sequence_length == t2.sequence_length) self.assertEqual(t1, t2) return t3 + def test_migration_error(self): + tables = tskit.TableCollection(1) + tables.migrations.add_row(0, 1, 0, 0, 0, 0) + with self.assertRaises(ValueError): + tables.keep_intervals([[0, 1]]) + + def test_bad_intervals(self): + tables = tskit.TableCollection(10) + bad_intervals = [ + [[1, 1]], + [[-1, 0]], + [[0, 11]], + [[0, 5], [4, 6]] + ] + for intervals in bad_intervals: + with self.assertRaises(ValueError): + tables.keep_intervals(intervals) + with self.assertRaises(ValueError): + tables.delete_intervals(intervals) + def test_one_interval(self): ts = msprime.simulate( 10, random_seed=self.random_seed, recombination_rate=2, mutation_rate=2) @@ -4516,30 +4535,116 @@ def test_regular_intervals(self): intervals = [(x, x + eps) for x in breaks[:-1]] self.do_keep_intervals(tables, intervals) - def test_empty_tables(self): - tables = tskit.TableCollection(1.0) - diced = self.do_keep_intervals(tables, [(0, 1)]) - self.assertEqual(diced.sequence_length, 1) - self.assertEqual(len(diced.edges), 0) - for intervals in self.example_intervals(tables): - diced = self.do_keep_intervals(tables, [(0, 1)]) - self.assertEqual(len(diced.edges), 0) - self.assertAlmostEqual( - diced.sequence_length, sum(r - l for l, r in intervals)) - - @unittest.skip("Sites not working - not sure why not") def test_no_edges_sites(self): tables = tskit.TableCollection(1.0) tables.sites.add_row(0.1, "A") tables.sites.add_row(0.2, "T") - diced = self.do_dice(tables, [(0, 1)]) - self.assertEqual(diced.sequence_length, 1) - self.assertEqual(len(diced.edges), 0) - print(tables.sites) - print(diced.sites) - self.assertEqual(tables.sites, diced.sites) for intervals in self.example_intervals(tables): - diced = self.do_dice(tables, [(0, 1)]) + self.assertEqual(len(tables.sites), 2) + diced = self.do_keep_intervals(tables, intervals) + self.assertEqual(diced.sequence_length, 1) self.assertEqual(len(diced.edges), 0) - self.assertAlmostEqual( - diced.sequence_length, sum(r - l for l, r in intervals)) + self.assertEqual(len(diced.sites), 0) + + def verify(self, tables): + for intervals in self.example_intervals(tables): + for simplify in [True, False]: + self.do_keep_intervals(tables, intervals, simplify=simplify) + + def test_empty_tables(self): + tables = tskit.TableCollection(1.0) + self.verify(tables) + + def test_single_tree_jukes_cantor(self): + ts = msprime.simulate(6, random_seed=1, mutation_rate=1) + ts = tsutil.jukes_cantor(ts, 20, 1, seed=10) + self.verify(ts.tables) + + def test_single_tree_multichar_mutations(self): + ts = msprime.simulate(6, random_seed=1, mutation_rate=1) + ts = tsutil.insert_multichar_mutations(ts) + self.verify(ts.tables) + + def test_many_trees_infinite_sites(self): + ts = msprime.simulate(6, recombination_rate=2, mutation_rate=2, random_seed=1) + self.assertGreater(ts.num_sites, 0) + self.assertGreater(ts.num_trees, 2) + self.verify(ts.tables) + + def test_many_trees_sequence_length_infinite_sites(self): + for L in [0.5, 1.5, 3.3333]: + ts = msprime.simulate( + 6, length=L, recombination_rate=2, mutation_rate=1, random_seed=1) + self.verify(ts.tables) + + def test_wright_fisher_unsimplified(self): + tables = wf.wf_sim( + 4, 5, seed=1, deep_history=True, initial_generation_samples=False, + num_loci=10) + tables.sort() + ts = msprime.mutate(tables.tree_sequence(), rate=0.05, random_seed=234) + self.assertGreater(ts.num_sites, 0) + self.verify(ts.tables) + + def test_wright_fisher_initial_generation(self): + tables = wf.wf_sim( + 6, 5, seed=3, deep_history=True, initial_generation_samples=True, + num_loci=2) + tables.sort() + tables.simplify() + ts = msprime.mutate(tables.tree_sequence(), rate=0.08, random_seed=2) + self.assertGreater(ts.num_sites, 0) + self.verify(ts.tables) + + def test_wright_fisher_initial_generation_no_deep_history(self): + tables = wf.wf_sim( + 7, 15, seed=202, deep_history=False, initial_generation_samples=True, + num_loci=5) + tables.sort() + tables.simplify() + ts = msprime.mutate(tables.tree_sequence(), rate=0.01, random_seed=2) + self.assertGreater(ts.num_sites, 0) + self.verify(ts.tables) + + def test_wright_fisher_unsimplified_multiple_roots(self): + tables = wf.wf_sim( + 8, 15, seed=1, deep_history=False, initial_generation_samples=False, + num_loci=20) + tables.sort() + ts = msprime.mutate(tables.tree_sequence(), rate=0.006, random_seed=2) + self.assertGreater(ts.num_sites, 0) + self.verify(ts.tables) + + def test_wright_fisher_simplified(self): + tables = wf.wf_sim( + 9, 10, seed=1, deep_history=True, initial_generation_samples=False, + num_loci=5) + tables.sort() + ts = tables.tree_sequence().simplify() + ts = msprime.mutate(ts, rate=0.01, random_seed=1234) + self.assertGreater(ts.num_sites, 0) + self.verify(ts.tables) + + +class TestKeepDeleteIntervalsExamples(unittest.TestCase): + """ + Simple examples of keep/delete intervals at work. + """ + + def test_single_tree_keep_middle(self): + ts = msprime.simulate(10, random_seed=2) + tables = ts.tables + t_keep = tables.keep_intervals([[0.25, 0.5]]) + t_delete = tables.delete_intervals([[0, 0.25], [0.5, 1.0]]) + t_keep.provenances.clear() + t_delete.provenances.clear() + self.assertEqual(t_keep, t_delete) + + def test_single_tree_delete_middle(self): + ts = msprime.simulate(10, random_seed=2) + tables = ts.tables + t_keep = tables.delete_intervals([[0.25, 0.5]]) + t_delete = tables.keep_intervals([[0, 0.25], [0.5, 1.0]]) + t_keep.provenances.clear() + t_delete.provenances.clear() + self.assertEqual(t_keep, t_delete) diff --git a/python/tests/test_util.py b/python/tests/test_util.py index b591fe1bc8..a41a85f350 100644 --- a/python/tests/test_util.py +++ b/python/tests/test_util.py @@ -50,9 +50,8 @@ def test_basic_arrays(self): self.assertEqual(pickle.dumps(converted), pickle.dumps(target)) def test_copy(self): - """ - Check that a copy is not returned if copy=False & the original matches the specs - """ + # Check that a copy is not returned if copy=False & the original matches + # the specs for dtype in self.dtypes_to_test: for orig in (np.array([0, 1], dtype=dtype), np.array([], dtype=dtype)): converted = util.safe_np_int_cast(orig, dtype=dtype, copy=True) @@ -66,9 +65,7 @@ def test_copy(self): self.assertNotEqual(id(orig), id(converted)) def test_empty_arrays(self): - """ - Empty arrays of any type (including float) should be allowed - """ + # Empty arrays of any type (including float) should be allowed for dtype in self.dtypes_to_test: target = np.array([], dtype=dtype) converted = util.safe_np_int_cast([], dtype=dtype) @@ -78,9 +75,7 @@ def test_empty_arrays(self): self.assertEqual(pickle.dumps(converted), pickle.dumps(target)) def test_bad_types(self): - """ - Shouldn't be able to convert a float (possibility of rounding error) - """ + # Shouldn't be able to convert a float (possibility of rounding error) for dtype in self.dtypes_to_test: for bad_type in [[0.1], ['str'], {}, [{}], np.array([0, 1], dtype=np.float)]: self.assertRaises(TypeError, util.safe_np_int_cast, bad_type, dtype) @@ -100,3 +95,56 @@ def test_overflow(self): self.assertEqual( # Test numpy array pickle.dumps(target), pickle.dumps(util.safe_np_int_cast(np.array([good_node]), dtype))) + + +class TestIntervalOps(unittest.TestCase): + """ + Test cases for the interval operations used in masks and slicing operations. + """ + def test_bad_intervals(self): + for bad_type in [{}, Exception]: + with self.assertRaises(TypeError): + util.intervals_to_np_array(bad_type, 0, 1) + for bad_depth in [[[[]]]]: + with self.assertRaises(ValueError): + util.intervals_to_np_array(bad_depth, 0, 1) + for bad_shape in [[[0], [0]], [[[0, 1, 2], [0, 1]]]]: + with self.assertRaises(ValueError): + util.intervals_to_np_array(bad_shape, 0, 1) + + # Out of bounds + with self.assertRaises(ValueError): + util.intervals_to_np_array([[-1, 0]], 0, 1) + with self.assertRaises(ValueError): + util.intervals_to_np_array([[0, 1]], 1, 2) + with self.assertRaises(ValueError): + util.intervals_to_np_array([[0, 1]], 0, 0.5) + + # Overlapping intervals + with self.assertRaises(ValueError): + util.intervals_to_np_array([[0, 1], [0.9, 2.0]], 0, 10) + + # Empty intervals + for bad_interval in [[0, 0], [1, 0]]: + with self.assertRaises(ValueError): + util.intervals_to_np_array([bad_interval], 0, 10) + + def test_empty_interval_list(self): + intervals = util.intervals_to_np_array([], 0, 10) + self.assertEqual(len(intervals), 0) + + def test_negate_intervals(self): + L = 10 + cases = [ + ([], [[0, L]]), + ([[0, 5], [6, L]], [[5, 6]]), + ([[0, 5]], [[5, L]]), + ([[5, L]], [[0, 5]]), + ( + [[0, 1], [2, 3], [3, 4], [5, 6]], + [[1, 2], [4, 5], [6, L]] + ), + ] + for source, dest in cases: + self.assertTrue(np.array_equal( + util.negate_intervals(source, 0, L), dest)) diff --git a/python/tskit/__init__.py b/python/tskit/__init__.py index f183d5d06c..5ce56c9643 100644 --- a/python/tskit/__init__.py +++ b/python/tskit/__init__.py @@ -34,3 +34,4 @@ from tskit.tables import * # NOQA from tskit.stats import * # NOQA from tskit.exceptions import * # NOQA +from tskit.util import * # NOQA diff --git a/python/tskit/tables.py b/python/tskit/tables.py index 4abb3ceb96..a9f8715425 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -223,7 +223,7 @@ def __str__(self): flags = self.flags location = self.location location_offset = self.location_offset - metadata = unpack_bytes(self.metadata, self.metadata_offset) + metadata = util.unpack_bytes(self.metadata, self.metadata_offset) ret = "id\tflags\tlocation\tmetadata\n" for j in range(self.num_rows): md = base64.b64encode(metadata[j]).decode('utf8') @@ -413,7 +413,7 @@ def __str__(self): flags = self.flags population = self.population individual = self.individual - metadata = unpack_bytes(self.metadata, self.metadata_offset) + metadata = util.unpack_bytes(self.metadata, self.metadata_offset) ret = "id\tflags\tpopulation\tindividual\ttime\tmetadata\n" for j in range(self.num_rows): md = base64.b64encode(metadata[j]).decode('utf8') @@ -876,9 +876,9 @@ def metadata_offset(self): def __str__(self): position = self.position - ancestral_state = unpack_strings( + ancestral_state = util.unpack_strings( self.ancestral_state, self.ancestral_state_offset) - metadata = unpack_bytes(self.metadata, self.metadata_offset) + metadata = util.unpack_bytes(self.metadata, self.metadata_offset) ret = "id\tposition\tancestral_state\tmetadata\n" for j in range(self.num_rows): md = base64.b64encode(metadata[j]).decode('utf8') @@ -1072,8 +1072,9 @@ def __str__(self): site = self.site node = self.node parent = self.parent - derived_state = unpack_strings(self.derived_state, self.derived_state_offset) - metadata = unpack_bytes(self.metadata, self.metadata_offset) + derived_state = util.unpack_strings( + self.derived_state, self.derived_state_offset) + metadata = util.unpack_bytes(self.metadata, self.metadata_offset) ret = "id\tsite\tnode\tderived_state\tparent\tmetadata\n" for j in range(self.num_rows): md = base64.b64encode(metadata[j]).decode('utf8') @@ -1259,7 +1260,7 @@ def add_row(self, metadata=None): return self.ll_table.add_row(metadata=metadata) def __str__(self): - metadata = unpack_bytes(self.metadata, self.metadata_offset) + metadata = util.unpack_bytes(self.metadata, self.metadata_offset) ret = "id\tmetadata\n" for j in range(self.num_rows): md = base64.b64encode(metadata[j]).decode('utf8') @@ -1370,8 +1371,8 @@ def append_columns( record=record, record_offset=record_offset)) def __str__(self): - timestamp = unpack_strings(self.timestamp, self.timestamp_offset) - record = unpack_strings(self.record, self.record_offset) + timestamp = util.unpack_strings(self.timestamp, self.timestamp_offset) + record = util.unpack_strings(self.record, self.record_offset) ret = "id\ttimestamp\trecord\n" for j in range(self.num_rows): ret += "{}\t{}\t{}\n".format(j, timestamp[j], record[j]) @@ -1803,9 +1804,32 @@ def deduplicate_sites(self): # TODO add provenance def delete_intervals(self, intervals, simplify=True, record_provenance=True): - # TODO document + """ + Returns a copy of this set of tables for which information in the + specified list of genomic intervals has been deleted. Edges spanning + these intervals are truncated or deleted, and sites falling within them are + discarded. + + Note that node IDs may change as a result of this operation, + as by default :meth:`.simplify` is called on the resulting tables to + remove redundant nodes. If you wish to keep node IDs stable between + this set of tables and the returned tables, specify ``simplify=True``. + + See also :meth:`.keep_intervals`. + + :param array_like intervals: A list (start, end) pairs describing the + genomic intervals to delete. Intervals must be non-overlapping and + in increasing order. The list of intervals must be interpretable as a + 2D numpy array with shape (N, 2), where N is the number of intervals. + :param bool simplify: If True, run simplify on the tables so that nodes + no longer used are discarded. (Default: True). + :param bool record_provenance: If True, record details of this operation + in the returned table collection's provenance information. + (Default: True). + :rtype: tskit.TableCollection + """ return self.keep_intervals( - negate_intervals(intervals, 0, self.sequence_length), + util.negate_intervals(intervals, 0, self.sequence_length), simplify=simplify, record_provenance=record_provenance) def keep_intervals(self, intervals, simplify=True, record_provenance=True): @@ -1815,7 +1839,17 @@ def keep_intervals(self, intervals, simplify=True, record_provenance=True): these intervals, and sites not falling within these intervals are discarded. - :param array_like intervals: TODO description of genomic intervals. + Note that node IDs may change as a result of this operation, + as by default :meth:`.simplify` is called on the resulting tables to + remove redundant nodes. If you wish to keep node IDs stable between + this set of tables and the returned tables, specify ``simplify=True``. + + See also :meth:`.delete_intervals`. + + :param array_like intervals: A list (start, end) pairs describing the + genomic intervals to keep. Intervals must be non-overlapping and + in increasing order. The list of intervals must be interpretable as a + 2D numpy array with shape (N, 2), where N is the number of intervals. :param bool simplify: If True, run simplify on the tables so that nodes no longer used are discarded. (Default: True). :param bool record_provenance: If True, record details of this operation @@ -1825,13 +1859,16 @@ def keep_intervals(self, intervals, simplify=True, record_provenance=True): """ def keep_with_offset(keep, data, offset): - lens = np.diff(offset) + # We need the astype here for 32 bit machines + lens = np.diff(offset).astype(np.int32) return (data[np.repeat(keep, lens)], np.concatenate([ np.array([0], dtype=offset.dtype), np.cumsum(lens[keep], dtype=offset.dtype)])) - intervals = intervals_to_np_array(intervals, 0, self.sequence_length) + intervals = util.intervals_to_np_array(intervals, 0, self.sequence_length) + if len(self.migrations) > 0: + raise ValueError("Migrations not supported by keep_intervals") tables = self.copy() sites = self.sites @@ -1873,16 +1910,23 @@ def keep_with_offset(keep, data, offset): node=mutations.node[keep_mutations], derived_state=new_ds, derived_state_offset=new_ds_offset, - parent=mutations.parent[keep_mutations], + # TODO Compute the mutation parents properly here. We're being + # lazy right now and just asking compute_mutation_parents to do + # it for us, but we have to build_index to do this and also + # run compute_mutation_parents. + parent=np.zeros(np.sum(keep_mutations), dtype=np.int32) - 1, metadata=new_md, metadata_offset=new_md_offset) tables.sort() + # See note above on compute_mutation_parents; we don't need these + # two steps if we do it properly. + tables.build_index() + tables.compute_mutation_parents() if simplify: tables.simplify() if record_provenance: - # TODO replace with a version of https://github.com/tskit-dev/tskit/pull/243 parameters = { - "command": "keep_slices", + "command": "keep_intervals", "TODO": "add parameters" } tables.provenances.add_row(record=json.dumps( @@ -1908,132 +1952,3 @@ def drop_index(self): indexed this method has no effect. """ self.ll_tables.drop_index() - - -############################################# -# Table functions. -############################################# - -def pack_bytes(data): - """ - Packs the specified list of bytes into a flattened numpy array of 8 bit integers - and corresponding offsets. See :ref:`sec_encoding_ragged_columns` for details - of this encoding. - - :param list[bytes] data: The list of bytes values to encode. - :return: The tuple (packed, offset) of numpy arrays representing the flattened - input data and offsets. - :rtype: numpy.array (dtype=np.int8), numpy.array (dtype=np.uint32). - """ - n = len(data) - offsets = np.zeros(n + 1, dtype=np.uint32) - for j in range(n): - offsets[j + 1] = offsets[j] + len(data[j]) - column = np.zeros(offsets[-1], dtype=np.int8) - for j, value in enumerate(data): - column[offsets[j]: offsets[j + 1]] = bytearray(value) - return column, offsets - - -def unpack_bytes(packed, offset): - """ - Unpacks a list of bytes from the specified numpy arrays of packed byte - data and corresponding offsets. See :ref:`sec_encoding_ragged_columns` for details - of this encoding. - - :param numpy.ndarray packed: The flattened array of byte values. - :param numpy.ndarray offset: The array of offsets into the ``packed`` array. - :return: The list of bytes values unpacked from the parameter arrays. - :rtype: list[bytes] - """ - # This could be done a lot more efficiently... - ret = [] - for j in range(offset.shape[0] - 1): - raw = packed[offset[j]: offset[j + 1]].tobytes() - ret.append(raw) - return ret - - -def pack_strings(strings, encoding="utf8"): - """ - Packs the specified list of strings into a flattened numpy array of 8 bit integers - and corresponding offsets using the specified text encoding. - See :ref:`sec_encoding_ragged_columns` for details of this encoding of - columns of variable length data. - - :param list[str] data: The list of strings to encode. - :param str encoding: The text encoding to use when converting string data - to bytes. See the :mod:`codecs` module for information on available - string encodings. - :return: The tuple (packed, offset) of numpy arrays representing the flattened - input data and offsets. - :rtype: numpy.array (dtype=np.int8), numpy.array (dtype=np.uint32). - """ - return pack_bytes([bytearray(s.encode(encoding)) for s in strings]) - - -def unpack_strings(packed, offset, encoding="utf8"): - """ - Unpacks a list of strings from the specified numpy arrays of packed byte - data and corresponding offsets using the specified text encoding. - See :ref:`sec_encoding_ragged_columns` for details of this encoding of - columns of variable length data. - - :param numpy.ndarray packed: The flattened array of byte values. - :param numpy.ndarray offset: The array of offsets into the ``packed`` array. - :param str encoding: The text encoding to use when converting string data - to bytes. See the :mod:`codecs` module for information on available - string encodings. - :return: The list of strings unpacked from the parameter arrays. - :rtype: list[str] - """ - return [b.decode(encoding) for b in unpack_bytes(packed, offset)] - - -#################### -# Interval utilities -#################### - - -def intervals_to_np_array(intervals, start, end): - """ - Converts the specified intervals to a numpy array and checks for - errors. - """ - intervals = np.array(intervals, dtype=np.float64) - # Special case the empty list of intervals - if len(intervals) == 0: - intervals = np.zeros((0, 2), dtype=np.float64) - if len(intervals.shape) != 2: - raise ValueError("Intervals must be a 2D numpy array") - if intervals.shape[1] != 2: - raise ValueError("Intervals array shape must be (N, 2)") - # TODO do this with numpy operations. - last_right = start - for left, right in intervals: - if left < start or right > end: - raise ValueError( - "Intervals must be within {} and {}".format(start, end)) - if right <= left: - raise ValueError("Bad interval: right <= left") - if left < last_right: - raise ValueError("Intervals must be disjoint.") - last_right = right - return intervals - - -def negate_intervals(intervals, start, end): - """ - Returns the set of intervals *not* covered by the specified set of - disjoint intervals in the specfied range. - """ - intervals = intervals_to_np_array(intervals, start, end) - other_intervals = [] - last_right = start - for left, right in intervals: - if left != last_right: - other_intervals.append((last_right, left)) - last_right = right - if last_right != end: - other_intervals.append((last_right, end)) - return np.array(other_intervals) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 17edb60a6b..af64938901 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -3206,130 +3206,6 @@ def simplify( else: return new_ts - def slice( - self, start=None, stop=None, reset_coordinates=True, simplify=True, - record_provenance=True): - """ - Truncate this tree sequence to include only information in the genomic interval - between ``start`` and ``stop``. Edges are truncated to this interval, and sites - not covered by the sliced region are thrown away. - - :param float start: The leftmost genomic position, giving the start point - of the kept region. Tree sequence information along the genome prior to (but - not including) this point will be discarded. If None, set equal to zero. - :param float stop: The rightmost genomic position, giving the end point of the - kept region. Tree sequence information at this point and further along the - genomic sequence will be discarded. If None, is set equal to the current tree - sequence's ``sequence_length``. - :param bool reset_coordinates: Reset the genomic coordinates such that position - 0 in the returned tree sequence corresponds to position ``start`` in the - original one, and the returned tree sequence has sequence length - ``stop``-``start``. Sites and tree intervals will all have their positions - shifted to reflect the new coordinate system. If ``False``, do not rescale: - the resulting tree sequence will therefore cover the same genomic span as - the original, but will have tree and site information missing for genomic - regions outside the sliced region. (Default: True) - :param bool simplify: If True, simplify the resulting tree sequence so that nodes - no longer used in the resulting trees are discarded. (Default: True). - :param bool record_provenance: If True, record details of this call to - slice in the returned tree sequence's provenance information. - (Default: True). - :return: The sliced tree sequence. - :rtype: .TreeSequence - """ - def keep_with_offset(keep, data, offset): - # Need to case diff to int32 for 32bit builds - lens = np.diff(offset).astype(np.int32) - return (data[np.repeat(keep, lens)], - np.concatenate([ - np.array([0], dtype=offset.dtype), - np.cumsum(lens[keep], dtype=offset.dtype)])) - if start is None: - start = np.array([0]) - if np.ndim(start) == 0: - start = np.array([start]) - if stop is None: - stop = np.array([self.sequence_length]) - if np.ndim(stop) == 0: - stop = np.array([stop]) - if (np.any(start < 0) or np.any(stop <= start) or - np.any(stop > self.sequence_length)): - raise ValueError("Slice bounds must be within the existing tree sequence") - edges = self.tables.edges - sites = self.tables.sites - mutations = self.tables.mutations - tables = self.dump_tables() - tables.edges.clear() - tables.sites.clear() - keep_sites = np.repeat(False, sites.num_rows) - keep_mutations = np.repeat(False, mutations.num_rows) - tmp_offset = 0 - for s, e in zip(start, stop): - curr_keep_sites = np.logical_and(sites.position >= s, - sites.position < e) - keep_sites = np.logical_or(keep_sites, curr_keep_sites) - new_as, new_as_offset = keep_with_offset( - curr_keep_sites, sites.ancestral_state, sites.ancestral_state_offset) - new_md, new_md_offset = keep_with_offset( - curr_keep_sites, sites.metadata, sites.metadata_offset) - keep_mutations = np.logical_or(keep_mutations, - curr_keep_sites[mutations.site]) - keep_edges = np.logical_not( - np.logical_or(edges.right <= s, edges.left >= e)) - if reset_coordinates: - tables.edges.append_columns( - left=tmp_offset + np.fmax(s, edges.left[keep_edges]) - s, - right=tmp_offset + np.fmin(e, edges.right[keep_edges]) - s, - parent=edges.parent[keep_edges], - child=edges.child[keep_edges]) - tables.sites.append_columns( - position=tmp_offset + sites.position[curr_keep_sites] - s, - ancestral_state=new_as, - ancestral_state_offset=new_as_offset, - metadata=new_md, - metadata_offset=new_md_offset) - tmp_offset += e - s - else: - tables.edges.append_columns( - left=np.fmax(s, edges.left[keep_edges]), - right=np.fmin(e, edges.right[keep_edges]), - parent=edges.parent[keep_edges], - child=edges.child[keep_edges]) - tables.sites.append_columns( - position=sites.position[curr_keep_sites], - ancestral_state=new_as, - ancestral_state_offset=new_as_offset, - metadata=new_md, - metadata_offset=new_md_offset) - new_ds, new_ds_offset = keep_with_offset( - keep_mutations, mutations.derived_state, - mutations.derived_state_offset) - new_md, new_md_offset = keep_with_offset( - keep_mutations, mutations.metadata, mutations.metadata_offset) - if reset_coordinates: - tables.sequence_length = max(tmp_offset, np.max(tables.edges.right)) - site_map = np.cumsum(keep_sites, dtype=mutations.site.dtype) - 1 - tables.mutations.set_columns( - site=site_map[mutations.site[keep_mutations]], - node=mutations.node[keep_mutations], - derived_state=new_ds, - derived_state_offset=new_ds_offset, - parent=mutations.parent[keep_mutations], - metadata=new_md, - metadata_offset=new_md_offset) - tables.sort() - if simplify: - tables.simplify() - if record_provenance: - # TODO replace with a version of https://github.com/tskit-dev/tskit/pull/243 - parameters = { - "command": "slice", - "TODO": "add slice parameters" - } - tables.provenances.add_row(record=json.dumps( - provenance.get_provenance_dict(parameters))) - return tables.tree_sequence() - def draw_svg(self, path=None, **kwargs): # TODO document this method, including semantic details of the # returned SVG object. diff --git a/python/tskit/util.py b/python/tskit/util.py index 7933199db4..6b5a0c80ac 100644 --- a/python/tskit/util.py +++ b/python/tskit/util.py @@ -57,3 +57,131 @@ def safe_np_int_cast(int_array, dtype, copy=False): # Raise a TypeError when we try to convert from, e.g., a float. casting = 'same_kind' return int_array.astype(dtype, casting=casting, copy=copy) + + +# +# Pack/unpack lists of data into flattened numpy arrays. +# + +def pack_bytes(data): + """ + Packs the specified list of bytes into a flattened numpy array of 8 bit integers + and corresponding offsets. See :ref:`sec_encoding_ragged_columns` for details + of this encoding. + + :param list[bytes] data: The list of bytes values to encode. + :return: The tuple (packed, offset) of numpy arrays representing the flattened + input data and offsets. + :rtype: numpy.array (dtype=np.int8), numpy.array (dtype=np.uint32). + """ + n = len(data) + offsets = np.zeros(n + 1, dtype=np.uint32) + for j in range(n): + offsets[j + 1] = offsets[j] + len(data[j]) + column = np.zeros(offsets[-1], dtype=np.int8) + for j, value in enumerate(data): + column[offsets[j]: offsets[j + 1]] = bytearray(value) + return column, offsets + + +def unpack_bytes(packed, offset): + """ + Unpacks a list of bytes from the specified numpy arrays of packed byte + data and corresponding offsets. See :ref:`sec_encoding_ragged_columns` for details + of this encoding. + + :param numpy.ndarray packed: The flattened array of byte values. + :param numpy.ndarray offset: The array of offsets into the ``packed`` array. + :return: The list of bytes values unpacked from the parameter arrays. + :rtype: list[bytes] + """ + # This could be done a lot more efficiently... + ret = [] + for j in range(offset.shape[0] - 1): + raw = packed[offset[j]: offset[j + 1]].tobytes() + ret.append(raw) + return ret + + +def pack_strings(strings, encoding="utf8"): + """ + Packs the specified list of strings into a flattened numpy array of 8 bit integers + and corresponding offsets using the specified text encoding. + See :ref:`sec_encoding_ragged_columns` for details of this encoding of + columns of variable length data. + + :param list[str] data: The list of strings to encode. + :param str encoding: The text encoding to use when converting string data + to bytes. See the :mod:`codecs` module for information on available + string encodings. + :return: The tuple (packed, offset) of numpy arrays representing the flattened + input data and offsets. + :rtype: numpy.array (dtype=np.int8), numpy.array (dtype=np.uint32). + """ + return pack_bytes([bytearray(s.encode(encoding)) for s in strings]) + + +def unpack_strings(packed, offset, encoding="utf8"): + """ + Unpacks a list of strings from the specified numpy arrays of packed byte + data and corresponding offsets using the specified text encoding. + See :ref:`sec_encoding_ragged_columns` for details of this encoding of + columns of variable length data. + + :param numpy.ndarray packed: The flattened array of byte values. + :param numpy.ndarray offset: The array of offsets into the ``packed`` array. + :param str encoding: The text encoding to use when converting string data + to bytes. See the :mod:`codecs` module for information on available + string encodings. + :return: The list of strings unpacked from the parameter arrays. + :rtype: list[str] + """ + return [b.decode(encoding) for b in unpack_bytes(packed, offset)] + + +# +# Interval utilities +# + +def intervals_to_np_array(intervals, start, end): + """ + Converts the specified intervals to a numpy array and checks for + errors. + """ + intervals = np.array(intervals, dtype=np.float64) + # Special case the empty list of intervals + if len(intervals) == 0: + intervals = np.zeros((0, 2), dtype=np.float64) + if len(intervals.shape) != 2: + raise ValueError("Intervals must be a 2D numpy array") + if intervals.shape[1] != 2: + raise ValueError("Intervals array shape must be (N, 2)") + # TODO do this with numpy operations. + last_right = start + for left, right in intervals: + if left < start or right > end: + raise ValueError( + "Intervals must be within {} and {}".format(start, end)) + if right <= left: + raise ValueError("Bad interval: right <= left") + if left < last_right: + raise ValueError("Intervals must be disjoint.") + last_right = right + return intervals + + +def negate_intervals(intervals, start, end): + """ + Returns the set of intervals *not* covered by the specified set of + disjoint intervals in the specfied range. + """ + intervals = intervals_to_np_array(intervals, start, end) + other_intervals = [] + last_right = start + for left, right in intervals: + if left != last_right: + other_intervals.append((last_right, left)) + last_right = right + if last_right != end: + other_intervals.append((last_right, end)) + return np.array(other_intervals)