Skip to content

Commit

Permalink
Merge pull request #2264 from ParamThakkar123/master
Browse files Browse the repository at this point in the history
Added SSPRK methods
  • Loading branch information
ChrisRackauckas authored Jun 30, 2024
2 parents b77a82f + 90aac04 commit 7986c8b
Show file tree
Hide file tree
Showing 20 changed files with 373 additions and 265 deletions.
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqLowStorageRK/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
using SafeTestsets

@time @safetestset "Extrapolation Tests" include("ode_low_storage_rk_tests.jl")
@time @safetestset "Low Storage RK Tests" include("ode_low_storage_rk_tests.jl")
23 changes: 23 additions & 0 deletions lib/OrdinaryDiffEqSSPRK/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name = "OrdinaryDiffEqSSPRK"
uuid = "669c94d9-1f4b-4b64-b377-1aa079aa2388"
authors = ["ParamThakkar123 <paramthakkar864@gmail.com>"]
version = "1.0.0"

[deps]
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"

[compat]
julia = "1.10"

[extras]
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["DiffEqDevTools", "Random", "SafeTestsets", "Test"]
35 changes: 35 additions & 0 deletions lib/OrdinaryDiffEqSSPRK/src/OrdinaryDiffEqSSPRK.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
module OrdinaryDiffEqSSPRK

import OrdinaryDiffEq: alg_order, calculate_residuals!,
initialize!, perform_step!, @unpack, unwrap_alg,
calculate_residuals, ssp_coefficient,
OrdinaryDiffEqAlgorithm,
OrdinaryDiffEqMutableCache, OrdinaryDiffEqConstantCache,
OrdinaryDiffEqNewtonAdaptiveAlgorithm,
OrdinaryDiffEqRosenbrockAdaptiveAlgorithm,
OrdinaryDiffEqAdaptiveAlgorithm, uses_uprev,
alg_cache, _vec, _reshape, @cache, isfsal, full_cache,
constvalue, _unwrap_val, du_alias_or_new,
explicit_rk_docstring, trivial_limiter!,
_ode_interpolant, _ode_interpolant!,
_ode_addsteps!
using DiffEqBase, FastBroadcast, Polyester, MuladdMacro, RecursiveArrayTools
using DiffEqBase: @def
using Static: False

import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, SA

include("algorithms.jl")
include("alg_utils.jl")
include("ssprk_caches.jl")
include("interp_func.jl")
include("ssprk_perform_step.jl")
include("interpolants.jl")
include("addsteps.jl")
include("functions.jl")

export SSPRK53_2N2, SSPRK22, SSPRK53, SSPRK63, SSPRK83, SSPRK43, SSPRK432, SSPRKMSVS32,
SSPRK54, SSPRK53_2N1, SSPRK104, SSPRK932, SSPRKMSVS43, SSPRK73, SSPRK53_H,
SSPRK33, SHLDDRK_2N, KYKSSPRK42, SHLDDRK52

end
21 changes: 21 additions & 0 deletions lib/OrdinaryDiffEqSSPRK/src/addsteps.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
function _ode_addsteps!(k, t, uprev, u, dt, f, p,
cache::Union{SSPRK22ConstantCache, SSPRK33ConstantCache,
SSPRK43ConstantCache, SSPRK432ConstantCache},
always_calc_begin = false, allow_calc_end = true,
force_calc_end = false)
if length(k) < 1 || always_calc_begin
copyat_or_push!(k, 1, f(uprev, p, t))
end
nothing
end

function _ode_addsteps!(k, t, uprev, u, dt, f, p,
cache::Union{SSPRK22Cache, SSPRK33Cache, SSPRK43Cache,
SSPRK432Cache}, always_calc_begin = false,
allow_calc_end = true, force_calc_end = false)
if length(k) < 1 || always_calc_begin
f(cache.k, uprev, p, t)
copyat_or_push!(k, 1, cache.k)
end
nothing
end
52 changes: 52 additions & 0 deletions lib/OrdinaryDiffEqSSPRK/src/alg_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
isfsal(alg::SSPRK53_2N1) = false
isfsal(alg::SSPRK53_2N2) = false
isfsal(alg::SSPRK22) = false
isfsal(alg::SSPRK33) = false
isfsal(alg::SSPRK53) = false
isfsal(alg::SSPRK53_H) = false
isfsal(alg::SSPRK63) = false
isfsal(alg::SSPRK73) = false
isfsal(alg::SSPRK83) = false
isfsal(alg::SSPRK43) = false
isfsal(alg::SSPRK432) = false
isfsal(alg::SSPRK932) = false
isfsal(alg::SSPRK54) = false
isfsal(alg::SSPRK104) = false

