Skip to content

Commit

Permalink
bugfixes to beta, stretching and squeezing removed where unnecessary
Browse files Browse the repository at this point in the history
  • Loading branch information
mbi6245 committed Dec 10, 2024
1 parent b514bf0 commit 44009ba
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 29 deletions.
27 changes: 7 additions & 20 deletions plots.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,20 @@
"cells": [
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['Unnamed: 0', 'nid', 'location_id', 'year_id', 'sex_id', 'age_cat'], dtype='object')\n",
"(303, 4)\n",
"Index(['Unnamed: 0', 'age_group_id', 'location_id', 'sex_id', 'year_id',\n",
" 'measure_id', 'metric_id', 'stgpr_model_version_id',\n",
" 'modelable_entity_id', 'draw_0',\n",
"Index(['Unnamed: 0', 'age_group_id', 'draw_0', 'draw_1', 'draw_10', 'draw_100',\n",
" 'draw_101', 'draw_102', 'draw_103', 'draw_104',\n",
" ...\n",
" 'draw_242', 'draw_243', 'draw_244', 'draw_245', 'draw_246', 'draw_247',\n",
" 'draw_248', 'draw_249', 'age_group_alternative_name',\n",
" 'age_group_years_start'],\n",
" dtype='object', length=261)\n",
" location_id year_id sex_id age_group_years_start\n",
"0 522 1994 1 25\n",
"1 522 1994 1 30\n",
"2 522 1994 1 35\n",
"3 522 1994 1 40\n",
"4 522 1994 1 45\n",
"['25-29 years' '30-34 years' '35-39 years' '40-44 years' '45-49 years'\n",
" '50-54 years' '55-59 years' '60-64 years' '65-69 years' '70-74 years'\n",
" '75-79 years' '80-84 years' '85-89 years' '90-94 years' '95+ years']\n",
"(7200, 4)\n"
" 'draw_99', 'location_id', 'measure_id', 'metric_id',\n",
" 'modelable_entity_id', 'sex_id', 'year_id', 'model_version_id',\n",
" 'age_group_alternative_name', 'age_group_years_start'],\n",
" dtype='object', length=261)\n"
]
}
],
Expand Down
19 changes: 13 additions & 6 deletions src/ensemble/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ class Fisk(Distribution):

def support(self) -> Tuple[float, float]:
return (0, np.inf)
# when a user passes in a different finite bound (i.e. the lb)
# fit a distribution with a translation in the mean only, no diff to variance b/c scaling doesn't make sense

def _create_scipy_dist(self):
positive_support(self.mean)
Expand Down Expand Up @@ -311,8 +313,12 @@ def support(self) -> Tuple[float, float]:
return (self.lb, self.ub)

def _create_scipy_dist(self) -> None:
# TODO: PUT THE WARNINGS HERE
# FIX THE MEAN, AND THEN DERIVE A FUNCTION IN TERMS OF ALPHA, THEN WARN
if self.mean**2 <= self.variance:
raise ValueError(
"beta distributions do not exist for certain mean and variance "
+ "combinations. The supplied variance must be in between "
+ "(0, mean^2)"
)
beta_bounds(self.mean)
if self.lb != 0 and self.ub != 1:
mean = (self.mean - self.lb) / self.width
Expand All @@ -325,6 +331,7 @@ def _create_scipy_dist(self) -> None:
beta = (1 - mean) * (mean - mean**2 - var) / var
print(alpha, beta)
self._scipy_dist = stats.beta(a=alpha, b=beta)
print(self._scipy_dist.stats("mv"))

def rvs(self, *args, **kwds):
"""defaults to scipy implementation for generating random variates
Expand All @@ -349,7 +356,7 @@ def pdf(self, x: npt.ArrayLike) -> np.ndarray:
np.ndarray
PDF evaluated at quantile x
"""
return self._stretch(self._scipy_dist.pdf(self._squeeze(x)))
return self._scipy_dist.pdf(self._squeeze(x))

def cdf(self, q: npt.ArrayLike) -> np.ndarray:
"""defaults to scipy implementation for cumulative density function
Expand All @@ -364,7 +371,7 @@ def cdf(self, q: npt.ArrayLike) -> np.ndarray:
np.ndarray
CDF evaluated at quantile q
"""
return self._stretch(self._scipy_dist.cdf(self._squeeze(q)))
return self._scipy_dist.cdf(self._squeeze(q))

def ppf(self, p: npt.ArrayLike) -> np.ndarray:
"""defaults to scipy implementation for percent point function
Expand Down Expand Up @@ -396,9 +403,9 @@ def stats(self, moments: str) -> Union[float, Tuple[float, ...]]:
"""
res_list = []
if "m" in moments:
res_list.append(self._stretch(self.mean))
res_list.append(self._stretch(self._scipy_dist.stats("m")))
if "v" in moments:
res_list.append(self.variance * self.width)
res_list.append(self._scipy_dist.stats("v") * self.width)

