Skip to content

Commit

Permalink
Merge branch 'main' into dhm-make-generate-predictive-use-dirname
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanhmorris authored Feb 19, 2025
2 parents f2be950 + 6fac159 commit 08513be
Show file tree
Hide file tree
Showing 9 changed files with 691 additions and 88 deletions.
12 changes: 10 additions & 2 deletions pipelines/build_pyrenew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path

import jax.numpy as jnp
import polars as pl
from pyrenew.deterministic import DeterministicVariable

from pyrenew_hew.pyrenew_hew_data import PyrenewHEWData
Expand Down Expand Up @@ -88,8 +89,14 @@ def build_model_from_dir(
else None
)

# placeholder
data_observed_disease_wastewater = None if fit_wastewater else None
data_observed_disease_wastewater = (
pl.DataFrame(
model_data["data_observed_disease_wastewater"],
schema_overrides={"date": pl.Date},
)
if sample_wastewater
else None
)

population_size = jnp.array(model_data["state_pop"])

Expand Down Expand Up @@ -175,6 +182,7 @@ def build_model_from_dir(
right_truncation_offset=right_truncation_offset,
first_ed_visits_date=first_ed_visits_date,
first_hospital_admissions_date=first_hospital_admissions_date,
population_size=population_size,
)

return (my_model, my_data)
27 changes: 27 additions & 0 deletions pipelines/forecast_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from generate_predictive import ( # noqa
generate_and_save_predictions,
)
from prep_ww_data import clean_nwss_data, preprocess_ww_data


def record_git_info(model_run_dir: Path):
Expand Down Expand Up @@ -192,6 +193,7 @@ def main(
state: str,
facility_level_nssp_data_dir: Path | str,
state_level_nssp_data_dir: Path | str,
nwss_data_dir: Path | str,
param_data_dir: Path | str,
priors_path: Path | str,
output_dir: Path | str,
Expand Down Expand Up @@ -333,6 +335,23 @@ def main(
"No data available for the requested report date " f"{report_date}"
)

available_nwss_reports = get_available_reports(nwss_data_dir)
# assming NWSS_vintage directory follows naming convention
# using as of date
# need to be modified otherwise

if report_date in available_nwss_reports:
nwss_data_raw = pl.scan_parquet(
Path(nwss_data_dir, f"{report_date}.parquet")
)
nwss_data_cleaned = clean_nwss_data(nwss_data_raw).filter(
(pl.col("location") == state)
& (pl.col("date") >= first_training_date)
)
state_level_nwss_data = preprocess_ww_data(nwss_data_cleaned)
else:
state_level_nwss_data = None ## TO DO: change

param_estimates = pl.scan_parquet(Path(param_data_dir, "prod.parquet"))
model_batch_dir_name = (
f"{disease.lower()}_r_{report_date}_f_"
Expand All @@ -357,6 +376,7 @@ def main(
disease=disease,
facility_level_nssp_data=facility_level_nssp_data,
state_level_nssp_data=state_level_nssp_data,
state_level_nwss_data=state_level_nwss_data,
report_date=report_date,
first_training_date=first_training_date,
last_training_date=last_training_date,
Expand Down Expand Up @@ -497,6 +517,13 @@ def main(
),
)

parser.add_argument(
"--nwss-data-dir",
type=Path,
default=Path("private_data", "nwss_vintages"),
help=("Directory in which to look for NWSS data."),
)

parser.add_argument(
"--param-data-dir",
type=Path,
Expand Down
10 changes: 10 additions & 0 deletions pipelines/prep_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pathlib import Path

import forecasttools
import jax.numpy as jnp
import polars as pl
import polars.selectors as cs

Expand Down Expand Up @@ -354,6 +355,7 @@ def process_and_save_state(
logger: Logger = None,
facility_level_nssp_data: pl.LazyFrame = None,
state_level_nssp_data: pl.LazyFrame = None,
state_level_nwss_data: pl.LazyFrame = None,
credentials_dict: dict = None,
) -> None:
logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -449,6 +451,12 @@ def process_and_save_state(
"hospital_admissions"
).to_list()

data_observed_disease_wastewater = (
state_level_nwss_data.to_dict(as_series=False)
if state_level_nwss_data is not None
else None
)

data_for_model_fit = {
"inf_to_ed_pmf": delay_pmf,
"generation_interval_pmf": generation_interval_pmf,
Expand All @@ -462,7 +470,9 @@ def process_and_save_state(
"nhsn_step_size": nhsn_step_size,
"state_pop": state_pop,
"right_truncation_offset": right_truncation_offset,
"data_observed_disease_wastewater": data_observed_disease_wastewater,
}

data_dir = Path(model_run_dir, "data")
os.makedirs(data_dir, exist_ok=True)

Expand Down
Loading

0 comments on commit 08513be

Please sign in to comment.