Skip to content

Commit

Permalink
Update ChainRules definitions and add differential for PoissonBinomia…
Browse files Browse the repository at this point in the history
…l pdf (#162)
  • Loading branch information
devmotion authored Apr 25, 2021
1 parent 978b1fe commit c463960
Show file tree
Hide file tree
Showing 11 changed files with 213 additions and 125 deletions.
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DistributionsAD"
uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
version = "0.6.22"
version = "0.6.23"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -10,7 +10,6 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Expand All @@ -30,7 +29,6 @@ Compat = "3.6"
DiffRules = "0.1, 1.0"
Distributions = "0.23.3, 0.24"
FillArrays = "0.8, 0.9, 0.10, 0.11"
ForwardDiff = "0.10.6"
NaNMath = "0.3"
PDMats = "0.9, 0.10, 0.11"
Requires = "1"
Expand Down
4 changes: 0 additions & 4 deletions src/DistributionsAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import StatsFuns: logsumexp,
nbetalogpdf
import Distributions: MvNormal,
MvLogNormal,
poissonbinomial_pdf_fft,
logpdf,
quantile,
PoissonBinomial,
Expand Down Expand Up @@ -65,9 +64,6 @@ include("zygote.jl")
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin
using .ForwardDiff: @define_binary_dual_op # Needed for `eval`ing diffrules here
include("forwarddiff.jl")

# loads adjoint for `poissonbinomial_pdf` and `poissonbinomial_pdf_fft`
include("zygote_forwarddiff.jl")
end

@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
Expand Down
108 changes: 82 additions & 26 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,75 +11,131 @@
(c, -c, z),
)

# StatsFuns: https://github.com/JuliaStats/StatsFuns.jl/pull/106

## Beta ##

@scalar_rule(
betalogpdf::Real, β::Real, x::Number),
@setup(di = digamma+ β)),
@setup(z = digamma+ β)),
(
@thunk(log(x) - digamma) + di),
@thunk(log(1 - x) - digamma) + di),
@thunk((α - 1)/x + (1 - β)/(1 - x)),
log(x) + z - digamma(α),
log1p(-x) + z - digamma(β),
(α - 1) / x + (1 - β) / (1 - x),
),
)

## Gamma ##

@scalar_rule(
gammalogpdf(k::Real, θ::Real, x::Number),
@setup(
invθ = inv(θ),
xoθ = invθ * x,
z = xoθ - k,
),
(
@thunk(-digamma(k) - log(θ) + log(x)),
@thunk(-k/θ + x/θ^2),
@thunk((k - 1)/x - 1/θ),
log(xoθ) - digamma(k),
invθ * z,
- (1 + z) / x,
),
)

## Chisq ##

@scalar_rule(
chisqlogpdf(k::Real, x::Number),
@setup(ko2 = k / 2),
(@thunk((-logtwo - digamma(ko2) + log(x)) / 2), @thunk((ko2 - 1)/x - one(ko2) / 2)),
@setup(hk = k / 2),
(
(log(x) - logtwo - digamma(hk)) / 2,
(hk - 1) / x - one(hk) / 2,
),
)

## FDist ##

@scalar_rule(
fdistlogpdf(v1::Real, v2::Real, x::Number),
fdistlogpdf(ν1::Real, ν2::Real, x::Number),
@setup(
temp1 = v1 * x + v2,
temp2 = log(temp1),
vsum = v1 + v2,
temp3 = vsum / temp1,
temp4 = digamma(vsum / 2),
xν1 = x * ν1,
temp1 = xν1 + ν2,
a = (x - 1) / temp1,
νsum = ν1 + ν2,
di = digamma(νsum / 2),
),
(
@thunk((log(v1 * x) + 1 - temp2 - x * temp3 - digamma(v1 / 2) + temp4) / 2),
@thunk((log(v2) + 1 - temp2 - temp3 - digamma(v2 / 2) + temp4) / 2),
@thunk(v1 / 2 * (1 / x - temp3) - 1 / x),
(-log1p(ν2 / xν1) - ν2 * a + di - digamma(ν1 / 2)) / 2,
(-log1p(xν1 / ν2) + ν1 * a + di - digamma(ν2 / 2)) / 2,
((ν1 - 2) / x - ν1 * νsum / temp1) / 2,
),
)

