Skip to content

Commit

Permalink
Using onnx-mlir through incremental stages (llvm#257)
Browse files Browse the repository at this point in the history
* Add lowering of Vector dialect for lower-all-llvm pass

* Fix generating CallOp instructions when return type is void

* Fix lowering of memref

* Reformat using clang-format

* Record more context.

* Reflow comments.

Co-authored-by: Tian Jin <tjingrant@gmail.com>
  • Loading branch information
kwu91 and tjingrant authored Sep 10, 2020
1 parent dbc41d2 commit 03dae57
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 7 deletions.
24 changes: 19 additions & 5 deletions src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SCF/SCF.h"
Expand Down Expand Up @@ -287,8 +288,7 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern {
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
// - Copy constant data into the alloca.
auto memcpyRef = getOrInsertMemcpy(rewriter, module);
rewriter.create<CallOp>(loc, memcpyRef,
LLVM::LLVMType::getVoidTy(context),
rewriter.create<CallOp>(loc, memcpyRef, ArrayRef<Type>({}),
ArrayRef<Value>({int8PtrAlloc, i8PtrGlobal, int64Size, isVolatile}));
} else {
// Some frequently used types.
Expand Down Expand Up @@ -381,7 +381,7 @@ class KrnlMemcpyOpLowering : public ConversionPattern {
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));

// Memcpy call
rewriter.create<CallOp>(loc, memcpyRef, LLVM::LLVMType::getVoidTy(context),
rewriter.create<CallOp>(loc, memcpyRef, ArrayRef<Type>({}),
ArrayRef<Value>({alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory,
int64Size, isVolatile}));

Expand Down Expand Up @@ -612,8 +612,19 @@ class KrnlEntryPointOpLowering : public OpRewritePattern<KrnlEntryPointOp> {
// returned, otherwise return nullptr.
Value callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry,
API apiId, ArrayRef<Value> params) const {
// To be used as parameters in LLVM::CallOp, voidTy must be converted
// to empty list to avoid emission of an SSA value with voidTy. However,
// we still keep using LLVM voidTy (as opposed to empty list) when recording
// API function signatures in API registry because when declaring API
// functions in LLVM IR, the correct way to indicate an output type for
// "void" is still LLVM voidTy. Relevant discussion thread:
// https://github.com/onnx/onnx-mlir/issues/255.
SmallVector<Type, 1> outputTys;
auto outputTy = registry.at(apiId).outputTy;
if (!outputTy.isVoidTy())
outputTys.emplace_back(outputTy);
auto returnVals =
rewriter.create<LLVM::CallOp>(loc, registry.at(apiId).outputTy,
rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>(outputTys),
registry.at(apiId).symbolRef, ArrayRef<Value>(params));
if (returnVals.getNumResults() == 1)
return returnVals.getResult(0);
Expand Down Expand Up @@ -642,7 +653,7 @@ class KrnlEntryPointOpLowering : public OpRewritePattern<KrnlEntryPointOp> {
auto memRefTy = memRefPtrTy.getPointerElementTy();
auto int64Ty = LLVM::LLVMType::getInt64Ty(context);

Value memRef = rewriter.create<LLVM::LoadOp>(loc, memRefTy, ptrToMemRef);
Value memRef = rewriter.create<LLVM::UndefOp>(loc, memRefTy);

// Set dataPtr and alignedDataPtr;
auto dataPtr =
Expand Down Expand Up @@ -859,6 +870,8 @@ void mlir::populateAffineAndKrnlToLLVMConversion(
populateAffineToStdConversionPatterns(patterns, ctx);
populateLoopToStdConversionPatterns(patterns, ctx);
populateShapeToStandardConversionPatterns(patterns, ctx);
populateVectorToLLVMMatrixConversionPatterns(typeConverter, patterns);
populateVectorToLLVMConversionPatterns(typeConverter, patterns);
populateStdToLLVMConversionPatterns(typeConverter, patterns);

patterns.insert<KrnlGlobalOpLowering, KrnlPackedConstOpLowering>(
Expand All @@ -883,6 +896,7 @@ void ConvertKrnlToLLVMPass::runOnOperation() {
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
target.addIllegalOp<LLVM::DialectCastOp>();

// Lower the MemRef types to a representation in LLVM.
LowerToLLVMOptions options;
Expand Down
2 changes: 1 addition & 1 deletion test/mlir/krnl/constant.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func @test_constant(%arg0 : tensor<1xf32>) -> tensor<*xf32> {
/// Volatile flag
// CHECK: [[CONST0:%.+]] = llvm.mlir.constant(false) : !llvm.i1

// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[I8ALLOCA]], [[I8GLOBAL]], [[GLOBAL_SIZE_BYTES]], [[CONST0]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, !llvm.i64, !llvm.i1) -> !llvm.void
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[I8ALLOCA]], [[I8GLOBAL]], [[GLOBAL_SIZE_BYTES]], [[CONST0]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, !llvm.i64, !llvm.i1) -> ()

/// Prepare data for MemRef insertion.
// CHECK: [[TYPED_ALLOCA:%.+]] = llvm.bitcast [[ALLOCA]] : !llvm.ptr<array<3 x array<2 x float>>> to !llvm.ptr<float>
Expand Down
2 changes: 1 addition & 1 deletion test/mlir/krnl/reshape.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi64>) -> tensor<*x
// CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm.ptr<float> to !llvm.ptr<i8>
// CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64
// CHECK: [[VOLATILE:%.+]] = llvm.mlir.constant(false) : !llvm.i1
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]], [[VOLATILE]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, !llvm.i64, !llvm.i1) -> !llvm.void
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]], [[VOLATILE]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, !llvm.i64, !llvm.i1) -> ()
// CHECK: llvm.return [[RES]] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<4 x i64>, array<4 x i64>)>
}

0 comments on commit 03dae57

Please sign in to comment.