diff --git a/CHANGELOG.md b/CHANGELOG.md index 38e0d430b0..217b4d8ede 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ * Add `skipna` argument to `hpd` and `summary` (#1035) * Added `transform` argument to `plot_trace`, `plot_forest`, `plot_pair`, `plot_posterior`, `plot_rank`, `plot_parallel`, `plot_violin`,`plot_density`, `plot_joint` (#1036) * Add `marker` functionality to `bokeh_plot_elpd` (#1040) +* Add `ridgeplot_quantiles` argument to `plot_forest` (#1047) * Added the functionality [interactive legends](https://docs.bokeh.org/en/1.4.0/docs/user_guide/interaction/legends.html) for bokeh plots of `densityplot`, `energyplot` and `essplot` (#1024) diff --git a/arviz/plots/backends/bokeh/forestplot.py b/arviz/plots/backends/bokeh/forestplot.py index 9121bfd4ee..4b10989c54 100644 --- a/arviz/plots/backends/bokeh/forestplot.py +++ b/arviz/plots/backends/bokeh/forestplot.py @@ -46,6 +46,7 @@ def plot_forest( ridgeplot_overlap, ridgeplot_alpha, ridgeplot_kind, + ridgeplot_quantiles, textsize, ess, r_hat, @@ -59,7 +60,10 @@ def plot_forest( ) if figsize is None: - figsize = (min(12, sum(width_ratios) * 2), plot_handler.fig_height()) + if kind == "ridgeplot": + figsize = (min(14, sum(width_ratios) * 3), plot_handler.fig_height() * 3) + else: + figsize = (min(12, sum(width_ratios) * 2), plot_handler.fig_height()) (figsize, _, _, _, auto_linewidth, auto_markersize) = _scale_fig_size(figsize, textsize, 1.1, 1) @@ -122,7 +126,12 @@ def plot_forest( ) elif kind == "ridgeplot": plot_handler.ridgeplot( - ridgeplot_overlap, linewidth, ridgeplot_alpha, ridgeplot_kind, axes[0, 0] + ridgeplot_overlap, + linewidth, + ridgeplot_alpha, + ridgeplot_kind, + ridgeplot_quantiles, + axes[0, 0], ) else: raise TypeError( @@ -270,7 +279,7 @@ def display_multiple_ropes(self, rope, ax, y, linewidth, rope_var): ) return ax - def ridgeplot(self, mult, linewidth, alpha, ridgeplot_kind, ax): + def ridgeplot(self, mult, linewidth, alpha, ridgeplot_kind, ridgeplot_quantiles, ax): """Draw ridgeplot for each plotter. Parameters @@ -290,7 +299,7 @@ def ridgeplot(self, mult, linewidth, alpha, ridgeplot_kind, ax): if alpha is None: alpha = 1.0 for plotter in list(self.plotters.values())[::-1]: - for x, y_min, y_max, color in list(plotter.ridgeplot(mult, ridgeplot_kind))[::-1]: + for x, y_min, y_max, y_q, color in list(plotter.ridgeplot(mult, ridgeplot_kind))[::-1]: if alpha == 0: border = color facecolor = None @@ -306,16 +315,43 @@ def ridgeplot(self, mult, linewidth, alpha, ridgeplot_kind, ax): fill_color=facecolor, ) else: - patch = ax.patch( - np.concatenate([x, x[::-1]]), - np.concatenate([y_min, y_max[::-1]]), - fill_color=color, - fill_alpha=alpha, - line_dash="solid", - line_width=linewidth, - line_color=border, - ) - patch.level = "overlay" + if ridgeplot_quantiles is None: + patch = ax.patch( + np.concatenate([x, x[::-1]]), + np.concatenate([y_min, y_max[::-1]]), + fill_color=color, + fill_alpha=alpha, + line_dash="solid", + line_width=linewidth, + line_color=border, + ) + patch.level = "overlay" + else: + quantiles = sorted(np.clip(ridgeplot_quantiles, 0, 1)) + if quantiles[0] != 0: + quantiles = [0] + quantiles + if quantiles[-1] != 1: + quantiles = quantiles + [1] + + for quant_0, quant_1 in zip(quantiles[:-1], quantiles[1:]): + idx = (y_q > quant_0) & (y_q < quant_1) + if idx.sum(): + patch_x = np.concatenate( + (x[idx], [x[idx][-1]], x[idx][::-1], [x[idx][0]]) + ) + patch_y = np.concatenate( + ( + y_min[idx], + [y_min[idx][-1]], + y_max[idx][::-1], + [y_max[idx][0]], + ) + ) + patch = ax.patch( + patch_x, patch_y, fill_color=color, fill_alpha=alpha, + ) + patch.level = "overlay" + return ax def forestplot(self, credible_interval, quartiles, linewidth, markersize, ax, rope): @@ -526,7 +562,7 @@ def treeplot(self, qlist, credible_interval): def ridgeplot(self, mult, ridgeplot_kind): """Get data for each ridgeplot for the variable.""" - xvals, yvals, pdfs, colors = [], [], [], [] + xvals, yvals, pdfs, pdfs_q, colors = [], [], [], [], [] for y, *_, values, color in self.iterator(): yvals.append(y) colors.append(color) @@ -544,15 +580,17 @@ def ridgeplot(self, mult, ridgeplot_kind): x = x[:-1] elif kind == "density": density, lower, upper = _fast_kde(values) + density_q = density.cumsum() / density.sum() x = np.linspace(lower, upper, len(density)) xvals.append(x) pdfs.append(density) + pdfs_q.append((density_q)) scaling = max(np.max(j) for j in pdfs) - for y, x, pdf, color in zip(yvals, xvals, pdfs, colors): + for y, x, pdf, pdf_q, color in zip(yvals, xvals, pdfs, pdfs_q, colors): y = y * np.ones_like(x) - yield x, y, mult * pdf / scaling + y, color + yield x, y, mult * pdf / scaling + y, pdf_q, color def ess(self): """Get effective n data for the variable.""" diff --git a/arviz/plots/backends/matplotlib/forestplot.py b/arviz/plots/backends/matplotlib/forestplot.py index 6fec8e375e..d68671dcf8 100644 --- a/arviz/plots/backends/matplotlib/forestplot.py +++ b/arviz/plots/backends/matplotlib/forestplot.py @@ -40,6 +40,7 @@ def plot_forest( ridgeplot_overlap, ridgeplot_alpha, ridgeplot_kind, + ridgeplot_quantiles, textsize, ess, r_hat, @@ -52,7 +53,10 @@ def plot_forest( ) if figsize is None: - figsize = (min(12, sum(width_ratios) * 2), plot_handler.fig_height()) + if kind == "ridgeplot": + figsize = (min(14, sum(width_ratios) * 4), plot_handler.fig_height() * 1.2) + else: + figsize = (min(12, sum(width_ratios) * 2), plot_handler.fig_height()) (figsize, _, titlesize, xt_labelsize, auto_linewidth, auto_markersize) = _scale_fig_size( figsize, textsize, 1.1, 1 @@ -98,7 +102,12 @@ def plot_forest( ) elif kind == "ridgeplot": plot_handler.ridgeplot( - ridgeplot_overlap, linewidth, ridgeplot_alpha, ridgeplot_kind, axes[0] + ridgeplot_overlap, + linewidth, + ridgeplot_alpha, + ridgeplot_kind, + ridgeplot_quantiles, + axes[0], ) else: raise TypeError( @@ -230,7 +239,7 @@ def display_multiple_ropes(self, rope, ax, y, linewidth, rope_var): ) return ax - def ridgeplot(self, mult, linewidth, alpha, ridgeplot_kind, ax): + def ridgeplot(self, mult, linewidth, alpha, ridgeplot_kind, ridgeplot_quantiles, ax): """Draw ridgeplot for each plotter. Parameters @@ -251,7 +260,7 @@ def ridgeplot(self, mult, linewidth, alpha, ridgeplot_kind, ax): alpha = 1.0 zorder = 0 for plotter in self.plotters.values(): - for x, y_min, y_max, color in plotter.ridgeplot(mult, ridgeplot_kind): + for x, y_min, y_max, y_q, color in plotter.ridgeplot(mult, ridgeplot_kind): if alpha == 0: border = color facecolor = "None" @@ -269,9 +278,21 @@ def ridgeplot(self, mult, linewidth, alpha, ridgeplot_kind, ax): zorder=zorder, ) else: - ax.plot(x, y_max, "-", linewidth=linewidth, color=border, zorder=zorder) - ax.plot(x, y_min, "-", linewidth=linewidth, color=border, zorder=zorder) - ax.fill_between(x, y_min, y_max, alpha=alpha, color=color, zorder=zorder) + if ridgeplot_quantiles is not None: + idx = [np.sum(y_q < quant) for quant in ridgeplot_quantiles] + ax.fill_between( + x, + y_min, + y_max, + where=np.isin(x, x[idx], invert=True, assume_unique=True), + alpha=alpha, + color=color, + zorder=zorder, + ) + else: + ax.plot(x, y_max, "-", linewidth=linewidth, color=border, zorder=zorder) + ax.plot(x, y_min, "-", linewidth=linewidth, color=border, zorder=zorder) + ax.fill_between(x, y_min, y_max, alpha=alpha, color=color, zorder=zorder) zorder -= 1 return ax @@ -484,7 +505,7 @@ def treeplot(self, qlist, credible_interval): def ridgeplot(self, mult, ridgeplot_kind): """Get data for each ridgeplot for the variable.""" - xvals, yvals, pdfs, colors = [], [], [], [] + xvals, yvals, pdfs, pdfs_q, colors = [], [], [], [], [] for y, *_, values, color in self.iterator(): yvals.append(y) colors.append(color) @@ -502,15 +523,17 @@ def ridgeplot(self, mult, ridgeplot_kind): x = x[:-1] elif kind == "density": density, lower, upper = _fast_kde(values) + density_q = density.cumsum() / density.sum() x = np.linspace(lower, upper, len(density)) xvals.append(x) pdfs.append(density) + pdfs_q.append((density_q)) scaling = max(np.max(j) for j in pdfs) - for y, x, pdf, color in zip(yvals, xvals, pdfs, colors): + for y, x, pdf, pdf_q, color in zip(yvals, xvals, pdfs, pdfs_q, colors): y = y * np.ones_like(x) - yield x, y, mult * pdf / scaling + y, color + yield x, y, mult * pdf / scaling + y, pdf_q, color def ess(self): """Get effective n data for the variable.""" diff --git a/arviz/plots/forestplot.py b/arviz/plots/forestplot.py index 5b95fa7f1a..db7677c0e6 100644 --- a/arviz/plots/forestplot.py +++ b/arviz/plots/forestplot.py @@ -25,6 +25,7 @@ def plot_forest( ridgeplot_alpha=None, ridgeplot_overlap=2, ridgeplot_kind="auto", + ridgeplot_quantiles=None, figsize=None, ax=None, backend=None, @@ -90,6 +91,9 @@ def plot_forest( ridgeplot_kind : string By default ("auto") continuous variables are plotted using KDEs and discrete ones using histograms. To override this use "hist" to plot histograms and "density" for KDEs + ridgeplot_quantiles : list + Quantiles in ascending order used to segment the KDE. Use [.25, .5, .75] for quartiles. + Defaults to None. figsize : tuple Figure size. If None it will be defined automatically. ax: axes, optional @@ -188,6 +192,7 @@ def plot_forest( ridgeplot_overlap=ridgeplot_overlap, ridgeplot_alpha=ridgeplot_alpha, ridgeplot_kind=ridgeplot_kind, + ridgeplot_quantiles=ridgeplot_quantiles, textsize=textsize, ess=ess, r_hat=r_hat,