From 51ce7e860512e0e4ea68ba0cdf2f85f9445cf001 Mon Sep 17 00:00:00 2001 From: Marco Gorelli Date: Sat, 13 Feb 2021 13:42:59 +0000 Subject: [PATCH] remove varname from legend in plotppc --- CHANGELOG.md | 1 + arviz/plots/backends/matplotlib/ppcplot.py | 26 ++++++++----------- .../tests/base_tests/test_plots_matplotlib.py | 8 ++++++ 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a41456f2e5..af7a9bd3ef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ ### Maintenance and fixes * Updated `from_cmdstan` and `from_numpyro` converter to follow schema convention ([1541](https://github.com/arviz-devs/arviz/pull/1541) and [1525](https://github.com/arviz-devs/arviz/pull/1525)) * Fix calculation of mode as point estimate ([1552](https://github.com/arviz-devs/arviz/pull/1552)) +* Remove variable name from legend in posterior predictive plot ([1559](https://github.com/arviz-devs/arviz/pull/1559)) ### Deprecation * Removed Geweke diagnostic ([1545](https://github.com/arviz-devs/arviz/pull/1545)) diff --git a/arviz/plots/backends/matplotlib/ppcplot.py b/arviz/plots/backends/matplotlib/ppcplot.py index 0cae2dcc71..320dc58469 100644 --- a/arviz/plots/backends/matplotlib/ppcplot.py +++ b/arviz/plots/backends/matplotlib/ppcplot.py @@ -124,14 +124,12 @@ def plot_ppc( plot_kwargs = {"color": color, "alpha": alpha, "linewidth": 0.5 * linewidth} if dtype == "i": plot_kwargs["drawstyle"] = "steps-pre" - ax_i.plot( - [], color=color, label="{} predictive {}".format(group.capitalize(), pp_var_name) - ) + ax_i.plot([], color=color, label="{} predictive".format(group.capitalize())) if observed: if dtype == "f": plot_kde( obs_vals, - label="Observed {}".format(var_name), + label="Observed", plot_kwargs={"color": "k", "linewidth": linewidth, "zorder": 3}, fill_kwargs={"alpha": 0}, ax=ax_i, @@ -144,7 +142,7 @@ def plot_ppc( ax_i.plot( bin_edges, hist, - label="Observed {}".format(var_name), + label="Observed", color="k", linewidth=linewidth, zorder=3, @@ -179,7 +177,7 @@ def plot_ppc( ax_i.plot(x_s, y_s, **plot_kwargs) if mean: - label = "{} predictive mean {}".format(group.capitalize(), pp_var_name) + label = "{} predictive mean".format(group.capitalize()) if dtype == "f": rep = len(pp_densities) len_density = len(pp_densities[0]) @@ -224,7 +222,7 @@ def plot_ppc( *_empirical_cdf(obs_vals), color="k", linewidth=linewidth, - label="Observed {}".format(var_name), + label="Observed", drawstyle=drawstyle, zorder=3 ) @@ -253,7 +251,7 @@ def plot_ppc( drawstyle=drawstyle, linewidth=linewidth ) - ax_i.plot([], color=color, label="Posterior predictive {}".format(pp_var_name)) + ax_i.plot([], color=color, label="Posterior predictive") if mean: ax_i.plot( *_empirical_cdf(pp_vals.flatten()), @@ -261,7 +259,7 @@ def plot_ppc( linestyle="--", linewidth=linewidth * 1.5, drawstyle=drawstyle, - label="Posterior predictive mean {}".format(pp_var_name) + label="Posterior predictive mean" ) ax_i.set_yticks([0, 0.5, 1]) @@ -276,7 +274,7 @@ def plot_ppc( "linewidth": linewidth * 1.5, "zorder": 3, }, - label="Posterior predictive mean {}".format(pp_var_name), + label="Posterior predictive mean", ax=ax_i, legend=legend, ) @@ -290,7 +288,7 @@ def plot_ppc( hist, color=color, linewidth=linewidth * 1.5, - label="Posterior predictive mean {}".format(pp_var_name), + label="Posterior predictive mean", zorder=3, linestyle="--", drawstyle="steps-pre", @@ -316,7 +314,7 @@ def plot_ppc( color="k", markersize=markersize, alpha=alpha, - label="Observed {}".format(var_name), + label="Observed", zorder=4, ) @@ -340,9 +338,7 @@ def plot_ppc( vals, yvals, "o", zorder=2, color=color, markersize=markersize, alpha=alpha ) - ax_i.plot( - [], color=color, marker="o", label="Posterior predictive {}".format(pp_var_name) - ) + ax_i.plot([], color=color, marker="o", label="Posterior predictive") ax_i.set_yticks([]) diff --git a/arviz/tests/base_tests/test_plots_matplotlib.py b/arviz/tests/base_tests/test_plots_matplotlib.py index 4c3ee5cef5..a28c51c750 100644 --- a/arviz/tests/base_tests/test_plots_matplotlib.py +++ b/arviz/tests/base_tests/test_plots_matplotlib.py @@ -794,6 +794,14 @@ def test_plot_ppc_bad_ax(models, fig_ax): plot_ppc(models.model_1, ax=ax2) +def test_plot_legend(models): + axes = plot_ppc(models.model_1) + legend_texts = axes.get_legend().get_texts() + result = [i.get_text() for i in legend_texts] + expected = ["Posterior predictive", "Observed", "Posterior predictive mean"] + assert result == expected + + @pytest.mark.parametrize("var_names", (None, "mu", ["mu", "tau"])) def test_plot_violin(models, var_names): axes = plot_violin(models.model_1, var_names=var_names)