Skip to content

Commit

Permalink
RuleConfig and Zygote support (#49)
Browse files Browse the repository at this point in the history
* ruleconfig support and Zygote tests

* actually run the tests :)

* avoid loading Yota

* lower bound Zygote compat

* move imports

* bump version

* avoid testing ruleconfig in 1.0

* remove Zygote compat
  • Loading branch information
mohamed82008 authored Feb 8, 2022
1 parent 9a3b564 commit 8f0d6db
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 1 deletion.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
name = "AbstractDifferentiation"
uuid = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
authors = ["Mohamed Tarek <mohamed82008@gmail.com> and contributors"]
version = "0.4.0"
version = "0.4.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"

[compat]
ChainRulesCore = "1"
Compat = "3"
ExprTools = "0.1"
ForwardDiff = "0.10"
Expand Down
7 changes: 7 additions & 0 deletions src/AbstractDifferentiation.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module AbstractDifferentiation

using LinearAlgebra, ExprTools, Requires, Compat
using ChainRulesCore: RuleConfig, rrule_via_ad

export AD

Expand Down Expand Up @@ -643,11 +644,17 @@ end
@inline asarray(x) = [x]
@inline asarray(x::AbstractArray) = x

include("ruleconfig.jl")
function __init__()
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include("forwarddiff.jl")
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include("reversediff.jl")
@require FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" include("finitedifferences.jl")
@require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("tracker.jl")
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
@static if VERSION >= v"1.6"
ZygoteBackend() = ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig())
end
end
end

end
19 changes: 19 additions & 0 deletions src/ruleconfig.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
ReverseRuleConfigBackend
AD backend that uses reverse mode with any ChainRules-compatible reverse-mode AD package.
"""
struct ReverseRuleConfigBackend{RC <: RuleConfig} <: AbstractReverseMode
ruleconfig::RC
end

AD.@primitive function pullback_function(ab::ReverseRuleConfigBackend, f, xs...)
return (vs) -> begin
_, back = rrule_via_ad(ab.ruleconfig, f, xs...)
if vs isa Tuple && length(vs) === 1
return Base.tail(back(vs[1]))
else
return Base.tail(back(vs))
end
end
end
33 changes: 33 additions & 0 deletions test/ruleconfig.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using AbstractDifferentiation
using Test
using Zygote

@testset "ReverseRuleConfigBackend(ZygoteRuleConfig())" begin
backends = [@inferred(AD.ZygoteBackend())]
@testset for backend in backends
@testset "Derivative" begin
test_derivatives(backend)
end
@testset "Gradient" begin
test_gradients(backend)
end
@testset "Jacobian" begin
test_jacobians(backend)
end
@testset "jvp" begin
test_jvp(backend)
end
@testset "j′vp" begin
test_j′vp(backend)
end
@testset "Lazy Derivative" begin
test_lazy_derivatives(backend)
end
@testset "Lazy Gradient" begin
test_lazy_gradients(backend)
end
@testset "Lazy Jacobian" begin
test_lazy_jacobians(backend)
end
end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,7 @@ using Test
include("reversediff.jl")
include("finitedifferences.jl")
include("tracker.jl")
@static if VERSION >= v"1.6"
include("ruleconfig.jl")
end
end

2 comments on commit 8f0d6db

@mohamed82008
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/54230

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.1 -m "<description of version>" 8f0d6db070abe8defaeb7a9c637b077b9895cd92
git push origin v0.4.1

Please sign in to comment.