From 5c4727306dc19cbedfd3955495994dc625b15606 Mon Sep 17 00:00:00 2001 From: sbidari Date: Wed, 19 Feb 2025 17:39:15 -0500 Subject: [PATCH] cleaning --- pipelines/forecast_state.py | 1 - pipelines/prep_data.py | 46 +++++++++++++++------- pipelines/prep_ww_data.py | 3 -- pyrenew_hew/pyrenew_hew_data.py | 38 +++++++++++------- pyrenew_hew/pyrenew_hew_model.py | 67 ++++++++++++++++++++++---------- tests/test_pyrenew_hew_data.py | 13 +++++-- 6 files changed, 114 insertions(+), 54 deletions(-) diff --git a/pipelines/forecast_state.py b/pipelines/forecast_state.py index 8abef411..34f748f5 100644 --- a/pipelines/forecast_state.py +++ b/pipelines/forecast_state.py @@ -383,7 +383,6 @@ def main( last_training_date=last_training_date, param_estimates=param_estimates, model_run_dir=model_run_dir, - ww_data_dir=ww_data_dir, logger=logger, credentials_dict=credentials_dict, ) diff --git a/pipelines/prep_data.py b/pipelines/prep_data.py index 16f76cfa..2ee60a79 100644 --- a/pipelines/prep_data.py +++ b/pipelines/prep_data.py @@ -8,8 +8,6 @@ from pathlib import Path import forecasttools -import jax.numpy as jnp -import jax.numpy as jnp import polars as pl import polars.selectors as cs @@ -48,7 +46,9 @@ def py_scalar_to_r_scalar(py_scalar): state_abb_for_query = state_abb if state_abb != "US" else "USA" temp_file = Path(temp_dir, "nhsn_temp.parquet") - api_key_id = credentials_dict.get("nhsn_api_key_id", os.getenv("NHSN_API_KEY_ID")) + api_key_id = credentials_dict.get( + "nhsn_api_key_id", os.getenv("NHSN_API_KEY_ID") + ) api_key_secret = credentials_dict.get( "nhsn_api_key_secret", os.getenv("NHSN_API_KEY_SECRET") ) @@ -80,7 +80,9 @@ def py_scalar_to_r_scalar(py_scalar): if result.returncode != 0: raise RuntimeError(f"pull_and_save_nhsn: {result.stderr}") raw_dat = pl.read_parquet(temp_file) - dat = raw_dat.with_columns(weekendingdate=pl.col("weekendingdate").cast(pl.Date)) + dat = raw_dat.with_columns( + weekendingdate=pl.col("weekendingdate").cast(pl.Date) + ) return dat @@ -102,7 +104,9 @@ def combine_nssp_and_nhsn( variable_name="drop_me", value_name=".value", ) - .with_columns(pl.col("count_type").replace(count_type_dict).alias(".variable")) + .with_columns( + pl.col("count_type").replace(count_type_dict).alias(".variable") + ) .select(cs.exclude(["count_type", "drop_me"])) ) @@ -182,7 +186,9 @@ def process_state_level_data( if state_abb == "US": locations_to_aggregate = ( - state_pop_df.filter(pl.col("abb") != "US").get_column("abb").unique() + state_pop_df.filter(pl.col("abb") != "US") + .get_column("abb") + .unique() ) logger.info("Aggregating state-level data to national") state_level_nssp_data = aggregate_to_national( @@ -209,7 +215,9 @@ def process_state_level_data( ] ) .with_columns( - disease=pl.col("disease").cast(pl.Utf8).replace(_inverse_disease_map), + disease=pl.col("disease") + .cast(pl.Utf8) + .replace(_inverse_disease_map), ) .sort(["date", "disease"]) .collect(streaming=True) @@ -241,7 +249,9 @@ def aggregate_facility_level_nssp_to_state( if state_abb == "US": logger.info("Aggregating facility-level data to national") locations_to_aggregate = ( - state_pop_df.filter(pl.col("abb") != "US").get_column("abb").unique() + state_pop_df.filter(pl.col("abb") != "US") + .get_column("abb") + .unique() ) facility_level_nssp_data = aggregate_to_national( facility_level_nssp_data, @@ -260,7 +270,9 @@ def aggregate_facility_level_nssp_to_state( .group_by(["reference_date", "disease"]) .agg(pl.col("value").sum().alias("ed_visits")) .with_columns( - disease=pl.col("disease").cast(pl.Utf8).replace(_inverse_disease_map), + disease=pl.col("disease") + .cast(pl.Utf8) + .replace(_inverse_disease_map), geo_value=pl.lit(state_abb).cast(pl.Utf8), ) .rename({"reference_date": "date"}) @@ -350,12 +362,16 @@ def process_and_save_state( if facility_level_nssp_data is None and state_level_nssp_data is None: raise ValueError( - "Must provide at least one " "of facility-level and state-level" "NSSP data" + "Must provide at least one " + "of facility-level and state-level" + "NSSP data" ) state_pop_df = get_state_pop_df() - state_pop = state_pop_df.filter(pl.col("abb") == state_abb).item(0, "population") + state_pop = state_pop_df.filter(pl.col("abb") == state_abb).item( + 0, "population" + ) (generation_interval_pmf, delay_pmf, right_truncation_pmf) = get_pmfs( param_estimates=param_estimates, state_abb=state_abb, disease=disease @@ -404,13 +420,17 @@ def process_and_save_state( credentials_dict=credentials_dict, ).with_columns(pl.lit("train").alias("data_type")) - nssp_training_dates = nssp_training_data.get_column("date").unique().to_list() + nssp_training_dates = ( + nssp_training_data.get_column("date").unique().to_list() + ) nhsn_training_dates = ( nhsn_training_data.get_column("weekendingdate").unique().to_list() ) nhsn_first_date_index = next( - i for i, x in enumerate(nssp_training_dates) if x == min(nhsn_training_dates) + i + for i, x in enumerate(nssp_training_dates) + if x == min(nhsn_training_dates) ) nhsn_step_size = 7 diff --git a/pipelines/prep_ww_data.py b/pipelines/prep_ww_data.py index 513143c6..af52d022 100644 --- a/pipelines/prep_ww_data.py +++ b/pipelines/prep_ww_data.py @@ -1,6 +1,3 @@ -import datetime -from pathlib import Path - import polars as pl diff --git a/pyrenew_hew/pyrenew_hew_data.py b/pyrenew_hew/pyrenew_hew_data.py index 002a9cab..ec892848 100644 --- a/pyrenew_hew/pyrenew_hew_data.py +++ b/pyrenew_hew/pyrenew_hew_data.py @@ -40,7 +40,9 @@ def __init__( self.first_ed_visits_date = first_ed_visits_date self.first_hospital_admissions_date = first_hospital_admissions_date self.first_wastewater_date_ = first_wastewater_date - self.data_observed_disease_wastewater = data_observed_disease_wastewater + self.data_observed_disease_wastewater = ( + data_observed_disease_wastewater + ) self.population_size = population_size self.shedding_offset = shedding_offset self.pop_fraction_ = pop_fraction @@ -132,7 +134,9 @@ def last_data_date_overall(self): @property def n_days_post_init(self): - return (self.last_data_date_overall - self.first_data_date_overall).days + return ( + self.last_data_date_overall - self.first_data_date_overall + ).days @property def site_subpop_spine(self): @@ -173,10 +177,12 @@ def site_subpop_spine(self): site_subpop_spine = ( pl.concat([aux_subpop, site_indices], how="vertical_relaxed") .with_columns( - subpop_index=pl.col("site_index").cum_count().alias("subpop_index"), - subpop_name=pl.format("Site: {}", pl.col("site")).fill_null( - "remainder of population" - ), + subpop_index=pl.col("site_index") + .cum_count() + .alias("subpop_index"), + subpop_name=pl.format( + "Site: {}", pl.col("site") + ).fill_null("remainder of population"), ) .rename({"site_pop": "subpop_pop"}) ) @@ -233,22 +239,24 @@ def pop_fraction(self): @property def data_observed_disease_wastewater_conc(self): if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended["log_genome_copies_per_ml"].to_numpy() + return self.wastewater_data_extended[ + "log_genome_copies_per_ml" + ].to_numpy() @property def ww_censored(self): if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended.filter(pl.col("below_lod") == 1)[ - "ind_rel_to_observed_times" - ].to_numpy() + return self.wastewater_data_extended.filter( + pl.col("below_lod") == 1 + )["ind_rel_to_observed_times"].to_numpy() return None @property def ww_uncensored(self): if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended.filter(pl.col("below_lod") == 0)[ - "ind_rel_to_observed_times" - ].to_numpy() + return self.wastewater_data_extended.filter( + pl.col("below_lod") == 0 + )["ind_rel_to_observed_times"].to_numpy() @property def ww_observed_times(self): @@ -304,7 +312,9 @@ def get_end_date( ) result = None else: - result = first_date + datetime.timedelta(days=n_datapoints * timestep_days) + result = first_date + datetime.timedelta( + days=n_datapoints * timestep_days + ) return result def get_n_data_days( diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index a62151be..da2fc592 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -73,7 +73,9 @@ def __init__( self.autoreg_rt_subpop_rv = autoreg_rt_subpop_rv self.sigma_rt_rv = sigma_rt_rv self.sigma_i_first_obs_rv = sigma_i_first_obs_rv - self.sigma_initial_exp_growth_rate_rv = sigma_initial_exp_growth_rate_rv + self.sigma_initial_exp_growth_rate_rv = ( + sigma_initial_exp_growth_rate_rv + ) self.n_initialization_points = n_initialization_points self.pop_fraction = pop_fraction self.n_subpops = len(pop_fraction) @@ -122,10 +124,13 @@ def sample(self, n_days_post_init: int): + self.offset_ref_logit_i_first_obs_rv(), ) initial_exp_growth_rate_ref_subpop = ( - initial_exp_growth_rate + self.offset_ref_initial_exp_growth_rate_rv() + initial_exp_growth_rate + + self.offset_ref_initial_exp_growth_rate_rv() ) - log_rtu_weekly_ref_subpop = log_rtu_weekly + self.offset_ref_log_rt_rv() + log_rtu_weekly_ref_subpop = ( + log_rtu_weekly + self.offset_ref_log_rt_rv() + ) i_first_obs_over_n_non_ref_subpop_rv = TransformedVariable( "i_first_obs_over_n_non_ref_subpop", DistributionalVariable( @@ -207,7 +212,9 @@ def sample(self, n_days_post_init: int): )[:n_days_post_init, :] ) # indexed rel to first post-init day. - i0_subpop_rv = DeterministicVariable("i0_subpop", i_first_obs_over_n_subpop) + i0_subpop_rv = DeterministicVariable( + "i0_subpop", i_first_obs_over_n_subpop + ) initial_exp_growth_rate_subpop_rv = DeterministicVariable( "initial_exp_growth_rate_subpop", initial_exp_growth_rate_subpop ) @@ -312,9 +319,7 @@ def sample( iedr = jnp.repeat( transformation.SigmoidTransform()(p_ed_ar + p_ed_mean), repeats=7, - )[ - :n_datapoints - ] # indexed rel to first ed report day + )[:n_datapoints] # indexed rel to first ed report day # this is only applied after the ed visits are generated, not to all # the latent infections. This is why we cannot apply the iedr in # compute_delay_ascertained_incidence @@ -334,21 +339,28 @@ def sample( )[-n_datapoints:] latent_ed_visits_final = ( - potential_latent_ed_visits * iedr * ed_wday_effect * population_size + potential_latent_ed_visits + * iedr + * ed_wday_effect + * population_size ) if right_truncation_offset is not None: prop_already_reported_tail = jnp.flip( self.ed_right_truncation_cdf_rv()[right_truncation_offset:] ) - n_points_to_prepend = n_datapoints - prop_already_reported_tail.shape[0] + n_points_to_prepend = ( + n_datapoints - prop_already_reported_tail.shape[0] + ) prop_already_reported = jnp.pad( prop_already_reported_tail, (n_points_to_prepend, 0), mode="constant", constant_values=(1, 0), ) - latent_ed_visits_now = latent_ed_visits_final * prop_already_reported + latent_ed_visits_now = ( + latent_ed_visits_final * prop_already_reported + ) else: latent_ed_visits_now = latent_ed_visits_final @@ -374,7 +386,9 @@ def __init__( ihr_rel_iedr_rv: RandomVariable = None, ) -> None: self.inf_to_hosp_admit_rv = inf_to_hosp_admit_rv - self.hosp_admit_neg_bin_concentration_rv = hosp_admit_neg_bin_concentration_rv + self.hosp_admit_neg_bin_concentration_rv = ( + hosp_admit_neg_bin_concentration_rv + ) self.ihr_rv = ihr_rv self.ihr_rel_iedr_rv = ihr_rel_iedr_rv @@ -508,7 +522,10 @@ def normed_shedding_cdf( norm_const = (t_p + t_d) * ((log_base - 1) / jnp.log(log_base) - 1) def ad_pre(x): - return t_p / jnp.log(log_base) * jnp.exp(jnp.log(log_base) * x / t_p) - x + return ( + t_p / jnp.log(log_base) * jnp.exp(jnp.log(log_base) * x / t_p) + - x + ) def ad_post(x): return ( @@ -588,10 +605,12 @@ def sample( def batch_colvolve_fn(m): return jnp.convolve(m, viral_kinetics, mode="valid") - model_net_inf_ind_shedding = jax.vmap(batch_colvolve_fn, in_axes=1, out_axes=1)( - jnp.atleast_2d(latent_infections_subpop) - )[-n_datapoints:, :] - numpyro.deterministic("model_net_inf_ind_shedding", model_net_inf_ind_shedding) + model_net_inf_ind_shedding = jax.vmap( + batch_colvolve_fn, in_axes=1, out_axes=1 + )(jnp.atleast_2d(latent_infections_subpop))[-n_datapoints:, :] + numpyro.deterministic( + "model_net_inf_ind_shedding", model_net_inf_ind_shedding + ) log10_genome_per_inf_ind = self.log10_genome_per_inf_ind_rv() expected_obs_viral_genomes = ( @@ -599,7 +618,9 @@ def batch_colvolve_fn(m): + jnp.log(model_net_inf_ind_shedding + shedding_offset) - jnp.log(self.ww_ml_produced_per_day) ) - numpyro.deterministic("expected_obs_viral_genomes", expected_obs_viral_genomes) + numpyro.deterministic( + "expected_obs_viral_genomes", expected_obs_viral_genomes + ) mode_sigma_ww_site = self.mode_sigma_ww_site_rv() sd_log_sigma_ww_site = self.sd_log_sigma_ww_site_rv() @@ -638,7 +659,11 @@ def batch_colvolve_fn(m): scale=sigma_ww_site[ww_observed_lab_sites[ww_uncensored]], ), ).sample( - obs=(data_observed[ww_uncensored] if data_observed is not None else None), + obs=( + data_observed[ww_uncensored] + if data_observed is not None + else None + ), ) if ww_censored.shape[0] != 0: @@ -695,8 +720,10 @@ def sample( sample_wastewater: bool = False, ) -> dict[str, ArrayLike]: # numpydoc ignore=GL08 n_init_days = self.latent_infection_process_rv.n_initialization_points - latent_infections, latent_infections_subpop = self.latent_infection_process_rv( - n_days_post_init=data.n_days_post_init, + latent_infections, latent_infections_subpop = ( + self.latent_infection_process_rv( + n_days_post_init=data.n_days_post_init, + ) ) first_latent_infection_dow = ( data.first_data_date_overall - datetime.timedelta(days=n_init_days) diff --git a/tests/test_pyrenew_hew_data.py b/tests/test_pyrenew_hew_data.py index e3195f6a..99f42058 100644 --- a/tests/test_pyrenew_hew_data.py +++ b/tests/test_pyrenew_hew_data.py @@ -96,7 +96,10 @@ def test_to_forecast_data( assert forecast_data.n_hospital_admissions_data_days == n_weeks_expected assert forecast_data.right_truncation_offset is None assert forecast_data.first_ed_visits_date == data.first_data_date_overall - assert forecast_data.first_hospital_admissions_date == data.first_data_date_overall + assert ( + forecast_data.first_hospital_admissions_date + == data.first_data_date_overall + ) assert forecast_data.first_wastewater_date == data.first_data_date_overall assert forecast_data.data_observed_disease_wastewater is None @@ -154,7 +157,11 @@ def test_wastewater_data_properties(): data.data_observed_disease_wastewater_conc, ww_data["log_genome_copies_per_ml"], ) - assert len(data.ww_censored) == len(ww_data.filter(pl.col("below_lod") == 1)) - assert len(data.ww_uncensored) == len(ww_data.filter(pl.col("below_lod") == 0)) + assert len(data.ww_censored) == len( + ww_data.filter(pl.col("below_lod") == 1) + ) + assert len(data.ww_uncensored) == len( + ww_data.filter(pl.col("below_lod") == 0) + ) assert jnp.array_equal(data.ww_log_lod, ww_data["log_lod"]) assert data.n_ww_lab_sites == ww_data["lab_site_index"].n_unique()