From 955ba6488bd2ef1818fcd5f56caa3cbf02a05067 Mon Sep 17 00:00:00 2001 From: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Date: Tue, 18 Feb 2025 16:20:55 -0800 Subject: [PATCH] Optimization for quantized gemm skinny sizes (#411) * Optimization for quantized gemm skinny sizes * lint fix * Add support for bf16/fp16 * code cleanup * code cleanup * lint fix2 * cleanup * Moved the logic into tuned gemm to preserve API compatibility --------- Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: Gregory Shtrasberg --- csrc/rocm/custom.cu | 18 + csrc/rocm/custom_kernels.cu | 532 ++++++++++++++++-- csrc/rocm/ops.h | 4 + csrc/rocm/torch_bindings.cpp | 7 + vllm/_custom_ops.py | 6 + .../layers/quantization/utils/w8a8_utils.py | 13 +- vllm/model_executor/layers/tuned_gemm.py | 31 + 7 files changed, 559 insertions(+), 52 deletions(-) diff --git a/csrc/rocm/custom.cu b/csrc/rocm/custom.cu index fae1b4fbfbe33..c799dd273daef 100644 --- a/csrc/rocm/custom.cu +++ b/csrc/rocm/custom.cu @@ -48,6 +48,24 @@ void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, at::cuda::getCurrentCUDAStream(), CuCount); } +void wvSpltKQ_(void* in_a, void* in_b, void* out_c, void* scale_a, + void* scale_b, const int M, const int K, const int Kp, + const int N, const int Otp_in, cudaStream_t stream, + const int CuCount); + +void wvSpltKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + at::Tensor& scale_a, at::Tensor& scale_b, const int64_t N_in, + const int64_t Otp_in, const int64_t CuCount) { + auto M = in_a.size(0); + auto K = in_a.size(1); + auto Kp = in_a.stride(0); + int N = N_in; + int Otp = Otp_in; + wvSpltKQ_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), + scale_a.data_ptr(), scale_b.data_ptr(), M, K, Kp, N, Otp, + at::cuda::getCurrentCUDAStream(), CuCount); +} + void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, cudaStream_t stream, const int solidx); diff --git a/csrc/rocm/custom_kernels.cu b/csrc/rocm/custom_kernels.cu index ba90b3f75a072..d130461b27e2e 100644 --- a/csrc/rocm/custom_kernels.cu +++ b/csrc/rocm/custom_kernels.cu @@ -1,5 +1,6 @@ #include #include +#include #include #include #include "cuda_compat.h" @@ -327,7 +328,7 @@ void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, #define DTYPE half -__device__ __forceinline__ int mindiv(int N, int div1, int div2) { +/*__device__ __forceinline__ int mindiv(int N, int div1, int div2) { int nPrRnd = div1 * div2; int rnds0 = N / nPrRnd; nPrRnd -= div1 * 3; @@ -354,14 +355,391 @@ __device__ __forceinline__ int mindiv(int N, int div1, int div2) { if (rnds0 == rnds8) rtn = div2 - 8; if (rnds0 == rnds9) rtn = div2 - 9; return rtn; +}*/ + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSpltKQ_hf_sml_(const int K, const int Kp, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const float* __restrict__ s_A, + const float* __restrict__ s_B, const int _WvPrGrp, + const int Otp, const int CuCount) { + using half8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int; + using intx4 = __attribute__((__vector_size__(4 * sizeof(int)))) int; + union bigType { + char f8[A_CHUNK * 2]; + char2 c2[A_CHUNK]; + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + int i[A_CHUNK / 2]; + long l[A_CHUNK / 4]; + intx4 l2[A_CHUNK / 8]; + half8 h8; + }; + + __shared__ half s[1024 * 32]; + + for (uint32_t k = 0; k < min(K / 2 * M, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + if (k_in >= min(K / 2 * M, 32 * 1024)) break; + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + } + __syncthreads(); + + if (threadIdx.y >= _WvPrGrp) return; + + uint32_t n = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; + + using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float; + floatx16 sum[M][YTILE]; + float sA = *s_A; + float sB = *s_B; + + while (n < N) { + for (int i = 0; i < YTILE; i++) + for (int m = 0; m < M; m++) sum[m][i] = {0}; + + bigType bigA[M][UNRL]; + bigType bigB0[UNRL]; + bigType bigB1[UNRL]; + bigType bigB2[UNRL]; + bigType bigB3[UNRL]; + bigType bigB4[UNRL]; + bigType bigB5[UNRL]; + bigType bigB6[UNRL]; + bigType bigB7[UNRL]; + + // Fetch the weight matrix from memory! + for (uint32_t k1 = 0; k1 < K / 2; k1 += THRDS * A_CHUNK * UNRL) { + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K / 2) break; + + const half* B_ = &B[(n + 0) * (Kp / 2) + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * Kp / 2]))); + if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * Kp / 2]))); + if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * Kp / 2]))); + if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * Kp / 2]))); + if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * Kp / 2]))); + if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * Kp / 2]))); + if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * Kp / 2]))); + if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * Kp / 2]))); + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K / 2) break; + for (int m = 0; m < M; m++) { + // if (k_ + K * m < 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[k_ + K / 2 * m]))); + // else + // bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K / 2) break; + float aV[A_CHUNK * 2]; + + for (uint32_t m = 0; m < M; m++) { + for (int i = 0; i < A_CHUNK * 2; i += 8) { + sum[m][0] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bigA[m][k2].l[i / 8], bigB0[k2].l[i / 8], sum[m][0], 0, 0, 0); + if (YTILE >= 2) + sum[m][1] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bigA[m][k2].l[i / 8], bigB1[k2].l[i / 8], sum[m][1], 0, 0, 0); + } + } + } + } + + // Final reduction + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + float accm0 = sum[m][y][0]; + float accm16 = sum[m][y][8]; + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][1]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][9]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][2]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][10]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][3]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][11]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][4]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][12]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][5]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][13]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][6]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][14]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][7]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][15]), "v"(accm16)); + accm0 += __shfl(accm0, 36); + accm16 += __shfl(accm16, 52); + sum[m][y][0] = accm0 + __shfl(accm16, 16); + } + } + + if (threadIdx.x == 0) { + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + if (Otp == 0) // fp16 + C[n + y + m * N] = __float2half(sum[m][y][0] * sA * sB); + else // if (Otp == 1) //bf16 + *reinterpret_cast<__hip_bfloat16*>(&C[n + y + m * N]) = + __float2bfloat16(sum[m][y][0] * sA * sB); + } + } + } + + n += CuCount * _WvPrGrp * YTILE; + } +} +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSpltKQ_hf_sml_(const int K, const int Kp, const int N, + const DTYPE* B, const DTYPE* __restrict__ A, + DTYPE* C, const float* __restrict__ s_A, + const float* __restrict__ s_B, + const int _WvPrGrp, const int Otp, + const int CuCount) { + UNREACHABLE_CODE } +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSpltKQ_hf_(const int K, const int Kp, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const float* __restrict__ s_A, const float* __restrict__ s_B, + const int _WvPrGrp, const int Otp, const int CuCount) { + using half8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + using intx2 = __attribute__((__vector_size__(2 * sizeof(int)))) int; + using intx4 = __attribute__((__vector_size__(4 * sizeof(int)))) int; + union bigType { + char f8[A_CHUNK * 2]; + char2 c2[A_CHUNK]; + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + int i[A_CHUNK / 2]; + long l[A_CHUNK / 4]; + intx4 l2[A_CHUNK / 8]; + half8 h8; + }; + + __shared__ half s[1024 * 32]; + + for (uint32_t k = 0; k < min(K / 2 * M, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + if (k_in >= min(K / 2 * M, 32 * 1024)) break; + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + } + __syncthreads(); + + if (threadIdx.y >= _WvPrGrp) return; + + uint32_t n = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; + + using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float; + floatx16 sum[M][YTILE]; + float sA = *s_A; + float sB = *s_B; + + while (n < N) { + for (int i = 0; i < YTILE; i++) + for (int m = 0; m < M; m++) sum[m][i] = {0}; + + bigType bigA[M][UNRL]; + bigType bigB0[UNRL]; + bigType bigB1[UNRL]; + bigType bigB2[UNRL]; + bigType bigB3[UNRL]; + bigType bigB4[UNRL]; + bigType bigB5[UNRL]; + bigType bigB6[UNRL]; + bigType bigB7[UNRL]; + + // Fetch the weight matrix from memory! + for (uint32_t k1 = 0; k1 < K / 2; k1 += THRDS * A_CHUNK * UNRL) { + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K / 2) break; + + const half* B_ = &B[(n + 0) * (Kp / 2) + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * Kp / 2]))); + if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * Kp / 2]))); + if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * Kp / 2]))); + if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * Kp / 2]))); + if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * Kp / 2]))); + if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * Kp / 2]))); + if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * Kp / 2]))); + if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * Kp / 2]))); + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K / 2) break; + for (int m = 0; m < M; m++) { + if (k_ + K / 2 * m < 64 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[k_ + K / 2 * m]))); + else + bigA[m][k2] = *((const bigType*)(&(A[k_ + K / 2 * m]))); + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K / 2) break; + float aV[A_CHUNK * 2]; + + for (uint32_t m = 0; m < M; m++) { + for (int i = 0; i < A_CHUNK * 2; i += 8) { + sum[m][0] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bigA[m][k2].l[i / 8], bigB0[k2].l[i / 8], sum[m][0], 0, 0, 0); + if (YTILE >= 2) + sum[m][1] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bigA[m][k2].l[i / 8], bigB1[k2].l[i / 8], sum[m][1], 0, 0, 0); + } + } + } + } + + // Final reduction + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + float accm0 = sum[m][y][0]; + float accm16 = sum[m][y][8]; + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][1]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][9]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][2]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][10]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][3]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][11]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][4]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][12]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][5]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][13]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][6]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][14]), "v"(accm16)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " + : "=v"(accm0) + : "0"(accm0), "v"(sum[m][y][7]), "v"(accm0)); + asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 " + : "=v"(accm16) + : "0"(accm16), "v"(sum[m][y][15]), "v"(accm16)); + accm0 += __shfl(accm0, 36); + accm16 += __shfl(accm16, 52); + sum[m][y][0] = accm0 + __shfl(accm16, 16); + } + } + + if (threadIdx.x == 0) { + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + if (Otp == 0) // fp16 + C[n + y + m * N] = __float2half(sum[m][y][0] * sA * sB); + else // if (Otp == 12) //bf16 + *reinterpret_cast<__hip_bfloat16*>(&C[n + y + m * N]) = + __float2bfloat16(sum[m][y][0] * sA * sB); + } + } + } + + n += CuCount * _WvPrGrp * YTILE; + } +} +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSpltKQ_hf_(const int K, const int Kp, const int N, + const DTYPE* B, const DTYPE* __restrict__ A, + DTYPE* C, const float* __restrict__ s_A, + const float* __restrict__ s_B, const int _WvPrGrp, + const int Otp, const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support // This version targets cases where A[] fits LDS capacity template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, + const int CuCount) { using half8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; union bigType { @@ -381,34 +759,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- __shared__ half s[1024 * 32]; - //---------------------------------------------------- - // Computation of columns that need to be committed to memory! - //---------------------------------------------------- - // uint32_t commitColumn[YTILE]; - // for (uint32_t i = 0; i < YTILE; i++) { - // commitColumn[i] = 1; - //} - - // It's worth trying to load-balance... - int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); - - //---------------------------------------------------- - // Indexing function into the column of weight matrix B - // Algorithm does 64 lane k-splitting / wave and uses - // WG ID and Thread ID to find the index. - //---------------------------------------------------- - uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; - - // Check whether there will be fragmenation! - // This will happen only for the last wave! - // if (n < N && (n + YTILE) >= N) { - // uint32_t startColumn = N - YTILE; - // for (uint32_t i = 0; i < (n - startColumn); i++) { - // commitColumn[i] = 0; - // } - // n = startColumn; - //} - //---------------------------------------------------- // Fetch the activation matrix to LDS // Loop iteration: @@ -434,8 +784,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } __syncthreads(); + // int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); if (threadIdx.y >= _WvPrGrp) return; + uint32_t n = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; + float sum[M][YTILE]; //---------------------------------------------------- @@ -490,6 +843,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // // TODO: Logic below will only work when K is multiple of 8 //---------------------------------------------------- + // for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { // Fetch the weight matrix from memory! #pragma unroll @@ -632,7 +986,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) template __global__ void wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, - const int CuCount) { + const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support @@ -642,7 +996,8 @@ __global__ void wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltK_hf_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, + const int CuCount) { using half8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; union bigType { @@ -670,14 +1025,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) commitColumn[i] = 1; } - // It's worth trying to load-balance... - int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); - //---------------------------------------------------- // Indexing function into the column of weight matrix B // Algorithm does 64 lane k-splitting / wave and uses // WG ID and Thread ID to find the index. //---------------------------------------------------- + // int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; // Check whether there will be fragmenation! @@ -713,6 +1066,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; } + __syncthreads(); if (threadIdx.y >= _WvPrGrp) return; @@ -915,7 +1269,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) template __global__ void wvSpltK_hf_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, - const int CuCount) { + const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support @@ -925,7 +1279,8 @@ __global__ void wvSpltK_hf_(const int K, const int N, const DTYPE* B, template __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSpltK_hf_big_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + const DTYPE* __restrict__ A, DTYPE* C, const int _WvPrGrp, + const int CuCount) { using half8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; @@ -954,8 +1309,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) commitColumn[i] = 1; } - // It's worth trying to load-balance... - int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + // int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); if (threadIdx.y >= _WvPrGrp) return; //---------------------------------------------------- @@ -1252,11 +1606,40 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) template __global__ void wvSpltK_hf_big_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, - const int CuCount) { + const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support +int mindiv(int N, int div1, int div2) { + int nPrRnd = div1 * div2; + int rnds0 = N / nPrRnd; + nPrRnd -= div1 * 3; + int rnds3 = N / nPrRnd; + nPrRnd -= div1; + int rnds4 = N / nPrRnd; + nPrRnd -= div1; + int rnds5 = N / nPrRnd; + nPrRnd -= div1; + int rnds6 = N / nPrRnd; + nPrRnd -= div1; + int rnds7 = N / nPrRnd; + nPrRnd -= div1; + int rnds8 = N / nPrRnd; + nPrRnd -= div1; + int rnds9 = N / nPrRnd; + nPrRnd -= div1; + int rtn = div2; + if (rnds0 == rnds3) rtn = div2 - 3; + if (rnds0 == rnds4) rtn = div2 - 4; + if (rnds0 == rnds5) rtn = div2 - 5; + if (rnds0 == rnds6) rtn = div2 - 6; + if (rnds0 == rnds7) rtn = div2 - 7; + if (rnds0 == rnds8) rtn = div2 - 8; + if (rnds0 == rnds9) rtn = div2 - 9; + return rtn; +} + void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, const int K_in, const int N_in, cudaStream_t stream, const int CuCount = 0) { @@ -1269,17 +1652,21 @@ void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, _N) \ { \ dim3 block(64, _WvPrGrp); \ - /*wvSpltK_hf:*/ \ if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ wvSpltK_hf_sml_<64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \ - <<>>(K_in, M_in, af4, bf4, c, CuCount); \ + <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ + CuCount); \ } else if (K_in * N_in <= 32 * 1024 * 1.2) { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ wvSpltK_hf_<64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ - <<>>(K_in, M_in, af4, bf4, c, CuCount); \ + <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ + CuCount); \ } else { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \ wvSpltK_hf_big_<64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \ - <<>>(K_in, M_in, af4, bf4, c, CuCount); \ + <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ + CuCount); \ } \ } @@ -1306,4 +1693,57 @@ void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, if (cudaSuccess != err) { throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); } -} \ No newline at end of file +} + +void wvSpltKQ_(void* in_a, void* in_b, void* out_c, void* scale_a, + void* scale_b, const int M_in, const int K_in, const int Kp_in, + const int N_in, const int Otp_in, cudaStream_t stream, + const int CuCount = 0) { + dim3 grid(CuCount); + half* af4 = reinterpret_cast(in_a); + const half* bf4 = reinterpret_cast(in_b); + auto* c = reinterpret_cast(out_c); + auto* s_a = reinterpret_cast(scale_a); + auto* s_b = reinterpret_cast(scale_b); + +#define WVSPLTKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ + _N) \ + { \ + dim3 block(64, _WvPrGrp); \ + if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ + wvSpltKQ_hf_sml_<64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \ + <<>>(K_in, Kp_in, M_in, af4, bf4, c, s_a, \ + s_b, __wvPrGrp, Otp_in, CuCount); \ + } else { \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ + wvSpltKQ_hf_<64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ + <<>>(K_in, Kp_in, M_in, af4, bf4, c, s_a, \ + s_b, __wvPrGrp, Otp_in, CuCount); \ + } \ + } + + switch (N_in) { + case 1: + WVSPLTKQ(16, 2, 2, 2, 2, 2, 2, 1) // MI308 + break; + case 2: + WVSPLTKQ(16, 2, 2, 2, 2, 2, 2, 2) // MI308 + break; + case 3: + WVSPLTKQ(16, 4, 7, 7, 1, 1, 1, 3) // MI308 + break; + case 4: + WVSPLTKQ(16, 4, 7, 7, 1, 1, 1, 4) // MI308 + break; + default: + throw std::runtime_error("Unsupported N value: " + std::to_string(M_in) + + "," + std::to_string(K_in) + "," + + std::to_string(N_in)); + } + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) { + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); + } +} diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index 59bd28e3bc127..0701b5df3f07a 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -11,6 +11,10 @@ void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, const int64_t N_in, const int64_t CuCount); +void wvSpltKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + at::Tensor& scale_a, at::Tensor& scale_b, const int64_t N_in, + const int64_t Otp_in, const int64_t CuCount); + void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 50640a96725e0..4ca24b7c19e7f 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -43,6 +43,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { "wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in," " int CuCount) -> ()"); rocm_ops.impl("wvSpltK", torch::kCUDA, &wvSpltK); + rocm_ops.def( + "wvSpltKQ(Tensor in_a, Tensor in_b, Tensor! out_c, Tensor scale_a, " + "Tensor scale_b," + " int N_in," + " int Otp_in," + " int CuCount) -> ()"); + rocm_ops.impl("wvSpltKQ", torch::kCUDA, &wvSpltKQ); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index bd19dbf56b1de..e45629387e639 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1192,3 +1192,9 @@ def LLMM_Silu(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, def wvSpltK(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, N: int, cu_count: int) -> None: torch.ops._rocm_C.wvSpltK(a, b, out, N, cu_count) + + +def wvSpltKQ(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, + scale_a: torch.Tensor, scale_b: torch.Tensor, N: int, Otp: int, + cu_count: int) -> None: + torch.ops._rocm_C.wvSpltKQ(a, b, out, scale_a, scale_b, N, Otp, cu_count) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index a3f4a4f622492..12a5aac9b8d1d 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -5,6 +5,7 @@ import torch from vllm import _custom_ops as ops +from vllm.model_executor.layers.tuned_gemm import tgemm from vllm.platforms import current_platform # Input scaling factors are no longer optional in _scaled_mm starting @@ -172,12 +173,12 @@ def apply_fp8_linear( if per_tensor_weights and per_tensor_activations: # Fused GEMM_DQ - output = torch._scaled_mm(qinput, - weight, - out_dtype=out_dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias) + output = tgemm.scaled_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) # A fix for discrepancy in scaled_mm which returns tuple # for torch < 2.5 and a single value in torch >= 2.5 if type(output) is tuple and len(output) == 2: diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 8fb44cdc96c2b..cf3caebf3201b 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import os from pathlib import Path +from typing import Optional import pandas as pd import torch @@ -91,6 +92,36 @@ def apply_skinny(self, m, n, k, inp_view, weights): else: return None + def scaled_mm( + self, + inp: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: Optional[torch.Tensor], + ) -> torch.Tensor: + n = inp.shape[0] + if n != 1: + return torch._scaled_mm(inp, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + bias=bias) + weightT = weight.t() + out = torch.empty(inp.shape[0], + weightT.shape[0], + dtype=out_dtype, + device='cuda') + + Otp = 1 #default bfloat16 + if out_dtype == torch.float16: + Otp = 0 + ops.wvSpltKQ(weightT, inp, out, scale_a, scale_b, n, Otp, + self.cu_count) + return out + def mm(self, inp, weights, bias=None): if not support_tuned_gemms: return F.linear(inp, weights, bias)