Skip to content

Commit

Permalink
Optimization for quantized gemm skinny sizes (#411)
Browse files Browse the repository at this point in the history
* Optimization for quantized gemm skinny sizes

* lint fix

* Add support for bf16/fp16

* code cleanup

* code cleanup

* lint fix2

* cleanup

* Moved the logic into tuned gemm to preserve API compatibility

---------

Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
  • Loading branch information
3 people authored Feb 19, 2025
1 parent 17b26bd commit 955ba64
Show file tree
Hide file tree
Showing 7 changed files with 559 additions and 52 deletions.
18 changes: 18 additions & 0 deletions csrc/rocm/custom.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,24 @@ void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
at::cuda::getCurrentCUDAStream(), CuCount);
}

void wvSpltKQ_(void* in_a, void* in_b, void* out_c, void* scale_a,
void* scale_b, const int M, const int K, const int Kp,
const int N, const int Otp_in, cudaStream_t stream,
const int CuCount);

void wvSpltKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t N_in,
const int64_t Otp_in, const int64_t CuCount) {
auto M = in_a.size(0);
auto K = in_a.size(1);
auto Kp = in_a.stride(0);
int N = N_in;
int Otp = Otp_in;
wvSpltKQ_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(),
scale_a.data_ptr(), scale_b.data_ptr(), M, K, Kp, N, Otp,
at::cuda::getCurrentCUDAStream(), CuCount);
}

void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K,
cudaStream_t stream, const int solidx);

Expand Down
Loading

0 comments on commit 955ba64

Please sign in to comment.