Skip to content

Commit

Permalink
Development/plot dist comparison warning (#1688)
Browse files Browse the repository at this point in the history
* Add warning status for plot_dist_comparison method [WIP]

* Changes to the warning

* Change comparison variable for warning over the limit subplots

* Fix a typo at distcomparisonplot.py, fix axes.shape(2, 3) at test_plot_matplotlib.py

* Add + 1 in the if statement of the warnings

* DElete .DS_Store, alter .gitignore

* Delete .DS_Store

* Fix the counting of the elements of each sublist

* Change total subplots to compare with len plots

* Nested list truncated #1688

* Add warning status for plot_dist_comparison method [WIP]

* Changes to the warning

* Change comparison variable for warning over the limit subplots

* Fix a typo at distcomparisonplot.py, fix axes.shape(2, 3) at test_plot_matplotlib.py

* Add + 1 in the if statement of the warnings

* DElete .DS_Store, alter .gitignore

* Delete .DS_Store

* Fix the counting of the elements of each sublist

* Change total subplots to compare with len plots

* Nested list truncated #1688

* Fix formating

* Correct total subplots counting

* Black reformation

* Fix warning message

* Add change to changelog #1688

* Fix a typo in changelog

* Fix git conflicts

* fix changelog

Co-authored-by: Oriol (ZBook) <oriol.abril.pla@gmail.com>
  • Loading branch information
alexisperakis and OriolAbril authored Mar 4, 2022
1 parent 6cc9f5e commit cd256c4
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ saved_animations/
# mypy
.mypy_cache

# MacOS generated files
.DS_Store

# Stan file
doc/getting_started/*.stan

Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* Fix legend labels in plot_ppc to reflect prior or posterior. ([1967](https://github.com/arviz-devs/arviz/pull/1967))
* Change `DataFrame.append` to `pandas.concat` ([1973](https://github.com/arviz-devs/arviz/pull/1973))
* Fix axis sharing behaviour in `plot_pair`. ([1985](https://github.com/arviz-devs/arviz/pull/1985))
* Added warning message in `plot_dist_comparison()` in case subplots go over the limit ([1688](https://github.com/arviz-devs/arviz/pull/1688))

### Deprecation

Expand Down
18 changes: 17 additions & 1 deletion arviz/plots/distcomparisonplot.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Density Comparison plot."""
import warnings
from ..labels import BaseLabeller
from ..rcparams import rcParams
from ..utils import _var_names, get_coords
from .plot_utils import get_plotting_function
from ..sel_utils import xarray_var_iter
from ..sel_utils import xarray_var_iter, xarray_sel_iter


def plot_dist_comparison(
Expand Down Expand Up @@ -147,6 +148,21 @@ def plot_dist_comparison(
for data, var in zip(datasets, var_names)
]

total_plots = sum(
1 for _ in xarray_sel_iter(datasets[0], var_names=var_names[0], combined=True)
) * (len(groups) + 1)
maxplots = len(dc_plotters[0]) * (len(groups) + 1)

if total_plots > rcParams["plot.max_subplots"]:
warnings.warn(
"rcParams['plot.max_subplots'] ({rcParam}) is smaller than the number "
"of subplots to plot ({len_plotters}), generating only {max_plots} "
"plots".format(
rcParam=rcParams["plot.max_subplots"], len_plotters=total_plots, max_plots=maxplots
),
UserWarning,
)

nvars = len(dc_plotters[0])
ngroups = len(groups)

Expand Down
7 changes: 7 additions & 0 deletions arviz/tests/base_tests/test_plots_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,13 @@ def test_plot_trace_max_subplots_warning(models):
assert axes.shape == (3, 2)


def test_plot_dist_comparison_warning(models):
with pytest.warns(UserWarning):
with rc_context(rc={"plot.max_subplots": 6}):
axes = plot_dist_comparison(models.model_1)
assert axes.shape == (2, 3)


@pytest.mark.parametrize("kwargs", [{"var_names": ["mu", "tau"], "lines": [("hey", {}, [1])]}])
def test_plot_trace_invalid_varname_warning(models, kwargs):
with pytest.warns(UserWarning, match="valid var.+should be provided"):
Expand Down

0 comments on commit cd256c4

Please sign in to comment.