Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to match performance of Zygote with Tracker.jl? With Issues in Tracker.jl itself. #1550

Closed
deveshjawla opened this issue Jan 6, 2025 · 1 comment

Comments

@deveshjawla
Copy link

Can someone help me understand what's causing the Tracker.jl to fail with "linking"? And how could Zygote match the performance of Tracker when "Standard" during the TuringBenchmarking?

I get the following outputs.

┌ Warning: Gradient computation (with linking) failed for AutoTracker(): MethodError(copyto!, (0.0, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}}(+, (0.0, -0.1371315395009085))), 0x0000000000006a89)
└ @ TuringBenchmarking ~/.julia/packages/TuringBenchmarking/fc6o7/src/TuringBenchmarking.jl:243
"gradient" => 3-element BenchmarkTools.BenchmarkGroup:
          tags: []
          "AutoMooncake{Mooncake.Config}(Mooncake.Config(false, false))" => 2-element BenchmarkTools.BenchmarkGroup:
                  tags: ["AutoMooncake{Mooncake.Config}(Mooncake.Config(false, false))"]
                  "linked" => Trial(531.125 μs)
                  "standard" => Trial(532.459 μs)
          "AutoZygote()" => 2-element BenchmarkTools.BenchmarkGroup:
                  tags: ["Zygote"]
                  "linked" => Trial(2.382 ms)
                  "standard" => Trial(2.252 ms)
          "AutoTracker()" => 2-element BenchmarkTools.BenchmarkGroup:
                  tags: ["Tracker"]
                  "linked" => 0-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                  "standard" => Trial(99.125 μs)

MWE:

using Lux, Zygote, Tracker, Mooncake, Turing, Random, TuringBenchmarking, Functors
nn = Chain(Dense(10, 5, relu), Dense(5, 1, use_bias=false))
rng = Xoshiro(0)
ps, st = Lux.setup(rng, nn)
num_params = Lux.parameterlength(nn) # number of parameters in NN
const model = StatefulLuxLayer{true}(nn, nothing, st)

function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
    @assert length(ps_new) == Lux.parameterlength(ps)
    i = 1
    function get_ps(x)
        z = reshape(view(ps_new, i:(i+length(x)-1)), size(x))
        i += length(x)
        return z
    end
    return fmap(get_ps, ps)
end

@model function BNN(x, y, num_p)
    θ_p ~ MvNormal(zeros(num_p), ones(num_p))

    preds = Lux.apply(model, x, vector_to_parameters(θ_p, ps))

    sigma ~ Gamma(0.1, 1.0) # Prior for the variance
 
    y[:] ~ Product(Normal.(vec(preds), sigma))
end

benchmark_result = benchmark_model(BNN(randn(10,10), randn(1,10), num_params), adbackends=[AutoZygote(), AutoTracker(), AutoMooncake(; config=Mooncake.Config(; debug_mode=false))]) 
@ToucheSir
Copy link
Member

This seems like an issue that should be reported first on the relevant Turing repo issue tracker(s). There's just way too much library code on top of Zygote and Tracker for the MWE to be an actual MWE.

I would recommend scoping down the specific perf bottlenecks with the help of Turing folks and then opening a more focused issue with improved MWE on here or the Tracker.jl repo.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants