Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Dylan H. Morris <dylanhmorris@users.noreply.github.com>
  • Loading branch information
sbidari and dylanhmorris authored Feb 27, 2025
1 parent f1fccac commit 14ba0bb
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions pyrenew_hew/pyrenew_hew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
import numpyro.distributions.transforms as transforms
import pyrenew.transformation as transformation
from jax.typing import ArrayLike
from numpyro.infer.reparam import LocScaleReparam
Expand Down Expand Up @@ -121,7 +120,7 @@ def sample(self, n_days_post_init: int):
log_rtu_weekly_subpop = log_rtu_weekly[:, jnp.newaxis]
else:
i_first_obs_over_n_ref_subpop = transformation.SigmoidTransform()(
transforms.logit(i0_first_obs_n)
transformation.SigmoidTransform().inv(i0_first_obs_n)
+ self.offset_ref_logit_i_first_obs_rv(),
) # Using numpyro.distributions.transform as 'pyrenew.transformation' has no attribute 'logit'
initial_exp_growth_rate_ref_subpop = (
Expand All @@ -137,7 +136,7 @@ def sample(self, n_days_post_init: int):
DistributionalVariable(
"i_first_obs_over_n_non_ref_subpop_raw",
dist.Normal(
transforms.logit(i0_first_obs_n),
transformation.SigmoidTransform().inv(i0_first_obs_n),
self.sigma_i_first_obs_rv(),
),
reparam=LocScaleReparam(0),
Expand Down

0 comments on commit 14ba0bb

Please sign in to comment.