Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Lobatto IIIa-c and Radau IIa solvers #178

Merged
merged 70 commits into from
Sep 1, 2024
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
84b5604
Added tests
axla-io Mar 8, 2024
5644872
Solve routines for FIRM and MIRK
axla-io Mar 8, 2024
b5fa089
Added adaptivity
axla-io Mar 8, 2024
ed8aac4
Added alg_utils
axla-io Mar 8, 2024
1f410e2
Added algorithms
axla-io Mar 8, 2024
b7e6e74
Updated BoundaryValueDiffEq.jl
axla-io Mar 8, 2024
a411397
Added collocation
axla-io Mar 8, 2024
6112055
Added interpolation
axla-io Mar 8, 2024
ec2bde7
Added radau and lobatto tableaus
axla-io Mar 8, 2024
eb58a2a
Added types
axla-io Mar 8, 2024
d2ce79d
updated sparse jacobians
axla-io Mar 8, 2024
cfdf839
Updated utils
axla-io Mar 8, 2024
a229cf5
General debugging
axla-io Mar 9, 2024
13a935d
Tests not related to adaptivity now work
axla-io Mar 9, 2024
6c65096
sparse jac modification
axla-io Mar 9, 2024
531c735
OOP convergence
axla-io Mar 9, 2024
8a0fa31
Tests that work
axla-io Mar 9, 2024
87a496e
Working adaptivity
axla-io Mar 9, 2024
956ae89
Tests run for nested
axla-io Mar 9, 2024
23b90ea
Expanded FIRK interpolations work
axla-io Mar 9, 2024
dedd0c6
Interpolation runs for nested
axla-io Mar 9, 2024
b5dd3e0
Working
axla-io Mar 9, 2024
f65cecc
Fixed vector of vector example
axla-io Mar 16, 2024
ff03c82
testing nested
axla-io Mar 16, 2024
1874be4
Test nest tol
axla-io Mar 17, 2024
5928460
Split tests to nested and non nested
axla-io Mar 17, 2024
e96c3e3
Tolerance docstring
axla-io Mar 17, 2024
abd6a84
Expanded tests are passing
axla-io Mar 31, 2024
aa3ae08
Nested tested
axla-io Mar 31, 2024
6c76956
Updated docs, removed double lines
axla-io Mar 31, 2024
2d80a40
Removed tabs in formatting
axla-io Apr 1, 2024
9cfa960
Merge branch 'master' into al/properlobatto
ErikQQY Jul 21, 2024
b131653
Lobatto methods working
ErikQQY Jul 21, 2024
b77d5a0
Upgrade to the latest ADTypes
ErikQQY Jul 22, 2024
49579a4
Skip tests for overconstrained and underconstrained BVP
ErikQQY Jul 22, 2024
6997d5c
Fix interpolation when plotting
ErikQQY Jul 24, 2024
969a985
Merge branch 'master' into al/properlobatto
ErikQQY Aug 17, 2024
ab4b59f
All test cases should work now
ErikQQY Aug 17, 2024
aabbc4d
Merge branch 'master' into al/properlobatto
ErikQQY Aug 17, 2024
772762c
Remove Aqua ambiguities test
ErikQQY Aug 17, 2024
1dd7eba
Put FIRK nlls tests into different test group
ErikQQY Aug 18, 2024
df610a7
Fix RAT related errors
ErikQQY Aug 18, 2024
ff74dd2
Fix deprecated AutoSparse
ErikQQY Aug 18, 2024
9a02a0c
Fix NLLS tests for FIRK
ErikQQY Aug 18, 2024
4048f45
Fix deprecated errors and only test FIRK for now
ErikQQY Aug 19, 2024
7241024
Don't test all methods and fix time span error
ErikQQY Aug 21, 2024
cbce3df
Fix deprecated usage of RAT
ErikQQY Aug 21, 2024
ee7440c
Fix interpolation for nested FIRK
ErikQQY Aug 22, 2024
ef58021
Fix typo
ErikQQY Aug 22, 2024
e03858b
Add the other tests back and clean up
ErikQQY Aug 24, 2024
a945fa0
All methods are working now
ErikQQY Aug 25, 2024
3222a19
Merge branch 'master' into al/properlobatto
ErikQQY Aug 25, 2024
35cc24a
Fix several test_broken and less allocation
ErikQQY Aug 26, 2024
7df0a73
Fix BigFloat support for FIRK solvers
ErikQQY Aug 26, 2024
c414345
Fix several test_broken
ErikQQY Aug 26, 2024
0cba86b
Fix interp_eval
ErikQQY Aug 26, 2024
27154fa
Use ForwardDiff in FIRK test cases
ErikQQY Aug 27, 2024
d67f975
Upgrade OrdinaryDiffEq
ErikQQY Aug 27, 2024
b92e192
Split into different test groups
ErikQQY Aug 28, 2024
7882dcb
Dont specify os
ErikQQY Aug 28, 2024
7fb997f
Dont specify os
ErikQQY Aug 28, 2024
6aacb4b
Put all of the tests into different test groups
ErikQQY Aug 28, 2024
0dfe5c1
Use more test groups
ErikQQY Aug 29, 2024
b491e65
Fix Interpolation errors and format
ErikQQY Aug 29, 2024
ca3fcc4
Set Aqua tests as test_broken
ErikQQY Aug 29, 2024
8bb2ec4
Fix deprecation errors
ErikQQY Aug 29, 2024
1f87d85
Skip Aqua for unrelated depwarn
ErikQQY Aug 29, 2024
f39a827
Fix nlls tests cost too much time
ErikQQY Aug 29, 2024
e83e925
Skip RadauIIa5 with GaussNewton for now
ErikQQY Aug 30, 2024
f8f5b35
Refactor the internal of FIRK
ErikQQY Sep 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/BoundaryValueDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@ include("algorithms.jl")
include("alg_utils.jl")

