Skip to content

Commit

Permalink
Merge pull request #2584 from Shreyas-Ekanathan/master
Browse files Browse the repository at this point in the history
Parallelize Radau
  • Loading branch information
ChrisRackauckas authored Jan 31, 2025
2 parents 2046615 + 02a5616 commit d90935d
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 31 deletions.
2 changes: 2 additions & 0 deletions lib/OrdinaryDiffEqFIRK/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
OrdinaryDiffEqDifferentiation = "4302a76b-040a-498a-8c04-15b101fed76b"
OrdinaryDiffEqNonlinearSolve = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8"
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Expand All @@ -34,6 +35,7 @@ ODEProblemLibrary = "0.1.8"
OrdinaryDiffEqCore = "1.14"
OrdinaryDiffEqDifferentiation = "<0.0.1, 1.2"
OrdinaryDiffEqNonlinearSolve = "<0.0.1, 1"
Polyester = "0.7.16"
Random = "<0.0.1, 1"
RecursiveArrayTools = "3.27.0"
Reexport = "1.2.2"
Expand Down
7 changes: 4 additions & 3 deletions lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!,
OrdinaryDiffEqAlgorithm, OrdinaryDiffEqNewtonAdaptiveAlgorithm,
OrdinaryDiffEqMutableCache, OrdinaryDiffEqConstantCache,
OrdinaryDiffEqAdaptiveAlgorithm, CompiledFloats, uses_uprev,
alg_cache, _vec, _reshape, @cache, isfsal, full_cache,
constvalue, _unwrap_val,
alg_cache, _vec, _reshape, @cache, @threaded, isthreaded, PolyesterThreads,
isfsal, full_cache, constvalue, _unwrap_val,
differentiation_rk_docstring, trivial_limiter!,
_ode_interpolant!, _ode_addsteps!, AbstractController,
qmax_default, alg_adaptive_order, DEFAULT_PRECS,
Expand All @@ -18,7 +18,8 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!,
get_current_adaptive_order, get_fsalfirstlast,
isfirk, generic_solver_docstring, _bool_to_ADType,
_process_AD_choice
using MuladdMacro, DiffEqBase, RecursiveArrayTools
using MuladdMacro, DiffEqBase, RecursiveArrayTools, Polyester
isfirk, generic_solver_docstring
using SciMLOperators: AbstractSciMLOperator
using LinearAlgebra: I, UniformScaling, mul!, lu
import LinearSolve
Expand Down
36 changes: 19 additions & 17 deletions lib/OrdinaryDiffEqFIRK/src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,26 +166,27 @@ function RadauIIA9(; chunk_size = Val{0}(), autodiff = AutoForwardDiff(),
AD_choice)
end

struct AdaptiveRadau{CS, AD, F, P, FDT, ST, CJ, Tol, C1, C2, StepLimiter} <:
struct AdaptiveRadau{CS, AD, F, P, FDT, ST, CJ, Tol, C1, C2, StepLimiter, TO} <:
OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ}
linsolve::F
precs::P
smooth_est::Bool
extrapolant::Symbol
κ::Tol
maxiters::Int
fast_convergence_cutoff::C1
new_W_γdt_cutoff::C2
controller::Symbol
step_limiter!::StepLimiter
min_order::Int
max_order::Int
autodiff::AD
linsolve::F
precs::P
smooth_est::Bool
extrapolant::Symbol
κ::Tol
maxiters::Int
fast_convergence_cutoff::C1
new_W_γdt_cutoff::C2
controller::Symbol
step_limiter!::StepLimiter
min_order::Int
max_order::Int
threading::TO
autodiff::AD
end

function AdaptiveRadau(; chunk_size = Val{0}(), autodiff = AutoForwardDiff(),
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}(), min_order = 5, max_order = 13,
diff_type = Val{:forward}, min_order = 5, max_order = 13, threading = false,
linsolve = nothing, precs = DEFAULT_PRECS,
extrapolant = :dense, fast_convergence_cutoff = 1 // 5,
new_W_γdt_cutoff = 1 // 5,
Expand All @@ -197,7 +198,7 @@ function AdaptiveRadau(; chunk_size = Val{0}(), autodiff = AutoForwardDiff(),
AdaptiveRadau{_unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve),
typeof(precs), diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac),
typeof(κ), typeof(fast_convergence_cutoff),
typeof(new_W_γdt_cutoff), typeof(step_limiter!)}(linsolve,
typeof(new_W_γdt_cutoff), typeof(step_limiter!), typeof(threading)}(linsolve,
precs,
smooth_est,
extrapolant,
Expand All @@ -206,6 +207,7 @@ function AdaptiveRadau(; chunk_size = Val{0}(), autodiff = AutoForwardDiff(),
fast_convergence_cutoff,
new_W_γdt_cutoff,
controller,
step_limiter!, min_order, max_order, AD_choice)
step_limiter!, min_order, max_order, threading,
AD_choice)
end

