diff --git a/pipelines/build_pyrenew_model.py b/pipelines/build_pyrenew_model.py index 1a0b2f05..9ecc26b9 100644 --- a/pipelines/build_pyrenew_model.py +++ b/pipelines/build_pyrenew_model.py @@ -15,6 +15,7 @@ PyrenewHEWModel, WastewaterObservationProcess, ) +from pyrenew_hew.pyrenew_wastewater_data import PyrenewWastewaterData def build_model_from_dir( @@ -92,13 +93,19 @@ def build_model_from_dir( data_observed_disease_wastewater = ( pl.DataFrame( model_data["data_observed_disease_wastewater"], - schema_overrides={"date": pl.Date}, + schema_overrides={ + "date": pl.Date, + "lab_index": pl.Int64, + "site_index": pl.Int64, + }, ) if fit_wastewater else None ) - population_size = jnp.array(model_data["state_pop"]) + population_size = jnp.array(model_data["state_pop"]).item() + + pop_fraction = jnp.array(model_data["pop_fraction"]) ed_right_truncation_pmf_rv = DeterministicVariable( "right_truncation_pmf", jnp.array(model_data["right_truncation_pmf"]) @@ -114,10 +121,10 @@ def build_model_from_dir( first_ed_visits_date = datetime.datetime.strptime( model_data["nssp_training_dates"][0], "%Y-%m-%d" - ) + ).date() first_hospital_admissions_date = datetime.datetime.strptime( model_data["nhsn_training_dates"][0], "%Y-%m-%d" - ) + ).date() priors = runpy.run_path(str(prior_path)) @@ -133,6 +140,20 @@ def build_model_from_dir( infection_feedback_strength_rv=priors["inf_feedback_strength_rv"], infection_feedback_pmf_rv=infection_feedback_pmf_rv, n_initialization_points=n_initialization_points, + pop_fraction=pop_fraction, + autoreg_rt_subpop_rv=priors["autoreg_rt_subpop_rv"], + sigma_rt_rv=priors["sigma_rt_rv"], + sigma_i_first_obs_rv=priors["sigma_i_first_obs_rv"], + sigma_initial_exp_growth_rate_rv=priors[ + "sigma_initial_exp_growth_rate_rv" + ], + offset_ref_logit_i_first_obs_rv=priors[ + "offset_ref_logit_i_first_obs_rv" + ], + offset_ref_initial_exp_growth_rate_rv=priors[ + "offset_ref_initial_exp_growth_rate_rv" + ], + offset_ref_log_rt_rv=priors["offset_ref_log_rt_rv"], ) ed_visit_obs_rv = EDVisitObservationProcess( @@ -175,16 +196,21 @@ def build_model_from_dir( wastewater_obs_process_rv=wastewater_obs_rv, ) + wastewater_data = PyrenewWastewaterData( + data_observed_disease_wastewater=data_observed_disease_wastewater, + population_size=population_size, + ) + dat = PyrenewHEWData( data_observed_disease_ed_visits=data_observed_disease_ed_visits, data_observed_disease_hospital_admissions=( data_observed_disease_hospital_admissions ), - data_observed_disease_wastewater=data_observed_disease_wastewater, right_truncation_offset=right_truncation_offset, first_ed_visits_date=first_ed_visits_date, first_hospital_admissions_date=first_hospital_admissions_date, - population_size=population_size, + pop_fraction=pop_fraction, + **wastewater_data.to_pyrenew_hew_data_args(), ) return (mod, dat) diff --git a/pipelines/forecast_state.py b/pipelines/forecast_state.py index b9b5bec2..5717336c 100644 --- a/pipelines/forecast_state.py +++ b/pipelines/forecast_state.py @@ -335,22 +335,28 @@ def main( "No data available for the requested report date " f"{report_date}" ) - available_nwss_reports = get_available_reports(nwss_data_dir) - # assming NWSS_vintage directory follows naming convention - # using as of date - # need to be modified otherwise - - if report_date in available_nwss_reports: - nwss_data_raw = pl.scan_parquet( - Path(nwss_data_dir, f"{report_date}.parquet") - ) - nwss_data_cleaned = clean_nwss_data(nwss_data_raw).filter( - (pl.col("location") == state) - & (pl.col("date") >= first_training_date) - ) - state_level_nwss_data = preprocess_ww_data(nwss_data_cleaned) + if fit_wastewater: + available_nwss_reports = get_available_reports(nwss_data_dir) + # assming NWSS_vintage directory follows naming convention + # using as of date + if report_date in available_nwss_reports: + nwss_data_raw = pl.scan_parquet( + Path(nwss_data_dir, f"{report_date}.parquet") + ) + nwss_data_cleaned = clean_nwss_data(nwss_data_raw).filter( + (pl.col("location") == state) + & (pl.col("date") >= first_training_date) + ) + state_level_nwss_data = preprocess_ww_data( + nwss_data_cleaned.collect() + ) + else: + raise ValueError( + "NWSS data not available for the requested report date " + f"{report_date}" + ) else: - state_level_nwss_data = None ## TO DO: change + state_level_nwss_data = None param_estimates = pl.scan_parquet(Path(param_data_dir, "prod.parquet")) model_batch_dir_name = ( @@ -453,9 +459,10 @@ def main( ) logger.info("Postprocessing complete.") - logger.info("Rendering webpage...") - render_diagnostic_report(model_run_dir) - logger.info("Rendering complete.") + if pyrenew_model_name == "pyrenew_e": + logger.info("Rendering webpage...") + render_diagnostic_report(model_run_dir) + logger.info("Rendering complete.") if score: logger.info("Scoring forecast...") diff --git a/pipelines/generate_test_data.R b/pipelines/generate_test_data.R index 4f6d3c88..6d302a81 100644 --- a/pipelines/generate_test_data.R +++ b/pipelines/generate_test_data.R @@ -304,6 +304,62 @@ generate_fake_param_data <- ) } + +#' Generate Fake NWSS Data +#' +#' This function generates fake wastewater data for a +#' and saves it as a parquet file. + +generate_fake_nwss_data <- function( + private_data_dir = fs::path(getwd()), + states_to_generate = c("MT", "CA"), + start_reference = as.Date("2024-06-01"), + end_reference = as.Date("2024-12-21"), + site = list( + CA = c(1, 2, 3, 4), + MT = c(5, 6, 7, 8) + ), + lab = list( + CA = c(1, 1, 2, 2), + MT = c(3, 3, 4, 4) + ), + lod = c(20, 31, 20, 30), + site_pop = list( + CA = c(4e6, 2e6, 1e6, 5e5), + MT = c(3e5, 2e5, 1e5, 5e4) + )) { + ww_dir <- fs::path(private_data_dir, "nwss_vintages") + fs::dir_create(ww_dir, recurse = TRUE) + + site_info <- function(state) { + tibble::tibble( + wwtp_id = site[[state]], + lab_id = lab[[state]], + lod_sewage = lod, + population_served = site_pop[[state]], + sample_location = "wwtp", + sample_matrix = "raw wastewater", + pcr_target_units = "copies/l wastewater", + pcr_target = "sars-cov-2", + quality_flag = c("no", NA_character_, "n", "n"), + wwtp_jurisdiction = state + ) + } + + ww_data <- purrr::map_dfr(states_to_generate, site_info) |> + tidyr::expand_grid( + sample_collect_date = seq(start_reference, end_reference, by = "week") + ) |> + dplyr::mutate( + pcr_target_avg_conc = abs(rnorm(dplyr::n(), mean = 500, sd = 50)) + ) + + arrow::write_parquet( + ww_data, fs::path(ww_dir, paste0(end_reference, ".parquet")) + ) +} + + main <- function(private_data_dir, target_diseases, n_forecast_days) { @@ -335,6 +391,9 @@ main <- function(private_data_dir, states_to_generate = c("MT", "CA", "US"), target_diseases = short_target_diseases ) + generate_fake_nwss_data( + private_data_dir + ) } p <- arg_parser("Create simulated epiweekly data.") |> diff --git a/pipelines/prep_data.py b/pipelines/prep_data.py index 14e657f9..2649cd7d 100644 --- a/pipelines/prep_data.py +++ b/pipelines/prep_data.py @@ -457,6 +457,25 @@ def process_and_save_state( else None ) + if state_level_nwss_data is None: + pop_fraction = jnp.array([1]) + else: + subpop_sizes = ( + state_level_nwss_data.select(["site_index", "site", "site_pop"]) + .unique() + .get_column("site_pop") + .to_numpy() + ) + if state_pop > sum(subpop_sizes): + pop_fraction = ( + jnp.concatenate( + (jnp.array([state_pop - sum(subpop_sizes)]), subpop_sizes) + ) + / state_pop + ) + else: + pop_fraction = subpop_sizes / state_pop + data_for_model_fit = { "inf_to_ed_pmf": delay_pmf, "generation_interval_pmf": generation_interval_pmf, @@ -471,6 +490,7 @@ def process_and_save_state( "state_pop": state_pop, "right_truncation_offset": right_truncation_offset, "data_observed_disease_wastewater": data_observed_disease_wastewater, + "pop_fraction": pop_fraction.tolist(), } data_dir = Path(model_run_dir, "data") diff --git a/pipelines/prep_ww_data.py b/pipelines/prep_ww_data.py index 513143c6..af52d022 100644 --- a/pipelines/prep_ww_data.py +++ b/pipelines/prep_ww_data.py @@ -1,6 +1,3 @@ -import datetime -from pathlib import Path - import polars as pl diff --git a/pipelines/priors/prod_priors.py b/pipelines/priors/prod_priors.py index c8b1beaa..3ca83eaf 100644 --- a/pipelines/priors/prod_priors.py +++ b/pipelines/priors/prod_priors.py @@ -117,6 +117,47 @@ "mode_sd_ww_site", dist.TruncatedNormal(0, 0.25, low=0) ) +autoreg_rt_subpop_rv = DistributionalVariable( + "autoreg_rt_subpop", dist.Beta(1, 4) +) +sigma_rt_rv = DistributionalVariable( + "sigma_rt", dist.TruncatedNormal(0, 0.1, low=0) +) + +sigma_i_first_obs_rv = DistributionalVariable( + "sigma_i_first_obs", + dist.TruncatedNormal(0, 0.5, low=0), +) + +sigma_initial_exp_growth_rate_rv = DistributionalVariable( + "sigma_initial_exp_growth_rate", + dist.TruncatedNormal( + 0, + 0.05, + low=0, + ), +) + +offset_ref_logit_i_first_obs_rv = DistributionalVariable( + "offset_ref_logit_i_first_obs", + dist.Normal(0, 0.25), +) + +offset_ref_initial_exp_growth_rate_rv = DistributionalVariable( + "offset_ref_initial_exp_growth_rate", + dist.TruncatedNormal( + 0, + 0.025, + low=-0.01, + high=0.01, + ), +) + +offset_ref_log_rt_rv = DistributionalVariable( + "offset_ref_log_r_t", + dist.Normal(0, 0.2), +) + # model constants related to wastewater obs process ww_ml_produced_per_day = 227000 max_shed_interval = 26 diff --git a/pipelines/tests/create_more_model_test_data.R b/pipelines/tests/create_more_model_test_data.R index ff0e90a1..fc2e1d30 100644 --- a/pipelines/tests/create_more_model_test_data.R +++ b/pipelines/tests/create_more_model_test_data.R @@ -13,10 +13,6 @@ try(source("pipelines/plot_and_save_state_forecast.R")) try(source("pipelines/score_forecast.R")) model_batch_dirs <- c( - path( - "pipelines/tests/private_data", - "covid-19_r_2024-12-21_f_2024-10-22_t_2024-12-20" - ), path( "pipelines/tests/private_data", "influenza_r_2024-12-21_f_2024-10-22_t_2024-12-20" diff --git a/pipelines/tests/test_build_pyrenew_model.py b/pipelines/tests/test_build_pyrenew_model.py index 95839652..008ac8c5 100644 --- a/pipelines/tests/test_build_pyrenew_model.py +++ b/pipelines/tests/test_build_pyrenew_model.py @@ -13,7 +13,7 @@ def mock_data(): { "data_observed_disease_ed_visits": [1, 2, 3], "data_observed_disease_hospital_admissions": [4, 5, 6], - "state_pop": [7, 8, 9], + "state_pop": [10000], "generation_interval_pmf": [0.1, 0.2, 0.7], "inf_to_ed_pmf": [0.4, 0.5, 0.1], "inf_to_hosp_admit_pmf": [0.0, 0.7, 0.1, 0.1, 0.1], @@ -32,13 +32,14 @@ def mock_data(): ], "site": ["1.0", "1.0", "2.0", "2.0"], "lab": ["1.0", "1.0", "1.0", "1.0"], - "site_pop": [4000000, 4000000, 2000000, 2000000], + "site_pop": [4000, 4000, 2000, 2000], "site_index": [1, 1, 0, 0], "lab_site_index": [1, 1, 0, 0], "log_genome_copies_per_ml": [0.1, 0.1, 0.5, 0.4], "log_lod": [1.1, 2.0, 1.5, 2.1], "below_lod": [False, False, False, False], }, + "pop_fraction": [0.4, 0.4, 0.2], } ) @@ -72,6 +73,14 @@ def mock_priors(): mode_sd_ww_site_rv = None max_shed_interval = None ww_ml_produced_per_day = None +pop_fraction=None +autoreg_rt_subpop_rv=None +sigma_rt_rv=None +sigma_i_first_obs_rv=None +sigma_initial_exp_growth_rate_rv=None +offset_ref_logit_i_first_obs_rv=None +offset_ref_initial_exp_growth_rate_rv=None +offset_ref_log_rt_rv=None """ @@ -91,7 +100,7 @@ def test_build_model_from_dir(tmp_path, mock_data, mock_priors): _, data = build_model_from_dir(model_dir) assert data.data_observed_disease_ed_visits is None assert data.data_observed_disease_hospital_admissions is None - assert data.data_observed_disease_wastewater is None + assert data.data_observed_disease_wastewater_conc is None # Test when all `fit_` arguments are True _, data = build_model_from_dir( @@ -108,7 +117,4 @@ def test_build_model_from_dir(tmp_path, mock_data, mock_priors): data.data_observed_disease_hospital_admissions, jnp.array(model_data["data_observed_disease_hospital_admissions"]), ) - assert pl.DataFrame( - model_data["data_observed_disease_wastewater"], - schema_overrides={"date": pl.Date}, - ).equals(data.data_observed_disease_wastewater) + assert data.data_observed_disease_wastewater_conc is not None diff --git a/pipelines/tests/test_end_to_end.sh b/pipelines/tests/test_end_to_end.sh index 309bb8b4..9c7730bc 100755 --- a/pipelines/tests/test_end_to_end.sh +++ b/pipelines/tests/test_end_to_end.sh @@ -16,13 +16,72 @@ if [ $? -ne 0 ]; then else echo "TEST-MODE: Finished generating test data" fi -echo "TEST-MODE: Running forecasting pipeline for two diseases in multiple locations" +echo "TEST-MODE: Running forecasting pipeline for COVID-19 in multiple states" +for state in CA MT +do + python pipelines/forecast_state.py \ + --disease COVID-19 \ + --state $state \ + --facility-level-nssp-data-dir "$BASE_DIR/private_data/nssp_etl_gold" \ + --state-level-nssp-data-dir "$BASE_DIR/private_data/nssp_state_level_gold" \ + --priors-path pipelines/priors/prod_priors.py \ + --param-data-dir "$BASE_DIR/private_data/prod_param_estimates" \ + --nwss-data-dir "$BASE_DIR/private_data/nwss_vintages" \ + --output-dir "$BASE_DIR/private_data" \ + --n-training-days 60 \ + --n-chains 2 \ + --n-samples 250 \ + --n-warmup 250 \ + --fit-ed-visits \ + --fit-hospital-admissions \ + --fit-wastewater \ + --forecast-ed-visits \ + --forecast-hospital-admissions \ + --forecast-wastewater \ + --score \ + --eval-data-path "$BASE_DIR/private_data/nssp-etl" + if [ $? -ne 0 ]; then + echo "TEST-MODE FAIL: Forecasting/postprocessing/scoring pipeline failed" + exit 1 + else + echo "TEST-MODE: Finished forecasting/postprocessing/scoring pipeline for COVID-19 in location" $state"." + fi +done + +echo "TEST-MODE: Running forecasting pipeline for COVID-19 in US" +python pipelines/forecast_state.py \ + --disease COVID-19 \ + --state US \ + --facility-level-nssp-data-dir "$BASE_DIR/private_data/nssp_etl_gold" \ + --state-level-nssp-data-dir "$BASE_DIR/private_data/nssp_state_level_gold" \ + --priors-path pipelines/priors/prod_priors.py \ + --param-data-dir "$BASE_DIR/private_data/prod_param_estimates" \ + --nwss-data-dir "$BASE_DIR/private_data/nwss_vintages" \ + --output-dir "$BASE_DIR/private_data" \ + --n-training-days 60 \ + --n-chains 2 \ + --n-samples 250 \ + --n-warmup 250 \ + --fit-ed-visits \ + --fit-hospital-admissions \ + --no-fit-wastewater \ + --forecast-ed-visits \ + --forecast-hospital-admissions \ + --no-forecast-wastewater \ + --score \ + --eval-data-path "$BASE_DIR/private_data/nssp-etl" +if [ $? -ne 0 ]; then + echo "TEST-MODE FAIL: Forecasting/postprocessing/scoring pipeline failed" + exit 1 +else + echo "TEST-MODE: Finished forecasting/postprocessing/scoring pipeline for COVID-19 in location US." +fi + +echo "TEST-MODE: Running forecasting pipeline for Influenza in multiple states" for state in CA MT US do - for disease in COVID-19 Influenza - do python pipelines/forecast_state.py \ - --disease $disease \ + --disease Influenza \ --state $state \ --facility-level-nssp-data-dir "$BASE_DIR/private_data/nssp_etl_gold" \ --state-level-nssp-data-dir "$BASE_DIR/private_data/nssp_state_level_gold" \ @@ -45,9 +104,8 @@ do echo "TEST-MODE FAIL: Forecasting/postprocessing/scoring pipeline failed" exit 1 else - echo "TEST-MODE: Finished forecasting/postprocessing/scoring pipeline for disease" $disease "in location" $state"." + echo "TEST-MODE: Finished forecasting/postprocessing/scoring pipeline for Influenza in location" $state"." fi - done done echo "TEST-MODE: pipeline runs complete for all location/disease pairs." diff --git a/pyrenew_hew/pyrenew_hew_data.py b/pyrenew_hew/pyrenew_hew_data.py index 99de33f4..91c53078 100644 --- a/pyrenew_hew/pyrenew_hew_data.py +++ b/pyrenew_hew/pyrenew_hew_data.py @@ -2,9 +2,10 @@ from typing import Self import jax.numpy as jnp -import polars as pl from jax.typing import ArrayLike +from pyrenew_hew.pyrenew_wastewater_data import PyrenewWastewaterData + class PyrenewHEWData: """ @@ -19,19 +20,25 @@ def __init__( n_wastewater_data_days: int = None, data_observed_disease_ed_visits: ArrayLike = None, data_observed_disease_hospital_admissions: ArrayLike = None, - data_observed_disease_wastewater: pl.DataFrame = None, right_truncation_offset: int = None, first_ed_visits_date: datetime.date = None, first_hospital_admissions_date: datetime.date = None, first_wastewater_date: datetime.date = None, - population_size: int = None, - shedding_offset: float = 1e-8, - pop_fraction: ArrayLike = jnp.array([1]), + n_ww_lab_sites: int = None, + ww_censored: ArrayLike = None, + ww_uncensored: ArrayLike = None, + ww_observed_subpops: ArrayLike = None, + ww_observed_times: ArrayLike = None, + ww_observed_lab_sites: ArrayLike = None, + lab_site_to_subpop_map: ArrayLike = None, + ww_log_lod: ArrayLike = None, + date_observed_disease_wastewater: ArrayLike = None, + data_observed_disease_wastewater_conc: ArrayLike = None, + pop_fraction: ArrayLike = None, ) -> None: self.n_ed_visits_data_days_ = n_ed_visits_data_days self.n_hospital_admissions_data_days_ = n_hospital_admissions_data_days self.n_wastewater_data_days_ = n_wastewater_data_days - self.data_observed_disease_ed_visits = data_observed_disease_ed_visits self.data_observed_disease_hospital_admissions = ( data_observed_disease_hospital_admissions @@ -40,12 +47,21 @@ def __init__( self.first_ed_visits_date = first_ed_visits_date self.first_hospital_admissions_date = first_hospital_admissions_date self.first_wastewater_date_ = first_wastewater_date - self.data_observed_disease_wastewater = ( - data_observed_disease_wastewater + self.date_observed_disease_wastewater = ( + date_observed_disease_wastewater ) - self.population_size = population_size - self.shedding_offset = shedding_offset - self.pop_fraction_ = pop_fraction + self.pop_fraction = pop_fraction + self.data_observed_disease_wastewater_conc = ( + data_observed_disease_wastewater_conc + ) + self.ww_censored = ww_censored + self.ww_uncensored = ww_uncensored + self.ww_observed_times = ww_observed_times + self.ww_observed_subpops = ww_observed_subpops + self.ww_observed_lab_sites = ww_observed_lab_sites + self.ww_log_lod = ww_log_lod + self.n_ww_lab_sites = n_ww_lab_sites + self.lab_site_to_subpop_map = lab_site_to_subpop_map @property def n_ed_visits_data_days(self): @@ -65,19 +81,23 @@ def n_hospital_admissions_data_days(self): def n_wastewater_data_days(self): return self.get_n_wastewater_data_days( n_datapoints=self.n_wastewater_data_days_, - date_array=( - None - if self.data_observed_disease_wastewater is None - else self.data_observed_disease_wastewater["date"] - ), + date_array=self.date_observed_disease_wastewater, ) @property def first_wastewater_date(self): - if self.data_observed_disease_wastewater is not None: - return self.data_observed_disease_wastewater["date"].min() + if self.date_observed_disease_wastewater is not None: + return self.date_observed_disease_wastewater.min() return self.first_wastewater_date_ + @property + def last_wastewater_date(self): + return self.get_end_date( + self.first_wastewater_date, + self.n_wastewater_data_days, + timestep_days=1, + ) + @property def last_ed_visits_date(self): return self.get_end_date( @@ -94,14 +114,6 @@ def last_hospital_admissions_date(self): timestep_days=7, ) - @property - def last_wastewater_date(self): - return self.get_end_date( - self.first_wastewater_date, - self.n_wastewater_data_days, - timestep_days=1, - ) - @property def first_data_dates(self): return dict( @@ -132,160 +144,6 @@ def n_days_post_init(self): self.last_data_date_overall - self.first_data_date_overall ).days - @property - def site_subpop_spine(self): - ww_data_present = self.data_observed_disease_wastewater is not None - if ww_data_present: - # Check if auxiliary subpopulation needs to be added - add_auxiliary_subpop = ( - self.population_size - > self.data_observed_disease_wastewater.select( - pl.col("site_pop", "site", "lab", "lab_site_index") - ) - .unique() - .get_column("site_pop") - .sum() - ) - site_indices = ( - self.data_observed_disease_wastewater.select( - ["site_index", "site", "site_pop"] - ) - .unique() - .sort("site_index") - ) - if add_auxiliary_subpop: - aux_subpop = pl.DataFrame( - { - "site_index": [None], - "site": [None], - "site_pop": [ - self.population_size - - site_indices.select(pl.col("site_pop")) - .get_column("site_pop") - .sum() - ], - } - ) - else: - aux_subpop = pl.DataFrame() - site_subpop_spine = ( - pl.concat([aux_subpop, site_indices], how="vertical_relaxed") - .with_columns( - subpop_index=pl.col("site_index") - .cum_count() - .alias("subpop_index"), - subpop_name=pl.format( - "Site: {}", pl.col("site") - ).fill_null("remainder of population"), - ) - .rename({"site_pop": "subpop_pop"}) - ) - else: - site_subpop_spine = pl.DataFrame( - { - "site_index": [None], - "site": [None], - "subpop_pop": [self.population_size], - "subpop_index": [1], - "subpop_name": ["total population"], - } - ) - return site_subpop_spine - - @property - def date_time_spine(self): - if self.data_observed_disease_wastewater is not None: - date_time_spine = pl.DataFrame( - { - "date": pl.date_range( - start=self.first_wastewater_date, - end=self.last_wastewater_date, - interval="1d", - eager=True, - ) - } - ).with_row_index("t") - return date_time_spine - - @property - def wastewater_data_extended(self): - if self.data_observed_disease_wastewater is not None: - return ( - self.data_observed_disease_wastewater.join( - self.date_time_spine, on="date", how="left", coalesce=True - ) - .join( - self.site_subpop_spine, - on=["site_index", "site"], - how="left", - coalesce=True, - ) - .with_row_index("ind_rel_to_observed_times") - ) - - @property - def pop_fraction(self): - if self.data_observed_disease_wastewater is not None: - subpop_sizes = self.site_subpop_spine["subpop_pop"].to_numpy() - return subpop_sizes / self.population_size - return self.pop_fraction_ - - @property - def data_observed_disease_wastewater_conc(self): - if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended[ - "log_genome_copies_per_ml" - ].to_numpy() - - @property - def ww_censored(self): - if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended.filter( - pl.col("below_lod") == 1 - )["ind_rel_to_observed_times"].to_numpy() - return None - - @property - def ww_uncensored(self): - if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended.filter( - pl.col("below_lod") == 0 - )["ind_rel_to_observed_times"].to_numpy() - - @property - def ww_observed_times(self): - if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended["t"].to_numpy() - - @property - def ww_observed_subpops(self): - if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended["subpop_index"].to_numpy() - - @property - def ww_observed_lab_sites(self): - if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended["lab_site_index"].to_numpy() - - @property - def ww_log_lod(self): - if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended["log_lod"].to_numpy() - - @property - def n_ww_lab_sites(self): - if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended["lab_site_index"].n_unique() - - @property - def lab_site_to_subpop_map(self): - if self.data_observed_disease_wastewater is not None: - return ( - self.wastewater_data_extended["lab_site_index", "subpop_index"] - .unique() - .sort(by="lab_site_index") - )["subpop_index"].to_numpy() - def get_end_date( self, first_date: datetime.date, @@ -354,6 +212,15 @@ def to_forecast_data(self, n_forecast_points: int) -> Self: first_ed_visits_date=self.first_data_date_overall, first_hospital_admissions_date=(self.first_data_date_overall), first_wastewater_date=self.first_data_date_overall, - pop_fraction=self.pop_fraction, right_truncation_offset=None, # by default, want forecasts of complete reports + n_ww_lab_sites=self.n_ww_lab_sites, + ww_uncensored=self.ww_uncensored, + ww_censored=self.ww_censored, + ww_observed_lab_sites=self.ww_observed_lab_sites, + ww_observed_subpops=self.ww_observed_subpops, + ww_observed_times=self.ww_observed_times, + lab_site_to_subpop_map=self.lab_site_to_subpop_map, + ww_log_lod=self.ww_log_lod, + pop_fraction=self.pop_fraction, + data_observed_disease_wastewater_conc=None, ) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index da2fc592..a0b8f520 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -120,7 +120,7 @@ def sample(self, n_days_post_init: int): log_rtu_weekly_subpop = log_rtu_weekly[:, jnp.newaxis] else: i_first_obs_over_n_ref_subpop = transformation.SigmoidTransform()( - transformation.logit(i0_first_obs_n) + transformation.SigmoidTransform().inv(i0_first_obs_n) + self.offset_ref_logit_i_first_obs_rv(), ) initial_exp_growth_rate_ref_subpop = ( @@ -136,7 +136,7 @@ def sample(self, n_days_post_init: int): DistributionalVariable( "i_first_obs_over_n_non_ref_subpop_raw", dist.Normal( - transformation.logit(i0_first_obs_n), + transformation.SigmoidTransform().inv(i0_first_obs_n), self.sigma_i_first_obs_rv(), ), reparam=LocScaleReparam(0), @@ -759,7 +759,6 @@ def sample( site_level_observed_wastewater, population_level_latent_wastewater, ) = self.wastewater_obs_process_rv( - latent_infections=latent_infections, latent_infections_subpop=latent_infections_subpop, data_observed=data.data_observed_disease_wastewater_conc, n_datapoints=data.n_wastewater_data_days, @@ -772,6 +771,7 @@ def sample( lab_site_to_subpop_map=data.lab_site_to_subpop_map, n_ww_lab_sites=data.n_ww_lab_sites, shedding_offset=1e-8, + pop_fraction=data.pop_fraction, ) return { diff --git a/pyrenew_hew/pyrenew_wastewater_data.py b/pyrenew_hew/pyrenew_wastewater_data.py new file mode 100644 index 00000000..bc0c9471 --- /dev/null +++ b/pyrenew_hew/pyrenew_wastewater_data.py @@ -0,0 +1,207 @@ +import datetime +from typing import Self + +import jax +import jax.numpy as jnp +import polars as pl +from jax.typing import ArrayLike + + +class PyrenewWastewaterData: + """ + Class for holding wastewater input data + to a PyrenewHEW model. + """ + + def __init__( + self, + data_observed_disease_wastewater: pl.DataFrame = None, + population_size: int = None, + ) -> None: + self.data_observed_disease_wastewater = ( + data_observed_disease_wastewater + ) + self.population_size = population_size + + @property + def site_subpop_spine(self): + ww_data_present = self.data_observed_disease_wastewater is not None + if ww_data_present: + site_indices = ( + self.data_observed_disease_wastewater.select( + ["site_index", "site", "site_pop"] + ) + .unique() + .sort("site_index") + ) + + total_pop_ww = ( + self.data_observed_disease_wastewater.unique( + ["site_pop", "site"] + ) + .get_column("site_pop") + .sum() + ) + + total_pop_no_ww = self.population_size - total_pop_ww + add_auxiliary_subpop = total_pop_no_ww > 0 + + if add_auxiliary_subpop: + aux_subpop = pl.DataFrame( + { + "site_index": [None], + "site": [None], + "site_pop": [total_pop_no_ww], + } + ) + else: + aux_subpop = pl.DataFrame() + site_subpop_spine = ( + pl.concat([aux_subpop, site_indices], how="vertical_relaxed") + .with_columns( + subpop_index=pl.col("site_index") + .cum_count() + .alias("subpop_index"), + subpop_name=pl.format( + "Site: {}", pl.col("site") + ).fill_null("remainder of population"), + ) + .rename({"site_pop": "subpop_pop"}) + ) + else: + site_subpop_spine = pl.DataFrame( + { + "site_index": [None], + "site": [None], + "subpop_pop": [self.population_size], + "subpop_index": [1], + "subpop_name": ["total population"], + } + ) + return site_subpop_spine + + @property + def date_time_spine(self): + if self.data_observed_disease_wastewater is not None: + date_time_spine = pl.DataFrame( + { + "date": pl.date_range( + start=self.date_observed_disease_wastewater.min(), + end=self.date_observed_disease_wastewater.max(), + interval="1d", + eager=True, + ) + } + ).with_row_index("t") + return date_time_spine + + @property + def wastewater_data_extended(self): + if self.data_observed_disease_wastewater is not None: + return ( + self.data_observed_disease_wastewater.join( + self.date_time_spine, on="date", how="left", coalesce=True + ) + .join( + self.site_subpop_spine, + on=["site_index", "site"], + how="left", + coalesce=True, + ) + .with_row_index("ind_rel_to_observed_times") + ) + + @property + def date_observed_disease_wastewater(self): + if self.data_observed_disease_wastewater is not None: + return self.data_observed_disease_wastewater.get_column( + "date" + ).unique() + + @property + def data_observed_disease_wastewater_conc(self): + if self.data_observed_disease_wastewater is not None: + return self.wastewater_data_extended.get_column( + "log_genome_copies_per_ml" + ).to_numpy() + + @property + def ww_censored(self): + if self.data_observed_disease_wastewater is not None: + return ( + self.wastewater_data_extended.filter(pl.col("below_lod") == 1) + .get_column("ind_rel_to_observed_times") + .to_numpy() + ) + return None + + @property + def ww_uncensored(self): + if self.data_observed_disease_wastewater is not None: + return ( + self.wastewater_data_extended.filter(pl.col("below_lod") == 0) + .get_column("ind_rel_to_observed_times") + .to_numpy() + ) + + @property + def ww_observed_times(self): + if self.data_observed_disease_wastewater is not None: + return self.wastewater_data_extended.get_column("t").to_numpy() + + @property + def ww_observed_subpops(self): + if self.data_observed_disease_wastewater is not None: + return self.wastewater_data_extended.get_column( + "subpop_index" + ).to_numpy() + + @property + def ww_observed_lab_sites(self): + if self.data_observed_disease_wastewater is not None: + return self.wastewater_data_extended.get_column( + "lab_site_index" + ).to_numpy() + + @property + def ww_log_lod(self): + if self.data_observed_disease_wastewater is not None: + return self.wastewater_data_extended.get_column( + "log_lod" + ).to_numpy() + + @property + def n_ww_lab_sites(self): + if self.data_observed_disease_wastewater is not None: + return self.wastewater_data_extended["lab_site_index"].n_unique() + + @property + def lab_site_to_subpop_map(self): + if self.data_observed_disease_wastewater is not None: + return ( + ( + self.wastewater_data_extended[ + "lab_site_index", "subpop_index" + ] + .unique() + .sort(by="lab_site_index") + ) + .get_column("subpop_index") + .to_numpy() + ) + + def to_pyrenew_hew_data_args(self): + return { + attr: getattr(self, attr) + for attr in [ + "n_ww_lab_sites", + "ww_censored", + "ww_uncensored", + "ww_log_lod", + "ww_observed_lab_sites", + "ww_observed_subpops", + "ww_observed_times", + "data_observed_disease_wastewater_conc", + "lab_site_to_subpop_map", + ] + } diff --git a/tests/test_pyrenew_hew_data.py b/tests/test_pyrenew_hew_data.py index 99f42058..cebff9fd 100644 --- a/tests/test_pyrenew_hew_data.py +++ b/tests/test_pyrenew_hew_data.py @@ -1,11 +1,9 @@ import datetime -import jax.numpy as jnp -import numpy as np -import polars as pl import pytest from pyrenew_hew.pyrenew_hew_data import PyrenewHEWData +from pyrenew_hew.pyrenew_wastewater_data import PyrenewWastewaterData @pytest.mark.parametrize( @@ -75,14 +73,15 @@ def test_to_forecast_data( """ Test the to_forecast_data method """ + ww_dat = PyrenewWastewaterData() + data = PyrenewHEWData( n_ed_visits_data_days=n_ed_visits_data_days, n_hospital_admissions_data_days=n_hospital_admissions_data_days, - n_wastewater_data_days=n_wastewater_data_days, first_ed_visits_date=first_ed_visits_date, first_hospital_admissions_date=first_hospital_admissions_date, - first_wastewater_date=first_wastewater_date, right_truncation_offset=right_truncation_offset, + **ww_dat.to_pyrenew_hew_data_args(), ) assert data.right_truncation_offset == right_truncation_offset @@ -101,67 +100,4 @@ def test_to_forecast_data( == data.first_data_date_overall ) assert forecast_data.first_wastewater_date == data.first_data_date_overall - assert forecast_data.data_observed_disease_wastewater is None - - -def test_wastewater_data_properties(): - first_training_date = datetime.date(2023, 1, 1) - last_training_date = datetime.date(2023, 7, 23) - dates = pl.date_range( - first_training_date, - last_training_date, - interval="1w", - closed="both", - eager=True, - ) - - ww_raw = pl.DataFrame( - { - "date": dates.extend(dates), - "site": [200] * 30 + [100] * 30, - "lab": [21] * 60, - "lab_site_index": [1] * 30 + [2] * 30, - "site_index": [1] * 30 + [2] * 30, - "log_genome_copies_per_ml": np.log( - np.abs(np.random.normal(loc=500, scale=50, size=60)) - ), - "log_lod": np.log([20] * 30 + [15] * 30), - "site_pop": [200_000] * 30 + [400_000] * 30, - } - ) - - ww_data = ww_raw.with_columns( - (pl.col("log_genome_copies_per_ml") <= pl.col("log_lod")) - .cast(pl.Int8) - .alias("below_lod") - ) - - first_ed_visits_date = datetime.date(2023, 1, 1) - first_hospital_admissions_date = datetime.date(2023, 1, 1) - first_wastewater_date = datetime.date(2023, 1, 1) - n_forecast_points = 10 - - data = PyrenewHEWData( - first_ed_visits_date=first_ed_visits_date, - first_hospital_admissions_date=first_hospital_admissions_date, - first_wastewater_date=first_wastewater_date, - data_observed_disease_wastewater=ww_data, - population_size=1e6, - ) - - forecast_data = data.to_forecast_data(n_forecast_points) assert forecast_data.data_observed_disease_wastewater_conc is None - assert data.data_observed_disease_wastewater_conc is not None - - assert jnp.array_equal( - data.data_observed_disease_wastewater_conc, - ww_data["log_genome_copies_per_ml"], - ) - assert len(data.ww_censored) == len( - ww_data.filter(pl.col("below_lod") == 1) - ) - assert len(data.ww_uncensored) == len( - ww_data.filter(pl.col("below_lod") == 0) - ) - assert jnp.array_equal(data.ww_log_lod, ww_data["log_lod"]) - assert data.n_ww_lab_sites == ww_data["lab_site_index"].n_unique() diff --git a/tests/test_pyrenew_wastewater_data.py b/tests/test_pyrenew_wastewater_data.py new file mode 100644 index 00000000..0d9f51d9 --- /dev/null +++ b/tests/test_pyrenew_wastewater_data.py @@ -0,0 +1,89 @@ +import datetime + +import numpy as np +import polars as pl + +from pyrenew_hew.pyrenew_hew_data import PyrenewHEWData +from pyrenew_hew.pyrenew_wastewater_data import PyrenewWastewaterData + + +def test_pyrenew_wastewater_data(): + first_training_date = datetime.date(2023, 1, 1) + last_training_date = datetime.date(2023, 7, 23) + dates = pl.date_range( + first_training_date, + last_training_date, + interval="1w", + closed="both", + eager=True, + ) + + ww_raw = pl.DataFrame( + { + "date": dates.extend(dates), + "site": [200] * 30 + [100] * 30, + "lab": [21] * 60, + "lab_site_index": [1] * 30 + [2] * 30, + "site_index": [1] * 30 + [2] * 30, + "log_genome_copies_per_ml": np.log( + np.abs(np.random.normal(loc=500, scale=50, size=60)) + ), + "log_lod": np.log([20] * 30 + [15] * 30), + "site_pop": [200_000] * 30 + [400_000] * 30, + } + ) + + ww_data = ww_raw.with_columns( + (pl.col("log_genome_copies_per_ml") <= pl.col("log_lod")) + .cast(pl.Int8) + .alias("below_lod") + ) + + first_ed_visits_date = datetime.date(2023, 1, 1) + first_hospital_admissions_date = datetime.date(2023, 1, 1) + first_wastewater_date = datetime.date(2023, 1, 1) + n_forecast_points = 10 + + wastewater_data = PyrenewWastewaterData( + data_observed_disease_wastewater=ww_data, + population_size=1e6, + ) + + data = PyrenewHEWData( + first_ed_visits_date=first_ed_visits_date, + first_hospital_admissions_date=first_hospital_admissions_date, + first_wastewater_date=first_wastewater_date, + **wastewater_data.to_pyrenew_hew_data_args(), + ) + + forecast_data = data.to_forecast_data(n_forecast_points) + assert forecast_data.data_observed_disease_wastewater_conc is None + assert data.data_observed_disease_wastewater_conc is not None + + assert np.array_equal(data.ww_censored, forecast_data.ww_censored) + assert np.array_equal(data.ww_uncensored, forecast_data.ww_uncensored) + assert np.array_equal(data.ww_log_lod, forecast_data.ww_log_lod) + assert np.array_equal( + data.ww_observed_lab_sites, forecast_data.ww_observed_lab_sites + ) + assert np.array_equal( + data.ww_observed_subpops, forecast_data.ww_observed_subpops + ) + assert np.array_equal( + data.ww_observed_times, forecast_data.ww_observed_times + ) + assert np.array_equal(data.n_ww_lab_sites, forecast_data.n_ww_lab_sites) + assert np.array_equal(data.pop_fraction, forecast_data.pop_fraction) + + assert np.array_equal( + data.data_observed_disease_wastewater_conc, + ww_data["log_genome_copies_per_ml"], + ) + assert len(data.ww_censored) == len( + ww_data.filter(pl.col("below_lod") == 1) + ) + assert len(data.ww_uncensored) == len( + ww_data.filter(pl.col("below_lod") == 0) + ) + assert np.array_equal(data.ww_log_lod, ww_data["log_lod"]) + assert data.n_ww_lab_sites == ww_data["lab_site_index"].n_unique()