Skip to content

Commit

Permalink
Use revised Pareto k threshold (#2349)
Browse files Browse the repository at this point in the history
* use revised Pareto k threshold

* avoid duplicated computations

* fix per comments

* fix ValueError and warning

* update changelog
  • Loading branch information
aloctavodia authored Jun 5, 2024
1 parent e9d13bf commit 3a454f7
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 26 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## v0.x.x Unreleased

### New features
- Use revised Pareto k threshold ([2349](https://github.com/arviz-devs/arviz/pull/2349))

### Maintenance and fixes
- Ensure support with numpy 2.0 ([2321](https://github.com/arviz-devs/arviz/pull/2321))
Expand All @@ -12,6 +13,8 @@
- Fix legend overwriting issue in `plot_trace` ([2334](https://github.com/arviz-devs/arviz/pull/2334))

### Deprecation
- Support for arrays and DataArrays in plot_khat has been deprecated. Only ELPDdata will be supported in the future ([2349](https://github.com/arviz-devs/arviz/pull/2349))


### Documentation

Expand Down
11 changes: 8 additions & 3 deletions arviz/plots/backends/bokeh/khatplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def plot_khat(
figsize,
xdata,
khats,
good_k,
kwargs,
threshold,
coord_labels,
Expand Down Expand Up @@ -53,7 +54,11 @@ def plot_khat(

if hlines_kwargs is None:
hlines_kwargs = {}
hlines_kwargs.setdefault("hlines", [0, 0.5, 0.7, 1])

if good_k is None:
good_k = 0.7

hlines_kwargs.setdefault("hlines", [0, good_k, 1])

cmap = None
if isinstance(color, str):
Expand All @@ -75,7 +80,7 @@ def plot_khat(
rgba_c = cmap(color)

khats = khats if isinstance(khats, np.ndarray) else khats.values.flatten()
alphas = 0.5 + 0.2 * (khats > 0.5) + 0.3 * (khats > 1)
alphas = 0.5 + 0.2 * (khats > good_k) + 0.3 * (khats > 1)

rgba_c = vectorized_to_hex(rgba_c)

Expand Down Expand Up @@ -130,7 +135,7 @@ def plot_khat(
xmax = len(khats)

if show_bins:
bin_edges = np.array([ymin, 0.5, 0.7, 1, ymax])
bin_edges = np.array([ymin, good_k, 1, ymax])
bin_edges = bin_edges[(bin_edges >= ymin) & (bin_edges <= ymax)]
hist, _, _ = histogram(khats, bin_edges)
for idx, count in enumerate(hist):
Expand Down
10 changes: 7 additions & 3 deletions arviz/plots/backends/matplotlib/khatplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def plot_khat(
figsize,
xdata,
khats,
good_k,
kwargs,
threshold,
coord_labels,
Expand Down Expand Up @@ -61,8 +62,11 @@ def plot_khat(
backend_kwargs.setdefault("figsize", figsize)
backend_kwargs["squeeze"] = True

if good_k is None:
good_k = 0.7

hlines_kwargs = matplotlib_kwarg_dealiaser(hlines_kwargs, "hlines")
hlines_kwargs.setdefault("hlines", [0, 0.5, 0.7, 1])
hlines_kwargs.setdefault("hlines", [0, good_k, 1])
hlines_kwargs.setdefault("linestyle", [":", "-.", "--", "-"])
hlines_kwargs.setdefault("alpha", 0.7)
hlines_kwargs.setdefault("zorder", -1)
Expand Down Expand Up @@ -102,7 +106,7 @@ def plot_khat(
rgba_c = cmap(norm_fun(color))

khats = khats if isinstance(khats, np.ndarray) else khats.values.flatten()
alphas = 0.5 + 0.2 * (khats > 0.5) + 0.3 * (khats > 1)
alphas = 0.5 + 0.2 * (khats > good_k) + 0.3 * (khats > 1)
rgba_c[:, 3] = alphas
rgba_c = vectorized_to_hex(rgba_c)
kwargs["c"] = rgba_c
Expand Down Expand Up @@ -151,7 +155,7 @@ def plot_khat(
)

if show_bins:
bin_edges = np.array([ymin, 0.5, 0.7, 1, ymax])
bin_edges = np.array([ymin, good_k, 1, ymax])
bin_edges = bin_edges[(bin_edges >= ymin) & (bin_edges <= ymax)]
hist, _, _ = histogram(khats, bin_edges)
for idx, count in enumerate(hist):
Expand Down
24 changes: 20 additions & 4 deletions arviz/plots/khatplot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Pareto tail indices plot."""

import logging
import warnings

import numpy as np
from xarray import DataArray
Expand Down Expand Up @@ -40,10 +41,8 @@ def plot_khat(
Parameters
----------
khats : ELPDData or array-like
The input Pareto tail indices to be plotted. It can be an ``ELPDData`` object containing
Pareto shapes or an array. In this second case, all the values in the array are interpreted
as Pareto tail indices.
khats : ELPDData
The input Pareto tail indices to be plotted.
color : str or array_like, default "C0"
Colors of the scatter plot, if color is a str all dots will have the same color,
if it is the size of the observations, each dot will have the specified color,
Expand Down Expand Up @@ -165,13 +164,29 @@ def plot_khat(
color = "C0"

if isinstance(khats, np.ndarray):
warnings.warn(
"support for arrays will be deprecated, please use ELPDData."
"The reason for this, is that we need to know the numbers of draws"
"sampled from the posterior",
FutureWarning,
)
khats = khats.flatten()
xlabels = False
legend = False
dims = []
good_k = None
else:
if isinstance(khats, ELPDData):
good_k = khats.good_k
khats = khats.pareto_k
else:
good_k = None
warnings.warn(
"support for DataArrays will be deprecated, please use ELPDData."
"The reason for this, is that we need to know the numbers of draws"
"sampled from the posterior",
FutureWarning,
)
if not isinstance(khats, DataArray):
raise ValueError("Incorrect khat data input. Check the documentation")

Expand All @@ -192,6 +207,7 @@ def plot_khat(
figsize=figsize,
xdata=xdata,
khats=khats,
good_k=good_k,
kwargs=kwargs,
threshold=threshold,
coord_labels=coord_labels,
Expand Down
34 changes: 24 additions & 10 deletions arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,8 +715,9 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
se: standard error of the elpd
p_loo: effective number of parameters
shape_warn: bool
True if the estimated shape parameter of
Pareto distribution is greater than 0.7 for one or more samples
True if the estimated shape parameter of Pareto distribution is greater than a thresold
value for one or more samples. For a sample size S, the thresold is compute as
min(1 - 1/log10(S), 0.7)
loo_i: array of pointwise predictive accuracy, only if pointwise True
pareto_k: array of Pareto shape values, only if pointwise True
scale: scale of the elpd
Expand Down Expand Up @@ -785,13 +786,15 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
log_weights += log_likelihood

warn_mg = False
if np.any(pareto_shape > 0.7):
good_k = min(1 - 1 / np.log10(n_samples), 0.7)

if np.any(pareto_shape > good_k):
warnings.warn(
"Estimated shape parameter of Pareto distribution is greater than 0.7 for "
"one or more samples. You should consider using a more robust model, this is because "
"importance sampling is less likely to work well if the marginal posterior and "
"LOO posterior are very different. This is more likely to happen with a non-robust "
"model and highly influential observations."
f"Estimated shape parameter of Pareto distribution is greater than {good_k:.2f} "
"for one or more samples. You should consider using a more robust model, this is "
"because importance sampling is less likely to work well if the marginal posterior "
"and LOO posterior are very different. This is more likely to happen with a "
"non-robust model and highly influential observations."
)
warn_mg = True

Expand All @@ -816,8 +819,17 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):

if not pointwise:
return ELPDData(
data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale],
index=["elpd_loo", "se", "p_loo", "n_samples", "n_data_points", "warning", "scale"],
data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale, good_k],
index=[
"elpd_loo",
"se",
"p_loo",
"n_samples",
"n_data_points",
"warning",
"scale",
"good_k",
],
)
if np.equal(loo_lppd, loo_lppd_i).all(): # pylint: disable=no-member
warnings.warn(
Expand All @@ -835,6 +847,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
loo_lppd_i.rename("loo_i"),
pareto_shape,
scale,
good_k,
],
index=[
"elpd_loo",
Expand All @@ -846,6 +859,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
"loo_i",
"pareto_k",
"scale",
"good_k",
],
)

Expand Down
14 changes: 8 additions & 6 deletions arviz/stats/stats_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,10 +454,9 @@ def get_log_likelihood(idata, var_name=None, single_var=True):
Pareto k diagnostic values:
{{0:>{0}}} {{1:>6}}
(-Inf, 0.5] (good) {{2:{0}d}} {{6:6.1f}}%
(0.5, 0.7] (ok) {{3:{0}d}} {{7:6.1f}}%
(0.7, 1] (bad) {{4:{0}d}} {{8:6.1f}}%
(1, Inf) (very bad) {{5:{0}d}} {{9:6.1f}}%
(-Inf, {{8:.2f}}] (good) {{2:{0}d}} {{5:6.1f}}%
({{8:.2f}}, 1] (bad) {{3:{0}d}} {{6:6.1f}}%
(1, Inf) (very bad) {{4:{0}d}} {{7:6.1f}}%
"""
SCALE_DICT = {"deviance": "deviance", "log": "elpd", "negative_log": "-elpd"}

Expand Down Expand Up @@ -488,11 +487,14 @@ def __str__(self):
base += "\n\nThere has been a warning during the calculation. Please check the results."

if kind == "loo" and "pareto_k" in self:
bins = np.asarray([-np.inf, 0.5, 0.7, 1, np.inf])
bins = np.asarray([-np.inf, self.good_k, 1, np.inf])
counts, *_ = _histogram(self.pareto_k.values, bins)
extended = POINTWISE_LOO_FMT.format(max(4, len(str(np.max(counts)))))
extended = extended.format(
"Count", "Pct.", *[*counts, *(counts / np.sum(counts) * 100)]
"Count",
"Pct.",
*[*counts, *(counts / np.sum(counts) * 100)],
self.good_k,
)
base = "\n".join([base, extended])
return base
Expand Down

0 comments on commit 3a454f7

Please sign in to comment.