diff --git a/pipelines/build_pyrenew_model.py b/pipelines/build_pyrenew_model.py index 951761b2..bf0b31b0 100644 --- a/pipelines/build_pyrenew_model.py +++ b/pipelines/build_pyrenew_model.py @@ -4,6 +4,7 @@ from pathlib import Path import jax.numpy as jnp +import polars as pl from pyrenew.deterministic import DeterministicVariable from pyrenew_hew.pyrenew_hew_data import PyrenewHEWData @@ -61,8 +62,14 @@ def build_model_from_dir( else None ) - # placeholder - data_observed_disease_wastewater = None if sample_wastewater else None + data_observed_disease_wastewater = ( + pl.DataFrame( + model_data["data_observed_disease_wastewater"], + schema_overrides={"date": pl.Date}, + ) + if sample_wastewater + else None + ) population_size = jnp.array(model_data["state_pop"]) @@ -148,6 +155,7 @@ def build_model_from_dir( right_truncation_offset=right_truncation_offset, first_ed_visits_date=first_ed_visits_date, first_hospital_admissions_date=first_hospital_admissions_date, + population_size=population_size, ) return (my_model, my_data) diff --git a/pipelines/forecast_state.py b/pipelines/forecast_state.py index da873c03..b9b5bec2 100644 --- a/pipelines/forecast_state.py +++ b/pipelines/forecast_state.py @@ -22,6 +22,7 @@ from generate_predictive import ( # noqa generate_and_save_predictions, ) +from prep_ww_data import clean_nwss_data, preprocess_ww_data def record_git_info(model_run_dir: Path): @@ -192,6 +193,7 @@ def main( state: str, facility_level_nssp_data_dir: Path | str, state_level_nssp_data_dir: Path | str, + nwss_data_dir: Path | str, param_data_dir: Path | str, priors_path: Path | str, output_dir: Path | str, @@ -333,6 +335,23 @@ def main( "No data available for the requested report date " f"{report_date}" ) + available_nwss_reports = get_available_reports(nwss_data_dir) + # assming NWSS_vintage directory follows naming convention + # using as of date + # need to be modified otherwise + + if report_date in available_nwss_reports: + nwss_data_raw = pl.scan_parquet( + Path(nwss_data_dir, f"{report_date}.parquet") + ) + nwss_data_cleaned = clean_nwss_data(nwss_data_raw).filter( + (pl.col("location") == state) + & (pl.col("date") >= first_training_date) + ) + state_level_nwss_data = preprocess_ww_data(nwss_data_cleaned) + else: + state_level_nwss_data = None ## TO DO: change + param_estimates = pl.scan_parquet(Path(param_data_dir, "prod.parquet")) model_batch_dir_name = ( f"{disease.lower()}_r_{report_date}_f_" @@ -357,6 +376,7 @@ def main( disease=disease, facility_level_nssp_data=facility_level_nssp_data, state_level_nssp_data=state_level_nssp_data, + state_level_nwss_data=state_level_nwss_data, report_date=report_date, first_training_date=first_training_date, last_training_date=last_training_date, @@ -497,6 +517,13 @@ def main( ), ) + parser.add_argument( + "--nwss-data-dir", + type=Path, + default=Path("private_data", "nwss_vintages"), + help=("Directory in which to look for NWSS data."), + ) + parser.add_argument( "--param-data-dir", type=Path, diff --git a/pipelines/prep_data.py b/pipelines/prep_data.py index 3b0a087c..14e657f9 100644 --- a/pipelines/prep_data.py +++ b/pipelines/prep_data.py @@ -8,6 +8,7 @@ from pathlib import Path import forecasttools +import jax.numpy as jnp import polars as pl import polars.selectors as cs @@ -354,6 +355,7 @@ def process_and_save_state( logger: Logger = None, facility_level_nssp_data: pl.LazyFrame = None, state_level_nssp_data: pl.LazyFrame = None, + state_level_nwss_data: pl.LazyFrame = None, credentials_dict: dict = None, ) -> None: logging.basicConfig(level=logging.INFO) @@ -449,6 +451,12 @@ def process_and_save_state( "hospital_admissions" ).to_list() + data_observed_disease_wastewater = ( + state_level_nwss_data.to_dict(as_series=False) + if state_level_nwss_data is not None + else None + ) + data_for_model_fit = { "inf_to_ed_pmf": delay_pmf, "generation_interval_pmf": generation_interval_pmf, @@ -462,7 +470,9 @@ def process_and_save_state( "nhsn_step_size": nhsn_step_size, "state_pop": state_pop, "right_truncation_offset": right_truncation_offset, + "data_observed_disease_wastewater": data_observed_disease_wastewater, } + data_dir = Path(model_run_dir, "data") os.makedirs(data_dir, exist_ok=True) diff --git a/pipelines/prep_ww_data.py b/pipelines/prep_ww_data.py new file mode 100644 index 00000000..513143c6 --- /dev/null +++ b/pipelines/prep_ww_data.py @@ -0,0 +1,288 @@ +import datetime +from pathlib import Path + +import polars as pl + + +def clean_nwss_data(nwss_data): + """ + Parameters + ---------- + nwss_data: + vintaged/pulled nwss data + + Return + ------ + A site-lab level dataset, filtered to only the columns we use + for model fitting + """ + nwss_subset = ( + nwss_data.filter( + pl.col("sample_location") == "wwtp", + pl.col("sample_matrix") != "primary sludge", + pl.col("pcr_target_units") != "copies/g dry sludge", + pl.col("pcr_target") == "sars-cov-2", + pl.col("lab_id").is_not_null(), + pl.col("wwtp_id").is_not_null(), + pl.col("lod_sewage").is_not_null(), + ) + .select( + [ + "lab_id", + "sample_collect_date", + "wwtp_id", + "pcr_target_avg_conc", + "wwtp_jurisdiction", + "population_served", + "pcr_target_units", + "lod_sewage", + "quality_flag", + ] + ) + .with_columns( + pcr_target_avg_conc=pl.when( + pl.col("pcr_target_units") == "copies/l wastewater" + ) + .then(pl.col("pcr_target_avg_conc") / 1000) + .when(pl.col("pcr_target_units") == "log10 copies/l wastewater") + .then((10 ** pl.col("pcr_target_avg_conc")) / 1000) + .otherwise(None), + lod_sewage=pl.when( + pl.col("pcr_target_units") == "copies/l wastewater" + ) + .then(pl.col("lod_sewage") / 1000) + .when(pl.col("pcr_target_units") == "log10 copies/l wastewater") + .then((10 ** pl.col("lod_sewage")) / 1000) + .otherwise(None), + ) + .filter( + ( + ~pl.col("quality_flag").is_in( + [ + "yes", + "y", + "result is not quantifiable", + "temperature not assessed upon arrival at the laboratory", + "> max temp and/or hold time", + ] + ) + ) + | (pl.col("quality_flag").is_null()) + ) + ).drop(["quality_flag", "pcr_target_units"]) + + # Remove if any exact duplicates of pcr_target_avg_conc + # values present for each combination of wwtp_id, lab_id, + # and sample_collect_date + nwss_subset_clean = nwss_subset.unique( + subset=[ + "sample_collect_date", + "wwtp_id", + "lab_id", + "pcr_target_avg_conc", + ] + ) + + # replaces time-varying population if present in the NWSS dataset. + # Model does not allow time varying population + nwss_subset_clean_pop = ( + nwss_subset_clean.group_by("wwtp_id") + .agg( + [ + pl.col("population_served") + .mean() + .round() + .cast(pl.Int64) + .alias("population_served") + ] + ) + .join(nwss_subset_clean, on=["wwtp_id"], how="left") + .select( + [ + "sample_collect_date", + "wwtp_id", + "lab_id", + "pcr_target_avg_conc", + "wwtp_jurisdiction", + "lod_sewage", + "population_served", + ] + ) + .unique( + [ + "wwtp_id", + "lab_id", + "sample_collect_date", + "pcr_target_avg_conc", + ] + ) + ) + + ww_data = ( + nwss_subset_clean_pop.rename( + { + "sample_collect_date": "date", + "population_served": "site_pop", + "wwtp_jurisdiction": "location", + "wwtp_id": "site", + "lab_id": "lab", + } + ) + .with_columns( + [ + pl.col("pcr_target_avg_conc") + .log() + .alias("log_genome_copies_per_ml"), + pl.col("lod_sewage").log().alias("log_lod"), + pl.col("location").str.to_uppercase().alias("location"), + pl.col("site").cast(pl.String).alias("site"), + pl.col("lab").cast(pl.String).alias("lab"), + ] + ) + .select( + [ + "date", + "site", + "lab", + "log_genome_copies_per_ml", + "log_lod", + "site_pop", + "location", + ] + ) + ) + return ww_data + + +def check_missing_values(df: pl.DataFrame, columns: list[str]): + """Raises an error if missing values in a given column(s).""" + missing_cols = [col for col in columns if df[col].has_nulls()] + if missing_cols: + raise ValueError(f"Missing values in column(s): {missing_cols}") + + +def validate_ww_conc_data( + ww_data: pl.DataFrame, + conc_col_name: str = "log_genome_copies_per_ml", + lod_col_name: str = "log_lod", + date_col_name: str = "date", + wwtp_col_name: str = "site", + wwtp_pop_name: str = "site_pop", + lab_col_name: str = "lab", +): + """ + Checks nwss data for missing values and data types. + """ + if ww_data.is_empty(): + raise ValueError("Input DataFrame 'ww_data' is empty.") + + required_cols = [ + conc_col_name, + lod_col_name, + date_col_name, + wwtp_col_name, + wwtp_pop_name, + lab_col_name, + ] + + assert all( + col in ww_data.columns for col in required_cols + ), "One or more required column(s) missing" + + check_missing_values( + ww_data, + required_cols, + ) + + assert ww_data[conc_col_name].dtype.is_float() + assert ww_data[lod_col_name].dtype.is_float() + assert ww_data[date_col_name].dtype == pl.Date + assert ww_data[wwtp_pop_name].dtype.is_integer() + assert ww_data[wwtp_col_name].dtype == pl.String() + assert ww_data[lab_col_name].dtype == pl.String() + + if (ww_data[wwtp_pop_name] < 0).any(): + raise ValueError("Site populations have negative values.") + + if ( + not ww_data.group_by(wwtp_col_name) + .n_unique() + .get_column(wwtp_pop_name) + .eq(1) + .all() + ): + raise ValueError( + "The data contains sites with varying population sizes." + ) + + return None + + +def preprocess_ww_data( + ww_data, + conc_col_name: str = "log_genome_copies_per_ml", + lod_col_name: str = "log_lod", + date_col_name: str = "date", + wwtp_col_name: str = "site", + wwtp_pop_name: str = "site_pop", + lab_col_name: str = "lab", +): + """ + Creates indices for wastewater-treatment plant names and + flag concentration data below the level of detection. + + """ + validate_ww_conc_data( + ww_data, + conc_col_name=conc_col_name, + lod_col_name=lod_col_name, + date_col_name=date_col_name, + ) + ww_data_ordered = ww_data.sort(by=wwtp_pop_name, descending=True) + lab_site_df = ( + ww_data_ordered.select([lab_col_name, wwtp_col_name]) + .unique() + .with_row_index("lab_site_index") + ) + site_df = ( + ww_data_ordered.select([wwtp_col_name]) + .unique() + .with_row_index("site_index") + ) + ww_preprocessed = ( + ww_data_ordered.join( + lab_site_df, on=[lab_col_name, wwtp_col_name], how="left" + ) + .join(site_df, on=wwtp_col_name, how="left") + .rename( + { + lod_col_name: "log_lod", + conc_col_name: "log_genome_copies_per_ml", + } + ) + .with_columns( + lab_site_name=( + "Site: " + + pl.col(wwtp_col_name) + + ", Lab: " + + pl.col(lab_col_name) + ), + below_lod=( + pl.col("log_genome_copies_per_ml") <= pl.col("log_lod") + ), + ) + .select( + [ + "date", + "site", + "lab", + "site_pop", + "site_index", + "lab_site_index", + "log_genome_copies_per_ml", + "log_lod", + "below_lod", + ] + ) + ) + return ww_preprocessed diff --git a/pipelines/tests/__init__.py b/pipelines/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pipelines/tests/test_build_pyrenew_model.py b/pipelines/tests/test_build_pyrenew_model.py index e8b5cdf1..648e31f8 100644 --- a/pipelines/tests/test_build_pyrenew_model.py +++ b/pipelines/tests/test_build_pyrenew_model.py @@ -1,7 +1,7 @@ import json -from pathlib import Path import jax.numpy as jnp +import polars as pl import pytest from pipelines.build_pyrenew_model import build_model_from_dir @@ -20,6 +20,22 @@ def mock_data(): "nssp_training_dates": ["2025-01-01"], "nhsn_training_dates": ["2025-01-02"], "right_truncation_offset": 10, + "data_observed_disease_wastewater": { + "date": [ + "2025-01-01", + "2025-01-01", + "2025-01-02", + "2025-01-02", + ], + "site": ["1.0", "1.0", "2.0", "2.0"], + "lab": ["1.0", "1.0", "1.0", "1.0"], + "site_pop": [4000000, 4000000, 2000000, 2000000], + "site_index": [1, 1, 0, 0], + "lab_site_index": [1, 1, 0, 0], + "log_genome_copies_per_ml": [0.1, 0.1, 0.5, 0.4], + "log_lod": [1.1, 2.0, 1.5, 2.1], + "below_lod": [False, False, False, False], + }, } ) @@ -86,5 +102,7 @@ def test_build_model_from_dir(tmp_path, mock_data, mock_priors): data.data_observed_disease_hospital_admissions, jnp.array(model_data["data_observed_disease_hospital_admissions"]), ) - assert data.data_observed_disease_wastewater is None - ## Update this if wastewater data is added later + assert pl.DataFrame( + model_data["data_observed_disease_wastewater"], + schema_overrides={"date": pl.Date}, + ).equals(data.data_observed_disease_wastewater) diff --git a/pyrenew_hew/pyrenew_hew_data.py b/pyrenew_hew/pyrenew_hew_data.py index 4c437d83..99de33f4 100644 --- a/pyrenew_hew/pyrenew_hew_data.py +++ b/pyrenew_hew/pyrenew_hew_data.py @@ -1,6 +1,8 @@ import datetime from typing import Self +import jax.numpy as jnp +import polars as pl from jax.typing import ArrayLike @@ -12,63 +14,75 @@ class PyrenewHEWData: def __init__( self, - n_ed_visits_datapoints: int = None, - n_hospital_admissions_datapoints: int = None, - n_wastewater_datapoints: int = None, + n_ed_visits_data_days: int = None, + n_hospital_admissions_data_days: int = None, + n_wastewater_data_days: int = None, data_observed_disease_ed_visits: ArrayLike = None, data_observed_disease_hospital_admissions: ArrayLike = None, - data_observed_disease_wastewater: ArrayLike = None, + data_observed_disease_wastewater: pl.DataFrame = None, right_truncation_offset: int = None, - first_ed_visits_date: datetime.datetime.date = None, - first_hospital_admissions_date: datetime.datetime.date = None, - first_wastewater_date: datetime.datetime.date = None, + first_ed_visits_date: datetime.date = None, + first_hospital_admissions_date: datetime.date = None, + first_wastewater_date: datetime.date = None, + population_size: int = None, + shedding_offset: float = 1e-8, + pop_fraction: ArrayLike = jnp.array([1]), ) -> None: - self.n_ed_visits_datapoints_ = n_ed_visits_datapoints - self.n_hospital_admissions_datapoints_ = ( - n_hospital_admissions_datapoints - ) - self.n_wastewater_datapoints_ = n_wastewater_datapoints + self.n_ed_visits_data_days_ = n_ed_visits_data_days + self.n_hospital_admissions_data_days_ = n_hospital_admissions_data_days + self.n_wastewater_data_days_ = n_wastewater_data_days self.data_observed_disease_ed_visits = data_observed_disease_ed_visits self.data_observed_disease_hospital_admissions = ( data_observed_disease_hospital_admissions ) - self.data_observed_disease_wastewater = ( - data_observed_disease_wastewater - ) - self.right_truncation_offset = right_truncation_offset - self.first_ed_visits_date = first_ed_visits_date self.first_hospital_admissions_date = first_hospital_admissions_date - self.first_wastewater_date = first_wastewater_date + self.first_wastewater_date_ = first_wastewater_date + self.data_observed_disease_wastewater = ( + data_observed_disease_wastewater + ) + self.population_size = population_size + self.shedding_offset = shedding_offset + self.pop_fraction_ = pop_fraction @property - def n_ed_visits_datapoints(self): - return self.get_n_datapoints( - n_datapoints=self.n_ed_visits_datapoints_, + def n_ed_visits_data_days(self): + return self.get_n_data_days( + n_datapoints=self.n_ed_visits_data_days_, data_array=self.data_observed_disease_ed_visits, ) @property - def n_hospital_admissions_datapoints(self): - return self.get_n_datapoints( - n_datapoints=self.n_hospital_admissions_datapoints_, + def n_hospital_admissions_data_days(self): + return self.get_n_data_days( + n_datapoints=self.n_hospital_admissions_data_days_, data_array=self.data_observed_disease_hospital_admissions, ) @property - def n_wastewater_datapoints(self): - return self.get_n_datapoints( - n_datapoints=self.n_wastewater_datapoints_, - data_array=self.data_observed_disease_wastewater, + def n_wastewater_data_days(self): + return self.get_n_wastewater_data_days( + n_datapoints=self.n_wastewater_data_days_, + date_array=( + None + if self.data_observed_disease_wastewater is None + else self.data_observed_disease_wastewater["date"] + ), ) + @property + def first_wastewater_date(self): + if self.data_observed_disease_wastewater is not None: + return self.data_observed_disease_wastewater["date"].min() + return self.first_wastewater_date_ + @property def last_ed_visits_date(self): return self.get_end_date( self.first_ed_visits_date, - self.n_ed_visits_datapoints, + self.n_ed_visits_data_days, timestep_days=1, ) @@ -76,7 +90,7 @@ def last_ed_visits_date(self): def last_hospital_admissions_date(self): return self.get_end_date( self.first_hospital_admissions_date, - self.n_hospital_admissions_datapoints, + self.n_hospital_admissions_data_days, timestep_days=7, ) @@ -84,7 +98,7 @@ def last_hospital_admissions_date(self): def last_wastewater_date(self): return self.get_end_date( self.first_wastewater_date, - self.n_wastewater_datapoints, + self.n_wastewater_data_days, timestep_days=1, ) @@ -118,12 +132,166 @@ def n_days_post_init(self): self.last_data_date_overall - self.first_data_date_overall ).days + @property + def site_subpop_spine(self): + ww_data_present = self.data_observed_disease_wastewater is not None + if ww_data_present: + # Check if auxiliary subpopulation needs to be added + add_auxiliary_subpop = ( + self.population_size + > self.data_observed_disease_wastewater.select( + pl.col("site_pop", "site", "lab", "lab_site_index") + ) + .unique() + .get_column("site_pop") + .sum() + ) + site_indices = ( + self.data_observed_disease_wastewater.select( + ["site_index", "site", "site_pop"] + ) + .unique() + .sort("site_index") + ) + if add_auxiliary_subpop: + aux_subpop = pl.DataFrame( + { + "site_index": [None], + "site": [None], + "site_pop": [ + self.population_size + - site_indices.select(pl.col("site_pop")) + .get_column("site_pop") + .sum() + ], + } + ) + else: + aux_subpop = pl.DataFrame() + site_subpop_spine = ( + pl.concat([aux_subpop, site_indices], how="vertical_relaxed") + .with_columns( + subpop_index=pl.col("site_index") + .cum_count() + .alias("subpop_index"), + subpop_name=pl.format( + "Site: {}", pl.col("site") + ).fill_null("remainder of population"), + ) + .rename({"site_pop": "subpop_pop"}) + ) + else: + site_subpop_spine = pl.DataFrame( + { + "site_index": [None], + "site": [None], + "subpop_pop": [self.population_size], + "subpop_index": [1], + "subpop_name": ["total population"], + } + ) + return site_subpop_spine + + @property + def date_time_spine(self): + if self.data_observed_disease_wastewater is not None: + date_time_spine = pl.DataFrame( + { + "date": pl.date_range( + start=self.first_wastewater_date, + end=self.last_wastewater_date, + interval="1d", + eager=True, + ) + } + ).with_row_index("t") + return date_time_spine + + @property + def wastewater_data_extended(self): + if self.data_observed_disease_wastewater is not None: + return ( + self.data_observed_disease_wastewater.join( + self.date_time_spine, on="date", how="left", coalesce=True + ) + .join( + self.site_subpop_spine, + on=["site_index", "site"], + how="left", + coalesce=True, + ) + .with_row_index("ind_rel_to_observed_times") + ) + + @property + def pop_fraction(self): + if self.data_observed_disease_wastewater is not None: + subpop_sizes = self.site_subpop_spine["subpop_pop"].to_numpy() + return subpop_sizes / self.population_size + return self.pop_fraction_ + + @property + def data_observed_disease_wastewater_conc(self): + if self.data_observed_disease_wastewater is not None: + return self.wastewater_data_extended[ + "log_genome_copies_per_ml" + ].to_numpy() + + @property + def ww_censored(self): + if self.data_observed_disease_wastewater is not None: + return self.wastewater_data_extended.filter( + pl.col("below_lod") == 1 + )["ind_rel_to_observed_times"].to_numpy() + return None + + @property + def ww_uncensored(self): + if self.data_observed_disease_wastewater is not None: + return self.wastewater_data_extended.filter( + pl.col("below_lod") == 0 + )["ind_rel_to_observed_times"].to_numpy() + + @property + def ww_observed_times(self): + if self.data_observed_disease_wastewater is not None: + return self.wastewater_data_extended["t"].to_numpy() + + @property + def ww_observed_subpops(self): + if self.data_observed_disease_wastewater is not None: + return self.wastewater_data_extended["subpop_index"].to_numpy() + + @property + def ww_observed_lab_sites(self): + if self.data_observed_disease_wastewater is not None: + return self.wastewater_data_extended["lab_site_index"].to_numpy() + + @property + def ww_log_lod(self): + if self.data_observed_disease_wastewater is not None: + return self.wastewater_data_extended["log_lod"].to_numpy() + + @property + def n_ww_lab_sites(self): + if self.data_observed_disease_wastewater is not None: + return self.wastewater_data_extended["lab_site_index"].n_unique() + + @property + def lab_site_to_subpop_map(self): + if self.data_observed_disease_wastewater is not None: + return ( + self.wastewater_data_extended["lab_site_index", "subpop_index"] + .unique() + .sort(by="lab_site_index") + )["subpop_index"].to_numpy() + def get_end_date( self, - first_date: datetime.datetime.date, + first_date: datetime.date, n_datapoints: int, timestep_days: int = 1, - ) -> datetime.datetime.date: + ) -> datetime.date: """ Get end date from a first date and a number of datapoints, with handling of None values and non-daily timeseries @@ -143,7 +311,7 @@ def get_end_date( ) return result - def get_n_datapoints( + def get_n_data_days( self, n_datapoints: int = None, data_array: ArrayLike = None ) -> int: if n_datapoints is None and data_array is None: @@ -159,16 +327,33 @@ def get_n_datapoints( else: return n_datapoints + def get_n_wastewater_data_days( + self, n_datapoints: int = None, date_array: ArrayLike = None + ) -> int: + if n_datapoints is None and date_array is None: + return 0 + elif date_array is not None and n_datapoints is not None: + raise ValueError( + "Must provide at most one out of a " + "number of datapoints to simulate and " + "an array of dates wastewater data is " + "observed." + ) + elif date_array is not None: + return (max(date_array) - min(date_array)).days + else: + return n_datapoints + def to_forecast_data(self, n_forecast_points: int) -> Self: n_days = self.n_days_post_init + n_forecast_points n_weeks = n_days // 7 return PyrenewHEWData( - n_ed_visits_datapoints=n_days, - n_hospital_admissions_datapoints=n_weeks, - n_wastewater_datapoints=n_days, + n_ed_visits_data_days=n_days, + n_hospital_admissions_data_days=n_weeks, + n_wastewater_data_days=n_days, first_ed_visits_date=self.first_data_date_overall, first_hospital_admissions_date=(self.first_data_date_overall), first_wastewater_date=self.first_data_date_overall, - right_truncation_offset=None, - # by default, want forecasts of complete reports + pop_fraction=self.pop_fraction, + right_truncation_offset=None, # by default, want forecasts of complete reports ) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index c52ecfaf..da2fc592 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -589,9 +589,9 @@ def sample( n_datapoints: int, ww_uncensored: ArrayLike, ww_censored: ArrayLike, - ww_sampled_lab_sites: ArrayLike, - ww_sampled_subpops: ArrayLike, - ww_sampled_times: ArrayLike, + ww_observed_lab_sites: ArrayLike, + ww_observed_subpops: ArrayLike, + ww_observed_times: ArrayLike, ww_log_lod: ArrayLike, lab_site_to_subpop_map: ArrayLike, n_ww_lab_sites: int, @@ -648,15 +648,15 @@ def batch_colvolve_fn(m): # multiply the expected observed genomes by the site-specific multiplier at that sampling time expected_obs_log_v_site = ( - expected_obs_viral_genomes[ww_sampled_times, ww_sampled_subpops] - + mode_ww_site[ww_sampled_lab_sites] + expected_obs_viral_genomes[ww_observed_times, ww_observed_subpops] + + mode_ww_site[ww_observed_lab_sites] ) DistributionalVariable( "log_conc_obs", dist.Normal( loc=expected_obs_log_v_site[ww_uncensored], - scale=sigma_ww_site[ww_sampled_lab_sites[ww_uncensored]], + scale=sigma_ww_site[ww_observed_lab_sites[ww_uncensored]], ), ).sample( obs=( @@ -669,7 +669,7 @@ def batch_colvolve_fn(m): if ww_censored.shape[0] != 0: log_cdf_values = dist.Normal( loc=expected_obs_log_v_site[ww_censored], - scale=sigma_ww_site[ww_sampled_lab_sites[ww_censored]], + scale=sigma_ww_site[ww_observed_lab_sites[ww_censored]], ).log_cdf(ww_log_lod[ww_censored]) numpyro.factor("log_prob_censored", log_cdf_values.sum()) @@ -741,7 +741,7 @@ def sample( latent_infections=latent_infections, population_size=self.population_size, data_observed=data.data_observed_disease_ed_visits, - n_datapoints=data.n_ed_visits_datapoints, + n_datapoints=data.n_ed_visits_data_days, right_truncation_offset=data.right_truncation_offset, ) @@ -750,7 +750,7 @@ def sample( latent_infections=latent_infections, first_latent_infection_dow=first_latent_infection_dow, population_size=self.population_size, - n_datapoints=data.n_hospital_admissions_datapoints, + n_datapoints=data.n_hospital_admissions_data_days, data_observed=(data.data_observed_disease_hospital_admissions), iedr=iedr, ) @@ -761,16 +761,16 @@ def sample( ) = self.wastewater_obs_process_rv( latent_infections=latent_infections, latent_infections_subpop=latent_infections_subpop, - data_observed=data.data_observed_disease_wastewater, - n_datapoints=data.n_wastewater_datapoints, - ww_uncensored=None, # placeholder - ww_censored=None, # placeholder - ww_sampled_lab_sites=None, # placeholder - ww_sampled_subpops=None, # placeholder - ww_sampled_times=None, # placeholder - ww_log_lod=None, # placeholder - lab_site_to_subpop_map=None, # placeholder - n_ww_lab_sites=None, # placeholder + data_observed=data.data_observed_disease_wastewater_conc, + n_datapoints=data.n_wastewater_data_days, + ww_uncensored=data.ww_uncensored, + ww_censored=data.ww_censored, + ww_observed_lab_sites=data.ww_observed_lab_sites, + ww_observed_subpops=data.ww_observed_subpops, + ww_observed_times=data.ww_observed_times, + ww_log_lod=data.ww_log_lod, + lab_site_to_subpop_map=data.lab_site_to_subpop_map, + n_ww_lab_sites=data.n_ww_lab_sites, shedding_offset=1e-8, ) diff --git a/tests/test_pyrenew_hew_data.py b/tests/test_pyrenew_hew_data.py index 90ebbfac..99f42058 100644 --- a/tests/test_pyrenew_hew_data.py +++ b/tests/test_pyrenew_hew_data.py @@ -1,5 +1,8 @@ -from datetime import datetime +import datetime +import jax.numpy as jnp +import numpy as np +import polars as pl import pytest from pyrenew_hew.pyrenew_hew_data import PyrenewHEWData @@ -7,9 +10,9 @@ @pytest.mark.parametrize( [ - "n_ed_visits_datapoints", - "n_hospital_admissions_datapoints", - "n_wastewater_datapoints", + "n_ed_visits_data_days", + "n_hospital_admissions_data_days", + "n_wastewater_data_days", "right_truncation_offset", "first_ed_visits_date", "first_hospital_admissions_date", @@ -22,9 +25,9 @@ 0, 0, 5, - datetime(2023, 1, 1), - datetime(2022, 2, 5), - datetime(2025, 12, 5), + datetime.date(2023, 1, 1), + datetime.date(2022, 2, 5), + datetime.date(2025, 12, 5), 10, ], [ @@ -32,9 +35,9 @@ 325, 2, 5, - datetime(2025, 1, 1), - datetime(2023, 5, 25), - datetime(2022, 4, 5), + datetime.date(2025, 1, 1), + datetime.date(2023, 5, 25), + datetime.date(2022, 4, 5), 10, ], [ @@ -42,9 +45,9 @@ 0, 2, 3, - datetime(2025, 1, 1), - datetime(2025, 2, 5), - datetime(2024, 12, 5), + datetime.date(2025, 1, 1), + datetime.date(2025, 2, 5), + datetime.date(2024, 12, 5), 30, ], [ @@ -52,17 +55,17 @@ 0, 23, 3, - datetime(2025, 1, 1), - datetime(2025, 2, 5), - datetime(2024, 12, 5), + datetime.date(2025, 1, 1), + datetime.date(2025, 2, 5), + datetime.date(2024, 12, 5), 30, ], ], ) def test_to_forecast_data( - n_ed_visits_datapoints: int, - n_hospital_admissions_datapoints: int, - n_wastewater_datapoints: int, + n_ed_visits_data_days: int, + n_hospital_admissions_data_days: int, + n_wastewater_data_days: int, right_truncation_offset: int, first_ed_visits_date: datetime.date, first_hospital_admissions_date: datetime.date, @@ -73,9 +76,9 @@ def test_to_forecast_data( Test the to_forecast_data method """ data = PyrenewHEWData( - n_ed_visits_datapoints=n_ed_visits_datapoints, - n_hospital_admissions_datapoints=n_hospital_admissions_datapoints, - n_wastewater_datapoints=n_wastewater_datapoints, + n_ed_visits_data_days=n_ed_visits_data_days, + n_hospital_admissions_data_days=n_hospital_admissions_data_days, + n_wastewater_data_days=n_wastewater_data_days, first_ed_visits_date=first_ed_visits_date, first_hospital_admissions_date=first_hospital_admissions_date, first_wastewater_date=first_wastewater_date, @@ -88,9 +91,9 @@ def test_to_forecast_data( forecast_data = data.to_forecast_data(n_forecast_points) n_days_expected = data.n_days_post_init + n_forecast_points n_weeks_expected = n_days_expected // 7 - assert forecast_data.n_ed_visits_datapoints == n_days_expected - assert forecast_data.n_wastewater_datapoints == n_days_expected - assert forecast_data.n_hospital_admissions_datapoints == n_weeks_expected + assert forecast_data.n_ed_visits_data_days == n_days_expected + assert forecast_data.n_wastewater_data_days == n_days_expected + assert forecast_data.n_hospital_admissions_data_days == n_weeks_expected assert forecast_data.right_truncation_offset is None assert forecast_data.first_ed_visits_date == data.first_data_date_overall assert ( @@ -98,3 +101,67 @@ def test_to_forecast_data( == data.first_data_date_overall ) assert forecast_data.first_wastewater_date == data.first_data_date_overall + assert forecast_data.data_observed_disease_wastewater is None + + +def test_wastewater_data_properties(): + first_training_date = datetime.date(2023, 1, 1) + last_training_date = datetime.date(2023, 7, 23) + dates = pl.date_range( + first_training_date, + last_training_date, + interval="1w", + closed="both", + eager=True, + ) + + ww_raw = pl.DataFrame( + { + "date": dates.extend(dates), + "site": [200] * 30 + [100] * 30, + "lab": [21] * 60, + "lab_site_index": [1] * 30 + [2] * 30, + "site_index": [1] * 30 + [2] * 30, + "log_genome_copies_per_ml": np.log( + np.abs(np.random.normal(loc=500, scale=50, size=60)) + ), + "log_lod": np.log([20] * 30 + [15] * 30), + "site_pop": [200_000] * 30 + [400_000] * 30, + } + ) + + ww_data = ww_raw.with_columns( + (pl.col("log_genome_copies_per_ml") <= pl.col("log_lod")) + .cast(pl.Int8) + .alias("below_lod") + ) + + first_ed_visits_date = datetime.date(2023, 1, 1) + first_hospital_admissions_date = datetime.date(2023, 1, 1) + first_wastewater_date = datetime.date(2023, 1, 1) + n_forecast_points = 10 + + data = PyrenewHEWData( + first_ed_visits_date=first_ed_visits_date, + first_hospital_admissions_date=first_hospital_admissions_date, + first_wastewater_date=first_wastewater_date, + data_observed_disease_wastewater=ww_data, + population_size=1e6, + ) + + forecast_data = data.to_forecast_data(n_forecast_points) + assert forecast_data.data_observed_disease_wastewater_conc is None + assert data.data_observed_disease_wastewater_conc is not None + + assert jnp.array_equal( + data.data_observed_disease_wastewater_conc, + ww_data["log_genome_copies_per_ml"], + ) + assert len(data.ww_censored) == len( + ww_data.filter(pl.col("below_lod") == 1) + ) + assert len(data.ww_uncensored) == len( + ww_data.filter(pl.col("below_lod") == 0) + ) + assert jnp.array_equal(data.ww_log_lod, ww_data["log_lod"]) + assert data.n_ww_lab_sites == ww_data["lab_site_index"].n_unique()