Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid unnecessarily disabling CUDA graphs #7302

Merged

Conversation

agray3
Copy link
Contributor

@agray3 agray3 commented May 15, 2024

As discussed in PR #6766, CUDA graphs were being disabled in the presence of long prompts. This fixes the issue by avoiding the consective update counter from incrementing unnecessarily for tokens in which cuda graphs are disabled due to batch size > 1.

As discussed in PR ggerganov#6766, CUDA graphs were being disabled in the presence of long prompts.
This fixes the issue by avoiding the consective update counter from incrementing unnecessarily
for tokens in which cuda graphs are disabled due to batch size > 1.
@mofosyne mofosyne added Nvidia GPU Issues specific to Nvidia GPUs Review Complexity : Low Trivial changes to code that most beginner devs (or those who want a break) can tackle. e.g. UI fix labels May 15, 2024
@slaren
Copy link
Collaborator

slaren commented May 15, 2024

GPU Model Test t/s master t/s ag_fix_consecutive_update_counter_increment Speedup
RTX 3090 Ti llama 7B Q4_0 pp512 4938.82 4940.78 1.00
RTX 3090 Ti llama 7B Q4_0 tg128 164.12 163.80 1.00
RTX 3090 Ti llama 7B Q4_0 pp512+tg128 599.39 691.16 1.15

@slaren
Copy link
Collaborator

slaren commented May 15, 2024

The benchmark is also failing on master so I don't think it is caused by this.

@slaren slaren merged commit dc02098 into ggerganov:master May 15, 2024
40 of 62 checks passed
teleprint-me pushed a commit to teleprint-me/llama.cpp that referenced this pull request May 17, 2024
As discussed in PR ggerganov#6766, CUDA graphs were being disabled in the presence of long prompts.
This fixes the issue by avoiding the consective update counter from incrementing unnecessarily
for tokens in which cuda graphs are disabled due to batch size > 1.
@skoulik
Copy link

skoulik commented May 29, 2024

@agray3

Any chance of enabling graphs on Turing architecture?
I've tried it on RTX 2080 Ti (CC 7.5) and the graphs seem to be supported.

@agray3
Copy link
Contributor Author

agray3 commented May 30, 2024

@agray3

Any chance of enabling graphs on Turing architecture? I've tried it on RTX 2080 Ti (CC 7.5) and the graphs seem to be supported.

@skoulik The reason that CUDA graphs are disabled for older GPUs is related to performance - have you tested performance with and without CUDA graphs on your Turing GPU? From my test (on Quadro RTX 8000), it is actually a bit slower with graphs. To remove the restriction, just comment out the line https://github.com/agray3/llama.cpp/blob/d5c05821f3c3d6cabe8ac45776fe0ecb0da13eca/ggml-cuda.cu#L2533

If there is a benefit in your case, we could maybe add GGML_CUDA_ENABLE_GRAPHS environment variable to override the default behaviour (noting we currently have the opposite GGML_CUDA_DISABLE_GRAPHS).

@skoulik
Copy link

skoulik commented May 30, 2024

@agray3

@skoulik The reason that CUDA graphs are disabled for older GPUs is related to performance - have you tested performance with and without CUDA graphs on your Turing GPU?

Actually not, because I seem cannot be able to. What I found in my usecase is that the graps keep being disabled because of cuda_ctx->cuda_graph->disable_due_to_too_many_updates a few lines lower. This is when I run server inference with llama 3 model. Can you tell me the usecase(s) where it is supposed to work? I'll be happy to test.

From my test (on Quadro RTX 8000), it is actually a bit slower with graphs.

Understood.

To remove the restriction, just comment out the line https://github.com/agray3/llama.cpp/blob/d5c05821f3c3d6cabe8ac45776fe0ecb0da13eca/ggml-cuda.cu#L2533

That's what I did.

@agray3
Copy link
Contributor Author

agray3 commented May 30, 2024

Actually not, because I seem cannot be able to. What I found in my usecase is that the graps keep being disabled because of cuda_ctx->cuda_graph->disable_due_to_too_many_updates a few lines lower. This is when I run server inference with llama 3 model. Can you tell me the usecase(s) where it is supposed to work? I'll be happy to test.

I'm not sure why that is happening - can you please tell me your exact command to allow me to reproduce?

@skoulik
Copy link

skoulik commented May 31, 2024

Actually not, because I seem cannot be able to. What I found in my usecase is that the graps keep being disabled because of cuda_ctx->cuda_graph->disable_due_to_too_many_updates a few lines lower. This is when I run server inference with llama 3 model. Can you tell me the usecase(s) where it is supposed to work? I'll be happy to test.

I'm not sure why that is happening - can you please tell me your exact command to allow me to reproduce?

Of course, here is the command I run:

