Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use DifferentiationInterface for AD in Implicit Solvers #2567

Draft
wants to merge 21 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions lib/OrdinaryDiffEqDifferentiation/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ version = "1.3.0"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Enzyme need to become a dependency? This adds significant install overhead, but if AutoEnzyme is to be the new default AD then it makes sense

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, probably doesn't need to be a dependency unless we're committing to having it be the default.

FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -16,7 +18,9 @@ LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

Expand All @@ -25,6 +29,7 @@ ADTypes = "1.11"
ArrayInterface = "7"
DiffEqBase = "6"
DiffEqDevTools = "2.44.4"
DifferentiationInterface = "0.6.23"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
DifferentiationInterface = "0.6.23"
DifferentiationInterface = "0.6.28"

the other deps are also missing compat bounds?

FastBroadcast = "0.3"
FiniteDiff = "2"
ForwardDiff = "0.10"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
module OrdinaryDiffEqDifferentiation

import ADTypes: AutoFiniteDiff, AutoForwardDiff, AbstractADType
import ADTypes
import ADTypes: AutoFiniteDiff, AutoForwardDiff, AbstractADType, AutoSparse

import SparseDiffTools: SparseDiffTools, matrix_colors, forwarddiff_color_jacobian!,
forwarddiff_color_jacobian, ForwardColorJacCache,
default_chunk_size, getsize, JacVec

import ForwardDiff, FiniteDiff
import SparseMatrixColorings: GreedyColoringAlgorithm
import SparseConnectivityTracer: TracerSparsityDetector

import ForwardDiff, FiniteDiff, Enzyme
import ForwardDiff.Dual
import LinearSolve
import LinearSolve: OperatorAssumptions
Expand Down Expand Up @@ -46,6 +50,8 @@ using OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, OrdinaryDiffEqAdaptiveImplici

import OrdinaryDiffEqCore: get_chunksize, resize_J_W!, resize_nlsolver!, alg_autodiff, _get_fwd_tag

import DifferentiationInterface as DI

using FastBroadcast: @..

@static if isdefined(DiffEqBase, :OrdinaryDiffEqTag)
Expand Down
106 changes: 76 additions & 30 deletions lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,47 +47,93 @@ function DiffEqBase.prepare_alg(
u0::AbstractArray{T},
p, prob) where {AD, FDT, T}

# If not using autodiff or norecompile mode or very large bitsize (like a dual number u0 already)
# don't use a large chunksize as it will either error or not be beneficial
# If prob.f.f is a FunctionWrappersWrappers from ODEFunction, need to set chunksize to 1

if alg_autodiff(alg) isa AutoForwardDiff && ((prob.f isa ODEFunction &&
prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper) || (isbitstype(T) && sizeof(T) > 24))
return remake(alg, autodiff = AutoForwardDiff(chunksize = 1, tag = alg_autodiff(alg).tag))
end
autodiff = prepare_ADType(alg_autodiff(alg), prob, u0, p, standardtag(alg))

# If the autodiff alg is AutoFiniteDiff, prob.f.f isa FunctionWrappersWrapper,
# and fdtype is complex, fdtype needs to change to something not complex
if alg_autodiff(alg) isa AutoFiniteDiff
if alg_difftype(alg) == Val{:complex} && (prob.f isa ODEFunction && prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper)
@warn "AutoFiniteDiff fdtype complex is not compatible with this function"
return remake(alg, autodiff = AutoFiniteDiff(fdtype = Val{:forward}()))
#sparsity preparation

sparsity = prob.f.sparsity

if sparsity isa SparseMatrixCSC
if f.mass_matrix isa UniformScaling
idxs = diagind(sparsity)
@. @view(sparsity[idxs]) = 1
else
idxs = findall(!iszero, f.mass_matrix)
@. @view(sparsity[idxs]) = @view(f.mass_matrix[idxs])
end
return alg
end

L = StaticArrayInterface.known_length(typeof(u0))
if L === nothing # dynamic sized
sparsity_detector = isnothing(sparsity) ? TracerSparsityDetector() : ADTypes.KnownJacobianSparsityDetector(sparsity)
color_alg = DiffEqBase.has_colorvec(prob.f) ? ADTypes.ConstantColoringAlgorithm(sparsity, prob.f.colorvec) : GreedyColoringAlgorithm()

autodiff = AutoSparse(autodiff, sparsity_detector = sparsity_detector, coloring_algorithm = color_alg)

alg = remake(alg, autodiff = autodiff)

return alg
end

function prepare_ADType(autodiff_alg::AutoSparse, prob, u0, p, standardtag)
prepare_ADType(dense_ad(autodiff_alg), prob, u0, p, standardtag)
end

function prepare_ADType(autodiff_alg::AutoForwardDiff, prob, u0, p, standardtag)
tag = if standardtag
ForwardDiff.Tag(OrdinaryDiffEqTag(), eltype(u0))
else
nothing
end

T = eltype(u0)

if ((prob.f isa ODEFunction &&
prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper) ||
(isbitstype(T) && sizeof(T) > 24))
autodiff_alg = AutoForwardDiff(chunksize = 1, tag = tag)
end

#L = StaticArrayInterface.known_length(typeof(u0))
#if L === nothing # dynamic sized
# If chunksize is zero, pick chunksize right at the start of solve and
# then do function barrier to infer the full solve
x = if prob.f.colorvec === nothing
length(u0)
else
maximum(prob.f.colorvec)
end
# x = if prob.f.colorvec === nothing
# length(u0)
# else
# maximum(prob.f.colorvec)
# end

# cs = ForwardDiff.pickchunksize(x)
# return remake(alg,
# autodiff = AutoForwardDiff(
# chunksize = cs, tag = tag))
#else # statically sized
# cs = pick_static_chunksize(Val{L}())
# cs = SciMLBase._unwrap_val(cs)
# return remake(
# alg, autodiff = AutoForwardDiff(chunksize = cs, tag = tag))
#end
autodiff_alg
end

cs = ForwardDiff.pickchunksize(x)
return remake(alg,
autodiff = AutoForwardDiff(
chunksize = cs))
else # statically sized
cs = pick_static_chunksize(Val{L}())
cs = SciMLBase._unwrap_val(cs)
return remake(
alg, autodiff = AutoForwardDiff(chunksize = cs))
function prepare_ADType(alg::AutoFiniteDiff, prob, u0, p, standardtag)
# If the autodiff alg is AutoFiniteDiff, prob.f.f isa FunctionWrappersWrapper,
# and fdtype is complex, fdtype needs to change to something not complex
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that DI does not explicitly support complex numbers yet. What I mean by that is that we forward things to the backend as much as possible, so if the backend does support complex numbers then it will probably work, but there are no tests or hard API guarantees on that. See JuliaDiff/DifferentiationInterface.jl#646 for the discussion

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also note that some differentiation operators are not defined unambiguously for complex numbers (e.g. the derivative for complex input)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enzyme has an explicit variant of modes for complex numbers, that it probably would be wise to similarly wrap here (by default it will instead err warning about ambiguity if a function returns a complex number otherwise): https://enzyme.mit.edu/julia/stable/api/#EnzymeCore.ReverseHolomorphic . @gdalle I'm not sure DI supports this yet? so perhaps that means you may need to just call Enzyme.jacobian / autodiff directly in that case

Copy link

@gdalle gdalle Jan 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jClugstor can you maybe specify where we will encounter complex numbers by filling the following table?

derivative jacobian
complex inputs possible yes / no yes / no
complex outputs possible yes / no yes / no

When there are both complex inputs and complex outputs, that's where we run into trouble because we cannot represent derivatives as a single scalar. In that case, the differentiation operators are not clearly defined (the Jacobian matrix is basically twice as big as it should be) so we would need to figure out what convention the ODE solvers need (see https://discourse.julialang.org/t/taking-complex-autodiff-seriously-in-chainrules/39317).

