From 9da7a9dfbd87c460c7bf733e6d961db080245507 Mon Sep 17 00:00:00 2001 From: sbidari Date: Fri, 28 Feb 2025 15:31:12 -0500 Subject: [PATCH] add missing args to to_forecast_data, code review suggestions --- pyrenew_hew/pyrenew_hew_data.py | 4 +++- pyrenew_hew/pyrenew_wastewater_data.py | 19 ++++++++++++------- tests/test_pyrenew_wastewater_data.py | 15 +++++++++++++++ 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/pyrenew_hew/pyrenew_hew_data.py b/pyrenew_hew/pyrenew_hew_data.py index f8af8caf..91c53078 100644 --- a/pyrenew_hew/pyrenew_hew_data.py +++ b/pyrenew_hew/pyrenew_hew_data.py @@ -34,7 +34,7 @@ def __init__( ww_log_lod: ArrayLike = None, date_observed_disease_wastewater: ArrayLike = None, data_observed_disease_wastewater_conc: ArrayLike = None, - pop_fraction: ArrayLike = jnp.array([1]), + pop_fraction: ArrayLike = None, ) -> None: self.n_ed_visits_data_days_ = n_ed_visits_data_days self.n_hospital_admissions_data_days_ = n_hospital_admissions_data_days @@ -220,5 +220,7 @@ def to_forecast_data(self, n_forecast_points: int) -> Self: ww_observed_subpops=self.ww_observed_subpops, ww_observed_times=self.ww_observed_times, lab_site_to_subpop_map=self.lab_site_to_subpop_map, + ww_log_lod=self.ww_log_lod, + pop_fraction=self.pop_fraction, data_observed_disease_wastewater_conc=None, ) diff --git a/pyrenew_hew/pyrenew_wastewater_data.py b/pyrenew_hew/pyrenew_wastewater_data.py index b6d23efc..bc0c9471 100644 --- a/pyrenew_hew/pyrenew_wastewater_data.py +++ b/pyrenew_hew/pyrenew_wastewater_data.py @@ -192,11 +192,16 @@ def lab_site_to_subpop_map(self): def to_pyrenew_hew_data_args(self): return { - attr: value - for attr, value in ( - (attr, getattr(self, attr)) - for attr, prop in self.__class__.__dict__.items() - if isinstance(prop, property) - ) - if isinstance(value, ArrayLike) + attr: getattr(self, attr) + for attr in [ + "n_ww_lab_sites", + "ww_censored", + "ww_uncensored", + "ww_log_lod", + "ww_observed_lab_sites", + "ww_observed_subpops", + "ww_observed_times", + "data_observed_disease_wastewater_conc", + "lab_site_to_subpop_map", + ] } diff --git a/tests/test_pyrenew_wastewater_data.py b/tests/test_pyrenew_wastewater_data.py index 58056d8d..0d9f51d9 100644 --- a/tests/test_pyrenew_wastewater_data.py +++ b/tests/test_pyrenew_wastewater_data.py @@ -60,6 +60,21 @@ def test_pyrenew_wastewater_data(): assert forecast_data.data_observed_disease_wastewater_conc is None assert data.data_observed_disease_wastewater_conc is not None + assert np.array_equal(data.ww_censored, forecast_data.ww_censored) + assert np.array_equal(data.ww_uncensored, forecast_data.ww_uncensored) + assert np.array_equal(data.ww_log_lod, forecast_data.ww_log_lod) + assert np.array_equal( + data.ww_observed_lab_sites, forecast_data.ww_observed_lab_sites + ) + assert np.array_equal( + data.ww_observed_subpops, forecast_data.ww_observed_subpops + ) + assert np.array_equal( + data.ww_observed_times, forecast_data.ww_observed_times + ) + assert np.array_equal(data.n_ww_lab_sites, forecast_data.n_ww_lab_sites) + assert np.array_equal(data.pop_fraction, forecast_data.pop_fraction) + assert np.array_equal( data.data_observed_disease_wastewater_conc, ww_data["log_genome_copies_per_ml"],