diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index 97b766f3..3eeb09cb 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -5,7 +5,6 @@ import jax.numpy as jnp import numpyro import numpyro.distributions as dist -import numpyro.distributions.transforms as transforms import pyrenew.transformation as transformation from jax.typing import ArrayLike from numpyro.infer.reparam import LocScaleReparam @@ -121,7 +120,7 @@ def sample(self, n_days_post_init: int): log_rtu_weekly_subpop = log_rtu_weekly[:, jnp.newaxis] else: i_first_obs_over_n_ref_subpop = transformation.SigmoidTransform()( - transforms.logit(i0_first_obs_n) + transformation.SigmoidTransform().inv(i0_first_obs_n) + self.offset_ref_logit_i_first_obs_rv(), ) # Using numpyro.distributions.transform as 'pyrenew.transformation' has no attribute 'logit' initial_exp_growth_rate_ref_subpop = ( @@ -137,7 +136,7 @@ def sample(self, n_days_post_init: int): DistributionalVariable( "i_first_obs_over_n_non_ref_subpop_raw", dist.Normal( - transforms.logit(i0_first_obs_n), + transformation.SigmoidTransform().inv(i0_first_obs_n), self.sigma_i_first_obs_rv(), ), reparam=LocScaleReparam(0),