Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix type stability of sampling from Chisq, TDist, Gamma #1885

Merged
merged 12 commits into from
Aug 23, 2024
Merged
2 changes: 1 addition & 1 deletion src/samplers/gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,6 @@ end

function rand(rng::AbstractRNG, s::GammaIPSampler)
x = rand(rng, s.s)
e = randexp(rng)
e = randexp(rng, typeof(x))
x*exp(s.nia*e)
end
2 changes: 1 addition & 1 deletion src/univariate/continuous/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ cf(d::Exponential, t::Real) = 1/(1 - t * im * scale(d))


#### Sampling
rand(rng::AbstractRNG, d::Exponential) = xval(d, randexp(rng))
rand(rng::AbstractRNG, d::Exponential{T}) where {T} = xval(d, randexp(rng, T))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a slightly different case and easily broken (e.g., when T is not a floating point number type). In the TDist and Gamma case we just try to avoid promotions of a sample from another rand call, whereas this case goes deeper into the question of how rand should behave wrt parameters etc. (see also #1433 (comment)).

Suggested change
rand(rng::AbstractRNG, d::Exponential{T}) where {T} = xval(d, randexp(rng, T))
rand(rng::AbstractRNG, d::Exponential) = xval(d, randexp(rng))

Copy link
Contributor Author

@Red-Portal Red-Portal Aug 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do this though, the return type of rand(Gamma(Float32, Float32)) changes depending on the value of the shape parameter because shape == 1 samples from Exponential. (This is why the tests are currently failing.) Should we let this happen? I imagine some people will be super surprised by such behavior.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @devmotion could you comment on this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. I didn't realize that GammaMTSampler already respects the parameter types (but samples are not necessarily of the parameter type:

d = shape(g) - 1//3
c = inv(3 * sqrt(d))
# Pre-compute scaling factor
κ = d * scale(g)
# We also pre-compute the factor in the squeeze function
return GammaMTSampler(promote(d, c, κ, 331//10_000)...)
).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's an argument for using the same approach as for Normal here, until we move to a better/different API:

Suggested change
rand(rng::AbstractRNG, d::Exponential{T}) where {T} = xval(d, randexp(rng, T))
rand(rng::AbstractRNG, d::Exponential{T}) where {T} = xval(d, randexp(rng, float(T)))



#### Fit model
Expand Down
2 changes: 1 addition & 1 deletion src/univariate/continuous/tdist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ end
function rand(rng::AbstractRNG, d::TDist)
ν = d.ν
z = sqrt(rand(rng, Chisq{typeof(ν)}(ν)) / ν)
return randn(rng) / (isinf(ν) ? one(z) : z)
return randn(rng, typeof(ν)) / (isinf(ν) ? one(z) : z)
Red-Portal marked this conversation as resolved.
Show resolved Hide resolved
end

function cf(d::TDist{T}, t::Real) where T <: Real
Expand Down
10 changes: 8 additions & 2 deletions test/univariate/continuous/chisq.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
test_cgf(Chisq(1), (0.49, -1, -100, -1f6))
test_cgf(Chisq(3), (0.49, -1, -100, -1f6))

@testset "Chisq" begin
test_cgf(Chisq(1), (0.49, -1, -100, -1.0f6))
test_cgf(Chisq(3), (0.49, -1, -100, -1.0f6))

@test rand(Chisq(1.0)) isa Float64
@test rand(Chisq(1.0f0)) isa Float32
end
Red-Portal marked this conversation as resolved.
Show resolved Hide resolved
40 changes: 25 additions & 15 deletions test/univariate/continuous/gamma.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,35 @@
using Test, Distributions, OffsetArrays

test_cgf(Gamma(1 ,1 ), (0.9, -1, -100f0, -1e6))
test_cgf(Gamma(10 ,1 ), (0.9, -1, -100f0, -1e6))
test_cgf(Gamma(0.2, 10), (0.08, -1, -100f0, -1e6))
@testset "Gamma" begin
test_cgf(Gamma(1, 1), (0.9, -1, -100.0f0, -1e6))
test_cgf(Gamma(10, 1), (0.9, -1, -100.0f0, -1e6))
test_cgf(Gamma(0.2, 10), (0.08, -1, -100.0f0, -1e6))

@testset "Gamma suffstats and OffsetArrays" begin
a = rand(Gamma(), 11)
wa = 1.0:11.0
@testset "Gamma suffstats and OffsetArrays" begin
a = rand(Gamma(), 11)
wa = 1.0:11.0

resulta = @inferred(suffstats(Gamma, a))
resulta = @inferred(suffstats(Gamma, a))

resultwa = @inferred(suffstats(Gamma, a, wa))
resultwa = @inferred(suffstats(Gamma, a, wa))

b = OffsetArray(a, -5:5)
wb = OffsetArray(wa, -5:5)
b = OffsetArray(a, -5:5)
wb = OffsetArray(wa, -5:5)

resultb = @inferred(suffstats(Gamma, b))
@test resulta == resultb
resultb = @inferred(suffstats(Gamma, b))
@test resulta == resultb

resultwb = @inferred(suffstats(Gamma, b, wb))
@test resultwa == resultwb
resultwb = @inferred(suffstats(Gamma, b, wb))
@test resultwa == resultwb

@test_throws DimensionMismatch suffstats(Gamma, a, wb)
@test_throws DimensionMismatch suffstats(Gamma, a, wb)
end

@test rand(Gamma(1.0, 1.0)) isa Float64
@test rand(Gamma(0.5, 1.0)) isa Float64
@test rand(Gamma(2.0, 1.0)) isa Float64

@test rand(Gamma(1.0f0, 1.0f0)) isa Float32
@test rand(Gamma(0.5f0, 1.0f0)) isa Float32
@test rand(Gamma(2.0f0, 1.0f0)) isa Float32
Red-Portal marked this conversation as resolved.
Show resolved Hide resolved
end
19 changes: 14 additions & 5 deletions test/univariate/continuous/tdist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,19 @@ using ForwardDiff

using Test

@testset "Type stability of `rand` (#1614)" begin
if VERSION >= v"1.9.0-DEV.348"
# randn(::BigFloat) was only added in https://github.com/JuliaLang/julia/pull/44714
@inferred(rand(TDist(big"1.0")))
@testset "TDist" begin
@testset "Type stability of `rand` (#1614)" begin
if VERSION >= v"1.9.0-DEV.348"
# randn(::BigFloat) was only added in https://github.com/JuliaLang/julia/pull/44714
@inferred(rand(TDist(big"1.0")))
end
@inferred(rand(TDist(ForwardDiff.Dual(1.0))))

end
@inferred(rand(TDist(ForwardDiff.Dual(1.0))))

@test rand(TDist(1.0)) isa Float64
@test rand(TDist(1.0f0)) isa Float32

@test entropy(TDist(1.0)) isa Float64
@test entropy(TDist(1.0f0)) isa Float32
Red-Portal marked this conversation as resolved.
Show resolved Hide resolved
end
Loading