Skip to content

Commit

Permalink
give seperate heuristics to IGEMM
Browse files Browse the repository at this point in the history
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
  • Loading branch information
nirvedhmeshram committed Jan 13, 2025
1 parent d999ed1 commit 91a09c7
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ LogicalResult setDataTiledMultiMmaLoweringConfig(
/// problem based on the available mma intrinsics.
static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget(
IREE::GPU::TargetAttr target, GPUMatmulShapeType problem,
bool transposedLhs, bool transposedRhs, bool mustBeAligned = true,
bool doCPromotion = false) {
bool transposedLhs, bool transposedRhs, bool isIGEMM,
bool mustBeAligned = true, bool doCPromotion = false) {
const int64_t targetSubgroupSize = target.getPreferredSubgroupSize();
SmallVector<GPUMatmulShapeType> intrinsics;
for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
Expand All @@ -142,20 +142,22 @@ static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget(
// See https://github.com/iree-org/iree/issues/16341 for details.
int64_t mSize = ShapedType::getNumElements(problem.mSizes);
int64_t nSize = ShapedType::getNumElements(problem.nSizes);
int64_t cacheLineSizeElements = kCacheLineSizeBits / inBitWidth;
int64_t bestKElementCountPerSubgroup =
isIGEMM ? cacheLineSizeElements / 2 : cacheLineSizeElements;
if (mSize * nSize <= 512 * 512) {
// For matmuls with small M*N size, we want to distribute M*N onto more
// workgroups to fill the GPU. Use a smaller bestMNTileCountPerSubgroup
// and a larger bestKTileCountPerSubgroup.
seeds = {/*bestSubgroupCountPerWorkgroup=*/4,
/*bestMNTileCountPerSubgroup=*/4,
/*bestKTileCountPerSubgroup=*/8,
/*bestKElementCountPerSubgroup*/ kCacheLineSizeBits * 2 /
inBitWidth};
/*bestKTileCountPerSubgroup=*/8, bestKElementCountPerSubgroup * 2};
} else {
int64_t bestKElementCountPerSubgroup =
isIGEMM ? cacheLineSizeElements / 2 : cacheLineSizeElements;
seeds = {/*bestSubgroupCountPerWorkgroup=*/4,
/*bestMNTileCountPerSubgroup=*/16,
/*bestKTileCountPerSubgroup=*/4,
/*bestKElementCountPerSubgroup*/ kCacheLineSizeBits / inBitWidth};
/*bestKTileCountPerSubgroup=*/4, bestKElementCountPerSubgroup};
}

// We target slightly below the full available shared Memory to leave room for
Expand All @@ -181,7 +183,8 @@ static FailureOr<std::pair<LoweringConfigAttr, int64_t>>
getMatmulLoweringConfigAndWorkgroupSize(SmallVector<int64_t> bounds,
ArrayRef<AffineMap> maps,
ArrayRef<Value> operands,
IREE::GPU::TargetAttr target) {
IREE::GPU::TargetAttr target,
bool isIGEMM) {
if (target.getWgp().getMma().empty())
return failure();

Expand Down Expand Up @@ -249,7 +252,7 @@ getMatmulLoweringConfigAndWorkgroupSize(SmallVector<int64_t> bounds,
bool mustBeAligned = true;
bool doCPromotion = false;
std::optional<GPUMMASchedule> schedule = getMmaScheduleFromProblemAndTarget(
target, problem, transposedLhs, transposedRhs);
target, problem, transposedLhs, transposedRhs, isIGEMM);

// TODO (nirvedhmeshram, qedawkins): The performance with this will be bad if
// the GEMM is accumulating (i.e doesnt have a zero fill dpsInit) as that
Expand All @@ -259,9 +262,9 @@ getMatmulLoweringConfigAndWorkgroupSize(SmallVector<int64_t> bounds,
LDBG("Attempting to deduce unaligned TileAndFuse MMA schedulee");
mustBeAligned = false;
doCPromotion = true;
schedule = getMmaScheduleFromProblemAndTarget(target, problem,
transposedLhs, transposedRhs,
mustBeAligned, doCPromotion);
schedule = getMmaScheduleFromProblemAndTarget(
target, problem, transposedLhs, transposedRhs, isIGEMM, mustBeAligned,
doCPromotion);
}

if (!schedule) {
Expand Down Expand Up @@ -384,7 +387,8 @@ setIGEMMConvolutionLoweringConfig(IREE::GPU::TargetAttr target,
SmallVector<int64_t> bounds = igemmLoopBounds.value();
FailureOr<std::pair<LoweringConfigAttr, int64_t>> configAndWgSize =
getMatmulLoweringConfigAndWorkgroupSize(
bounds, igemmContractionMaps.value(), igemmOperands.value(), target);
bounds, igemmContractionMaps.value(), igemmOperands.value(), target,
/*isIGEMM=*/true);
if (failed(configAndWgSize)) {
return failure();
}
Expand Down Expand Up @@ -434,7 +438,8 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
LDBG("Matmul TileAndFuse Config");

FailureOr<std::pair<LoweringConfigAttr, int64_t>> configAndWgSize =
getMatmulLoweringConfigAndWorkgroupSize(bounds, maps, operands, target);
getMatmulLoweringConfigAndWorkgroupSize(bounds, maps, operands, target,
/*isIGEMM=*/false);
if (failed(configAndWgSize)) {
return failure();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func.func @nhwc_conv_mfma() {
// CHECK: linalg.conv_2d_nhwc_hwcf {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>
// CHECK-SAME: promote_operands = [0, 1]
// CHECK-SAME: reduction = [0, 0, 0, 0, 16]
// CHECK-SAME: reduction = [0, 0, 0, 0, 8]
// CHECK-SAME: subgroup = [1, 2, 2, 1, 0]
// CHECK-SAME: workgroup = [1, 2, 32, 64, 0]

Expand Down Expand Up @@ -53,7 +53,7 @@ func.func @nchw_conv_mfma() {
// CHECK: linalg.conv_2d_nchw_fchw {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>
// CHECK-SAME: promote_operands = [0, 1]
// CHECK-SAME: reduction = [0, 0, 0, 0, 16]
// CHECK-SAME: reduction = [0, 0, 0, 0, 8]
// CHECK-SAME: subgroup = [1, 2, 2, 1, 0]
// CHECK-SAME: workgroup = [1, 64, 2, 32, 0]

Expand Down Expand Up @@ -81,9 +81,9 @@ func.func @nhwc_conv_unaligned_mfma() {

// CHECK: linalg.conv_2d_nhwc_hwcf {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>
// CHECK-SAME: padding = [2, 1, 32, 64, 64]
// CHECK-SAME: padding = [2, 1, 32, 64, 32]
// CHECK-SAME: promote_operands = [0, 1, 2]
// CHECK-SAME: reduction = [0, 0, 0, 0, 16]
// CHECK-SAME: reduction = [0, 0, 0, 0, 8]
// CHECK-SAME: subgroup = [2, 1, 2, 1, 0]
// CHECK-SAME: workgroup = [2, 1, 32, 64, 0]

Expand Down Expand Up @@ -111,8 +111,8 @@ func.func @nchw_conv_unaligned_mfma() {

// CHECK: linalg.conv_2d_nchw_fchw {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>
// CHECK-SAME: padding = [1, 64, 2, 32, 64]
// CHECK-SAME: padding = [1, 64, 2, 32, 32]
// CHECK-SAME: promote_operands = [0, 1, 2]
// CHECK-SAME: reduction = [0, 0, 0, 0, 16]
// CHECK-SAME: reduction = [0, 0, 0, 0, 8]
// CHECK-SAME: subgroup = [1, 2, 2, 1, 0]
// CHECK-SAME: workgroup = [1, 64, 2, 32, 0]

0 comments on commit 91a09c7

Please sign in to comment.