Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix and test TuringDirichlet constructors #152

Merged
merged 3 commits into from
Feb 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DistributionsAD"
uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
version = "0.6.18"
version = "0.6.19"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
72 changes: 37 additions & 35 deletions src/multivariate.jl
Original file line number Diff line number Diff line change
@@ -1,52 +1,54 @@
## Dirichlet ##

struct TuringDirichlet{T, TV <: AbstractVector} <: ContinuousMultivariateDistribution
struct TuringDirichlet{T<:Real,TV<:AbstractVector,S<:Real} <: ContinuousMultivariateDistribution
alpha::TV
alpha0::T
lmnB::T
end
Base.length(d::TuringDirichlet) = length(d.alpha)
function check(alpha)
all(ai -> ai > 0, alpha) ||
throw(ArgumentError("Dirichlet: alpha must be a positive vector."))
end

function Distributions._rand!(rng::Random.AbstractRNG,
d::TuringDirichlet,
x::AbstractVector{<:Real})
s = 0.0
n = length(x)
α = d.alpha
for i in 1:n
@inbounds s += (x[i] = rand(rng, Gamma(α[i])))
end
Distributions.multiply!(x, inv(s)) # this returns x
lmnB::S
end

function TuringDirichlet(alpha::AbstractVector)
check(alpha)
all(ai -> ai > 0, alpha) ||
throw(ArgumentError("Dirichlet: alpha must be a positive vector."))

alpha0 = sum(alpha)
lmnB = sum(loggamma, alpha) - loggamma(alpha0)
T = promote_type(typeof(alpha0), typeof(lmnB))
TV = typeof(alpha)
TuringDirichlet{T, TV}(alpha, alpha0, lmnB)
end

function TuringDirichlet(d::Integer, alpha::Real)
alpha0 = alpha * d
_alpha = fill(alpha, d)
lmnB = loggamma(alpha) * d - loggamma(alpha0)
T = promote_type(typeof(alpha0), typeof(lmnB))
TV = typeof(_alpha)
TuringDirichlet{T, TV}(_alpha, alpha0, lmnB)
end
function TuringDirichlet(alpha::AbstractVector{T}) where {T <: Integer}
TuringDirichlet(float.(alpha))
return TuringDirichlet(alpha, alpha0, lmnB)
end
TuringDirichlet(d::Integer, alpha::Integer) = TuringDirichlet(d, Float64(alpha))
TuringDirichlet(d::Integer, alpha::Real) = TuringDirichlet(Fill(alpha, d))

# TODO: remove?
TuringDirichlet(alpha::AbstractVector{<:Integer}) = TuringDirichlet(float.(alpha))
TuringDirichlet(d::Integer, alpha::Integer) = TuringDirichlet(d, float(alpha))

# TODO: remove and use `Dirichlet` only for `Tracker.TrackedVector`
Distributions.Dirichlet(alpha::AbstractVector) = TuringDirichlet(alpha)

TuringDirichlet(d::Dirichlet) = TuringDirichlet(d.alpha, d.alpha0, d.lmnB)

Base.length(d::TuringDirichlet) = length(d.alpha)

# copied from Distributions
# TODO: remove and use `Dirichlet`?
function Distributions._rand!(
rng::Random.AbstractRNG,
d::TuringDirichlet,
x::AbstractVector{<:Real},
)
@inbounds for (i, αi) in zip(eachindex(x), d.alpha)
x[i] = rand(rng, Gamma(αi))
end
Comment on lines +38 to +40
Copy link

@luiarthur luiarthur Feb 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this:

Suggested change
@inbounds for (i, αi) in zip(eachindex(x), d.alpha)
x[i] = rand(rng, Gamma(αi))
end
@. x = rand(rng, Gamma(d.alpha))

According to BenchmarkTools, this appears much more efficient in terms of speed and memory.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. Generally, the broadcasting machinery is very complex and often creates a large overhead. This also leads to quite involved implementations in e.g. Tracker and Zygote based on ForwardDiff for broadcasting. However, I did not compare it in this case. What exactly did you bechmark?

