Skip to content

Commit

Permalink
feat: mma rowsum for fp8 (flashinfer-ai#180)
Browse files Browse the repository at this point in the history
Both e4m3 and e5m2.
  • Loading branch information
yzh119 authored Mar 13, 2024
1 parent d305798 commit 5af935c
Showing 1 changed file with 41 additions and 3 deletions.
44 changes: 41 additions & 3 deletions include/flashinfer/mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ __device__ __forceinline__ void stmatrix_m8n8x4(uint32_t* R, T* smem_ptr) {
template <typename T, MMAMode mma_mode = MMAMode::kInplaceUpdate>
__device__ __forceinline__ void mma_sync_m16n16k32_row_col_f8f8f32(float* C, uint32_t* A,
uint32_t* B) {
static_assert(sizeof(T) == 1, "DType must be 8bit floating data type");
#if defined(FLASHINFER_MMA_F8F8F32_M16N8K32_ENABLED)
if constexpr (mma_mode == MMAMode::kInit) {
if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
Expand Down Expand Up @@ -216,7 +217,7 @@ __device__ __forceinline__ void mma_sync_m16n16k32_row_col_f8f8f32(float* C, uin
}
}
#else
static_assert(false, "fp8 mma instruction is only available for sm89, PTX 8.4+ and CUDA 12.4+");
#error "fp8 mma instruction is only available for sm89, PTX 8.4+ and CUDA 12.4+"
#endif
}

Expand Down Expand Up @@ -387,8 +388,45 @@ __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, u
#endif
}

// template <typename DType>
// __device__ __forceinline__ void
/*!
* \brief Use mma instructions to compute rowsum.
*/
template <typename DType>
__device__ __forceinline__ void rowsum_f8f8f32(float* d, DType* s) {
static_assert(sizeof(DType) == 1, "DType must be 8bit floating data type");
uint32_t* s_u32 = (uint32_t*)(s);
#if defined(FLASHINFER_MMA_F8F8F32_M16N8K32_ENABLED)
if constexpr (std::is_same<DType, __nv_fp8_e4m3>::value) {
asm volatile(
"{\n"
".reg .f32 ph;\n"
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0, ph, %1, ph},"
"{%2, %3, %4, %5},"
"{%6, %7},"
"{%8, 0., %9, 0.};\n"
"}\n"
: "=f"(d[0]), "=f"(d[1])
: "r"(s_u32[0]), "r"(s_u32[1]), "r"(s_u32[2]), "r"(s_u32[3]), "r"(943208504),
"r"(943208504), "f"(d[0]), "f"(d[1]));
} else { // e5m2
asm volatile(
"{\n"
".reg .f32 ph;\n"
"mma.sync.aligned.m16n8k16.row.col.f32.e5m2.e5m2.f32 "
"{%0, ph, %1, ph},"
"{%2, %3, %4, %5},"
"{%6, %7},"
"{%8, 0., %9, 0.};\n"
"}\n"
: "=f"(d[0]), "=f"(d[1])
: "r"(s_u32[0]), "r"(s_u32[1]), "r"(s_u32[2]), "r"(s_u32[3]), "r"(1010580540),
"r"(1010580540), "f"(d[0]), "f"(d[1]));
}
#else
#error "fp8 mma instruction is only available for sm89, PTX 8.4+ and CUDA 12.4+"
#endif
}

/*!
* \brief Use mma instructions to compute rowsum.
Expand Down

0 comments on commit 5af935c

Please sign in to comment.