Skip to content

Commit

Permalink
Fix FiniteDiff derivative (#436)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Sep 1, 2024
1 parent 7c60378 commit e89aab3
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ end
DI.prepare_pullback(f, ::AutoReverseChainRules, x, ty::Tangents) = NoPullbackExtras()

function DI.prepare_pullback_same_point(
f, backend::AutoReverseChainRules, x, ty::Tangents, ::PullbackExtras=NoPullbackExtras()
f, backend::AutoReverseChainRules, x, ty::Tangents, ::NoPullbackExtras
)
rc = ruleconfig(backend)
y, pb = rrule_via_ad(rc, f, x)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,39 +23,40 @@ struct FiniteDiffTwoArgDerivativeExtras{C} <: DerivativeExtras
cache::C
end

function DI.prepare_derivative(f!, y, ::AutoFiniteDiff, x)
cache = nothing
function DI.prepare_derivative(f!, y, backend::AutoFiniteDiff, x)
df = similar(y)
cache = GradientCache(df, x, fdtype(backend), eltype(y), FUNCTION_INPLACE)
return FiniteDiffTwoArgDerivativeExtras(cache)
end

function DI.value_and_derivative(
f!, y, backend::AutoFiniteDiff, x, ::FiniteDiffTwoArgDerivativeExtras
f!, y, backend::AutoFiniteDiff, x, extras::FiniteDiffTwoArgDerivativeExtras
)
f!(y, x)
der = finite_difference_gradient(f!, x, fdtype(backend), eltype(y), FUNCTION_INPLACE, y)
der = finite_difference_gradient(f!, x, extras.cache)
return y, der
end

function DI.value_and_derivative!(
f!, y, der, backend::AutoFiniteDiff, x, ::FiniteDiffTwoArgDerivativeExtras
f!, y, der, backend::AutoFiniteDiff, x, extras::FiniteDiffTwoArgDerivativeExtras
)
f!(y, x)
finite_difference_gradient!(der, f!, x, fdtype(backend), eltype(y), FUNCTION_INPLACE, y)
finite_difference_gradient!(der, f!, x, extras.cache)
return y, der
end

function DI.derivative(
f!, y, backend::AutoFiniteDiff, x, ::FiniteDiffTwoArgDerivativeExtras
f!, y, backend::AutoFiniteDiff, x, extras::FiniteDiffTwoArgDerivativeExtras
)
f!(y, x)
der = finite_difference_gradient(f!, x, fdtype(backend), eltype(y), FUNCTION_INPLACE, y)
der = finite_difference_gradient(f!, x, extras.cache)
return der
end

function DI.derivative!(
f!, y, der, backend::AutoFiniteDiff, x, ::FiniteDiffTwoArgDerivativeExtras
f!, y, der, backend::AutoFiniteDiff, x, extras::FiniteDiffTwoArgDerivativeExtras
)
finite_difference_gradient!(der, f!, x, fdtype(backend), eltype(y), FUNCTION_INPLACE)
finite_difference_gradient!(der, f!, x, extras.cache)
return der
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ end

DI.prepare_pullback(f, ::AutoTracker, x, ty::Tangents) = NoPullbackExtras()

function DI.prepare_pullback_same_point(f, ::AutoTracker, x, ty::Tangents, ::PullbackExtras)
function DI.prepare_pullback_same_point(
f, ::AutoTracker, x, ty::Tangents, ::NoPullbackExtras
)
y, pb = forward(f, x)
return TrackerPullbackExtrasSamePoint(y, pb)
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ end

DI.prepare_pullback(f, ::AutoZygote, x, ty::Tangents) = NoPullbackExtras()

function DI.prepare_pullback_same_point(f, ::AutoZygote, x, ty::Tangents, ::PullbackExtras)
function DI.prepare_pullback_same_point(
f, ::AutoZygote, x, ty::Tangents, ::NoPullbackExtras
)
y, pb = pullback(f, x)
return ZygotePullbackExtrasSamePoint(y, pb)
end
Expand Down

0 comments on commit e89aab3

Please sign in to comment.