Copy link

@luiarthur luiarthur Feb 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. I did not test with AD. I know PyTorch implements a method to differentiate through random Gammas. But I didn't think that was implemented here, so I didn't test.

This is what I did:

using BenchmarkTools
using Distributions
import Random

rng = Random.MersenneTwister(0)

x = rand(10)
alpha = rand(10) * 3

v1 = @benchmark @inbounds for (i, αi) in zip(eachindex(x), alpha)
    x[i] = rand(rng, Gamma(αi))
end

v2 = @benchmark @. x = rand(rng, Gamma(alpha))

v2_against_v1 = judge(median(v2), median(v1))

Results

julia> v2_against_v1
BenchmarkTools.TrialJudgement:
  time:   -68.87% => improvement (5.00% tolerance)
  memory: -96.06% => improvement (1.00% tolerance)

julia> v1
BenchmarkTools.Trial:
  memory estimate:  3.17 KiB
  allocs estimate:  112
  --------------
  minimum time:     8.028 μs (0.00% GC)
  median time:      9.868 μs (0.00% GC)
  mean time:        10.664 μs (1.97% GC)
  maximum time:     1.125 ms (98.60% GC)
  --------------
  samples:          10000
  evals/sample:     3

julia> v2
BenchmarkTools.Trial:
  memory estimate:  128 bytes
  allocs estimate:  5
  --------------
  minimum time:     2.631 μs (0.00% GC)
  median time:      3.072 μs (0.00% GC)
  mean time:        3.202 μs (0.00% GC)
  maximum time:     6.997 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     9

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The benchmarks are not correct, both approaches should not have any allocations. You can obtain better estimates by defining two functions and interpolating the global variables:

using BenchmarkTools
using Distributions
using Random

Random.seed!(1234)
x = rand(10)
alpha = 3 * rand(10)

function f!(rng, x, alpha)
    @inbounds for (i, αi) in zip(eachindex(x), alpha)
        x[i] = rand(rng, Gamma(αi))
    end
    return x
end

function g!(rng, x, alpha)
    x .= rand.((rng,), Gamma.(alpha))
    return x
end

Random.seed!(1)
v1 = @benchmark f!($(Random.GLOBAL_RNG), $x, $alpha)

Random.seed!(1)
v2 = @benchmark g!($(Random.GLOBAL_RNG), $x, $alpha)

v2_against_v1 = judge(median(v2), median(v1))
julia> v1 = @benchmark f!($(Random.GLOBAL_RNG), $x, $alpha)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     673.462 ns (0.00% GC)
  median time:      768.867 ns (0.00% GC)
  mean time:        778.740 ns (0.00% GC)
  maximum time:     2.939 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     158

julia> v2 = @benchmark g!($(Random.GLOBAL_RNG), $x, $alpha)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     668.188 ns (0.00% GC)
  median time:      763.924 ns (0.00% GC)
  mean time:        769.628 ns (0.00% GC)
  maximum time:     1.301 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     165

julia> v2_against_v1 = judge(median(v2), median(v1))
BenchmarkTools.TrialJudgement: 
  time:   -0.64% => invariant (5.00% tolerance)
  memory: +0.00% => invariant (1.00% tolerance)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Thanks for looking into this!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel strongly about either implementation in that case. To me, x .= rand.(rng, Gamma.(alpha)) looks a little cleaner, but the original might be clearer to most people.

Distributions.multiply!(x, inv(sum(x))) # this returns x
end
function Distributions._rand!(
rng::AbstractRNG,
d::TuringDirichlet{<:Real,<:FillArrays.AbstractFill},
x::AbstractVector{<:Real}
)
rand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x)
Distributions.multiply!(x, inv(sum(x))) # this returns x
end

