From e89aab395578e9d06460ee07de957b24cfe18dd7 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 1 Sep 2024 19:54:41 +0200 Subject: [PATCH] Fix FiniteDiff derivative (#436) --- .../reverse_onearg.jl | 2 +- .../twoarg.jl | 21 ++++++++++--------- .../DifferentiationInterfaceTrackerExt.jl | 4 +++- .../DifferentiationInterfaceZygoteExt.jl | 4 +++- 4 files changed, 18 insertions(+), 13 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl index f20579dc1..459086493 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl @@ -8,7 +8,7 @@ end DI.prepare_pullback(f, ::AutoReverseChainRules, x, ty::Tangents) = NoPullbackExtras() function DI.prepare_pullback_same_point( - f, backend::AutoReverseChainRules, x, ty::Tangents, ::PullbackExtras=NoPullbackExtras() + f, backend::AutoReverseChainRules, x, ty::Tangents, ::NoPullbackExtras ) rc = ruleconfig(backend) y, pb = rrule_via_ad(rc, f, x) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl index 4e2593c35..47fe7afdb 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl @@ -23,39 +23,40 @@ struct FiniteDiffTwoArgDerivativeExtras{C} <: DerivativeExtras cache::C end -function DI.prepare_derivative(f!, y, ::AutoFiniteDiff, x) - cache = nothing +function DI.prepare_derivative(f!, y, backend::AutoFiniteDiff, x) + df = similar(y) + cache = GradientCache(df, x, fdtype(backend), eltype(y), FUNCTION_INPLACE) return FiniteDiffTwoArgDerivativeExtras(cache) end function DI.value_and_derivative( - f!, y, backend::AutoFiniteDiff, x, ::FiniteDiffTwoArgDerivativeExtras + f!, y, backend::AutoFiniteDiff, x, extras::FiniteDiffTwoArgDerivativeExtras ) f!(y, x) - der = finite_difference_gradient(f!, x, fdtype(backend), eltype(y), FUNCTION_INPLACE, y) + der = finite_difference_gradient(f!, x, extras.cache) return y, der end function DI.value_and_derivative!( - f!, y, der, backend::AutoFiniteDiff, x, ::FiniteDiffTwoArgDerivativeExtras + f!, y, der, backend::AutoFiniteDiff, x, extras::FiniteDiffTwoArgDerivativeExtras ) f!(y, x) - finite_difference_gradient!(der, f!, x, fdtype(backend), eltype(y), FUNCTION_INPLACE, y) + finite_difference_gradient!(der, f!, x, extras.cache) return y, der end function DI.derivative( - f!, y, backend::AutoFiniteDiff, x, ::FiniteDiffTwoArgDerivativeExtras + f!, y, backend::AutoFiniteDiff, x, extras::FiniteDiffTwoArgDerivativeExtras ) f!(y, x) - der = finite_difference_gradient(f!, x, fdtype(backend), eltype(y), FUNCTION_INPLACE, y) + der = finite_difference_gradient(f!, x, extras.cache) return der end function DI.derivative!( - f!, y, der, backend::AutoFiniteDiff, x, ::FiniteDiffTwoArgDerivativeExtras + f!, y, der, backend::AutoFiniteDiff, x, extras::FiniteDiffTwoArgDerivativeExtras ) - finite_difference_gradient!(der, f!, x, fdtype(backend), eltype(y), FUNCTION_INPLACE) + finite_difference_gradient!(der, f!, x, extras.cache) return der end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl index f6c27df29..17f6cdce5 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl @@ -18,7 +18,9 @@ end DI.prepare_pullback(f, ::AutoTracker, x, ty::Tangents) = NoPullbackExtras() -function DI.prepare_pullback_same_point(f, ::AutoTracker, x, ty::Tangents, ::PullbackExtras) +function DI.prepare_pullback_same_point( + f, ::AutoTracker, x, ty::Tangents, ::NoPullbackExtras +) y, pb = forward(f, x) return TrackerPullbackExtrasSamePoint(y, pb) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index c33e12f00..7feb1389d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -28,7 +28,9 @@ end DI.prepare_pullback(f, ::AutoZygote, x, ty::Tangents) = NoPullbackExtras() -function DI.prepare_pullback_same_point(f, ::AutoZygote, x, ty::Tangents, ::PullbackExtras) +function DI.prepare_pullback_same_point( + f, ::AutoZygote, x, ty::Tangents, ::NoPullbackExtras +) y, pb = pullback(f, x) return ZygotePullbackExtrasSamePoint(y, pb) end