-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix and test TuringDirichlet
constructors
#152
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great to me. Just one minor comment.
@inbounds for (i, αi) in zip(eachindex(x), d.alpha) | ||
x[i] = rand(rng, Gamma(αi)) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about this:
@inbounds for (i, αi) in zip(eachindex(x), d.alpha) | |
x[i] = rand(rng, Gamma(αi)) | |
end | |
@. x = rand(rng, Gamma(d.alpha)) |
According to BenchmarkTools
, this appears much more efficient in terms of speed and memory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. Generally, the broadcasting machinery is very complex and often creates a large overhead. This also leads to quite involved implementations in e.g. Tracker and Zygote based on ForwardDiff for broadcasting. However, I did not compare it in this case. What exactly did you bechmark?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see. I did not test with AD. I know PyTorch implements a method to differentiate through random Gammas. But I didn't think that was implemented here, so I didn't test.
This is what I did:
using BenchmarkTools
using Distributions
import Random
rng = Random.MersenneTwister(0)
x = rand(10)
alpha = rand(10) * 3
v1 = @benchmark @inbounds for (i, αi) in zip(eachindex(x), alpha)
x[i] = rand(rng, Gamma(αi))
end
v2 = @benchmark @. x = rand(rng, Gamma(alpha))
v2_against_v1 = judge(median(v2), median(v1))
Results
julia> v2_against_v1
BenchmarkTools.TrialJudgement:
time: -68.87% => improvement (5.00% tolerance)
memory: -96.06% => improvement (1.00% tolerance)
julia> v1
BenchmarkTools.Trial:
memory estimate: 3.17 KiB
allocs estimate: 112
--------------
minimum time: 8.028 μs (0.00% GC)
median time: 9.868 μs (0.00% GC)
mean time: 10.664 μs (1.97% GC)
maximum time: 1.125 ms (98.60% GC)
--------------
samples: 10000
evals/sample: 3
julia> v2
BenchmarkTools.Trial:
memory estimate: 128 bytes
allocs estimate: 5
--------------
minimum time: 2.631 μs (0.00% GC)
median time: 3.072 μs (0.00% GC)
mean time: 3.202 μs (0.00% GC)
maximum time: 6.997 μs (0.00% GC)
--------------
samples: 10000
evals/sample: 9
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The benchmarks are not correct, both approaches should not have any allocations. You can obtain better estimates by defining two functions and interpolating the global variables:
using BenchmarkTools
using Distributions
using Random
Random.seed!(1234)
x = rand(10)
alpha = 3 * rand(10)
function f!(rng, x, alpha)
@inbounds for (i, αi) in zip(eachindex(x), alpha)
x[i] = rand(rng, Gamma(αi))
end
return x
end
function g!(rng, x, alpha)
x .= rand.((rng,), Gamma.(alpha))
return x
end
Random.seed!(1)
v1 = @benchmark f!($(Random.GLOBAL_RNG), $x, $alpha)
Random.seed!(1)
v2 = @benchmark g!($(Random.GLOBAL_RNG), $x, $alpha)
v2_against_v1 = judge(median(v2), median(v1))
julia> v1 = @benchmark f!($(Random.GLOBAL_RNG), $x, $alpha)
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
--------------
minimum time: 673.462 ns (0.00% GC)
median time: 768.867 ns (0.00% GC)
mean time: 778.740 ns (0.00% GC)
maximum time: 2.939 μs (0.00% GC)
--------------
samples: 10000
evals/sample: 158
julia> v2 = @benchmark g!($(Random.GLOBAL_RNG), $x, $alpha)
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
--------------
minimum time: 668.188 ns (0.00% GC)
median time: 763.924 ns (0.00% GC)
mean time: 769.628 ns (0.00% GC)
maximum time: 1.301 μs (0.00% GC)
--------------
samples: 10000
evals/sample: 165
julia> v2_against_v1 = judge(median(v2), median(v1))
BenchmarkTools.TrialJudgement:
time: -0.64% => invariant (5.00% tolerance)
memory: +0.00% => invariant (1.00% tolerance)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. Thanks for looking into this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't feel strongly about either implementation in that case. To me, x .= rand.(rng, Gamma.(alpha))
looks a little cleaner, but the original might be clearer to most people.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me except the comment from @luiarthur!
Why are some of the tests listed as "Broken"?
|
Some distributions don't work for all AD backends. By marking them as broken (if possible - in some cases only a subset of the tests fail or tests error from which we can't recover) we can check if upstream changes in e.g. the AD backends fix the problems. |
thanks, @devmotion and @luiarthur! |
This PR fixes TuringLang/Turing.jl#1530 by relaxing the type constraints of
TuringDirichlet
, in line with recent updates in Distributions that allow integer-valued parameters.It should also be more efficient to use
Fill
instead offill
(again in line with changes in Distributions).I added some tests. Tests of Turing master pass locally with this PR.