From c6f20d1406a3a8c4f134c4a764d16e157a184338 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Thu, 18 Jul 2024 19:43:09 -0700 Subject: [PATCH] perf: use stmatrix in epilogue for sm90+ (#380) sm90+ can benefit from stmatrix in epilogue. --- include/flashinfer/attention/prefill.cuh | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index e1c676bb5..13d7a54fe 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -893,6 +893,11 @@ __device__ __forceinline__ void write_o_reg_gmem( for (uint32_t fy = 0; fy < num_frags_y; ++fy) { uint32_t o_frag_f16[4]; vec_cast((DTypeOut*)o_frag_f16, o_frag[fx][fy]); +#ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED + uint32_t o_smem_offset_w = smem_t::get_permuted_offset( + (warp_idx_x * num_frags_x + fx) * 16 + lane_idx % 16, fy * 2 + lane_idx / 16); + o_smem->stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); +#else uint32_t o_smem_offset_w = smem_t::get_permuted_offset( (warp_idx_x * num_frags_x + fx) * 16 + lane_idx / 4, fy * 2); ((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0]; @@ -901,6 +906,7 @@ __device__ __forceinline__ void write_o_reg_gmem( ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[lane_idx % 4] = o_frag_f16[2]; ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1) + 8 * channel_size_128b_out))[lane_idx % 4] = o_frag_f16[3]; +#endif } }