Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First pass at trim functions #292

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 119 additions & 12 deletions python/tskit/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,14 +1858,6 @@ def keep_intervals(self, intervals, simplify=True, record_provenance=True):
:rtype: tskit.TableCollection
"""

def keep_with_offset(keep, data, offset):
# We need the astype here for 32 bit machines
lens = np.diff(offset).astype(np.int32)
return (data[np.repeat(keep, lens)],
np.concatenate([
np.array([0], dtype=offset.dtype),
np.cumsum(lens[keep], dtype=offset.dtype)]))

intervals = util.intervals_to_np_array(intervals, 0, self.sequence_length)
if len(self.migrations) > 0:
raise ValueError("Migrations not supported by keep_intervals")
Expand All @@ -1882,9 +1874,9 @@ def keep_with_offset(keep, data, offset):
for s, e in intervals:
curr_keep_sites = np.logical_and(sites.position >= s, sites.position < e)
keep_sites = np.logical_or(keep_sites, curr_keep_sites)
new_as, new_as_offset = keep_with_offset(
new_as, new_as_offset = util.keep_with_offset(
curr_keep_sites, sites.ancestral_state, sites.ancestral_state_offset)
new_md, new_md_offset = keep_with_offset(
new_md, new_md_offset = util.keep_with_offset(
curr_keep_sites, sites.metadata, sites.metadata_offset)
keep_mutations = np.logical_or(
keep_mutations, curr_keep_sites[mutations.site])
Expand All @@ -1900,9 +1892,9 @@ def keep_with_offset(keep, data, offset):
ancestral_state_offset=new_as_offset,
metadata=new_md,
metadata_offset=new_md_offset)
new_ds, new_ds_offset = keep_with_offset(
new_ds, new_ds_offset = util.keep_with_offset(
keep_mutations, mutations.derived_state, mutations.derived_state_offset)
new_md, new_md_offset = keep_with_offset(
new_md, new_md_offset = util.keep_with_offset(
keep_mutations, mutations.metadata, mutations.metadata_offset)
site_map = np.cumsum(keep_sites, dtype=mutations.site.dtype) - 1
tables.mutations.set_columns(
Expand Down Expand Up @@ -1933,6 +1925,121 @@ def keep_with_offset(keep, data, offset):
provenance.get_provenance_dict(parameters)))
return tables

def remove_sites(self, site_ids, record_provenance=True):
"""
Remove the specified sites entirely from the sites and mutations tables in this
collection.

:param list[int] site_ids: A list of site IDs to remove.
:param bool record_provenance: If True, record details of this call to
``remove_sites`` in the returned tree sequence's provenance information.
(Default: True).
"""
if len(site_ids) != 0:
keep_sites = np.ones(self.sites.num_rows, dtype=bool)
keep_sites[util.safe_np_int_cast(site_ids, np.uint32)] = 0
new_as, new_as_offset = util.keep_with_offset(
keep_sites, self.sites.ancestral_state,
self.sites.ancestral_state_offset)
new_md, new_md_offset = util.keep_with_offset(
keep_sites, self.sites.metadata, self.sites.metadata_offset)
self.sites.set_columns(
position=self.sites.position[keep_sites],
ancestral_state=new_as,
ancestral_state_offset=new_as_offset,
metadata=new_md,
metadata_offset=new_md_offset)
# We also need to adjust the mutations table, as it references into sites
keep_mutations = keep_sites[self.mutations.site]
new_ds, new_ds_offset = util.keep_with_offset(
keep_mutations, self.mutations.derived_state,
self.mutations.derived_state_offset)
new_md, new_md_offset = util.keep_with_offset(
keep_mutations, self.mutations.metadata, self.mutations.metadata_offset)
site_map = np.cumsum(keep_sites, dtype=self.mutations.site.dtype) - 1
self.mutations.set_columns(
site=site_map[self.mutations.site[keep_mutations]],
node=self.mutations.node[keep_mutations],
derived_state=new_ds,
derived_state_offset=new_ds_offset,
parent=self.mutations.parent[keep_mutations],
metadata=new_md,
metadata_offset=new_md_offset)
if record_provenance:
# TODO replace with a version of https://github.com/tskit-dev/tskit/pull/243
parameters = {
"command": "remove_sites",
"TODO": "add parameters"
}
self.provenances.add_row(record=json.dumps(
provenance.get_provenance_dict(parameters)))

def ltrim(self, record_provenance=True):
"""
Reset the coordinate system, changing the left and right genomic positions in the
edge table such that the leftmost edge is at position 0. Positions in the sites
table are also adjusted accordingly. Additionally, sites (and associated
mutations) to the left of the new zero point are thrown away.

