Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Single sided violin plots (enabbling split-violin plots) [WIP] - feedback welcome #1996

Merged
merged 18 commits into from
Jul 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
## v0.x.x Unreleased

### New features
* Add `side` argument to `plot_violin` to allow single-sided violin plots ([1996](https://github.com/arviz-devs/arviz/pull/1996))

### Maintenance and fixes
* Add exception in `az.plot_hdi` for `x` of type `np.datetime64` and `smooth=True` ([2016](https://github.com/arviz-devs/arviz/pull/2016))
* Change `ax.plot` usage to `ax.scatter` in `plot_pair`. ([1990](https://github.com/arviz-devs/arviz/pull/1990))
* Change `ax.plot` usage to `ax.scatter` in `plot_pair` ([1990](https://github.com/arviz-devs/arviz/pull/1990))

### Deprecation

Expand Down
41 changes: 30 additions & 11 deletions arviz/plots/backends/bokeh/violinplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def plot_violin(
shade_kwargs,
shade,
rug,
side,
rug_kwargs,
bw,
textsize,
Expand Down Expand Up @@ -64,10 +65,9 @@ def plot_violin(
):
val = x.flatten()
if val[0].dtype.kind == "i":
dens = cat_hist(val, rug, shade, ax_, **shade_kwargs)
dens = cat_hist(val, rug, side, shade, ax_, **shade_kwargs)
else:
dens = _violinplot(val, rug, shade, bw, circular, ax_, **shade_kwargs)

dens = _violinplot(val, rug, side, shade, bw, circular, ax_, **shade_kwargs)
if rug:
rug_x = -np.abs(np.random.normal(scale=max(dens) / 3.5, size=len(val)))
ax_.scatter(rug_x, val, **rug_kwargs)
Expand Down Expand Up @@ -102,32 +102,51 @@ def plot_violin(
return ax


def _violinplot(val, rug, shade, bw, circular, ax, **shade_kwargs):
def _violinplot(val, rug, side, shade, bw, circular, ax, **shade_kwargs):
"""Auxiliary function to plot violinplots."""
if bw == "default":
bw = "taylor" if circular else "experimental"
x, density = kde(val, circular=circular, bw=bw)

if not rug:
if rug and side == "both":
side = "right"

if side == "left":
dens = -density
elif side == "right":
x = x[::-1]
dens = density[::-1]
elif side == "both":
x = np.concatenate([x, x[::-1]])
density = np.concatenate([-density, density[::-1]])
dens = np.concatenate([-density, density[::-1]])

ax.harea(y=x, x1=density, x2=np.zeros_like(density), fill_alpha=shade, **shade_kwargs)
ax.harea(y=x, x1=dens, x2=np.zeros_like(dens), fill_alpha=shade, **shade_kwargs)

return density


def cat_hist(val, rug, shade, ax, **shade_kwargs):
def cat_hist(val, rug, side, shade, ax, **shade_kwargs):
"""Auxiliary function to plot discrete-violinplots."""
bins = get_bins(val)
_, binned_d, _ = histogram(val, bins=bins)

bin_edges = np.linspace(np.min(val), np.max(val), len(bins))
heights = np.diff(bin_edges)
centers = bin_edges[:-1] + heights.mean() / 2
right = 0.5 * binned_d

left = 0 if rug else -right
bar_length = 0.5 * binned_d

if rug and side == "both":
side = "right"

if side == "right":
left = 0
right = bar_length
elif side == "left":
left = -bar_length
right = 0
elif side == "both":
left = -bar_length
right = bar_length

ax.hbar(
y=centers,
Expand Down
33 changes: 25 additions & 8 deletions arviz/plots/backends/matplotlib/violinplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def plot_violin(
shade,
rug,
rug_kwargs,
side,
bw,
textsize,
labeller,
Expand Down Expand Up @@ -68,9 +69,9 @@ def plot_violin(
for (var_name, selection, isel, x), ax_ in zip(plotters, ax.flatten()):
val = x.flatten()
if val[0].dtype.kind == "i":
dens = cat_hist(val, rug, shade, ax_, **shade_kwargs)
dens = cat_hist(val, rug, side, shade, ax_, **shade_kwargs)
else:
dens = _violinplot(val, rug, shade, bw, circular, ax_, **shade_kwargs)
dens = _violinplot(val, rug, side, shade, bw, circular, ax_, **shade_kwargs)

if rug:
rug_x = -np.abs(np.random.normal(scale=max(dens) / 3.5, size=len(val)))
Expand Down Expand Up @@ -101,21 +102,29 @@ def plot_violin(
return ax


def _violinplot(val, rug, shade, bw, circular, ax, **shade_kwargs):
def _violinplot(val, rug, side, shade, bw, circular, ax, **shade_kwargs):
"""Auxiliary function to plot violinplots."""
if bw == "default":
bw = "taylor" if circular else "experimental"
x, density = kde(val, circular=circular, bw=bw)

if not rug:
if rug and side == "both":
side = "right"

if side == "left":
dens = -density
elif side == "right":
x = x[::-1]
dens = density[::-1]
elif side == "both":
x = np.concatenate([x, x[::-1]])
density = np.concatenate([-density, density[::-1]])
dens = np.concatenate([-density, density[::-1]])

ax.fill_betweenx(x, density, alpha=shade, lw=0, **shade_kwargs)
ax.fill_betweenx(x, dens, alpha=shade, lw=0, **shade_kwargs)
return density


def cat_hist(val, rug, shade, ax, **shade_kwargs):
def cat_hist(val, rug, side, shade, ax, **shade_kwargs):
"""Auxiliary function to plot discrete-violinplots."""
bins = get_bins(val)
_, binned_d, _ = histogram(val, bins=bins)
Expand All @@ -124,7 +133,15 @@ def cat_hist(val, rug, shade, ax, **shade_kwargs):
heights = np.diff(bin_edges)
centers = bin_edges[:-1] + heights.mean() / 2

left = None if rug else -0.5 * binned_d
if rug and side == "both":
side = "right"

if side == "right":
left = None
elif side == "left":
left = -binned_d
elif side == "both":
left = -0.5 * binned_d

ax.barh(centers, binned_d, height=heights, left=left, alpha=shade, **shade_kwargs)
return binned_d
9 changes: 9 additions & 0 deletions arviz/plots/violinplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def plot_violin(
transform=None,
quartiles=True,
rug=False,
side="both",
hdi_prob=None,
shade=0.35,
bw="default",
Expand Down Expand Up @@ -61,6 +62,10 @@ def plot_violin(
intervals. Defaults to ``True``.
rug: bool
If ``True`` adds a jittered rugplot. Defaults to ``False``.
side : {"both", "left", "right"}, default "both"
If ``both``, both sides of the violin plot are rendered. If ``left`` or ``right``, only
the respective side is rendered. By separately plotting left and right halfs with
different data, split violin plots can be achieved.
hdi_prob: float, optional
Plots highest posterior density interval for chosen percentage of density.
Defaults to 0.94.
Expand Down Expand Up @@ -162,6 +167,7 @@ def plot_violin(
shade=shade,
rug=rug,
rug_kwargs=rug_kwargs,
side=side,
bw=bw,
textsize=textsize,
labeller=labeller,
Expand All @@ -176,6 +182,9 @@ def plot_violin(
backend = rcParams["plot.backend"]
backend = backend.lower()

if side not in ("both", "left", "right"):
raise ValueError(f"'side' can only be 'both', 'left', or 'right', got: '{side}'")

# TODO: Add backend kwargs
plot = get_plotting_function("plot_violin", "violinplot", backend)
ax = plot(**violinplot_kwargs)
Expand Down
8 changes: 6 additions & 2 deletions arviz/tests/base_tests/test_plots_bokeh.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,8 +848,12 @@ def test_plot_parallel_exception(models, var_names):


@pytest.mark.parametrize("var_names", (None, "mu", ["mu", "tau"]))
def test_plot_violin(models, var_names):
axes = plot_violin(models.model_1, var_names=var_names, backend="bokeh", show=False)
@pytest.mark.parametrize("side", ["both", "left", "right"])
@pytest.mark.parametrize("rug", [True])
def test_plot_violin(models, var_names, side, rug):
axes = plot_violin(
models.model_1, var_names=var_names, side=side, rug=rug, backend="bokeh", show=False
)
assert axes.shape


Expand Down
6 changes: 4 additions & 2 deletions arviz/tests/base_tests/test_plots_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,8 +908,10 @@ def test_plot_legend(models):


@pytest.mark.parametrize("var_names", (None, "mu", ["mu", "tau"]))
def test_plot_violin(models, var_names):
axes = plot_violin(models.model_1, var_names=var_names)
@pytest.mark.parametrize("side", ["both", "left", "right"])
@pytest.mark.parametrize("rug", [True])
def test_plot_violin(models, var_names, side, rug):
axes = plot_violin(models.model_1, var_names=var_names, side=side, rug=rug)
assert axes.shape


Expand Down
23 changes: 23 additions & 0 deletions examples/bokeh/bokeh_plot_violin_single_sided.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
Single sided Violinplot
==========

_thumb: .2, .8
"""
import arviz as az

data = az.load_arviz_data("rugby")
labeller = az.labels.MapLabeller(var_name_map={"defs": "atts | defs"})

p1 = az.plot_violin(
data.posterior["atts"], side="left", backend="bokeh", show=False, labeller=labeller
)
p2 = az.plot_violin(
data.posterior["defs"],
side="right",
ax=p1,
backend="bokeh",
shade_kwargs={"color": "lightsalmon"},
show=True,
labeller=labeller,
)
14 changes: 14 additions & 0 deletions examples/matplotlib/mpl_plot_violin_single_sided.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""
Violin plot single sided
===========

_thumb: .2, .8
_example_title: Single sided violin plot
"""
import arviz as az

data = az.load_arviz_data("rugby")

labeller = az.labels.MapLabeller(var_name_map={"defs": "atts | defs"})
axs = az.plot_violin(data, var_names=["atts"], side="left", show=False)
az.plot_violin(data, var_names=["defs"], side="right", labeller=labeller, ax=axs, show=True)