Skip to content

Commit

Permalink
Metagroups and InferenceData.map (#1161)
Browse files Browse the repository at this point in the history
* custom validator

* simplify rc file

* start using metagroups

* black

* add also filter_groups capabilities

* add todo to .sel method

* minor fixes

* typo

* add tests and fixes

* send a warning if negation pattern is ignored

* fix tests

makes rctemplate test ignore system templates
  • Loading branch information
OriolAbril authored May 19, 2020
1 parent 18797b8 commit b062cbc
Show file tree
Hide file tree
Showing 7 changed files with 496 additions and 150 deletions.
195 changes: 186 additions & 9 deletions arviz/data/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
import xarray as xr

from ..utils import _subset_list
from ..rcparams import rcParams

SUPPORTED_GROUPS = [
Expand Down Expand Up @@ -216,10 +217,12 @@ def __add__(self, other):
"""Concatenate two InferenceData objects."""
return concat(self, other, copy=True, inplace=False)

def sel(self, inplace=False, chain_prior=False, warmup=False, **kwargs):
def sel(
self, groups=None, filter_groups=None, inplace=False, chain_prior=None, **kwargs,
):
"""Perform an xarray selection on all groups.
Loops over all groups to perform Dataset.sel(key=item)
Loops groups to perform Dataset.sel(key=item)
for every kwarg if key is a dimension of the dataset.
One example could be performing a burn in cut on the InferenceData object
or discarding a chain. The selection is performed on all relevant groups (like
Expand All @@ -228,15 +231,16 @@ def sel(self, inplace=False, chain_prior=False, warmup=False, **kwargs):
Parameters
----------
inplace : bool, optional
groups: str or list of str, optional
Groups where the selection is to be applied. Can either be group names
or metagroup names.
inplace: bool, optional
If ``True``, modify the InferenceData object inplace,
otherwise, return the modified copy.
chain_prior: bool, optional
chain_prior: bool, optional, deprecated
If ``False``, do not select prior related groups using ``chain`` dim.
Otherwise, use selection on ``chain`` if present.
warmup: bool, optional
If ``False``, do not select warmup groups.
**kwargs : mapping
Otherwise, use selection on ``chain`` if present. Default=False
**kwargs: mapping
It must be accepted by Dataset.sel().
Returns
Expand Down Expand Up @@ -269,8 +273,18 @@ def sel(self, inplace=False, chain_prior=False, warmup=False, **kwargs):
...: print(idata_subset.observed_data.coords)
"""
if chain_prior is not None:
warnings.warn(
"chain_prior has been deprecated. Use groups argument and "
"rcParams['data.metagroups'] instead.",
DeprecationWarning,
)
else:
chain_prior = False
groups = self._group_names(groups, filter_groups)

out = self if inplace else deepcopy(self)
for group in self._groups_all if warmup else self._groups:
for group in groups:
dataset = getattr(self, group)
valid_keys = set(kwargs.keys()).intersection(dataset.dims)
if not chain_prior and "prior" in group:
Expand All @@ -282,6 +296,169 @@ def sel(self, inplace=False, chain_prior=False, warmup=False, **kwargs):
else:
return out

def _group_names(self, groups, filter_groups=None):
"""Handle expansion of group names input across arviz.
Parameters
----------
groups: str, list of str or None
group or metagroup names.
idata: xarray.Dataset
Posterior data in an xarray
filter_groups: {None, "like", "regex"}, optional, default=None
If `None` (default), interpret groups as the real group or metagroup names.
If "like", interpret groups as substrings of the real group or metagroup names.
If "regex", interpret groups as regular expressions on the real group or
metagroup names. A la `pandas.filter`.
Returns
-------
groups: list
"""
all_groups = self._groups_all
if groups is None:
return all_groups
if isinstance(groups, str):
groups = [groups]
sel_groups = []
metagroups = rcParams["data.metagroups"]
for group in groups:
if group[0] == "~":
sel_groups.extend(
[f"~{item}" for item in metagroups[group[1:]] if item in all_groups]
if group[1:] in metagroups
else [group]
)
else:
sel_groups.extend(
[item for item in metagroups[group] if item in all_groups]
if group in metagroups
else [group]
)

try:
group_names = _subset_list(sel_groups, all_groups, filter_items=filter_groups)
except KeyError as err:
msg = " ".join(("groups:", f"{err}", "in InferenceData"))
raise KeyError(msg)
return group_names

def map(self, fun, groups=None, filter_groups=None, inplace=False, args=None, **kwargs):
"""Apply a function to multiple groups.
Parameters
----------
fun: callable
Function to be applied to each group.
groups: str or list of str, optional
Groups where the selection is to be applied. Can either be group names
or metagroup names.
inplace: bool, optional
If ``True``, modify the InferenceData object inplace,
otherwise, return the modified copy.
args: array_like, optional
Positional arguments passed to ``fun``. Assumes the function is called as
``fun(dataset, *args, **kwargs)``.
**kwargs: mapping, optional
Keyword arguments passed to ``fun``.
Returns
-------
InferenceData
A new InferenceData object by default.
When `inplace==True` perform selection in place and return `None`
Examples
--------
Shift observed_data, prior_predictive and posterior_predictive.
.. ipython::
In [1]: import arviz as az
...: idata = az.load_arviz_data("non_centered_eight")
...: idata_shifted_obs = idata.map(lambda x: x + 3, groups="observed_RVs")
...: print(idata_shifted_obs.observed_data)
...: print(idata_shifted_obs.posterior_predictive)
"""
if args is None:
args = []
groups = self._group_names(groups, filter_groups)

out = self if inplace else deepcopy(self)
for group in groups:
dataset = getattr(self, group)
dataset = fun(dataset, *args, **kwargs)
setattr(out, group, dataset)
if inplace:
return None
else:
return out

def _wrap_xarray_method(
self, method, groups=None, filter_groups=None, inplace=False, args=None, **kwargs
):
"""Extend and xarray.Dataset method to InferenceData object.
Parameters
----------
method: str
Method to be extended. Must be a ``xarray.Dataset`` method.
groups: str or list of str, optional
Groups where the selection is to be applied. Can either be group names
or metagroup names.
inplace: bool, optional
If ``True``, modify the InferenceData object inplace,
otherwise, return the modified copy.
**kwargs: mapping, optional
Keyword arguments passed to the xarray Dataset method.
Returns
-------
InferenceData
A new InferenceData object by default.
When `inplace==True` perform selection in place and return `None`
Examples
--------
Compute the mean of `posterior_groups`:
.. ipython::
In [1]: import arviz as az
...: idata = az.load_arviz_data("non_centered_eight")
...: idata_means = idata._wrap_xarray_method("mean", groups="latent_RVs")
...: print(idata_means.posterior)
...: print(idata_means.observed_data)
.. ipython::
In [1]: idata_stack = idata._wrap_xarray_method(
...: "stack",
...: groups=["posterior_groups", "prior_groups"],
...: sample=["chain", "draw"]
...: )
...: print(idata_stack.posterior)
...: print(idata_stack.prior)
...: print(idata_stack.observed_data)
"""
if args is None:
args = []
groups = self._group_names(groups, filter_groups)

method = getattr(xr.Dataset, method)

out = self if inplace else deepcopy(self)
for group in groups:
dataset = getattr(self, group)
dataset = method(dataset, *args, **kwargs)
setattr(out, group, dataset)
if inplace:
return None
else:
return out


# pylint: disable=protected-access, inconsistent-return-statements
def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
Expand Down
70 changes: 58 additions & 12 deletions arviz/rcparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
import re
import pprint
import warnings
import logging
import locale
from collections.abc import MutableMapping
Expand Down Expand Up @@ -165,6 +166,24 @@ def _validate_bokeh_marker(value):
return value


def _validate_dict_of_lists(values):
if isinstance(values, dict):
return {key: tuple(item) for key, item in values.items()}
else:
validated_dict = {}
for value in values:
tup = value.split(":", 1)
if len(tup) != 2:
raise ValueError(f"Could not interpret '{value}' as key: list or str")
key, vals = tup
key = key.strip(' "')
vals = [val.strip(' "') for val in vals.strip(" [],").split(",")]
if key in validated_dict:
warnings.warn(f"Repeated key {key} when validating dict of lists")
validated_dict[key] = tuple(vals)
return validated_dict


def make_iterable_validator(scalar_validator, length=None, allow_none=False, allow_auto=False):
"""Validate value is an iterable datatype."""
# based on matplotlib's _listify_validator function
Expand All @@ -190,10 +209,22 @@ def validate_iterable(value):
_validate_float_or_none, length=2, allow_none=True, allow_auto=True
)

METAGROUPS = {
"posterior_groups": ["posterior", "posterior_predictive", "sample_stats", "log_likelihood"],
"prior_groups": ["prior", "prior_predictive", "sample_stats_prior"],
"posterior_groups_warmup": [
"_warmup_posterior",
"_warmup_posterior_predictive",
"_warmup_sample_stats",
],
"latent_vars": ["posterior", "prior"],
"observed_vars": ["posterior_predictive", "observed_data", "prior_predictive"],
}

defaultParams = { # pylint: disable=invalid-name
"data.http_protocol": ("https", _make_validate_choice({"https", "http"})),
"data.load": ("lazy", _make_validate_choice({"lazy", "eager"})),
"data.metagroups": (METAGROUPS, _validate_dict_of_lists),
"data.index_origin": (0, _make_validate_choice({0, 1}, typeof=int)),
"data.save_warmup": (False, _validate_boolean),
"plot.backend": ("matplotlib", _make_validate_choice({"matplotlib", "bokeh"})),
Expand Down Expand Up @@ -407,20 +438,33 @@ def read_rcfile(fname):
config = RcParams()
with open(fname, "r") as rcfile:
try:
multiline = False
for line_no, line in enumerate(rcfile, 1):
strippedline = line.split("#", 1)[0].strip()
if not strippedline:
continue
tup = strippedline.split(":", 1)
if len(tup) != 2:
error_details = _error_details_fmt % (line_no, line, fname)
_log.warning("Illegal %s", error_details)
continue
key, val = tup
key = key.strip()
val = val.strip()
if key in config:
_log.warning("Duplicate key in file %r line #%d.", fname, line_no)
if multiline:
if strippedline == "}":
multiline = False
val = aux_val
else:
aux_val.append(strippedline)
continue
else:
tup = strippedline.split(":", 1)
if len(tup) != 2:
error_details = _error_details_fmt % (line_no, line, fname)
_log.warning("Illegal %s", error_details)
continue
key, val = tup
key = key.strip()
val = val.strip()
if key in config:
_log.warning("Duplicate key in file %r line #%d.", fname, line_no)
if key in {"data.metagroups"}:
aux_val = []
multiline = True
continue
try:
config[key] = val
except ValueError as verr:
Expand All @@ -439,9 +483,11 @@ def read_rcfile(fname):
return config


def rc_params():
def rc_params(ignore_files=False):
"""Read and validate arvizrc file."""
fname = get_arviz_rcfile()
fname = None
if not ignore_files:
fname = get_arviz_rcfile()
defaults = RcParams([(key, default) for key, (default, _) in defaultParams.items()])
if fname is not None:
file_defaults = read_rcfile(fname)
Expand Down
Loading

0 comments on commit b062cbc

Please sign in to comment.