Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update end to end test #353

Merged
merged 35 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
240cde2
add wastewater data prep code
sbidari Jan 31, 2025
cad5172
update prep_data.py
sbidari Feb 3, 2025
26bbbcd
pre-commit
sbidari Feb 3, 2025
573f64c
add wastewater data in json file
sbidari Feb 3, 2025
c279103
move processing of wastewater data to pyrenew-hew-data
sbidari Feb 5, 2025
687bc54
syn to use datetime.date for dates throughout, rename ww_sampled to w…
sbidari Feb 6, 2025
1d8e2f0
add a test
sbidari Feb 6, 2025
64d3fd7
fix test
sbidari Feb 6, 2025
3856e02
n_datapoints -> n_data_days, move get_spines function to pyrenewHEWData
sbidari Feb 7, 2025
6bbc57a
drop wastewater data with no LOD reported
sbidari Feb 8, 2025
7614a1b
create fake nwss data
sbidari Feb 13, 2025
ab6beb8
add ww_data_dir to forecast_state.py, test_end_to_end.sh
sbidari Feb 14, 2025
60ced06
add schema for ww dataframe, add placeholder path
sbidari Feb 12, 2025
0056e3d
pre-commit
sbidari Feb 12, 2025
5896b5a
code review suggestions
sbidari Feb 13, 2025
f7dea09
add schema override
sbidari Feb 18, 2025
98b24c3
fix test
sbidari Feb 19, 2025
8f53d08
add wastewater related priors
sbidari Feb 20, 2025
2fcd94b
Merge branch 'main' of https://github.com/CDCgov/pyrenew-hew into sb-…
sbidari Feb 20, 2025
bdcbc2f
add PyrenewWastewaterData
sbidari Feb 24, 2025
83dd48f
code clean up
sbidari Feb 25, 2025
a94c4b6
only render diagnostic report for pyrenew_e
sbidari Feb 25, 2025
b97de0d
update test ww data generation
sbidari Feb 25, 2025
efead4e
nwss data handling, split end-to-end test by disease
sbidari Feb 25, 2025
a6bd1c5
remove unused imports
sbidari Feb 25, 2025
58d6342
end-to-end run test for covid-19 in US
sbidari Feb 26, 2025
ca9ba7e
update create_more_model_test_data to exclude model_run_dir for covid
sbidari Feb 26, 2025
2f54caa
Update pipelines/generate_test_data.R
sbidari Feb 26, 2025
f1fccac
add PyRenewWasteWater args as a dict to PyrenewHewData
sbidari Feb 27, 2025
14ba0bb
Apply suggestions from code review
sbidari Feb 27, 2025
abfdffb
coerce population_size to be an integer
sbidari Feb 27, 2025
e63d3c9
Suggestion for how to handle the population calculation (#365)
dylanhmorris Feb 27, 2025
8899552
add missing args to to_forecast_data, code review suggestions
sbidari Feb 28, 2025
c7a1dfd
Update pyrenew_hew/pyrenew_hew_model.py
dylanhmorris Mar 3, 2025
5d52a5e
Update pyrenew_hew/pyrenew_hew_model.py
dylanhmorris Mar 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions pipelines/build_pyrenew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
PyrenewHEWModel,
WastewaterObservationProcess,
)
from pyrenew_hew.pyrenew_wastewater_data import PyrenewWastewaterData


