diff --git a/src/truncated/normal.jl b/src/truncated/normal.jl index fa827d70a6..43b509fa55 100644 --- a/src/truncated/normal.jl +++ b/src/truncated/normal.jl @@ -159,50 +159,48 @@ end # # - Available at http://arxiv.org/abs/0907.4010 -function randnt(rng::AbstractRNG, lb::Float64, ub::Float64, tp::Float64) - local r::Float64 - if tp > 0.3 # has considerable chance of falling in [lb, ub] - r = randn(rng) - while r < lb || r > ub - r = randn(rng) +function randnt(rng::AbstractRNG, lb::T, ub::T, tp::T) where {T<:AbstractFloat} + if 3tp > 1 # has considerable chance of falling in [lb, ub] + while true + r = randn(rng, T) + if lb ≤ r ≤ ub + return r + end end - return r - - else - span = ub - lb - if lb > 0 && span > 2.0 / (lb + sqrt(lb^2 + 4.0)) * exp((lb^2 - lb * sqrt(lb^2 + 4.0)) / 4.0) - a = (lb + sqrt(lb^2 + 4.0))/2.0 - while true - r = rand(rng, Exponential(1.0 / a)) + lb - u = rand(rng) - if u < exp(-0.5 * (r - a)^2) && r < ub - return r - end + end + span = ub - lb + a = (sqrt(lb^2 + 4) + lb) / 2 + if lb > 0 && span > exp(lb * (lb - sqrt(lb^2 + 4)) / 4) / a + while true + r = rand(rng, Exponential(1 / a)) + lb + u = rand(rng, T) + if u < exp(-(r - a)^2 / 2) && r < ub + return r end - elseif ub < 0 && ub - lb > 2.0 / (-ub + sqrt(ub^2 + 4.0)) * exp((ub^2 + ub * sqrt(ub^2 + 4.0)) / 4.0) - a = (-ub + sqrt(ub^2 + 4.0)) / 2.0 - while true - r = rand(rng, Exponential(1.0 / a)) - ub - u = rand(rng) - if u < exp(-0.5 * (r - a)^2) && r < -lb - return -r - end + end + end + b = (sqrt(ub^2 + 4) - ub) / 2 + if ub < 0 && span > exp(ub * (ub + sqrt(ub^2 + 4)) / 4) / b + while true + r = rand(rng, Exponential(1 / b)) - ub + u = rand(rng, T) + if u < exp(-(r - b)^2 / 2) && r < -lb + return -r end + end + end + while true + r = lb + rand(rng, T) * span + u = rand(rng, T) + if lb > 0 + rho = exp((lb^2 - r^2) / 2) + elseif ub < 0 + rho = exp((ub^2 - r^2) / 2) else - while true - r = lb + rand(rng) * (ub - lb) - u = rand(rng) - if lb > 0 - rho = exp((lb^2 - r^2) * 0.5) - elseif ub < 0 - rho = exp((ub^2 - r^2) * 0.5) - else - rho = exp(-r^2 * 0.5) - end - if u < rho - return r - end - end + rho = exp(-r^2 / 2) + end + if u < rho + return r end end end diff --git a/src/univariate/continuous/exponential.jl b/src/univariate/continuous/exponential.jl index 92ee56cb3d..0deaa0385d 100644 --- a/src/univariate/continuous/exponential.jl +++ b/src/univariate/continuous/exponential.jl @@ -88,7 +88,7 @@ cf(d::Exponential, t::Real) = 1/(1 - t * im * scale(d)) #### Sampling -rand(rng::AbstractRNG, d::Exponential) = xval(d, randexp(rng)) +rand(rng::AbstractRNG, d::Exponential{T}) where {T} = xval(d, randexp(rng, T)) #### Fit model