include("mirk_tableaus.jl")
include("lobatto_tableaus.jl")
include("radau_tableaus.jl")

include("solve/single_shooting.jl")
include("solve/multiple_shooting.jl")
include("solve/firk.jl")
include("solve/mirk.jl")

include("collocation.jl")
Expand Down Expand Up @@ -273,6 +276,10 @@ export Shooting, MultipleShooting
export MIRK2, MIRK3, MIRK4, MIRK5, MIRK6
export BVPM2, BVPSOL, COLNEW # From ODEInterface.jl

export RadauIIa1, RadauIIa2, RadauIIa3, RadauIIa5, RadauIIa7
export LobattoIIIa2, LobattoIIIa3, LobattoIIIa4, LobattoIIIa5
export LobattoIIIb2, LobattoIIIb3, LobattoIIIb4, LobattoIIIb5
export LobattoIIIc2, LobattoIIIc3, LobattoIIIc4, LobattoIIIc5
export MIRKJacobianComputationAlgorithm, BVPJacobianAlgorithm

end
249 changes: 246 additions & 3 deletions src/adaptivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,129 @@ After we construct an interpolant, we use interp_eval to evaluate it.
return y
end

@views function interp_eval!(y::AbstractArray, cache::FIRKCacheExpand{iip}, t, mesh, mesh_dt) where {iip}
i = findfirst(x -> x == y, cache.y₀.u)
interp_eval!(cache.y₀.u, i, cache::FIRKCacheExpand{iip}, t, mesh, mesh_dt)
return y
end

@views function interp_eval!(y::AbstractArray, i::Int, cache::FIRKCacheExpand{iip}, t, mesh, mesh_dt) where {iip}
j = interval(mesh, t)
h = mesh_dt[j]
lf = (length(cache.y₀) - 1) / (length(cache.y) - 1) # Cache length factor. We use a h corresponding to cache.y. Note that this assumes equidistributed mesh
if lf > 1
h *= lf
end
τ = (t - mesh[j]) / h

(; f, M, p, ITU, TU) = cache
(; c, a, b) = TU
(; q_coeff, stage) = ITU

K = zeros(eltype(cache.y[1].du), M, stage)
ErikQQY marked this conversation as resolved.
Show resolved Hide resolved

ctr_y0 = (i - 1) * (ITU.stage + 1) + 1
ctr_y = (j - 1) * (ITU.stage + 1) + 1

yᵢ = cache.y[ctr_y].du
yᵢ₊₁ = cache.y[ctr_y + ITU.stage + 1].du

