From 07912832abae027e3ab839be50373d7b8afb8903 Mon Sep 17 00:00:00 2001 From: "Robert P. Goldman" Date: Mon, 15 Mar 2021 10:49:27 -0500 Subject: [PATCH 01/18] Modify compare() docstring, and var_name param and error-check. Check to see if the log_likelihood groups have more than one data variable, and if so, require the var_name parameter. This avoids having a less understandable error message crop up later. Introduce `var_name` parameter, and pass it through to the IC function invoked by compare. --- arviz/stats/stats.py | 41 +++++++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 0b32afbcc0..515fd42c09 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -2,13 +2,14 @@ """Statistical functions in ArviZ.""" import warnings from copy import deepcopy -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Dict import numpy as np import pandas as pd import scipy.stats as st import xarray as xr from scipy.optimize import minimize +from typing_extensions import Literal from arviz import _log from ..data import InferenceData, convert_to_dataset, convert_to_inference_data @@ -27,9 +28,6 @@ from ..sel_utils import xarray_var_iter from ..labels import BaseLabeller -if TYPE_CHECKING: - from typing_extensions import Literal - __all__ = [ "apply_test_function", @@ -45,7 +43,14 @@ def compare( - dataset_dict, ic=None, method="stacking", b_samples=1000, alpha=1, seed=None, scale=None + dataset_dict: Optional[Dict[str, InferenceData]], + ic: Optional[Literal["loo", "waic"]] = None, + method: Literal["stacking", "BB-pseudo-BMA", "pseudo-MA"] = "stacking", + b_samples: int = 1000, + alpha: float = 1, + seed=None, + scale: Optional[Literal["log", "negative_log", "deviance"]] = None, + var_name: Optional[str] = None, ): r"""Compare models based on PSIS-LOO `loo` or WAIC `waic` cross-validation. @@ -58,10 +63,10 @@ def compare( ---------- dataset_dict: dict[str] -> InferenceData A dictionary of model names and InferenceData objects - ic: str + ic: str, optional Information Criterion (PSIS-LOO `loo` or WAIC `waic`) used to compare models. Defaults to ``rcParams["stats.information_criterion"]``. - method: str + method: str, optional Method used to estimate the weights for each model. Available options are: - 'stacking' : (default) stacking of predictive distributions. @@ -71,18 +76,18 @@ def compare( weighting, without Bootstrap stabilization (not recommended). For more information read https://arxiv.org/abs/1704.02030 - b_samples: int + b_samples: int, optional default = 1000 Number of samples taken by the Bayesian bootstrap estimation. Only useful when method = 'BB-pseudo-BMA'. - alpha: float + alpha: float, optional The shape parameter in the Dirichlet distribution used for the Bayesian bootstrap. Only useful when method = 'BB-pseudo-BMA'. When alpha=1 (default), the distribution is uniform on the simplex. A smaller alpha will keeps the final weights more away from 0 and 1. - seed: int or np.random.RandomState instance + seed: int or np.random.RandomState instance, optional If int or RandomState, use it for seeding Bayesian bootstrap. Only useful when method = 'BB-pseudo-BMA'. Default None the global np.random state is used. - scale: str + scale: str, optional Output scale for IC. Available options are: - `log` : (default) log-score (after Vehtari et al. (2017)) @@ -91,6 +96,9 @@ def compare( A higher log-score (or a lower deviance) indicates a model with better predictive accuracy. + var_name: str, optional + If there is more than a single observed variable in the ``InferenceData``, which + should be used as the basis for comparison. Returns ------- @@ -206,7 +214,16 @@ def compare( names = [] for name, dataset in dataset_dict.items(): names.append(name) - ics = ics.append([ic_func(dataset, pointwise=True, scale=scale)]) + if len(dataset.log_likelihood.data_vars) > 1: + raise ValueError( + ( + f"Dataset {name} has multiple variables in its log_likelihood.\n" + "In such cases, the var_name parameter is mandatory." + ) + ) + # Here is where the IC function is actually computed -- the rest of this + # function is argument processing and return value formatting + ics = ics.append([ic_func(dataset, pointwise=True, scale=scale, var_name=var_name)]) ics.index = names ics.sort_values(by=ic, inplace=True, ascending=ascending) ics[ic_i] = ics[ic_i].apply(lambda x: x.values.flatten()) From be8c21da8e76388fd551ae82aafd6b9340532a43 Mon Sep 17 00:00:00 2001 From: "Robert P. Goldman" Date: Mon, 15 Mar 2021 12:32:38 -0500 Subject: [PATCH 02/18] Fix error in argument test. --- arviz/stats/stats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 515fd42c09..103c66df71 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -214,7 +214,7 @@ def compare( names = [] for name, dataset in dataset_dict.items(): names.append(name) - if len(dataset.log_likelihood.data_vars) > 1: + if len(dataset.log_likelihood.data_vars) > 1 and var_name is None: raise ValueError( ( f"Dataset {name} has multiple variables in its log_likelihood.\n" From ce480fde14ba6bbc952d6eeddd47c4ee61781654 Mon Sep 17 00:00:00 2001 From: "Robert P. Goldman" Date: Mon, 15 Mar 2021 15:08:16 -0500 Subject: [PATCH 03/18] More explicit error message. Previously, if we got a TypeError when trying to find a log_likelihood in from_pymc3() that TypeError would be squashed completely. Now we will echo it to the log before handling it. --- arviz/data/io_pymc3.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/arviz/data/io_pymc3.py b/arviz/data/io_pymc3.py index 79bba68eb8..f0b8b40855 100644 --- a/arviz/data/io_pymc3.py +++ b/arviz/data/io_pymc3.py @@ -89,7 +89,8 @@ def __init__( # this permits us to get the model from command-line argument or from with model: try: self.model = self.pymc3.modelcontext(model or self.model) - except TypeError: + except TypeError as e: + _log.error("Got error %s trying to find log_likelihood in translation.", e) self.model = None if self.model is None: @@ -249,12 +250,17 @@ def _extract_log_likelihood(self, trace): "`pip install pymc3>=3.8` or `conda install -c conda-forge pymc3>=3.8`." ) from err for var, log_like_fun in cached: - for k, chain in enumerate(trace.chains): - log_like_chain = [ - self.log_likelihood_vals_point(point, var, log_like_fun) - for point in trace.points([chain]) - ] - log_likelihood_dict.insert(var.name, np.stack(log_like_chain), k) + try: + for k, chain in enumerate(trace.chains): + log_like_chain = [ + self.log_likelihood_vals_point(point, var, log_like_fun) + for point in trace.points([chain]) + ] + log_likelihood_dict.insert(var.name, np.stack(log_like_chain), k) + except TypeError as e: + raise TypeError( + *tuple(["While computing log-likelihood for {var}: "] + list(e.args)) + ) from e return log_likelihood_dict.trace_dict @requires("trace") From 63c4d06e63dec5e23180a50dd73a6607a0b018a9 Mon Sep 17 00:00:00 2001 From: "Robert P. Goldman" Date: Sat, 20 Mar 2021 12:55:04 -0500 Subject: [PATCH 04/18] Fix type signature of compare(). Took @OriolAbril correction. --- arviz/stats/stats.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 103c66df71..d3537690a6 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -2,7 +2,7 @@ """Statistical functions in ArviZ.""" import warnings from copy import deepcopy -from typing import List, Optional, Tuple, Union, Dict +from typing import List, Optional, Tuple, Union, Mapping import numpy as np import pandas as pd @@ -43,7 +43,7 @@ def compare( - dataset_dict: Optional[Dict[str, InferenceData]], + dataset_dict: Mapping[str, InferenceData], ic: Optional[Literal["loo", "waic"]] = None, method: Literal["stacking", "BB-pseudo-BMA", "pseudo-MA"] = "stacking", b_samples: int = 1000, From 37c2e8144a418b2e7557ad9a2c1274df27509286 Mon Sep 17 00:00:00 2001 From: "Robert P. Goldman" Date: Sat, 20 Mar 2021 14:33:46 -0500 Subject: [PATCH 05/18] Annotated IC function errors from compare(). Catch errors from IC functions invoked inside compare and annotate them with information about the source `InferenceData` object. --- arviz/stats/stats.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index d3537690a6..f25ba7f55a 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -150,6 +150,11 @@ def compare( loo : Compute the Pareto Smoothed importance sampling Leave One Out cross-validation. waic : Compute the widely applicable information criterion. + Notes + ----- + If the `log_likelihood` group is not present in the input datasets, ArviZ will attempt + to compute it. + """ names = list(dataset_dict.keys()) scale = rcParams["stats.ic_scale"] if scale is None else scale.lower() @@ -214,16 +219,14 @@ def compare( names = [] for name, dataset in dataset_dict.items(): names.append(name) - if len(dataset.log_likelihood.data_vars) > 1 and var_name is None: - raise ValueError( - ( - f"Dataset {name} has multiple variables in its log_likelihood.\n" - "In such cases, the var_name parameter is mandatory." - ) - ) - # Here is where the IC function is actually computed -- the rest of this - # function is argument processing and return value formatting - ics = ics.append([ic_func(dataset, pointwise=True, scale=scale, var_name=var_name)]) + try: + # Here is where the IC function is actually computed -- the rest of this + # function is argument processing and return value formatting + ics = ics.append([ic_func(dataset, pointwise=True, scale=scale, var_name=var_name)]) + except Exception as e: + raise e.__class__( + f"Encountered error trying to compute {ic} from model {name}: {e}" + ) from e ics.index = names ics.sort_values(by=ic, inplace=True, ascending=ascending) ics[ic_i] = ics[ic_i].apply(lambda x: x.values.flatten()) From b50b6de3614a5440bda8d294bf4c4f1ee6a81381 Mon Sep 17 00:00:00 2001 From: "Robert P. Goldman" Date: Sat, 20 Mar 2021 14:53:13 -0500 Subject: [PATCH 06/18] Remove incorrect docstring. --- arviz/stats/stats.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index f25ba7f55a..7b9415b161 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -149,12 +149,6 @@ def compare( -------- loo : Compute the Pareto Smoothed importance sampling Leave One Out cross-validation. waic : Compute the widely applicable information criterion. - - Notes - ----- - If the `log_likelihood` group is not present in the input datasets, ArviZ will attempt - to compute it. - """ names = list(dataset_dict.keys()) scale = rcParams["stats.ic_scale"] if scale is None else scale.lower() From e8503856f37f78a0bfe27c49d5200960684873e9 Mon Sep 17 00:00:00 2001 From: "Robert P. Goldman" Date: Fri, 26 Mar 2021 15:11:07 -0500 Subject: [PATCH 07/18] pylint --- arviz/data/io_pymc3.py | 4 ++-- arviz/stats/stats.py | 6 ++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/arviz/data/io_pymc3.py b/arviz/data/io_pymc3.py index f0b8b40855..fb8f806840 100644 --- a/arviz/data/io_pymc3.py +++ b/arviz/data/io_pymc3.py @@ -89,7 +89,7 @@ def __init__( # this permits us to get the model from command-line argument or from with model: try: self.model = self.pymc3.modelcontext(model or self.model) - except TypeError as e: + except TypeError as e: # pylint: disable=invalid-name _log.error("Got error %s trying to find log_likelihood in translation.", e) self.model = None @@ -257,7 +257,7 @@ def _extract_log_likelihood(self, trace): for point in trace.points([chain]) ] log_likelihood_dict.insert(var.name, np.stack(log_like_chain), k) - except TypeError as e: + except TypeError as e: # pylint: disable=invalid-name raise TypeError( *tuple(["While computing log-likelihood for {var}: "] + list(e.args)) ) from e diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 7b9415b161..25df699717 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -217,10 +217,8 @@ def compare( # Here is where the IC function is actually computed -- the rest of this # function is argument processing and return value formatting ics = ics.append([ic_func(dataset, pointwise=True, scale=scale, var_name=var_name)]) - except Exception as e: - raise e.__class__( - f"Encountered error trying to compute {ic} from model {name}: {e}" - ) from e + except Exception as e: # pylint: disable=invalid-name + raise e.__class__(f"Encountered error trying to compute {ic} from model {name}.") from e ics.index = names ics.sort_values(by=ic, inplace=True, ascending=ascending) ics[ic_i] = ics[ic_i].apply(lambda x: x.values.flatten()) From c3672704a34dfbc8b41247ae9bba4ab883368d1d Mon Sep 17 00:00:00 2001 From: "Robert P. Goldman" Date: Fri, 26 Mar 2021 16:51:41 -0500 Subject: [PATCH 08/18] mypy fixes. --- arviz/rcparams.py | 9 ++++++--- arviz/stats/stats.py | 20 ++++++++++++++------ 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/arviz/rcparams.py b/arviz/rcparams.py index f8ffebbd8d..f0ac060e2c 100644 --- a/arviz/rcparams.py +++ b/arviz/rcparams.py @@ -8,12 +8,15 @@ import warnings from collections.abc import MutableMapping from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, Literal, get_args import numpy as np _log = logging.getLogger(__name__) +ScaleKeyword = Literal["log", "negative_log", "deviance"] +ICKeyword = Literal["loo", "waic"] + def _make_validate_choice(accepted_values, allow_none=False, typeof=str): """Validate value is in accepted_values. @@ -274,9 +277,9 @@ def validate_iterable(value): "plot.matplotlib.constrained_layout": (True, _validate_boolean), "plot.matplotlib.show": (False, _validate_boolean), "stats.hdi_prob": (0.94, _validate_probability), - "stats.information_criterion": ("loo", _make_validate_choice({"waic", "loo"})), + "stats.information_criterion": ("loo", _make_validate_choice(set(get_args(ICKeyword)))), "stats.ic_pointwise": (False, _validate_boolean), - "stats.ic_scale": ("log", _make_validate_choice({"deviance", "log", "negative_log"})), + "stats.ic_scale": ("log", _make_validate_choice(set(get_args(ScaleKeyword)))), } diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 25df699717..353856bc9a 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -2,7 +2,7 @@ """Statistical functions in ArviZ.""" import warnings from copy import deepcopy -from typing import List, Optional, Tuple, Union, Mapping +from typing import List, Optional, Tuple, Union, Mapping, cast, get_args import numpy as np import pandas as pd @@ -13,7 +13,7 @@ from arviz import _log from ..data import InferenceData, convert_to_dataset, convert_to_inference_data -from ..rcparams import rcParams +from ..rcparams import rcParams, ScaleKeyword, ICKeyword from ..utils import Numba, _numba_var, _var_names, get_coords from .density_utils import get_bins as _get_bins from .density_utils import histogram as _histogram @@ -44,12 +44,12 @@ def compare( dataset_dict: Mapping[str, InferenceData], - ic: Optional[Literal["loo", "waic"]] = None, + ic: Optional[ICKeyword] = None, method: Literal["stacking", "BB-pseudo-BMA", "pseudo-MA"] = "stacking", b_samples: int = 1000, alpha: float = 1, seed=None, - scale: Optional[Literal["log", "negative_log", "deviance"]] = None, + scale: Optional[ScaleKeyword] = None, var_name: Optional[str] = None, ): r"""Compare models based on PSIS-LOO `loo` or WAIC `waic` cross-validation. @@ -151,7 +151,11 @@ def compare( waic : Compute the widely applicable information criterion. """ names = list(dataset_dict.keys()) - scale = rcParams["stats.ic_scale"] if scale is None else scale.lower() + if scale is not None: + scale = cast(ScaleKeyword, scale.lower()) + else: + scale = cast(ScaleKeyword, rcParams["stats.ic_scale"]) + assert scale in get_args(ScaleKeyword) if scale == "log": scale_value = 1 ascending = False @@ -162,7 +166,11 @@ def compare( scale_value = -2 ascending = True - ic = rcParams["stats.information_criterion"] if ic is None else ic.lower() + if ic is None: + ic = cast(ICKeyword, rcParams["stats.information_criterion"]) + else: + ic = cast(ICKeyword, ic.lower()) + assert ic in get_args(ICKeyword) if ic == "loo": ic_func = loo df_comp = pd.DataFrame( From cf2e8b530a1ad907c11b4f046e5222a0c671db58 Mon Sep 17 00:00:00 2001 From: "Robert P. Goldman" Date: Sat, 27 Mar 2021 10:27:23 -0500 Subject: [PATCH 09/18] Make "e" acceptable variable name. Many examples show this as a good variable name for exceptions, particularly in "except as e:" --- .pylintrc | 3 ++- arviz/data/io_pymc3.py | 4 ++-- arviz/stats/stats.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.pylintrc b/.pylintrc index f629f9ed8d..53675c1a99 100644 --- a/.pylintrc +++ b/.pylintrc @@ -244,6 +244,7 @@ function-naming-style=snake_case # Good variable names which should always be accepted, separated by a comma good-names=b, + e, i, j, k, @@ -265,7 +266,7 @@ good-names=b, ok, sd, tr, - eta, + eta, Run, _log, _ diff --git a/arviz/data/io_pymc3.py b/arviz/data/io_pymc3.py index c640b32fa7..6d95fbed46 100644 --- a/arviz/data/io_pymc3.py +++ b/arviz/data/io_pymc3.py @@ -89,7 +89,7 @@ def __init__( # this permits us to get the model from command-line argument or from with model: try: self.model = self.pymc3.modelcontext(model or self.model) - except TypeError as e: # pylint: disable=invalid-name + except TypeError as e: _log.error("Got error %s trying to find log_likelihood in translation.", e) self.model = None @@ -259,7 +259,7 @@ def _extract_log_likelihood(self, trace): for point in trace.points([chain]) ] log_likelihood_dict.insert(var.name, np.stack(log_like_chain), k) - except TypeError as e: # pylint: disable=invalid-name + except TypeError as e: raise TypeError( *tuple(["While computing log-likelihood for {var}: "] + list(e.args)) ) from e diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 42402a8fac..5d78216b92 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -227,7 +227,7 @@ def compare( # Here is where the IC function is actually computed -- the rest of this # function is argument processing and return value formatting ics = ics.append([ic_func(dataset, pointwise=True, scale=scale, var_name=var_name)]) - except Exception as e: # pylint: disable=invalid-name + except Exception as e: raise e.__class__(f"Encountered error trying to compute {ic} from model {name}.") from e ics.index = names ics.sort_values(by=ic, inplace=True, ascending=ascending) From 7b0ad6e7fff2f77409f456a0579c39d36f03d770 Mon Sep 17 00:00:00 2001 From: "Robert P. Goldman" Date: Sat, 27 Mar 2021 10:30:32 -0500 Subject: [PATCH 10/18] Changelog update. --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 96eff9724e..da8bd5c3de 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ * Added interactive legend to bokeh `forestplot` ([1591](https://github.com/arviz-devs/arviz/pull/1591)) * Added interactive legend to bokeh `ppcplot` ([1602](https://github.com/arviz-devs/arviz/pull/1602)) * Added `data.log_likelihood`, `stats.ic_compare_method` and `plot.density_kind` to `rcParams` ([1611](https://github.com/arviz-devs/arviz/pull/1611)) +* Improve error messages in `stats.compare()`, and `var_name` parameter. ([1616](https://github.com/arviz-devs/arviz/pull/1616)) ### Maintenance and fixes * Enforced using coordinate values as default labels ([1201](https://github.com/arviz-devs/arviz/pull/1201)) From 31be0802e1b37577798f5a23b409a48048f50b41 Mon Sep 17 00:00:00 2001 From: "Robert P. Goldman" Date: Sat, 27 Mar 2021 10:41:53 -0500 Subject: [PATCH 11/18] Backward-compatibility fix. --- arviz/rcparams.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/arviz/rcparams.py b/arviz/rcparams.py index 3341152037..6c1506eb43 100644 --- a/arviz/rcparams.py +++ b/arviz/rcparams.py @@ -8,7 +8,8 @@ import warnings from collections.abc import MutableMapping from pathlib import Path -from typing import Any, Dict, Literal, get_args +from typing import Any, Dict +from typing_extensions import Literal, get_args import numpy as np From 6edc71eb52488599e292424d410c3dd427093aaf Mon Sep 17 00:00:00 2001 From: "Robert P. Goldman" Date: Sat, 27 Mar 2021 12:12:24 -0500 Subject: [PATCH 12/18] Fix test. Error now caught sooner. --- arviz/tests/base_tests/test_stats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arviz/tests/base_tests/test_stats.py b/arviz/tests/base_tests/test_stats.py index ce7b4ad413..a28be61d65 100644 --- a/arviz/tests/base_tests/test_stats.py +++ b/arviz/tests/base_tests/test_stats.py @@ -151,7 +151,7 @@ def test_compare_same(centered_eight, multidim_models, method, multidim): def test_compare_unknown_ic_and_method(centered_eight, non_centered_eight): model_dict = {"centered": centered_eight, "non_centered": non_centered_eight} - with pytest.raises(NotImplementedError): + with pytest.raises(ValueError): compare(model_dict, ic="Unknown", method="stacking") with pytest.raises(ValueError): compare(model_dict, ic="loo", method="Unknown") From 7eb162cfe4a1753a3efbdd0971593780f44bb7f3 Mon Sep 17 00:00:00 2001 From: "Robert P. Goldman" Date: Sat, 27 Mar 2021 12:27:41 -0500 Subject: [PATCH 13/18] Fix mypy issues. --- arviz/stats/stats.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 5d78216b92..d3a7410eed 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -2,14 +2,14 @@ """Statistical functions in ArviZ.""" import warnings from copy import deepcopy -from typing import List, Optional, Tuple, Union, Mapping, cast, get_args +from typing import List, Optional, Tuple, Union, Mapping, cast, Callable import numpy as np import pandas as pd import scipy.stats as st import xarray as xr from scipy.optimize import minimize -from typing_extensions import Literal +from typing_extensions import Literal, get_args from arviz import _log from ..data import InferenceData, convert_to_dataset, convert_to_inference_data @@ -156,6 +156,11 @@ def compare( scale = cast(ScaleKeyword, scale.lower()) else: scale = cast(ScaleKeyword, rcParams["stats.ic_scale"]) + if scale not in get_args(ScaleKeyword): + raise ValueError( + f"{scale} is not a valid value for scale: must be in {get_args(ScaleKeyword)}" + ) + assert scale in get_args(ScaleKeyword) if scale == "log": scale_value = 1 @@ -171,9 +176,10 @@ def compare( ic = cast(ICKeyword, rcParams["stats.information_criterion"]) else: ic = cast(ICKeyword, ic.lower()) - assert ic in get_args(ICKeyword) + if ic not in get_args(ICKeyword): + raise ValueError(f"{ic} is not a valid value for ic: must be in {get_args(ICKeyword)}") if ic == "loo": - ic_func = loo + ic_func: Callable = loo df_comp = pd.DataFrame( index=names, columns=[ @@ -187,7 +193,7 @@ def compare( "warning", "loo_scale", ], - dtype=np.float, + dtype=np.float_, ) scale_col = "loo_scale" elif ic == "waic": @@ -205,7 +211,7 @@ def compare( "warning", "waic_scale", ], - dtype=np.float, + dtype=np.float_, ) scale_col = "waic_scale" else: @@ -1300,7 +1306,9 @@ def summary( n_vars = np.sum([joined[var].size // n_metrics for var in joined.data_vars]) if fmt.lower() == "wide": - summary_df = pd.DataFrame(np.full((n_vars, n_metrics), np.nan), columns=metric_names) + summary_df = pd.DataFrame( + (np.full(cast(Tuple[int, int], (n_vars, n_metrics)), np.nan)), columns=metric_names + ) indexs = [] for i, (var_name, sel, isel, values) in enumerate( xarray_var_iter(joined, skip_dims={"metric"}) From 6d27fd51914329c32ddbd3896cae52d952cecbc6 Mon Sep 17 00:00:00 2001 From: "Robert P. Goldman" Date: Sat, 27 Mar 2021 14:00:12 -0500 Subject: [PATCH 14/18] Python 3.5 and 3.6 compatibility. --- arviz/rcparams.py | 21 ++++++++++++++++++--- arviz/stats/stats.py | 21 +++++++++++++-------- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/arviz/rcparams.py b/arviz/rcparams.py index 6c1506eb43..b22ba80f7d 100644 --- a/arviz/rcparams.py +++ b/arviz/rcparams.py @@ -9,7 +9,14 @@ from collections.abc import MutableMapping from pathlib import Path from typing import Any, Dict -from typing_extensions import Literal, get_args +from typing_extensions import Literal + +NO_GET_ARGS: bool = False +try: + from typing_extensions import get_args +except ImportError: + NO_GET_ARGS = True + import numpy as np @@ -281,9 +288,17 @@ def validate_iterable(value): "plot.matplotlib.constrained_layout": (True, _validate_boolean), "plot.matplotlib.show": (False, _validate_boolean), "stats.hdi_prob": (0.94, _validate_probability), - "stats.information_criterion": ("loo", _make_validate_choice(set(get_args(ICKeyword)))), + "stats.information_criterion": ( + "loo", + _make_validate_choice({"loo", "waic"} if NO_GET_ARGS else set(get_args(ICKeyword))), + ), "stats.ic_pointwise": (False, _validate_boolean), - "stats.ic_scale": ("log", _make_validate_choice(set(get_args(ScaleKeyword)))), + "stats.ic_scale": ( + "log", + _make_validate_choice( + {"log", "negative_log", "deviance"} if NO_GET_ARGS else set(get_args(ScaleKeyword)) + ), + ), "stats.ic_compare_method": ( "stacking", _make_validate_choice({"stacking", "bb-pseudo-bma", "pseudo-bma"}), diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index d3a7410eed..59ba7f1945 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -9,7 +9,13 @@ import scipy.stats as st import xarray as xr from scipy.optimize import minimize -from typing_extensions import Literal, get_args +from typing_extensions import Literal + +NO_GET_ARGS: bool = False +try: + from typing_extensions import get_args +except ImportError: + NO_GET_ARGS = True from arviz import _log from ..data import InferenceData, convert_to_dataset, convert_to_inference_data @@ -156,12 +162,10 @@ def compare( scale = cast(ScaleKeyword, scale.lower()) else: scale = cast(ScaleKeyword, rcParams["stats.ic_scale"]) - if scale not in get_args(ScaleKeyword): - raise ValueError( - f"{scale} is not a valid value for scale: must be in {get_args(ScaleKeyword)}" - ) + allowable = ["log", "negative_log", "deviance"] if NO_GET_ARGS else get_args(ScaleKeyword) + if scale not in allowable: + raise ValueError(f"{scale} is not a valid value for scale: must be in {allowable}") - assert scale in get_args(ScaleKeyword) if scale == "log": scale_value = 1 ascending = False @@ -176,8 +180,9 @@ def compare( ic = cast(ICKeyword, rcParams["stats.information_criterion"]) else: ic = cast(ICKeyword, ic.lower()) - if ic not in get_args(ICKeyword): - raise ValueError(f"{ic} is not a valid value for ic: must be in {get_args(ICKeyword)}") + allowable = ["loo", "waic"] if NO_GET_ARGS else get_args(ICKeyword) + if ic not in allowable: + raise ValueError(f"{ic} is not a valid value for ic: must be in {allowable}") if ic == "loo": ic_func: Callable = loo df_comp = pd.DataFrame( From 9e61c585f6cc1e81d95f09d66089f05aa5320e21 Mon Sep 17 00:00:00 2001 From: "Robert P. Goldman" Date: Sat, 27 Mar 2021 20:21:14 -0500 Subject: [PATCH 15/18] Whitespace issue caught by Oriol. --- .pylintrc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pylintrc b/.pylintrc index 53675c1a99..fda2139e63 100644 --- a/.pylintrc +++ b/.pylintrc @@ -266,7 +266,7 @@ good-names=b, ok, sd, tr, - eta, + eta, Run, _log, _ From 1a31db273fac99098849a04c889ab42f6d8185f3 Mon Sep 17 00:00:00 2001 From: "Robert P. Goldman" Date: Mon, 29 Mar 2021 12:02:51 -0500 Subject: [PATCH 16/18] Test for error-trapping. Make sure we check for multiple observed variables in compare() and that support for "var_name" works. --- arviz/tests/base_tests/test_stats.py | 31 ++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/arviz/tests/base_tests/test_stats.py b/arviz/tests/base_tests/test_stats.py index a28be61d65..b95148b0d5 100644 --- a/arviz/tests/base_tests/test_stats.py +++ b/arviz/tests/base_tests/test_stats.py @@ -40,6 +40,23 @@ def non_centered_eight(): return non_centered_eight +@pytest.fixture +def multivariable_log_likelihood(centered_eight): + centered_eight = centered_eight.copy() + centered_eight.add_groups({"log_likelihood": centered_eight.sample_stats.log_likelihood}) + centered_eight.log_likelihood = centered_eight.log_likelihood.rename_vars( + {"log_likelihood": "obs"} + ) + new_arr = DataArray( + np.zeros(centered_eight.log_likelihood["obs"].values.shape), + dims=["chain", "draw", "school"], + coords=centered_eight.log_likelihood.coords, + ) + centered_eight.log_likelihood["decoy"] = new_arr + delattr(centered_eight, "sample_stats") + return centered_eight + + def test_hdp(): normal_sample = np.random.randn(5000000) interval = hdi(normal_sample) @@ -192,6 +209,20 @@ def test_compare_different_size(centered_eight, non_centered_eight): compare(model_dict, ic="waic", method="stacking") +@pytest.mark.parametrize("ic", ["loo", "waic"]) +def test_compare_multiple_obs(multivariable_log_likelihood, centered_eight, non_centered_eight, ic): + compare_dict = { + "centered_eight": centered_eight, + "non_centered_eight": non_centered_eight, + "problematic": multivariable_log_likelihood, + } + with pytest.raises(TypeError, match="several log likelihood arrays"): + get_log_likelihood(compare_dict["problematic"]) + with pytest.raises(TypeError, match="model problematic"): + compare(compare_dict, ic=ic) + assert compare(compare_dict, ic=ic, var_name="obs") is not None + + def test_summary_ndarray(): array = np.random.randn(4, 100, 2) summary_df = summary(array) From a6d14ef9c68002c711e39605d96cc895ff6c7379 Mon Sep 17 00:00:00 2001 From: "Robert P. Goldman" Date: Mon, 29 Mar 2021 12:05:33 -0500 Subject: [PATCH 17/18] Don't let sample_stats shadow log_likelihood. Previously, we checked for `sample_stats` in `get_log_likelihood()` *before* reading `log_likelihood`. Add a check that `log_likelihood` must be missing before we check `sample_stats`. --- arviz/stats/stats_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/arviz/stats/stats_utils.py b/arviz/stats/stats_utils.py index 78fdb63ce3..e0f2d20ca5 100644 --- a/arviz/stats/stats_utils.py +++ b/arviz/stats/stats_utils.py @@ -412,7 +412,11 @@ def not_valid(ary, check_nan=True, check_shape=True, nan_kwargs=None, shape_kwar def get_log_likelihood(idata, var_name=None): """Retrieve the log likelihood dataarray of a given variable.""" - if hasattr(idata, "sample_stats") and hasattr(idata.sample_stats, "log_likelihood"): + if ( + not hasattr(idata, "log_likelihood") + and hasattr(idata, "sample_stats") + and hasattr(idata.sample_stats, "log_likelihood") + ): warnings.warn( "Storing the log_likelihood in sample_stats groups has been deprecated", DeprecationWarning, From 29a03ea783f1f6837d4fe39a7117c2102a972545 Mon Sep 17 00:00:00 2001 From: "Robert P. Goldman" Date: Mon, 29 Mar 2021 16:56:28 -0500 Subject: [PATCH 18/18] Improvements suggested by OriolAbril. * Limit recomputation in tests by scoping a fixture. * Test for expected IC in `compare` test. * Refine type assertion. --- arviz/stats/stats.py | 2 +- arviz/tests/base_tests/test_stats.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index 59ba7f1945..f3fadb3778 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -1312,7 +1312,7 @@ def summary( if fmt.lower() == "wide": summary_df = pd.DataFrame( - (np.full(cast(Tuple[int, int], (n_vars, n_metrics)), np.nan)), columns=metric_names + (np.full((cast(int, n_vars), n_metrics), np.nan)), columns=metric_names ) indexs = [] for i, (var_name, sel, isel, values) in enumerate( diff --git a/arviz/tests/base_tests/test_stats.py b/arviz/tests/base_tests/test_stats.py index b95148b0d5..4f737d5fbc 100644 --- a/arviz/tests/base_tests/test_stats.py +++ b/arviz/tests/base_tests/test_stats.py @@ -40,7 +40,7 @@ def non_centered_eight(): return non_centered_eight -@pytest.fixture +@pytest.fixture(scope="module") def multivariable_log_likelihood(centered_eight): centered_eight = centered_eight.copy() centered_eight.add_groups({"log_likelihood": centered_eight.sample_stats.log_likelihood}) @@ -218,7 +218,7 @@ def test_compare_multiple_obs(multivariable_log_likelihood, centered_eight, non_ } with pytest.raises(TypeError, match="several log likelihood arrays"): get_log_likelihood(compare_dict["problematic"]) - with pytest.raises(TypeError, match="model problematic"): + with pytest.raises(TypeError, match=f"{ic}.*model problematic"): compare(compare_dict, ic=ic) assert compare(compare_dict, ic=ic, var_name="obs") is not None