Skip to content

Commit

Permalink
Violinplot: fix histogram, add rug (#997)
Browse files Browse the repository at this point in the history
* fix histogram, add rug

* add rug to bokeh

* remove redundant line, make bokeh plot looks closer to matplotlib, fix scale jitter
  • Loading branch information
aloctavodia authored Jan 20, 2020
1 parent 9987acb commit 3dee7bc
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 34 deletions.
58 changes: 42 additions & 16 deletions arviz/plots/backends/bokeh/violinplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ def plot_violin(
figsize,
rows,
cols,
sharex,
sharey,
kwargs_shade,
shade_kwargs,
shade,
rug,
rug_kwargs,
bw,
credible_interval,
linewidth,
Expand All @@ -40,6 +43,7 @@ def plot_violin(
len(plotters),
rows,
cols,
sharex=sharex,
sharey=sharey,
figsize=figsize,
squeeze=False,
Expand All @@ -54,17 +58,30 @@ def plot_violin(
):
val = x.flatten()
if val[0].dtype.kind == "i":
cat_hist(val, shade, ax_, **kwargs_shade)
dens = cat_hist(val, rug, shade, ax_, **shade_kwargs)
else:
_violinplot(val, shade, bw, ax_, **kwargs_shade)
dens = _violinplot(val, rug, shade, bw, ax_, **shade_kwargs)

if rug:
rug_x = -np.abs(np.random.normal(scale=max(dens) / 3.5, size=len(val)))
ax_.scatter(rug_x, val, **rug_kwargs)

per = np.percentile(val, [25, 75, 50])
hpd_intervals = hpd(val, credible_interval, multimodal=False)

if quartiles:
ax_.line([0, 0], per[:2], line_width=linewidth * 3, line_color="black")
ax_.line([0, 0], hpd_intervals, line_width=linewidth, line_color="black")
ax_.circle(0, per[-1])
ax_.line(
[0, 0], per[:2], line_width=linewidth * 3, line_color="black", line_cap="round"
)
ax_.line([0, 0], hpd_intervals, line_width=linewidth, line_color="black", line_cap="round")
ax_.circle(
0,
per[-1],
line_color="white",
fill_color="white",
size=linewidth * 1.5,
line_width=linewidth,
)

_title = Title()
_title.text = make_label(var_name, selection)
Expand All @@ -80,35 +97,44 @@ def plot_violin(
return ax


def _violinplot(val, shade, bw, ax, **kwargs_shade):
def _violinplot(val, rug, shade, bw, ax, **shade_kwargs):
"""Auxiliary function to plot violinplots."""
density, low_b, up_b = _fast_kde(val, bw=bw)
x = np.linspace(low_b, up_b, len(density))

x = np.concatenate([x, x[::-1]])
density = np.concatenate([-density, density[::-1]])
if not rug:
x = np.concatenate([x, x[::-1]])
density = np.concatenate([-density, density[::-1]])

ax.harea(y=x, x1=density, x2=np.zeros_like(density), fill_alpha=shade, **shade_kwargs)

ax.patch(density, x, fill_alpha=shade, line_width=0, **kwargs_shade)
return density


def cat_hist(val, shade, ax, **kwargs_shade):
def cat_hist(val, rug, shade, ax, **shade_kwargs):
"""Auxiliary function to plot discrete-violinplots."""
bins = get_bins(val)
_, binned_d, _ = histogram(val, bins=bins)

bin_edges = np.linspace(np.min(val), np.max(val), len(bins))
centers = 0.5 * (bin_edges + np.roll(bin_edges, 1))[:-1]
heights = np.diff(bin_edges)
centers = bin_edges[:-1] + heights.mean() / 2
right = 0.5 * binned_d

lefts = -0.5 * binned_d
if rug:
left = 0
else:
left = -right

ax.hbar(
y=centers,
left=lefts,
right=-lefts,
left=left,
right=right,
height=heights,
fill_alpha=shade,
line_alpha=shade,
line_color=None,
**kwargs_shade
**shade_kwargs
)

return binned_d
41 changes: 29 additions & 12 deletions arviz/plots/backends/matplotlib/violinplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ def plot_violin(
figsize,
rows,
cols,
sharex,
sharey,
kwargs_shade,
shade_kwargs,
shade,
rug,
rug_kwargs,
bw,
credible_interval,
linewidth,
Expand All @@ -28,24 +31,31 @@ def plot_violin(
):
"""Matplotlib violin plot."""
if ax is None:
_, ax = _create_axes_grid(
fig, ax = _create_axes_grid(
len(plotters),
rows,
cols,
sharex=sharex,
sharey=sharey,
figsize=figsize,
squeeze=False,
backend_kwargs=backend_kwargs,
)
fig.set_constrained_layout(False)
fig.subplots_adjust(wspace=0)

ax = np.atleast_1d(ax)

for (var_name, selection, x), ax_ in zip(plotters, ax.flatten()):
val = x.flatten()
if val[0].dtype.kind == "i":
cat_hist(val, shade, ax_, **kwargs_shade)
dens = cat_hist(val, rug, shade, ax_, **shade_kwargs)
else:
_violinplot(val, shade, bw, ax_, **kwargs_shade)
dens = _violinplot(val, rug, shade, bw, ax_, **shade_kwargs)

if rug:
rug_x = -np.abs(np.random.normal(scale=max(dens) / 3.5, size=len(val)))
ax_.plot(rug_x, val, **rug_kwargs)

per = np.percentile(val, [25, 75, 50])
hpd_intervals = hpd(val, credible_interval, multimodal=False)
Expand All @@ -66,25 +76,32 @@ def plot_violin(
return ax


def _violinplot(val, shade, bw, ax, **kwargs_shade):
def _violinplot(val, rug, shade, bw, ax, **shade_kwargs):
"""Auxiliary function to plot violinplots."""
density, low_b, up_b = _fast_kde(val, bw=bw)
x = np.linspace(low_b, up_b, len(density))

x = np.concatenate([x, x[::-1]])
density = np.concatenate([-density, density[::-1]])
if not rug:
x = np.concatenate([x, x[::-1]])
density = np.concatenate([-density, density[::-1]])

ax.fill_betweenx(x, density, alpha=shade, lw=0, **kwargs_shade)
ax.fill_betweenx(x, density, alpha=shade, lw=0, **shade_kwargs)
return density


def cat_hist(val, shade, ax, **kwargs_shade):
def cat_hist(val, rug, shade, ax, **shade_kwargs):
"""Auxiliary function to plot discrete-violinplots."""
bins = get_bins(val)
_, binned_d, _ = histogram(val, bins=bins)

bin_edges = np.linspace(np.min(val), np.max(val), len(bins))
centers = 0.5 * (bin_edges + np.roll(bin_edges, 1))[:-1]
heights = np.diff(bin_edges)
centers = bin_edges[:-1] + heights.mean() / 2

if rug:
left = None
else:
left = -0.5 * binned_d

lefts = -0.5 * binned_d
ax.barh(centers, binned_d, height=heights, left=lefts, alpha=shade, **kwargs_shade)
ax.barh(centers, binned_d, height=heights, left=left, alpha=shade, **shade_kwargs)
return binned_d
35 changes: 29 additions & 6 deletions arviz/plots/violinplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@ def plot_violin(
data,
var_names=None,
quartiles=True,
rug=False,
credible_interval=0.94,
shade=0.35,
bw=4.5,
sharex=True,
sharey=True,
figsize=None,
textsize=None,
ax=None,
kwargs_shade=None,
shade_kwargs=None,
rug_kwargs=None,
backend=None,
backend_kwargs=None,
show=None,
Expand All @@ -42,6 +45,8 @@ def plot_violin(
quartiles : bool, optional
Flag for plotting the interquartile range, in addition to the credible_interval*100%
intervals. Defaults to True
rug : bool
If True adds a jittered rugplot. Defaults to False.
credible_interval : float, optional
Credible intervals. Defaults to 0.94.
shade : float
Expand All @@ -56,12 +61,17 @@ def plot_violin(
textsize: int
Text size of the point_estimates, axis ticks, and HPD. If None it will be autoscaled
based on figsize.
sharex : bool
Defaults to True, violinplots share a common x-axis scale.
sharey : bool
Defaults to True, violinplots share a common y-axis scale.
ax: axes, optional
Matplotlib axes or bokeh figures.
kwargs_shade : dicts, optional
shade_kwargs : dicts, optional
Additional keywords passed to `fill_between`, or `barh` to control the shade.
rug_kwargs : dict
Keywords passed to the rug plot. If true only the righ half side of the violin will be
plotted.
backend: str, optional
Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib".
backend_kwargs: bool, optional
Expand All @@ -81,25 +91,30 @@ def plot_violin(
list(xarray_var_iter(data, var_names=var_names, combined=True)), "plot_violin"
)

if kwargs_shade is None:
kwargs_shade = {}
if shade_kwargs is None:
shade_kwargs = {}

rows, cols = default_grid(len(plotters))

(figsize, ax_labelsize, _, xt_labelsize, linewidth, _) = _scale_fig_size(
figsize, textsize, rows, cols
)
ax_labelsize *= 2

if rug_kwargs is None:
rug_kwargs = {}

violinplot_kwargs = dict(
ax=ax,
plotters=plotters,
figsize=figsize,
rows=rows,
cols=cols,
sharex=sharex,
sharey=sharey,
kwargs_shade=kwargs_shade,
shade_kwargs=shade_kwargs,
shade=shade,
rug=rug,
rug_kwargs=rug_kwargs,
bw=bw,
credible_interval=credible_interval,
linewidth=linewidth,
Expand All @@ -115,6 +130,14 @@ def plot_violin(
violinplot_kwargs.pop("ax_labelsize")
violinplot_kwargs.pop("xt_labelsize")

rug_kwargs.setdefault("fill_alpha", 0.1)
rug_kwargs.setdefault("line_alpha", 0.1)

else:
rug_kwargs.setdefault("alpha", 0.1)
rug_kwargs.setdefault("marker", ".")
rug_kwargs.setdefault("linestyle", "")

# TODO: Add backend kwargs
plot = get_plotting_function("plot_violin", "violinplot", backend)
ax = plot(**violinplot_kwargs)
Expand Down

0 comments on commit 3dee7bc

Please sign in to comment.