From 99dd21959aac0e9152ec8f8d933f925001accf37 Mon Sep 17 00:00:00 2001 From: Aniruddha Banerjea <29407246+Ban-zee@users.noreply.github.com> Date: Fri, 8 Mar 2019 01:13:04 +0530 Subject: [PATCH] Addressing issue number #470 (using plot_kde with inference data object) (#600) * Made the following additions: * Using inference data or xr.dataset with plot_kde now raises an error prompting user to use plot_posterior instead. * Rug plots displayed by default for plot_posterior plots. * Not included the tests yet. * Added the test * Made the suggested changes. * fixed load_arviz_data import * Made the following additions: * Using inference data or xr.dataset with plot_kde now raises an error prompting user to use plot_posterior instead. * Rug plots displayed by default for plot_posterior plots. * Not included the tests yet. * Added the test * Made the suggested changes. * fixed load_arviz_data import * Made the suggested changes. * Added a docstring to my test. --- arviz/plots/kdeplot.py | 10 +++++++++- arviz/plots/posteriorplot.py | 1 + arviz/tests/test_plots.py | 15 +++++++++++++-- 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/arviz/plots/kdeplot.py b/arviz/plots/kdeplot.py index a0a3fbe989..65c3eea429 100644 --- a/arviz/plots/kdeplot.py +++ b/arviz/plots/kdeplot.py @@ -4,7 +4,8 @@ from scipy.signal import gaussian, convolve, convolve2d # pylint: disable=no-name-in-module from scipy.sparse import coo_matrix from scipy.stats import entropy - +import xarray as xr +from ..data.inference_data import InferenceData from ..utils import conditional_jit from .plot_utils import _scale_fig_size @@ -146,6 +147,13 @@ def plot_kde( figsize, *_, xt_labelsize, linewidth, markersize = _scale_fig_size(figsize, textsize, 1, 1) + if isinstance(values, xr.Dataset): + raise ValueError( + "Xarray dataset object detected.Use plot_posterior, plot_density, plot_joint" + "or plot_pair instead of plot_kde" + ) + if isinstance(values, InferenceData): + raise ValueError(" Inference Data object detected. Use plot_posterior instead of plot_kde") if values2 is None: if plot_kwargs is None: plot_kwargs = {} diff --git a/arviz/plots/posteriorplot.py b/arviz/plots/posteriorplot.py index 60c713504e..a50cca56eb 100644 --- a/arviz/plots/posteriorplot.py +++ b/arviz/plots/posteriorplot.py @@ -358,6 +358,7 @@ def format_axes(): fill_kwargs={"alpha": kwargs.pop("fill_alpha", 0)}, plot_kwargs={"linewidth": linewidth}, ax=ax, + rug=False, ) else: if bins is None: diff --git a/arviz/tests/test_plots.py b/arviz/tests/test_plots.py index f7825c1a99..1142412c63 100644 --- a/arviz/tests/test_plots.py +++ b/arviz/tests/test_plots.py @@ -7,8 +7,7 @@ import pytest import pymc3 as pm - -from ..data import from_dict, from_pymc3 +from ..data import from_dict, from_pymc3, load_arviz_data from ..stats import compare, psislw from .helpers import eight_schools_params, load_cached_models # pylint: disable=unused-import from ..plots import ( @@ -313,6 +312,18 @@ def test_plot_kde_quantiles(continuous_model, kwargs): assert axes +def test_plot_kde_inference_data(): + """ + Ensure that an exception is raised when plot_kde + is used with an inference data or Xarray dataset object. + """ + eight = load_arviz_data("centered_eight") + with pytest.raises(ValueError, match="Inference Data"): + plot_kde(eight) + with pytest.raises(ValueError, match="Xarray"): + plot_kde(eight.posterior) + + def test_plot_khat(): linewidth = np.random.randn(20000, 10) _, khats = psislw(linewidth)