diff --git a/src/linalg.cu b/src/linalg.cu index 1ef5fce68..4bbe40fb7 100644 --- a/src/linalg.cu +++ b/src/linalg.cu @@ -476,6 +476,166 @@ BFstatus bfMatMul_ab_exec_nobatch(BFlinalg handle, return BF_STATUS_SUCCESS; } +BFstatus bfMatMul_ab_exec_batch(BFlinalg handle, + cudaStream_t stream, + cublasOperation_t trans_a, + cublasOperation_t trans_b, + long m, + long n, + long k, + double alpha, + void const* a_data, + BFdtype a_type, + long a_stride, + long a_batchstride, + void const* b_data, + BFdtype b_type, + long b_stride, + long b_batchstride, + double beta, + void* c_data, + BFdtype c_type, + long c_stride, + long c_batchstride, + long nbatch) { + BF_TRACE_STREAM(stream); + BF_CHECK_CUBLAS(cublasSetStream(handle->cublas(), stream)); + BF_CHECK_CUBLAS(cublasSetPointerMode(handle->cublas(), + CUBLAS_POINTER_MODE_HOST)); + BF_ASSERT(a_data, BF_STATUS_INVALID_POINTER); + BF_ASSERT(b_data, BF_STATUS_INVALID_POINTER); + BF_ASSERT(c_data, BF_STATUS_INVALID_POINTER); + BF_ASSERT(a_type == b_type, BF_STATUS_UNSUPPORTED_DTYPE); + // TODO: Look into optimizations using cublasGemmEx algo selection and + // batched/strided APIs. + switch( a_type ) { + case BF_DTYPE_F32: { + BF_ASSERT(c_type == BF_DTYPE_F32, BF_STATUS_UNSUPPORTED_DTYPE); + BF_CHECK_CUBLAS(cublasSgemmStridedBatched(handle->cublas(), + trans_a, + trans_b, + m, + n, + k, + (float *)&alpha, + (const float *)a_data, + a_stride, + a_batchstride, + (const float *)b_data, + b_stride, + b_batchstride, + (float *)&beta, + (float *)c_data, + c_stride, + c_batchstride, + nbatch)); + break; + } + case BF_DTYPE_F64: { + BF_ASSERT(c_type == BF_DTYPE_F64, BF_STATUS_UNSUPPORTED_DTYPE); + BF_CHECK_CUBLAS(cublasDgemmStridedBatched(handle->cublas(), + trans_a, + trans_b, + m, + n, + k, + &alpha, + (const double *)a_data, + a_stride, + a_batchstride, + (const double *)b_data, + b_stride, + b_batchstride, + &beta, + (double *)c_data, + c_stride, + c_batchstride, + nbatch)); + break; + } + case BF_DTYPE_CI8: { + BF_ASSERT(c_type == BF_DTYPE_CF32, BF_STATUS_UNSUPPORTED_DTYPE); + cuComplex alpha_cf = make_cuComplex(alpha, 0); + cuComplex beta_cf = make_cuComplex(beta, 0); + BF_CHECK_CUBLAS(cublasGemmStridedBatchedEx(handle->cublas(), + trans_a, + trans_b, + m, + n, + k, + &alpha_cf, + a_data, + CUDA_C_8I, + a_stride, + a_batchstride, + b_data, + CUDA_C_8I, + b_stride, + b_batchstride, + &beta_cf, + c_data, + CUDA_C_32F, + c_stride, + c_batchstride, + nbatch, + CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT)); + } + case BF_DTYPE_CF32: { + BF_ASSERT(c_type == BF_DTYPE_CF32, BF_STATUS_UNSUPPORTED_DTYPE); + const cuComplex alpha_cf = make_cuComplex(alpha, 0); + const cuComplex beta_cf = make_cuComplex(beta, 0); + BF_CHECK_CUBLAS(cublasCgemm3mStridedBatched(handle->cublas(), + trans_a, + trans_b, + m, + n, + k, + &alpha_cf, + (const cuComplex *)a_data, + a_stride, + a_batchstride, + (const cuComplex *)b_data, + b_stride, + b_batchstride, + &beta_cf, + (cuComplex *)c_data, + c_stride, + c_batchstride, + nbatch + )); + break; + } + case BF_DTYPE_CF64: { + BF_ASSERT(c_type == BF_DTYPE_CF64, BF_STATUS_UNSUPPORTED_DTYPE); + const cuDoubleComplex alpha_cf = make_cuDoubleComplex(alpha, 0); + const cuDoubleComplex beta_cf = make_cuDoubleComplex(beta, 0); + BF_CHECK_CUBLAS(cublasZgemmStridedBatched(handle->cublas(), + trans_a, + trans_b, + m, + n, + k, + &alpha_cf, + (const cuDoubleComplex *)a_data, + a_stride, + a_batchstride, + (const cuDoubleComplex *)b_data, + b_stride, + b_batchstride, + &beta_cf, + (cuDoubleComplex *)c_data, + c_stride, + c_batchstride, + nbatch + )); + break; + } + default: + BF_FAIL("Supported dtype for input array", BF_STATUS_UNSUPPORTED_DTYPE); + } + return BF_STATUS_SUCCESS; +} + BFstatus bfMatMul_ab_exec(BFlinalg handle, cudaStream_t stream, cublasOperation_t trans_a, @@ -498,13 +658,12 @@ BFstatus bfMatMul_ab_exec(BFlinalg handle, BFdtype c_type, long c_stride, long c_batchstride) { - // TODO: Use batched algos here where possible + // We prefer the CI4@CF32 -> CF32 code + // second choice would be the batched algorithms - //char* use_bf_cgemm_str = getenv("BF_CGEMM"); - //bool use_bf_cgemm = use_bf_cgemm_str && atoi(use_bf_cgemm_str); - if( //use_bf_cgemm && - n <= 12 && - trans_a == CUBLAS_OP_T && trans_b == CUBLAS_OP_N && + + if( trans_a == CUBLAS_OP_T && + trans_b == CUBLAS_OP_N && (a_type == BF_DTYPE_CI4 || a_type == BF_DTYPE_CI8) && (b_type == BF_DTYPE_CI16 || b_type == BF_DTYPE_CF16 || b_type == BF_DTYPE_CF32) && c_type == BF_DTYPE_CF32 ) { @@ -518,19 +677,45 @@ BFstatus bfMatMul_ab_exec(BFlinalg handle, stream)); } - for( long b=0; b 12 && a_type != BF_DTYPE_CI8 ) { + BF_CHECK( bfMatMul_ab_exec_batch(handle, + stream, + trans_a, + trans_b, + m, + n, + k, + alpha, + a_data, + a_type, + a_stride, + a_batchstride, + b_data, + b_type, + b_stride, + b_batchstride, + beta, + c_data, + c_type, + c_stride, + c_batchstride, + nbatch) ); + } else { + for( long b=0; b