diff --git a/python/tests/test_genotypes.py b/python/tests/test_genotypes.py index 4193646045..fd1fecd6c3 100644 --- a/python/tests/test_genotypes.py +++ b/python/tests/test_genotypes.py @@ -410,9 +410,8 @@ def test_jukes_cantor_n20(self): def test_zero_edge_missing_data(self): ts = msprime.simulate(10, random_seed=2, mutation_rate=2) - ts = ts.slice(0.25, 0.75, reset_coordinates=False) + tables = ts.tables.keep_intervals([[0.25, 0.75]]) # add some sites in the deleted regions - tables = ts.dump_tables() tables.sites.add_row(0.1, "A") tables.sites.add_row(0.2, "A") tables.sites.add_row(0.8, "A") diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index 2c77991bae..c379d35143 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -1600,7 +1600,22 @@ def test_asdict(self): "provenances": t.provenances.asdict()} d2 = t.asdict() self.assertEqual(set(d1.keys()), set(d2.keys())) - # TODO test the fromdict constructor + + def test_from_dict(self): + ts = msprime.simulate(10, mutation_rate=1, random_seed=1) + t1 = ts.tables + d = { + "sequence_length": t1.sequence_length, + "individuals": t1.individuals.asdict(), + "populations": t1.populations.asdict(), + "nodes": t1.nodes.asdict(), + "edges": t1.edges.asdict(), + "sites": t1.sites.asdict(), + "mutations": t1.mutations.asdict(), + "migrations": t1.migrations.asdict(), + "provenances": t1.provenances.asdict()} + t2 = tskit.TableCollection.fromdict(d) + self.assertEquals(t1, t2) def test_iter(self): def test_iter(table_collection): @@ -1621,6 +1636,21 @@ def test_equals_sequence_length(self): tskit.TableCollection(sequence_length=1), tskit.TableCollection(sequence_length=2)) + def test_copy(self): + pop_configs = [msprime.PopulationConfiguration(5) for _ in range(2)] + migration_matrix = [[0, 1], [1, 0]] + t1 = msprime.simulate( + population_configurations=pop_configs, + migration_matrix=migration_matrix, + mutation_rate=1, + record_migrations=True, + random_seed=100).dump_tables() + t2 = t1.copy() + self.assertIsNot(t1, t2) + self.assertEqual(t1, t2) + t1.edges.clear() + self.assertNotEqual(t1, t2) + def test_equals(self): pop_configs = [msprime.PopulationConfiguration(5) for _ in range(2)] migration_matrix = [[0, 1], [1, 0]] @@ -1637,6 +1667,9 @@ def test_equals(self): record_migrations=True, random_seed=1).dump_tables() self.assertEqual(t1, t1) + self.assertEqual(t1, t1.copy()) + self.assertEqual(t1.copy(), t1) + # The provenances may or may not be equal depending on the clock # precision for record. So clear them first. t1.provenances.clear() diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index a8ee0932ea..f5d6bfc0a2 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -40,61 +40,50 @@ import tests.test_wright_fisher as wf -def slice( - ts, start=None, stop=None, reset_coordinates=True, simplify=True, - record_provenance=True): +def simple_keep_intervals(tables, intervals, simplify=True, record_provenance=True): """ - A clearer but slower implementation of TreeSequence.slice() defined in trees.py + Simple Python implementation of keep_intervals. """ - if start is None: - start = 0 - if stop is None: - stop = ts.sequence_length - - if start < 0 or stop <= start or stop > ts.sequence_length: - raise ValueError("Slice bounds must be within the existing tree sequence") - tables = ts.dump_tables() + ts = tables.tree_sequence() + last_stop = 0 + for start, stop in intervals: + if start < 0 or stop > ts.sequence_length: + raise ValueError("Slice bounds must be within the existing tree sequence") + if start >= stop: + raise ValueError("Interval error: start must be < stop") + if start < last_stop: + raise ValueError("Intervals must be disjoint") + last_stop = stop + tables = tables.copy() tables.edges.clear() tables.sites.clear() tables.mutations.clear() for edge in ts.edges(): - if edge.right <= start or edge.left >= stop: - # This edge is outside the sliced area - do not include it - continue - if reset_coordinates: - tables.edges.add_row( - max(start, edge.left) - start, min(stop, edge.right) - start, - edge.parent, edge.child) - else: - tables.edges.add_row( - max(start, edge.left), min(stop, edge.right), - edge.parent, edge.child) + for interval_left, interval_right in intervals: + if not (edge.right <= interval_left or edge.left >= interval_right): + left = max(interval_left, edge.left) + right = min(interval_right, edge.right) + tables.edges.add_row(left, right, edge.parent, edge.child) for site in ts.sites(): - if start <= site.position < stop: - if reset_coordinates: - site_id = tables.sites.add_row( - site.position - start, site.ancestral_state, site.metadata) - else: + for interval_left, interval_right in intervals: + if interval_left <= site.position < interval_right: site_id = tables.sites.add_row( site.position, site.ancestral_state, site.metadata) - for m in site.mutations: - tables.mutations.add_row( - site_id, m.node, m.derived_state, m.parent, m.metadata) - if reset_coordinates: - tables.sequence_length = stop - start + for m in site.mutations: + tables.mutations.add_row( + site_id, m.node, m.derived_state, tskit.NULL, m.metadata) + tables.build_index() + tables.compute_mutation_parents() if simplify: tables.simplify() if record_provenance: - # TODO add slice arguments here - # TODO also make sure we convert all the arguments so that they are - # definitely JSON encodable. parameters = { - "command": "slice", - "TODO": "add slice parameters" + "command": "keep_intervals", + "TODO": "add parameters" } tables.provenances.add_row(record=json.dumps( provenance.get_provenance_dict(parameters))) - return tables.tree_sequence() + return tables def generate_segments(n, sequence_length=100, seed=None): @@ -4370,89 +4359,66 @@ def test_edge_cases(self): self.assertEqual(search_sorted([1], v), np.searchsorted([1], v)) -class TestSlice(TopologyTestCase): +class TestKeepSingleInterval(unittest.TestCase): """ Tests for cutting up tree sequences along the genome. """ - def test_numpy_vs_basic_slice(self): - ts = msprime.simulate( - 10, random_seed=self.random_seed, recombination_rate=2, mutation_rate=2) - for a, b in zip(np.random.uniform(0, 1, 10), np.random.uniform(0, 1, 10)): - if a != b: - for reset_coords in (True, False): - for simplify in (True, False): - for rec_prov in (True, False): - start = min(a, b) - stop = max(a, b) - x = slice(ts, start, stop, reset_coords, simplify, rec_prov) - y = ts.slice(start, stop, reset_coords, simplify, rec_prov) - t1 = x.dump_tables() - t2 = y.dump_tables() - # Provenances may differ using timestamps, so ignore them - # (this is a hack, as we prob want to compare their contents) - t1.provenances.clear() - t2.provenances.clear() - self.assertEqual(t1, t2) - def test_slice_by_tree_positions(self): ts = msprime.simulate(5, random_seed=1, recombination_rate=2, mutation_rate=2) breakpoints = list(ts.breakpoints()) # Keep the last 3 trees (from 4th last breakpoint onwards) - ts_sliced = ts.slice(start=breakpoints[-4]) - self.assertEqual(ts_sliced.num_trees, 3) + ts_sliced = ts.tables.keep_intervals( + [[breakpoints[-4], ts.sequence_length]]).tree_sequence() + self.assertEqual(ts_sliced.num_trees, 4) self.assertLess(ts_sliced.num_edges, ts.num_edges) - self.assertAlmostEqual(ts_sliced.sequence_length, 1.0 - breakpoints[-4]) + self.assertAlmostEqual(ts_sliced.sequence_length, 1.0) last_3_mutations = 0 for tree_index in range(-3, 0): last_3_mutations += ts.at_index(tree_index).num_mutations self.assertEqual(ts_sliced.num_mutations, last_3_mutations) # Keep the first 3 trees - ts_sliced = ts.slice(stop=breakpoints[3]) - self.assertEqual(ts_sliced.num_trees, 3) + ts_sliced = ts.tables.keep_intervals( + [[0, breakpoints[3]]]).tree_sequence() + self.assertEqual(ts_sliced.num_trees, 4) self.assertLess(ts_sliced.num_edges, ts.num_edges) - self.assertAlmostEqual(ts_sliced.sequence_length, breakpoints[3]) + self.assertAlmostEqual(ts_sliced.sequence_length, 1) first_3_mutations = 0 for tree_index in range(0, 3): first_3_mutations += ts.at_index(tree_index).num_mutations self.assertEqual(ts_sliced.num_mutations, first_3_mutations) # Slice out the middle - ts_sliced = ts.slice(breakpoints[3], breakpoints[-4]) - self.assertEqual(ts_sliced.num_trees, ts.num_trees - 6) + ts_sliced = ts.tables.keep_intervals( + [[breakpoints[3], breakpoints[-4]]]).tree_sequence() + self.assertEqual(ts_sliced.num_trees, ts.num_trees - 4) self.assertLess(ts_sliced.num_edges, ts.num_edges) - self.assertAlmostEqual( - ts_sliced.sequence_length, breakpoints[-4] - breakpoints[3]) + self.assertAlmostEqual(ts_sliced.sequence_length, 1.0) self.assertEqual( ts_sliced.num_mutations, ts.num_mutations - first_3_mutations - last_3_mutations) def test_slice_by_position(self): ts = msprime.simulate(5, random_seed=1, recombination_rate=2, mutation_rate=2) - ts_sliced = ts.slice(0.4, 0.6) + ts_sliced = ts.tables.keep_intervals([[0.4, 0.6]]).tree_sequence() positions = ts.tables.sites.position self.assertEqual( ts_sliced.num_sites, np.sum((positions >= 0.4) & (positions < 0.6))) - def test_slice_bounds(self): - ts = msprime.simulate(5, random_seed=1, recombination_rate=2, mutation_rate=2) - self.assertRaises(ValueError, ts.slice, -1) - self.assertRaises(ValueError, ts.slice, stop=2) - self.assertRaises(ValueError, ts.slice, 0.8, 0.2) - def test_slice_unsimplified(self): ts = msprime.simulate(5, random_seed=1, recombination_rate=2, mutation_rate=2) - ts_sliced = ts.slice(0.4, 0.6, simplify=True) + ts_sliced = ts.tables.keep_intervals([[0.4, 0.6]], simplify=True).tree_sequence() self.assertNotEqual(ts.num_nodes, ts_sliced.num_nodes) - self.assertAlmostEqual(ts_sliced.sequence_length, 0.2) - ts_sliced = ts.slice(0.4, 0.6, simplify=False) + self.assertAlmostEqual(ts_sliced.sequence_length, 1.0) + ts_sliced = ts.tables.keep_intervals( + [[0.4, 0.6]], simplify=False).tree_sequence() self.assertEqual(ts.num_nodes, ts_sliced.num_nodes) - self.assertAlmostEqual(ts_sliced.sequence_length, 0.2) + self.assertAlmostEqual(ts_sliced.sequence_length, 1.0) - def test_slice_keep_coordinates(self): + def test_slice_coordinates(self): ts = msprime.simulate(5, random_seed=1, recombination_rate=2, mutation_rate=2) - ts_sliced = ts.slice(0.4, 0.6, reset_coordinates=False) + ts_sliced = ts.tables.keep_intervals([[0.4, 0.6]]).tree_sequence() self.assertAlmostEqual(ts_sliced.sequence_length, 1) self.assertNotEqual(ts_sliced.num_trees, ts.num_trees) self.assertEqual(ts_sliced.at_index(0).total_branch_length, 0) @@ -4464,3 +4430,221 @@ def test_slice_keep_coordinates(self): self.assertEqual(ts_sliced.at(0.6).total_branch_length, 0) self.assertEqual(ts_sliced.at(0.999).total_branch_length, 0) self.assertEqual(ts_sliced.at_index(-1).total_branch_length, 0) + + +class TestKeepIntervals(TopologyTestCase): + """ + Tests for keep_intervals operation, where we slice out multiple disjoint + intervals concurrently. + """ + def example_intervals(self, tables): + L = tables.sequence_length + yield [] + yield [(0, L)] + yield [(0, L / 2), (L / 2, L)] + yield [(0, 0.25 * L), (0.75 * L, L)] + yield [(0.25 * L, L)] + yield [(0.25 * L, 0.5 * L)] + yield [(0.25 * L, 0.5 * L), (0.75 * L, 0.8 * L)] + + def do_keep_intervals( + self, tables, intervals, simplify=True, record_provenance=True): + t1 = simple_keep_intervals(tables, intervals, simplify, record_provenance) + t2 = tables.keep_intervals(intervals, simplify, record_provenance) + t3 = t2.copy() + self.assertEqual(len(t1.provenances), len(t2.provenances)) + # Provenances may differ using timestamps, so ignore them + # (this is a hack, as we prob want to compare their contents) + t1.provenances.clear() + t2.provenances.clear() + self.assertEqual(t1, t2) + return t3 + + def test_migration_error(self): + tables = tskit.TableCollection(1) + tables.migrations.add_row(0, 1, 0, 0, 0, 0) + with self.assertRaises(ValueError): + tables.keep_intervals([[0, 1]]) + + def test_bad_intervals(self): + tables = tskit.TableCollection(10) + bad_intervals = [ + [[1, 1]], + [[-1, 0]], + [[0, 11]], + [[0, 5], [4, 6]] + ] + for intervals in bad_intervals: + with self.assertRaises(ValueError): + tables.keep_intervals(intervals) + with self.assertRaises(ValueError): + tables.delete_intervals(intervals) + + def test_one_interval(self): + ts = msprime.simulate( + 10, random_seed=self.random_seed, recombination_rate=2, mutation_rate=2) + tables = ts.tables + intervals = [(0.3, 0.7)] + for simplify in (True, False): + for rec_prov in (True, False): + self.do_keep_intervals(tables, intervals, simplify, rec_prov) + + def test_two_intervals(self): + ts = msprime.simulate( + 10, random_seed=self.random_seed, recombination_rate=2, mutation_rate=2) + tables = ts.tables + intervals = [(0.1, 0.2), (0.8, 0.9)] + for simplify in (True, False): + for rec_prov in (True, False): + self.do_keep_intervals(tables, intervals, simplify, rec_prov) + + def test_ten_intervals(self): + ts = msprime.simulate( + 10, random_seed=self.random_seed, recombination_rate=2, mutation_rate=2) + tables = ts.tables + intervals = [(x, x + 0.05) for x in np.arange(0.0, 1.0, 0.1)] + for simplify in (True, False): + for rec_prov in (True, False): + self.do_keep_intervals(tables, intervals, simplify, rec_prov) + + def test_hundred_intervals(self): + ts = msprime.simulate( + 10, random_seed=self.random_seed, recombination_rate=2, mutation_rate=2) + tables = ts.tables + intervals = [(x, x + 0.005) for x in np.arange(0.0, 1.0, 0.01)] + for simplify in (True, False): + for rec_prov in (True, False): + self.do_keep_intervals(tables, intervals, simplify, rec_prov) + + def test_read_only(self): + # tables.keep_intervals should not alter the source tables + ts = msprime.simulate(10, random_seed=4, recombination_rate=2, mutation_rate=2) + source_tables = ts.tables + source_tables.keep_intervals([(0.5, 0.511)]) + self.assertEqual(source_tables, ts.dump_tables()) + source_tables.keep_intervals([(0.5, 0.511)], simplify=False) + self.assertEqual(source_tables, ts.dump_tables()) + + def test_regular_intervals(self): + ts = msprime.simulate( + 3, random_seed=1234, recombination_rate=2, mutation_rate=2) + tables = ts.tables + eps = 0.0125 + for num_intervals in range(2, 10): + breaks = np.linspace(0, ts.sequence_length, num=num_intervals) + intervals = [(x, x + eps) for x in breaks[:-1]] + self.do_keep_intervals(tables, intervals) + + def test_no_edges_sites(self): + tables = tskit.TableCollection(1.0) + tables.sites.add_row(0.1, "A") + tables.sites.add_row(0.2, "T") + for intervals in self.example_intervals(tables): + self.assertEqual(len(tables.sites), 2) + diced = self.do_keep_intervals(tables, intervals) + self.assertEqual(diced.sequence_length, 1) + self.assertEqual(len(diced.edges), 0) + self.assertEqual(len(diced.sites), 0) + + def verify(self, tables): + for intervals in self.example_intervals(tables): + for simplify in [True, False]: + self.do_keep_intervals(tables, intervals, simplify=simplify) + + def test_empty_tables(self): + tables = tskit.TableCollection(1.0) + self.verify(tables) + + def test_single_tree_jukes_cantor(self): + ts = msprime.simulate(6, random_seed=1, mutation_rate=1) + ts = tsutil.jukes_cantor(ts, 20, 1, seed=10) + self.verify(ts.tables) + + def test_single_tree_multichar_mutations(self): + ts = msprime.simulate(6, random_seed=1, mutation_rate=1) + ts = tsutil.insert_multichar_mutations(ts) + self.verify(ts.tables) + + def test_many_trees_infinite_sites(self): + ts = msprime.simulate(6, recombination_rate=2, mutation_rate=2, random_seed=1) + self.assertGreater(ts.num_sites, 0) + self.assertGreater(ts.num_trees, 2) + self.verify(ts.tables) + + def test_many_trees_sequence_length_infinite_sites(self): + for L in [0.5, 1.5, 3.3333]: + ts = msprime.simulate( + 6, length=L, recombination_rate=2, mutation_rate=1, random_seed=1) + self.verify(ts.tables) + + def test_wright_fisher_unsimplified(self): + tables = wf.wf_sim( + 4, 5, seed=1, deep_history=True, initial_generation_samples=False, + num_loci=10) + tables.sort() + ts = msprime.mutate(tables.tree_sequence(), rate=0.05, random_seed=234) + self.assertGreater(ts.num_sites, 0) + self.verify(ts.tables) + + def test_wright_fisher_initial_generation(self): + tables = wf.wf_sim( + 6, 5, seed=3, deep_history=True, initial_generation_samples=True, + num_loci=2) + tables.sort() + tables.simplify() + ts = msprime.mutate(tables.tree_sequence(), rate=0.08, random_seed=2) + self.assertGreater(ts.num_sites, 0) + self.verify(ts.tables) + + def test_wright_fisher_initial_generation_no_deep_history(self): + tables = wf.wf_sim( + 7, 15, seed=202, deep_history=False, initial_generation_samples=True, + num_loci=5) + tables.sort() + tables.simplify() + ts = msprime.mutate(tables.tree_sequence(), rate=0.01, random_seed=2) + self.assertGreater(ts.num_sites, 0) + self.verify(ts.tables) + + def test_wright_fisher_unsimplified_multiple_roots(self): + tables = wf.wf_sim( + 8, 15, seed=1, deep_history=False, initial_generation_samples=False, + num_loci=20) + tables.sort() + ts = msprime.mutate(tables.tree_sequence(), rate=0.006, random_seed=2) + self.assertGreater(ts.num_sites, 0) + self.verify(ts.tables) + + def test_wright_fisher_simplified(self): + tables = wf.wf_sim( + 9, 10, seed=1, deep_history=True, initial_generation_samples=False, + num_loci=5) + tables.sort() + ts = tables.tree_sequence().simplify() + ts = msprime.mutate(ts, rate=0.01, random_seed=1234) + self.assertGreater(ts.num_sites, 0) + self.verify(ts.tables) + + +class TestKeepDeleteIntervalsExamples(unittest.TestCase): + """ + Simple examples of keep/delete intervals at work. + """ + + def test_single_tree_keep_middle(self): + ts = msprime.simulate(10, random_seed=2) + tables = ts.tables + t_keep = tables.keep_intervals([[0.25, 0.5]]) + t_delete = tables.delete_intervals([[0, 0.25], [0.5, 1.0]]) + t_keep.provenances.clear() + t_delete.provenances.clear() + self.assertEqual(t_keep, t_delete) + + def test_single_tree_delete_middle(self): + ts = msprime.simulate(10, random_seed=2) + tables = ts.tables + t_keep = tables.delete_intervals([[0.25, 0.5]]) + t_delete = tables.keep_intervals([[0, 0.25], [0.5, 1.0]]) + t_keep.provenances.clear() + t_delete.provenances.clear() + self.assertEqual(t_keep, t_delete) diff --git a/python/tests/test_util.py b/python/tests/test_util.py index b591fe1bc8..a41a85f350 100644 --- a/python/tests/test_util.py +++ b/python/tests/test_util.py @@ -50,9 +50,8 @@ def test_basic_arrays(self): self.assertEqual(pickle.dumps(converted), pickle.dumps(target)) def test_copy(self): - """ - Check that a copy is not returned if copy=False & the original matches the specs - """ + # Check that a copy is not returned if copy=False & the original matches + # the specs for dtype in self.dtypes_to_test: for orig in (np.array([0, 1], dtype=dtype), np.array([], dtype=dtype)): converted = util.safe_np_int_cast(orig, dtype=dtype, copy=True) @@ -66,9 +65,7 @@ def test_copy(self): self.assertNotEqual(id(orig), id(converted)) def test_empty_arrays(self): - """ - Empty arrays of any type (including float) should be allowed - """ + # Empty arrays of any type (including float) should be allowed for dtype in self.dtypes_to_test: target = np.array([], dtype=dtype) converted = util.safe_np_int_cast([], dtype=dtype) @@ -78,9 +75,7 @@ def test_empty_arrays(self): self.assertEqual(pickle.dumps(converted), pickle.dumps(target)) def test_bad_types(self): - """ - Shouldn't be able to convert a float (possibility of rounding error) - """ + # Shouldn't be able to convert a float (possibility of rounding error) for dtype in self.dtypes_to_test: for bad_type in [[0.1], ['str'], {}, [{}], np.array([0, 1], dtype=np.float)]: self.assertRaises(TypeError, util.safe_np_int_cast, bad_type, dtype) @@ -100,3 +95,56 @@ def test_overflow(self): self.assertEqual( # Test numpy array pickle.dumps(target), pickle.dumps(util.safe_np_int_cast(np.array([good_node]), dtype))) + + +class TestIntervalOps(unittest.TestCase): + """ + Test cases for the interval operations used in masks and slicing operations. + """ + def test_bad_intervals(self): + for bad_type in [{}, Exception]: + with self.assertRaises(TypeError): + util.intervals_to_np_array(bad_type, 0, 1) + for bad_depth in [[[[]]]]: + with self.assertRaises(ValueError): + util.intervals_to_np_array(bad_depth, 0, 1) + for bad_shape in [[[0], [0]], [[[0, 1, 2], [0, 1]]]]: + with self.assertRaises(ValueError): + util.intervals_to_np_array(bad_shape, 0, 1) + + # Out of bounds + with self.assertRaises(ValueError): + util.intervals_to_np_array([[-1, 0]], 0, 1) + with self.assertRaises(ValueError): + util.intervals_to_np_array([[0, 1]], 1, 2) + with self.assertRaises(ValueError): + util.intervals_to_np_array([[0, 1]], 0, 0.5) + + # Overlapping intervals + with self.assertRaises(ValueError): + util.intervals_to_np_array([[0, 1], [0.9, 2.0]], 0, 10) + + # Empty intervals + for bad_interval in [[0, 0], [1, 0]]: + with self.assertRaises(ValueError): + util.intervals_to_np_array([bad_interval], 0, 10) + + def test_empty_interval_list(self): + intervals = util.intervals_to_np_array([], 0, 10) + self.assertEqual(len(intervals), 0) + + def test_negate_intervals(self): + L = 10 + cases = [ + ([], [[0, L]]), + ([[0, 5], [6, L]], [[5, 6]]), + ([[0, 5]], [[5, L]]), + ([[5, L]], [[0, 5]]), + ( + [[0, 1], [2, 3], [3, 4], [5, 6]], + [[1, 2], [4, 5], [6, L]] + ), + ] + for source, dest in cases: + self.assertTrue(np.array_equal( + util.negate_intervals(source, 0, L), dest)) diff --git a/python/tskit/__init__.py b/python/tskit/__init__.py index f183d5d06c..5ce56c9643 100644 --- a/python/tskit/__init__.py +++ b/python/tskit/__init__.py @@ -34,3 +34,4 @@ from tskit.tables import * # NOQA from tskit.stats import * # NOQA from tskit.exceptions import * # NOQA +from tskit.util import * # NOQA diff --git a/python/tskit/tables.py b/python/tskit/tables.py index c682829430..a9f8715425 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -26,6 +26,7 @@ import base64 import collections import datetime +import json import warnings import numpy as np @@ -37,6 +38,7 @@ # can't do this in Py3. import tskit import tskit.util as util +import tskit.provenance as provenance IndividualTableRow = collections.namedtuple( @@ -221,7 +223,7 @@ def __str__(self): flags = self.flags location = self.location location_offset = self.location_offset - metadata = unpack_bytes(self.metadata, self.metadata_offset) + metadata = util.unpack_bytes(self.metadata, self.metadata_offset) ret = "id\tflags\tlocation\tmetadata\n" for j in range(self.num_rows): md = base64.b64encode(metadata[j]).decode('utf8') @@ -411,7 +413,7 @@ def __str__(self): flags = self.flags population = self.population individual = self.individual - metadata = unpack_bytes(self.metadata, self.metadata_offset) + metadata = util.unpack_bytes(self.metadata, self.metadata_offset) ret = "id\tflags\tpopulation\tindividual\ttime\tmetadata\n" for j in range(self.num_rows): md = base64.b64encode(metadata[j]).decode('utf8') @@ -874,9 +876,9 @@ def metadata_offset(self): def __str__(self): position = self.position - ancestral_state = unpack_strings( + ancestral_state = util.unpack_strings( self.ancestral_state, self.ancestral_state_offset) - metadata = unpack_bytes(self.metadata, self.metadata_offset) + metadata = util.unpack_bytes(self.metadata, self.metadata_offset) ret = "id\tposition\tancestral_state\tmetadata\n" for j in range(self.num_rows): md = base64.b64encode(metadata[j]).decode('utf8') @@ -1070,8 +1072,9 @@ def __str__(self): site = self.site node = self.node parent = self.parent - derived_state = unpack_strings(self.derived_state, self.derived_state_offset) - metadata = unpack_bytes(self.metadata, self.metadata_offset) + derived_state = util.unpack_strings( + self.derived_state, self.derived_state_offset) + metadata = util.unpack_bytes(self.metadata, self.metadata_offset) ret = "id\tsite\tnode\tderived_state\tparent\tmetadata\n" for j in range(self.num_rows): md = base64.b64encode(metadata[j]).decode('utf8') @@ -1257,7 +1260,7 @@ def add_row(self, metadata=None): return self.ll_table.add_row(metadata=metadata) def __str__(self): - metadata = unpack_bytes(self.metadata, self.metadata_offset) + metadata = util.unpack_bytes(self.metadata, self.metadata_offset) ret = "id\tmetadata\n" for j in range(self.num_rows): md = base64.b64encode(metadata[j]).decode('utf8') @@ -1368,8 +1371,8 @@ def append_columns( record=record, record_offset=record_offset)) def __str__(self): - timestamp = unpack_strings(self.timestamp, self.timestamp_offset) - record = unpack_strings(self.record, self.record_offset) + timestamp = util.unpack_strings(self.timestamp, self.timestamp_offset) + record = util.unpack_strings(self.record, self.record_offset) ret = "id\ttimestamp\trecord\n" for j in range(self.num_rows): ret += "{}\t{}\t{}\n".format(j, timestamp[j], record[j]) @@ -1576,6 +1579,15 @@ def fromdict(self, tables_dict): tables.provenances.set_columns(**tables_dict["provenances"]) return tables + def copy(self): + """ + Returns a deep copy of this TableCollection. + + :return: A deep copy of this TableCollection. + :rtype: .TableCollection + """ + return TableCollection.fromdict(self.asdict()) + def tree_sequence(self): """ Returns a :class:`TreeSequence` instance with the structure defined by the @@ -1791,6 +1803,136 @@ def deduplicate_sites(self): self.ll_tables.deduplicate_sites() # TODO add provenance + def delete_intervals(self, intervals, simplify=True, record_provenance=True): + """ + Returns a copy of this set of tables for which information in the + specified list of genomic intervals has been deleted. Edges spanning + these intervals are truncated or deleted, and sites falling within them are + discarded. + + Note that node IDs may change as a result of this operation, + as by default :meth:`.simplify` is called on the resulting tables to + remove redundant nodes. If you wish to keep node IDs stable between + this set of tables and the returned tables, specify ``simplify=True``. + + See also :meth:`.keep_intervals`. + + :param array_like intervals: A list (start, end) pairs describing the + genomic intervals to delete. Intervals must be non-overlapping and + in increasing order. The list of intervals must be interpretable as a + 2D numpy array with shape (N, 2), where N is the number of intervals. + :param bool simplify: If True, run simplify on the tables so that nodes + no longer used are discarded. (Default: True). + :param bool record_provenance: If True, record details of this operation + in the returned table collection's provenance information. + (Default: True). + :rtype: tskit.TableCollection + """ + return self.keep_intervals( + util.negate_intervals(intervals, 0, self.sequence_length), + simplify=simplify, record_provenance=record_provenance) + + def keep_intervals(self, intervals, simplify=True, record_provenance=True): + """ + Returns a copy of this set of tables which include only information in + the specified list of genomic intervals. Edges are truncated to within + these intervals, and sites not falling within these intervals are + discarded. + + Note that node IDs may change as a result of this operation, + as by default :meth:`.simplify` is called on the resulting tables to + remove redundant nodes. If you wish to keep node IDs stable between + this set of tables and the returned tables, specify ``simplify=True``. + + See also :meth:`.delete_intervals`. + + :param array_like intervals: A list (start, end) pairs describing the + genomic intervals to keep. Intervals must be non-overlapping and + in increasing order. The list of intervals must be interpretable as a + 2D numpy array with shape (N, 2), where N is the number of intervals. + :param bool simplify: If True, run simplify on the tables so that nodes + no longer used are discarded. (Default: True). + :param bool record_provenance: If True, record details of this operation + in the returned table collection's provenance information. + (Default: True). + :rtype: tskit.TableCollection + """ + + def keep_with_offset(keep, data, offset): + # We need the astype here for 32 bit machines + lens = np.diff(offset).astype(np.int32) + return (data[np.repeat(keep, lens)], + np.concatenate([ + np.array([0], dtype=offset.dtype), + np.cumsum(lens[keep], dtype=offset.dtype)])) + + intervals = util.intervals_to_np_array(intervals, 0, self.sequence_length) + if len(self.migrations) > 0: + raise ValueError("Migrations not supported by keep_intervals") + + tables = self.copy() + sites = self.sites + edges = self.edges + mutations = self.mutations + tables.edges.clear() + tables.sites.clear() + tables.mutations.clear() + keep_sites = np.repeat(False, sites.num_rows) + keep_mutations = np.repeat(False, mutations.num_rows) + for s, e in intervals: + curr_keep_sites = np.logical_and(sites.position >= s, sites.position < e) + keep_sites = np.logical_or(keep_sites, curr_keep_sites) + new_as, new_as_offset = keep_with_offset( + curr_keep_sites, sites.ancestral_state, sites.ancestral_state_offset) + new_md, new_md_offset = keep_with_offset( + curr_keep_sites, sites.metadata, sites.metadata_offset) + keep_mutations = np.logical_or( + keep_mutations, curr_keep_sites[mutations.site]) + keep_edges = np.logical_not(np.logical_or(edges.right <= s, edges.left >= e)) + tables.edges.append_columns( + left=np.fmax(s, edges.left[keep_edges]), + right=np.fmin(e, edges.right[keep_edges]), + parent=edges.parent[keep_edges], + child=edges.child[keep_edges]) + tables.sites.append_columns( + position=sites.position[curr_keep_sites], + ancestral_state=new_as, + ancestral_state_offset=new_as_offset, + metadata=new_md, + metadata_offset=new_md_offset) + new_ds, new_ds_offset = keep_with_offset( + keep_mutations, mutations.derived_state, mutations.derived_state_offset) + new_md, new_md_offset = keep_with_offset( + keep_mutations, mutations.metadata, mutations.metadata_offset) + site_map = np.cumsum(keep_sites, dtype=mutations.site.dtype) - 1 + tables.mutations.set_columns( + site=site_map[mutations.site[keep_mutations]], + node=mutations.node[keep_mutations], + derived_state=new_ds, + derived_state_offset=new_ds_offset, + # TODO Compute the mutation parents properly here. We're being + # lazy right now and just asking compute_mutation_parents to do + # it for us, but we have to build_index to do this and also + # run compute_mutation_parents. + parent=np.zeros(np.sum(keep_mutations), dtype=np.int32) - 1, + metadata=new_md, + metadata_offset=new_md_offset) + tables.sort() + # See note above on compute_mutation_parents; we don't need these + # two steps if we do it properly. + tables.build_index() + tables.compute_mutation_parents() + if simplify: + tables.simplify() + if record_provenance: + parameters = { + "command": "keep_intervals", + "TODO": "add parameters" + } + tables.provenances.add_row(record=json.dumps( + provenance.get_provenance_dict(parameters))) + return tables + def has_index(self): """ Returns True if this TableCollection is indexed. @@ -1810,83 +1952,3 @@ def drop_index(self): indexed this method has no effect. """ self.ll_tables.drop_index() - - -############################################# -# Table functions. -############################################# - -def pack_bytes(data): - """ - Packs the specified list of bytes into a flattened numpy array of 8 bit integers - and corresponding offsets. See :ref:`sec_encoding_ragged_columns` for details - of this encoding. - - :param list[bytes] data: The list of bytes values to encode. - :return: The tuple (packed, offset) of numpy arrays representing the flattened - input data and offsets. - :rtype: numpy.array (dtype=np.int8), numpy.array (dtype=np.uint32). - """ - n = len(data) - offsets = np.zeros(n + 1, dtype=np.uint32) - for j in range(n): - offsets[j + 1] = offsets[j] + len(data[j]) - column = np.zeros(offsets[-1], dtype=np.int8) - for j, value in enumerate(data): - column[offsets[j]: offsets[j + 1]] = bytearray(value) - return column, offsets - - -def unpack_bytes(packed, offset): - """ - Unpacks a list of bytes from the specified numpy arrays of packed byte - data and corresponding offsets. See :ref:`sec_encoding_ragged_columns` for details - of this encoding. - - :param numpy.ndarray packed: The flattened array of byte values. - :param numpy.ndarray offset: The array of offsets into the ``packed`` array. - :return: The list of bytes values unpacked from the parameter arrays. - :rtype: list[bytes] - """ - # This could be done a lot more efficiently... - ret = [] - for j in range(offset.shape[0] - 1): - raw = packed[offset[j]: offset[j + 1]].tobytes() - ret.append(raw) - return ret - - -def pack_strings(strings, encoding="utf8"): - """ - Packs the specified list of strings into a flattened numpy array of 8 bit integers - and corresponding offsets using the specified text encoding. - See :ref:`sec_encoding_ragged_columns` for details of this encoding of - columns of variable length data. - - :param list[str] data: The list of strings to encode. - :param str encoding: The text encoding to use when converting string data - to bytes. See the :mod:`codecs` module for information on available - string encodings. - :return: The tuple (packed, offset) of numpy arrays representing the flattened - input data and offsets. - :rtype: numpy.array (dtype=np.int8), numpy.array (dtype=np.uint32). - """ - return pack_bytes([bytearray(s.encode(encoding)) for s in strings]) - - -def unpack_strings(packed, offset, encoding="utf8"): - """ - Unpacks a list of strings from the specified numpy arrays of packed byte - data and corresponding offsets using the specified text encoding. - See :ref:`sec_encoding_ragged_columns` for details of this encoding of - columns of variable length data. - - :param numpy.ndarray packed: The flattened array of byte values. - :param numpy.ndarray offset: The array of offsets into the ``packed`` array. - :param str encoding: The text encoding to use when converting string data - to bytes. See the :mod:`codecs` module for information on available - string encodings. - :return: The list of strings unpacked from the parameter arrays. - :rtype: list[str] - """ - return [b.decode(encoding) for b in unpack_bytes(packed, offset)] diff --git a/python/tskit/trees.py b/python/tskit/trees.py index a75056d22f..af64938901 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -3206,114 +3206,6 @@ def simplify( else: return new_ts - def slice( - self, start=None, stop=None, reset_coordinates=True, simplify=True, - record_provenance=True): - """ - Truncate this tree sequence to include only information in the genomic interval - between ``start`` and ``stop``. Edges are truncated to this interval, and sites - not covered by the sliced region are thrown away. - - :param float start: The leftmost genomic position, giving the start point - of the kept region. Tree sequence information along the genome prior to (but - not including) this point will be discarded. If None, set equal to zero. - :param float stop: The rightmost genomic position, giving the end point of the - kept region. Tree sequence information at this point and further along the - genomic sequence will be discarded. If None, is set equal to the current tree - sequence's ``sequence_length``. - :param bool reset_coordinates: Reset the genomic coordinates such that position - 0 in the returned tree sequence corresponds to position ``start`` in the - original one, and the returned tree sequence has sequence length - ``stop``-``start``. Sites and tree intervals will all have their positions - shifted to reflect the new coordinate system. If ``False``, do not rescale: - the resulting tree sequence will therefore cover the same genomic span as - the original, but will have tree and site information missing for genomic - regions outside the sliced region. (Default: True) - :param bool simplify: If True, simplify the resulting tree sequence so that nodes - no longer used in the resulting trees are discarded. (Default: True). - :param bool record_provenance: If True, record details of this call to - slice in the returned tree sequence's provenance information. - (Default: True). - :return: The sliced tree sequence. - :rtype: .TreeSequence - """ - def keep_with_offset(keep, data, offset): - # Need to case diff to int32 for 32bit builds - lens = np.diff(offset).astype(np.int32) - return (data[np.repeat(keep, lens)], - np.concatenate([ - np.array([0], dtype=offset.dtype), - np.cumsum(lens[keep], dtype=offset.dtype)])) - if start is None: - start = 0 - if stop is None: - stop = self.sequence_length - if start < 0 or stop <= start or stop > self.sequence_length: - raise ValueError("Slice bounds must be within the existing tree sequence") - edges = self.tables.edges - sites = self.tables.sites - mutations = self.tables.mutations - keep_edges = np.logical_not( - np.logical_or(edges.right <= start, edges.left >= stop)) - keep_sites = np.logical_and(sites.position >= start, sites.position < stop) - keep_mutations = keep_sites[mutations.site] - tables = self.dump_tables() - new_as, new_as_offset = keep_with_offset( - keep_sites, sites.ancestral_state, sites.ancestral_state_offset) - new_md, new_md_offset = keep_with_offset( - keep_sites, sites.metadata, sites.metadata_offset) - - if reset_coordinates: - tables.edges.set_columns( - left=np.fmax(start, edges.left[keep_edges]) - start, - right=np.fmin(stop, edges.right[keep_edges]) - start, - parent=edges.parent[keep_edges], - child=edges.child[keep_edges]) - tables.sites.set_columns( - position=sites.position[keep_sites] - start, - ancestral_state=new_as, - ancestral_state_offset=new_as_offset, - metadata=new_md, - metadata_offset=new_md_offset) - tables.sequence_length = stop - start - else: - tables.edges.set_columns( - left=np.fmax(start, edges.left[keep_edges]), - right=np.fmin(stop, edges.right[keep_edges]), - parent=edges.parent[keep_edges], - child=edges.child[keep_edges]) - tables.sites.set_columns( - position=sites.position[keep_sites], - ancestral_state=new_as, - ancestral_state_offset=new_as_offset, - metadata=new_md, - metadata_offset=new_md_offset) - new_ds, new_ds_offset = keep_with_offset( - keep_mutations, mutations.derived_state, - mutations.derived_state_offset) - new_md, new_md_offset = keep_with_offset( - keep_mutations, mutations.metadata, mutations.metadata_offset) - site_map = np.cumsum(keep_sites, dtype=mutations.site.dtype) - 1 - tables.mutations.set_columns( - site=site_map[mutations.site[keep_mutations]], - node=mutations.node[keep_mutations], - derived_state=new_ds, - derived_state_offset=new_ds_offset, - parent=mutations.parent[keep_mutations], - metadata=new_md, - metadata_offset=new_md_offset) - if simplify: - tables.simplify() - if record_provenance: - # TODO replace with a version of https://github.com/tskit-dev/tskit/pull/243 - parameters = { - "command": "slice", - "TODO": "add slice parameters" - } - tables.provenances.add_row(record=json.dumps( - provenance.get_provenance_dict(parameters))) - return tables.tree_sequence() - def draw_svg(self, path=None, **kwargs): # TODO document this method, including semantic details of the # returned SVG object. diff --git a/python/tskit/util.py b/python/tskit/util.py index 7933199db4..6b5a0c80ac 100644 --- a/python/tskit/util.py +++ b/python/tskit/util.py @@ -57,3 +57,131 @@ def safe_np_int_cast(int_array, dtype, copy=False): # Raise a TypeError when we try to convert from, e.g., a float. casting = 'same_kind' return int_array.astype(dtype, casting=casting, copy=copy) + + +# +# Pack/unpack lists of data into flattened numpy arrays. +# + +def pack_bytes(data): + """ + Packs the specified list of bytes into a flattened numpy array of 8 bit integers + and corresponding offsets. See :ref:`sec_encoding_ragged_columns` for details + of this encoding. + + :param list[bytes] data: The list of bytes values to encode. + :return: The tuple (packed, offset) of numpy arrays representing the flattened + input data and offsets. + :rtype: numpy.array (dtype=np.int8), numpy.array (dtype=np.uint32). + """ + n = len(data) + offsets = np.zeros(n + 1, dtype=np.uint32) + for j in range(n): + offsets[j + 1] = offsets[j] + len(data[j]) + column = np.zeros(offsets[-1], dtype=np.int8) + for j, value in enumerate(data): + column[offsets[j]: offsets[j + 1]] = bytearray(value) + return column, offsets + + +def unpack_bytes(packed, offset): + """ + Unpacks a list of bytes from the specified numpy arrays of packed byte + data and corresponding offsets. See :ref:`sec_encoding_ragged_columns` for details + of this encoding. + + :param numpy.ndarray packed: The flattened array of byte values. + :param numpy.ndarray offset: The array of offsets into the ``packed`` array. + :return: The list of bytes values unpacked from the parameter arrays. + :rtype: list[bytes] + """ + # This could be done a lot more efficiently... + ret = [] + for j in range(offset.shape[0] - 1): + raw = packed[offset[j]: offset[j + 1]].tobytes() + ret.append(raw) + return ret + + +def pack_strings(strings, encoding="utf8"): + """ + Packs the specified list of strings into a flattened numpy array of 8 bit integers + and corresponding offsets using the specified text encoding. + See :ref:`sec_encoding_ragged_columns` for details of this encoding of + columns of variable length data. + + :param list[str] data: The list of strings to encode. + :param str encoding: The text encoding to use when converting string data + to bytes. See the :mod:`codecs` module for information on available + string encodings. + :return: The tuple (packed, offset) of numpy arrays representing the flattened + input data and offsets. + :rtype: numpy.array (dtype=np.int8), numpy.array (dtype=np.uint32). + """ + return pack_bytes([bytearray(s.encode(encoding)) for s in strings]) + + +def unpack_strings(packed, offset, encoding="utf8"): + """ + Unpacks a list of strings from the specified numpy arrays of packed byte + data and corresponding offsets using the specified text encoding. + See :ref:`sec_encoding_ragged_columns` for details of this encoding of + columns of variable length data. + + :param numpy.ndarray packed: The flattened array of byte values. + :param numpy.ndarray offset: The array of offsets into the ``packed`` array. + :param str encoding: The text encoding to use when converting string data + to bytes. See the :mod:`codecs` module for information on available + string encodings. + :return: The list of strings unpacked from the parameter arrays. + :rtype: list[str] + """ + return [b.decode(encoding) for b in unpack_bytes(packed, offset)] + + +# +# Interval utilities +# + +def intervals_to_np_array(intervals, start, end): + """ + Converts the specified intervals to a numpy array and checks for + errors. + """ + intervals = np.array(intervals, dtype=np.float64) + # Special case the empty list of intervals + if len(intervals) == 0: + intervals = np.zeros((0, 2), dtype=np.float64) + if len(intervals.shape) != 2: + raise ValueError("Intervals must be a 2D numpy array") + if intervals.shape[1] != 2: + raise ValueError("Intervals array shape must be (N, 2)") + # TODO do this with numpy operations. + last_right = start + for left, right in intervals: + if left < start or right > end: + raise ValueError( + "Intervals must be within {} and {}".format(start, end)) + if right <= left: + raise ValueError("Bad interval: right <= left") + if left < last_right: + raise ValueError("Intervals must be disjoint.") + last_right = right + return intervals + + +def negate_intervals(intervals, start, end): + """ + Returns the set of intervals *not* covered by the specified set of + disjoint intervals in the specfied range. + """ + intervals = intervals_to_np_array(intervals, start, end) + other_intervals = [] + last_right = start + for left, right in intervals: + if left != last_right: + other_intervals.append((last_right, left)) + last_right = right + if last_right != end: + other_intervals.append((last_right, end)) + return np.array(other_intervals)