-
Notifications
You must be signed in to change notification settings - Fork 10.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Avoid unnecessarily disabling CUDA graphs #7302
Avoid unnecessarily disabling CUDA graphs #7302
Conversation
As discussed in PR ggerganov#6766, CUDA graphs were being disabled in the presence of long prompts. This fixes the issue by avoiding the consective update counter from incrementing unnecessarily for tokens in which cuda graphs are disabled due to batch size > 1.
|
The benchmark is also failing on master so I don't think it is caused by this. |
As discussed in PR ggerganov#6766, CUDA graphs were being disabled in the presence of long prompts. This fixes the issue by avoiding the consective update counter from incrementing unnecessarily for tokens in which cuda graphs are disabled due to batch size > 1.
Any chance of enabling graphs on Turing architecture? |
@skoulik The reason that CUDA graphs are disabled for older GPUs is related to performance - have you tested performance with and without CUDA graphs on your Turing GPU? From my test (on Quadro RTX 8000), it is actually a bit slower with graphs. To remove the restriction, just comment out the line https://github.com/agray3/llama.cpp/blob/d5c05821f3c3d6cabe8ac45776fe0ecb0da13eca/ggml-cuda.cu#L2533 If there is a benefit in your case, we could maybe add GGML_CUDA_ENABLE_GRAPHS environment variable to override the default behaviour (noting we currently have the opposite GGML_CUDA_DISABLE_GRAPHS). |
Actually not, because I seem cannot be able to. What I found in my usecase is that the graps keep being disabled because of cuda_ctx->cuda_graph->disable_due_to_too_many_updates a few lines lower. This is when I run server inference with llama 3 model. Can you tell me the usecase(s) where it is supposed to work? I'll be happy to test.
Understood.
That's what I did. |
I'm not sure why that is happening - can you please tell me your exact command to allow me to reproduce? |
Of course, here is the command I run:
The quants downloaded from https://huggingface.co/NurtureAI/Meta-Llama-3-8B-Instruct-64k-GGUF I've added logging and confirm that all calls end up in disable_due_to_too_many_updates. Not a single graph launch, unfortunately. |
@skoulik I tried with your exact command but unfortunately I cannot reproduce - it uses CUDA graphs for me as expected (including on Turing if I remove the restriction). All I can suggest is that you add some print statements to the code to trace back exactly what is causing the |
@agray3 ,
I suppose these are expected.
Then, after 4 retries because of number_consecutive_updates>=4 it sets disable_due_to_too_many_updates = true. Now, I suppose the question is, what's wrong with graph topology? Any idea on how to debug? Update:
I'll look further on why the params are changing. |
@agray3 , Update: |
As for the root causes of the cuda_graph_update_required.
node->src[1]->ne[1] == 2. Which is OK, I believe. Now, when the generations starts, the updates are caused by: has_matching_properties == false, by thus check: node->src[i]->data != graph_node_properties->src_address[i]
Not sure what this means and why it works for you, |
This change always destroys any instantiated graph, but the instantiated graph is required for the graph update mechanism, which involves updating a previously instantiated graph with a newly captured graph to create an updated instance.
Yeah it's weird we are seeing different behaviour. Can you please share the full standard output you get from
and I'll cross check with what I get,? |
@agray3 ,
Are you saying that the graph instance from the previous graph should work with the newly captured graph? Can you point me to where it is stated in CUDA documentation, because I struggle to find the confirmation (and my observations seem to prove the opposite at least is some cases - see below)?
I've added more logging and commented-out the disable_due_to_too_many_updates block - see attached. |
In case it still won't reproduce, please compare your cmake config to mine (attached). Especially LLAMA_CUDA_xxx settings. There might be clues. |
I think I may have found the difference. Comparing logs, there is already a difference right at the start where you have the extra lines:
which suggest that OpenCL is in use for your build, whereas I am not using OpenCL (just CUDA). This may be because, as I can see in your CMakeCache.txt file, you have Here is the documentation for cudaGraphExecUpdate: |
Good catch, mate! Indeed I was building with CLBlast support in order to try llama.cpp on my another GPU (Radeon RX 5700). As for the speed. Without graphs:
With graphs:
I'd say unconclusive. The measurements fluctuate a lot. I noticed that tensor cores utilization sits at 100% almost all the time with graphs, while it is closer to 97% without.
Thank you. I was reading through the documents and blogposts over the weekend. Now I understand the idea behind Instantiate then loop(Capture, Update) better. The only thing that bothers me after all these excercises is if the current code is generic enough. What if there are layers with different topologies present in some model? (I understand that my case is a kind of misconfiguration, but I won't dismiss such possibility) Won't we be better off if we cache a few graphs - one for each encoundered topology - rather than just one. With such a design we'll be able to work without re-instantiating even if there are different topologies present simultaneously (and the code is looping over them). |
BTW, I still believe there is a bug in cudaErrorGraphExecUpdateFailure handling. The graph is re-instantiated, but num_nodes is not updated. It won't be updated in line 2707 either, beause of the check. So if the nodes count changes the code won't recover. |
Good spot! See #7738 |
I guess it will be a trade off between code complexity and benefit to such cases where this occurs, but I don't have a good feeling for how common/likely such scenarios are. |
As discussed in PR #6766, CUDA graphs were being disabled in the presence of long prompts. This fixes the issue by avoiding the consective update counter from incrementing unnecessarily for tokens in which cuda graphs are disabled due to batch size > 1.