main.exe -m models\Meta-Llama-3-8B-Instruct_64K_Q8_0.gguf --ctx-size 65536 --n-gpu-layers 33 -t 1 --flash-attn --override-kv tokenizer.ggml.pre=str:llama3 -s 1

The quants downloaded from https://huggingface.co/NurtureAI/Meta-Llama-3-8B-Instruct-64k-GGUF

I've added logging and confirm that all calls end up in disable_due_to_too_many_updates. Not a single graph launch, unfortunately.

@agray3
Copy link
Contributor Author

agray3 commented May 31, 2024

@skoulik I tried with your exact command but unfortunately I cannot reproduce - it uses CUDA graphs for me as expected (including on Turing if I remove the restriction). All I can suggest is that you add some print statements to the code to trace back exactly what is causing the cuda_graph_update_required flag to be set true across multiple consecutive tokens, in your case.

@skoulik
Copy link

skoulik commented May 31, 2024

@agray3 ,
I've done some more debugging.
On startup I got a bunch of

ggml_backend_cuda_graph_compute: disabling CUDA graphs due to batch size > 1 [ffn_inp-0] [4096 2 1 1]
ggml_backend_cuda_graph_compute: disabling CUDA graphs due to batch size > 1 [ffn_inp-1] [4096 2 1 1]
...
ggml_backend_cuda_graph_compute: disabling CUDA graphs due to batch size > 1 [ffn_inp-30] [4096 2 1 1]

I suppose these are expected.
Then it gets more interesting:
Very first pair of calls to cudaGraphExecUpdate() and cudaGraphLaunch() are fine.
Next call to cudaGraphExecUpdate() fails:

stat  cudaErrorGraphExecUpdateFailure (910)   cudaError
result_info:
result  cudaGraphExecUpdateErrorTopologyChanged (2)
errorNode       0x0000000000000000 <NULL>
errorFromNode   0x0000000000000000 <NULL>

Then, after 4 retries because of number_consecutive_updates>=4 it sets disable_due_to_too_many_updates = true.

Now, I suppose the question is, what's wrong with graph topology? Any idea on how to debug?

Update:
I reckon cudaErrorGraphExecUpdateFailure is also kind of expected. What happens it the graph is instantiated then later re-captured because of params change, but the original instance is not destroyed and re-instantiated. I would've detoyed it just before the call to cudaStreamEndCapture() and also set num_nodes = 0 to avoid the error altogether. I mean, like this:

@@ -2670,6 +2670,11 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
 
 #ifdef USE_CUDA_GRAPH
         if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
+            if(cuda_ctx->cuda_graph->instance != nullptr) {
+                CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
+                cuda_ctx->cuda_graph->instance = nullptr;
+                cuda_ctx->cuda_graph->num_nodes = 0;
+            }
             if (cuda_ctx->cuda_graph->graph != nullptr) {
                 CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
                 cuda_ctx->cuda_graph->graph = nullptr;

I'll look further on why the params are changing.

@skoulik
Copy link

skoulik commented May 31, 2024

@agray3 ,
On a second thought, the error cannot be "expected", because it prevents the just captured graph from executing. I reckon the patch that I proposed is valid and should be integrated. What do you think?

Update:
After experimenting, I've found that trying to execute graph instances created from previous capture always produces cudaErrorGraphExecUpdateFailure and the error is un-recoverable in a way it is handled by the current implementation. So in my case cudaGraphExecDestroy() before cudaStreamEndCapture() is a must. Might be specific to Turing though.

@skoulik
Copy link

skoulik commented May 31, 2024

As for the root causes of the cuda_graph_update_required.
The first batch of updates mentioned above is caused by:

if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) 

node->src[1]->ne[1] == 2. Which is OK, I believe.

Now, when the generations starts, the updates are caused by: has_matching_properties == false, by thus check: node->src[i]->data != graph_node_properties->src_address[i]

 	main.exe!ggml_graph_node_has_matching_properties(ggml_tensor * node=0x0000021415569da0, ggml_graph_node_properties * graph_node_properties=0x000002104502de00) Line 2506	C++
