diff --git a/gneiss/plot/_dendrogram.py b/gneiss/plot/_dendrogram.py index e9884eb..12a8bd8 100644 --- a/gneiss/plot/_dendrogram.py +++ b/gneiss/plot/_dendrogram.py @@ -5,15 +5,11 @@ # # The full license is in the file COPYING.txt, distributed with this software. # ---------------------------------------------------------------------------- +import abc +from collections import namedtuple from skbio import TreeNode import pandas as pd import numpy -import abc - - -def _sign(x): - """Returns True if x is positive, False otherwise.""" - return x and x/abs(x) class Dendrogram(TreeNode): @@ -39,20 +35,16 @@ class Dendrogram(TreeNode): Notes ----- - `length` refers to the branch length connect to the specified subtree. + `length` refers to the branch length of a node to its parent. `leafcount` is the number of tips within a subtree. `height` refers to the longest path from root to the deepst leaf in that subtree. `depth` is the number of nodes found in the longest path. """ - aspect_distorts_lengths = True - def __init__(self, use_lengths=True, **kwargs): """ Constructs a Dendrogram object for visualization. - """ super().__init__(**kwargs) - self.use_lengths_default = use_lengths def _cache_ntips(self): """ Counts the number of leaves under each subtree.""" @@ -80,8 +72,6 @@ def update_geometry(self, use_lengths, depth=None): self.length = 0 else: self.length = 1 - else: - self.length = self.length self.depth = (depth or 0) + self.length @@ -105,7 +95,6 @@ def coords(self, height, width): The height of the canvas. width : int The width of the canvas. - Returns ------- pd.DataFrame @@ -156,9 +145,10 @@ class UnrootedDendrogram(Dendrogram): Attributes ---------- length + leafcount + height + depth """ - aspect_distorts_lengths = True - def __init__(self, **kwargs): """ Constructs a UnrootedDendrogram object for visualization. @@ -190,8 +180,7 @@ def from_tree(cls, tree, use_lengths=True): return tree def rescale(self, width, height): - """ Find best scaling factor for fitting the tree in the dimensions - specified by width and height. + """ Find best scaling factor for fitting the tree in the figure. This method will find the best orientation and scaling possible to fit the tree within the dimensions specified by width and height. @@ -290,3 +279,129 @@ def update_coordinates(self, s, x1, y1, a, da): points += child.update_coordinates(s, x2, y2, a+ca/2, da) a += ca return points + + +Dimensions = namedtuple('Dimensions', ['x', 'y', 'height']) + + +class RootedDendrogram(Dendrogram): + """ Stores data to be plotted as an rooted dendrogram. + + A `RootedDendrogram` object is represents a tree in addition to the + key information required to create a radial tree layout prior to + visualization. + + Parameters + ---------- + use_lengths: bool + Specifies if the branch lengths should be included in the + resulting visualization (default True). + + Attributes + ---------- + length + leafcount + height + depth + """ + + def width_required(self): + return self.leafcount + + @abc.abstractmethod + def xcoords(self, scale, x1): + pass + + @abc.abstractmethod + def ycoords(self, scale, y1): + pass + + def rescale(self, width, height): + """ Update x, y coordinates of tree nodes in canvas. + + Parameters + ---------- + scale : Dimensions + Scaled dimensions of the tree + x1 : int + X-coordinate of parent + """ + xscale = width / self.height + yscale = height / self.width_required() + scale = Dimensions(xscale, yscale, self.height) + + # y coords done postorder, x preorder, y first. + # so it has to be done in 2 passes. + self.update_y_coordinates(scale) + self.update_x_coordinates(scale) + return xscale + + def update_y_coordinates(self, scale, y1=None): + """The second pass through the tree. Y coordinates only + depend on the shape of the tree and yscale. + + Parameters + ---------- + scale : Dimensions + Scaled dimensions of the tree + x1 : int + X-coordinate of parent + """ + if y1 is None: + y1 = self.width_required() * scale.y + child_y = y1 + for child in self.children: + child.update_y_coordinates(scale, child_y) + child_y -= child.width_required() * scale.y + (self.y1, self.y2) = self.ycoords(scale, y1) + + def update_x_coordinates(self, scale, x1=0): + """For non 'square' styles the x coordinates will depend + (a bit) on the y coodinates, so they should be done first. + Parameters + ---------- + scale : Dimensions + Scaled dimensions of the tree + x1 : int + X-coordinate of parent + """ + (self.x1, self.x2) = self.xcoords(scale, x1) + for child in self.children: + child.update_x_coordinates(scale, self.x2) + + +class SquareDendrogram(RootedDendrogram): + + def ycoords(self, scale, y1): + cys = [c.y1 for c in self.children] + if cys: + y2 = (cys[0]+cys[-1]) / 2.0 + else: + y2 = y1 - 0.5 * scale.y + return (y2, y2) + + def xcoords(self, scale, x1): + if self.is_tip(): + return (x1, (scale.height-(self.height-self.length))*scale.x) + else: + # give some margins for internal nodes + dx = scale.x * self.length * 0.95 + x2 = x1 + dx + return (x1, x2) + + @classmethod + def from_tree(cls, tree): + """ Creates an SquareDendrogram object from a skbio tree. + + Parameters + ---------- + tree : skbio.TreeNode + Input skbio tree + Returns + ------- + SquareDendrogram + """ + for n in tree.postorder(include_self=True): + n.__class__ = SquareDendrogram + tree.update_geometry(use_lengths=False) + return tree diff --git a/gneiss/plot/_heatmap.py b/gneiss/plot/_heatmap.py index 97c3815..30bfdeb 100644 --- a/gneiss/plot/_heatmap.py +++ b/gneiss/plot/_heatmap.py @@ -6,112 +6,196 @@ # The full license is in the file COPYING.txt, distributed with this software. # ---------------------------------------------------------------------------- import numpy as np -from ete3 import TreeStyle, AttrFace, ProfileFace -from ete3 import ClusterNode -from ete3.treeview.faces import add_face_to_node -import io +import matplotlib.pyplot as plt +import matplotlib.patches as patches +import pandas as pd +from gneiss.plot._dendrogram import SquareDendrogram -def heatmap(table, tree, cmap='viridis', **kwargs): - """ Plots tree on heatmap +def heatmap(table, tree, mdvar, highlights=None, + grid_col='w', grid_width=2, dendrogram_width=20, + highlight_width=0.02, figsize=(5, 5)): + """ Creates heatmap plotting object Parameters ---------- table : pd.DataFrame - Contingency table where samples correspond to rows and - features correspond to columns. - tree : skbio.TreeNode - A strictly bifurcating tree defining a hierarchical relationship - between all of the features within `table`. - cmap: matplotlib colormap - String or function encoding matplotlib colormap. - labelcolor: str - Color of the node labels. (default 'black') - rowlabel_size : int - Size of row labels. (default 8) - width : int - Heatmap cell width. (default 200) - height : int - Heatmap cell height (default 14) + Contain sample/feature labels along with table of values. + Rows correspond to samples, and columns correspond to features. + tree: skbio.TreeNode + Tree representing the feature hierarchy. + highlights: pd.DataFrame or dict of tuple of str + List of internal nodes in the tree to highlight. + Each internal node must contain two colors, one for the left + subtree and the other for the right subtree highlight. + The first color will always correspond to the left subtree, + and the second color will always correspond to the right subtree. + mdvar: pd.Series + Metadata values for samples. The index must correspond to the + index of `table`. + highlight_width : int + Width of highlights. (default=0.02) + dendrogram_width : int + Width of axes for dendrogram plot. (default=20) + grid_col: str + Color of vertical lines for highlighting sample metadata. + (default='w') + grid_width: int + Width of vertical lines for highlighting sample metadata. + (default=2) + figsize: tuple of int + Species (width, height) for figure. (default=(5, 5)) + Returns ------- - ete.Tree - ETE tree object that will be plotted. - ete.TreeStyle - ETE TreeStyle that decorates the tree and heatmap visualization. + matplotlib.pyplot.figure + Matplotlib figure object + + Note + ---- + The highlights parameter assumes that the tree is bifurcating. + """ + + # get edges from tree + t = SquareDendrogram.from_tree(tree) + t = _tree_coordinates(t) + pts = t.coords(width=dendrogram_width, height=table.shape[0]) + edges = pts[['child0', 'child1']] + edges = edges.dropna(subset=['child0', 'child1']) + edges = edges.unstack() + edges = pd.DataFrame({'src_node': edges.index.get_level_values(1), + 'dest_node': edges.values}) + edges['x0'] = [pts.loc[n].x for n in edges.src_node] + edges['x1'] = [pts.loc[n].x for n in edges.dest_node] + edges['y0'] = [pts.loc[n].y for n in edges.src_node] + edges['y1'] = [pts.loc[n].y for n in edges.dest_node] + + # now plot the stuff + fig = plt.figure(figsize=figsize) + + xwidth = 0.2 + top_buffer = 0.1 + height = 0.8 + + # heatmap axes + [axm_x, axm_y, axm_w, axm_h] = [0, top_buffer, xwidth, height] + + # create a split for the highlights + if highlights is not None: + h = len(highlights) + else: + h = 0 + hwidth = highlight_width + [axs_x, axs_y, axs_w, axs_h] = [xwidth, top_buffer, hwidth * h, height] + + # dendrogram axes on the right side + hstart = xwidth + (h * hwidth) # beginning of heatmap + [ax1_x, ax1_y, ax1_w, ax1_h] = [hstart, top_buffer, 1-hstart, height] + + # plot heatmap + ax_heatmap = fig.add_axes([ax1_x, ax1_y, ax1_w, ax1_h], frame_on=True) + _plot_heatmap(ax_heatmap, table, mdvar, grid_col, grid_width) + + # plot dendrogram + ax_dendrogram = fig.add_axes([axm_x, axm_y, axm_w, axm_h], + frame_on=True, sharey=ax_heatmap) + _plot_dendrogram(ax_dendrogram, table, edges) + + # plot highlights for dendrogram + if highlights is not None: + ax_highlights = fig.add_axes([axs_x, axs_y, axs_w, axs_h], + frame_on=True, sharey=ax_heatmap) + _plot_highlights_dendrogram(ax_highlights, table, t, highlights) + return fig + + +def _tree_coordinates(t): + """ Builds a matrix to link tree positions to matrix""" + # first traverse the tree to count the children + for n in t.postorder(include_self=True): + if n.is_tip(): + n._n_tips = 1 + else: + n._n_tips = sum(c._n_tips for c in n.children) + + for i, n in enumerate(t.levelorder(include_self=True)): + if n.is_root(): + n._k = 0 + n._t = 0 + else: + if n is n.parent.children[0]: + n._k = n.parent._k + n.parent._r + n._t = n.parent._t + else: + n._k = n.parent._k + n._t = n.parent._t + n.parent._l + if n.is_tip(): + continue + n._l, n._r = n.children[0]._n_tips, n.children[1]._n_tips + return t + + +def _plot_highlights_dendrogram(ax_highlights, table, t, highlights): + """ Plots highlights for subtrees in the dendrograms. + + Note that this assumes that the dendrograms are strictly bifurcating + and the highlights only specify the children for a given subtree. """ - # TODO: Allow for the option to encode labels in different colors - # (i.e. pass in a pandas series) - params = {'rowlabel_size': 8, 'width': 200, 'height': 14, - 'cmap': 'viridis', 'labelcolor': 'black', - # TODO: Enable layout - # layout : function, optional - # A layout for formatting the tree visualization. Must take a - # `ete.tree` as a parameter. - 'layout': lambda x: x} - - for key in params.keys(): - params[key] = kwargs.get(key, params[key]) - fsize = params['rowlabel_size'] - width = params['width'] - height = params['height'] - colorscheme = params['cmap'] - layout = params['layout'] - - # Allow for matplotlib colors to be encoded in ETE3 heatmap - # Originally from https://github.com/lthiberiol/virfac - def get_color_gradient(self): - from PyQt4 import QtGui - try: - import matplotlib.pyplot as plt - import matplotlib.colors as colors - import matplotlib.cm as cmx - except: - ImportError("Matplotlib not installed.") - - cNorm = colors.Normalize(vmin=0, vmax=1) - scalarMap = cmx.ScalarMappable(norm=cNorm, - cmap=plt.get_cmap(self.colorscheme)) - color_scale = [] - for scale in np.linspace(0, 1, 255): - [r, g, b, a] = scalarMap.to_rgba(scale, bytes=True) - color_scale.append(QtGui.QColor(r, g, b, a)) - return color_scale - - ProfileFace.get_color_gradient = get_color_gradient - tree.name = "" - - f = io.StringIO() - table.T.to_csv(f, sep='\t', index_label='#Names') - - tr = ClusterNode(str(tree), text_array=str(f.getvalue())) - matrix_max = np.max(table.values) - matrix_min = np.min(table.values) - matrix_avg = matrix_min + ((matrix_max - matrix_min) / 2) - - # Encode the actual profile face - nameFace = AttrFace("name", fsize=fsize) - - def heatmap_layout(node): - # Run the layout passed in first before - # filling in the heatmap - layout(node) - - if node.is_leaf(): - profileFace = ProfileFace( - values_vector=table.loc[:, node.name].values, - style="heatmap", - max_v=matrix_max, min_v=matrix_min, - center_v=matrix_avg, - colorscheme=colorscheme, - width=width, height=height) - - add_face_to_node(profileFace, node, 0, aligned=True) - node.img_style["size"] = 0 - add_face_to_node(nameFace, node, 1, aligned=True) - - ts = TreeStyle() - ts.layout_fn = heatmap_layout - - return tr, ts + offset = 0.5 + + num_h = len(highlights) + hcoords = [] + for i, n in enumerate(highlights.index): + node = t.find(n) + k, l, r = node._k, node._l, node._r + + ax_highlights.add_patch( + patches.Rectangle( + (i/num_h, k-offset), # x, y + 1/num_h, # width + r, # height + facecolor=highlights.iloc[i, 0] + )) + + ax_highlights.add_patch( + patches.Rectangle( + (i/num_h, k+r-offset), # x, y + 1/num_h, # width + l, # height + facecolor=highlights.iloc[i, 1] + )) + hcoords.append((i+offset)/num_h) + ax_highlights.set_ylim([-offset, table.shape[0]-offset]) + ax_highlights.set_yticks([]) + ax_highlights.set_xticks(hcoords) + ax_highlights.set_xticklabels(highlights.index, rotation=90) + + +def _plot_dendrogram(ax_dendrogram, table, edges): + """ Plots the actual dendrogram.""" + offset = 0.5 + # offset = 0 + for i in range(len(edges.index)): + row = edges.iloc[i] + ax_dendrogram.plot([row.x0, row.x1], + [row.y0-offset, row.y1-offset], '-k') + ax_dendrogram.set_ylim([-offset, table.shape[0]-offset]) + ax_dendrogram.set_yticks([]) + ax_dendrogram.set_xticks([]) + + +def _plot_heatmap(ax_heatmap, table, mdvar, grid_col, grid_width): + ax_heatmap.imshow(table, aspect='auto', interpolation='nearest') + ax_heatmap.set_ylim([0, table.shape[0]]) + vcounts = mdvar.value_counts() + ticks = vcounts.sort_index().cumsum() + midpoints = ticks - (ticks - np.array([0] + list(ticks.values[:-1]))) / 2.0 + ax_heatmap.set_xticks(ticks.values-0.5, minor=False) + ax_heatmap.set_xticklabels([], minor=False) + + ax_heatmap.xaxis.grid(True, which='major', color=grid_col, + linestyle='-', linewidth=grid_width) + + ax_heatmap.set_xticks(midpoints-0.5, minor=True) + ax_heatmap.set_xticklabels(vcounts.index, minor=True) diff --git a/gneiss/plot/tests/test_dendrogram.py b/gneiss/plot/tests/test_dendrogram.py index 7b8f065..ef3a903 100644 --- a/gneiss/plot/tests/test_dendrogram.py +++ b/gneiss/plot/tests/test_dendrogram.py @@ -9,7 +9,8 @@ import numpy as np import pandas as pd from skbio import DistanceMatrix, TreeNode -from gneiss.plot._dendrogram import Dendrogram, UnrootedDendrogram +from gneiss.plot._dendrogram import (Dendrogram, UnrootedDendrogram, + SquareDendrogram) from scipy.cluster.hierarchy import ward import pandas.util.testing as pdt @@ -90,8 +91,8 @@ def test_coords(self): def test_rescale(self): t = UnrootedDendrogram.from_tree(self.tree) - self.assertAlmostEquals(t.rescale(500, 500), 91.608680314971238, - places=5) + self.assertAlmostEqual(t.rescale(500, 500), 91.608680314971238, + places=5) def test_update_coordinates(self): t = UnrootedDendrogram.from_tree(self.tree) @@ -110,5 +111,50 @@ def test_update_coordinates(self): pdt.assert_frame_equal(res, exp, check_less_precise=True) +class TestSquareDendrogram(unittest.TestCase): + + def setUp(self): + np.random.seed(0) + self.table = pd.DataFrame(np.random.random((5, 5))) + num_otus = 5 # otus + x = np.random.rand(num_otus) + dm = DistanceMatrix.from_iterable(x, lambda x, y: np.abs(x-y)) + lm = ward(dm.condensed_form()) + t = TreeNode.from_linkage_matrix(lm, np.arange(len(x)).astype(np.str)) + self.tree = SquareDendrogram.from_tree(t) + + for i, n in enumerate(t.postorder()): + if not n.is_tip(): + n.name = "y%d" % i + n.length = np.random.rand()*3 + + def test_from_tree(self): + t = SquareDendrogram.from_tree(self.tree) + self.assertEqual(t.__class__, SquareDendrogram) + + def test_coords(self): + # just test to make sure that the coordinates are calculated properly. + t = SquareDendrogram.from_tree(self.tree) + + exp = pd.DataFrame({'0': [20, 2.5, np.nan, np.nan, True], + '1': [20, 3.5, np.nan, np.nan, True], + '2': [20, 4.5, np.nan, np.nan, True], + '3': [20, 1.5, np.nan, np.nan, True], + '4': [20, 0.5, np.nan, np.nan, True], + 'y5': [14.25, 1, '3', '4', False], + 'y6': [9.5, 1.75, '0', 'y5', False], + 'y7': [4.75, 2.625, '1', 'y6', False], + 'y8': [0, 3.5625, '2', 'y7', False]}, + index=['x', 'y', 'child0', 'child1', 'is_tip']).T + + res = t.coords(width=20, height=self.table.shape[0]) + pdt.assert_frame_equal(exp, res) + + def test_rescale(self): + t = SquareDendrogram.from_tree(self.tree) + res = t.rescale(10, 10) + self.assertEqual(res, 2.5) + + if __name__ == "__main__": unittest.main() diff --git a/gneiss/plot/tests/test_heatmap.py b/gneiss/plot/tests/test_heatmap.py index f9b7945..8219099 100644 --- a/gneiss/plot/tests/test_heatmap.py +++ b/gneiss/plot/tests/test_heatmap.py @@ -1,28 +1,106 @@ from gneiss.plot import heatmap import pandas as pd -from skbio import TreeNode +from skbio import TreeNode, DistanceMatrix +from scipy.cluster.hierarchy import ward +from gneiss.plot._dendrogram import SquareDendrogram +import numpy as np +import numpy.testing.utils as npt import unittest -import os class HeatmapTest(unittest.TestCase): def setUp(self): - self.fname = 'test.pdf' - - def tearDown(self): - if os.path.exists(self.fname): - os.remove(self.fname) - - def test_not_fail(self): - t = pd.DataFrame({'a': [1, 2, 3], - 'b': [4, 5, 6], - 'c': [7, 8, 9]}, - index=['x', 'y', 'z']) - r = TreeNode.read([r"((a,b),c);"]) - tr, ts = heatmap(t, r, cmap='viridis', rowlabel_size=14) - tr.render(file_name=self.fname, tree_style=ts) - self.assertTrue(os.path.exists(self.fname)) - self.assertTrue(os.path.getsize(self.fname) > 0) + np.random.seed(0) + self.table = pd.DataFrame(np.random.random((5, 5))) + num_otus = 5 # otus + x = np.random.rand(num_otus) + dm = DistanceMatrix.from_iterable(x, lambda x, y: np.abs(x-y)) + lm = ward(dm.condensed_form()) + t = TreeNode.from_linkage_matrix(lm, np.arange(len(x)).astype(np.str)) + self.t = SquareDendrogram.from_tree(t) + self.md = pd.Series(['a', 'a', 'a', 'b', 'b']) + for i, n in enumerate(t.postorder()): + if not n.is_tip(): + n.name = "y%d" % i + n.length = np.random.rand()*3 + + self.highlights = pd.DataFrame({'y8': ['#FF0000', '#00FF00'], + 'y6': ['#0000FF', '#F0000F']}).T + + def test_basic(self): + fig = heatmap(self.table, self.t, self.md) + + # Test to see if the lineages of the tree are ok + lines = list(fig.get_axes()[1].get_lines()) + pts = self.t.coords(width=20, height=self.table.shape[0]) + pts['y'] = pts['y'] - 0.5 # account for offset + pts['x'] = pts['x'].astype(np.float) + pts['y'] = pts['y'].astype(np.float) + + npt.assert_allclose(lines[0]._xy, + pts.loc[['y5', '3'], ['x', 'y']]) + npt.assert_allclose(lines[1]._xy, + pts.loc[['y6', '0'], ['x', 'y']].values) + npt.assert_allclose(lines[2]._xy, + pts.loc[['y7', '1'], ['x', 'y']].values) + npt.assert_allclose(lines[3]._xy, + pts.loc[['y8', '2'], ['x', 'y']].values) + npt.assert_allclose(lines[4]._xy, + pts.loc[['y5', '4'], ['x', 'y']].values) + npt.assert_allclose(lines[5]._xy, + pts.loc[['y6', 'y5'], ['x', 'y']].values) + npt.assert_allclose(lines[6]._xy, + pts.loc[['y7', 'y6'], ['x', 'y']].values) + npt.assert_allclose(lines[7]._xy, + pts.loc[['y8', 'y7'], ['x', 'y']].values) + + # Make sure that the metadata labels are set properly + res = str(fig.get_axes()[0].get_xticklabels(minor=True)[0]) + self.assertEqual(res, "Text(0,0,'a')") + + res = str(fig.get_axes()[0].get_xticklabels(minor=True)[1]) + self.assertEqual(res, "Text(0,0,'b')") + + def test_basic_highlights(self): + fig = heatmap(self.table, self.t, self.md, self.highlights) + + # Test to see if the lineages of the tree are ok + lines = list(fig.get_axes()[1].get_lines()) + pts = self.t.coords(width=20, height=self.table.shape[0]) + pts['y'] = pts['y'] - 0.5 # account for offset + pts['x'] = pts['x'].astype(np.float) + pts['y'] = pts['y'].astype(np.float) + + npt.assert_allclose(lines[0]._xy, + pts.loc[['y5', '3'], ['x', 'y']].values) + npt.assert_allclose(lines[1]._xy, + pts.loc[['y6', '0'], ['x', 'y']].values) + npt.assert_allclose(lines[2]._xy, + pts.loc[['y7', '1'], ['x', 'y']].values) + npt.assert_allclose(lines[3]._xy, + pts.loc[['y8', '2'], ['x', 'y']].values) + npt.assert_allclose(lines[4]._xy, + pts.loc[['y5', '4'], ['x', 'y']].values) + npt.assert_allclose(lines[5]._xy, + pts.loc[['y6', 'y5'], ['x', 'y']].values) + npt.assert_allclose(lines[6]._xy, + pts.loc[['y7', 'y6'], ['x', 'y']].values) + npt.assert_allclose(lines[7]._xy, + pts.loc[['y8', 'y7'], ['x', 'y']].values) + + # Make sure that the metadata labels are set properly + res = str(fig.get_axes()[0].get_xticklabels(minor=True)[0]) + self.assertEqual(res, "Text(0,0,'a')") + + res = str(fig.get_axes()[0].get_xticklabels(minor=True)[1]) + self.assertEqual(res, "Text(0,0,'b')") + + # Make sure that the highlight labels are set properly + res = str(fig.get_axes()[2].get_xticklabels()[0]) + self.assertEqual(res, "Text(0,0,'y6')") + + res = str(fig.get_axes()[2].get_xticklabels()[1]) + self.assertEqual(res, "Text(0,0,'y8')") if __name__ == "__main__":