Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make which quantities are sampled more readily configurable #326

Merged
merged 37 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
6e1c5b6
Make which quantities are sampled more readily configurable
dylanhmorris Feb 5, 2025
c3b7858
Fix variable name typo
dylanhmorris Feb 5, 2025
3d15528
Predictive flags and CLI
dylanhmorris Feb 6, 2025
0238435
Set flag in forecast state
dylanhmorris Feb 6, 2025
250ea74
Harmonize data generation
dylanhmorris Feb 6, 2025
557330d
Remove print call
dylanhmorris Feb 6, 2025
8baa0f0
Add switches to fitting scripts
dylanhmorris Feb 6, 2025
c40eabb
Add flags to prod job call
dylanhmorris Feb 6, 2025
674af5d
Fix flag ordering and logic
dylanhmorris Feb 6, 2025
18df0b5
Merge branch 'main' into dhm-configure-which-fit
dylanhmorris Feb 10, 2025
1ef2244
Merge branch 'main' into dhm-configure-which-fit
dylanhmorris Feb 10, 2025
098d568
More meaningful flag names
dylanhmorris Feb 10, 2025
85f692a
DRY flag setting
dylanhmorris Feb 10, 2025
31b6f1d
Align CLI with main function
dylanhmorris Feb 10, 2025
0c75437
Fix missing import
dylanhmorris Feb 10, 2025
ca6e01e
Update end-to-end test
dylanhmorris Feb 10, 2025
6252dfc
Fix typo
dylanhmorris Feb 10, 2025
54771ed
Add build model switch test
dylanhmorris Feb 10, 2025
db3564d
Merge branch 'main' into dhm-configure-which-fit
dylanhmorris Feb 11, 2025
25eaca0
Add programmatic model naming
dylanhmorris Feb 11, 2025
d46b958
use lowercase, add unit test
dylanhmorris Feb 11, 2025
536ed3f
End to end test is only pyrenew_e for now
dylanhmorris Feb 11, 2025
c868636
Do not delete test check on additional data
dylanhmorris Feb 11, 2025
2362f42
Remove extra line
dylanhmorris Feb 11, 2025
3327a94
fix typo in end to end test
dylanhmorris Feb 11, 2025
b8dc990
Move argument parsing inside if __name__==__main__ in setup prod job,…
dylanhmorris Feb 11, 2025
e25aac3
Merge branch 'main' into dhm-configure-which-fit
dylanhmorris Feb 11, 2025
26102b7
Qualify namespace in generate_predictive.py
dylanhmorris Feb 11, 2025
f0eef15
Revert import formulation
dylanhmorris Feb 11, 2025
283d5b0
Fix missing space
dylanhmorris Feb 11, 2025
2b3b881
Apply suggestions from code review
dylanhmorris Feb 11, 2025
0b4c4d7
Update pyrenew_hew/util.py
damonbayer Feb 11, 2025
85d1446
Add hew model iterator
dylanhmorris Feb 12, 2025
3661e5e
Merge branch 'main' into dhm-configure-which-fit
dylanhmorris Feb 13, 2025
474d283
Add temporary checks for disallowed models
dylanhmorris Feb 13, 2025
94ebe3b
Fix typos
dylanhmorris Feb 13, 2025
58dfda3
Merge branch 'main' into dhm-configure-which-fit
dylanhmorris Feb 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 39 additions & 8 deletions pipelines/batch/setup_prod_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def main(
pool_id: str,
diseases: str | list[str],
output_subdir: str | Path = "./",
sample_ed_visits: bool = False,
sample_hospital_admissions: bool = False,
sample_wastewater: bool = False,
container_image_name: str = "pyrenew-hew",
container_image_version: str = "latest",
n_training_days: int = 90,
Expand Down Expand Up @@ -155,24 +158,35 @@ def main(
],
)

sample_ed_visits_flag = "--sample-ed-visits " if sample_ed_visits else ""
sample_hospital_admissions_flag = (
"--sample-ed-visits " if sample_hospital_admissions else ""
)
sample_wastewater_flag = (
"--sample-ed-visits " if sample_hospital_admissions else ""
)

