diff --git a/arviz/plots/kdeplot.py b/arviz/plots/kdeplot.py index a0a3fbe989..65c3eea429 100644 --- a/arviz/plots/kdeplot.py +++ b/arviz/plots/kdeplot.py @@ -4,7 +4,8 @@ from scipy.signal import gaussian, convolve, convolve2d # pylint: disable=no-name-in-module from scipy.sparse import coo_matrix from scipy.stats import entropy - +import xarray as xr +from ..data.inference_data import InferenceData from ..utils import conditional_jit from .plot_utils import _scale_fig_size @@ -146,6 +147,13 @@ def plot_kde( figsize, *_, xt_labelsize, linewidth, markersize = _scale_fig_size(figsize, textsize, 1, 1) + if isinstance(values, xr.Dataset): + raise ValueError( + "Xarray dataset object detected.Use plot_posterior, plot_density, plot_joint" + "or plot_pair instead of plot_kde" + ) + if isinstance(values, InferenceData): + raise ValueError(" Inference Data object detected. Use plot_posterior instead of plot_kde") if values2 is None: if plot_kwargs is None: plot_kwargs = {} diff --git a/arviz/plots/posteriorplot.py b/arviz/plots/posteriorplot.py index 60c713504e..a50cca56eb 100644 --- a/arviz/plots/posteriorplot.py +++ b/arviz/plots/posteriorplot.py @@ -358,6 +358,7 @@ def format_axes(): fill_kwargs={"alpha": kwargs.pop("fill_alpha", 0)}, plot_kwargs={"linewidth": linewidth}, ax=ax, + rug=False, ) else: if bins is None: diff --git a/arviz/tests/test_plots.py b/arviz/tests/test_plots.py index f7825c1a99..1142412c63 100644 --- a/arviz/tests/test_plots.py +++ b/arviz/tests/test_plots.py @@ -7,8 +7,7 @@ import pytest import pymc3 as pm - -from ..data import from_dict, from_pymc3 +from ..data import from_dict, from_pymc3, load_arviz_data from ..stats import compare, psislw from .helpers import eight_schools_params, load_cached_models # pylint: disable=unused-import from ..plots import ( @@ -313,6 +312,18 @@ def test_plot_kde_quantiles(continuous_model, kwargs): assert axes +def test_plot_kde_inference_data(): + """ + Ensure that an exception is raised when plot_kde + is used with an inference data or Xarray dataset object. + """ + eight = load_arviz_data("centered_eight") + with pytest.raises(ValueError, match="Inference Data"): + plot_kde(eight) + with pytest.raises(ValueError, match="Xarray"): + plot_kde(eight.posterior) + + def test_plot_khat(): linewidth = np.random.randn(20000, 10) _, khats = psislw(linewidth)