From 97f6d4da79c6fa78e084034908b346fc92979254 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 21 Dec 2023 11:22:13 +0100 Subject: [PATCH] Fix docstring for value_and_pullback_function --- Project.toml | 2 +- src/AbstractDifferentiation.jl | 61 +++++++++++++++++++--------------- 2 files changed, 35 insertions(+), 28 deletions(-) diff --git a/Project.toml b/Project.toml index 12627f8..22d4707 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "AbstractDifferentiation" uuid = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" authors = ["Mohamed Tarek and contributors"] -version = "0.6.0" +version = "0.6.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/AbstractDifferentiation.jl b/src/AbstractDifferentiation.jl index a868b37..36a2a99 100644 --- a/src/AbstractDifferentiation.jl +++ b/src/AbstractDifferentiation.jl @@ -226,44 +226,48 @@ The pushfoward function `pf` accepts as input a `Tuple` of tangents, one for eac If `xs` consists of a single element, `pf` can also accept a single tangent instead of a 1-tuple. """ function pushforward_function(ab::AbstractBackend, f, xs...) - return (ds) -> begin - return jacobian( - lowest(ab), - (xds...,) -> begin - if ds isa Tuple - @assert length(xs) == length(ds) - newxs = xs .+ ds .* xds - return f(newxs...) - else - newx = only(xs) + ds * only(xds) - return f(newx) - end - end, - _zero.(xs, ds)..., - ) + function pf(ds) + function pf_aux(xds...) + if ds isa Tuple + @assert length(xs) == length(ds) + newxs = xs .+ ds .* xds + return f(newxs...) + else + newx = only(xs) + ds * only(xds) + return f(newx) + end + end, + return jacobian(lowest(ab), pf_aux, _zero.(xs, ds)...) end + return pf end """ AD.value_and_pushforward_function(ab::AD.AbstractBackend, f, xs...) -Return a function that, given tangents `ts`, computes the tuple `(v, p)` of the function value `v = f(xs...)` and the output `p` of the pushforward function `AD.pushforward_function(ab, f, xs...)` applied to `ts`. +Return a function `vpf` which, given tangents `ts`, computes the tuple `(v, p) = vpf(ts)` composed of + +- the function value `v = f(xs...)` +- the pushforward value `p = pf(ts)` given by the pushforward function `pf = AD.pushforward_function(ab, f, xs...)` applied to `ts`. See also [`AbstractDifferentiation.pushforward_function`](@ref). + +!!! warning + This name should be understood as "(value and pushforward) function", and thus is not aligned with the reverse mode counterpart [`AbstractDifferentiation.value_and_pullback_function`](@ref). """ function value_and_pushforward_function(ab::AbstractBackend, f, xs...) n = length(xs) value = f(xs...) - pf_function = pushforward_function(lowest(ab), f, xs...) + pf = pushforward_function(lowest(ab), f, xs...) - return ds -> begin + function vpf(ds) if !(ds isa Tuple) ds = (ds,) end @assert length(ds) == n - pf = pf_function(ds) - return value, pf + return value, pf(ds) end + return vpf end _zero(::Number, d::Number) = zero(d) @@ -291,21 +295,24 @@ The pullback function `pb` accepts as input a `Tuple` of cotangents, one for eac If `f` has a single output, `pb` can also accept a single input instead of a 1-tuple. """ function pullback_function(ab::AbstractBackend, f, xs...) - _, pbf = value_and_pullback_function(ab, f, xs...) - return pbf + _, pb = value_and_pullback_function(ab, f, xs...) + return pb end """ AD.value_and_pullback_function(ab::AD.AbstractBackend, f, xs...) -Return a function that, given cotangents `ts`, computes the tuple `(v, p)` of the function value `v = f(xs...)` and the output `p` of the pullback function `AD.pullback_function(ab, f, xs...)` applied to `ts`. +Return a tuple `(v, pb)` of the function value `v = f(xs...)` and the pullback function `pb = AD.pullback_function(ab, f, xs...)`. See also [`AbstractDifferentiation.pullback_function`](@ref). + +!!! warning + This name should be understood as "value and (pullback function)", and thus is not aligned with the forward mode counterpart [`AbstractDifferentiation.value_and_pushforward_function`](@ref). """ function value_and_pullback_function(ab::AbstractBackend, f, xs...) value = f(xs...) - function pullback_function(ws) - function pullback_gradient_function(_xs...) + function pb(ws) + function pb_aux(_xs...) vs = f(_xs...) if ws isa Tuple @assert length(vs) == length(ws) @@ -314,9 +321,9 @@ function value_and_pullback_function(ab::AbstractBackend, f, xs...) return _dot(vs, ws) end end - return gradient(lowest(ab), pullback_gradient_function, xs...) + return gradient(lowest(ab), pb_aux, xs...) end - return value, pullback_function + return value, pb end struct LazyDerivative{B,F,X}