base_call = (
"/bin/bash -c '"
"python pipelines/forecast_state.py "
"--disease {disease} "
"--state {state} "
"--n-training-days {n_training_days} "
"--n-warmup {n_warmup} "
"--n-samples {n_samples} "
f"--n-training-days {n_training_days} "
f"--n-warmup {n_warmup} "
f"--n-samples {n_samples} "
"--facility-level-nssp-data-dir nssp-etl/gold "
"--state-level-nssp-data-dir "
"nssp-archival-vintages/gold "
"--param-data-dir params "
"--output-dir {output_dir} "
"--priors-path pipelines/priors/prod_priors.py "
"--report-date {report_date} "
"--exclude-last-n-days {exclude_last_n_days} "
f"--exclude-last-n-days {exclude_last_n_days} "
"--no-score "
"--eval-data-path "
f"{sample_ed_visits_flag}"
f"{sample_hospital_admissions_flag}"
f"{sample_wastewater_flag}"
"nssp-etl/latest_comprehensive.parquet"
"'"
)
Expand All @@ -196,10 +210,6 @@ def main(
state=state,
disease=disease,
report_date="latest",
n_warmup=n_warmup,
n_samples=n_samples,
n_training_days=n_training_days,
exclude_last_n_days=exclude_last_n_days,
output_dir=str(Path("output", output_subdir)),
),
container_settings=container_settings,
Expand Down Expand Up @@ -253,6 +263,27 @@ def main(
default="latest",
)


parser.add_argument(
"--sample-ed-visits",
type=bool,
action=argparse.BooleanOptionalAction,
help="If provided, fit to and predict ED visit data.",
)
parser.add_argument(
"--sample-hospital-admissions",
type=bool,
action=argparse.BooleanOptionalAction,
help=("If provided, fit to and predict hospital admissions data."),
)
parser.add_argument(
"--sample-wastewater",
type=bool,
action=argparse.BooleanOptionalAction,
help="If provided, fit to and predict wastewater data.",
)


parser.add_argument(
"--n-training-days",
type=int,
Expand Down
21 changes: 16 additions & 5 deletions pipelines/build_pyrenew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

def build_model_from_dir(
model_dir: Path,
sample_ed_visits: bool = False,
sample_hospital_admissions: bool = False,
sample_wastewater: bool = False,
) -> tuple[PyrenewHEWModel, PyrenewHEWData]:
data_path = Path(model_dir) / "data" / "data_for_model_fit.json"
prior_path = Path(model_dir) / "priors.py"
Expand Down Expand Up @@ -47,12 +50,20 @@
jnp.array(model_data["generation_interval_pmf"]),
) # check if off by 1 or reversed

