diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index 7188de257ca8..7ebd5fe49257 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -654,7 +654,7 @@ def TileLargeTensorsPass : ]; let options = [ Option<"maxVectorSize", "max-vector-size", "int64_t", - /*default=*/"64", + /*default=*/"128", "Maximum static size to tile to (i.e. all remaining ops will be smaller)">, ]; } diff --git a/compiler/src/iree/compiler/Codegen/Common/test/tile_large_tensors.mlir b/compiler/src/iree/compiler/Codegen/Common/test/tile_large_tensors.mlir index 66c73da981c0..5c5a102444e1 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/tile_large_tensors.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/tile_large_tensors.mlir @@ -17,8 +17,8 @@ func.func @simple_generic(%3: tensor<64x256xf32>, %4: tensor<64x256xf32>, %5: te // CHECK-LABEL: func.func @simple_generic // CHECK: scf.for %{{.*}} = %c0 to %c64 step %c1 -// CHECK: scf.for %{{.*}} = %c0 to %c256 step %c64 -// CHECK: linalg.generic {{.*}} outs({{.*}}: tensor<1x64xf32>) +// CHECK: scf.for %{{.*}} = %c0 to %c256 step %c128 +// CHECK: linalg.generic {{.*}} outs({{.*}}: tensor<1x128xf32>) // ----- @@ -79,7 +79,7 @@ func.func @multiple_use_tilable_op(%3: tensor<64x256xf32>, %4: tensor<64x256xf32 // CHECK-LABEL: func.func @multiple_use_tilable_op // CHECK: %[[ADD_TILING:.+]] = scf.for -// CHECK: linalg.add {{.*}} -> tensor<1x64xf32> +// CHECK: linalg.add {{.*}} -> tensor<1x128xf32> // CHECK: %[[T_TILING:.+]] = scf.for // CHECK: %[[FUSED_ADD:.+]] = linalg.add {{.*}} -> tensor<64x1xf32> // CHECK: linalg.transpose ins(%[[FUSED_ADD]] diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index a11935114eba..9aff3a263a9d 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -118,8 +118,8 @@ LogicalResult setDataTiledMultiMmaLoweringConfig( /// problem based on the available mma intrinsics. static std::optional getMmaScheduleFromProblemAndTarget( IREE::GPU::TargetAttr target, GPUMatmulShapeType problem, - bool transposedLhs, bool transposedRhs, bool mustBeAligned = true, - bool doCPromotion = false) { + bool transposedLhs, bool transposedRhs, bool isIGEMM, + bool mustBeAligned = true, bool doCPromotion = false) { const int64_t targetSubgroupSize = target.getPreferredSubgroupSize(); SmallVector intrinsics; for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { @@ -142,23 +142,28 @@ static std::optional getMmaScheduleFromProblemAndTarget( // See https://github.com/iree-org/iree/issues/16341 for details. int64_t mSize = ShapedType::getNumElements(problem.mSizes); int64_t nSize = ShapedType::getNumElements(problem.nSizes); + int64_t cacheLineSizeElements = kCacheLineSizeBits / inBitWidth; + int64_t bestKElementCountPerSubgroup = + isIGEMM ? cacheLineSizeElements / 2 : cacheLineSizeElements; if (mSize * nSize <= 512 * 512) { // For matmuls with small M*N size, we want to distribute M*N onto more // workgroups to fill the GPU. Use a smaller bestMNTileCountPerSubgroup // and a larger bestKTileCountPerSubgroup. seeds = {/*bestSubgroupCountPerWorkgroup=*/4, /*bestMNTileCountPerSubgroup=*/4, - /*bestKTileCountPerSubgroup=*/8, - /*bestKElementCountPerSubgroup*/ kCacheLineSizeBits / inBitWidth}; + /*bestKTileCountPerSubgroup=*/8, bestKElementCountPerSubgroup * 2}; } else { seeds = {/*bestSubgroupCountPerWorkgroup=*/4, - /*bestMNTileCountPerSubgroup=*/16, - /*bestKTileCountPerSubgroup=*/4, - /*bestKElementCountPerSubgroup*/ kCacheLineSizeBits / 2 / - inBitWidth}; + /*bestMNTileCountPerSubgroup=*/8, + /*bestKTileCountPerSubgroup=*/4, bestKElementCountPerSubgroup}; } - int64_t maxSharedMemoryBytes = target.getWgp().getMaxWorkgroupMemoryBytes(); + // We target slightly below the full available shared Memory to leave room for + // `GPUReduceBankConflictsPass` that will pad shared memory without keeping + // track of usage. We can drop this after solving + // https://github.com/iree-org/iree/issues/19675 + int64_t maxSharedMemoryBytes = + target.getWgp().getMaxWorkgroupMemoryBytes() - 64 * inBitWidth; // First try to find a schedule with an exactly matching intrinsic. std::optional schedule = deduceMMASchedule( @@ -176,7 +181,8 @@ static FailureOr> getMatmulLoweringConfigAndWorkgroupSize(SmallVector bounds, ArrayRef maps, ArrayRef operands, - IREE::GPU::TargetAttr target) { + IREE::GPU::TargetAttr target, + bool isIGEMM) { if (target.getWgp().getMma().empty()) return failure(); @@ -244,7 +250,7 @@ getMatmulLoweringConfigAndWorkgroupSize(SmallVector bounds, bool mustBeAligned = true; bool doCPromotion = false; std::optional schedule = getMmaScheduleFromProblemAndTarget( - target, problem, transposedLhs, transposedRhs); + target, problem, transposedLhs, transposedRhs, isIGEMM); // TODO (nirvedhmeshram, qedawkins): The performance with this will be bad if // the GEMM is accumulating (i.e doesnt have a zero fill dpsInit) as that @@ -254,9 +260,9 @@ getMatmulLoweringConfigAndWorkgroupSize(SmallVector bounds, LDBG("Attempting to deduce unaligned TileAndFuse MMA schedulee"); mustBeAligned = false; doCPromotion = true; - schedule = getMmaScheduleFromProblemAndTarget(target, problem, - transposedLhs, transposedRhs, - mustBeAligned, doCPromotion); + schedule = getMmaScheduleFromProblemAndTarget( + target, problem, transposedLhs, transposedRhs, isIGEMM, mustBeAligned, + doCPromotion); } if (!schedule) { @@ -379,7 +385,8 @@ setIGEMMConvolutionLoweringConfig(IREE::GPU::TargetAttr target, SmallVector bounds = igemmLoopBounds.value(); FailureOr> configAndWgSize = getMatmulLoweringConfigAndWorkgroupSize( - bounds, igemmContractionMaps.value(), igemmOperands.value(), target); + bounds, igemmContractionMaps.value(), igemmOperands.value(), target, + /*isIGEMM=*/true); if (failed(configAndWgSize)) { return failure(); } @@ -422,7 +429,8 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, LDBG("Matmul TileAndFuse Config"); FailureOr> configAndWgSize = - getMatmulLoweringConfigAndWorkgroupSize(bounds, maps, operands, target); + getMatmulLoweringConfigAndWorkgroupSize(bounds, maps, operands, target, + /*isIGEMM=*/false); if (failed(configAndWgSize)) { return failure(); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index e09a72b52df5..8687bf5238d3 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -48,10 +48,10 @@ #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") namespace mlir::iree_compiler { -llvm::cl::opt clGPUTestTileAndFuseMatmul( - "iree-codegen-llvmgpu-test-tile-and-fuse-matmul", +llvm::cl::opt clGPUEnableTileAndFuseMatmul( + "iree-codegen-llvmgpu-enable-tile-and-fuse-matmul", llvm::cl::desc("test the the tile and fuse pipeline for matmul"), - llvm::cl::init(false)); + llvm::cl::init(true)); llvm::cl::opt clGPUTestTileAndFuseVectorize( "iree-codegen-llvmgpu-test-tile-and-fuse-vectorize", @@ -620,6 +620,9 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, /*canUpcastAcc=*/true); } + LDBG("transposedLhs: " << transposedLhs); + LDBG("transposedRhs: " << transposedRhs); + // Only batch_matmul is supported in the LLVMGPUPadAndVectorDistribute // pipeline. // TODO(hanchung): Support cases that there are fused producers. @@ -2352,7 +2355,7 @@ static LogicalResult setRootConfig(IREE::GPU::TargetAttr target, LDBG("Tile and fuse data tiled multi_mma config"); return success(); } - if (clGPUTestTileAndFuseMatmul) { + if (clGPUEnableTileAndFuseMatmul) { if (succeeded(IREE::GPU::setMatmulLoweringConfig(target, entryPointFn, computeOp))) { LDBG("Tile and fuse matmul config"); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir index 6f94069f2c6d..6cdb721c06c3 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir @@ -1,5 +1,5 @@ // RUN: iree-opt --mlir-print-local-scope --split-input-file --iree-gpu-test-target=gfx942 \ -// RUN: --iree-codegen-llvmgpu-test-tile-and-fuse-matmul=true --iree-codegen-llvmgpu-test-tile-and-fuse-vectorize=true \ +// RUN: --iree-codegen-llvmgpu-test-tile-and-fuse-vectorize=true \ // RUN: --iree-codegen-llvmgpu-use-igemm=false \ // RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s @@ -39,7 +39,7 @@ func.func @expanded_matmul_transpose_b(%lhs: tensor<2x64x2048xf16>, %rhs: tensor // CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config // CHECK-SAME: mma_kind = #iree_gpu.mma_layout // CHECK-SAME: promote_operands = [0, 1] -// CHECK-SAME: reduction = [0, 0, 0, 0, 4] +// CHECK-SAME: reduction = [0, 0, 0, 0, 8] // CHECK-SAME: subgroup = [1, 1, 4, 1, 0] // CHECK-SAME: workgroup = [1, 1, 64, 64, 0] @@ -74,7 +74,7 @@ func.func @multi_dim_mma_schedule(%lhs: tensor<10x32x128x16xf16>, %rhs: tensor<4 // CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config // CHECK-SAME: mma_kind = #iree_gpu.mma_layout // CHECK-SAME: promote_operands = [0, 1] -// CHECK-SAME: reduction = [0, 0, 0, 0, 4, 1] +// CHECK-SAME: reduction = [0, 0, 0, 0, 8, 1] // CHECK-SAME: subgroup = [2, 2, 1, 1, 0, 0] // CHECK-SAME: workgroup = [2, 2, 32, 32, 0, 0] @@ -136,9 +136,9 @@ func.func @mfma_matmul_1024x1024x1024(%lhs: tensor<1024x1024xf16>, %rhs: tensor< // CHECK: linalg.matmul {{.*}}lowering_config = #iree_gpu.lowering_config // CHECK-SAME: mma_kind = #iree_gpu.mma_layout // CHECK-SAME: promote_operands = [0, 1] -// CHECK-SAME: reduction = [0, 0, 2] -// CHECK-SAME: subgroup = [4, 4, 0] -// CHECK-SAME: workgroup = [128, 128, 0] +// CHECK-SAME: reduction = [0, 0, 4] +// CHECK-SAME: subgroup = [2, 4, 0] +// CHECK-SAME: workgroup = [64, 128, 0] // ----- diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx1100.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx1100.mlir index 3198f1592bdd..e42bdc266742 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx1100.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx1100.mlir @@ -1,4 +1,5 @@ // RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx1100 --iree-codegen-llvmgpu-use-vector-distribution \ +// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false \ // RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s --check-prefix=WMMA // TODO: This test is still using the legacy LLVMGPU kernel config. This needs diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx942.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx942.mlir index 373b67b04e8f..f43881200de0 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx942.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx942.mlir @@ -1,5 +1,6 @@ // RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 --iree-codegen-llvmgpu-use-vector-distribution \ // RUN: --iree-codegen-llvmgpu-use-unaligned-gemm-vector-distribution --iree-codegen-llvmgpu-use-igemm=false \ +// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false \ // RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s // TODO: This test is still using the legacy LLVMGPU kernel config. This needs diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir index 1a521e61ebd9..525086f8b5bd 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir @@ -1013,10 +1013,8 @@ hal.executable public @main { // CHECK: scf.yield %[[REDUCE]] // CHECK: scf.for %{{.*}} = %{{.*}} to %c16 step %c1 -// CHECK: scf.for -// CHECK-COUNT-4: arith.addf {{.*}} : vector<9xf32> -// CHECK: vector.transfer_write {{.*}} vector<9xi8>, memref<32x16x9x9xi8, #hal.descriptor_type> - +// CHECK-COUNT-4: arith.addf {{.*}} : vector<9x9xf32> +// CHECK: vector.transfer_write {{.*}} vector<9x9xi8>, memref<32x16x9x9xi8, #hal.descriptor_type> // ----- #pipeline_layout = #hal.pipeline.layout, %arg1 : tensor<512x128xf32>, return %1 : tensor<384x128xf32> } // CHECK: #[[CONFIG:.+]] = #iree_codegen.lowering_config -// CHECK: #[[TRANSLATION:.+]] = #iree_codegen.translation_info, promote_operands = [0, 1], reduction = [0, 0, 32], subgroup_m_count = 2 : i64, subgroup_n_count = 2 : i64, workgroup = [64, 64, 0]}> +// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout, promote_operands = [0, 1], reduction = [0, 0, 16], subgroup = [2, 2, 0], workgroup = [64, 64, 0]}> // CHECK: iree_linalg_ext.yield // ----- diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir index 642c6ed1a179..cd93a7b0268b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir @@ -1,7 +1,7 @@ // RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-codegen-llvmgpu-configuration-pipeline)" \ -// RUN: --iree-gpu-test-target=sm_60 %s | FileCheck %s +// RUN: --iree-gpu-test-target=sm_60 --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false %s | FileCheck %s // RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-codegen-llvmgpu-configuration-pipeline)" \ -// RUN: --iree-gpu-test-target=sm_80 %s | FileCheck %s --check-prefix=SM80 +// RUN: --iree-gpu-test-target=sm_80 --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false %s | FileCheck %s --check-prefix=SM80 // Transform dialect attributes are tested separately. diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir index be163903800e..8ea01ff2b54e 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir @@ -1,4 +1,6 @@ -// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))' --iree-gpu-test-target=sm_80 -split-input-file %s -o - | FileCheck %s +// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant( \ +// RUN: builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))' \ +// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false --iree-gpu-test-target=sm_80 -split-input-file %s -o - | FileCheck %s // This test checks that the lowering of nvvm includes the extraction // and optimization of address computations. diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_mma_sync_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_mma_sync_pipeline_test.mlir index c0cd53377863..2065390cd199 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_mma_sync_pipeline_test.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_mma_sync_pipeline_test.mlir @@ -1,4 +1,7 @@ -// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_80 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" -iree-codegen-llvmgpu-use-mma-sync %s | FileCheck %s +// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_80 \ +// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant( \ +// RUN: builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" \ +// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false -iree-codegen-llvmgpu-use-mma-sync %s | FileCheck %s // Verify that a simple element wise op gets lowered succefully all the way to // nvvm/llvm dialect via mma.sync path. diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir index ad6aad32420c..b210a806ae3a 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir @@ -1,5 +1,11 @@ -// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_60 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" -iree-codegen-llvmgpu-use-wmma %s | FileCheck %s -// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_80 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" -iree-codegen-llvmgpu-use-wmma %s | FileCheck %s --check-prefix=SM80 +// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_60 \ +// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant( \ +// RUN: builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" \ +// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false -iree-codegen-llvmgpu-use-wmma %s | FileCheck %s +// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_80 \ +// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant( \ +// RUN: builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" \ +// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false -iree-codegen-llvmgpu-use-wmma %s | FileCheck %s --check-prefix=SM80 // Verify that a simple element wise op gets lowered succefully all the way to // nvvm/llvm dialect.