Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/arviz-devs/arviz into devel…
Browse files Browse the repository at this point in the history
…opment/plot_dist_comparison_warning

Resolve conflicts
  • Loading branch information
alexisperakis committed May 24, 2021
2 parents bf598b9 + 7ebedd2 commit 92f9947
Show file tree
Hide file tree
Showing 52 changed files with 581 additions and 124 deletions.
3 changes: 3 additions & 0 deletions .azure-pipelines/azure-pipelines-external.yml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ jobs:
if [ "$(pymc3.version)" = "github" ]; then
# Pip installation is failing for some reason. This is the same thing
git clone https://github.com/pymc-devs/pymc3
cd pymc3
git checkout v3
cd ..
pip install $PWD/pymc3
# python -m pip --no-cache-dir --log log.txt install git+https://github.com/pymc-devs/pymc3
# cat log.txt
Expand Down
3 changes: 2 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ disable=missing-docstring,
not-an-iterable,
no-member,
#TODO: Remove this once todos are done
fixme
fixme,
consider-using-with


# Enable the message, report, category or checker with the given id(s). You can
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@

## v0.9.0 (2020 June 23)
### New features
* loo-pit plot. The kde is computed over the data interval (this could be shorter than [0, 1]). The HDI is computed analitically ([1215](https://github.com/arviz-devs/arviz/pull/1215))
* loo-pit plot. The kde is computed over the data interval (this could be shorter than [0, 1]). The HDI is computed analytically ([1215](https://github.com/arviz-devs/arviz/pull/1215))
* Added `html_repr` of InferenceData objects for jupyter notebooks. ([1217](https://github.com/arviz-devs/arviz/pull/1217))
* Added support for PyJAGS via the function `from_pyjags`. ([1219](https://github.com/arviz-devs/arviz/pull/1219) and [1245](https://github.com/arviz-devs/arviz/pull/1245))
* `from_pymc3` can now retrieve `coords` and `dims` from model context ([1228](https://github.com/arviz-devs/arviz/pull/1228), [1240](https://github.com/arviz-devs/arviz/pull/1240) and [1249](https://github.com/arviz-devs/arviz/pull/1249))
Expand Down
45 changes: 45 additions & 0 deletions arviz/data/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Low level converters usually used by other functions."""
import datetime
import functools
import re
import warnings
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
Expand Down Expand Up @@ -438,3 +439,47 @@ def _make_json_serializable(data: dict) -> dict:
f"Value associated with variable `{type(value)}` is not JSON serializable."
)
return ret


def infer_stan_dtypes(stan_code):
"""Infer Stan integer variables from generated quantities block."""
# Remove old deprecated comments
stan_code = "\n".join(
line if "#" not in line else line[: line.find("#")] for line in stan_code.splitlines()
)
pattern_remove_comments = re.compile(
r'//.*?$|/\*.*?\*/|\'(?:\\.|[^\\\'])*\'|"(?:\\.|[^\\"])*"', re.DOTALL | re.MULTILINE
)
stan_code = re.sub(pattern_remove_comments, "", stan_code)

# Check generated quantities
if "generated quantities" not in stan_code:
return {}

# Extract generated quantities block
gen_quantities_location = stan_code.index("generated quantities")
block_start = gen_quantities_location + stan_code[gen_quantities_location:].index("{")

curly_bracket_count = 0
block_end = None
for block_end, char in enumerate(stan_code[block_start:], block_start + 1):
if char == "{":
curly_bracket_count += 1
elif char == "}":
curly_bracket_count -= 1

if curly_bracket_count == 0:
break

stan_code = stan_code[block_start:block_end]

stan_integer = r"int"
stan_limits = r"(?:\<[^\>]+\>)*" # ignore group: 0 or more <....>
stan_param = r"([^;=\s\[]+)" # capture group: ends= ";", "=", "[" or whitespace
stan_ws = r"\s*" # 0 or more whitespace
stan_ws_one = r"\s+" # 1 or more whitespace
pattern_int = re.compile(
"".join((stan_integer, stan_ws_one, stan_limits, stan_ws, stan_param)), re.IGNORECASE
)
dtypes = {key.strip(): "int" for key in re.findall(pattern_int, stan_code)}
return dtypes
2 changes: 1 addition & 1 deletion arviz/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def load_arviz_data(dataset=None, data_home=None):
Run with no parameters to get a list of all available models.
The directory to save to can also be set with the environement
The directory to save to can also be set with the environment
variable `ARVIZ_HOME`. The checksum of the dataset is checked against a
hardcoded value to watch for data corruption.
Expand Down
2 changes: 1 addition & 1 deletion arviz/data/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1874,7 +1874,7 @@ def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
msg = "Mismatch between the groups."
raise TypeError(msg)
for group in arg._groups_all:
# handle data groups seperately
# handle data groups separately
if group not in ["observed_data", "constant_data", "predictions_constant_data"]:
# assert that groups are equal
if group not in arg0_groups:
Expand Down
23 changes: 18 additions & 5 deletions arviz/data/io_cmdstan.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
# pylint: disable=too-many-lines
"""CmdStan-specific conversion code."""
import logging
import os
import re
from collections import defaultdict
from glob import glob
from pathlib import Path
from typing import Dict, List, Optional, Union

import numpy as np

from .. import utils
from ..rcparams import rcParams
from .base import CoordSpec, DimSpec, dict_to_dataset, requires
from .base import CoordSpec, DimSpec, dict_to_dataset, infer_stan_dtypes, requires
from .inference_data import InferenceData

_log = logging.getLogger(__name__)
Expand Down Expand Up @@ -83,9 +85,19 @@ def __init__(
self.index_origin = index_origin

if dtypes is None:
self.dtypes = {}
else:
self.dtypes = dtypes
dtypes = {}
elif isinstance(dtypes, str):
dtypes_path = Path(dtypes)
if dtypes_path.exists():
with dtypes_path.open("r") as f_obj:
model_code = f_obj.read()
else:
model_code = dtypes

dtypes = infer_stan_dtypes(model_code)

self.dtypes = dtypes

# populate posterior and sample_stats
self._parse_posterior()
self._parse_prior()
Expand Down Expand Up @@ -963,8 +975,9 @@ def from_cmdstan(
save_warmup : bool
Save warmup iterations into InferenceData object, if found in the input files.
If not defined, use default defined by the rcParams.
dtypes : dict
dtypes : dict or str
A dictionary containing dtype information (int, float) for parameters.
If input is a string, it is assumed to be a model code or path to model code file.
Returns
-------
Expand Down
Loading

0 comments on commit 92f9947

Please sign in to comment.