Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better HVP in reverse over forward #494

Merged
merged 1 commit into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 30 additions & 25 deletions DifferentiationInterface/src/second_order/hvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ struct ForwardOverReverseHVPPrep{G,E<:PushforwardPrep} <: HVPPrep
outer_pushforward_prep::E
end

struct ReverseOverForwardHVPPrep <: HVPPrep end
struct ReverseOverForwardHVPPrep{P,E} <: HVPPrep
inner_pushforward::P
outer_gradient_prep::E
end

struct ReverseOverReverseHVPPrep{G,E<:PullbackPrep} <: HVPPrep
inner_gradient::G
Expand Down Expand Up @@ -111,9 +114,19 @@ function _prepare_hvp_aux(
tx::Tangents,
contexts::Vararg{Context,C},
) where {F,C}
rewrap = Rewrap(contexts...)
# gradient of pushforward
# uses dx in the closure so it can't be prepared
return ReverseOverForwardHVPPrep()
function inner_pushforward(_x, _dx, unannotated_contexts...)
annotated_contexts = rewrap(unannotated_contexts...)
ty = pushforward(
f, nested(inner(backend)), _x, Tangents(_dx), annotated_contexts...
)
return only(ty)
end
outer_gradient_prep = prepare_gradient(
inner_pushforward, outer(backend), x, contexts...
)
return ReverseOverForwardHVPPrep(inner_pushforward, outer_gradient_prep)
end

function _prepare_hvp_aux(
Expand Down Expand Up @@ -168,23 +181,15 @@ end

function hvp(
f::F,
::ReverseOverForwardHVPPrep,
prep::ReverseOverForwardHVPPrep,
backend::AbstractADType,
x,
tx::Tangents,
contexts::Vararg{Context,C},
) where {F,C}
rewrap = Rewrap(contexts...)
@compat (; inner_pushforward, outer_gradient_prep) = prep
tg = map(tx) do dx
function inner_pushforward(_x, unannotated_contexts...)
annotated_contexts = rewrap(unannotated_contexts...)
return only(
pushforward(
f, nested(inner(backend)), _x, Tangents(dx), annotated_contexts...
),
)
end
gradient(only ∘ inner_pushforward, outer(backend), x, contexts...)
gradient(inner_pushforward, outer(backend), x, Constant(dx), contexts...)
end
return tg
end
Expand Down Expand Up @@ -234,23 +239,23 @@ end
function hvp!(
f::F,
tg::Tangents,
::ReverseOverForwardHVPPrep,
prep::ReverseOverForwardHVPPrep,
backend::AbstractADType,
x,
tx::Tangents,
contexts::Vararg{Context,C},
) where {F,C}
rewrap = Rewrap(contexts...)
@compat (; inner_pushforward, outer_gradient_prep) = prep
for b in eachindex(tx.d, tg.d)
function inner_pushforward(_x, unannotated_contexts...)
annotated_contexts = rewrap(unannotated_contexts...)
return only(
pushforward(
f, nested(inner(backend)), _x, Tangents(tx.d[b]), annotated_contexts...
),
)
end
gradient!(only ∘ inner_pushforward, tg.d[b], outer(backend), x, contexts...)
gradient!(
inner_pushforward,
tg.d[b],
outer_gradient_prep,
outer(backend),
x,
Constant(tx.d[b]),
contexts...,
)
end
return tg
end
Expand Down
18 changes: 18 additions & 0 deletions DifferentiationInterface/test/Misc/FromPrimitive/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@ fromprimitive_backends = [ #
AutoReverseFromPrimitive(AutoForwardDiff(; chunksize=5)),
]

fromprimitive_secondorder_backends = [ #
SecondOrder(
AutoForwardFromPrimitive(AutoForwardDiff(; chunksize=5)),
AutoReverseFromPrimitive(AutoForwardDiff(; chunksize=5)),
),
SecondOrder(
AutoReverseFromPrimitive(AutoForwardDiff(; chunksize=5)),
AutoForwardFromPrimitive(AutoForwardDiff(; chunksize=5)),
),
]

for backend in vcat(fromprimitive_backends)
@test check_available(backend)
@test check_inplace(backend)
Expand All @@ -19,3 +30,10 @@ end
test_differentiation(
fromprimitive_backends, default_scenarios(; include_constantified=true); logging=LOGGING
);

test_differentiation(
fromprimitive_secondorder_backends,
default_scenarios(; include_constantified=true);
first_order=false,
logging=LOGGING,
);
Loading