data_observed_disease_ed_visits = jnp.array(
model_data["data_observed_disease_ed_visits"]
data_observed_disease_ed_visits = (

Check warning on line 53 in pipelines/build_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/build_pyrenew_model.py#L53

Added line #L53 was not covered by tests
jnp.array(model_data["data_observed_disease_ed_visits"])
if sample_ed_visits
else None
)
data_observed_disease_hospital_admissions = jnp.array(
model_data["data_observed_disease_hospital_admissions"]
data_observed_disease_hospital_admissions = (

Check warning on line 58 in pipelines/build_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/build_pyrenew_model.py#L58

Added line #L58 was not covered by tests
jnp.array(model_data["data_observed_disease_hospital_admissions"])
if sample_hospital_admissions
else None
)

# placeholder
data_observed_disease_wastewater = None if sample_wastewater else None

Check warning on line 65 in pipelines/build_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/build_pyrenew_model.py#L65

Added line #L65 was not covered by tests

population_size = jnp.array(model_data["state_pop"])

ed_right_truncation_pmf_rv = DeterministicVariable(
Expand Down Expand Up @@ -133,7 +144,7 @@
data_observed_disease_hospital_admissions=(
data_observed_disease_hospital_admissions
),
data_observed_disease_wastewater=None, # placeholder
data_observed_disease_wastewater=data_observed_disease_wastewater,
right_truncation_offset=right_truncation_offset,
first_ed_visits_date=first_ed_visits_date,
first_hospital_admissions_date=first_hospital_admissions_date,
Expand Down
36 changes: 32 additions & 4 deletions pipelines/fit_pyrenew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
def fit_and_save_model(
model_run_dir: str,
model_name: str,
sample_ed_visits: bool = False,
sample_hospital_admissions: bool = False,
sample_wastewater: bool = False,
n_warmup: int = 1000,
n_samples: int = 1000,
n_chains: int = 4,
Expand All @@ -26,12 +29,17 @@
"rng_key must be an integer with which "
"to seed :func:`jax.random.key`"
)
(my_model, my_data) = build_model_from_dir(model_run_dir)
(my_model, my_data) = build_model_from_dir(

Check warning on line 32 in pipelines/fit_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/fit_pyrenew_model.py#L32

Added line #L32 was not covered by tests
model_run_dir,
sample_ed_visits=sample_ed_visits,
sample_hospital_admissions=sample_hospital_admissions,
sample_wastewater=sample_wastewater,
)
my_model.run(
data=my_data,
sample_ed_visits=True,
sample_hospital_admissions=True,
sample_wastewater=False,
sample_ed_visits=sample_ed_visits,
sample_hospital_admissions=sample_hospital_admissions,
sample_wastewater=sample_wastewater,
num_warmup=n_warmup,
num_samples=n_samples,
rng_key=rng_key,
Expand Down Expand Up @@ -67,6 +75,26 @@
required=True,
help="Name of the model to use for generating predictions.",
)

parser.add_argument(

Check warning on line 79 in pipelines/fit_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/fit_pyrenew_model.py#L79

Added line #L79 was not covered by tests
"--sample-ed-visits",
type=bool,
action=argparse.BooleanOptionalAction,
help="If provided, fit to ED visit data.",
)
parser.add_argument(

Check warning on line 85 in pipelines/fit_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/fit_pyrenew_model.py#L85

Added line #L85 was not covered by tests
"--sample-hospital-admissions",
type=bool,
action=argparse.BooleanOptionalAction,
help=("If provided, fit to hospital admissions data."),
)
parser.add_argument(

Check warning on line 91 in pipelines/fit_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/fit_pyrenew_model.py#L91

Added line #L91 was not covered by tests
"--sample-wastewater",
type=bool,
action=argparse.BooleanOptionalAction,
help="If provided, fit to wastewater data.",
)

parser.add_argument(
"--n-warmup",
type=int,
Expand Down
35 changes: 33 additions & 2 deletions pipelines/forecast_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,10 @@
exclude_last_n_days: int = 0,
score: bool = False,
eval_data_path: Path = None,
):
sample_ed_visits: bool = False,
sample_hospital_admissions: bool = False,
sample_wastewater: bool = False,
) -> None:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -337,14 +340,22 @@
n_warmup=n_warmup,
n_samples=n_samples,
n_chains=n_chains,
sample_ed_visits=sample_ed_visits,
sample_hospital_admissions=sample_hospital_admissions,
sample_wastewater=sample_wastewater,
)
logger.info("Model fitting complete")

logger.info("Performing posterior prediction / forecasting...")

n_days_past_last_training = n_forecast_days + exclude_last_n_days
generate_and_save_predictions(
model_run_dir, "pyrenew_e", n_days_past_last_training
model_run_dir,
"pyrenew_e",
n_days_past_last_training,
predict_ed_visits=sample_ed_visits,
predict_hospital_admissions=sample_hospital_admissions,
predict_wastewater=sample_wastewater,
)

logger.info(
Expand Down Expand Up @@ -524,6 +535,26 @@
type=Path,
help=("Path to a parquet file containing compehensive truth data."),
)

parser.add_argument(

Check warning on line 539 in pipelines/forecast_state.py

View check run for this annotation

Codecov / codecov/patch

pipelines/forecast_state.py#L539

Added line #L539 was not covered by tests
"--sample-ed-visits",
type=bool,
action=argparse.BooleanOptionalAction,
help="If provided, fit to ED visit data.",
)
parser.add_argument(

Check warning on line 545 in pipelines/forecast_state.py

View check run for this annotation

Codecov / codecov/patch

pipelines/forecast_state.py#L545

Added line #L545 was not covered by tests
"--sample-hospital-admissions",
type=bool,
action=argparse.BooleanOptionalAction,
help=("If provided, fit to hospital admissions data."),
)
parser.add_argument(

Check warning on line 551 in pipelines/forecast_state.py

View check run for this annotation

Codecov / codecov/patch

pipelines/forecast_state.py#L551

Added line #L551 was not covered by tests
"--sample-wastewater",
type=bool,
action=argparse.BooleanOptionalAction,
help="If provided, fit to wastewater data.",
)

args = parser.parse_args()
numpyro.set_host_device_count(args.n_chains)
main(**vars(args))
42 changes: 37 additions & 5 deletions pipelines/generate_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,23 @@


def generate_and_save_predictions(
model_run_dir: str | Path, model_name: str, n_forecast_points: int
model_run_dir: str | Path,
model_name: str,
n_forecast_points: int,
predict_ed_visits: bool = False,
predict_hospital_admissions: bool = False,
predict_wastewater: bool = False,
) -> None:
model_run_dir = Path(model_run_dir)
model_dir = Path(model_run_dir, model_name)
if not model_dir.exists():
raise FileNotFoundError(f"The directory {model_dir} does not exist.")
(my_model, my_data) = build_model_from_dir(model_run_dir)
(my_model, my_data) = build_model_from_dir(

Check warning on line 23 in pipelines/generate_predictive.py

View check run for this annotation

Codecov / codecov/patch

pipelines/generate_predictive.py#L23

Added line #L23 was not covered by tests
model_run_dir,
sample_ed_visits=predict_ed_visits,
sample_hospital_admissions=predict_hospital_admissions,
sample_wastewater=predict_wastewater,
)

my_model._init_model(1, 1)
fresh_sampler = my_model.mcmc.sampler
Expand All @@ -31,9 +41,9 @@

posterior_predictive = my_model.posterior_predictive(
data=forecast_data,
sample_ed_visits=True,
sample_hospital_admissions=True,
sample_wastewater=False,
sample_ed_visits=predict_ed_visits,
sample_hospital_admissions=predict_hospital_admissions,
sample_wastewater=predict_wastewater,
)

idata = az.from_numpyro(
Expand Down Expand Up @@ -73,6 +83,28 @@
default=0,
help="Number of time points to forecast (Default: 0).",
)
parser.add_argument(

Check warning on line 86 in pipelines/generate_predictive.py

View check run for this annotation

Codecov / codecov/patch

pipelines/generate_predictive.py#L86

Added line #L86 was not covered by tests
"--predict-ed-visits",
type=bool,
action=argparse.BooleanOptionalAction,
help="If provided, generate posterior predictions for ED visits.",
)
parser.add_argument(

Check warning on line 92 in pipelines/generate_predictive.py

View check run for this annotation

Codecov / codecov/patch

pipelines/generate_predictive.py#L92

Added line #L92 was not covered by tests
"--predict-hospital-admissions",
type=bool,
action=argparse.BooleanOptionalAction,
help=(
"If provided, generate posterior predictions "
"for hospital admissions."
),
)
parser.add_argument(

Check warning on line 101 in pipelines/generate_predictive.py

View check run for this annotation

Codecov / codecov/patch

pipelines/generate_predictive.py#L101

Added line #L101 was not covered by tests
"--predict-wastewater",
type=bool,
action=argparse.BooleanOptionalAction,
help="If provided, generate posterior predictions for wastewater.",
)

args = parser.parse_args()

generate_and_save_predictions(**vars(args))
Loading