You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello! I have a minimum example here about Flux and ChainRulesCore
using Flux, ChainRulesCore
import ChainRulesCore.rrule
functionf(a::Float32, b::Float32)
return a * b
endfunctionrrule(::typeof(f), a::Float32, b::Float32)
println("rrule is called")
y =f(a,b)
functionpullback(Δy)
da =@thunk(∇a(a,b,Δy))
db =@thunk(∇b(a,b,Δy))
return (NoTangent(), da, db)
endreturn y, pullback
endfunction∇a(a,b,Δy)
println("∇a is called")
return b * Δy
endfunction∇b(a,b,Δy)
println("∇b is called")
return a * Δy
end
a =1f0
b =2f0
ga =gradient(()->f(a,b), Flux.params(a))
which defines my custom function f (as a multiplication of 2 scalars) and defines the rrule from ChainRulesCore. In the last line, when I compute gradient w.r.t. variable a only, as ga = gradient(()->f(a,b), Flux.params(a)) , I expect to only see ∇a being called but actually I see both of ∇a and ∇b being called in the log
rrule is called
∇a is called
∇b is called
any idea why? This could be problematic when f is complicated function and it is unnecessary to call ∇b if time-consuming. Thanks for any help!
The text was updated successfully, but these errors were encountered:
Hello! I have a minimum example here about Flux and ChainRulesCore
which defines my custom function
f
(as a multiplication of 2 scalars) and defines therrule
from ChainRulesCore. In the last line, when I compute gradient w.r.t. variablea
only, asga = gradient(()->f(a,b), Flux.params(a))
, I expect to only see∇a
being called but actually I see both of∇a
and∇b
being called in the logany idea why? This could be problematic when
f
is complicated function and it is unnecessary to call∇b
if time-consuming. Thanks for any help!The text was updated successfully, but these errors were encountered: