Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable dimension order in selection #2103

Merged
merged 13 commits into from
Oct 6, 2022
8 changes: 7 additions & 1 deletion arviz/plots/traceplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,13 @@ def plot_trace(
skip_dims = set(coords_data.dims) - {"chain", "draw"} if compact else set()

plotters = list(
xarray_var_iter(coords_data, var_names=var_names, combined=True, skip_dims=skip_dims)
xarray_var_iter(
coords_data,
var_names=var_names,
combined=True,
skip_dims=skip_dims,
dim_order=["chain", "draw"],
)
)
max_plots = rcParams["plot.max_subplots"]
max_plots = len(plotters) if max_plots is None else max(max_plots // 2, 1)
Expand Down
44 changes: 16 additions & 28 deletions arviz/sel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,31 +55,6 @@ def make_label(var_name, selection, position="below"):
return base.format(var_name, sel)


def purge_duplicates(list_in):
"""Remove duplicates from list while preserving order.

Parameters
----------
list_in: Iterable

Returns
-------
list
List of first occurrences in order
"""
# Algorithm taken from Stack Overflow,
# https://stackoverflow.com/questions/480214. Content by Georgy
# Skorobogatov (https://stackoverflow.com/users/7851470/georgy) and
# Markus Jarderot
# (https://stackoverflow.com/users/22364/markus-jarderot), licensed
# under CC-BY-SA 4.0.
# https://creativecommons.org/licenses/by-sa/4.0/.

seen = set()
seen_add = seen.add
return [x for x in list_in if not (x in seen or seen_add(x))]


def _dims(data, var_name, skip_dims):
return [dim for dim in data[var_name].dims if dim not in skip_dims]

Expand Down Expand Up @@ -136,7 +111,7 @@ def xarray_sel_iter(data, var_names=None, combined=False, skip_dims=None, revers
for var_name in var_names:
if var_name in data:
new_dims = _dims(data, var_name, skip_dims)
vals = [purge_duplicates(data[var_name][dim].values) for dim in new_dims]
vals = [list(dict.fromkeys(data[var_name][dim].values)) for dim in new_dims]
dims = _zip_dims(new_dims, vals)
idims = _zip_dims(new_dims, [range(len(v)) for v in vals])
if reverse_selections:
Expand All @@ -147,7 +122,9 @@ def xarray_sel_iter(data, var_names=None, combined=False, skip_dims=None, revers
yield var_name, selection, iselection


def xarray_var_iter(data, var_names=None, combined=False, skip_dims=None, reverse_selections=False):
def xarray_var_iter(
data, var_names=None, combined=False, skip_dims=None, reverse_selections=False, dim_order=None
):
"""Convert xarray data to an iterator over vectors.

Iterates over each var_name and all of its coordinates, returning the 1d
Expand All @@ -170,6 +147,9 @@ def xarray_var_iter(data, var_names=None, combined=False, skip_dims=None, revers
reverse_selections : bool
Whether to reverse selections before iterating.

dim_order: list
Order for the first dimensions. Skips dimensions not found in the variable.

Returns
-------
Iterator of (str, dict(str, any), np.array)
Expand All @@ -180,14 +160,22 @@ def xarray_var_iter(data, var_names=None, combined=False, skip_dims=None, revers
if var_names is None and isinstance(data, xr.DataArray):
data_to_sel = {data.name: data}

if isinstance(dim_order, str):
dim_order = [dim_order]

for var_name, selection, iselection in xarray_sel_iter(
data,
var_names=var_names,
combined=combined,
skip_dims=skip_dims,
reverse_selections=reverse_selections,
):
yield var_name, selection, iselection, data_to_sel[var_name].sel(**selection).values
selected_data = data_to_sel[var_name].sel(**selection)
if dim_order is not None:
dim_order_selected = [dim for dim in dim_order if dim in selected_data.dims]
if dim_order_selected:
selected_data = selected_data.transpose(*dim_order, ...)
ahartikainen marked this conversation as resolved.
Show resolved Hide resolved
yield var_name, selection, iselection, selected_data.values


def xarray_to_ndarray(data, *, var_names=None, combined=True, label_fun=None):
Expand Down