Skip to content

Commit

Permalink
update ess and autocorr plots
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Feb 18, 2021
1 parent 9e9aa04 commit bf05f4e
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 25 deletions.
12 changes: 8 additions & 4 deletions arviz/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,28 @@ def make_label_vert(self, var_name: str, sel: dict, isel: dict):
if not sel:
return var_name_str
sel_str = self.sel_to_str(sel, isel)
if not var_name:
return sel_str
return f"{var_name_str}\n{sel_str}"

def make_label_flat(self, var_name: str, sel: dict, isel: dict):
var_name_str = self.var_name_to_str(var_name)
if not sel:
return var_name_str
sel_str = self.sel_to_str(sel, isel)
return f"{var_name_str}: {sel_str}"
if not var_name:
return sel_str
return f"{var_name_str}[{sel_str}]"


class DimCoordLabeller(BaseLabeller):
def dim_coord_to_str(self, dim, coord_val, coord_idx):
return f"{dim}[{coord_val}]"
return f"{dim}: {coord_val}"


class DimIdxLabeller(BaseLabeller):
def dim_coord_to_str(self, dim, coord_val, coord_idx):
return f"{dim}[{coord_idx}]"
return f"{dim}#{coord_idx}"


class MapLabeller(BaseLabeller):
Expand All @@ -52,7 +56,7 @@ def __init__(self, var_name_map=None, dim_map=None, coord_map=None):
def dim_coord_to_str(self, dim, coord_val, coord_idx):
dim_str = self.dim_map.get(dim, dim)
coord_str = self.coord_map.get(coord_val, coord_val)
return f"{dim_str}[{coord_str}]"
return super().dim_coord_to_str(dim_str, coord_str, coord_idx)

