Skip to content

Commit

Permalink
Update mcse_sd calculation to not use normality assumption. (#2167)
Browse files Browse the repository at this point in the history
* Update mcsd calculation to not use normality assumption.

See https://github.com/stan-dev/posterior/pull/233/files

* Update diagnostics.py

* Remove ddof=1 from mcse_sd

Co-authored-by: Osvaldo A Martin <aloctavodia@gmail.com>

* Fix black

* Update arviz/stats/diagnostics.py

Co-authored-by: Seth Axen <seth@sethaxen.com>

* update changelog

* black

* fix function and update tests

* black again

* fix multichain diagnostics version of mcse_sd

* use fix in both ess and mcse

---------

Co-authored-by: Osvaldo A Martin <aloctavodia@gmail.com>
Co-authored-by: Seth Axen <seth@sethaxen.com>
Co-authored-by: Oriol (ProDesk) <oriol.abril.pla@gmail.com>
  • Loading branch information
4 people authored Dec 24, 2024
1 parent 00658ff commit 27bd755
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 66 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
### Maintenance and fixes
- Make `arviz.data.generate_dims_coords` handle `dims` and `default_dims` consistently ([2395](https://github.com/arviz-devs/arviz/pull/2395))
- Only emit a warning for custom groups in `InferenceData` when explicitly requested ([2401](https://github.com/arviz-devs/arviz/pull/2401))
- Update `method="sd"` of `mcse` to not use normality assumption ([2167](https://github.com/arviz-devs/arviz/pull/2167))

### Documentation
- Add example of ECDF comparison plot to gallery ([2178](https://github.com/arviz-devs/arviz/pull/2178))
Expand Down
32 changes: 18 additions & 14 deletions arviz/stats/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,8 +744,8 @@ def _ess_sd(ary, relative=False):
ary = np.asarray(ary)
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
return np.nan
ary = _split_chains(ary)
return min(_ess(ary, relative=relative), _ess(ary**2, relative=relative))
ary = (ary - ary.mean()) ** 2
return _ess(_split_chains(ary), relative=relative)


def _ess_quantile(ary, prob, relative=False):
Expand Down Expand Up @@ -838,13 +838,15 @@ def _mcse_sd(ary):
ary = np.asarray(ary)
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
return np.nan
ess = _ess_sd(ary)
sims_c2 = (ary - ary.mean()) ** 2
ess = _ess_mean(sims_c2)
evar = (sims_c2).mean()
varvar = ((sims_c2**2).mean() - evar**2) / ess
varsd = varvar / evar / 4
if _numba_flag:
sd = float(_sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1)).item())
mcse_sd_value = float(_sqrt(np.ravel(varsd), np.zeros(1)))
else:
sd = np.std(ary, ddof=1)
fac_mcse_sd = np.sqrt(np.exp(1) * (1 - 1 / ess) ** (ess - 1) - 1)
mcse_sd_value = sd * fac_mcse_sd
mcse_sd_value = np.sqrt(varsd)
return mcse_sd_value


Expand Down Expand Up @@ -973,19 +975,21 @@ def _multichain_statistics(ary, focus="mean"):
# ess mean
ess_mean_value = _ess_mean(ary)

# ess sd
ess_sd_value = _ess_sd(ary)

# mcse_mean
sd = np.std(ary, ddof=1)
mcse_mean_value = sd / np.sqrt(ess_mean_value)
sims_c2 = (ary - ary.mean()) ** 2
sims_c2_sum = sims_c2.sum()
var = sims_c2_sum / (sims_c2.size - 1)
mcse_mean_value = np.sqrt(var / ess_mean_value)

# ess bulk
ess_bulk_value = _ess(z_split)

# mcse_sd
fac_mcse_sd = np.sqrt(np.exp(1) * (1 - 1 / ess_sd_value) ** (ess_sd_value - 1) - 1)
mcse_sd_value = sd * fac_mcse_sd
evar = sims_c2_sum / sims_c2.size
ess_mean_sims = _ess_mean(sims_c2)
varvar = ((sims_c2**2).mean() - evar**2) / ess_mean_sims
varsd = varvar / evar / 4
mcse_sd_value = np.sqrt(varsd)

return (
mcse_mean_value,
Expand Down
9 changes: 5 additions & 4 deletions arviz/tests/base_tests/test_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,11 @@ def test_deterministic(self):
```
Reference file:
Created: 2020-08-31
System: Ubuntu 18.04.5 LTS
R version 4.0.2 (2020-06-22)
posterior 0.1.2
Created: 2024-12-20
System: Ubuntu 24.04.1 LTS
R version 4.4.2 (2024-10-31)
posterior version from https://github.com/stan-dev/posterior/pull/388
(after release 1.6.0 but before the fixes in the PR were released).
"""
# download input files
here = os.path.dirname(os.path.abspath(__file__))
Expand Down
Loading

0 comments on commit 27bd755

Please sign in to comment.