Skip to content

Commit

Permalink
Correctly (re)order dimensions for bfmi and plot_energy (#2126)
Browse files Browse the repository at this point in the history
* Order dimensions as expected by bfmi

* Test dimensions correctly ordered

* Run black

* Update CHANGELOG.md [skip ci]

* Standardize dimensions

* Add test for plot_energy axis ordering

* Update CHANGELOG.md

* Run black
  • Loading branch information
sethaxen authored Oct 2, 2022
1 parent 3c2dc6e commit bc1e887
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
([2096](https://github.com/arviz-devs/arviz/pull/2096) and [2105](https://github.com/arviz-devs/arviz/pull/2105))
* Bokeh kde contour plots started to use `contourpy` package ([2104](https://github.com/arviz-devs/arviz/pull/2104))
* Update default Bokeh markers for rcparams ([2104](https://github.com/arviz-devs/arviz/pull/2104))
* Correctly (re)order dimensions for `bfmi` and `plot_energy` ([2126](https://github.com/arviz-devs/arviz/pull/2126))

### Deprecation
* Removed `fill_last`, `contour` and `plot_kwargs` arguments from `plot_pair` function ([2085](https://github.com/arviz-devs/arviz/pull/2085))
Expand Down
2 changes: 1 addition & 1 deletion arviz/plots/energyplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def plot_energy(
>>> az.plot_energy(data, kind='hist')
"""
energy = convert_to_dataset(data, group="sample_stats").energy.values
energy = convert_to_dataset(data, group="sample_stats").energy.transpose("chain", "draw").values

if kind == "histogram":
warnings.warn(
Expand Down
2 changes: 1 addition & 1 deletion arviz/stats/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def bfmi(data):
dataset = convert_to_dataset(data, group="sample_stats")
if not hasattr(dataset, "energy"):
raise TypeError("Energy variable was not found.")
return _bfmi(dataset.energy)
return _bfmi(dataset.energy.transpose("chain", "draw"))


def ess(
Expand Down
7 changes: 7 additions & 0 deletions arviz/tests/base_tests/test_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ def test_bfmi_dataset_bad(self):
with pytest.raises(TypeError):
bfmi(data)

def test_bfmi_correctly_transposed(self):
data = load_arviz_data("centered_eight")
vals1 = bfmi(data)
data.sample_stats["energy"] = data.sample_stats["energy"].T
vals2 = bfmi(data)
assert_almost_equal(vals1, vals2)

def test_deterministic(self):
"""
Test algorithm against posterior (R) convergence functions.
Expand Down
8 changes: 8 additions & 0 deletions arviz/tests/base_tests/test_plots_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,14 @@ def test_plot_energy_bad(models):
plot_energy(models.model_1, kind="bad_kind")


def test_plot_energy_correctly_transposed():
idata = load_arviz_data("centered_eight")
idata.sample_stats["energy"] = idata.sample_stats.energy.T
ax = plot_energy(idata)
# legend has one entry for each KDE and 1 BFMI for each chain
assert len(ax.legend_.texts) == 2 + len(idata.sample_stats.chain)


def test_plot_parallel_raises_valueerror(df_trace): # pylint: disable=invalid-name
with pytest.raises(ValueError):
plot_parallel(df_trace)
Expand Down

0 comments on commit bc1e887

Please sign in to comment.