Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cuda : add batched cuBLAS GEMM for faster attention #3749
cuda : add batched cuBLAS GEMM for faster attention #3749
Changes from 5 commits
8fb1be6
6a30bf3
8d8d54f
84d4ca0
c13fcfb
878aa4f
d415669
3d297c1
27c34c0
d798a17
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does anyone know if I changed these
cudaMemcpy
tocudaMemcpyAsync
, do I need to add some synchronization before callingcublasGemmBatchedEx
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You only need to run them on the same stream, but I don't think that this can be made async because the host memory may already be freed by the time the copy happens. Running memcpy asynchronously also requires using host pinned memory.
If the
cublasGemmBatchedEx
needs to stay to support GQA, I would consider writing a kernel to calculate these values and calling cublas from the kernel. Additionally,cudaMalloc
is usually very slow, which is why we have the memory pool. These allocations should be changed to use the memory pool.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reduced the mallocs from 3 to 1, but when I try to replace it with
ggml_cuda_pool_malloc()
(see commented code) the computation crashes with illegal access memory somewhere later. Couldn't figure out what is the cause - probably some memory alignment issue in the pointer arrays.