Skip to content

Commit

Permalink
update end to end test (#353)
Browse files Browse the repository at this point in the history
  • Loading branch information
sbidari authored Mar 3, 2025
1 parent 128ae86 commit 2f0de0e
Show file tree
Hide file tree
Showing 14 changed files with 605 additions and 296 deletions.
38 changes: 32 additions & 6 deletions pipelines/build_pyrenew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
PyrenewHEWModel,
WastewaterObservationProcess,
)
from pyrenew_hew.pyrenew_wastewater_data import PyrenewWastewaterData


def build_model_from_dir(
Expand Down Expand Up @@ -92,13 +93,19 @@ def build_model_from_dir(
data_observed_disease_wastewater = (
pl.DataFrame(
model_data["data_observed_disease_wastewater"],
schema_overrides={"date": pl.Date},
schema_overrides={
"date": pl.Date,
"lab_index": pl.Int64,
"site_index": pl.Int64,
},
)
if fit_wastewater
else None
)

population_size = jnp.array(model_data["state_pop"])
population_size = jnp.array(model_data["state_pop"]).item()

pop_fraction = jnp.array(model_data["pop_fraction"])

ed_right_truncation_pmf_rv = DeterministicVariable(
"right_truncation_pmf", jnp.array(model_data["right_truncation_pmf"])
Expand All @@ -114,10 +121,10 @@ def build_model_from_dir(

first_ed_visits_date = datetime.datetime.strptime(
model_data["nssp_training_dates"][0], "%Y-%m-%d"
)
).date()
first_hospital_admissions_date = datetime.datetime.strptime(
model_data["nhsn_training_dates"][0], "%Y-%m-%d"
)
).date()

priors = runpy.run_path(str(prior_path))

Expand All @@ -133,6 +140,20 @@ def build_model_from_dir(
infection_feedback_strength_rv=priors["inf_feedback_strength_rv"],
infection_feedback_pmf_rv=infection_feedback_pmf_rv,
n_initialization_points=n_initialization_points,
pop_fraction=pop_fraction,
autoreg_rt_subpop_rv=priors["autoreg_rt_subpop_rv"],
sigma_rt_rv=priors["sigma_rt_rv"],
sigma_i_first_obs_rv=priors["sigma_i_first_obs_rv"],
sigma_initial_exp_growth_rate_rv=priors[
"sigma_initial_exp_growth_rate_rv"
],
offset_ref_logit_i_first_obs_rv=priors[
"offset_ref_logit_i_first_obs_rv"
],
offset_ref_initial_exp_growth_rate_rv=priors[
"offset_ref_initial_exp_growth_rate_rv"
],
offset_ref_log_rt_rv=priors["offset_ref_log_rt_rv"],
)

ed_visit_obs_rv = EDVisitObservationProcess(
Expand Down Expand Up @@ -175,16 +196,21 @@ def build_model_from_dir(
wastewater_obs_process_rv=wastewater_obs_rv,
)

wastewater_data = PyrenewWastewaterData(
data_observed_disease_wastewater=data_observed_disease_wastewater,
population_size=population_size,
)

dat = PyrenewHEWData(
data_observed_disease_ed_visits=data_observed_disease_ed_visits,
data_observed_disease_hospital_admissions=(
data_observed_disease_hospital_admissions
),
data_observed_disease_wastewater=data_observed_disease_wastewater,
right_truncation_offset=right_truncation_offset,
first_ed_visits_date=first_ed_visits_date,
first_hospital_admissions_date=first_hospital_admissions_date,
population_size=population_size,
pop_fraction=pop_fraction,
**wastewater_data.to_pyrenew_hew_data_args(),
)

return (mod, dat)
43 changes: 25 additions & 18 deletions pipelines/forecast_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,22 +335,28 @@ def main(
"No data available for the requested report date " f"{report_date}"
)

available_nwss_reports = get_available_reports(nwss_data_dir)
# assming NWSS_vintage directory follows naming convention
# using as of date
# need to be modified otherwise

if report_date in available_nwss_reports:
nwss_data_raw = pl.scan_parquet(
Path(nwss_data_dir, f"{report_date}.parquet")
)
nwss_data_cleaned = clean_nwss_data(nwss_data_raw).filter(
(pl.col("location") == state)
& (pl.col("date") >= first_training_date)
)
state_level_nwss_data = preprocess_ww_data(nwss_data_cleaned)
if fit_wastewater:
available_nwss_reports = get_available_reports(nwss_data_dir)
# assming NWSS_vintage directory follows naming convention
# using as of date
if report_date in available_nwss_reports:
nwss_data_raw = pl.scan_parquet(
Path(nwss_data_dir, f"{report_date}.parquet")
)
nwss_data_cleaned = clean_nwss_data(nwss_data_raw).filter(
(pl.col("location") == state)
& (pl.col("date") >= first_training_date)
)
state_level_nwss_data = preprocess_ww_data(
nwss_data_cleaned.collect()
)
else:
raise ValueError(
"NWSS data not available for the requested report date "
f"{report_date}"
)
else:
state_level_nwss_data = None ## TO DO: change
state_level_nwss_data = None

param_estimates = pl.scan_parquet(Path(param_data_dir, "prod.parquet"))
model_batch_dir_name = (
Expand Down Expand Up @@ -453,9 +459,10 @@ def main(
)
logger.info("Postprocessing complete.")

logger.info("Rendering webpage...")
render_diagnostic_report(model_run_dir)
logger.info("Rendering complete.")
if pyrenew_model_name == "pyrenew_e":
logger.info("Rendering webpage...")
render_diagnostic_report(model_run_dir)
logger.info("Rendering complete.")

if score:
logger.info("Scoring forecast...")
Expand Down
59 changes: 59 additions & 0 deletions pipelines/generate_test_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,62 @@ generate_fake_param_data <-
)
}


#' Generate Fake NWSS Data
#'
#' This function generates fake wastewater data for a
#' and saves it as a parquet file.

generate_fake_nwss_data <- function(
private_data_dir = fs::path(getwd()),
states_to_generate = c("MT", "CA"),
start_reference = as.Date("2024-06-01"),
end_reference = as.Date("2024-12-21"),
site = list(
CA = c(1, 2, 3, 4),
MT = c(5, 6, 7, 8)
),
lab = list(
CA = c(1, 1, 2, 2),
MT = c(3, 3, 4, 4)
),
lod = c(20, 31, 20, 30),
site_pop = list(
CA = c(4e6, 2e6, 1e6, 5e5),
MT = c(3e5, 2e5, 1e5, 5e4)
)) {
ww_dir <- fs::path(private_data_dir, "nwss_vintages")
fs::dir_create(ww_dir, recurse = TRUE)

site_info <- function(state) {
tibble::tibble(
wwtp_id = site[[state]],
lab_id = lab[[state]],
lod_sewage = lod,
population_served = site_pop[[state]],
sample_location = "wwtp",
sample_matrix = "raw wastewater",
pcr_target_units = "copies/l wastewater",
pcr_target = "sars-cov-2",
quality_flag = c("no", NA_character_, "n", "n"),
wwtp_jurisdiction = state
)
}

ww_data <- purrr::map_dfr(states_to_generate, site_info) |>
tidyr::expand_grid(
sample_collect_date = seq(start_reference, end_reference, by = "week")
) |>
dplyr::mutate(
pcr_target_avg_conc = abs(rnorm(dplyr::n(), mean = 500, sd = 50))
)

arrow::write_parquet(
ww_data, fs::path(ww_dir, paste0(end_reference, ".parquet"))
)
}


main <- function(private_data_dir,
target_diseases,
n_forecast_days) {
Expand Down Expand Up @@ -335,6 +391,9 @@ main <- function(private_data_dir,
states_to_generate = c("MT", "CA", "US"),
target_diseases = short_target_diseases
)
generate_fake_nwss_data(
private_data_dir
)
}

p <- arg_parser("Create simulated epiweekly data.") |>
Expand Down
20 changes: 20 additions & 0 deletions pipelines/prep_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,25 @@ def process_and_save_state(
else None
)

if state_level_nwss_data is None:
pop_fraction = jnp.array([1])
else:
subpop_sizes = (
state_level_nwss_data.select(["site_index", "site", "site_pop"])
.unique()
.get_column("site_pop")
.to_numpy()
)
if state_pop > sum(subpop_sizes):
pop_fraction = (
jnp.concatenate(
(jnp.array([state_pop - sum(subpop_sizes)]), subpop_sizes)
)
/ state_pop
)
else:
pop_fraction = subpop_sizes / state_pop

data_for_model_fit = {
"inf_to_ed_pmf": delay_pmf,
"generation_interval_pmf": generation_interval_pmf,
Expand All @@ -471,6 +490,7 @@ def process_and_save_state(
"state_pop": state_pop,
"right_truncation_offset": right_truncation_offset,
"data_observed_disease_wastewater": data_observed_disease_wastewater,
"pop_fraction": pop_fraction.tolist(),
}

data_dir = Path(model_run_dir, "data")
Expand Down
3 changes: 0 additions & 3 deletions pipelines/prep_ww_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import datetime
from pathlib import Path

import polars as pl


Expand Down
41 changes: 41 additions & 0 deletions pipelines/priors/prod_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,47 @@
"mode_sd_ww_site", dist.TruncatedNormal(0, 0.25, low=0)
)

autoreg_rt_subpop_rv = DistributionalVariable(
"autoreg_rt_subpop", dist.Beta(1, 4)
)
sigma_rt_rv = DistributionalVariable(
"sigma_rt", dist.TruncatedNormal(0, 0.1, low=0)
)

sigma_i_first_obs_rv = DistributionalVariable(
"sigma_i_first_obs",
dist.TruncatedNormal(0, 0.5, low=0),
)

sigma_initial_exp_growth_rate_rv = DistributionalVariable(
"sigma_initial_exp_growth_rate",
dist.TruncatedNormal(
0,
0.05,
low=0,
),
)

offset_ref_logit_i_first_obs_rv = DistributionalVariable(
"offset_ref_logit_i_first_obs",
dist.Normal(0, 0.25),
)

offset_ref_initial_exp_growth_rate_rv = DistributionalVariable(
"offset_ref_initial_exp_growth_rate",
dist.TruncatedNormal(
0,
0.025,
low=-0.01,
high=0.01,
),
)

offset_ref_log_rt_rv = DistributionalVariable(
"offset_ref_log_r_t",
dist.Normal(0, 0.2),
)

# model constants related to wastewater obs process
ww_ml_produced_per_day = 227000
max_shed_interval = 26
4 changes: 0 additions & 4 deletions pipelines/tests/create_more_model_test_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@ try(source("pipelines/plot_and_save_state_forecast.R"))
try(source("pipelines/score_forecast.R"))

model_batch_dirs <- c(
path(
"pipelines/tests/private_data",
"covid-19_r_2024-12-21_f_2024-10-22_t_2024-12-20"
),
path(
"pipelines/tests/private_data",
"influenza_r_2024-12-21_f_2024-10-22_t_2024-12-20"
Expand Down
20 changes: 13 additions & 7 deletions pipelines/tests/test_build_pyrenew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def mock_data():
{
"data_observed_disease_ed_visits": [1, 2, 3],
"data_observed_disease_hospital_admissions": [4, 5, 6],
"state_pop": [7, 8, 9],
"state_pop": [10000],
"generation_interval_pmf": [0.1, 0.2, 0.7],
"inf_to_ed_pmf": [0.4, 0.5, 0.1],
"inf_to_hosp_admit_pmf": [0.0, 0.7, 0.1, 0.1, 0.1],
Expand All @@ -32,13 +32,14 @@ def mock_data():
],
"site": ["1.0", "1.0", "2.0", "2.0"],
"lab": ["1.0", "1.0", "1.0", "1.0"],
"site_pop": [4000000, 4000000, 2000000, 2000000],
"site_pop": [4000, 4000, 2000, 2000],
"site_index": [1, 1, 0, 0],
"lab_site_index": [1, 1, 0, 0],
"log_genome_copies_per_ml": [0.1, 0.1, 0.5, 0.4],
"log_lod": [1.1, 2.0, 1.5, 2.1],
"below_lod": [False, False, False, False],
},
"pop_fraction": [0.4, 0.4, 0.2],
}
)

Expand Down Expand Up @@ -72,6 +73,14 @@ def mock_priors():
mode_sd_ww_site_rv = None
max_shed_interval = None
ww_ml_produced_per_day = None
pop_fraction=None
autoreg_rt_subpop_rv=None
sigma_rt_rv=None
sigma_i_first_obs_rv=None
sigma_initial_exp_growth_rate_rv=None
offset_ref_logit_i_first_obs_rv=None
offset_ref_initial_exp_growth_rate_rv=None
offset_ref_log_rt_rv=None
"""


Expand All @@ -91,7 +100,7 @@ def test_build_model_from_dir(tmp_path, mock_data, mock_priors):
_, data = build_model_from_dir(model_dir)
assert data.data_observed_disease_ed_visits is None
assert data.data_observed_disease_hospital_admissions is None
assert data.data_observed_disease_wastewater is None
assert data.data_observed_disease_wastewater_conc is None

# Test when all `fit_` arguments are True
_, data = build_model_from_dir(
Expand All @@ -108,7 +117,4 @@ def test_build_model_from_dir(tmp_path, mock_data, mock_priors):
data.data_observed_disease_hospital_admissions,
jnp.array(model_data["data_observed_disease_hospital_admissions"]),
)
assert pl.DataFrame(
model_data["data_observed_disease_wastewater"],
schema_overrides={"date": pl.Date},
).equals(data.data_observed_disease_wastewater)
assert data.data_observed_disease_wastewater_conc is not None
Loading

0 comments on commit 2f0de0e

Please sign in to comment.