def build_model_from_dir(
Expand Down Expand Up @@ -92,14 +93,20 @@ def build_model_from_dir(
data_observed_disease_wastewater = (
pl.DataFrame(
model_data["data_observed_disease_wastewater"],
schema_overrides={"date": pl.Date},
schema_overrides={
"date": pl.Date,
"lab_index": pl.Int64,
"site_index": pl.Int64,
},
)
if fit_wastewater
else None
)

population_size = jnp.array(model_data["state_pop"])

pop_fraction = jnp.array(model_data["pop_fraction"])

ed_right_truncation_pmf_rv = DeterministicVariable(
"right_truncation_pmf", jnp.array(model_data["right_truncation_pmf"])
)
Expand All @@ -114,10 +121,10 @@ def build_model_from_dir(

first_ed_visits_date = datetime.datetime.strptime(
model_data["nssp_training_dates"][0], "%Y-%m-%d"
)
).date()
first_hospital_admissions_date = datetime.datetime.strptime(
model_data["nhsn_training_dates"][0], "%Y-%m-%d"
)
).date()

priors = runpy.run_path(str(prior_path))

Expand All @@ -133,6 +140,20 @@ def build_model_from_dir(
infection_feedback_strength_rv=priors["inf_feedback_strength_rv"],
infection_feedback_pmf_rv=infection_feedback_pmf_rv,
n_initialization_points=n_initialization_points,
pop_fraction=pop_fraction,
autoreg_rt_subpop_rv=priors["autoreg_rt_subpop_rv"],
sigma_rt_rv=priors["sigma_rt_rv"],
sigma_i_first_obs_rv=priors["sigma_i_first_obs_rv"],
sigma_initial_exp_growth_rate_rv=priors[
"sigma_initial_exp_growth_rate_rv"
],
offset_ref_logit_i_first_obs_rv=priors[
"offset_ref_logit_i_first_obs_rv"
],
offset_ref_initial_exp_growth_rate_rv=priors[
"offset_ref_initial_exp_growth_rate_rv"
],
offset_ref_log_rt_rv=priors["offset_ref_log_rt_rv"],
)

ed_visit_obs_rv = EDVisitObservationProcess(
Expand Down Expand Up @@ -175,16 +196,21 @@ def build_model_from_dir(
wastewater_obs_process_rv=wastewater_obs_rv,
)

wastewater_data = PyrenewWastewaterData(
data_observed_disease_wastewater=data_observed_disease_wastewater,
population_size=population_size,
pop_fraction=pop_fraction,
)

dat = PyrenewHEWData(
data_observed_disease_ed_visits=data_observed_disease_ed_visits,
data_observed_disease_hospital_admissions=(
data_observed_disease_hospital_admissions
),
data_observed_disease_wastewater=data_observed_disease_wastewater,
right_truncation_offset=right_truncation_offset,
first_ed_visits_date=first_ed_visits_date,
first_hospital_admissions_date=first_hospital_admissions_date,
population_size=population_size,
wastewater_data=wastewater_data,
)

return (mod, dat)
2 changes: 1 addition & 1 deletion pipelines/forecast_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def main(
(pl.col("location") == state)
& (pl.col("date") >= first_training_date)
)
state_level_nwss_data = preprocess_ww_data(nwss_data_cleaned)
state_level_nwss_data = preprocess_ww_data(nwss_data_cleaned.collect())
else:
state_level_nwss_data = None ## TO DO: change

Expand Down
48 changes: 48 additions & 0 deletions pipelines/generate_test_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,51 @@ generate_fake_param_data <-
)
}


#' Generate Fake NWSS Data
#'
#' This function generates fake wastewater data for a
#' and saves it as a parquet file.

generate_fake_nwss_data <- function(
private_data_dir = fs::path(getwd()),
states_to_generate = c("MT", "CA"),
start_reference = as.Date("2024-06-01"),
end_reference = as.Date("2024-12-21"),
site = c(1, 2, 3, 4),
lab = c(1, 1, 3, 3),
lod = c(20, 31, 20, 30),
site_pop = c(4e6, 2e6, 1e6, 5e5)) {
ww_dir <- fs::path(private_data_dir, "nwss_vintages")
fs::dir_create(ww_dir, recurse = TRUE)

site_info <- tibble::tibble(
wwtp_id = site,
lab_id = lab,
lod_sewage = lod,
population_served = site_pop,
sample_location = "wwtp",
sample_matrix = "raw wastewater",
pcr_target_units = "copies/l wastewater",
pcr_target = "sars-cov-2",
quality_flag = c("no", NA_character_, "n", "n")
)

ww_data <- tidyr::expand_grid(
sample_collect_date = seq(start_reference, end_reference, by = "week"),
wwtp_jurisdiction = states_to_generate,
site_info
) |>
dplyr::mutate(
pcr_target_avg_conc = abs(rnorm(dplyr::n(), mean = 500, sd = 50))
)

arrow::write_parquet(
ww_data, fs::path(ww_dir, paste0(end_reference, ".parquet"))
)
}


main <- function(private_data_dir,
target_diseases,
n_forecast_days) {
Expand Down Expand Up @@ -335,6 +380,9 @@ main <- function(private_data_dir,
states_to_generate = c("MT", "CA", "US"),
target_diseases = short_target_diseases
)
generate_fake_nwss_data(
private_data_dir,
)
}

p <- arg_parser("Create simulated epiweekly data.") |>
Expand Down
19 changes: 19 additions & 0 deletions pipelines/prep_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,24 @@ def process_and_save_state(
else None
)

if state_level_nwss_data is None:
pop_fraction = jnp.array([1])
else:
subpop_sizes = (
state_level_nwss_data.select(["site_index", "site", "site_pop"])
.unique()["site_pop"]
.to_numpy()
)
pop_fraction = jnp.where(
state_pop > subpop_sizes.sum(),
jnp.concatenate(
[jnp.array([state_pop - subpop_sizes.sum()]), subpop_sizes],
axis=0,
)
/ state_pop,
subpop_sizes / state_pop,
)

data_for_model_fit = {
"inf_to_ed_pmf": delay_pmf,
"generation_interval_pmf": generation_interval_pmf,
Expand All @@ -471,6 +489,7 @@ def process_and_save_state(
"state_pop": state_pop,
"right_truncation_offset": right_truncation_offset,
"data_observed_disease_wastewater": data_observed_disease_wastewater,
"pop_fraction": pop_fraction,
}

data_dir = Path(model_run_dir, "data")
Expand Down
3 changes: 0 additions & 3 deletions pipelines/prep_ww_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import datetime
from pathlib import Path

import polars as pl


Expand Down
41 changes: 41 additions & 0 deletions pipelines/priors/prod_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,47 @@
"mode_sd_ww_site", dist.TruncatedNormal(0, 0.25, low=0)
)

