Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use weak dependencies if supported #68

Merged
merged 16 commits into from
Mar 2, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "AbstractDifferentiation"
uuid = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
authors = ["Mohamed Tarek <mohamed82008@gmail.com> and contributors"]
version = "0.4.3"
version = "0.5.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -10,6 +10,20 @@ ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"

[weakdeps]
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
FiniteDifferencesExt = "FiniteDifferences"
ForwardDiffExt = "ForwardDiff"
ReverseDiffExt = "ReverseDiff"
TrackerExt = "Tracker"
ZygoteExt = "Zygote"

[compat]
ChainRulesCore = "1"
Compat = "3, 4"
Expand Down
16 changes: 11 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,24 @@ Julia has more (automatic) differentiation packages than you can count on 2 hand

## Loading `AbstractDifferentiation`

To load `AbstractDifferentiation`, use:
To load `AbstractDifferentiation`, it is recommended to use
```julia
using AbstractDifferentiation
import AbstractDifferentiation as AD
```
`AbstractDifferentiation` exports a single name `AD` which is just an alias for the `AbstractDifferentiation` module itself. You can use this to access names inside `AbstractDifferentiation` using `AD.<>` instead of typing the long name `AbstractDifferentiation`.
on Julia ≥ 1.6 and
```julia
import AbstractDifferentiation
const AD = AbstractDifferentiation
```
on older Julia versions.
With the `AD` alias you can access names inside of `AbstractDifferentiation` using `AD.<>` instead of typing the long name `AbstractDifferentiation`.

## `AbstractDifferentiation` backends

To use `AbstractDifferentiation`, first construct a backend instance `ab::AD.AbstractBackend` using your favorite differentiation package in Julia that supports `AbstractDifferentiation`.
In particular, you may want to use `AD.ReverseRuleConfigBackend(ruleconfig)` for any [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible reverse mode differentiation package.

The following backends are temporarily made available by `AbstractDifferentiation` as soon as their corresponding package is loaded (thanks to [Requires.jl](https://github.com/JuliaPackaging/Requires.jl)):
The following backends are temporarily made available by `AbstractDifferentiation` as soon as their corresponding package is loaded (thanks to [weak dependencies](https://pkgdocs.julialang.org/dev/creating-packages/#Weak-dependencies) on Julia ≥ 1.9 and [Requires.jl](https://github.com/JuliaPackaging/Requires.jl) on older Julia versions):

- `AD.ForwardDiffBackend()` for [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl)
- `AD.FiniteDifferencesBackend()` for [FiniteDifferences.jl](https://github.com/JuliaDiff/FiniteDifferences.jl)
Expand All @@ -35,7 +41,7 @@ In the long term, these backend objects (and many more) will be defined within t
Here's an example:

```julia
julia> using AbstractDifferentiation, Zygote
julia> import AbstractDifferentiation as AD, Zygote

julia> ab = AD.ZygoteBackend()
AbstractDifferentiation.ReverseRuleConfigBackend{Zygote.ZygoteRuleConfig{Zygote.Context}}(Zygote.ZygoteRuleConfig{Zygote.Context}(Zygote.Context(nothing)))
Expand Down
36 changes: 36 additions & 0 deletions ext/FiniteDifferencesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
module FiniteDifferencesExt

using AbstractDifferentiation: AbstractDifferentiation, EXTENSIONS_SUPPORTED, FiniteDifferencesBackend
if EXTENSIONS_SUPPORTED
using FiniteDifferences: FiniteDifferences
else
using ..FiniteDifferences: FiniteDifferences
end

const AD = AbstractDifferentiation

"""
FiniteDifferencesBackend(method=FiniteDifferences.central_fdm(5, 1))

Create an AD backend that uses forward mode with FiniteDifferences.jl.
"""
AD.FiniteDifferencesBackend() = FiniteDifferencesBackend(FiniteDifferences.central_fdm(5, 1))

AD.@primitive function jacobian(ba::FiniteDifferencesBackend, f, xs...)
return FiniteDifferences.jacobian(ba.method, f, xs...)
end

function AD.pushforward_function(ba::FiniteDifferencesBackend, f, xs...)
return function pushforward(vs)
ws = FiniteDifferences.jvp(ba.method, f, tuple.(xs, vs)...)
return length(xs) == 1 ? (ws,) : ws
end
end

function AD.pullback_function(ba::FiniteDifferencesBackend, f, xs...)
function pullback(vs)
return FiniteDifferences.j′vp(ba.method, f, vs, xs...)
end
end

end # module
44 changes: 23 additions & 21 deletions src/forwarddiff.jl → ext/ForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
using .ForwardDiff: ForwardDiff, DiffResults, StaticArrays
module ForwardDiffExt

"""
ForwardDiffBackend{CS}

AD backend that uses forward mode with ForwardDiff.jl.

The type parameter `CS` denotes the chunk size of the differentiation algorithm. If it is
`Nothing`, then ForwardiffDiff uses a heuristic to set the chunk size based on the input.
using AbstractDifferentiation: AbstractDifferentiation, asarray, EXTENSIONS_SUPPORTED, ForwardDiffBackend
if EXTENSIONS_SUPPORTED
using ForwardDiff: ForwardDiff, DiffResults
else
using ..ForwardDiff: ForwardDiff, DiffResults
end
if VERSION < v"1.4.0-DEV.142"
using Compat: only
end

See also: [ForwardDiff.jl: Configuring Chunk Size](https://juliadiff.org/ForwardDiff.jl/dev/user/advanced/#Configuring-Chunk-Size)
"""
struct ForwardDiffBackend{CS} <: AbstractForwardMode end
const AD = AbstractDifferentiation

"""
ForwardDiffBackend(; chunksize::Union{Val,Nothing}=nothing)
Expand All @@ -23,11 +23,11 @@ ForwarddDiff uses a heuristic to set the chunk size based on the input. Alternat

See also: [ForwardDiff.jl: Configuring Chunk Size](https://juliadiff.org/ForwardDiff.jl/dev/user/advanced/#Configuring-Chunk-Size)
"""
function ForwardDiffBackend(; chunksize::Union{Val,Nothing}=nothing)
function AD.ForwardDiffBackend(; chunksize::Union{Val,Nothing}=nothing)
return ForwardDiffBackend{getchunksize(chunksize)}()
end

@primitive function pushforward_function(ba::ForwardDiffBackend, f, xs...)
AD.@primitive function pushforward_function(ba::ForwardDiffBackend, f, xs...)
return function pushforward(vs)
if length(xs) == 1
v = vs isa Tuple ? only(vs) : vs
Expand All @@ -38,35 +38,35 @@ end
end
end

primal_value(x::ForwardDiff.Dual) = ForwardDiff.value(x)
primal_value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x)
AD.primal_value(x::ForwardDiff.Dual) = ForwardDiff.value(x)
AD.primal_value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x)

# these implementations are more efficient than the fallbacks

function gradient(ba::ForwardDiffBackend, f, x::AbstractArray)
function AD.gradient(ba::ForwardDiffBackend, f, x::AbstractArray)
cfg = ForwardDiff.GradientConfig(f, x, chunk(ba, x))
return (ForwardDiff.gradient(f, x, cfg),)
end

function jacobian(ba::ForwardDiffBackend, f, x::AbstractArray)
function AD.jacobian(ba::ForwardDiffBackend, f, x::AbstractArray)
cfg = ForwardDiff.JacobianConfig(asarray ∘ f, x, chunk(ba, x))
return (ForwardDiff.jacobian(asarray ∘ f, x, cfg),)
end
jacobian(::ForwardDiffBackend, f, x::Number) = (ForwardDiff.derivative(f, x),)
AD.jacobian(::ForwardDiffBackend, f, x::Number) = (ForwardDiff.derivative(f, x),)

function hessian(ba::ForwardDiffBackend, f, x::AbstractArray)
function AD.hessian(ba::ForwardDiffBackend, f, x::AbstractArray)
cfg = ForwardDiff.HessianConfig(f, x, chunk(ba, x))
return (ForwardDiff.hessian(f, x, cfg),)
end

function value_and_gradient(ba::ForwardDiffBackend, f, x::AbstractArray)
function AD.value_and_gradient(ba::ForwardDiffBackend, f, x::AbstractArray)
result = DiffResults.GradientResult(x)
cfg = ForwardDiff.GradientConfig(f, x, chunk(ba, x))
ForwardDiff.gradient!(result, f, x, cfg)
return DiffResults.value(result), (DiffResults.derivative(result),)
end

function value_and_hessian(ba::ForwardDiffBackend, f, x)
function AD.value_and_hessian(ba::ForwardDiffBackend, f, x)
result = DiffResults.HessianResult(x)
cfg = ForwardDiff.HessianConfig(f, result, x, chunk(ba, x))
ForwardDiff.hessian!(result, f, x, cfg)
Expand All @@ -82,3 +82,5 @@ getchunksize(::Val{N}) where {N} = N

chunk(::ForwardDiffBackend{Nothing}, x) = ForwardDiff.Chunk(x)
chunk(::ForwardDiffBackend{N}, _) where {N} = ForwardDiff.Chunk{N}()

end # module
39 changes: 23 additions & 16 deletions src/reversediff.jl → ext/ReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
using .ReverseDiff: ReverseDiff, DiffResults
module ReverseDiffExt

primal_value(x::ReverseDiff.TrackedReal) = ReverseDiff.value(x)
primal_value(x::AbstractArray{<:ReverseDiff.TrackedReal}) = ReverseDiff.value.(x)
primal_value(x::ReverseDiff.TrackedArray) = ReverseDiff.value(x)
using AbstractDifferentiation: AbstractDifferentiation, asarray, EXTENSIONS_SUPPORTED, ReverseDiffBackend
if EXTENSIONS_SUPPORTED
using ReverseDiff: ReverseDiff, DiffResults
else
using ..ReverseDiff: ReverseDiff, DiffResults
end
if VERSION < v"1.4.0-DEV.142"
using Compat: only
end

"""
ReverseDiffBackend
const AD = AbstractDifferentiation

AD backend that uses reverse mode with ReverseDiff.jl.
"""
struct ReverseDiffBackend <: AbstractReverseMode end
AD.primal_value(x::ReverseDiff.TrackedReal) = ReverseDiff.value(x)
AD.primal_value(x::AbstractArray{<:ReverseDiff.TrackedReal}) = ReverseDiff.value.(x)
AD.primal_value(x::ReverseDiff.TrackedArray) = ReverseDiff.value(x)

@primitive function jacobian(ba::ReverseDiffBackend, f, xs...)
AD.@primitive function jacobian(ba::ReverseDiffBackend, f, xs...)
xs_arr = map(asarray, xs)
tape = ReverseDiff.JacobianTape(xs_arr) do (xs_arr...)
xs_new = map(xs, xs_arr) do x, x_arr
Expand All @@ -24,11 +29,11 @@ struct ReverseDiffBackend <: AbstractReverseMode end
return x isa Number ? vec(result) : result
end
end
function jacobian(ba::ReverseDiffBackend, f, xs::AbstractArray...)
function AD.jacobian(ba::ReverseDiffBackend, f, xs::AbstractArray...)
return ReverseDiff.jacobian(asarray ∘ f, xs)
end

function derivative(ba::ReverseDiffBackend, f, xs::Number...)
function AD.derivative(ba::ReverseDiffBackend, f, xs::Number...)
tape = ReverseDiff.InstructionTape()
xs_tracked = ReverseDiff.TrackedReal.(xs, zero.(xs), Ref(tape))
y_tracked = f(xs_tracked...)
Expand All @@ -37,24 +42,26 @@ function derivative(ba::ReverseDiffBackend, f, xs::Number...)
return ReverseDiff.deriv.(xs_tracked)
end

function gradient(ba::ReverseDiffBackend, f, xs::AbstractArray...)
function AD.gradient(ba::ReverseDiffBackend, f, xs::AbstractArray...)
return ReverseDiff.gradient(f, xs)
end

function hessian(ba::ReverseDiffBackend, f, x::AbstractArray)
function AD.hessian(ba::ReverseDiffBackend, f, x::AbstractArray)
return (ReverseDiff.hessian(f, x),)
end

function value_and_gradient(ba::ReverseDiffBackend, f, x::AbstractArray)
function AD.value_and_gradient(ba::ReverseDiffBackend, f, x::AbstractArray)
result = DiffResults.GradientResult(x)
cfg = ReverseDiff.GradientConfig(x)
ReverseDiff.gradient!(result, f, x, cfg)
return DiffResults.value(result), (DiffResults.derivative(result),)
end

function value_and_hessian(ba::ReverseDiffBackend, f, x)
function AD.value_and_hessian(ba::ReverseDiffBackend, f, x)
result = DiffResults.HessianResult(x)
cfg = ReverseDiff.HessianConfig(result, x)
ReverseDiff.hessian!(result, f, x, cfg)
return DiffResults.value(result), (DiffResults.hessian(result),)
end

end # module
41 changes: 41 additions & 0 deletions ext/TrackerExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
module TrackerExt

using AbstractDifferentiation: AbstractDifferentiation, EXTENSIONS_SUPPORTED, TrackerBackend
if EXTENSIONS_SUPPORTED
using Tracker: Tracker
else
using ..Tracker: Tracker
end

const AD = AbstractDifferentiation

function AD.second_lowest(::TrackerBackend)
return throw(ArgumentError("Tracker backend does not support nested differentiation."))
end

AD.primal_value(x::Tracker.TrackedReal) = Tracker.data(x)
AD.primal_value(x::Tracker.TrackedArray) = Tracker.data(x)
AD.primal_value(x::AbstractArray{<:Tracker.TrackedReal}) = Tracker.data.(x)

AD.@primitive function pullback_function(ba::TrackerBackend, f, xs...)
value, back = Tracker.forward(f, xs...)
function pullback(ws)
if ws isa Tuple && !(value isa Tuple)
@assert length(ws) == 1
map(Tracker.data, back(ws[1]))
else
map(Tracker.data, back(ws))
end
end
return pullback
end

function AD.derivative(ba::TrackerBackend, f, xs::Number...)
return Tracker.data.(Tracker.gradient(f, xs...))
end

function AD.gradient(ba::TrackerBackend, f, xs::AbstractVector...)
return Tracker.data.(Tracker.gradient(f, xs...))
end

end # module
15 changes: 15 additions & 0 deletions ext/ZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
module ZygoteExt

using AbstractDifferentiation: AbstractDifferentiation, EXTENSIONS_SUPPORTED, ReverseRuleConfigBackend

if EXTENSIONS_SUPPORTED
using Zygote: Zygote
else
using ..Zygote: Zygote
end

@static if isdefined(AbstractDifferentiation, :ZygoteBackend) && isdefined(Zygote, :ZygoteRuleConfig)
AbstractDifferentiation.ZygoteBackend() = ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig())
end

end # module
28 changes: 17 additions & 11 deletions src/AbstractDifferentiation.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
module AbstractDifferentiation

using LinearAlgebra, ExprTools, Requires, Compat
using LinearAlgebra, ExprTools
using ChainRulesCore: ChainRulesCore

export AD

const AD = AbstractDifferentiation
if VERSION < v"1.1.0-DEV.792"
using Compat: eachcol
end

abstract type AbstractBackend end
abstract type AbstractFiniteDifference <: AbstractBackend end
Expand Down Expand Up @@ -645,14 +644,21 @@ end
@inline asarray(x::AbstractArray) = x

include("ruleconfig.jl")
include("backends.jl")

# TODO: Replace with proper version
const EXTENSIONS_SUPPORTED = isdefined(Base, :get_extension)
if !EXTENSIONS_SUPPORTED
using Requires: @require
end
function __init__()
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include("forwarddiff.jl")
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include("reversediff.jl")
@require FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" include("finitedifferences.jl")
@require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("tracker.jl")
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
@static if !EXTENSIONS_SUPPORTED
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can move this out of the init definition

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried many different versions and this was the (only?) one that seemed to work. But I'll give it another shot.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I think the problem was that I had put using Requires in the same check (outside of ìnit) which does not work. Having to if statements seems to work though.

Copy link
Member Author

@devmotion devmotion Feb 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include("../ext/ForwardDiffExt.jl")
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include("../ext/ReverseDiffExt.jl")
@require FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" include("../ext/FiniteDifferencesExt.jl")
@require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("../ext/TrackerExt.jl")
@static if VERSION >= v"1.6"
ZygoteBackend() = ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig())
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include("../ext/ZygoteExt.jl")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this Zygote one versioned off?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For simplicity, just bump to LTS?

end
end
end
Expand Down
Loading