From b74c64bc123d4792c14a47b29001174aa6f79eff Mon Sep 17 00:00:00 2001 From: Jamie Morton Date: Mon, 27 Feb 2017 14:41:36 -0800 Subject: [PATCH 01/12] ENH: heatmap --- gneiss/plot/_heatmap.py | 278 +++++++++++++++++++----------- gneiss/plot/tests/test_heatmap.py | 148 ++++++++++++++-- 2 files changed, 309 insertions(+), 117 deletions(-) diff --git a/gneiss/plot/_heatmap.py b/gneiss/plot/_heatmap.py index 97c3815..d344980 100644 --- a/gneiss/plot/_heatmap.py +++ b/gneiss/plot/_heatmap.py @@ -5,113 +5,193 @@ # # The full license is in the file COPYING.txt, distributed with this software. # ---------------------------------------------------------------------------- -import numpy as np -from ete3 import TreeStyle, AttrFace, ProfileFace -from ete3 import ClusterNode -from ete3.treeview.faces import add_face_to_node import io +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.patches as patches +import pandas as pd +from gneiss.plot._dendrogram import SquareDendrogram -def heatmap(table, tree, cmap='viridis', **kwargs): - """ Plots tree on heatmap +def heatmap(table, tree, mdvar, highlights=None, + grid_col='w', grid_width=2, dendrogram_width=20, + figsize=(5, 5)): + """ Creates heatmap plotting object Parameters ---------- table : pd.DataFrame - Contingency table where samples correspond to rows and - features correspond to columns. - tree : skbio.TreeNode - A strictly bifurcating tree defining a hierarchical relationship - between all of the features within `table`. - cmap: matplotlib colormap - String or function encoding matplotlib colormap. - labelcolor: str - Color of the node labels. (default 'black') - rowlabel_size : int - Size of row labels. (default 8) - width : int - Heatmap cell width. (default 200) - height : int - Heatmap cell height (default 14) + Contain sample/feature labels along with table of values. + Rows correspond to samples, and columns correspond to features. + tree: skbio.TreeNode + Tree representing the feature hierarchy. + highlights: pd.DataFrame or dict of tuple of str + List of internal nodes in the tree to highlight. + Each internal node must contain two colors, one for the left + subtree and the other for the right subtree highlight. + The first color will always correspond to the left subtree, + and the second color will always correspond to the right subtree. + mdvar: pd.Series + Metadata values for samples. The index must correspond to the + index of `table`. + dendrogram_width : int + Width of axes for dendrogram plot. + grid_col: str + Color of vertical lines for highlighting sample metadata. + grid_width: int + Width of vertical lines for highlighting sample metadata. + figsize: tuple of int + Species (width, height) for figure. Returns ------- - ete.Tree - ETE tree object that will be plotted. - ete.TreeStyle - ETE TreeStyle that decorates the tree and heatmap visualization. + matplotlib.pyplot.figure + Matplotlib figure object + + Note + ---- + The highlights parameter assumes that the tree is bifurcating. + """ + + # get edges from tree + t = SquareDendrogram.from_tree(tree) + t = _tree_coordinates(t) + pts = t.coords(width=dendrogram_width, height=table.shape[0]) + edges = pts[['child0', 'child1']] + edges = edges.dropna(subset=['child0', 'child1']) + edges = edges.unstack() + edges = pd.DataFrame({'src_node': edges.index.get_level_values(1), + 'dest_node': edges.values}) + edges['x0'] = [pts.loc[n].x for n in edges.src_node] + edges['x1'] = [pts.loc[n].x for n in edges.dest_node] + edges['y0'] = [pts.loc[n].y for n in edges.src_node] + edges['y1'] = [pts.loc[n].y for n in edges.dest_node] + + # now plot the stuff + fig = plt.figure(figsize=figsize) + + xwidth = 0.2 + top_buffer = 0.1 + height = 0.8 + + # heatmap axes + [axm_x, axm_y, axm_w, axm_h] = [0, top_buffer, xwidth, height] + + # create a split for the highlights + if highlights is not None: + h = len(highlights) + else: + h = 0 + hwidth = 0.02 + [axs_x, axs_y, axs_w, axs_h] = [xwidth, top_buffer, hwidth * h, height] + + # dendrogram axes on the right side + hstart = xwidth + (h * hwidth) # beginning of heatmap + [ax1_x, ax1_y, ax1_w, ax1_h] = [hstart, top_buffer, 1-hstart, height] + + # plot heatmap + ax_heatmap = fig.add_axes([ax1_x, ax1_y, ax1_w, ax1_h], frame_on=True) + _plot_heatmap(ax_heatmap, table, mdvar, grid_col, grid_width) + + # plot dendrogram + ax_dendrogram = fig.add_axes([axm_x, axm_y, axm_w, axm_h], + frame_on=True, sharey=ax_heatmap) + _plot_dendrogram(ax_dendrogram, table, edges) + + # plot highlights for dendrogram + if highlights is not None: + ax_highlights = fig.add_axes([axs_x, axs_y, axs_w, axs_h], + frame_on=True, sharey=ax_heatmap) + _plot_highlights_dendrogram(ax_highlights, table, t, highlights) + return fig + + +def _tree_coordinates(t): + """ Builds a matrix to link tree positions to matrix""" + # first traverse the tree to count the children + for n in t.postorder(include_self=True): + if n.is_tip(): + n._n_tips = 1 + else: + n._n_tips = sum(c._n_tips for c in n.children) + + for i, n in enumerate(t.levelorder(include_self=True)): + if n.is_root(): + n._k = 0 + n._t = 0 + else: + if n is n.parent.children[0]: + n._k = n.parent._k + n.parent._r + n._t = n.parent._t + else: + n._k = n.parent._k + n._t = n.parent._t + n.parent._l + if n.is_tip(): + continue + n._l, n._r = n.children[0]._n_tips, n.children[1]._n_tips + return t + + +def _plot_highlights_dendrogram(ax_highlights, table, t, highlights): + """ Plots highlights for subtrees in the dendrograms. + + Note that this assumes that the dendrograms are strictly bifurcating + and the highlights only specify the children for a given subtree. """ - # TODO: Allow for the option to encode labels in different colors - # (i.e. pass in a pandas series) - params = {'rowlabel_size': 8, 'width': 200, 'height': 14, - 'cmap': 'viridis', 'labelcolor': 'black', - # TODO: Enable layout - # layout : function, optional - # A layout for formatting the tree visualization. Must take a - # `ete.tree` as a parameter. - 'layout': lambda x: x} - - for key in params.keys(): - params[key] = kwargs.get(key, params[key]) - fsize = params['rowlabel_size'] - width = params['width'] - height = params['height'] - colorscheme = params['cmap'] - layout = params['layout'] - - # Allow for matplotlib colors to be encoded in ETE3 heatmap - # Originally from https://github.com/lthiberiol/virfac - def get_color_gradient(self): - from PyQt4 import QtGui - try: - import matplotlib.pyplot as plt - import matplotlib.colors as colors - import matplotlib.cm as cmx - except: - ImportError("Matplotlib not installed.") - - cNorm = colors.Normalize(vmin=0, vmax=1) - scalarMap = cmx.ScalarMappable(norm=cNorm, - cmap=plt.get_cmap(self.colorscheme)) - color_scale = [] - for scale in np.linspace(0, 1, 255): - [r, g, b, a] = scalarMap.to_rgba(scale, bytes=True) - color_scale.append(QtGui.QColor(r, g, b, a)) - return color_scale - - ProfileFace.get_color_gradient = get_color_gradient - tree.name = "" - - f = io.StringIO() - table.T.to_csv(f, sep='\t', index_label='#Names') - - tr = ClusterNode(str(tree), text_array=str(f.getvalue())) - matrix_max = np.max(table.values) - matrix_min = np.min(table.values) - matrix_avg = matrix_min + ((matrix_max - matrix_min) / 2) - - # Encode the actual profile face - nameFace = AttrFace("name", fsize=fsize) - - def heatmap_layout(node): - # Run the layout passed in first before - # filling in the heatmap - layout(node) - - if node.is_leaf(): - profileFace = ProfileFace( - values_vector=table.loc[:, node.name].values, - style="heatmap", - max_v=matrix_max, min_v=matrix_min, - center_v=matrix_avg, - colorscheme=colorscheme, - width=width, height=height) - - add_face_to_node(profileFace, node, 0, aligned=True) - node.img_style["size"] = 0 - add_face_to_node(nameFace, node, 1, aligned=True) - - ts = TreeStyle() - ts.layout_fn = heatmap_layout - - return tr, ts + offset = 0.5 + + num_h = len(highlights) + hcoords = [] + for i, n in enumerate(highlights.index): + node = t.find(n) + k, l, r = node._k, node._l, node._r + + ax_highlights.add_patch( + patches.Rectangle( + (i/num_h, k-offset), # x, y + 1/num_h, # width + r, # height + facecolor=highlights.iloc[i, 0] + )) + + ax_highlights.add_patch( + patches.Rectangle( + (i/num_h, k+r-offset), # x, y + 1/num_h, # width + l, # height + facecolor=highlights.iloc[i, 1] + )) + hcoords.append((i+offset)/num_h) + ax_highlights.set_ylim([-offset, table.shape[0]-offset]) + ax_highlights.set_yticks([]) + ax_highlights.set_xticks(hcoords) + ax_highlights.set_xticklabels(highlights, rotation=90) + + +def _plot_dendrogram(ax_dendrogram, table, edges): + """ Plots the actual dendrogram.""" + offset = 0.5 + # offset = 0 + for i in range(len(edges.index)): + row = edges.iloc[i] + ax_dendrogram.plot([row.x0, row.x1], + [row.y0-offset, row.y1-offset], '-k') + ax_dendrogram.set_ylim([-offset, table.shape[0]-offset]) + ax_dendrogram.set_yticks([]) + ax_dendrogram.set_xticks([]) + + +def _plot_heatmap(ax_heatmap, table, mdvar, grid_col, grid_width): + ax_heatmap.imshow(table, aspect='auto', interpolation='nearest') + ax_heatmap.set_ylim([0, table.shape[0]]) + vcounts = mdvar.value_counts() + ticks = vcounts.sort_index().cumsum() + midpoints = ticks - (ticks - np.array([0] + list(ticks.values[:-1]))) / 2.0 + ax_heatmap.set_xticks(ticks.values-0.5, minor=False) + ax_heatmap.set_xticklabels([], minor=False) + + ax_heatmap.xaxis.grid(True, which='major', color=grid_col, + linestyle='-', linewidth=grid_width) + + ax_heatmap.set_xticks(midpoints-0.5, minor=True) + ax_heatmap.set_xticklabels(vcounts.index, minor=True) diff --git a/gneiss/plot/tests/test_heatmap.py b/gneiss/plot/tests/test_heatmap.py index f9b7945..223fe21 100644 --- a/gneiss/plot/tests/test_heatmap.py +++ b/gneiss/plot/tests/test_heatmap.py @@ -1,29 +1,141 @@ +import os from gneiss.plot import heatmap import pandas as pd -from skbio import TreeNode +from skbio import TreeNode, DistanceMatrix +from scipy.cluster.hierarchy import ward +from gneiss.plot._dendrogram import SquareDendrogram +import numpy as np +import numpy.testing.utils as npt import unittest -import os class HeatmapTest(unittest.TestCase): def setUp(self): - self.fname = 'test.pdf' - - def tearDown(self): - if os.path.exists(self.fname): - os.remove(self.fname) - - def test_not_fail(self): - t = pd.DataFrame({'a': [1, 2, 3], - 'b': [4, 5, 6], - 'c': [7, 8, 9]}, - index=['x', 'y', 'z']) - r = TreeNode.read([r"((a,b),c);"]) - tr, ts = heatmap(t, r, cmap='viridis', rowlabel_size=14) - tr.render(file_name=self.fname, tree_style=ts) - self.assertTrue(os.path.exists(self.fname)) - self.assertTrue(os.path.getsize(self.fname) > 0) + np.random.seed(0) + self.table = pd.DataFrame(np.random.random((5, 5))) + num_otus = 5 # otus + x = np.random.rand(num_otus) + dm = DistanceMatrix.from_iterable(x, lambda x, y: np.abs(x-y)) + lm = ward(dm.condensed_form()) + t = TreeNode.from_linkage_matrix(lm, np.arange(len(x)).astype(np.str)) + self.t = SquareDendrogram.from_tree(t) + self.md = pd.Series(['a', 'a', 'a', 'b', 'b']) + for i, n in enumerate(t.postorder()): + if not n.is_tip(): + n.name = "y%d" % i + n.length = np.random.rand()*3 + + self.highlights = pd.DataFrame({'y8': ['#FF0000', '#00FF00'], + 'y6': ['#0000FF', '#F0000F']}).T + + def test_basic(self): + fig = heatmap(self.table, self.t, self.md) + + # Test to see if the lineages of the tree are ok + lines = list(fig.get_axes()[1].get_lines()) + pts = self.t.coords(width=20, height=self.table.shape[0]) + pts['y'] = pts['y'] - 0.5 # account for offset + pts['x'] = pts['x'].astype(np.float) + pts['y'] = pts['y'].astype(np.float) + + npt.assert_allclose(lines[0]._xy, + pts.loc[['y5', '3'], ['x', 'y']]) + npt.assert_allclose(lines[1]._xy, + pts.loc[['y6', '0'], ['x', 'y']].values) + npt.assert_allclose(lines[2]._xy, + pts.loc[['y7', '1'], ['x', 'y']].values) + npt.assert_allclose(lines[3]._xy, + pts.loc[['y8', '2'], ['x', 'y']].values) + npt.assert_allclose(lines[4]._xy, + pts.loc[['y5', '4'], ['x', 'y']].values) + npt.assert_allclose(lines[5]._xy, + pts.loc[['y6', 'y5'], ['x', 'y']].values) + npt.assert_allclose(lines[6]._xy, + pts.loc[['y7', 'y6'], ['x', 'y']].values) + npt.assert_allclose(lines[7]._xy, + pts.loc[['y8', 'y7'], ['x', 'y']].values) + + # Make sure that the metadata labels are set properly + res = str(fig.get_axes()[0].get_xticklabels(minor=True)[0]) + self.assertEqual(res, "Text(0,0,'a')") + + res = str(fig.get_axes()[0].get_xticklabels(minor=True)[1]) + self.assertEqual(res, "Text(0,0,'b')") + + # make sure that xlims are set properly + self.assertEqual(fig.get_axes()[0].get_xlim(), + (-0.5, 4.5)) + + self.assertEqual(fig.get_axes()[1].get_xlim(), + (-1.0, 21.0)) + + # make sure that ylims are set properly + self.assertEqual(fig.get_axes()[0].get_ylim(), + (-0.5, 4.5)) + + self.assertEqual(fig.get_axes()[1].get_ylim(), + (-0.5, 4.5)) + + def test_basic_highlights(self): + fig = heatmap(self.table, self.t, self.md, self.highlights) + + # Test to see if the lineages of the tree are ok + lines = list(fig.get_axes()[1].get_lines()) + pts = self.t.coords(width=20, height=self.table.shape[0]) + pts['y'] = pts['y'] - 0.5 # account for offset + pts['x'] = pts['x'].astype(np.float) + pts['y'] = pts['y'].astype(np.float) + + npt.assert_allclose(lines[0]._xy, + pts.loc[['y5', '3'], ['x', 'y']].values) + npt.assert_allclose(lines[1]._xy, + pts.loc[['y6', '0'], ['x', 'y']].values) + npt.assert_allclose(lines[2]._xy, + pts.loc[['y7', '1'], ['x', 'y']].values) + npt.assert_allclose(lines[3]._xy, + pts.loc[['y8', '2'], ['x', 'y']].values) + npt.assert_allclose(lines[4]._xy, + pts.loc[['y5', '4'], ['x', 'y']].values) + npt.assert_allclose(lines[5]._xy, + pts.loc[['y6', 'y5'], ['x', 'y']].values) + npt.assert_allclose(lines[6]._xy, + pts.loc[['y7', 'y6'], ['x', 'y']].values) + npt.assert_allclose(lines[7]._xy, + pts.loc[['y8', 'y7'], ['x', 'y']].values) + + # Make sure that the metadata labels are set properly + res = str(fig.get_axes()[0].get_xticklabels(minor=True)[0]) + self.assertEqual(res, "Text(0,0,'a')") + + res = str(fig.get_axes()[0].get_xticklabels(minor=True)[1]) + self.assertEqual(res, "Text(0,0,'b')") + + # Make sure that the highlight labels are set properly + res = str(fig.get_axes()[2].get_xticklabels()[0]) + self.assertEqual(res, "Text(0,0,'0')") + + res = str(fig.get_axes()[2].get_xticklabels()[0]) + self.assertEqual(res, "Text(0,0,'0')") + + # make sure that xlims are set properly + self.assertEqual(fig.get_axes()[0].get_xlim(), + (-0.5, 4.5)) + + self.assertEqual(fig.get_axes()[1].get_xlim(), + (-1.0, 21.0)) + + self.assertEqual(fig.get_axes()[2].get_xlim(), + (0.0, 1.0)) + + # make sure that ylims are set properly + self.assertEqual(fig.get_axes()[0].get_ylim(), + (-0.5, 4.5)) + + self.assertEqual(fig.get_axes()[1].get_ylim(), + (-0.5, 4.5)) + self.assertEqual(fig.get_axes()[1].get_ylim(), + (-0.5, 4.5)) if __name__ == "__main__": unittest.main() From c39cdb517713bff5447d930b9d843e600e9eecc7 Mon Sep 17 00:00:00 2001 From: Jamie Morton Date: Mon, 27 Feb 2017 14:49:17 -0800 Subject: [PATCH 02/12] ENH: Including dendrogram code --- gneiss/plot/_dendrogram.py | 176 +++++++++++++++++++++++++++++++------ 1 file changed, 148 insertions(+), 28 deletions(-) diff --git a/gneiss/plot/_dendrogram.py b/gneiss/plot/_dendrogram.py index f37175f..fe4e589 100644 --- a/gneiss/plot/_dendrogram.py +++ b/gneiss/plot/_dendrogram.py @@ -5,26 +5,29 @@ # # The full license is in the file COPYING.txt, distributed with this software. # ---------------------------------------------------------------------------- +import abc +from collections import namedtuple from skbio import TreeNode import pandas as pd import numpy -import abc + + +def _sign(x): + """Returns True if x is positive, False otherwise.""" + return x and x/abs(x) class Dendrogram(TreeNode): """ Stores data to be plotted as a dendrogram. - A `Dendrogram` object is represents a tree in addition to the key information required to create a tree layout prior to visualization. No layouts are specified within this class, since this serves as a super class for different tree layouts. - Parameters ---------- use_lengths: bool Specifies if the branch lengths should be included in the resulting visualization (default True). - Attributes ---------- length @@ -33,7 +36,6 @@ class Dendrogram(TreeNode): def __init__(self, use_lengths=True, **kwargs): """ Constructs a Dendrogram object for visualization. - """ super().__init__(**kwargs) self.use_lengths_default = use_lengths @@ -41,20 +43,43 @@ def __init__(self, use_lengths=True, **kwargs): def _cache_ntips(self): for n in self.postorder(): if n.is_tip(): - n._n_tips = 1 + n.leafcount = 1 else: - n._n_tips = sum(c._n_tips for c in n.children) + n.leafcount = sum(c.leafcount for c in n.children) + + def update_geometry(self, use_lengths, depth=None): + """Calculate tree node attributes such as height and depth. + Despite the name this first pass is ignorant of issues like + scale and orientation""" + if self.length is None or not use_lengths: + if depth is None: + self.length = 0 + else: + self.length = 1 + else: + self.length = self.length + + self.depth = (depth or 0) + self.length + + children = self.children + if children: + for c in children: + c.update_geometry(use_lengths, self.depth) + self.height = max([c.height for c in children]) + self.length + self.leafcount = sum([c.leafcount for c in children]) + + else: + self.height = self.length + self.leafcount = self.edgecount = 1 def coords(self, height, width): """ Returns coordinates of nodes to be rendered in plot. - Parameters ---------- height : int The height of the canvas. width : int The width of the canvas. - Returns ------- pd.DataFrame @@ -91,17 +116,14 @@ def rescale(self, width, height): class UnrootedDendrogram(Dendrogram): """ Stores data to be plotted as an unrooted dendrogram. - A `UnrootedDendrogram` object is represents a tree in addition to the key information required to create a radial tree layout prior to visualization. - Parameters ---------- use_lengths: bool Specifies if the branch lengths should be included in the resulting visualization (default True). - Attributes ---------- length @@ -110,7 +132,6 @@ class UnrootedDendrogram(Dendrogram): def __init__(self, **kwargs): """ Constructs a UnrootedDendrogram object for visualization. - Parameters ---------- use_lengths: bool @@ -120,46 +141,41 @@ def __init__(self, **kwargs): super().__init__(**kwargs) @classmethod - def from_tree(cls, tree): + def from_tree(cls, tree, use_lengths=True): """ Creates an UnrootedDendrogram object from a skbio tree. - Parameters ---------- tree : skbio.TreeNode Input skbio tree - Returns ------- UnrootedDendrogram """ for n in tree.postorder(): n.__class__ = UnrootedDendrogram - tree._cache_ntips() + + tree.update_geometry(use_lengths) return tree def rescale(self, width, height): """ Find best scaling factor for fitting the tree in the dimensions specified by width and height. - This method will find the best orientation and scaling possible to fit the tree within the dimensions specified by width and height. - Parameters ---------- width : float width of the canvas height : float height of the canvas - Returns ------- best_scaling : float largest scaling factor in which the tree can fit in the canvas. - Notes ----- """ - angle = (2 * numpy.pi) / self._n_tips + angle = (2 * numpy.pi) / self.leafcount # this loop is a horrible brute force hack # there are better (but complex) ways to find # the best rotation of the tree to fit the display. @@ -189,13 +205,11 @@ def rescale(self, width, height): def update_coordinates(self, s, x1, y1, a, da): """ Update x, y coordinates of tree nodes in canvas. - `update_coordinates` will recursively updating the plotting parameters for all of the nodes within the tree. This can be applied when the tree becomes modified (i.e. pruning or collapsing) and the resulting coordinates need to be modified to reflect the changes to the tree structure. - Parameters ---------- s : float @@ -208,12 +222,10 @@ def update_coordinates(self, s, x1, y1, a, da): angle (degrees) da : float angle resolution (degrees) - Returns ------- points : list of tuple 2D coordinates of all of the nodes. - Notes ----- This function has a little bit of recursion. This will @@ -224,7 +236,7 @@ def update_coordinates(self, s, x1, y1, a, da): y2 = y1 + self.length * s * numpy.cos(a) (self.x1, self.y1, self.x2, self.y2, self.angle) = (x1, y1, x2, y2, a) # TODO: Add functionality that allows for collapsing of nodes - a = a - self._n_tips * da / 2 + a = a - self.leafcount * da / 2 if self.is_tip(): points = [(x2, y2)] else: @@ -234,7 +246,115 @@ def update_coordinates(self, s, x1, y1, a, da): # need to be refactored to remove the recursion. for child in self.children: # calculate the arc that covers the subtree. - ca = child._n_tips * da + ca = child.leafcount * da points += child.update_coordinates(s, x2, y2, a+ca/2, da) a += ca return points + + +Dimensions = namedtuple('Dimensions', ['x', 'y', 'height']) + + +class RootedDendrogram(Dendrogram): + """RootedDendrogram subclasses provide ycoords and xcoords, which examine + attributes of a node (its length, coodinates of its children) and return + a tuple for start/end of the line representing the edge.""" + + def width_required(self): + return self.leafcount + + @abc.abstractmethod + def xcoords(self, scale, x1): + pass + + @abc.abstractmethod + def ycoords(self, scale, y1): + pass + + def rescale(self, width, height): + """ Update x, y coordinates of tree nodes in canvas. + Parameters + ---------- + scale : Dimensions + Scaled dimensions of the tree + x1 : int + X-coordinate of parent + """ + xscale = width / self.height + yscale = height / self.width_required() + scale = Dimensions(xscale, yscale, self.height) + + # y coords done postorder, x preorder, y first. + # so it has to be done in 2 passes. + self.update_y_coordinates(scale) + self.update_x_coordinates(scale) + return xscale + + def update_y_coordinates(self, scale, y1=None): + """The second pass through the tree. Y coordinates only + depend on the shape of the tree and yscale. + Parameters + ---------- + scale : Dimensions + Scaled dimensions of the tree + x1 : int + X-coordinate of parent + """ + if y1 is None: + y1 = self.width_required() * scale.y + child_y = y1 + for child in self.children: + child.update_y_coordinates(scale, child_y) + child_y -= child.width_required() * scale.y + (self.y1, self.y2) = self.ycoords(scale, y1) + + def update_x_coordinates(self, scale, x1=0): + """For non 'square' styles the x coordinates will depend + (a bit) on the y coodinates, so they should be done first. + Parameters + ---------- + scale : Dimensions + Scaled dimensions of the tree + x1 : int + X-coordinate of parent + """ + (self.x1, self.x2) = self.xcoords(scale, x1) + for child in self.children: + child.update_x_coordinates(scale, self.x2) + + +class SquareDendrogram(RootedDendrogram): + aspect_distorts_lengths = False + + def ycoords(self, scale, y1): + cys = [c.y1 for c in self.children] + if cys: + y2 = (cys[0]+cys[-1]) / 2.0 + else: + y2 = y1 - 0.5 * scale.y + return (y2, y2) + + def xcoords(self, scale, x1): + if self.is_tip(): + return (x1, (scale.height-(self.height-self.length))*scale.x) + else: + # give some margins for internal nodes + dx = scale.x * self.length * 0.95 + x2 = x1 + dx + return (x1, x2) + + @classmethod + def from_tree(cls, tree): + """ Creates an SquareDendrogram object from a skbio tree. + Parameters + ---------- + tree : skbio.TreeNode + Input skbio tree + Returns + ------- + SquareDendrogram + """ + for n in tree.postorder(include_self=True): + n.__class__ = SquareDendrogram + tree.update_geometry(use_lengths=False) + return tree From a227a627863f9f46d4d3963d6551c73dd4d4101c Mon Sep 17 00:00:00 2001 From: Jamie Morton Date: Mon, 27 Feb 2017 15:19:43 -0800 Subject: [PATCH 03/12] STY: pep8 --- gneiss/plot/_heatmap.py | 1 - gneiss/plot/tests/test_heatmap.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/gneiss/plot/_heatmap.py b/gneiss/plot/_heatmap.py index d344980..b3cef67 100644 --- a/gneiss/plot/_heatmap.py +++ b/gneiss/plot/_heatmap.py @@ -5,7 +5,6 @@ # # The full license is in the file COPYING.txt, distributed with this software. # ---------------------------------------------------------------------------- -import io import numpy as np import matplotlib.pyplot as plt import matplotlib.patches as patches diff --git a/gneiss/plot/tests/test_heatmap.py b/gneiss/plot/tests/test_heatmap.py index 223fe21..62e3345 100644 --- a/gneiss/plot/tests/test_heatmap.py +++ b/gneiss/plot/tests/test_heatmap.py @@ -1,4 +1,3 @@ -import os from gneiss.plot import heatmap import pandas as pd from skbio import TreeNode, DistanceMatrix @@ -137,5 +136,6 @@ def test_basic_highlights(self): self.assertEqual(fig.get_axes()[1].get_ylim(), (-0.5, 4.5)) + if __name__ == "__main__": unittest.main() From 1b813bad663820e2bd95374ba087f8fd6a0b515b Mon Sep 17 00:00:00 2001 From: Jamie Morton Date: Mon, 27 Feb 2017 16:17:47 -0800 Subject: [PATCH 04/12] TST: Removing offending matplotlib lines. --- gneiss/plot/tests/test_heatmap.py | 33 ------------------------------- 1 file changed, 33 deletions(-) diff --git a/gneiss/plot/tests/test_heatmap.py b/gneiss/plot/tests/test_heatmap.py index 62e3345..f527b87 100644 --- a/gneiss/plot/tests/test_heatmap.py +++ b/gneiss/plot/tests/test_heatmap.py @@ -61,19 +61,6 @@ def test_basic(self): res = str(fig.get_axes()[0].get_xticklabels(minor=True)[1]) self.assertEqual(res, "Text(0,0,'b')") - # make sure that xlims are set properly - self.assertEqual(fig.get_axes()[0].get_xlim(), - (-0.5, 4.5)) - - self.assertEqual(fig.get_axes()[1].get_xlim(), - (-1.0, 21.0)) - - # make sure that ylims are set properly - self.assertEqual(fig.get_axes()[0].get_ylim(), - (-0.5, 4.5)) - - self.assertEqual(fig.get_axes()[1].get_ylim(), - (-0.5, 4.5)) def test_basic_highlights(self): fig = heatmap(self.table, self.t, self.md, self.highlights) @@ -116,26 +103,6 @@ def test_basic_highlights(self): res = str(fig.get_axes()[2].get_xticklabels()[0]) self.assertEqual(res, "Text(0,0,'0')") - # make sure that xlims are set properly - self.assertEqual(fig.get_axes()[0].get_xlim(), - (-0.5, 4.5)) - - self.assertEqual(fig.get_axes()[1].get_xlim(), - (-1.0, 21.0)) - - self.assertEqual(fig.get_axes()[2].get_xlim(), - (0.0, 1.0)) - - # make sure that ylims are set properly - self.assertEqual(fig.get_axes()[0].get_ylim(), - (-0.5, 4.5)) - - self.assertEqual(fig.get_axes()[1].get_ylim(), - (-0.5, 4.5)) - - self.assertEqual(fig.get_axes()[1].get_ylim(), - (-0.5, 4.5)) - if __name__ == "__main__": unittest.main() From 5b0d3bd05044c804dea8c7032037f28e007ac617 Mon Sep 17 00:00:00 2001 From: Jamie Morton Date: Mon, 27 Feb 2017 20:18:59 -0800 Subject: [PATCH 05/12] STY: pep8 --- gneiss/plot/tests/test_heatmap.py | 1 - 1 file changed, 1 deletion(-) diff --git a/gneiss/plot/tests/test_heatmap.py b/gneiss/plot/tests/test_heatmap.py index f527b87..07c2807 100644 --- a/gneiss/plot/tests/test_heatmap.py +++ b/gneiss/plot/tests/test_heatmap.py @@ -61,7 +61,6 @@ def test_basic(self): res = str(fig.get_axes()[0].get_xticklabels(minor=True)[1]) self.assertEqual(res, "Text(0,0,'b')") - def test_basic_highlights(self): fig = heatmap(self.table, self.t, self.md, self.highlights) From 62ed51f532fd1e72ecf125f5f58f832b5fff8488 Mon Sep 17 00:00:00 2001 From: Jamie Morton Date: Mon, 27 Feb 2017 22:21:13 -0800 Subject: [PATCH 06/12] TST: fix dendrogram test --- gneiss/plot/tests/test_dendrogram.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/gneiss/plot/tests/test_dendrogram.py b/gneiss/plot/tests/test_dendrogram.py index 0892d5c..7b8f065 100644 --- a/gneiss/plot/tests/test_dendrogram.py +++ b/gneiss/plot/tests/test_dendrogram.py @@ -32,13 +32,13 @@ def test_cache_ntips(self): t._cache_ntips() - self.assertEquals(t._n_tips, 4) - self.assertEquals(t.children[0]._n_tips, 2) - self.assertEquals(t.children[1]._n_tips, 2) - self.assertEquals(t.children[0].children[0]._n_tips, 1) - self.assertEquals(t.children[0].children[1]._n_tips, 1) - self.assertEquals(t.children[1].children[0]._n_tips, 1) - self.assertEquals(t.children[1].children[1]._n_tips, 1) + self.assertEquals(t.leafcount, 4) + self.assertEquals(t.children[0].leafcount, 2) + self.assertEquals(t.children[1].leafcount, 2) + self.assertEquals(t.children[0].children[0].leafcount, 1) + self.assertEquals(t.children[0].children[1].leafcount, 1) + self.assertEquals(t.children[1].children[0].leafcount, 1) + self.assertEquals(t.children[1].children[1].leafcount, 1) class TestUnrootedDendrogram(unittest.TestCase): From 8d8a06d53e3778b1b99e24dbd445ac391a894cde Mon Sep 17 00:00:00 2001 From: Jamie Morton Date: Tue, 28 Feb 2017 17:02:18 -0800 Subject: [PATCH 07/12] TST: Adding tests for SquareDendrogram --- gneiss/plot/_dendrogram.py | 71 ++++++++++++++++++---------- gneiss/plot/tests/test_dendrogram.py | 43 ++++++++++++++++- 2 files changed, 87 insertions(+), 27 deletions(-) diff --git a/gneiss/plot/_dendrogram.py b/gneiss/plot/_dendrogram.py index d52b539..8b52fbd 100644 --- a/gneiss/plot/_dendrogram.py +++ b/gneiss/plot/_dendrogram.py @@ -12,27 +12,20 @@ import numpy -def _sign(x): - """Returns True if x is positive, False otherwise.""" - return x and x/abs(x) - - -def _sign(x): - """Returns True if x is positive, False otherwise.""" - return x and x/abs(x) - - class Dendrogram(TreeNode): """ Stores data to be plotted as a dendrogram. + A `Dendrogram` object is represents a tree in addition to the key information required to create a tree layout prior to visualization. No layouts are specified within this class, since this serves as a super class for different tree layouts. + Parameters ---------- use_lengths: bool Specifies if the branch lengths should be included in the resulting visualization (default True). + Attributes ---------- length @@ -61,19 +54,6 @@ def _cache_ntips(self): for n in self.postorder(): if n.is_tip(): n.leafcount = 1 -<<<<<<< HEAD - else: - n.leafcount = sum(c.leafcount for c in n.children) - - def update_geometry(self, use_lengths, depth=None): - """Calculate tree node attributes such as height and depth. - Despite the name this first pass is ignorant of issues like - scale and orientation""" - if self.length is None or not use_lengths: - if depth is None: - self.length = 0 - else: -======= else: n.leafcount = sum(c.leafcount for c in n.children) @@ -94,7 +74,6 @@ def update_geometry(self, use_lengths, depth=None): if depth is None: self.length = 0 else: ->>>>>>> 92b37faf8dbb7f05088e9e5edd2739d47dea8743 self.length = 1 else: self.length = self.length @@ -114,6 +93,7 @@ def update_geometry(self, use_lengths, depth=None): def coords(self, height, width): """ Returns coordinates of nodes to be rendered in plot. + Parameters ---------- height : int @@ -156,22 +136,29 @@ def rescale(self, width, height): class UnrootedDendrogram(Dendrogram): """ Stores data to be plotted as an unrooted dendrogram. + A `UnrootedDendrogram` object is represents a tree in addition to the key information required to create a radial tree layout prior to visualization. + Parameters ---------- use_lengths: bool Specifies if the branch lengths should be included in the resulting visualization (default True). + Attributes ---------- length + leafcount + height + depth """ aspect_distorts_lengths = True def __init__(self, **kwargs): """ Constructs a UnrootedDendrogram object for visualization. + Parameters ---------- use_lengths: bool @@ -183,10 +170,12 @@ def __init__(self, **kwargs): @classmethod def from_tree(cls, tree, use_lengths=True): """ Creates an UnrootedDendrogram object from a skbio tree. + Parameters ---------- tree : skbio.TreeNode Input skbio tree + Returns ------- UnrootedDendrogram @@ -198,20 +187,23 @@ def from_tree(cls, tree, use_lengths=True): return tree def rescale(self, width, height): - """ Find best scaling factor for fitting the tree in the dimensions - specified by width and height. + """ Find best scaling factor for fitting the tree in the figure. + This method will find the best orientation and scaling possible to fit the tree within the dimensions specified by width and height. + Parameters ---------- width : float width of the canvas height : float height of the canvas + Returns ------- best_scaling : float largest scaling factor in which the tree can fit in the canvas. + Notes ----- """ @@ -245,11 +237,13 @@ def rescale(self, width, height): def update_coordinates(self, s, x1, y1, a, da): """ Update x, y coordinates of tree nodes in canvas. + `update_coordinates` will recursively updating the plotting parameters for all of the nodes within the tree. This can be applied when the tree becomes modified (i.e. pruning or collapsing) and the resulting coordinates need to be modified to reflect the changes to the tree structure. + Parameters ---------- s : float @@ -262,10 +256,12 @@ def update_coordinates(self, s, x1, y1, a, da): angle (degrees) da : float angle resolution (degrees) + Returns ------- points : list of tuple 2D coordinates of all of the nodes. + Notes ----- This function has a little bit of recursion. This will @@ -296,6 +292,26 @@ def update_coordinates(self, s, x1, y1, a, da): class RootedDendrogram(Dendrogram): + """ Stores data to be plotted as an rooted dendrogram. + + A `RootedDendrogram` object is represents a tree in addition to the + key information required to create a radial tree layout prior to + visualization. + + Parameters + ---------- + use_lengths: bool + Specifies if the branch lengths should be included in the + resulting visualization (default True). + + Attributes + ---------- + length + leafcount + height + depth + """ + """RootedDendrogram subclasses provide ycoords and xcoords, which examine attributes of a node (its length, coodinates of its children) and return a tuple for start/end of the line representing the edge.""" @@ -313,6 +329,7 @@ def ycoords(self, scale, y1): def rescale(self, width, height): """ Update x, y coordinates of tree nodes in canvas. + Parameters ---------- scale : Dimensions @@ -333,6 +350,7 @@ def rescale(self, width, height): def update_y_coordinates(self, scale, y1=None): """The second pass through the tree. Y coordinates only depend on the shape of the tree and yscale. + Parameters ---------- scale : Dimensions @@ -386,6 +404,7 @@ def xcoords(self, scale, x1): @classmethod def from_tree(cls, tree): """ Creates an SquareDendrogram object from a skbio tree. + Parameters ---------- tree : skbio.TreeNode diff --git a/gneiss/plot/tests/test_dendrogram.py b/gneiss/plot/tests/test_dendrogram.py index 7b8f065..0dee71f 100644 --- a/gneiss/plot/tests/test_dendrogram.py +++ b/gneiss/plot/tests/test_dendrogram.py @@ -9,7 +9,8 @@ import numpy as np import pandas as pd from skbio import DistanceMatrix, TreeNode -from gneiss.plot._dendrogram import Dendrogram, UnrootedDendrogram +from gneiss.plot._dendrogram import (Dendrogram, UnrootedDendrogram, + SquareDendrogram) from scipy.cluster.hierarchy import ward import pandas.util.testing as pdt @@ -110,5 +111,45 @@ def test_update_coordinates(self): pdt.assert_frame_equal(res, exp, check_less_precise=True) +class TestSquareDendrogram(unittest.TestCase): + + def setUp(self): + np.random.seed(0) + self.table = pd.DataFrame(np.random.random((5, 5))) + num_otus = 5 # otus + x = np.random.rand(num_otus) + dm = DistanceMatrix.from_iterable(x, lambda x, y: np.abs(x-y)) + lm = ward(dm.condensed_form()) + t = TreeNode.from_linkage_matrix(lm, np.arange(len(x)).astype(np.str)) + self.tree = SquareDendrogram.from_tree(t) + + for i, n in enumerate(t.postorder()): + if not n.is_tip(): + n.name = "y%d" % i + n.length = np.random.rand()*3 + + def test_from_tree(self): + t = SquareDendrogram.from_tree(self.tree) + self.assertEqual(t.__class__, SquareDendrogram) + + def test_coords(self): + # just test to make sure that the coordinates are calculated properly. + t = SquareDendrogram.from_tree(self.tree) + + exp = pd.DataFrame({'0': [20, 2.5, np.nan, np.nan, True], + '1': [20, 3.5, np.nan, np.nan, True], + '2': [20, 4.5, np.nan, np.nan, True], + '3': [20, 1.5, np.nan, np.nan, True], + '4': [20, 0.5, np.nan, np.nan, True], + 'y5': [14.25, 1, '3', '4', False], + 'y6': [9.5, 1.75, '0', 'y5', False], + 'y7': [4.75, 2.625, '1', 'y6', False], + 'y8': [0, 3.5625, '2', 'y7', False]}, + index=['x', 'y', 'child0', 'child1', 'is_tip']).T + + res = t.coords(width=20, height=self.table.shape[0]) + pdt.assert_frame_equal(exp, res) + + if __name__ == "__main__": unittest.main() From 2a9f89e0e2d229344a757186c3d1fb3326cf8ade Mon Sep 17 00:00:00 2001 From: Jamie Morton Date: Tue, 28 Feb 2017 17:05:25 -0800 Subject: [PATCH 08/12] TST: Adding test_rescale --- gneiss/plot/_dendrogram.py | 4 ---- gneiss/plot/tests/test_dendrogram.py | 5 +++++ 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/gneiss/plot/_dendrogram.py b/gneiss/plot/_dendrogram.py index 8b52fbd..d293443 100644 --- a/gneiss/plot/_dendrogram.py +++ b/gneiss/plot/_dendrogram.py @@ -312,10 +312,6 @@ class RootedDendrogram(Dendrogram): depth """ - """RootedDendrogram subclasses provide ycoords and xcoords, which examine - attributes of a node (its length, coodinates of its children) and return - a tuple for start/end of the line representing the edge.""" - def width_required(self): return self.leafcount diff --git a/gneiss/plot/tests/test_dendrogram.py b/gneiss/plot/tests/test_dendrogram.py index 0dee71f..4a8ef99 100644 --- a/gneiss/plot/tests/test_dendrogram.py +++ b/gneiss/plot/tests/test_dendrogram.py @@ -150,6 +150,11 @@ def test_coords(self): res = t.coords(width=20, height=self.table.shape[0]) pdt.assert_frame_equal(exp, res) + def test_rescale(self): + t = SquareDendrogram.from_tree(self.tree) + res = t.rescale(10, 10) + self.assertEqual(res, 2.5) + if __name__ == "__main__": unittest.main() From 174d4f8270b15889a9d63082df3ae02bfe8cd6b8 Mon Sep 17 00:00:00 2001 From: Jamie Morton Date: Wed, 1 Mar 2017 12:28:17 -0800 Subject: [PATCH 09/12] FIX: fixing highlights --- gneiss/plot/_heatmap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gneiss/plot/_heatmap.py b/gneiss/plot/_heatmap.py index b3cef67..a153a86 100644 --- a/gneiss/plot/_heatmap.py +++ b/gneiss/plot/_heatmap.py @@ -164,7 +164,7 @@ def _plot_highlights_dendrogram(ax_highlights, table, t, highlights): ax_highlights.set_ylim([-offset, table.shape[0]-offset]) ax_highlights.set_yticks([]) ax_highlights.set_xticks(hcoords) - ax_highlights.set_xticklabels(highlights, rotation=90) + ax_highlights.set_xticklabels(highlights.index, rotation=90) def _plot_dendrogram(ax_dendrogram, table, edges): From f64d3519bf98c3dbdedd7ff51b6d9e1981437123 Mon Sep 17 00:00:00 2001 From: Jamie Morton Date: Wed, 1 Mar 2017 12:35:42 -0800 Subject: [PATCH 10/12] ENH: Adding parameter to specify the width of the highlights --- gneiss/plot/_heatmap.py | 13 +++++++++---- gneiss/plot/tests/test_dendrogram.py | 4 ++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/gneiss/plot/_heatmap.py b/gneiss/plot/_heatmap.py index a153a86..30bfdeb 100644 --- a/gneiss/plot/_heatmap.py +++ b/gneiss/plot/_heatmap.py @@ -14,7 +14,7 @@ def heatmap(table, tree, mdvar, highlights=None, grid_col='w', grid_width=2, dendrogram_width=20, - figsize=(5, 5)): + highlight_width=0.02, figsize=(5, 5)): """ Creates heatmap plotting object Parameters @@ -33,14 +33,19 @@ def heatmap(table, tree, mdvar, highlights=None, mdvar: pd.Series Metadata values for samples. The index must correspond to the index of `table`. + highlight_width : int + Width of highlights. (default=0.02) dendrogram_width : int - Width of axes for dendrogram plot. + Width of axes for dendrogram plot. (default=20) grid_col: str Color of vertical lines for highlighting sample metadata. + (default='w') grid_width: int Width of vertical lines for highlighting sample metadata. + (default=2) figsize: tuple of int - Species (width, height) for figure. + Species (width, height) for figure. (default=(5, 5)) + Returns ------- @@ -81,7 +86,7 @@ def heatmap(table, tree, mdvar, highlights=None, h = len(highlights) else: h = 0 - hwidth = 0.02 + hwidth = highlight_width [axs_x, axs_y, axs_w, axs_h] = [xwidth, top_buffer, hwidth * h, height] # dendrogram axes on the right side diff --git a/gneiss/plot/tests/test_dendrogram.py b/gneiss/plot/tests/test_dendrogram.py index 4a8ef99..ef3a903 100644 --- a/gneiss/plot/tests/test_dendrogram.py +++ b/gneiss/plot/tests/test_dendrogram.py @@ -91,8 +91,8 @@ def test_coords(self): def test_rescale(self): t = UnrootedDendrogram.from_tree(self.tree) - self.assertAlmostEquals(t.rescale(500, 500), 91.608680314971238, - places=5) + self.assertAlmostEqual(t.rescale(500, 500), 91.608680314971238, + places=5) def test_update_coordinates(self): t = UnrootedDendrogram.from_tree(self.tree) From 21532b7f61964c5fd4ad65d517ca4950d4a2bedb Mon Sep 17 00:00:00 2001 From: Jamie Morton Date: Wed, 1 Mar 2017 16:52:50 -0800 Subject: [PATCH 11/12] TST: fixing labels in test --- gneiss/plot/tests/test_heatmap.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gneiss/plot/tests/test_heatmap.py b/gneiss/plot/tests/test_heatmap.py index 07c2807..8219099 100644 --- a/gneiss/plot/tests/test_heatmap.py +++ b/gneiss/plot/tests/test_heatmap.py @@ -97,10 +97,10 @@ def test_basic_highlights(self): # Make sure that the highlight labels are set properly res = str(fig.get_axes()[2].get_xticklabels()[0]) - self.assertEqual(res, "Text(0,0,'0')") + self.assertEqual(res, "Text(0,0,'y6')") - res = str(fig.get_axes()[2].get_xticklabels()[0]) - self.assertEqual(res, "Text(0,0,'0')") + res = str(fig.get_axes()[2].get_xticklabels()[1]) + self.assertEqual(res, "Text(0,0,'y8')") if __name__ == "__main__": From b0caea48b167e43e3a3dbd849c17608ed5569d8e Mon Sep 17 00:00:00 2001 From: Jamie Morton Date: Wed, 1 Mar 2017 17:32:45 -0800 Subject: [PATCH 12/12] STY: Addressing comments Removing attributes that aren't used --- gneiss/plot/_dendrogram.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/gneiss/plot/_dendrogram.py b/gneiss/plot/_dendrogram.py index d293443..12a8bd8 100644 --- a/gneiss/plot/_dendrogram.py +++ b/gneiss/plot/_dendrogram.py @@ -35,19 +35,16 @@ class Dendrogram(TreeNode): Notes ----- - `length` refers to the branch length connect to the specified subtree. + `length` refers to the branch length of a node to its parent. `leafcount` is the number of tips within a subtree. `height` refers to the longest path from root to the deepst leaf in that subtree. `depth` is the number of nodes found in the longest path. """ - aspect_distorts_lengths = True - def __init__(self, use_lengths=True, **kwargs): """ Constructs a Dendrogram object for visualization. """ super().__init__(**kwargs) - self.use_lengths_default = use_lengths def _cache_ntips(self): """ Counts the number of leaves under each subtree.""" @@ -75,8 +72,6 @@ def update_geometry(self, use_lengths, depth=None): self.length = 0 else: self.length = 1 - else: - self.length = self.length self.depth = (depth or 0) + self.length @@ -154,8 +149,6 @@ class UnrootedDendrogram(Dendrogram): height depth """ - aspect_distorts_lengths = True - def __init__(self, **kwargs): """ Constructs a UnrootedDendrogram object for visualization. @@ -378,7 +371,6 @@ def update_x_coordinates(self, scale, x1=0): class SquareDendrogram(RootedDendrogram): - aspect_distorts_lengths = False def ycoords(self, scale, y1): cys = [c.y1 for c in self.children]