Skip to content

Commit

Permalink
Merge pull request #224 from ledatelescope/update-linalg-algos
Browse files Browse the repository at this point in the history
Update linalg algos
  • Loading branch information
jaycedowell authored Dec 7, 2023
2 parents f390184 + af885c0 commit 3c68028
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 20 deletions.
223 changes: 204 additions & 19 deletions src/linalg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,166 @@ BFstatus bfMatMul_ab_exec_nobatch(BFlinalg handle,
return BF_STATUS_SUCCESS;
}

BFstatus bfMatMul_ab_exec_batch(BFlinalg handle,
cudaStream_t stream,
cublasOperation_t trans_a,
cublasOperation_t trans_b,
long m,
long n,
long k,
double alpha,
void const* a_data,
BFdtype a_type,
long a_stride,
long a_batchstride,
void const* b_data,
BFdtype b_type,
long b_stride,
long b_batchstride,
double beta,
void* c_data,
BFdtype c_type,
long c_stride,
long c_batchstride,
long nbatch) {
BF_TRACE_STREAM(stream);
BF_CHECK_CUBLAS(cublasSetStream(handle->cublas(), stream));
BF_CHECK_CUBLAS(cublasSetPointerMode(handle->cublas(),
CUBLAS_POINTER_MODE_HOST));
BF_ASSERT(a_data, BF_STATUS_INVALID_POINTER);
BF_ASSERT(b_data, BF_STATUS_INVALID_POINTER);
BF_ASSERT(c_data, BF_STATUS_INVALID_POINTER);
BF_ASSERT(a_type == b_type, BF_STATUS_UNSUPPORTED_DTYPE);
// TODO: Look into optimizations using cublasGemmEx algo selection and
// batched/strided APIs.
switch( a_type ) {
case BF_DTYPE_F32: {
BF_ASSERT(c_type == BF_DTYPE_F32, BF_STATUS_UNSUPPORTED_DTYPE);
BF_CHECK_CUBLAS(cublasSgemmStridedBatched(handle->cublas(),
trans_a,
trans_b,
m,
n,
k,
(float *)&alpha,
(const float *)a_data,
a_stride,
a_batchstride,
(const float *)b_data,
b_stride,
b_batchstride,
(float *)&beta,
(float *)c_data,
c_stride,
c_batchstride,
nbatch));
break;
}
case BF_DTYPE_F64: {
BF_ASSERT(c_type == BF_DTYPE_F64, BF_STATUS_UNSUPPORTED_DTYPE);
BF_CHECK_CUBLAS(cublasDgemmStridedBatched(handle->cublas(),
trans_a,
trans_b,
m,
n,
k,
&alpha,
(const double *)a_data,
a_stride,
a_batchstride,
(const double *)b_data,
b_stride,
b_batchstride,
&beta,
(double *)c_data,
c_stride,
c_batchstride,
nbatch));
break;
}
case BF_DTYPE_CI8: {
BF_ASSERT(c_type == BF_DTYPE_CF32, BF_STATUS_UNSUPPORTED_DTYPE);
cuComplex alpha_cf = make_cuComplex(alpha, 0);
cuComplex beta_cf = make_cuComplex(beta, 0);
BF_CHECK_CUBLAS(cublasGemmStridedBatchedEx(handle->cublas(),
trans_a,
trans_b,
m,
n,
k,
&alpha_cf,
a_data,
CUDA_C_8I,
a_stride,
a_batchstride,
b_data,
CUDA_C_8I,
b_stride,
b_batchstride,
&beta_cf,
c_data,
CUDA_C_32F,
c_stride,
c_batchstride,
nbatch,
CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT));
}
case BF_DTYPE_CF32: {
BF_ASSERT(c_type == BF_DTYPE_CF32, BF_STATUS_UNSUPPORTED_DTYPE);
const cuComplex alpha_cf = make_cuComplex(alpha, 0);
const cuComplex beta_cf = make_cuComplex(beta, 0);
BF_CHECK_CUBLAS(cublasCgemm3mStridedBatched(handle->cublas(),
trans_a,
trans_b,
m,
n,
k,
&alpha_cf,
(const cuComplex *)a_data,
a_stride,
a_batchstride,
(const cuComplex *)b_data,
b_stride,
b_batchstride,
&beta_cf,
(cuComplex *)c_data,
c_stride,
c_batchstride,
nbatch
));
break;
}
case BF_DTYPE_CF64: {
BF_ASSERT(c_type == BF_DTYPE_CF64, BF_STATUS_UNSUPPORTED_DTYPE);
const cuDoubleComplex alpha_cf = make_cuDoubleComplex(alpha, 0);
const cuDoubleComplex beta_cf = make_cuDoubleComplex(beta, 0);
BF_CHECK_CUBLAS(cublasZgemmStridedBatched(handle->cublas(),
trans_a,
trans_b,
m,
n,
k,
&alpha_cf,
(const cuDoubleComplex *)a_data,
a_stride,
a_batchstride,
(const cuDoubleComplex *)b_data,
b_stride,
b_batchstride,
&beta_cf,
(cuDoubleComplex *)c_data,
c_stride,
c_batchstride,
nbatch
));
break;
}
default:
BF_FAIL("Supported dtype for input array", BF_STATUS_UNSUPPORTED_DTYPE);
}
return BF_STATUS_SUCCESS;
}

