Skip to content

Commit

Permalink
avoid duplicated computations
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed May 28, 2024
1 parent 41e67e0 commit f7dc5ac
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 23 deletions.
8 changes: 3 additions & 5 deletions arviz/plots/backends/bokeh/khatplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def plot_khat(
figsize,
xdata,
khats,
sample_size,
good_k,
kwargs,
threshold,
coord_labels,
Expand Down Expand Up @@ -55,10 +55,8 @@ def plot_khat(
if hlines_kwargs is None:
hlines_kwargs = {}

if sample_size is None:
if good_k is None:
good_k = 0.7
else:
good_k = min(1 - 1 / np.log10(sample_size), 0.7)

hlines_kwargs.setdefault("hlines", [0, good_k, 1])

Expand Down Expand Up @@ -137,7 +135,7 @@ def plot_khat(
xmax = len(khats)

if show_bins:
bin_edges = np.array([ymin, threshold, 1, ymax])
bin_edges = np.array([ymin, good_k, 1, ymax])
bin_edges = bin_edges[(bin_edges >= ymin) & (bin_edges <= ymax)]
hist, _, _ = histogram(khats, bin_edges)
for idx, count in enumerate(hist):
Expand Down
8 changes: 3 additions & 5 deletions arviz/plots/backends/matplotlib/khatplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def plot_khat(
figsize,
xdata,
khats,
sample_size,
good_k,
kwargs,
threshold,
coord_labels,
Expand Down Expand Up @@ -62,10 +62,8 @@ def plot_khat(
backend_kwargs.setdefault("figsize", figsize)
backend_kwargs["squeeze"] = True

if sample_size is None:
if good_k is None:
good_k = 0.7
else:
good_k = min(1 - 1 / np.log10(sample_size), 0.7)

hlines_kwargs = matplotlib_kwarg_dealiaser(hlines_kwargs, "hlines")
hlines_kwargs.setdefault("hlines", [0, good_k, 1])
Expand Down Expand Up @@ -157,7 +155,7 @@ def plot_khat(
)

if show_bins:
bin_edges = np.array([ymin, 0.5, 0.7, 1, ymax])
bin_edges = np.array([ymin, good_k, 1, ymax])
bin_edges = bin_edges[(bin_edges >= ymin) & (bin_edges <= ymax)]
hist, _, _ = histogram(khats, bin_edges)
for idx, count in enumerate(hist):
Expand Down
12 changes: 7 additions & 5 deletions arviz/plots/khatplot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Pareto tail indices plot."""

import logging
import warnings

import numpy as np
from xarray import DataArray
Expand Down Expand Up @@ -163,19 +164,20 @@ def plot_khat(
color = "C0"

if isinstance(khats, np.ndarray):
_log.warning(
warnings.warn(
"support for arrays will be deprecated, please use ELPDData."
"The reason for this, is that we need to know the numbers of draws"
"sampled from the posterior"
"sampled from the posterior",
FutureWarning,
)
khats = khats.flatten()
xlabels = False
legend = False
dims = []
sample_size = None
good_k = None
else:
if isinstance(khats, ELPDData):
sample_size = khats.n_samples
good_k = khats.good_k
khats = khats.pareto_k
if not isinstance(khats, DataArray):
raise ValueError("Incorrect khat data input. Check the documentation")
Expand All @@ -197,7 +199,7 @@ def plot_khat(
figsize=figsize,
xdata=xdata,
khats=khats,
sample_size=sample_size,
good_k=good_k,
kwargs=kwargs,
threshold=threshold,
coord_labels=coord_labels,
Expand Down
15 changes: 13 additions & 2 deletions arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,8 +819,17 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):

if not pointwise:
return ELPDData(
data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale],
index=["elpd_loo", "se", "p_loo", "n_samples", "n_data_points", "warning", "scale"],
data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale, good_k],
index=[
"elpd_loo",
"se",
"p_loo",
"n_samples",
"n_data_points",
"warning",
"scale",
"good_k",
],
)
if np.equal(loo_lppd, loo_lppd_i).all(): # pylint: disable=no-member
warnings.warn(
Expand All @@ -838,6 +847,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
loo_lppd_i.rename("loo_i"),
pareto_shape,
scale,
good_k,
],
index=[
"elpd_loo",
Expand All @@ -849,6 +859,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
"loo_i",
"pareto_k",
"scale",
"good_k",
],
)

Expand Down
14 changes: 8 additions & 6 deletions arviz/stats/stats_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,10 +454,9 @@ def get_log_likelihood(idata, var_name=None, single_var=True):
Pareto k diagnostic values:
{{0:>{0}}} {{1:>6}}
(-Inf, 0.5] (good) {{2:{0}d}} {{6:6.1f}}%
(0.5, 0.7] (ok) {{3:{0}d}} {{7:6.1f}}%
(0.7, 1] (bad) {{4:{0}d}} {{8:6.1f}}%
(1, Inf) (very bad) {{5:{0}d}} {{9:6.1f}}%
(-Inf, {{8:.1f}}] (good) {{2:{0}d}} {{5:6.1f}}%
({{8:.1f}}, 1] (bad) {{3:{0}d}} {{6:6.1f}}%
(1, Inf) (very bad) {{4:{0}d}} {{7:6.1f}}%
"""
SCALE_DICT = {"deviance": "deviance", "log": "elpd", "negative_log": "-elpd"}

Expand Down Expand Up @@ -488,11 +487,14 @@ def __str__(self):
base += "\n\nThere has been a warning during the calculation. Please check the results."

if kind == "loo" and "pareto_k" in self:
bins = np.asarray([-np.inf, 0.5, 0.7, 1, np.inf])
bins = np.asarray([-np.inf, self.good_k, 1, np.inf])
counts, *_ = _histogram(self.pareto_k.values, bins)
extended = POINTWISE_LOO_FMT.format(max(4, len(str(np.max(counts)))))
extended = extended.format(
"Count", "Pct.", *[*counts, *(counts / np.sum(counts) * 100)]
"Count",
"Pct.",
*[*counts, *(counts / np.sum(counts) * 100)],
self.good_k,
)
base = "\n".join([base, extended])
return base
Expand Down

0 comments on commit f7dc5ac

Please sign in to comment.