Skip to content

Commit

Permalink
cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
sbidari committed Feb 19, 2025
1 parent 98b24c3 commit 5c47273
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 54 deletions.
1 change: 0 additions & 1 deletion pipelines/forecast_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,6 @@ def main(
last_training_date=last_training_date,
param_estimates=param_estimates,
model_run_dir=model_run_dir,
ww_data_dir=ww_data_dir,
logger=logger,
credentials_dict=credentials_dict,
)
Expand Down
46 changes: 33 additions & 13 deletions pipelines/prep_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from pathlib import Path

import forecasttools
import jax.numpy as jnp
import jax.numpy as jnp
import polars as pl
import polars.selectors as cs

Expand Down Expand Up @@ -48,7 +46,9 @@ def py_scalar_to_r_scalar(py_scalar):
state_abb_for_query = state_abb if state_abb != "US" else "USA"

temp_file = Path(temp_dir, "nhsn_temp.parquet")
api_key_id = credentials_dict.get("nhsn_api_key_id", os.getenv("NHSN_API_KEY_ID"))
api_key_id = credentials_dict.get(
"nhsn_api_key_id", os.getenv("NHSN_API_KEY_ID")
)
api_key_secret = credentials_dict.get(
"nhsn_api_key_secret", os.getenv("NHSN_API_KEY_SECRET")
)
Expand Down Expand Up @@ -80,7 +80,9 @@ def py_scalar_to_r_scalar(py_scalar):
if result.returncode != 0:
raise RuntimeError(f"pull_and_save_nhsn: {result.stderr}")
raw_dat = pl.read_parquet(temp_file)
dat = raw_dat.with_columns(weekendingdate=pl.col("weekendingdate").cast(pl.Date))
dat = raw_dat.with_columns(
weekendingdate=pl.col("weekendingdate").cast(pl.Date)
)
return dat


Expand All @@ -102,7 +104,9 @@ def combine_nssp_and_nhsn(
variable_name="drop_me",
value_name=".value",
)
.with_columns(pl.col("count_type").replace(count_type_dict).alias(".variable"))
.with_columns(
pl.col("count_type").replace(count_type_dict).alias(".variable")
)
.select(cs.exclude(["count_type", "drop_me"]))
)