## TDist ##

@scalar_rule(
tdistlogpdf(v::Real, x::Number),
tdistlogpdf::Real, x::Number),
@setup(
νp1 = ν + 1,
xsq = x^2,
invν = inv(ν),
a = xsq * invν,
b = νp1 /+ xsq),
),
(
@thunk((digamma((v + 1) / 2) - 1 / v - digamma(v / 2) - log(1 + x^2 / v) + x^2 * (v + 1) / v^2 / (1 + x^2 / v)) / 2),
@thunk(-x * (v + 1) / (v + x^2)),
)
(digamma(νp1 / 2) - digamma(ν / 2) + a * b - log1p(a) - invν) / 2,
- x * b,
),
)

## Binomial ##

@scalar_rule(
binomlogpdf(n::Int, p::Real, x::Int),
(DoesNotExist(), x / p - (n - x) / (1 - p), DoesNotExist()),
binomlogpdf(n::Real, p::Real, k::Real),
@setup(z = digamma(n - k + 1)),
(
digamma(n + 2) - z + log1p(-p) - 1 / (1 + n),
(k / p - n) / (1 - p),
z - digamma(k + 1) + logit(p),
),
)

## Poisson ##

@scalar_rule(
poislogpdf(v::Real, x::Int),
(x / v - 1, DoesNotExist()),
poislogpdf::Number, x::Number),
((iszero(x) && iszero(λ) ? zero(x / λ) : x / λ) - 1, log(λ) - digamma(x + 1)),
)

## PoissonBinomial

function ChainRulesCore.rrule(
::typeof(Distributions.poissonbinomial_pdf_fft), p::AbstractVector{<:Real}
)
y = Distributions.poissonbinomial_pdf_fft(p)
A = poissonbinomial_partialderivatives(p)
function poissonbinomial_pdf_fft_pullback(Δy)
= InplaceableThunk(
@thunk(A * Δy),
Δ -> LinearAlgebra.mul!(Δ, A, Δy, true, true),
)
return (NO_FIELDS, p̄)
end
return y, poissonbinomial_pdf_fft_pullback
end

if isdefined(Distributions, :poissonbinomial_pdf)
function ChainRulesCore.rrule(
::typeof(Distributions.poissonbinomial_pdf), p::AbstractVector{<:Real}
)
y = Distributions.poissonbinomial_pdf(p)
A = poissonbinomial_partialderivatives(p)
function poissonbinomial_pdf_pullback(Δy)
= InplaceableThunk(
@thunk(A * Δy),
Δ -> LinearAlgebra.mul!(Δ, A, Δy, true, true),
)
return (NO_FIELDS, p̄)
end
return y, poissonbinomial_pdf_pullback
end
end
42 changes: 42 additions & 0 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,45 @@ parameterless_type(x) = parameterless_type(typeof(x))
parameterless_type(x::Type) = __parameterless_type(x)

@non_differentiable adapt_randn(::Any...)

# PoissonBinomial

# compute matrix of partial derivatives [∂P(X=j-1)/∂pᵢ]_{i=1,…,n; j=1,…,n+1}
#
# This uses the same dynamic programming "trick" as for the computation of the primals
# in Distributions
#
# Reference (for the primal):
#
# Marlin A. Thomas & Audrey E. Taub (1982)
# Calculating binomial probabilities when the trial probabilities are unequal,
# Journal of Statistical Computation and Simulation, 14:2, 125-131, DOI: 10.1080/00949658208810534
function poissonbinomial_partialderivatives(p)
n = length(p)
A = zeros(eltype(p), n, n + 1)
@inbounds for j in 1:n
A[j, end] = 1
end
@inbounds for (i, pi) in enumerate(p)
qi = 1 - pi
for k in (n - i + 1):n
kp1 = k + 1
for j in 1:(i - 1)
A[j, k] = pi * A[j, k] + qi * A[j, kp1]
end
for j in (i+1):n
A[j, k] = pi * A[j, k] + qi * A[j, kp1]
end
end
for j in 1:(i-1)
A[j, end] *= pi
end
for j in (i+1):n
A[j, end] *= pi
end
end
@inbounds for j in 1:n, i in 1:n
A[i, j] -= A[i, j+1]
end
return A
end
16 changes: 6 additions & 10 deletions src/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,26 +261,22 @@ end
PoissonBinomial(p::TrackedArray{<:Real}; check_args=true) =
TuringPoissonBinomial(p; check_args = check_args)

# TODO: add adjoints without ForwardDiff
poissonbinomial_pdf_fft(x::TrackedArray) = track(poissonbinomial_pdf_fft, x)
@grad function poissonbinomial_pdf_fft(x::TrackedArray)
x_data = data(x)
T = eltype(x_data)
fft = poissonbinomial_pdf_fft(x_data)
return fft, Δ -> begin
((ForwardDiff.jacobian(poissonbinomial_pdf_fft, x_data)::Matrix{T})' * Δ,)
end
value = poissonbinomial_pdf_fft(x_data)
A = poissonbinomial_partialderivatives(x_data)
poissonbinomial_pdf_fft_pullback(Δ) = (A * Δ,)
return value, poissonbinomial_pdf_fft_pullback
end

if isdefined(Distributions, :poissonbinomial_pdf)
Distributions.poissonbinomial_pdf(x::TrackedArray) = track(Distributions.poissonbinomial_pdf, x)
@grad function Distributions.poissonbinomial_pdf(x::TrackedArray)
x_data = data(x)
T = eltype(x_data)
value = Distributions.poissonbinomial_pdf(x_data)
function poissonbinomial_pdf_pullback(Δ)
return ((ForwardDiff.jacobian(Distributions.poissonbinomial_pdf, x_data)::Matrix{T})' * Δ,)
end
A = poissonbinomial_partialderivatives(x_data)
poissonbinomial_pdf_pullback(Δ) = (A * Δ,)
return value, poissonbinomial_pdf_pullback
end
end
Expand Down
8 changes: 0 additions & 8 deletions src/zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,6 @@ ZygoteRules.@adjoint function Distributions.Uniform(args...)
return ZygoteRules.pullback(TuringUniform, args...)
end

## PoissonBinomial ##

# Zygote loads ForwardDiff, so this dummy adjoint should never be needed.
# The adjoint that is used for `poissonbinomial_pdf_fft` is defined in `src/zygote_forwarddiff.jl`
# ZygoteRules.@adjoint function poissonbinomial_pdf_fft(x::AbstractArray{T}) where T<:Real
# error("This needs ForwardDiff. `using ForwardDiff` should fix this error.")
# end

## Product

# Tests with `Kolmogorov` seem to fail otherwise?!
Expand Down
20 changes: 0 additions & 20 deletions src/zygote_forwarddiff.jl

This file was deleted.

4 changes: 3 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -16,7 +17,8 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ChainRulesTestUtils = "0.5.3, 0.6"
ChainRulesCore = "0.9"
ChainRulesTestUtils = "0.6.3"
Combinatorics = "1.0.2"
Distributions = "0.24.3"
FiniteDifferences = "0.11.3, 0.12"
Expand Down
Loading

2 comments on commit c463960

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/35276

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.23 -m "<description of version>" c46396060564a1c565ba5e5031ba1a8788d6ed57
git push origin v0.6.23

Please sign in to comment.