diff --git a/arviz/plots/backends/bokeh/khatplot.py b/arviz/plots/backends/bokeh/khatplot.py index c09ff99f36..a424680a16 100644 --- a/arviz/plots/backends/bokeh/khatplot.py +++ b/arviz/plots/backends/bokeh/khatplot.py @@ -21,7 +21,7 @@ def plot_khat( figsize, xdata, khats, - sample_size, + good_k, kwargs, threshold, coord_labels, @@ -55,10 +55,8 @@ def plot_khat( if hlines_kwargs is None: hlines_kwargs = {} - if sample_size is None: + if good_k is None: good_k = 0.7 - else: - good_k = min(1 - 1 / np.log10(sample_size), 0.7) hlines_kwargs.setdefault("hlines", [0, good_k, 1]) @@ -137,7 +135,7 @@ def plot_khat( xmax = len(khats) if show_bins: - bin_edges = np.array([ymin, threshold, 1, ymax]) + bin_edges = np.array([ymin, good_k, 1, ymax]) bin_edges = bin_edges[(bin_edges >= ymin) & (bin_edges <= ymax)] hist, _, _ = histogram(khats, bin_edges) for idx, count in enumerate(hist): diff --git a/arviz/plots/backends/matplotlib/khatplot.py b/arviz/plots/backends/matplotlib/khatplot.py index f6628b3ae6..af30bc832d 100644 --- a/arviz/plots/backends/matplotlib/khatplot.py +++ b/arviz/plots/backends/matplotlib/khatplot.py @@ -20,7 +20,7 @@ def plot_khat( figsize, xdata, khats, - sample_size, + good_k, kwargs, threshold, coord_labels, @@ -62,10 +62,8 @@ def plot_khat( backend_kwargs.setdefault("figsize", figsize) backend_kwargs["squeeze"] = True - if sample_size is None: + if good_k is None: good_k = 0.7 - else: - good_k = min(1 - 1 / np.log10(sample_size), 0.7) hlines_kwargs = matplotlib_kwarg_dealiaser(hlines_kwargs, "hlines") hlines_kwargs.setdefault("hlines", [0, good_k, 1]) @@ -157,7 +155,7 @@ def plot_khat( ) if show_bins: - bin_edges = np.array([ymin, 0.5, 0.7, 1, ymax]) + bin_edges = np.array([ymin, good_k, 1, ymax]) bin_edges = bin_edges[(bin_edges >= ymin) & (bin_edges <= ymax)] hist, _, _ = histogram(khats, bin_edges) for idx, count in enumerate(hist): diff --git a/arviz/plots/khatplot.py b/arviz/plots/khatplot.py index e2bc0bbdad..6e4383ae8f 100644 --- a/arviz/plots/khatplot.py +++ b/arviz/plots/khatplot.py @@ -1,6 +1,7 @@ """Pareto tail indices plot.""" import logging +import warnings import numpy as np from xarray import DataArray @@ -163,19 +164,20 @@ def plot_khat( color = "C0" if isinstance(khats, np.ndarray): - _log.warning( + warnings.warn( "support for arrays will be deprecated, please use ELPDData." "The reason for this, is that we need to know the numbers of draws" - "sampled from the posterior" + "sampled from the posterior", + FutureWarning, ) khats = khats.flatten() xlabels = False legend = False dims = [] - sample_size = None + good_k = None else: if isinstance(khats, ELPDData): - sample_size = khats.n_samples + good_k = khats.good_k khats = khats.pareto_k if not isinstance(khats, DataArray): raise ValueError("Incorrect khat data input. Check the documentation") @@ -197,7 +199,7 @@ def plot_khat( figsize=figsize, xdata=xdata, khats=khats, - sample_size=sample_size, + good_k=good_k, kwargs=kwargs, threshold=threshold, coord_labels=coord_labels, diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index f2dc579203..e1484c3e92 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -819,8 +819,17 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None): if not pointwise: return ELPDData( - data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale], - index=["elpd_loo", "se", "p_loo", "n_samples", "n_data_points", "warning", "scale"], + data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale, good_k], + index=[ + "elpd_loo", + "se", + "p_loo", + "n_samples", + "n_data_points", + "warning", + "scale", + "good_k", + ], ) if np.equal(loo_lppd, loo_lppd_i).all(): # pylint: disable=no-member warnings.warn( @@ -838,6 +847,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None): loo_lppd_i.rename("loo_i"), pareto_shape, scale, + good_k, ], index=[ "elpd_loo", @@ -849,6 +859,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None): "loo_i", "pareto_k", "scale", + "good_k", ], ) diff --git a/arviz/stats/stats_utils.py b/arviz/stats/stats_utils.py index 7a5772f920..1f0f8c9892 100644 --- a/arviz/stats/stats_utils.py +++ b/arviz/stats/stats_utils.py @@ -454,10 +454,9 @@ def get_log_likelihood(idata, var_name=None, single_var=True): Pareto k diagnostic values: {{0:>{0}}} {{1:>6}} -(-Inf, 0.5] (good) {{2:{0}d}} {{6:6.1f}}% - (0.5, 0.7] (ok) {{3:{0}d}} {{7:6.1f}}% - (0.7, 1] (bad) {{4:{0}d}} {{8:6.1f}}% - (1, Inf) (very bad) {{5:{0}d}} {{9:6.1f}}% +(-Inf, {{8:.1f}}] (good) {{2:{0}d}} {{5:6.1f}}% + ({{8:.1f}}, 1] (bad) {{3:{0}d}} {{6:6.1f}}% + (1, Inf) (very bad) {{4:{0}d}} {{7:6.1f}}% """ SCALE_DICT = {"deviance": "deviance", "log": "elpd", "negative_log": "-elpd"} @@ -488,11 +487,14 @@ def __str__(self): base += "\n\nThere has been a warning during the calculation. Please check the results." if kind == "loo" and "pareto_k" in self: - bins = np.asarray([-np.inf, 0.5, 0.7, 1, np.inf]) + bins = np.asarray([-np.inf, self.good_k, 1, np.inf]) counts, *_ = _histogram(self.pareto_k.values, bins) extended = POINTWISE_LOO_FMT.format(max(4, len(str(np.max(counts))))) extended = extended.format( - "Count", "Pct.", *[*counts, *(counts / np.sum(counts) * 100)] + "Count", + "Pct.", + *[*counts, *(counts / np.sum(counts) * 100)], + self.good_k, ) base = "\n".join([base, extended]) return base