From 66cac67695cbded90acb28d3785b6017bbc5ae3c Mon Sep 17 00:00:00 2001 From: Sarah Chastain Date: Mon, 13 Nov 2023 13:17:27 -0700 Subject: [PATCH] Update linalg algorithms for better optimization --- src/linalg.cu | 208 ++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 192 insertions(+), 16 deletions(-) diff --git a/src/linalg.cu b/src/linalg.cu index 1ef5fce68..a27741e37 100644 --- a/src/linalg.cu +++ b/src/linalg.cu @@ -209,7 +209,7 @@ BFstatus bfMatMul_aa_exec(BFlinalg handle, //bool use_bf_cherk = use_bf_cherk_str && atoi(use_bf_cherk_str); enum { BF_CUBLAS_CHERK_THRESHOLD = 896 }; if( //use_bf_cherk && - n < BF_CUBLAS_CHERK_THRESHOLD && + (CUDART_VERSION < 8000 || n < BF_CUBLAS_CHERK_THRESHOLD) && trans == CUBLAS_OP_N && n % 2 == 0 && a_stride % 2 == 0 && a_batchstride % 2 == 0 && @@ -476,6 +476,166 @@ BFstatus bfMatMul_ab_exec_nobatch(BFlinalg handle, return BF_STATUS_SUCCESS; } +BFstatus bfMatMul_ab_exec_batch(BFlinalg handle, + cudaStream_t stream, + cublasOperation_t trans_a, + cublasOperation_t trans_b, + long m, + long n, + long k, + double alpha, + void const* a_data, + BFdtype a_type, + long a_stride, + long a_batchstride, + void const* b_data, + BFdtype b_type, + long b_stride, + long b_batchstride, + double beta, + void* c_data, + BFdtype c_type, + long c_stride, + long c_batchstride, + long nbatch) { + BF_TRACE_STREAM(stream); + BF_CHECK_CUBLAS(cublasSetStream(handle->cublas(), stream)); + BF_CHECK_CUBLAS(cublasSetPointerMode(handle->cublas(), + CUBLAS_POINTER_MODE_HOST)); + BF_ASSERT(a_data, BF_STATUS_INVALID_POINTER); + BF_ASSERT(b_data, BF_STATUS_INVALID_POINTER); + BF_ASSERT(c_data, BF_STATUS_INVALID_POINTER); + BF_ASSERT(a_type == b_type, BF_STATUS_UNSUPPORTED_DTYPE); + // TODO: Look into optimizations using cublasGemmEx algo selection and + // batched/strided APIs. + switch( a_type ) { + case BF_DTYPE_F32: { + BF_ASSERT(c_type == BF_DTYPE_F32, BF_STATUS_UNSUPPORTED_DTYPE); + BF_CHECK_CUBLAS(cublasSgemmStridedBatched(handle->cublas(), + trans_a, + trans_b, + m, + n, + k, + (float *)&alpha, + (const float *)a_data, + a_stride, + a_batchstride, + (const float *)b_data, + b_stride, + b_batchstride, + (float *)&beta, + (float *)c_data, + c_stride, + c_batchstride, + nbatch)); + break; + } + case BF_DTYPE_F64: { + BF_ASSERT(c_type == BF_DTYPE_F64, BF_STATUS_UNSUPPORTED_DTYPE); + BF_CHECK_CUBLAS(cublasDgemmStridedBatched(handle->cublas(), + trans_a, + trans_b, + m, + n, + k, + &alpha, + (const double *)a_data, + a_stride, + a_batchstride, + (const double *)b_data, + b_stride, + b_batchstride, + &beta, + (double *)c_data, + c_stride, + c_batchstride, + nbatch)); + break; + } + case BF_DTYPE_CI8: { + BF_ASSERT(c_type == BF_DTYPE_CF32, BF_STATUS_UNSUPPORTED_DTYPE); + cuComplex alpha_cf = make_cuComplex(alpha, 0); + cuComplex beta_cf = make_cuComplex(beta, 0); + BF_CHECK_CUBLAS(cublasGemmStridedBatchedEx(handle->cublas(), + trans_a, + trans_b, + m, + n, + k, + &alpha_cf, + a_data, + CUDA_C_8I, + a_stride, + a_batchstride, + b_data, + CUDA_C_8I, + b_stride, + b_batchstride, + &beta_cf, + c_data, + CUDA_C_32F, + c_stride, + c_batchstride, + nbatch, + CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT)); + } + case BF_DTYPE_CF32: { + BF_ASSERT(c_type == BF_DTYPE_CF32, BF_STATUS_UNSUPPORTED_DTYPE); + const cuComplex alpha_cf = make_cuComplex(alpha, 0); + const cuComplex beta_cf = make_cuComplex(beta, 0); + BF_CHECK_CUBLAS(cublasCgemm3mStridedBatched(handle->cublas(), + trans_a, + trans_b, + m, + n, + k, + &alpha_cf, + (const cuComplex *)a_data, + a_stride, + a_batchstride, + (const cuComplex *)b_data, + b_stride, + b_batchstride, + &beta_cf, + (cuComplex *)c_data, + c_stride, + c_batchstride, + nbatch + )); + break; + } + case BF_DTYPE_CF64: { + BF_ASSERT(c_type == BF_DTYPE_CF64, BF_STATUS_UNSUPPORTED_DTYPE); + const cuDoubleComplex alpha_cf = make_cuDoubleComplex(alpha, 0); + const cuDoubleComplex beta_cf = make_cuDoubleComplex(beta, 0); + BF_CHECK_CUBLAS(cublasZgemmStridedBatched(handle->cublas(), + trans_a, + trans_b, + m, + n, + k, + &alpha_cf, + (const cuDoubleComplex *)a_data, + a_stride, + a_batchstride, + (const cuDoubleComplex *)b_data, + b_stride, + b_batchstride, + &beta_cf, + (cuDoubleComplex *)c_data, + c_stride, + c_batchstride, + nbatch + )); + break; + } + default: + BF_FAIL("Supported dtype for input array", BF_STATUS_UNSUPPORTED_DTYPE); + } + return BF_STATUS_SUCCESS; +} + BFstatus bfMatMul_ab_exec(BFlinalg handle, cudaStream_t stream, cublasOperation_t trans_a, @@ -498,16 +658,18 @@ BFstatus bfMatMul_ab_exec(BFlinalg handle, BFdtype c_type, long c_stride, long c_batchstride) { - // TODO: Use batched algos here where possible + // We prefer the CI4@CF32 -> CF32 code + // second choice would be the batched algorithms //char* use_bf_cgemm_str = getenv("BF_CGEMM"); //bool use_bf_cgemm = use_bf_cgemm_str && atoi(use_bf_cgemm_str); + // std::cout << "nbatch: "< 12 && a_type != BF_DTYPE_CI8 ) { + //std::cout << "nbatch: " << nbatch << std::endl; + BF_CHECK( bfMatMul_ab_exec_batch(handle, stream, + trans_a, trans_b, + m, n, k, + alpha, + a_data, a_type, a_stride, a_batchstride, + b_data, b_type, b_stride, b_batchstride, + beta, + c_data, c_type, c_stride, c_batchstride, + nbatch) ); + } else { + for( long b=0; b