Skip to content

Commit

Permalink
Fix Enzyme diffrules
Browse files Browse the repository at this point in the history
  • Loading branch information
dominic-chang committed Mar 31, 2024
1 parent 01a91de commit 5243dd9
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 49 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ ForwardDiff = "0.10"
Setfield = "1.1"
StaticArrays = "1.6"
julia = "1.8"
Enzyme = "0.11.20"
Enzyme = "0.11"
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# JacobiElliptic
JacobiElliptic is an implementation of Toshio Fukushima's algorithms for calculating [Elliptic Integrals and Jacobi Elliptic Functions](https://ieeexplore.ieee.org/document/7203795).
JacobiElliptic is an implementation of Toshio Fukushima's & Billie C. Carlson's for calculating [Elliptic Integrals and Jacobi Elliptic Functions](https://ieeexplore.ieee.org/document/7203795).

## Features
- Type stable and preserving
Expand Down
56 changes: 22 additions & 34 deletions ext/JacobiEllipticEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,72 +18,61 @@ function ∂F_∂ϕ(ϕ, m)
end

function forward(
::Const{typeof(JacobiElliptic.F)},
::Type,
func::Const{typeof(JacobiElliptic.F)},
::Type{<:Duplicated},
ϕ::Const,
m::Duplicated
)
println("In custom forward rule.")

return ∂F_∂m.val, m.val)*m.dval
return Duplicated(func.val.val, m.val), ∂F_∂m.val, m.val)*m.dval)
end

function forward(
::Const{typeof(JacobiElliptic.F)},
::Type,
func::Const{typeof(JacobiElliptic.F)},
::Type{<:Duplicated},
ϕ::Duplicated,
m::Const
)
println("In custom forward rule.")

return ∂F_∂ϕ.val, m.val)*ϕ.dval
return Duplicated(func.val.val, m.val), ∂F_∂ϕ.val, m.val)*ϕ.dval)
end

function forward(
::Const{typeof(JacobiElliptic.F)},
::Type,
func::Const{typeof(JacobiElliptic.F)},
::Type{<:Duplicated},
ϕ::Duplicated,
m::Duplicated
)
println("In custom forward rule.")

return ∂F_∂m.val, m.val)*m.dval + ∂F_∂ϕ.val, m.val)*ϕ.dval
return Duplicated(func.val.val, m.val), ∂F_∂m.val, m.val)*m.dval + ∂F_∂ϕ.val, m.val)*ϕ.dval)
end

function augmented_primal(
config,
config::ConfigWidth{N},
func::Const{typeof(JacobiElliptic.F)},
::Type{<:Active},
tape,
ϕ,
m
)
println("In custom augmented primal rule.")
ϕ::Union{Const,Active},
m::Union{Const,Active}
) where {N}
#println("In custom augmented primal rule.")
# Save x in tape if x will be overwritten
primal = EnzymeRules.needs_primal(config) ? func.val, m.val) : nothing
primal = EnzymeRules.needs_primal(config) ? func.val.val, m.val) : nothing

return EnzymeRules.AugmentedReturn(primal, nothing, tape)
return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
end


function reverse(config::ConfigWidth{N}, ::Const{JacobiElliptic.F}, dret::Active, tape, ϕ::Const, m::Active) where {N}
println("In custom reverse rule.")
function reverse(config::ConfigWidth{1}, ::Const{typeof(JacobiElliptic.F)}, dret::Active, tape, ϕ::Const, m::Active)
# retrieve x value, either from original x or from tape if x may have been overwritten.
mval = m.val
dm = ∂F_∂m.val, mval) * dret.val
return (dm, )
return (nothing, dm)
end

function reverse(config::ConfigWidth{N}, ::Const{JacobiElliptic.F}, dret::Active, tape, ϕ::Active, m::Const) where {N}
println("In custom reverse rule.")
function reverse(config::ConfigWidth{1}, ::Const{typeof(JacobiElliptic.F)}, dret::Active, tape, ϕ::Active, m::Const)
# retrieve x value, either from original x or from tape if x may have been overwritten.
ϕval = EnzymeRules.overwritten(config)[2] ? tape : ϕ.val
ϕval = ϕ.val
= ∂F_∂ϕ(ϕval, m.val) * dret.val
return (dϕ, )
return (dϕ, nothing)
end

