Skip to content

Commit

Permalink
amd fp8 rowwise batched gemm tuning (#3624)
Browse files Browse the repository at this point in the history
Summary:

X-link: facebookresearch/FBGEMM#702

Range-based lookup for fixed B, N, and K

Reviewed By: jwfromm

Differential Revision: D68780527
  • Loading branch information
mxz297 authored and facebook-github-bot committed Jan 28, 2025
1 parent bab9b62 commit 10b8ed7
Show file tree
Hide file tree
Showing 25 changed files with 1,036 additions and 259 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,81 @@ namespace fbgemm_gpu {
using RowwiseBatchedKernel = std::function<
at::Tensor(at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor)>;

using BNKLookupTableType = std::map<int, RowwiseBatchedKernel>;

// Define a custom hash function for std::tuple<int, int, int>
struct IntTupleHash {
size_t operator()(const std::tuple<int, int, int>& t) const {
auto hash1 = std::hash<int>{}(std::get<0>(t));
auto hash2 = std::hash<int>{}(std::get<1>(t));
auto hash3 = std::hash<int>{}(std::get<2>(t));
return hash1 ^ hash2 ^ hash3;
}
};

static const std::map<int, RowwiseBatchedKernel> B_2_N_5120_K_8192_dispatch_table = {
{ 8, fp8_rowwise_batched_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2},
{ 16, fp8_rowwise_batched_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_intrawave_v2},
{ 32, fp8_rowwise_batched_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 64, fp8_rowwise_batched_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 72, fp8_rowwise_batched_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 96, fp8_rowwise_batched_256x32x128x256_32x32_1x1_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 192, fp8_rowwise_batched_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 248, fp8_rowwise_batched_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
{ 384, fp8_rowwise_batched_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 512, fp8_rowwise_batched_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
{ 640, fp8_rowwise_batched_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 768, fp8_rowwise_batched_256x256x128x128_16x16_8x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 896, fp8_rowwise_batched_256x128x256x128_32x32_2x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 1024, fp8_rowwise_batched_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{ 1568, fp8_rowwise_batched_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 1792, fp8_rowwise_batched_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 2304, fp8_rowwise_batched_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{ 2816, fp8_rowwise_batched_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 3360, fp8_rowwise_batched_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 8992, fp8_rowwise_batched_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
};

static const std::map<int, RowwiseBatchedKernel> B_2_N_8192_K_5120_dispatch_table = {
{ 4, fp8_rowwise_batched_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
{ 8, fp8_rowwise_batched_128x16x32x256_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v1},
{ 16, fp8_rowwise_batched_256x16x64x512_16x16_1x1_32x8x1_32x8x1_1x16x1x16_4x4x1_1x1_intrawave_v3},
{ 32, fp8_rowwise_batched_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 64, fp8_rowwise_batched_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 128, fp8_rowwise_batched_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 192, fp8_rowwise_batched_256x64x192x256_32x32_1x3_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 208, fp8_rowwise_batched_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5},
{ 232, fp8_rowwise_batched_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 256, fp8_rowwise_batched_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 384, fp8_rowwise_batched_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 512, fp8_rowwise_batched_256x256x128x128_16x16_8x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 768, fp8_rowwise_batched_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 896, fp8_rowwise_batched_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 1024, fp8_rowwise_batched_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{ 1280, fp8_rowwise_batched_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{ 1792, fp8_rowwise_batched_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 2048, fp8_rowwise_batched_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{ 2304, fp8_rowwise_batched_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 2560, fp8_rowwise_batched_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 3136, fp8_rowwise_batched_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 8992, fp8_rowwise_batched_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
};

static const std::unordered_map<std::tuple<int, int, int>, BNKLookupTableType, IntTupleHash> BNK_lookup_table = {
{{2, 5120, 8192}, B_2_N_5120_K_8192_dispatch_table},
{{2, 8192, 5120}, B_2_N_8192_K_5120_dispatch_table}
};

RowwiseBatchedKernel rowwise_batched_bnk_lookup(int M, const BNKLookupTableType& table) {
auto it = table.lower_bound(M);
if (it != table.end()) {
return it->second;
} else {
--it;
return it->second;
}
}

RowwiseBatchedKernel
rowwise_batched_heuristic_dispatch(int B, int M, int N, int K) {
// Use shape heuristics to guess what the best kernel might be for the given
Expand Down Expand Up @@ -114,6 +189,15 @@ rowwise_batched_heuristic_dispatch(int B, int M, int N, int K) {
return fp8_rowwise_batched_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3;
}

RowwiseBatchedKernel
rowwise_batched_dispatch(int B, int M, int N, int K) {
auto it = BNK_lookup_table.find({B, N, K});
if (it != BNK_lookup_table.end()) {
return rowwise_batched_bnk_lookup(M, it->second);
}
return rowwise_batched_heuristic_dispatch(B, M, N, K);
}

at::Tensor f8f8bf16_rowwise_batched(
at::Tensor XQ,
at::Tensor WQ,
Expand Down Expand Up @@ -165,7 +249,7 @@ at::Tensor f8f8bf16_rowwise_batched(
}

RowwiseBatchedKernel selected_kernel =
rowwise_batched_heuristic_dispatch(B, M, N, K);
rowwise_batched_dispatch(B, M, N, K);
return selected_kernel(XQ, WQ, x_scale, w_scale, Y);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "fp8_rowwise_batched_common.h"

at::Tensor
fp8_rowwise_batched_128x16x32x256_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_interwave_v1(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
using DeviceGemmInstance = DeviceGemmHelper<
128,
16,
32,
256,
16,
16,
1,
1,
S<16, 8, 1>,
S<16, 8, 1>,
S<1, 16, 1, 8>,
S<4, 4, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Interwave,
ck::BlockGemmPipelineVersion::v1,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_batched_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "fp8_rowwise_batched_common.h"

at::Tensor
fp8_rowwise_batched_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
using DeviceGemmInstance = DeviceGemmHelper<
128,
16,
32,
512,
16,
16,
1,
1,
S<32, 4, 1>,
S<32, 4, 1>,
S<1, 16, 1, 8>,
S<4, 4, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v2,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_batched_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
}

Original file line number Diff line number Diff line change
Expand Up @@ -15,55 +15,25 @@ fp8_rowwise_batched_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_i
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
// Check if this input needs to be padded.
int M = XQ.size(1);
int N = WQ.size(1);
int K = WQ.size(2);
bool pad = (K % 128 != 0);

if (pad) {
using DeviceGemmInstance = DeviceGemmHelper<
256,
128,
128,
128,
32,
32,
2,
2,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 32, 1, 8>,
S<8, 8, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3,
ck::tensor_operation::device::GemmSpecialization::KPadding>;
// Run kernel instance.
return f8f8bf16_rowwise_batched_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
} else {
using DeviceGemmInstance = DeviceGemmHelper<
256,
128,
128,
128,
32,
32,
2,
2,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 32, 1, 8>,
S<8, 8, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_batched_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
}
using DeviceGemmInstance = DeviceGemmHelper<
256,
128,
128,
128,
32,
32,
2,
2,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 32, 1, 8>,
S<8, 8, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_batched_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
}

Original file line number Diff line number Diff line change
Expand Up @@ -15,55 +15,25 @@ fp8_rowwise_batched_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_i
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
// Check if this input needs to be padded.
int M = XQ.size(1);
int N = WQ.size(1);
int K = WQ.size(2);
bool pad = (K % 128 != 0);

if (pad) {
using DeviceGemmInstance = DeviceGemmHelper<
256,
128,
128,
128,
32,
32,
2,
2,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 32, 1, 8>,
S<8, 8, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v5,
ck::tensor_operation::device::GemmSpecialization::KPadding>;
// Run kernel instance.
return f8f8bf16_rowwise_batched_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
} else {
using DeviceGemmInstance = DeviceGemmHelper<
256,
128,
128,
128,
32,
32,
2,
2,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 32, 1, 8>,
S<8, 8, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v5,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_batched_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
}
using DeviceGemmInstance = DeviceGemmHelper<
256,
128,
128,
128,
32,
32,
2,
2,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 32, 1, 8>,
S<8, 8, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v5,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_batched_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "fp8_rowwise_batched_common.h"

at::Tensor
fp8_rowwise_batched_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
using DeviceGemmInstance = DeviceGemmHelper<
256,
128,
128,
256,
32,
32,
2,
2,
S<16, 16, 1>,
S<16, 16, 1>,
S<1, 32, 1, 8>,
S<8, 8, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_batched_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
}

Loading

0 comments on commit 10b8ed7

Please sign in to comment.