def var_name_to_str(self, var_name):
return self.var_name_map.get(var_name, var_name)
Expand Down
8 changes: 8 additions & 0 deletions arviz/plots/autocorrplot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Autocorrelation plot of data."""
from ..data import convert_to_dataset
from ..labels import BaseLabeller
from ..sel_utils import xarray_var_iter
from ..rcparams import rcParams
from ..utils import _var_names
Expand All @@ -15,6 +16,7 @@ def plot_autocorr(
grid=None,
figsize=None,
textsize=None,
labeller=None,
ax=None,
backend=None,
backend_config=None,
Expand Down Expand Up @@ -53,6 +55,8 @@ def plot_autocorr(
textsize: float
Text size scaling factor for labels, titles and lines. If None it will be autoscaled based
on figsize.
labeller : labeller instance, optional
Class providing the method `make_label_vert` to generate the labels in the plot titles.
ax: numpy array-like of matplotlib axes or bokeh figures, optional
A 2D array of locations into which to plot the densities. If not supplied, Arviz will create
its own array of plot areas (and return it).
Expand Down Expand Up @@ -111,6 +115,9 @@ def plot_autocorr(
if max_lag is None:
max_lag = min(100, data["draw"].shape[0])

if labeller is None:
labeller = BaseLabeller()

plotters = filter_plotters_list(
list(xarray_var_iter(data, var_names, combined)), "plot_autocorr"
)
Expand All @@ -125,6 +132,7 @@ def plot_autocorr(
cols=cols,
combined=combined,
textsize=textsize,
labeller=labeller,
backend_kwargs=backend_kwargs,
show=show,
)
Expand Down
6 changes: 3 additions & 3 deletions arviz/plots/backends/bokeh/autocorrplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from bokeh.models.annotations import Title


from ....sel_utils import make_label
from ....stats import autocorr
from ...plot_utils import _scale_fig_size
from .. import show_layout
Expand All @@ -20,6 +19,7 @@ def plot_autocorr(
cols,
combined,
textsize,
labeller,
backend_config,
backend_kwargs,
show,
Expand Down Expand Up @@ -70,7 +70,7 @@ def plot_autocorr(
start=-1, end=1, bounds=backend_config["bounds_y_range"], min_interval=0.1
)

for (var_name, selection, x), ax in zip(
for (var_name, selection, isel, x), ax in zip(
plotters, (item for item in axes.flatten() if item is not None)
):
x_prime = x
Expand All @@ -90,7 +90,7 @@ def plot_autocorr(
)

title = Title()
title.text = make_label(var_name, selection)
title.text = labeller.make_label_vert(var_name, selection, isel)
ax.title = title
ax.x_range = data_range_x
ax.y_range = data_range_y
Expand Down
6 changes: 3 additions & 3 deletions arviz/plots/backends/bokeh/essplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from .. import show_layout
from . import backend_kwarg_defaults, create_axes_grid
from ...plot_utils import _scale_fig_size
from ....sel_utils import make_label


def plot_ess(
Expand All @@ -32,6 +31,7 @@ def plot_ess(
n_samples,
relative,
min_ess,
labeller,
ylabel,
rug,
rug_kind,
Expand Down Expand Up @@ -62,7 +62,7 @@ def plot_ess(
else:
ax = np.atleast_2d(ax)

for (var_name, selection, x), ax_ in zip(
for (var_name, selection, isel, x), ax_ in zip(
plotters, (item for item in ax.flatten() if item is not None)
):
bulk_points = ax_.circle(np.asarray(xdata), np.asarray(x), size=6)
Expand Down Expand Up @@ -154,7 +154,7 @@ def plot_ess(
ax_.legend.click_policy = "hide"

title = Title()
title.text = make_label(var_name, selection)
title.text = labeller.make_label_vert(var_name, selection, isel)
ax_.title = title

ax_.xaxis.axis_label = "Total number of draws" if kind == "evolution" else "Quantile"
Expand Down
8 changes: 5 additions & 3 deletions arviz/plots/backends/matplotlib/autocorrplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from ....stats import autocorr
from ...plot_utils import _scale_fig_size
from . import backend_kwarg_defaults, backend_show, create_axes_grid
from ....sel_utils import make_label


def plot_autocorr(
Expand All @@ -17,6 +16,7 @@ def plot_autocorr(
cols,
combined,
textsize,
labeller,
backend_kwargs,
show,
):
Expand Down Expand Up @@ -46,7 +46,7 @@ def plot_autocorr(
backend_kwargs=backend_kwargs,
)

for (var_name, selection, x), ax in zip(plotters, np.ravel(axes)):
for (var_name, selection, isel, x), ax in zip(plotters, np.ravel(axes)):
x_prime = x
if combined:
x_prime = x.flatten()
Expand All @@ -56,7 +56,9 @@ def plot_autocorr(
ax.fill_between([0, max_lag], -c_i, c_i, color="0.75")
ax.vlines(x=np.arange(0, max_lag), ymin=0, ymax=y[0:max_lag], lw=linewidth)

ax.set_title(make_label(var_name, selection), fontsize=titlesize, wrap=True)
ax.set_title(
labeller.make_label_vert(var_name, selection, isel), fontsize=titlesize, wrap=True
)
ax.tick_params(labelsize=xt_labelsize)

if np.asarray(axes).size > 0:
Expand Down
6 changes: 3 additions & 3 deletions arviz/plots/backends/matplotlib/essplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from ...plot_utils import _scale_fig_size
from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser
from ....sel_utils import make_label


def plot_ess(
Expand All @@ -29,6 +28,7 @@ def plot_ess(
n_samples,
relative,
min_ess,
labeller,
ylabel,
rug,
rug_kind,
Expand Down Expand Up @@ -96,7 +96,7 @@ def plot_ess(
backend_kwargs=backend_kwargs,
)

for (var_name, selection, x), ax_ in zip(plotters, np.ravel(ax)):
for (var_name, selection, isel, x), ax_ in zip(plotters, np.ravel(ax)):
ax_.plot(xdata, x, **kwargs)
if kind == "evolution":
ess_tail = ess_tail_dataset[var_name].sel(**selection)
Expand Down Expand Up @@ -144,7 +144,7 @@ def plot_ess(

ax_.axhline(400 / n_samples if relative else min_ess, **hline_kwargs)

ax_.set_title(make_label(var_name, selection), fontsize=titlesize, wrap=True)
ax_.set_title(labeller.make_label_vert(var_name, selection, isel), fontsize=titlesize, wrap=True)
ax_.tick_params(labelsize=xt_labelsize)
ax_.set_xlabel(
"Total number of draws" if kind == "evolution" else "Quantile", fontsize=ax_labelsize
Expand Down
7 changes: 7 additions & 0 deletions arviz/plots/essplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import xarray as xr

from ..data import convert_to_dataset
from ..labels import BaseLabeller
from ..rcparams import rcParams
from ..sel_utils import xarray_var_iter
from ..stats import ess
Expand All @@ -25,6 +26,7 @@ def plot_ess(
n_points=20,
extra_methods=False,
min_ess=400,
labeller=None,
ax=None,
extra_kwargs=None,
text_kwargs=None,
Expand Down Expand Up @@ -75,6 +77,8 @@ def plot_ess(
Plot mean and sd ESS as horizontal lines. Not taken into account in evolution kind
min_ess: int
Minimum number of ESS desired.
labeller : labeller instance, optional
Class providing the method `make_label_vert` to generate the labels in the plot titles.
ax: numpy array-like of matplotlib axes or bokeh figures, optional
A 2D array of locations into which to plot the densities. If not supplied, Arviz will create
its own array of plot areas (and return it).
Expand Down Expand Up @@ -174,6 +178,8 @@ def plot_ess(
coords = {}
if "chain" in coords or "draw" in coords:
raise ValueError("chain and draw are invalid coordinates for this kind of plot")
if labeller is None:
labeller = BaseLabeller()
extra_methods = False if kind == "evolution" else extra_methods

data = get_coords(convert_to_dataset(idata, group="posterior"), coords)
Expand Down Expand Up @@ -274,6 +280,7 @@ def plot_ess(
n_samples=n_samples,
relative=relative,
min_ess=min_ess,
labeller=labeller,
ylabel=ylabel,
rug=rug,
rug_kind=rug_kind,
Expand Down
16 changes: 10 additions & 6 deletions arviz/sel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,13 @@ def xarray_sel_iter(data, var_names=None, combined=False, skip_dims=None, revers
new_dims = _dims(data, var_name, skip_dims)
vals = [purge_duplicates(data[var_name][dim].values) for dim in new_dims]
dims = _zip_dims(new_dims, vals)
idims = _zip_dims(new_dims, [range(len(v)) for v in vals])
if reverse_selections:
dims = reversed(dims)
idims = reversed(idims)

for selection in dims:
yield var_name, selection
for selection, iselection in zip(dims, idims):
yield var_name, selection, iselection


def xarray_var_iter(data, var_names=None, combined=False, skip_dims=None, reverse_selections=False):
Expand Down Expand Up @@ -174,14 +176,14 @@ def xarray_var_iter(data, var_names=None, combined=False, skip_dims=None, revers
if var_names is None and isinstance(data, xr.DataArray):
data_to_sel = {data.name: data}

for var_name, selection in xarray_sel_iter(
for var_name, selection, iselection in xarray_sel_iter(
data,
var_names=var_names,
combined=combined,
skip_dims=skip_dims,
reverse_selections=reverse_selections,
):
yield var_name, selection, data_to_sel[var_name].sel(**selection).values
yield var_name, selection, iselection, data_to_sel[var_name].sel(**selection).values


def xarray_to_ndarray(data, *, var_names=None, combined=True):
Expand Down Expand Up @@ -213,12 +215,14 @@ def xarray_to_ndarray(data, *, var_names=None, combined=True):

iterator1, iterator2 = tee(xarray_sel_iter(data, var_names=var_names, combined=combined))
vars_and_sel = list(iterator1)
unpacked_var_names = [make_label(var_name, selection) for var_name, selection in vars_and_sel]
unpacked_var_names = [
make_label(var_name, selection) for var_name, selection, _ in vars_and_sel
]

# Merge chains and variables, check dtype to be compatible with divergences data
data0 = data_to_sel[vars_and_sel[0][0]].sel(**vars_and_sel[0][1])
unpacked_data = np.empty((len(unpacked_var_names), data0.size), dtype=data0.dtype)
for idx, (var_name, selection) in enumerate(iterator2):
for idx, (var_name, selection, _) in enumerate(iterator2):
unpacked_data[idx] = data_to_sel[var_name].sel(**selection).values.flatten()

return unpacked_var_names, unpacked_data
7 changes: 4 additions & 3 deletions arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,10 +1132,11 @@ def summary(
raise TypeError(f"InferenceData does not contain group: {group}")
dataset = data[group]
else:
dataset = convert_to_dataset(data, group="posterior")
dataset = get_coords(convert_to_dataset(data, group="posterior"), coords)
var_names = _var_names(var_names, dataset, filter_vars)
dataset = dataset if var_names is None else dataset[var_names]


fmt_group = ("wide", "long", "xarray")
if not isinstance(fmt, str) or (fmt.lower() not in fmt_group):
raise TypeError(f"Invalid format: '{fmt}'. Formatting options are: {fmt_group}")
Expand Down Expand Up @@ -1281,9 +1282,9 @@ def summary(
if fmt.lower() == "wide":
summary_df = pd.DataFrame(np.full((n_vars, n_metrics), np.nan), columns=metric_names)
indexs = []
for i, (var_name, sel, values) in enumerate(xarray_var_iter(joined, skip_dims={"metric"})):
for i, (var_name, sel, isel, values) in enumerate(xarray_var_iter(joined, skip_dims={"metric"})):
summary_df.iloc[i] = values
indexs.append(labeller.make_label_flat(var_name, sel, sel))
indexs.append(labeller.make_label_flat(var_name, sel, isel))
summary_df.index = indexs
elif fmt.lower() == "long":
df = joined.to_dataframe().reset_index().set_index("metric")
Expand Down

0 comments on commit bf05f4e

Please sign in to comment.