Skip to content

Commit

Permalink
Merge pull request #2559 from efaulhaber/verlet-leapfrog
Browse files Browse the repository at this point in the history
Optimize `VerletLeapfrog` method
  • Loading branch information
ChrisRackauckas authored Dec 23, 2024
2 parents 607c124 + 87ff9ea commit a265106
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 24 deletions.
2 changes: 2 additions & 0 deletions lib/OrdinaryDiffEqSymplecticRK/src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ struct VelocityVerlet <: OrdinaryDiffEqPartitionedAlgorithm end
verlet1967, "", "")
struct VerletLeapfrog <: OrdinaryDiffEqPartitionedAlgorithm end

OrdinaryDiffEqCore.default_linear_interpolation(alg::VerletLeapfrog, prob) = true

@doc generic_solver_docstring("2nd order explicit symplectic integrator.",
"PseudoVerletLeapfrog",
"Symplectic Runge-Kutta Methods",
Expand Down
28 changes: 21 additions & 7 deletions lib/OrdinaryDiffEqSymplecticRK/src/symplectic_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,18 @@ function alg_cache(alg::VelocityVerlet, u, rate_prototype, ::Type{uEltypeNoUnits
VelocityVerletConstantCache(uEltypeNoUnits(1 // 2))
end

@cache struct Symplectic2Cache{uType, rateType, tableauType} <: HamiltonMutableCache
@cache struct VerletLeapfrogCache{uType, rateType, uEltypeNoUnits} <:
OrdinaryDiffEqMutableCache
u::uType
uprev::uType
tmp::uType
k::rateType
fsalfirst::rateType
tab::tableauType
half::uEltypeNoUnits
end

struct VerletLeapfrogConstantCache{uEltypeNoUnits} <: HamiltonConstantCache
half::uEltypeNoUnits
end

function alg_cache(alg::VerletLeapfrog, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand All @@ -73,16 +78,24 @@ function alg_cache(alg::VerletLeapfrog, u, rate_prototype, ::Type{uEltypeNoUnits
tmp = zero(u)
k = zero(rate_prototype)
fsalfirst = zero(rate_prototype)
tab = VerletLeapfrogConstantCache(constvalue(uBottomEltypeNoUnits),
constvalue(tTypeNoUnits))
Symplectic2Cache(u, uprev, k, tmp, fsalfirst, tab)
half = uEltypeNoUnits(1 // 2)
VerletLeapfrogCache(u, uprev, k, tmp, fsalfirst, half)
end

function alg_cache(alg::VerletLeapfrog, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
VerletLeapfrogConstantCache(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))
VerletLeapfrogConstantCache(uEltypeNoUnits(1 // 2))
end

@cache struct Symplectic2Cache{uType, rateType, tableauType} <: HamiltonMutableCache
u::uType
uprev::uType
tmp::uType
k::rateType
fsalfirst::rateType
tab::tableauType
end

function alg_cache(alg::PseudoVerletLeapfrog, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand Down Expand Up @@ -422,6 +435,7 @@ function alg_cache(alg::SofSpa10, u, rate_prototype, ::Type{uEltypeNoUnits},
end

function get_fsalfirstlast(
cache::Union{HamiltonMutableCache, VelocityVerletCache, SymplecticEulerCache}, u)
cache::Union{HamiltonMutableCache, VelocityVerletCache, VerletLeapfrogCache,
SymplecticEulerCache}, u)
(cache.fsalfirst, cache.k)
end
71 changes: 62 additions & 9 deletions lib/OrdinaryDiffEqSymplecticRK/src/symplectic_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,12 @@ end
# f.f2(p, q, pa, t) = p which is the Newton/Lagrange equations
# If called with different functions (which are possible in the Hamiltonian case)
# an exception is thrown to avoid silently calculate wrong results.
verify_f2(f, p, q, pa, t, ::Any, ::C) where {C <: HamiltonConstantCache} = f(p, q, pa, t)
function verify_f2(f, res, p, q, pa, t, ::Any, ::C) where {C <: HamiltonMutableCache}
function verify_f2(f, p, q, pa, t, ::Any,
::C) where {C <: Union{HamiltonConstantCache, VerletLeapfrogConstantCache}}
f(p, q, pa, t)
end
function verify_f2(f, res, p, q, pa, t, ::Any,
::C) where {C <: Union{HamiltonMutableCache, VerletLeapfrogCache}}
f(res, p, q, pa, t)
end

Expand Down Expand Up @@ -124,8 +128,8 @@ function store_symp_state!(integrator, ::OrdinaryDiffEqMutableCache, kdu, ku)
end

function initialize!(integrator,
cache::C) where {C <:
Union{HamiltonMutableCache, VelocityVerletCache}}
cache::C) where {C <: Union{
HamiltonMutableCache, VelocityVerletCache, VerletLeapfrogCache}}
integrator.kshortsize = 2
resize!(integrator.k, integrator.kshortsize)
integrator.k[1] = integrator.fsalfirst
Expand All @@ -140,9 +144,8 @@ function initialize!(integrator,
end

function initialize!(integrator,
cache::C) where {
C <:
Union{HamiltonConstantCache, VelocityVerletConstantCache}}
cache::C) where {C <: Union{
HamiltonConstantCache, VelocityVerletConstantCache, VerletLeapfrogConstantCache}}
integrator.kshortsize = 2
integrator.k = typeof(integrator.k)(undef, integrator.kshortsize)

Expand Down Expand Up @@ -171,7 +174,7 @@ end
# v(t+Δt) = v(t) + 1/2*(a(t)+a(t+Δt))*Δt
du = duprev + dt * (half * ku + half * kdu)

OrdinaryDiffEqCore.increment_nf!(integrator.stats, 2)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
store_symp_state!(integrator, cache, du, u, kdu, du)
end

Expand All @@ -186,13 +189,63 @@ end
half = cache.half
@.. broadcast=false u=uprev + dt * duprev + dtsq * (half * ku)
f.f1(kdu, duprev, u, p, t + dt)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 2)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
# v(t+Δt) = v(t) + 1/2*(a(t)+a(t+Δt))*Δt
@.. broadcast=false du=duprev + dt * (half * ku + half * kdu)

store_symp_state!(integrator, cache, kdu, du)
end

@muladd function perform_step!(integrator, cache::VerletLeapfrogConstantCache,
repeat_step = false)
@unpack t, dt, f, p = integrator
duprev, uprev, kduprev, _ = load_symp_state(integrator)

# kick-drift-kick scheme of the Leapfrog method:
# update velocity
half = cache.half
du = duprev + dt * half * kduprev

# update position
tnew = t + half * dt
ku = f.f2(du, uprev, p, tnew)
u = uprev + dt * ku

# update velocity
tnew = tnew + half * dt
kdu = f.f1(du, u, p, tnew)
du = du + dt * half * kdu

OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
integrator.stats.nf2 += 1
store_symp_state!(integrator, cache, du, u, kdu, ku)
end

@muladd function perform_step!(integrator, cache::VerletLeapfrogCache, repeat_step = false)
@unpack t, dt, f, p = integrator
duprev, uprev, kduprev, _ = load_symp_state(integrator)
du, u, kdu, ku = alloc_symp_state(integrator)

# Kick-Drift-Kick scheme of the Verlet Leapfrog method:
# update velocity
half = cache.half
@.. broadcast=false du=duprev + dt * half * kduprev

# update position
tnew = t + half * dt
f.f2(ku, du, uprev, p, tnew)
@.. broadcast=false u=uprev + dt * ku

# update velocity
tnew = tnew + half * dt
f.f1(kdu, du, u, p, tnew)
@.. broadcast=false du=du + dt * half * kdu

OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
integrator.stats.nf2 += 1
store_symp_state!(integrator, cache, kdu, ku)
end

@muladd function perform_step!(integrator, cache::Symplectic2ConstantCache,
repeat_step = false)
@unpack t, dt, f, p = integrator
Expand Down
8 changes: 0 additions & 8 deletions lib/OrdinaryDiffEqSymplecticRK/src/symplectic_tableaus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,6 @@ function McAte2ConstantCache(T, T2)
Symplectic2ConstantCache{T, T2}(a1, a2, b1, b2)
end

function VerletLeapfrogConstantCache(T, T2)
a1 = convert(T, 1 // 2)
a2 = convert(T, 1 // 2)
b1 = convert(T, 0)
b2 = convert(T, 1)
Symplectic2ConstantCache{T, T2}(a1, a2, b1, b2)
end

struct Symplectic3ConstantCache{T, T2} <: HamiltonConstantCache
a1::T
a2::T
Expand Down

0 comments on commit a265106

Please sign in to comment.