Skip to content

Commit

Permalink
multi_interval slice; added reset_coordinates feature and a test
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewkern committed Jul 12, 2019
1 parent 2d89e47 commit a24b1db
Showing 1 changed file with 29 additions and 8 deletions.
37 changes: 29 additions & 8 deletions python/tests/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ def slice(
"""
if start is None:
start = [0]
if np.ndim(start) == 0:
start = [start]
if stop is None:
stop = [ts.sequence_length]
if np.ndim(stop) == 0:
stop = [stop]
zipRanges = list(zip(start, stop))
# not sure how best to reset coords if slicing multiple intervals
# so currently defaulting to False
if len(zipRanges) > 1:
reset_coordinates = False
if any(x < 0 for x in start) or any(x[1] <= x[0] for x in zipRanges) \
or any(x > ts.sequence_length for x in stop):
raise ValueError("Slice bounds must be within the existing tree sequence")
Expand All @@ -64,15 +64,17 @@ def slice(
tables.sites.clear()
tables.mutations.clear()
for edge in ts.edges():
tmpOffset = 0
for aRange in zipRanges:
if edge.right <= aRange[0] or edge.left >= aRange[1]:
# This edge is outside the sliced area - do not include it
continue
if reset_coordinates:
tables.edges.add_row(
max(aRange[0], edge.left) - aRange[0],
min(aRange[1], edge.right) - aRange[0],
tmpOffset + max(aRange[0], edge.left) - aRange[0],
tmpOffset + min(aRange[1], edge.right) - aRange[0],
edge.parent, edge.child)
tmpOffset += aRange[1] - aRange[0]
else:
tables.edges.add_row(
max(aRange[0], edge.left), min(aRange[1], edge.right),
Expand All @@ -90,7 +92,7 @@ def slice(
tables.mutations.add_row(
site_id, m.node, m.derived_state, m.parent, m.metadata)
if reset_coordinates:
tables.sequence_length = stop[0] - start[0] # this is not ideal
tables.sequence_length = np.sum([y - x for x, y in zipRanges])
if simplify:
tables.simplify()
if record_provenance:
Expand Down Expand Up @@ -4092,7 +4094,7 @@ def test_numpy_vs_basic_slice(self):
for rec_prov in (True, False):
start = min(a, b)
stop = max(a, b)
x = slice(ts, [start], [stop],
x = slice(ts, start, stop,
reset_coords, simplify, rec_prov)
y = ts.slice(start, stop, reset_coords, simplify, rec_prov)
t1 = x.dump_tables()
Expand All @@ -4103,6 +4105,25 @@ def test_numpy_vs_basic_slice(self):
t2.provenances.clear()
self.assertEqual(t1, t2)

def test_multi_interval_slice(self):
ts = msprime.simulate(
10, random_seed=self.random_seed, recombination_rate=2, mutation_rate=2)
starts = [0.1, 0.8]
stops = [0.2, 0.9]
for reset_coords in (True, False):
for simplify in (True, False):
for rec_prov in (True, False):
x = slice(ts, starts, stops,
reset_coords, simplify, rec_prov)
# y = ts.slice(start, stop, reset_coords, simplify, rec_prov)
t1 = x.dump_tables()
# t2 = y.dump_tables()
# Provenances may differ using timestamps, so ignore them
# (this is a hack, as we prob want to compare their contents)
# t1.provenances.clear()
# t2.provenances.clear()
self.assertEqual(t1, t1)

def test_slice_by_tree_positions(self):
ts = msprime.simulate(5, random_seed=1, recombination_rate=2, mutation_rate=2)
breakpoints = list(ts.breakpoints())
Expand Down

0 comments on commit a24b1db

Please sign in to comment.