Skip to content

Commit

Permalink
fixes to matplotlib subplot handling (#1205)
Browse files Browse the repository at this point in the history
* fixes to matplotlib subplot handling

* lint

* minor fix

* remove prints

* extend tests and more fixes

* lint
  • Loading branch information
OriolAbril authored May 24, 2020
1 parent f5bf1f4 commit f5784b5
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 25 deletions.
47 changes: 26 additions & 21 deletions arviz/plots/backends/matplotlib/pairplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,49 +199,52 @@ def plot_pair(
ax.tick_params(labelsize=xt_labelsize)

else:
not_marginals = int(not marginals)
num_subplot_cols = numvars - not_marginals
max_plots = (
numvars ** 2 if rcParams["plot.max_subplots"] is None else rcParams["plot.max_subplots"]
num_subplot_cols ** 2
if rcParams["plot.max_subplots"] is None
else rcParams["plot.max_subplots"]
)
vars_to_plot = np.sum(np.arange(numvars).cumsum() < max_plots)
if vars_to_plot < numvars:
cols_to_plot = np.sum(np.arange(1, num_subplot_cols + 1).cumsum() <= max_plots)
if cols_to_plot < num_subplot_cols:
vars_to_plot = cols_to_plot
warnings.warn(
"rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number "
"of resulting pair plots with these variables, generating only a "
"{side}x{side} grid".format(max_plots=max_plots, side=vars_to_plot),
UserWarning,
)
numvars = vars_to_plot
else:
vars_to_plot = numvars - not_marginals

(figsize, ax_labelsize, _, xt_labelsize, _, markersize) = _scale_fig_size(
figsize, textsize, numvars - 2, numvars - 2
figsize, textsize, vars_to_plot, vars_to_plot
)

point_estimate_marker_kwargs.setdefault("s", markersize + 50)

if ax is None:
fig, ax = plt.subplots(numvars, numvars, figsize=figsize, **backend_kwargs)
fig, ax = plt.subplots(vars_to_plot, vars_to_plot, figsize=figsize, **backend_kwargs,)
hexbin_values = []
for i in range(0, numvars):
for i in range(0, vars_to_plot):
var1 = infdata_group[i]

for j in range(0, numvars):
var2 = infdata_group[j]
for j in range(0, vars_to_plot):
var2 = infdata_group[j + not_marginals]
if i > j:
if ax[j, i].get_figure() is not None:
ax[j, i].remove()
continue

elif i == j:
if marginals:
loc = "right"
plot_dist(var1, ax=ax[i, j], **marginal_kwargs)
else:
loc = "left"
if ax[j, i].get_figure() is not None:
ax[j, i].remove()
continue
elif i == j and marginals:
loc = "right"
plot_dist(var1, ax=ax[i, j], **marginal_kwargs)

else:
if i == j:
loc = "left"

if "scatter" in kind:
ax[j, i].plot(var1, var2, **scatter_kwargs)

Expand Down Expand Up @@ -285,15 +288,15 @@ def plot_pair(

if reference_values:
x_name = flat_var_names[i]
y_name = flat_var_names[j]
y_name = flat_var_names[j + not_marginals]
if x_name and y_name not in difference:
ax[j, i].plot(
reference_values_copy[x_name],
reference_values_copy[y_name],
**reference_values_kwargs,
)

if j != numvars - 1:
if j != vars_to_plot - 1:
ax[j, i].axes.get_xaxis().set_major_formatter(NullFormatter())
else:
ax[j, i].set_xlabel(
Expand All @@ -303,7 +306,9 @@ def plot_pair(
ax[j, i].axes.get_yaxis().set_major_formatter(NullFormatter())
else:
ax[j, i].set_ylabel(
"{}".format(flat_var_names[j]), fontsize=ax_labelsize, wrap=True
"{}".format(flat_var_names[j + not_marginals]),
fontsize=ax_labelsize,
wrap=True,
)
ax[j, i].tick_params(labelsize=xt_labelsize)

Expand Down
2 changes: 1 addition & 1 deletion arviz/plots/traceplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def plot_trace(

plotters = list(xarray_var_iter(data, var_names=var_names, combined=True, skip_dims=skip_dims))
max_plots = rcParams["plot.max_subplots"]
max_plots = len(plotters) if max_plots is None else max_plots
max_plots = len(plotters) if max_plots is None else max(max_plots // 2, 1)
if len(plotters) > max_plots:
warnings.warn(
"rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number "
Expand Down
2 changes: 1 addition & 1 deletion arviz/tests/base_tests/test_plots_bokeh.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def test_plot_trace_discrete(discrete_model):

def test_plot_trace_max_subplots_warning(models):
with pytest.warns(UserWarning):
with rc_context(rc={"plot.max_subplots": 1}):
with rc_context(rc={"plot.max_subplots": 2}):
axes = plot_trace(models.model_1, backend="bokeh", show=False)
assert axes.shape

Expand Down
19 changes: 17 additions & 2 deletions arviz/tests/base_tests/test_plots_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,9 @@ def test_plot_trace_discrete(discrete_model):

def test_plot_trace_max_subplots_warning(models):
with pytest.warns(UserWarning):
with rc_context(rc={"plot.max_subplots": 1}):
with rc_context(rc={"plot.max_subplots": 6}):
axes = plot_trace(models.model_1)
assert axes.shape
assert axes.shape == (3, 2)


@pytest.mark.parametrize("kwargs", [{"var_names": ["mu", "tau"], "lines": [("hey", {}, [1])]}])
Expand Down Expand Up @@ -478,6 +478,21 @@ def test_plot_pair_overlaid(models, kwargs):
assert ax.shape


@pytest.mark.parametrize("marginals", [True, False])
@pytest.mark.parametrize("max_subplots", [True, False])
def test_plot_pair_shapes(marginals, max_subplots):
rng = np.random.default_rng()
idata = from_dict({"a": rng.standard_normal((4, 500, 5))})
if max_subplots:
with rc_context({"plot.max_subplots": 6}):
with pytest.warns(UserWarning, match="3x3 grid"):
ax = plot_pair(idata, marginals=marginals)
else:
ax = plot_pair(idata, marginals=marginals)
side = 3 if max_subplots else (4 + marginals)
assert ax.shape == (side, side)


@pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
@pytest.mark.parametrize("alpha", [None, 0.2, 1])
@pytest.mark.parametrize("animated", [False, True])
Expand Down

0 comments on commit f5784b5

Please sign in to comment.