Skip to content

Commit

Permalink
Fix docstring for value_and_pullback_function
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Dec 21, 2023
1 parent 211b675 commit 97f6d4d
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 28 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "AbstractDifferentiation"
uuid = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
authors = ["Mohamed Tarek <mohamed82008@gmail.com> and contributors"]
version = "0.6.0"
version = "0.6.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
61 changes: 34 additions & 27 deletions src/AbstractDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,44 +226,48 @@ The pushfoward function `pf` accepts as input a `Tuple` of tangents, one for eac
If `xs` consists of a single element, `pf` can also accept a single tangent instead of a 1-tuple.
"""
function pushforward_function(ab::AbstractBackend, f, xs...)
return (ds) -> begin
return jacobian(
lowest(ab),
(xds...,) -> begin
if ds isa Tuple
@assert length(xs) == length(ds)
newxs = xs .+ ds .* xds
return f(newxs...)
else
newx = only(xs) + ds * only(xds)
return f(newx)
end
end,
_zero.(xs, ds)...,
)
function pf(ds)
function pf_aux(xds...)
if ds isa Tuple
@assert length(xs) == length(ds)
newxs = xs .+ ds .* xds
return f(newxs...)
else
newx = only(xs) + ds * only(xds)
return f(newx)

Check warning on line 237 in src/AbstractDifferentiation.jl

View check run for this annotation

Codecov / codecov/patch

src/AbstractDifferentiation.jl#L236-L237

Added lines #L236 - L237 were not covered by tests
end
end,
return jacobian(lowest(ab), pf_aux, _zero.(xs, ds)...)
end
return pf
end

"""
AD.value_and_pushforward_function(ab::AD.AbstractBackend, f, xs...)
Return a function that, given tangents `ts`, computes the tuple `(v, p)` of the function value `v = f(xs...)` and the output `p` of the pushforward function `AD.pushforward_function(ab, f, xs...)` applied to `ts`.
Return a function `vpf` which, given tangents `ts`, computes the tuple `(v, p) = vpf(ts)` composed of
- the function value `v = f(xs...)`
- the pushforward value `p = pf(ts)` given by the pushforward function `pf = AD.pushforward_function(ab, f, xs...)` applied to `ts`.
See also [`AbstractDifferentiation.pushforward_function`](@ref).
!!! warning
This name should be understood as "(value and pushforward) function", and thus is not aligned with the reverse mode counterpart [`AbstractDifferentiation.value_and_pullback_function`](@ref).
"""
function value_and_pushforward_function(ab::AbstractBackend, f, xs...)
n = length(xs)
value = f(xs...)
pf_function = pushforward_function(lowest(ab), f, xs...)
pf = pushforward_function(lowest(ab), f, xs...)

return ds -> begin
function vpf(ds)
if !(ds isa Tuple)
ds = (ds,)
end
@assert length(ds) == n
pf = pf_function(ds)
return value, pf
return value, pf(ds)
end
return vpf
end

_zero(::Number, d::Number) = zero(d)
Expand Down Expand Up @@ -291,21 +295,24 @@ The pullback function `pb` accepts as input a `Tuple` of cotangents, one for eac
If `f` has a single output, `pb` can also accept a single input instead of a 1-tuple.
"""
function pullback_function(ab::AbstractBackend, f, xs...)
_, pbf = value_and_pullback_function(ab, f, xs...)
return pbf
_, pb = value_and_pullback_function(ab, f, xs...)
return pb
end

"""
AD.value_and_pullback_function(ab::AD.AbstractBackend, f, xs...)
Return a function that, given cotangents `ts`, computes the tuple `(v, p)` of the function value `v = f(xs...)` and the output `p` of the pullback function `AD.pullback_function(ab, f, xs...)` applied to `ts`.
Return a tuple `(v, pb)` of the function value `v = f(xs...)` and the pullback function `pb = AD.pullback_function(ab, f, xs...)`.
See also [`AbstractDifferentiation.pullback_function`](@ref).
!!! warning
This name should be understood as "value and (pullback function)", and thus is not aligned with the forward mode counterpart [`AbstractDifferentiation.value_and_pushforward_function`](@ref).
"""
function value_and_pullback_function(ab::AbstractBackend, f, xs...)
value = f(xs...)
function pullback_function(ws)
function pullback_gradient_function(_xs...)
function pb(ws)
function pb_aux(_xs...)
vs = f(_xs...)
if ws isa Tuple
@assert length(vs) == length(ws)
Expand All @@ -314,9 +321,9 @@ function value_and_pullback_function(ab::AbstractBackend, f, xs...)
return _dot(vs, ws)
end
end
return gradient(lowest(ab), pullback_gradient_function, xs...)
return gradient(lowest(ab), pb_aux, xs...)
end
return value, pullback_function
return value, pb
end

struct LazyDerivative{B,F,X}
Expand Down

0 comments on commit 97f6d4d

Please sign in to comment.