if iip
dyᵢ = copy(yᵢ)
dyᵢ₊₁ = copy(yᵢ₊₁)

f(dyᵢ, yᵢ, cache.p, mesh[j])
f(dyᵢ₊₁, yᵢ₊₁, cache.p, mesh[j + 1])
else
dyᵢ = f(yᵢ, cache.p, mesh[j])
dyᵢ₊₁ = f(yᵢ₊₁, cache.p, mesh[j + 1])
end

# Load interpolation residual
for jj in 1:stage
K[:, jj] = cache.y[ctr_y + jj].du
end

z₁, z₁′ = eval_q(yᵢ, 0.5, h, q_coeff, K) # Evaluate q(x) at midpoints
S_coeffs = get_S_coeffs(h, yᵢ, yᵢ₊₁, z₁, dyᵢ, dyᵢ₊₁, z₁′)

#FIXME: need a better way to fix this
eltype(y) <: AbstractArray ? (y[ctr_y0] = S_interpolate(τ * h, S_coeffs)) : (y .= S_interpolate(τ * h, S_coeffs))
if ctr_y0 > length(y)
for (k, ci) in enumerate(c)
eltype(y) <: AbstractArray ? (y[ctr_y0 + k] = dS_interpolate(τ * h + (1 - τ * h) * ci, S_coeffs)) : (y = dS_interpolate(τ * h + (1 - τ * h) * ci, S_coeffs))
end
end

return eltype(y) <: AbstractArray ? y[ctr_y0] : y
end

@views function interp_eval!(y::AbstractArray, cache::FIRKCacheNested{iip}, t, mesh, mesh_dt) where {iip}
j = interval(mesh, t)
h = mesh_dt[j]
lf = (length(cache.y₀) - 1) / (length(cache.y) - 1) # Cache length factor. We use a h corresponding to cache.y. Note that this assumes equidistributed mesh
if lf > 1
h *= lf
end
τ = (t - mesh[j]) / h

(; f, M, p, k_discrete, ITU, TU, nest_cache, p_nestprob, prob) = cache
(; c, a, b) = TU
(; q_coeff, stage) = ITU

yᵢ = copy(cache.y[j].du)
yᵢ₊₁ = copy(cache.y[j + 1].du)

if iip
dyᵢ = copy(yᵢ)
dyᵢ₊₁ = copy(yᵢ₊₁)
ErikQQY marked this conversation as resolved.
Show resolved Hide resolved

f(dyᵢ, yᵢ, cache.p, mesh[j])
f(dyᵢ₊₁, yᵢ₊₁, cache.p, mesh[j + 1])
else
dyᵢ = f(yᵢ, cache.p, mesh[j])
dyᵢ₊₁ = f(yᵢ₊₁, cache.p, mesh[j + 1])
end

# Load interpolation residual
y_i = eltype(yᵢ) == Float64 ? yᵢ : [y.value for y in yᵢ]

p_nestprob[1:2] .= promote(mesh[j], mesh_dt[j], one(eltype(y_i)))[1:2]
p_nestprob[3:end] .= y_i

solve_cache!(nest_cache, k_discrete[j].du, p_nestprob)
K = nest_cache.u

z₁, z₁′ = eval_q(yᵢ, 0.5, h, q_coeff, K) # Evaluate q(x) at midpoints
S_coeffs = get_S_coeffs(h, yᵢ, yᵢ₊₁, z₁, dyᵢ, dyᵢ₊₁, z₁′)

y .= S_interpolate(τ * h, S_coeffs)

return y
end

function get_S_coeffs(h, yᵢ, yᵢ₊₁, dyᵢ, dyᵢ₊₁, ymid, dymid)
vals = vcat(yᵢ, yᵢ₊₁, dyᵢ, dyᵢ₊₁, ymid, dymid)
M = length(yᵢ)
A = s_constraints(M, h)
coeffs = reshape(A \ vals, 6, M)'
return coeffs
end

# S forward Interpolation
function S_interpolate(t, coeffs)
ts = [t^(i - 1) for i in axes(coeffs, 2)]
return coeffs * ts
end

function dS_interpolate(t, S_coeffs)
ts = zeros(size(S_coeffs, 2))
for i in 2:size(S_coeffs, 2)
ts[i] = (i - 1) * t^(i - 2)
end
return S_coeffs * ts
end

"""
interval(mesh, t)

Expand All @@ -26,7 +149,7 @@ end

Generate new mesh based on the defect.
"""
@views function mesh_selector!(cache::MIRKCache{iip, T}) where {iip, T}
@views function mesh_selector!(cache::Union{MIRKCache{iip, T}, FIRKCacheExpand{iip, T}, FIRKCacheNested{iip, T}}) where {iip, T}
(; M, order, defect, mesh, mesh_dt) = cache
(abstol, _, _), kwargs = __split_mirk_kwargs(; cache.kwargs...)
N = length(cache.mesh)
Expand Down Expand Up @@ -134,7 +257,7 @@ function half_mesh!(mesh::Vector{T}, mesh_dt::Vector{T}) where {T}
end
return mesh, mesh_dt
end
half_mesh!(cache::MIRKCache) = half_mesh!(cache.mesh, cache.mesh_dt)
half_mesh!(cache::Union{MIRKCache, FIRKCacheNested, FIRKCacheExpand}) = half_mesh!(cache.mesh, cache.mesh_dt)

"""
defect_estimate!(cache::MIRKCache)
Expand Down Expand Up @@ -169,7 +292,7 @@ an interpolant

z, z′ = sum_stages!(cache, w₂, w₂′, i)
if iip
yᵢ₂ = cache.y[i + 1].du
yᵢ₂ = cache.y[i+1].du
f(yᵢ₂, z, cache.p, mesh[i] + (T(1) - τ_star) * dt)
else
yᵢ₂ = f(z, cache.p, mesh[i] + (T(1) - τ_star) * dt)
Expand All @@ -183,6 +306,126 @@ an interpolant
return maximum(Base.Fix1(maximum, abs), defect.u)
end

@views function defect_estimate!(cache::FIRKCacheExpand{iip, T}) where {iip, T}
(; f, M, stage, mesh, mesh_dt, defect, ITU) = cache
(; q_coeff, τ_star) = ITU

ctr = 1
K = zeros(eltype(cache.y[1].du), M, stage)
for i in 1:(length(mesh) - 1)
h = mesh_dt[i]

# Load interpolation residual
for j in 1:stage
K[:, j] = cache.y[ctr + j].du
end

# Defect estimate from q(x) at y_i + τ* * h
yᵢ₁ = copy(cache.y[ctr].du)
yᵢ₂ = copy(yᵢ₁)
z₁, z₁′ = eval_q(yᵢ₁, τ_star, h, q_coeff, K)
if iip
f(yᵢ₁, z₁, cache.p, mesh[i] + τ_star * h)
else
yᵢ₁ = f(z₁, cache.p, mesh[i] + τ_star * h)
end
yᵢ₁ .= (z₁′ .- yᵢ₁) ./ (abs.(yᵢ₁) .+ T(1))
est₁ = maximum(abs, yᵢ₁)

z₂, z₂′ = eval_q(yᵢ₂, (T(1) - τ_star), h, q_coeff, K)
# Defect estimate from q(x) at y_i + (1-τ*) * h
if iip
f(yᵢ₂, z₂, cache.p, mesh[i] + (T(1) - τ_star) * h)
else
yᵢ₂ = f(z₂, cache.p, mesh[i] + (T(1) - τ_star) * h)
end
yᵢ₂ .= (z₂′ .- yᵢ₂) ./ (abs.(yᵢ₂) .+ T(1))
est₂ = maximum(abs, yᵢ₂)

defect.u[i] .= est₁ > est₂ ? yᵢ₁ : yᵢ₂
ctr += stage + 1 # Advance one step
end

return maximum(Base.Fix1(maximum, abs), defect)
end

@views function defect_estimate!(cache::FIRKCacheNested{iip, T}) where {iip, T}
(; f, M, stage, mesh, mesh_dt, defect, TU, ITU, nest_cache, p_nestprob, prob) = cache
(; a, c) = TU
(; q_coeff, τ_star) = ITU

for i in 1:(length(mesh) - 1)
h = mesh_dt[i]
yᵢ₁ = copy(cache.y[i].du)
yᵢ₂ = copy(yᵢ₁)

K = copy(cache.k_discrete[i].du)

if minimum(abs.(K)) < 1e-2
K = fill(one(eltype(K)), size(K))
end

y_i = eltype(yᵢ₁) == Float64 ? yᵢ₁ : [y.value for y in yᵢ₁]

p_nestprob[1:2] .= promote(mesh[i], mesh_dt[i], one(eltype(y_i)))[1:2]
p_nestprob[3:end] = y_i
solve_cache!(nest_cache, K, p_nestprob)

# Defect estimate from q(x) at y_i + τ* * h
z₁, z₁′ = eval_q(yᵢ₁, τ_star, h, q_coeff, nest_cache.u)
if iip
f(yᵢ₁, z₁, cache.p, mesh[i] + τ_star * h)
else
yᵢ₁ = f(z₁, cache.p, mesh[i] + τ_star * h)
end
yᵢ₁ .= (z₁′ .- yᵢ₁) ./ (abs.(yᵢ₁) .+ T(1))
est₁ = maximum(abs, yᵢ₁)

# Defect estimate from q(x) at y_i + (1-τ*) * h
z₂, z₂′ = eval_q(yᵢ₂, (T(1) - τ_star), h, q_coeff, nest_cache.u)
if iip
f(yᵢ₂, z₂, cache.p, mesh[i] + (T(1) - τ_star) * h)
else
yᵢ₂ = f(z₂, cache.p, mesh[i] + (T(1) - τ_star) * h)
end
yᵢ₂ .= (z₂′ .- yᵢ₂) ./ (abs.(yᵢ₂) .+ T(1))
est₂ = maximum(abs, yᵢ₂)

defect.u[i] .= est₁ > est₂ ? yᵢ₁ : yᵢ₂
end

return maximum(Base.Fix1(maximum, abs), defect)
end

function get_q_coeffs(A, ki, h)
coeffs = A * ki
for i in axes(coeffs, 1)
coeffs[i] = coeffs[i] / (h^(i - 1))
end
return coeffs
end

function apply_q(y_i, τ, h, coeffs)
return y_i + sum(coeffs[i] * (τ * h)^(i) for i in axes(coeffs, 1))
end

function apply_q_prime(τ, h, coeffs)
return sum(i * coeffs[i] * (τ * h)^(i - 1) for i in axes(coeffs, 1))
end

function eval_q(y_i, τ, h, A, K)
M = size(K, 1)
q = zeros(M)
q′ = zeros(M)
for i in 1:M
ki = @view K[i, :]
coeffs = get_q_coeffs(A, ki, h)
q[i] = apply_q(y_i[i], τ, h, coeffs)
q′[i] = apply_q_prime(τ, h, coeffs)
end
return q, q′
end

"""
interp_setup!(cache::MIRKCache)