autoreg_rt_subpop_rv = DistributionalVariable(
"autoreg_rt_subpop", dist.Beta(1, 4)
)
sigma_rt_rv = DistributionalVariable(
"sigma_rt", dist.TruncatedNormal(0, 0.1, low=0)
)

sigma_i_first_obs_rv = DistributionalVariable(
"sigma_i_first_obs",
dist.TruncatedNormal(0, 0.5, low=0),
)

sigma_initial_exp_growth_rate_rv = DistributionalVariable(
"sigma_initial_exp_growth_rate",
dist.TruncatedNormal(
0,
0.05,
low=0,
),
)

offset_ref_logit_i_first_obs_rv = DistributionalVariable(
"offset_ref_logit_i_first_obs",
dist.Normal(0, 0.25),
)

offset_ref_initial_exp_growth_rate_rv = DistributionalVariable(
"offset_ref_initial_exp_growth_rate",
dist.TruncatedNormal(
0,
0.025,
low=-0.01,
high=0.01,
),
)

offset_ref_log_rt_rv = DistributionalVariable(
"offset_ref_log_r_t",
dist.Normal(0, 0.2),
)

# model constants related to wastewater obs process
ww_ml_produced_per_day = 227000
max_shed_interval = 26
20 changes: 13 additions & 7 deletions pipelines/tests/test_build_pyrenew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def mock_data():
{
"data_observed_disease_ed_visits": [1, 2, 3],
"data_observed_disease_hospital_admissions": [4, 5, 6],
"state_pop": [7, 8, 9],
"state_pop": [10000],
"generation_interval_pmf": [0.1, 0.2, 0.7],
"inf_to_ed_pmf": [0.4, 0.5, 0.1],
"inf_to_hosp_admit_pmf": [0.0, 0.7, 0.1, 0.1, 0.1],
Expand All @@ -32,13 +32,14 @@ def mock_data():
],
"site": ["1.0", "1.0", "2.0", "2.0"],
"lab": ["1.0", "1.0", "1.0", "1.0"],
"site_pop": [4000000, 4000000, 2000000, 2000000],
"site_pop": [4000, 4000, 2000, 2000],
"site_index": [1, 1, 0, 0],
"lab_site_index": [1, 1, 0, 0],
"log_genome_copies_per_ml": [0.1, 0.1, 0.5, 0.4],
"log_lod": [1.1, 2.0, 1.5, 2.1],
"below_lod": [False, False, False, False],
},
"pop_fraction": [0.4, 0.4, 0.2],
}
)

Expand Down Expand Up @@ -72,6 +73,14 @@ def mock_priors():
mode_sd_ww_site_rv = None
max_shed_interval = None
ww_ml_produced_per_day = None
pop_fraction=None
autoreg_rt_subpop_rv=None
sigma_rt_rv=None
sigma_i_first_obs_rv=None
sigma_initial_exp_growth_rate_rv=None
offset_ref_logit_i_first_obs_rv=None
offset_ref_initial_exp_growth_rate_rv=None
offset_ref_log_rt_rv=None
"""


Expand All @@ -91,7 +100,7 @@ def test_build_model_from_dir(tmp_path, mock_data, mock_priors):
_, data = build_model_from_dir(model_dir)
assert data.data_observed_disease_ed_visits is None
assert data.data_observed_disease_hospital_admissions is None
assert data.data_observed_disease_wastewater is None
assert data.data_observed_disease_wastewater_conc is None

# Test when all `fit_` arguments are True
_, data = build_model_from_dir(
Expand All @@ -108,7 +117,4 @@ def test_build_model_from_dir(tmp_path, mock_data, mock_priors):
data.data_observed_disease_hospital_admissions,
jnp.array(model_data["data_observed_disease_hospital_admissions"]),
)
assert pl.DataFrame(
model_data["data_observed_disease_wastewater"],
schema_overrides={"date": pl.Date},
).equals(data.data_observed_disease_wastewater)
# assert data.data_observed_disease_wastewater_conc is not None
7 changes: 4 additions & 3 deletions pipelines/tests/test_end_to_end.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,18 @@ do
--state-level-nssp-data-dir "$BASE_DIR/private_data/nssp_state_level_gold" \
--priors-path pipelines/priors/prod_priors.py \
--param-data-dir "$BASE_DIR/private_data/prod_param_estimates" \
--nwss-data-dir "$BASE_DIR/private_data/nwss_vintages" \
--output-dir "$BASE_DIR/private_data" \
--n-training-days 60 \
--n-chains 2 \
--n-samples 250 \
--n-warmup 250 \
--fit-ed-visits \
--no-fit-hospital-admissions \
--no-fit-wastewater \
--fit-hospital-admissions \
--fit-wastewater \
--forecast-ed-visits \
--forecast-hospital-admissions \
--no-forecast-wastewater \
--forecast-wastewater \
--score \
--eval-data-path "$BASE_DIR/private_data/nssp-etl"
if [ $? -ne 0 ]; then
Expand Down
Loading
Loading