Skip to content

Commit

Permalink
[GPU] Use padding in IGEMM pipeline to support unaligned to intrinsic…
Browse files Browse the repository at this point in the history
… shapes (#19484)

This PR does two things
1. Allow all GEMM shapes to use padded TileAndFuse Matmul configuration.
This is still behind the
`iree-codegen-llvmgpu-test-tile-and-fuse-matmul=false` flag by default
and does not change the default behavior. However following PRs that
have landed in the past month make it possible to relax the guards we
originally had on this.
#19196
#19307
llvm/llvm-project#117340
2. Allow fused producers to use use padded TileAndFuse Matmul
configuration. Following PRs make this possible now
#19399
llvm/llvm-project#119039

Together this allows us to do padded IGEMM with intrinsics for shapes
unaligned to intrinsic which we use by default.
[Here](https://docs.google.com/spreadsheets/d/1O-SdUZCn5pHsxx7JTGjIIdH6PWCFnvlfe4XBbjEBaIM/edit?gid=0#gid=0)
is the performance difference observed in conv cases in
iree-kernel-benchmark-module that utilize this change. A median speedup
of 2.26x was observed.

The numeric changes I observed with enabling this path were the same
between any aligned shape when comparing intrinsic vs no intrinsic use.
Generally some differences are noticed for narrow types like f16 but
they are within a relative error of 0.001 but since our tests use
absolute errors we may have to change some test values to account for
this change.

The perf difference in CI seem to be within noise margin compared to
main,
https://github.com/iree-org/iree/actions/runs/12323399269/attempts/1#summary-34399247902

---------

Signed-off-by: Nirvedh <nirvedh@gmail.com>
  • Loading branch information
nirvedhmeshram authored Dec 18, 2024
1 parent 78ea0ad commit 8ae1b54
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,7 @@ static FailureOr<std::pair<LoweringConfigAttr, int64_t>>
getMatmulLoweringConfigAndWorkgroupSize(SmallVector<int64_t> bounds,
ArrayRef<AffineMap> maps,
ArrayRef<Value> operands,
IREE::GPU::TargetAttr target,
bool hasFusedLeadingOp) {
IREE::GPU::TargetAttr target) {
if (target.getWgp().getMma().empty())
return failure();

Expand Down Expand Up @@ -253,13 +252,11 @@ getMatmulLoweringConfigAndWorkgroupSize(SmallVector<int64_t> bounds,
std::optional<GPUMMASchedule> schedule = getMmaScheduleFromProblemAndTarget(
target, problem, transposedLhs, transposedRhs);

// TODO (nirvedhmeshram, jerryyin): Support all GEMM types.
// TODO (nirvedhmeshram): Support fused leading op.
// TODO (nirvedhmeshram, qedawkins): The performance with this will be bad if
// the GEMM is accumulating (i.e doesnt have a zero fill dpsInit) as that
// buffer currently gets materialized as private memory. We need to add
// missing patterns to fix that.
if (!schedule && !contractionDims.batch.empty() && !hasFusedLeadingOp) {
if (!schedule) {
LDBG("Attempting to deduce unaligned TileAndFuse MMA schedulee");
mustBeAligned = false;
doCPromotion = true;
Expand Down Expand Up @@ -342,9 +339,6 @@ getMatmulLoweringConfigAndWorkgroupSize(SmallVector<int64_t> bounds,
} else {
// TODO (nirvedhmeshram, Max191, jerryyin) : Add support so that unaligned
// shapes do not require c promotion.
// TODO (nirvedhmeshram, jerryyin) : When using c promotion the heuristics
// used during finding a schedule need to be updated to account for the
// extra shared memory for the result.
GPU::setPromotedOperandList(context, attrs, {0, 1, 2});
SmallVector<int64_t> paddingTileSizes = workgroupTileSizes;
int64_t innerKDim = contractionDims.k.back();
Expand Down Expand Up @@ -391,8 +385,7 @@ setIGEMMConvolutionLoweringConfig(IREE::GPU::TargetAttr target,
SmallVector<int64_t> bounds = igemmLoopBounds.value();
FailureOr<std::pair<LoweringConfigAttr, int64_t>> configAndWgSize =
getMatmulLoweringConfigAndWorkgroupSize(
bounds, igemmContractionMaps.value(), igemmOperands.value(), target,
/*hasFusedLeadingOp=*/true);
bounds, igemmContractionMaps.value(), igemmOperands.value(), target);
if (failed(configAndWgSize)) {
return failure();
}
Expand Down Expand Up @@ -435,8 +428,7 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
LDBG("Matmul TileAndFuse Config");

FailureOr<std::pair<LoweringConfigAttr, int64_t>> configAndWgSize =
getMatmulLoweringConfigAndWorkgroupSize(bounds, maps, operands, target,
hasFusedLeadingOp(linalgOp));
getMatmulLoweringConfigAndWorkgroupSize(bounds, maps, operands, target);
if (failed(configAndWgSize)) {
return failure();
}
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,8 @@ static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager,
// Pad allocations with dynamic dimension after linalg lowering but before
// lowering SCF and affine ops.
.addPass(createPadDynamicAllocPass)
// Hoist any newly static allocations from PadDynamicAlloc.
.addPass(createHoistStaticallyBoundAllocationsPass)

.addPass(createLowerAffinePass)
.addPass(createCanonicalizerPass)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func.func @nchw_conv_mfma() {

// -----

func.func @nhwc_conv_no_mfma() {
func.func @nhwc_conv_unaligned_mfma() {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<2x33x33x128xf32>>
Expand All @@ -74,12 +74,22 @@ func.func @nhwc_conv_no_mfma() {
return
}

// CHECK-LABEL: func.func @nhwc_conv_no_mfma
// CHECK-NOT: use_igemm_convolution = true
// CHECK-LABEL: func.func @nhwc_conv_unaligned_mfma
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
// CHECK-SAME: #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false
// CHECK-SAME: use_igemm_convolution = true

// CHECK: linalg.conv_2d_nhwc_hwcf {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>
// CHECK-SAME: padding = [2, 1, 32, 64, 32]
// CHECK-SAME: promote_operands = [0, 1, 2]
// CHECK-SAME: reduction = [0, 0, 0, 0, 8]
// CHECK-SAME: subgroup = [2, 1, 2, 1, 0]
// CHECK-SAME: workgroup = [2, 1, 32, 64, 0]

// -----

func.func @nchw_conv_no_mfma() {
func.func @nchw_conv_unaligned_mfma() {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<2x128x34x34xf32>>
Expand All @@ -94,5 +104,15 @@ func.func @nchw_conv_no_mfma() {
return
}

// CHECK-LABEL: func.func @nchw_conv_no_mfma
// CHECK-NOT: use_igemm_convolution = true
// CHECK-LABEL: func.func @nchw_conv_unaligned_mfma
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64
// CHECK-SAME: #iree_gpu.pipeline_options<prefetch_shared_memory = true, no_reduce_shared_memory_bank_conflicts = false
// CHECK-SAME: use_igemm_convolution = true

// CHECK: linalg.conv_2d_nchw_fchw {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>
// CHECK-SAME: padding = [1, 64, 2, 32, 32]
// CHECK-SAME: promote_operands = [0, 1, 2]
// CHECK-SAME: reduction = [0, 0, 0, 0, 8]
// CHECK-SAME: subgroup = [1, 2, 2, 1, 0]
// CHECK-SAME: workgroup = [1, 64, 2, 32, 0]
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,83 @@ hal.executable private @main {
// CHECK: } {mapping = [#iree_codegen.workgroup_mapping<z>, #iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}

// TODO(Max191): Add tests for more convolution types

// -----

#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer, ReadOnly>,
#hal.pipeline.binding<storage_buffer, ReadOnly>,
#hal.pipeline.binding<storage_buffer>
]>
#translation = #iree_codegen.translation_info<pipeline =
LLVMGPUTileAndFuse
workgroup_size = [256, 1, 1]
subgroup_size = 64,
{
gpu_pipeline_options = #iree_gpu.pipeline_options<
prefetch_shared_memory = false,
no_reduce_shared_memory_bank_conflicts = false,
use_igemm_convolution = true>
}>
#config = #iree_gpu.lowering_config<{
padding = [2, 1, 32, 16, 16],
workgroup = [2, 1, 32, 16, 0],
reduction = [0, 0, 0, 0, 1],
subgroup = [1, 1, 1, 1, 0],
mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
promote_operands = [0, 1, 2]
}>
hal.executable private @main {
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) {
hal.executable.export public @conv_dispatch_0_conv_2d_nhwc_hwcf_2x17x17x1281x3x3x1281_f16xf16xf32 ordinal(0) layout(#pipeline_layout) {
^bb0(%arg0: !hal.device):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @conv_nhwc_unaligned_stride_2() attributes {translation_info = #translation} {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<2x35x35x1281xf16>> %1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<3x3x1281x1281xf16>>
%2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<2x17x17x1281xf32>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [2, 35, 35, 1281], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x35x35x1281xf16>> -> tensor<2x35x35x1281xf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [3, 3, 1281, 1281], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<3x3x1281x1281xf16>> -> tensor<3x3x1281x1281xf16>
%5 = tensor.empty() : tensor<2x17x17x1281xf32>
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2x17x17x1281xf32>) -> tensor<2x17x17x1281xf32>
%7 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, lowering_config = #config, strides = dense<2> : tensor<2xi64>} ins(%3, %4 : tensor<2x35x35x1281xf16>, tensor<3x3x1281x1281xf16>) outs(%6 : tensor<2x17x17x1281xf32>) -> tensor<2x17x17x1281xf32>
flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [2, 17, 17, 1281], strides = [1, 1, 1, 1] : tensor<2x17x17x1281xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x17x17x1281xf32>>
return
}
}
}
}

// CHECK-LABEL: func @conv_nhwc_unaligned
// CHECK-DAG: %[[B0:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0)
// CHECK-DAG: %[[B1:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1)
// CHECK-DAG: %[[B2:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2)
// CHECK-DAG: memref.alloc() : memref<2x1x2x16x1x16xf32, #gpu.address_space<workgroup>>
// CHECK-DAG: memref.alloc() : memref<16x20xf16, #gpu.address_space<workgroup>>
// CHECK-DAG: memref.alloc() : memref<2x1x32x20xf16, #gpu.address_space<workgroup>>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C721:.+]] = arith.constant 721 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK: scf.forall ({{.*}}) in (17, 81) {
// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[C721]] step %[[C1]] {{.*}} -> (vector<1x1x1x1x4x1xf32>)
// CHECK: gpu.barrier
// CHECK-DAG: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<1xf16>
// CHECK-DAG: vector.transfer_write %[[LHS_RD]]
// Note that to simplify the test we are not showing the mapping of the RHS_RD
// to its buffer as it goes through an scf.if/else control structure
// involving allocas.
// CHECK-DAG: %[[RHS_RD:.+]] = vector.transfer_read {{.*}} vector<1xf16>
// CHECK-DAG: vector.transfer_write %[[RHS_RD]]
// CHECK: gpu.barrier
// CHECK-DAG: %[[LHS_MM0:.+]] = vector.transfer_read {{.*}} vector<4xf16>
// CHECK-DAG: %[[RHS_MM:.+]] = vector.transfer_read {{.*}} vector<4x1x1xf16>
// CHECK-COUNT-1: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32
// CHECK: %[[LOOP_T:.+]] = vector.shape_cast %[[LOOP]] : vector<1x1x1x1x4x1xf32> to vector<4x1x1xf32>
// CHECK: vector.transfer_write %[[LOOP_T]]
// Note there is a writeback loop here that is skipped to simplify the test.
// CHECK: vector.transfer_write {{.*}}, %[[B2]]
// CHECK: } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}

0 comments on commit 8ae1b54

Please sign in to comment.