Expand Down
25 changes: 25 additions & 0 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,33 @@ for order in (2, 3, 4, 5, 6)
@eval alg_stage(::$(alg)) = $(order - 1)
end

for stage in (1, 2, 3, 5, 7)
alg = Symbol("RadauIIa$(stage)")
@eval alg_order(::$(alg)) = $(2 * stage - 1)
@eval alg_stage(::$(alg)) = $stage
end

for stage in (2, 3, 4, 5)
alg = Symbol("LobattoIIIa$(stage)")
@eval alg_order(::$(alg)) = $(2 * stage - 2)
@eval alg_stage(::$(alg)) = $stage
end

for stage in (2, 3, 4, 5)
alg = Symbol("LobattoIIIb$(stage)")
@eval alg_order(::$(alg)) = $(2 * stage - 2)
@eval alg_stage(::$(alg)) = $stage
end

for stage in (2, 3, 4, 5)
alg = Symbol("LobattoIIIc$(stage)")
@eval alg_order(::$(alg)) = $(2 * stage - 2)
@eval alg_stage(::$(alg)) = $stage
end

SciMLBase.isautodifferentiable(::BoundaryValueDiffEqAlgorithm) = true
SciMLBase.allows_arbitrary_number_types(::BoundaryValueDiffEqAlgorithm) = true
SciMLBase.allowscomplex(alg::BoundaryValueDiffEqAlgorithm) = true

SciMLBase.isadaptive(alg::AbstractMIRK) = true
SciMLBase.isadaptive(alg::AbstractFIRK) = true
Loading
Loading