# res_list = [res[()] for res in res_list]
if len(res_list) == 1:
Expand Down
19 changes: 18 additions & 1 deletion src/ensemble/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import json
import warnings
from typing import List, Tuple, Union

import cvxpy as cp
Expand Down Expand Up @@ -38,6 +39,8 @@ def __init__(
named_weights: dict,
mean: float,
variance: float,
lb: float = None,
ub: float = None,
):
self._distributions = list(named_weights.keys())
self._weights = list(named_weights.values())
Expand All @@ -53,6 +56,18 @@ def __init__(
)
self.mean = mean
self.variance = variance
if lb is not None and self.cdf(lb) > 0.05:
warnings.warn(
"Ensemble density less than the specified lower bound "
+ lb
+ " exceeds 0.05. Check for low sample size!"
)
if ub is not None and (1 - self.cdf(ub)) > 0.05:
warnings.warn(
"Ensemble density greater than the specified upper bound "
+ ub
+ " exceeds 0.05. Check for low sample size!"
)

def _ppf_to_solve(self, x: float, p: float) -> float:
"""ensemble_CDF(x) - lower tail probability
Expand Down Expand Up @@ -145,13 +160,15 @@ def cdf(self, q: npt.ArrayLike) -> np.ndarray:
)
)

def ppf(self, p: npt.ArrayLike) -> np.ndarray:
def ppf(self, p: npt.ArrayLike, uncertainty: bool = True) -> np.ndarray:
"""percent point function of ensemble distribution
Parameters
----------
p : npt.ArrayLike
lower tail probability
uncertainty : bool, optional
return a 95% CI using the delta method about p
Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ def test_normal():

def test_beta():
beta = Beta(BETA_MEAN, BETA_VARIANCE)
# beta = Beta(0.5, 0.249)
res = beta.stats(moments="mv")
print("resulting mean and var: ", res)
assert np.isclose(res[0], BETA_MEAN)
assert np.isclose(res[1], BETA_VARIANCE)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_read.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
[{"named_weights": {"normal": 0.5, "gumbel": 0.5}, "mean": 1, "variance": 1}, {"named_weights": {"gamma": 0.2, "invgamma": 0.8}, "mean": 1, "variance": 1}, {"named_weights": {"gamma": 0.2, "invgamma": 0.8}, "mean": 1, "variance": 1}, {"named_weights": {"gamma": 0.2, "invgamma": 0.8}, "mean": 1, "variance": 1}, {"named_weights": {"gamma": 0.2, "invgamma": 0.8}, "mean": 1, "variance": 1}, {"named_weights": {"gamma": 0.2, "invgamma": 0.8}, "mean": 1, "variance": 1}, {"named_weights": {"gamma": 0.2, "invgamma": 0.8}, "mean": 1, "variance": 1}]
[{"named_weights": {"normal": 0.5, "gumbel": 0.5}, "mean": 1, "variance": 1}, {"named_weights": {"gamma": 0.2, "invgamma": 0.8}, "mean": 1, "variance": 1}, {"named_weights": {"gamma": 0.2, "invgamma": 0.8}, "mean": 1, "variance": 1}, {"named_weights": {"gamma": 0.2, "invgamma": 0.8}, "mean": 1, "variance": 1}, {"named_weights": {"gamma": 0.2, "invgamma": 0.8}, "mean": 1, "variance": 1}, {"named_weights": {"gamma": 0.2, "invgamma": 0.8}, "mean": 1, "variance": 1}, {"named_weights": {"gamma": 0.2, "invgamma": 0.8}, "mean": 1, "variance": 1}, {"named_weights": {"gamma": 0.2, "invgamma": 0.8}, "mean": 1, "variance": 1}, {"named_weights": {"gamma": 0.2, "invgamma": 0.8}, "mean": 1, "variance": 1}, {"named_weights": {"gamma": 0.2, "invgamma": 0.8}, "mean": 1, "variance": 1}, {"named_weights": {"gamma": 0.2, "invgamma": 0.8}, "mean": 1, "variance": 1}, {"named_weights": {"gamma": 0.2, "invgamma": 0.8}, "mean": 1, "variance": 1}, {"named_weights": {"gamma": 0.2, "invgamma": 0.8}, "mean": 1, "variance": 1}, {"named_weights": {"gamma": 0.2, "invgamma": 0.8}, "mean": 1, "variance": 1}, {"named_weights": {"gamma": 0.2, "invgamma": 0.8}, "mean": 1, "variance": 1}, {"named_weights": {"gamma": 0.2, "invgamma": 0.8}, "mean": 1, "variance": 1}, {"named_weights": {"gamma": 0.2, "invgamma": 0.8}, "mean": 1, "variance": 1}, {"named_weights": {"gamma": 0.2, "invgamma": 0.8}, "mean": 1, "variance": 1}]

0 comments on commit 44009ba

Please sign in to comment.