Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Script to run small version of prod pipeline with current branch #135

Merged
merged 15 commits into from
Nov 18, 2024
Merged
1 change: 1 addition & 0 deletions pipelines/batch/setup_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def main(pool_name: str) -> None:
"prod-param-estimates",
"pyrenew-hew-prod-output",
"pyrenew-hew-config",
"pyrenew-test-output",
],
account_names=creds.azure_blob_storage_account,
identity_references=node_id_ref,
Expand Down
25 changes: 17 additions & 8 deletions pipelines/batch/setup_prod_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def main(
"VI",
"WY",
],
test: bool = False,
) -> None:
"""
job_id
Expand Down Expand Up @@ -78,6 +79,12 @@ def main(
f"supported diseases are: {', '.join(supported_diseases)}"
)

pyrenew_hew_output_container = (
"pyrenew-test-output" if test else "pyrenew-hew-prod-output"
)
n_warmup = 200 if test else 1000
n_samples = 200 if test else 500

creds = EnvCredentialHandler()
client = get_batch_service_client(creds)
job = models.JobAddParameter(
Expand Down Expand Up @@ -108,7 +115,7 @@ def main(
"target": "/pyrenew-hew/params",
},
{
"source": "pyrenew-hew-prod-output",
"source": pyrenew_hew_output_container,
"target": "/pyrenew-hew/output",
},
{
Expand All @@ -124,8 +131,8 @@ def main(
"--disease {disease} "
"--state {state} "
"--n-training-days 90 "
"--n-warmup 1000 "
"--n-samples 500 "
"--n-warmup {n_warmup} "
"--n-samples {n_samples} "
"--facility-level-nssp-data-dir nssp-etl/gold "
"--state-level-nssp-data-dir "
"nssp-archival-vintages/gold "
Expand All @@ -145,11 +152,11 @@ def main(
"https://www2.census.gov/geo/docs/reference/state.txt", separator="|"
)

all_locations = (
locations.filter(~pl.col("STUSAB").is_in(excluded_locations))
.get_column("STUSAB")
.to_list()
) + ["US"]
all_locations = [
loc
for loc in locations.get_column("STUSAB").to_list() + ["US"]
if loc not in excluded_locations
]

for disease, state in itertools.product(disease_list, all_locations):
task = get_task_config(
Expand All @@ -158,6 +165,8 @@ def main(
state=state,
disease=disease,
report_date="latest",
n_warmup=n_warmup,
n_samples=n_samples,
),
container_settings=container_settings,
)
Expand Down
98 changes: 98 additions & 0 deletions pipelines/batch/setup_test_prod_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
Set up a multi-location, multi-date,
potentially multi-disease end to end
retrospective evaluation run for pyrenew-hew
on Azure Batch.
"""

import argparse
import os
from datetime import datetime, timezone
from pathlib import Path

from pygit2 import Repository
from setup_prod_job import main

if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Test production pipeline on small subset of locations"
)
parser.add_argument(
"tag",
type=str,
help="The tag name to use for the container image version",
default=Path(Repository(os.getcwd()).head.name).stem,
)

args = parser.parse_args()

tag = args.tag
print(f"Using tag {tag}")
current_datetime = datetime.now(timezone.utc).strftime("%Y%m%d%H%M%SZ")
tag = Path(Repository(os.getcwd()).head.name).stem

locs_to_exclude = [ # keep CA, MN, SD, and US
"AS",
"GU",
"MO",
"MP",
"PR",
"UM",
"VI",
"WY",
"AK",
"AL",
"AR",
"AZ",
"CO",
"CT",
"DC",
"DE",
"FL",
"GA",
"HI",
"IA",
"ID",
"IL",
"IN",
"KS",
"KY",
"LA",
"MA",
"MD",
"ME",
"MI",
"MS",
"MT",
"NC",
"ND",
"NE",
"NH",
"NJ",
"NM",
"NV",
"NY",
"OH",
"OK",
"OR",
"PA",
"RI",
"SC",
"TN",
"TX",
"UT",
"VA",
"VT",
"WA",
"WI",
"WV",
]
main(
job_id=f"pyrenew-hew-test-{current_datetime}",
pool_id="pyrenew-pool",
diseases=["COVID-19", "Influenza"],
container_image_name="pyrenew-hew",
container_image_version=tag,
excluded_locations=locs_to_exclude,
test=True,
)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ ipykernel = "^6.29.5"
polars = "^1.5.0"
pypdf = "^5.1.0"
pyarrow = "^18.0.0"
pygit2 = "^1.16.0"

[tool.poetry.group.azurebatch.dependencies]
azuretools = {git = "https://github.com/cdcent/cfa-stf-azuretools"}
Expand Down
Loading