From febcc7d3e347b2af7c93a5c3669cf21fa8c2140a Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Mon, 23 Mar 2020 16:15:48 -0400 Subject: [PATCH 1/7] Reorganize main function. --- src/CMakeLists.txt | 5 +- src/main.cpp | 113 ++++----------------------------------------- src/main_utils.cpp | 99 +++++++++++++++++++++++++++++++++++++++ src/main_utils.hpp | 66 ++++++++++++++++++++++++++ 4 files changed, 177 insertions(+), 106 deletions(-) create mode 100644 src/main_utils.cpp create mode 100644 src/main_utils.hpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 697aaa4ffd..8760185047 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -6,7 +6,10 @@ add_subdirectory(Tool) add_subdirectory(Builder) add_subdirectory(Runtime) -add_executable(onnx-mlir main.cpp) +add_executable(onnx-mlir + main_utils.hpp + main_utils.cpp + main.cpp) target_link_libraries(onnx-mlir ${MLIRLibs} OMBuilder diff --git a/src/main.cpp b/src/main.cpp index 33b65959cb..627536edb9 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -6,76 +6,10 @@ // //===----------------------------------------------------------------------===// -#include -#include - -#include "llvm/Bitcode/BitcodeWriter.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/FileUtilities.h" -#include "llvm/Support/InitLLVM.h" -#include "llvm/Support/Regex.h" -#include "llvm/Support/SourceMgr.h" - -#include "src/Builder/FrontendDialectTransformer.hpp" -#include "src/Dialect/Krnl/KrnlOps.hpp" -#include "src/Dialect/ONNX/ONNXOps.hpp" -#include "src/Pass/Passes.hpp" - -#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" -#include "mlir/ExecutionEngine/ExecutionEngine.h" -#include "mlir/ExecutionEngine/OptUtils.h" -#include "mlir/InitAllDialects.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Module.h" -#include "mlir/Parser.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Target/LLVMIR.h" -#include "mlir/Transforms/Passes.h" - -void EmitLLVMBitCode(const mlir::OwningModuleRef &module); - -using namespace std; -using namespace onnx_mlir; - -void LoadMLIR(string inputFilename, mlir::MLIRContext &context, - mlir::OwningModuleRef &module) { - // Handle '.mlir' input to the ONNX MLIR frontend. - // The mlir format indicates that one or more of the supported - // representations are used in the file. - llvm::ErrorOr> fileOrErr = - llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); - if (std::error_code EC = fileOrErr.getError()) { - llvm::errs() << "Could not open input file: " << EC.message() << "\n"; - return; - } - - // Parse the input mlir. - llvm::SourceMgr sourceMgr; - sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); - module = mlir::parseSourceFile(sourceMgr, &context); - if (!module) { - llvm::errs() << "Error can't load file " << inputFilename << "\n"; - return; - } -} - -void EmitLLVMBitCode(const mlir::OwningModuleRef &module) { - error_code error; - llvm::raw_fd_ostream moduleBitcodeStream("model.bc", error, - llvm::sys::fs::F_None); - llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module), - moduleBitcodeStream); - moduleBitcodeStream.flush(); -} +#include "src/main_utils.hpp" int main(int argc, char *argv[]) { - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); + registerDialectsForONNXMLIR(); llvm::cl::OptionCategory OnnxMlirOptions("ONNX MLIR Options", "These are frontend options."); @@ -83,12 +17,6 @@ int main(int argc, char *argv[]) { llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::init("-"), llvm::cl::cat(OnnxMlirOptions)); - enum EmissionTargetType { - EmitONNXIR, - EmitMLIR, - EmitLLVMIR, - EmitLLVMBC, - }; llvm::cl::opt emissionTarget( llvm::cl::desc("Choose target to emit:"), llvm::cl::values( @@ -105,49 +33,24 @@ int main(int argc, char *argv[]) { llvm::cl::ParseCommandLineOptions(argc, argv, "ONNX MLIR modular optimizer driver\n"); - // Decide if the input file is an ONNX model or a model specified - // in MLIR. The extension of the file is the decider. - string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1); - bool inputIsONNX = (extension == "onnx"); - bool inputIsMLIR = (extension == "mlir"); - assert(inputIsONNX != inputIsMLIR && - "Either ONNX model or MLIR file needs to be provided."); - mlir::MLIRContext context; mlir::OwningModuleRef module; - if (inputIsONNX) { - ImportFrontendModelFile(inputFilename, context, module); - } else { - LoadMLIR(inputFilename, context, module); - } + processInputFile(inputFilename, emissionTarget, context, module); mlir::PassManager pm(&context); - pm.addPass(mlir::createDecomposeONNXToONNXPass()); - pm.addPass(mlir::createShapeInferencePass()); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::createShapeInferencePass()); - pm.addPass(mlir::createAttributePromotionPass()); + addONNXToMLIRPasses(pm); if (emissionTarget >= EmitMLIR) { - pm.addPass(mlir::createLowerToKrnlPass()); - // An additional pass of canonicalization is helpful because lowering - // from ONNX dialect to Standard dialect exposes additional canonicalization - // oppertunities. - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::createLowerKrnlPass()); + addONNXToKRNLPasses(pm); + addKRNLToAffinePasses(pm); } - if (emissionTarget >= EmitLLVMIR) { - pm.addPass(mlir::createLowerAffinePass()); - pm.addPass(mlir::createLowerToCFGPass()); - pm.addPass(mlir::createKrnlLowerToLLVMPass()); - pm.addPass(mlir::createCanonicalizerPass()); - } + if (emissionTarget >= EmitLLVMIR) + addKRNLToLLVMPasses(pm); if (mlir::failed(pm.run(*module))) return 4; - if (emissionTarget == EmitLLVMBC) { // Write LLVM bitcode to disk. EmitLLVMBitCode(module); diff --git a/src/main_utils.cpp b/src/main_utils.cpp new file mode 100644 index 0000000000..16ef0faa83 --- /dev/null +++ b/src/main_utils.cpp @@ -0,0 +1,99 @@ +//===--------------------------- main_utils.cpp ---------------------------===// +// +// Copyright 2019-2020 The IBM Research Authors. +// +// ============================================================================= +// +// Functions for adding passes and processing input files. +// +//===----------------------------------------------------------------------===// + +#include "src/main_utils.hpp" + +using namespace std; +using namespace onnx_mlir; + +void LoadMLIR(string inputFilename, mlir::MLIRContext &context, + mlir::OwningModuleRef &module) { + // Handle '.mlir' input to the ONNX MLIR frontend. + // The mlir format indicates that one or more of the supported + // representations are used in the file. + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); + if (std::error_code EC = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << EC.message() << "\n"; + return; + } + + // Parse the input mlir. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + module = mlir::parseSourceFile(sourceMgr, &context); + if (!module) { + llvm::errs() << "Error can't load file " << inputFilename << "\n"; + return; + } +} + +void EmitLLVMBitCode(const mlir::OwningModuleRef &module) { + error_code error; + llvm::raw_fd_ostream moduleBitcodeStream("model.bc", error, + llvm::sys::fs::F_None); + llvm::WriteBitcodeToFile(*mlir::translateModuleToLLVMIR(*module), + moduleBitcodeStream); + moduleBitcodeStream.flush(); +} + +void registerDialectsForONNXMLIR() { + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); +} + +void addONNXToMLIRPasses(mlir::PassManager &pm) { + pm.addPass(mlir::createDecomposeONNXToONNXPass()); + pm.addPass(mlir::createShapeInferencePass()); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createShapeInferencePass()); + pm.addPass(mlir::createAttributePromotionPass()); +} + +void addONNXToKRNLPasses(mlir::PassManager &pm) { + pm.addPass(mlir::createLowerToKrnlPass()); + // An additional pass of canonicalization is helpful because lowering + // from ONNX dialect to Standard dialect exposes additional canonicalization + // oppertunities. + pm.addPass(mlir::createCanonicalizerPass()); +} + +void addKRNLToAffinePasses(mlir::PassManager &pm) { + pm.addPass(mlir::createLowerKrnlPass()); +} + +void addKRNLToLLVMPasses(mlir::PassManager &pm) { + pm.addPass(mlir::createLowerAffinePass()); + pm.addPass(mlir::createLowerToCFGPass()); + pm.addPass(mlir::createKrnlLowerToLLVMPass()); + pm.addPass(mlir::createCanonicalizerPass()); +} + +void processInputFile(string inputFilename, EmissionTargetType emissionTarget, + mlir::MLIRContext &context, mlir::OwningModuleRef &module) { + // Decide if the input file is an ONNX model or a model specified + // in MLIR. The extension of the file is the decider. + string extension = inputFilename.substr(inputFilename.find_last_of(".") + 1); + bool inputIsONNX = (extension == "onnx"); + bool inputIsMLIR = (extension == "mlir"); + assert(inputIsONNX != inputIsMLIR && + "Either ONNX model or MLIR file needs to be provided."); + + + if (inputIsONNX) { + ImportFrontendModelFile(inputFilename, context, module); + } else { + LoadMLIR(inputFilename, context, module); + } +} diff --git a/src/main_utils.hpp b/src/main_utils.hpp new file mode 100644 index 0000000000..7496c53cbf --- /dev/null +++ b/src/main_utils.hpp @@ -0,0 +1,66 @@ +//===--------------------------- main_utils.hpp ---------------------------===// +// +// Copyright 2019-2020 The IBM Research Authors. +// +// ============================================================================= +// +// Functions for adding passes and processing input files. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include + +#include "llvm/Bitcode/BitcodeWriter.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FileUtilities.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/Regex.h" +#include "llvm/Support/SourceMgr.h" + +#include "src/Builder/FrontendDialectTransformer.hpp" +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" +#include "src/Pass/Passes.hpp" + +#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/InitAllDialects.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR.h" +#include "mlir/Transforms/Passes.h" + +using namespace std; +using namespace onnx_mlir; + +enum EmissionTargetType { + EmitONNXIR, + EmitMLIR, + EmitLLVMIR, + EmitLLVMBC, +}; + +void LoadMLIR(string inputFilename, mlir::MLIRContext &context, + mlir::OwningModuleRef &module); + +void EmitLLVMBitCode(const mlir::OwningModuleRef &module); + +void registerDialectsForONNXMLIR(); + +void addONNXToMLIRPasses(mlir::PassManager &pm); + +void addONNXToKRNLPasses(mlir::PassManager &pm); + +void addKRNLToAffinePasses(mlir::PassManager &pm); + +void addKRNLToLLVMPasses(mlir::PassManager &pm); + +void processInputFile(string inputFilename, EmissionTargetType emissionTarget, + mlir::MLIRContext &context, mlir::OwningModuleRef &module); \ No newline at end of file From 177311eb759b66233c71c6d28d11bf009b98a424 Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Tue, 24 Mar 2020 13:27:47 -0400 Subject: [PATCH 2/7] Follow review comments. --- src/{main_utils.cpp => MainUtils.cpp} | 8 ++++---- src/{main_utils.hpp => MainUtils.hpp} | 18 ++++++++---------- src/main.cpp | 11 +++++++---- 3 files changed, 19 insertions(+), 18 deletions(-) rename src/{main_utils.cpp => MainUtils.cpp} (94%) rename src/{main_utils.hpp => MainUtils.hpp} (76%) diff --git a/src/main_utils.cpp b/src/MainUtils.cpp similarity index 94% rename from src/main_utils.cpp rename to src/MainUtils.cpp index 16ef0faa83..a7d6467045 100644 --- a/src/main_utils.cpp +++ b/src/MainUtils.cpp @@ -44,7 +44,7 @@ void EmitLLVMBitCode(const mlir::OwningModuleRef &module) { moduleBitcodeStream.flush(); } -void registerDialectsForONNXMLIR() { +void registerDialects() { mlir::registerDialect(); mlir::registerDialect(); mlir::registerDialect(); @@ -61,7 +61,7 @@ void addONNXToMLIRPasses(mlir::PassManager &pm) { pm.addPass(mlir::createAttributePromotionPass()); } -void addONNXToKRNLPasses(mlir::PassManager &pm) { +void addONNXToKrnlPasses(mlir::PassManager &pm) { pm.addPass(mlir::createLowerToKrnlPass()); // An additional pass of canonicalization is helpful because lowering // from ONNX dialect to Standard dialect exposes additional canonicalization @@ -69,11 +69,11 @@ void addONNXToKRNLPasses(mlir::PassManager &pm) { pm.addPass(mlir::createCanonicalizerPass()); } -void addKRNLToAffinePasses(mlir::PassManager &pm) { +void addKrnlToAffinePasses(mlir::PassManager &pm) { pm.addPass(mlir::createLowerKrnlPass()); } -void addKRNLToLLVMPasses(mlir::PassManager &pm) { +void addKrnlToLLVMPasses(mlir::PassManager &pm) { pm.addPass(mlir::createLowerAffinePass()); pm.addPass(mlir::createLowerToCFGPass()); pm.addPass(mlir::createKrnlLowerToLLVMPass()); diff --git a/src/main_utils.hpp b/src/MainUtils.hpp similarity index 76% rename from src/main_utils.hpp rename to src/MainUtils.hpp index 7496c53cbf..eef65210e3 100644 --- a/src/main_utils.hpp +++ b/src/MainUtils.hpp @@ -37,9 +37,6 @@ #include "mlir/Target/LLVMIR.h" #include "mlir/Transforms/Passes.h" -using namespace std; -using namespace onnx_mlir; - enum EmissionTargetType { EmitONNXIR, EmitMLIR, @@ -47,20 +44,21 @@ enum EmissionTargetType { EmitLLVMBC, }; -void LoadMLIR(string inputFilename, mlir::MLIRContext &context, +void LoadMLIR(std::string inputFilename, mlir::MLIRContext &context, mlir::OwningModuleRef &module); void EmitLLVMBitCode(const mlir::OwningModuleRef &module); -void registerDialectsForONNXMLIR(); +void registerDialects(); void addONNXToMLIRPasses(mlir::PassManager &pm); -void addONNXToKRNLPasses(mlir::PassManager &pm); +void addONNXToKrnlPasses(mlir::PassManager &pm); -void addKRNLToAffinePasses(mlir::PassManager &pm); +void addKrnlToAffinePasses(mlir::PassManager &pm); -void addKRNLToLLVMPasses(mlir::PassManager &pm); +void addKrnlToLLVMPasses(mlir::PassManager &pm); -void processInputFile(string inputFilename, EmissionTargetType emissionTarget, - mlir::MLIRContext &context, mlir::OwningModuleRef &module); \ No newline at end of file +void processInputFile(std::string inputFilename, + EmissionTargetType emissionTarget, mlir::MLIRContext &context, + mlir::OwningModuleRef &module); \ No newline at end of file diff --git a/src/main.cpp b/src/main.cpp index 627536edb9..201517be06 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -8,8 +8,11 @@ #include "src/main_utils.hpp" +using namespace std; +using namespace onnx_mlir; + int main(int argc, char *argv[]) { - registerDialectsForONNXMLIR(); + registerDialects(); llvm::cl::OptionCategory OnnxMlirOptions("ONNX MLIR Options", "These are frontend options."); @@ -41,12 +44,12 @@ int main(int argc, char *argv[]) { addONNXToMLIRPasses(pm); if (emissionTarget >= EmitMLIR) { - addONNXToKRNLPasses(pm); - addKRNLToAffinePasses(pm); + addONNXToKrnlPasses(pm); + addKrnlToAffinePasses(pm); } if (emissionTarget >= EmitLLVMIR) - addKRNLToLLVMPasses(pm); + addKrnlToLLVMPasses(pm); if (mlir::failed(pm.run(*module))) return 4; From f5548fbec923a13db055ae0b9e83d853ca2b8be9 Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Mon, 30 Mar 2020 18:44:15 -0400 Subject: [PATCH 3/7] Emit constants are globals in Krnl and LLVM dialects. --- .../ONNXToKrnl/ConvertONNXToKrnl.cpp | 2 +- .../ONNXToKrnl/ONNXToKrnlCommon.cpp | 4 + .../ONNXToKrnl/ONNXToKrnlCommon.hpp | 2 + src/Conversion/ONNXToKrnl/Tensor/Constant.cpp | 71 ++------ src/Dialect/Krnl/KrnlOps.td | 13 ++ src/Transform/LowerKrnl.cpp | 1 + src/Transform/LowerToLLVM.cpp | 171 +++++++++++++++--- 7 files changed, 175 insertions(+), 89 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp index fe6f5d4fff..a0ae4deca3 100644 --- a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp +++ b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp @@ -47,7 +47,7 @@ struct FrontendToKrnlLoweringPass } // end anonymous namespace. void FrontendToKrnlLoweringPass::runOnModule() { - auto module = getModule(); + ModuleOp module = getModule(); // The first thing to define is the conversion target. This will define the // final target for this lowering. diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp index 9eadcac51f..d79490b329 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp @@ -485,3 +485,7 @@ Value emitNegativeInfinityConstantOp( return rewriter.create(loc, constantAttr); } + +int64_t ArrayAttrIntVal(ArrayAttr a, int i) { + return (a.getValue()[i]).cast().getInt(); +} \ No newline at end of file diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp index a6d121520d..b493a33a1e 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp @@ -117,6 +117,8 @@ Value emitPositiveInfinityConstantOp( Value emitNegativeInfinityConstantOp( ConversionPatternRewriter &rewriter, Location loc, Type type); +int64_t ArrayAttrIntVal(ArrayAttr a, int i); + //===----------------------------------------------------------------------===// // This is to get a scalar operation of a given type for a specific operation. //===----------------------------------------------------------------------===// diff --git a/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp b/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp index 17f9bc4708..4e3d0524b3 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp @@ -12,37 +12,6 @@ using namespace mlir; -template -void emitConstantAndStoreOpForDenseElementsAttr( - ConversionPatternRewriter &rewriter, Location loc, - DenseElementsAttr constantValue, ArrayRef valueShape, - ArrayRef constantIndices, Value alloc) { - // The following functor recursively walks the dimensions of the constant - // shape, generating a store when the recursion hits the base case. - SmallVector indices; - auto valueIt = constantValue.getValues().begin(); - std::function storeElements = [&](uint64_t dimension) { - // The last dimension is the base case of the recursion, at this point - // we store the element at the given index. - if (dimension == valueShape.size()) { - rewriter.create(loc, - rewriter.create(loc, *valueIt++), alloc, - llvm::makeArrayRef(indices)); - return; - } - - // Otherwise, iterate over the current dimension and add the indices to - // the list. - for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) { - indices.push_back(constantIndices[i]); - storeElements(dimension + 1); - indices.pop_back(); - } - }; - // Start the element storing recursion from the first dimension. - storeElements(/*dimension=*/0); -} - struct ONNXConstantOpLowering : public ConversionPattern { ONNXConstantOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXConstantOp::getOperationName(), 1, ctx) {} @@ -58,37 +27,21 @@ struct ONNXConstantOpLowering : public ConversionPattern { auto memRefType = convertToMemRefType(*op->result_type_begin()); - Value alloc; - bool insertDealloc = checkInsertDealloc(op); + // Shape based computations. + auto shape = memRefType.getShape(); + int64_t numElements = 1; + for (int i=0; i(); - - auto valueShape = memRefType.getShape(); - SmallVector constantIndices; - for (auto i : llvm::seq( - 0, *std::max_element(valueShape.begin(), valueShape.end()))) - constantIndices.push_back(rewriter.create(loc, i)); - - // The constant operation represents a multi-dimensional constant, so we - // will need to generate a store for each of the elements. - if (memRefType.getElementType().isa()) { - emitConstantAndStoreOpForDenseElementsAttr( - rewriter, loc, constantValue, valueShape, constantIndices, alloc); - } else if (memRefType.getElementType().isa()) { - emitConstantAndStoreOpForDenseElementsAttr( - rewriter, loc, constantValue, valueShape, constantIndices, alloc); - } else { - emitError(loc, "Unsupported output type"); - } + // Emit the constant global in Krnl dialect. + auto constantGlobal = rewriter.create(loc, + memRefType, + rewriter.getI64ArrayAttr(shape), + constantOp.value().getValue()); // Replace this operation with the generated alloc. - rewriter.replaceOp(op, alloc); + // rewriter.replaceOp(op, alloc); + rewriter.replaceOp(op, constantGlobal.getResult()); return matchSuccess(); } diff --git a/src/Dialect/Krnl/KrnlOps.td b/src/Dialect/Krnl/KrnlOps.td index e4e73de5ba..b3b5ff2f80 100644 --- a/src/Dialect/Krnl/KrnlOps.td +++ b/src/Dialect/Krnl/KrnlOps.td @@ -192,3 +192,16 @@ def KrnlMemcpyOp : Op { let parser = ?; let printer = ?; } + +def KrnlGlobalOp : Op { + let summary = "Krnl global operation"; + let description = [{ + Operation for holding global data values. + }]; + + let arguments = (ins AnyAttr:$shape, AnyAttr:$value); + let results = (outs AnyTypeOf<[AnyMemRef]>:$output); + + let parser = ?; + let printer = ?; +} diff --git a/src/Transform/LowerKrnl.cpp b/src/Transform/LowerKrnl.cpp index ec6a51b14f..02342069c6 100644 --- a/src/Transform/LowerKrnl.cpp +++ b/src/Transform/LowerKrnl.cpp @@ -154,6 +154,7 @@ void KrnlToAffineLoweringPass::runOnFunction() { target.addIllegalDialect(); target.addLegalOp(); target.addLegalOp(); + target.addLegalOp(); OwningRewritePatternList patterns; patterns.insert("llvm.memcpy.p0i8.p0i8.i64")) + return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context); + // Create a function declaration for memcpy, the signature is: + // * `void (i8*, i8* , i64, i1)` + auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect); + auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); + auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); + auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect); + auto llvmFnType = LLVM::LLVMType::getFunctionTy( + llvmVoidTy, + ArrayRef( + {llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}), + false); + + // Insert the memcpy function into the body of the parent module. + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + rewriter.create(module.getLoc(), + "llvm.memcpy.p0i8.p0i8.i64", llvmFnType); + return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context); +} + +//===----------------------------------------------------------------------===// +// KRNL to LLVM: KrnlGlobalOpLowering +//===----------------------------------------------------------------------===// + +class KrnlGlobalOpLowering : public ConvertToLLVMPattern { +public: + explicit KrnlGlobalOpLowering(MLIRContext *context, + LLVMTypeConverter &lowering_) + : ConvertToLLVMPattern(KrnlGlobalOp::getOperationName(), context, + lowering_) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto *context = op->getContext(); + auto loc = op->getLoc(); + auto *llvmDialect = + op->getContext()->getRegisteredDialect(); + assert(llvmDialect && "expected llvm dialect to be registered"); + + auto krnlGlobalOp = llvm::dyn_cast(op); + + // Get module. + ModuleOp module = op->getParentOfType(); + + // Compute total number of elements. + auto shape = (krnlGlobalOp.shape()).dyn_cast(); + int64_t numElements = 1; + for (int i=0; igetResult(0).getType(); + auto memRefTy = type.cast(); + auto llvmMemRefType = + typeConverter.convertType(type).cast(); + + // The element type of the array. + auto globalType = typeConverter.convertType(memRefTy.getElementType()); + for (int i=shape.size() - 1; i >= 0; i--) + globalType = LLVM::LLVMType::getArrayTy( + globalType.cast(), ArrayAttrIntVal(shape, i)); + auto llvmGlobalType = globalType.cast(); + + { + OpBuilder::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + + global = rewriter.create(loc, + llvmGlobalType, /*isConstant=*/true, + LLVM::Linkage::Internal, "constant_000", krnlGlobalOp.value()); + } + + // Create the llvm.mlir.undef corresponding to the MemRef. + auto llvmMemRef = MemRefDescriptor::undef(rewriter, loc, llvmMemRefType); + + // Copy over the global data: + // - Bitcast MemRef entry 1 to i8* + auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); + auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); + Value alignedMemRefDescMemory = rewriter.create( + loc, llvmMemRefType, llvmMemRef, rewriter.getI64ArrayAttr(1)); + Value int8PtrMemRef = rewriter.create( + loc, llvmI8PtrTy, alignedMemRefDescMemory); + // - Bitcast global to i8* + Value globalValue = rewriter.create(loc, global); + Value i8PtrGlobal = rewriter.create( + loc, llvmI8PtrTy, globalValue); + // - Set size. + Value memRefElementSize = rewriter.create(loc, + llvmI64Ty, rewriter.getI64IntegerAttr( + getMemRefEltSizeInBytes(memRefTy))); + Value numElementsValue = rewriter.create( + loc, llvmI64Ty, rewriter.getI64IntegerAttr(numElements)); + Value totalElementsSize = rewriter.create( + loc, memRefElementSize, numElementsValue); + Value int64Size = rewriter.create( + loc, llvmI64Ty, totalElementsSize); + // - Set volatile. + Value isVolatile = rewriter.create( + loc, LLVM::LLVMType::getInt1Ty(llvmDialect), + rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0)); + // - Copy constant data into the MemRef entry 1. + auto memcpyRef = getOrInsertMemcpy(rewriter, module, llvmDialect); + rewriter.create( + loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect), + ArrayRef({int8PtrMemRef, i8PtrGlobal, int64Size, isVolatile})); + + // Set MemRef offset to 0. + llvmMemRef.setConstantOffset(rewriter, loc, 0); + + // Set MemRef sizes and strides. All strides are 1. + // Strides of other dimensions not supported yet. + for (int i = 0; i < shape.size(); ++i) { + llvmMemRef.setConstantSize(rewriter, loc, i, ArrayAttrIntVal(shape, i)); + llvmMemRef.setConstantStride(rewriter, loc, i, 1); + } + + rewriter.replaceOp(op, {llvmMemRef}); + // rewriter.eraseOp(op); + return matchSuccess(); + } + +private: + static int64_t ArrayAttrIntVal(ArrayAttr a, int i) { + return (a.getValue()[i]).cast().getInt(); + } +}; + //===----------------------------------------------------------------------===// // KRNL to LLVM: KrnlMemcpyOpLowering //===----------------------------------------------------------------------===// @@ -118,35 +258,6 @@ class KrnlMemcpyOpLowering : public ConversionPattern { rewriter.eraseOp(op); return matchSuccess(); } - -private: - /// Return a symbol reference to the memcpy function, inserting it into the - /// module if necessary. - static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter, - ModuleOp module, - LLVM::LLVMDialect *llvmDialect) { - auto *context = module.getContext(); - if (module.lookupSymbol("llvm.memcpy.p0i8.p0i8.i64")) - return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context); - // Create a function declaration for memcpy, the signature is: - // * `void (i8*, i8* , i64, i1)` - auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect); - auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); - auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); - auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect); - auto llvmFnType = LLVM::LLVMType::getFunctionTy( - llvmVoidTy, - ArrayRef( - {llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}), - false); - - // Insert the memcpy function into the body of the parent module. - PatternRewriter::InsertionGuard insertGuard(rewriter); - rewriter.setInsertionPointToStart(module.getBody()); - rewriter.create(module.getLoc(), - "llvm.memcpy.p0i8.p0i8.i64", llvmFnType); - return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context); - } }; //===----------------------------------------------------------------------===// @@ -514,6 +625,8 @@ void KrnlToLLVMLoweringPass::runOnModule() { /*useAlloca=*/false, /*emitCWrapper=*/true); + patterns.insert(&getContext(), typeConverter); + // Lower from the `krnl` dialect i.e. the Reshape operation. patterns.insert( &getContext()); From 036f1ad92d402373cf48884367b3297cc652a279 Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Tue, 8 Dec 2020 23:33:31 -0500 Subject: [PATCH 4/7] Fix global emission when value is returned. --- src/Conversion/ONNXToKrnl/Tensor/Constant.cpp | 43 +++++++- test/mlir/krnl/constant.mlir | 102 +++++++++++++++++- test/mlir/onnx/onnx_lowering.mlir | 18 +++- 3 files changed, 153 insertions(+), 10 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp b/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp index d4fa675dfe..3e41f23aae 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp @@ -12,6 +12,20 @@ using namespace mlir; +bool checkOpResultIsReturned(ONNXConstantOp *constantOp) { + FuncOp function = getContainingFunction(constantOp->getOperation()); + + bool opIsReturned = false; + function.walk([&opIsReturned, constantOp](ReturnOp op) { + auto result = constantOp->getResult(); + for (const auto &operand : op.getOperands()) + if (operand == result) + opIsReturned = true; + }); + + return opIsReturned; +} + struct ONNXConstantOpLowering : public ConversionPattern { static int constantID; @@ -25,6 +39,8 @@ struct ONNXConstantOpLowering : public ConversionPattern { auto loc = op->getLoc(); auto constantOp = llvm::dyn_cast(op); + printf("Operation is returned: %d\n", checkOpResultIsReturned(&constantOp)); + if (constantOp.sparse_value().hasValue()) return emitError(loc, "Only support dense values at this time"); @@ -47,9 +63,30 @@ struct ONNXConstantOpLowering : public ConversionPattern { // Increment constant ID: constantID++; - // Replace this operation with the generated alloc. - // rewriter.replaceOp(op, alloc); - rewriter.replaceOp(op, constantGlobal.getResult()); + // Check if the variable is returned. + if (checkOpResultIsReturned(&constantOp)) { + // In this case, use an AllocOp for the constant since krnl.Global operations + // are not mean to be returned. + AllocOp alloc = rewriter.create(loc, memRefType); + + // Compute size in bytes using the input tensor. + Value tensorSize = emitConstantOp(rewriter, loc, + rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType)); + auto numElementsValue = emitConstantOp( + rewriter, loc, rewriter.getIntegerType(64), numElements); + tensorSize = rewriter.create(loc, tensorSize, numElementsValue); + + // Copy the value in the AllocOp. + //Value data = constantGlobal.value().getValue(); + rewriter.create(loc, alloc, constantGlobal.getResult(), tensorSize); + + // Since the value is returned we need to only work with the AllocOp + // not the KrnlGlobalOp. Globals cannot be returned. + rewriter.replaceOp(op, alloc.getResult()); + } else { + // Replace this operation with the generated krnl.global. + rewriter.replaceOp(op, constantGlobal.getResult()); + } return success(); } diff --git a/test/mlir/krnl/constant.mlir b/test/mlir/krnl/constant.mlir index 5790529f2d..53f13b811c 100644 --- a/test/mlir/krnl/constant.mlir +++ b/test/mlir/krnl/constant.mlir @@ -2,14 +2,33 @@ // ----- -func @test_constant(%arg0 : tensor<1xf32>) -> tensor<*xf32> { +func @test_constant(%arg0 : tensor<3x2xf32>) -> tensor<*xf32> { %0 = "onnx.Constant"() {value = dense<[[0.0, 0.0], [1.0, 1.1], [2.0, 2.1]]> : tensor<3x2xf32>} : () -> tensor<*xf32> - "std.return"(%0) : (tensor<*xf32>) -> () + %1 = "onnx.Relu"(%0) : (tensor<*xf32>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () // CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm.ptr, !llvm.ptr, !llvm.i64, !llvm.i1) // CHECK: llvm.mlir.global internal constant [[GLOBAL_CONST:@.+]](dense<{{.*}}[0.000000e+00, 0.000000e+00], [1.000000e+00, 1.100000e+00], [2.000000e+00, 2.100000e+00]{{.*}}> : tensor<3x2xf32>) : !llvm.array<3 x array<2 x float>> // CHECK: llvm.func @test_constant({{.*}}) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> { + // CHECK: [[CONST_3:%.+]] = llvm.mlir.constant(3 : index) : !llvm.i64 + // CHECK: [[CONST_4:%.+]] = llvm.mlir.constant(2 : index) : !llvm.i64 + + /// This is the result MemRef: + // CHECK: [[MALLOC_FOR_RES:%.+]] = llvm.call @malloc + // CHECK: [[CAST_MALLOC_FOR_RES:%.+]] = llvm.bitcast [[MALLOC_FOR_RES]] : !llvm.ptr to !llvm.ptr + // CHECK: [[RES_MEMREF:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[RES_MEMREF_1:%.+]] = llvm.insertvalue [[CAST_MALLOC_FOR_RES]], [[RES_MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[RES_MEMREF_2:%.+]] = llvm.insertvalue [[CAST_MALLOC_FOR_RES]], [[RES_MEMREF_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[CONST_0:%.+]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK: [[RES_MEMREF_3:%.+]] = llvm.insertvalue [[CONST_0]], [[RES_MEMREF_2]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[CONST_1:%.+]] = llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK: [[CONST_2:%.+]] = llvm.mlir.constant(2 : index) : !llvm.i64 + // CHECK: [[RES_MEMREF_4:%.+]] = llvm.insertvalue [[CONST_3]], [[RES_MEMREF_3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[RES_MEMREF_5:%.+]] = llvm.insertvalue [[CONST_2]], [[RES_MEMREF_4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[RES_MEMREF_6:%.+]] = llvm.insertvalue [[CONST_4]], [[RES_MEMREF_5]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[RES_MEMREF_7:%.+]] = llvm.insertvalue [[CONST_1]], [[RES_MEMREF_6]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[CONST1:%.+]] = llvm.mlir.constant(1 : i64) : !llvm.i64 // CHECK: [[ALLOCA:%.+]] = llvm.alloca [[CONST1]] x !llvm.array<3 x array<2 x float>> : (!llvm.i64) -> !llvm.ptr>> // CHECK: [[I8ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm.ptr>> to !llvm.ptr @@ -51,5 +70,82 @@ func @test_constant(%arg0 : tensor<1xf32>) -> tensor<*xf32> { // CHECK: [[CONST1:%.+]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK: [[MEMREF5:%.+]] = llvm.insertvalue [[CONST1]], [[MEMREF4]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: llvm.return [[MEMREF5]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.return [[RES_MEMREF_7]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +} + +// ----- + +func @test_constant(%arg0 : tensor<3x2xf32>) -> tensor<*xf32> { + %0 = "onnx.Constant"() {value = dense<[[0.0, 0.0], [1.0, 1.1], [2.0, 2.1]]> : tensor<3x2xf32>} : () -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK: [[CONST1:%.+]] = llvm.mlir.constant(1 : i64) : !llvm.i64 + // CHECK: [[ALLOCA:%.+]] = llvm.alloca [[CONST1]] x !llvm.array<3 x array<2 x float>> : (!llvm.i64) -> !llvm.ptr>> + // CHECK: [[I8ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm.ptr>> to !llvm.ptr + + // CHECK: [[GLOBAL_ADDR:%.+]] = llvm.mlir.addressof [[GLOBAL_CONST]] : !llvm.ptr>> + // CHECK: [[I8GLOBAL:%.+]] = llvm.bitcast [[GLOBAL_ADDR]] : !llvm.ptr>> to !llvm.ptr + + /// Size of the constant tensor in bytes. + // CHECK: [[CONST4:%.+]] = llvm.mlir.constant(4 : i64) : !llvm.i64 + // CHECK: [[CONST6:%.+]] = llvm.mlir.constant(6 : i64) : !llvm.i64 + // CHECK: [[CONST_MUL1:%.+]] = llvm.mul [[CONST4]], [[CONST6]] : !llvm.i64 + // CHECK: [[GLOBAL_SIZE_BYTES:%.+]] = llvm.sext [[CONST_MUL1]] : !llvm.i64 to !llvm.i64 + + /// Volatile flag + // CHECK: [[CONST0:%.+]] = llvm.mlir.constant(false) : !llvm.i1 + + // CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[I8ALLOCA]], [[I8GLOBAL]], [[GLOBAL_SIZE_BYTES]], [[CONST0]]) : (!llvm.ptr, !llvm.ptr, !llvm.i64, !llvm.i1) -> () + + /// Prepare data for MemRef insertion. + // CHECK: [[TYPED_ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm.ptr>> to !llvm.ptr + + /// Insert the constant value in the local MemRef. + // CHECK: [[LOCAL_MEMREF:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[LOCAL_MEMREF0:%.+]] = llvm.insertvalue [[TYPED_ALLOCA]], [[LOCAL_MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[LOCAL_MEMREF1:%.+]] = llvm.insertvalue [[TYPED_ALLOCA]], [[LOCAL_MEMREF0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + + /// Insert offset. + // CHECK: [[CONST00:%.+]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK: [[MEMREF1:%.+]] = llvm.insertvalue [[CONST00]], [[LOCAL_MEMREF1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + + /// Insert sizes and strides. + // CHECK: [[CONST3:%.+]] = llvm.mlir.constant(3 : index) : !llvm.i64 + // CHECK: [[MEMREF2:%.+]] = llvm.insertvalue [[CONST3]], [[MEMREF1]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[CONST1:%.+]] = llvm.mlir.constant(2 : index) : !llvm.i64 + // CHECK: [[MEMREF3:%.+]] = llvm.insertvalue [[CONST1]], [[MEMREF2]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + + // CHECK: [[CONST2:%.+]] = llvm.mlir.constant(2 : index) : !llvm.i64 + // CHECK: [[MEMREF4:%.+]] = llvm.insertvalue [[CONST2]], [[MEMREF3]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[CONST1:%.+]] = llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK: [[MEMREF5:%.+]] = llvm.insertvalue [[CONST1]], [[MEMREF4]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + + // CHECK: [[CONST_3:%.+]] = llvm.mlir.constant(3 : index) : !llvm.i64 + // CHECK: [[CONST_4:%.+]] = llvm.mlir.constant(2 : index) : !llvm.i64 + + /// This is the result MemRef: + // CHECK: [[MALLOC_FOR_RES:%.+]] = llvm.call @malloc + // CHECK: [[CAST_MALLOC_FOR_RES:%.+]] = llvm.bitcast [[MALLOC_FOR_RES]] : !llvm.ptr to !llvm.ptr + // CHECK: [[RES_MEMREF:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[RES_MEMREF_1:%.+]] = llvm.insertvalue [[CAST_MALLOC_FOR_RES]], [[RES_MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[RES_MEMREF_2:%.+]] = llvm.insertvalue [[CAST_MALLOC_FOR_RES]], [[RES_MEMREF_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[CONST_0:%.+]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK: [[RES_MEMREF_3:%.+]] = llvm.insertvalue [[CONST_0]], [[RES_MEMREF_2]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[CONST_1:%.+]] = llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK: [[CONST_2:%.+]] = llvm.mlir.constant(2 : index) : !llvm.i64 + // CHECK: [[RES_MEMREF_4:%.+]] = llvm.insertvalue [[CONST_3]], [[RES_MEMREF_3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[RES_MEMREF_5:%.+]] = llvm.insertvalue [[CONST_2]], [[RES_MEMREF_4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[RES_MEMREF_6:%.+]] = llvm.insertvalue [[CONST_4]], [[RES_MEMREF_5]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[RES_MEMREF_7:%.+]] = llvm.insertvalue [[CONST_1]], [[RES_MEMREF_6]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + + /// Copy result in a MemRef: + // CHECK: [[CONST_5:%.+]] = llvm.mlir.constant(24 : i64) : !llvm.i64 + // CHECK: [[OUT_DATA:%.+]] = llvm.extractvalue [[RES_MEMREF_7]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[TYPED_OUT_DATA:%.+]] = llvm.bitcast [[OUT_DATA]] : !llvm.ptr to !llvm.ptr + // CHECK: [[GLOBAL_DATA:%.+]] = llvm.extractvalue [[MEMREF5]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: [[TYPED_GLOBAL_DATA:%.+]] = llvm.bitcast [[GLOBAL_DATA]] : !llvm.ptr to !llvm.ptr + // CHECK: [[EXTENDED_CONST_5:%.+]] = llvm.sext [[CONST_5]] : !llvm.i64 to !llvm.i64 + // CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : !llvm.i1 + // CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[TYPED_OUT_DATA]], [[TYPED_GLOBAL_DATA]], [[EXTENDED_CONST_5]], [[FALSE]]) : (!llvm.ptr, !llvm.ptr, !llvm.i64, !llvm.i1) -> () + // CHECK: llvm.return [[RES_MEMREF_7]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> } diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 9da5ca5233..5b4eb4bf88 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -13,8 +13,13 @@ func @test_no_argument_2() -> tensor<*xf32> { // CHECK: test_no_argument_1 // CHECK-NEXT: test_no_argument_2 -// CHECK: [[RES:%.+]] = "{{.*}}"({{.*}}) {{.*}} : ({{.*}}) -> memref<2x2xf32> -// CHECK: return [[RES]] : memref<2x2xf32> +// CHECK: [[GLOBAL:%.+]] = "{{.*}}"({{.*}}) {{.*}} : ({{.*}}) -> memref<2x2xf32> +// CHECK: [[ALLOC:%.+]] = alloc() : memref<2x2xf32> +// CHECK: [[CONST_4:%.+]] = constant 4 : i64 +// CHECK: [[CONST_4_0:%.+]] = constant 4 : i64 +// CHECK: [[SIZE:%.+]] = muli [[CONST_4]], [[CONST_4_0]] : i64 +// CHECK: "krnl.memcpy"([[ALLOC]], [[GLOBAL]], [[SIZE]]) : (memref<2x2xf32>, memref<2x2xf32>, i64) -> () +// CHECK: return [[ALLOC]] : memref<2x2xf32> // ----- @@ -1666,8 +1671,13 @@ func @test_constant_dense_2d_value(%arg0: tensor<1xf32>) -> tensor<*xf32> { %0 = "onnx.Constant"() {value = dense<[[0.0, 0.0], [1.0, 1.1], [2.0, 2.1]]> : tensor<3x2xf32>} : () -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_constant_dense_2d_value - // CHECK: [[RES:%.+]] = "krnl.global"() {name = "constant_0", shape = [3, 2], value = dense<{{.*}}[0.000000e+00, 0.000000e+00], [1.000000e+00, 1.100000e+00], [2.000000e+00, 2.100000e+00]{{.*}}> : tensor<3x2xf32>} : () -> memref<3x2xf32> - // CHECK: return [[RES]] : memref<3x2xf32> + // CHECK: [[GLOBAL:%.+]] = "krnl.global"() {name = "constant_0", shape = [3, 2], value = dense<{{.*}}[0.000000e+00, 0.000000e+00], [1.000000e+00, 1.100000e+00], [2.000000e+00, 2.100000e+00]{{.*}}> : tensor<3x2xf32>} : () -> memref<3x2xf32> + // CHECK: [[ALLOC:%.+]] = alloc() : memref<3x2xf32> + // CHECK: [[CONST_4:%.+]] = constant 4 : i64 + // CHECK: [[CONST_6:%.+]] = constant 6 : i64 + // CHECK: [[SIZE:%.+]] = muli [[CONST_4]], [[CONST_6]] : i64 + // CHECK: "krnl.memcpy"([[ALLOC]], [[GLOBAL]], [[SIZE]]) : (memref<3x2xf32>, memref<3x2xf32>, i64) -> () + // CHECK: return [[ALLOC]] : memref<3x2xf32> } // ----- From b7a04cf07bed8f42dfdb40f5f1c206ee26b340c7 Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Tue, 8 Dec 2020 23:37:22 -0500 Subject: [PATCH 5/7] Format. --- src/Conversion/ONNXToKrnl/Tensor/Constant.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp b/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp index 3e41f23aae..5150868859 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp @@ -65,8 +65,8 @@ struct ONNXConstantOpLowering : public ConversionPattern { // Check if the variable is returned. if (checkOpResultIsReturned(&constantOp)) { - // In this case, use an AllocOp for the constant since krnl.Global operations - // are not mean to be returned. + // In this case, use an AllocOp for the constant since krnl.Global + // operations are not mean to be returned. AllocOp alloc = rewriter.create(loc, memRefType); // Compute size in bytes using the input tensor. @@ -77,8 +77,8 @@ struct ONNXConstantOpLowering : public ConversionPattern { tensorSize = rewriter.create(loc, tensorSize, numElementsValue); // Copy the value in the AllocOp. - //Value data = constantGlobal.value().getValue(); - rewriter.create(loc, alloc, constantGlobal.getResult(), tensorSize); + rewriter.create( + loc, alloc, constantGlobal.getResult(), tensorSize); // Since the value is returned we need to only work with the AllocOp // not the KrnlGlobalOp. Globals cannot be returned. From ece041630fbd2677416ccd2e3d89134082f56bba Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Wed, 9 Dec 2020 10:06:21 -0500 Subject: [PATCH 6/7] Remove comment. --- src/Conversion/ONNXToKrnl/Tensor/Constant.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp b/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp index 5150868859..7a0d1acdbf 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp @@ -39,8 +39,6 @@ struct ONNXConstantOpLowering : public ConversionPattern { auto loc = op->getLoc(); auto constantOp = llvm::dyn_cast(op); - printf("Operation is returned: %d\n", checkOpResultIsReturned(&constantOp)); - if (constantOp.sparse_value().hasValue()) return emitError(loc, "Only support dense values at this time"); From 29c0db41bdf7831e04fd5a3d460bd66b9e34a8c1 Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Wed, 9 Dec 2020 10:40:10 -0500 Subject: [PATCH 7/7] Enable size tests. --- test/backend/test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/backend/test.py b/test/backend/test.py index a44eb84134..d1c448b58b 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -570,8 +570,8 @@ # Size # TODO(tjingrant): fix unit test for size ops. - # "test_size_cpu": (test_static,), - # "test_size_example_cpu": (test_static,), + "test_size_cpu": (test_static,), + "test_size_example_cpu": (test_static,), # Slice (makes Axis a runtime argument, which is not supported).