alg_order(alg::KYKSSPRK42) = 2
alg_order(alg::SSPRKMSVS32) = 2
alg_order(alg::SSPRK33) = 3
alg_order(alg::SSPRK53_2N1) = 3
alg_order(alg::SSPRK53_2N2) = 3
alg_order(alg::SSPRK22) = 2
alg_order(alg::SSPRK53) = 3
alg_order(alg::SSPRK53_H) = 3
alg_order(alg::SSPRK63) = 3
alg_order(alg::SSPRK73) = 3
alg_order(alg::SSPRK83) = 3
alg_order(alg::SSPRK43) = 3
alg_order(alg::SSPRK432) = 3
alg_order(alg::SSPRKMSVS43) = 3
alg_order(alg::SSPRK932) = 3
alg_order(alg::SSPRK54) = 4
alg_order(alg::SSPRK104) = 4
alg_order(alg::SHLDDRK_2N) = 4
alg_order(alg::SHLDDRK52) = 2

ssp_coefficient(alg::SSPRK53_2N1) = 2.18
ssp_coefficient(alg::SSPRK53_2N2) = 2.148
ssp_coefficient(alg::SSPRK53) = 2.65
ssp_coefficient(alg::SSPRK53_H) = 2.65
ssp_coefficient(alg::SSPRK63) = 3.518
ssp_coefficient(alg::SSPRK73) = 4.2879
ssp_coefficient(alg::SSPRK83) = 5.107
ssp_coefficient(alg::SSPRK43) = 2
ssp_coefficient(alg::SSPRK432) = 2
ssp_coefficient(alg::SSPRK932) = 6
ssp_coefficient(alg::SSPRK54) = 1.508
ssp_coefficient(alg::SSPRK104) = 6
ssp_coefficient(alg::SSPRK33) = 1
ssp_coefficient(alg::SSPRK22) = 1
ssp_coefficient(alg::SSPRKMSVS32) = 0.5
ssp_coefficient(alg::SSPRKMSVS43) = 0.33
ssp_coefficient(alg::KYKSSPRK42) = 2.459
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
using OrdinaryDiffEq: explicit_rk_docstring
using Static: False
@inline trivial_limiter!(u, integrator, p, t) = nothing