BFstatus bfMatMul_ab_exec(BFlinalg handle,
cudaStream_t stream,
cublasOperation_t trans_a,
Expand All @@ -498,13 +658,12 @@ BFstatus bfMatMul_ab_exec(BFlinalg handle,
BFdtype c_type,
long c_stride,
long c_batchstride) {
// TODO: Use batched algos here where possible
// We prefer the CI4@CF32 -> CF32 code
// second choice would be the batched algorithms

//char* use_bf_cgemm_str = getenv("BF_CGEMM");
//bool use_bf_cgemm = use_bf_cgemm_str && atoi(use_bf_cgemm_str);
if( //use_bf_cgemm &&
n <= 12 &&
trans_a == CUBLAS_OP_T && trans_b == CUBLAS_OP_N &&

if( trans_a == CUBLAS_OP_T &&
trans_b == CUBLAS_OP_N &&
(a_type == BF_DTYPE_CI4 || a_type == BF_DTYPE_CI8) &&
(b_type == BF_DTYPE_CI16 || b_type == BF_DTYPE_CF16 || b_type == BF_DTYPE_CF32) &&
c_type == BF_DTYPE_CF32 ) {
Expand All @@ -518,19 +677,45 @@ BFstatus bfMatMul_ab_exec(BFlinalg handle,
stream));
}

for( long b=0; b<nbatch; ++b ) {
cuda::child_stream child_stream(stream);
BF_CHECK( bfMatMul_ab_exec_nobatch(handle, child_stream,
trans_a, trans_b,
m, n, k,
alpha,
a_data, a_type, a_stride,
b_data, b_type, b_stride,
beta,
c_data, c_type, c_stride) );
a_data = (char*)a_data + a_batchstride * BF_DTYPE_NBYTE(a_type);
b_data = (char*)b_data + b_batchstride * BF_DTYPE_NBYTE(b_type);
c_data = (char*)c_data + c_batchstride * BF_DTYPE_NBYTE(c_type);
// TODO: Why does ci8 yield nans when we go into batched execution?
if( nbatch > 12 && a_type != BF_DTYPE_CI8 ) {
BF_CHECK( bfMatMul_ab_exec_batch(handle,
stream,
trans_a,
trans_b,
m,
n,
k,
alpha,
a_data,
a_type,
a_stride,
a_batchstride,
b_data,
b_type,
b_stride,
b_batchstride,
beta,
c_data,
c_type,
c_stride,
c_batchstride,
nbatch) );
} else {
for( long b=0; b<nbatch; ++b ) {
cuda::child_stream child_stream(stream);
BF_CHECK( bfMatMul_ab_exec_nobatch(handle, child_stream,
trans_a, trans_b,
m, n, k,
alpha,
a_data, a_type, a_stride,
b_data, b_type, b_stride,
beta,
c_data, c_type, c_stride) );
a_data = (char*)a_data + a_batchstride * BF_DTYPE_NBYTE(a_type);
b_data = (char*)b_data + b_batchstride * BF_DTYPE_NBYTE(b_type);
c_data = (char*)c_data + c_batchstride * BF_DTYPE_NBYTE(c_type);
}
}
return BF_STATUS_SUCCESS;
}
Expand Down
4 changes: 3 additions & 1 deletion src/linalg_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,9 @@ void bf_cherk_N(int N, int K, int nbatch,
int A_offset = A_byte_offset / element_bytes;

size_t A_nelement_total =
std::max(A_stride * K, A_batchstride * nbatch) + A_offset;
(A_offset + N) // the elements in the first row of first batch
+ (K - 1) * A_stride // the elements in the rest of the first batch
+ (nbatch - 1) * A_batchstride; // the elements for the remaining batches
size_t texture_element_limit = 1 << 27;
BF_ASSERT_EXCEPTION(A_nelement_total <= texture_element_limit,
BF_STATUS_UNSUPPORTED_SHAPE);
Expand Down

0 comments on commit 3c68028

Please sign in to comment.