Expand Down Expand Up @@ -182,7 +186,9 @@ def process_state_level_data(

if state_abb == "US":
locations_to_aggregate = (
state_pop_df.filter(pl.col("abb") != "US").get_column("abb").unique()
state_pop_df.filter(pl.col("abb") != "US")
.get_column("abb")
.unique()
)
logger.info("Aggregating state-level data to national")
state_level_nssp_data = aggregate_to_national(
Expand All @@ -209,7 +215,9 @@ def process_state_level_data(
]
)
.with_columns(
disease=pl.col("disease").cast(pl.Utf8).replace(_inverse_disease_map),
disease=pl.col("disease")
.cast(pl.Utf8)
.replace(_inverse_disease_map),
)
.sort(["date", "disease"])
.collect(streaming=True)
Expand Down Expand Up @@ -241,7 +249,9 @@ def aggregate_facility_level_nssp_to_state(
if state_abb == "US":
logger.info("Aggregating facility-level data to national")
locations_to_aggregate = (
state_pop_df.filter(pl.col("abb") != "US").get_column("abb").unique()
state_pop_df.filter(pl.col("abb") != "US")
.get_column("abb")
.unique()
)
facility_level_nssp_data = aggregate_to_national(
facility_level_nssp_data,
Expand All @@ -260,7 +270,9 @@ def aggregate_facility_level_nssp_to_state(
.group_by(["reference_date", "disease"])
.agg(pl.col("value").sum().alias("ed_visits"))
.with_columns(
disease=pl.col("disease").cast(pl.Utf8).replace(_inverse_disease_map),
disease=pl.col("disease")
.cast(pl.Utf8)
.replace(_inverse_disease_map),
geo_value=pl.lit(state_abb).cast(pl.Utf8),
)
.rename({"reference_date": "date"})
Expand Down Expand Up @@ -350,12 +362,16 @@ def process_and_save_state(

if facility_level_nssp_data is None and state_level_nssp_data is None:
raise ValueError(
"Must provide at least one " "of facility-level and state-level" "NSSP data"
"Must provide at least one "
"of facility-level and state-level"
"NSSP data"
)

state_pop_df = get_state_pop_df()

state_pop = state_pop_df.filter(pl.col("abb") == state_abb).item(0, "population")
state_pop = state_pop_df.filter(pl.col("abb") == state_abb).item(
0, "population"
)

(generation_interval_pmf, delay_pmf, right_truncation_pmf) = get_pmfs(
param_estimates=param_estimates, state_abb=state_abb, disease=disease
Expand Down Expand Up @@ -404,13 +420,17 @@ def process_and_save_state(
credentials_dict=credentials_dict,
).with_columns(pl.lit("train").alias("data_type"))

nssp_training_dates = nssp_training_data.get_column("date").unique().to_list()
nssp_training_dates = (
nssp_training_data.get_column("date").unique().to_list()
)
nhsn_training_dates = (
nhsn_training_data.get_column("weekendingdate").unique().to_list()
)

nhsn_first_date_index = next(
i for i, x in enumerate(nssp_training_dates) if x == min(nhsn_training_dates)
i
for i, x in enumerate(nssp_training_dates)
if x == min(nhsn_training_dates)
)
nhsn_step_size = 7

Expand Down
3 changes: 0 additions & 3 deletions pipelines/prep_ww_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import datetime
from pathlib import Path

import polars as pl


Expand Down
38 changes: 24 additions & 14 deletions pyrenew_hew/pyrenew_hew_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def __init__(
self.first_ed_visits_date = first_ed_visits_date
self.first_hospital_admissions_date = first_hospital_admissions_date
self.first_wastewater_date_ = first_wastewater_date
self.data_observed_disease_wastewater = data_observed_disease_wastewater
self.data_observed_disease_wastewater = (
data_observed_disease_wastewater
)
self.population_size = population_size
self.shedding_offset = shedding_offset
self.pop_fraction_ = pop_fraction
Expand Down Expand Up @@ -132,7 +134,9 @@ def last_data_date_overall(self):

@property
def n_days_post_init(self):
return (self.last_data_date_overall - self.first_data_date_overall).days
return (
self.last_data_date_overall - self.first_data_date_overall
).days

@property
def site_subpop_spine(self):
Expand Down Expand Up @@ -173,10 +177,12 @@ def site_subpop_spine(self):
site_subpop_spine = (
pl.concat([aux_subpop, site_indices], how="vertical_relaxed")
.with_columns(
subpop_index=pl.col("site_index").cum_count().alias("subpop_index"),
subpop_name=pl.format("Site: {}", pl.col("site")).fill_null(
"remainder of population"
),
subpop_index=pl.col("site_index")
.cum_count()
.alias("subpop_index"),
subpop_name=pl.format(
"Site: {}", pl.col("site")
).fill_null("remainder of population"),
)
.rename({"site_pop": "subpop_pop"})
)
Expand Down Expand Up @@ -233,22 +239,24 @@ def pop_fraction(self):
@property
def data_observed_disease_wastewater_conc(self):
if self.data_observed_disease_wastewater is not None:
return self.wastewater_data_extended["log_genome_copies_per_ml"].to_numpy()
return self.wastewater_data_extended[
"log_genome_copies_per_ml"
].to_numpy()

@property
def ww_censored(self):
if self.data_observed_disease_wastewater is not None:
return self.wastewater_data_extended.filter(pl.col("below_lod") == 1)[
"ind_rel_to_observed_times"
].to_numpy()
return self.wastewater_data_extended.filter(
pl.col("below_lod") == 1
)["ind_rel_to_observed_times"].to_numpy()
return None

@property
def ww_uncensored(self):
if self.data_observed_disease_wastewater is not None:
return self.wastewater_data_extended.filter(pl.col("below_lod") == 0)[
"ind_rel_to_observed_times"
].to_numpy()
return self.wastewater_data_extended.filter(
pl.col("below_lod") == 0
)["ind_rel_to_observed_times"].to_numpy()

@property
def ww_observed_times(self):
Expand Down Expand Up @@ -304,7 +312,9 @@ def get_end_date(
)
result = None
else:
result = first_date + datetime.timedelta(days=n_datapoints * timestep_days)
result = first_date + datetime.timedelta(
days=n_datapoints * timestep_days
)
return result

def get_n_data_days(
Expand Down
67 changes: 47 additions & 20 deletions pyrenew_hew/pyrenew_hew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def __init__(
self.autoreg_rt_subpop_rv = autoreg_rt_subpop_rv
self.sigma_rt_rv = sigma_rt_rv
self.sigma_i_first_obs_rv = sigma_i_first_obs_rv
self.sigma_initial_exp_growth_rate_rv = sigma_initial_exp_growth_rate_rv
self.sigma_initial_exp_growth_rate_rv = (
sigma_initial_exp_growth_rate_rv
)
self.n_initialization_points = n_initialization_points
self.pop_fraction = pop_fraction
self.n_subpops = len(pop_fraction)
Expand Down Expand Up @@ -122,10 +124,13 @@ def sample(self, n_days_post_init: int):
+ self.offset_ref_logit_i_first_obs_rv(),
)
initial_exp_growth_rate_ref_subpop = (
initial_exp_growth_rate + self.offset_ref_initial_exp_growth_rate_rv()
initial_exp_growth_rate
+ self.offset_ref_initial_exp_growth_rate_rv()
)

log_rtu_weekly_ref_subpop = log_rtu_weekly + self.offset_ref_log_rt_rv()
log_rtu_weekly_ref_subpop = (
log_rtu_weekly + self.offset_ref_log_rt_rv()
)
i_first_obs_over_n_non_ref_subpop_rv = TransformedVariable(
"i_first_obs_over_n_non_ref_subpop",
DistributionalVariable(
Expand Down Expand Up @@ -207,7 +212,9 @@ def sample(self, n_days_post_init: int):
)[:n_days_post_init, :]
) # indexed rel to first post-init day.

i0_subpop_rv = DeterministicVariable("i0_subpop", i_first_obs_over_n_subpop)
i0_subpop_rv = DeterministicVariable(
"i0_subpop", i_first_obs_over_n_subpop
)
initial_exp_growth_rate_subpop_rv = DeterministicVariable(
"initial_exp_growth_rate_subpop", initial_exp_growth_rate_subpop
)
Expand Down Expand Up @@ -312,9 +319,7 @@ def sample(
iedr = jnp.repeat(
transformation.SigmoidTransform()(p_ed_ar + p_ed_mean),
repeats=7,
)[
:n_datapoints
] # indexed rel to first ed report day
)[:n_datapoints] # indexed rel to first ed report day
# this is only applied after the ed visits are generated, not to all
# the latent infections. This is why we cannot apply the iedr in
# compute_delay_ascertained_incidence
Expand All @@ -334,21 +339,28 @@ def sample(
)[-n_datapoints:]

latent_ed_visits_final = (
potential_latent_ed_visits * iedr * ed_wday_effect * population_size
potential_latent_ed_visits
* iedr
* ed_wday_effect
* population_size
)

if right_truncation_offset is not None:
prop_already_reported_tail = jnp.flip(
self.ed_right_truncation_cdf_rv()[right_truncation_offset:]
)
n_points_to_prepend = n_datapoints - prop_already_reported_tail.shape[0]
n_points_to_prepend = (
n_datapoints - prop_already_reported_tail.shape[0]
)
prop_already_reported = jnp.pad(
prop_already_reported_tail,
(n_points_to_prepend, 0),
mode="constant",
constant_values=(1, 0),
)
latent_ed_visits_now = latent_ed_visits_final * prop_already_reported
latent_ed_visits_now = (
latent_ed_visits_final * prop_already_reported
)
else:
latent_ed_visits_now = latent_ed_visits_final

Expand All @@ -374,7 +386,9 @@ def __init__(
ihr_rel_iedr_rv: RandomVariable = None,
) -> None:
self.inf_to_hosp_admit_rv = inf_to_hosp_admit_rv
self.hosp_admit_neg_bin_concentration_rv = hosp_admit_neg_bin_concentration_rv
self.hosp_admit_neg_bin_concentration_rv = (
hosp_admit_neg_bin_concentration_rv
)
self.ihr_rv = ihr_rv
self.ihr_rel_iedr_rv = ihr_rel_iedr_rv

Expand Down Expand Up @@ -508,7 +522,10 @@ def normed_shedding_cdf(
norm_const = (t_p + t_d) * ((log_base - 1) / jnp.log(log_base) - 1)

def ad_pre(x):
return t_p / jnp.log(log_base) * jnp.exp(jnp.log(log_base) * x / t_p) - x
return (
t_p / jnp.log(log_base) * jnp.exp(jnp.log(log_base) * x / t_p)
- x
)

def ad_post(x):
return (
Expand Down Expand Up @@ -588,18 +605,22 @@ def sample(
def batch_colvolve_fn(m):
return jnp.convolve(m, viral_kinetics, mode="valid")

model_net_inf_ind_shedding = jax.vmap(batch_colvolve_fn, in_axes=1, out_axes=1)(
jnp.atleast_2d(latent_infections_subpop)
)[-n_datapoints:, :]
numpyro.deterministic("model_net_inf_ind_shedding", model_net_inf_ind_shedding)
model_net_inf_ind_shedding = jax.vmap(
batch_colvolve_fn, in_axes=1, out_axes=1
)(jnp.atleast_2d(latent_infections_subpop))[-n_datapoints:, :]
numpyro.deterministic(
"model_net_inf_ind_shedding", model_net_inf_ind_shedding
)

log10_genome_per_inf_ind = self.log10_genome_per_inf_ind_rv()
expected_obs_viral_genomes = (
jnp.log(10) * log10_genome_per_inf_ind
+ jnp.log(model_net_inf_ind_shedding + shedding_offset)
- jnp.log(self.ww_ml_produced_per_day)
)
numpyro.deterministic("expected_obs_viral_genomes", expected_obs_viral_genomes)
numpyro.deterministic(
"expected_obs_viral_genomes", expected_obs_viral_genomes
)

mode_sigma_ww_site = self.mode_sigma_ww_site_rv()
sd_log_sigma_ww_site = self.sd_log_sigma_ww_site_rv()
Expand Down Expand Up @@ -638,7 +659,11 @@ def batch_colvolve_fn(m):
scale=sigma_ww_site[ww_observed_lab_sites[ww_uncensored]],
),
).sample(
obs=(data_observed[ww_uncensored] if data_observed is not None else None),
obs=(
data_observed[ww_uncensored]
if data_observed is not None
else None
),
)

if ww_censored.shape[0] != 0:
Expand Down Expand Up @@ -695,8 +720,10 @@ def sample(
sample_wastewater: bool = False,
) -> dict[str, ArrayLike]: # numpydoc ignore=GL08
n_init_days = self.latent_infection_process_rv.n_initialization_points
latent_infections, latent_infections_subpop = self.latent_infection_process_rv(
n_days_post_init=data.n_days_post_init,
latent_infections, latent_infections_subpop = (
self.latent_infection_process_rv(
n_days_post_init=data.n_days_post_init,
)
)
first_latent_infection_dow = (
data.first_data_date_overall - datetime.timedelta(days=n_init_days)
Expand Down
Loading

0 comments on commit 5c47273

Please sign in to comment.