Skip to content

Commit

Permalink
Python 3.5 and 3.6 compatibility.
Browse files Browse the repository at this point in the history
  • Loading branch information
rpgoldman committed Mar 27, 2021
1 parent 7eb162c commit 6d27fd5
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 11 deletions.
21 changes: 18 additions & 3 deletions arviz/rcparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@
from collections.abc import MutableMapping
from pathlib import Path
from typing import Any, Dict
from typing_extensions import Literal, get_args
from typing_extensions import Literal

NO_GET_ARGS: bool = False
try:
from typing_extensions import get_args
except ImportError:
NO_GET_ARGS = True


import numpy as np

Expand Down Expand Up @@ -281,9 +288,17 @@ def validate_iterable(value):
"plot.matplotlib.constrained_layout": (True, _validate_boolean),
"plot.matplotlib.show": (False, _validate_boolean),
"stats.hdi_prob": (0.94, _validate_probability),
"stats.information_criterion": ("loo", _make_validate_choice(set(get_args(ICKeyword)))),
"stats.information_criterion": (
"loo",
_make_validate_choice({"loo", "waic"} if NO_GET_ARGS else set(get_args(ICKeyword))),
),
"stats.ic_pointwise": (False, _validate_boolean),
"stats.ic_scale": ("log", _make_validate_choice(set(get_args(ScaleKeyword)))),
"stats.ic_scale": (
"log",
_make_validate_choice(
{"log", "negative_log", "deviance"} if NO_GET_ARGS else set(get_args(ScaleKeyword))
),
),
"stats.ic_compare_method": (
"stacking",
_make_validate_choice({"stacking", "bb-pseudo-bma", "pseudo-bma"}),
Expand Down
21 changes: 13 additions & 8 deletions arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
import scipy.stats as st
import xarray as xr
from scipy.optimize import minimize
from typing_extensions import Literal, get_args
from typing_extensions import Literal

NO_GET_ARGS: bool = False
try:
from typing_extensions import get_args
except ImportError:
NO_GET_ARGS = True

from arviz import _log
from ..data import InferenceData, convert_to_dataset, convert_to_inference_data
Expand Down Expand Up @@ -156,12 +162,10 @@ def compare(
scale = cast(ScaleKeyword, scale.lower())
else:
scale = cast(ScaleKeyword, rcParams["stats.ic_scale"])
if scale not in get_args(ScaleKeyword):
raise ValueError(
f"{scale} is not a valid value for scale: must be in {get_args(ScaleKeyword)}"
)
allowable = ["log", "negative_log", "deviance"] if NO_GET_ARGS else get_args(ScaleKeyword)
if scale not in allowable:
raise ValueError(f"{scale} is not a valid value for scale: must be in {allowable}")

assert scale in get_args(ScaleKeyword)
if scale == "log":
scale_value = 1
ascending = False
Expand All @@ -176,8 +180,9 @@ def compare(
ic = cast(ICKeyword, rcParams["stats.information_criterion"])
else:
ic = cast(ICKeyword, ic.lower())
if ic not in get_args(ICKeyword):
raise ValueError(f"{ic} is not a valid value for ic: must be in {get_args(ICKeyword)}")
allowable = ["loo", "waic"] if NO_GET_ARGS else get_args(ICKeyword)
if ic not in allowable:
raise ValueError(f"{ic} is not a valid value for ic: must be in {allowable}")
if ic == "loo":
ic_func: Callable = loo
df_comp = pd.DataFrame(
Expand Down

0 comments on commit 6d27fd5

Please sign in to comment.