function Distributions._logpdf(d::TuringDirichlet, x::AbstractVector{<:Real})
return simplex_logpdf(d.alpha, d.lmnB, x)
end
Expand Down
6 changes: 3 additions & 3 deletions src/reversediff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,13 +260,13 @@ Dirichlet(alpha::AbstractVector{<:TrackedReal}) = TuringDirichlet(alpha)
Dirichlet(d::Integer, alpha::TrackedReal) = TuringDirichlet(d, alpha)

function _logpdf(d::Dirichlet, x::AbstractVector{<:TrackedReal})
return _logpdf(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x)
return _logpdf(TuringDirichlet(d), x)
end
function logpdf(d::Dirichlet, x::AbstractMatrix{<:TrackedReal})
return logpdf(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x)
return logpdf(TuringDirichlet(d), x)
end
function loglikelihood(d::Dirichlet, x::AbstractMatrix{<:TrackedReal})
return loglikelihood(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x)
return loglikelihood(TuringDirichlet(d), x)
end

# default definition of `loglikelihood` yields gradients of zero?!
Expand Down
7 changes: 3 additions & 4 deletions src/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,13 +371,13 @@ Distributions.Dirichlet(alpha::TrackedVector) = TuringDirichlet(alpha)
Distributions.Dirichlet(d::Integer, alpha::TrackedReal) = TuringDirichlet(d, alpha)

function Distributions._logpdf(d::Dirichlet, x::TrackedVector{<:Real})
return Distributions._logpdf(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x)
return Distributions._logpdf(TuringDirichlet(d), x)
end
function Distributions.logpdf(d::Dirichlet, x::TrackedMatrix{<:Real})
return logpdf(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x)
return logpdf(TuringDirichlet(d), x)
end
function Distributions.loglikelihood(d::Dirichlet, x::TrackedMatrix{<:Real})
return loglikelihood(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x)
return loglikelihood(TuringDirichlet(d), x)
end

# Fix ambiguities
Expand Down Expand Up @@ -615,4 +615,3 @@ Distributions.InverseWishart(df::TrackedReal, S::AbstractMatrix{<:Real}) = Turin
Distributions.InverseWishart(df::Real, S::TrackedMatrix) = TuringInverseWishart(df, S)
Distributions.InverseWishart(df::TrackedReal, S::TrackedMatrix) = TuringInverseWishart(df, S)
Distributions.InverseWishart(df::TrackedReal, S::AbstractPDMat{<:TrackedReal}) = TuringInverseWishart(df, S)

38 changes: 38 additions & 0 deletions test/others.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,4 +298,42 @@
end
end
end

@testset "TuringDirichlet" begin
dim = 3
n = 4
for alpha in (2, rand())
d1 = TuringDirichlet(dim, alpha)
d2 = Dirichlet(dim, alpha)
d3 = TuringDirichlet(d2)
@test d1.alpha == d2.alpha == d3.alpha
@test d1.alpha0 == d2.alpha0 == d3.alpha0
@test d1.lmnB == d2.lmnB == d3.lmnB

s1 = rand(d1)
@test s1 isa Vector{Float64}
@test length(s1) == dim

s2 = rand(d1, n)
@test s2 isa Matrix{Float64}
@test size(s2) == (dim, n)
end

for alpha in (ones(Int, dim), rand(dim))
d1 = TuringDirichlet(alpha)
d2 = Dirichlet(alpha)
d3 = TuringDirichlet(d2)
@test d1.alpha == d2.alpha == d3.alpha
@test d1.alpha0 == d2.alpha0 == d3.alpha0
@test d1.lmnB == d2.lmnB == d3.lmnB

s1 = rand(d1)
@test s1 isa Vector{Float64}
@test length(s1) == dim

s2 = rand(d1, n)
@test s2 isa Matrix{Float64}
@test size(s2) == (dim, n)
end
end
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using Random, LinearAlgebra, Test

using Distributions: meanlogdet
using DistributionsAD: TuringUniform, TuringMvNormal, TuringMvLogNormal,
TuringPoissonBinomial
TuringPoissonBinomial, TuringDirichlet
using StatsBase: entropy
using StatsFuns: binomlogpdf, logsumexp, logistic

Expand Down