>	main.exe!ggml_backend_cuda_graph_compute(ggml_backend * backend=0x00000210476b5220, ggml_cgraph * cgraph=0x00000210467da3c0) Line 2566	C++
 	main.exe!ggml_backend_graph_compute_async(ggml_backend * backend=0x00000210476b5220, ggml_cgraph * cgraph=0x00000210467da3c0) Line 282	C
 	main.exe!ggml_backend_sched_compute_splits(ggml_backend_sched * sched=0x0000021415868070) Line 1667	C
    for (int i = 0; i < GGML_MAX_SRC; i++) {
        if (node->src[i] &&
            node->src[i]->data != graph_node_properties->src_address[i] &&
            node->op != GGML_OP_CPY &&
            node->op != GGML_OP_VIEW
        ) {
            return false;
        }
		node->src[i]->data	0x0000001916f60000	void *
		graph_node_properties->src_address[i]	0x0000001909258000	void *

Not sure what this means and why it works for you,

@agray3
Copy link
Contributor Author

agray3 commented May 31, 2024

I reckon the patch that I proposed is valid and should be integrated. What do you think?

This change always destroys any instantiated graph, but the instantiated graph is required for the graph update mechanism, which involves updating a previously instantiated graph with a newly captured graph to create an updated instance.

Not sure what this means and why it works for you,

Yeah it's weird we are seeing different behaviour. Can you please share the full standard output you get from

main.exe -m models\Meta-Llama-3-8B-Instruct_64K_Q8_0.gguf --ctx-size 65536 --n-gpu-layers 33 -t 1 --flash-attn --override-kv tokenizer.ggml.pre=str:llama3 -s 1

and I'll cross check with what I get,?

@skoulik
Copy link

skoulik commented Jun 1, 2024

@agray3 ,
Thank you for looking at it.

I reckon the patch that I proposed is valid and should be integrated. What do you think?

This change always destroys any instantiated graph, but the instantiated graph is required for the graph update mechanism, which involves updating a previously instantiated graph with a newly captured graph to create an updated instance.

Are you saying that the graph instance from the previous graph should work with the newly captured graph? Can you point me to where it is stated in CUDA documentation, because I struggle to find the confirmation (and my observations seem to prove the opposite at least is some cases - see below)?
Even if it is true, I'd expect that the graphs topologies must match. Which is not the case: the first captured graph has only 10 nodes while the next one, that fails, has 17 nodes.
I just can't see how the code below could update the instance with the correct data. For instance, in line 2705 (if (cuda_ctx->cuda_graph->num_nodes == 0)) it won't re-read the new nodes count. The new graph has 17 nodes, the previous instance has 10 and the count is not updated, the cuda_ctx->cuda_graph->num_nodes is still 10 while cgraph->n_nodes is 17. So not all kernel parameters are extracted in the following loop.
However, if the instance is re-created and nodes count zeroed-out and re-read in line 2707, it works fine.
Something fishy is going on here.

Not sure what this means and why it works for you,

Yeah it's weird we are seeing different behaviour. Can you please share the full standard output you get from

main.exe -m models\Meta-Llama-3-8B-Instruct_64K_Q8_0.gguf --ctx-size 65536 --n-gpu-layers 33 -t 1 --flash-attn --override-kv tokenizer.ggml.pre=str:llama3 -s 1

and I'll cross check with what I get,?

I've added more logging and commented-out the disable_due_to_too_many_updates block - see attached.
Looking closely, sometimes the update indeed works, sometimes it does not. Please run the same (attached) version of ggml-cuda.cu so that we can compare apples to apples.
ggml-cuda.cu.more-logging.txt
log.txt

@skoulik
Copy link

skoulik commented Jun 2, 2024

In case it still won't reproduce, please compare your cmake config to mine (attached). Especially LLAMA_CUDA_xxx settings. There might be clues.
CMakeCache.txt

@agray3
Copy link
Contributor Author

agray3 commented Jun 3, 2024

I think I may have found the difference. Comparing logs, there is already a difference right at the start where you have the extra lines:

ggml_opencl: selecting platform: 'NVIDIA CUDA'
ggml_opencl: selecting device: 'NVIDIA GeForce RTX 2080 Ti'

which suggest that OpenCL is in use for your build, whereas I am not using OpenCL (just CUDA). This may be because, as I can see in your CMakeCache.txt file, you have LLAMA_CLBLAST set. Please can you try using a build just using LLAMA_CUDA=ON.
If it still doesn't work, please share your updated log.

Here is the documentation for cudaGraphExecUpdate:
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g34c8e1e6118aa68fc715250dfd93c1c4
You are right that this won't succeed if the topology has changed, in which case we have logic to fall back to a full graph instantiation. But if just the parameters have changed then it will succeed, and is faster than a full instantiation.

@skoulik
Copy link

skoulik commented Jun 3, 2024

I think I may have found the difference. Comparing logs, there is already a difference right at the start where you have the extra lines:

ggml_opencl: selecting platform: 'NVIDIA CUDA'
ggml_opencl: selecting device: 'NVIDIA GeForce RTX 2080 Ti'

which suggest that OpenCL is in use for your build, whereas I am not using OpenCL (just CUDA). This may be because, as I can see in your CMakeCache.txt file, you have LLAMA_CLBLAST set. Please can you try using a build just using LLAMA_CUDA=ON. If it still doesn't work, please share your updated log.

Good catch, mate! Indeed I was building with CLBlast support in order to try llama.cpp on my another GPU (Radeon RX 5700).
The least I expected it to affect CUDA path. Without CLBlast it seems to work as expected:
log.no-clblast.txt
The graphs are being split differently with CLBlast for some reason. Might be room for improvement in current graphs code (to handle changing topologies better)?

As for the speed. Without graphs:

llama_print_timings:        load time =   10624.56 ms
llama_print_timings:      sample time =      57.01 ms /   836 runs   (    0.07 ms per token, 14663.58 tokens per second)
llama_print_timings: prompt eval time =       0.00 ms /     0 tokens (-nan(ind) ms per token, -nan(ind) tokens per second)
llama_print_timings:        eval time =   36796.82 ms /   836 runs   (   44.02 ms per token,    22.72 tokens per second)
llama_print_timings:       total time =   37721.09 ms /   836 tokens

llama_print_timings:        load time =   11034.41 ms
llama_print_timings:      sample time =      25.59 ms /   374 runs   (    0.07 ms per token, 14615.08 tokens per second)
llama_print_timings: prompt eval time =       0.00 ms /     0 tokens (-nan(ind) ms per token, -nan(ind) tokens per second)
llama_print_timings:        eval time =   14160.57 ms /   374 runs   (   37.86 ms per token,    26.41 tokens per second)
llama_print_timings:       total time =   14609.14 ms /   374 tokens

With graphs:

llama_print_timings:        load time =   13810.44 ms
llama_print_timings:      sample time =      55.51 ms /   836 runs   (    0.07 ms per token, 15060.62 tokens per second)
llama_print_timings: prompt eval time =       0.00 ms /     0 tokens (-nan(ind) ms per token, -nan(ind) tokens per second)
llama_print_timings:        eval time =   29994.80 ms /   836 runs   (   35.88 ms per token,    27.87 tokens per second)
llama_print_timings:       total time =   30885.57 ms /   836 tokens

llama_print_timings:        load time =   10272.23 ms
llama_print_timings:      sample time =      57.02 ms /   836 runs   (    0.07 ms per token, 14661.01 tokens per second)
llama_print_timings: prompt eval time =       0.00 ms /     0 tokens (-nan(ind) ms per token, -nan(ind) tokens per second)
llama_print_timings:        eval time =   34784.29 ms /   836 runs   (   41.61 ms per token,    24.03 tokens per second)
llama_print_timings:       total time =   35703.72 ms /   836 tokens

I'd say unconclusive. The measurements fluctuate a lot. I noticed that tensor cores utilization sits at 100% almost all the time with graphs, while it is closer to 97% without.

Here is the documentation for cudaGraphExecUpdate: https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g34c8e1e6118aa68fc715250dfd93c1c4 You are right that this won't succeed if the topology has changed, in which case we have logic to fall back to a full graph instantiation. But if just the parameters have changed then it will succeed, and is faster than a full instantiation.

Thank you. I was reading through the documents and blogposts over the weekend. Now I understand the idea behind Instantiate then loop(Capture, Update) better.

The only thing that bothers me after all these excercises is if the current code is generic enough. What if there are layers with different topologies present in some model? (I understand that my case is a kind of misconfiguration, but I won't dismiss such possibility) Won't we be better off if we cache a few graphs - one for each encoundered topology - rather than just one. With such a design we'll be able to work without re-instantiating even if there are different topologies present simultaneously (and the code is looping over them).

@skoulik
Copy link

skoulik commented Jun 3, 2024

BTW, I still believe there is a bug in cudaErrorGraphExecUpdateFailure handling. The graph is re-instantiated, but num_nodes is not updated. It won't be updated in line 2707 either, beause of the check. So if the nodes count changes the code won't recover.

@agray3
Copy link
Contributor Author

agray3 commented Jun 4, 2024

BTW, I still believe there is a bug in cudaErrorGraphExecUpdateFailure handling. The graph is re-instantiated, but num_nodes is not updated. It won't be updated in line 2707 either, beause of the check. So if the nodes count changes the code won't recover.

Good spot! See #7738

@agray3
Copy link
Contributor Author

agray3 commented Jun 4, 2024

The only thing that bothers me after all these excercises is if the current code is generic enough. What if there are layers with different topologies present in some model? (I understand that my case is a kind of misconfiguration, but I won't dismiss such possibility) Won't we be better off if we cache a few graphs - one for each encoundered topology - rather than just one. With such a design we'll be able to work without re-instantiating even if there are different topologies present simultaneously (and the code is looping over them).

I guess it will be a trade off between code complexity and benefit to such cases where this occurs, but I don't have a good feeling for how common/likely such scenarios are.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Nvidia GPU Issues specific to Nvidia GPUs Review Complexity : Low Trivial changes to code that most beginner devs (or those who want a break) can tackle. e.g. UI fix
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants