From 3218c0866d027fe795a84a20b51759eeeb6770ed Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 3 Mar 2023 21:53:57 +0100 Subject: [PATCH] New Zygote context in every call to `AD.pullback_function` (#77) * New Zygote context in every call to `AD.pullback_function` * Make fix more modular --- Project.toml | 2 +- ...bstractDifferentiationChainRulesCoreExt.jl | 2 +- ext/AbstractDifferentiationZygoteExt.jl | 5 +++++ src/backends.jl | 5 +++++ test/ruleconfig.jl | 22 +++++++++++++++++++ 5 files changed, 34 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 94bcbae..70e3c2e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "AbstractDifferentiation" uuid = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" authors = ["Mohamed Tarek and contributors"] -version = "0.5.0" +version = "0.5.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/ext/AbstractDifferentiationChainRulesCoreExt.jl b/ext/AbstractDifferentiationChainRulesCoreExt.jl index c2a2ff6..4424381 100644 --- a/ext/AbstractDifferentiationChainRulesCoreExt.jl +++ b/ext/AbstractDifferentiationChainRulesCoreExt.jl @@ -4,7 +4,7 @@ import AbstractDifferentiation as AD using ChainRulesCore: ChainRulesCore AD.@primitive function pullback_function(ba::AD.ReverseRuleConfigBackend, f, xs...) - _, back = ChainRulesCore.rrule_via_ad(ba.ruleconfig, f, xs...) + _, back = ChainRulesCore.rrule_via_ad(AD.ruleconfig(ba), f, xs...) pullback(vs) = Base.tail(back(vs)) pullback(vs::Tuple{Any}) = Base.tail(back(first(vs))) return pullback diff --git a/ext/AbstractDifferentiationZygoteExt.jl b/ext/AbstractDifferentiationZygoteExt.jl index bd65f84..85e4a5c 100644 --- a/ext/AbstractDifferentiationZygoteExt.jl +++ b/ext/AbstractDifferentiationZygoteExt.jl @@ -9,4 +9,9 @@ end AD.ZygoteBackend() = AD.ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig()) +# Context should not persist between different AD calls: fixes #69 +function AD.ruleconfig(::AD.ReverseRuleConfigBackend{<:Zygote.ZygoteRuleConfig}) + return Zygote.ZygoteRuleConfig() +end + end # module diff --git a/src/backends.jl b/src/backends.jl index 7009195..7985b14 100644 --- a/src/backends.jl +++ b/src/backends.jl @@ -61,6 +61,11 @@ struct ReverseRuleConfigBackend{RC} <: AbstractReverseMode ruleconfig::RC end +# internal function for extracting the rule config +# falls back to returning the wrapped `ruleconfig` but can be specialized +# e.g., for Zygote to fix #69 +ruleconfig(ba::ReverseRuleConfigBackend) = ba.ruleconfig + """ ZygoteBackend() diff --git a/test/ruleconfig.jl b/test/ruleconfig.jl index 412d89c..f18330c 100644 --- a/test/ruleconfig.jl +++ b/test/ruleconfig.jl @@ -30,4 +30,26 @@ using Zygote test_lazy_jacobians(backend) end end + + # issue #69 + @testset "Zygote context" begin + ad = AD.ZygoteBackend() + + # example in #69: context is not mutated + @test ad.ruleconfig.context.cache === nothing + @test AD.derivative(ad, exp, 1.0) === (exp(1.0),) + @test ad.ruleconfig.context.cache === nothing + @test AD.derivative(ad, exp, 1.0) === (exp(1.0),) + @test ad.ruleconfig.context.cache === nothing + + # Jacobian computation still works + # https://github.com/JuliaDiff/AbstractDifferentiation.jl/pull/70#issuecomment-1449481724 + function f(x, a) + r = Ref(x) + r[] = r[] + r[] + r[] = r[] * a + r[] + end + @test AD.jacobian(ad, f, [1, 2, 3], 3) == ([6.0 0.0 0.0; 0.0 6.0 0.0; 0.0 0.0 6.0], [2.0, 4.0, 6.0]) + end end