Skip to content

Commit

Permalink
generic truncated normal & exponential
Browse files Browse the repository at this point in the history
  • Loading branch information
cossio committed Mar 7, 2020
1 parent be19a9e commit 073c529
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 41 deletions.
78 changes: 38 additions & 40 deletions src/truncated/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,50 +159,48 @@ end
#
# - Available at http://arxiv.org/abs/0907.4010

function randnt(rng::AbstractRNG, lb::Float64, ub::Float64, tp::Float64)
local r::Float64
if tp > 0.3 # has considerable chance of falling in [lb, ub]
r = randn(rng)
while r < lb || r > ub
r = randn(rng)
function randnt(rng::AbstractRNG, lb::T, ub::T, tp::T) where {T<:AbstractFloat}
if 3tp > 1 # has considerable chance of falling in [lb, ub]
while true
r = randn(rng, T)
if lb r ub
return r
end
end
return r

else
span = ub - lb
if lb > 0 && span > 2.0 / (lb + sqrt(lb^2 + 4.0)) * exp((lb^2 - lb * sqrt(lb^2 + 4.0)) / 4.0)
a = (lb + sqrt(lb^2 + 4.0))/2.0
while true
r = rand(rng, Exponential(1.0 / a)) + lb
u = rand(rng)
if u < exp(-0.5 * (r - a)^2) && r < ub
return r
end
end
span = ub - lb
a = (sqrt(lb^2 + 4) + lb) / 2
if lb > 0 && span > exp(lb * (lb - sqrt(lb^2 + 4)) / 4) / a
while true
r = rand(rng, Exponential(1 / a)) + lb
u = rand(rng, T)
if u < exp(-(r - a)^2 / 2) && r < ub
return r
end
elseif ub < 0 && ub - lb > 2.0 / (-ub + sqrt(ub^2 + 4.0)) * exp((ub^2 + ub * sqrt(ub^2 + 4.0)) / 4.0)
a = (-ub + sqrt(ub^2 + 4.0)) / 2.0
while true
r = rand(rng, Exponential(1.0 / a)) - ub
u = rand(rng)
if u < exp(-0.5 * (r - a)^2) && r < -lb
return -r
end
end
end
b = (sqrt(ub^2 + 4) - ub) / 2
if ub < 0 && span > exp(ub * (ub + sqrt(ub^2 + 4)) / 4) / b
while true
r = rand(rng, Exponential(1 / b)) - ub
u = rand(rng, T)
if u < exp(-(r - b)^2 / 2) && r < -lb
return -r
end
end
end
while true
r = lb + rand(rng, T) * span
u = rand(rng, T)
if lb > 0
rho = exp((lb^2 - r^2) / 2)
elseif ub < 0
rho = exp((ub^2 - r^2) / 2)
else
while true
r = lb + rand(rng) * (ub - lb)
u = rand(rng)
if lb > 0
rho = exp((lb^2 - r^2) * 0.5)
elseif ub < 0
rho = exp((ub^2 - r^2) * 0.5)
else
rho = exp(-r^2 * 0.5)
end
if u < rho
return r
end
end
rho = exp(-r^2 / 2)
end
if u < rho
return r
end
end
end
2 changes: 1 addition & 1 deletion src/univariate/continuous/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ cf(d::Exponential, t::Real) = 1/(1 - t * im * scale(d))


#### Sampling
rand(rng::AbstractRNG, d::Exponential) = xval(d, randexp(rng))
rand(rng::AbstractRNG, d::Exponential{T}) where {T} = xval(d, randexp(rng, T))


#### Fit model
Expand Down

0 comments on commit 073c529

Please sign in to comment.