From 474d283ad8eb78ed4ba1200b4978ba8c94ee0864 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Wed, 12 Feb 2025 20:37:22 -0500 Subject: [PATCH] Add temporary checks for disallowed models --- pipelines/batch/setup_prod_job.py | 16 ++++++++++++++++ pipelines/forecast_state.py | 19 ++++++++++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/pipelines/batch/setup_prod_job.py b/pipelines/batch/setup_prod_job.py index 683d3c9a..9af70f7c 100644 --- a/pipelines/batch/setup_prod_job.py +++ b/pipelines/batch/setup_prod_job.py @@ -133,6 +133,22 @@ def main( f"supported diseases are: {', '.join(supported_diseases)}" ) + for signal in signals: + fit = locals().get(f"fit_{signal}", False) + forecast = locals().get(f"forecast_{signal}", False) + if fit and not forecast: + ValueError( + "This pipeline does not currently support " + "fitting to but not forecasting a signal. " + f"Asked to fit but not forecast {signal}." + ) + any_fit = any([locals().get(f"fit_{signal}", False) for signal in signals]) + if not any_fit: + raise ValueError( + "pyrenew_null (fitting to no signals) " + "is not supported by this pipeline" + ) + pyrenew_hew_output_container = ( "pyrenew-test-output" if test else "pyrenew-hew-prod-output" ) diff --git a/pipelines/forecast_state.py b/pipelines/forecast_state.py index 90ee6cc4..fe549495 100644 --- a/pipelines/forecast_state.py +++ b/pipelines/forecast_state.py @@ -225,7 +225,24 @@ def main( f"model {pyrenew_model_name}, location {state}, " f"and report date {report_date}" ) - + signals = ["ed_visits", "hospital_admissions", "wastewater"] + + for signal in signals: + fit = locals().get(f"fit_{signal}", False) + forecast = locals().get(f"forecast_{signal}", False) + if fit and not forecast: + ValueError( + "This pipeline does not currently support " + "fitting to but not forecasting a signal. " + f"Asked to fit but not forecast {signal}." + ) + any_fit = any([locals().get(f"fit_{signal}", False) for signal in signals]) + if not any_fit: + raise ValueError( + "pyrenew_null (fitting to no signals) " + "is not supported by this pipeline" + ) + if credentials_path is not None: cp = Path(credentials_path) if not cp.suffix.lower() == ".toml":