diff --git a/CHANGELOG.md b/CHANGELOG.md index 269fa95967..be2e14a36e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,7 @@ * Set `fill_last` argument of `plot_kde` to False by default (#1158) * plot_ppc animation: improve docs and error handling (#1162) * Fix import error when wrapped function docstring is empty (#1192) +* Fix passing axes to plot_density with several datasets ([#1198](https://github.com/arviz-devs/arviz/pull/1198)) ### Deprecation * `hpd` function deprecated in favor of `hdi`. `credible_interval` argument replaced by `hdi_prob`throughout with exception of `plot_loo_pit` (#1176) diff --git a/arviz/plots/backends/matplotlib/densityplot.py b/arviz/plots/backends/matplotlib/densityplot.py index acb4053cb6..5953369de3 100644 --- a/arviz/plots/backends/matplotlib/densityplot.py +++ b/arviz/plots/backends/matplotlib/densityplot.py @@ -74,8 +74,8 @@ def plot_density( if n_data > 1: for m_idx, label in enumerate(data_labels): - ax[0].plot([], label=label, c=colors[m_idx], markersize=markersize) - ax[0].legend(fontsize=xt_labelsize) + ax.item(0).plot([], label=label, c=colors[m_idx], markersize=markersize) + ax.item(0).legend(fontsize=xt_labelsize) if backend_show(show): plt.show() diff --git a/arviz/tests/base_tests/test_plots_matplotlib.py b/arviz/tests/base_tests/test_plots_matplotlib.py index c994d213dd..c678064499 100644 --- a/arviz/tests/base_tests/test_plots_matplotlib.py +++ b/arviz/tests/base_tests/test_plots_matplotlib.py @@ -101,12 +101,16 @@ def fig_ax(): {"hdi_markers": ["v"]}, {"shade": 1}, {"transform": lambda x: x + 1}, + {"ax": plt.subplots(6, 3)[1]}, ], ) def test_plot_density_float(models, kwargs): obj = [getattr(models, model_fit) for model_fit in ["model_1", "model_2"]] axes = plot_density(obj, **kwargs) - assert axes.shape[0] >= 18 + if "ax" in kwargs: + assert axes.shape == (6, 3) + else: + assert axes.shape[0] >= 18 def test_plot_density_discrete(discrete_model):