Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
cuda : try cublasGemmStridedBatchedEx
Browse files Browse the repository at this point in the history
ggerganov committed Oct 24, 2023

Verified

This commit was signed with the committer’s verified signature.
ggerganov Georgi Gerganov
1 parent d415669 commit 25a0b90
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
@@ -7134,8 +7134,21 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
}
}
#else
// use cublasGemmBatchedEx
{
if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
// use cublasGemmStridedBatchedEx
CUBLAS_CHECK(
cublasGemmStridedBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha_f16, (char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), ne02*src0->nb[2], // strideA
(char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), ne12*src1->nb[2]/2, // strideB
&beta_f16, (char *) dst_f16, CUDA_R_16F, ne01, ne12* dst->nb[2]/2, // strideC
ne13,
CUBLAS_COMPUTE_16F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
printf("cublasGemmStridedBatchedEx\n");
} else {
// use cublasGemmBatchedEx
const int ne23 = ne12*ne13;

// TODO: avoid this alloc

0 comments on commit 25a0b90

Please sign in to comment.