diff --git a/CHANGELOG.md b/CHANGELOG.md index 80171eb746..e8956ad4d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ * Added confidence interval band to auto-correlation plot ([1535](https://github.com/arviz-devs/arviz/pull/1535)) ### Maintenance and fixes +* Updated `from_cmdstan` and `from_numpyro` converter to follow schema convention ([1541](https://github.com/arviz-devs/arviz/pull/1541) and [1525](https://github.com/arviz-devs/arviz/pull/1525)) ### Deprecation * Removed Geweke diagnostic ([1545](https://github.com/arviz-devs/arviz/pull/1545)) diff --git a/arviz/data/io_cmdstan.py b/arviz/data/io_cmdstan.py index 9b9a350d14..368e3ccf71 100644 --- a/arviz/data/io_cmdstan.py +++ b/arviz/data/io_cmdstan.py @@ -203,22 +203,24 @@ def posterior_to_xarray(self): def sample_stats_to_xarray(self): """Extract sample_stats from fit.""" dtypes = {"divergent__": bool, "n_leapfrog__": np.int64, "treedepth__": np.int64} + rename_dict = { + "divergent": "diverging", + "n_leapfrog": "n_steps", + "treedepth": "tree_depth", + "stepsize": "step_size", + "accept_stat": "acceptance_rate", + } sampler_params, sampler_params_warmup = self.sample_stats for j, s_params in enumerate(sampler_params): - rename_dict = {} for key in s_params: - key_, *end = key.split(".") - name = re.sub("__$", "", key_) - name = "diverging" if name == "divergent" else name - rename_dict[key] = ".".join((name, *end)) - sampler_params[j][key] = s_params[key].astype(dtypes.get(key_)) - sampler_params_warmup[j][key] = sampler_params_warmup[j][key].astype( - dtypes.get(key_) + name = re.sub("__$", "", key) + name = rename_dict.get(name, name) + sampler_params[j][name] = s_params[key].astype(dtypes.get(key)) + sampler_params_warmup[j][name] = sampler_params_warmup[j][key].astype( + dtypes.get(key) ) - sampler_params[j] = sampler_params[j].rename(columns=rename_dict) - sampler_params_warmup[j] = sampler_params_warmup[j].rename(columns=rename_dict) data = _unpack_dataframes(sampler_params) data_warmup = _unpack_dataframes(sampler_params_warmup) return (