Skip to content

Commit

Permalink
Update linalg algorithms for better optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
dentalfloss1 committed Nov 13, 2023
1 parent 1233200 commit 66cac67
Showing 1 changed file with 192 additions and 16 deletions.
208 changes: 192 additions & 16 deletions src/linalg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ BFstatus bfMatMul_aa_exec(BFlinalg handle,
//bool use_bf_cherk = use_bf_cherk_str && atoi(use_bf_cherk_str);
enum { BF_CUBLAS_CHERK_THRESHOLD = 896 };
if( //use_bf_cherk &&
n < BF_CUBLAS_CHERK_THRESHOLD &&
(CUDART_VERSION < 8000 || n < BF_CUBLAS_CHERK_THRESHOLD) &&
trans == CUBLAS_OP_N &&
n % 2 == 0 &&
a_stride % 2 == 0 && a_batchstride % 2 == 0 &&
Expand Down Expand Up @@ -476,6 +476,166 @@ BFstatus bfMatMul_ab_exec_nobatch(BFlinalg handle,
return BF_STATUS_SUCCESS;
}

BFstatus bfMatMul_ab_exec_batch(BFlinalg handle,
cudaStream_t stream,
cublasOperation_t trans_a,
cublasOperation_t trans_b,
long m,
long n,
long k,
double alpha,
void const* a_data,
BFdtype a_type,
long a_stride,
long a_batchstride,
void const* b_data,
BFdtype b_type,
long b_stride,
long b_batchstride,
double beta,
void* c_data,
BFdtype c_type,
long c_stride,
long c_batchstride,
long nbatch) {
BF_TRACE_STREAM(stream);
BF_CHECK_CUBLAS(cublasSetStream(handle->cublas(), stream));
BF_CHECK_CUBLAS(cublasSetPointerMode(handle->cublas(),
CUBLAS_POINTER_MODE_HOST));
BF_ASSERT(a_data, BF_STATUS_INVALID_POINTER);
BF_ASSERT(b_data, BF_STATUS_INVALID_POINTER);
BF_ASSERT(c_data, BF_STATUS_INVALID_POINTER);
BF_ASSERT(a_type == b_type, BF_STATUS_UNSUPPORTED_DTYPE);
// TODO: Look into optimizations using cublasGemmEx algo selection and
// batched/strided APIs.
switch( a_type ) {
case BF_DTYPE_F32: {
BF_ASSERT(c_type == BF_DTYPE_F32, BF_STATUS_UNSUPPORTED_DTYPE);
BF_CHECK_CUBLAS(cublasSgemmStridedBatched(handle->cublas(),
trans_a,
trans_b,
m,
n,
k,
(float *)&alpha,
(const float *)a_data,
a_stride,
a_batchstride,
(const float *)b_data,
b_stride,
b_batchstride,
(float *)&beta,
(float *)c_data,
c_stride,
c_batchstride,
nbatch));
break;
}
case BF_DTYPE_F64: {
BF_ASSERT(c_type == BF_DTYPE_F64, BF_STATUS_UNSUPPORTED_DTYPE);
BF_CHECK_CUBLAS(cublasDgemmStridedBatched(handle->cublas(),
trans_a,
trans_b,
m,
n,
k,
&alpha,
(const double *)a_data,
a_stride,
a_batchstride,
(const double *)b_data,
b_stride,
b_batchstride,
&beta,
(double *)c_data,
c_stride,
c_batchstride,
nbatch));
break;
}
case BF_DTYPE_CI8: {
BF_ASSERT(c_type == BF_DTYPE_CF32, BF_STATUS_UNSUPPORTED_DTYPE);
cuComplex alpha_cf = make_cuComplex(alpha, 0);
cuComplex beta_cf = make_cuComplex(beta, 0);
BF_CHECK_CUBLAS(cublasGemmStridedBatchedEx(handle->cublas(),
trans_a,
trans_b,
m,
n,
k,
&alpha_cf,
a_data,
CUDA_C_8I,
a_stride,
a_batchstride,
b_data,
CUDA_C_8I,
b_stride,
b_batchstride,
&beta_cf,
c_data,
CUDA_C_32F,
c_stride,
c_batchstride,
nbatch,
CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT));
}
case BF_DTYPE_CF32: {
BF_ASSERT(c_type == BF_DTYPE_CF32, BF_STATUS_UNSUPPORTED_DTYPE);
const cuComplex alpha_cf = make_cuComplex(alpha, 0);
const cuComplex beta_cf = make_cuComplex(beta, 0);
BF_CHECK_CUBLAS(cublasCgemm3mStridedBatched(handle->cublas(),
trans_a,
trans_b,
m,
n,
k,
&alpha_cf,
(const cuComplex *)a_data,
a_stride,
a_batchstride,
(const cuComplex *)b_data,
b_stride,
b_batchstride,
&beta_cf,
(cuComplex *)c_data,
c_stride,
c_batchstride,
nbatch
));
break;
}
case BF_DTYPE_CF64: {
BF_ASSERT(c_type == BF_DTYPE_CF64, BF_STATUS_UNSUPPORTED_DTYPE);
const cuDoubleComplex alpha_cf = make_cuDoubleComplex(alpha, 0);
const cuDoubleComplex beta_cf = make_cuDoubleComplex(beta, 0);
BF_CHECK_CUBLAS(cublasZgemmStridedBatched(handle->cublas(),
trans_a,
trans_b,
m,
n,
k,
&alpha_cf,
(const cuDoubleComplex *)a_data,
a_stride,
a_batchstride,
(const cuDoubleComplex *)b_data,
b_stride,
b_batchstride,
&beta_cf,
(cuDoubleComplex *)c_data,
c_stride,
c_batchstride,
nbatch
));
break;
}
default:
BF_FAIL("Supported dtype for input array", BF_STATUS_UNSUPPORTED_DTYPE);
}
return BF_STATUS_SUCCESS;
}

BFstatus bfMatMul_ab_exec(BFlinalg handle,
cudaStream_t stream,
cublasOperation_t trans_a,
Expand All @@ -498,16 +658,18 @@ BFstatus bfMatMul_ab_exec(BFlinalg handle,
BFdtype c_type,
long c_stride,
long c_batchstride) {
// TODO: Use batched algos here where possible
// We prefer the CI4@CF32 -> CF32 code
// second choice would be the batched algorithms

//char* use_bf_cgemm_str = getenv("BF_CGEMM");
//bool use_bf_cgemm = use_bf_cgemm_str && atoi(use_bf_cgemm_str);
// std::cout << "nbatch: "<<nbatch<<std::endl;
if( //use_bf_cgemm &&
n <= 12 &&
trans_a == CUBLAS_OP_T && trans_b == CUBLAS_OP_N &&
(a_type == BF_DTYPE_CI4 || a_type == BF_DTYPE_CI8) &&
(b_type == BF_DTYPE_CI16 || b_type == BF_DTYPE_CF16 || b_type == BF_DTYPE_CF32) &&
c_type == BF_DTYPE_CF32 ) {
std::cout<<"custom kernel"<<endl;
BF_TRY_RETURN(bf_cgemm_TN_smallM(
m, n, k, nbatch,
alpha,
Expand All @@ -518,19 +680,33 @@ BFstatus bfMatMul_ab_exec(BFlinalg handle,
stream));
}

for( long b=0; b<nbatch; ++b ) {
cuda::child_stream child_stream(stream);
BF_CHECK( bfMatMul_ab_exec_nobatch(handle, child_stream,
trans_a, trans_b,
m, n, k,
alpha,
a_data, a_type, a_stride,
b_data, b_type, b_stride,
beta,
c_data, c_type, c_stride) );
a_data = (char*)a_data + a_batchstride * BF_DTYPE_NBYTE(a_type);
b_data = (char*)b_data + b_batchstride * BF_DTYPE_NBYTE(b_type);
c_data = (char*)c_data + c_batchstride * BF_DTYPE_NBYTE(c_type);
// TODO: Why does ci8 yield nans when we go into batched execution?
if( nbatch > 12 && a_type != BF_DTYPE_CI8 ) {
//std::cout << "nbatch: " << nbatch << std::endl;
BF_CHECK( bfMatMul_ab_exec_batch(handle, stream,
trans_a, trans_b,
m, n, k,
alpha,
a_data, a_type, a_stride, a_batchstride,
b_data, b_type, b_stride, b_batchstride,
beta,
c_data, c_type, c_stride, c_batchstride,
nbatch) );
} else {
for( long b=0; b<nbatch; ++b ) {
cuda::child_stream child_stream(stream);
BF_CHECK( bfMatMul_ab_exec_nobatch(handle, child_stream,
trans_a, trans_b,
m, n, k,
alpha,
a_data, a_type, a_stride,
b_data, b_type, b_stride,
beta,
c_data, c_type, c_stride) );
a_data = (char*)a_data + a_batchstride * BF_DTYPE_NBYTE(a_type);
b_data = (char*)b_data + b_batchstride * BF_DTYPE_NBYTE(b_type);
c_data = (char*)c_data + c_batchstride * BF_DTYPE_NBYTE(c_type);
}
}
return BF_STATUS_SUCCESS;
}
Expand Down

0 comments on commit 66cac67

Please sign in to comment.