From 1a0ecbf7ceca5099223adac4934d7d56f44ec6f3 Mon Sep 17 00:00:00 2001 From: Nirvedh Date: Wed, 18 Dec 2024 14:46:10 -0600 Subject: [PATCH 1/6] [GPU] Enable tile and fuse matmul by default Signed-off-by: Nirvedh Signed-off-by: Nirvedh Meshram --- .../Dialect/GPU/TargetUtils/ConfigUtils.cpp | 18 ++++++++++++++++-- .../compiler/Codegen/LLVMGPU/KernelConfig.cpp | 8 ++++---- .../test/ROCDL/config_igemm_tile_and_fuse.mlir | 4 ++-- .../test/ROCDL/config_tile_and_fuse.mlir | 6 +++--- .../config_vector_distribute_gfx1100.mlir | 1 + .../ROCDL/config_vector_distribute_gfx942.mlir | 1 + .../Codegen/LLVMGPU/test/config_custom_op.mlir | 4 ++-- .../LLVMGPU/test/gpu_set_num_workgroups.mlir | 4 ++-- .../test/nvvm_extract_address_computation.mlir | 4 +++- .../test/nvvm_mma_sync_pipeline_test.mlir | 5 ++++- .../LLVMGPU/test/nvvm_pipeline_test.mlir | 10 ++++++++-- 11 files changed, 46 insertions(+), 19 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index a11935114eba..c1a4896f5430 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -386,9 +386,16 @@ setIGEMMConvolutionLoweringConfig(IREE::GPU::TargetAttr target, std::array workgroupSize = {configAndWgSize->second, 1, 1}; LoweringConfigAttr loweringConfig = configAndWgSize->first; + bool usePrefetchSharedMemory = true; + // Prefetching has issues when doing c promotion, see + // https://github.com/iree-org/iree/issues/19612. + if (llvm::any_of(getPromotedOperandList(loweringConfig).value(), + [](int64_t promote) { return promote == 2; })) { + usePrefetchSharedMemory = false; + } SmallVector pipelineAttrs; auto pipelineOptions = IREE::GPU::GPUPipelineOptionsAttr::get( - linalgOp->getContext(), /*prefetchSharedMemory=*/true, + linalgOp->getContext(), /*prefetchSharedMemory=*/usePrefetchSharedMemory, /*no_reduce_shared_memory_bank_conflicts=*/false, /*use_igemm_convolution=*/true, /*reorder_workgroups_strategy=*/std::nullopt); @@ -429,9 +436,16 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, std::array workgroupSize = {configAndWgSize->second, 1, 1}; LoweringConfigAttr loweringConfig = configAndWgSize->first; + bool usePrefetchSharedMemory = true; + // Prefetching has issues when doing c promotion, see + // https://github.com/iree-org/iree/issues/19612. + if (llvm::any_of(getPromotedOperandList(loweringConfig).value(), + [](int64_t promote) { return promote == 2; })) { + usePrefetchSharedMemory = false; + } SmallVector pipelineAttrs; auto pipelineOptions = IREE::GPU::GPUPipelineOptionsAttr::get( - linalgOp->getContext(), /*prefetchSharedMemory=*/true, + linalgOp->getContext(), /*prefetchSharedMemory=*/usePrefetchSharedMemory, /*no_reduce_shared_memory_bank_conflicts=*/false, /*use_igemm_convolution=*/false, /*reorder_workgroups_strategy=*/std::nullopt); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index e09a72b52df5..2ba637f9889c 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -48,10 +48,10 @@ #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") namespace mlir::iree_compiler { -llvm::cl::opt clGPUTestTileAndFuseMatmul( - "iree-codegen-llvmgpu-test-tile-and-fuse-matmul", +llvm::cl::opt clGPUEnableTileAndFuseMatmul( + "iree-codegen-llvmgpu-enable-tile-and-fuse-matmul", llvm::cl::desc("test the the tile and fuse pipeline for matmul"), - llvm::cl::init(false)); + llvm::cl::init(true)); llvm::cl::opt clGPUTestTileAndFuseVectorize( "iree-codegen-llvmgpu-test-tile-and-fuse-vectorize", @@ -2352,7 +2352,7 @@ static LogicalResult setRootConfig(IREE::GPU::TargetAttr target, LDBG("Tile and fuse data tiled multi_mma config"); return success(); } - if (clGPUTestTileAndFuseMatmul) { + if (clGPUEnableTileAndFuseMatmul) { if (succeeded(IREE::GPU::setMatmulLoweringConfig(target, entryPointFn, computeOp))) { LDBG("Tile and fuse matmul config"); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir index cf170ef7d930..d8af22e58664 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir @@ -76,7 +76,7 @@ func.func @nhwc_conv_unaligned_mfma() { // CHECK-LABEL: func.func @nhwc_conv_unaligned_mfma // CHECK-SAME: #iree_codegen.translation_info, // CHECK-LABEL: func.func @unaligned_to_intrinsic_batched_matmul // CHECK-SAME: #iree_codegen.translation_info} +// CHECK-SAME: {gpu_pipeline_options = #iree_gpu.pipeline_options} // CHECK: linalg.batch_matmul {{.*}}lowering_config = #iree_gpu.lowering_config // CHECK-SAME: padding = [1, 16, 16, 4] // CHECK-SAME: promote_operands = [0, 1, 2] @@ -306,7 +306,7 @@ func.func @unaligned_to_intrinsic_batched_matmul_tiling_check(%lhs : tensor<12x5 // CHECK-LABEL: func.func @unaligned_to_intrinsic_batched_matmul_tiling_check // CHECK-SAME: #iree_codegen.translation_info} +// CHECK-SAME: {gpu_pipeline_options = #iree_gpu.pipeline_options} // CHECK: linalg.batch_matmul {{.*}}lowering_config = #iree_gpu.lowering_config // CHECK-SAME: padding = [1, 16, 512, 4] // CHECK-SAME: promote_operands = [0, 1, 2] diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx1100.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx1100.mlir index 3198f1592bdd..e42bdc266742 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx1100.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx1100.mlir @@ -1,4 +1,5 @@ // RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx1100 --iree-codegen-llvmgpu-use-vector-distribution \ +// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false \ // RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s --check-prefix=WMMA // TODO: This test is still using the legacy LLVMGPU kernel config. This needs diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx942.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx942.mlir index 373b67b04e8f..f43881200de0 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx942.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx942.mlir @@ -1,5 +1,6 @@ // RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 --iree-codegen-llvmgpu-use-vector-distribution \ // RUN: --iree-codegen-llvmgpu-use-unaligned-gemm-vector-distribution --iree-codegen-llvmgpu-use-igemm=false \ +// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false \ // RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s // TODO: This test is still using the legacy LLVMGPU kernel config. This needs diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_custom_op.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_custom_op.mlir index 62ccec73c67a..bea2f2abe738 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_custom_op.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_custom_op.mlir @@ -33,14 +33,14 @@ func.func @custom_op(%arg0 : tensor<384x512xf32>, %arg1 : tensor<512x128xf32>, return %1 : tensor<384x128xf32> } // CHECK: #[[CONFIG:.+]] = #iree_codegen.lowering_config -// CHECK: #[[TRANSLATION:.+]] = #iree_codegen.translation_info, promote_operands = [0, 1], reduction = [0, 0, 32], subgroup_m_count = 2 : i64, subgroup_n_count = 2 : i64, workgroup = [64, 64, 0]}> +// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout, promote_operands = [0, 1], reduction = [0, 0, 8], subgroup = [2, 2, 0], workgroup = [64, 64, 0]}> // CHECK: iree_linalg_ext.yield // ----- diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir index 642c6ed1a179..cd93a7b0268b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir @@ -1,7 +1,7 @@ // RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-codegen-llvmgpu-configuration-pipeline)" \ -// RUN: --iree-gpu-test-target=sm_60 %s | FileCheck %s +// RUN: --iree-gpu-test-target=sm_60 --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false %s | FileCheck %s // RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-codegen-llvmgpu-configuration-pipeline)" \ -// RUN: --iree-gpu-test-target=sm_80 %s | FileCheck %s --check-prefix=SM80 +// RUN: --iree-gpu-test-target=sm_80 --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false %s | FileCheck %s --check-prefix=SM80 // Transform dialect attributes are tested separately. diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir index be163903800e..8ea01ff2b54e 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir @@ -1,4 +1,6 @@ -// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))' --iree-gpu-test-target=sm_80 -split-input-file %s -o - | FileCheck %s +// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant( \ +// RUN: builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))' \ +// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false --iree-gpu-test-target=sm_80 -split-input-file %s -o - | FileCheck %s // This test checks that the lowering of nvvm includes the extraction // and optimization of address computations. diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_mma_sync_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_mma_sync_pipeline_test.mlir index c0cd53377863..2065390cd199 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_mma_sync_pipeline_test.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_mma_sync_pipeline_test.mlir @@ -1,4 +1,7 @@ -// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_80 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" -iree-codegen-llvmgpu-use-mma-sync %s | FileCheck %s +// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_80 \ +// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant( \ +// RUN: builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" \ +// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false -iree-codegen-llvmgpu-use-mma-sync %s | FileCheck %s // Verify that a simple element wise op gets lowered succefully all the way to // nvvm/llvm dialect via mma.sync path. diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir index ad6aad32420c..b210a806ae3a 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_pipeline_test.mlir @@ -1,5 +1,11 @@ -// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_60 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" -iree-codegen-llvmgpu-use-wmma %s | FileCheck %s -// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_80 --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" -iree-codegen-llvmgpu-use-wmma %s | FileCheck %s --check-prefix=SM80 +// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_60 \ +// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant( \ +// RUN: builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" \ +// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false -iree-codegen-llvmgpu-use-wmma %s | FileCheck %s +// RUN: iree-opt --split-input-file --iree-gpu-test-target=sm_80 \ +// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant( \ +// RUN: builtin.module(iree-codegen-llvmgpu-configuration-pipeline), iree-codegen-linalg-to-nvvm-pipeline)))" \ +// RUN: --iree-codegen-llvmgpu-enable-tile-and-fuse-matmul=false -iree-codegen-llvmgpu-use-wmma %s | FileCheck %s --check-prefix=SM80 // Verify that a simple element wise op gets lowered succefully all the way to // nvvm/llvm dialect. From 365265df09cf6a1027d685b052e777f6c125d8a2 Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram Date: Fri, 10 Jan 2025 11:17:58 -0600 Subject: [PATCH 2/6] Match TileAndFuse Matmul Heuristics to VectorDistibute and raise limit of TileLargeTensorPass Signed-off-by: Nirvedh Meshram --- .../iree/compiler/Codegen/Common/Passes.td | 2 +- .../Common/test/tile_large_tensors.mlir | 30 +++++++++---------- .../Dialect/GPU/TargetUtils/ConfigUtils.cpp | 13 +++++--- .../compiler/Codegen/LLVMGPU/KernelConfig.cpp | 3 ++ .../ROCDL/config_igemm_tile_and_fuse.mlir | 12 ++++---- .../test/ROCDL/config_tile_and_fuse.mlir | 6 ++-- .../test/ROCDL/pipeline_tile_and_fuse.mlir | 5 ++-- .../LLVMGPU/test/config_custom_op.mlir | 2 +- 8 files changed, 40 insertions(+), 33 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index 7188de257ca8..245b07f6deaa 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -654,7 +654,7 @@ def TileLargeTensorsPass : ]; let options = [ Option<"maxVectorSize", "max-vector-size", "int64_t", - /*default=*/"64", + /*default=*/"256", "Maximum static size to tile to (i.e. all remaining ops will be smaller)">, ]; } diff --git a/compiler/src/iree/compiler/Codegen/Common/test/tile_large_tensors.mlir b/compiler/src/iree/compiler/Codegen/Common/test/tile_large_tensors.mlir index 66c73da981c0..3bb51a2d6d0c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/tile_large_tensors.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/tile_large_tensors.mlir @@ -3,22 +3,22 @@ // RUN: FileCheck %s #map = affine_map<(d0, d1) -> (d0, d1)> -func.func @simple_generic(%3: tensor<64x256xf32>, %4: tensor<64x256xf32>, %5: tensor<64x256xf32>) -> tensor<64x256xf32> { +func.func @simple_generic(%3: tensor<64x512xf32>, %4: tensor<64x512xf32>, %5: tensor<64x512xf32>) -> tensor<64x512xf32> { %6 = linalg.generic { indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"] - } ins(%3, %4 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%5 : tensor<64x256xf32>) { + } ins(%3, %4 : tensor<64x512xf32>, tensor<64x512xf32>) outs(%5 : tensor<64x512xf32>) { ^bb0(%in: f32, %in_0: f32, %out: f32): %7 = arith.addf %in, %in_0 : f32 linalg.yield %7 : f32 - } -> tensor<64x256xf32> - return %6 : tensor<64x256xf32> + } -> tensor<64x512xf32> + return %6 : tensor<64x512xf32> } // CHECK-LABEL: func.func @simple_generic // CHECK: scf.for %{{.*}} = %c0 to %c64 step %c1 -// CHECK: scf.for %{{.*}} = %c0 to %c256 step %c64 -// CHECK: linalg.generic {{.*}} outs({{.*}}: tensor<1x64xf32>) +// CHECK: scf.for %{{.*}} = %c0 to %c512 step %c256 +// CHECK: linalg.generic {{.*}} outs({{.*}}: tensor<1x256xf32>) // ----- @@ -65,21 +65,21 @@ func.func @in_nested_region(%3: tensor<64x64xf32>, %4: tensor<64x64xf32>, %5: te // ----- -func.func @multiple_use_tilable_op(%3: tensor<64x256xf32>, %4: tensor<64x256xf32>) -> (tensor<64x256xf32>, tensor<256x64xf32>) { - %add_empty = tensor.empty() : tensor<64x256xf32> +func.func @multiple_use_tilable_op(%3: tensor<64x512xf32>, %4: tensor<64x512xf32>) -> (tensor<64x512xf32>, tensor<512x64xf32>) { + %add_empty = tensor.empty() : tensor<64x512xf32> %6 = linalg.add - ins(%3, %4 : tensor<64x256xf32>, tensor<64x256xf32>) - outs(%add_empty : tensor<64x256xf32>) -> tensor<64x256xf32> - %transpose_empty = tensor.empty() : tensor<256x64xf32> + ins(%3, %4 : tensor<64x512xf32>, tensor<64x512xf32>) + outs(%add_empty : tensor<64x512xf32>) -> tensor<64x512xf32> + %transpose_empty = tensor.empty() : tensor<512x64xf32> %7 = linalg.transpose - ins(%6 : tensor<64x256xf32>) - outs(%transpose_empty : tensor<256x64xf32>) permutation = [1, 0] - return %6, %7 : tensor<64x256xf32>, tensor<256x64xf32> + ins(%6 : tensor<64x512xf32>) + outs(%transpose_empty : tensor<512x64xf32>) permutation = [1, 0] + return %6, %7 : tensor<64x512xf32>, tensor<512x64xf32> } // CHECK-LABEL: func.func @multiple_use_tilable_op // CHECK: %[[ADD_TILING:.+]] = scf.for -// CHECK: linalg.add {{.*}} -> tensor<1x64xf32> +// CHECK: linalg.add {{.*}} -> tensor<1x256xf32> // CHECK: %[[T_TILING:.+]] = scf.for // CHECK: %[[FUSED_ADD:.+]] = linalg.add {{.*}} -> tensor<64x1xf32> // CHECK: linalg.transpose ins(%[[FUSED_ADD]] diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index c1a4896f5430..a07887a7f35f 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -149,16 +149,21 @@ static std::optional getMmaScheduleFromProblemAndTarget( seeds = {/*bestSubgroupCountPerWorkgroup=*/4, /*bestMNTileCountPerSubgroup=*/4, /*bestKTileCountPerSubgroup=*/8, - /*bestKElementCountPerSubgroup*/ kCacheLineSizeBits / inBitWidth}; + /*bestKElementCountPerSubgroup*/ kCacheLineSizeBits * 2 / + inBitWidth}; } else { seeds = {/*bestSubgroupCountPerWorkgroup=*/4, /*bestMNTileCountPerSubgroup=*/16, /*bestKTileCountPerSubgroup=*/4, - /*bestKElementCountPerSubgroup*/ kCacheLineSizeBits / 2 / - inBitWidth}; + /*bestKElementCountPerSubgroup*/ kCacheLineSizeBits / inBitWidth}; } - int64_t maxSharedMemoryBytes = target.getWgp().getMaxWorkgroupMemoryBytes(); + // We target slightly below the full available shared Memory to leave room for + // `GPUReduceBankConflictsPass` that will pad shared memory without keeping + // track of usage. We can drop this after solving + // https://github.com/iree-org/iree/issues/19675 + int64_t maxSharedMemoryBytes = + target.getWgp().getMaxWorkgroupMemoryBytes() - 64 * inBitWidth; // First try to find a schedule with an exactly matching intrinsic. std::optional schedule = deduceMMASchedule( diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index 2ba637f9889c..8687bf5238d3 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -620,6 +620,9 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, /*canUpcastAcc=*/true); } + LDBG("transposedLhs: " << transposedLhs); + LDBG("transposedRhs: " << transposedRhs); + // Only batch_matmul is supported in the LLVMGPUPadAndVectorDistribute // pipeline. // TODO(hanchung): Support cases that there are fused producers. diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir index d8af22e58664..c1be57b94903 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir @@ -24,7 +24,7 @@ func.func @nhwc_conv_mfma() { // CHECK: linalg.conv_2d_nhwc_hwcf {{.*}}lowering_config = #iree_gpu.lowering_config // CHECK-SAME: mma_kind = #iree_gpu.mma_layout // CHECK-SAME: promote_operands = [0, 1] -// CHECK-SAME: reduction = [0, 0, 0, 0, 8] +// CHECK-SAME: reduction = [0, 0, 0, 0, 16] // CHECK-SAME: subgroup = [1, 2, 2, 1, 0] // CHECK-SAME: workgroup = [1, 2, 32, 64, 0] @@ -53,7 +53,7 @@ func.func @nchw_conv_mfma() { // CHECK: linalg.conv_2d_nchw_fchw {{.*}}lowering_config = #iree_gpu.lowering_config // CHECK-SAME: mma_kind = #iree_gpu.mma_layout // CHECK-SAME: promote_operands = [0, 1] -// CHECK-SAME: reduction = [0, 0, 0, 0, 8] +// CHECK-SAME: reduction = [0, 0, 0, 0, 16] // CHECK-SAME: subgroup = [1, 2, 2, 1, 0] // CHECK-SAME: workgroup = [1, 64, 2, 32, 0] @@ -81,9 +81,9 @@ func.func @nhwc_conv_unaligned_mfma() { // CHECK: linalg.conv_2d_nhwc_hwcf {{.*}}lowering_config = #iree_gpu.lowering_config // CHECK-SAME: mma_kind = #iree_gpu.mma_layout -// CHECK-SAME: padding = [2, 1, 32, 64, 32] +// CHECK-SAME: padding = [2, 1, 32, 64, 64] // CHECK-SAME: promote_operands = [0, 1, 2] -// CHECK-SAME: reduction = [0, 0, 0, 0, 8] +// CHECK-SAME: reduction = [0, 0, 0, 0, 16] // CHECK-SAME: subgroup = [2, 1, 2, 1, 0] // CHECK-SAME: workgroup = [2, 1, 32, 64, 0] @@ -111,8 +111,8 @@ func.func @nchw_conv_unaligned_mfma() { // CHECK: linalg.conv_2d_nchw_fchw {{.*}}lowering_config = #iree_gpu.lowering_config // CHECK-SAME: mma_kind = #iree_gpu.mma_layout -// CHECK-SAME: padding = [1, 64, 2, 32, 32] +// CHECK-SAME: padding = [1, 64, 2, 32, 64] // CHECK-SAME: promote_operands = [0, 1, 2] -// CHECK-SAME: reduction = [0, 0, 0, 0, 8] +// CHECK-SAME: reduction = [0, 0, 0, 0, 16] // CHECK-SAME: subgroup = [1, 2, 2, 1, 0] // CHECK-SAME: workgroup = [1, 64, 2, 32, 0] diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir index 25c9dda4dcc4..3576ea4416a2 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir @@ -39,7 +39,7 @@ func.func @expanded_matmul_transpose_b(%lhs: tensor<2x64x2048xf16>, %rhs: tensor // CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config // CHECK-SAME: mma_kind = #iree_gpu.mma_layout // CHECK-SAME: promote_operands = [0, 1] -// CHECK-SAME: reduction = [0, 0, 0, 0, 4] +// CHECK-SAME: reduction = [0, 0, 0, 0, 8] // CHECK-SAME: subgroup = [1, 1, 4, 1, 0] // CHECK-SAME: workgroup = [1, 1, 64, 64, 0] @@ -74,7 +74,7 @@ func.func @multi_dim_mma_schedule(%lhs: tensor<10x32x128x16xf16>, %rhs: tensor<4 // CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config // CHECK-SAME: mma_kind = #iree_gpu.mma_layout // CHECK-SAME: promote_operands = [0, 1] -// CHECK-SAME: reduction = [0, 0, 0, 0, 4, 1] +// CHECK-SAME: reduction = [0, 0, 0, 0, 8, 1] // CHECK-SAME: subgroup = [2, 2, 1, 1, 0, 0] // CHECK-SAME: workgroup = [2, 2, 32, 32, 0, 0] @@ -136,7 +136,7 @@ func.func @mfma_matmul_1024x1024x1024(%lhs: tensor<1024x1024xf16>, %rhs: tensor< // CHECK: linalg.matmul {{.*}}lowering_config = #iree_gpu.lowering_config // CHECK-SAME: mma_kind = #iree_gpu.mma_layout // CHECK-SAME: promote_operands = [0, 1] -// CHECK-SAME: reduction = [0, 0, 2] +// CHECK-SAME: reduction = [0, 0, 4] // CHECK-SAME: subgroup = [4, 4, 0] // CHECK-SAME: workgroup = [128, 128, 0] diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir index 1a521e61ebd9..6c9e4d5f752d 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir @@ -1013,9 +1013,8 @@ hal.executable public @main { // CHECK: scf.yield %[[REDUCE]] // CHECK: scf.for %{{.*}} = %{{.*}} to %c16 step %c1 -// CHECK: scf.for -// CHECK-COUNT-4: arith.addf {{.*}} : vector<9xf32> -// CHECK: vector.transfer_write {{.*}} vector<9xi8>, memref<32x16x9x9xi8, #hal.descriptor_type> +// CHECK-COUNT-4: arith.addf {{.*}} : vector<9x9xf32> +// CHECK: vector.transfer_write {{.*}} vector<9x9xi8>, memref<32x16x9x9xi8, #hal.descriptor_type> // ----- diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_custom_op.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_custom_op.mlir index bea2f2abe738..d553fb67d11b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_custom_op.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_custom_op.mlir @@ -40,7 +40,7 @@ func.func @custom_op(%arg0 : tensor<384x512xf32>, %arg1 : tensor<512x128xf32>, // CHECK-SAME: lowering_config = #[[CONFIG]] // CHECK: ^bb // CHECK: linalg.matmul -// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout, promote_operands = [0, 1], reduction = [0, 0, 8], subgroup = [2, 2, 0], workgroup = [64, 64, 0]}> +// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout, promote_operands = [0, 1], reduction = [0, 0, 16], subgroup = [2, 2, 0], workgroup = [64, 64, 0]}> // CHECK: iree_linalg_ext.yield // ----- From be30e3c689302774b7bafd6f3881108656065e97 Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram Date: Mon, 13 Jan 2025 09:28:01 -0600 Subject: [PATCH 3/6] give seperate heuristics to IGEMM Signed-off-by: Nirvedh Meshram --- .../iree/compiler/Codegen/Common/Passes.td | 2 +- .../Common/test/tile_large_tensors.mlir | 30 ++++++++--------- .../Dialect/GPU/TargetUtils/ConfigUtils.cpp | 33 ++++++++++--------- .../ROCDL/config_igemm_tile_and_fuse.mlir | 12 +++---- .../test/ROCDL/config_tile_and_fuse.mlir | 4 +-- .../test/ROCDL/pipeline_tile_and_fuse.mlir | 5 +-- 6 files changed, 45 insertions(+), 41 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index 245b07f6deaa..7188de257ca8 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -654,7 +654,7 @@ def TileLargeTensorsPass : ]; let options = [ Option<"maxVectorSize", "max-vector-size", "int64_t", - /*default=*/"256", + /*default=*/"64", "Maximum static size to tile to (i.e. all remaining ops will be smaller)">, ]; } diff --git a/compiler/src/iree/compiler/Codegen/Common/test/tile_large_tensors.mlir b/compiler/src/iree/compiler/Codegen/Common/test/tile_large_tensors.mlir index 3bb51a2d6d0c..66c73da981c0 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/tile_large_tensors.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/tile_large_tensors.mlir @@ -3,22 +3,22 @@ // RUN: FileCheck %s #map = affine_map<(d0, d1) -> (d0, d1)> -func.func @simple_generic(%3: tensor<64x512xf32>, %4: tensor<64x512xf32>, %5: tensor<64x512xf32>) -> tensor<64x512xf32> { +func.func @simple_generic(%3: tensor<64x256xf32>, %4: tensor<64x256xf32>, %5: tensor<64x256xf32>) -> tensor<64x256xf32> { %6 = linalg.generic { indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"] - } ins(%3, %4 : tensor<64x512xf32>, tensor<64x512xf32>) outs(%5 : tensor<64x512xf32>) { + } ins(%3, %4 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%5 : tensor<64x256xf32>) { ^bb0(%in: f32, %in_0: f32, %out: f32): %7 = arith.addf %in, %in_0 : f32 linalg.yield %7 : f32 - } -> tensor<64x512xf32> - return %6 : tensor<64x512xf32> + } -> tensor<64x256xf32> + return %6 : tensor<64x256xf32> } // CHECK-LABEL: func.func @simple_generic // CHECK: scf.for %{{.*}} = %c0 to %c64 step %c1 -// CHECK: scf.for %{{.*}} = %c0 to %c512 step %c256 -// CHECK: linalg.generic {{.*}} outs({{.*}}: tensor<1x256xf32>) +// CHECK: scf.for %{{.*}} = %c0 to %c256 step %c64 +// CHECK: linalg.generic {{.*}} outs({{.*}}: tensor<1x64xf32>) // ----- @@ -65,21 +65,21 @@ func.func @in_nested_region(%3: tensor<64x64xf32>, %4: tensor<64x64xf32>, %5: te // ----- -func.func @multiple_use_tilable_op(%3: tensor<64x512xf32>, %4: tensor<64x512xf32>) -> (tensor<64x512xf32>, tensor<512x64xf32>) { - %add_empty = tensor.empty() : tensor<64x512xf32> +func.func @multiple_use_tilable_op(%3: tensor<64x256xf32>, %4: tensor<64x256xf32>) -> (tensor<64x256xf32>, tensor<256x64xf32>) { + %add_empty = tensor.empty() : tensor<64x256xf32> %6 = linalg.add - ins(%3, %4 : tensor<64x512xf32>, tensor<64x512xf32>) - outs(%add_empty : tensor<64x512xf32>) -> tensor<64x512xf32> - %transpose_empty = tensor.empty() : tensor<512x64xf32> + ins(%3, %4 : tensor<64x256xf32>, tensor<64x256xf32>) + outs(%add_empty : tensor<64x256xf32>) -> tensor<64x256xf32> + %transpose_empty = tensor.empty() : tensor<256x64xf32> %7 = linalg.transpose - ins(%6 : tensor<64x512xf32>) - outs(%transpose_empty : tensor<512x64xf32>) permutation = [1, 0] - return %6, %7 : tensor<64x512xf32>, tensor<512x64xf32> + ins(%6 : tensor<64x256xf32>) + outs(%transpose_empty : tensor<256x64xf32>) permutation = [1, 0] + return %6, %7 : tensor<64x256xf32>, tensor<256x64xf32> } // CHECK-LABEL: func.func @multiple_use_tilable_op // CHECK: %[[ADD_TILING:.+]] = scf.for -// CHECK: linalg.add {{.*}} -> tensor<1x256xf32> +// CHECK: linalg.add {{.*}} -> tensor<1x64xf32> // CHECK: %[[T_TILING:.+]] = scf.for // CHECK: %[[FUSED_ADD:.+]] = linalg.add {{.*}} -> tensor<64x1xf32> // CHECK: linalg.transpose ins(%[[FUSED_ADD]] diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index a07887a7f35f..b9f5ea04586b 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -118,8 +118,8 @@ LogicalResult setDataTiledMultiMmaLoweringConfig( /// problem based on the available mma intrinsics. static std::optional getMmaScheduleFromProblemAndTarget( IREE::GPU::TargetAttr target, GPUMatmulShapeType problem, - bool transposedLhs, bool transposedRhs, bool mustBeAligned = true, - bool doCPromotion = false) { + bool transposedLhs, bool transposedRhs, bool isIGEMM, + bool mustBeAligned = true, bool doCPromotion = false) { const int64_t targetSubgroupSize = target.getPreferredSubgroupSize(); SmallVector intrinsics; for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { @@ -142,20 +142,20 @@ static std::optional getMmaScheduleFromProblemAndTarget( // See https://github.com/iree-org/iree/issues/16341 for details. int64_t mSize = ShapedType::getNumElements(problem.mSizes); int64_t nSize = ShapedType::getNumElements(problem.nSizes); + int64_t cacheLineSizeElements = kCacheLineSizeBits / inBitWidth; + int64_t bestKElementCountPerSubgroup = + isIGEMM ? cacheLineSizeElements / 2 : cacheLineSizeElements; if (mSize * nSize <= 512 * 512) { // For matmuls with small M*N size, we want to distribute M*N onto more // workgroups to fill the GPU. Use a smaller bestMNTileCountPerSubgroup // and a larger bestKTileCountPerSubgroup. seeds = {/*bestSubgroupCountPerWorkgroup=*/4, /*bestMNTileCountPerSubgroup=*/4, - /*bestKTileCountPerSubgroup=*/8, - /*bestKElementCountPerSubgroup*/ kCacheLineSizeBits * 2 / - inBitWidth}; + /*bestKTileCountPerSubgroup=*/8, bestKElementCountPerSubgroup * 2}; } else { seeds = {/*bestSubgroupCountPerWorkgroup=*/4, - /*bestMNTileCountPerSubgroup=*/16, - /*bestKTileCountPerSubgroup=*/4, - /*bestKElementCountPerSubgroup*/ kCacheLineSizeBits / inBitWidth}; + /*bestMNTileCountPerSubgroup=*/8, + /*bestKTileCountPerSubgroup=*/4, bestKElementCountPerSubgroup}; } // We target slightly below the full available shared Memory to leave room for @@ -181,7 +181,8 @@ static FailureOr> getMatmulLoweringConfigAndWorkgroupSize(SmallVector bounds, ArrayRef maps, ArrayRef operands, - IREE::GPU::TargetAttr target) { + IREE::GPU::TargetAttr target, + bool isIGEMM) { if (target.getWgp().getMma().empty()) return failure(); @@ -249,7 +250,7 @@ getMatmulLoweringConfigAndWorkgroupSize(SmallVector bounds, bool mustBeAligned = true; bool doCPromotion = false; std::optional schedule = getMmaScheduleFromProblemAndTarget( - target, problem, transposedLhs, transposedRhs); + target, problem, transposedLhs, transposedRhs, isIGEMM); // TODO (nirvedhmeshram, qedawkins): The performance with this will be bad if // the GEMM is accumulating (i.e doesnt have a zero fill dpsInit) as that @@ -259,9 +260,9 @@ getMatmulLoweringConfigAndWorkgroupSize(SmallVector bounds, LDBG("Attempting to deduce unaligned TileAndFuse MMA schedulee"); mustBeAligned = false; doCPromotion = true; - schedule = getMmaScheduleFromProblemAndTarget(target, problem, - transposedLhs, transposedRhs, - mustBeAligned, doCPromotion); + schedule = getMmaScheduleFromProblemAndTarget( + target, problem, transposedLhs, transposedRhs, isIGEMM, mustBeAligned, + doCPromotion); } if (!schedule) { @@ -384,7 +385,8 @@ setIGEMMConvolutionLoweringConfig(IREE::GPU::TargetAttr target, SmallVector bounds = igemmLoopBounds.value(); FailureOr> configAndWgSize = getMatmulLoweringConfigAndWorkgroupSize( - bounds, igemmContractionMaps.value(), igemmOperands.value(), target); + bounds, igemmContractionMaps.value(), igemmOperands.value(), target, + /*isIGEMM=*/true); if (failed(configAndWgSize)) { return failure(); } @@ -434,7 +436,8 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, LDBG("Matmul TileAndFuse Config"); FailureOr> configAndWgSize = - getMatmulLoweringConfigAndWorkgroupSize(bounds, maps, operands, target); + getMatmulLoweringConfigAndWorkgroupSize(bounds, maps, operands, target, + /*isIGEMM=*/false); if (failed(configAndWgSize)) { return failure(); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir index c1be57b94903..d8af22e58664 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir @@ -24,7 +24,7 @@ func.func @nhwc_conv_mfma() { // CHECK: linalg.conv_2d_nhwc_hwcf {{.*}}lowering_config = #iree_gpu.lowering_config // CHECK-SAME: mma_kind = #iree_gpu.mma_layout // CHECK-SAME: promote_operands = [0, 1] -// CHECK-SAME: reduction = [0, 0, 0, 0, 16] +// CHECK-SAME: reduction = [0, 0, 0, 0, 8] // CHECK-SAME: subgroup = [1, 2, 2, 1, 0] // CHECK-SAME: workgroup = [1, 2, 32, 64, 0] @@ -53,7 +53,7 @@ func.func @nchw_conv_mfma() { // CHECK: linalg.conv_2d_nchw_fchw {{.*}}lowering_config = #iree_gpu.lowering_config // CHECK-SAME: mma_kind = #iree_gpu.mma_layout // CHECK-SAME: promote_operands = [0, 1] -// CHECK-SAME: reduction = [0, 0, 0, 0, 16] +// CHECK-SAME: reduction = [0, 0, 0, 0, 8] // CHECK-SAME: subgroup = [1, 2, 2, 1, 0] // CHECK-SAME: workgroup = [1, 64, 2, 32, 0] @@ -81,9 +81,9 @@ func.func @nhwc_conv_unaligned_mfma() { // CHECK: linalg.conv_2d_nhwc_hwcf {{.*}}lowering_config = #iree_gpu.lowering_config // CHECK-SAME: mma_kind = #iree_gpu.mma_layout -// CHECK-SAME: padding = [2, 1, 32, 64, 64] +// CHECK-SAME: padding = [2, 1, 32, 64, 32] // CHECK-SAME: promote_operands = [0, 1, 2] -// CHECK-SAME: reduction = [0, 0, 0, 0, 16] +// CHECK-SAME: reduction = [0, 0, 0, 0, 8] // CHECK-SAME: subgroup = [2, 1, 2, 1, 0] // CHECK-SAME: workgroup = [2, 1, 32, 64, 0] @@ -111,8 +111,8 @@ func.func @nchw_conv_unaligned_mfma() { // CHECK: linalg.conv_2d_nchw_fchw {{.*}}lowering_config = #iree_gpu.lowering_config // CHECK-SAME: mma_kind = #iree_gpu.mma_layout -// CHECK-SAME: padding = [1, 64, 2, 32, 64] +// CHECK-SAME: padding = [1, 64, 2, 32, 32] // CHECK-SAME: promote_operands = [0, 1, 2] -// CHECK-SAME: reduction = [0, 0, 0, 0, 16] +// CHECK-SAME: reduction = [0, 0, 0, 0, 8] // CHECK-SAME: subgroup = [1, 2, 2, 1, 0] // CHECK-SAME: workgroup = [1, 64, 2, 32, 0] diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir index 3576ea4416a2..5c1135da6a59 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir @@ -137,8 +137,8 @@ func.func @mfma_matmul_1024x1024x1024(%lhs: tensor<1024x1024xf16>, %rhs: tensor< // CHECK-SAME: mma_kind = #iree_gpu.mma_layout // CHECK-SAME: promote_operands = [0, 1] // CHECK-SAME: reduction = [0, 0, 4] -// CHECK-SAME: subgroup = [4, 4, 0] -// CHECK-SAME: workgroup = [128, 128, 0] +// CHECK-SAME: subgroup = [2, 4, 0] +// CHECK-SAME: workgroup = [64, 128, 0] // ----- diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir index 6c9e4d5f752d..1a521e61ebd9 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir @@ -1013,8 +1013,9 @@ hal.executable public @main { // CHECK: scf.yield %[[REDUCE]] // CHECK: scf.for %{{.*}} = %{{.*}} to %c16 step %c1 -// CHECK-COUNT-4: arith.addf {{.*}} : vector<9x9xf32> -// CHECK: vector.transfer_write {{.*}} vector<9x9xi8>, memref<32x16x9x9xi8, #hal.descriptor_type> +// CHECK: scf.for +// CHECK-COUNT-4: arith.addf {{.*}} : vector<9xf32> +// CHECK: vector.transfer_write {{.*}} vector<9xi8>, memref<32x16x9x9xi8, #hal.descriptor_type> // ----- From 6e52bd63429b7e336a090af757179fb589c49807 Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram Date: Tue, 14 Jan 2025 11:25:36 -0600 Subject: [PATCH 4/6] bring back prefetching --- .../Dialect/GPU/TargetUtils/ConfigUtils.cpp | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index b9f5ea04586b..9aff3a263a9d 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -393,16 +393,9 @@ setIGEMMConvolutionLoweringConfig(IREE::GPU::TargetAttr target, std::array workgroupSize = {configAndWgSize->second, 1, 1}; LoweringConfigAttr loweringConfig = configAndWgSize->first; - bool usePrefetchSharedMemory = true; - // Prefetching has issues when doing c promotion, see - // https://github.com/iree-org/iree/issues/19612. - if (llvm::any_of(getPromotedOperandList(loweringConfig).value(), - [](int64_t promote) { return promote == 2; })) { - usePrefetchSharedMemory = false; - } SmallVector pipelineAttrs; auto pipelineOptions = IREE::GPU::GPUPipelineOptionsAttr::get( - linalgOp->getContext(), /*prefetchSharedMemory=*/usePrefetchSharedMemory, + linalgOp->getContext(), /*prefetchSharedMemory=*/true, /*no_reduce_shared_memory_bank_conflicts=*/false, /*use_igemm_convolution=*/true, /*reorder_workgroups_strategy=*/std::nullopt); @@ -444,16 +437,9 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, std::array workgroupSize = {configAndWgSize->second, 1, 1}; LoweringConfigAttr loweringConfig = configAndWgSize->first; - bool usePrefetchSharedMemory = true; - // Prefetching has issues when doing c promotion, see - // https://github.com/iree-org/iree/issues/19612. - if (llvm::any_of(getPromotedOperandList(loweringConfig).value(), - [](int64_t promote) { return promote == 2; })) { - usePrefetchSharedMemory = false; - } SmallVector pipelineAttrs; auto pipelineOptions = IREE::GPU::GPUPipelineOptionsAttr::get( - linalgOp->getContext(), /*prefetchSharedMemory=*/usePrefetchSharedMemory, + linalgOp->getContext(), /*prefetchSharedMemory=*/true, /*no_reduce_shared_memory_bank_conflicts=*/false, /*use_igemm_convolution=*/false, /*reorder_workgroups_strategy=*/std::nullopt); From acdbd1836cbeec462517180965a3504ad456955e Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram Date: Tue, 14 Jan 2025 12:11:03 -0600 Subject: [PATCH 5/6] Increase default threshold of TileLargeTensor pass to 128 --- compiler/src/iree/compiler/Codegen/Common/Passes.td | 2 +- .../compiler/Codegen/Common/test/tile_large_tensors.mlir | 6 +++--- .../Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir | 6 ++---- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index 7188de257ca8..7ebd5fe49257 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -654,7 +654,7 @@ def TileLargeTensorsPass : ]; let options = [ Option<"maxVectorSize", "max-vector-size", "int64_t", - /*default=*/"64", + /*default=*/"128", "Maximum static size to tile to (i.e. all remaining ops will be smaller)">, ]; } diff --git a/compiler/src/iree/compiler/Codegen/Common/test/tile_large_tensors.mlir b/compiler/src/iree/compiler/Codegen/Common/test/tile_large_tensors.mlir index 66c73da981c0..5c5a102444e1 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/tile_large_tensors.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/tile_large_tensors.mlir @@ -17,8 +17,8 @@ func.func @simple_generic(%3: tensor<64x256xf32>, %4: tensor<64x256xf32>, %5: te // CHECK-LABEL: func.func @simple_generic // CHECK: scf.for %{{.*}} = %c0 to %c64 step %c1 -// CHECK: scf.for %{{.*}} = %c0 to %c256 step %c64 -// CHECK: linalg.generic {{.*}} outs({{.*}}: tensor<1x64xf32>) +// CHECK: scf.for %{{.*}} = %c0 to %c256 step %c128 +// CHECK: linalg.generic {{.*}} outs({{.*}}: tensor<1x128xf32>) // ----- @@ -79,7 +79,7 @@ func.func @multiple_use_tilable_op(%3: tensor<64x256xf32>, %4: tensor<64x256xf32 // CHECK-LABEL: func.func @multiple_use_tilable_op // CHECK: %[[ADD_TILING:.+]] = scf.for -// CHECK: linalg.add {{.*}} -> tensor<1x64xf32> +// CHECK: linalg.add {{.*}} -> tensor<1x128xf32> // CHECK: %[[T_TILING:.+]] = scf.for // CHECK: %[[FUSED_ADD:.+]] = linalg.add {{.*}} -> tensor<64x1xf32> // CHECK: linalg.transpose ins(%[[FUSED_ADD]] diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir index 1a521e61ebd9..525086f8b5bd 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir @@ -1013,10 +1013,8 @@ hal.executable public @main { // CHECK: scf.yield %[[REDUCE]] // CHECK: scf.for %{{.*}} = %{{.*}} to %c16 step %c1 -// CHECK: scf.for -// CHECK-COUNT-4: arith.addf {{.*}} : vector<9xf32> -// CHECK: vector.transfer_write {{.*}} vector<9xi8>, memref<32x16x9x9xi8, #hal.descriptor_type> - +// CHECK-COUNT-4: arith.addf {{.*}} : vector<9x9xf32> +// CHECK: vector.transfer_write {{.*}} vector<9x9xi8>, memref<32x16x9x9xi8, #hal.descriptor_type> // ----- #pipeline_layout = #hal.pipeline.layout Date: Tue, 14 Jan 2025 12:13:06 -0600 Subject: [PATCH 6/6] Prefetch shared memory fix tests Signed-off-by: Nirvedh Meshram --- .../LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir | 4 ++-- .../Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir index d8af22e58664..cf170ef7d930 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_igemm_tile_and_fuse.mlir @@ -76,7 +76,7 @@ func.func @nhwc_conv_unaligned_mfma() { // CHECK-LABEL: func.func @nhwc_conv_unaligned_mfma // CHECK-SAME: #iree_codegen.translation_info, // CHECK-LABEL: func.func @unaligned_to_intrinsic_batched_matmul // CHECK-SAME: #iree_codegen.translation_info} +// CHECK-SAME: {gpu_pipeline_options = #iree_gpu.pipeline_options} // CHECK: linalg.batch_matmul {{.*}}lowering_config = #iree_gpu.lowering_config // CHECK-SAME: padding = [1, 16, 16, 4] // CHECK-SAME: promote_operands = [0, 1, 2] @@ -306,7 +306,7 @@ func.func @unaligned_to_intrinsic_batched_matmul_tiling_check(%lhs : tensor<12x5 // CHECK-LABEL: func.func @unaligned_to_intrinsic_batched_matmul_tiling_check // CHECK-SAME: #iree_codegen.translation_info} +// CHECK-SAME: {gpu_pipeline_options = #iree_gpu.pipeline_options} // CHECK: linalg.batch_matmul {{.*}}lowering_config = #iree_gpu.lowering_config // CHECK-SAME: padding = [1, 16, 512, 4] // CHECK-SAME: promote_operands = [0, 1, 2]