:param bool record_provenance: If True, record details of this call to
``remove_sites`` in the returned tree sequence's provenance information.
(Default: True).
"""
leftmost = np.max(self.edges.left)
self.remove_sites(
np.where(self.sites.position < leftmost), record_provenance=False)
self.edges.set_columns(
left=self.edges.left - leftmost, right=self.edges.right - leftmost,
parent=self.edges.parent, child=self.edges.child)
self.sites.set_columns(
position=self.sites.position - leftmost,
ancestral_state=self.sites.ancestral_state,
ancestral_state_offset=self.sites.ancestral_state_offset,
metadata=self.sites.metadata,
metadata_offset=self.sites.metadata_offset)
self.sequence_length = self.sequence_length - leftmost

if record_provenance:
# TODO replace with a version of https://github.com/tskit-dev/tskit/pull/243
parameters = {
"command": "ltrim",
}
self.provenances.add_row(record=json.dumps(
provenance.get_provenance_dict(parameters)))

def rtrim(self, record_provenance=True):
"""
Reset the ``sequence_length`` property so that the sequence ends at maximum value
of ``edges.right``. Additionally, sites (and associated mutations) at positions
greater than the new ``sequence_length`` are thrown away.

:param bool record_provenance: If True, record details of this call to
``rtrim`` in the returned tree sequence's provenance information.
(Default: True).
"""
rightmost = np.max(self.edges.right)
self.remove_sites(
np.where(self.sites.position >= rightmost), record_provenance=False)
self.sequence_length = rightmost
if record_provenance:
# TODO replace with a version of https://github.com/tskit-dev/tskit/pull/243
parameters = {
"command": "rtrim",
}
self.provenances.add_row(record=json.dumps(
provenance.get_provenance_dict(parameters)))

def trim(self, record_provenance=True):
self.rtrim(record_provenance=False)
self.ltrim(record_provenance=False)
if record_provenance:
# TODO replace with a version of https://github.com/tskit-dev/tskit/pull/243
parameters = {
"command": "trim",
}
self.provenances.add_row(record=json.dumps(
provenance.get_provenance_dict(parameters)))

def has_index(self):
"""
Returns True if this TableCollection is indexed.
Expand Down
107 changes: 102 additions & 5 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -3222,11 +3222,11 @@ def simplify(
sequence.
:rtype: .TreeSequence or a (.TreeSequence, numpy.array) tuple
"""
tables = self.dump_tables()
ts_tables = self.dump_tables()
if samples is None:
samples = self.get_samples()
assert tables.sequence_length == self.sequence_length
node_map = tables.simplify(
assert ts_tables.sequence_length == self.sequence_length
node_map = ts_tables.simplify(
samples=samples,
filter_zero_mutation_sites=filter_zero_mutation_sites,
reduce_to_site_topology=reduce_to_site_topology,
Expand All @@ -3242,15 +3242,112 @@ def simplify(
"command": "simplify",
"TODO": "add simplify parameters"
}
tables.provenances.add_row(record=json.dumps(
ts_tables.provenances.add_row(record=json.dumps(
provenance.get_provenance_dict(parameters)))
new_ts = tables.tree_sequence()
new_ts = ts_tables.tree_sequence()
assert new_ts.sequence_length == self.sequence_length
if map_nodes:
return new_ts, node_map
else:
return new_ts

def remove_sites(self, site_ids, record_provenance=True):
"""
Remove the specified sites entirely from a tree sequence

:param list site_ids: The list of sites to remove. This
may be a numpy array (or array-like) object (dtype=np.uint32).
:param bool record_provenance: If True, record details of this call to
slice in the returned tree sequence's provenance information.
(Default: True).
:return: The sliced tree sequence.
:rtype: .TreeSequence
"""
ts_tables = self.dump_tables()
ts_tables.remove_sites(site_ids, record_provenance)
return ts_tables.tree_sequence()

def ltrim(self, record_provenance=True):
"""
Trim from the left side of this tree sequence any genomic region with no trees.
This is equivalent to resetting the coordinate system so that the leftmost edge
in the tree sequence starts at position 0. Sites and mutations within the trimmed
region are thrown away.

:param bool record_provenance: If True, record details of this call to
slice in the returned tree sequence's provenance information (Default: True).
:return: The sliced tree sequence.
:rtype: .TreeSequence
"""
ts_tables = self.dump_tables()
ts_tables.ltrim(record_provenance)
return ts_tables.tree_sequence()

def rtrim(self, record_provenance=True):
"""
Trim from the right side of this tree sequence any genomic region with no trees.
This is equivalent to setting the sequence_length to the end point of the
rightmost edge in the tree sequence. Sites and mutations within the trimmed
region are thrown away.

:param bool record_provenance: If True, record details of this call to
slice in the returned tree sequence's provenance information (Default: True).
:return: The sliced tree sequence.
:rtype: .TreeSequence
"""
ts_tables = self.dump_tables()
ts_tables.rtrim(record_provenance)
return ts_tables.tree_sequence()

def trim(self, record_provenance=True):
"""
Trim from the right side of this tree sequence any genomic region with no trees.
This is equivalent to setting the sequence_length to the end point of the
rightmost edge in the tree sequence. Sites and mutations within the trimmed
region are thrown away.

:param bool record_provenance: If True, record details of this call to
slice in the returned tree sequence's provenance information.
(Default: True).
:return: The sliced tree sequence.
:rtype: .TreeSequence
"""
ts_tables = self.dump_tables()
ts_tables.trim(record_provenance)
return ts_tables.tree_sequence()

def slice(self, start=None, stop=None, simplify=True, record_provenance=True):
"""
Extract a portion of the tree sequence covering a restricted genomic region. This
is essentially a wrapper that runs ``keep_intervals([(start, stop)]).trim()``
on the tables that describe this tree sequence.

:param float start: The leftmost genomic position, giving the start point
of the kept region. Tree sequence information along the genome prior to (but
not including) this point will be discarded. If None, treat as zero.
:param float stop: The rightmost genomic position, giving the end point of the
kept region. Tree sequence information at this point and further along the
genomic sequence will be discarded. If None, treat as equal to the current
tree sequence's ``sequence_length``.
:param bool simplify: If True, simplify the resulting tree sequence so that nodes
no longer used in the resulting trees are discarded. (Default: True).
:param bool record_provenance: If True, record details of this call to
slice in the returned tree sequence's provenance information (Default: True).
:return: The sliced tree sequence.
:rtype: .TreeSequence
"""
ts_tables = self.dump_tables()
ts_tables.keep_interval([(start, stop)], simplify=simplify).trim()
if record_provenance:
# TODO replace with a version of https://github.com/tskit-dev/tskit/pull/243
parameters = {
"command": "slice",
"TODO": "add slice parameters"
}
ts_tables.provenances.add_row(record=json.dumps(
provenance.get_provenance_dict(parameters)))
return ts_tables.tree_sequence()

def draw_svg(self, path=None, **kwargs):
# TODO document this method, including semantic details of the
# returned SVG object.
Expand Down
14 changes: 13 additions & 1 deletion python/tskit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def intervals_to_np_array(intervals, start, end):
def negate_intervals(intervals, start, end):
"""
Returns the set of intervals *not* covered by the specified set of
disjoint intervals in the specfied range.
disjoint intervals in the specified range.
"""
intervals = intervals_to_np_array(intervals, start, end)
other_intervals = []
Expand All @@ -185,3 +185,15 @@ def negate_intervals(intervals, start, end):
if last_right != end:
other_intervals.append((last_right, end))
return np.array(other_intervals)


def keep_with_offset(keep, data, offset):
"""
Used e.g. in keep_intervals and other functions to do fast numpy removal from tables
"""
# We need the astype here for 32 bit machines
lens = np.diff(offset).astype(np.int32)
return (data[np.repeat(keep, lens)],
np.concatenate([
np.array([0], dtype=offset.dtype),
np.cumsum(lens[keep], dtype=offset.dtype)]))