Skip to content

Commit

Permalink
Lower krnl.call to llvm (#1408)
Browse files Browse the repository at this point in the history
Signed-off-by: chentong319 <chentong@us.ibm.com>
Co-authored-by: Kevin O'Brien <caomhin@us.ibm.com>
  • Loading branch information
chentong319 and caoimhinuibrian authored Jun 10, 2022
1 parent 8b890b6 commit 5a321d2
Show file tree
Hide file tree
Showing 13 changed files with 708 additions and 25 deletions.
1 change: 1 addition & 0 deletions src/Conversion/KrnlToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
add_onnx_mlir_library(OMKrnlToLLVM
ConvertKrnlToLLVM.cpp
KrnlFindIndex.cpp
KrnlCall.cpp
KrnlEntryPoint.cpp
KrnlGetRef.cpp
KrnlGlobal.cpp
Expand Down
1 change: 1 addition & 0 deletions src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ std::unique_ptr<Pass> createConvertKrnlToLLVMPass() {
void populateKrnlToLLVMConversion(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, MLIRContext *ctx,
ArrayRef<bool> outputOMTensorOwnerships, bool singleEntryPoint) {
krnl::populateLoweringKrnlCallOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlEntryPointOpPattern(
typeConverter, patterns, ctx, outputOMTensorOwnerships, singleEntryPoint);
krnl::populateLoweringKrnlFindIndexOpPattern(typeConverter, patterns, ctx);
Expand Down
3 changes: 3 additions & 0 deletions src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ void populateKrnlToLLVMConversion(mlir::LLVMTypeConverter &typeConverter,
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx,
llvm::ArrayRef<bool> constantOutputs, bool singleEntryPoint);

void populateLoweringKrnlCallOpPattern(mlir::TypeConverter &typeConverter,
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);

void populateLoweringKrnlEntryPointOpPattern(mlir::TypeConverter &typeConverter,
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx,
llvm::ArrayRef<bool> constantOutputs, bool singleEntryPoint);
Expand Down
213 changes: 213 additions & 0 deletions src/Conversion/KrnlToLLVM/KrnlCall.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@

/*
* SPDX-License-Identifier: Apache-2.0
*/

//===-------------- KrnlCall.cpp - Lower KrnlCallOp -----------------------===//
//
// Copyright 2022 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers the KrnlCallOp operator.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"

#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp"
#include "src/Dialect/Krnl/DialectBuilder.hpp"
#include "src/Dialect/Krnl/KrnlHelper.hpp"
#include "src/Dialect/Krnl/KrnlOps.hpp"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "krnl_to_llvm"

using namespace mlir;

namespace onnx_mlir {
namespace krnl {

class KrnlCallOpLowering : public ConversionPattern {
public:
explicit KrnlCallOpLowering(
TypeConverter &typeConverter, MLIRContext *context)
: ConversionPattern(
typeConverter, KrnlCallOp::getOperationName(), 1, context) {}

LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
KrnlCallOpAdaptor krnlCallAdaptor(operands);
Location loc = op->getLoc();
KrnlCallOp krnlCallOp = llvm::cast<KrnlCallOp>(op);

// Get a symbol reference to the function, inserting it if necessary.
ModuleOp module = op->getParentOfType<ModuleOp>();
llvm::SmallVector<Type, 4> parameterTypeList;
llvm::SmallVector<Value, 4> parameterList;
handleOneParameter(rewriter, op, krnlCallAdaptor.result(),
krnlCallOp.result(), parameterTypeList, parameterList);

// Some type of operands has been converted.
// It is better to check the type of original operands.
// Thus, the two kinds of operands are used together.
auto itConverted = krnlCallAdaptor.parameters().begin();
auto itOriginal = krnlCallOp.parameters().begin();
for (; itConverted != krnlCallAdaptor.parameters().end();
itConverted++, itOriginal++) {
handleOneParameter(rewriter, op, *itConverted, *itOriginal,
parameterTypeList, parameterList);
}

// Handle the Attributes
for (auto namedAttr : op->getAttrs()) {
// Avoid the funcName() Attribute
if (namedAttr.getName().getValue().equals("funcName"))
continue;
handleOneAttribute(rewriter, getTypeConverter(), op, namedAttr.getValue(),
parameterTypeList, parameterList);
}

auto callRef = getOrInsertCall(
rewriter, module, krnlCallOp.funcName(), parameterTypeList);
rewriter.create<func::CallOp>(
loc, callRef, ArrayRef<Type>({}), parameterList);

rewriter.eraseOp(op);
return success();
}

private:
static void handleOneParameter(PatternRewriter &rewriter, Operation *op,
Value parameter, Value original,
llvm::SmallVector<Type, 4> &parameterTypeList,
llvm::SmallVector<Value, 4> &parameterList) {
MLIRContext *context = op->getContext();
Location loc = op->getLoc();
ModuleOp module = op->getParentOfType<ModuleOp>();
const auto &apiRegistry = RuntimeAPIRegistry::build(module, rewriter);

// Check the original type, not after type conversion
Type ty = original.getType();
if (ty.isa<MemRefType>()) {
auto int64Ty = IntegerType::get(context, 64);
auto memRefTy = parameter.getType().dyn_cast<LLVM::LLVMStructType>();
auto memRefRank = krnl::getRankFromMemRefType(memRefTy);
auto memRefRankVal = rewriter.create<LLVM::ConstantOp>(
loc, int64Ty, rewriter.getI64IntegerAttr(memRefRank));
Value omTensor = RuntimeAPI::callApi(rewriter, loc, apiRegistry,
RuntimeAPI::API::CREATE_OMTENSOR, {memRefRankVal});

krnl::fillOMTensorWithMemRef(parameter, omTensor, false /*outOwning*/,
rewriter, loc, apiRegistry, module);
auto int8Ty = IntegerType::get(context, 8);
auto opaquePtrTy = LLVM::LLVMPointerType::get(int8Ty);
parameterTypeList.emplace_back(opaquePtrTy);
parameterList.emplace_back(omTensor);
} else {
parameterTypeList.emplace_back(parameter.getType());
parameterList.emplace_back(parameter);
}
}

static void handleOneAttribute(PatternRewriter &rewriter,
TypeConverter *typeConverter, Operation *op, Attribute attribute,
llvm::SmallVector<Type, 4> &parameterTypeList,
llvm::SmallVector<Value, 4> &parameterList) {
auto *context = op->getContext();
auto loc = op->getLoc();
ModuleOp module = op->getParentOfType<ModuleOp>();

TypeSwitch<Attribute>(attribute)
.Case<StringAttr>([&](StringAttr strAttr) {
StringRef attrValue = strAttr.getValue();
LLVM::GlobalOp globalStr =
krnl::getOrCreateGlobalString(attrValue, loc, rewriter, module,
static_cast<LLVMTypeConverter *>(typeConverter));
Value strPtr = krnl::getPtrToGlobalString(globalStr, loc, rewriter);
auto int8Ty = IntegerType::get(context, 8);
auto opaquePtrTy = LLVM::LLVMPointerType::get(int8Ty);
parameterTypeList.emplace_back(opaquePtrTy);
parameterList.emplace_back(strPtr);
})
.Case<IntegerAttr>([&](IntegerAttr integerAttr) {
auto int64Ty = IntegerType::get(context, 64);
Value cst =
rewriter.create<LLVM::ConstantOp>(loc, int64Ty, integerAttr);
parameterTypeList.emplace_back(int64Ty);
parameterList.emplace_back(cst);
})
.Case<FloatAttr>([&](FloatAttr floatAttr) {
auto f64Ty = rewriter.getF64Type();
Value cst = rewriter.create<LLVM::ConstantOp>(loc, f64Ty, floatAttr);
parameterTypeList.emplace_back(f64Ty);
parameterList.emplace_back(cst);
})
.Case<DenseElementsAttr>([&](DenseElementsAttr denseAttr) {
// Use krnl.global to handle it
// Since the attribute is still in tensor type, the code has to cross
// onnx to krnl, and krnl to llvm.
// In future, the attributes should be converted in krnl.call builder.
// This code passed onnx-mlir-opt --convert-krnl-to-llvm test case,
// but failed in onnx-milr for the tensor type for the attribute
const auto &apiRegistry = RuntimeAPIRegistry::build(module, rewriter);
auto tensorTy = denseAttr.getType().cast<TensorType>();
auto memRefTy =
MemRefType::get(tensorTy.getShape(), tensorTy.getElementType());
memRefTy.dump();
MultiDialectBuilder<KrnlBuilder> create(rewriter, loc);
Value constantGlobal =
create.krnl.constant(memRefTy, "constant_", denseAttr);
Value convertedConstantGlobal =
rewriter
.create<UnrealizedConversionCastOp>(
loc, typeConverter->convertType(memRefTy), constantGlobal)
.getResult(0);
// constantGlobal.setType(typeConverter->convertType(memRefTy));

auto int64Ty = IntegerType::get(context, 64);
auto memRefRank = memRefTy.getRank();
auto memRefRankVal = rewriter.create<LLVM::ConstantOp>(
loc, int64Ty, rewriter.getI64IntegerAttr(memRefRank));
Value omTensor = RuntimeAPI::callApi(rewriter, loc, apiRegistry,
RuntimeAPI::API::CREATE_OMTENSOR, {memRefRankVal});

krnl::fillOMTensorWithMemRef(convertedConstantGlobal, omTensor,
false /*outOwning*/, rewriter, loc, apiRegistry, module);
auto int8Ty = IntegerType::get(context, 8);
auto opaquePtrTy = LLVM::LLVMPointerType::get(int8Ty);
parameterTypeList.emplace_back(opaquePtrTy);
parameterList.emplace_back(omTensor);
})
.Default([&](Attribute attr) {
llvm_unreachable("This type of Attribute used by krnl.call is not "
"yet implemented");
});
}

FlatSymbolRefAttr getOrInsertCall(PatternRewriter &rewriter, ModuleOp module,
llvm::StringRef funcName, ArrayRef<Type> parameterTypeList) const {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>(funcName))
return SymbolRefAttr::get(context, funcName);
auto llvmVoidTy = LLVM::LLVMVoidType::get(context);
auto llvmFnType =
LLVM::LLVMFunctionType::get(llvmVoidTy, parameterTypeList, false);

PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), funcName, llvmFnType);
return SymbolRefAttr::get(context, funcName);
}
};

void populateLoweringKrnlCallOpPattern(TypeConverter &typeConverter,
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<KrnlCallOpLowering>(typeConverter, ctx);
}

} // namespace krnl
} // namespace onnx_mlir
21 changes: 18 additions & 3 deletions src/Conversion/ONNXToKrnl/Tensor/Resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,27 @@ struct ONNXResizeOpLowering : public ConversionPattern {

// Call external function when the mode is not "nearest"
// Create KrnlCallOp and replace the du chain
// One of inputs, scales() and size(), has to be None.
// For now, None input is picked out by KrnlCall builder,
// and different function will be called accordingly.
// Another issue is the attributes with default value.
// Currently, it is assumed that all the optional attributes have
// the default value and does appear in the Attribute dictionry.
// ToFix: Handle attributes for general case
if (resizeOp.mode() != "nearest") {
Value resizeCall =
rewriter.create<KrnlCallOp>(loc, alloc, op, operands, true);
rewriter.replaceOp(op, resizeCall);
assert(op->getAttrs().size() == 1 &&
"ResizeOp: runtime lib is not supported for this case");
if (!isFromNone(resizeOp.scales())) {
rewriter.create<KrnlCallOp>(
loc, "Resize_Scales", alloc, op, operands, true);
} else {
rewriter.create<KrnlCallOp>(
loc, "Resize_Size", alloc, op, operands, true);
}
rewriter.replaceOp(op, alloc);
return success();
}
// It is much more efficient to generate codes directly if possible

// Constants used in the loop body
Value zero = create.math.constant(rewriter.getIntegerType(64), 0);
Expand Down
36 changes: 24 additions & 12 deletions src/Dialect/Krnl/Krnl.td
Original file line number Diff line number Diff line change
Expand Up @@ -36,31 +36,36 @@ def StringType : Type<CPred<"$_self.isa<krnl::StringType>()">, "string type">;
// Require regions to have krnl.terminate terminator operation.
def ImplicitKrnlTerminator : SingleBlockImplicitTerminator<"KrnlTerminatorOp">;

def KrnlCallOp : Op<Krnl_Dialect, "call"> {
def KrnlCallOp : Op<Krnl_Dialect, "call",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]
> {
let summary = "call operation";
let description = [{
The call operation provides a generic way to call an external function
at Krnl level. The `funcName` determines which function to call.
The `alloc` is the Value to store the function return. Since allocation
of the return MemRef involves shape inference usually with IndexExpr.
Thus most of time the allocation should stay in compiler, not in runtime library.
The `result` is the Value to store the function return. Currently only
one output is supported. `result` has to be resultated memref.
Since resultation of the output MemRef involves shape inference on ONNX Op,
resultation should be done at lowering ONNX Op, not within krnl.Call.
Another reason is that Krnl.call need to be defined with AllocationOp
interface if `result` is allcated inside this Op.
The parameters can be of any type: MemRef, NoneType or any llvm type.
Different types of parameters will be converted, if needed, when KrnlCallOp
is lowered. Attributes will be converted to parameters too (To be Added).
The function signature will be determined with the types of parameters.
An LLVM::CallOp to either a runtime library or a llvm intrinsic function
will be generated.
The krnl.call op will be lowered to llvm at krnl-to-llvm conversion.
}];

let arguments = (ins StrAttr:$funcName, AnyType:$alloc, Variadic<AnyType>:$parameters);
let results =(outs AnyType:$result);
let arguments = (ins StrAttr:$funcName, AnyType:$result, Variadic<AnyType>:$parameters);

// builders to build KrnlCallOp from op and operands, helping conversion from onnx to krnl.
// The name of function can be determined by the op name and elemnt type of the return,
// or given to builder if the simple rule does not work
// Attributes of the op will be propagated to KrnlCallOp if the copyAttrs is true
let builders = [OpBuilder<(ins "mlir::Value":$alloc, "mlir::Operation *":$op, "mlir::ValueRange":$operands, "bool":$copyAttrs)>,
OpBuilder<(ins "std::string":$funcNameStr, "mlir::Value":$alloc, "mlir::Operation *":$op, "mlir::ValueRange":$operands, "bool":$copyAttrs)>];
let builders = [OpBuilder<(ins "mlir::Value":$result, "mlir::Operation *":$op, "mlir::ValueRange":$operands, "bool":$copyAttrs)>,
OpBuilder<(ins "std::string":$funcNameStr, "mlir::Value":$result, "mlir::Operation *":$op, "mlir::ValueRange":$operands, "bool":$copyAttrs)>];
}

def KrnlDefineLoopsOp : Op<Krnl_Dialect, "define_loops"> {
Expand Down Expand Up @@ -208,11 +213,18 @@ def KrnlTerminatorOp : Op<Krnl_Dialect, "terminate", [Terminator]> {

def KrnlRegionOp : Op<Krnl_Dialect, "region", [NoTerminator, SingleBlock,
AffineScope]> {
let summary = "Affine Region for Krnl";
let summary = "Affine boundary for krnl loops";
let description = [{
This Op has AffineScope trait and is used to limit the scope of affine.for.
Otherwise, the affine.for might be illegal if a symbol is not defined at
the top of function, which has the AffineScope trait.
This Op has a region with AffineScope trait and is used to limit the
scope of `affine.for.'. The loop inside krnl.region can be affined if
its boundary is defined at the level of krnl.region. krnl.region does
not guarantee or require the loops inside it to be affine.
With krnl.oregion, a krnl loop may not be affine if its boundary symbol
is not defined inside a enclosing region without AffineScope trait.
In MLIR, FuncOp has the AffineScope trait.
The `krnl.region` will be removed after affine.for is lowered.
ToFix: current krnl.region does not have input and output. You cannot
create a new memref inside the region and use it outside of the region.
}];

let regions = (region SizedRegion<1>:$bodyRegion);
Expand Down
Loading

0 comments on commit 5a321d2

Please sign in to comment.