From 10b8ed7417eff2354d14b37722d58e8ac1e7bdb7 Mon Sep 17 00:00:00 2001 From: Xiaozhu Meng Date: Tue, 28 Jan 2025 10:00:46 -0800 Subject: [PATCH] amd fp8 rowwise batched gemm tuning (#3624) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/702 Range-based lookup for fixed B, N, and K Reviewed By: jwfromm Differential Revision: D68780527 --- .../fp8_rowwise_batched_gemm.hip | 86 ++++++++++- ...16x8x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip | 39 +++++ ...32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip | 39 +++++ ...8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip | 72 +++------ ...8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5.hip | 72 +++------ ...6x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip | 39 +++++ ...8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3.hip | 39 +++++ ...8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip | 39 +++++ ...8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip | 39 +++++ ...6x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3.hip | 39 +++++ ...2x8x1_1x16x1x16_4x4x1_1x1_intrawave_v3.hip | 39 +++++ ...8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip | 74 +++------ ...8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip | 39 +++++ ...8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip | 39 +++++ ...8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip | 39 +++++ ...8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip | 72 +++------ ...8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip | 73 +++------ ...6x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip | 39 +++++ ...32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip | 39 +++++ ...6x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip | 39 +++++ ...6x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip | 39 +++++ ...32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip | 39 +++++ ...32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip | 39 +++++ ...32x2x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip | 39 +++++ .../fp8_rowwise_batched_kernel_manifest.h | 144 ++++++++++++++++++ 25 files changed, 1036 insertions(+), 259 deletions(-) create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_128x16x32x256_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3.hip create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x256x128_32x32_2x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3.hip create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x16x64x512_16x16_1x1_32x8x1_32x8x1_1x16x1x16_4x4x1_1x1_intrawave_v3.hip create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x256x128x128_16x16_8x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x32x128x256_32x32_1x1_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x64x192x256_32x32_1x3_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/fp8_rowwise_batched_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/fp8_rowwise_batched_gemm.hip index 1942a5ddb7..6c0cb5330a 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/fp8_rowwise_batched_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/fp8_rowwise_batched_gemm.hip @@ -25,6 +25,81 @@ namespace fbgemm_gpu { using RowwiseBatchedKernel = std::function< at::Tensor(at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor)>; +using BNKLookupTableType = std::map; + +// Define a custom hash function for std::tuple +struct IntTupleHash { + size_t operator()(const std::tuple& t) const { + auto hash1 = std::hash{}(std::get<0>(t)); + auto hash2 = std::hash{}(std::get<1>(t)); + auto hash3 = std::hash{}(std::get<2>(t)); + return hash1 ^ hash2 ^ hash3; + } +}; + +static const std::map B_2_N_5120_K_8192_dispatch_table = { + { 8, fp8_rowwise_batched_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2}, + { 16, fp8_rowwise_batched_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_intrawave_v2}, + { 32, fp8_rowwise_batched_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + { 64, fp8_rowwise_batched_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, + { 72, fp8_rowwise_batched_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + { 96, fp8_rowwise_batched_256x32x128x256_32x32_1x1_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, + { 192, fp8_rowwise_batched_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, + { 248, fp8_rowwise_batched_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3}, + { 384, fp8_rowwise_batched_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, + { 512, fp8_rowwise_batched_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3}, + { 640, fp8_rowwise_batched_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, + { 768, fp8_rowwise_batched_256x256x128x128_16x16_8x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + { 896, fp8_rowwise_batched_256x128x256x128_32x32_2x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, + { 1024, fp8_rowwise_batched_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + { 1568, fp8_rowwise_batched_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + { 1792, fp8_rowwise_batched_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + { 2304, fp8_rowwise_batched_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + { 2816, fp8_rowwise_batched_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + { 3360, fp8_rowwise_batched_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + { 8992, fp8_rowwise_batched_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, +}; + +static const std::map B_2_N_8192_K_5120_dispatch_table = { + { 4, fp8_rowwise_batched_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2}, + { 8, fp8_rowwise_batched_128x16x32x256_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v1}, + { 16, fp8_rowwise_batched_256x16x64x512_16x16_1x1_32x8x1_32x8x1_1x16x1x16_4x4x1_1x1_intrawave_v3}, + { 32, fp8_rowwise_batched_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + { 64, fp8_rowwise_batched_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, + { 128, fp8_rowwise_batched_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, + { 192, fp8_rowwise_batched_256x64x192x256_32x32_1x3_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, + { 208, fp8_rowwise_batched_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5}, + { 232, fp8_rowwise_batched_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, + { 256, fp8_rowwise_batched_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, + { 384, fp8_rowwise_batched_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, + { 512, fp8_rowwise_batched_256x256x128x128_16x16_8x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + { 768, fp8_rowwise_batched_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + { 896, fp8_rowwise_batched_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + { 1024, fp8_rowwise_batched_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + { 1280, fp8_rowwise_batched_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + { 1792, fp8_rowwise_batched_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + { 2048, fp8_rowwise_batched_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + { 2304, fp8_rowwise_batched_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + { 2560, fp8_rowwise_batched_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + { 3136, fp8_rowwise_batched_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + { 8992, fp8_rowwise_batched_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, +}; + +static const std::unordered_map, BNKLookupTableType, IntTupleHash> BNK_lookup_table = { + {{2, 5120, 8192}, B_2_N_5120_K_8192_dispatch_table}, + {{2, 8192, 5120}, B_2_N_8192_K_5120_dispatch_table} +}; + +RowwiseBatchedKernel rowwise_batched_bnk_lookup(int M, const BNKLookupTableType& table) { + auto it = table.lower_bound(M); + if (it != table.end()) { + return it->second; + } else { + --it; + return it->second; + } +} + RowwiseBatchedKernel rowwise_batched_heuristic_dispatch(int B, int M, int N, int K) { // Use shape heuristics to guess what the best kernel might be for the given @@ -114,6 +189,15 @@ rowwise_batched_heuristic_dispatch(int B, int M, int N, int K) { return fp8_rowwise_batched_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3; } +RowwiseBatchedKernel +rowwise_batched_dispatch(int B, int M, int N, int K) { + auto it = BNK_lookup_table.find({B, N, K}); + if (it != BNK_lookup_table.end()) { + return rowwise_batched_bnk_lookup(M, it->second); + } + return rowwise_batched_heuristic_dispatch(B, M, N, K); +} + at::Tensor f8f8bf16_rowwise_batched( at::Tensor XQ, at::Tensor WQ, @@ -165,7 +249,7 @@ at::Tensor f8f8bf16_rowwise_batched( } RowwiseBatchedKernel selected_kernel = - rowwise_batched_heuristic_dispatch(B, M, N, K); + rowwise_batched_dispatch(B, M, N, K); return selected_kernel(XQ, WQ, x_scale, w_scale, Y); } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_128x16x32x256_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_128x16x32x256_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip new file mode 100644 index 0000000000..b4dcc651b0 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_128x16x32x256_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_batched_common.h" + +at::Tensor +fp8_rowwise_batched_128x16x32x256_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<16, 8, 1>, + S<16, 8, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_batched_impl(XQ, WQ, x_scale, w_scale, Y); +} + diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip new file mode 100644 index 0000000000..cde89b9cd1 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_batched_common.h" + +at::Tensor +fp8_rowwise_batched_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<32, 4, 1>, + S<32, 4, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_batched_impl(XQ, WQ, x_scale, w_scale, Y); +} + diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip index 5f75483921..bab6f9ae63 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip @@ -15,55 +15,25 @@ fp8_rowwise_batched_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_i at::Tensor x_scale, at::Tensor w_scale, at::Tensor Y) { - // Check if this input needs to be padded. - int M = XQ.size(1); - int N = WQ.size(1); - int K = WQ.size(2); - bool pad = (K % 128 != 0); - - if (pad) { - using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 128, - 128, - 32, - 32, - 2, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::KPadding>; - // Run kernel instance. - return f8f8bf16_rowwise_batched_impl( - XQ, WQ, x_scale, w_scale, Y); - } else { - using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 128, - 128, - 32, - 32, - 2, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::Default>; - // Run kernel instance. - return f8f8bf16_rowwise_batched_impl( - XQ, WQ, x_scale, w_scale, Y); - } + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_batched_impl(XQ, WQ, x_scale, w_scale, Y); } + diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5.hip index 0fa17d0849..07a5c442b0 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5.hip @@ -15,55 +15,25 @@ fp8_rowwise_batched_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_i at::Tensor x_scale, at::Tensor w_scale, at::Tensor Y) { - // Check if this input needs to be padded. - int M = XQ.size(1); - int N = WQ.size(1); - int K = WQ.size(2); - bool pad = (K % 128 != 0); - - if (pad) { - using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 128, - 128, - 32, - 32, - 2, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v5, - ck::tensor_operation::device::GemmSpecialization::KPadding>; - // Run kernel instance. - return f8f8bf16_rowwise_batched_impl( - XQ, WQ, x_scale, w_scale, Y); - } else { - using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 128, - 128, - 32, - 32, - 2, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v5, - ck::tensor_operation::device::GemmSpecialization::Default>; - // Run kernel instance. - return f8f8bf16_rowwise_batched_impl( - XQ, WQ, x_scale, w_scale, Y); - } + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v5, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_batched_impl(XQ, WQ, x_scale, w_scale, Y); } + diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip new file mode 100644 index 0000000000..337bfac402 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_batched_common.h" + +at::Tensor +fp8_rowwise_batched_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 128, + 128, + 256, + 32, + 32, + 2, + 2, + S<16, 16, 1>, + S<16, 16, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_batched_impl(XQ, WQ, x_scale, w_scale, Y); +} + diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3.hip new file mode 100644 index 0000000000..53daf76014 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3.hip @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_batched_common.h" + +at::Tensor +fp8_rowwise_batched_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 128, + 160, + 128, + 32, + 32, + 1, + 5, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 64, 1, 4>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_batched_impl(XQ, WQ, x_scale, w_scale, Y); +} + diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip new file mode 100644 index 0000000000..68ff57ee94 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_batched_common.h" + +at::Tensor +fp8_rowwise_batched_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 128, + 192, + 128, + 32, + 32, + 2, + 3, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_batched_impl(XQ, WQ, x_scale, w_scale, Y); +} + diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x256x128_32x32_2x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x256x128_32x32_2x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip new file mode 100644 index 0000000000..c9ec1c99fd --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x256x128_32x32_2x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_batched_common.h" + +at::Tensor +fp8_rowwise_batched_256x128x256x128_32x32_2x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 128, + 256, + 128, + 32, + 32, + 2, + 4, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_batched_impl(XQ, WQ, x_scale, w_scale, Y); +} + diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3.hip new file mode 100644 index 0000000000..52a9b9cced --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3.hip @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_batched_common.h" + +at::Tensor +fp8_rowwise_batched_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 128, + 96, + 256, + 32, + 32, + 1, + 3, + S<16, 16, 1>, + S<16, 16, 1>, + S<1, 64, 1, 4>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_batched_impl(XQ, WQ, x_scale, w_scale, Y); +} + diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x16x64x512_16x16_1x1_32x8x1_32x8x1_1x16x1x16_4x4x1_1x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x16x64x512_16x16_1x1_32x8x1_32x8x1_1x16x1x16_4x4x1_1x1_intrawave_v3.hip new file mode 100644 index 0000000000..88dc96a7ac --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x16x64x512_16x16_1x1_32x8x1_32x8x1_1x16x1x16_4x4x1_1x1_intrawave_v3.hip @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_batched_common.h" + +at::Tensor +fp8_rowwise_batched_256x16x64x512_16x16_1x1_32x8x1_32x8x1_1x16x1x16_4x4x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 16, + 64, + 512, + 16, + 16, + 1, + 1, + S<32, 8, 1>, + S<32, 8, 1>, + S<1, 16, 1, 16>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_batched_impl(XQ, WQ, x_scale, w_scale, Y); +} + diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip index a8a1938baa..16305ff289 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip @@ -15,57 +15,25 @@ fp8_rowwise_batched_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_i at::Tensor x_scale, at::Tensor w_scale, at::Tensor Y) { - // This kernel works well for many medium to large shapes. - - int M = XQ.size(1); - int N = WQ.size(1); - int K = WQ.size(2); - - bool kpad = K % 128 != 0; - - if (kpad) { - using DeviceGemmInstance = DeviceGemmHelper< - 256, - 224, - 256, - 128, - 16, - 16, - 7, - 8, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 2, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::KPadding>; - // Run kernel instance. - return f8f8bf16_rowwise_batched_impl( - XQ, WQ, x_scale, w_scale, Y); - } else { - using DeviceGemmInstance = DeviceGemmHelper< - 256, - 224, - 256, - 128, - 16, - 16, - 7, - 8, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 2, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::Default>; - // Run kernel instance. - return f8f8bf16_rowwise_batched_impl( - XQ, WQ, x_scale, w_scale, Y); - } + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 224, + 256, + 128, + 16, + 16, + 7, + 8, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_batched_impl(XQ, WQ, x_scale, w_scale, Y); } + diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x256x128x128_16x16_8x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x256x128x128_16x16_8x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip new file mode 100644 index 0000000000..7e60aa00c6 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x256x128x128_16x16_8x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_batched_common.h" + +at::Tensor +fp8_rowwise_batched_256x256x128x128_16x16_8x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 128, + 128, + 16, + 16, + 8, + 4, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_batched_impl(XQ, WQ, x_scale, w_scale, Y); +} + diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip new file mode 100644 index 0000000000..d1f0d1d06d --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_batched_common.h" + +at::Tensor +fp8_rowwise_batched_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 160, + 128, + 16, + 16, + 8, + 5, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 64, 1, 4>, + S<8, 8, 1>, + 2, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_batched_impl(XQ, WQ, x_scale, w_scale, Y); +} + diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip new file mode 100644 index 0000000000..91bbfdc494 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_batched_common.h" + +at::Tensor +fp8_rowwise_batched_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 192, + 128, + 16, + 16, + 8, + 6, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_batched_impl(XQ, WQ, x_scale, w_scale, Y); +} + diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip index 6fc0b5c027..33e04605a7 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip @@ -15,55 +15,25 @@ fp8_rowwise_batched_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_i at::Tensor x_scale, at::Tensor w_scale, at::Tensor Y) { - // Check if this input needs to be padded. - int M = XQ.size(1); - int N = WQ.size(1); - int K = WQ.size(2); - bool pad = (M % 256 != 0) || (N % 224 != 0) || (K % 128 != 0); - - // This kernel seems optimal in the most purely compute bound tasks. - if (pad) { - using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 224, - 128, - 16, - 16, - 8, - 7, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 64, 1, 4>, - S<8, 8, 1>, - 2, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3>; - // Run kernel instance. - return f8f8bf16_rowwise_batched_impl( - XQ, WQ, x_scale, w_scale, Y); - } else { - using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 224, - 128, - 16, - 16, - 8, - 7, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 64, 1, 4>, - S<8, 8, 1>, - 2, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::Default>; - // Run kernel instance. - return f8f8bf16_rowwise_batched_impl( - XQ, WQ, x_scale, w_scale, Y); - } + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 224, + 128, + 16, + 16, + 8, + 7, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 64, 1, 4>, + S<8, 8, 1>, + 2, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_batched_impl(XQ, WQ, x_scale, w_scale, Y); } + diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip index d6d4984bae..603561c268 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip @@ -15,56 +15,25 @@ fp8_rowwise_batched_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_i at::Tensor x_scale, at::Tensor w_scale, at::Tensor Y) { - // Check if this input needs to be padded. - int M = XQ.size(1); - int N = WQ.size(1); - int K = WQ.size(2); - bool pad = (K % 128 != 0); - - // Dispatch based on whether padding is needed or not. - if (pad) { - using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 256, - 128, - 16, - 16, - 8, - 8, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 2, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::KPadding>; - // Run kernel instance. - return f8f8bf16_rowwise_batched_impl( - XQ, WQ, x_scale, w_scale, Y); - } else { - using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 256, - 128, - 16, - 16, - 8, - 8, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 2, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::Default>; - // Run kernel instance. - return f8f8bf16_rowwise_batched_impl( - XQ, WQ, x_scale, w_scale, Y); - } + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 256, + 128, + 16, + 16, + 8, + 8, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_batched_impl(XQ, WQ, x_scale, w_scale, Y); } + diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x32x128x256_32x32_1x1_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x32x128x256_32x32_1x1_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip new file mode 100644 index 0000000000..7bf049bb64 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x32x128x256_32x32_1x1_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_batched_common.h" + +at::Tensor +fp8_rowwise_batched_256x32x128x256_32x32_1x1_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 32, + 128, + 256, + 32, + 32, + 1, + 1, + S<16, 16, 1>, + S<16, 16, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_batched_impl(XQ, WQ, x_scale, w_scale, Y); +} + diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip new file mode 100644 index 0000000000..75db812799 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_batched_common.h" + +at::Tensor +fp8_rowwise_batched_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 32, + 64, + 512, + 16, + 16, + 1, + 2, + S<32, 8, 1>, + S<32, 8, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_batched_impl(XQ, WQ, x_scale, w_scale, Y); +} + diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip new file mode 100644 index 0000000000..7dfa22b7c9 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_batched_common.h" + +at::Tensor +fp8_rowwise_batched_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 64, + 128, + 256, + 32, + 32, + 1, + 2, + S<16, 16, 1>, + S<16, 16, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_batched_impl(XQ, WQ, x_scale, w_scale, Y); +} + diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x64x192x256_32x32_1x3_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x64x192x256_32x32_1x3_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip new file mode 100644 index 0000000000..402b7389c0 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x64x192x256_32x32_1x3_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_batched_common.h" + +at::Tensor +fp8_rowwise_batched_256x64x192x256_32x32_1x3_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 64, + 192, + 256, + 32, + 32, + 1, + 3, + S<16, 16, 1>, + S<16, 16, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_batched_impl(XQ, WQ, x_scale, w_scale, Y); +} + diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip new file mode 100644 index 0000000000..6100e683f1 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_batched_common.h" + +at::Tensor +fp8_rowwise_batched_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 64, + 64, + 512, + 32, + 32, + 1, + 1, + S<32, 8, 1>, + S<32, 8, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_batched_impl(XQ, WQ, x_scale, w_scale, Y); +} + diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip new file mode 100644 index 0000000000..e204e1879f --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_batched_common.h" + +at::Tensor +fp8_rowwise_batched_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<32, 2, 1>, + S<32, 2, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_batched_impl(XQ, WQ, x_scale, w_scale, Y); +} + diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip new file mode 100644 index 0000000000..e735e99769 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_batched_common.h" + +at::Tensor +fp8_rowwise_batched_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<32, 2, 1>, + S<32, 2, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_batched_impl(XQ, WQ, x_scale, w_scale, Y); +} + diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_kernel_manifest.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_kernel_manifest.h index 93bbbb2d28..16ee1e8813 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_kernel_manifest.h +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_kernel_manifest.h @@ -170,3 +170,147 @@ fp8_rowwise_batched_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_in at::Tensor x_scale, at::Tensor w_scale, at::Tensor Y); + +at::Tensor +fp8_rowwise_batched_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y); + +at::Tensor +fp8_rowwise_batched_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y); + +at::Tensor +fp8_rowwise_batched_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y); + +at::Tensor +fp8_rowwise_batched_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y); + +at::Tensor +fp8_rowwise_batched_256x256x128x128_16x16_8x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y); + +at::Tensor +fp8_rowwise_batched_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y); + +at::Tensor +fp8_rowwise_batched_256x16x64x512_16x16_1x1_32x8x1_32x8x1_1x16x1x16_4x4x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y); + +at::Tensor +fp8_rowwise_batched_256x128x256x128_32x32_2x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y); + +at::Tensor +fp8_rowwise_batched_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y); + +at::Tensor +fp8_rowwise_batched_256x32x128x256_32x32_1x1_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y); + +at::Tensor +fp8_rowwise_batched_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y); + +at::Tensor +fp8_rowwise_batched_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y); + +at::Tensor +fp8_rowwise_batched_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y); + +at::Tensor +fp8_rowwise_batched_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y); + +at::Tensor +fp8_rowwise_batched_128x16x32x256_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y); + +at::Tensor +fp8_rowwise_batched_256x64x192x256_32x32_1x3_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y); + +at::Tensor +fp8_rowwise_batched_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y); + +at::Tensor +fp8_rowwise_batched_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y);