Skip to content

Commit

Permalink
code review suggestion
Browse files Browse the repository at this point in the history
  • Loading branch information
sbidari committed Feb 12, 2025
1 parent 5af153b commit 6cb725e
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 12 deletions.
44 changes: 33 additions & 11 deletions pipelines/prep_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def py_scalar_to_r_scalar(py_scalar):
state_abb_for_query = state_abb if state_abb != "US" else "USA"

temp_file = Path(temp_dir, "nhsn_temp.parquet")
api_key_id = credentials_dict.get("nhsn_api_key_id", os.getenv("NHSN_API_KEY_ID"))
api_key_id = credentials_dict.get(
"nhsn_api_key_id", os.getenv("NHSN_API_KEY_ID")
)
api_key_secret = credentials_dict.get(
"nhsn_api_key_secret", os.getenv("NHSN_API_KEY_SECRET")
)
Expand Down Expand Up @@ -83,7 +85,9 @@ def py_scalar_to_r_scalar(py_scalar):
if result.returncode != 0:
raise RuntimeError(f"pull_and_save_nhsn: {result.stderr}")
raw_dat = pl.read_parquet(temp_file)
dat = raw_dat.with_columns(weekendingdate=pl.col("weekendingdate").cast(pl.Date))
dat = raw_dat.with_columns(
weekendingdate=pl.col("weekendingdate").cast(pl.Date)
)
return dat


Expand All @@ -105,7 +109,9 @@ def combine_nssp_and_nhsn(
variable_name="drop_me",
value_name=".value",
)
.with_columns(pl.col("count_type").replace(count_type_dict).alias(".variable"))
.with_columns(
pl.col("count_type").replace(count_type_dict).alias(".variable")
)
.select(cs.exclude(["count_type", "drop_me"]))
)

Expand Down Expand Up @@ -185,7 +191,9 @@ def process_state_level_data(

if state_abb == "US":
locations_to_aggregate = (
state_pop_df.filter(pl.col("abb") != "US").get_column("abb").unique()
state_pop_df.filter(pl.col("abb") != "US")
.get_column("abb")
.unique()
)
logger.info("Aggregating state-level data to national")
state_level_nssp_data = aggregate_to_national(
Expand All @@ -212,7 +220,9 @@ def process_state_level_data(
]
)
.with_columns(
disease=pl.col("disease").cast(pl.Utf8).replace(_inverse_disease_map),
disease=pl.col("disease")
.cast(pl.Utf8)
.replace(_inverse_disease_map),
)
.sort(["date", "disease"])
.collect(streaming=True)
Expand Down Expand Up @@ -244,7 +254,9 @@ def aggregate_facility_level_nssp_to_state(
if state_abb == "US":
logger.info("Aggregating facility-level data to national")
locations_to_aggregate = (
state_pop_df.filter(pl.col("abb") != "US").get_column("abb").unique()
state_pop_df.filter(pl.col("abb") != "US")
.get_column("abb")
.unique()
)
facility_level_nssp_data = aggregate_to_national(
facility_level_nssp_data,
Expand All @@ -263,7 +275,9 @@ def aggregate_facility_level_nssp_to_state(
.group_by(["reference_date", "disease"])
.agg(pl.col("value").sum().alias("ed_visits"))
.with_columns(
disease=pl.col("disease").cast(pl.Utf8).replace(_inverse_disease_map),
disease=pl.col("disease")
.cast(pl.Utf8)
.replace(_inverse_disease_map),
geo_value=pl.lit(state_abb).cast(pl.Utf8),
)
.rename({"reference_date": "date"})
Expand Down Expand Up @@ -353,12 +367,16 @@ def process_and_save_state(

if facility_level_nssp_data is None and state_level_nssp_data is None:
raise ValueError(
"Must provide at least one " "of facility-level and state-level" "NSSP data"
"Must provide at least one "
"of facility-level and state-level"
"NSSP data"
)

state_pop_df = get_state_pop_df()

state_pop = state_pop_df.filter(pl.col("abb") == state_abb).item(0, "population")
state_pop = state_pop_df.filter(pl.col("abb") == state_abb).item(
0, "population"
)

(generation_interval_pmf, delay_pmf, right_truncation_pmf) = get_pmfs(
param_estimates=param_estimates, state_abb=state_abb, disease=disease
Expand Down Expand Up @@ -407,13 +425,17 @@ def process_and_save_state(
credentials_dict=credentials_dict,
).with_columns(pl.lit("train").alias("data_type"))

nssp_training_dates = nssp_training_data.get_column("date").unique().to_list()
nssp_training_dates = (
nssp_training_data.get_column("date").unique().to_list()
)
nhsn_training_dates = (
nhsn_training_data.get_column("weekendingdate").unique().to_list()
)

nhsn_first_date_index = next(
i for i, x in enumerate(nssp_training_dates) if x == min(nhsn_training_dates)
i
for i, x in enumerate(nssp_training_dates)
if x == min(nhsn_training_dates)
)
nhsn_step_size = 7

Expand Down
2 changes: 1 addition & 1 deletion pipelines/prep_ww_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def preprocess_ww_data(
),
below_lod=(
pl.col("log_genome_copies_per_ml") <= pl.col("log_lod")
).cast(pl.Int8),
),
)
.select(
[
Expand Down

0 comments on commit 6cb725e

Please sign in to comment.