Skip to content

Commit

Permalink
Stop relying on Enzyme internals (#511)
Browse files Browse the repository at this point in the history
* Stop relying on Enzyme internals

* Remove useless converter
  • Loading branch information
gdalle authored Sep 30, 2024
1 parent b93c17e commit c3f5360
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 153 deletions.
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ ADTypes = "1.9.0"
ChainRulesCore = "1.23.0"
Compat = "3.46,4.2"
Diffractor = "=0.2.6"
Enzyme = "0.13.2"
Enzyme = "0.13.6"
FastDifferentiation = "0.3.17"
FiniteDiff = "2.23.1"
FiniteDifferences = "0.12.31"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,15 @@ using Enzyme:
ForwardWithPrimal,
MixedDuplicated,
Mode,
NoPrimal,
Reverse,
ReverseMode,
ReverseModeSplit,
ReverseSplitNoPrimal,
ReverseSplitWidth,
ReverseSplitWithPrimal,
ReverseWithPrimal,
WithPrimal,
autodiff,
autodiff_thunk,
create_shadows,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function DI.value_and_pushforward(
dx_sametype = convert(typeof(x), only(tx))
x_and_dx = Duplicated(x, dx_sametype)
dy, y = autodiff(
forward_mode_withprimal(backend), f_and_df, x_and_dx, map(translate, contexts)...
forward_withprimal(backend), f_and_df, x_and_dx, map(translate, contexts)...
)
return y, (dy,)
end
Expand All @@ -39,7 +39,7 @@ function DI.value_and_pushforward(
tx_sametype = map(Fix1(convert, typeof(x)), tx)
x_and_tx = BatchDuplicated(x, tx_sametype)
ty, y = autodiff(
forward_mode_withprimal(backend), f_and_df, x_and_tx, map(translate, contexts)...
forward_withprimal(backend), f_and_df, x_and_tx, map(translate, contexts)...
)
return y, values(ty)
end
Expand All @@ -56,9 +56,7 @@ function DI.pushforward(
dx_sametype = convert(typeof(x), only(tx))
x_and_dx = Duplicated(x, dx_sametype)
dy = only(
autodiff(
forward_mode_noprimal(backend), f_and_df, x_and_dx, map(translate, contexts)...
),
autodiff(forward_noprimal(backend), f_and_df, x_and_dx, map(translate, contexts)...)
)
return (dy,)
end
Expand All @@ -75,9 +73,7 @@ function DI.pushforward(
tx_sametype = map(Fix1(convert, typeof(x)), tx)
x_and_tx = BatchDuplicated(x, tx_sametype)
ty = only(
autodiff(
forward_mode_noprimal(backend), f_and_df, x_and_tx, map(translate, contexts)...
),
autodiff(forward_noprimal(backend), f_and_df, x_and_tx, map(translate, contexts)...)
)
return values(ty)
end
Expand Down Expand Up @@ -134,7 +130,7 @@ function DI.gradient(
) where {F,B}
f_and_df = get_f_and_df(f, backend)
derivs = gradient(
forward_mode_noprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows
forward_noprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows
)
return only(derivs)
end
Expand All @@ -147,7 +143,7 @@ function DI.value_and_gradient(
) where {F,B}
f_and_df = get_f_and_df(f, backend)
(; derivs, val) = gradient(
forward_mode_withprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows
forward_withprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows
)
return val, only(derivs)
end
Expand Down Expand Up @@ -197,7 +193,7 @@ function DI.jacobian(
) where {F,B}
f_and_df = get_f_and_df(f, backend)
derivs = jacobian(
forward_mode_noprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows
forward_noprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows
)
jac_tensor = only(derivs)
return maybe_reshape(jac_tensor, prep.output_length, length(x))
Expand All @@ -211,7 +207,7 @@ function DI.value_and_jacobian(
) where {F,B}
f_and_df = get_f_and_df(f, backend)
(; derivs, val) = jacobian(
forward_mode_withprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows
forward_withprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows
)
jac_tensor = only(derivs)
return val, maybe_reshape(jac_tensor, prep.output_length, length(x))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function DI.value_and_pushforward(
x_and_dx = Duplicated(x, dx_sametype)
y_and_dy = Duplicated(y, dy_sametype)
autodiff(
forward_mode_noprimal(backend),
forward_noprimal(backend),
f!_and_df!,
Const,
y_and_dy,
Expand All @@ -51,7 +51,7 @@ function DI.value_and_pushforward(
x_and_tx = BatchDuplicated(x, tx_sametype)
y_and_ty = BatchDuplicated(y, ty_sametype)
autodiff(
forward_mode_noprimal(backend),
forward_noprimal(backend),
f!_and_df!,
Const,
y_and_ty,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function batch_seeded_autodiff_thunk(
::Type{RA},
args::Vararg{Annotation,N},
) where {ReturnPrimal,B,FA<:Annotation,RA<:Annotation,N}
rmode_rightwidth = set_width(rmode, Val(B))
rmode_rightwidth = ReverseSplitWidth(rmode, Val(B))
forward, reverse = autodiff_thunk(rmode_rightwidth, FA, RA, typeof.(args)...)
tape, result, shadow_results = forward(f, args...)
if RA <: Active
Expand Down Expand Up @@ -70,7 +70,7 @@ function DI.value_and_pullback(
contexts::Vararg{Context,C},
) where {F,C}
f_and_df = force_annotation(get_f_and_df(f, backend))
mode = reverse_mode_split_withprimal(backend)
mode = reverse_split_withprimal(backend)
RA = eltype(ty) <: Number ? Active : Duplicated
dinputs, result = seeded_autodiff_thunk(
mode, only(ty), f_and_df, RA, Active(x), map(translate, contexts)...
Expand All @@ -87,7 +87,7 @@ function DI.value_and_pullback(
contexts::Vararg{Context,C},
) where {F,B,C}
f_and_df = force_annotation(get_f_and_df(f, backend, Val(B)))
mode = reverse_mode_split_withprimal(backend)
mode = reverse_split_withprimal(backend)
RA = eltype(ty) <: Number ? Active : BatchDuplicated
dinputs, result = batch_seeded_autodiff_thunk(
mode, ty, f_and_df, RA, Active(x), map(translate, contexts)...
Expand All @@ -104,7 +104,7 @@ function DI.value_and_pullback(
contexts::Vararg{Context,C},
) where {F,C}
f_and_df = force_annotation(get_f_and_df(f, backend))
mode = reverse_mode_split_withprimal(backend)
mode = reverse_split_withprimal(backend)
RA = eltype(ty) <: Number ? Active : Duplicated
dx = make_zero(x)
_, result = seeded_autodiff_thunk(
Expand All @@ -122,7 +122,7 @@ function DI.value_and_pullback(
contexts::Vararg{Context,C},
) where {F,B,C}
f_and_df = force_annotation(get_f_and_df(f, backend, Val(B)))
mode = reverse_mode_split_withprimal(backend)
mode = reverse_split_withprimal(backend)
RA = eltype(ty) <: Number ? Active : BatchDuplicated
tx = ntuple(_ -> make_zero(x), Val(B))
_, result = batch_seeded_autodiff_thunk(
Expand Down Expand Up @@ -154,7 +154,7 @@ function DI.value_and_pullback!(
contexts::Vararg{Context,C},
) where {F,C}
f_and_df = force_annotation(get_f_and_df(f, backend))
mode = reverse_mode_split_withprimal(backend)
mode = reverse_split_withprimal(backend)
RA = eltype(ty) <: Number ? Active : Duplicated
dx_righttype = convert(typeof(x), only(tx))
make_zero!(dx_righttype)
Expand All @@ -180,7 +180,7 @@ function DI.value_and_pullback!(
contexts::Vararg{Context,C},
) where {F,B,C}
f_and_df = force_annotation(get_f_and_df(f, backend, Val(B)))
mode = reverse_mode_split_withprimal(backend)
mode = reverse_split_withprimal(backend)
RA = eltype(ty) <: Number ? Active : BatchDuplicated
tx_righttype = map(Fix1(convert, typeof(x)), tx)
make_zero!(tx_righttype)
Expand Down Expand Up @@ -227,9 +227,7 @@ function DI.gradient(
contexts::Vararg{Context,C},
) where {F,C}
f_and_df = get_f_and_df(f, backend)
derivs = gradient(
reverse_mode_noprimal(backend), f_and_df, x, map(translate, contexts)...
)
derivs = gradient(reverse_noprimal(backend), f_and_df, x, map(translate, contexts)...)
return first(derivs)
end

Expand All @@ -245,7 +243,7 @@ function DI.gradient!(
dx_righttype = convert(typeof(x), grad)
make_zero!(dx_righttype)
autodiff(
reverse_mode_noprimal(backend),
reverse_noprimal(backend),
f_and_df,
Active,
Duplicated(x, dx_righttype),
Expand All @@ -264,7 +262,7 @@ function DI.value_and_gradient(
) where {F,C}
f_and_df = get_f_and_df(f, backend)
(; derivs, val) = gradient(
reverse_mode_withprimal(backend), f_and_df, x, map(translate, contexts)...
reverse_withprimal(backend), f_and_df, x, map(translate, contexts)...
)
return val, first(derivs)
end
Expand All @@ -281,7 +279,7 @@ function DI.value_and_gradient!(
dx_righttype = convert(typeof(x), grad)
make_zero!(dx_righttype)
_, y = autodiff(
reverse_mode_withprimal(backend),
reverse_withprimal(backend),
f_and_df,
Active,
Duplicated(x, dx_righttype),
Expand All @@ -308,7 +306,7 @@ function DI.jacobian(
backend::AutoEnzyme{<:ReverseMode,Nothing},
x,
) where {F,Sy,B}
derivs = jacobian(reverse_mode_noprimal(backend), f, x; n_outs=Val(Sy), chunk=Val(B))
derivs = jacobian(reverse_noprimal(backend), f, x; n_outs=Val(Sy), chunk=Val(B))
jac_tensor = only(derivs)
return maybe_reshape(jac_tensor, prod(Sy), length(x))
end
Expand All @@ -320,7 +318,7 @@ function DI.value_and_jacobian(
x,
) where {F,Sy,B}
(; derivs, val) = jacobian(
reverse_mode_withprimal(backend), f, x; n_outs=Val(Sy), chunk=Val(B)
reverse_withprimal(backend), f, x; n_outs=Val(Sy), chunk=Val(B)
)
jac_tensor = only(derivs)
return val, maybe_reshape(jac_tensor, prod(Sy), length(x))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function DI.value_and_pullback(
y_and_dy = Duplicated(y, dy_sametype)
dinputs = only(
autodiff(
reverse_mode_noprimal(backend),
reverse_noprimal(backend),
f!_and_df!,
Const,
y_and_dy,
Expand All @@ -51,7 +51,7 @@ function DI.value_and_pullback(
y_and_ty = BatchDuplicated(y, ty_sametype)
dinputs = only(
autodiff(
reverse_mode_noprimal(backend),
reverse_noprimal(backend),
f!_and_df!,
Const,
y_and_ty,
Expand All @@ -78,7 +78,7 @@ function DI.value_and_pullback(
x_and_dx = Duplicated(x, dx_sametype)
y_and_dy = Duplicated(y, dy_sametype)
autodiff(
reverse_mode_noprimal(backend),
reverse_noprimal(backend),
f!_and_df!,
Const,
y_and_dy,
Expand All @@ -103,7 +103,7 @@ function DI.value_and_pullback(
x_and_tx = BatchDuplicated(x, tx_sametype)
y_and_ty = BatchDuplicated(y, ty_sametype)
autodiff(
reverse_mode_noprimal(backend),
reverse_noprimal(backend),
f!_and_df!,
Const,
y_and_ty,
Expand Down
Loading

0 comments on commit c3f5360

Please sign in to comment.