Skip to content

Commit

Permalink
Add wastewater data to pyrenew_hew_data (#317)
Browse files Browse the repository at this point in the history
* add wastewater data prep code

* pre-commit

* doctring edit

* remove some functions

* update prep_data.py

* pre-commit

* add wastewater data in json file

* more refactor

* remove log offset and change order of some operation for clarity

* move processing of wastewater data to pyrenew-hew-data

* update tests

* restore removed functions

* syn to use datetime.date for dates throughout, rename ww_sampled to ww_observed

* add a test

* fix test

* n_datapoints -> n_data_days, move get_spines function to pyrenewHEWData

* Apply suggestions from code review

Co-authored-by: Dylan H. Morris <dylanhmorris@users.noreply.github.com>

* pre-commit

* update clean_and_filter_nwss.py

* make columns name configurable in validate_ww_data and preprocess_ww_data

* drop wastewater data with no LOD reported

* DRY-ify validation checks

* code review suggestions

* remove redundancy and clean up

* add init file

* Apply suggestions from code review

Co-authored-by: Damon Bayer <xum8@cdc.gov>

* code review suggestions

* error msg

* fix test

* add schema for ww dataframe, add placeholder path

* code review suggestion

* code review suggestions

* fix test

* fix end to end test fail

---------

Co-authored-by: Dylan H. Morris <dylanhmorris@users.noreply.github.com>
Co-authored-by: Damon Bayer <xum8@cdc.gov>
  • Loading branch information
3 people authored Feb 19, 2025
1 parent 026f1eb commit 6fac159
Show file tree
Hide file tree
Showing 9 changed files with 691 additions and 88 deletions.
12 changes: 10 additions & 2 deletions pipelines/build_pyrenew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path

import jax.numpy as jnp
import polars as pl
from pyrenew.deterministic import DeterministicVariable

from pyrenew_hew.pyrenew_hew_data import PyrenewHEWData
Expand Down Expand Up @@ -61,8 +62,14 @@ def build_model_from_dir(
else None
)

# placeholder
data_observed_disease_wastewater = None if sample_wastewater else None
data_observed_disease_wastewater = (
pl.DataFrame(
model_data["data_observed_disease_wastewater"],
schema_overrides={"date": pl.Date},
)
if sample_wastewater
else None
)

population_size = jnp.array(model_data["state_pop"])

Expand Down Expand Up @@ -148,6 +155,7 @@ def build_model_from_dir(
right_truncation_offset=right_truncation_offset,
first_ed_visits_date=first_ed_visits_date,
first_hospital_admissions_date=first_hospital_admissions_date,
population_size=population_size,
)

return (my_model, my_data)
27 changes: 27 additions & 0 deletions pipelines/forecast_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from generate_predictive import ( # noqa
generate_and_save_predictions,
)
from prep_ww_data import clean_nwss_data, preprocess_ww_data


def record_git_info(model_run_dir: Path):
Expand Down Expand Up @@ -192,6 +193,7 @@ def main(
state: str,
facility_level_nssp_data_dir: Path | str,
state_level_nssp_data_dir: Path | str,
nwss_data_dir: Path | str,
param_data_dir: Path | str,
priors_path: Path | str,
output_dir: Path | str,
Expand Down Expand Up @@ -333,6 +335,23 @@ def main(
"No data available for the requested report date " f"{report_date}"
)

available_nwss_reports = get_available_reports(nwss_data_dir)
# assming NWSS_vintage directory follows naming convention
# using as of date
# need to be modified otherwise

if report_date in available_nwss_reports:
nwss_data_raw = pl.scan_parquet(
Path(nwss_data_dir, f"{report_date}.parquet")
)
nwss_data_cleaned = clean_nwss_data(nwss_data_raw).filter(
(pl.col("location") == state)
& (pl.col("date") >= first_training_date)
)
state_level_nwss_data = preprocess_ww_data(nwss_data_cleaned)
else:
state_level_nwss_data = None ## TO DO: change

param_estimates = pl.scan_parquet(Path(param_data_dir, "prod.parquet"))
model_batch_dir_name = (
f"{disease.lower()}_r_{report_date}_f_"
Expand All @@ -357,6 +376,7 @@ def main(
disease=disease,
facility_level_nssp_data=facility_level_nssp_data,
state_level_nssp_data=state_level_nssp_data,
state_level_nwss_data=state_level_nwss_data,
report_date=report_date,
first_training_date=first_training_date,
last_training_date=last_training_date,
Expand Down Expand Up @@ -497,6 +517,13 @@ def main(
),
)

parser.add_argument(
"--nwss-data-dir",
type=Path,
default=Path("private_data", "nwss_vintages"),
help=("Directory in which to look for NWSS data."),
)

parser.add_argument(
"--param-data-dir",
type=Path,
Expand Down
10 changes: 10 additions & 0 deletions pipelines/prep_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pathlib import Path

import forecasttools
import jax.numpy as jnp
import polars as pl
import polars.selectors as cs

Expand Down Expand Up @@ -354,6 +355,7 @@ def process_and_save_state(
logger: Logger = None,
facility_level_nssp_data: pl.LazyFrame = None,
state_level_nssp_data: pl.LazyFrame = None,
state_level_nwss_data: pl.LazyFrame = None,
credentials_dict: dict = None,
) -> None:
logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -449,6 +451,12 @@ def process_and_save_state(
"hospital_admissions"
).to_list()

data_observed_disease_wastewater = (
state_level_nwss_data.to_dict(as_series=False)
if state_level_nwss_data is not None
else None
)

data_for_model_fit = {
"inf_to_ed_pmf": delay_pmf,
"generation_interval_pmf": generation_interval_pmf,
Expand All @@ -462,7 +470,9 @@ def process_and_save_state(
"nhsn_step_size": nhsn_step_size,
"state_pop": state_pop,
"right_truncation_offset": right_truncation_offset,
"data_observed_disease_wastewater": data_observed_disease_wastewater,
}

data_dir = Path(model_run_dir, "data")
os.makedirs(data_dir, exist_ok=True)

Expand Down
Loading

0 comments on commit 6fac159

Please sign in to comment.