diff --git a/.gitignore b/.gitignore index b70c2b7848..a699fab126 100644 --- a/.gitignore +++ b/.gitignore @@ -70,6 +70,9 @@ saved_animations/ # mypy .mypy_cache +# MacOS generated files +.DS_Store + # Stan file doc/getting_started/*.stan diff --git a/CHANGELOG.md b/CHANGELOG.md index 64b9aeac0f..b820cec538 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ * Fix legend labels in plot_ppc to reflect prior or posterior. ([1967](https://github.com/arviz-devs/arviz/pull/1967)) * Change `DataFrame.append` to `pandas.concat` ([1973](https://github.com/arviz-devs/arviz/pull/1973)) * Fix axis sharing behaviour in `plot_pair`. ([1985](https://github.com/arviz-devs/arviz/pull/1985)) +* Added warning message in `plot_dist_comparison()` in case subplots go over the limit ([1688](https://github.com/arviz-devs/arviz/pull/1688)) ### Deprecation diff --git a/arviz/plots/distcomparisonplot.py b/arviz/plots/distcomparisonplot.py index e3c115e431..606a96fc7c 100644 --- a/arviz/plots/distcomparisonplot.py +++ b/arviz/plots/distcomparisonplot.py @@ -1,9 +1,10 @@ """Density Comparison plot.""" +import warnings from ..labels import BaseLabeller from ..rcparams import rcParams from ..utils import _var_names, get_coords from .plot_utils import get_plotting_function -from ..sel_utils import xarray_var_iter +from ..sel_utils import xarray_var_iter, xarray_sel_iter def plot_dist_comparison( @@ -147,6 +148,21 @@ def plot_dist_comparison( for data, var in zip(datasets, var_names) ] + total_plots = sum( + 1 for _ in xarray_sel_iter(datasets[0], var_names=var_names[0], combined=True) + ) * (len(groups) + 1) + maxplots = len(dc_plotters[0]) * (len(groups) + 1) + + if total_plots > rcParams["plot.max_subplots"]: + warnings.warn( + "rcParams['plot.max_subplots'] ({rcParam}) is smaller than the number " + "of subplots to plot ({len_plotters}), generating only {max_plots} " + "plots".format( + rcParam=rcParams["plot.max_subplots"], len_plotters=total_plots, max_plots=maxplots + ), + UserWarning, + ) + nvars = len(dc_plotters[0]) ngroups = len(groups) diff --git a/arviz/tests/base_tests/test_plots_matplotlib.py b/arviz/tests/base_tests/test_plots_matplotlib.py index 1fa7af4e7b..235fdb6745 100644 --- a/arviz/tests/base_tests/test_plots_matplotlib.py +++ b/arviz/tests/base_tests/test_plots_matplotlib.py @@ -257,6 +257,13 @@ def test_plot_trace_max_subplots_warning(models): assert axes.shape == (3, 2) +def test_plot_dist_comparison_warning(models): + with pytest.warns(UserWarning): + with rc_context(rc={"plot.max_subplots": 6}): + axes = plot_dist_comparison(models.model_1) + assert axes.shape == (2, 3) + + @pytest.mark.parametrize("kwargs", [{"var_names": ["mu", "tau"], "lines": [("hey", {}, [1])]}]) def test_plot_trace_invalid_varname_warning(models, kwargs): with pytest.warns(UserWarning, match="valid var.+should be provided"):