Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add feature #2, add option to collapse low support branches #5

Merged
merged 2 commits into from
Jun 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 53 additions & 23 deletions shiptv/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,71 +6,101 @@

import click

from shiptv.shiptv import genbank_metadata, add_user_metadata, parse_tree, write_html_tree, parse_leaf_list, prune_tree, reorder_metadata_fields, \
get_metadata_fields, highlight_user_samples, try_fix_serotype_metadata, try_fix_country_metadata, \
try_fix_collection_date_metadata, try_fix_host_metadata
from shiptv.shiptv import genbank_metadata, add_user_metadata, parse_tree, \
write_html_tree, parse_leaf_list, prune_tree, reorder_metadata_fields, \
get_metadata_fields, highlight_user_samples, try_fix_serotype_metadata, \
try_fix_country_metadata, \
try_fix_collection_date_metadata, try_fix_host_metadata, collapse_branches


@click.command()
@click.option('-r', '--ref-genomes-genbank',
required=True,
type=click.Path(exists=True, dir_okay=False),
help='Reference genome sequences Genbank file')
@click.option('-n', '--newick',
required=True,
required=True, type=click.Path(exists=True, dir_okay=False),
help='Phylogenetic tree Newick file')
@click.option('-o', '--output-html',
@click.option('-N', '--output-newick',
type=click.Path(),
help='Output Newick file')
@click.option('-o', '--output-html', type=click.Path(),
required=True, help='Output HTML tree path')
@click.option('-m', '--output-metadata-table',
type=click.Path(),
required=True, help='Output metadata table path')
@click.option('--leaflist',
required=False,
type=click.Path(),
default=None,
help='Optional leaf names to select from phylogenetic tree for pruned tree visualization. '
'One leaf name per line.')
help='Optional leaf names to select from phylogenetic tree for '
'pruned tree visualization. One leaf name per line.')
@click.option('--genbank-metadata-fields',
required=False,
type=click.Path(),
default=None,
help='Optional fields to extract from Genbank source metadata. One field per line.')
help='Optional fields to extract from Genbank source metadata. '
'One field per line.')
@click.option('--user-sample-metadata',
required=False,
type=click.Path(),
default=None,
help='Optional tab-delimited metadata for user samples to join with metadata derived from reference '
'genome sequences Genbank file. '
'Sample IDs must be in the first column.')
help='Optional tab-delimited metadata for user samples to join '
'with metadata derived from reference genome sequences '
'Genbank file. Sample IDs must be in the first column.')
@click.option('--metadata-fields-in-order',
required=False,
type=click.Path(),
default=None,
help='Optional list of fields in order to output in metadata table and HTML tree visualization. '
'One field per line.')
@click.option('--dont-fix-metadata', is_flag=True, help='Do not automatically fix metadata')
help='Optional list of fields in order to output in metadata '
'table and HTML tree visualization. One field per line.')
@click.option('--dont-fix-metadata', is_flag=True,
help='Do not automatically fix metadata')
@click.option('-C', '--collapse-support', default=-1, type=float,
help='Collapse internal branches below specified bootstrap '
'support value (default -1 for no collapsing)')
def main(ref_genomes_genbank,
newick,
output_newick,
output_html,
output_metadata_table,
leaflist,
genbank_metadata_fields,
user_sample_metadata,
metadata_fields_in_order,
dont_fix_metadata):
dont_fix_metadata,
collapse_support):
"""Create HTML tree visualization with metadata.

The metadata for reference genomes is extracted from the specified Genbank file.
The metadata for reference genomes is extracted from the specified Genbank
file.

Any leaf names that are present in the tree but not present in the Genbank file are assumed to be user samples
and are flagged as such in the metadata table as "user_sample"="Yes".
Any leaf names that are present in the tree but not present in the Genbank
file are assumed to be user samples and are flagged as such in the
metadata table as "user_sample"="Yes".
"""
LOG_FORMAT = '%(asctime)s %(levelname)s: %(message)s [in %(filename)s:%(lineno)d]'
logging.basicConfig(format=LOG_FORMAT, level=logging.INFO)
logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s '
'[in %(filename)s:%(lineno)d]',
level=logging.INFO)
tree = parse_tree(newick)
if collapse_support != -1:
logging.info(f'Collapsing internal branches with support values less '
f'than {collapse_support}')
collapse_branches(tree, collapse_support)
if output_newick:
tree.write(outfile=output_newick)

df_metadata = genbank_metadata(ref_genomes_genbank)
logging.info(f'Parsed metadata from "{ref_genomes_genbank}" with columns "{";".join(df_metadata.columns)}')
logging.info(f'Parsed metadata from "{ref_genomes_genbank}" with columns '
f'"{";".join(df_metadata.columns)}')

metadata_fields = get_metadata_fields(genbank_metadata_fields)
# only use columns present in the reference genome metadata
metadata_fields = [x for x in metadata_fields if x in list(df_metadata.columns)]
metadata_fields = [x for x in metadata_fields if
x in list(df_metadata.columns)]
logging.info(f'Metadata table fields: {";".join(metadata_fields)}')
df_metadata = highlight_user_samples(df_metadata, metadata_fields, tree.get_leaf_names())
df_metadata = highlight_user_samples(df_metadata, metadata_fields,
tree.get_leaf_names())
if dont_fix_metadata:
logging.warning('Not fixing any genome metadata.')
else:
Expand Down
66 changes: 44 additions & 22 deletions shiptv/shiptv.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
# -*- coding: utf-8 -*-

"""Main module."""
import jinja2
import logging
from typing import List, Dict, Optional

import jinja2
import pandas as pd
import pkg_resources
from pkg_resources import resource_filename
from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
from ete3 import Tree
from typing import List, Dict, Optional

html_template = pkg_resources.resource_filename('shiptv', 'tmpl/phylocanvas.html')
html_template = resource_filename('shiptv', 'tmpl/phylocanvas.html')


def read_lines(filepath: str) -> List[str]:
Expand All @@ -24,37 +25,40 @@ def read_lines(filepath: str) -> List[str]:

def genbank_source_metadata(rec: SeqRecord) -> Dict[str, str]:
"""Get source feature metadata dictionary for a SeqRecord"""
return {k: v[0] if v is not None and len(v) == 1 else v for k, v in rec.features[0].qualifiers.items()}
return {k: v[0] if v is not None and len(v) == 1 else v
for k, v in rec.features[0].qualifiers.items()}


def genbank_metadata(genbank: str) -> pd.DataFrame:
"""Parse genome metadata from Genbank file into Pandas DataFrame.
"""
id_to_rec = {r.id: r for r in SeqIO.parse(genbank, 'genbank')}
df_metadata = pd.DataFrame({gid: genbank_source_metadata(rec) for gid, rec in id_to_rec.items()}).transpose()
df_metadata = pd.DataFrame({gid: genbank_source_metadata(rec)
for gid, rec in id_to_rec.items()}).transpose()
if 'isolate' in df_metadata and 'strain' in df_metadata:
df_metadata['strain'] = df_metadata['isolate'].combine_first(df_metadata['strain'])
df_metadata['strain'] = df_metadata['isolate']\
.combine_first(df_metadata['strain'])
return df_metadata


def fix_host_metadata(df_metadata: pd.DataFrame) -> None:
def fix_host_metadata(df: pd.DataFrame) -> None:
cattle_syn = '''
Bos taurus
bovine
cattle (Ankole cow, sentinel herd)
Ankole cow
Cattle
'''.strip().split('\n')
df_metadata.host[df_metadata.host.isin(cattle_syn)] = 'cattle'
df.host[df.host.isin(cattle_syn)] = 'cattle'
sheep_syn = '''
ovine
'''.strip().split('\n')
df_metadata.host[df_metadata.host.isin(sheep_syn)] = 'sheep'
df.host[df.host.isin(sheep_syn)] = 'sheep'
pig_syn = '''
sus scrofa domesticus
swine
porcine'''.strip().split('\n')
df_metadata.host[df_metadata.host.isin(pig_syn)] = 'pig'
df.host[df.host.isin(pig_syn)] = 'pig'


def fix_collection_date(df_metadata):
Expand All @@ -64,30 +68,30 @@ def fix_collection_date(df_metadata):
df_metadata.collection_date = [str(x).split()[0] if not pd.isnull(x) else None for x in dates]


def fix_country_region(df_metadata):
df_metadata['region'] = df_metadata.country.str.extract(r'.*:\s*(.*)\s*')
df_metadata['country'] = df_metadata.country.str.extract(r'([^:]+)(:\s*.*\s*)?')[0]
def fix_country_region(df):
df['region'] = df.country.str.extract(r'.*:\s*(.*)\s*')
df['country'] = df.country.str.extract(r'([^:]+)(:\s*.*\s*)?')[0]


def add_user_metadata(df_metadata, user_sample_metadata):
def add_user_metadata(df: pd.DataFrame, user_sample_metadata: str) -> None:
df_user_metadata = pd.read_csv(user_sample_metadata, sep='\t', index_col=0)
logging.info(f'Read user sample metadata table from '
f'"{user_sample_metadata}" with '
f'{df_user_metadata.shape[0]} rows and columns: '
f'{";".join(df_user_metadata.columns)}')
for user_column in df_user_metadata.columns:
if user_column not in df_metadata.columns:
df_metadata[user_column] = None
if user_column not in df.columns:
df[user_column] = None
for idx, row in df_user_metadata.iterrows():
if idx not in df_metadata.index:
if idx not in df.index:
continue
original_row = df_metadata.loc[idx, :]
original_row = df.loc[idx, :]
row_dict = original_row[~pd.isnull(original_row)].to_dict()
row_dict.update(row.to_dict())
df_metadata.loc[idx, :] = pd.Series(row_dict)
df.loc[idx, :] = pd.Series(row_dict)


def parse_tree(newick):
def parse_tree(newick: str) -> Tree:
# Read phylogenetic tree newick file using ete3
tree = Tree(newick=newick)
# Calculate the midpoint node
Expand All @@ -97,7 +101,9 @@ def parse_tree(newick):
return tree


def write_html_tree(df_metadata, output_html, tree):
def write_html_tree(df_metadata: pd.DataFrame,
output_html: str,
tree: Tree) -> None:
with open(html_template) as fh, open(output_html, 'w') as fout:
tmpl = jinja2.Template(fh.read())
fout.write(tmpl.render(newick_string=tree.write(),
Expand Down Expand Up @@ -183,3 +189,19 @@ def try_fix_host_metadata(df_metadata: pd.DataFrame) -> None:
f'to use more consistent host type categories. '
f'Before: {before_host_types} host types. '
f'After: {after_host_types} host types.')


def collapse_branches(tree: Tree, collapse_support: float) -> None:
"""Collapse internal branches below support threshold

Note:
This function modifies the supplied `tree` object.

Args:
tree: ete3 Tree object
collapse_support: Support threshold
"""
for node in tree.traverse():
if not node.is_leaf() and not node.is_root():
if node.support < collapse_support:
node.delete()
52 changes: 51 additions & 1 deletion tests/test_shiptv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from pandas.testing import assert_frame_equal, assert_series_equal

from shiptv import cli
from shiptv.shiptv import fix_collection_date
from shiptv.shiptv import collapse_branches, fix_collection_date
from ete3 import Tree

input_ref_genbank = abspath('tests/data/fmdv-5.gb')
input_newick = abspath('tests/data/fmdv-5.newick')
Expand All @@ -28,17 +29,66 @@ def test_command_line_interface():
with runner.isolated_filesystem():
out_html = 'test.html'
out_table = 'test.tsv'
out_newick = 'test.newick'
test_result = runner.invoke(cli.main, ['-r', input_ref_genbank,
'-n', input_newick,
'-N', out_newick,
'-o', out_html,
'-m', out_table])
assert test_result.exit_code == 0
assert exists(out_html)
assert exists(out_table)
assert exists(out_newick)
assert open(input_newick).read() != open(out_newick).read()
assert_frame_equal(pd.read_csv(expected_table, sep='\t'), pd.read_csv(out_table, sep='\t'))

with runner.isolated_filesystem():
out_html = 'test.html'
out_table = 'test.tsv'
test_result = runner.invoke(cli.main, ['-r', input_ref_genbank,
'-n', input_newick,
'-o', out_html,
'-m', out_table,
'-C', 95])
assert test_result.exit_code == 0
assert exists(out_html)
assert exists(out_table)
assert_frame_equal(pd.read_csv(expected_table, sep='\t'), pd.read_csv(out_table, sep='\t'))


def test_collapse_branches():
before_tree_ascii = """
/-MK088171.1
|
--|--MK071699.1
|
| /-MH845413.2
\-|
| /-MH784405.1
\-|
\-MH784404.1
""".strip()
after_tree_ascii = """
/-MK088171.1
|
|--MK071699.1
--|
|--MH845413.2
|
| /-MH784405.1
\-|
\-MH784404.1
""".strip()

tree = Tree(newick=input_newick)
assert tree.get_ascii().strip() == before_tree_ascii
collapse_branches(tree, 95)
assert tree.get_ascii().strip() == after_tree_ascii


def test_fix_collection_date():
df = pd.DataFrame(dict(collection_date=['1994/1995', '2000', 'not-a-date', '2009/03/31']))
fix_collection_date(df)
expected_years = pd.Series([None, 2000, None, 2009], name='collection_year')
assert_series_equal(df.collection_year, expected_years)