Skip to content

Commit

Permalink
Fix Enzyme's batched pullback and Jacobian (#499)
Browse files Browse the repository at this point in the history
* Fix Enzyme's batched pullback and Jacobian

* Version
  • Loading branch information
gdalle authored Sep 25, 2024
1 parent ea192f6 commit d9a5cab
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 26 deletions.
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ ADTypes = "1.7.0"
ChainRulesCore = "1.23.0"
Compat = "3.46,4.2"
Diffractor = "=0.2.6"
Enzyme = "0.13.1"
Enzyme = "0.13.2"
FastDifferentiation = "0.3.17"
FiniteDiff = "2.23.1"
FiniteDifferences = "0.12.31"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ end

function batch_seeded_autodiff_thunk(
rmode::ReverseModeSplit{ReturnPrimal},
dresults::NTuple,
dresults::NTuple{B},
f::FA,
::Type{RA},
args::Vararg{Annotation,N},
) where {ReturnPrimal,FA<:Annotation,RA<:Annotation,N}
forward, reverse = autodiff_thunk(rmode, FA, RA, typeof.(args)...)
) where {ReturnPrimal,B,FA<:Annotation,RA<:Annotation,N}
rmode_rightwidth = set_width(rmode, Val(B))
forward, reverse = autodiff_thunk(rmode_rightwidth, FA, RA, typeof.(args)...)
tape, result, shadow_results = forward(f, args...)
if RA <: Active
dresults_righttype = map(Fix1(convert, typeof(result)), dresults)
Expand Down Expand Up @@ -79,20 +80,19 @@ end

function DI.value_and_pullback(
f::F,
prep::NoPullbackPrep,
::NoPullbackPrep,
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
x::Number,
ty::Tangents{B},
contexts::Vararg{Context,C},
) where {F,B,C}
# TODO: improve
ys_and_dxs = map(ty.d) do dy
y, tx = DI.value_and_pullback(f, prep, backend, x, Tangents(dy), contexts...)
y, only(tx)
end
y = first(ys_and_dxs[1])
dxs = last.(ys_and_dxs)
return y, Tangents(dxs...)
f_and_df = force_annotation(get_f_and_df(f, backend, Val(B)))
mode = reverse_mode_split_withprimal(backend)
RA = eltype(ty) <: Number ? Active : BatchDuplicated
dinputs, result = batch_seeded_autodiff_thunk(
mode, NTuple(ty), f_and_df, RA, Active(x), map(translate, contexts)...
)
return result, Tangents(first(dinputs)...)
end

function DI.value_and_pullback(
Expand Down Expand Up @@ -293,37 +293,37 @@ end

## Jacobian

struct EnzymeReverseOneArgJacobianPrep{M,B} <: JacobianPrep end
struct EnzymeReverseOneArgJacobianPrep{Sy,B} <: JacobianPrep end

function DI.prepare_jacobian(f::F, backend::AutoEnzyme{<:ReverseMode,Nothing}, x) where {F}
y = f(x)
M = length(y)
B = pick_batchsize(backend, M)
return EnzymeReverseOneArgJacobianPrep{M,B}()
Sy = size(y)
B = pick_batchsize(backend, prod(Sy))
return EnzymeReverseOneArgJacobianPrep{Sy,B}()
end

function DI.jacobian(
f::F,
::EnzymeReverseOneArgJacobianPrep{M,B},
::EnzymeReverseOneArgJacobianPrep{Sy,B},
backend::AutoEnzyme{<:ReverseMode,Nothing},
x,
) where {F,M,B}
derivs = jacobian(reverse_mode_noprimal(backend), f, x; n_outs=Val((M,)), chunk=Val(B))
) where {F,Sy,B}
derivs = jacobian(reverse_mode_noprimal(backend), f, x; n_outs=Val(Sy), chunk=Val(B))
jac_tensor = only(derivs)
return maybe_reshape(jac_tensor, M, length(x))
return maybe_reshape(jac_tensor, prod(Sy), length(x))
end

function DI.value_and_jacobian(
f::F,
prep::EnzymeReverseOneArgJacobianPrep{M,B},
::EnzymeReverseOneArgJacobianPrep{Sy,B},
backend::AutoEnzyme{<:ReverseMode,Nothing},
x,
) where {F,M,B}
) where {F,Sy,B}
(; derivs, val) = jacobian(
reverse_mode_withprimal(backend), f, x; n_outs=Val((M,)), chunk=Val(B)
reverse_mode_withprimal(backend), f, x; n_outs=Val(Sy), chunk=Val(B)
)
jac_tensor = derivs
return val, maybe_reshape(jac_tensor, M, length(x))
jac_tensor = only(derivs)
return val, maybe_reshape(jac_tensor, prod(Sy), length(x))
end

function DI.jacobian!(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,32 @@ function mode_split(
}()
end

function set_width(
::ReverseModeSplit{
ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI,ErrIfFuncWritten
},
::Val{NewWidth},
) where {
ReturnPrimal,
ReturnShadow,
RuntimeActivity,
Width,
ModifiedBetween,
ABI,
ErrIfFuncWritten,
NewWidth,
}
return ReverseModeSplit{
ReturnPrimal,
ReturnShadow,
RuntimeActivity,
NewWidth,
ModifiedBetween,
ABI,
ErrIfFuncWritten,
}()
end

mode_noprimal(mode::Mode) = mode_noprimal(typeof(mode))
mode_withprimal(mode::Mode) = mode_withprimal(typeof(mode))
mode_split(mode::Mode) = mode_split(typeof(mode))
Expand Down

0 comments on commit d9a5cab

Please sign in to comment.