Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

thunk still runs for non Flux.params which leads to unnecessary computation #558

Closed
ziyiyin97 opened this issue Jun 21, 2022 · 3 comments
Closed
Labels

Comments

@ziyiyin97
Copy link

Hello! I have a minimum example here about Flux and ChainRulesCore

using Flux, ChainRulesCore
import ChainRulesCore.rrule

function f(a::Float32, b::Float32)
    return a * b
end

function rrule(::typeof(f), a::Float32, b::Float32)
    println("rrule is called")
    y = f(a,b)
    function pullback(Δy)
        da = @thunk(∇a(a,b,Δy))
        db = @thunk(∇b(a,b,Δy))
        return (NoTangent(), da, db)
    end
    return y, pullback
end

function ∇a(a,b,Δy)
    println("∇a is called")
    return b * Δy
end

function ∇b(a,b,Δy)
    println("∇b is called")
    return a * Δy
end

a = 1f0
b = 2f0

ga = gradient(()->f(a,b), Flux.params(a))

which defines my custom function f (as a multiplication of 2 scalars) and defines the rrule from ChainRulesCore. In the last line, when I compute gradient w.r.t. variable a only, as ga = gradient(()->f(a,b), Flux.params(a)) , I expect to only see ∇a being called but actually I see both of ∇a and ∇b being called in the log

rrule is called
∇a is called
∇b is called

any idea why? This could be problematic when f is complicated function and it is unnecessary to call ∇b if time-consuming. Thanks for any help!

@nickrobinson251
Copy link
Contributor

nickrobinson251 commented Jun 21, 2022

Flux.jl uses Zygote.jl, and Zygote.jl doesn't yet utilise ChainRulesCore.jl's Thunks
https://github.com/FluxML/Zygote.jl/blob/9602c6b2038879034c2de14d1f4aa251d99c6ea4/src/compiler/chainrules.jl#L104

There is a WIP PR to make Zygote.jl utilise Thunks here: FluxML/Zygote.jl#966

@ziyiyin97
Copy link
Author

Thanks for your quick reply. Looking forward to the PR being merged

@oxinabox
Copy link
Member

This is a Zygote problem not a CRC problem

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants