Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
sbidari committed Feb 19, 2025
1 parent 1c627ab commit d48dec6
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 22 deletions.
48 changes: 33 additions & 15 deletions pipelines/prep_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
import jax.numpy as jnp
import polars as pl
import polars.selectors as cs
from prep_ww_data import (
get_nwss_data,
preprocess_ww_data,
)

_disease_map = {
"COVID-19": "COVID-19/Omicron",
Expand Down Expand Up @@ -51,7 +47,9 @@ def py_scalar_to_r_scalar(py_scalar):
state_abb_for_query = state_abb if state_abb != "US" else "USA"

temp_file = Path(temp_dir, "nhsn_temp.parquet")
api_key_id = credentials_dict.get("nhsn_api_key_id", os.getenv("NHSN_API_KEY_ID"))
api_key_id = credentials_dict.get(
"nhsn_api_key_id", os.getenv("NHSN_API_KEY_ID")
)
api_key_secret = credentials_dict.get(
"nhsn_api_key_secret", os.getenv("NHSN_API_KEY_SECRET")
)
Expand Down Expand Up @@ -83,7 +81,9 @@ def py_scalar_to_r_scalar(py_scalar):
if result.returncode != 0:
raise RuntimeError(f"pull_and_save_nhsn: {result.stderr}")
raw_dat = pl.read_parquet(temp_file)
dat = raw_dat.with_columns(weekendingdate=pl.col("weekendingdate").cast(pl.Date))
dat = raw_dat.with_columns(
weekendingdate=pl.col("weekendingdate").cast(pl.Date)
)
return dat


Expand All @@ -105,7 +105,9 @@ def combine_nssp_and_nhsn(
variable_name="drop_me",
value_name=".value",
)
.with_columns(pl.col("count_type").replace(count_type_dict).alias(".variable"))
.with_columns(
pl.col("count_type").replace(count_type_dict).alias(".variable")
)
.select(cs.exclude(["count_type", "drop_me"]))
)

