diff --git a/CHANGELOG.md b/CHANGELOG.md index a2a8f05f0d..c1c210dce4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,16 @@ # Change Log +## v0.x.x Unreleased + +### New features + +### Maintenance and fixes +- Fix dimension ordering for `plot_trace` with divergences ([2151](https://github.com/arviz-devs/arviz/pull/2151)) + +### Deprecation + +### Documentation + ## v0.13.0 (2022 Oct 22) ### New features diff --git a/arviz/plots/traceplot.py b/arviz/plots/traceplot.py index f6113b948a..f8dbde5ad8 100644 --- a/arviz/plots/traceplot.py +++ b/arviz/plots/traceplot.py @@ -174,7 +174,9 @@ def plot_trace( divergences = "top" if rug else "bottom" if divergences: try: - divergence_data = convert_to_dataset(data, group="sample_stats").diverging + divergence_data = convert_to_dataset(data, group="sample_stats").diverging.transpose( + "chain", "draw" + ) except (ValueError, AttributeError): # No sample_stats, or no `.diverging` divergences = None diff --git a/arviz/tests/base_tests/test_plots_matplotlib.py b/arviz/tests/base_tests/test_plots_matplotlib.py index 0dbf9ecfa2..8fad67e249 100644 --- a/arviz/tests/base_tests/test_plots_matplotlib.py +++ b/arviz/tests/base_tests/test_plots_matplotlib.py @@ -294,6 +294,12 @@ def test_plot_trace_invalid_varname_warning(models, kwargs): assert axes.shape +def test_plot_trace_diverging_correctly_transposed(): + idata = load_arviz_data("centered_eight") + idata.sample_stats["diverging"] = idata.sample_stats.diverging.T + plot_trace(idata, divergences="bottom") + + @pytest.mark.parametrize( "bad_kwargs", [{"var_names": ["mu", "tau"], "lines": [("mu", {}, ["hey"])]}] )