diff --git a/ci/conda_requirements.txt b/ci/conda_requirements.txt index f64ad1e..cd6435f 100644 --- a/ci/conda_requirements.txt +++ b/ci/conda_requirements.txt @@ -6,3 +6,4 @@ IPython<4.0.0 notebook scikit-bio=0.5.1 pyqt=4.11.4 +bokeh diff --git a/gneiss/plot/__init__.py b/gneiss/plot/__init__.py index 9e2eaf3..4509e0d 100644 --- a/gneiss/plot/__init__.py +++ b/gneiss/plot/__init__.py @@ -25,6 +25,7 @@ from ._heatmap import heatmap from ._blobtree import diamondtree +from ._radial import radialplot -__all__ = ["heatmap", "diamondtree"] +__all__ = ["heatmap", "diamondtree", "radialplot"] diff --git a/gneiss/plot/_dendrogram.py b/gneiss/plot/_dendrogram.py index f37175f..e9884eb 100644 --- a/gneiss/plot/_dendrogram.py +++ b/gneiss/plot/_dendrogram.py @@ -11,6 +11,11 @@ import abc +def _sign(x): + """Returns True if x is positive, False otherwise.""" + return x and x/abs(x) + + class Dendrogram(TreeNode): """ Stores data to be plotted as a dendrogram. @@ -28,6 +33,17 @@ class Dendrogram(TreeNode): Attributes ---------- length + leafcount + height + depth + + Notes + ----- + `length` refers to the branch length connect to the specified subtree. + `leafcount` is the number of tips within a subtree. `height` refers + to the longest path from root to the deepst leaf in that subtree. + `depth` is the number of nodes found in the longest path. + """ aspect_distorts_lengths = True @@ -39,11 +55,46 @@ def __init__(self, use_lengths=True, **kwargs): self.use_lengths_default = use_lengths def _cache_ntips(self): + """ Counts the number of leaves under each subtree.""" for n in self.postorder(): if n.is_tip(): - n._n_tips = 1 + n.leafcount = 1 + else: + n.leafcount = sum(c.leafcount for c in n.children) + + def update_geometry(self, use_lengths, depth=None): + """Calculate tree node attributes such as height and depth. + + Parameters + ---------- + use_lengths: bool + Specify if the branch length should be incorporated into + the geometry calculations for visualization. + depth: int + The number of nodes in the longest path from root to leaf. + + This is agnostic to scale and orientation. + """ + if self.length is None or not use_lengths: + if depth is None: + self.length = 0 else: - n._n_tips = sum(c._n_tips for c in n.children) + self.length = 1 + else: + self.length = self.length + + self.depth = (depth or 0) + self.length + + children = self.children + if children: + for c in children: + c.update_geometry(use_lengths, self.depth) + self.height = max([c.height for c in children]) + self.length + self.leafcount = sum([c.leafcount for c in children]) + + else: + self.height = self.length + self.leafcount = self.edgecount = 1 def coords(self, height, width): """ Returns coordinates of nodes to be rendered in plot. @@ -120,7 +171,7 @@ def __init__(self, **kwargs): super().__init__(**kwargs) @classmethod - def from_tree(cls, tree): + def from_tree(cls, tree, use_lengths=True): """ Creates an UnrootedDendrogram object from a skbio tree. Parameters @@ -134,7 +185,8 @@ def from_tree(cls, tree): """ for n in tree.postorder(): n.__class__ = UnrootedDendrogram - tree._cache_ntips() + + tree.update_geometry(use_lengths) return tree def rescale(self, width, height): @@ -159,7 +211,7 @@ def rescale(self, width, height): Notes ----- """ - angle = (2 * numpy.pi) / self._n_tips + angle = (2 * numpy.pi) / self.leafcount # this loop is a horrible brute force hack # there are better (but complex) ways to find # the best rotation of the tree to fit the display. @@ -224,7 +276,7 @@ def update_coordinates(self, s, x1, y1, a, da): y2 = y1 + self.length * s * numpy.cos(a) (self.x1, self.y1, self.x2, self.y2, self.angle) = (x1, y1, x2, y2, a) # TODO: Add functionality that allows for collapsing of nodes - a = a - self._n_tips * da / 2 + a = a - self.leafcount * da / 2 if self.is_tip(): points = [(x2, y2)] else: @@ -234,7 +286,7 @@ def update_coordinates(self, s, x1, y1, a, da): # need to be refactored to remove the recursion. for child in self.children: # calculate the arc that covers the subtree. - ca = child._n_tips * da + ca = child.leafcount * da points += child.update_coordinates(s, x2, y2, a+ca/2, da) a += ca return points diff --git a/gneiss/plot/_radial.py b/gneiss/plot/_radial.py new file mode 100644 index 0000000..afa8284 --- /dev/null +++ b/gneiss/plot/_radial.py @@ -0,0 +1,128 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2016--, gneiss development team. +# +# Distributed under the terms of the GPLv3 License. +# +# The full license is in the file COPYING.txt, distributed with this software. +# ---------------------------------------------------------------------------- +import pandas as pd +from gneiss.plot._dendrogram import UnrootedDendrogram + +try: + from bokeh.models.glyphs import Circle, Segment + from bokeh.models import ColumnDataSource, DataRange1d, Plot + from bokeh.models import HoverTool, BoxZoomTool, ResetTool +except ImportError: + raise ImportWarning('Bokeh not installed. ' + '`radialplot` will not be available') + + +def radialplot(tree, node_color='node_color', node_size='node_size', + node_alpha='node_alpha', edge_color='edge_color', + edge_alpha='edge_alpha', edge_width='edge_width', + figsize=(500, 500), **kwargs): + """ Plots unrooted radial tree. + + Parameters + ---------- + tree : instance of skbio.TreeNode + Input tree for plotting. + node_color : str + Name of variable in `tree` to color nodes. + node_size : str + Name of variable in `tree` that specifies the radius of nodes. + node_alpha : str + Name of variable in `tree` to specify node transparency. + edge_color : str + Name of variable in `tree` to color edges. + edge_alpha : str + Name of variable in `tree` to specify edge transparency. + edge_width : str + Name of variable in `tree` to specify edge width. + figsize : tuple, int + Size of resulting figure. default: (500, 500) + **kwargs: dict + Plotting options to pass into bokeh.models.Plot + + Returns + ------- + bokeh.models.Plot + Interactive plotting instance. + + Notes + ----- + This assumes that the tree is strictly bifurcating. + + See also + -------- + bifurcate + """ + # This entire function was motivated by + # http://chuckpr.github.io/blog/trees2.html + t = UnrootedDendrogram.from_tree(tree) + + nodes = t.coords(figsize[0], figsize[1]) + + # fill in all of the node attributes + def _retreive(tree, x, default): + return pd.Series({n.name: getattr(n, x, default) + for n in tree.levelorder()}) + + # default node color to light grey + nodes[node_color] = _retreive(t, node_color, default='#D3D3D3') + nodes[node_size] = _retreive(t, node_size, default=1) + nodes[node_alpha] = _retreive(t, node_alpha, default=1) + + edges = nodes[['child0', 'child1']] + edges = edges.dropna(subset=['child0', 'child1']) + edges = edges.unstack() + edges = pd.DataFrame({'src_node': edges.index.get_level_values(1), + 'dest_node': edges.values}) + edges['x0'] = [nodes.loc[n].x for n in edges.src_node] + edges['x1'] = [nodes.loc[n].x for n in edges.dest_node] + edges['y0'] = [nodes.loc[n].y for n in edges.src_node] + edges['y1'] = [nodes.loc[n].y for n in edges.dest_node] + ns = [n.name for n in t.levelorder(include_self=True)] + attrs = pd.DataFrame(index=ns) + + # default edge color to black + attrs[edge_color] = _retreive(t, edge_color, default='#000000') + attrs[edge_width] = _retreive(t, edge_width, default=1) + attrs[edge_alpha] = _retreive(t, edge_alpha, default=1) + + edges = pd.merge(edges, attrs, left_on='dest_node', + right_index=True, how='outer') + edges = edges.dropna(subset=['src_node']) + + node_glyph = Circle(x="x", y="y", + radius=node_size, + fill_color=node_color, + fill_alpha=node_alpha) + + edge_glyph = Segment(x0="x0", y0="y0", + x1="x1", y1="y1", + line_color=edge_color, + line_alpha=edge_alpha, + line_width=edge_width) + + def df2ds(df): + return ColumnDataSource(ColumnDataSource.from_df(df)) + + ydr = DataRange1d(range_padding=0.05) + xdr = DataRange1d(range_padding=0.05) + + plot = Plot(x_range=xdr, y_range=ydr, **kwargs) + plot.add_glyph(df2ds(edges), edge_glyph) + ns = plot.add_glyph(df2ds(nodes), node_glyph) + + # TODO: Will need to make the hovertool options more configurable + tooltip = """ +
+ name: + @index +
+ """ + hover = HoverTool(renderers=[ns], tooltips=tooltip) + plot.add_tools(hover, BoxZoomTool(), ResetTool()) + + return plot diff --git a/gneiss/plot/tests/test_dendrogram.py b/gneiss/plot/tests/test_dendrogram.py index 0892d5c..7b8f065 100644 --- a/gneiss/plot/tests/test_dendrogram.py +++ b/gneiss/plot/tests/test_dendrogram.py @@ -32,13 +32,13 @@ def test_cache_ntips(self): t._cache_ntips() - self.assertEquals(t._n_tips, 4) - self.assertEquals(t.children[0]._n_tips, 2) - self.assertEquals(t.children[1]._n_tips, 2) - self.assertEquals(t.children[0].children[0]._n_tips, 1) - self.assertEquals(t.children[0].children[1]._n_tips, 1) - self.assertEquals(t.children[1].children[0]._n_tips, 1) - self.assertEquals(t.children[1].children[1]._n_tips, 1) + self.assertEquals(t.leafcount, 4) + self.assertEquals(t.children[0].leafcount, 2) + self.assertEquals(t.children[1].leafcount, 2) + self.assertEquals(t.children[0].children[0].leafcount, 1) + self.assertEquals(t.children[0].children[1].leafcount, 1) + self.assertEquals(t.children[1].children[0].leafcount, 1) + self.assertEquals(t.children[1].children[1].leafcount, 1) class TestUnrootedDendrogram(unittest.TestCase): diff --git a/gneiss/plot/tests/test_radial.py b/gneiss/plot/tests/test_radial.py new file mode 100644 index 0000000..42bf746 --- /dev/null +++ b/gneiss/plot/tests/test_radial.py @@ -0,0 +1,92 @@ +import unittest +import pandas as pd +import numpy as np +from scipy.cluster.hierarchy import ward + +from skbio import TreeNode, DistanceMatrix +from gneiss.plot._radial import radialplot +from gneiss.plot._dendrogram import UnrootedDendrogram + + +class TestRadial(unittest.TestCase): + def setUp(self): + + self.coords = pd.DataFrame( + [['487.5', '347.769', 'NaN', 'NaN', 'True'], + ['12.5', '483.28', 'NaN', 'NaN', 'True'], + ['324.897', '16.7199', 'NaN', 'NaN', 'True'], + ['338.261', '271.728', '0', '2', 'False'], + ['193.169', '365.952', '1', 'y3', 'False']], + columns=['x', 'y', 'child0', 'child1', 'is_tip'], + index=['0', '1', '2', 'y3', 'y4']) + + def test_basic_plot(self): + exp_edges = {'dest_node': ['0', '1', '2', 'y3'], + 'edge_alpha': [1, 1, 1, 1], + 'edge_color': ['#00FF00', '#00FF00', + '#00FF00', '#FF0000'], + 'edge_width': [2, 2, 2, 2], + 'index': [0, 1, 2, 3], + 'src_node': ['y3', 'y4', 'y3', 'y4'], + 'x0': [338.2612593838583, + 193.1688862557773, + 338.2612593838583, + 193.1688862557773], + 'x1': [487.5, 12.499999999999972, + 324.89684138234867, 338.2612593838583], + 'y0': [271.7282256126416, + 365.95231443706376, + 271.7282256126416, + 365.95231443706376], + 'y1': [347.7691620070637, + 483.2800610261029, + 16.719938973897143, + 271.7282256126416]} + + exp_nodes = {'child0': [np.nan, np.nan, np.nan, '0', '1'], + 'child1': [np.nan, np.nan, np.nan, '2', 'y3'], + 'color': ['#1C9099', '#1C9099', '#1C9099', + '#FF999F', '#FF999F'], + 'index': ['0', '1', '2', 'y3', 'y4'], + 'is_tip': [True, True, True, False, False], + 'node_alpha': [1, 1, 1, 1, 1], + 'node_size': [10, 10, 10, 10, 10], + 'x': [487.5, + 12.499999999999972, + 324.89684138234867, + 338.26125938385832, + 193.16888625577729], + 'y': [347.7691620070637, + 483.28006102610289, + 16.719938973897143, + 271.72822561264161, + 365.95231443706376]} + np.random.seed(0) + num_otus = 3 # otus + x = np.random.rand(num_otus) + dm = DistanceMatrix.from_iterable(x, lambda x, y: np.abs(x-y)) + lm = ward(dm.condensed_form()) + t = TreeNode.from_linkage_matrix(lm, np.arange(len(x)).astype(np.str)) + t = UnrootedDendrogram.from_tree(t) + # incorporate colors in tree + for i, n in enumerate(t.postorder(include_self=True)): + if not n.is_tip(): + n.name = "y%d" % i + n.color = '#FF999F' + n.edge_color = '#FF0000' + n.node_size = 10 + else: + n.color = '#1C9099' + n.edge_color = '#00FF00' + n.node_size = 10 + n.length = np.random.rand()*3 + n.edge_width = 2 + p = radialplot(t, node_color='color', edge_color='edge_color', + node_size='node_size', edge_width='edge_width') + + self.assertDictEqual(p.renderers[0].data_source.data, exp_edges) + self.assertDictEqual(p.renderers[1].data_source.data, exp_nodes) + + +if __name__ == "__main__": + unittest.main()