Skip to content

Commit

Permalink
Addressing issue number #470 (using plot_kde with inference data obje…
Browse files Browse the repository at this point in the history
…ct) (#600)

* Made the following additions:
 * Using inference data or xr.dataset with plot_kde now raises an error prompting user to use plot_posterior instead.
 * Rug plots displayed by default for plot_posterior plots.
 * Not included the tests yet.

* Added the test

* Made the suggested changes.

* fixed load_arviz_data import

* Made the following additions:
 * Using inference data or xr.dataset with plot_kde now raises an error prompting user to use plot_posterior instead.
 * Rug plots displayed by default for plot_posterior plots.
 * Not included the tests yet.

* Added the test

* Made the suggested changes.

* fixed load_arviz_data import

* Made the suggested changes.

* Added a docstring to my test.
  • Loading branch information
Ban-zee authored and canyon289 committed Mar 7, 2019
1 parent 43d8e46 commit 99dd219
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
10 changes: 9 additions & 1 deletion arviz/plots/kdeplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from scipy.signal import gaussian, convolve, convolve2d # pylint: disable=no-name-in-module
from scipy.sparse import coo_matrix
from scipy.stats import entropy

import xarray as xr
from ..data.inference_data import InferenceData
from ..utils import conditional_jit
from .plot_utils import _scale_fig_size

Expand Down Expand Up @@ -146,6 +147,13 @@ def plot_kde(

figsize, *_, xt_labelsize, linewidth, markersize = _scale_fig_size(figsize, textsize, 1, 1)

if isinstance(values, xr.Dataset):
raise ValueError(
"Xarray dataset object detected.Use plot_posterior, plot_density, plot_joint"
"or plot_pair instead of plot_kde"
)
if isinstance(values, InferenceData):
raise ValueError(" Inference Data object detected. Use plot_posterior instead of plot_kde")
if values2 is None:
if plot_kwargs is None:
plot_kwargs = {}
Expand Down
1 change: 1 addition & 0 deletions arviz/plots/posteriorplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ def format_axes():
fill_kwargs={"alpha": kwargs.pop("fill_alpha", 0)},
plot_kwargs={"linewidth": linewidth},
ax=ax,
rug=False,
)
else:
if bins is None:
Expand Down
15 changes: 13 additions & 2 deletions arviz/tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import pytest
import pymc3 as pm


from ..data import from_dict, from_pymc3
from ..data import from_dict, from_pymc3, load_arviz_data
from ..stats import compare, psislw
from .helpers import eight_schools_params, load_cached_models # pylint: disable=unused-import
from ..plots import (
Expand Down Expand Up @@ -313,6 +312,18 @@ def test_plot_kde_quantiles(continuous_model, kwargs):
assert axes


def test_plot_kde_inference_data():
"""
Ensure that an exception is raised when plot_kde
is used with an inference data or Xarray dataset object.
"""
eight = load_arviz_data("centered_eight")
with pytest.raises(ValueError, match="Inference Data"):
plot_kde(eight)
with pytest.raises(ValueError, match="Xarray"):
plot_kde(eight.posterior)


def test_plot_khat():
linewidth = np.random.randn(20000, 10)
_, khats = psislw(linewidth)
Expand Down

0 comments on commit 99dd219

Please sign in to comment.