From 240cde212c467fb13091282b7fac36f4826a35dc Mon Sep 17 00:00:00 2001 From: sbidari Date: Fri, 31 Jan 2025 15:41:37 -0500 Subject: [PATCH 01/34] add wastewater data prep code --- pipelines/prep_ww_data.py | 29 +++++++---------------------- 1 file changed, 7 insertions(+), 22 deletions(-) diff --git a/pipelines/prep_ww_data.py b/pipelines/prep_ww_data.py index 513143c6..760f1f62 100644 --- a/pipelines/prep_ww_data.py +++ b/pipelines/prep_ww_data.py @@ -47,9 +47,7 @@ def clean_nwss_data(nwss_data): .when(pl.col("pcr_target_units") == "log10 copies/l wastewater") .then((10 ** pl.col("pcr_target_avg_conc")) / 1000) .otherwise(None), - lod_sewage=pl.when( - pl.col("pcr_target_units") == "copies/l wastewater" - ) + lod_sewage=pl.when(pl.col("pcr_target_units") == "copies/l wastewater") .then(pl.col("lod_sewage") / 1000) .when(pl.col("pcr_target_units") == "log10 copies/l wastewater") .then((10 ** pl.col("lod_sewage")) / 1000) @@ -130,9 +128,7 @@ def clean_nwss_data(nwss_data): ) .with_columns( [ - pl.col("pcr_target_avg_conc") - .log() - .alias("log_genome_copies_per_ml"), + pl.col("pcr_target_avg_conc").log().alias("log_genome_copies_per_ml"), pl.col("lod_sewage").log().alias("log_lod"), pl.col("location").str.to_uppercase().alias("location"), pl.col("site").cast(pl.String).alias("site"), @@ -211,9 +207,7 @@ def validate_ww_conc_data( .eq(1) .all() ): - raise ValueError( - "The data contains sites with varying population sizes." - ) + raise ValueError("The data contains sites with varying population sizes.") return None @@ -245,14 +239,10 @@ def preprocess_ww_data( .with_row_index("lab_site_index") ) site_df = ( - ww_data_ordered.select([wwtp_col_name]) - .unique() - .with_row_index("site_index") + ww_data_ordered.select([wwtp_col_name]).unique().with_row_index("site_index") ) ww_preprocessed = ( - ww_data_ordered.join( - lab_site_df, on=[lab_col_name, wwtp_col_name], how="left" - ) + ww_data_ordered.join(lab_site_df, on=[lab_col_name, wwtp_col_name], how="left") .join(site_df, on=wwtp_col_name, how="left") .rename( { @@ -262,14 +252,9 @@ def preprocess_ww_data( ) .with_columns( lab_site_name=( - "Site: " - + pl.col(wwtp_col_name) - + ", Lab: " - + pl.col(lab_col_name) - ), - below_lod=( - pl.col("log_genome_copies_per_ml") <= pl.col("log_lod") + "Site: " + pl.col(wwtp_col_name) + ", Lab: " + pl.col(lab_col_name) ), + below_lod=(pl.col("log_genome_copies_per_ml") <= pl.col("log_lod")), ) .select( [ From cad517239421945c43c1f914f6d2425795b02c5c Mon Sep 17 00:00:00 2001 From: Subekshya Date: Mon, 3 Feb 2025 19:27:45 +0000 Subject: [PATCH 02/34] update prep_data.py --- pipelines/prep_data.py | 47 +++++++++++++----------------------------- 1 file changed, 14 insertions(+), 33 deletions(-) diff --git a/pipelines/prep_data.py b/pipelines/prep_data.py index 14e657f9..42efde5b 100644 --- a/pipelines/prep_data.py +++ b/pipelines/prep_data.py @@ -11,6 +11,9 @@ import jax.numpy as jnp import polars as pl import polars.selectors as cs +import jax.numpy as jnp + +from prep_ww_data import get_nwss_data _disease_map = { "COVID-19": "COVID-19/Omicron", @@ -47,9 +50,7 @@ def py_scalar_to_r_scalar(py_scalar): state_abb_for_query = state_abb if state_abb != "US" else "USA" temp_file = Path(temp_dir, "nhsn_temp.parquet") - api_key_id = credentials_dict.get( - "nhsn_api_key_id", os.getenv("NHSN_API_KEY_ID") - ) + api_key_id = credentials_dict.get("nhsn_api_key_id", os.getenv("NHSN_API_KEY_ID")) api_key_secret = credentials_dict.get( "nhsn_api_key_secret", os.getenv("NHSN_API_KEY_SECRET") ) @@ -81,9 +82,7 @@ def py_scalar_to_r_scalar(py_scalar): if result.returncode != 0: raise RuntimeError(f"pull_and_save_nhsn: {result.stderr}") raw_dat = pl.read_parquet(temp_file) - dat = raw_dat.with_columns( - weekendingdate=pl.col("weekendingdate").cast(pl.Date) - ) + dat = raw_dat.with_columns(weekendingdate=pl.col("weekendingdate").cast(pl.Date)) return dat @@ -105,9 +104,7 @@ def combine_nssp_and_nhsn( variable_name="drop_me", value_name=".value", ) - .with_columns( - pl.col("count_type").replace(count_type_dict).alias(".variable") - ) + .with_columns(pl.col("count_type").replace(count_type_dict).alias(".variable")) .select(cs.exclude(["count_type", "drop_me"])) ) @@ -187,9 +184,7 @@ def process_state_level_data( if state_abb == "US": locations_to_aggregate = ( - state_pop_df.filter(pl.col("abb") != "US") - .get_column("abb") - .unique() + state_pop_df.filter(pl.col("abb") != "US").get_column("abb").unique() ) logger.info("Aggregating state-level data to national") state_level_nssp_data = aggregate_to_national( @@ -216,9 +211,7 @@ def process_state_level_data( ] ) .with_columns( - disease=pl.col("disease") - .cast(pl.Utf8) - .replace(_inverse_disease_map), + disease=pl.col("disease").cast(pl.Utf8).replace(_inverse_disease_map), ) .sort(["date", "disease"]) .collect(streaming=True) @@ -250,9 +243,7 @@ def aggregate_facility_level_nssp_to_state( if state_abb == "US": logger.info("Aggregating facility-level data to national") locations_to_aggregate = ( - state_pop_df.filter(pl.col("abb") != "US") - .get_column("abb") - .unique() + state_pop_df.filter(pl.col("abb") != "US").get_column("abb").unique() ) facility_level_nssp_data = aggregate_to_national( facility_level_nssp_data, @@ -271,9 +262,7 @@ def aggregate_facility_level_nssp_to_state( .group_by(["reference_date", "disease"]) .agg(pl.col("value").sum().alias("ed_visits")) .with_columns( - disease=pl.col("disease") - .cast(pl.Utf8) - .replace(_inverse_disease_map), + disease=pl.col("disease").cast(pl.Utf8).replace(_inverse_disease_map), geo_value=pl.lit(state_abb).cast(pl.Utf8), ) .rename({"reference_date": "date"}) @@ -363,16 +352,12 @@ def process_and_save_state( if facility_level_nssp_data is None and state_level_nssp_data is None: raise ValueError( - "Must provide at least one " - "of facility-level and state-level" - "NSSP data" + "Must provide at least one " "of facility-level and state-level" "NSSP data" ) state_pop_df = get_state_pop_df() - state_pop = state_pop_df.filter(pl.col("abb") == state_abb).item( - 0, "population" - ) + state_pop = state_pop_df.filter(pl.col("abb") == state_abb).item(0, "population") (generation_interval_pmf, delay_pmf, right_truncation_pmf) = get_pmfs( param_estimates=param_estimates, state_abb=state_abb, disease=disease @@ -421,17 +406,13 @@ def process_and_save_state( credentials_dict=credentials_dict, ).with_columns(pl.lit("train").alias("data_type")) - nssp_training_dates = ( - nssp_training_data.get_column("date").unique().to_list() - ) + nssp_training_dates = nssp_training_data.get_column("date").unique().to_list() nhsn_training_dates = ( nhsn_training_data.get_column("weekendingdate").unique().to_list() ) nhsn_first_date_index = next( - i - for i, x in enumerate(nssp_training_dates) - if x == min(nhsn_training_dates) + i for i, x in enumerate(nssp_training_dates) if x == min(nhsn_training_dates) ) nhsn_step_size = 7 From 26bbbcdb9cbf26d20f9d8efcdfb413eba333bae1 Mon Sep 17 00:00:00 2001 From: sbidari Date: Mon, 3 Feb 2025 15:07:34 -0500 Subject: [PATCH 03/34] pre-commit --- pipelines/prep_data.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pipelines/prep_data.py b/pipelines/prep_data.py index 42efde5b..22f9df05 100644 --- a/pipelines/prep_data.py +++ b/pipelines/prep_data.py @@ -9,10 +9,9 @@ import forecasttools import jax.numpy as jnp +import jax.numpy as jnp import polars as pl import polars.selectors as cs -import jax.numpy as jnp - from prep_ww_data import get_nwss_data _disease_map = { From 573f64ccb8f5264183ce5b6306f15ba628dba1e2 Mon Sep 17 00:00:00 2001 From: sbidari Date: Mon, 3 Feb 2025 17:39:55 -0500 Subject: [PATCH 04/34] add wastewater data in json file --- pipelines/prep_data.py | 1 - pyrenew_hew/pyrenew_hew_data.py | 38 ++++++++++++--------------------- 2 files changed, 14 insertions(+), 25 deletions(-) diff --git a/pipelines/prep_data.py b/pipelines/prep_data.py index 22f9df05..2b549c3b 100644 --- a/pipelines/prep_data.py +++ b/pipelines/prep_data.py @@ -448,7 +448,6 @@ def process_and_save_state( "nhsn_training_dates": nhsn_training_dates, "nhsn_first_date_index": nhsn_first_date_index, "nhsn_step_size": nhsn_step_size, - "state_pop": state_pop, "right_truncation_offset": right_truncation_offset, "data_observed_disease_wastewater": data_observed_disease_wastewater, } diff --git a/pyrenew_hew/pyrenew_hew_data.py b/pyrenew_hew/pyrenew_hew_data.py index 99de33f4..cf6b41c0 100644 --- a/pyrenew_hew/pyrenew_hew_data.py +++ b/pyrenew_hew/pyrenew_hew_data.py @@ -40,9 +40,7 @@ def __init__( self.first_ed_visits_date = first_ed_visits_date self.first_hospital_admissions_date = first_hospital_admissions_date self.first_wastewater_date_ = first_wastewater_date - self.data_observed_disease_wastewater = ( - data_observed_disease_wastewater - ) + self.data_observed_disease_wastewater = data_observed_disease_wastewater self.population_size = population_size self.shedding_offset = shedding_offset self.pop_fraction_ = pop_fraction @@ -128,9 +126,7 @@ def last_data_date_overall(self): @property def n_days_post_init(self): - return ( - self.last_data_date_overall - self.first_data_date_overall - ).days + return (self.last_data_date_overall - self.first_data_date_overall).days @property def site_subpop_spine(self): @@ -171,12 +167,10 @@ def site_subpop_spine(self): site_subpop_spine = ( pl.concat([aux_subpop, site_indices], how="vertical_relaxed") .with_columns( - subpop_index=pl.col("site_index") - .cum_count() - .alias("subpop_index"), - subpop_name=pl.format( - "Site: {}", pl.col("site") - ).fill_null("remainder of population"), + subpop_index=pl.col("site_index").cum_count().alias("subpop_index"), + subpop_name=pl.format("Site: {}", pl.col("site")).fill_null( + "remainder of population" + ), ) .rename({"site_pop": "subpop_pop"}) ) @@ -233,24 +227,22 @@ def pop_fraction(self): @property def data_observed_disease_wastewater_conc(self): if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended[ - "log_genome_copies_per_ml" - ].to_numpy() + return self.wastewater_data_extended["log_genome_copies_per_ml"].to_numpy() @property def ww_censored(self): if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended.filter( - pl.col("below_lod") == 1 - )["ind_rel_to_observed_times"].to_numpy() + return self.wastewater_data_extended.filter(pl.col("below_lod") == 1)[ + "ind_rel_to_observed_times" + ].to_numpy() return None @property def ww_uncensored(self): if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended.filter( - pl.col("below_lod") == 0 - )["ind_rel_to_observed_times"].to_numpy() + return self.wastewater_data_extended.filter(pl.col("below_lod") == 0)[ + "ind_rel_to_observed_times" + ].to_numpy() @property def ww_observed_times(self): @@ -306,9 +298,7 @@ def get_end_date( ) result = None else: - result = first_date + datetime.timedelta( - days=n_datapoints * timestep_days - ) + result = first_date + datetime.timedelta(days=n_datapoints * timestep_days) return result def get_n_data_days( From c279103bba22b82a32b91e6090228221f7dfd864 Mon Sep 17 00:00:00 2001 From: sbidari Date: Wed, 5 Feb 2025 16:29:41 -0500 Subject: [PATCH 05/34] move processing of wastewater data to pyrenew-hew-data --- pipelines/build_pyrenew_model.py | 2 ++ pyrenew_hew/pyrenew_hew_data.py | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/pipelines/build_pyrenew_model.py b/pipelines/build_pyrenew_model.py index 0f4359c8..c41b3f80 100644 --- a/pipelines/build_pyrenew_model.py +++ b/pipelines/build_pyrenew_model.py @@ -123,6 +123,8 @@ def build_model_from_dir( right_truncation_offset = model_data["right_truncation_offset"] + wastewater_data = pl.DataFrame(model_data["train_disease_wastewater"]) + my_latent_infection_model = LatentInfectionProcess( i0_first_obs_n_rv=priors["i0_first_obs_n_rv"], initialization_rate_rv=priors["initialization_rate_rv"], diff --git a/pyrenew_hew/pyrenew_hew_data.py b/pyrenew_hew/pyrenew_hew_data.py index cf6b41c0..002a9cab 100644 --- a/pyrenew_hew/pyrenew_hew_data.py +++ b/pyrenew_hew/pyrenew_hew_data.py @@ -76,6 +76,12 @@ def first_wastewater_date(self): return self.data_observed_disease_wastewater["date"].min() return self.first_wastewater_date_ + @property + def first_wastewater_date(self): + if self.data_observed_disease_wastewater is not None: + return self.data_observed_disease_wastewater["date"].min() + return self.first_wastewater_date_ + @property def last_ed_visits_date(self): return self.get_end_date( From 687bc54c88d014fc5d46d87f401a7d70418d290f Mon Sep 17 00:00:00 2001 From: sbidari Date: Wed, 5 Feb 2025 19:35:14 -0500 Subject: [PATCH 06/34] syn to use datetime.date for dates throughout, rename ww_sampled to ww_observed --- pyrenew_hew/pyrenew_hew_model.py | 67 ++++++++++---------------------- tests/test_pyrenew_hew_data.py | 13 ++----- 2 files changed, 23 insertions(+), 57 deletions(-) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index da2fc592..a62151be 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -73,9 +73,7 @@ def __init__( self.autoreg_rt_subpop_rv = autoreg_rt_subpop_rv self.sigma_rt_rv = sigma_rt_rv self.sigma_i_first_obs_rv = sigma_i_first_obs_rv - self.sigma_initial_exp_growth_rate_rv = ( - sigma_initial_exp_growth_rate_rv - ) + self.sigma_initial_exp_growth_rate_rv = sigma_initial_exp_growth_rate_rv self.n_initialization_points = n_initialization_points self.pop_fraction = pop_fraction self.n_subpops = len(pop_fraction) @@ -124,13 +122,10 @@ def sample(self, n_days_post_init: int): + self.offset_ref_logit_i_first_obs_rv(), ) initial_exp_growth_rate_ref_subpop = ( - initial_exp_growth_rate - + self.offset_ref_initial_exp_growth_rate_rv() + initial_exp_growth_rate + self.offset_ref_initial_exp_growth_rate_rv() ) - log_rtu_weekly_ref_subpop = ( - log_rtu_weekly + self.offset_ref_log_rt_rv() - ) + log_rtu_weekly_ref_subpop = log_rtu_weekly + self.offset_ref_log_rt_rv() i_first_obs_over_n_non_ref_subpop_rv = TransformedVariable( "i_first_obs_over_n_non_ref_subpop", DistributionalVariable( @@ -212,9 +207,7 @@ def sample(self, n_days_post_init: int): )[:n_days_post_init, :] ) # indexed rel to first post-init day. - i0_subpop_rv = DeterministicVariable( - "i0_subpop", i_first_obs_over_n_subpop - ) + i0_subpop_rv = DeterministicVariable("i0_subpop", i_first_obs_over_n_subpop) initial_exp_growth_rate_subpop_rv = DeterministicVariable( "initial_exp_growth_rate_subpop", initial_exp_growth_rate_subpop ) @@ -319,7 +312,9 @@ def sample( iedr = jnp.repeat( transformation.SigmoidTransform()(p_ed_ar + p_ed_mean), repeats=7, - )[:n_datapoints] # indexed rel to first ed report day + )[ + :n_datapoints + ] # indexed rel to first ed report day # this is only applied after the ed visits are generated, not to all # the latent infections. This is why we cannot apply the iedr in # compute_delay_ascertained_incidence @@ -339,28 +334,21 @@ def sample( )[-n_datapoints:] latent_ed_visits_final = ( - potential_latent_ed_visits - * iedr - * ed_wday_effect - * population_size + potential_latent_ed_visits * iedr * ed_wday_effect * population_size ) if right_truncation_offset is not None: prop_already_reported_tail = jnp.flip( self.ed_right_truncation_cdf_rv()[right_truncation_offset:] ) - n_points_to_prepend = ( - n_datapoints - prop_already_reported_tail.shape[0] - ) + n_points_to_prepend = n_datapoints - prop_already_reported_tail.shape[0] prop_already_reported = jnp.pad( prop_already_reported_tail, (n_points_to_prepend, 0), mode="constant", constant_values=(1, 0), ) - latent_ed_visits_now = ( - latent_ed_visits_final * prop_already_reported - ) + latent_ed_visits_now = latent_ed_visits_final * prop_already_reported else: latent_ed_visits_now = latent_ed_visits_final @@ -386,9 +374,7 @@ def __init__( ihr_rel_iedr_rv: RandomVariable = None, ) -> None: self.inf_to_hosp_admit_rv = inf_to_hosp_admit_rv - self.hosp_admit_neg_bin_concentration_rv = ( - hosp_admit_neg_bin_concentration_rv - ) + self.hosp_admit_neg_bin_concentration_rv = hosp_admit_neg_bin_concentration_rv self.ihr_rv = ihr_rv self.ihr_rel_iedr_rv = ihr_rel_iedr_rv @@ -522,10 +508,7 @@ def normed_shedding_cdf( norm_const = (t_p + t_d) * ((log_base - 1) / jnp.log(log_base) - 1) def ad_pre(x): - return ( - t_p / jnp.log(log_base) * jnp.exp(jnp.log(log_base) * x / t_p) - - x - ) + return t_p / jnp.log(log_base) * jnp.exp(jnp.log(log_base) * x / t_p) - x def ad_post(x): return ( @@ -605,12 +588,10 @@ def sample( def batch_colvolve_fn(m): return jnp.convolve(m, viral_kinetics, mode="valid") - model_net_inf_ind_shedding = jax.vmap( - batch_colvolve_fn, in_axes=1, out_axes=1 - )(jnp.atleast_2d(latent_infections_subpop))[-n_datapoints:, :] - numpyro.deterministic( - "model_net_inf_ind_shedding", model_net_inf_ind_shedding - ) + model_net_inf_ind_shedding = jax.vmap(batch_colvolve_fn, in_axes=1, out_axes=1)( + jnp.atleast_2d(latent_infections_subpop) + )[-n_datapoints:, :] + numpyro.deterministic("model_net_inf_ind_shedding", model_net_inf_ind_shedding) log10_genome_per_inf_ind = self.log10_genome_per_inf_ind_rv() expected_obs_viral_genomes = ( @@ -618,9 +599,7 @@ def batch_colvolve_fn(m): + jnp.log(model_net_inf_ind_shedding + shedding_offset) - jnp.log(self.ww_ml_produced_per_day) ) - numpyro.deterministic( - "expected_obs_viral_genomes", expected_obs_viral_genomes - ) + numpyro.deterministic("expected_obs_viral_genomes", expected_obs_viral_genomes) mode_sigma_ww_site = self.mode_sigma_ww_site_rv() sd_log_sigma_ww_site = self.sd_log_sigma_ww_site_rv() @@ -659,11 +638,7 @@ def batch_colvolve_fn(m): scale=sigma_ww_site[ww_observed_lab_sites[ww_uncensored]], ), ).sample( - obs=( - data_observed[ww_uncensored] - if data_observed is not None - else None - ), + obs=(data_observed[ww_uncensored] if data_observed is not None else None), ) if ww_censored.shape[0] != 0: @@ -720,10 +695,8 @@ def sample( sample_wastewater: bool = False, ) -> dict[str, ArrayLike]: # numpydoc ignore=GL08 n_init_days = self.latent_infection_process_rv.n_initialization_points - latent_infections, latent_infections_subpop = ( - self.latent_infection_process_rv( - n_days_post_init=data.n_days_post_init, - ) + latent_infections, latent_infections_subpop = self.latent_infection_process_rv( + n_days_post_init=data.n_days_post_init, ) first_latent_infection_dow = ( data.first_data_date_overall - datetime.timedelta(days=n_init_days) diff --git a/tests/test_pyrenew_hew_data.py b/tests/test_pyrenew_hew_data.py index 99f42058..e3195f6a 100644 --- a/tests/test_pyrenew_hew_data.py +++ b/tests/test_pyrenew_hew_data.py @@ -96,10 +96,7 @@ def test_to_forecast_data( assert forecast_data.n_hospital_admissions_data_days == n_weeks_expected assert forecast_data.right_truncation_offset is None assert forecast_data.first_ed_visits_date == data.first_data_date_overall - assert ( - forecast_data.first_hospital_admissions_date - == data.first_data_date_overall - ) + assert forecast_data.first_hospital_admissions_date == data.first_data_date_overall assert forecast_data.first_wastewater_date == data.first_data_date_overall assert forecast_data.data_observed_disease_wastewater is None @@ -157,11 +154,7 @@ def test_wastewater_data_properties(): data.data_observed_disease_wastewater_conc, ww_data["log_genome_copies_per_ml"], ) - assert len(data.ww_censored) == len( - ww_data.filter(pl.col("below_lod") == 1) - ) - assert len(data.ww_uncensored) == len( - ww_data.filter(pl.col("below_lod") == 0) - ) + assert len(data.ww_censored) == len(ww_data.filter(pl.col("below_lod") == 1)) + assert len(data.ww_uncensored) == len(ww_data.filter(pl.col("below_lod") == 0)) assert jnp.array_equal(data.ww_log_lod, ww_data["log_lod"]) assert data.n_ww_lab_sites == ww_data["lab_site_index"].n_unique() From 1d8e2f03033738c43adf22b61f093fea3394262d Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 6 Feb 2025 01:33:30 -0500 Subject: [PATCH 07/34] add a test --- tests/test_pyrenew_hew_data.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_pyrenew_hew_data.py b/tests/test_pyrenew_hew_data.py index e3195f6a..fb0e1533 100644 --- a/tests/test_pyrenew_hew_data.py +++ b/tests/test_pyrenew_hew_data.py @@ -5,6 +5,7 @@ import polars as pl import pytest +from pipelines.prep_ww_data import get_date_time_spine from pyrenew_hew.pyrenew_hew_data import PyrenewHEWData From 64d3fd7533bca87a2f9b6109412773b956acacc9 Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 6 Feb 2025 10:08:25 -0500 Subject: [PATCH 08/34] fix test --- pyrenew_hew/pyrenew_hew_data.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyrenew_hew/pyrenew_hew_data.py b/pyrenew_hew/pyrenew_hew_data.py index 002a9cab..726c5a5c 100644 --- a/pyrenew_hew/pyrenew_hew_data.py +++ b/pyrenew_hew/pyrenew_hew_data.py @@ -351,5 +351,6 @@ def to_forecast_data(self, n_forecast_points: int) -> Self: first_hospital_admissions_date=(self.first_data_date_overall), first_wastewater_date=self.first_data_date_overall, pop_fraction=self.pop_fraction, + wastewater_data=self.wastewater_data, right_truncation_offset=None, # by default, want forecasts of complete reports ) From 3856e02128eef1422527bb9c893553b81a69d158 Mon Sep 17 00:00:00 2001 From: sbidari Date: Fri, 7 Feb 2025 14:29:48 -0500 Subject: [PATCH 09/34] n_datapoints -> n_data_days, move get_spines function to pyrenewHEWData --- pipelines/build_pyrenew_model.py | 4 +++- pipelines/prep_data.py | 1 + pyrenew_hew/pyrenew_hew_data.py | 1 - tests/test_pyrenew_hew_data.py | 1 - 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pipelines/build_pyrenew_model.py b/pipelines/build_pyrenew_model.py index c41b3f80..e3616d5a 100644 --- a/pipelines/build_pyrenew_model.py +++ b/pipelines/build_pyrenew_model.py @@ -123,7 +123,9 @@ def build_model_from_dir( right_truncation_offset = model_data["right_truncation_offset"] - wastewater_data = pl.DataFrame(model_data["train_disease_wastewater"]) + data_observed_disease_wastewater = pl.DataFrame( + model_data["data_observed_disease_wastewater"] + ) my_latent_infection_model = LatentInfectionProcess( i0_first_obs_n_rv=priors["i0_first_obs_n_rv"], diff --git a/pipelines/prep_data.py b/pipelines/prep_data.py index 2b549c3b..22f9df05 100644 --- a/pipelines/prep_data.py +++ b/pipelines/prep_data.py @@ -448,6 +448,7 @@ def process_and_save_state( "nhsn_training_dates": nhsn_training_dates, "nhsn_first_date_index": nhsn_first_date_index, "nhsn_step_size": nhsn_step_size, + "state_pop": state_pop, "right_truncation_offset": right_truncation_offset, "data_observed_disease_wastewater": data_observed_disease_wastewater, } diff --git a/pyrenew_hew/pyrenew_hew_data.py b/pyrenew_hew/pyrenew_hew_data.py index 726c5a5c..002a9cab 100644 --- a/pyrenew_hew/pyrenew_hew_data.py +++ b/pyrenew_hew/pyrenew_hew_data.py @@ -351,6 +351,5 @@ def to_forecast_data(self, n_forecast_points: int) -> Self: first_hospital_admissions_date=(self.first_data_date_overall), first_wastewater_date=self.first_data_date_overall, pop_fraction=self.pop_fraction, - wastewater_data=self.wastewater_data, right_truncation_offset=None, # by default, want forecasts of complete reports ) diff --git a/tests/test_pyrenew_hew_data.py b/tests/test_pyrenew_hew_data.py index fb0e1533..e3195f6a 100644 --- a/tests/test_pyrenew_hew_data.py +++ b/tests/test_pyrenew_hew_data.py @@ -5,7 +5,6 @@ import polars as pl import pytest -from pipelines.prep_ww_data import get_date_time_spine from pyrenew_hew.pyrenew_hew_data import PyrenewHEWData From 6bbc57a980063c2b5b235d6648104a13ae3cc1b4 Mon Sep 17 00:00:00 2001 From: sbidari Date: Fri, 7 Feb 2025 19:16:09 -0500 Subject: [PATCH 10/34] drop wastewater data with no LOD reported --- pipelines/prep_ww_data.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pipelines/prep_ww_data.py b/pipelines/prep_ww_data.py index 760f1f62..1e176ff8 100644 --- a/pipelines/prep_ww_data.py +++ b/pipelines/prep_ww_data.py @@ -52,6 +52,9 @@ def clean_nwss_data(nwss_data): .when(pl.col("pcr_target_units") == "log10 copies/l wastewater") .then((10 ** pl.col("lod_sewage")) / 1000) .otherwise(None), + sample_collect_date=pl.col("sample_collect_date").str.to_date( + format="%Y-%m-%d" + ), ) .filter( ( From 7614a1b3bd28113781fd2eccf62b6a835c1441fe Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 13 Feb 2025 16:52:00 -0500 Subject: [PATCH 11/34] create fake nwss data --- pipelines/generate_test_data.R | 45 ++++++++++++++++++++++++++++++++++ pipelines/prep_ww_data.py | 3 --- 2 files changed, 45 insertions(+), 3 deletions(-) diff --git a/pipelines/generate_test_data.R b/pipelines/generate_test_data.R index 4f6d3c88..5c353883 100644 --- a/pipelines/generate_test_data.R +++ b/pipelines/generate_test_data.R @@ -304,6 +304,51 @@ generate_fake_param_data <- ) } + +#' Generate Fake NWSS Data +#' +#' This function generates fake wastewater data for a +#' and saves it as a parquet file. + +create_test_nwss_data <- function( + private_data_dir = fs::path(getwd()), + states_to_generate = c("MT", "CA"), + start_reference = as.Date("2024-06-01"), + end_reference = as.Date("2024-12-21"), + site = c(1, 2, 3, 4), + lab = c(1, 1, 3, 3), + lod = c(20, 31, 20, 30), + site_pop = c(4e6, 2e6, 1e6, 5e5)) { + ww_dir <- fs::path(private_data_dir, "nwss-vintages") + fs::dir_create(ww_dir, recurse = TRUE) + + site_info <- tibble::tibble( + wwtp_id = site, + lab_id = lab, + lod_sewage = lod, + population_served = site_pop, + sample_location = "wwtp", + sample_matrix = "primary sludge", + pcr_target_units = "copies/l wastewater", + pcr_target = "sars-cov-2", + quality_flag = NA + ) + + ww_data <- tidyr::expand_grid( + sample_collect_date = seq(start_reference, end_reference, by = "week"), + wwtp_jurisdiction = states_to_generate, + site_info + ) |> + dplyr::mutate( + pcr_target_avg_conc = abs(rnorm(dplyr::n(), mean = 500, sd = 50)) + ) + + arrow::write_parquet( + ww_data, fs::path(ww_dir, paste0(end_reference, ".parquet")) + ) +} + + main <- function(private_data_dir, target_diseases, n_forecast_days) { diff --git a/pipelines/prep_ww_data.py b/pipelines/prep_ww_data.py index 1e176ff8..760f1f62 100644 --- a/pipelines/prep_ww_data.py +++ b/pipelines/prep_ww_data.py @@ -52,9 +52,6 @@ def clean_nwss_data(nwss_data): .when(pl.col("pcr_target_units") == "log10 copies/l wastewater") .then((10 ** pl.col("lod_sewage")) / 1000) .otherwise(None), - sample_collect_date=pl.col("sample_collect_date").str.to_date( - format="%Y-%m-%d" - ), ) .filter( ( From ab6beb81241a1c5c2f30aab3639d7d43f4aeca7d Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 13 Feb 2025 19:29:30 -0500 Subject: [PATCH 12/34] add ww_data_dir to forecast_state.py, test_end_to_end.sh --- pipelines/forecast_state.py | 9 +++++++ pipelines/generate_test_data.R | 11 +++++--- pipelines/prep_data.py | 40 ++++++++++++++++++++++-------- pipelines/tests/test_end_to_end.sh | 1 + 4 files changed, 47 insertions(+), 14 deletions(-) diff --git a/pipelines/forecast_state.py b/pipelines/forecast_state.py index b9b5bec2..8abef411 100644 --- a/pipelines/forecast_state.py +++ b/pipelines/forecast_state.py @@ -195,6 +195,7 @@ def main( state_level_nssp_data_dir: Path | str, nwss_data_dir: Path | str, param_data_dir: Path | str, + ww_data_dir: Path | str, priors_path: Path | str, output_dir: Path | str, n_training_days: int, @@ -382,6 +383,7 @@ def main( last_training_date=last_training_date, param_estimates=param_estimates, model_run_dir=model_run_dir, + ww_data_dir=ww_data_dir, logger=logger, credentials_dict=credentials_dict, ) @@ -535,6 +537,13 @@ def main( required=True, ) + parser.add_argument( + "--ww-data-dir", + type=Path, + default=Path("private_data", "nwss_vintages"), + help=("Directory in which to look for NWSS wastewater data"), + ) + parser.add_argument( "--priors-path", type=Path, diff --git a/pipelines/generate_test_data.R b/pipelines/generate_test_data.R index 5c353883..84d5c2d1 100644 --- a/pipelines/generate_test_data.R +++ b/pipelines/generate_test_data.R @@ -310,7 +310,7 @@ generate_fake_param_data <- #' This function generates fake wastewater data for a #' and saves it as a parquet file. -create_test_nwss_data <- function( +generate_fake_nwss_data <- function( private_data_dir = fs::path(getwd()), states_to_generate = c("MT", "CA"), start_reference = as.Date("2024-06-01"), @@ -319,7 +319,7 @@ create_test_nwss_data <- function( lab = c(1, 1, 3, 3), lod = c(20, 31, 20, 30), site_pop = c(4e6, 2e6, 1e6, 5e5)) { - ww_dir <- fs::path(private_data_dir, "nwss-vintages") + ww_dir <- fs::path(private_data_dir, "nwss_vintages") fs::dir_create(ww_dir, recurse = TRUE) site_info <- tibble::tibble( @@ -328,10 +328,10 @@ create_test_nwss_data <- function( lod_sewage = lod, population_served = site_pop, sample_location = "wwtp", - sample_matrix = "primary sludge", + sample_matrix = "raw wastewater", pcr_target_units = "copies/l wastewater", pcr_target = "sars-cov-2", - quality_flag = NA + quality_flag = c("no", NA_character_, "n", "n") ) ww_data <- tidyr::expand_grid( @@ -380,6 +380,9 @@ main <- function(private_data_dir, states_to_generate = c("MT", "CA", "US"), target_diseases = short_target_diseases ) + generate_fake_nwss_data( + private_data_dir, + ) } p <- arg_parser("Create simulated epiweekly data.") |> diff --git a/pipelines/prep_data.py b/pipelines/prep_data.py index 22f9df05..fd91a463 100644 --- a/pipelines/prep_data.py +++ b/pipelines/prep_data.py @@ -81,7 +81,9 @@ def py_scalar_to_r_scalar(py_scalar): if result.returncode != 0: raise RuntimeError(f"pull_and_save_nhsn: {result.stderr}") raw_dat = pl.read_parquet(temp_file) - dat = raw_dat.with_columns(weekendingdate=pl.col("weekendingdate").cast(pl.Date)) + dat = raw_dat.with_columns( + weekendingdate=pl.col("weekendingdate").cast(pl.Date) + ) return dat @@ -103,7 +105,9 @@ def combine_nssp_and_nhsn( variable_name="drop_me", value_name=".value", ) - .with_columns(pl.col("count_type").replace(count_type_dict).alias(".variable")) + .with_columns( + pl.col("count_type").replace(count_type_dict).alias(".variable") + ) .select(cs.exclude(["count_type", "drop_me"])) ) @@ -183,7 +187,9 @@ def process_state_level_data( if state_abb == "US": locations_to_aggregate = ( - state_pop_df.filter(pl.col("abb") != "US").get_column("abb").unique() + state_pop_df.filter(pl.col("abb") != "US") + .get_column("abb") + .unique() ) logger.info("Aggregating state-level data to national") state_level_nssp_data = aggregate_to_national( @@ -210,7 +216,9 @@ def process_state_level_data( ] ) .with_columns( - disease=pl.col("disease").cast(pl.Utf8).replace(_inverse_disease_map), + disease=pl.col("disease") + .cast(pl.Utf8) + .replace(_inverse_disease_map), ) .sort(["date", "disease"]) .collect(streaming=True) @@ -242,7 +250,9 @@ def aggregate_facility_level_nssp_to_state( if state_abb == "US": logger.info("Aggregating facility-level data to national") locations_to_aggregate = ( - state_pop_df.filter(pl.col("abb") != "US").get_column("abb").unique() + state_pop_df.filter(pl.col("abb") != "US") + .get_column("abb") + .unique() ) facility_level_nssp_data = aggregate_to_national( facility_level_nssp_data, @@ -261,7 +271,9 @@ def aggregate_facility_level_nssp_to_state( .group_by(["reference_date", "disease"]) .agg(pl.col("value").sum().alias("ed_visits")) .with_columns( - disease=pl.col("disease").cast(pl.Utf8).replace(_inverse_disease_map), + disease=pl.col("disease") + .cast(pl.Utf8) + .replace(_inverse_disease_map), geo_value=pl.lit(state_abb).cast(pl.Utf8), ) .rename({"reference_date": "date"}) @@ -351,12 +363,16 @@ def process_and_save_state( if facility_level_nssp_data is None and state_level_nssp_data is None: raise ValueError( - "Must provide at least one " "of facility-level and state-level" "NSSP data" + "Must provide at least one " + "of facility-level and state-level" + "NSSP data" ) state_pop_df = get_state_pop_df() - state_pop = state_pop_df.filter(pl.col("abb") == state_abb).item(0, "population") + state_pop = state_pop_df.filter(pl.col("abb") == state_abb).item( + 0, "population" + ) (generation_interval_pmf, delay_pmf, right_truncation_pmf) = get_pmfs( param_estimates=param_estimates, state_abb=state_abb, disease=disease @@ -405,13 +421,17 @@ def process_and_save_state( credentials_dict=credentials_dict, ).with_columns(pl.lit("train").alias("data_type")) - nssp_training_dates = nssp_training_data.get_column("date").unique().to_list() + nssp_training_dates = ( + nssp_training_data.get_column("date").unique().to_list() + ) nhsn_training_dates = ( nhsn_training_data.get_column("weekendingdate").unique().to_list() ) nhsn_first_date_index = next( - i for i, x in enumerate(nssp_training_dates) if x == min(nhsn_training_dates) + i + for i, x in enumerate(nssp_training_dates) + if x == min(nhsn_training_dates) ) nhsn_step_size = 7 diff --git a/pipelines/tests/test_end_to_end.sh b/pipelines/tests/test_end_to_end.sh index 309bb8b4..4d4367b8 100755 --- a/pipelines/tests/test_end_to_end.sh +++ b/pipelines/tests/test_end_to_end.sh @@ -28,6 +28,7 @@ do --state-level-nssp-data-dir "$BASE_DIR/private_data/nssp_state_level_gold" \ --priors-path pipelines/priors/prod_priors.py \ --param-data-dir "$BASE_DIR/private_data/prod_param_estimates" \ + --ww-data-dir "$BASE_DIR/private_data/nwss_vintages" \ --output-dir "$BASE_DIR/private_data" \ --n-training-days 60 \ --n-chains 2 \ From 60ced068395d0942b387e6bff397317617083353 Mon Sep 17 00:00:00 2001 From: sbidari Date: Wed, 12 Feb 2025 18:27:55 -0500 Subject: [PATCH 13/34] add schema for ww dataframe, add placeholder path --- pipelines/build_pyrenew_model.py | 13 ++++++++++- pipelines/prep_data.py | 40 ++++++++------------------------ 2 files changed, 22 insertions(+), 31 deletions(-) diff --git a/pipelines/build_pyrenew_model.py b/pipelines/build_pyrenew_model.py index e3616d5a..d112b04a 100644 --- a/pipelines/build_pyrenew_model.py +++ b/pipelines/build_pyrenew_model.py @@ -124,7 +124,18 @@ def build_model_from_dir( right_truncation_offset = model_data["right_truncation_offset"] data_observed_disease_wastewater = pl.DataFrame( - model_data["data_observed_disease_wastewater"] + model_data["data_observed_disease_wastewater"], + schema={ + "date": pl.Date, + "site": pl.String, + "lab": pl.String, + "site_pop": pl.Int64, + "site_index": pl.Int64, + "lab_site_index": pl.Int64, + "log_genomes_copies_per_ml": pl.Float64, + "log_lod": pl.Float64, + "below_lod": pl.Int64, + }, ) my_latent_infection_model = LatentInfectionProcess( diff --git a/pipelines/prep_data.py b/pipelines/prep_data.py index fd91a463..22f9df05 100644 --- a/pipelines/prep_data.py +++ b/pipelines/prep_data.py @@ -81,9 +81,7 @@ def py_scalar_to_r_scalar(py_scalar): if result.returncode != 0: raise RuntimeError(f"pull_and_save_nhsn: {result.stderr}") raw_dat = pl.read_parquet(temp_file) - dat = raw_dat.with_columns( - weekendingdate=pl.col("weekendingdate").cast(pl.Date) - ) + dat = raw_dat.with_columns(weekendingdate=pl.col("weekendingdate").cast(pl.Date)) return dat @@ -105,9 +103,7 @@ def combine_nssp_and_nhsn( variable_name="drop_me", value_name=".value", ) - .with_columns( - pl.col("count_type").replace(count_type_dict).alias(".variable") - ) + .with_columns(pl.col("count_type").replace(count_type_dict).alias(".variable")) .select(cs.exclude(["count_type", "drop_me"])) ) @@ -187,9 +183,7 @@ def process_state_level_data( if state_abb == "US": locations_to_aggregate = ( - state_pop_df.filter(pl.col("abb") != "US") - .get_column("abb") - .unique() + state_pop_df.filter(pl.col("abb") != "US").get_column("abb").unique() ) logger.info("Aggregating state-level data to national") state_level_nssp_data = aggregate_to_national( @@ -216,9 +210,7 @@ def process_state_level_data( ] ) .with_columns( - disease=pl.col("disease") - .cast(pl.Utf8) - .replace(_inverse_disease_map), + disease=pl.col("disease").cast(pl.Utf8).replace(_inverse_disease_map), ) .sort(["date", "disease"]) .collect(streaming=True) @@ -250,9 +242,7 @@ def aggregate_facility_level_nssp_to_state( if state_abb == "US": logger.info("Aggregating facility-level data to national") locations_to_aggregate = ( - state_pop_df.filter(pl.col("abb") != "US") - .get_column("abb") - .unique() + state_pop_df.filter(pl.col("abb") != "US").get_column("abb").unique() ) facility_level_nssp_data = aggregate_to_national( facility_level_nssp_data, @@ -271,9 +261,7 @@ def aggregate_facility_level_nssp_to_state( .group_by(["reference_date", "disease"]) .agg(pl.col("value").sum().alias("ed_visits")) .with_columns( - disease=pl.col("disease") - .cast(pl.Utf8) - .replace(_inverse_disease_map), + disease=pl.col("disease").cast(pl.Utf8).replace(_inverse_disease_map), geo_value=pl.lit(state_abb).cast(pl.Utf8), ) .rename({"reference_date": "date"}) @@ -363,16 +351,12 @@ def process_and_save_state( if facility_level_nssp_data is None and state_level_nssp_data is None: raise ValueError( - "Must provide at least one " - "of facility-level and state-level" - "NSSP data" + "Must provide at least one " "of facility-level and state-level" "NSSP data" ) state_pop_df = get_state_pop_df() - state_pop = state_pop_df.filter(pl.col("abb") == state_abb).item( - 0, "population" - ) + state_pop = state_pop_df.filter(pl.col("abb") == state_abb).item(0, "population") (generation_interval_pmf, delay_pmf, right_truncation_pmf) = get_pmfs( param_estimates=param_estimates, state_abb=state_abb, disease=disease @@ -421,17 +405,13 @@ def process_and_save_state( credentials_dict=credentials_dict, ).with_columns(pl.lit("train").alias("data_type")) - nssp_training_dates = ( - nssp_training_data.get_column("date").unique().to_list() - ) + nssp_training_dates = nssp_training_data.get_column("date").unique().to_list() nhsn_training_dates = ( nhsn_training_data.get_column("weekendingdate").unique().to_list() ) nhsn_first_date_index = next( - i - for i, x in enumerate(nssp_training_dates) - if x == min(nhsn_training_dates) + i for i, x in enumerate(nssp_training_dates) if x == min(nhsn_training_dates) ) nhsn_step_size = 7 From 0056e3df8c7860f1800c3321bf3d133cc4f013d9 Mon Sep 17 00:00:00 2001 From: sbidari Date: Wed, 12 Feb 2025 18:39:54 -0500 Subject: [PATCH 14/34] pre-commit --- pipelines/prep_data.py | 44 +++++++++++++++++++++++++++++++----------- 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/pipelines/prep_data.py b/pipelines/prep_data.py index 22f9df05..837d2ce2 100644 --- a/pipelines/prep_data.py +++ b/pipelines/prep_data.py @@ -49,7 +49,9 @@ def py_scalar_to_r_scalar(py_scalar): state_abb_for_query = state_abb if state_abb != "US" else "USA" temp_file = Path(temp_dir, "nhsn_temp.parquet") - api_key_id = credentials_dict.get("nhsn_api_key_id", os.getenv("NHSN_API_KEY_ID")) + api_key_id = credentials_dict.get( + "nhsn_api_key_id", os.getenv("NHSN_API_KEY_ID") + ) api_key_secret = credentials_dict.get( "nhsn_api_key_secret", os.getenv("NHSN_API_KEY_SECRET") ) @@ -81,7 +83,9 @@ def py_scalar_to_r_scalar(py_scalar): if result.returncode != 0: raise RuntimeError(f"pull_and_save_nhsn: {result.stderr}") raw_dat = pl.read_parquet(temp_file) - dat = raw_dat.with_columns(weekendingdate=pl.col("weekendingdate").cast(pl.Date)) + dat = raw_dat.with_columns( + weekendingdate=pl.col("weekendingdate").cast(pl.Date) + ) return dat @@ -103,7 +107,9 @@ def combine_nssp_and_nhsn( variable_name="drop_me", value_name=".value", ) - .with_columns(pl.col("count_type").replace(count_type_dict).alias(".variable")) + .with_columns( + pl.col("count_type").replace(count_type_dict).alias(".variable") + ) .select(cs.exclude(["count_type", "drop_me"])) ) @@ -183,7 +189,9 @@ def process_state_level_data( if state_abb == "US": locations_to_aggregate = ( - state_pop_df.filter(pl.col("abb") != "US").get_column("abb").unique() + state_pop_df.filter(pl.col("abb") != "US") + .get_column("abb") + .unique() ) logger.info("Aggregating state-level data to national") state_level_nssp_data = aggregate_to_national( @@ -210,7 +218,9 @@ def process_state_level_data( ] ) .with_columns( - disease=pl.col("disease").cast(pl.Utf8).replace(_inverse_disease_map), + disease=pl.col("disease") + .cast(pl.Utf8) + .replace(_inverse_disease_map), ) .sort(["date", "disease"]) .collect(streaming=True) @@ -242,7 +252,9 @@ def aggregate_facility_level_nssp_to_state( if state_abb == "US": logger.info("Aggregating facility-level data to national") locations_to_aggregate = ( - state_pop_df.filter(pl.col("abb") != "US").get_column("abb").unique() + state_pop_df.filter(pl.col("abb") != "US") + .get_column("abb") + .unique() ) facility_level_nssp_data = aggregate_to_national( facility_level_nssp_data, @@ -261,7 +273,9 @@ def aggregate_facility_level_nssp_to_state( .group_by(["reference_date", "disease"]) .agg(pl.col("value").sum().alias("ed_visits")) .with_columns( - disease=pl.col("disease").cast(pl.Utf8).replace(_inverse_disease_map), + disease=pl.col("disease") + .cast(pl.Utf8) + .replace(_inverse_disease_map), geo_value=pl.lit(state_abb).cast(pl.Utf8), ) .rename({"reference_date": "date"}) @@ -351,12 +365,16 @@ def process_and_save_state( if facility_level_nssp_data is None and state_level_nssp_data is None: raise ValueError( - "Must provide at least one " "of facility-level and state-level" "NSSP data" + "Must provide at least one " + "of facility-level and state-level" + "NSSP data" ) state_pop_df = get_state_pop_df() - state_pop = state_pop_df.filter(pl.col("abb") == state_abb).item(0, "population") + state_pop = state_pop_df.filter(pl.col("abb") == state_abb).item( + 0, "population" + ) (generation_interval_pmf, delay_pmf, right_truncation_pmf) = get_pmfs( param_estimates=param_estimates, state_abb=state_abb, disease=disease @@ -405,13 +423,17 @@ def process_and_save_state( credentials_dict=credentials_dict, ).with_columns(pl.lit("train").alias("data_type")) - nssp_training_dates = nssp_training_data.get_column("date").unique().to_list() + nssp_training_dates = ( + nssp_training_data.get_column("date").unique().to_list() + ) nhsn_training_dates = ( nhsn_training_data.get_column("weekendingdate").unique().to_list() ) nhsn_first_date_index = next( - i for i, x in enumerate(nssp_training_dates) if x == min(nhsn_training_dates) + i + for i, x in enumerate(nssp_training_dates) + if x == min(nhsn_training_dates) ) nhsn_step_size = 7 From 5896b5a4baaf7b7b7d8a13a0c88f499fbc2eb389 Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 13 Feb 2025 17:27:30 -0500 Subject: [PATCH 15/34] code review suggestions --- pipelines/build_pyrenew_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipelines/build_pyrenew_model.py b/pipelines/build_pyrenew_model.py index d112b04a..c9a7e146 100644 --- a/pipelines/build_pyrenew_model.py +++ b/pipelines/build_pyrenew_model.py @@ -134,7 +134,7 @@ def build_model_from_dir( "lab_site_index": pl.Int64, "log_genomes_copies_per_ml": pl.Float64, "log_lod": pl.Float64, - "below_lod": pl.Int64, + "below_lod": pl.Boolean, }, ) From f7dea095174b1ca6d584cd3c86e92457f40993cb Mon Sep 17 00:00:00 2001 From: sbidari Date: Tue, 18 Feb 2025 14:34:33 -0500 Subject: [PATCH 16/34] add schema override --- pipelines/build_pyrenew_model.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/pipelines/build_pyrenew_model.py b/pipelines/build_pyrenew_model.py index c9a7e146..0f4359c8 100644 --- a/pipelines/build_pyrenew_model.py +++ b/pipelines/build_pyrenew_model.py @@ -123,21 +123,6 @@ def build_model_from_dir( right_truncation_offset = model_data["right_truncation_offset"] - data_observed_disease_wastewater = pl.DataFrame( - model_data["data_observed_disease_wastewater"], - schema={ - "date": pl.Date, - "site": pl.String, - "lab": pl.String, - "site_pop": pl.Int64, - "site_index": pl.Int64, - "lab_site_index": pl.Int64, - "log_genomes_copies_per_ml": pl.Float64, - "log_lod": pl.Float64, - "below_lod": pl.Boolean, - }, - ) - my_latent_infection_model = LatentInfectionProcess( i0_first_obs_n_rv=priors["i0_first_obs_n_rv"], initialization_rate_rv=priors["initialization_rate_rv"], From 98b24c318913dcdacc62ebd2814fbb1f820accf4 Mon Sep 17 00:00:00 2001 From: sbidari Date: Wed, 19 Feb 2025 10:11:33 -0500 Subject: [PATCH 17/34] fix test --- pipelines/prep_data.py | 45 ++++++++++----------------------------- pipelines/prep_ww_data.py | 29 +++++++++++++++++++------ 2 files changed, 33 insertions(+), 41 deletions(-) diff --git a/pipelines/prep_data.py b/pipelines/prep_data.py index 837d2ce2..16f76cfa 100644 --- a/pipelines/prep_data.py +++ b/pipelines/prep_data.py @@ -12,7 +12,6 @@ import jax.numpy as jnp import polars as pl import polars.selectors as cs -from prep_ww_data import get_nwss_data _disease_map = { "COVID-19": "COVID-19/Omicron", @@ -49,9 +48,7 @@ def py_scalar_to_r_scalar(py_scalar): state_abb_for_query = state_abb if state_abb != "US" else "USA" temp_file = Path(temp_dir, "nhsn_temp.parquet") - api_key_id = credentials_dict.get( - "nhsn_api_key_id", os.getenv("NHSN_API_KEY_ID") - ) + api_key_id = credentials_dict.get("nhsn_api_key_id", os.getenv("NHSN_API_KEY_ID")) api_key_secret = credentials_dict.get( "nhsn_api_key_secret", os.getenv("NHSN_API_KEY_SECRET") ) @@ -83,9 +80,7 @@ def py_scalar_to_r_scalar(py_scalar): if result.returncode != 0: raise RuntimeError(f"pull_and_save_nhsn: {result.stderr}") raw_dat = pl.read_parquet(temp_file) - dat = raw_dat.with_columns( - weekendingdate=pl.col("weekendingdate").cast(pl.Date) - ) + dat = raw_dat.with_columns(weekendingdate=pl.col("weekendingdate").cast(pl.Date)) return dat @@ -107,9 +102,7 @@ def combine_nssp_and_nhsn( variable_name="drop_me", value_name=".value", ) - .with_columns( - pl.col("count_type").replace(count_type_dict).alias(".variable") - ) + .with_columns(pl.col("count_type").replace(count_type_dict).alias(".variable")) .select(cs.exclude(["count_type", "drop_me"])) ) @@ -189,9 +182,7 @@ def process_state_level_data( if state_abb == "US": locations_to_aggregate = ( - state_pop_df.filter(pl.col("abb") != "US") - .get_column("abb") - .unique() + state_pop_df.filter(pl.col("abb") != "US").get_column("abb").unique() ) logger.info("Aggregating state-level data to national") state_level_nssp_data = aggregate_to_national( @@ -218,9 +209,7 @@ def process_state_level_data( ] ) .with_columns( - disease=pl.col("disease") - .cast(pl.Utf8) - .replace(_inverse_disease_map), + disease=pl.col("disease").cast(pl.Utf8).replace(_inverse_disease_map), ) .sort(["date", "disease"]) .collect(streaming=True) @@ -252,9 +241,7 @@ def aggregate_facility_level_nssp_to_state( if state_abb == "US": logger.info("Aggregating facility-level data to national") locations_to_aggregate = ( - state_pop_df.filter(pl.col("abb") != "US") - .get_column("abb") - .unique() + state_pop_df.filter(pl.col("abb") != "US").get_column("abb").unique() ) facility_level_nssp_data = aggregate_to_national( facility_level_nssp_data, @@ -273,9 +260,7 @@ def aggregate_facility_level_nssp_to_state( .group_by(["reference_date", "disease"]) .agg(pl.col("value").sum().alias("ed_visits")) .with_columns( - disease=pl.col("disease") - .cast(pl.Utf8) - .replace(_inverse_disease_map), + disease=pl.col("disease").cast(pl.Utf8).replace(_inverse_disease_map), geo_value=pl.lit(state_abb).cast(pl.Utf8), ) .rename({"reference_date": "date"}) @@ -365,16 +350,12 @@ def process_and_save_state( if facility_level_nssp_data is None and state_level_nssp_data is None: raise ValueError( - "Must provide at least one " - "of facility-level and state-level" - "NSSP data" + "Must provide at least one " "of facility-level and state-level" "NSSP data" ) state_pop_df = get_state_pop_df() - state_pop = state_pop_df.filter(pl.col("abb") == state_abb).item( - 0, "population" - ) + state_pop = state_pop_df.filter(pl.col("abb") == state_abb).item(0, "population") (generation_interval_pmf, delay_pmf, right_truncation_pmf) = get_pmfs( param_estimates=param_estimates, state_abb=state_abb, disease=disease @@ -423,17 +404,13 @@ def process_and_save_state( credentials_dict=credentials_dict, ).with_columns(pl.lit("train").alias("data_type")) - nssp_training_dates = ( - nssp_training_data.get_column("date").unique().to_list() - ) + nssp_training_dates = nssp_training_data.get_column("date").unique().to_list() nhsn_training_dates = ( nhsn_training_data.get_column("weekendingdate").unique().to_list() ) nhsn_first_date_index = next( - i - for i, x in enumerate(nssp_training_dates) - if x == min(nhsn_training_dates) + i for i, x in enumerate(nssp_training_dates) if x == min(nhsn_training_dates) ) nhsn_step_size = 7 diff --git a/pipelines/prep_ww_data.py b/pipelines/prep_ww_data.py index 760f1f62..513143c6 100644 --- a/pipelines/prep_ww_data.py +++ b/pipelines/prep_ww_data.py @@ -47,7 +47,9 @@ def clean_nwss_data(nwss_data): .when(pl.col("pcr_target_units") == "log10 copies/l wastewater") .then((10 ** pl.col("pcr_target_avg_conc")) / 1000) .otherwise(None), - lod_sewage=pl.when(pl.col("pcr_target_units") == "copies/l wastewater") + lod_sewage=pl.when( + pl.col("pcr_target_units") == "copies/l wastewater" + ) .then(pl.col("lod_sewage") / 1000) .when(pl.col("pcr_target_units") == "log10 copies/l wastewater") .then((10 ** pl.col("lod_sewage")) / 1000) @@ -128,7 +130,9 @@ def clean_nwss_data(nwss_data): ) .with_columns( [ - pl.col("pcr_target_avg_conc").log().alias("log_genome_copies_per_ml"), + pl.col("pcr_target_avg_conc") + .log() + .alias("log_genome_copies_per_ml"), pl.col("lod_sewage").log().alias("log_lod"), pl.col("location").str.to_uppercase().alias("location"), pl.col("site").cast(pl.String).alias("site"), @@ -207,7 +211,9 @@ def validate_ww_conc_data( .eq(1) .all() ): - raise ValueError("The data contains sites with varying population sizes.") + raise ValueError( + "The data contains sites with varying population sizes." + ) return None @@ -239,10 +245,14 @@ def preprocess_ww_data( .with_row_index("lab_site_index") ) site_df = ( - ww_data_ordered.select([wwtp_col_name]).unique().with_row_index("site_index") + ww_data_ordered.select([wwtp_col_name]) + .unique() + .with_row_index("site_index") ) ww_preprocessed = ( - ww_data_ordered.join(lab_site_df, on=[lab_col_name, wwtp_col_name], how="left") + ww_data_ordered.join( + lab_site_df, on=[lab_col_name, wwtp_col_name], how="left" + ) .join(site_df, on=wwtp_col_name, how="left") .rename( { @@ -252,9 +262,14 @@ def preprocess_ww_data( ) .with_columns( lab_site_name=( - "Site: " + pl.col(wwtp_col_name) + ", Lab: " + pl.col(lab_col_name) + "Site: " + + pl.col(wwtp_col_name) + + ", Lab: " + + pl.col(lab_col_name) + ), + below_lod=( + pl.col("log_genome_copies_per_ml") <= pl.col("log_lod") ), - below_lod=(pl.col("log_genome_copies_per_ml") <= pl.col("log_lod")), ) .select( [ From 8f53d081dcf9926a76daf3706266968318237c37 Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 20 Feb 2025 18:10:02 -0500 Subject: [PATCH 18/34] add wastewater related priors --- pipelines/build_pyrenew_model.py | 44 +++++++++++------ pipelines/forecast_state.py | 11 +---- pipelines/prep_data.py | 46 +++++++++++++----- pipelines/prep_ww_data.py | 3 -- pipelines/priors/prod_priors.py | 41 ++++++++++++++++ pipelines/tests/test_end_to_end.sh | 8 ++-- pyrenew_hew/pyrenew_hew_data.py | 77 ++++++++++++++++-------------- pyrenew_hew/pyrenew_hew_model.py | 76 +++++++++++++++++++---------- tests/test_pyrenew_hew_data.py | 13 +++-- 9 files changed, 213 insertions(+), 106 deletions(-) diff --git a/pipelines/build_pyrenew_model.py b/pipelines/build_pyrenew_model.py index 0f4359c8..97585852 100644 --- a/pipelines/build_pyrenew_model.py +++ b/pipelines/build_pyrenew_model.py @@ -114,15 +114,27 @@ def build_model_from_dir( first_ed_visits_date = datetime.datetime.strptime( model_data["nssp_training_dates"][0], "%Y-%m-%d" - ) + ).date() first_hospital_admissions_date = datetime.datetime.strptime( model_data["nhsn_training_dates"][0], "%Y-%m-%d" - ) + ).date() priors = runpy.run_path(str(prior_path)) right_truncation_offset = model_data["right_truncation_offset"] + my_data = PyrenewHEWData( + data_observed_disease_ed_visits=data_observed_disease_ed_visits, + data_observed_disease_hospital_admissions=( + data_observed_disease_hospital_admissions + ), + data_observed_disease_wastewater=data_observed_disease_wastewater, + right_truncation_offset=right_truncation_offset, + first_ed_visits_date=first_ed_visits_date, + first_hospital_admissions_date=first_hospital_admissions_date, + population_size=population_size, + ) + my_latent_infection_model = LatentInfectionProcess( i0_first_obs_n_rv=priors["i0_first_obs_n_rv"], initialization_rate_rv=priors["initialization_rate_rv"], @@ -133,6 +145,22 @@ def build_model_from_dir( infection_feedback_strength_rv=priors["inf_feedback_strength_rv"], infection_feedback_pmf_rv=infection_feedback_pmf_rv, n_initialization_points=uot, + pop_fraction=my_data.pop_fraction + if fit_wastewater + else jnp.array([1]), + autoreg_rt_subpop_rv=priors["autoreg_rt_subpop_rv"], + sigma_rt_rv=priors["sigma_rt_rv"], + sigma_i_first_obs_rv=priors["sigma_i_first_obs_rv"], + sigma_initial_exp_growth_rate_rv=priors[ + "sigma_initial_exp_growth_rate_rv" + ], + offset_ref_logit_i_first_obs_rv=priors[ + "offset_ref_logit_i_first_obs_rv" + ], + offset_ref_initial_exp_growth_rate_rv=priors[ + "offset_ref_initial_exp_growth_rate_rv" + ], + offset_ref_log_rt_rv=priors["offset_ref_log_rt_rv"], ) my_ed_visit_obs_model = EDVisitObservationProcess( @@ -173,16 +201,4 @@ def build_model_from_dir( wastewater_obs_process_rv=my_wastewater_obs_model, ) - my_data = PyrenewHEWData( - data_observed_disease_ed_visits=data_observed_disease_ed_visits, - data_observed_disease_hospital_admissions=( - data_observed_disease_hospital_admissions - ), - data_observed_disease_wastewater=data_observed_disease_wastewater, - right_truncation_offset=right_truncation_offset, - first_ed_visits_date=first_ed_visits_date, - first_hospital_admissions_date=first_hospital_admissions_date, - population_size=population_size, - ) - return (my_model, my_data) diff --git a/pipelines/forecast_state.py b/pipelines/forecast_state.py index 8abef411..e73cf85c 100644 --- a/pipelines/forecast_state.py +++ b/pipelines/forecast_state.py @@ -195,7 +195,6 @@ def main( state_level_nssp_data_dir: Path | str, nwss_data_dir: Path | str, param_data_dir: Path | str, - ww_data_dir: Path | str, priors_path: Path | str, output_dir: Path | str, n_training_days: int, @@ -349,7 +348,7 @@ def main( (pl.col("location") == state) & (pl.col("date") >= first_training_date) ) - state_level_nwss_data = preprocess_ww_data(nwss_data_cleaned) + state_level_nwss_data = preprocess_ww_data(nwss_data_cleaned.collect()) else: state_level_nwss_data = None ## TO DO: change @@ -383,7 +382,6 @@ def main( last_training_date=last_training_date, param_estimates=param_estimates, model_run_dir=model_run_dir, - ww_data_dir=ww_data_dir, logger=logger, credentials_dict=credentials_dict, ) @@ -537,13 +535,6 @@ def main( required=True, ) - parser.add_argument( - "--ww-data-dir", - type=Path, - default=Path("private_data", "nwss_vintages"), - help=("Directory in which to look for NWSS wastewater data"), - ) - parser.add_argument( "--priors-path", type=Path, diff --git a/pipelines/prep_data.py b/pipelines/prep_data.py index 16f76cfa..2ee60a79 100644 --- a/pipelines/prep_data.py +++ b/pipelines/prep_data.py @@ -8,8 +8,6 @@ from pathlib import Path import forecasttools -import jax.numpy as jnp -import jax.numpy as jnp import polars as pl import polars.selectors as cs @@ -48,7 +46,9 @@ def py_scalar_to_r_scalar(py_scalar): state_abb_for_query = state_abb if state_abb != "US" else "USA" temp_file = Path(temp_dir, "nhsn_temp.parquet") - api_key_id = credentials_dict.get("nhsn_api_key_id", os.getenv("NHSN_API_KEY_ID")) + api_key_id = credentials_dict.get( + "nhsn_api_key_id", os.getenv("NHSN_API_KEY_ID") + ) api_key_secret = credentials_dict.get( "nhsn_api_key_secret", os.getenv("NHSN_API_KEY_SECRET") ) @@ -80,7 +80,9 @@ def py_scalar_to_r_scalar(py_scalar): if result.returncode != 0: raise RuntimeError(f"pull_and_save_nhsn: {result.stderr}") raw_dat = pl.read_parquet(temp_file) - dat = raw_dat.with_columns(weekendingdate=pl.col("weekendingdate").cast(pl.Date)) + dat = raw_dat.with_columns( + weekendingdate=pl.col("weekendingdate").cast(pl.Date) + ) return dat @@ -102,7 +104,9 @@ def combine_nssp_and_nhsn( variable_name="drop_me", value_name=".value", ) - .with_columns(pl.col("count_type").replace(count_type_dict).alias(".variable")) + .with_columns( + pl.col("count_type").replace(count_type_dict).alias(".variable") + ) .select(cs.exclude(["count_type", "drop_me"])) ) @@ -182,7 +186,9 @@ def process_state_level_data( if state_abb == "US": locations_to_aggregate = ( - state_pop_df.filter(pl.col("abb") != "US").get_column("abb").unique() + state_pop_df.filter(pl.col("abb") != "US") + .get_column("abb") + .unique() ) logger.info("Aggregating state-level data to national") state_level_nssp_data = aggregate_to_national( @@ -209,7 +215,9 @@ def process_state_level_data( ] ) .with_columns( - disease=pl.col("disease").cast(pl.Utf8).replace(_inverse_disease_map), + disease=pl.col("disease") + .cast(pl.Utf8) + .replace(_inverse_disease_map), ) .sort(["date", "disease"]) .collect(streaming=True) @@ -241,7 +249,9 @@ def aggregate_facility_level_nssp_to_state( if state_abb == "US": logger.info("Aggregating facility-level data to national") locations_to_aggregate = ( - state_pop_df.filter(pl.col("abb") != "US").get_column("abb").unique() + state_pop_df.filter(pl.col("abb") != "US") + .get_column("abb") + .unique() ) facility_level_nssp_data = aggregate_to_national( facility_level_nssp_data, @@ -260,7 +270,9 @@ def aggregate_facility_level_nssp_to_state( .group_by(["reference_date", "disease"]) .agg(pl.col("value").sum().alias("ed_visits")) .with_columns( - disease=pl.col("disease").cast(pl.Utf8).replace(_inverse_disease_map), + disease=pl.col("disease") + .cast(pl.Utf8) + .replace(_inverse_disease_map), geo_value=pl.lit(state_abb).cast(pl.Utf8), ) .rename({"reference_date": "date"}) @@ -350,12 +362,16 @@ def process_and_save_state( if facility_level_nssp_data is None and state_level_nssp_data is None: raise ValueError( - "Must provide at least one " "of facility-level and state-level" "NSSP data" + "Must provide at least one " + "of facility-level and state-level" + "NSSP data" ) state_pop_df = get_state_pop_df() - state_pop = state_pop_df.filter(pl.col("abb") == state_abb).item(0, "population") + state_pop = state_pop_df.filter(pl.col("abb") == state_abb).item( + 0, "population" + ) (generation_interval_pmf, delay_pmf, right_truncation_pmf) = get_pmfs( param_estimates=param_estimates, state_abb=state_abb, disease=disease @@ -404,13 +420,17 @@ def process_and_save_state( credentials_dict=credentials_dict, ).with_columns(pl.lit("train").alias("data_type")) - nssp_training_dates = nssp_training_data.get_column("date").unique().to_list() + nssp_training_dates = ( + nssp_training_data.get_column("date").unique().to_list() + ) nhsn_training_dates = ( nhsn_training_data.get_column("weekendingdate").unique().to_list() ) nhsn_first_date_index = next( - i for i, x in enumerate(nssp_training_dates) if x == min(nhsn_training_dates) + i + for i, x in enumerate(nssp_training_dates) + if x == min(nhsn_training_dates) ) nhsn_step_size = 7 diff --git a/pipelines/prep_ww_data.py b/pipelines/prep_ww_data.py index 513143c6..af52d022 100644 --- a/pipelines/prep_ww_data.py +++ b/pipelines/prep_ww_data.py @@ -1,6 +1,3 @@ -import datetime -from pathlib import Path - import polars as pl diff --git a/pipelines/priors/prod_priors.py b/pipelines/priors/prod_priors.py index c8b1beaa..3ca83eaf 100644 --- a/pipelines/priors/prod_priors.py +++ b/pipelines/priors/prod_priors.py @@ -117,6 +117,47 @@ "mode_sd_ww_site", dist.TruncatedNormal(0, 0.25, low=0) ) +autoreg_rt_subpop_rv = DistributionalVariable( + "autoreg_rt_subpop", dist.Beta(1, 4) +) +sigma_rt_rv = DistributionalVariable( + "sigma_rt", dist.TruncatedNormal(0, 0.1, low=0) +) + +sigma_i_first_obs_rv = DistributionalVariable( + "sigma_i_first_obs", + dist.TruncatedNormal(0, 0.5, low=0), +) + +sigma_initial_exp_growth_rate_rv = DistributionalVariable( + "sigma_initial_exp_growth_rate", + dist.TruncatedNormal( + 0, + 0.05, + low=0, + ), +) + +offset_ref_logit_i_first_obs_rv = DistributionalVariable( + "offset_ref_logit_i_first_obs", + dist.Normal(0, 0.25), +) + +offset_ref_initial_exp_growth_rate_rv = DistributionalVariable( + "offset_ref_initial_exp_growth_rate", + dist.TruncatedNormal( + 0, + 0.025, + low=-0.01, + high=0.01, + ), +) + +offset_ref_log_rt_rv = DistributionalVariable( + "offset_ref_log_r_t", + dist.Normal(0, 0.2), +) + # model constants related to wastewater obs process ww_ml_produced_per_day = 227000 max_shed_interval = 26 diff --git a/pipelines/tests/test_end_to_end.sh b/pipelines/tests/test_end_to_end.sh index 4d4367b8..eee0582d 100755 --- a/pipelines/tests/test_end_to_end.sh +++ b/pipelines/tests/test_end_to_end.sh @@ -28,18 +28,18 @@ do --state-level-nssp-data-dir "$BASE_DIR/private_data/nssp_state_level_gold" \ --priors-path pipelines/priors/prod_priors.py \ --param-data-dir "$BASE_DIR/private_data/prod_param_estimates" \ - --ww-data-dir "$BASE_DIR/private_data/nwss_vintages" \ + --nwss-data-dir "$BASE_DIR/private_data/nwss_vintages" \ --output-dir "$BASE_DIR/private_data" \ --n-training-days 60 \ --n-chains 2 \ --n-samples 250 \ --n-warmup 250 \ --fit-ed-visits \ - --no-fit-hospital-admissions \ - --no-fit-wastewater \ + --fit-hospital-admissions \ + --fit-wastewater \ --forecast-ed-visits \ --forecast-hospital-admissions \ - --no-forecast-wastewater \ + --forecast-wastewater \ --score \ --eval-data-path "$BASE_DIR/private_data/nssp-etl" if [ $? -ne 0 ]; then diff --git a/pyrenew_hew/pyrenew_hew_data.py b/pyrenew_hew/pyrenew_hew_data.py index 002a9cab..2637df83 100644 --- a/pyrenew_hew/pyrenew_hew_data.py +++ b/pyrenew_hew/pyrenew_hew_data.py @@ -40,7 +40,9 @@ def __init__( self.first_ed_visits_date = first_ed_visits_date self.first_hospital_admissions_date = first_hospital_admissions_date self.first_wastewater_date_ = first_wastewater_date - self.data_observed_disease_wastewater = data_observed_disease_wastewater + self.data_observed_disease_wastewater = ( + data_observed_disease_wastewater + ) self.population_size = population_size self.shedding_offset = shedding_offset self.pop_fraction_ = pop_fraction @@ -76,12 +78,6 @@ def first_wastewater_date(self): return self.data_observed_disease_wastewater["date"].min() return self.first_wastewater_date_ - @property - def first_wastewater_date(self): - if self.data_observed_disease_wastewater is not None: - return self.data_observed_disease_wastewater["date"].min() - return self.first_wastewater_date_ - @property def last_ed_visits_date(self): return self.get_end_date( @@ -132,7 +128,9 @@ def last_data_date_overall(self): @property def n_days_post_init(self): - return (self.last_data_date_overall - self.first_data_date_overall).days + return ( + self.last_data_date_overall - self.first_data_date_overall + ).days @property def site_subpop_spine(self): @@ -155,28 +153,33 @@ def site_subpop_spine(self): .unique() .sort("site_index") ) - if add_auxiliary_subpop: - aux_subpop = pl.DataFrame( - { - "site_index": [None], - "site": [None], - "site_pop": [ - self.population_size - - site_indices.select(pl.col("site_pop")) - .get_column("site_pop") - .sum() - ], - } - ) - else: - aux_subpop = pl.DataFrame() + + site_pop = jnp.where( + add_auxiliary_subpop, + self.population_size + - site_indices.select(pl.col("site_pop")) + .get_column("site_pop") + .sum(), + jnp.nan, + ) + + aux_subpop = pl.DataFrame( + {"site_index": [None], "site": [None], "site_pop": site_pop} + ).with_columns( + pl.col("site_index").cast(pl.Int64), + pl.col("site").cast(pl.String), + pl.col("site_pop").cast(pl.Int64), + ) + site_subpop_spine = ( pl.concat([aux_subpop, site_indices], how="vertical_relaxed") .with_columns( - subpop_index=pl.col("site_index").cum_count().alias("subpop_index"), - subpop_name=pl.format("Site: {}", pl.col("site")).fill_null( - "remainder of population" - ), + subpop_index=pl.col("site_index") + .cum_count() + .alias("subpop_index"), + subpop_name=pl.format( + "Site: {}", pl.col("site") + ).fill_null("remainder of population"), ) .rename({"site_pop": "subpop_pop"}) ) @@ -233,22 +236,24 @@ def pop_fraction(self): @property def data_observed_disease_wastewater_conc(self): if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended["log_genome_copies_per_ml"].to_numpy() + return self.wastewater_data_extended[ + "log_genome_copies_per_ml" + ].to_numpy() @property def ww_censored(self): if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended.filter(pl.col("below_lod") == 1)[ - "ind_rel_to_observed_times" - ].to_numpy() + return self.wastewater_data_extended.filter( + pl.col("below_lod") == 1 + )["ind_rel_to_observed_times"].to_numpy() return None @property def ww_uncensored(self): if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended.filter(pl.col("below_lod") == 0)[ - "ind_rel_to_observed_times" - ].to_numpy() + return self.wastewater_data_extended.filter( + pl.col("below_lod") == 0 + )["ind_rel_to_observed_times"].to_numpy() @property def ww_observed_times(self): @@ -304,7 +309,9 @@ def get_end_date( ) result = None else: - result = first_date + datetime.timedelta(days=n_datapoints * timestep_days) + result = first_date + datetime.timedelta( + days=n_datapoints * timestep_days + ) return result def get_n_data_days( diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index a62151be..97b766f3 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -5,6 +5,7 @@ import jax.numpy as jnp import numpyro import numpyro.distributions as dist +import numpyro.distributions.transforms as transforms import pyrenew.transformation as transformation from jax.typing import ArrayLike from numpyro.infer.reparam import LocScaleReparam @@ -73,7 +74,9 @@ def __init__( self.autoreg_rt_subpop_rv = autoreg_rt_subpop_rv self.sigma_rt_rv = sigma_rt_rv self.sigma_i_first_obs_rv = sigma_i_first_obs_rv - self.sigma_initial_exp_growth_rate_rv = sigma_initial_exp_growth_rate_rv + self.sigma_initial_exp_growth_rate_rv = ( + sigma_initial_exp_growth_rate_rv + ) self.n_initialization_points = n_initialization_points self.pop_fraction = pop_fraction self.n_subpops = len(pop_fraction) @@ -118,20 +121,23 @@ def sample(self, n_days_post_init: int): log_rtu_weekly_subpop = log_rtu_weekly[:, jnp.newaxis] else: i_first_obs_over_n_ref_subpop = transformation.SigmoidTransform()( - transformation.logit(i0_first_obs_n) + transforms.logit(i0_first_obs_n) + self.offset_ref_logit_i_first_obs_rv(), - ) + ) # Using numpyro.distributions.transform as 'pyrenew.transformation' has no attribute 'logit' initial_exp_growth_rate_ref_subpop = ( - initial_exp_growth_rate + self.offset_ref_initial_exp_growth_rate_rv() + initial_exp_growth_rate + + self.offset_ref_initial_exp_growth_rate_rv() ) - log_rtu_weekly_ref_subpop = log_rtu_weekly + self.offset_ref_log_rt_rv() + log_rtu_weekly_ref_subpop = ( + log_rtu_weekly + self.offset_ref_log_rt_rv() + ) i_first_obs_over_n_non_ref_subpop_rv = TransformedVariable( "i_first_obs_over_n_non_ref_subpop", DistributionalVariable( "i_first_obs_over_n_non_ref_subpop_raw", dist.Normal( - transformation.logit(i0_first_obs_n), + transforms.logit(i0_first_obs_n), self.sigma_i_first_obs_rv(), ), reparam=LocScaleReparam(0), @@ -207,7 +213,9 @@ def sample(self, n_days_post_init: int): )[:n_days_post_init, :] ) # indexed rel to first post-init day. - i0_subpop_rv = DeterministicVariable("i0_subpop", i_first_obs_over_n_subpop) + i0_subpop_rv = DeterministicVariable( + "i0_subpop", i_first_obs_over_n_subpop + ) initial_exp_growth_rate_subpop_rv = DeterministicVariable( "initial_exp_growth_rate_subpop", initial_exp_growth_rate_subpop ) @@ -312,9 +320,7 @@ def sample( iedr = jnp.repeat( transformation.SigmoidTransform()(p_ed_ar + p_ed_mean), repeats=7, - )[ - :n_datapoints - ] # indexed rel to first ed report day + )[:n_datapoints] # indexed rel to first ed report day # this is only applied after the ed visits are generated, not to all # the latent infections. This is why we cannot apply the iedr in # compute_delay_ascertained_incidence @@ -334,21 +340,28 @@ def sample( )[-n_datapoints:] latent_ed_visits_final = ( - potential_latent_ed_visits * iedr * ed_wday_effect * population_size + potential_latent_ed_visits + * iedr + * ed_wday_effect + * population_size ) if right_truncation_offset is not None: prop_already_reported_tail = jnp.flip( self.ed_right_truncation_cdf_rv()[right_truncation_offset:] ) - n_points_to_prepend = n_datapoints - prop_already_reported_tail.shape[0] + n_points_to_prepend = ( + n_datapoints - prop_already_reported_tail.shape[0] + ) prop_already_reported = jnp.pad( prop_already_reported_tail, (n_points_to_prepend, 0), mode="constant", constant_values=(1, 0), ) - latent_ed_visits_now = latent_ed_visits_final * prop_already_reported + latent_ed_visits_now = ( + latent_ed_visits_final * prop_already_reported + ) else: latent_ed_visits_now = latent_ed_visits_final @@ -374,7 +387,9 @@ def __init__( ihr_rel_iedr_rv: RandomVariable = None, ) -> None: self.inf_to_hosp_admit_rv = inf_to_hosp_admit_rv - self.hosp_admit_neg_bin_concentration_rv = hosp_admit_neg_bin_concentration_rv + self.hosp_admit_neg_bin_concentration_rv = ( + hosp_admit_neg_bin_concentration_rv + ) self.ihr_rv = ihr_rv self.ihr_rel_iedr_rv = ihr_rel_iedr_rv @@ -508,7 +523,10 @@ def normed_shedding_cdf( norm_const = (t_p + t_d) * ((log_base - 1) / jnp.log(log_base) - 1) def ad_pre(x): - return t_p / jnp.log(log_base) * jnp.exp(jnp.log(log_base) * x / t_p) - x + return ( + t_p / jnp.log(log_base) * jnp.exp(jnp.log(log_base) * x / t_p) + - x + ) def ad_post(x): return ( @@ -588,10 +606,12 @@ def sample( def batch_colvolve_fn(m): return jnp.convolve(m, viral_kinetics, mode="valid") - model_net_inf_ind_shedding = jax.vmap(batch_colvolve_fn, in_axes=1, out_axes=1)( - jnp.atleast_2d(latent_infections_subpop) - )[-n_datapoints:, :] - numpyro.deterministic("model_net_inf_ind_shedding", model_net_inf_ind_shedding) + model_net_inf_ind_shedding = jax.vmap( + batch_colvolve_fn, in_axes=1, out_axes=1 + )(jnp.atleast_2d(latent_infections_subpop))[-n_datapoints:, :] + numpyro.deterministic( + "model_net_inf_ind_shedding", model_net_inf_ind_shedding + ) log10_genome_per_inf_ind = self.log10_genome_per_inf_ind_rv() expected_obs_viral_genomes = ( @@ -599,7 +619,9 @@ def batch_colvolve_fn(m): + jnp.log(model_net_inf_ind_shedding + shedding_offset) - jnp.log(self.ww_ml_produced_per_day) ) - numpyro.deterministic("expected_obs_viral_genomes", expected_obs_viral_genomes) + numpyro.deterministic( + "expected_obs_viral_genomes", expected_obs_viral_genomes + ) mode_sigma_ww_site = self.mode_sigma_ww_site_rv() sd_log_sigma_ww_site = self.sd_log_sigma_ww_site_rv() @@ -638,7 +660,11 @@ def batch_colvolve_fn(m): scale=sigma_ww_site[ww_observed_lab_sites[ww_uncensored]], ), ).sample( - obs=(data_observed[ww_uncensored] if data_observed is not None else None), + obs=( + data_observed[ww_uncensored] + if data_observed is not None + else None + ), ) if ww_censored.shape[0] != 0: @@ -695,8 +721,10 @@ def sample( sample_wastewater: bool = False, ) -> dict[str, ArrayLike]: # numpydoc ignore=GL08 n_init_days = self.latent_infection_process_rv.n_initialization_points - latent_infections, latent_infections_subpop = self.latent_infection_process_rv( - n_days_post_init=data.n_days_post_init, + latent_infections, latent_infections_subpop = ( + self.latent_infection_process_rv( + n_days_post_init=data.n_days_post_init, + ) ) first_latent_infection_dow = ( data.first_data_date_overall - datetime.timedelta(days=n_init_days) @@ -732,7 +760,6 @@ def sample( site_level_observed_wastewater, population_level_latent_wastewater, ) = self.wastewater_obs_process_rv( - latent_infections=latent_infections, latent_infections_subpop=latent_infections_subpop, data_observed=data.data_observed_disease_wastewater_conc, n_datapoints=data.n_wastewater_data_days, @@ -745,6 +772,7 @@ def sample( lab_site_to_subpop_map=data.lab_site_to_subpop_map, n_ww_lab_sites=data.n_ww_lab_sites, shedding_offset=1e-8, + pop_fraction=data.pop_fraction, ) return { diff --git a/tests/test_pyrenew_hew_data.py b/tests/test_pyrenew_hew_data.py index e3195f6a..99f42058 100644 --- a/tests/test_pyrenew_hew_data.py +++ b/tests/test_pyrenew_hew_data.py @@ -96,7 +96,10 @@ def test_to_forecast_data( assert forecast_data.n_hospital_admissions_data_days == n_weeks_expected assert forecast_data.right_truncation_offset is None assert forecast_data.first_ed_visits_date == data.first_data_date_overall - assert forecast_data.first_hospital_admissions_date == data.first_data_date_overall + assert ( + forecast_data.first_hospital_admissions_date + == data.first_data_date_overall + ) assert forecast_data.first_wastewater_date == data.first_data_date_overall assert forecast_data.data_observed_disease_wastewater is None @@ -154,7 +157,11 @@ def test_wastewater_data_properties(): data.data_observed_disease_wastewater_conc, ww_data["log_genome_copies_per_ml"], ) - assert len(data.ww_censored) == len(ww_data.filter(pl.col("below_lod") == 1)) - assert len(data.ww_uncensored) == len(ww_data.filter(pl.col("below_lod") == 0)) + assert len(data.ww_censored) == len( + ww_data.filter(pl.col("below_lod") == 1) + ) + assert len(data.ww_uncensored) == len( + ww_data.filter(pl.col("below_lod") == 0) + ) assert jnp.array_equal(data.ww_log_lod, ww_data["log_lod"]) assert data.n_ww_lab_sites == ww_data["lab_site_index"].n_unique() From bdcbc2fb8d8fafa368d0f451ba4b1467471303d4 Mon Sep 17 00:00:00 2001 From: sbidari Date: Mon, 24 Feb 2025 10:53:20 -0500 Subject: [PATCH 19/34] add PyrenewWastewaterData --- pipelines/build_pyrenew_model.py | 40 ++-- pipelines/prep_data.py | 20 ++ pipelines/tests/test_build_pyrenew_model.py | 20 +- pyrenew_hew/pyrenew_hew_data.py | 243 +++++--------------- pyrenew_hew/pyrenew_wastewater_data.py | 177 ++++++++++++++ tests/test_pyrenew_hew_data.py | 70 +----- tests/test_pyrenew_wastewater_data.py | 58 +++++ 7 files changed, 354 insertions(+), 274 deletions(-) create mode 100644 pyrenew_hew/pyrenew_wastewater_data.py create mode 100644 tests/test_pyrenew_wastewater_data.py diff --git a/pipelines/build_pyrenew_model.py b/pipelines/build_pyrenew_model.py index 3cf3e4ee..914299ec 100644 --- a/pipelines/build_pyrenew_model.py +++ b/pipelines/build_pyrenew_model.py @@ -15,6 +15,7 @@ PyrenewHEWModel, WastewaterObservationProcess, ) +from pyrenew_hew.pyrenew_wastewater_data import PyrenewWastewaterData def build_model_from_dir( @@ -92,7 +93,11 @@ def build_model_from_dir( data_observed_disease_wastewater = ( pl.DataFrame( model_data["data_observed_disease_wastewater"], - schema_overrides={"date": pl.Date}, + schema_overrides={ + "date": pl.Date, + "lab_index": pl.Int64, + "site_index": pl.Int64, + }, ) if fit_wastewater else None @@ -100,6 +105,8 @@ def build_model_from_dir( population_size = jnp.array(model_data["state_pop"]) + pop_fraction = jnp.array(model_data["pop_fraction"]) + ed_right_truncation_pmf_rv = DeterministicVariable( "right_truncation_pmf", jnp.array(model_data["right_truncation_pmf"]) ) @@ -123,18 +130,6 @@ def build_model_from_dir( right_truncation_offset = model_data["right_truncation_offset"] - dat = PyrenewHEWData( - data_observed_disease_ed_visits=data_observed_disease_ed_visits, - data_observed_disease_hospital_admissions=( - data_observed_disease_hospital_admissions - ), - data_observed_disease_wastewater=data_observed_disease_wastewater, - right_truncation_offset=right_truncation_offset, - first_ed_visits_date=first_ed_visits_date, - first_hospital_admissions_date=first_hospital_admissions_date, - population_size=population_size, - ) - latent_infections_rv = LatentInfectionProcess( i0_first_obs_n_rv=priors["i0_first_obs_n_rv"], initialization_rate_rv=priors["initialization_rate_rv"], @@ -145,7 +140,7 @@ def build_model_from_dir( infection_feedback_strength_rv=priors["inf_feedback_strength_rv"], infection_feedback_pmf_rv=infection_feedback_pmf_rv, n_initialization_points=n_initialization_points, - pop_fraction=dat.pop_fraction if fit_wastewater else jnp.array([1]), + pop_fraction=pop_fraction, autoreg_rt_subpop_rv=priors["autoreg_rt_subpop_rv"], sigma_rt_rv=priors["sigma_rt_rv"], sigma_i_first_obs_rv=priors["sigma_i_first_obs_rv"], @@ -201,4 +196,21 @@ def build_model_from_dir( wastewater_obs_process_rv=wastewater_obs_rv, ) + wastewater_data = PyrenewWastewaterData( + data_observed_disease_wastewater=data_observed_disease_wastewater, + population_size=population_size, + pop_fraction=pop_fraction, + ) + + dat = PyrenewHEWData( + data_observed_disease_ed_visits=data_observed_disease_ed_visits, + data_observed_disease_hospital_admissions=( + data_observed_disease_hospital_admissions + ), + right_truncation_offset=right_truncation_offset, + first_ed_visits_date=first_ed_visits_date, + first_hospital_admissions_date=first_hospital_admissions_date, + wastewater_data=wastewater_data, + ) + return (mod, dat) diff --git a/pipelines/prep_data.py b/pipelines/prep_data.py index 2ee60a79..0740988a 100644 --- a/pipelines/prep_data.py +++ b/pipelines/prep_data.py @@ -8,6 +8,7 @@ from pathlib import Path import forecasttools +import jax.numpy as jnp import polars as pl import polars.selectors as cs @@ -456,6 +457,24 @@ def process_and_save_state( else None ) + if state_level_nwss_data is None: + pop_fraction = jnp.array([1]) + else: + subpop_sizes = ( + state_level_nwss_data.select(["site_index", "site", "site_pop"]) + .unique()["site_pop"] + .to_numpy() + ) + pop_fraction = jnp.where( + state_pop > subpop_sizes.sum(), + jnp.concatenate( + [jnp.array([state_pop - subpop_sizes.sum()]), subpop_sizes], + axis=0, + ) + / state_pop, + subpop_sizes / state_pop, + ) + data_for_model_fit = { "inf_to_ed_pmf": delay_pmf, "generation_interval_pmf": generation_interval_pmf, @@ -470,6 +489,7 @@ def process_and_save_state( "state_pop": state_pop, "right_truncation_offset": right_truncation_offset, "data_observed_disease_wastewater": data_observed_disease_wastewater, + "pop_fraction": pop_fraction, } data_dir = Path(model_run_dir, "data") diff --git a/pipelines/tests/test_build_pyrenew_model.py b/pipelines/tests/test_build_pyrenew_model.py index 95839652..a9a0613d 100644 --- a/pipelines/tests/test_build_pyrenew_model.py +++ b/pipelines/tests/test_build_pyrenew_model.py @@ -13,7 +13,7 @@ def mock_data(): { "data_observed_disease_ed_visits": [1, 2, 3], "data_observed_disease_hospital_admissions": [4, 5, 6], - "state_pop": [7, 8, 9], + "state_pop": [10000], "generation_interval_pmf": [0.1, 0.2, 0.7], "inf_to_ed_pmf": [0.4, 0.5, 0.1], "inf_to_hosp_admit_pmf": [0.0, 0.7, 0.1, 0.1, 0.1], @@ -32,13 +32,14 @@ def mock_data(): ], "site": ["1.0", "1.0", "2.0", "2.0"], "lab": ["1.0", "1.0", "1.0", "1.0"], - "site_pop": [4000000, 4000000, 2000000, 2000000], + "site_pop": [4000, 4000, 2000, 2000], "site_index": [1, 1, 0, 0], "lab_site_index": [1, 1, 0, 0], "log_genome_copies_per_ml": [0.1, 0.1, 0.5, 0.4], "log_lod": [1.1, 2.0, 1.5, 2.1], "below_lod": [False, False, False, False], }, + "pop_fraction": [0.4, 0.4, 0.2], } ) @@ -72,6 +73,14 @@ def mock_priors(): mode_sd_ww_site_rv = None max_shed_interval = None ww_ml_produced_per_day = None +pop_fraction=None +autoreg_rt_subpop_rv=None +sigma_rt_rv=None +sigma_i_first_obs_rv=None +sigma_initial_exp_growth_rate_rv=None +offset_ref_logit_i_first_obs_rv=None +offset_ref_initial_exp_growth_rate_rv=None +offset_ref_log_rt_rv=None """ @@ -91,7 +100,7 @@ def test_build_model_from_dir(tmp_path, mock_data, mock_priors): _, data = build_model_from_dir(model_dir) assert data.data_observed_disease_ed_visits is None assert data.data_observed_disease_hospital_admissions is None - assert data.data_observed_disease_wastewater is None + assert data.data_observed_disease_wastewater_conc is None # Test when all `fit_` arguments are True _, data = build_model_from_dir( @@ -108,7 +117,4 @@ def test_build_model_from_dir(tmp_path, mock_data, mock_priors): data.data_observed_disease_hospital_admissions, jnp.array(model_data["data_observed_disease_hospital_admissions"]), ) - assert pl.DataFrame( - model_data["data_observed_disease_wastewater"], - schema_overrides={"date": pl.Date}, - ).equals(data.data_observed_disease_wastewater) + # assert data.data_observed_disease_wastewater_conc is not None diff --git a/pyrenew_hew/pyrenew_hew_data.py b/pyrenew_hew/pyrenew_hew_data.py index 2637df83..e0b880a3 100644 --- a/pyrenew_hew/pyrenew_hew_data.py +++ b/pyrenew_hew/pyrenew_hew_data.py @@ -1,10 +1,10 @@ import datetime from typing import Self -import jax.numpy as jnp -import polars as pl from jax.typing import ArrayLike +from pyrenew_hew.pyrenew_wastewater_data import PyrenewWastewaterData + class PyrenewHEWData: """ @@ -19,19 +19,15 @@ def __init__( n_wastewater_data_days: int = None, data_observed_disease_ed_visits: ArrayLike = None, data_observed_disease_hospital_admissions: ArrayLike = None, - data_observed_disease_wastewater: pl.DataFrame = None, right_truncation_offset: int = None, first_ed_visits_date: datetime.date = None, first_hospital_admissions_date: datetime.date = None, first_wastewater_date: datetime.date = None, - population_size: int = None, - shedding_offset: float = 1e-8, - pop_fraction: ArrayLike = jnp.array([1]), + wastewater_data: PyrenewWastewaterData = None, ) -> None: self.n_ed_visits_data_days_ = n_ed_visits_data_days self.n_hospital_admissions_data_days_ = n_hospital_admissions_data_days self.n_wastewater_data_days_ = n_wastewater_data_days - self.data_observed_disease_ed_visits = data_observed_disease_ed_visits self.data_observed_disease_hospital_admissions = ( data_observed_disease_hospital_admissions @@ -40,12 +36,51 @@ def __init__( self.first_ed_visits_date = first_ed_visits_date self.first_hospital_admissions_date = first_hospital_admissions_date self.first_wastewater_date_ = first_wastewater_date - self.data_observed_disease_wastewater = ( - data_observed_disease_wastewater + self.date_observed_disease_wastewater = ( + None + if wastewater_data is None + else wastewater_data.date_observed_disease_wastewater + ) + self.pop_fraction = ( + None if wastewater_data is None else wastewater_data.pop_fraction + ) + self.data_observed_disease_wastewater_conc = ( + None + if wastewater_data is None + else wastewater_data.data_observed_disease_wastewater_conc + ) + self.ww_censored = ( + None if wastewater_data is None else wastewater_data.ww_censored + ) + self.ww_uncensored = ( + None if wastewater_data is None else wastewater_data.ww_uncensored + ) + self.ww_observed_times = ( + None + if wastewater_data is None + else wastewater_data.ww_observed_times + ) + self.ww_observed_subpops = ( + None + if wastewater_data is None + else wastewater_data.ww_observed_subpops + ) + self.ww_observed_lab_sites = ( + None + if wastewater_data is None + else wastewater_data.ww_observed_lab_sites + ) + self.ww_log_lod = ( + None if wastewater_data is None else wastewater_data.ww_log_lod + ) + self.n_ww_lab_sites = ( + None if wastewater_data is None else wastewater_data.n_ww_lab_sites + ) + self.lab_site_to_subpop_map = ( + None + if wastewater_data is None + else wastewater_data.lab_site_to_subpop_map ) - self.population_size = population_size - self.shedding_offset = shedding_offset - self.pop_fraction_ = pop_fraction @property def n_ed_visits_data_days(self): @@ -65,19 +100,23 @@ def n_hospital_admissions_data_days(self): def n_wastewater_data_days(self): return self.get_n_wastewater_data_days( n_datapoints=self.n_wastewater_data_days_, - date_array=( - None - if self.data_observed_disease_wastewater is None - else self.data_observed_disease_wastewater["date"] - ), + date_array=self.date_observed_disease_wastewater, ) @property def first_wastewater_date(self): - if self.data_observed_disease_wastewater is not None: - return self.data_observed_disease_wastewater["date"].min() + if self.date_observed_disease_wastewater is not None: + return self.date_observed_disease_wastewater.min() return self.first_wastewater_date_ + @property + def last_wastewater_date(self): + return self.get_end_date( + self.first_wastewater_date, + self.n_wastewater_data_days, + timestep_days=1, + ) + @property def last_ed_visits_date(self): return self.get_end_date( @@ -94,14 +133,6 @@ def last_hospital_admissions_date(self): timestep_days=7, ) - @property - def last_wastewater_date(self): - return self.get_end_date( - self.first_wastewater_date, - self.n_wastewater_data_days, - timestep_days=1, - ) - @property def first_data_dates(self): return dict( @@ -132,163 +163,6 @@ def n_days_post_init(self): self.last_data_date_overall - self.first_data_date_overall ).days - @property - def site_subpop_spine(self): - ww_data_present = self.data_observed_disease_wastewater is not None - if ww_data_present: - # Check if auxiliary subpopulation needs to be added - add_auxiliary_subpop = ( - self.population_size - > self.data_observed_disease_wastewater.select( - pl.col("site_pop", "site", "lab", "lab_site_index") - ) - .unique() - .get_column("site_pop") - .sum() - ) - site_indices = ( - self.data_observed_disease_wastewater.select( - ["site_index", "site", "site_pop"] - ) - .unique() - .sort("site_index") - ) - - site_pop = jnp.where( - add_auxiliary_subpop, - self.population_size - - site_indices.select(pl.col("site_pop")) - .get_column("site_pop") - .sum(), - jnp.nan, - ) - - aux_subpop = pl.DataFrame( - {"site_index": [None], "site": [None], "site_pop": site_pop} - ).with_columns( - pl.col("site_index").cast(pl.Int64), - pl.col("site").cast(pl.String), - pl.col("site_pop").cast(pl.Int64), - ) - - site_subpop_spine = ( - pl.concat([aux_subpop, site_indices], how="vertical_relaxed") - .with_columns( - subpop_index=pl.col("site_index") - .cum_count() - .alias("subpop_index"), - subpop_name=pl.format( - "Site: {}", pl.col("site") - ).fill_null("remainder of population"), - ) - .rename({"site_pop": "subpop_pop"}) - ) - else: - site_subpop_spine = pl.DataFrame( - { - "site_index": [None], - "site": [None], - "subpop_pop": [self.population_size], - "subpop_index": [1], - "subpop_name": ["total population"], - } - ) - return site_subpop_spine - - @property - def date_time_spine(self): - if self.data_observed_disease_wastewater is not None: - date_time_spine = pl.DataFrame( - { - "date": pl.date_range( - start=self.first_wastewater_date, - end=self.last_wastewater_date, - interval="1d", - eager=True, - ) - } - ).with_row_index("t") - return date_time_spine - - @property - def wastewater_data_extended(self): - if self.data_observed_disease_wastewater is not None: - return ( - self.data_observed_disease_wastewater.join( - self.date_time_spine, on="date", how="left", coalesce=True - ) - .join( - self.site_subpop_spine, - on=["site_index", "site"], - how="left", - coalesce=True, - ) - .with_row_index("ind_rel_to_observed_times") - ) - - @property - def pop_fraction(self): - if self.data_observed_disease_wastewater is not None: - subpop_sizes = self.site_subpop_spine["subpop_pop"].to_numpy() - return subpop_sizes / self.population_size - return self.pop_fraction_ - - @property - def data_observed_disease_wastewater_conc(self): - if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended[ - "log_genome_copies_per_ml" - ].to_numpy() - - @property - def ww_censored(self): - if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended.filter( - pl.col("below_lod") == 1 - )["ind_rel_to_observed_times"].to_numpy() - return None - - @property - def ww_uncensored(self): - if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended.filter( - pl.col("below_lod") == 0 - )["ind_rel_to_observed_times"].to_numpy() - - @property - def ww_observed_times(self): - if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended["t"].to_numpy() - - @property - def ww_observed_subpops(self): - if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended["subpop_index"].to_numpy() - - @property - def ww_observed_lab_sites(self): - if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended["lab_site_index"].to_numpy() - - @property - def ww_log_lod(self): - if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended["log_lod"].to_numpy() - - @property - def n_ww_lab_sites(self): - if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended["lab_site_index"].n_unique() - - @property - def lab_site_to_subpop_map(self): - if self.data_observed_disease_wastewater is not None: - return ( - self.wastewater_data_extended["lab_site_index", "subpop_index"] - .unique() - .sort(by="lab_site_index") - )["subpop_index"].to_numpy() - def get_end_date( self, first_date: datetime.date, @@ -357,6 +231,5 @@ def to_forecast_data(self, n_forecast_points: int) -> Self: first_ed_visits_date=self.first_data_date_overall, first_hospital_admissions_date=(self.first_data_date_overall), first_wastewater_date=self.first_data_date_overall, - pop_fraction=self.pop_fraction, right_truncation_offset=None, # by default, want forecasts of complete reports ) diff --git a/pyrenew_hew/pyrenew_wastewater_data.py b/pyrenew_hew/pyrenew_wastewater_data.py new file mode 100644 index 00000000..915bdb32 --- /dev/null +++ b/pyrenew_hew/pyrenew_wastewater_data.py @@ -0,0 +1,177 @@ +import datetime +from typing import Self + +import jax +import jax.numpy as jnp +import polars as pl +from jax.typing import ArrayLike + + +class PyrenewWastewaterData: + """ + Class for holding wastewater input data + to a PyrenewHEW model. + """ + + def __init__( + self, + data_observed_disease_wastewater: pl.DataFrame = None, + population_size: int = None, + pop_fraction: ArrayLike = jnp.array([1]), + ) -> None: + self.data_observed_disease_wastewater = ( + data_observed_disease_wastewater + ) + self.population_size = population_size + self.pop_fraction = pop_fraction + + @property + def site_subpop_spine(self): + ww_data_present = self.data_observed_disease_wastewater is not None + if ww_data_present: + add_auxiliary_subpop = ( + self.population_size + > self.data_observed_disease_wastewater.select( + pl.col("site_pop", "site", "lab", "lab_site_index") + ) + .unique() + .get_column("site_pop") + .sum() + ) + site_indices = ( + self.data_observed_disease_wastewater.select( + ["site_index", "site", "site_pop"] + ) + .unique() + .sort("site_index") + ) + if add_auxiliary_subpop: + aux_subpop = pl.DataFrame( + { + "site_index": [None], + "site": [None], + "site_pop": [ + self.population_size + - site_indices.select(pl.col("site_pop")) + .get_column("site_pop") + .sum() + ], + } + ) + else: + aux_subpop = pl.DataFrame() + site_subpop_spine = ( + pl.concat([aux_subpop, site_indices], how="vertical_relaxed") + .with_columns( + subpop_index=pl.col("site_index") + .cum_count() + .alias("subpop_index"), + subpop_name=pl.format( + "Site: {}", pl.col("site") + ).fill_null("remainder of population"), + ) + .rename({"site_pop": "subpop_pop"}) + ) + else: + site_subpop_spine = pl.DataFrame( + { + "site_index": [None], + "site": [None], + "subpop_pop": [self.population_size], + "subpop_index": [1], + "subpop_name": ["total population"], + } + ) + return site_subpop_spine + + @property + def date_time_spine(self): + if self.data_observed_disease_wastewater is not None: + date_time_spine = pl.DataFrame( + { + "date": pl.date_range( + start=self.date_observed_disease_wastewater.min(), + end=self.date_observed_disease_wastewater.max(), + interval="1d", + eager=True, + ) + } + ).with_row_index("t") + return date_time_spine + + @property + def wastewater_data_extended(self): + if self.data_observed_disease_wastewater is not None: + return ( + self.data_observed_disease_wastewater.join( + self.date_time_spine, on="date", how="left", coalesce=True + ) + .join( + self.site_subpop_spine, + on=["site_index", "site"], + how="left", + coalesce=True, + ) + .with_row_index("ind_rel_to_observed_times") + ) + + @property + def date_observed_disease_wastewater(self): + if self.data_observed_disease_wastewater is not None: + return self.data_observed_disease_wastewater["date"].to_numpy() + + @property + def data_observed_disease_wastewater_conc(self): + if self.data_observed_disease_wastewater is not None: + return self.wastewater_data_extended[ + "log_genome_copies_per_ml" + ].to_numpy() + + @property + def ww_censored(self): + if self.data_observed_disease_wastewater is not None: + return self.wastewater_data_extended.filter( + pl.col("below_lod") == 1 + )["ind_rel_to_observed_times"].to_numpy() + return None + + @property + def ww_uncensored(self): + if self.data_observed_disease_wastewater is not None: + return self.wastewater_data_extended.filter( + pl.col("below_lod") == 0 + )["ind_rel_to_observed_times"].to_numpy() + + @property + def ww_observed_times(self): + if self.data_observed_disease_wastewater is not None: + return self.wastewater_data_extended["t"].to_numpy() + + @property + def ww_observed_subpops(self): + if self.data_observed_disease_wastewater is not None: + return self.wastewater_data_extended["subpop_index"].to_numpy() + + @property + def ww_observed_lab_sites(self): + if self.data_observed_disease_wastewater is not None: + return self.wastewater_data_extended["lab_site_index"].to_numpy() + + @property + def ww_log_lod(self): + if self.data_observed_disease_wastewater is not None: + return self.wastewater_data_extended["log_lod"].to_numpy() + + @property + def n_ww_lab_sites(self): + if self.data_observed_disease_wastewater is not None: + return self.wastewater_data_extended["lab_site_index"].n_unique() + + @property + def lab_site_to_subpop_map(self): + if self.data_observed_disease_wastewater is not None: + return ( + self.wastewater_data_extended["lab_site_index", "subpop_index"] + .unique() + .sort(by="lab_site_index") + )["subpop_index"].to_numpy() diff --git a/tests/test_pyrenew_hew_data.py b/tests/test_pyrenew_hew_data.py index 99f42058..8b22bef5 100644 --- a/tests/test_pyrenew_hew_data.py +++ b/tests/test_pyrenew_hew_data.py @@ -1,11 +1,9 @@ import datetime -import jax.numpy as jnp -import numpy as np -import polars as pl import pytest from pyrenew_hew.pyrenew_hew_data import PyrenewHEWData +from pyrenew_hew.pyrenew_wastewater_data import PyrenewWastewaterData @pytest.mark.parametrize( @@ -75,13 +73,12 @@ def test_to_forecast_data( """ Test the to_forecast_data method """ + data = PyrenewHEWData( n_ed_visits_data_days=n_ed_visits_data_days, n_hospital_admissions_data_days=n_hospital_admissions_data_days, - n_wastewater_data_days=n_wastewater_data_days, first_ed_visits_date=first_ed_visits_date, first_hospital_admissions_date=first_hospital_admissions_date, - first_wastewater_date=first_wastewater_date, right_truncation_offset=right_truncation_offset, ) @@ -101,67 +98,4 @@ def test_to_forecast_data( == data.first_data_date_overall ) assert forecast_data.first_wastewater_date == data.first_data_date_overall - assert forecast_data.data_observed_disease_wastewater is None - - -def test_wastewater_data_properties(): - first_training_date = datetime.date(2023, 1, 1) - last_training_date = datetime.date(2023, 7, 23) - dates = pl.date_range( - first_training_date, - last_training_date, - interval="1w", - closed="both", - eager=True, - ) - - ww_raw = pl.DataFrame( - { - "date": dates.extend(dates), - "site": [200] * 30 + [100] * 30, - "lab": [21] * 60, - "lab_site_index": [1] * 30 + [2] * 30, - "site_index": [1] * 30 + [2] * 30, - "log_genome_copies_per_ml": np.log( - np.abs(np.random.normal(loc=500, scale=50, size=60)) - ), - "log_lod": np.log([20] * 30 + [15] * 30), - "site_pop": [200_000] * 30 + [400_000] * 30, - } - ) - - ww_data = ww_raw.with_columns( - (pl.col("log_genome_copies_per_ml") <= pl.col("log_lod")) - .cast(pl.Int8) - .alias("below_lod") - ) - - first_ed_visits_date = datetime.date(2023, 1, 1) - first_hospital_admissions_date = datetime.date(2023, 1, 1) - first_wastewater_date = datetime.date(2023, 1, 1) - n_forecast_points = 10 - - data = PyrenewHEWData( - first_ed_visits_date=first_ed_visits_date, - first_hospital_admissions_date=first_hospital_admissions_date, - first_wastewater_date=first_wastewater_date, - data_observed_disease_wastewater=ww_data, - population_size=1e6, - ) - - forecast_data = data.to_forecast_data(n_forecast_points) assert forecast_data.data_observed_disease_wastewater_conc is None - assert data.data_observed_disease_wastewater_conc is not None - - assert jnp.array_equal( - data.data_observed_disease_wastewater_conc, - ww_data["log_genome_copies_per_ml"], - ) - assert len(data.ww_censored) == len( - ww_data.filter(pl.col("below_lod") == 1) - ) - assert len(data.ww_uncensored) == len( - ww_data.filter(pl.col("below_lod") == 0) - ) - assert jnp.array_equal(data.ww_log_lod, ww_data["log_lod"]) - assert data.n_ww_lab_sites == ww_data["lab_site_index"].n_unique() diff --git a/tests/test_pyrenew_wastewater_data.py b/tests/test_pyrenew_wastewater_data.py new file mode 100644 index 00000000..2f61ecc9 --- /dev/null +++ b/tests/test_pyrenew_wastewater_data.py @@ -0,0 +1,58 @@ +import datetime + +import jax.numpy as jnp +import numpy as np +import polars as pl + +from pyrenew_hew.pyrenew_wastewater_data import PyrenewWastewaterData + + +def test_wastewater_data_properties(): + first_training_date = datetime.date(2023, 1, 1) + last_training_date = datetime.date(2023, 7, 23) + dates = pl.date_range( + first_training_date, + last_training_date, + interval="1w", + closed="both", + eager=True, + ) + + ww_raw = pl.DataFrame( + { + "date": dates.extend(dates), + "site": ["200"] * 30 + ["100"] * 30, + "lab": ["21"] * 60, + "lab_site_index": [1] * 30 + [2] * 30, + "site_index": [1] * 30 + [2] * 30, + "log_genome_copies_per_ml": np.log( + np.abs(np.random.normal(loc=500, scale=50, size=60)) + ), + "log_lod": np.log([20] * 30 + [15] * 30), + "site_pop": [200_000] * 30 + [400_000] * 30, + } + ) + + ww_data = ww_raw.with_columns( + (pl.col("log_genome_copies_per_ml") <= pl.col("log_lod")) + .cast(pl.Int8) + .alias("below_lod") + ) + + data = PyrenewWastewaterData( + data_observed_disease_wastewater=ww_data, + population_size=1e6, + ) + + assert jnp.array_equal( + data.data_observed_disease_wastewater_conc, + ww_data["log_genome_copies_per_ml"], + ) + assert len(data.ww_censored) == len( + ww_data.filter(pl.col("below_lod") == 1) + ) + assert len(data.ww_uncensored) == len( + ww_data.filter(pl.col("below_lod") == 0) + ) + assert jnp.array_equal(data.ww_log_lod, ww_data["log_lod"]) + assert data.n_ww_lab_sites == ww_data["lab_site_index"].n_unique() From 83dd48ff4471bd2a7f0a341639b0cd1430e4c121 Mon Sep 17 00:00:00 2001 From: sbidari Date: Mon, 24 Feb 2025 19:42:27 -0500 Subject: [PATCH 20/34] code clean up --- pipelines/prep_data.py | 18 +++---- pyrenew_hew/pyrenew_hew_data.py | 74 +++++++++++++++++++++----- pyrenew_hew/pyrenew_wastewater_data.py | 60 +++++++++++++-------- tests/test_pyrenew_wastewater_data.py | 16 +++--- 4 files changed, 116 insertions(+), 52 deletions(-) diff --git a/pipelines/prep_data.py b/pipelines/prep_data.py index 0740988a..00cf46c3 100644 --- a/pipelines/prep_data.py +++ b/pipelines/prep_data.py @@ -9,6 +9,7 @@ import forecasttools import jax.numpy as jnp +import numpy as np import polars as pl import polars.selectors as cs @@ -465,15 +466,12 @@ def process_and_save_state( .unique()["site_pop"] .to_numpy() ) - pop_fraction = jnp.where( - state_pop > subpop_sizes.sum(), - jnp.concatenate( - [jnp.array([state_pop - subpop_sizes.sum()]), subpop_sizes], - axis=0, - ) - / state_pop, - subpop_sizes / state_pop, - ) + if state_pop > sum(subpop_sizes): + pop_fraction = ( + [state_pop - sum(subpop_sizes)] + subpop_sizes + ) / state_pop + else: + pop_fraction = subpop_sizes / state_pop data_for_model_fit = { "inf_to_ed_pmf": delay_pmf, @@ -489,7 +487,7 @@ def process_and_save_state( "state_pop": state_pop, "right_truncation_offset": right_truncation_offset, "data_observed_disease_wastewater": data_observed_disease_wastewater, - "pop_fraction": pop_fraction, + "pop_fraction": pop_fraction.tolist(), } data_dir = Path(model_run_dir, "data") diff --git a/pyrenew_hew/pyrenew_hew_data.py b/pyrenew_hew/pyrenew_hew_data.py index e0b880a3..93be55f5 100644 --- a/pyrenew_hew/pyrenew_hew_data.py +++ b/pyrenew_hew/pyrenew_hew_data.py @@ -1,6 +1,7 @@ import datetime from typing import Self +import jax.numpy as jnp from jax.typing import ArrayLike from pyrenew_hew.pyrenew_wastewater_data import PyrenewWastewaterData @@ -24,6 +25,12 @@ def __init__( first_hospital_admissions_date: datetime.date = None, first_wastewater_date: datetime.date = None, wastewater_data: PyrenewWastewaterData = None, + n_ww_lab_sites: int = None, + ww_censored: ArrayLike = None, + ww_uncensored: ArrayLike = None, + ww_observed_subpops: ArrayLike = None, + ww_observed_times: ArrayLike = None, + ww_observed_lab_sites: ArrayLike = None, ) -> None: self.n_ed_visits_data_days_ = n_ed_visits_data_days self.n_hospital_admissions_data_days_ = n_hospital_admissions_data_days @@ -42,7 +49,9 @@ def __init__( else wastewater_data.date_observed_disease_wastewater ) self.pop_fraction = ( - None if wastewater_data is None else wastewater_data.pop_fraction + jnp.array([1]) + if wastewater_data is None + else jnp.array(wastewater_data.pop_fraction) ) self.data_observed_disease_wastewater_conc = ( None @@ -50,32 +59,65 @@ def __init__( else wastewater_data.data_observed_disease_wastewater_conc ) self.ww_censored = ( - None if wastewater_data is None else wastewater_data.ww_censored + ww_censored + if ww_censored is not None + else ( + None + if wastewater_data is None + else wastewater_data.ww_censored + ) ) self.ww_uncensored = ( - None if wastewater_data is None else wastewater_data.ww_uncensored + ww_uncensored + if ww_uncensored is not None + else ( + None + if wastewater_data is None + else wastewater_data.ww_uncensored + ) ) self.ww_observed_times = ( - None - if wastewater_data is None - else wastewater_data.ww_observed_times + ww_observed_times + if ww_observed_times is not None + else ( + None + if wastewater_data is None + else wastewater_data.ww_observed_times + ) ) self.ww_observed_subpops = ( - None - if wastewater_data is None - else wastewater_data.ww_observed_subpops + ww_observed_subpops + if ww_observed_subpops is not None + else ( + None + if wastewater_data is None + else wastewater_data.ww_observed_subpops + ) ) + self.ww_observed_lab_sites = ( - None - if wastewater_data is None - else wastewater_data.ww_observed_lab_sites + ww_observed_lab_sites + if ww_observed_lab_sites is not None + else ( + None + if wastewater_data is None + else wastewater_data.ww_observed_lab_sites + ) ) + self.ww_log_lod = ( None if wastewater_data is None else wastewater_data.ww_log_lod ) self.n_ww_lab_sites = ( - None if wastewater_data is None else wastewater_data.n_ww_lab_sites + n_ww_lab_sites + if n_ww_lab_sites is not None + else ( + None + if wastewater_data is None + else wastewater_data.n_ww_lab_sites + ) ) + self.lab_site_to_subpop_map = ( None if wastewater_data is None @@ -232,4 +274,10 @@ def to_forecast_data(self, n_forecast_points: int) -> Self: first_hospital_admissions_date=(self.first_data_date_overall), first_wastewater_date=self.first_data_date_overall, right_truncation_offset=None, # by default, want forecasts of complete reports + n_ww_lab_sites=self.n_ww_lab_sites, + ww_uncensored=self.ww_uncensored, + ww_censored=self.ww_censored, + ww_observed_lab_sites=self.ww_observed_lab_sites, + ww_observed_subpops=self.ww_observed_subpops, + ww_observed_times=self.ww_observed_times, ) diff --git a/pyrenew_hew/pyrenew_wastewater_data.py b/pyrenew_hew/pyrenew_wastewater_data.py index 915bdb32..cc0d027d 100644 --- a/pyrenew_hew/pyrenew_wastewater_data.py +++ b/pyrenew_hew/pyrenew_wastewater_data.py @@ -50,12 +50,10 @@ def site_subpop_spine(self): { "site_index": [None], "site": [None], - "site_pop": [ + "site_pop": ( self.population_size - - site_indices.select(pl.col("site_pop")) - .get_column("site_pop") - .sum() - ], + - site_indices.get_column("site_pop").sum() + ).tolist(), } ) else: @@ -118,49 +116,61 @@ def wastewater_data_extended(self): @property def date_observed_disease_wastewater(self): if self.data_observed_disease_wastewater is not None: - return self.data_observed_disease_wastewater["date"].to_numpy() + return self.data_observed_disease_wastewater.get_column( + "date" + ).unique() @property def data_observed_disease_wastewater_conc(self): if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended[ + return self.wastewater_data_extended.get_column( "log_genome_copies_per_ml" - ].to_numpy() + ).to_numpy() @property def ww_censored(self): if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended.filter( - pl.col("below_lod") == 1 - )["ind_rel_to_observed_times"].to_numpy() + return ( + self.wastewater_data_extended.filter(pl.col("below_lod") == 1) + .get_column("ind_rel_to_observed_times") + .to_numpy() + ) return None @property def ww_uncensored(self): if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended.filter( - pl.col("below_lod") == 0 - )["ind_rel_to_observed_times"].to_numpy() + return ( + self.wastewater_data_extended.filter(pl.col("below_lod") == 0) + .get_column("ind_rel_to_observed_times") + .to_numpy() + ) @property def ww_observed_times(self): if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended["t"].to_numpy() + return self.wastewater_data_extended.get_column("t").to_numpy() @property def ww_observed_subpops(self): if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended["subpop_index"].to_numpy() + return self.wastewater_data_extended.get_column( + "subpop_index" + ).to_numpy() @property def ww_observed_lab_sites(self): if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended["lab_site_index"].to_numpy() + return self.wastewater_data_extended.get_column( + "lab_site_index" + ).to_numpy() @property def ww_log_lod(self): if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended["log_lod"].to_numpy() + return self.wastewater_data_extended.get_column( + "log_lod" + ).to_numpy() @property def n_ww_lab_sites(self): @@ -171,7 +181,13 @@ def n_ww_lab_sites(self): def lab_site_to_subpop_map(self): if self.data_observed_disease_wastewater is not None: return ( - self.wastewater_data_extended["lab_site_index", "subpop_index"] - .unique() - .sort(by="lab_site_index") - )["subpop_index"].to_numpy() + ( + self.wastewater_data_extended[ + "lab_site_index", "subpop_index" + ] + .unique() + .sort(by="lab_site_index") + ) + .get_column("subpop_index") + .to_numpy() + ) diff --git a/tests/test_pyrenew_wastewater_data.py b/tests/test_pyrenew_wastewater_data.py index 2f61ecc9..19c4a4a2 100644 --- a/tests/test_pyrenew_wastewater_data.py +++ b/tests/test_pyrenew_wastewater_data.py @@ -4,6 +4,7 @@ import numpy as np import polars as pl +from pyrenew_hew.pyrenew_hew_data import PyrenewHEWData from pyrenew_hew.pyrenew_wastewater_data import PyrenewWastewaterData @@ -29,19 +30,20 @@ def test_wastewater_data_properties(): np.abs(np.random.normal(loc=500, scale=50, size=60)) ), "log_lod": np.log([20] * 30 + [15] * 30), - "site_pop": [200_000] * 30 + [400_000] * 30, + "site_pop": [200000] * 30 + [400000] * 30, } ) ww_data = ww_raw.with_columns( - (pl.col("log_genome_copies_per_ml") <= pl.col("log_lod")) - .cast(pl.Int8) - .alias("below_lod") + below_lod=pl.col("log_genome_copies_per_ml") <= pl.col("log_lod") ) - data = PyrenewWastewaterData( - data_observed_disease_wastewater=ww_data, - population_size=1e6, + data = PyrenewHEWData( + wastewater_data=PyrenewWastewaterData( + data_observed_disease_wastewater=ww_data, + population_size=1e6, + pop_fraction=[0.4, 0.2, 0.4], + ), ) assert jnp.array_equal( From a94c4b6d2a2c42ae7a47f517d9455bb15b836ae0 Mon Sep 17 00:00:00 2001 From: sbidari Date: Tue, 25 Feb 2025 15:17:08 -0500 Subject: [PATCH 21/34] only render diagnostic report for pyrenew_e --- pipelines/forecast_state.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pipelines/forecast_state.py b/pipelines/forecast_state.py index e73cf85c..719de5c4 100644 --- a/pipelines/forecast_state.py +++ b/pipelines/forecast_state.py @@ -453,9 +453,10 @@ def main( ) logger.info("Postprocessing complete.") - logger.info("Rendering webpage...") - render_diagnostic_report(model_run_dir) - logger.info("Rendering complete.") + if pyrenew_model_name == "pyrenew_e": + logger.info("Rendering webpage...") + render_diagnostic_report(model_run_dir) + logger.info("Rendering complete.") if score: logger.info("Scoring forecast...") From b97de0d8794319c31f225d0946ae2536ced2a699 Mon Sep 17 00:00:00 2001 From: sbidari Date: Tue, 25 Feb 2025 15:18:37 -0500 Subject: [PATCH 22/34] update test ww data generation --- pipelines/generate_test_data.R | 49 +++++++++++++++++++++------------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/pipelines/generate_test_data.R b/pipelines/generate_test_data.R index 84d5c2d1..7de4d9b2 100644 --- a/pipelines/generate_test_data.R +++ b/pipelines/generate_test_data.R @@ -315,30 +315,41 @@ generate_fake_nwss_data <- function( states_to_generate = c("MT", "CA"), start_reference = as.Date("2024-06-01"), end_reference = as.Date("2024-12-21"), - site = c(1, 2, 3, 4), - lab = c(1, 1, 3, 3), + site = list( + CA = c(1, 2, 3, 4), + MT = c(5, 6, 7, 8) + ), + lab = list( + CA = c(1, 1, 2, 2), + MT = c(3, 3, 4, 4) + ), lod = c(20, 31, 20, 30), - site_pop = c(4e6, 2e6, 1e6, 5e5)) { + site_pop = list( + CA = c(4e6, 2e6, 1e6, 5e5), + MT = c(3e5, 2e5, 1e5, 5e4) + )) { ww_dir <- fs::path(private_data_dir, "nwss_vintages") fs::dir_create(ww_dir, recurse = TRUE) - site_info <- tibble::tibble( - wwtp_id = site, - lab_id = lab, - lod_sewage = lod, - population_served = site_pop, - sample_location = "wwtp", - sample_matrix = "raw wastewater", - pcr_target_units = "copies/l wastewater", - pcr_target = "sars-cov-2", - quality_flag = c("no", NA_character_, "n", "n") - ) + site_info <- function(state) { + tibble::tibble( + wwtp_id = site[[state]], + lab_id = lab[[state]], + lod_sewage = lod, + population_served = site_pop[[state]], + sample_location = "wwtp", + sample_matrix = "raw wastewater", + pcr_target_units = "copies/l wastewater", + pcr_target = "sars-cov-2", + quality_flag = c("no", NA_character_, "n", "n"), + wwtp_jurisdiction = state + ) + } - ww_data <- tidyr::expand_grid( - sample_collect_date = seq(start_reference, end_reference, by = "week"), - wwtp_jurisdiction = states_to_generate, - site_info - ) |> + ww_data <- purrr::map_dfr(states_to_generate, site_info) |> + tidyr::expand_grid( + sample_collect_date = seq(start_reference, end_reference, by = "week") + ) |> dplyr::mutate( pcr_target_avg_conc = abs(rnorm(dplyr::n(), mean = 500, sd = 50)) ) From efead4e1d2c1f28e249ffdeae82c040cbf07e39f Mon Sep 17 00:00:00 2001 From: sbidari Date: Tue, 25 Feb 2025 15:50:47 -0500 Subject: [PATCH 23/34] nwss data handling, split end-to-end test by disease --- pipelines/forecast_state.py | 36 +++++++++------- pipelines/prep_data.py | 10 +++-- pipelines/tests/test_end_to_end.sh | 43 +++++++++++++++---- pyrenew_hew/pyrenew_hew_data.py | 12 ++++-- tests/test_pyrenew_wastewater_data.py | 60 --------------------------- 5 files changed, 73 insertions(+), 88 deletions(-) delete mode 100644 tests/test_pyrenew_wastewater_data.py diff --git a/pipelines/forecast_state.py b/pipelines/forecast_state.py index 719de5c4..5717336c 100644 --- a/pipelines/forecast_state.py +++ b/pipelines/forecast_state.py @@ -335,22 +335,28 @@ def main( "No data available for the requested report date " f"{report_date}" ) - available_nwss_reports = get_available_reports(nwss_data_dir) - # assming NWSS_vintage directory follows naming convention - # using as of date - # need to be modified otherwise - - if report_date in available_nwss_reports: - nwss_data_raw = pl.scan_parquet( - Path(nwss_data_dir, f"{report_date}.parquet") - ) - nwss_data_cleaned = clean_nwss_data(nwss_data_raw).filter( - (pl.col("location") == state) - & (pl.col("date") >= first_training_date) - ) - state_level_nwss_data = preprocess_ww_data(nwss_data_cleaned.collect()) + if fit_wastewater: + available_nwss_reports = get_available_reports(nwss_data_dir) + # assming NWSS_vintage directory follows naming convention + # using as of date + if report_date in available_nwss_reports: + nwss_data_raw = pl.scan_parquet( + Path(nwss_data_dir, f"{report_date}.parquet") + ) + nwss_data_cleaned = clean_nwss_data(nwss_data_raw).filter( + (pl.col("location") == state) + & (pl.col("date") >= first_training_date) + ) + state_level_nwss_data = preprocess_ww_data( + nwss_data_cleaned.collect() + ) + else: + raise ValueError( + "NWSS data not available for the requested report date " + f"{report_date}" + ) else: - state_level_nwss_data = None ## TO DO: change + state_level_nwss_data = None param_estimates = pl.scan_parquet(Path(param_data_dir, "prod.parquet")) model_batch_dir_name = ( diff --git a/pipelines/prep_data.py b/pipelines/prep_data.py index 00cf46c3..d57a6ce8 100644 --- a/pipelines/prep_data.py +++ b/pipelines/prep_data.py @@ -463,13 +463,17 @@ def process_and_save_state( else: subpop_sizes = ( state_level_nwss_data.select(["site_index", "site", "site_pop"]) - .unique()["site_pop"] + .unique() + .get_column("site_pop") .to_numpy() ) if state_pop > sum(subpop_sizes): pop_fraction = ( - [state_pop - sum(subpop_sizes)] + subpop_sizes - ) / state_pop + jnp.concatenate( + (jnp.array([state_pop - sum(subpop_sizes)]), subpop_sizes) + ) + / state_pop + ) else: pop_fraction = subpop_sizes / state_pop diff --git a/pipelines/tests/test_end_to_end.sh b/pipelines/tests/test_end_to_end.sh index eee0582d..75420052 100755 --- a/pipelines/tests/test_end_to_end.sh +++ b/pipelines/tests/test_end_to_end.sh @@ -16,13 +16,11 @@ if [ $? -ne 0 ]; then else echo "TEST-MODE: Finished generating test data" fi -echo "TEST-MODE: Running forecasting pipeline for two diseases in multiple locations" -for state in CA MT US +echo "TEST-MODE: Running forecasting pipeline for COVID-19 in multiple states" +for state in CA MT do - for disease in COVID-19 Influenza - do python pipelines/forecast_state.py \ - --disease $disease \ + --disease COVID-19 \ --state $state \ --facility-level-nssp-data-dir "$BASE_DIR/private_data/nssp_etl_gold" \ --state-level-nssp-data-dir "$BASE_DIR/private_data/nssp_state_level_gold" \ @@ -46,9 +44,40 @@ do echo "TEST-MODE FAIL: Forecasting/postprocessing/scoring pipeline failed" exit 1 else - echo "TEST-MODE: Finished forecasting/postprocessing/scoring pipeline for disease" $disease "in location" $state"." + echo "TEST-MODE: Finished forecasting/postprocessing/scoring pipeline for COVID-19 in location" $state"." + fi +done + +echo "TEST-MODE: Running forecasting pipeline for Influenza in multiple states" +for state in CA MT US +do + python pipelines/forecast_state.py \ + --disease Influenza \ + --state $state \ + --facility-level-nssp-data-dir "$BASE_DIR/private_data/nssp_etl_gold" \ + --state-level-nssp-data-dir "$BASE_DIR/private_data/nssp_state_level_gold" \ + --priors-path pipelines/priors/prod_priors.py \ + --param-data-dir "$BASE_DIR/private_data/prod_param_estimates" \ + --nwss-data-dir "$BASE_DIR/private_data/nwss_vintages" \ + --output-dir "$BASE_DIR/private_data" \ + --n-training-days 60 \ + --n-chains 2 \ + --n-samples 250 \ + --n-warmup 250 \ + --fit-ed-visits \ + --fit-hospital-admissions \ + --no-fit-wastewater \ + --forecast-ed-visits \ + --forecast-hospital-admissions \ + --no-forecast-wastewater \ + --score \ + --eval-data-path "$BASE_DIR/private_data/nssp-etl" + if [ $? -ne 0 ]; then + echo "TEST-MODE FAIL: Forecasting/postprocessing/scoring pipeline failed" + exit 1 + else + echo "TEST-MODE: Finished forecasting/postprocessing/scoring pipeline for Influenza in location" $state"." fi - done done echo "TEST-MODE: pipeline runs complete for all location/disease pairs." diff --git a/pyrenew_hew/pyrenew_hew_data.py b/pyrenew_hew/pyrenew_hew_data.py index 93be55f5..9023297c 100644 --- a/pyrenew_hew/pyrenew_hew_data.py +++ b/pyrenew_hew/pyrenew_hew_data.py @@ -31,6 +31,7 @@ def __init__( ww_observed_subpops: ArrayLike = None, ww_observed_times: ArrayLike = None, ww_observed_lab_sites: ArrayLike = None, + lab_site_to_subpop_map: ArrayLike = None, ) -> None: self.n_ed_visits_data_days_ = n_ed_visits_data_days self.n_hospital_admissions_data_days_ = n_hospital_admissions_data_days @@ -119,9 +120,13 @@ def __init__( ) self.lab_site_to_subpop_map = ( - None - if wastewater_data is None - else wastewater_data.lab_site_to_subpop_map + lab_site_to_subpop_map + if lab_site_to_subpop_map is not None + else ( + None + if wastewater_data is None + else wastewater_data.lab_site_to_subpop_map + ) ) @property @@ -280,4 +285,5 @@ def to_forecast_data(self, n_forecast_points: int) -> Self: ww_observed_lab_sites=self.ww_observed_lab_sites, ww_observed_subpops=self.ww_observed_subpops, ww_observed_times=self.ww_observed_times, + lab_site_to_subpop_map=self.lab_site_to_subpop_map, ) diff --git a/tests/test_pyrenew_wastewater_data.py b/tests/test_pyrenew_wastewater_data.py deleted file mode 100644 index 19c4a4a2..00000000 --- a/tests/test_pyrenew_wastewater_data.py +++ /dev/null @@ -1,60 +0,0 @@ -import datetime - -import jax.numpy as jnp -import numpy as np -import polars as pl - -from pyrenew_hew.pyrenew_hew_data import PyrenewHEWData -from pyrenew_hew.pyrenew_wastewater_data import PyrenewWastewaterData - - -def test_wastewater_data_properties(): - first_training_date = datetime.date(2023, 1, 1) - last_training_date = datetime.date(2023, 7, 23) - dates = pl.date_range( - first_training_date, - last_training_date, - interval="1w", - closed="both", - eager=True, - ) - - ww_raw = pl.DataFrame( - { - "date": dates.extend(dates), - "site": ["200"] * 30 + ["100"] * 30, - "lab": ["21"] * 60, - "lab_site_index": [1] * 30 + [2] * 30, - "site_index": [1] * 30 + [2] * 30, - "log_genome_copies_per_ml": np.log( - np.abs(np.random.normal(loc=500, scale=50, size=60)) - ), - "log_lod": np.log([20] * 30 + [15] * 30), - "site_pop": [200000] * 30 + [400000] * 30, - } - ) - - ww_data = ww_raw.with_columns( - below_lod=pl.col("log_genome_copies_per_ml") <= pl.col("log_lod") - ) - - data = PyrenewHEWData( - wastewater_data=PyrenewWastewaterData( - data_observed_disease_wastewater=ww_data, - population_size=1e6, - pop_fraction=[0.4, 0.2, 0.4], - ), - ) - - assert jnp.array_equal( - data.data_observed_disease_wastewater_conc, - ww_data["log_genome_copies_per_ml"], - ) - assert len(data.ww_censored) == len( - ww_data.filter(pl.col("below_lod") == 1) - ) - assert len(data.ww_uncensored) == len( - ww_data.filter(pl.col("below_lod") == 0) - ) - assert jnp.array_equal(data.ww_log_lod, ww_data["log_lod"]) - assert data.n_ww_lab_sites == ww_data["lab_site_index"].n_unique() From a6bd1c545fd2abfd779af8e8b2489de573ae7b6e Mon Sep 17 00:00:00 2001 From: sbidari Date: Tue, 25 Feb 2025 16:02:00 -0500 Subject: [PATCH 24/34] remove unused imports --- pipelines/prep_data.py | 1 - pipelines/tests/test_build_pyrenew_model.py | 2 +- pipelines/tests/test_end_to_end.sh | 3 +-- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/pipelines/prep_data.py b/pipelines/prep_data.py index d57a6ce8..2649cd7d 100644 --- a/pipelines/prep_data.py +++ b/pipelines/prep_data.py @@ -9,7 +9,6 @@ import forecasttools import jax.numpy as jnp -import numpy as np import polars as pl import polars.selectors as cs diff --git a/pipelines/tests/test_build_pyrenew_model.py b/pipelines/tests/test_build_pyrenew_model.py index a9a0613d..008ac8c5 100644 --- a/pipelines/tests/test_build_pyrenew_model.py +++ b/pipelines/tests/test_build_pyrenew_model.py @@ -117,4 +117,4 @@ def test_build_model_from_dir(tmp_path, mock_data, mock_priors): data.data_observed_disease_hospital_admissions, jnp.array(model_data["data_observed_disease_hospital_admissions"]), ) - # assert data.data_observed_disease_wastewater_conc is not None + assert data.data_observed_disease_wastewater_conc is not None diff --git a/pipelines/tests/test_end_to_end.sh b/pipelines/tests/test_end_to_end.sh index 75420052..9c98fd4d 100755 --- a/pipelines/tests/test_end_to_end.sh +++ b/pipelines/tests/test_end_to_end.sh @@ -58,14 +58,13 @@ do --state-level-nssp-data-dir "$BASE_DIR/private_data/nssp_state_level_gold" \ --priors-path pipelines/priors/prod_priors.py \ --param-data-dir "$BASE_DIR/private_data/prod_param_estimates" \ - --nwss-data-dir "$BASE_DIR/private_data/nwss_vintages" \ --output-dir "$BASE_DIR/private_data" \ --n-training-days 60 \ --n-chains 2 \ --n-samples 250 \ --n-warmup 250 \ --fit-ed-visits \ - --fit-hospital-admissions \ + --no-fit-hospital-admissions \ --no-fit-wastewater \ --forecast-ed-visits \ --forecast-hospital-admissions \ From 58d6342c6009f86a06fa6b3db9665632a86c3afc Mon Sep 17 00:00:00 2001 From: sbidari Date: Tue, 25 Feb 2025 19:12:22 -0500 Subject: [PATCH 25/34] end-to-end run test for covid-19 in US --- pipelines/tests/test_end_to_end.sh | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/pipelines/tests/test_end_to_end.sh b/pipelines/tests/test_end_to_end.sh index 9c98fd4d..9c7730bc 100755 --- a/pipelines/tests/test_end_to_end.sh +++ b/pipelines/tests/test_end_to_end.sh @@ -48,6 +48,35 @@ do fi done +echo "TEST-MODE: Running forecasting pipeline for COVID-19 in US" +python pipelines/forecast_state.py \ + --disease COVID-19 \ + --state US \ + --facility-level-nssp-data-dir "$BASE_DIR/private_data/nssp_etl_gold" \ + --state-level-nssp-data-dir "$BASE_DIR/private_data/nssp_state_level_gold" \ + --priors-path pipelines/priors/prod_priors.py \ + --param-data-dir "$BASE_DIR/private_data/prod_param_estimates" \ + --nwss-data-dir "$BASE_DIR/private_data/nwss_vintages" \ + --output-dir "$BASE_DIR/private_data" \ + --n-training-days 60 \ + --n-chains 2 \ + --n-samples 250 \ + --n-warmup 250 \ + --fit-ed-visits \ + --fit-hospital-admissions \ + --no-fit-wastewater \ + --forecast-ed-visits \ + --forecast-hospital-admissions \ + --no-forecast-wastewater \ + --score \ + --eval-data-path "$BASE_DIR/private_data/nssp-etl" +if [ $? -ne 0 ]; then + echo "TEST-MODE FAIL: Forecasting/postprocessing/scoring pipeline failed" + exit 1 +else + echo "TEST-MODE: Finished forecasting/postprocessing/scoring pipeline for COVID-19 in location US." +fi + echo "TEST-MODE: Running forecasting pipeline for Influenza in multiple states" for state in CA MT US do From ca9ba7efc86ff533de901f19627dbc4784446828 Mon Sep 17 00:00:00 2001 From: sbidari Date: Tue, 25 Feb 2025 19:15:43 -0500 Subject: [PATCH 26/34] update create_more_model_test_data to exclude model_run_dir for covid --- pipelines/tests/create_more_model_test_data.R | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pipelines/tests/create_more_model_test_data.R b/pipelines/tests/create_more_model_test_data.R index ff0e90a1..fc2e1d30 100644 --- a/pipelines/tests/create_more_model_test_data.R +++ b/pipelines/tests/create_more_model_test_data.R @@ -13,10 +13,6 @@ try(source("pipelines/plot_and_save_state_forecast.R")) try(source("pipelines/score_forecast.R")) model_batch_dirs <- c( - path( - "pipelines/tests/private_data", - "covid-19_r_2024-12-21_f_2024-10-22_t_2024-12-20" - ), path( "pipelines/tests/private_data", "influenza_r_2024-12-21_f_2024-10-22_t_2024-12-20" From 2f54caa1bb862de545bedb5234b30fe417cd6805 Mon Sep 17 00:00:00 2001 From: Subekshya Bidari <37636707+sbidari@users.noreply.github.com> Date: Wed, 26 Feb 2025 12:13:27 -0500 Subject: [PATCH 27/34] Update pipelines/generate_test_data.R Co-authored-by: Dylan H. Morris --- pipelines/generate_test_data.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipelines/generate_test_data.R b/pipelines/generate_test_data.R index 7de4d9b2..6d302a81 100644 --- a/pipelines/generate_test_data.R +++ b/pipelines/generate_test_data.R @@ -392,7 +392,7 @@ main <- function(private_data_dir, target_diseases = short_target_diseases ) generate_fake_nwss_data( - private_data_dir, + private_data_dir ) } From f1fccac4288bd8d15f46fcd51a9d135daed78e45 Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 27 Feb 2025 13:28:08 -0500 Subject: [PATCH 28/34] add PyRenewWasteWater args as a dict to PyrenewHewData --- pipelines/build_pyrenew_model.py | 4 +- pyrenew_hew/pyrenew_hew_data.py | 97 +++++--------------------- pyrenew_hew/pyrenew_wastewater_data.py | 13 +++- tests/test_pyrenew_hew_data.py | 2 + tests/test_pyrenew_wastewater_data.py | 74 ++++++++++++++++++++ 5 files changed, 105 insertions(+), 85 deletions(-) create mode 100644 tests/test_pyrenew_wastewater_data.py diff --git a/pipelines/build_pyrenew_model.py b/pipelines/build_pyrenew_model.py index 914299ec..565c9b13 100644 --- a/pipelines/build_pyrenew_model.py +++ b/pipelines/build_pyrenew_model.py @@ -199,7 +199,6 @@ def build_model_from_dir( wastewater_data = PyrenewWastewaterData( data_observed_disease_wastewater=data_observed_disease_wastewater, population_size=population_size, - pop_fraction=pop_fraction, ) dat = PyrenewHEWData( @@ -210,7 +209,8 @@ def build_model_from_dir( right_truncation_offset=right_truncation_offset, first_ed_visits_date=first_ed_visits_date, first_hospital_admissions_date=first_hospital_admissions_date, - wastewater_data=wastewater_data, + pop_fraction=pop_fraction, + **wastewater_data.to_pyrenew_hew_data_args(), ) return (mod, dat) diff --git a/pyrenew_hew/pyrenew_hew_data.py b/pyrenew_hew/pyrenew_hew_data.py index 9023297c..f8af8caf 100644 --- a/pyrenew_hew/pyrenew_hew_data.py +++ b/pyrenew_hew/pyrenew_hew_data.py @@ -24,7 +24,6 @@ def __init__( first_ed_visits_date: datetime.date = None, first_hospital_admissions_date: datetime.date = None, first_wastewater_date: datetime.date = None, - wastewater_data: PyrenewWastewaterData = None, n_ww_lab_sites: int = None, ww_censored: ArrayLike = None, ww_uncensored: ArrayLike = None, @@ -32,6 +31,10 @@ def __init__( ww_observed_times: ArrayLike = None, ww_observed_lab_sites: ArrayLike = None, lab_site_to_subpop_map: ArrayLike = None, + ww_log_lod: ArrayLike = None, + date_observed_disease_wastewater: ArrayLike = None, + data_observed_disease_wastewater_conc: ArrayLike = None, + pop_fraction: ArrayLike = jnp.array([1]), ) -> None: self.n_ed_visits_data_days_ = n_ed_visits_data_days self.n_hospital_admissions_data_days_ = n_hospital_admissions_data_days @@ -45,89 +48,20 @@ def __init__( self.first_hospital_admissions_date = first_hospital_admissions_date self.first_wastewater_date_ = first_wastewater_date self.date_observed_disease_wastewater = ( - None - if wastewater_data is None - else wastewater_data.date_observed_disease_wastewater - ) - self.pop_fraction = ( - jnp.array([1]) - if wastewater_data is None - else jnp.array(wastewater_data.pop_fraction) + date_observed_disease_wastewater ) + self.pop_fraction = pop_fraction self.data_observed_disease_wastewater_conc = ( - None - if wastewater_data is None - else wastewater_data.data_observed_disease_wastewater_conc - ) - self.ww_censored = ( - ww_censored - if ww_censored is not None - else ( - None - if wastewater_data is None - else wastewater_data.ww_censored - ) - ) - self.ww_uncensored = ( - ww_uncensored - if ww_uncensored is not None - else ( - None - if wastewater_data is None - else wastewater_data.ww_uncensored - ) - ) - self.ww_observed_times = ( - ww_observed_times - if ww_observed_times is not None - else ( - None - if wastewater_data is None - else wastewater_data.ww_observed_times - ) - ) - self.ww_observed_subpops = ( - ww_observed_subpops - if ww_observed_subpops is not None - else ( - None - if wastewater_data is None - else wastewater_data.ww_observed_subpops - ) - ) - - self.ww_observed_lab_sites = ( - ww_observed_lab_sites - if ww_observed_lab_sites is not None - else ( - None - if wastewater_data is None - else wastewater_data.ww_observed_lab_sites - ) - ) - - self.ww_log_lod = ( - None if wastewater_data is None else wastewater_data.ww_log_lod - ) - self.n_ww_lab_sites = ( - n_ww_lab_sites - if n_ww_lab_sites is not None - else ( - None - if wastewater_data is None - else wastewater_data.n_ww_lab_sites - ) - ) - - self.lab_site_to_subpop_map = ( - lab_site_to_subpop_map - if lab_site_to_subpop_map is not None - else ( - None - if wastewater_data is None - else wastewater_data.lab_site_to_subpop_map - ) + data_observed_disease_wastewater_conc ) + self.ww_censored = ww_censored + self.ww_uncensored = ww_uncensored + self.ww_observed_times = ww_observed_times + self.ww_observed_subpops = ww_observed_subpops + self.ww_observed_lab_sites = ww_observed_lab_sites + self.ww_log_lod = ww_log_lod + self.n_ww_lab_sites = n_ww_lab_sites + self.lab_site_to_subpop_map = lab_site_to_subpop_map @property def n_ed_visits_data_days(self): @@ -286,4 +220,5 @@ def to_forecast_data(self, n_forecast_points: int) -> Self: ww_observed_subpops=self.ww_observed_subpops, ww_observed_times=self.ww_observed_times, lab_site_to_subpop_map=self.lab_site_to_subpop_map, + data_observed_disease_wastewater_conc=None, ) diff --git a/pyrenew_hew/pyrenew_wastewater_data.py b/pyrenew_hew/pyrenew_wastewater_data.py index cc0d027d..5df556d7 100644 --- a/pyrenew_hew/pyrenew_wastewater_data.py +++ b/pyrenew_hew/pyrenew_wastewater_data.py @@ -17,13 +17,11 @@ def __init__( self, data_observed_disease_wastewater: pl.DataFrame = None, population_size: int = None, - pop_fraction: ArrayLike = jnp.array([1]), ) -> None: self.data_observed_disease_wastewater = ( data_observed_disease_wastewater ) self.population_size = population_size - self.pop_fraction = pop_fraction @property def site_subpop_spine(self): @@ -191,3 +189,14 @@ def lab_site_to_subpop_map(self): .get_column("subpop_index") .to_numpy() ) + + def to_pyrenew_hew_data_args(self): + return { + attr: value + for attr, value in ( + (attr, getattr(self, attr)) + for attr, prop in self.__class__.__dict__.items() + if isinstance(prop, property) + ) + if isinstance(value, ArrayLike) + } diff --git a/tests/test_pyrenew_hew_data.py b/tests/test_pyrenew_hew_data.py index 8b22bef5..cebff9fd 100644 --- a/tests/test_pyrenew_hew_data.py +++ b/tests/test_pyrenew_hew_data.py @@ -73,6 +73,7 @@ def test_to_forecast_data( """ Test the to_forecast_data method """ + ww_dat = PyrenewWastewaterData() data = PyrenewHEWData( n_ed_visits_data_days=n_ed_visits_data_days, @@ -80,6 +81,7 @@ def test_to_forecast_data( first_ed_visits_date=first_ed_visits_date, first_hospital_admissions_date=first_hospital_admissions_date, right_truncation_offset=right_truncation_offset, + **ww_dat.to_pyrenew_hew_data_args(), ) assert data.right_truncation_offset == right_truncation_offset diff --git a/tests/test_pyrenew_wastewater_data.py b/tests/test_pyrenew_wastewater_data.py new file mode 100644 index 00000000..58056d8d --- /dev/null +++ b/tests/test_pyrenew_wastewater_data.py @@ -0,0 +1,74 @@ +import datetime + +import numpy as np +import polars as pl + +from pyrenew_hew.pyrenew_hew_data import PyrenewHEWData +from pyrenew_hew.pyrenew_wastewater_data import PyrenewWastewaterData + + +def test_pyrenew_wastewater_data(): + first_training_date = datetime.date(2023, 1, 1) + last_training_date = datetime.date(2023, 7, 23) + dates = pl.date_range( + first_training_date, + last_training_date, + interval="1w", + closed="both", + eager=True, + ) + + ww_raw = pl.DataFrame( + { + "date": dates.extend(dates), + "site": [200] * 30 + [100] * 30, + "lab": [21] * 60, + "lab_site_index": [1] * 30 + [2] * 30, + "site_index": [1] * 30 + [2] * 30, + "log_genome_copies_per_ml": np.log( + np.abs(np.random.normal(loc=500, scale=50, size=60)) + ), + "log_lod": np.log([20] * 30 + [15] * 30), + "site_pop": [200_000] * 30 + [400_000] * 30, + } + ) + + ww_data = ww_raw.with_columns( + (pl.col("log_genome_copies_per_ml") <= pl.col("log_lod")) + .cast(pl.Int8) + .alias("below_lod") + ) + + first_ed_visits_date = datetime.date(2023, 1, 1) + first_hospital_admissions_date = datetime.date(2023, 1, 1) + first_wastewater_date = datetime.date(2023, 1, 1) + n_forecast_points = 10 + + wastewater_data = PyrenewWastewaterData( + data_observed_disease_wastewater=ww_data, + population_size=1e6, + ) + + data = PyrenewHEWData( + first_ed_visits_date=first_ed_visits_date, + first_hospital_admissions_date=first_hospital_admissions_date, + first_wastewater_date=first_wastewater_date, + **wastewater_data.to_pyrenew_hew_data_args(), + ) + + forecast_data = data.to_forecast_data(n_forecast_points) + assert forecast_data.data_observed_disease_wastewater_conc is None + assert data.data_observed_disease_wastewater_conc is not None + + assert np.array_equal( + data.data_observed_disease_wastewater_conc, + ww_data["log_genome_copies_per_ml"], + ) + assert len(data.ww_censored) == len( + ww_data.filter(pl.col("below_lod") == 1) + ) + assert len(data.ww_uncensored) == len( + ww_data.filter(pl.col("below_lod") == 0) + ) + assert np.array_equal(data.ww_log_lod, ww_data["log_lod"]) + assert data.n_ww_lab_sites == ww_data["lab_site_index"].n_unique() From 14ba0bb62b69b31702fb2d7a2c738999cc00eb35 Mon Sep 17 00:00:00 2001 From: Subekshya Bidari <37636707+sbidari@users.noreply.github.com> Date: Thu, 27 Feb 2025 17:25:59 -0500 Subject: [PATCH 29/34] Apply suggestions from code review Co-authored-by: Dylan H. Morris --- pyrenew_hew/pyrenew_hew_model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index 97b766f3..3eeb09cb 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -5,7 +5,6 @@ import jax.numpy as jnp import numpyro import numpyro.distributions as dist -import numpyro.distributions.transforms as transforms import pyrenew.transformation as transformation from jax.typing import ArrayLike from numpyro.infer.reparam import LocScaleReparam @@ -121,7 +120,7 @@ def sample(self, n_days_post_init: int): log_rtu_weekly_subpop = log_rtu_weekly[:, jnp.newaxis] else: i_first_obs_over_n_ref_subpop = transformation.SigmoidTransform()( - transforms.logit(i0_first_obs_n) + transformation.SigmoidTransform().inv(i0_first_obs_n) + self.offset_ref_logit_i_first_obs_rv(), ) # Using numpyro.distributions.transform as 'pyrenew.transformation' has no attribute 'logit' initial_exp_growth_rate_ref_subpop = ( @@ -137,7 +136,7 @@ def sample(self, n_days_post_init: int): DistributionalVariable( "i_first_obs_over_n_non_ref_subpop_raw", dist.Normal( - transforms.logit(i0_first_obs_n), + transformation.SigmoidTransform().inv(i0_first_obs_n), self.sigma_i_first_obs_rv(), ), reparam=LocScaleReparam(0), From abfdffb3aed9b512a30ff06d8be3e139e23edfc2 Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 27 Feb 2025 17:27:55 -0500 Subject: [PATCH 30/34] coerce population_size to be an integer --- pipelines/build_pyrenew_model.py | 2 +- pyrenew_hew/pyrenew_wastewater_data.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pipelines/build_pyrenew_model.py b/pipelines/build_pyrenew_model.py index 565c9b13..9ecc26b9 100644 --- a/pipelines/build_pyrenew_model.py +++ b/pipelines/build_pyrenew_model.py @@ -103,7 +103,7 @@ def build_model_from_dir( else None ) - population_size = jnp.array(model_data["state_pop"]) + population_size = jnp.array(model_data["state_pop"]).item() pop_fraction = jnp.array(model_data["pop_fraction"]) diff --git a/pyrenew_hew/pyrenew_wastewater_data.py b/pyrenew_hew/pyrenew_wastewater_data.py index 5df556d7..a7c846bf 100644 --- a/pyrenew_hew/pyrenew_wastewater_data.py +++ b/pyrenew_hew/pyrenew_wastewater_data.py @@ -51,7 +51,7 @@ def site_subpop_spine(self): "site_pop": ( self.population_size - site_indices.get_column("site_pop").sum() - ).tolist(), + ), } ) else: From e63d3c988f948f1f5fd4ed17dd38af4480e8b440 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Thu, 27 Feb 2025 17:50:28 -0500 Subject: [PATCH 31/34] Suggestion for how to handle the population calculation (#365) * Adjust to fix test * Update pyrenew_hew/pyrenew_wastewater_data.py * Hack fix * pre-commit --------- Co-authored-by: sbidari --- pyrenew_hew/pyrenew_wastewater_data.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/pyrenew_hew/pyrenew_wastewater_data.py b/pyrenew_hew/pyrenew_wastewater_data.py index a7c846bf..b6d23efc 100644 --- a/pyrenew_hew/pyrenew_wastewater_data.py +++ b/pyrenew_hew/pyrenew_wastewater_data.py @@ -27,15 +27,6 @@ def __init__( def site_subpop_spine(self): ww_data_present = self.data_observed_disease_wastewater is not None if ww_data_present: - add_auxiliary_subpop = ( - self.population_size - > self.data_observed_disease_wastewater.select( - pl.col("site_pop", "site", "lab", "lab_site_index") - ) - .unique() - .get_column("site_pop") - .sum() - ) site_indices = ( self.data_observed_disease_wastewater.select( ["site_index", "site", "site_pop"] @@ -43,15 +34,24 @@ def site_subpop_spine(self): .unique() .sort("site_index") ) + + total_pop_ww = ( + self.data_observed_disease_wastewater.unique( + ["site_pop", "site"] + ) + .get_column("site_pop") + .sum() + ) + + total_pop_no_ww = self.population_size - total_pop_ww + add_auxiliary_subpop = total_pop_no_ww > 0 + if add_auxiliary_subpop: aux_subpop = pl.DataFrame( { "site_index": [None], "site": [None], - "site_pop": ( - self.population_size - - site_indices.get_column("site_pop").sum() - ), + "site_pop": [total_pop_no_ww], } ) else: From 8899552ab05d335a2ed7f3fd4ef8f06316670c07 Mon Sep 17 00:00:00 2001 From: sbidari Date: Fri, 28 Feb 2025 15:41:52 -0500 Subject: [PATCH 32/34] add missing args to to_forecast_data, code review suggestions --- pyrenew_hew/pyrenew_hew_data.py | 4 +++- pyrenew_hew/pyrenew_wastewater_data.py | 19 ++++++++++++------- tests/test_pyrenew_wastewater_data.py | 15 +++++++++++++++ 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/pyrenew_hew/pyrenew_hew_data.py b/pyrenew_hew/pyrenew_hew_data.py index f8af8caf..91c53078 100644 --- a/pyrenew_hew/pyrenew_hew_data.py +++ b/pyrenew_hew/pyrenew_hew_data.py @@ -34,7 +34,7 @@ def __init__( ww_log_lod: ArrayLike = None, date_observed_disease_wastewater: ArrayLike = None, data_observed_disease_wastewater_conc: ArrayLike = None, - pop_fraction: ArrayLike = jnp.array([1]), + pop_fraction: ArrayLike = None, ) -> None: self.n_ed_visits_data_days_ = n_ed_visits_data_days self.n_hospital_admissions_data_days_ = n_hospital_admissions_data_days @@ -220,5 +220,7 @@ def to_forecast_data(self, n_forecast_points: int) -> Self: ww_observed_subpops=self.ww_observed_subpops, ww_observed_times=self.ww_observed_times, lab_site_to_subpop_map=self.lab_site_to_subpop_map, + ww_log_lod=self.ww_log_lod, + pop_fraction=self.pop_fraction, data_observed_disease_wastewater_conc=None, ) diff --git a/pyrenew_hew/pyrenew_wastewater_data.py b/pyrenew_hew/pyrenew_wastewater_data.py index b6d23efc..bc0c9471 100644 --- a/pyrenew_hew/pyrenew_wastewater_data.py +++ b/pyrenew_hew/pyrenew_wastewater_data.py @@ -192,11 +192,16 @@ def lab_site_to_subpop_map(self): def to_pyrenew_hew_data_args(self): return { - attr: value - for attr, value in ( - (attr, getattr(self, attr)) - for attr, prop in self.__class__.__dict__.items() - if isinstance(prop, property) - ) - if isinstance(value, ArrayLike) + attr: getattr(self, attr) + for attr in [ + "n_ww_lab_sites", + "ww_censored", + "ww_uncensored", + "ww_log_lod", + "ww_observed_lab_sites", + "ww_observed_subpops", + "ww_observed_times", + "data_observed_disease_wastewater_conc", + "lab_site_to_subpop_map", + ] } diff --git a/tests/test_pyrenew_wastewater_data.py b/tests/test_pyrenew_wastewater_data.py index 58056d8d..0d9f51d9 100644 --- a/tests/test_pyrenew_wastewater_data.py +++ b/tests/test_pyrenew_wastewater_data.py @@ -60,6 +60,21 @@ def test_pyrenew_wastewater_data(): assert forecast_data.data_observed_disease_wastewater_conc is None assert data.data_observed_disease_wastewater_conc is not None + assert np.array_equal(data.ww_censored, forecast_data.ww_censored) + assert np.array_equal(data.ww_uncensored, forecast_data.ww_uncensored) + assert np.array_equal(data.ww_log_lod, forecast_data.ww_log_lod) + assert np.array_equal( + data.ww_observed_lab_sites, forecast_data.ww_observed_lab_sites + ) + assert np.array_equal( + data.ww_observed_subpops, forecast_data.ww_observed_subpops + ) + assert np.array_equal( + data.ww_observed_times, forecast_data.ww_observed_times + ) + assert np.array_equal(data.n_ww_lab_sites, forecast_data.n_ww_lab_sites) + assert np.array_equal(data.pop_fraction, forecast_data.pop_fraction) + assert np.array_equal( data.data_observed_disease_wastewater_conc, ww_data["log_genome_copies_per_ml"], From c7a1dfd11f76edb3acb0303de03877f88b6faf47 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 3 Mar 2025 12:26:42 -0500 Subject: [PATCH 33/34] Update pyrenew_hew/pyrenew_hew_model.py --- pyrenew_hew/pyrenew_hew_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index 3eeb09cb..255cb6cc 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -122,7 +122,6 @@ def sample(self, n_days_post_init: int): i_first_obs_over_n_ref_subpop = transformation.SigmoidTransform()( transformation.SigmoidTransform().inv(i0_first_obs_n) + self.offset_ref_logit_i_first_obs_rv(), - ) # Using numpyro.distributions.transform as 'pyrenew.transformation' has no attribute 'logit' initial_exp_growth_rate_ref_subpop = ( initial_exp_growth_rate + self.offset_ref_initial_exp_growth_rate_rv() From 5d52a5e83d989cef5c276cc9983f0ceabe602371 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 3 Mar 2025 12:30:15 -0500 Subject: [PATCH 34/34] Update pyrenew_hew/pyrenew_hew_model.py --- pyrenew_hew/pyrenew_hew_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index 255cb6cc..a0b8f520 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -122,6 +122,7 @@ def sample(self, n_days_post_init: int): i_first_obs_over_n_ref_subpop = transformation.SigmoidTransform()( transformation.SigmoidTransform().inv(i0_first_obs_n) + self.offset_ref_logit_i_first_obs_rv(), + ) initial_exp_growth_rate_ref_subpop = ( initial_exp_growth_rate + self.offset_ref_initial_exp_growth_rate_rv()