Skip to content

Commit

Permalink
Adds a flag to enable/disable vector contract custom kernels in `LLVM…
Browse files Browse the repository at this point in the history
…CPUMmt4dVectorLoweringPass` (#16867)

The `VectorContractCustomKernelsPatterns` are being replaced by
Microkernels and improvements to the DT-only path.

Adds a flag to toggle them on/off. For now they default to on.
  • Loading branch information
KoolJBlack authored Apr 8, 2024
1 parent 39bf204 commit dcc8e19
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h"
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
Expand All @@ -29,8 +30,11 @@ namespace mlir::iree_compiler {
namespace {
struct LLVMCPUMmt4dVectorLoweringPass
: public LLVMCPUMmt4dVectorLoweringBase<LLVMCPUMmt4dVectorLoweringPass> {
LLVMCPUMmt4dVectorLoweringPass(bool enableVectorContractCustomKernels) {
this->enableVectorContractCustomKernels = enableVectorContractCustomKernels;
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<vector::VectorDialect>();
registry.insert<vector::VectorDialect, LLVM::LLVMDialect>();
}
void runOnOperation() override;
};
Expand Down Expand Up @@ -89,7 +93,7 @@ void LLVMCPUMmt4dVectorLoweringPass::runOnOperation() {
}
}

{
if (enableVectorContractCustomKernels) {
// Special-case vector.contract codegen paths. This needs to happen
// just before the generic vector ops lowerings.
RewritePatternSet patterns(context);
Expand All @@ -102,8 +106,9 @@ void LLVMCPUMmt4dVectorLoweringPass::runOnOperation() {
}

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMCPUMmt4dVectorLoweringPass() {
return std::make_unique<LLVMCPUMmt4dVectorLoweringPass>();
createLLVMCPUMmt4dVectorLoweringPass(bool enableVectorContractCustomKernels) {
return std::make_unique<LLVMCPUMmt4dVectorLoweringPass>(
enableVectorContractCustomKernels);
}

} // namespace mlir::iree_compiler
9 changes: 8 additions & 1 deletion compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ static llvm::cl::opt<bool> clUseSoftmaxInterFusion(
llvm::cl::desc("Enables inter-pass fusion for the DecomposeSoftmax pass."),
llvm::cl::init(true));

static llvm::cl::opt<bool> clEnableVectorContractCustomKernels(
"iree-llvmcpu-enable-vector-contract-custom-kernels",
llvm::cl::desc("Enables vector contract custom kernels for "
"LLVMCPUMmt4dVectorLowering pass."),
llvm::cl::init(true));

static void addTileAndDistributePasses(OpPassManager &pm) {
pm.addPass(createTileAndDistributeToWorkgroupsPass());
auto &nestedModulePM = pm.nest<ModuleOp>();
Expand Down Expand Up @@ -579,7 +585,8 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &passManager,

// Vector lowering of Mmt4d.
nestedModulePM.addNestedPass<func::FuncOp>(
createLLVMCPUMmt4dVectorLoweringPass());
createLLVMCPUMmt4dVectorLoweringPass(
clEnableVectorContractCustomKernels));

// Generic vector lowering.
LLVMCPUVectorLoweringPassOptions options;
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ createLLVMCPULowerExecutableTargetPass();
std::unique_ptr<Pass> createExpandF16OpToF32Pass();

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMCPUMmt4dVectorLoweringPass();
createLLVMCPUMmt4dVectorLoweringPass(
bool enableVectorContractCustomKernels = true);

/// Pass to perform peeling on non-distributed loops.
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def LLVMCPULowerExecutableTarget :
def LLVMCPUMmt4dVectorLowering
: InterfacePass<"iree-llvmcpu-mmt4d-vector-lowering", "mlir::FunctionOpInterface"> {
let summary = "Apply vector lowering logic to vector ops";
let options = [
Option<"enableVectorContractCustomKernels", "vector-contract-custom-kernels", "bool",
/*default=*/"true",
"Flag to enable or disable vector contract custom kernels.">,
];
let constructor =
"mlir::iree_compiler::createLLVMCPUMmt4dVectorLoweringPass()";
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// RUN: iree-opt %s --pass-pipeline="builtin.module(func.func(iree-llvmcpu-mmt4d-vector-lowering),iree-codegen-llvmcpu-vector-lowering-pipeline)" --split-input-file | FileCheck %s
// RUN: iree-opt %s --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-llvmcpu-mmt4d-vector-lowering{vector-contract-custom-kernels=false})))))" --split-input-file | FileCheck %s -check-prefix=CHECK-KERNEL-OFF
// RUN: iree-opt %s --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-llvmcpu-mmt4d-vector-lowering{vector-contract-custom-kernels=true})))))" --split-input-file | FileCheck %s -check-prefix=CHECK-KERNEL-ON

#map0 = affine_map<()[s0] -> (s0 * 64)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
Expand Down Expand Up @@ -171,3 +173,49 @@ module {
// CHECK-LABEL: func.func @matmul_gather() {
// CHECK-32: vector.fma
// CHECK: linalg.generic

// -----

// CHECK-KERNEL-OFF-LABEL: @simpul_mul_mixed_mini_no_custom_kernel
// CHECK-KERNEL-OFF-NOT: llvm.inline_asm asm_dialect

hal.executable private @simpul_mul_mixed_mini_dispatch {
hal.executable.variant public @embedded_elf_arm_64 target(<"llvm-cpu", "embedded-elf-arm_64", {cpu = "generic", cpu_features = "+neon,+i8mm,+reserve-x18", data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128", native_vector_size = 16 : i64, target_triple = "aarch64-unknown-unknown-eabi-elf", ukernels = "none"}>) {
hal.executable.export public @simpul_mul_mixed_mini_no_custom_kernel ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>], translation_info = #iree_codegen.translation_info<Mmt4dTilingExpert>} {
^bb0(%arg0: !hal.device):
%c1 = arith.constant 1 : index
hal.return %c1, %c1, %c1 : index, index, index
}
builtin.module {
func.func @simpul_mul_mixed_mini_no_custom_kernel(%5 : vector<1x1x8x1xi8>, %6 : vector<1x1x8x1xi8> , %arg3 : vector<1x1x8x8xi32> ) -> vector<1x1x8x8xi32> {
%7 = arith.extsi %5 : vector<1x1x8x1xi8> to vector<1x1x8x1xi32>
%8 = arith.extsi %6 : vector<1x1x8x1xi8> to vector<1x1x8x1xi32>
%9 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %7, %8, %arg3 : vector<1x1x8x1xi32>, vector<1x1x8x1xi32> into vector<1x1x8x8xi32>
return %9 : vector<1x1x8x8xi32>
}
}
}
}

// -----

// CHECK-KERNEL-ON-LABEL: @simpul_mul_mixed_mini_custom_kernel
// CHECK-KERNEL-ON-DAG: llvm.inline_asm asm_dialect

hal.executable private @simpul_mul_mixed_mini_dispatch {
hal.executable.variant public @embedded_elf_arm_64 target(<"llvm-cpu", "embedded-elf-arm_64", {cpu = "generic", cpu_features = "+neon,+i8mm,+reserve-x18", data_layout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128", native_vector_size = 16 : i64, target_triple = "aarch64-unknown-unknown-eabi-elf", ukernels = "none"}>) {
hal.executable.export public @simpul_mul_mixed_mini_custom_kernel ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>], translation_info = #iree_codegen.translation_info<Mmt4dTilingExpert>} {
^bb0(%arg0: !hal.device):
%c1 = arith.constant 1 : index
hal.return %c1, %c1, %c1 : index, index, index
}
builtin.module {
func.func @simpul_mul_mixed_mini_custom_kernel(%5 : vector<1x1x8x1xi8>, %6 : vector<1x1x8x1xi8> , %arg3 : vector<1x1x8x8xi32> ) -> vector<1x1x8x8xi32> {
%7 = arith.extsi %5 : vector<1x1x8x1xi8> to vector<1x1x8x1xi32>
%8 = arith.extsi %6 : vector<1x1x8x1xi8> to vector<1x1x8x1xi32>
%9 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %7, %8, %arg3 : vector<1x1x8x1xi32>, vector<1x1x8x1xi32> into vector<1x1x8x8xi32>
return %9 : vector<1x1x8x8xi32>
}
}
}
}

0 comments on commit dcc8e19

Please sign in to comment.