Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Development/plot dist comparison warning #1688

Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
876dcc6
Add warning status for plot_dist_comparison method [WIP]
alexisperakis May 5, 2021
9721102
Changes to the warning
alexisperakis May 5, 2021
260e2e7
Change comparison variable for warning over the limit subplots
alexisperakis May 5, 2021
fde9f77
Fix a typo at distcomparisonplot.py, fix axes.shape(2, 3) at test_plo…
alexisperakis May 5, 2021
7c2b569
Add + 1 in the if statement of the warnings
alexisperakis May 6, 2021
079cf5e
DElete .DS_Store, alter .gitignore
alexisperakis May 7, 2021
2553ad5
Delete .DS_Store
alexisperakis May 9, 2021
49ce2fe
Fix the counting of the elements of each sublist
alexisperakis May 9, 2021
62a7a67
Change total subplots to compare with len plots
alexisperakis May 10, 2021
ee8c9d1
Nested list truncated #1688
alexisperakis May 10, 2021
ab06cd2
Add warning status for plot_dist_comparison method [WIP]
alexisperakis May 5, 2021
ae55486
Changes to the warning
alexisperakis May 5, 2021
ebe83e5
Change comparison variable for warning over the limit subplots
alexisperakis May 5, 2021
7faa75b
Fix a typo at distcomparisonplot.py, fix axes.shape(2, 3) at test_plo…
alexisperakis May 5, 2021
d0d7a30
Add + 1 in the if statement of the warnings
alexisperakis May 6, 2021
d5d805b
DElete .DS_Store, alter .gitignore
alexisperakis May 7, 2021
4864849
Delete .DS_Store
alexisperakis May 9, 2021
ed73a13
Fix the counting of the elements of each sublist
alexisperakis May 9, 2021
e7f6a97
Change total subplots to compare with len plots
alexisperakis May 10, 2021
f596483
Nested list truncated #1688
alexisperakis May 10, 2021
c59c6fe
Fix formating
alexisperakis May 10, 2021
bf376f1
Resolve conflicts
alexisperakis May 12, 2021
d54e336
Correct total subplots counting
alexisperakis May 13, 2021
9911abc
Black reformation
alexisperakis May 13, 2021
fcefcfd
Fix warning message
alexisperakis May 15, 2021
3a63dbe
Add change to changelog #1688
alexisperakis May 16, 2021
a5a6060
Fix a typo in changelog
alexisperakis May 16, 2021
bf598b9
Fix git conflicts
alexisperakis May 18, 2021
92f9947
Merge branch 'main' of https://github.com/arviz-devs/arviz into devel…
alexisperakis May 24, 2021
56fc92c
Merge remote-tracking branch 'upstream/main' into development/plot_di…
OriolAbril Mar 4, 2022
9a2a1f1
fix changelog
OriolAbril Mar 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ saved_animations/
# mypy
.mypy_cache

# MacOS generated files
.DS_Store

# Stan file
doc/getting_started/*.stan

Expand Down
17 changes: 17 additions & 0 deletions arviz/plots/distcomparisonplot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Density Comparison plot."""
import warnings
from ..labels import BaseLabeller
from ..rcparams import rcParams
from ..utils import _var_names, get_coords
Expand Down Expand Up @@ -134,6 +135,22 @@ def plot_dist_comparison(
for data, var in zip(datasets, var_names)
]

total_plots = sum(
sum(1 for _ in xarray_var_iter(data, var_names=var, combined=True))
for data, var in zip(datasets, var_names)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here similar comment as for next line, the nested loop is not needed and is actually incorrect. It should be something like:

total_plots = sum(1 for _ in xarray_sel_iter(datasets[0], var_names=var_names[0], combined=True))

Also as we only want to count, we can use sel instead of using var, we don't need to subset the data to get the values in each position to count how many positions there are.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay thank you for this explanation it was really useful and sorry for all the trouble caused, it's my first time contributing.
So what I have understood and already done (no commit yet) is that I have compared the total_plots which are the desired plots by the user with the rcParams["plot.max_subplots"]. If total_plots are greater than this then the warning is popping up and it does not truncate the plots because they are already truncated at the creation. I also used the maxplots only at the description of the warning. I've also included a test case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds great, and don't worry about the discussion and back-and-forth comments, I think they were useful to all of us.

maxplots = sum(len(splot) for splot in dc_plotters)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maxplots as defined here is not necessarily equal to the rcParam, here plots are done in blocks of 3-4 subplots, so this will be equal or lower than the rcParam. To avoid confusion, I think the warning should use the rcParam first, then in the second occurrence it can be maxplots. It is not necessary nor correct to loop over dc_plotters though, it should be:

maxplots = len(dc_plotters[0]) *  (len(groups) + 1)


if total_plots > rcParams["plot.max_subplots"]:
warnings.warn(
"rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number "
"of variables to plot ({len_plotters}), generating only {max_plots} "
"plots".format(max_plots=maxplots, len_plotters=total_plots),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be updated, not sure what is the best way to do so, but somehow. Here are the things that need to somehow be changed:

  • smaller than the number of variables to plot does not work here. In most plots the number of subplots is the same as the number of variables, but here each variable has 3-4 subplots. Maybe using subplots instead of variable is enough?
  • rcParams['plot.max_subplots'] ({max_plots}) precisely because of this 3-4 factor, max_plots is not necessarily equal to the rcParam. I think it will be very confusing to users to set the rcParam to 7, then get a warning saying we'll only plot 6 subplots because you have set the rcParam to 6.

UserWarning,
)
for i in range(0, len(dc_plotters)):
dc_plotters[i] = dc_plotters[i][:len_plots]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not necessary because the inner lists are truncated at creation, this would have no effect.


nvars = len(dc_plotters[0])
ngroups = len(groups)

Expand Down
7 changes: 7 additions & 0 deletions arviz/tests/base_tests/test_plots_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,13 @@ def test_plot_trace_max_subplots_warning(models):
assert axes.shape == (3, 2)


def test_plot_dist_comparison_warning(models):
with pytest.warns(UserWarning):
with rc_context(rc={"plot.max_subplots": 6}):
axes = plot_dist_comparison(models.model_1)
assert axes.shape == (2, 3)


@pytest.mark.parametrize("kwargs", [{"var_names": ["mu", "tau"], "lines": [("hey", {}, [1])]}])
def test_plot_trace_invalid_varname_warning(models, kwargs):
with pytest.warns(UserWarning, match="valid var.+should be provided"):
Expand Down