diff --git a/src/Conversion/KrnlToLLVM/KrnlCall.cpp b/src/Conversion/KrnlToLLVM/KrnlCall.cpp index 30abd50511..3251fd1697 100644 --- a/src/Conversion/KrnlToLLVM/KrnlCall.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlCall.cpp @@ -68,10 +68,27 @@ class KrnlCallOpLowering : public ConversionPattern { rewriter, op, namedAttr.getValue(), parameterTypeList, parameterList); } - FlatSymbolRefAttr callRef = - create.llvm.getOrInsertSymbolRef(module, krnlCallOp.getFuncName(), - LLVM::LLVMVoidType::get(module.getContext()), parameterTypeList); - create.llvm.call({}, callRef, parameterList); + ValueRange returns = op->getResults(); + if (returns.size() == 0) { + // There is no return + FlatSymbolRefAttr callRef = + create.llvm.getOrInsertSymbolRef(module, krnlCallOp.getFuncName(), + LLVM::LLVMVoidType::get(module.getContext()), parameterTypeList); + create.llvm.call({}, callRef, parameterList); + + rewriter.eraseOp(op); + } else { + assert(returns.size() == 1 && + "Only one return value is allowed for krnl.call now"); + Type llvmReturnType = + llvmTypeConverter->convertType(returns[0].getType()); + + FlatSymbolRefAttr callRef = create.llvm.getOrInsertSymbolRef( + module, krnlCallOp.getFuncName(), llvmReturnType, parameterTypeList); + auto llvmCall = + create.llvm.call({llvmReturnType}, callRef, parameterList); + rewriter.replaceOp(op, llvmCall.getDefiningOp()->getResults()[0]); + } // Destroy OMTensor wrappers of parameters. const auto &apiRegistry = @@ -81,7 +98,6 @@ class KrnlCallOpLowering : public ConversionPattern { rewriter, loc, apiRegistry, RuntimeAPI::API::DESTROY_OMTENSOR, {omt}); } - rewriter.eraseOp(op); return success(); } diff --git a/src/Dialect/Krnl/Krnl.td b/src/Dialect/Krnl/Krnl.td index 6f8bd48aed..cac5423fcf 100644 --- a/src/Dialect/Krnl/Krnl.td +++ b/src/Dialect/Krnl/Krnl.td @@ -89,6 +89,13 @@ def KrnlCallOp : Op:$numOfOutput, Variadic:$parameters); + // Return Value for the Call. + // No return if the type is NoneType (void in llvm) + // Only scalar type is supported now. + // In future, return of memref can be supported with pointer of OMTensor. + // The returned memref will be created inside the call. + let results = (outs Variadic>:$returnValue); + // builders to build KrnlCallOp from op and operands, helping conversion from // onnx to krnl. // The name of function can be determined by the op name and elemnt type of @@ -96,6 +103,8 @@ def KrnlCallOp : Op, + OpBuilder<(ins "mlir::StringAttr":$funcNameStr, "IntegerAttr":$numOfOutput, "mlir::ValueRange":$operands)>, OpBuilder<(ins "std::string":$funcNameStr, "mlir::ValueRange":$results, "mlir::Operation *":$op, "mlir::ValueRange":$operands, "std::vector":$attributeNames)>, OpBuilder<(ins "mlir::ValueRange":$results, "mlir::Operation *":$op, "mlir::ValueRange":$operands, "bool":$copyAttrs)>, OpBuilder<(ins "std::string":$funcNameStr, "mlir::ValueRange":$results, "mlir::Operation *":$op, "mlir::ValueRange":$operands, "std::vector":$attributeNames)>, diff --git a/src/Dialect/Krnl/KrnlOps.cpp b/src/Dialect/Krnl/KrnlOps.cpp index affa43780b..cec7b2d94d 100644 --- a/src/Dialect/Krnl/KrnlOps.cpp +++ b/src/Dialect/Krnl/KrnlOps.cpp @@ -154,6 +154,16 @@ void KrnlCallOp::build(OpBuilder &builder, ::mlir::OperationState &odsState, build(builder, odsState, funcNameStr, resultVals, op, operands, copyAttrs); } +void KrnlCallOp::build(OpBuilder &builder, ::mlir::OperationState &odsState, + std::string funcName, int64_t numOfOutput, ValueRange operands) { + build(builder, odsState, {}, funcName, numOfOutput, operands); +} + +void KrnlCallOp::build(OpBuilder &builder, ::mlir::OperationState &odsState, + StringAttr funcName, IntegerAttr numOfOutput, ValueRange operands) { + build(builder, odsState, {}, funcName, numOfOutput, operands); +} + void KrnlCallOp::getEffects( SmallVectorImpl> &effects) { diff --git a/test/mlir/conversion/krnl_to_llvm/call_with_return.mlir b/test/mlir/conversion/krnl_to_llvm/call_with_return.mlir new file mode 100644 index 0000000000..4f04da7a4a --- /dev/null +++ b/test/mlir/conversion/krnl_to_llvm/call_with_return.mlir @@ -0,0 +1,10 @@ +// RUN: onnx-mlir-opt --convert-krnl-to-llvm %s -split-input-file | FileCheck %s + +func.func private @test_krnl_call_with_return(%arg0: memref<2x3xi32>) -> i32 { + %1 = "krnl.call"() {funcName = "get_omp_num_thread", numOfOutput = 0 : si64} : () -> (i32) + func.return %1: i32 +// CHECK: llvm.func @get_omp_num_thread() -> i32 +// CHECK: llvm.func @test_krnl_call_with_return +// CHECK: [[VAR_0_:%.+]] = llvm.call @get_omp_num_thread() : () -> i32 +// CHECK: llvm.return [[VAR_0_]] : i32 +}