From 154769ca230b281961bb67198bd81a3c2cea0872 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 27 Dec 2022 16:07:13 +0100 Subject: [PATCH 01/14] Reorganize files --- src/finitedifferences.jl => ext/FiniteDifferencesExt.jl | 0 src/forwarddiff.jl => ext/ForwardDiffExt.jl | 0 src/reversediff.jl => ext/ReverseDiffExt.jl | 0 src/tracker.jl => ext/TrackerExt.jl | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename src/finitedifferences.jl => ext/FiniteDifferencesExt.jl (100%) rename src/forwarddiff.jl => ext/ForwardDiffExt.jl (100%) rename src/reversediff.jl => ext/ReverseDiffExt.jl (100%) rename src/tracker.jl => ext/TrackerExt.jl (100%) diff --git a/src/finitedifferences.jl b/ext/FiniteDifferencesExt.jl similarity index 100% rename from src/finitedifferences.jl rename to ext/FiniteDifferencesExt.jl diff --git a/src/forwarddiff.jl b/ext/ForwardDiffExt.jl similarity index 100% rename from src/forwarddiff.jl rename to ext/ForwardDiffExt.jl diff --git a/src/reversediff.jl b/ext/ReverseDiffExt.jl similarity index 100% rename from src/reversediff.jl rename to ext/ReverseDiffExt.jl diff --git a/src/tracker.jl b/ext/TrackerExt.jl similarity index 100% rename from src/tracker.jl rename to ext/TrackerExt.jl From b659ca27d2d56f84011ac893550bde5940a8e639 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 27 Dec 2022 16:08:38 +0100 Subject: [PATCH 02/14] Use weak dependencies and remove `AD` alias --- Project.toml | 14 ++++++++ README.md | 16 ++++++--- ext/FiniteDifferencesExt.jl | 28 +++++++-------- ext/ForwardDiffExt.jl | 41 +++++++++++----------- ext/ReverseDiffExt.jl | 36 +++++++++++--------- ext/TrackerExt.jl | 30 +++++++++------- ext/ZygoteExt.jl | 15 ++++++++ src/AbstractDifferentiation.jl | 25 ++++++++------ src/backends.jl | 62 ++++++++++++++++++++++++++++++++++ src/ruleconfig.jl | 2 +- test/defaults.jl | 1 + 11 files changed, 189 insertions(+), 81 deletions(-) create mode 100644 ext/ZygoteExt.jl create mode 100644 src/backends.jl diff --git a/Project.toml b/Project.toml index 965889e..431b2ba 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,20 @@ ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +[weakdeps] +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[extensions] +FiniteDifferencesExt = "FiniteDifferences" +ForwardDiffExt = "ForwardDiff" +ReverseDiffExt = "ReverseDiff" +TrackerExt = "Tracker" +ZygoteExt = "Zygote" + [compat] ChainRulesCore = "1" Compat = "3, 4" diff --git a/README.md b/README.md index 42125b8..62fab0c 100644 --- a/README.md +++ b/README.md @@ -11,18 +11,24 @@ Julia has more (automatic) differentiation packages than you can count on 2 hand ## Loading `AbstractDifferentiation` -To load `AbstractDifferentiation`, use: +To load `AbstractDifferentiation`, it is recommended to use ```julia -using AbstractDifferentiation +import AbstractDifferentiation as AD ``` -`AbstractDifferentiation` exports a single name `AD` which is just an alias for the `AbstractDifferentiation` module itself. You can use this to access names inside `AbstractDifferentiation` using `AD.<>` instead of typing the long name `AbstractDifferentiation`. +on Julia ≥ 1.6 and +```julia +import AbstractDifferentiation +const AD = AbstractDifferentiation +``` +on older Julia versions. +With the `AD` alias you can access names inside of `AbstractDifferentiation` using `AD.<>` instead of typing the long name `AbstractDifferentiation`. ## `AbstractDifferentiation` backends To use `AbstractDifferentiation`, first construct a backend instance `ab::AD.AbstractBackend` using your favorite differentiation package in Julia that supports `AbstractDifferentiation`. In particular, you may want to use `AD.ReverseRuleConfigBackend(ruleconfig)` for any [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible reverse mode differentiation package. -The following backends are temporarily made available by `AbstractDifferentiation` as soon as their corresponding package is loaded (thanks to [Requires.jl](https://github.com/JuliaPackaging/Requires.jl)): +The following backends are temporarily made available by `AbstractDifferentiation` as soon as their corresponding package is loaded (thanks to [weak dependencies](https://pkgdocs.julialang.org/dev/creating-packages/#Weak-dependencies) on Julia ≥ 1.9 and [Requires.jl](https://github.com/JuliaPackaging/Requires.jl) on older Julia versions): - `AD.ForwardDiffBackend()` for [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) - `AD.FiniteDifferencesBackend()` for [FiniteDifferences.jl](https://github.com/JuliaDiff/FiniteDifferences.jl) @@ -35,7 +41,7 @@ In the long term, these backend objects (and many more) will be defined within t Here's an example: ```julia -julia> using AbstractDifferentiation, Zygote +julia> import AbstractDifferentiation as AD, Zygote julia> ab = AD.ZygoteBackend() AbstractDifferentiation.ReverseRuleConfigBackend{Zygote.ZygoteRuleConfig{Zygote.Context}}(Zygote.ZygoteRuleConfig{Zygote.Context}(Zygote.Context(nothing))) diff --git a/ext/FiniteDifferencesExt.jl b/ext/FiniteDifferencesExt.jl index dc7d1ef..c4e5c7e 100644 --- a/ext/FiniteDifferencesExt.jl +++ b/ext/FiniteDifferencesExt.jl @@ -1,36 +1,36 @@ -using .FiniteDifferences: FiniteDifferences +module FiniteDifferencesExt -""" - FiniteDifferencesBackend{M} - -AD backend that uses forward mode with FiniteDifferences.jl. - -The type parameter `M` is the type of the method used to perform finite differences. -""" -struct FiniteDifferencesBackend{M} <: AbstractFiniteDifference - method::M +using AbstractDifferentiation: AbstractDifferentiation, EXTENSIONS_SUPPORTED, FiniteDifferencesBackend +if EXTENSIONS_SUPPORTED + using FiniteDifferences: FiniteDifferences +else + using ..FiniteDifferences: FiniteDifferences end +const AD = AbstractDifferentiation + """ FiniteDifferencesBackend(method=FiniteDifferences.central_fdm(5, 1)) Create an AD backend that uses forward mode with FiniteDifferences.jl. """ -FiniteDifferencesBackend() = FiniteDifferencesBackend(FiniteDifferences.central_fdm(5, 1)) +AD.FiniteDifferencesBackend() = FiniteDifferencesBackend(FiniteDifferences.central_fdm(5, 1)) -@primitive function jacobian(ba::FiniteDifferencesBackend, f, xs...) +AD.@primitive function jacobian(ba::FiniteDifferencesBackend, f, xs...) return FiniteDifferences.jacobian(ba.method, f, xs...) end -function pushforward_function(ba::FiniteDifferencesBackend, f, xs...) +function AD.pushforward_function(ba::FiniteDifferencesBackend, f, xs...) return function pushforward(vs) ws = FiniteDifferences.jvp(ba.method, f, tuple.(xs, vs)...) return length(xs) == 1 ? (ws,) : ws end end -function pullback_function(ba::FiniteDifferencesBackend, f, xs...) +function AD.pullback_function(ba::FiniteDifferencesBackend, f, xs...) function pullback(vs) return FiniteDifferences.j′vp(ba.method, f, vs, xs...) end end + +end # module diff --git a/ext/ForwardDiffExt.jl b/ext/ForwardDiffExt.jl index 08b7fa7..673cc21 100644 --- a/ext/ForwardDiffExt.jl +++ b/ext/ForwardDiffExt.jl @@ -1,16 +1,13 @@ -using .ForwardDiff: ForwardDiff, DiffResults, StaticArrays +module ForwardDiffExt -""" - ForwardDiffBackend{CS} - -AD backend that uses forward mode with ForwardDiff.jl. +using AbstractDifferentiation: AbstractDifferentiation, asarray, EXTENSIONS_SUPPORTED, ForwardDiffBackend +if EXTENSIONS_SUPPORTED + using ForwardDiff: ForwardDiff, DiffResults +else + using ..ForwardDiff: ForwardDiff, DiffResults +end -The type parameter `CS` denotes the chunk size of the differentiation algorithm. If it is -`Nothing`, then ForwardiffDiff uses a heuristic to set the chunk size based on the input. - -See also: [ForwardDiff.jl: Configuring Chunk Size](https://juliadiff.org/ForwardDiff.jl/dev/user/advanced/#Configuring-Chunk-Size) -""" -struct ForwardDiffBackend{CS} <: AbstractForwardMode end +const AD = AbstractDifferentiation """ ForwardDiffBackend(; chunksize::Union{Val,Nothing}=nothing) @@ -23,11 +20,11 @@ ForwarddDiff uses a heuristic to set the chunk size based on the input. Alternat See also: [ForwardDiff.jl: Configuring Chunk Size](https://juliadiff.org/ForwardDiff.jl/dev/user/advanced/#Configuring-Chunk-Size) """ -function ForwardDiffBackend(; chunksize::Union{Val,Nothing}=nothing) +function AD.ForwardDiffBackend(; chunksize::Union{Val,Nothing}=nothing) return ForwardDiffBackend{getchunksize(chunksize)}() end -@primitive function pushforward_function(ba::ForwardDiffBackend, f, xs...) +AD.@primitive function pushforward_function(ba::ForwardDiffBackend, f, xs...) return function pushforward(vs) if length(xs) == 1 v = vs isa Tuple ? only(vs) : vs @@ -38,35 +35,35 @@ end end end -primal_value(x::ForwardDiff.Dual) = ForwardDiff.value(x) -primal_value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) +AD.primal_value(x::ForwardDiff.Dual) = ForwardDiff.value(x) +AD.primal_value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) # these implementations are more efficient than the fallbacks -function gradient(ba::ForwardDiffBackend, f, x::AbstractArray) +function AD.gradient(ba::ForwardDiffBackend, f, x::AbstractArray) cfg = ForwardDiff.GradientConfig(f, x, chunk(ba, x)) return (ForwardDiff.gradient(f, x, cfg),) end -function jacobian(ba::ForwardDiffBackend, f, x::AbstractArray) +function AD.jacobian(ba::ForwardDiffBackend, f, x::AbstractArray) cfg = ForwardDiff.JacobianConfig(asarray ∘ f, x, chunk(ba, x)) return (ForwardDiff.jacobian(asarray ∘ f, x, cfg),) end -jacobian(::ForwardDiffBackend, f, x::Number) = (ForwardDiff.derivative(f, x),) +AD.jacobian(::ForwardDiffBackend, f, x::Number) = (ForwardDiff.derivative(f, x),) -function hessian(ba::ForwardDiffBackend, f, x::AbstractArray) +function AD.hessian(ba::ForwardDiffBackend, f, x::AbstractArray) cfg = ForwardDiff.HessianConfig(f, x, chunk(ba, x)) return (ForwardDiff.hessian(f, x, cfg),) end -function value_and_gradient(ba::ForwardDiffBackend, f, x::AbstractArray) +function AD.value_and_gradient(ba::ForwardDiffBackend, f, x::AbstractArray) result = DiffResults.GradientResult(x) cfg = ForwardDiff.GradientConfig(f, x, chunk(ba, x)) ForwardDiff.gradient!(result, f, x, cfg) return DiffResults.value(result), (DiffResults.derivative(result),) end -function value_and_hessian(ba::ForwardDiffBackend, f, x) +function AD.value_and_hessian(ba::ForwardDiffBackend, f, x) result = DiffResults.HessianResult(x) cfg = ForwardDiff.HessianConfig(f, result, x, chunk(ba, x)) ForwardDiff.hessian!(result, f, x, cfg) @@ -82,3 +79,5 @@ getchunksize(::Val{N}) where {N} = N chunk(::ForwardDiffBackend{Nothing}, x) = ForwardDiff.Chunk(x) chunk(::ForwardDiffBackend{N}, _) where {N} = ForwardDiff.Chunk{N}() + +end # module diff --git a/ext/ReverseDiffExt.jl b/ext/ReverseDiffExt.jl index e1a0764..1a605a2 100644 --- a/ext/ReverseDiffExt.jl +++ b/ext/ReverseDiffExt.jl @@ -1,17 +1,19 @@ -using .ReverseDiff: ReverseDiff, DiffResults +module ReverseDiffExt -primal_value(x::ReverseDiff.TrackedReal) = ReverseDiff.value(x) -primal_value(x::AbstractArray{<:ReverseDiff.TrackedReal}) = ReverseDiff.value.(x) -primal_value(x::ReverseDiff.TrackedArray) = ReverseDiff.value(x) +using AbstractDifferentiation: AbstractDifferentiation, asarray, EXTENSIONS_SUPPORTED, ReverseDiffBackend +if EXTENSIONS_SUPPORTED + using ReverseDiff: ReverseDiff, DiffResults +else + using ..ReverseDiff: ReverseDiff, DiffResults +end -""" - ReverseDiffBackend +const AD = AbstractDifferentiation -AD backend that uses reverse mode with ReverseDiff.jl. -""" -struct ReverseDiffBackend <: AbstractReverseMode end +AD.primal_value(x::ReverseDiff.TrackedReal) = ReverseDiff.value(x) +AD.primal_value(x::AbstractArray{<:ReverseDiff.TrackedReal}) = ReverseDiff.value.(x) +AD.primal_value(x::ReverseDiff.TrackedArray) = ReverseDiff.value(x) -@primitive function jacobian(ba::ReverseDiffBackend, f, xs...) +AD.@primitive function jacobian(ba::ReverseDiffBackend, f, xs...) xs_arr = map(asarray, xs) tape = ReverseDiff.JacobianTape(xs_arr) do (xs_arr...) xs_new = map(xs, xs_arr) do x, x_arr @@ -24,11 +26,11 @@ struct ReverseDiffBackend <: AbstractReverseMode end return x isa Number ? vec(result) : result end end -function jacobian(ba::ReverseDiffBackend, f, xs::AbstractArray...) +function AD.jacobian(ba::ReverseDiffBackend, f, xs::AbstractArray...) return ReverseDiff.jacobian(asarray ∘ f, xs) end -function derivative(ba::ReverseDiffBackend, f, xs::Number...) +function AD.derivative(ba::ReverseDiffBackend, f, xs::Number...) tape = ReverseDiff.InstructionTape() xs_tracked = ReverseDiff.TrackedReal.(xs, zero.(xs), Ref(tape)) y_tracked = f(xs_tracked...) @@ -37,24 +39,26 @@ function derivative(ba::ReverseDiffBackend, f, xs::Number...) return ReverseDiff.deriv.(xs_tracked) end -function gradient(ba::ReverseDiffBackend, f, xs::AbstractArray...) +function AD.gradient(ba::ReverseDiffBackend, f, xs::AbstractArray...) return ReverseDiff.gradient(f, xs) end -function hessian(ba::ReverseDiffBackend, f, x::AbstractArray) +function AD.hessian(ba::ReverseDiffBackend, f, x::AbstractArray) return (ReverseDiff.hessian(f, x),) end -function value_and_gradient(ba::ReverseDiffBackend, f, x::AbstractArray) +function AD.value_and_gradient(ba::ReverseDiffBackend, f, x::AbstractArray) result = DiffResults.GradientResult(x) cfg = ReverseDiff.GradientConfig(x) ReverseDiff.gradient!(result, f, x, cfg) return DiffResults.value(result), (DiffResults.derivative(result),) end -function value_and_hessian(ba::ReverseDiffBackend, f, x) +function AD.value_and_hessian(ba::ReverseDiffBackend, f, x) result = DiffResults.HessianResult(x) cfg = ReverseDiff.HessianConfig(result, x) ReverseDiff.hessian!(result, f, x, cfg) return DiffResults.value(result), (DiffResults.hessian(result),) end + +end # module diff --git a/ext/TrackerExt.jl b/ext/TrackerExt.jl index eb51610..4ae9eff 100644 --- a/ext/TrackerExt.jl +++ b/ext/TrackerExt.jl @@ -1,21 +1,23 @@ -using .Tracker: Tracker +module TrackerExt -""" - TrackerBackend +using AbstractDifferentiation: AbstractDifferentiation, EXTENSIONS_SUPPORTED, TrackerBackend +if EXTENSIONS_SUPPORTED + using Tracker: Tracker +else + using ..Tracker: Tracker +end -AD backend that uses reverse mode with Tracker.jl. -""" -struct TrackerBackend <: AbstractReverseMode end +const AD = AbstractDifferentiation -function second_lowest(::TrackerBackend) +function AD.second_lowest(::TrackerBackend) return throw(ArgumentError("Tracker backend does not support nested differentiation.")) end -primal_value(x::Tracker.TrackedReal) = Tracker.data(x) -primal_value(x::Tracker.TrackedArray) = Tracker.data(x) -primal_value(x::AbstractArray{<:Tracker.TrackedReal}) = Tracker.data.(x) +AD.primal_value(x::Tracker.TrackedReal) = Tracker.data(x) +AD.primal_value(x::Tracker.TrackedArray) = Tracker.data(x) +AD.primal_value(x::AbstractArray{<:Tracker.TrackedReal}) = Tracker.data.(x) -@primitive function pullback_function(ba::TrackerBackend, f, xs...) +AD.@primitive function pullback_function(ba::TrackerBackend, f, xs...) value, back = Tracker.forward(f, xs...) function pullback(ws) if ws isa Tuple && !(value isa Tuple) @@ -28,10 +30,12 @@ primal_value(x::AbstractArray{<:Tracker.TrackedReal}) = Tracker.data.(x) return pullback end -function derivative(ba::TrackerBackend, f, xs::Number...) +function AD.derivative(ba::TrackerBackend, f, xs::Number...) return Tracker.data.(Tracker.gradient(f, xs...)) end -function gradient(ba::TrackerBackend, f, xs::AbstractVector...) +function AD.gradient(ba::TrackerBackend, f, xs::AbstractVector...) return Tracker.data.(Tracker.gradient(f, xs...)) end + +end # module diff --git a/ext/ZygoteExt.jl b/ext/ZygoteExt.jl new file mode 100644 index 0000000..190f251 --- /dev/null +++ b/ext/ZygoteExt.jl @@ -0,0 +1,15 @@ +module ZygoteExt + +using AbstractDifferentiation: AbstractDifferentiation, EXTENSIONS_SUPPORTED, ReverseRuleConfigBackend + +if EXTENSIONS_SUPPORTED + using Zygote: Zygote +else + using ..Zygote: Zygote +end + +@static if isdefined(AbstractDifferentiation, :ZygoteBackend) && isdefined(Zygote, :ZygoteRuleConfig) + AbstractDifferentiation.ZygoteBackend() = ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig()) +end + +end # module diff --git a/src/AbstractDifferentiation.jl b/src/AbstractDifferentiation.jl index 4143be0..3205ad0 100644 --- a/src/AbstractDifferentiation.jl +++ b/src/AbstractDifferentiation.jl @@ -1,12 +1,8 @@ module AbstractDifferentiation -using LinearAlgebra, ExprTools, Requires, Compat +using LinearAlgebra, ExprTools, Compat using ChainRulesCore: ChainRulesCore -export AD - -const AD = AbstractDifferentiation - abstract type AbstractBackend end abstract type AbstractFiniteDifference <: AbstractBackend end abstract type AbstractForwardMode <: AbstractBackend end @@ -645,14 +641,21 @@ end @inline asarray(x::AbstractArray) = x include("ruleconfig.jl") +include("backends.jl") + +# TODO: Replace with proper version +const EXTENSIONS_SUPPORTED = isdefined(Base, :get_extension) +if !EXTENSIONS_SUPPORTED + using Requires: @require +end function __init__() - @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include("forwarddiff.jl") - @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include("reversediff.jl") - @require FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" include("finitedifferences.jl") - @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("tracker.jl") - @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin + @static if !EXTENSIONS_SUPPORTED + @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include("../ext/ForwardDiffExt.jl") + @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include("../ext/ReverseDiffExt.jl") + @require FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" include("../ext/FiniteDifferencesExt.jl") + @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("../ext/TrackerExt.jl") @static if VERSION >= v"1.6" - ZygoteBackend() = ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig()) + @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include("../ext/ZygoteExt.jl") end end end diff --git a/src/backends.jl b/src/backends.jl new file mode 100644 index 0000000..392e8df --- /dev/null +++ b/src/backends.jl @@ -0,0 +1,62 @@ +""" + FiniteDifferencesBackend{M} + +AD backend that uses forward mode with FiniteDifferences.jl. + +The type parameter `M` is the type of the method used to perform finite differences. + +!!! note + To be able to use this backend, you have to load FiniteDifferences. +""" +struct FiniteDifferencesBackend{M} <: AbstractFiniteDifference + method::M +end + +""" + ForwardDiffBackend{CS} + +AD backend that uses forward mode with ForwardDiff.jl. + +The type parameter `CS` denotes the chunk size of the differentiation algorithm. If it is +`Nothing`, then ForwardiffDiff uses a heuristic to set the chunk size based on the input. + +See also: [ForwardDiff.jl: Configuring Chunk Size](https://juliadiff.org/ForwardDiff.jl/dev/user/advanced/#Configuring-Chunk-Size) + +!!! note + To be able to use this backend, you have to load ForwardDiff. +""" +struct ForwardDiffBackend{CS} <: AbstractForwardMode end + +""" + ReverseDiffBackend + +AD backend that uses reverse mode with ReverseDiff.jl. + +!!! note + To be able to use this backend, you have to load ReverseDiff. +""" +struct ReverseDiffBackend <: AbstractReverseMode end + +""" + TrackerBackend + +AD backend that uses reverse mode with Tracker.jl. + +!!! note + To be able to use this backend, you have to load Tracker. +""" +struct TrackerBackend <: AbstractReverseMode end + +@static if VERSION >= v"1.6" +""" + ZygoteBackend() + +Create an AD backend that uses reverse mode with Zygote.jl. + +It is a special case of [`ReverseRuleConfigBackend`](@ref). + +!!! note + To be able to use this backend, you have to load Zygote. +""" +function ZygoteBackend end +end diff --git a/src/ruleconfig.jl b/src/ruleconfig.jl index 1dcb2c1..daa580d 100644 --- a/src/ruleconfig.jl +++ b/src/ruleconfig.jl @@ -7,7 +7,7 @@ struct ReverseRuleConfigBackend{RC<:ChainRulesCore.RuleConfig{>:ChainRulesCore.H ruleconfig::RC end -AD.@primitive function pullback_function(ab::ReverseRuleConfigBackend, f, xs...) +@primitive function pullback_function(ab::ReverseRuleConfigBackend, f, xs...) _, back = ChainRulesCore.rrule_via_ad(ab.ruleconfig, f, xs...) pullback(vs) = Base.tail(back(vs)) pullback(vs::Tuple{Any}) = Base.tail(back(first(vs))) diff --git a/test/defaults.jl b/test/defaults.jl index 1f5abfb..c343a15 100644 --- a/test/defaults.jl +++ b/test/defaults.jl @@ -2,6 +2,7 @@ using AbstractDifferentiation using Test using FiniteDifferences, ForwardDiff, Zygote +const AD = AbstractDifferentiation const FDM = FiniteDifferences ## FiniteDifferences From 87694f398b2c4227f708e3563acd0f4901a8e82a Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 27 Dec 2022 16:08:58 +0100 Subject: [PATCH 03/14] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 431b2ba..6c0352a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "AbstractDifferentiation" uuid = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" authors = ["Mohamed Tarek and contributors"] -version = "0.4.3" +version = "0.5.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From db4af2d7a8c3c8797b3a5ad6a35454b7b0cecf3f Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 27 Dec 2022 17:37:29 +0100 Subject: [PATCH 04/14] Load `only` on Julia < 1.4 --- ext/ForwardDiffExt.jl | 3 +++ ext/ReverseDiffExt.jl | 3 +++ src/AbstractDifferentiation.jl | 2 +- 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/ext/ForwardDiffExt.jl b/ext/ForwardDiffExt.jl index 673cc21..3f7f8e4 100644 --- a/ext/ForwardDiffExt.jl +++ b/ext/ForwardDiffExt.jl @@ -6,6 +6,9 @@ if EXTENSIONS_SUPPORTED else using ..ForwardDiff: ForwardDiff, DiffResults end +if VERSION < v"1.4.0-DEV.142" + using Compat: only +end const AD = AbstractDifferentiation diff --git a/ext/ReverseDiffExt.jl b/ext/ReverseDiffExt.jl index 1a605a2..6e81456 100644 --- a/ext/ReverseDiffExt.jl +++ b/ext/ReverseDiffExt.jl @@ -6,6 +6,9 @@ if EXTENSIONS_SUPPORTED else using ..ReverseDiff: ReverseDiff, DiffResults end +if VERSION < v"1.4.0-DEV.142" + using Compat: only +end const AD = AbstractDifferentiation diff --git a/src/AbstractDifferentiation.jl b/src/AbstractDifferentiation.jl index 3205ad0..0f48da2 100644 --- a/src/AbstractDifferentiation.jl +++ b/src/AbstractDifferentiation.jl @@ -1,6 +1,6 @@ module AbstractDifferentiation -using LinearAlgebra, ExprTools, Compat +using LinearAlgebra, ExprTools using ChainRulesCore: ChainRulesCore abstract type AbstractBackend end From ec29a2d7013fb12a9746c15633bc5510cbb304eb Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 27 Dec 2022 18:33:48 +0100 Subject: [PATCH 05/14] Use `Compat.eachcol` --- src/AbstractDifferentiation.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/AbstractDifferentiation.jl b/src/AbstractDifferentiation.jl index 0f48da2..2e530fb 100644 --- a/src/AbstractDifferentiation.jl +++ b/src/AbstractDifferentiation.jl @@ -2,6 +2,9 @@ module AbstractDifferentiation using LinearAlgebra, ExprTools using ChainRulesCore: ChainRulesCore +if VERSION < v"1.1.0-DEV.792" + using Compat: eachcol +end abstract type AbstractBackend end abstract type AbstractFiniteDifference <: AbstractBackend end From 5403dd2338f8970a598c1e6999d9d76f69630835 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 16 Feb 2023 16:55:22 +0100 Subject: [PATCH 06/14] Rename extensions --- Project.toml | 10 +++++----- ... => AbstractDifferentiationFiniteDifferencesExt.jl} | 2 +- ...Ext.jl => AbstractDifferentiationForwardDiffExt.jl} | 2 +- ...Ext.jl => AbstractDifferentiationReverseDiffExt.jl} | 2 +- ...ckerExt.jl => AbstractDifferentiationTrackerExt.jl} | 2 +- ...ygoteExt.jl => AbstractDifferentiationZygoteExt.jl} | 2 +- src/AbstractDifferentiation.jl | 10 +++++----- 7 files changed, 15 insertions(+), 15 deletions(-) rename ext/{FiniteDifferencesExt.jl => AbstractDifferentiationFiniteDifferencesExt.jl} (95%) rename ext/{ForwardDiffExt.jl => AbstractDifferentiationForwardDiffExt.jl} (98%) rename ext/{ReverseDiffExt.jl => AbstractDifferentiationReverseDiffExt.jl} (97%) rename ext/{TrackerExt.jl => AbstractDifferentiationTrackerExt.jl} (96%) rename ext/{ZygoteExt.jl => AbstractDifferentiationZygoteExt.jl} (91%) diff --git a/Project.toml b/Project.toml index 6c0352a..d13d232 100644 --- a/Project.toml +++ b/Project.toml @@ -18,11 +18,11 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] -FiniteDifferencesExt = "FiniteDifferences" -ForwardDiffExt = "ForwardDiff" -ReverseDiffExt = "ReverseDiff" -TrackerExt = "Tracker" -ZygoteExt = "Zygote" +AbstractDifferentiationFiniteDifferencesExt = "FiniteDifferences" +AbstractDifferentiationForwardDiffExt = "ForwardDiff" +AbstractDifferentiationReverseDiffExt = "ReverseDiff" +AbstractDifferentiationTrackerExt = "Tracker" +AbstractDifferentiationZygoteExt = "Zygote" [compat] ChainRulesCore = "1" diff --git a/ext/FiniteDifferencesExt.jl b/ext/AbstractDifferentiationFiniteDifferencesExt.jl similarity index 95% rename from ext/FiniteDifferencesExt.jl rename to ext/AbstractDifferentiationFiniteDifferencesExt.jl index c4e5c7e..c9cbcce 100644 --- a/ext/FiniteDifferencesExt.jl +++ b/ext/AbstractDifferentiationFiniteDifferencesExt.jl @@ -1,4 +1,4 @@ -module FiniteDifferencesExt +module AbstractDifferentiationFiniteDifferencesExt using AbstractDifferentiation: AbstractDifferentiation, EXTENSIONS_SUPPORTED, FiniteDifferencesBackend if EXTENSIONS_SUPPORTED diff --git a/ext/ForwardDiffExt.jl b/ext/AbstractDifferentiationForwardDiffExt.jl similarity index 98% rename from ext/ForwardDiffExt.jl rename to ext/AbstractDifferentiationForwardDiffExt.jl index 673cc21..968a0aa 100644 --- a/ext/ForwardDiffExt.jl +++ b/ext/AbstractDifferentiationForwardDiffExt.jl @@ -1,4 +1,4 @@ -module ForwardDiffExt +module AbstractDifferentiationForwardDiffExt using AbstractDifferentiation: AbstractDifferentiation, asarray, EXTENSIONS_SUPPORTED, ForwardDiffBackend if EXTENSIONS_SUPPORTED diff --git a/ext/ReverseDiffExt.jl b/ext/AbstractDifferentiationReverseDiffExt.jl similarity index 97% rename from ext/ReverseDiffExt.jl rename to ext/AbstractDifferentiationReverseDiffExt.jl index 1a605a2..49d8b34 100644 --- a/ext/ReverseDiffExt.jl +++ b/ext/AbstractDifferentiationReverseDiffExt.jl @@ -1,4 +1,4 @@ -module ReverseDiffExt +module AbstractDifferentiationReverseDiffExt using AbstractDifferentiation: AbstractDifferentiation, asarray, EXTENSIONS_SUPPORTED, ReverseDiffBackend if EXTENSIONS_SUPPORTED diff --git a/ext/TrackerExt.jl b/ext/AbstractDifferentiationTrackerExt.jl similarity index 96% rename from ext/TrackerExt.jl rename to ext/AbstractDifferentiationTrackerExt.jl index 4ae9eff..83494b7 100644 --- a/ext/TrackerExt.jl +++ b/ext/AbstractDifferentiationTrackerExt.jl @@ -1,4 +1,4 @@ -module TrackerExt +module AbstractDifferentiationTrackerExt using AbstractDifferentiation: AbstractDifferentiation, EXTENSIONS_SUPPORTED, TrackerBackend if EXTENSIONS_SUPPORTED diff --git a/ext/ZygoteExt.jl b/ext/AbstractDifferentiationZygoteExt.jl similarity index 91% rename from ext/ZygoteExt.jl rename to ext/AbstractDifferentiationZygoteExt.jl index 190f251..04d013f 100644 --- a/ext/ZygoteExt.jl +++ b/ext/AbstractDifferentiationZygoteExt.jl @@ -1,4 +1,4 @@ -module ZygoteExt +module AbstractDifferentiationZygoteExt using AbstractDifferentiation: AbstractDifferentiation, EXTENSIONS_SUPPORTED, ReverseRuleConfigBackend diff --git a/src/AbstractDifferentiation.jl b/src/AbstractDifferentiation.jl index 3205ad0..7272ea7 100644 --- a/src/AbstractDifferentiation.jl +++ b/src/AbstractDifferentiation.jl @@ -650,12 +650,12 @@ if !EXTENSIONS_SUPPORTED end function __init__() @static if !EXTENSIONS_SUPPORTED - @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include("../ext/ForwardDiffExt.jl") - @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include("../ext/ReverseDiffExt.jl") - @require FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" include("../ext/FiniteDifferencesExt.jl") - @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("../ext/TrackerExt.jl") + @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include("../ext/AbstractDifferentiationForwardDiffExt.jl") + @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include("../ext/AbstractDifferentiationReverseDiffExt.jl") + @require FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" include("../ext/AbstractDifferentiationFiniteDifferencesExt.jl") + @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("../ext/AbstractDifferentiationTrackerExt.jl") @static if VERSION >= v"1.6" - @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include("../ext/ZygoteExt.jl") + @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include("../ext/AbstractDifferentiationZygoteExt.jl") end end end From a83e71c68b8cf146ca187b1a077b7cb3b02af330 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 16 Feb 2023 17:01:45 +0100 Subject: [PATCH 07/14] Move check out of `__init__` --- src/AbstractDifferentiation.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/AbstractDifferentiation.jl b/src/AbstractDifferentiation.jl index 7272ea7..efd4511 100644 --- a/src/AbstractDifferentiation.jl +++ b/src/AbstractDifferentiation.jl @@ -646,10 +646,10 @@ include("backends.jl") # TODO: Replace with proper version const EXTENSIONS_SUPPORTED = isdefined(Base, :get_extension) if !EXTENSIONS_SUPPORTED - using Requires: @require + using Requires: @require end -function __init__() - @static if !EXTENSIONS_SUPPORTED +if !EXTENSIONS_SUPPORTED + function __init__() @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include("../ext/AbstractDifferentiationForwardDiffExt.jl") @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include("../ext/AbstractDifferentiationReverseDiffExt.jl") @require FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" include("../ext/AbstractDifferentiationFiniteDifferencesExt.jl") From 6fe27db44c9664ccd4f37eb633dd2d5edbbc2a83 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 28 Feb 2023 13:46:52 +0100 Subject: [PATCH 08/14] Move ChainRulesCore support in extension and drop Julia < 1.6 --- Project.toml | 19 +++++++----- ...bstractDifferentiationChainRulesCoreExt.jl | 19 ++++++++++++ ext/AbstractDifferentiationForwardDiffExt.jl | 9 +++--- ext/AbstractDifferentiationReverseDiffExt.jl | 9 +++--- ext/AbstractDifferentiationZygoteExt.jl | 5 +--- src/AbstractDifferentiation.jl | 29 +++++-------------- src/backends.jl | 15 ++++++++-- src/ruleconfig.jl | 15 ---------- test/defaults.jl | 12 +++----- test/runtests.jl | 4 +-- test/test_utils.jl | 8 +---- 11 files changed, 66 insertions(+), 78 deletions(-) create mode 100644 ext/AbstractDifferentiationChainRulesCoreExt.jl delete mode 100644 src/ruleconfig.jl diff --git a/Project.toml b/Project.toml index d13d232..0cf73bf 100644 --- a/Project.toml +++ b/Project.toml @@ -5,12 +5,13 @@ version = "0.5.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Requires = "ae029012-a4dd-5104-9daa-d747884805df" [weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" @@ -18,20 +19,24 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] +AbstractDifferentiationChainRulesCoreExt = "ChainRulesCore" AbstractDifferentiationFiniteDifferencesExt = "FiniteDifferences" -AbstractDifferentiationForwardDiffExt = "ForwardDiff" -AbstractDifferentiationReverseDiffExt = "ReverseDiff" +AbstractDifferentiationForwardDiffExt = ["DiffResults", "ForwardDiff"] +AbstractDifferentiationReverseDiffExt = ["DiffResults", "ReverseDiff"] AbstractDifferentiationTrackerExt = "Tracker" -AbstractDifferentiationZygoteExt = "Zygote" +AbstractDifferentiationZygoteExt = ["ChainRulesCore", "Zygote"] [compat] ChainRulesCore = "1" -Compat = "3, 4" +DiffResults = "1" ExprTools = "0.1" +FiniteDifferences = "0.12" ForwardDiff = "0.10" -Requires = "0.5, 1" +Requires = "1" ReverseDiff = "1" -julia = "1" +Tracker = "0.2" +Zygote = "0.6" +julia = "1.6" [extras] FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" diff --git a/ext/AbstractDifferentiationChainRulesCoreExt.jl b/ext/AbstractDifferentiationChainRulesCoreExt.jl new file mode 100644 index 0000000..00c665c --- /dev/null +++ b/ext/AbstractDifferentiationChainRulesCoreExt.jl @@ -0,0 +1,19 @@ +module AbstractDifferentiationChainRulesCoreExt + +using AbstractDifferentiation: AbstractDifferentiation, EXTENSIONS_SUPPORTED, ReverseRuleConfigBackend +if EXTENSIONS_SUPPORTED + using ChainRulesCore: ChainRulesCore +else + using .ChainRulesCore: ChainRulesCore +end + +const AD = AbstractDifferentiation + +AD.@primitive function pullback_function(ab::ReverseRuleConfigBackend, f, xs...) + _, back = ChainRulesCore.rrule_via_ad(ab.ruleconfig, f, xs...) + pullback(vs) = Base.tail(back(vs)) + pullback(vs::Tuple{Any}) = Base.tail(back(first(vs))) + return pullback +end + +end # module diff --git a/ext/AbstractDifferentiationForwardDiffExt.jl b/ext/AbstractDifferentiationForwardDiffExt.jl index 5f59cf3..260943b 100644 --- a/ext/AbstractDifferentiationForwardDiffExt.jl +++ b/ext/AbstractDifferentiationForwardDiffExt.jl @@ -2,12 +2,11 @@ module AbstractDifferentiationForwardDiffExt using AbstractDifferentiation: AbstractDifferentiation, asarray, EXTENSIONS_SUPPORTED, ForwardDiffBackend if EXTENSIONS_SUPPORTED - using ForwardDiff: ForwardDiff, DiffResults + using DiffResults: DiffResults + using ForwardDiff: ForwardDiff else - using ..ForwardDiff: ForwardDiff, DiffResults -end -if VERSION < v"1.4.0-DEV.142" - using Compat: only + using ..DiffResults: DiffResults + using ..ForwardDiff: ForwardDiff end const AD = AbstractDifferentiation diff --git a/ext/AbstractDifferentiationReverseDiffExt.jl b/ext/AbstractDifferentiationReverseDiffExt.jl index dbcf506..548b0eb 100644 --- a/ext/AbstractDifferentiationReverseDiffExt.jl +++ b/ext/AbstractDifferentiationReverseDiffExt.jl @@ -2,12 +2,11 @@ module AbstractDifferentiationReverseDiffExt using AbstractDifferentiation: AbstractDifferentiation, asarray, EXTENSIONS_SUPPORTED, ReverseDiffBackend if EXTENSIONS_SUPPORTED - using ReverseDiff: ReverseDiff, DiffResults + using DiffResults: DiffResults + using ReverseDiff: ReverseDiff else - using ..ReverseDiff: ReverseDiff, DiffResults -end -if VERSION < v"1.4.0-DEV.142" - using Compat: only + using ..DiffResults: DiffResults + using ..ReverseDiff: ReverseDiff end const AD = AbstractDifferentiation diff --git a/ext/AbstractDifferentiationZygoteExt.jl b/ext/AbstractDifferentiationZygoteExt.jl index 04d013f..04983a0 100644 --- a/ext/AbstractDifferentiationZygoteExt.jl +++ b/ext/AbstractDifferentiationZygoteExt.jl @@ -1,15 +1,12 @@ module AbstractDifferentiationZygoteExt using AbstractDifferentiation: AbstractDifferentiation, EXTENSIONS_SUPPORTED, ReverseRuleConfigBackend - if EXTENSIONS_SUPPORTED using Zygote: Zygote else using ..Zygote: Zygote end -@static if isdefined(AbstractDifferentiation, :ZygoteBackend) && isdefined(Zygote, :ZygoteRuleConfig) - AbstractDifferentiation.ZygoteBackend() = ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig()) -end +AbstractDifferentiation.ZygoteBackend() = ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig()) end # module diff --git a/src/AbstractDifferentiation.jl b/src/AbstractDifferentiation.jl index ca603bc..e810003 100644 --- a/src/AbstractDifferentiation.jl +++ b/src/AbstractDifferentiation.jl @@ -1,10 +1,6 @@ module AbstractDifferentiation using LinearAlgebra, ExprTools -using ChainRulesCore: ChainRulesCore -if VERSION < v"1.1.0-DEV.792" - using Compat: eachcol -end abstract type AbstractBackend end abstract type AbstractFiniteDifference <: AbstractBackend end @@ -526,16 +522,10 @@ function define_pushforward_function_and_friends(fdef) pff = AbstractDifferentiation.pushforward_function($(args...),) if eltype(identity_like) <: Tuple{Vararg{Union{AbstractMatrix, Number}}} return map(identity_like) do identity_like_i - if VERSION < v"1.3" - return reduce(hcat, map(AbstractDifferentiation._eachcol.(identity_like_i)...) do (cols...) - pff(cols) - end) - else return mapreduce(hcat, AbstractDifferentiation._eachcol.(identity_like_i)...) do (cols...) pff(cols) end end - end elseif eltype(identity_like) <: AbstractMatrix # needed for the computation of the Hessian and Jacobian ret = hcat.(mapslices(identity_like[1], dims=1) do cols @@ -569,14 +559,8 @@ function define_pullback_function_and_friends(fdef) identity_like = AbstractDifferentiation.identity_matrix_like(value) if eltype(identity_like) <: Tuple{Vararg{AbstractMatrix}} return map(identity_like) do identity_like_i - if VERSION < v"1.3" - return reduce(vcat, map(AbstractDifferentiation._eachcol.(identity_like_i)...) do (cols...) - value_and_pbf(cols)[2]' - end) - else return mapreduce(vcat, AbstractDifferentiation._eachcol.(identity_like_i)...) do (cols...) value_and_pbf(cols)[2]' - end end end elseif eltype(identity_like) <: AbstractMatrix @@ -650,16 +634,17 @@ include("backends.jl") const EXTENSIONS_SUPPORTED = isdefined(Base, :get_extension) if !EXTENSIONS_SUPPORTED using Requires: @require + include("../ext/AbstractDifferentiationChainRulesCoreExt.jl") end -if !EXTENSIONS_SUPPORTED +@static if !EXTENSIONS_SUPPORTED function __init__() - @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include("../ext/AbstractDifferentiationForwardDiffExt.jl") - @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include("../ext/AbstractDifferentiationReverseDiffExt.jl") + @require DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" begin + @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include("../ext/AbstractDifferentiationForwardDiffExt.jl") + @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include("../ext/AbstractDifferentiationReverseDiffExt.jl") + end @require FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" include("../ext/AbstractDifferentiationFiniteDifferencesExt.jl") @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("../ext/AbstractDifferentiationTrackerExt.jl") - @static if VERSION >= v"1.6" - @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include("../ext/AbstractDifferentiationZygoteExt.jl") - end + @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include("../ext/AbstractDifferentiationZygoteExt.jl") end end diff --git a/src/backends.jl b/src/backends.jl index 392e8df..de3517c 100644 --- a/src/backends.jl +++ b/src/backends.jl @@ -47,7 +47,19 @@ AD backend that uses reverse mode with Tracker.jl. """ struct TrackerBackend <: AbstractReverseMode end -@static if VERSION >= v"1.6" + +""" + ReverseRuleConfigBackend + +AD backend that uses reverse mode with any ChainRules-compatible reverse-mode AD package. + +!!! note + To be able to use this backend, you have to load ChainRulesCore. +""" +struct ReverseRuleConfigBackend{RC} <: AbstractReverseMode + ruleconfig::RC +end + """ ZygoteBackend() @@ -59,4 +71,3 @@ It is a special case of [`ReverseRuleConfigBackend`](@ref). To be able to use this backend, you have to load Zygote. """ function ZygoteBackend end -end diff --git a/src/ruleconfig.jl b/src/ruleconfig.jl deleted file mode 100644 index daa580d..0000000 --- a/src/ruleconfig.jl +++ /dev/null @@ -1,15 +0,0 @@ -""" - ReverseRuleConfigBackend - -AD backend that uses reverse mode with any ChainRules-compatible reverse-mode AD package. -""" -struct ReverseRuleConfigBackend{RC<:ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}} <: AbstractReverseMode - ruleconfig::RC -end - -@primitive function pullback_function(ab::ReverseRuleConfigBackend, f, xs...) - _, back = ChainRulesCore.rrule_via_ad(ab.ruleconfig, f, xs...) - pullback(vs) = Base.tail(back(vs)) - pullback(vs::Tuple{Any}) = Base.tail(back(first(vs))) - return pullback -end diff --git a/test/defaults.jl b/test/defaults.jl index c343a15..2e4c1bc 100644 --- a/test/defaults.jl +++ b/test/defaults.jl @@ -216,10 +216,8 @@ end # Zygote over Zygote problems backends = AD.HigherOrderBackend((forwarddiff_backend2,zygote_backend1)) test_hessians(backends) - if VERSION >= v"1.3" - backends = AD.HigherOrderBackend((zygote_backend1,forwarddiff_backend1)) - test_hessians(backends) - end + backends = AD.HigherOrderBackend((zygote_backend1,forwarddiff_backend1)) + test_hessians(backends) # fails: # backends = AD.HigherOrderBackend((zygote_backend1,forwarddiff_backend2)) # test_hessians(backends) @@ -243,10 +241,8 @@ end # Zygote over Zygote problems backends = AD.HigherOrderBackend((forwarddiff_backend2,zygote_backend1)) test_lazy_hessians(backends) - if VERSION >= v"1.3" - backends = AD.HigherOrderBackend((zygote_backend1,forwarddiff_backend1)) - test_lazy_hessians(backends) - end + backends = AD.HigherOrderBackend((zygote_backend1,forwarddiff_backend1)) + test_lazy_hessians(backends) end end end diff --git a/test/runtests.jl b/test/runtests.jl index b435dee..d79bafc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,7 +8,5 @@ using Test include("reversediff.jl") include("finitedifferences.jl") include("tracker.jl") - @static if VERSION >= v"1.6" - include("ruleconfig.jl") - end + include("ruleconfig.jl") end diff --git a/test/test_utils.jl b/test/test_utils.jl index bc1f95d..6d28445 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -16,13 +16,7 @@ dfgraddydy(x, y) = zeros(length(y),length(y)) function fjac(x, y) x + -3*y + [y[2:end];zero(y[end])]/2# Bidiagonal(-ones(length(y)) * 3, ones(length(y) - 1) / 2, :U) * y end -function dfjacdx(x, y) - if VERSION < v"1.3" - return Matrix{Float64}(I, length(x), length(x)) - else - return I(length(x)) - end -end +dfjacdx(x, y) = I(length(x)) dfjacdy(x, y) = Bidiagonal(-ones(length(y)) * 3, ones(length(y) - 1) / 2, :U) # Jvp From 6dccd1707ee920c9d9005d3e53ce7da7db2710f3 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 28 Feb 2023 14:17:14 +0100 Subject: [PATCH 09/14] Simplify code using features of Julia >= 1.6 --- README.md | 8 +---- ...bstractDifferentiationChainRulesCoreExt.jl | 8 ++--- ...ractDifferentiationFiniteDifferencesExt.jl | 14 ++++----- ext/AbstractDifferentiationForwardDiffExt.jl | 30 +++++++++---------- ext/AbstractDifferentiationReverseDiffExt.jl | 26 ++++++++-------- ext/AbstractDifferentiationTrackerExt.jl | 15 ++++------ ext/AbstractDifferentiationZygoteExt.jl | 6 ++-- src/AbstractDifferentiation.jl | 21 +++++-------- test/defaults.jl | 6 ++-- 9 files changed, 54 insertions(+), 80 deletions(-) diff --git a/README.md b/README.md index 62fab0c..6e8b94a 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # AbstractDifferentiation -[![CI](https://github.com/JuliaDiff/AbstractDifferentiation.jl/workflows/CI/badge.svg?branch=master)](https://github.com/JuliaDiff/AbstractDifferentiation.jl/actions?query=workflow%3ACI) +[![CI](https://github.com/JuliaDiff/AbstractDifferentiation.jl/actions/workflows/CI.yml/badge.svg?branch=master)](https://github.com/JuliaDiff/AbstractDifferentiation.jl/actions/workflows/CI.yml?query=branch%3Amaster) [![Coverage](https://codecov.io/gh/JuliaDiff/AbstractDifferentiation.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/JuliaDiff/AbstractDifferentiation.jl) ## Motivation @@ -15,12 +15,6 @@ To load `AbstractDifferentiation`, it is recommended to use ```julia import AbstractDifferentiation as AD ``` -on Julia ≥ 1.6 and -```julia -import AbstractDifferentiation -const AD = AbstractDifferentiation -``` -on older Julia versions. With the `AD` alias you can access names inside of `AbstractDifferentiation` using `AD.<>` instead of typing the long name `AbstractDifferentiation`. ## `AbstractDifferentiation` backends diff --git a/ext/AbstractDifferentiationChainRulesCoreExt.jl b/ext/AbstractDifferentiationChainRulesCoreExt.jl index 00c665c..0655512 100644 --- a/ext/AbstractDifferentiationChainRulesCoreExt.jl +++ b/ext/AbstractDifferentiationChainRulesCoreExt.jl @@ -1,15 +1,13 @@ module AbstractDifferentiationChainRulesCoreExt -using AbstractDifferentiation: AbstractDifferentiation, EXTENSIONS_SUPPORTED, ReverseRuleConfigBackend -if EXTENSIONS_SUPPORTED +import AbstractDifferentiation as AD +if AD.EXTENSIONS_SUPPORTED using ChainRulesCore: ChainRulesCore else using .ChainRulesCore: ChainRulesCore end -const AD = AbstractDifferentiation - -AD.@primitive function pullback_function(ab::ReverseRuleConfigBackend, f, xs...) +AD.@primitive function pullback_function(ab::AD.ReverseRuleConfigBackend, f, xs...) _, back = ChainRulesCore.rrule_via_ad(ab.ruleconfig, f, xs...) pullback(vs) = Base.tail(back(vs)) pullback(vs::Tuple{Any}) = Base.tail(back(first(vs))) diff --git a/ext/AbstractDifferentiationFiniteDifferencesExt.jl b/ext/AbstractDifferentiationFiniteDifferencesExt.jl index c9cbcce..67e229a 100644 --- a/ext/AbstractDifferentiationFiniteDifferencesExt.jl +++ b/ext/AbstractDifferentiationFiniteDifferencesExt.jl @@ -1,33 +1,31 @@ module AbstractDifferentiationFiniteDifferencesExt -using AbstractDifferentiation: AbstractDifferentiation, EXTENSIONS_SUPPORTED, FiniteDifferencesBackend -if EXTENSIONS_SUPPORTED +import AbstractDifferentiation as AD +if AD.EXTENSIONS_SUPPORTED using FiniteDifferences: FiniteDifferences else using ..FiniteDifferences: FiniteDifferences end -const AD = AbstractDifferentiation - """ FiniteDifferencesBackend(method=FiniteDifferences.central_fdm(5, 1)) Create an AD backend that uses forward mode with FiniteDifferences.jl. """ -AD.FiniteDifferencesBackend() = FiniteDifferencesBackend(FiniteDifferences.central_fdm(5, 1)) +AD.FiniteDifferencesBackend() = AD.FiniteDifferencesBackend(FiniteDifferences.central_fdm(5, 1)) -AD.@primitive function jacobian(ba::FiniteDifferencesBackend, f, xs...) +AD.@primitive function jacobian(ba::AD.FiniteDifferencesBackend, f, xs...) return FiniteDifferences.jacobian(ba.method, f, xs...) end -function AD.pushforward_function(ba::FiniteDifferencesBackend, f, xs...) +function AD.pushforward_function(ba::AD.FiniteDifferencesBackend, f, xs...) return function pushforward(vs) ws = FiniteDifferences.jvp(ba.method, f, tuple.(xs, vs)...) return length(xs) == 1 ? (ws,) : ws end end -function AD.pullback_function(ba::FiniteDifferencesBackend, f, xs...) +function AD.pullback_function(ba::AD.FiniteDifferencesBackend, f, xs...) function pullback(vs) return FiniteDifferences.j′vp(ba.method, f, vs, xs...) end diff --git a/ext/AbstractDifferentiationForwardDiffExt.jl b/ext/AbstractDifferentiationForwardDiffExt.jl index 260943b..55f950b 100644 --- a/ext/AbstractDifferentiationForwardDiffExt.jl +++ b/ext/AbstractDifferentiationForwardDiffExt.jl @@ -1,7 +1,7 @@ module AbstractDifferentiationForwardDiffExt -using AbstractDifferentiation: AbstractDifferentiation, asarray, EXTENSIONS_SUPPORTED, ForwardDiffBackend -if EXTENSIONS_SUPPORTED +import AbstractDifferentiation as AD +if AD.EXTENSIONS_SUPPORTED using DiffResults: DiffResults using ForwardDiff: ForwardDiff else @@ -9,8 +9,6 @@ else using ..ForwardDiff: ForwardDiff end -const AD = AbstractDifferentiation - """ ForwardDiffBackend(; chunksize::Union{Val,Nothing}=nothing) @@ -23,10 +21,10 @@ ForwarddDiff uses a heuristic to set the chunk size based on the input. Alternat See also: [ForwardDiff.jl: Configuring Chunk Size](https://juliadiff.org/ForwardDiff.jl/dev/user/advanced/#Configuring-Chunk-Size) """ function AD.ForwardDiffBackend(; chunksize::Union{Val,Nothing}=nothing) - return ForwardDiffBackend{getchunksize(chunksize)}() + return AD.ForwardDiffBackend{getchunksize(chunksize)}() end -AD.@primitive function pushforward_function(ba::ForwardDiffBackend, f, xs...) +AD.@primitive function pushforward_function(ba::AD.ForwardDiffBackend, f, xs...) return function pushforward(vs) if length(xs) == 1 v = vs isa Tuple ? only(vs) : vs @@ -42,30 +40,30 @@ AD.primal_value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) # these implementations are more efficient than the fallbacks -function AD.gradient(ba::ForwardDiffBackend, f, x::AbstractArray) +function AD.gradient(ba::AD.ForwardDiffBackend, f, x::AbstractArray) cfg = ForwardDiff.GradientConfig(f, x, chunk(ba, x)) return (ForwardDiff.gradient(f, x, cfg),) end -function AD.jacobian(ba::ForwardDiffBackend, f, x::AbstractArray) - cfg = ForwardDiff.JacobianConfig(asarray ∘ f, x, chunk(ba, x)) - return (ForwardDiff.jacobian(asarray ∘ f, x, cfg),) +function AD.jacobian(ba::AD.ForwardDiffBackend, f, x::AbstractArray) + cfg = ForwardDiff.JacobianConfig(AD.asarray ∘ f, x, chunk(ba, x)) + return (ForwardDiff.jacobian(AD.asarray ∘ f, x, cfg),) end -AD.jacobian(::ForwardDiffBackend, f, x::Number) = (ForwardDiff.derivative(f, x),) +AD.jacobian(::AD.ForwardDiffBackend, f, x::Number) = (ForwardDiff.derivative(f, x),) -function AD.hessian(ba::ForwardDiffBackend, f, x::AbstractArray) +function AD.hessian(ba::AD.ForwardDiffBackend, f, x::AbstractArray) cfg = ForwardDiff.HessianConfig(f, x, chunk(ba, x)) return (ForwardDiff.hessian(f, x, cfg),) end -function AD.value_and_gradient(ba::ForwardDiffBackend, f, x::AbstractArray) +function AD.value_and_gradient(ba::AD.ForwardDiffBackend, f, x::AbstractArray) result = DiffResults.GradientResult(x) cfg = ForwardDiff.GradientConfig(f, x, chunk(ba, x)) ForwardDiff.gradient!(result, f, x, cfg) return DiffResults.value(result), (DiffResults.derivative(result),) end -function AD.value_and_hessian(ba::ForwardDiffBackend, f, x) +function AD.value_and_hessian(ba::AD.ForwardDiffBackend, f, x) result = DiffResults.HessianResult(x) cfg = ForwardDiff.HessianConfig(f, result, x, chunk(ba, x)) ForwardDiff.hessian!(result, f, x, cfg) @@ -79,7 +77,7 @@ end getchunksize(::Nothing) = Nothing getchunksize(::Val{N}) where {N} = N -chunk(::ForwardDiffBackend{Nothing}, x) = ForwardDiff.Chunk(x) -chunk(::ForwardDiffBackend{N}, _) where {N} = ForwardDiff.Chunk{N}() +chunk(::AD.ForwardDiffBackend{Nothing}, x) = ForwardDiff.Chunk(x) +chunk(::AD.ForwardDiffBackend{N}, _) where {N} = ForwardDiff.Chunk{N}() end # module diff --git a/ext/AbstractDifferentiationReverseDiffExt.jl b/ext/AbstractDifferentiationReverseDiffExt.jl index 548b0eb..1e3f2a5 100644 --- a/ext/AbstractDifferentiationReverseDiffExt.jl +++ b/ext/AbstractDifferentiationReverseDiffExt.jl @@ -1,7 +1,7 @@ module AbstractDifferentiationReverseDiffExt -using AbstractDifferentiation: AbstractDifferentiation, asarray, EXTENSIONS_SUPPORTED, ReverseDiffBackend -if EXTENSIONS_SUPPORTED +import AbstractDifferentiation as AD +if AD.EXTENSIONS_SUPPORTED using DiffResults: DiffResults using ReverseDiff: ReverseDiff else @@ -9,30 +9,28 @@ else using ..ReverseDiff: ReverseDiff end -const AD = AbstractDifferentiation - AD.primal_value(x::ReverseDiff.TrackedReal) = ReverseDiff.value(x) AD.primal_value(x::AbstractArray{<:ReverseDiff.TrackedReal}) = ReverseDiff.value.(x) AD.primal_value(x::ReverseDiff.TrackedArray) = ReverseDiff.value(x) -AD.@primitive function jacobian(ba::ReverseDiffBackend, f, xs...) - xs_arr = map(asarray, xs) +AD.@primitive function jacobian(::AD.ReverseDiffBackend, f, xs...) + xs_arr = map(AD.asarray, xs) tape = ReverseDiff.JacobianTape(xs_arr) do (xs_arr...) xs_new = map(xs, xs_arr) do x, x_arr return x isa Number ? only(x_arr) : x_arr end - return asarray(f(xs_new...)) + return AD.asarray(f(xs_new...)) end results = ReverseDiff.jacobian!(tape, xs_arr) return map(xs, results) do x, result return x isa Number ? vec(result) : result end end -function AD.jacobian(ba::ReverseDiffBackend, f, xs::AbstractArray...) - return ReverseDiff.jacobian(asarray ∘ f, xs) +function AD.jacobian(::AD.ReverseDiffBackend, f, xs::AbstractArray...) + return ReverseDiff.jacobian(AD.asarray ∘ f, xs) end -function AD.derivative(ba::ReverseDiffBackend, f, xs::Number...) +function AD.derivative(::AD.ReverseDiffBackend, f, xs::Number...) tape = ReverseDiff.InstructionTape() xs_tracked = ReverseDiff.TrackedReal.(xs, zero.(xs), Ref(tape)) y_tracked = f(xs_tracked...) @@ -41,22 +39,22 @@ function AD.derivative(ba::ReverseDiffBackend, f, xs::Number...) return ReverseDiff.deriv.(xs_tracked) end -function AD.gradient(ba::ReverseDiffBackend, f, xs::AbstractArray...) +function AD.gradient(::AD.ReverseDiffBackend, f, xs::AbstractArray...) return ReverseDiff.gradient(f, xs) end -function AD.hessian(ba::ReverseDiffBackend, f, x::AbstractArray) +function AD.hessian(::AD.ReverseDiffBackend, f, x::AbstractArray) return (ReverseDiff.hessian(f, x),) end -function AD.value_and_gradient(ba::ReverseDiffBackend, f, x::AbstractArray) +function AD.value_and_gradient(::AD.ReverseDiffBackend, f, x::AbstractArray) result = DiffResults.GradientResult(x) cfg = ReverseDiff.GradientConfig(x) ReverseDiff.gradient!(result, f, x, cfg) return DiffResults.value(result), (DiffResults.derivative(result),) end -function AD.value_and_hessian(ba::ReverseDiffBackend, f, x) +function AD.value_and_hessian(::AD.ReverseDiffBackend, f, x) result = DiffResults.HessianResult(x) cfg = ReverseDiff.HessianConfig(result, x) ReverseDiff.hessian!(result, f, x, cfg) diff --git a/ext/AbstractDifferentiationTrackerExt.jl b/ext/AbstractDifferentiationTrackerExt.jl index 83494b7..aa9a716 100644 --- a/ext/AbstractDifferentiationTrackerExt.jl +++ b/ext/AbstractDifferentiationTrackerExt.jl @@ -1,14 +1,12 @@ module AbstractDifferentiationTrackerExt -using AbstractDifferentiation: AbstractDifferentiation, EXTENSIONS_SUPPORTED, TrackerBackend -if EXTENSIONS_SUPPORTED +import AbstractDifferentiation as AD +if AD.EXTENSIONS_SUPPORTED using Tracker: Tracker else using ..Tracker: Tracker end -const AD = AbstractDifferentiation - function AD.second_lowest(::TrackerBackend) return throw(ArgumentError("Tracker backend does not support nested differentiation.")) end @@ -17,12 +15,11 @@ AD.primal_value(x::Tracker.TrackedReal) = Tracker.data(x) AD.primal_value(x::Tracker.TrackedArray) = Tracker.data(x) AD.primal_value(x::AbstractArray{<:Tracker.TrackedReal}) = Tracker.data.(x) -AD.@primitive function pullback_function(ba::TrackerBackend, f, xs...) +AD.@primitive function pullback_function(::AD.TrackerBackend, f, xs...) value, back = Tracker.forward(f, xs...) function pullback(ws) if ws isa Tuple && !(value isa Tuple) - @assert length(ws) == 1 - map(Tracker.data, back(ws[1])) + map(Tracker.data, back(only(ws))) else map(Tracker.data, back(ws)) end @@ -30,11 +27,11 @@ AD.@primitive function pullback_function(ba::TrackerBackend, f, xs...) return pullback end -function AD.derivative(ba::TrackerBackend, f, xs::Number...) +function AD.derivative(::AD.TrackerBackend, f, xs::Number...) return Tracker.data.(Tracker.gradient(f, xs...)) end -function AD.gradient(ba::TrackerBackend, f, xs::AbstractVector...) +function AD.gradient(::AD.TrackerBackend, f, xs::AbstractVector...) return Tracker.data.(Tracker.gradient(f, xs...)) end diff --git a/ext/AbstractDifferentiationZygoteExt.jl b/ext/AbstractDifferentiationZygoteExt.jl index 04983a0..bd65f84 100644 --- a/ext/AbstractDifferentiationZygoteExt.jl +++ b/ext/AbstractDifferentiationZygoteExt.jl @@ -1,12 +1,12 @@ module AbstractDifferentiationZygoteExt -using AbstractDifferentiation: AbstractDifferentiation, EXTENSIONS_SUPPORTED, ReverseRuleConfigBackend -if EXTENSIONS_SUPPORTED +import AbstractDifferentiation as AD +if AD.EXTENSIONS_SUPPORTED using Zygote: Zygote else using ..Zygote: Zygote end -AbstractDifferentiation.ZygoteBackend() = ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig()) +AD.ZygoteBackend() = AD.ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig()) end # module diff --git a/src/AbstractDifferentiation.jl b/src/AbstractDifferentiation.jl index e810003..66f95dc 100644 --- a/src/AbstractDifferentiation.jl +++ b/src/AbstractDifferentiation.jl @@ -50,8 +50,7 @@ end function hessian(ab::AbstractBackend, f, x) if x isa Tuple # only support computation of Hessian for functions with single input argument - @assert length(x) == 1 - x = x[1] + x = only(x) end return jacobian(second_lowest(ab), x -> begin gradient(lowest(ab), f, x)[1] # gradient returns a tuple @@ -87,8 +86,7 @@ end function value_and_hessian(ab::AbstractBackend, f, x) if x isa Tuple # only support computation of Hessian for functions with single input argument - @assert length(x) == 1 - x = x[1] + x = only(x) end local value @@ -110,8 +108,7 @@ end function value_and_hessian(ab::HigherOrderBackend, f, x) if x isa Tuple # only support computation of Hessian for functions with single input argument - @assert length(x) == 1 - x = x[1] + x = only(x) end local value primalcalled = false @@ -128,8 +125,7 @@ end function value_gradient_and_hessian(ab::AbstractBackend, f, x) if x isa Tuple # only support computation of Hessian for functions with single input argument - @assert length(x) == 1 - x = x[1] + x = only(x) end local value primalcalled = false @@ -146,8 +142,7 @@ end function value_gradient_and_hessian(ab::HigherOrderBackend, f, x) if x isa Tuple # only support computation of Hessian for functions with single input argument - @assert length(x) == 1 - x = x[1] + x = only(x) end local value primalcalled = false @@ -174,8 +169,7 @@ function pushforward_function( newxs = xs .+ ds .* xds return f(newxs...) else - @assert length(xs) == length(xds) == 1 - newx = xs[1] + ds * xds[1] + newx = only(xs) + ds * only(xds) return f(newx) end end, _zero.(xs, ds)...) @@ -219,8 +213,7 @@ _zero(::Any, d::Any) = zero(d) @inline _dot(x, y) = dot(x, y) @inline function _dot(x::AbstractVector, y::UniformScaling) - @assert length(x) == 1 - return @inbounds dot(x[1], y.λ) + return @inbounds dot(only(x), y.λ) end @inline function _dot(x::AbstractVector, y::AbstractMatrix) @assert size(y, 2) == 1 diff --git a/test/defaults.jl b/test/defaults.jl index 2e4c1bc..ea1d080 100644 --- a/test/defaults.jl +++ b/test/defaults.jl @@ -39,8 +39,7 @@ AD.@primitive function pullback_function(ab::FDMBackend3, f, xs...) if vs isa AbstractVector return FDM.j′vp(ab.alg, f, vs, xs...) else - @assert length(vs) == 1 - return FDM.j′vp(ab.alg, f, vs[1], xs...) + return FDM.j′vp(ab.alg, f, only(vs), xs...) end end end @@ -98,8 +97,7 @@ AD.@primitive function pullback_function(ab::ZygoteBackend1, f, xs...) if vs isa AbstractVector back(vs) else - @assert length(vs) == 1 - back(vs[1]) + back(only(vs)) end end end From 9e77312dd70df76bef0de01222e1cd32efc4cf7c Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 28 Feb 2023 14:21:30 +0100 Subject: [PATCH 10/14] Update CI --- .github/workflows/CI.yml | 9 ++++----- Project.toml | 4 +++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 53ea562..6283227 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -20,16 +20,15 @@ jobs: fail-fast: false matrix: version: - - '1.0' - '1.6' - '1' - #- 'nightly' + - 'nightly' os: - ubuntu-latest arch: - x64 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} @@ -38,6 +37,6 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v1 + - uses: codecov/codecov-action@v3 with: - file: lcov.info + files: lcov.info diff --git a/Project.toml b/Project.toml index 0cf73bf..8024bfe 100644 --- a/Project.toml +++ b/Project.toml @@ -39,6 +39,8 @@ Zygote = "0.6" julia = "1.6" [extras] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -48,4 +50,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "FiniteDifferences", "ForwardDiff", "Random", "ReverseDiff", "Tracker", "Zygote"] +test = ["Test", "ChainRulesCore", "DiffResults", "FiniteDifferences", "ForwardDiff", "Random", "ReverseDiff", "Tracker", "Zygote"] From d42dbe6f35fef657a60cafec8fcca075a2252801 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 28 Feb 2023 14:26:13 +0100 Subject: [PATCH 11/14] Fix include statements --- src/AbstractDifferentiation.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/AbstractDifferentiation.jl b/src/AbstractDifferentiation.jl index 66f95dc..41a0353 100644 --- a/src/AbstractDifferentiation.jl +++ b/src/AbstractDifferentiation.jl @@ -620,7 +620,6 @@ end @inline asarray(x) = [x] @inline asarray(x::AbstractArray) = x -include("ruleconfig.jl") include("backends.jl") # TODO: Replace with proper version From 90fcbae1f483c20f4aaffd49e4b6e331a208c0db Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 28 Feb 2023 15:20:47 +0100 Subject: [PATCH 12/14] More fixes --- ...bstractDifferentiationChainRulesCoreExt.jl | 8 ++----- ext/AbstractDifferentiationForwardDiffExt.jl | 2 +- ext/AbstractDifferentiationReverseDiffExt.jl | 2 +- ext/AbstractDifferentiationTrackerExt.jl | 2 +- src/AbstractDifferentiation.jl | 24 +++++++++---------- src/backends.jl | 3 ++- test/test_utils.jl | 4 ++-- 7 files changed, 21 insertions(+), 24 deletions(-) diff --git a/ext/AbstractDifferentiationChainRulesCoreExt.jl b/ext/AbstractDifferentiationChainRulesCoreExt.jl index 0655512..673de12 100644 --- a/ext/AbstractDifferentiationChainRulesCoreExt.jl +++ b/ext/AbstractDifferentiationChainRulesCoreExt.jl @@ -1,13 +1,9 @@ module AbstractDifferentiationChainRulesCoreExt import AbstractDifferentiation as AD -if AD.EXTENSIONS_SUPPORTED - using ChainRulesCore: ChainRulesCore -else - using .ChainRulesCore: ChainRulesCore -end +using ChainRulesCore: ChainRulesCore -AD.@primitive function pullback_function(ab::AD.ReverseRuleConfigBackend, f, xs...) +AD.@primitive function pullback_function(ba::AD.ReverseRuleConfigBackend, f, xs...) _, back = ChainRulesCore.rrule_via_ad(ab.ruleconfig, f, xs...) pullback(vs) = Base.tail(back(vs)) pullback(vs::Tuple{Any}) = Base.tail(back(first(vs))) diff --git a/ext/AbstractDifferentiationForwardDiffExt.jl b/ext/AbstractDifferentiationForwardDiffExt.jl index bfa25bd..55f950b 100644 --- a/ext/AbstractDifferentiationForwardDiffExt.jl +++ b/ext/AbstractDifferentiationForwardDiffExt.jl @@ -24,7 +24,7 @@ function AD.ForwardDiffBackend(; chunksize::Union{Val,Nothing}=nothing) return AD.ForwardDiffBackend{getchunksize(chunksize)}() end -AD.@primitive function pushforward_function(::AD.ForwardDiffBackend, f, xs...) +AD.@primitive function pushforward_function(ba::AD.ForwardDiffBackend, f, xs...) return function pushforward(vs) if length(xs) == 1 v = vs isa Tuple ? only(vs) : vs diff --git a/ext/AbstractDifferentiationReverseDiffExt.jl b/ext/AbstractDifferentiationReverseDiffExt.jl index 1e3f2a5..9e8f46d 100644 --- a/ext/AbstractDifferentiationReverseDiffExt.jl +++ b/ext/AbstractDifferentiationReverseDiffExt.jl @@ -13,7 +13,7 @@ AD.primal_value(x::ReverseDiff.TrackedReal) = ReverseDiff.value(x) AD.primal_value(x::AbstractArray{<:ReverseDiff.TrackedReal}) = ReverseDiff.value.(x) AD.primal_value(x::ReverseDiff.TrackedArray) = ReverseDiff.value(x) -AD.@primitive function jacobian(::AD.ReverseDiffBackend, f, xs...) +AD.@primitive function jacobian(ba::AD.ReverseDiffBackend, f, xs...) xs_arr = map(AD.asarray, xs) tape = ReverseDiff.JacobianTape(xs_arr) do (xs_arr...) xs_new = map(xs, xs_arr) do x, x_arr diff --git a/ext/AbstractDifferentiationTrackerExt.jl b/ext/AbstractDifferentiationTrackerExt.jl index aa9a716..53770fd 100644 --- a/ext/AbstractDifferentiationTrackerExt.jl +++ b/ext/AbstractDifferentiationTrackerExt.jl @@ -15,7 +15,7 @@ AD.primal_value(x::Tracker.TrackedReal) = Tracker.data(x) AD.primal_value(x::Tracker.TrackedArray) = Tracker.data(x) AD.primal_value(x::AbstractArray{<:Tracker.TrackedReal}) = Tracker.data.(x) -AD.@primitive function pullback_function(::AD.TrackerBackend, f, xs...) +AD.@primitive function pullback_function(ba::AD.TrackerBackend, f, xs...) value, back = Tracker.forward(f, xs...) function pullback(ws) if ws isa Tuple && !(value isa Tuple) diff --git a/src/AbstractDifferentiation.jl b/src/AbstractDifferentiation.jl index 41a0353..144f382 100644 --- a/src/AbstractDifferentiation.jl +++ b/src/AbstractDifferentiation.jl @@ -506,16 +506,16 @@ macro primitive(expr) end function define_pushforward_function_and_friends(fdef) - fdef[:name] = :(AbstractDifferentiation.pushforward_function) + fdef[:name] = :($(AbstractDifferentiation).pushforward_function) args = fdef[:args] funcs = quote $(ExprTools.combinedef(fdef)) - function AbstractDifferentiation.jacobian($(args...),) - identity_like = AbstractDifferentiation.identity_matrix_like($(args[3:end]...),) - pff = AbstractDifferentiation.pushforward_function($(args...),) + function $(AbstractDifferentiation).jacobian($(args...),) + identity_like = $(identity_matrix_like)($(args[3:end]...),) + pff = $(pushforward_function)($(args...),) if eltype(identity_like) <: Tuple{Vararg{Union{AbstractMatrix, Number}}} return map(identity_like) do identity_like_i - return mapreduce(hcat, AbstractDifferentiation._eachcol.(identity_like_i)...) do (cols...) + return mapreduce(hcat, $(_eachcol).(identity_like_i)...) do (cols...) pff(cols) end end @@ -542,17 +542,17 @@ function define_pushforward_function_and_friends(fdef) end function define_pullback_function_and_friends(fdef) - fdef[:name] = :(AbstractDifferentiation.pullback_function) + fdef[:name] = :($(AbstractDifferentiation).pullback_function) args = fdef[:args] funcs = quote $(ExprTools.combinedef(fdef)) - function AbstractDifferentiation.jacobian($(args...),) - value_and_pbf = AbstractDifferentiation.value_and_pullback_function($(args...),) + function $(AbstractDifferentiation).jacobian($(args...),) + value_and_pbf = $(value_and_pullback_function)($(args...),) value, _ = value_and_pbf(nothing) - identity_like = AbstractDifferentiation.identity_matrix_like(value) + identity_like = $(identity_matrix_like)(value) if eltype(identity_like) <: Tuple{Vararg{AbstractMatrix}} return map(identity_like) do identity_like_i - return mapreduce(vcat, AbstractDifferentiation._eachcol.(identity_like_i)...) do (cols...) + return mapreduce(vcat, $(_eachcol).(identity_like_i)...) do (cols...) value_and_pbf(cols)[2]' end end @@ -575,12 +575,12 @@ _eachcol(a::Number) = (a,) _eachcol(a) = eachcol(a) function define_jacobian_and_friends(fdef) - fdef[:name] = :(AbstractDifferentiation.jacobian) + fdef[:name] = :($(AbstractDifferentiation).jacobian) return ExprTools.combinedef(fdef) end function define_primal_value(fdef) - fdef[:name] = :(AbstractDifferentiation.primal_value) + fdef[:name] = :($(AbstractDifferentiation).primal_value) return ExprTools.combinedef(fdef) end diff --git a/src/backends.jl b/src/backends.jl index de3517c..7009195 100644 --- a/src/backends.jl +++ b/src/backends.jl @@ -54,7 +54,8 @@ struct TrackerBackend <: AbstractReverseMode end AD backend that uses reverse mode with any ChainRules-compatible reverse-mode AD package. !!! note - To be able to use this backend, you have to load ChainRulesCore. + On Julia >= 1.9, you have to load ChainRulesCore (possibly implicitly by loading + a ChainRules-compatible AD package) to be able to use this backend. """ struct ReverseRuleConfigBackend{RC} <: AbstractReverseMode ruleconfig::RC diff --git a/test/test_utils.jl b/test/test_utils.jl index 6d28445..3711fcd 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -150,8 +150,8 @@ function test_hessians(backend; multiple_inputs=false, test_types=true) else # explicit test that AbstractDifferentiation throws an error # don't support tuple of Hessians - @test_throws AssertionError H1 = AD.hessian(backend, fgrad, (xvec, yvec)) - @test_throws MethodError H1 = AD.hessian(backend, fgrad, xvec, yvec) + @test_throws ArgumentError AD.hessian(backend, fgrad, (xvec, yvec)) + @test_throws MethodError AD.hessian(backend, fgrad, xvec, yvec) end # @test dfgraddxdx(xvec,yvec) ≈ H1[1] atol=1e-10 From 8378e262b8282e6941cba4d0f2abf803de182867 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 28 Feb 2023 15:38:15 +0100 Subject: [PATCH 13/14] Additional fixes --- ext/AbstractDifferentiationChainRulesCoreExt.jl | 2 +- ext/AbstractDifferentiationTrackerExt.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/AbstractDifferentiationChainRulesCoreExt.jl b/ext/AbstractDifferentiationChainRulesCoreExt.jl index 673de12..c2a2ff6 100644 --- a/ext/AbstractDifferentiationChainRulesCoreExt.jl +++ b/ext/AbstractDifferentiationChainRulesCoreExt.jl @@ -4,7 +4,7 @@ import AbstractDifferentiation as AD using ChainRulesCore: ChainRulesCore AD.@primitive function pullback_function(ba::AD.ReverseRuleConfigBackend, f, xs...) - _, back = ChainRulesCore.rrule_via_ad(ab.ruleconfig, f, xs...) + _, back = ChainRulesCore.rrule_via_ad(ba.ruleconfig, f, xs...) pullback(vs) = Base.tail(back(vs)) pullback(vs::Tuple{Any}) = Base.tail(back(first(vs))) return pullback diff --git a/ext/AbstractDifferentiationTrackerExt.jl b/ext/AbstractDifferentiationTrackerExt.jl index 53770fd..14b4d08 100644 --- a/ext/AbstractDifferentiationTrackerExt.jl +++ b/ext/AbstractDifferentiationTrackerExt.jl @@ -7,7 +7,7 @@ else using ..Tracker: Tracker end -function AD.second_lowest(::TrackerBackend) +function AD.second_lowest(::AD.TrackerBackend) return throw(ArgumentError("Tracker backend does not support nested differentiation.")) end From 9db3b75eeca4ca1a7b51250700d4cc256114d005 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 1 Mar 2023 23:01:54 +0100 Subject: [PATCH 14/14] Simplify Zygote extension --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8024bfe..94bcbae 100644 --- a/Project.toml +++ b/Project.toml @@ -24,7 +24,7 @@ AbstractDifferentiationFiniteDifferencesExt = "FiniteDifferences" AbstractDifferentiationForwardDiffExt = ["DiffResults", "ForwardDiff"] AbstractDifferentiationReverseDiffExt = ["DiffResults", "ReverseDiff"] AbstractDifferentiationTrackerExt = "Tracker" -AbstractDifferentiationZygoteExt = ["ChainRulesCore", "Zygote"] +AbstractDifferentiationZygoteExt = "Zygote" [compat] ChainRulesCore = "1"