Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix and test TuringDirichlet constructors #152

Merged
merged 3 commits into from
Feb 3, 2021
Merged

Conversation

devmotion
Copy link
Member

This PR fixes TuringLang/Turing.jl#1530 by relaxing the type constraints of TuringDirichlet, in line with recent updates in Distributions that allow integer-valued parameters.

It should also be more efficient to use Fill instead of fill (again in line with changes in Distributions).

I added some tests. Tests of Turing master pass locally with this PR.

@devmotion devmotion requested a review from yebai January 30, 2021 17:05
Copy link

@luiarthur luiarthur left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me. Just one minor comment.

Comment on lines +38 to +40
@inbounds for (i, αi) in zip(eachindex(x), d.alpha)
x[i] = rand(rng, Gamma(αi))
end
Copy link

@luiarthur luiarthur Feb 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this:

Suggested change
@inbounds for (i, αi) in zip(eachindex(x), d.alpha)
x[i] = rand(rng, Gamma(αi))
end
@. x = rand(rng, Gamma(d.alpha))

According to BenchmarkTools, this appears much more efficient in terms of speed and memory.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. Generally, the broadcasting machinery is very complex and often creates a large overhead. This also leads to quite involved implementations in e.g. Tracker and Zygote based on ForwardDiff for broadcasting. However, I did not compare it in this case. What exactly did you bechmark?

Copy link

@luiarthur luiarthur Feb 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. I did not test with AD. I know PyTorch implements a method to differentiate through random Gammas. But I didn't think that was implemented here, so I didn't test.

This is what I did:

using BenchmarkTools
using Distributions
import Random

rng = Random.MersenneTwister(0)

x = rand(10)
alpha = rand(10) * 3

v1 = @benchmark @inbounds for (i, αi) in zip(eachindex(x), alpha)
    x[i] = rand(rng, Gamma(αi))
end

v2 = @benchmark @. x = rand(rng, Gamma(alpha))

v2_against_v1 = judge(median(v2), median(v1))

Results

julia> v2_against_v1
BenchmarkTools.TrialJudgement:
  time:   -68.87% => improvement (5.00% tolerance)
  memory: -96.06% => improvement (1.00% tolerance)

julia> v1
BenchmarkTools.Trial:
  memory estimate:  3.17 KiB
  allocs estimate:  112
  --------------
  minimum time:     8.028 μs (0.00% GC)
  median time:      9.868 μs (0.00% GC)
  mean time:        10.664 μs (1.97% GC)
  maximum time:     1.125 ms (98.60% GC)
  --------------
  samples:          10000
  evals/sample:     3

julia> v2
BenchmarkTools.Trial:
  memory estimate:  128 bytes
  allocs estimate:  5
  --------------
  minimum time:     2.631 μs (0.00% GC)
  median time:      3.072 μs (0.00% GC)
  mean time:        3.202 μs (0.00% GC)
  maximum time:     6.997 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     9

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The benchmarks are not correct, both approaches should not have any allocations. You can obtain better estimates by defining two functions and interpolating the global variables:

using BenchmarkTools
using Distributions
using Random

Random.seed!(1234)
x = rand(10)
alpha = 3 * rand(10)

function f!(rng, x, alpha)
    @inbounds for (i, αi) in zip(eachindex(x), alpha)
        x[i] = rand(rng, Gamma(αi))
    end
    return x
end

function g!(rng, x, alpha)
    x .= rand.((rng,), Gamma.(alpha))
    return x
end

Random.seed!(1)
v1 = @benchmark f!($(Random.GLOBAL_RNG), $x, $alpha)

Random.seed!(1)
v2 = @benchmark g!($(Random.GLOBAL_RNG), $x, $alpha)

v2_against_v1 = judge(median(v2), median(v1))
julia> v1 = @benchmark f!($(Random.GLOBAL_RNG), $x, $alpha)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     673.462 ns (0.00% GC)
  median time:      768.867 ns (0.00% GC)
  mean time:        778.740 ns (0.00% GC)
  maximum time:     2.939 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     158

julia> v2 = @benchmark g!($(Random.GLOBAL_RNG), $x, $alpha)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     668.188 ns (0.00% GC)
  median time:      763.924 ns (0.00% GC)
  mean time:        769.628 ns (0.00% GC)
  maximum time:     1.301 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     165

julia> v2_against_v1 = judge(median(v2), median(v1))
BenchmarkTools.TrialJudgement: 
  time:   -0.64% => invariant (5.00% tolerance)
  memory: +0.00% => invariant (1.00% tolerance)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Thanks for looking into this!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel strongly about either implementation in that case. To me, x .= rand.(rng, Gamma.(alpha)) looks a little cleaner, but the original might be clearer to most people.

@luiarthur luiarthur self-requested a review February 2, 2021 20:22
Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me except the comment from @luiarthur!

@luiarthur
Copy link

Why are some of the tests listed as "Broken"?

Test Summary: |  Pass  Broken  Total
distributions | 10907    1102  12009
    Testing DistributionsAD tests passed 

@devmotion
Copy link
Member Author

Some distributions don't work for all AD backends. By marking them as broken (if possible - in some cases only a subset of the tests fail or tests error from which we can't recover) we can check if upstream changes in e.g. the AD backends fix the problems.

@yebai yebai merged commit 9806ec3 into master Feb 3, 2021
@delete-merged-branch delete-merged-branch bot deleted the dw/fix_turingdirichlet branch February 3, 2021 21:46
@yebai
Copy link
Member

yebai commented Feb 3, 2021

thanks, @devmotion and @luiarthur!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Tests on master fail
3 participants