Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update diagnostics #1366

Merged
merged 15 commits into from
Sep 2, 2020
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* Replaced `_fast_kde()` with `kde()` which now also supports circular variables via the argument `circular` ([1284](https://github.com/arviz-devs/arviz/pull/1284)).
* Increased `from_pystan` attrs information content ([1353](https://github.com/arviz-devs/arviz/pull/1353))
* Allow `plot_trace` to return and accept axes ([#1361](https://github.com/arviz-devs/arviz/pull/1361))
* Update diagnostics to be on par with posterior package ([#1366](https://github.com/arviz-devs/arviz/pull/1366))

### Deprecation

Expand Down
75 changes: 48 additions & 27 deletions arviz/stats/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from .stats_utils import autocov as _autocov
from .stats_utils import not_valid as _not_valid
from .stats_utils import quantile as _quantile
from .stats_utils import rint as _rint
from .stats_utils import stats_variance_2d as svar
from .stats_utils import wrap_xarray_ufunc as _wrap_xarray_ufunc

Expand Down Expand Up @@ -321,6 +320,7 @@ def mcse(data, *, var_names=None, method="mean", prob=None):
Select mcse method. Valid methods are:
- "mean"
- "sd"
- "median"
- "quantile"

prob : float
Expand All @@ -345,10 +345,15 @@ def mcse(data, *, var_names=None, method="mean", prob=None):

.. ipython::

In [1]: az.mcse(data, method="quantile", prob=.7)
In [1]: az.mcse(data, method="quantile", prob=0.7)

"""
methods = {"mean": _mcse_mean, "sd": _mcse_sd, "quantile": _mcse_quantile}
methods = {
"mean": _mcse_mean,
"sd": _mcse_sd,
"median": _mcse_median,
"quantile": _mcse_quantile,
}
if method not in methods:
raise TypeError(
"mcse method {} not found. Valid methods are:\n{}".format(
Expand Down Expand Up @@ -530,6 +535,29 @@ def _bfmi(energy):
return num / den


def _backtransform_ranks(arr, c=3 / 8): # pylint: disable=invalid-name
"""Backtransformation of ranks.
canyon289 marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
arr : np.ndarray
Ranks array
c : float
Fractional offset. Defaults to c = 3/8 as recommended by Blom (1958).

Returns
-------
np.ndarray

References
----------
Blom, G. (1958). Statistical Estimates and Transformed Beta-Variables. Wiley; New York.
"""
arr = np.asarray(arr)
size = arr.size
return (arr - c) / (size - 2 * c + 1)


def _z_scale(ary):
"""Calculate z_scale.

Expand All @@ -542,9 +570,9 @@ def _z_scale(ary):
np.ndarray
"""
ary = np.asarray(ary)
size = ary.size
rank = stats.rankdata(ary, method="average")
z = stats.norm.ppf((rank - 0.5) / size)
rank = _backtransform_ranks(rank)
z = stats.norm.ppf(rank)
z = z.reshape(ary.shape)
return z

Expand Down Expand Up @@ -811,26 +839,6 @@ def _ess_identity(ary, relative=False):
return _ess(ary, relative=relative)


def _conv_quantile(ary, prob):
"""Return mcse, Q05, Q95, Seff."""
ary = np.asarray(ary)
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
return np.nan, np.nan, np.nan, np.nan
ess = _ess_quantile(ary, prob)
probability = [0.1586553, 0.8413447, 0.05, 0.95]
with np.errstate(invalid="ignore"):
ppf = stats.beta.ppf(probability, ess * prob + 1, ess * (1 - prob) + 1)
sorted_ary = np.sort(ary.ravel())
size = sorted_ary.size
ppf_size = ppf * size - 1
th1 = sorted_ary[_rint(np.nanmax((ppf_size[0], 0)))]
th2 = sorted_ary[_rint(np.nanmin((ppf_size[1], size - 1)))]
mcse_quantile = (th2 - th1) / 2
th1 = sorted_ary[_rint(np.nanmax((ppf_size[2], 0)))]
th2 = sorted_ary[_rint(np.nanmin((ppf_size[3], size - 1)))]
return mcse_quantile, th1, th2, ess


def _mcse_mean(ary):
"""Compute the Markov Chain mean error."""
_numba_flag = Numba.numba_flag
Expand Down Expand Up @@ -862,13 +870,26 @@ def _mcse_sd(ary):
return mcse_sd_value


def _mcse_median(ary):
"""Compute the Markov Chain median error."""
return _mcse_quantile(ary, 0.5)


def _mcse_quantile(ary, prob):
"""Compute the Markov Chain quantile error at quantile=prob."""
ary = np.asarray(ary)
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
return np.nan
mcse_q, *_ = _conv_quantile(ary, prob)
return mcse_q
ess = _ess_quantile(ary, prob)
probability = [0.1586553, 0.8413447]
with np.errstate(invalid="ignore"):
ppf = stats.beta.ppf(probability, ess * prob + 1, ess * (1 - prob) + 1)
sorted_ary = np.sort(ary.ravel())
size = sorted_ary.size
ppf_size = ppf * size - 1
th1 = sorted_ary[int(np.floor(np.nanmax((ppf_size[0], 0))))]
th2 = sorted_ary[int(np.ceil(np.nanmin((ppf_size[1], size - 1))))]
return (th2 - th1) / 2


def _mc_error(ary, batches=5, circular=False):
Expand Down
6 changes: 0 additions & 6 deletions arviz/stats/stats_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,12 +318,6 @@ def logsumexp(ary, *, b=None, b_inv=None, axis=None, keepdims=False, out=None, c
return out if out.shape else dtype(out)


def rint(num):
"""Round and change to ingeter."""
rnum = np.rint(num) # pylint: disable=assignment-from-no-return
return int(rnum)


def quantile(ary, q, axis=None, limit=None):
"""Use same quantile function as R (Type 7)."""
if limit is None:
Expand Down
91 changes: 54 additions & 37 deletions arviz/tests/base_tests/test_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from ...rcparams import rcParams
from ...stats import bfmi, ess, geweke, mcse, rhat
from ...stats.diagnostics import (
_conv_quantile,
_ess,
_ess_quantile,
_mc_error,
_mcse_quantile,
_multichain_statistics,
_rhat,
_rhat_rank,
Expand Down Expand Up @@ -54,63 +54,73 @@ def test_bfmi_dataset_bad(self):

def test_deterministic(self):
"""
Test algorithm against RStan monitor.R functions.
monitor.R :
https://github.com/stan-dev/rstan/blob/425d195565c4d9bcbcb8cccf513e140e6908ca62/rstan/rstan/R/monitor.R
Test algorithm against posterior (R) convergence functions.

posterior: https://github.com/stan-dev/posterior
R code:
```
source('~/monitor.R')
library("posterior")
data2 <- read.csv("blocker.2.csv", comment.char = "#")
data1 <- read.csv("blocker.1.csv", comment.char = "#")
output <- matrix(ncol=15, nrow=length(names(data1))-4)
output <- matrix(ncol=17, nrow=length(names(data1))-4)
j = 0
for (i in 1:length(names(data1))) {
name = names(data1)[i]
ary = matrix(c(data1[,name], data2[,name]), 1000, 2)
if (!endsWith(name, "__"))
j <- j + 1
output[j,] <- c(
rhat(ary),
rhat_rfun(ary),
ess_bulk(ary),
ess_tail(ary),
ess_mean(ary),
ess_sd(ary),
ess_rfun(ary),
ess_quantile(ary, 0.01),
ess_quantile(ary, 0.1),
ess_quantile(ary, 0.3),
mcse_mean(ary),
mcse_sd(ary),
mcse_quantile(ary, prob=0.01),
mcse_quantile(ary, prob=0.1),
mcse_quantile(ary, prob=0.3))
}
name = names(data1)[i]
ary = matrix(c(data1[,name], data2[,name]), 1000, 2)
if (!endsWith(name, "__"))
j <- j + 1
output[j,] <- c(
posterior::rhat(ary),
posterior::rhat_basic(ary, FALSE),
posterior::ess_bulk(ary),
posterior::ess_tail(ary),
posterior::ess_mean(ary),
posterior::ess_sd(ary),
posterior::ess_median(ary),
posterior::ess_basic(ary, FALSE),
posterior::ess_quantile(ary, 0.01),
posterior::ess_quantile(ary, 0.1),
posterior::ess_quantile(ary, 0.3),
posterior::mcse_mean(ary),
posterior::mcse_sd(ary),
posterior::mcse_median(ary),
posterior::mcse_quantile(ary, prob=0.01),
posterior::mcse_quantile(ary, prob=0.1),
posterior::mcse_quantile(ary, prob=0.3))
}
df = data.frame(output, row.names = names(data1)[5:ncol(data1)])
colnames(df) <- c("rhat_rank",
"rhat_raw",
"ess_bulk",
"ess_tail",
"ess_mean",
"ess_sd",
"ess_median",
"ess_raw",
"ess_quantile01",
"ess_quantile10",
"ess_quantile30",
"mcse_mean",
"mcse_sd",
"mcse_median",
"mcse_quantile01",
"mcse_quantile10",
"mcse_quantile30")
write.csv(df, "reference_values.csv")
write.csv(df, "reference_posterior.csv")
```
Reference file:

Created: 2020-08-31
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very detailed

System: Ubuntu 18.04.5 LTS
R version 4.0.2 (2020-06-22)
posterior 0.1.2
"""
# download input files
here = os.path.dirname(os.path.abspath(__file__))
data_directory = os.path.join(here, "..", "saved_models")
path = os.path.join(data_directory, "stan_diagnostics", "blocker.[0-9].csv")
posterior = from_cmdstan(path)
reference_path = os.path.join(data_directory, "stan_diagnostics", "reference_values.csv")
reference_path = os.path.join(data_directory, "stan_diagnostics", "reference_posterior.csv")
reference = pd.read_csv(reference_path, index_col=0).sort_index(axis=1).sort_index(axis=0)
# test arviz functions
funcs = {
Expand All @@ -120,12 +130,14 @@ def test_deterministic(self):
"ess_tail": lambda x: ess(x, method="tail"),
"ess_mean": lambda x: ess(x, method="mean"),
"ess_sd": lambda x: ess(x, method="sd"),
"ess_median": lambda x: ess(x, method="median"),
"ess_raw": lambda x: ess(x, method="identity"),
"ess_quantile01": lambda x: ess(x, method="quantile", prob=0.01),
"ess_quantile10": lambda x: ess(x, method="quantile", prob=0.1),
"ess_quantile30": lambda x: ess(x, method="quantile", prob=0.3),
"mcse_mean": lambda x: mcse(x, method="mean"),
"mcse_sd": lambda x: mcse(x, method="sd"),
"mcse_median": lambda x: mcse(x, method="median"),
"mcse_quantile01": lambda x: mcse(x, method="quantile", prob=0.01),
"mcse_quantile10": lambda x: mcse(x, method="quantile", prob=0.1),
"mcse_quantile30": lambda x: mcse(x, method="quantile", prob=0.3),
Expand All @@ -136,21 +148,26 @@ def test_deterministic(self):
key = key + ".{}".format(list(coord_dict.values())[0] + 1)
results[key] = {func_name: func(vals) for func_name, func in funcs.items()}
arviz_data = pd.DataFrame.from_dict(results).T.sort_index(axis=1).sort_index(axis=0)

# check column names
assert set(arviz_data.columns) == set(reference.columns)

# check parameter names
assert set(arviz_data.index) == set(reference.index)

# check equality (rhat_rank has accuracy < 6e-5, atleast with this data, R vs Py)
# this is due to numerical accuracy in calculation leading to rankdata
# function, which scales minimal difference to larger scale
# test first with numpy
assert_array_almost_equal(reference, arviz_data, decimal=4)

# then test manually (more strict)
assert (abs(reference["rhat_rank"] - arviz_data["rhat_rank"]) < 6e-5).all(None)
assert abs(np.median(reference["rhat_rank"] - arviz_data["rhat_rank"]) < 1e-14).all(None)
assert abs(np.median(reference["rhat_rank"] - arviz_data["rhat_rank"]) < 1e-8).all(None)

not_rhat = [col for col in reference.columns if col != "rhat_rank"]
assert (abs((reference[not_rhat] - arviz_data[not_rhat])).values < 1e-8).all(None)
assert abs(np.median(reference[not_rhat] - arviz_data[not_rhat]) < 1e-14).all(None)
assert abs(np.median(reference[not_rhat] - arviz_data[not_rhat]) < 1e-8).all(None)

@pytest.mark.parametrize("method", ("rank", "split", "folded", "z_scale", "identity"))
@pytest.mark.parametrize("var_names", (None, "mu", ["mu", "tau"]))
Expand Down Expand Up @@ -350,7 +367,7 @@ def test_effective_sample_size_dataset(self, data, method, var_names, relative):
ess_hat = ess(data, var_names=var_names, method=method, relative=relative)
assert np.all(ess_hat.mu.values > n_low) # This might break if the data is regenerated

@pytest.mark.parametrize("mcse_method", ("mean", "sd", "quantile"))
@pytest.mark.parametrize("mcse_method", ("mean", "sd", "median", "quantile"))
def test_mcse_array(self, mcse_method):
if mcse_method == "quantile":
mcse_hat = mcse(np.random.randn(4, 100), method=mcse_method, prob=0.34)
Expand All @@ -362,7 +379,7 @@ def test_mcse_ndarray(self):
with pytest.raises(TypeError):
mcse(np.random.randn(2, 300, 10))

@pytest.mark.parametrize("mcse_method", ("mean", "sd", "quantile"))
@pytest.mark.parametrize("mcse_method", ("mean", "sd", "median", "quantile"))
@pytest.mark.parametrize("var_names", (None, "mu", ["mu", "tau"]))
def test_mcse_dataset(self, data, mcse_method, var_names):
if mcse_method == "quantile":
Expand All @@ -371,7 +388,7 @@ def test_mcse_dataset(self, data, mcse_method, var_names):
mcse_hat = mcse(data, var_names=var_names, method=mcse_method)
assert mcse_hat # This might break if the data is regenerated

@pytest.mark.parametrize("mcse_method", ("mean", "sd", "quantile"))
@pytest.mark.parametrize("mcse_method", ("mean", "sd", "median", "quantile"))
@pytest.mark.parametrize("chain", (None, 1, 2))
@pytest.mark.parametrize("draw", (1, 2, 3, 4))
@pytest.mark.parametrize("use_nan", (True, False))
Expand Down Expand Up @@ -515,12 +532,12 @@ def test_mc_error_nan(self, size, ndim):
else:
assert np.isnan(_mc_error(x))

@pytest.mark.parametrize("func", ("_conv_quantile", "_z_scale"))
@pytest.mark.parametrize("func", ("_mcse_quantile", "_z_scale"))
def test_nan_behaviour(self, func):
data = np.random.randn(100, 4)
data[0, 0] = np.nan # pylint: disable=unsupported-assignment-operation
if func == "_conv_quantile":
assert np.isnan(_conv_quantile(data, 0.5)).all(None)
if func == "_mcse_quantile":
assert np.isnan(_mcse_quantile(data, 0.5)).all(None)
else:
assert not np.isnan(_z_scale(data)).all(None)
assert not np.isnan(_z_scale(data)).any(None)
Expand Down
Loading