diff --git a/arviz/labels.py b/arviz/labels.py index f6cc4cebd1..f0d34fb004 100644 --- a/arviz/labels.py +++ b/arviz/labels.py @@ -23,6 +23,8 @@ def make_label_vert(self, var_name: str, sel: dict, isel: dict): if not sel: return var_name_str sel_str = self.sel_to_str(sel, isel) + if not var_name: + return sel_str return f"{var_name_str}\n{sel_str}" def make_label_flat(self, var_name: str, sel: dict, isel: dict): @@ -30,17 +32,19 @@ def make_label_flat(self, var_name: str, sel: dict, isel: dict): if not sel: return var_name_str sel_str = self.sel_to_str(sel, isel) - return f"{var_name_str}: {sel_str}" + if not var_name: + return sel_str + return f"{var_name_str}[{sel_str}]" class DimCoordLabeller(BaseLabeller): def dim_coord_to_str(self, dim, coord_val, coord_idx): - return f"{dim}[{coord_val}]" + return f"{dim}: {coord_val}" class DimIdxLabeller(BaseLabeller): def dim_coord_to_str(self, dim, coord_val, coord_idx): - return f"{dim}[{coord_idx}]" + return f"{dim}#{coord_idx}" class MapLabeller(BaseLabeller): @@ -52,7 +56,7 @@ def __init__(self, var_name_map=None, dim_map=None, coord_map=None): def dim_coord_to_str(self, dim, coord_val, coord_idx): dim_str = self.dim_map.get(dim, dim) coord_str = self.coord_map.get(coord_val, coord_val) - return f"{dim_str}[{coord_str}]" + return super().dim_coord_to_str(dim_str, coord_str, coord_idx) def var_name_to_str(self, var_name): return self.var_name_map.get(var_name, var_name) diff --git a/arviz/plots/autocorrplot.py b/arviz/plots/autocorrplot.py index f856264831..3f089ed491 100644 --- a/arviz/plots/autocorrplot.py +++ b/arviz/plots/autocorrplot.py @@ -1,5 +1,6 @@ """Autocorrelation plot of data.""" from ..data import convert_to_dataset +from ..labels import BaseLabeller from ..sel_utils import xarray_var_iter from ..rcparams import rcParams from ..utils import _var_names @@ -15,6 +16,7 @@ def plot_autocorr( grid=None, figsize=None, textsize=None, + labeller=None, ax=None, backend=None, backend_config=None, @@ -53,6 +55,8 @@ def plot_autocorr( textsize: float Text size scaling factor for labels, titles and lines. If None it will be autoscaled based on figsize. + labeller : labeller instance, optional + Class providing the method `make_label_vert` to generate the labels in the plot titles. ax: numpy array-like of matplotlib axes or bokeh figures, optional A 2D array of locations into which to plot the densities. If not supplied, Arviz will create its own array of plot areas (and return it). @@ -111,6 +115,9 @@ def plot_autocorr( if max_lag is None: max_lag = min(100, data["draw"].shape[0]) + if labeller is None: + labeller = BaseLabeller() + plotters = filter_plotters_list( list(xarray_var_iter(data, var_names, combined)), "plot_autocorr" ) @@ -125,6 +132,7 @@ def plot_autocorr( cols=cols, combined=combined, textsize=textsize, + labeller=labeller, backend_kwargs=backend_kwargs, show=show, ) diff --git a/arviz/plots/backends/bokeh/autocorrplot.py b/arviz/plots/backends/bokeh/autocorrplot.py index 5f2e14a68a..927da25080 100644 --- a/arviz/plots/backends/bokeh/autocorrplot.py +++ b/arviz/plots/backends/bokeh/autocorrplot.py @@ -4,7 +4,6 @@ from bokeh.models.annotations import Title -from ....sel_utils import make_label from ....stats import autocorr from ...plot_utils import _scale_fig_size from .. import show_layout @@ -20,6 +19,7 @@ def plot_autocorr( cols, combined, textsize, + labeller, backend_config, backend_kwargs, show, @@ -70,7 +70,7 @@ def plot_autocorr( start=-1, end=1, bounds=backend_config["bounds_y_range"], min_interval=0.1 ) - for (var_name, selection, x), ax in zip( + for (var_name, selection, isel, x), ax in zip( plotters, (item for item in axes.flatten() if item is not None) ): x_prime = x @@ -90,7 +90,7 @@ def plot_autocorr( ) title = Title() - title.text = make_label(var_name, selection) + title.text = labeller.make_label_vert(var_name, selection, isel) ax.title = title ax.x_range = data_range_x ax.y_range = data_range_y diff --git a/arviz/plots/backends/bokeh/essplot.py b/arviz/plots/backends/bokeh/essplot.py index ea340d4f68..94206e63ee 100644 --- a/arviz/plots/backends/bokeh/essplot.py +++ b/arviz/plots/backends/bokeh/essplot.py @@ -8,7 +8,6 @@ from .. import show_layout from . import backend_kwarg_defaults, create_axes_grid from ...plot_utils import _scale_fig_size -from ....sel_utils import make_label def plot_ess( @@ -32,6 +31,7 @@ def plot_ess( n_samples, relative, min_ess, + labeller, ylabel, rug, rug_kind, @@ -62,7 +62,7 @@ def plot_ess( else: ax = np.atleast_2d(ax) - for (var_name, selection, x), ax_ in zip( + for (var_name, selection, isel, x), ax_ in zip( plotters, (item for item in ax.flatten() if item is not None) ): bulk_points = ax_.circle(np.asarray(xdata), np.asarray(x), size=6) @@ -154,7 +154,7 @@ def plot_ess( ax_.legend.click_policy = "hide" title = Title() - title.text = make_label(var_name, selection) + title.text = labeller.make_label_vert(var_name, selection, isel) ax_.title = title ax_.xaxis.axis_label = "Total number of draws" if kind == "evolution" else "Quantile" diff --git a/arviz/plots/backends/matplotlib/autocorrplot.py b/arviz/plots/backends/matplotlib/autocorrplot.py index 47346c1d32..d80b5d38fa 100644 --- a/arviz/plots/backends/matplotlib/autocorrplot.py +++ b/arviz/plots/backends/matplotlib/autocorrplot.py @@ -5,7 +5,6 @@ from ....stats import autocorr from ...plot_utils import _scale_fig_size from . import backend_kwarg_defaults, backend_show, create_axes_grid -from ....sel_utils import make_label def plot_autocorr( @@ -17,6 +16,7 @@ def plot_autocorr( cols, combined, textsize, + labeller, backend_kwargs, show, ): @@ -46,7 +46,7 @@ def plot_autocorr( backend_kwargs=backend_kwargs, ) - for (var_name, selection, x), ax in zip(plotters, np.ravel(axes)): + for (var_name, selection, isel, x), ax in zip(plotters, np.ravel(axes)): x_prime = x if combined: x_prime = x.flatten() @@ -56,7 +56,9 @@ def plot_autocorr( ax.fill_between([0, max_lag], -c_i, c_i, color="0.75") ax.vlines(x=np.arange(0, max_lag), ymin=0, ymax=y[0:max_lag], lw=linewidth) - ax.set_title(make_label(var_name, selection), fontsize=titlesize, wrap=True) + ax.set_title( + labeller.make_label_vert(var_name, selection, isel), fontsize=titlesize, wrap=True + ) ax.tick_params(labelsize=xt_labelsize) if np.asarray(axes).size > 0: diff --git a/arviz/plots/backends/matplotlib/essplot.py b/arviz/plots/backends/matplotlib/essplot.py index c99ccc6bc9..e608d86495 100644 --- a/arviz/plots/backends/matplotlib/essplot.py +++ b/arviz/plots/backends/matplotlib/essplot.py @@ -5,7 +5,6 @@ from ...plot_utils import _scale_fig_size from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser -from ....sel_utils import make_label def plot_ess( @@ -29,6 +28,7 @@ def plot_ess( n_samples, relative, min_ess, + labeller, ylabel, rug, rug_kind, @@ -96,7 +96,7 @@ def plot_ess( backend_kwargs=backend_kwargs, ) - for (var_name, selection, x), ax_ in zip(plotters, np.ravel(ax)): + for (var_name, selection, isel, x), ax_ in zip(plotters, np.ravel(ax)): ax_.plot(xdata, x, **kwargs) if kind == "evolution": ess_tail = ess_tail_dataset[var_name].sel(**selection) @@ -144,7 +144,7 @@ def plot_ess( ax_.axhline(400 / n_samples if relative else min_ess, **hline_kwargs) - ax_.set_title(make_label(var_name, selection), fontsize=titlesize, wrap=True) + ax_.set_title(labeller.make_label_vert(var_name, selection, isel), fontsize=titlesize, wrap=True) ax_.tick_params(labelsize=xt_labelsize) ax_.set_xlabel( "Total number of draws" if kind == "evolution" else "Quantile", fontsize=ax_labelsize diff --git a/arviz/plots/essplot.py b/arviz/plots/essplot.py index 3e423df4a3..a8ecce349a 100644 --- a/arviz/plots/essplot.py +++ b/arviz/plots/essplot.py @@ -3,6 +3,7 @@ import xarray as xr from ..data import convert_to_dataset +from ..labels import BaseLabeller from ..rcparams import rcParams from ..sel_utils import xarray_var_iter from ..stats import ess @@ -25,6 +26,7 @@ def plot_ess( n_points=20, extra_methods=False, min_ess=400, + labeller=None, ax=None, extra_kwargs=None, text_kwargs=None, @@ -75,6 +77,8 @@ def plot_ess( Plot mean and sd ESS as horizontal lines. Not taken into account in evolution kind min_ess: int Minimum number of ESS desired. + labeller : labeller instance, optional + Class providing the method `make_label_vert` to generate the labels in the plot titles. ax: numpy array-like of matplotlib axes or bokeh figures, optional A 2D array of locations into which to plot the densities. If not supplied, Arviz will create its own array of plot areas (and return it). @@ -174,6 +178,8 @@ def plot_ess( coords = {} if "chain" in coords or "draw" in coords: raise ValueError("chain and draw are invalid coordinates for this kind of plot") + if labeller is None: + labeller = BaseLabeller() extra_methods = False if kind == "evolution" else extra_methods data = get_coords(convert_to_dataset(idata, group="posterior"), coords) @@ -274,6 +280,7 @@ def plot_ess( n_samples=n_samples, relative=relative, min_ess=min_ess, + labeller=labeller, ylabel=ylabel, rug=rug, rug_kind=rug_kind, diff --git a/arviz/sel_utils.py b/arviz/sel_utils.py index b0835f3633..432ce6a1fe 100644 --- a/arviz/sel_utils.py +++ b/arviz/sel_utils.py @@ -134,11 +134,13 @@ def xarray_sel_iter(data, var_names=None, combined=False, skip_dims=None, revers new_dims = _dims(data, var_name, skip_dims) vals = [purge_duplicates(data[var_name][dim].values) for dim in new_dims] dims = _zip_dims(new_dims, vals) + idims = _zip_dims(new_dims, [range(len(v)) for v in vals]) if reverse_selections: dims = reversed(dims) + idims = reversed(idims) - for selection in dims: - yield var_name, selection + for selection, iselection in zip(dims, idims): + yield var_name, selection, iselection def xarray_var_iter(data, var_names=None, combined=False, skip_dims=None, reverse_selections=False): @@ -174,14 +176,14 @@ def xarray_var_iter(data, var_names=None, combined=False, skip_dims=None, revers if var_names is None and isinstance(data, xr.DataArray): data_to_sel = {data.name: data} - for var_name, selection in xarray_sel_iter( + for var_name, selection, iselection in xarray_sel_iter( data, var_names=var_names, combined=combined, skip_dims=skip_dims, reverse_selections=reverse_selections, ): - yield var_name, selection, data_to_sel[var_name].sel(**selection).values + yield var_name, selection, iselection, data_to_sel[var_name].sel(**selection).values def xarray_to_ndarray(data, *, var_names=None, combined=True): @@ -213,12 +215,14 @@ def xarray_to_ndarray(data, *, var_names=None, combined=True): iterator1, iterator2 = tee(xarray_sel_iter(data, var_names=var_names, combined=combined)) vars_and_sel = list(iterator1) - unpacked_var_names = [make_label(var_name, selection) for var_name, selection in vars_and_sel] + unpacked_var_names = [ + make_label(var_name, selection) for var_name, selection, _ in vars_and_sel + ] # Merge chains and variables, check dtype to be compatible with divergences data data0 = data_to_sel[vars_and_sel[0][0]].sel(**vars_and_sel[0][1]) unpacked_data = np.empty((len(unpacked_var_names), data0.size), dtype=data0.dtype) - for idx, (var_name, selection) in enumerate(iterator2): + for idx, (var_name, selection, _) in enumerate(iterator2): unpacked_data[idx] = data_to_sel[var_name].sel(**selection).values.flatten() return unpacked_var_names, unpacked_data diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 7f1a34fcce..27054abe56 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -1132,10 +1132,11 @@ def summary( raise TypeError(f"InferenceData does not contain group: {group}") dataset = data[group] else: - dataset = convert_to_dataset(data, group="posterior") + dataset = get_coords(convert_to_dataset(data, group="posterior"), coords) var_names = _var_names(var_names, dataset, filter_vars) dataset = dataset if var_names is None else dataset[var_names] + fmt_group = ("wide", "long", "xarray") if not isinstance(fmt, str) or (fmt.lower() not in fmt_group): raise TypeError(f"Invalid format: '{fmt}'. Formatting options are: {fmt_group}") @@ -1281,9 +1282,9 @@ def summary( if fmt.lower() == "wide": summary_df = pd.DataFrame(np.full((n_vars, n_metrics), np.nan), columns=metric_names) indexs = [] - for i, (var_name, sel, values) in enumerate(xarray_var_iter(joined, skip_dims={"metric"})): + for i, (var_name, sel, isel, values) in enumerate(xarray_var_iter(joined, skip_dims={"metric"})): summary_df.iloc[i] = values - indexs.append(labeller.make_label_flat(var_name, sel, sel)) + indexs.append(labeller.make_label_flat(var_name, sel, isel)) summary_df.index = indexs elif fmt.lower() == "long": df = joined.to_dataframe().reset_index().set_index("metric")