48 changes: 38 additions & 10 deletions lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1619,8 +1619,21 @@ end
if (new_W = do_newW(integrator, alg, new_jac, cache.W_γdt))
@inbounds for II in CartesianIndices(J)
W1[II] = -γdt * mass_matrix[Tuple(II)...] + J[II]
for i in 1 : (num_stages - 1) ÷ 2
W2[i][II] = -(αdt[i] + βdt[i] * im) * mass_matrix[Tuple(II)...] + J[II]
end
if !isthreaded(alg.threading)
@inbounds for II in CartesianIndices(J)
for i in 1 : (num_stages - 1) ÷ 2
W2[i][II] = -(αdt[i] + βdt[i] * im) * mass_matrix[Tuple(II)...] + J[II]
end
end
else
let W1 = W1, W2 = W2, γdt = γdt, αdt = αdt, βdt = βdt, mass_matrix = mass_matrix,
num_stages = num_stages, J = J
@inbounds @threaded alg.threading for i in 1 : (num_stages - 1) ÷ 2
for II in CartesianIndices(J)
W2[i][II] = -(αdt[i] + βdt[i] * im) * mass_matrix[Tuple(II)...] + J[II]
end
end
end
end
integrator.stats.nw += 1
Expand Down Expand Up @@ -1706,16 +1719,30 @@ end
cache.linsolve1 = dolinsolve(integrator, linsolve1; A = nothing, b = _vec(ubuff), linu = _vec(dw1)).cache
end

for i in 1 :(num_stages - 1) ÷ 2
@.. cubuff[i]=complex(
fw[2 * i] - αdt[i] * Mw[2 * i] + βdt[i] * Mw[2 * i + 1], fw[2 * i + 1] - βdt[i] * Mw[2 * i] - αdt[i] * Mw[2 * i + 1])
if needfactor
cache.linsolve2[i] = dolinsolve(integrator, linsolve2[i]; A = W2[i], b = _vec(cubuff[i]), linu = _vec(dw2[i])).cache
else
cache.linsolve2[i] = dolinsolve(integrator, linsolve2[i]; A = nothing, b = _vec(cubuff[i]), linu = _vec(dw2[i])).cache
if !isthreaded(alg.threading)
for i in 1 :(num_stages - 1) ÷ 2
@.. cubuff[i]=complex(fw[2 * i] - αdt[i] * Mw[2 * i] + βdt[i] * Mw[2 * i + 1], fw[2 * i + 1] - βdt[i] * Mw[2 * i] - αdt[i] * Mw[2 * i + 1])
if needfactor
cache.linsolve2[i] = dolinsolve(integrator, linsolve2[i]; A = W2[i], b = _vec(cubuff[i]), linu = _vec(dw2[i])).cache
else
cache.linsolve2[i] = dolinsolve(integrator, linsolve2[i]; A = nothing, b = _vec(cubuff[i]), linu = _vec(dw2[i])).cache
end
end
else
let integrator = integrator, linsolve2 = linsolve2, fw = fw, αdt = αdt, βdt = βdt, Mw = Mw, W1 = W1, W2 = W2,
cubuff = cubuff, dw2 = dw2, needfactor = needfactor
@threaded alg.threading for i in 1:(num_stages - 1) ÷ 2
#@show i == Threads.threadid()
@.. cubuff[i]=complex(fw[2 * i] - αdt[i] * Mw[2 * i] + βdt[i] * Mw[2 * i + 1], fw[2 * i + 1] - βdt[i] * Mw[2 * i] - αdt[i] * Mw[2 * i + 1])
if needfactor
cache.linsolve2[i] = dolinsolve(integrator, linsolve2[i]; A = W2[i], b = _vec(cubuff[i]), linu = _vec(dw2[i])).cache
else
cache.linsolve2[i] = dolinsolve(integrator, linsolve2[i]; A = nothing, b = _vec(cubuff[i]), linu = _vec(dw2[i])).cache
end
end
end
end

integrator.stats.nsolve += (num_stages + 1) / 2

for i in 1 : (num_stages - 1) ÷ 2
Expand Down Expand Up @@ -1850,3 +1877,4 @@ end
integrator.stats.nf += 1
return
end

10 changes: 9 additions & 1 deletion lib/OrdinaryDiffEqFIRK/test/ode_firk_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,19 @@ using GenericSchur
prob_ode_linear_big = remake(prob_ode_linear, u0 = big.(prob_ode_linear.u0), tspan = big.(prob_ode_linear.tspan))
prob_ode_2Dlinear_big = remake(prob_ode_2Dlinear, u0 = big.(prob_ode_2Dlinear.u0), tspan = big.(prob_ode_2Dlinear.tspan))

#non-threaded tests
for i in [5, 9, 13, 17, 21, 25], prob in [prob_ode_linear_big, prob_ode_2Dlinear_big]
dts = 1 ./ 2 .^ (4.25:-1:0.25)
local sim21 = test_convergence(dts, prob, AdaptiveRadau(min_order = i, max_order = i))
@test sim21.𝒪est[:final] i atol=testTol
end
#threaded tests
using OrdinaryDiffEqCore
for i in [5, 9, 13, 17, 21, 25], prob in [prob_ode_linear_big, prob_ode_2Dlinear_big]
dts = 1 ./ 2 .^ (4.25:-1:0.25)
local sim21 = test_convergence(dts, prob, AdaptiveRadau(min_order = i, max_order = i, threading = OrdinaryDiffEqCore.PolyesterThreads()))
@test sim21.𝒪est[:final] i atol=testTol
end

# test adaptivity
for iip in (true, false)
Expand Down Expand Up @@ -68,4 +76,4 @@ for iip in (true, false)
@test sol.stats.njacs < sol.stats.nw # W reuse
end
@test length(sol) < 5000 # the error estimate is not very good
end
end

0 comments on commit d90935d

Please sign in to comment.