From dcc8e1990e09370d8818c42bebdec5c5e372ec01 Mon Sep 17 00:00:00 2001 From: Kojo Acquah Date: Mon, 8 Apr 2024 08:21:58 -0700 Subject: [PATCH] Adds a flag to enable/disable vector contract custom kernels in `LLVMCPUMmt4dVectorLoweringPass` (#16867) The `VectorContractCustomKernelsPatterns` are being replaced by Microkernels and improvements to the DT-only path. Adds a flag to toggle them on/off. For now they default to on. --- .../LLVMCPU/LLVMCPUMmt4dVectorLowering.cpp | 13 +++-- .../iree/compiler/Codegen/LLVMCPU/Passes.cpp | 9 +++- .../iree/compiler/Codegen/LLVMCPU/Passes.h | 3 +- .../iree/compiler/Codegen/LLVMCPU/Passes.td | 5 ++ .../LLVMCPU/test/aarch64_vector_lowering.mlir | 48 +++++++++++++++++++ 5 files changed, 72 insertions(+), 6 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMmt4dVectorLowering.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMmt4dVectorLowering.cpp index cd4055c15db2..7044f2c7c9c2 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMmt4dVectorLowering.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMmt4dVectorLowering.cpp @@ -7,6 +7,7 @@ #include "iree/compiler/Codegen/LLVMCPU/PassDetail.h" #include "iree/compiler/Codegen/LLVMCPU/Passes.h" #include "llvm/Support/Debug.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" @@ -29,8 +30,11 @@ namespace mlir::iree_compiler { namespace { struct LLVMCPUMmt4dVectorLoweringPass : public LLVMCPUMmt4dVectorLoweringBase { + LLVMCPUMmt4dVectorLoweringPass(bool enableVectorContractCustomKernels) { + this->enableVectorContractCustomKernels = enableVectorContractCustomKernels; + } void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } void runOnOperation() override; }; @@ -89,7 +93,7 @@ void LLVMCPUMmt4dVectorLoweringPass::runOnOperation() { } } - { + if (enableVectorContractCustomKernels) { // Special-case vector.contract codegen paths. This needs to happen // just before the generic vector ops lowerings. RewritePatternSet patterns(context); @@ -102,8 +106,9 @@ void LLVMCPUMmt4dVectorLoweringPass::runOnOperation() { } std::unique_ptr> -createLLVMCPUMmt4dVectorLoweringPass() { - return std::make_unique(); +createLLVMCPUMmt4dVectorLoweringPass(bool enableVectorContractCustomKernels) { + return std::make_unique( + enableVectorContractCustomKernels); } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp index 6e0bec35a0f6..4aeffdb08b08 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp @@ -87,6 +87,12 @@ static llvm::cl::opt clUseSoftmaxInterFusion( llvm::cl::desc("Enables inter-pass fusion for the DecomposeSoftmax pass."), llvm::cl::init(true)); +static llvm::cl::opt clEnableVectorContractCustomKernels( + "iree-llvmcpu-enable-vector-contract-custom-kernels", + llvm::cl::desc("Enables vector contract custom kernels for " + "LLVMCPUMmt4dVectorLowering pass."), + llvm::cl::init(true)); + static void addTileAndDistributePasses(OpPassManager &pm) { pm.addPass(createTileAndDistributeToWorkgroupsPass()); auto &nestedModulePM = pm.nest(); @@ -579,7 +585,8 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &passManager, // Vector lowering of Mmt4d. nestedModulePM.addNestedPass( - createLLVMCPUMmt4dVectorLoweringPass()); + createLLVMCPUMmt4dVectorLoweringPass( + clEnableVectorContractCustomKernels)); // Generic vector lowering. LLVMCPUVectorLoweringPassOptions options; diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h index 68320def1728..2efe83a2f99c 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h @@ -49,7 +49,8 @@ createLLVMCPULowerExecutableTargetPass(); std::unique_ptr createExpandF16OpToF32Pass(); std::unique_ptr> -createLLVMCPUMmt4dVectorLoweringPass(); +createLLVMCPUMmt4dVectorLoweringPass( + bool enableVectorContractCustomKernels = true); /// Pass to perform peeling on non-distributed loops. std::unique_ptr> diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td index 796941619ee8..7c77d568377b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td @@ -81,6 +81,11 @@ def LLVMCPULowerExecutableTarget : def LLVMCPUMmt4dVectorLowering : InterfacePass<"iree-llvmcpu-mmt4d-vector-lowering", "mlir::FunctionOpInterface"> { let summary = "Apply vector lowering logic to vector ops"; + let options = [ + Option<"enableVectorContractCustomKernels", "vector-contract-custom-kernels", "bool", + /*default=*/"true", + "Flag to enable or disable vector contract custom kernels.">, + ]; let constructor = "mlir::iree_compiler::createLLVMCPUMmt4dVectorLoweringPass()"; } diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/aarch64_vector_lowering.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/aarch64_vector_lowering.mlir index aea0d850b553..16d81db7ad72 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/aarch64_vector_lowering.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/aarch64_vector_lowering.mlir @@ -1,4 +1,6 @@ // RUN: iree-opt %s --pass-pipeline="builtin.module(func.func(iree-llvmcpu-mmt4d-vector-lowering),iree-codegen-llvmcpu-vector-lowering-pipeline)" --split-input-file | FileCheck %s +// RUN: iree-opt %s --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-llvmcpu-mmt4d-vector-lowering{vector-contract-custom-kernels=false})))))" --split-input-file | FileCheck %s -check-prefix=CHECK-KERNEL-OFF +// RUN: iree-opt %s --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-llvmcpu-mmt4d-vector-lowering{vector-contract-custom-kernels=true})))))" --split-input-file | FileCheck %s -check-prefix=CHECK-KERNEL-ON #map0 = affine_map<()[s0] -> (s0 * 64)> #map1 = affine_map<(d0, d1, d2) -> (d0, d2)> @@ -171,3 +173,49 @@ module { // CHECK-LABEL: func.func @matmul_gather() { // CHECK-32: vector.fma // CHECK: linalg.generic + +// ----- + +// CHECK-KERNEL-OFF-LABEL: @simpul_mul_mixed_mini_no_custom_kernel +// CHECK-KERNEL-OFF-NOT: llvm.inline_asm asm_dialect + +hal.executable private @simpul_mul_mixed_mini_dispatch { + hal.executable.variant public @embedded_elf_arm_64 target(<"llvm-cpu", "embedded-elf-arm_64", {cpu = "generic", cpu_features = "+neon,+i8mm,+reserve-x18", data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128", native_vector_size = 16 : i64, target_triple = "aarch64-unknown-unknown-eabi-elf", ukernels = "none"}>) { + hal.executable.export public @simpul_mul_mixed_mini_no_custom_kernel ordinal(0) layout(#hal.pipeline.layout, <1, storage_buffer>]>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>], translation_info = #iree_codegen.translation_info} { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + func.func @simpul_mul_mixed_mini_no_custom_kernel(%5 : vector<1x1x8x1xi8>, %6 : vector<1x1x8x1xi8> , %arg3 : vector<1x1x8x8xi32> ) -> vector<1x1x8x8xi32> { + %7 = arith.extsi %5 : vector<1x1x8x1xi8> to vector<1x1x8x1xi32> + %8 = arith.extsi %6 : vector<1x1x8x1xi8> to vector<1x1x8x1xi32> + %9 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %7, %8, %arg3 : vector<1x1x8x1xi32>, vector<1x1x8x1xi32> into vector<1x1x8x8xi32> + return %9 : vector<1x1x8x8xi32> + } + } + } + } + +// ----- + +// CHECK-KERNEL-ON-LABEL: @simpul_mul_mixed_mini_custom_kernel +// CHECK-KERNEL-ON-DAG: llvm.inline_asm asm_dialect + +hal.executable private @simpul_mul_mixed_mini_dispatch { + hal.executable.variant public @embedded_elf_arm_64 target(<"llvm-cpu", "embedded-elf-arm_64", {cpu = "generic", cpu_features = "+neon,+i8mm,+reserve-x18", data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128", native_vector_size = 16 : i64, target_triple = "aarch64-unknown-unknown-eabi-elf", ukernels = "none"}>) { + hal.executable.export public @simpul_mul_mixed_mini_custom_kernel ordinal(0) layout(#hal.pipeline.layout, <1, storage_buffer>]>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>], translation_info = #iree_codegen.translation_info} { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + func.func @simpul_mul_mixed_mini_custom_kernel(%5 : vector<1x1x8x1xi8>, %6 : vector<1x1x8x1xi8> , %arg3 : vector<1x1x8x8xi32> ) -> vector<1x1x8x8xi32> { + %7 = arith.extsi %5 : vector<1x1x8x1xi8> to vector<1x1x8x1xi32> + %8 = arith.extsi %6 : vector<1x1x8x1xi8> to vector<1x1x8x1xi32> + %9 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind} %7, %8, %arg3 : vector<1x1x8x1xi32>, vector<1x1x8x1xi32> into vector<1x1x8x8xi32> + return %9 : vector<1x1x8x8xi32> + } + } + } + }