From e409ba7dea28a3cd9301c4efee7c367834c9d63e Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Thu, 27 Feb 2025 16:16:16 -0500 Subject: [PATCH 1/5] Adjust to fix test --- pyrenew_hew/pyrenew_wastewater_data.py | 27 +++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/pyrenew_hew/pyrenew_wastewater_data.py b/pyrenew_hew/pyrenew_wastewater_data.py index 5df556d7..a80216a7 100644 --- a/pyrenew_hew/pyrenew_wastewater_data.py +++ b/pyrenew_hew/pyrenew_wastewater_data.py @@ -27,15 +27,6 @@ def __init__( def site_subpop_spine(self): ww_data_present = self.data_observed_disease_wastewater is not None if ww_data_present: - add_auxiliary_subpop = ( - self.population_size - > self.data_observed_disease_wastewater.select( - pl.col("site_pop", "site", "lab", "lab_site_index") - ) - .unique() - .get_column("site_pop") - .sum() - ) site_indices = ( self.data_observed_disease_wastewater.select( ["site_index", "site", "site_pop"] @@ -43,15 +34,25 @@ def site_subpop_spine(self): .unique() .sort("site_index") ) + + total_pop_ww = ( + self.data_observed_disease_wastewater.unique( + ["site_pop", "site"] + ) + .get_column("site_pop") + .sum() + ) + + total_pop_no_ww = self.population_size - total_pop_ww > 0 + + add_auxiliary_subpop = total_pop_no_ww > 0 + if add_auxiliary_subpop: aux_subpop = pl.DataFrame( { "site_index": [None], "site": [None], - "site_pop": ( - self.population_size - - site_indices.get_column("site_pop").sum() - ).tolist(), + "site_pop": [total_pop_no_ww], } ) else: From dbdf08c008b8e0d17049e98d0d391895cf9e1fce Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Thu, 27 Feb 2025 16:17:53 -0500 Subject: [PATCH 2/5] Update pyrenew_hew/pyrenew_wastewater_data.py --- pyrenew_hew/pyrenew_wastewater_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrenew_hew/pyrenew_wastewater_data.py b/pyrenew_hew/pyrenew_wastewater_data.py index a80216a7..b459131b 100644 --- a/pyrenew_hew/pyrenew_wastewater_data.py +++ b/pyrenew_hew/pyrenew_wastewater_data.py @@ -43,7 +43,7 @@ def site_subpop_spine(self): .sum() ) - total_pop_no_ww = self.population_size - total_pop_ww > 0 + total_pop_no_ww = self.population_size - total_pop_ww add_auxiliary_subpop = total_pop_no_ww > 0 From c3181e2ecd0798ec84779d9c923f7e8643550f93 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Thu, 27 Feb 2025 16:27:46 -0500 Subject: [PATCH 3/5] Hack fix --- pyrenew_hew/pyrenew_wastewater_data.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pyrenew_hew/pyrenew_wastewater_data.py b/pyrenew_hew/pyrenew_wastewater_data.py index b459131b..40be8edf 100644 --- a/pyrenew_hew/pyrenew_wastewater_data.py +++ b/pyrenew_hew/pyrenew_wastewater_data.py @@ -43,8 +43,9 @@ def site_subpop_spine(self): .sum() ) - total_pop_no_ww = self.population_size - total_pop_ww - + total_pop_no_ww = ( + int(jnp.atleast_1d(self.population_size)[0]) - total_pop_ww + ) add_auxiliary_subpop = total_pop_no_ww > 0 if add_auxiliary_subpop: From dc95f78174170d4a04dba879c93a08ab9fb1d22e Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 27 Feb 2025 17:49:13 -0500 Subject: [PATCH 4/5] pre-commit --- pyrenew_hew/pyrenew_wastewater_data.py | 38 ++++++++++++++++++-------- 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/pyrenew_hew/pyrenew_wastewater_data.py b/pyrenew_hew/pyrenew_wastewater_data.py index 377066e1..b6d23efc 100644 --- a/pyrenew_hew/pyrenew_wastewater_data.py +++ b/pyrenew_hew/pyrenew_wastewater_data.py @@ -18,7 +18,9 @@ def __init__( data_observed_disease_wastewater: pl.DataFrame = None, population_size: int = None, ) -> None: - self.data_observed_disease_wastewater = data_observed_disease_wastewater + self.data_observed_disease_wastewater = ( + data_observed_disease_wastewater + ) self.population_size = population_size @property @@ -34,7 +36,9 @@ def site_subpop_spine(self): ) total_pop_ww = ( - self.data_observed_disease_wastewater.unique(["site_pop", "site"]) + self.data_observed_disease_wastewater.unique( + ["site_pop", "site"] + ) .get_column("site_pop") .sum() ) @@ -55,10 +59,12 @@ def site_subpop_spine(self): site_subpop_spine = ( pl.concat([aux_subpop, site_indices], how="vertical_relaxed") .with_columns( - subpop_index=pl.col("site_index").cum_count().alias("subpop_index"), - subpop_name=pl.format("Site: {}", pl.col("site")).fill_null( - "remainder of population" - ), + subpop_index=pl.col("site_index") + .cum_count() + .alias("subpop_index"), + subpop_name=pl.format( + "Site: {}", pl.col("site") + ).fill_null("remainder of population"), ) .rename({"site_pop": "subpop_pop"}) ) @@ -108,7 +114,9 @@ def wastewater_data_extended(self): @property def date_observed_disease_wastewater(self): if self.data_observed_disease_wastewater is not None: - return self.data_observed_disease_wastewater.get_column("date").unique() + return self.data_observed_disease_wastewater.get_column( + "date" + ).unique() @property def data_observed_disease_wastewater_conc(self): @@ -144,17 +152,23 @@ def ww_observed_times(self): @property def ww_observed_subpops(self): if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended.get_column("subpop_index").to_numpy() + return self.wastewater_data_extended.get_column( + "subpop_index" + ).to_numpy() @property def ww_observed_lab_sites(self): if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended.get_column("lab_site_index").to_numpy() + return self.wastewater_data_extended.get_column( + "lab_site_index" + ).to_numpy() @property def ww_log_lod(self): if self.data_observed_disease_wastewater is not None: - return self.wastewater_data_extended.get_column("log_lod").to_numpy() + return self.wastewater_data_extended.get_column( + "log_lod" + ).to_numpy() @property def n_ww_lab_sites(self): @@ -166,7 +180,9 @@ def lab_site_to_subpop_map(self): if self.data_observed_disease_wastewater is not None: return ( ( - self.wastewater_data_extended["lab_site_index", "subpop_index"] + self.wastewater_data_extended[ + "lab_site_index", "subpop_index" + ] .unique() .sort(by="lab_site_index") ) From 9da7a9dfbd87c460c7bf733e6d961db080245507 Mon Sep 17 00:00:00 2001 From: sbidari Date: Fri, 28 Feb 2025 15:31:12 -0500 Subject: [PATCH 5/5] add missing args to to_forecast_data, code review suggestions --- pyrenew_hew/pyrenew_hew_data.py | 4 +++- pyrenew_hew/pyrenew_wastewater_data.py | 19 ++++++++++++------- tests/test_pyrenew_wastewater_data.py | 15 +++++++++++++++ 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/pyrenew_hew/pyrenew_hew_data.py b/pyrenew_hew/pyrenew_hew_data.py index f8af8caf..91c53078 100644 --- a/pyrenew_hew/pyrenew_hew_data.py +++ b/pyrenew_hew/pyrenew_hew_data.py @@ -34,7 +34,7 @@ def __init__( ww_log_lod: ArrayLike = None, date_observed_disease_wastewater: ArrayLike = None, data_observed_disease_wastewater_conc: ArrayLike = None, - pop_fraction: ArrayLike = jnp.array([1]), + pop_fraction: ArrayLike = None, ) -> None: self.n_ed_visits_data_days_ = n_ed_visits_data_days self.n_hospital_admissions_data_days_ = n_hospital_admissions_data_days @@ -220,5 +220,7 @@ def to_forecast_data(self, n_forecast_points: int) -> Self: ww_observed_subpops=self.ww_observed_subpops, ww_observed_times=self.ww_observed_times, lab_site_to_subpop_map=self.lab_site_to_subpop_map, + ww_log_lod=self.ww_log_lod, + pop_fraction=self.pop_fraction, data_observed_disease_wastewater_conc=None, ) diff --git a/pyrenew_hew/pyrenew_wastewater_data.py b/pyrenew_hew/pyrenew_wastewater_data.py index b6d23efc..bc0c9471 100644 --- a/pyrenew_hew/pyrenew_wastewater_data.py +++ b/pyrenew_hew/pyrenew_wastewater_data.py @@ -192,11 +192,16 @@ def lab_site_to_subpop_map(self): def to_pyrenew_hew_data_args(self): return { - attr: value - for attr, value in ( - (attr, getattr(self, attr)) - for attr, prop in self.__class__.__dict__.items() - if isinstance(prop, property) - ) - if isinstance(value, ArrayLike) + attr: getattr(self, attr) + for attr in [ + "n_ww_lab_sites", + "ww_censored", + "ww_uncensored", + "ww_log_lod", + "ww_observed_lab_sites", + "ww_observed_subpops", + "ww_observed_times", + "data_observed_disease_wastewater_conc", + "lab_site_to_subpop_map", + ] } diff --git a/tests/test_pyrenew_wastewater_data.py b/tests/test_pyrenew_wastewater_data.py index 58056d8d..0d9f51d9 100644 --- a/tests/test_pyrenew_wastewater_data.py +++ b/tests/test_pyrenew_wastewater_data.py @@ -60,6 +60,21 @@ def test_pyrenew_wastewater_data(): assert forecast_data.data_observed_disease_wastewater_conc is None assert data.data_observed_disease_wastewater_conc is not None + assert np.array_equal(data.ww_censored, forecast_data.ww_censored) + assert np.array_equal(data.ww_uncensored, forecast_data.ww_uncensored) + assert np.array_equal(data.ww_log_lod, forecast_data.ww_log_lod) + assert np.array_equal( + data.ww_observed_lab_sites, forecast_data.ww_observed_lab_sites + ) + assert np.array_equal( + data.ww_observed_subpops, forecast_data.ww_observed_subpops + ) + assert np.array_equal( + data.ww_observed_times, forecast_data.ww_observed_times + ) + assert np.array_equal(data.n_ww_lab_sites, forecast_data.n_ww_lab_sites) + assert np.array_equal(data.pop_fraction, forecast_data.pop_fraction) + assert np.array_equal( data.data_observed_disease_wastewater_conc, ww_data["log_genome_copies_per_ml"],