@doc explicit_rk_docstring(
"A third-order, five-stage explicit strong stability preserving (SSP) low-storage method.
Fixed timestep only.",
Expand Down
52 changes: 52 additions & 0 deletions lib/OrdinaryDiffEqSSPRK/src/functions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
@eval @inline function DiffEqBase.get_tmp_cache(integrator, alg::SSPRK22,
cache::OrdinaryDiffEqConstantCache)
nothing
end

@eval @inline function DiffEqBase.get_tmp_cache(integrator, alg::SSPRK33,
cache::OrdinaryDiffEqConstantCache)
nothing
end

@eval @inline function DiffEqBase.get_tmp_cache(integrator, alg::SSPRK53_2N1,
cache::OrdinaryDiffEqConstantCache)
nothing
end

@eval @inline function DiffEqBase.get_tmp_cache(integrator, alg::SSPRK53_2N2,
cache::OrdinaryDiffEqConstantCache)
nothing
end

@eval @inline function DiffEqBase.get_tmp_cache(integrator, alg::SSPRK432,
cache::OrdinaryDiffEqConstantCache)
nothing
end

@eval @inline function DiffEqBase.get_tmp_cache(integrator, alg::SSPRK932,
cache::OrdinaryDiffEqConstantCache)
nothing
end

@eval @inline function DiffEqBase.get_tmp_cache(integrator, alg::OrdinaryDiffEqNewtonAdaptiveAlgorithm,
cache::OrdinaryDiffEqConstantCache)
nothing
end

@eval @inline function DiffEqBase.get_tmp_cache(integrator, alg::OrdinaryDiffEqRosenbrockAdaptiveAlgorithm,
cache::OrdinaryDiffEqConstantCache)
nothing
end

@eval @inline function DiffEqBase.get_tmp_cache(integrator, alg::OrdinaryDiffEqAlgorithm,
cache::OrdinaryDiffEqConstantCache)
nothing
end

@inline function DiffEqBase.get_tmp_cache(integrator,
alg::Union{SSPRK22, SSPRK33, SSPRK53_2N1,
SSPRK53_2N2, SSPRK43, SSPRK432,
SSPRK932},
cache::OrdinaryDiffEqMutableCache)
(cache.k,)
end
10 changes: 10 additions & 0 deletions lib/OrdinaryDiffEqSSPRK/src/interp_func.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
function DiffEqBase.interp_summary(::Type{cacheType},
dense::Bool) where {
cacheType <:
Union{SSPRK22, SSPRK22ConstantCache,
SSPRK33, SSPRK33ConstantCache,
SSPRK43, SSPRK43ConstantCache,
SSPRK432, SSPRK432ConstantCache
}}
dense ? "2nd order \"free\" SSP interpolation" : "1st order linear"
end
135 changes: 135 additions & 0 deletions lib/OrdinaryDiffEqSSPRK/src/interpolants.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
const NEGZERO = Float16(-0.0f0)

@def ssprkpre0 begin
c00 = @evalpoly(Θ, 1, NEGZERO, -1)
c10 = Θ^2
b10dt = Θ * @evalpoly(Θ, 1, -1) * dt
end

@def ssprkpre1 begin
b10diff = @evalpoly(Θ, 1, -2)
c10diffinvdt = 2Θ * inv(dt) # = -c00diff * inv(dt)
end

@def ssprkpre2 begin
invdt = inv(dt)
b10diff2invdt = -2 * invdt
c10diff2invdt2 = 2 * invdt^2 # = -c00diff2 * inv(dt)^2
end

@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k,
cache::Union{SSPRK22ConstantCache, SSPRK22Cache,
SSPRK33ConstantCache, SSPRK33Cache,
SSPRK43ConstantCache, SSPRK43Cache,
SSPRK432ConstantCache, SSPRK432Cache},
idxs::Nothing, T::Type{Val{0}}, differential_vars::Nothing)
@ssprkpre0
@inbounds @.. broadcast=false y₀*c00+y₁*c10+k[1]*b10dt
end

@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k,
cache::Union{SSPRK22ConstantCache, SSPRK22Cache,
SSPRK33ConstantCache, SSPRK33Cache,
SSPRK43ConstantCache, SSPRK43Cache,
SSPRK432ConstantCache, SSPRK432Cache}, idxs,
T::Type{Val{0}}, differential_vars::Nothing)
@ssprkpre0
@views @.. broadcast=false y₀[idxs]*c00+y₁[idxs]*c10+k[1][idxs]*b10dt
end

@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
cache::Union{SSPRK22ConstantCache, SSPRK22Cache,
SSPRK33ConstantCache, SSPRK33Cache,
SSPRK43ConstantCache, SSPRK43Cache,
SSPRK432ConstantCache, SSPRK432Cache},
idxs::Nothing, T::Type{Val{0}}, differential_vars::Nothing)
@ssprkpre0
@inbounds @.. broadcast=false out=y₀ * c00 + y₁ * c10 + k[1] * b10dt
out
end

@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
cache::Union{SSPRK22ConstantCache, SSPRK22Cache,
SSPRK33ConstantCache, SSPRK33Cache,
SSPRK43ConstantCache, SSPRK43Cache,
SSPRK432ConstantCache, SSPRK432Cache}, idxs,
T::Type{Val{0}}, differential_vars::Nothing)
@ssprkpre0
@views @.. broadcast=false out=y₀[idxs] * c00 + y₁[idxs] * c10 + k[1][idxs] * b10dt
out
end

