From 15b760343edf0b4cef53d97f4115610b19247c28 Mon Sep 17 00:00:00 2001 From: dominic-chang Date: Wed, 30 Oct 2024 14:07:45 -0400 Subject: [PATCH] Fix broken Enzyme reverse diff rule on Const evaluation --- ext/JacobiEllipticEnzymeExt.jl | 50 +++++++++++++++++++++++++++------- test/autodiff.jl | 12 ++++---- 2 files changed, 46 insertions(+), 16 deletions(-) diff --git a/ext/JacobiEllipticEnzymeExt.jl b/ext/JacobiEllipticEnzymeExt.jl index 13de3dd..a1b149f 100644 --- a/ext/JacobiEllipticEnzymeExt.jl +++ b/ext/JacobiEllipticEnzymeExt.jl @@ -44,10 +44,8 @@ d end elseif EnzymeRules.needs_shadow(config) if EnzymeRules.width(config) == 1 return (ϕ isa Const ? zero(ϕ.val) : ∂F_∂ϕ(ϕ.val, m.val)*ϕ.dval) +(m isa Const ? zero(m.val) : ∂F_∂m(ϕ.val, m.val)*m.dval) - else return ntuple(i -> (ϕ isa Const ? zero(ϕ.val) : ∂F_∂ϕ(ϕ.val, m.val)*ϕ.dval[i]) + (m isa Const ? zero(m.val) : ∂F_∂m(ϕ.val, m.val)*m.dval[i]), Val(EnzymeRules.width(config))) - end elseif EnzymeRules.needs_primal(config) return func.val(ϕ.val, m.val) @@ -79,17 +77,33 @@ function reverse( dϕ = if ϕ isa Const nothing elseif EnzymeRules.width(config) == 1 - ∂F_∂ϕ(ϕ.val, m.val) * dret.val + if dret isa Type{<:Const} + zero(ϕ.val) + else + ∂F_∂ϕ(ϕ.val, m.val) * dret.val + end else - ntuple(i -> ∂F_∂ϕ(ϕ.val, m.val) * dret.val[i], Val(EnzymeRules.width(config))) + if dret isa Type{<:Const} + ntuple(i -> zero(ϕ.val), Val(EnzymeRules.width(config))) + else + ntuple(i -> ∂F_∂ϕ(ϕ.val, m.val) * dret.val[i], Val(EnzymeRules.width(config))) + end end dm = if m isa Const nothing elseif EnzymeRules.width(config) == 1 - ∂F_∂m(ϕ.val, m.val) * dret.val + if dret isa Type{<:Const} + zero(ϕ.val) + else + ∂F_∂m(ϕ.val, m.val) * dret.val + end else - ntuple(i -> ∂F_∂m(ϕ.val, m.val) * dret.val[i], Val(EnzymeRules.width(config))) + if dret isa Type{<:Const} + ntuple(i -> zero(ϕ.val), Val(EnzymeRules.width(config))) + else + ntuple(i -> ∂F_∂m(ϕ.val, m.val) * dret.val[i], Val(EnzymeRules.width(config))) + end end return (dϕ, dm) end @@ -167,17 +181,33 @@ function reverse( dϕ = if ϕ isa Const nothing elseif EnzymeRules.width(config) == 1 - ∂E_∂ϕ(ϕ.val, m.val) * dret.val + if dret isa Type{<:Const} + zero(ϕ.val) + else + ∂E_∂ϕ(ϕ.val, m.val) * dret.val + end else - ntuple(i -> ∂E_∂ϕ(ϕ.val, m.val) * dret.val[i], Val(EnzymeRules.width(config))) + if dret isa Type{<:Const} + ntuple(i -> zero(ϕ.val), Val(EnzymeRules.width(config))) + else + ntuple(i -> ∂E_∂ϕ(ϕ.val, m.val) * dret.val[i], Val(EnzymeRules.width(config))) + end end dm = if m isa Const nothing elseif EnzymeRules.width(config) == 1 - ∂E_∂m(ϕ.val, m.val) * dret.val + if dret isa Type{<:Const} + zero(ϕ.val) + else + ∂E_∂m(ϕ.val, m.val) * dret.val + end else - ntuple(i -> ∂E_∂m(ϕ.val, m.val) * dret.val[i], Val(EnzymeRules.width(config))) + if dret isa Type{<:Const} + ntuple(i -> zero(ϕ.val), Val(EnzymeRules.width(config))) + else + ntuple(i -> ∂E_∂m(ϕ.val, m.val) * dret.val[i], Val(EnzymeRules.width(config))) + end end return (dϕ, dm) end diff --git a/test/autodiff.jl b/test/autodiff.jl index 9668e4b..cb4bfcd 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -35,8 +35,8 @@ using SpecialFunctions _F = alg.F @test Zygote.gradient(ϕ -> _F(ϕ, m), ϕ)[1] ≈ 1 / √(1 - m*sin(ϕ)^2) atol=1e-5 @test ForwardDiff.derivative(ϕ -> _F(ϕ, m), ϕ) ≈ 1 / √(1 - m*sin(ϕ)^2) atol=1e-5 - @test Enzyme.autodiff(Reverse, ϕ -> _F(ϕ, m), Active, Active(ϕ))[1][1] ≈ 1 / √(1 - m*sin(ϕ)^2) atol=1e-5 - @test Enzyme.autodiff(Forward, ϕ -> _F(ϕ, m), Duplicated, Duplicated(ϕ, 1.0))[1][1] ≈ 1 / √(1 - m*sin(ϕ)^2) atol=1e-5 + @test Enzyme.autodiff(Reverse, _F, Active, Active(ϕ), Const(m))[1][1] ≈ 1 / √(1 - m*sin(ϕ)^2) atol=1e-5 + @test Enzyme.autodiff(Forward, _F, Duplicated, Duplicated(ϕ, 1.0), Const(m))[1][1] ≈ 1 / √(1 - m*sin(ϕ)^2) atol=1e-5 # 4. ∂m(F(ϕ, m)) == E(ϕ, m) / (2 * m * (1 - m)) - F(ϕ, m) / 2m - sin(2ϕ) / (4 * (1-m) * √(1 - m * sin(ϕ)^2)) @test Zygote.gradient(m -> _F(ϕ, m), m)[1] ≈ @@ -47,7 +47,7 @@ using SpecialFunctions alg.E(ϕ, m) / (2 * m * (1 - m)) - alg.F(ϕ, m) / 2 / m - sin(2*ϕ) / (4 * (1 - m) * √(1 - m * sin(ϕ)^2)) atol=1e-5 - @test Enzyme.autodiff(Reverse, m -> _F(ϕ, m), Active, Active(m))[1][1] ≈ + @test Enzyme.autodiff(Reverse, _F, Active, Const(ϕ), Active(m))[1][2] ≈ alg.E(ϕ, m) / (2 * m * (1 - m)) - alg.F(ϕ, m) / 2 / m - sin(2*ϕ) / (4 * (1 - m) * √(1 - m * sin(ϕ)^2)) atol=1e-5 @@ -56,13 +56,13 @@ using SpecialFunctions # 5. ∂ϕ(E(ϕ, m)) == √(1 - m * sin(ϕ)^2) @test Zygote.gradient(ϕ -> _E(ϕ, m), ϕ)[1] ≈ √(1 - m * sin(ϕ)^2) atol=1e-5 @test ForwardDiff.derivative(ϕ -> _E(ϕ, m), ϕ) ≈ √(1 - m * sin(ϕ)^2) atol=1e-5 - @test Enzyme.autodiff(Reverse, ϕ -> _E(ϕ, m), Active, Active(ϕ))[1][1] ≈ √(1 - m * sin(ϕ)^2) atol=1e-5 - @test Enzyme.autodiff(Forward, ϕ -> _E(ϕ, m), Duplicated, Duplicated(ϕ, 1.0))[1][1] ≈ √(1 - m*sin(ϕ)^2) atol=1e-5 + @test Enzyme.autodiff(Reverse, _E, Active, Active(ϕ), Const(m))[1][1] ≈ √(1 - m * sin(ϕ)^2) atol=1e-5 + @test Enzyme.autodiff(Forward, _E, Duplicated, Duplicated(ϕ, 1.0), Const(m))[1][1] ≈ √(1 - m*sin(ϕ)^2) atol=1e-5 # 6. ∂m(E(ϕ, m)) == (E(ϕ, m) - F(ϕ, m)) / 2m @test Zygote.gradient(m -> _E(ϕ, m), m)[1] ≈ (alg.E(ϕ, m) - alg.F(ϕ, m)) / 2m @test ForwardDiff.derivative(m -> _E(ϕ, m), m) ≈ (alg.E(ϕ, m) - alg.F(ϕ, m)) / 2m - @test Enzyme.autodiff(Reverse, m -> _E(ϕ, m), Active, Active(m))[1][1] ≈ (alg.E(ϕ, m) - alg.F(ϕ, m)) / 2m atol=1e-5 + @test Enzyme.autodiff(Reverse, _E, Active, Const(ϕ), Active(m))[1][2] ≈ (alg.E(ϕ, m) - alg.F(ϕ, m)) / 2m atol=1e-5 end end