Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Violinplot: fix histogram, add rug #997

Merged
merged 3 commits into from
Jan 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 42 additions & 16 deletions arviz/plots/backends/bokeh/violinplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ def plot_violin(
figsize,
rows,
cols,
sharex,
sharey,
kwargs_shade,
shade_kwargs,
shade,
rug,
rug_kwargs,
bw,
credible_interval,
linewidth,
Expand All @@ -40,6 +43,7 @@ def plot_violin(
len(plotters),
rows,
cols,
sharex=sharex,
sharey=sharey,
figsize=figsize,
squeeze=False,
Expand All @@ -54,17 +58,30 @@ def plot_violin(
):
val = x.flatten()
if val[0].dtype.kind == "i":
cat_hist(val, shade, ax_, **kwargs_shade)
dens = cat_hist(val, rug, shade, ax_, **shade_kwargs)
else:
_violinplot(val, shade, bw, ax_, **kwargs_shade)
dens = _violinplot(val, rug, shade, bw, ax_, **shade_kwargs)

if rug:
rug_x = -np.abs(np.random.normal(scale=max(dens) / 3.5, size=len(val)))
ax_.scatter(rug_x, val, **rug_kwargs)

per = np.percentile(val, [25, 75, 50])
hpd_intervals = hpd(val, credible_interval, multimodal=False)

if quartiles:
ax_.line([0, 0], per[:2], line_width=linewidth * 3, line_color="black")
ax_.line([0, 0], hpd_intervals, line_width=linewidth, line_color="black")
ax_.circle(0, per[-1])
ax_.line(
[0, 0], per[:2], line_width=linewidth * 3, line_color="black", line_cap="round"
)
ax_.line([0, 0], hpd_intervals, line_width=linewidth, line_color="black", line_cap="round")
ax_.circle(
0,
per[-1],
line_color="white",
fill_color="white",
size=linewidth * 1.5,
line_width=linewidth,
)

_title = Title()
_title.text = make_label(var_name, selection)
Expand All @@ -80,35 +97,44 @@ def plot_violin(
return ax


def _violinplot(val, shade, bw, ax, **kwargs_shade):
def _violinplot(val, rug, shade, bw, ax, **shade_kwargs):
"""Auxiliary function to plot violinplots."""
density, low_b, up_b = _fast_kde(val, bw=bw)
x = np.linspace(low_b, up_b, len(density))

x = np.concatenate([x, x[::-1]])
density = np.concatenate([-density, density[::-1]])
if not rug:
x = np.concatenate([x, x[::-1]])
density = np.concatenate([-density, density[::-1]])

ax.harea(y=x, x1=density, x2=np.zeros_like(density), fill_alpha=shade, **shade_kwargs)

ax.patch(density, x, fill_alpha=shade, line_width=0, **kwargs_shade)
return density


def cat_hist(val, shade, ax, **kwargs_shade):
def cat_hist(val, rug, shade, ax, **shade_kwargs):
"""Auxiliary function to plot discrete-violinplots."""
bins = get_bins(val)
_, binned_d, _ = histogram(val, bins=bins)

bin_edges = np.linspace(np.min(val), np.max(val), len(bins))
centers = 0.5 * (bin_edges + np.roll(bin_edges, 1))[:-1]
heights = np.diff(bin_edges)
centers = bin_edges[:-1] + heights.mean() / 2
right = 0.5 * binned_d

lefts = -0.5 * binned_d
if rug:
left = 0
else:
left = -right

ax.hbar(
y=centers,
left=lefts,
right=-lefts,
left=left,
right=right,
height=heights,
fill_alpha=shade,
line_alpha=shade,
line_color=None,
**kwargs_shade
**shade_kwargs
)

return binned_d
41 changes: 29 additions & 12 deletions arviz/plots/backends/matplotlib/violinplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@ def plot_violin(
figsize,
rows,
cols,
sharex,
sharey,
kwargs_shade,
shade_kwargs,
shade,
rug,
rug_kwargs,
bw,
credible_interval,
linewidth,
Expand All @@ -29,24 +32,31 @@ def plot_violin(
):
"""Matplotlib violin plot."""
if ax is None:
_, ax = _create_axes_grid(
fig, ax = _create_axes_grid(
len(plotters),
rows,
cols,
sharex=sharex,
sharey=sharey,
figsize=figsize,
squeeze=False,
backend_kwargs=backend_kwargs,
)
fig.set_constrained_layout(False)
fig.subplots_adjust(wspace=0)

ax = np.atleast_1d(ax)

for (var_name, selection, x), ax_ in zip(plotters, ax.flatten()):
val = x.flatten()
if val[0].dtype.kind == "i":
cat_hist(val, shade, ax_, **kwargs_shade)
dens = cat_hist(val, rug, shade, ax_, **shade_kwargs)
else:
_violinplot(val, shade, bw, ax_, **kwargs_shade)
dens = _violinplot(val, rug, shade, bw, ax_, **shade_kwargs)

if rug:
rug_x = -np.abs(np.random.normal(scale=max(dens) / 3.5, size=len(val)))
ax_.plot(rug_x, val, **rug_kwargs)

per = np.percentile(val, [25, 75, 50])
hpd_intervals = hpd(val, credible_interval, multimodal=False)
Expand All @@ -67,25 +77,32 @@ def plot_violin(
return ax


def _violinplot(val, shade, bw, ax, **kwargs_shade):
def _violinplot(val, rug, shade, bw, ax, **shade_kwargs):
"""Auxiliary function to plot violinplots."""
density, low_b, up_b = _fast_kde(val, bw=bw)
x = np.linspace(low_b, up_b, len(density))

x = np.concatenate([x, x[::-1]])
density = np.concatenate([-density, density[::-1]])
if not rug:
x = np.concatenate([x, x[::-1]])
density = np.concatenate([-density, density[::-1]])

ax.fill_betweenx(x, density, alpha=shade, lw=0, **kwargs_shade)
ax.fill_betweenx(x, density, alpha=shade, lw=0, **shade_kwargs)
return density


def cat_hist(val, shade, ax, **kwargs_shade):
def cat_hist(val, rug, shade, ax, **shade_kwargs):
"""Auxiliary function to plot discrete-violinplots."""
bins = get_bins(val)
_, binned_d, _ = histogram(val, bins=bins)

bin_edges = np.linspace(np.min(val), np.max(val), len(bins))
centers = 0.5 * (bin_edges + np.roll(bin_edges, 1))[:-1]
heights = np.diff(bin_edges)
centers = bin_edges[:-1] + heights.mean() / 2

if rug:
left = None
else:
left = -0.5 * binned_d

lefts = -0.5 * binned_d
ax.barh(centers, binned_d, height=heights, left=lefts, alpha=shade, **kwargs_shade)
ax.barh(centers, binned_d, height=heights, left=left, alpha=shade, **shade_kwargs)
return binned_d
35 changes: 29 additions & 6 deletions arviz/plots/violinplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@ def plot_violin(
data,
var_names=None,
quartiles=True,
rug=False,
credible_interval=0.94,
shade=0.35,
bw=4.5,
sharex=True,
sharey=True,
figsize=None,
textsize=None,
ax=None,
kwargs_shade=None,
shade_kwargs=None,
rug_kwargs=None,
backend=None,
backend_kwargs=None,
show=None,
Expand All @@ -42,6 +45,8 @@ def plot_violin(
quartiles : bool, optional
Flag for plotting the interquartile range, in addition to the credible_interval*100%
intervals. Defaults to True
rug : bool
If True adds a jittered rugplot. Defaults to False.
credible_interval : float, optional
Credible intervals. Defaults to 0.94.
shade : float
Expand All @@ -56,12 +61,17 @@ def plot_violin(
textsize: int
Text size of the point_estimates, axis ticks, and HPD. If None it will be autoscaled
based on figsize.
sharex : bool
Defaults to True, violinplots share a common x-axis scale.
sharey : bool
Defaults to True, violinplots share a common y-axis scale.
ax: axes, optional
Matplotlib axes or bokeh figures.
kwargs_shade : dicts, optional
shade_kwargs : dicts, optional
Additional keywords passed to `fill_between`, or `barh` to control the shade.
rug_kwargs : dict
Keywords passed to the rug plot. If true only the righ half side of the violin will be
plotted.
backend: str, optional
Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib".
backend_kwargs: bool, optional
Expand All @@ -81,25 +91,30 @@ def plot_violin(
list(xarray_var_iter(data, var_names=var_names, combined=True)), "plot_violin"
)

if kwargs_shade is None:
kwargs_shade = {}
if shade_kwargs is None:
shade_kwargs = {}

rows, cols = default_grid(len(plotters))

(figsize, ax_labelsize, _, xt_labelsize, linewidth, _) = _scale_fig_size(
figsize, textsize, rows, cols
)
ax_labelsize *= 2

if rug_kwargs is None:
rug_kwargs = {}

violinplot_kwargs = dict(
ax=ax,
plotters=plotters,
figsize=figsize,
rows=rows,
cols=cols,
sharex=sharex,
sharey=sharey,
kwargs_shade=kwargs_shade,
shade_kwargs=shade_kwargs,
shade=shade,
rug=rug,
rug_kwargs=rug_kwargs,
bw=bw,
credible_interval=credible_interval,
linewidth=linewidth,
Expand All @@ -115,6 +130,14 @@ def plot_violin(
violinplot_kwargs.pop("ax_labelsize")
violinplot_kwargs.pop("xt_labelsize")

rug_kwargs.setdefault("fill_alpha", 0.1)
rug_kwargs.setdefault("line_alpha", 0.1)

else:
rug_kwargs.setdefault("alpha", 0.1)
rug_kwargs.setdefault("marker", ".")
rug_kwargs.setdefault("linestyle", "")

# TODO: Add backend kwargs
plot = get_plotting_function("plot_violin", "violinplot", backend)
ax = plot(**violinplot_kwargs)
Expand Down