function reverse(config::ConfigWidth{N}, ::Const{JacobiElliptic.F}, dret::Active, tape, ϕ::Active, m::Active) where {N}
println("In custom reverse rule.")
function reverse(config, ::Const{typeof(JacobiElliptic.F)}, dret::Active, tape, ϕ::Union{Active, Duplicated}, m::Union{Active, Duplicated})
# retrieve x value, either from original x or from tape if x may have been overwritten.
ϕval = ϕ.val
mval = m.val
Expand All @@ -92,5 +81,4 @@ function reverse(config::ConfigWidth{N}, ::Const{JacobiElliptic.F}, dret::Active
return (dϕ, dm)
end


end # module
24 changes: 12 additions & 12 deletions src/Fukushima.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ Returns the complete elliptic integral of the first kind.
- `m` : Elliptic modulus
"""
function K(m::T) where T
!(-1 m 1) && throw(DomainError("argument m not in [0,1]"))
m == one(T) && return T(Inf)

# I didn't really see any speedup from evalpoly, so I left the evaluation in this form
Expand Down Expand Up @@ -50,7 +49,7 @@ function K(m::T) where T
x == zero(T) && return T/2)
x == one(T) && return T(Inf)
x > one(T) && return T(NaN)

if x < T(0.1)
t = poly1( x - T(0.05));
elseif ( x < T(0.2))
Expand Down Expand Up @@ -79,7 +78,7 @@ function K(m::T) where T
end
# Complete the transformation mentioned above for m < 0:
flag && return t / sqrt( one(T) - m );

return t
end

Expand All @@ -93,7 +92,6 @@ Returns the complete elliptic integral of the second kind.
- `m` : Elliptic modulus
"""
function E(m::T) where T
!(-1 m 1) && throw(DomainError("argument m not in [0,1]"))
m == zero(T) && return T/2)
m == one(T) && return one(T)

Expand Down Expand Up @@ -224,13 +222,14 @@ function asn(s::A, m::B) where {A,B}
yA = T(0.04094) - T(0.00652)*m
y = s * s
y < yA && return s*serf(y, m)

p = one(T)
for _ in 1:10
y = y * inv((1+√(1-y))*(1+√(1-m*y)))
p += p
y < yA && return p*√y*serf(y, m)
end
return T(NaN)
end

"""
Expand All @@ -257,6 +256,7 @@ function acn(c::A, m::B) where {A,B}
x = (x + d)/(1+d)
p += p
end
return T(NaN)
end

@inline function rawF(sinφ::A, m::B) where {A,B}
Expand Down Expand Up @@ -552,7 +552,7 @@ function Pi(n::A, φ::B, m::C) where {A,B,C}
return (FukushimaT(t1, h1) - n1*J(n1, φ, m))
end
return n*J(n, φ, m) + F(φ, m)
end
end

"""
``J (n;\\varphi \\,|\\,m)=\\frac{\\Pi(n;\\varphi|\\, m) - F(\\varphi|\\,m)}{n}.``
Expand All @@ -566,7 +566,7 @@ Returns the associate incomplete elliptic integral of the third kind.
- `m` : Elliptic modulus
"""
function J(n::A, φ::B, m::C) where {A, B, C} #Appendix A
T = promote_type(A,B,C)
T = promote_type(A,B,C)
# Reduction of Amplitude
φ == zero(T) && return zero(T)
φ == T/2) && return J(n,m)
Expand Down Expand Up @@ -818,16 +818,16 @@ function FukushimaT(t::A, h::B) where {A,B}
return t
else
arg = t * (-h)
ans = abs(arg) < one(T) ? atanh(arg) : custom_atanh(arg)
return ans / (-h)
end
ans = abs(arg) < one(T) ? atanh(arg) : custom_atanh(arg)
return ans / (-h)
end
end

#https://link-springer-com.ezp-prod1.hul.harvard.edu/article/T(10).1007/BF02165405
function cel(kc::A, p::B, a::C, b::D) where {A,B,C,D}
T = promote_type(A,B,C,D)
#ca = T(1e-6)
ca = eps(T)
ca = eps(T)
kc = abs(kc)
e = kc
m = one(T)
Expand Down Expand Up @@ -932,7 +932,7 @@ function fold_1_00(u1::T, m::T, Kscreen::T, Kactual::T, kp::T) where T

if u1 > Kscreen/2
sn, cn, dn = fold_0_50(Kactual - u1, m, Kscreen, Kactual, kp)
elseif u1 > Kscreen/4
elseif u1 > Kscreen/4
sn, cn, dn = fold_0_25(Kactual/2 - u1, m, kp)
else
sn, cn, dn = _ΔXNloop(u1, m, u1 > zero(T) ? max(6+(floor(log2(u1))), one(T)) : zero(T))
Expand Down
3 changes: 2 additions & 1 deletion src/JacobiElliptic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct Fukushima <: AbstractAlgorithm end
struct Carlson <: AbstractAlgorithm end


func_syms = [:E, :F, :K, :Pi, :J, :sn, :cn, :dn, :nn, :sd, :dd, :nd, :sc, :cc, :dc, :nc, :ss, :cs, :ds, :ns, :am, :cd]
func_syms = [:E, :F, :K, :Pi, :J, :sn, :cn, :dn, :nn, :sd, :dd, :nd, :sc, :cc, :dc, :nc, :ss, :cs, :ds, :ns, :cd]
sym_list = []


Expand All @@ -46,5 +46,6 @@ end

asn = FukushimaAlg.asn
acn = FukushimaAlg.asn
am = CarlsonAlg.am

end

0 comments on commit 5243dd9

Please sign in to comment.