Skip to content

Commit

Permalink
[mlir][spirv] Add a generic "convert-to-spirv" pass
Browse files Browse the repository at this point in the history
This commit implements a MVP version of an MLIR lowering pipeline to SPIR-V.
The goal is to have a better test coverage of SPIR-V compilation upstream,
and enable writing simple kernels by hand. The dialects supported in this
version include arith, vector (only 1-D vectors with size 2,3,4,8 or 16),
scf, ub, index, func and math.
  • Loading branch information
angelz913 committed Jun 19, 2024
1 parent 79e668f commit cc9eb3a
Show file tree
Hide file tree
Showing 13 changed files with 1,067 additions and 0 deletions.
22 changes: 22 additions & 0 deletions mlir/include/mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//===- ConvertToSPIRVPass.h - Conversion to SPIR-V pass ---*- C++ -*-=========//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H
#define MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H

#include <memory>

namespace mlir {
class Pass;

#define GEN_PASS_DECL_CONVERTTOSPIRVPASS
#include "mlir/Conversion/Passes.h.inc"

} // namespace mlir

#endif // MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H
1 change: 1 addition & 0 deletions mlir/include/mlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
#include "mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h"
Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> {
];
}

//===----------------------------------------------------------------------===//
// ToSPIRV
//===----------------------------------------------------------------------===//

def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
let summary = "Convert to SPIR-V";
let description = [{
This is a generic pass to convert to SPIR-V.
}];
let dependentDialects = ["spirv::SPIRVDialect"];
}

//===----------------------------------------------------------------------===//
// AffineToStandard
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ add_subdirectory(ControlFlowToLLVM)
add_subdirectory(ControlFlowToSCF)
add_subdirectory(ControlFlowToSPIRV)
add_subdirectory(ConvertToLLVM)
add_subdirectory(ConvertToSPIRV)
add_subdirectory(FuncToEmitC)
add_subdirectory(FuncToLLVM)
add_subdirectory(FuncToSPIRV)
Expand Down
22 changes: 22 additions & 0 deletions mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
set(LLVM_OPTIONAL_SOURCES
ConvertToSPIRVPass.cpp
)

add_mlir_conversion_library(MLIRConvertToSPIRVPass
ConvertToSPIRVPass.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ConvertToSPIRV

DEPENDS
MLIRConversionPassIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRRewrite
MLIRSPIRVConversion
MLIRSPIRVDialect
MLIRSupport
MLIRTransformUtils
)
72 changes: 72 additions & 0 deletions mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
//===- ConvertToSPIRVPass.cpp - MLIR SPIR-V Conversion --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <memory>

#define DEBUG_TYPE "convert-to-spirv"

namespace mlir {
#define GEN_PASS_DEF_CONVERTTOSPIRVPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

using namespace mlir;

namespace {

/// A pass to perform the SPIR-V conversion.
struct ConvertToSPIRVPass final
: impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {

public:
void runOnOperation() final {
MLIRContext *context = &getContext();
Operation *op = getOperation();

spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
SPIRVTypeConverter typeConverter(targetAttr);

RewritePatternSet patterns(context);
ScfToSPIRVContext scfToSPIRVContext;

// Populate patterns.
arith::populateCeilFloorDivExpandOpsPatterns(patterns);
arith::populateArithToSPIRVPatterns(typeConverter, patterns);
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
populateFuncToSPIRVPatterns(typeConverter, patterns);
index::populateIndexToSPIRVPatterns(typeConverter, patterns);
populateVectorToSPIRVPatterns(typeConverter, patterns);
populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);

std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);

if (failed(applyPartialConversion(op, *target, std::move(patterns))))
return signalPassFailure();
}
};

} // namespace
Loading

0 comments on commit cc9eb3a

Please sign in to comment.