From 1f46e17d0c25b7598a520c07ce78b960429a9e25 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 19 Apr 2021 09:35:23 +0200 Subject: [PATCH 1/4] Fix errors with Adapt 3.3.0 --- src/common.jl | 11 ++++++++++- src/forwarddiff.jl | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/common.jl b/src/common.jl index fea04d08..95d5e77d 100644 --- a/src/common.jl +++ b/src/common.jl @@ -32,9 +32,18 @@ end # Tracker's implementation of ldiv isn't good. We'll use Zygote's instead. zygote_ldiv(A::AbstractMatrix, B::AbstractVecOrMat) = A \ B +# fixes `randn` on GPU (https://github.com/TuringLang/DistributionsAD.jl/pull/108) function adapt_randn(rng::AbstractRNG, x::AbstractArray, dims...) - adapt(typeof(x), randn(rng, eltype(x), dims...)) + return adapt_randn(rng, eltype(x), x, dims...) end +function adapt_randn(rng::AbstractRNG, ::Type{T}, x::AbstractArray, dims...) where {T} + return adapt(parameterless_type(x), randn(rng, T, dims...)) +end + +# required by Adapt >= 3.3.0: https://github.com/SciML/OrdinaryDiffEq.jl/issues/1369 +Base.@pure __parameterless_type(T) = Base.typename(T).wrapper +parameterless_type(x) = parameterless_type(typeof(x)) +parameterless_type(x::Type) = __parameterless_type(x) # TODO: should be replaced by @non_differentiable when # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/212 is fixed diff --git a/src/forwarddiff.jl b/src/forwarddiff.jl index 4e591be2..8bafa2d1 100644 --- a/src/forwarddiff.jl +++ b/src/forwarddiff.jl @@ -1,5 +1,5 @@ function adapt_randn(rng::AbstractRNG, x::AbstractArray{<:ForwardDiff.Dual}, dims...) - adapt(typeof(x), randn(rng, ForwardDiff.valtype(eltype(x)), dims...)) + return adapt_randn(rng, ForwardDiff.valtype(eltype(x)), x, dims...) end ## Binomial ## From 111d7b53f725da32ba91acc1156ba9b6f135cd49 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 19 Apr 2021 09:36:19 +0200 Subject: [PATCH 2/4] Use `@non_differentiable` --- Project.toml | 4 +--- src/DistributionsAD.jl | 1 - src/common.jl | 9 +-------- 3 files changed, 2 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index 3704083d..341d321d 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,6 @@ version = "0.6.21" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" @@ -24,8 +23,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] Adapt = "2, 3" -ChainRules = "0.7" -ChainRulesCore = "0.9.9" +ChainRulesCore = "0.9.21" Compat = "3.6" DiffRules = "0.1, 1.0" Distributions = "0.23.3, 0.24" diff --git a/src/DistributionsAD.jl b/src/DistributionsAD.jl index c9306485..6729e3e0 100644 --- a/src/DistributionsAD.jl +++ b/src/DistributionsAD.jl @@ -9,7 +9,6 @@ using PDMats, Compat, Requires, ZygoteRules, - ChainRules, # needed for `ChainRules.chol_blocked_rev` ChainRulesCore, FillArrays, Adapt diff --git a/src/common.jl b/src/common.jl index 95d5e77d..7d1e1417 100644 --- a/src/common.jl +++ b/src/common.jl @@ -45,11 +45,4 @@ Base.@pure __parameterless_type(T) = Base.typename(T).wrapper parameterless_type(x) = parameterless_type(typeof(x)) parameterless_type(x::Type) = __parameterless_type(x) -# TODO: should be replaced by @non_differentiable when -# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/212 is fixed -function ChainRules.rrule(::typeof(adapt_randn), rng::AbstractRNG, x::AbstractArray, dims...) - function adapt_randn_pullback(ΔQ) - return (NO_FIELDS, Zero(), Zero(), map(_ -> Zero(), dims)...) - end - adapt_randn(rng, x, dims...), adapt_randn_pullback -end +ChainRulesCore.@non_differentiable adapt_randn(::Any...) From 28e1cee8622653af836b632d8677a2a6b17e4aa0 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 19 Apr 2021 09:36:40 +0200 Subject: [PATCH 3/4] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 341d321d..34571fee 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DistributionsAD" uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -version = "0.6.21" +version = "0.6.22" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From e120b6f42b32134035b0af51df237627544278b2 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 19 Apr 2021 10:05:28 +0200 Subject: [PATCH 4/4] Do not drop ChainRules dependency --- Project.toml | 2 ++ src/DistributionsAD.jl | 1 + src/common.jl | 2 +- 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 34571fee..71028ee9 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ version = "0.6.22" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" @@ -23,6 +24,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] Adapt = "2, 3" +ChainRules = "0.7" ChainRulesCore = "0.9.21" Compat = "3.6" DiffRules = "0.1, 1.0" diff --git a/src/DistributionsAD.jl b/src/DistributionsAD.jl index 6729e3e0..c9306485 100644 --- a/src/DistributionsAD.jl +++ b/src/DistributionsAD.jl @@ -9,6 +9,7 @@ using PDMats, Compat, Requires, ZygoteRules, + ChainRules, # needed for `ChainRules.chol_blocked_rev` ChainRulesCore, FillArrays, Adapt diff --git a/src/common.jl b/src/common.jl index 7d1e1417..ba469308 100644 --- a/src/common.jl +++ b/src/common.jl @@ -45,4 +45,4 @@ Base.@pure __parameterless_type(T) = Base.typename(T).wrapper parameterless_type(x) = parameterless_type(typeof(x)) parameterless_type(x::Type) = __parameterless_type(x) -ChainRulesCore.@non_differentiable adapt_randn(::Any...) +@non_differentiable adapt_randn(::Any...)