Expand Down Expand Up @@ -185,7 +187,9 @@ def process_state_level_data(

if state_abb == "US":
locations_to_aggregate = (
state_pop_df.filter(pl.col("abb") != "US").get_column("abb").unique()
state_pop_df.filter(pl.col("abb") != "US")
.get_column("abb")
.unique()
)
logger.info("Aggregating state-level data to national")
state_level_nssp_data = aggregate_to_national(
Expand All @@ -212,7 +216,9 @@ def process_state_level_data(
]
)
.with_columns(
disease=pl.col("disease").cast(pl.Utf8).replace(_inverse_disease_map),
disease=pl.col("disease")
.cast(pl.Utf8)
.replace(_inverse_disease_map),
)
.sort(["date", "disease"])
.collect(streaming=True)
Expand Down Expand Up @@ -244,7 +250,9 @@ def aggregate_facility_level_nssp_to_state(
if state_abb == "US":
logger.info("Aggregating facility-level data to national")
locations_to_aggregate = (
state_pop_df.filter(pl.col("abb") != "US").get_column("abb").unique()
state_pop_df.filter(pl.col("abb") != "US")
.get_column("abb")
.unique()
)
facility_level_nssp_data = aggregate_to_national(
facility_level_nssp_data,
Expand All @@ -263,7 +271,9 @@ def aggregate_facility_level_nssp_to_state(
.group_by(["reference_date", "disease"])
.agg(pl.col("value").sum().alias("ed_visits"))
.with_columns(
disease=pl.col("disease").cast(pl.Utf8).replace(_inverse_disease_map),
disease=pl.col("disease")
.cast(pl.Utf8)
.replace(_inverse_disease_map),
geo_value=pl.lit(state_abb).cast(pl.Utf8),
)
.rename({"reference_date": "date"})
Expand Down Expand Up @@ -354,12 +364,16 @@ def process_and_save_state(

if facility_level_nssp_data is None and state_level_nssp_data is None:
raise ValueError(
"Must provide at least one " "of facility-level and state-level" "NSSP data"
"Must provide at least one "
"of facility-level and state-level"
"NSSP data"
)

state_pop_df = get_state_pop_df()

state_pop = state_pop_df.filter(pl.col("abb") == state_abb).item(0, "population")
state_pop = state_pop_df.filter(pl.col("abb") == state_abb).item(
0, "population"
)

(generation_interval_pmf, delay_pmf, right_truncation_pmf) = get_pmfs(
param_estimates=param_estimates, state_abb=state_abb, disease=disease
Expand Down Expand Up @@ -408,13 +422,17 @@ def process_and_save_state(
credentials_dict=credentials_dict,
).with_columns(pl.lit("train").alias("data_type"))

nssp_training_dates = nssp_training_data.get_column("date").unique().to_list()
nssp_training_dates = (
nssp_training_data.get_column("date").unique().to_list()
)
nhsn_training_dates = (
nhsn_training_data.get_column("weekendingdate").unique().to_list()
)

nhsn_first_date_index = next(
i for i, x in enumerate(nssp_training_dates) if x == min(nhsn_training_dates)
i
for i, x in enumerate(nssp_training_dates)
if x == min(nhsn_training_dates)
)
nhsn_step_size = 7

Expand Down
29 changes: 22 additions & 7 deletions pipelines/prep_ww_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def clean_nwss_data(nwss_data):
.when(pl.col("pcr_target_units") == "log10 copies/l wastewater")
.then((10 ** pl.col("pcr_target_avg_conc")) / 1000)
.otherwise(None),
lod_sewage=pl.when(pl.col("pcr_target_units") == "copies/l wastewater")
lod_sewage=pl.when(
pl.col("pcr_target_units") == "copies/l wastewater"
)
.then(pl.col("lod_sewage") / 1000)
.when(pl.col("pcr_target_units") == "log10 copies/l wastewater")
.then((10 ** pl.col("lod_sewage")) / 1000)
Expand Down Expand Up @@ -128,7 +130,9 @@ def clean_nwss_data(nwss_data):
)
.with_columns(
[
pl.col("pcr_target_avg_conc").log().alias("log_genome_copies_per_ml"),
pl.col("pcr_target_avg_conc")
.log()
.alias("log_genome_copies_per_ml"),
pl.col("lod_sewage").log().alias("log_lod"),
pl.col("location").str.to_uppercase().alias("location"),
pl.col("site").cast(pl.String).alias("site"),
Expand Down Expand Up @@ -207,7 +211,9 @@ def validate_ww_conc_data(
.eq(1)
.all()
):
raise ValueError("The data contains sites with varying population sizes.")
raise ValueError(
"The data contains sites with varying population sizes."
)

return None

Expand Down Expand Up @@ -239,10 +245,14 @@ def preprocess_ww_data(
.with_row_index("lab_site_index")
)
site_df = (
ww_data_ordered.select([wwtp_col_name]).unique().with_row_index("site_index")
ww_data_ordered.select([wwtp_col_name])
.unique()
.with_row_index("site_index")
)
ww_preprocessed = (
ww_data_ordered.join(lab_site_df, on=[lab_col_name, wwtp_col_name], how="left")
ww_data_ordered.join(
lab_site_df, on=[lab_col_name, wwtp_col_name], how="left"
)
.join(site_df, on=wwtp_col_name, how="left")
.rename(
{
Expand All @@ -252,9 +262,14 @@ def preprocess_ww_data(
)
.with_columns(
lab_site_name=(
"Site: " + pl.col(wwtp_col_name) + ", Lab: " + pl.col(lab_col_name)
"Site: "
+ pl.col(wwtp_col_name)
+ ", Lab: "
+ pl.col(lab_col_name)
),
below_lod=(
pl.col("log_genome_copies_per_ml") <= pl.col("log_lod")
),
below_lod=(pl.col("log_genome_copies_per_ml") <= pl.col("log_lod")),
)
.select(
[
Expand Down

0 comments on commit d48dec6

Please sign in to comment.