diff --git a/CHANGELOG.md b/CHANGELOG.md index d160f32122..dd242c1986 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ * Bokeh kde contour plots started to use `contourpy` package ([2104](https://github.com/arviz-devs/arviz/pull/2104)) * Update default Bokeh markers for rcparams ([2104](https://github.com/arviz-devs/arviz/pull/2104)) * Correctly (re)order dimensions for `bfmi` and `plot_energy` ([2126](https://github.com/arviz-devs/arviz/pull/2126)) +* Fix bug with the dimension order dependency ([2103](https://github.com/arviz-devs/arviz/pull/2103)) ### Deprecation * Removed `fill_last`, `contour` and `plot_kwargs` arguments from `plot_pair` function ([2085](https://github.com/arviz-devs/arviz/pull/2085)) diff --git a/arviz/plots/autocorrplot.py b/arviz/plots/autocorrplot.py index 91ece635c3..dd557cef0a 100644 --- a/arviz/plots/autocorrplot.py +++ b/arviz/plots/autocorrplot.py @@ -127,7 +127,8 @@ def plot_autocorr( labeller = BaseLabeller() plotters = filter_plotters_list( - list(xarray_var_iter(data, var_names, combined)), "plot_autocorr" + list(xarray_var_iter(data, var_names, combined, dim_order=["chain", "draw"])), + "plot_autocorr", ) rows, cols = default_grid(len(plotters), grid=grid) diff --git a/arviz/plots/rankplot.py b/arviz/plots/rankplot.py index 0b5eb97c99..eeefe67ca4 100644 --- a/arviz/plots/rankplot.py +++ b/arviz/plots/rankplot.py @@ -174,6 +174,7 @@ def plot_rank( posterior_data, var_names=var_names, combined=True, + dim_order=["chain", "draw"], ) ), "plot_rank", diff --git a/arviz/plots/traceplot.py b/arviz/plots/traceplot.py index f02356a4ae..f6113b948a 100644 --- a/arviz/plots/traceplot.py +++ b/arviz/plots/traceplot.py @@ -201,7 +201,13 @@ def plot_trace( skip_dims = set(coords_data.dims) - {"chain", "draw"} if compact else set() plotters = list( - xarray_var_iter(coords_data, var_names=var_names, combined=True, skip_dims=skip_dims) + xarray_var_iter( + coords_data, + var_names=var_names, + combined=True, + skip_dims=skip_dims, + dim_order=["chain", "draw"], + ) ) max_plots = rcParams["plot.max_subplots"] max_plots = len(plotters) if max_plots is None else max(max_plots // 2, 1) diff --git a/arviz/sel_utils.py b/arviz/sel_utils.py index 11773d7967..4c041f13b2 100644 --- a/arviz/sel_utils.py +++ b/arviz/sel_utils.py @@ -55,31 +55,6 @@ def make_label(var_name, selection, position="below"): return base.format(var_name, sel) -def purge_duplicates(list_in): - """Remove duplicates from list while preserving order. - - Parameters - ---------- - list_in: Iterable - - Returns - ------- - list - List of first occurrences in order - """ - # Algorithm taken from Stack Overflow, - # https://stackoverflow.com/questions/480214. Content by Georgy - # Skorobogatov (https://stackoverflow.com/users/7851470/georgy) and - # Markus Jarderot - # (https://stackoverflow.com/users/22364/markus-jarderot), licensed - # under CC-BY-SA 4.0. - # https://creativecommons.org/licenses/by-sa/4.0/. - - seen = set() - seen_add = seen.add - return [x for x in list_in if not (x in seen or seen_add(x))] - - def _dims(data, var_name, skip_dims): return [dim for dim in data[var_name].dims if dim not in skip_dims] @@ -136,7 +111,7 @@ def xarray_sel_iter(data, var_names=None, combined=False, skip_dims=None, revers for var_name in var_names: if var_name in data: new_dims = _dims(data, var_name, skip_dims) - vals = [purge_duplicates(data[var_name][dim].values) for dim in new_dims] + vals = [list(dict.fromkeys(data[var_name][dim].values)) for dim in new_dims] dims = _zip_dims(new_dims, vals) idims = _zip_dims(new_dims, [range(len(v)) for v in vals]) if reverse_selections: @@ -147,7 +122,9 @@ def xarray_sel_iter(data, var_names=None, combined=False, skip_dims=None, revers yield var_name, selection, iselection -def xarray_var_iter(data, var_names=None, combined=False, skip_dims=None, reverse_selections=False): +def xarray_var_iter( + data, var_names=None, combined=False, skip_dims=None, reverse_selections=False, dim_order=None +): """Convert xarray data to an iterator over vectors. Iterates over each var_name and all of its coordinates, returning the 1d @@ -170,6 +147,9 @@ def xarray_var_iter(data, var_names=None, combined=False, skip_dims=None, revers reverse_selections : bool Whether to reverse selections before iterating. + dim_order: list + Order for the first dimensions. Skips dimensions not found in the variable. + Returns ------- Iterator of (str, dict(str, any), np.array) @@ -180,6 +160,9 @@ def xarray_var_iter(data, var_names=None, combined=False, skip_dims=None, revers if var_names is None and isinstance(data, xr.DataArray): data_to_sel = {data.name: data} + if isinstance(dim_order, str): + dim_order = [dim_order] + for var_name, selection, iselection in xarray_sel_iter( data, var_names=var_names, @@ -187,7 +170,12 @@ def xarray_var_iter(data, var_names=None, combined=False, skip_dims=None, revers skip_dims=skip_dims, reverse_selections=reverse_selections, ): - yield var_name, selection, iselection, data_to_sel[var_name].sel(**selection).values + selected_data = data_to_sel[var_name].sel(**selection) + if dim_order is not None: + dim_order_selected = [dim for dim in dim_order if dim in selected_data.dims] + if dim_order_selected: + selected_data = selected_data.transpose(*dim_order_selected, ...) + yield var_name, selection, iselection, selected_data.values def xarray_to_ndarray(data, *, var_names=None, combined=True, label_fun=None): diff --git a/arviz/tests/helpers.py b/arviz/tests/helpers.py index 9cc79157cf..23c198071f 100644 --- a/arviz/tests/helpers.py +++ b/arviz/tests/helpers.py @@ -1,4 +1,4 @@ -# pylint: disable=redefined-outer-name, comparison-with-callable +# pylint: disable=redefined-outer-name, comparison-with-callable, protected-access """Test helper functions.""" import gzip import importlib @@ -51,7 +51,7 @@ def chains(): return 2 -def create_model(seed=10): +def create_model(seed=10, transpose=False): """Create model with fake data.""" np.random.seed(seed) nchains = 4 @@ -104,10 +104,15 @@ def create_model(seed=10): }, coords={"obs_dim": range(data["J"])}, ) + if transpose: + for group in model._groups: + group_dataset = getattr(model, group) + if all(dim in group_dataset.dims for dim in ("draw", "chain")): + setattr(model, group, group_dataset.transpose(*["draw", "chain"], ...)) return model -def create_multidimensional_model(seed=10): +def create_multidimensional_model(seed=10, transpose=False): """Create model with fake data.""" np.random.seed(seed) nchains = 4 @@ -155,6 +160,11 @@ def create_multidimensional_model(seed=10): dims={"y": ["dim1", "dim2"], "log_likelihood": ["dim1", "dim2"]}, coords={"dim1": range(ndim1), "dim2": range(ndim2)}, ) + if transpose: + for group in model._groups: + group_dataset = getattr(model, group) + if all(dim in group_dataset.dims for dim in ("draw", "chain")): + setattr(model, group, group_dataset.transpose(*["draw", "chain"], ...)) return model @@ -195,7 +205,7 @@ def models(): class Models: model_1 = create_model(seed=10) - model_2 = create_model(seed=11) + model_2 = create_model(seed=11, transpose=True) return Models() @@ -207,7 +217,7 @@ def multidim_models(): class Models: model_1 = create_multidimensional_model(seed=10) - model_2 = create_multidimensional_model(seed=11) + model_2 = create_multidimensional_model(seed=11, transpose=True) return Models()