From 978b1fece5784369ebbf8e7d4cbebb5034ff4961 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 19 Apr 2021 17:51:38 +0200 Subject: [PATCH] Fix errors with Adapt 3.3.0 (#161) * Fix errors with Adapt 3.3.0 * Use `@non_differentiable` * Bump version * Do not drop ChainRules dependency --- Project.toml | 4 ++-- src/common.jl | 20 +++++++++++--------- src/forwarddiff.jl | 2 +- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index 3704083d..71028ee9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DistributionsAD" uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -version = "0.6.21" +version = "0.6.22" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -25,7 +25,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] Adapt = "2, 3" ChainRules = "0.7" -ChainRulesCore = "0.9.9" +ChainRulesCore = "0.9.21" Compat = "3.6" DiffRules = "0.1, 1.0" Distributions = "0.23.3, 0.24" diff --git a/src/common.jl b/src/common.jl index fea04d08..ba469308 100644 --- a/src/common.jl +++ b/src/common.jl @@ -32,15 +32,17 @@ end # Tracker's implementation of ldiv isn't good. We'll use Zygote's instead. zygote_ldiv(A::AbstractMatrix, B::AbstractVecOrMat) = A \ B +# fixes `randn` on GPU (https://github.com/TuringLang/DistributionsAD.jl/pull/108) function adapt_randn(rng::AbstractRNG, x::AbstractArray, dims...) - adapt(typeof(x), randn(rng, eltype(x), dims...)) + return adapt_randn(rng, eltype(x), x, dims...) end - -# TODO: should be replaced by @non_differentiable when -# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/212 is fixed -function ChainRules.rrule(::typeof(adapt_randn), rng::AbstractRNG, x::AbstractArray, dims...) - function adapt_randn_pullback(ΔQ) - return (NO_FIELDS, Zero(), Zero(), map(_ -> Zero(), dims)...) - end - adapt_randn(rng, x, dims...), adapt_randn_pullback +function adapt_randn(rng::AbstractRNG, ::Type{T}, x::AbstractArray, dims...) where {T} + return adapt(parameterless_type(x), randn(rng, T, dims...)) end + +# required by Adapt >= 3.3.0: https://github.com/SciML/OrdinaryDiffEq.jl/issues/1369 +Base.@pure __parameterless_type(T) = Base.typename(T).wrapper +parameterless_type(x) = parameterless_type(typeof(x)) +parameterless_type(x::Type) = __parameterless_type(x) + +@non_differentiable adapt_randn(::Any...) diff --git a/src/forwarddiff.jl b/src/forwarddiff.jl index 4e591be2..8bafa2d1 100644 --- a/src/forwarddiff.jl +++ b/src/forwarddiff.jl @@ -1,5 +1,5 @@ function adapt_randn(rng::AbstractRNG, x::AbstractArray{<:ForwardDiff.Dual}, dims...) - adapt(typeof(x), randn(rng, ForwardDiff.valtype(eltype(x)), dims...)) + return adapt_randn(rng, ForwardDiff.valtype(eltype(x)), x, dims...) end ## Binomial ##