@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k,
cache::Union{SSPRK22ConstantCache, SSPRK22Cache,
SSPRK33ConstantCache, SSPRK33Cache,
SSPRK43ConstantCache, SSPRK43Cache,
SSPRK432ConstantCache, SSPRK432Cache},
idxs::Nothing, T::Type{Val{1}}, differential_vars::Nothing)
@ssprkpre1
@inbounds @.. broadcast=false (y₁ - y₀) * c10diffinvdt+k[1] * b10diff
end

@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k,
cache::Union{SSPRK22ConstantCache, SSPRK22Cache,
SSPRK33ConstantCache, SSPRK33Cache,
SSPRK43ConstantCache, SSPRK43Cache,
SSPRK432ConstantCache, SSPRK432Cache}, idxs,
T::Type{Val{1}}, differential_vars::Nothing)
@ssprkpre1
@views @.. broadcast=false (y₁[idxs] - y₀[idxs]) * c10diffinvdt+k[1][idxs] * b10diff
end

@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
cache::Union{SSPRK22ConstantCache, SSPRK22Cache,
SSPRK33ConstantCache, SSPRK33Cache,
SSPRK43ConstantCache, SSPRK43Cache,
SSPRK432ConstantCache, SSPRK432Cache},
idxs::Nothing, T::Type{Val{1}}, differential_vars::Nothing)
@ssprkpre1
@inbounds @.. broadcast=false out=(y₁ - y₀) * c10diffinvdt + k[1] * b10diff
out
end

@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
cache::Union{SSPRK22ConstantCache, SSPRK22Cache,
SSPRK33ConstantCache, SSPRK33Cache,
SSPRK43ConstantCache, SSPRK43Cache,
SSPRK432ConstantCache, SSPRK432Cache}, idxs,
T::Type{Val{1}}, differential_vars::Nothing)
@ssprkpre1
@views @.. broadcast=false out=(y₁[idxs] - y₀[idxs]) * c10diffinvdt +
k[1][idxs] * b10diff
out
end

@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k,
cache::Union{SSPRK22ConstantCache, SSPRK22Cache,
SSPRK33ConstantCache, SSPRK33Cache,
SSPRK43ConstantCache, SSPRK43Cache,
SSPRK432ConstantCache, SSPRK432Cache},
idxs::Nothing, T::Type{Val{2}}, differential_vars::Nothing)
@ssprkpre2
@inbounds @.. broadcast=false (y₁ - y₀) * c10diff2invdt2+k[1] * b10diff2invdt
end

@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k,
cache::Union{SSPRK22ConstantCache, SSPRK22Cache,
SSPRK33ConstantCache, SSPRK33Cache,
SSPRK43ConstantCache, SSPRK43Cache,
SSPRK432ConstantCache, SSPRK432Cache}, idxs,
T::Type{Val{2}}, differential_vars::Nothing)
@ssprkpre2
@views @.. broadcast=false (y₁[idxs] - y₀[idxs]) *
c10diff2invdt2+k[1][idxs] * b10diff2invdt
end

@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
cache::Union{SSPRK22ConstantCache, SSPRK22Cache,
SSPRK33ConstantCache, SSPRK33Cache,
SSPRK43ConstantCache, SSPRK43Cache,
SSPRK432ConstantCache, SSPRK432Cache},
idxs::Nothing, T::Type{Val{2}}, differential_vars::Nothing)
@ssprkpre2
@inbounds @.. broadcast=false out=(y₁ - y₀) * c10diff2invdt2 + k[1] * b10diff2invdt
out
end
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -1704,4 +1704,4 @@ end
stage_limiter!(u, integrator, p, t + dt)
step_limiter!(u, integrator, p, t + dt)
integrator.stats.nf += 10
end
end
File renamed without changes.
3 changes: 3 additions & 0 deletions lib/OrdinaryDiffEqSSPRK/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
using SafeTestsets

@time @safetestset "SSPRK Tests" include("ode_ssprk_tests.jl")
Loading

0 comments on commit 7986c8b

Please sign in to comment.