Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

plot_trace errors for non-default order of dimensions #2102

Closed
sethaxen opened this issue Aug 25, 2022 · 4 comments · Fixed by #2103
Closed

plot_trace errors for non-default order of dimensions #2102

sethaxen opened this issue Aug 25, 2022 · 4 comments · Fixed by #2103

Comments

@sethaxen
Copy link
Member

Describe the bug
The default dimension order of variables is (chain, draw, shape...), but as discussed in #1693, all dimension orders should be supported. plot_trace errors when the dimension order is (draw, chain).

To Reproduce

>>> import arviz as az
>>> import xarray as xr
>>> import numpy as np
>>> nchains, ndraws = 4, 100
>>> ds = xr.Dataset(dict(x = (('chain', 'draw'), np.random.normal(size=(nchains, ndraws))))) 
>>> az.plot_trace(ds)
array([[<AxesSubplot:title={'center':'x'}>,
        <AxesSubplot:title={'center':'x'}>]], dtype=object)
>>> ds = xr.Dataset(dict(x = (('draw', 'chain'), np.random.normal(size=(ndraws, nchains)))))
>>> ds
<xarray.Dataset>
Dimensions:  (draw: 100, chain: 4)
Dimensions without coordinates: draw, chain
Data variables:
    x        (draw, chain) float64 0.7368 0.347 1.105 ... 0.9654 -0.98 0.6756
>>> az.plot_trace(ds)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/sethaxen/software/mambaforge/envs/arvizdev/lib/python3.10/site-packages/arviz/plots/traceplot.py", line 260, in plot_trace
    axes = plot(**trace_plot_args)
  File "/home/sethaxen/software/mambaforge/envs/arvizdev/lib/python3.10/site-packages/arviz/plots/backends/matplotlib/traceplot.py", line 248, in plot_trace
    ax = _plot_chains_mpl(
  File "/home/sethaxen/software/mambaforge/envs/arvizdev/lib/python3.10/site-packages/arviz/plots/backends/matplotlib/traceplot.py", line 476, in _plot_chains_mpl
    aux_kwargs = dealiase_sel_kwargs(trace_kwargs, chain_prop, chain_idx)
  File "/home/sethaxen/software/mambaforge/envs/arvizdev/lib/python3.10/site-packages/arviz/plots/backends/matplotlib/__init__.py", line 99, in dealiase_sel_kwargs
    {prop: props[idx] for prop, props in prop_dict.items()}, "plot"
  File "/home/sethaxen/software/mambaforge/envs/arvizdev/lib/python3.10/site-packages/arviz/plots/backends/matplotlib/__init__.py", line 99, in <dictcomp>
    {prop: props[idx] for prop, props in prop_dict.items()}, "plot"
IndexError: list index out of range
@ahartikainen
Copy link
Contributor

Problem with sel_utils.xarray_var_iter?

@ahartikainen
Copy link
Contributor

yield var_name, selection, iselection, data_to_sel[var_name].sel(**selection).values

This returns an array with the shape based on the order of data. Currently this function is missing the reshape possibility. E.g. we could add one more step to that loop and call .transpose with a new keyword dim_order

@sethaxen
Copy link
Member Author

Quite a few other functions use sel_utils.xarray_var_iter, and I don't think all of them have this issue.

A few more functions that do the wrong thing with dimensions are bfmi (if draw and chain are swapped, it interprets draws as chains; as a consequence, so does the legend of plot_energy), and plot_rank (errors).

@ahartikainen
Copy link
Contributor

I mean that function returns values in the original order. So we can add possibility to change the order there (we still have dims)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants