Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed Posterior plot errors with boolean array. #1707

Merged
merged 11 commits into from
Mar 23, 2022
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
* Fixed xarray related tests. ([1726](https://github.com/arviz-devs/arviz/pull/1726))
* Fix Bokeh deprecation warnings ([1657](https://github.com/arviz-devs/arviz/pull/1657))
* Fix credible inteval percentage in legend in `plot_loo_pit` ([1745](https://github.com/arviz-devs/arviz/pull/1745))
* Fixed plot_posterior with boolean data ([1707](https://github.com/arviz-devs/arviz/pull/1707))

### Deprecation
* Deprecated `index_origin` and `order` arguments in `az.summary` ([1201](https://github.com/arviz-devs/arviz/pull/1201))
Expand Down
22 changes: 17 additions & 5 deletions arviz/plots/backends/bokeh/posteriorplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,18 +286,30 @@ def format_axes():
show=False,
)
_, hist, edges = histogram(values, bins="auto")
else:
elif values.dtype.kind == "i" or (values.dtype.kind == "f" and kind == "hist"):
if bins is None:
if values.dtype.kind == "i":
bins = get_bins(values)
else:
bins = "auto"
bins = get_bins(values)
kwargs.setdefault("align", "left")
kwargs.setdefault("color", "blue")
_, hist, edges = histogram(values, bins=bins)
ax.quad(
top=hist, bottom=0, left=edges[:-1], right=edges[1:], fill_alpha=0.35, line_alpha=0.35
)
elif values.dtype.kind == "b":
if bins is None:
bins = "auto"
kwargs.setdefault("color", "blue")

hist = np.array([(~values).sum(), values.sum()])
edges = np.array([-0.5, 0.5, 1.5])
ax.quad(
top=hist, bottom=0, left=edges[:-1], right=edges[1:], fill_alpha=0.35, line_alpha=0.35
)
hdi_prob = "hide"
ax.xaxis.ticker = [0, 1]
ax.xaxis.major_label_overrides = {0: "False", 1: "True"}
else:
raise TypeError("Values must be float, integer or boolean")

format_axes()
max_data = hist.max()
Expand Down
21 changes: 13 additions & 8 deletions arviz/plots/backends/matplotlib/posteriorplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,18 +319,23 @@ def format_axes():
rug=False,
show=False,
)
else:
elif values.dtype.kind == "i" or (values.dtype.kind == "f" and kind == "hist"):
if bins is None:
if values.dtype.kind == "i":
xmin = values.min()
xmax = values.max()
bins = get_bins(values)
ax.set_xlim(xmin - 0.5, xmax + 0.5)
else:
bins = "auto"
xmin = values.min()
xmax = values.max()
bins = get_bins(values)
ax.set_xlim(xmin - 0.5, xmax + 0.5)
kwargs.setdefault("align", "left")
kwargs.setdefault("color", "C0")
ax.hist(values, bins=bins, alpha=0.35, **kwargs)
elif values.dtype.kind == "b":
if bins is None:
bins = "auto"
kwargs.setdefault("color", "C0")
ax.bar(["False", "True"], [(~values).sum(), values.sum()], alpha=0.35, **kwargs)
hdi_prob = "hide"
utkarsh-maheshwari marked this conversation as resolved.
Show resolved Hide resolved
else:
raise TypeError("Values must be float, integer or boolean")
ahartikainen marked this conversation as resolved.
Show resolved Hide resolved

plot_height = ax.get_ylim()[1]

Expand Down
11 changes: 11 additions & 0 deletions arviz/tests/base_tests/test_plots_bokeh.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,17 @@ def test_plot_posterior_discrete(discrete_model, kwargs):
assert axes.shape


def test_plot_posterior_boolean():
data = np.random.choice(a=[False, True], size=(4, 100))
axes = plot_posterior(data, backend="bokeh", show=False)
assert axes.shape


def test_plot_posterior_bad_type():
with pytest.raises(TypeError):
plot_posterior(np.array(["a", "b", "c"]), backend="bokeh", show=False)


def test_plot_posterior_bad(models):
with pytest.raises(ValueError):
plot_posterior(models.model_1, backend="bokeh", show=False, rope="bad_value")
Expand Down
16 changes: 15 additions & 1 deletion arviz/tests/base_tests/test_plots_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,12 +951,26 @@ def test_plot_posterior(models, kwargs):
assert axes.shape


def test_plot_posterior_boolean():
data = np.random.choice(a=[False, True], size=(4, 100))
axes = plot_posterior(data)
assert axes
ahartikainen marked this conversation as resolved.
Show resolved Hide resolved
plt.draw()
labels = [label.get_text() for label in axes.get_xticklabels()]
assert all(item in labels for item in ("True", "False"))


@pytest.mark.parametrize("kwargs", [{}, {"point_estimate": "mode"}, {"bins": None, "kind": "hist"}])
def test_plot_posterior_discrete(discrete_model, kwargs):
axes = plot_posterior(discrete_model, **kwargs)
assert axes.shape



def test_plot_posterior_bad_type(models):
with pytest.raises(TypeError):
plot_posterior(np.array(["a", "b", "c"]))


def test_plot_posterior_bad(models):
with pytest.raises(ValueError):
plot_posterior(models.model_1, rope="bad_value")
Expand Down