diff --git a/CHANGELOG.md b/CHANGELOG.md index 07a2fd5998..59fa526c1c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ ### New features +- Bayes Factor plot: Use arviz's kde instead of the one from scipy ([2237](https://github.com/arviz-devs/arviz/pull/2237)) + ### Maintenance and fixes ### Deprecation diff --git a/arviz/plots/backends/matplotlib/bfplot.py b/arviz/plots/backends/matplotlib/bfplot.py index 824d6a3512..8d3201334f 100644 --- a/arviz/plots/backends/matplotlib/bfplot.py +++ b/arviz/plots/backends/matplotlib/bfplot.py @@ -1,5 +1,4 @@ import matplotlib.pyplot as plt -import numpy as np from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser from ...distplot import plot_dist @@ -10,11 +9,8 @@ def plot_bf( ax, bf_10, bf_01, - xlim, prior, posterior, - prior_pdf, - posterior_pdf, ref_val, prior_at_ref_val, posterior_at_ref_val, @@ -52,14 +48,22 @@ def plot_bf( if ax is None: _, ax = create_axes_grid(1, backend_kwargs=backend_kwargs) - x = np.linspace(*xlim, 5000) - - if posterior.dtype.kind == "f": - ax.plot(x, prior_pdf(x), color=colors[0], label="Prior", **plot_kwargs) - ax.plot(x, posterior_pdf(x), color=colors[1], label="Posterior", **plot_kwargs) - elif posterior.dtype.kind == "i": - plot_dist(prior, color=colors[0], label="Prior", ax=ax, hist_kwargs=hist_kwargs) - plot_dist(posterior, color=colors[1], label="Posterior", ax=ax, hist_kwargs=hist_kwargs) + plot_dist( + prior, + color=colors[0], + label="Prior", + ax=ax, + plot_kwargs=plot_kwargs, + hist_kwargs=hist_kwargs, + ) + plot_dist( + posterior, + color=colors[1], + label="Posterior", + ax=ax, + plot_kwargs=plot_kwargs, + hist_kwargs=hist_kwargs, + ) ax.plot(ref_val, posterior_at_ref_val, "ko", lw=1.5) ax.plot(ref_val, prior_at_ref_val, "ko", lw=1.5) diff --git a/arviz/plots/bfplot.py b/arviz/plots/bfplot.py index 3dd2e0ca2e..bc1d1899f6 100644 --- a/arviz/plots/bfplot.py +++ b/arviz/plots/bfplot.py @@ -1,10 +1,12 @@ # Plotting and reporting Bayes Factor given idata, var name, prior distribution and reference value +# pylint: disable=unbalanced-tuple-unpacking import logging -from scipy.stats import gaussian_kde +from numpy import interp from ..data.utils import extract from .plot_utils import get_plotting_function +from ..stats.density_utils import _kde_linear _log = logging.getLogger(__name__) @@ -14,7 +16,6 @@ def plot_bf( var_name, prior=None, ref_val=0, - xlim=None, colors=("C0", "C1"), figsize=None, textsize=None, @@ -47,8 +48,6 @@ def plot_bf( In case we want to use different prior, for example for sensitivity analysis. ref_val : int, default 0 Point-null for Bayes factor estimation. - xlim : tuple, optional - Set the x limits, which might be used for visualization purposes. colors : tuple, default ('C0', 'C1') Tuple of valid Matplotlib colors. First element for the prior, second for the posterior. figsize : (float, float), optional @@ -94,7 +93,7 @@ def plot_bf( ... prior={"a":np.random.normal(0, 1, 5000)}) >>> az.plot_bf(idata, var_name="a", ref_val=0) """ - posterior = extract(idata, var_names=var_name) + posterior = extract(idata, var_names=var_name).values if ref_val > posterior.max() or ref_val < posterior.min(): _log.warning( @@ -106,21 +105,15 @@ def plot_bf( _log.warning("Posterior distribution has {posterior.ndim} dimensions") if prior is None: - prior = extract(idata, var_names=var_name, group="prior") - - if xlim is None: - xlim = (prior.min(), prior.max()) + prior = extract(idata, var_names=var_name, group="prior").values if posterior.dtype.kind == "f": - posterior_pdf = gaussian_kde(posterior) - prior_pdf = gaussian_kde(prior) - - posterior_at_ref_val = posterior_pdf(ref_val) - prior_at_ref_val = prior_pdf(ref_val) + posterior_grid, posterior_pdf = _kde_linear(posterior) + prior_grid, prior_pdf = _kde_linear(prior) + posterior_at_ref_val = interp(ref_val, posterior_grid, posterior_pdf) + prior_at_ref_val = interp(ref_val, prior_grid, prior_pdf) elif posterior.dtype.kind == "i": - prior_pdf = None - posterior_pdf = None posterior_at_ref_val = (posterior == ref_val).mean() prior_at_ref_val = (prior == ref_val).mean() @@ -131,11 +124,8 @@ def plot_bf( ax=ax, bf_10=bf_10.item(), bf_01=bf_01.item(), - xlim=xlim, prior=prior, posterior=posterior, - prior_pdf=prior_pdf, - posterior_pdf=posterior_pdf, ref_val=ref_val, prior_at_ref_val=prior_at_ref_val, posterior_at_ref_val=posterior_at_ref_val,