@wsmoses I understand your concern, but I find it encouraging that DI actually allowed Enzyme to be used here for the first time (or at least so I've been told). This makes me think that the right approach is to handle complex numbers properly in DI instead of introducing a special case for Enzyme?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure adding proper complex number support to DI would be great, but a three line change here to use in-spec Complex support when there's already overloads for other ADTypes feels reasonable?

e.g. something like

function jacobian(f, x::AbstractArray{<:Complex}, integrator::WhatevertheTypeIs{<:AutoEnzyme})
  Enzyme.jacobian(ReverseHolomorphic, f, x)
end

from the discussion in JuliaDiff/DifferentiationInterface.jl#646 I think DI complex support is a much thornier issue. In particular, various tools have different conventions (e.g. jax vs pytorch pick different conjugates of what is propagated). So either DI needs to make a choice and shim/force all tools to use it (definitely doable), and then user code must be converted to that convention (e.g. a separate shim on the user side). For example, suppose DI picked a different conjugate from forwarddiff.jl. DI could write its shim once in forward diff to convert which is reasonable. But suppose one was defining a custom rule within ForwardDiff and the code called DI somewhere, now that user code needs to conditionally do a different the shim to conjugate which feels kind of nasty to be put everywhere (in contrast to a self consistent assumption). I suppose the other alternative is for DI to not pick a convention, but that again prevents users from using since it's not possible to know whether they get the correct value for them -- and worse, they won't know when they need to do a conversion or not.

Thus, if complex support is desired, a three line patch where things are explicitly supported seems okay (at least until the DI story is figured out)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that for now, this change seems to do the job (although it raises the question of consistency with the other backends that are handled via DI). But what will happen if the function in question is not holomorphic? That's the thorniest part of the problem, and that's why I wanted to inquire a bit more as to what kind of functions we can expect. Perhaps @jClugstor or @ChrisRackauckas can tell us more?

In any case, I have started a discussion on Discourse to figure out the right conventions: https://discourse.julialang.org/t/choosing-a-convention-for-complex-numbers-in-differentiationinterface/124433

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also note that the Enzyme-specific fix only handles dense Jacobians, not sparse Jacobians (which are one of the main reasons to use DI in the first place)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I can't really tell you much about the complex number support, other than previously only ForwardDiff or FiniteDiff were used, so when someone used an implicit solver on a complex problem, their conventions were used I guess. Also just wanted to note that the code this comment is on is just making sure that the FiniteDiff fdtype isn't complex if the function is a function wrapper and doesn't have to do with complex numbers through the solver in general.

if alg.fdtype == Val{:complex}() && (prob.f isa ODEFunction && prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper)
@warn "AutoFiniteDiff fdtype complex is not compatible with this function"
return AutoFiniteDiff(fdtype = Val{:forward}())
end
return alg
end

function prepare_ADType(alg::AbstractADType, prob, u0,p,standardtag)
return alg
end

#function prepare_ADType(alg::DiffEqAutoAD, prob, u0, p, standardtag)

#end

@generated function pick_static_chunksize(::Val{chunksize}) where {chunksize}
x = ForwardDiff.pickchunksize(chunksize)
:(Val{$x}())
Expand Down
26 changes: 22 additions & 4 deletions lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,18 @@ function calc_tderivative!(integrator, cache, dtd1, repeat_step)
else
tf.uprev = uprev
tf.p = p
derivative!(dT, tf, t, du2, integrator, cache.grad_config)
alg = unwrap_alg(integrator, true)
#derivative!(dT, tf, t, du2, integrator, cache.grad_config)
autodiff_alg = alg_autodiff(alg)

autodiff_alg = if autodiff_alg isa AutoSparse
ADTypes.dense_ad(autodiff_alg)
else
autodiff_alg
end

autodiff_alg = ADTypes.dense_ad(alg_autodiff(alg))
DI.derivative!(tf, linsolve_tmp, dT, cache.grad_config, autodiff_alg, t)
end
end

Expand All @@ -48,7 +59,7 @@ function calc_tderivative!(integrator, cache, dtd1, repeat_step)
end

function calc_tderivative(integrator, cache)
@unpack t, dt, uprev, u, f, p = integrator
@unpack t, dt, uprev, u, f, p, alg = integrator

# Time derivative
if DiffEqBase.has_tgrad(f)
Expand All @@ -57,7 +68,15 @@ function calc_tderivative(integrator, cache)
tf = cache.tf
tf.u = uprev
tf.p = p
dT = derivative(tf, t, integrator)

autodiff_alg = alg_autodiff(alg)
autodiff_alg = if autodiff_alg isa AutoSparse
autodiff_alg = ADTypes.dense_ad(autodiff_alg)
else
autodiff_alg
end

dT = DI.derivative(tf, autodiff_alg, t)
end
dT
end
Expand Down Expand Up @@ -97,7 +116,6 @@ function calc_J(integrator, cache, next_step::Bool = false)
uf.f = nlsolve_f(f, alg)
uf.p = p
uf.t = t

J = jacobian(uf, uprev, integrator)
end

Expand Down
Loading
Loading