diff --git a/Project.toml b/Project.toml index 71028ee9..3f1935a8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DistributionsAD" uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -version = "0.6.22" +version = "0.6.23" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -10,7 +10,6 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" @@ -30,7 +29,6 @@ Compat = "3.6" DiffRules = "0.1, 1.0" Distributions = "0.23.3, 0.24" FillArrays = "0.8, 0.9, 0.10, 0.11" -ForwardDiff = "0.10.6" NaNMath = "0.3" PDMats = "0.9, 0.10, 0.11" Requires = "1" diff --git a/src/DistributionsAD.jl b/src/DistributionsAD.jl index c9306485..acfb8dbd 100644 --- a/src/DistributionsAD.jl +++ b/src/DistributionsAD.jl @@ -28,7 +28,6 @@ import StatsFuns: logsumexp, nbetalogpdf import Distributions: MvNormal, MvLogNormal, - poissonbinomial_pdf_fft, logpdf, quantile, PoissonBinomial, @@ -65,9 +64,6 @@ include("zygote.jl") @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin using .ForwardDiff: @define_binary_dual_op # Needed for `eval`ing diffrules here include("forwarddiff.jl") - - # loads adjoint for `poissonbinomial_pdf` and `poissonbinomial_pdf_fft` - include("zygote_forwarddiff.jl") end @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin diff --git a/src/chainrules.jl b/src/chainrules.jl index ae55d6fb..de84e0db 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -11,15 +11,17 @@ (c, -c, z), ) +# StatsFuns: https://github.com/JuliaStats/StatsFuns.jl/pull/106 + ## Beta ## @scalar_rule( betalogpdf(α::Real, β::Real, x::Number), - @setup(di = digamma(α + β)), + @setup(z = digamma(α + β)), ( - @thunk(log(x) - digamma(α) + di), - @thunk(log(1 - x) - digamma(β) + di), - @thunk((α - 1)/x + (1 - β)/(1 - x)), + log(x) + z - digamma(α), + log1p(-x) + z - digamma(β), + (α - 1) / x + (1 - β) / (1 - x), ), ) @@ -27,10 +29,15 @@ @scalar_rule( gammalogpdf(k::Real, θ::Real, x::Number), + @setup( + invθ = inv(θ), + xoθ = invθ * x, + z = xoθ - k, + ), ( - @thunk(-digamma(k) - log(θ) + log(x)), - @thunk(-k/θ + x/θ^2), - @thunk((k - 1)/x - 1/θ), + log(xoθ) - digamma(k), + invθ * z, + - (1 + z) / x, ), ) @@ -38,48 +45,97 @@ @scalar_rule( chisqlogpdf(k::Real, x::Number), - @setup(ko2 = k / 2), - (@thunk((-logtwo - digamma(ko2) + log(x)) / 2), @thunk((ko2 - 1)/x - one(ko2) / 2)), + @setup(hk = k / 2), + ( + (log(x) - logtwo - digamma(hk)) / 2, + (hk - 1) / x - one(hk) / 2, + ), ) ## FDist ## @scalar_rule( - fdistlogpdf(v1::Real, v2::Real, x::Number), + fdistlogpdf(ν1::Real, ν2::Real, x::Number), @setup( - temp1 = v1 * x + v2, - temp2 = log(temp1), - vsum = v1 + v2, - temp3 = vsum / temp1, - temp4 = digamma(vsum / 2), + xν1 = x * ν1, + temp1 = xν1 + ν2, + a = (x - 1) / temp1, + νsum = ν1 + ν2, + di = digamma(νsum / 2), ), ( - @thunk((log(v1 * x) + 1 - temp2 - x * temp3 - digamma(v1 / 2) + temp4) / 2), - @thunk((log(v2) + 1 - temp2 - temp3 - digamma(v2 / 2) + temp4) / 2), - @thunk(v1 / 2 * (1 / x - temp3) - 1 / x), + (-log1p(ν2 / xν1) - ν2 * a + di - digamma(ν1 / 2)) / 2, + (-log1p(xν1 / ν2) + ν1 * a + di - digamma(ν2 / 2)) / 2, + ((ν1 - 2) / x - ν1 * νsum / temp1) / 2, ), ) ## TDist ## @scalar_rule( - tdistlogpdf(v::Real, x::Number), + tdistlogpdf(ν::Real, x::Number), + @setup( + νp1 = ν + 1, + xsq = x^2, + invν = inv(ν), + a = xsq * invν, + b = νp1 / (ν + xsq), + ), ( - @thunk((digamma((v + 1) / 2) - 1 / v - digamma(v / 2) - log(1 + x^2 / v) + x^2 * (v + 1) / v^2 / (1 + x^2 / v)) / 2), - @thunk(-x * (v + 1) / (v + x^2)), - ) + (digamma(νp1 / 2) - digamma(ν / 2) + a * b - log1p(a) - invν) / 2, + - x * b, + ), ) ## Binomial ## @scalar_rule( - binomlogpdf(n::Int, p::Real, x::Int), - (DoesNotExist(), x / p - (n - x) / (1 - p), DoesNotExist()), + binomlogpdf(n::Real, p::Real, k::Real), + @setup(z = digamma(n - k + 1)), + ( + digamma(n + 2) - z + log1p(-p) - 1 / (1 + n), + (k / p - n) / (1 - p), + z - digamma(k + 1) + logit(p), + ), ) ## Poisson ## @scalar_rule( - poislogpdf(v::Real, x::Int), - (x / v - 1, DoesNotExist()), + poislogpdf(λ::Number, x::Number), + ((iszero(x) && iszero(λ) ? zero(x / λ) : x / λ) - 1, log(λ) - digamma(x + 1)), +) + +## PoissonBinomial + +function ChainRulesCore.rrule( + ::typeof(Distributions.poissonbinomial_pdf_fft), p::AbstractVector{<:Real} ) + y = Distributions.poissonbinomial_pdf_fft(p) + A = poissonbinomial_partialderivatives(p) + function poissonbinomial_pdf_fft_pullback(Δy) + p̄ = InplaceableThunk( + @thunk(A * Δy), + Δ -> LinearAlgebra.mul!(Δ, A, Δy, true, true), + ) + return (NO_FIELDS, p̄) + end + return y, poissonbinomial_pdf_fft_pullback +end + +if isdefined(Distributions, :poissonbinomial_pdf) + function ChainRulesCore.rrule( + ::typeof(Distributions.poissonbinomial_pdf), p::AbstractVector{<:Real} + ) + y = Distributions.poissonbinomial_pdf(p) + A = poissonbinomial_partialderivatives(p) + function poissonbinomial_pdf_pullback(Δy) + p̄ = InplaceableThunk( + @thunk(A * Δy), + Δ -> LinearAlgebra.mul!(Δ, A, Δy, true, true), + ) + return (NO_FIELDS, p̄) + end + return y, poissonbinomial_pdf_pullback + end +end diff --git a/src/common.jl b/src/common.jl index ba469308..de790393 100644 --- a/src/common.jl +++ b/src/common.jl @@ -46,3 +46,45 @@ parameterless_type(x) = parameterless_type(typeof(x)) parameterless_type(x::Type) = __parameterless_type(x) @non_differentiable adapt_randn(::Any...) + +# PoissonBinomial + +# compute matrix of partial derivatives [∂P(X=j-1)/∂pᵢ]_{i=1,…,n; j=1,…,n+1} +# +# This uses the same dynamic programming "trick" as for the computation of the primals +# in Distributions +# +# Reference (for the primal): +# +# Marlin A. Thomas & Audrey E. Taub (1982) +# Calculating binomial probabilities when the trial probabilities are unequal, +# Journal of Statistical Computation and Simulation, 14:2, 125-131, DOI: 10.1080/00949658208810534 +function poissonbinomial_partialderivatives(p) + n = length(p) + A = zeros(eltype(p), n, n + 1) + @inbounds for j in 1:n + A[j, end] = 1 + end + @inbounds for (i, pi) in enumerate(p) + qi = 1 - pi + for k in (n - i + 1):n + kp1 = k + 1 + for j in 1:(i - 1) + A[j, k] = pi * A[j, k] + qi * A[j, kp1] + end + for j in (i+1):n + A[j, k] = pi * A[j, k] + qi * A[j, kp1] + end + end + for j in 1:(i-1) + A[j, end] *= pi + end + for j in (i+1):n + A[j, end] *= pi + end + end + @inbounds for j in 1:n, i in 1:n + A[i, j] -= A[i, j+1] + end + return A +end diff --git a/src/tracker.jl b/src/tracker.jl index f4a79067..1473e809 100644 --- a/src/tracker.jl +++ b/src/tracker.jl @@ -261,26 +261,22 @@ end PoissonBinomial(p::TrackedArray{<:Real}; check_args=true) = TuringPoissonBinomial(p; check_args = check_args) -# TODO: add adjoints without ForwardDiff poissonbinomial_pdf_fft(x::TrackedArray) = track(poissonbinomial_pdf_fft, x) @grad function poissonbinomial_pdf_fft(x::TrackedArray) x_data = data(x) - T = eltype(x_data) - fft = poissonbinomial_pdf_fft(x_data) - return fft, Δ -> begin - ((ForwardDiff.jacobian(poissonbinomial_pdf_fft, x_data)::Matrix{T})' * Δ,) - end + value = poissonbinomial_pdf_fft(x_data) + A = poissonbinomial_partialderivatives(x_data) + poissonbinomial_pdf_fft_pullback(Δ) = (A * Δ,) + return value, poissonbinomial_pdf_fft_pullback end if isdefined(Distributions, :poissonbinomial_pdf) Distributions.poissonbinomial_pdf(x::TrackedArray) = track(Distributions.poissonbinomial_pdf, x) @grad function Distributions.poissonbinomial_pdf(x::TrackedArray) x_data = data(x) - T = eltype(x_data) value = Distributions.poissonbinomial_pdf(x_data) - function poissonbinomial_pdf_pullback(Δ) - return ((ForwardDiff.jacobian(Distributions.poissonbinomial_pdf, x_data)::Matrix{T})' * Δ,) - end + A = poissonbinomial_partialderivatives(x_data) + poissonbinomial_pdf_pullback(Δ) = (A * Δ,) return value, poissonbinomial_pdf_pullback end end diff --git a/src/zygote.jl b/src/zygote.jl index 7460c039..4feeff16 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -12,14 +12,6 @@ ZygoteRules.@adjoint function Distributions.Uniform(args...) return ZygoteRules.pullback(TuringUniform, args...) end -## PoissonBinomial ## - -# Zygote loads ForwardDiff, so this dummy adjoint should never be needed. -# The adjoint that is used for `poissonbinomial_pdf_fft` is defined in `src/zygote_forwarddiff.jl` -# ZygoteRules.@adjoint function poissonbinomial_pdf_fft(x::AbstractArray{T}) where T<:Real -# error("This needs ForwardDiff. `using ForwardDiff` should fix this error.") -# end - ## Product # Tests with `Kolmogorov` seem to fail otherwise?! diff --git a/src/zygote_forwarddiff.jl b/src/zygote_forwarddiff.jl deleted file mode 100644 index 7e157379..00000000 --- a/src/zygote_forwarddiff.jl +++ /dev/null @@ -1,20 +0,0 @@ -# Zygote loads ForwardDiff, so this adjoint will autmatically be loaded together -# with `using Zygote`. - -# TODO: add adjoints without ForwardDiff -@adjoint function poissonbinomial_pdf_fft(x::AbstractArray{T}) where T<:Real - fft = poissonbinomial_pdf_fft(x) - return fft, Δ -> begin - ((ForwardDiff.jacobian(poissonbinomial_pdf_fft, x)::Matrix{T})' * Δ,) - end -end - -if isdefined(Distributions, :poissonbinomial_pdf) - @adjoint function Distributions.poissonbinomial_pdf(x::AbstractArray{T}) where T<:Real - value = Distributions.poissonbinomial_pdf(x) - function poissonbinomial_pdf_pullback(Δ) - return ((ForwardDiff.jacobian(Distributions.poissonbinomial_pdf, x)::Matrix{T})' * Δ,) - end - return value, poissonbinomial_pdf_pullback - end -end diff --git a/test/Project.toml b/test/Project.toml index 82dba790..c9446ba7 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -16,7 +17,8 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ChainRulesTestUtils = "0.5.3, 0.6" +ChainRulesCore = "0.9" +ChainRulesTestUtils = "0.6.3" Combinatorics = "1.0.2" Distributions = "0.24.3" FiniteDifferences = "0.11.3, 0.12" diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl index 42a6e229..a94d86df 100644 --- a/test/ad/chainrules.jl +++ b/test/ad/chainrules.jl @@ -1,51 +1,63 @@ - @testset "chainrules" begin - x, Δx, x̄ = randn(3) - y, Δy, ȳ = randn(3) - z, Δz, z̄ = randn(3) - Δu = randn() - - ỹ = x + exp(y) + exp(z) - z̃ = x + exp(y) - frule_test(DistributionsAD.uniformlogpdf, (x, Δx), (ỹ, Δy), (z̃, Δz)) - rrule_test(DistributionsAD.uniformlogpdf, Δu, (x, x̄), (ỹ, ȳ), (z̃, z̄)) - - x̃ = exp(x) - ỹ = exp(y) - z̃ = logistic(z) - frule_test(DistributionsAD.betalogpdf, (x̃, Δx), (ỹ, Δy), (z̃, Δz)) - rrule_test(DistributionsAD.betalogpdf, Δu, (x̃, x̄), (ỹ, ȳ), (z̃, z̄)) - - x̃ = exp(x) - ỹ = exp(y) - z̃ = exp(z) - frule_test(DistributionsAD.gammalogpdf, (x̃, Δx), (ỹ, Δy), (z̃, Δz)) - rrule_test(DistributionsAD.gammalogpdf, Δu, (x̃, x̄), (ỹ, ȳ), (z̃, z̄)) - - x̃ = exp(x) - ỹ = exp(y) - z̃ = exp(z) - frule_test(DistributionsAD.chisqlogpdf, (x̃, Δx), (ỹ, Δy)) - rrule_test(DistributionsAD.chisqlogpdf, Δu, (x̃, x̄), (ỹ, ȳ)) - - x̃ = exp(x) - ỹ = exp(y) - z̃ = exp(z) - frule_test(DistributionsAD.fdistlogpdf, (x̃, Δx), (ỹ, Δy), (z̃, Δz)) - rrule_test(DistributionsAD.fdistlogpdf, Δu, (x̃, x̄), (ỹ, ȳ), (z̃, z̄)) - - x̃ = exp(x) - frule_test(DistributionsAD.tdistlogpdf, (x̃, Δx), (y, Δy)) - rrule_test(DistributionsAD.tdistlogpdf, Δu, (x̃, x̄), (y, ȳ)) - - x̃ = rand(1:100) - ỹ = logistic(y) - z̃ = rand(1:x̃) - frule_test(DistributionsAD.binomlogpdf, (x̃, nothing), (ỹ, Δy), (z̃, nothing)) - rrule_test(DistributionsAD.binomlogpdf, Δu, (x̃, nothing), (ỹ, ȳ), (z̃, nothing)) - - x̃ = exp(x) - ỹ = rand(1:100) - frule_test(DistributionsAD.poislogpdf, (x̃, Δx), (ỹ, nothing)) - rrule_test(DistributionsAD.poislogpdf, Δu, (x̃, x̄), (ỹ, nothing)) + x = randn() + z = x + exp(randn()) + y = z + exp(randn()) + test_frule(DistributionsAD.uniformlogpdf, x, y, z) + test_rrule(DistributionsAD.uniformlogpdf, x, y, z) + + # StatsFuns: https://github.com/JuliaStats/StatsFuns.jl/pull/106 + x = exp(randn()) + y = exp(randn()) + z = logistic(randn()) + test_frule(StatsFuns.betalogpdf, x, y, z) + test_rrule(StatsFuns.betalogpdf, x, y, z) + + x = exp(randn()) + y = exp(randn()) + z = exp(randn()) + test_frule(StatsFuns.gammalogpdf, x, y, z) + test_rrule(StatsFuns.gammalogpdf, x, y, z) + + x = exp(randn()) + y = exp(randn()) + test_frule(StatsFuns.chisqlogpdf, x, y) + test_rrule(StatsFuns.chisqlogpdf, x, y) + + x = exp(randn()) + y = exp(randn()) + z = exp(randn()) + test_frule(StatsFuns.fdistlogpdf, x, y, z) + test_rrule(StatsFuns.fdistlogpdf, x, y, z) + + x = exp(randn()) + y = randn() + test_frule(StatsFuns.tdistlogpdf, x, y) + test_rrule(StatsFuns.tdistlogpdf, x, y) + + # use `BigFloat` to avoid Rmath implementation in finite differencing check + # (returns `NaN` for non-integer values) + n = rand(1:100) + x = BigFloat(n) + y = big(logistic(randn())) + z = BigFloat(rand(1:n)) + test_frule(StatsFuns.binomlogpdf, x, y, z) + test_rrule(StatsFuns.binomlogpdf, x, y, z) + + x = big(exp(randn())) + y = BigFloat(rand(1:100)) + test_frule(StatsFuns.poislogpdf, x, y) + test_rrule(StatsFuns.poislogpdf, x, y) + + _, pb = rrule(StatsFuns.poislogpdf, 0.0, 0.0) + _, x̄1, _ = pb(1) + @test x̄1 == -1 + _, pb = rrule(StatsFuns.poislogpdf, 0.0, 1.0) + _, x̄1, _ = pb(1) + @test x̄1 == Inf + + # PoissonBinomial + test_rrule(Distributions.poissonbinomial_pdf_fft, rand(50)) + if isdefined(Distributions, :poissonbinomial_pdf) + test_rrule(Distributions.poissonbinomial_pdf, rand(50)) + end end diff --git a/test/ad/distributions.jl b/test/ad/distributions.jl index 90b541c6..4620cae3 100644 --- a/test/ad/distributions.jl +++ b/test/ad/distributions.jl @@ -58,7 +58,6 @@ DistSpec(Skellam, (1.0, 2.0), [-2, -2]; broken=(:Zygote,)), DistSpec(PoissonBinomial, ([0.5, 0.5],), 0), - DistSpec(PoissonBinomial, ([0.5, 0.5],), [0, 0]), DistSpec(TuringPoissonBinomial, ([0.5, 0.5],), 0), DistSpec(TuringPoissonBinomial, ([0.5, 0.5],), [0, 0]), @@ -217,6 +216,9 @@ # Only some Zygote tests are broken and therefore this can not be checked DistSpec(Pareto, (), 1.5; broken=(:Zygote,)), + + # Some tests are broken on some Julia versions, therefore it can't be checked reliably + DistSpec(PoissonBinomial, ([0.5, 0.5],), [0, 0]; broken=(:Zygote,)), ] # Tests that have a `broken` field can be executed but, according to FiniteDifferences, @@ -392,8 +394,19 @@ # Skellam only fails in these tests with ReverseDiff # Ref: https://github.com/TuringLang/DistributionsAD.jl/issues/126 - filldist_broken = d.f(d.θ...) isa Skellam ? (d.broken..., :ReverseDiff) : d.broken - arraydist_broken = d.broken + # PoissonBinomial fails with Zygote + filldist_broken = if d.f(d.θ...) isa Skellam + (d.broken..., :ReverseDiff) + elseif d.f(d.θ...) isa PoissonBinomial + (d.broken..., :Zygote) + else + d.broken + end + arraydist_broken = if d.f(d.θ...) isa PoissonBinomial + (d.broken..., :Zygote) + else + d.broken + end # Create `filldist` distribution f_filldist = (θ...,) -> filldist(d.f(θ...), n) diff --git a/test/runtests.jl b/test/runtests.jl index 779e5ec1..07bfe007 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using DistributionsAD +using ChainRulesCore using ChainRulesTestUtils using Combinatorics using Distributions @@ -12,7 +13,7 @@ using Distributions: meanlogdet using DistributionsAD: TuringUniform, TuringMvNormal, TuringMvLogNormal, TuringPoissonBinomial, TuringDirichlet using StatsBase: entropy -using StatsFuns: binomlogpdf, logsumexp, logistic +using StatsFuns: StatsFuns, logsumexp, logistic import NNlib