Skip to content

Commit

Permalink
Add temporary checks for disallowed models
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanhmorris committed Feb 13, 2025
1 parent 3661e5e commit 474d283
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
16 changes: 16 additions & 0 deletions pipelines/batch/setup_prod_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,22 @@ def main(
f"supported diseases are: {', '.join(supported_diseases)}"
)

for signal in signals:
fit = locals().get(f"fit_{signal}", False)
forecast = locals().get(f"forecast_{signal}", False)
if fit and not forecast:
ValueError(
"This pipeline does not currently support "
"fitting to but not forecasting a signal. "
f"Asked to fit but not forecast {signal}."
)
any_fit = any([locals().get(f"fit_{signal}", False) for signal in signals])
if not any_fit:
raise ValueError(
"pyrenew_null (fitting to no signals) "
"is not supported by this pipeline"
)

pyrenew_hew_output_container = (
"pyrenew-test-output" if test else "pyrenew-hew-prod-output"
)
Expand Down
19 changes: 18 additions & 1 deletion pipelines/forecast_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,24 @@ def main(
f"model {pyrenew_model_name}, location {state}, "
f"and report date {report_date}"
)

signals = ["ed_visits", "hospital_admissions", "wastewater"]

Check warning on line 228 in pipelines/forecast_state.py

View check run for this annotation

Codecov / codecov/patch

pipelines/forecast_state.py#L228

Added line #L228 was not covered by tests

for signal in signals:
fit = locals().get(f"fit_{signal}", False)
forecast = locals().get(f"forecast_{signal}", False)
if fit and not forecast:
ValueError(

Check warning on line 234 in pipelines/forecast_state.py

View check run for this annotation

Codecov / codecov/patch

pipelines/forecast_state.py#L230-L234

Added lines #L230 - L234 were not covered by tests
"This pipeline does not currently support "
"fitting to but not forecasting a signal. "
f"Asked to fit but not forecast {signal}."
)
any_fit = any([locals().get(f"fit_{signal}", False) for signal in signals])
if not any_fit:
raise ValueError(

Check warning on line 241 in pipelines/forecast_state.py

View check run for this annotation

Codecov / codecov/patch

pipelines/forecast_state.py#L239-L241

Added lines #L239 - L241 were not covered by tests
"pyrenew_null (fitting to no signals) "
"is not supported by this pipeline"
)

if credentials_path is not None:
cp = Path(credentials_path)
if not cp.suffix.lower() == ".toml":
Expand Down

0 comments on commit 474d283

Please sign in to comment.