Skip to content

Commit

Permalink
Add a pass to decompose ONNX operations (llvm#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
tungld authored Mar 4, 2020
1 parent 7c1dd02 commit e97df0b
Show file tree
Hide file tree
Showing 11 changed files with 191 additions and 115 deletions.
3 changes: 1 addition & 2 deletions doc/gen_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@

# Operations supporting canonicalization.
OpsWithCanonicalizer = [
'Add', 'Identity', 'ReduceL1', 'ReduceL2', 'ReduceLogSum',
'ReduceLogSumExp', 'ReduceSumSquare', 'Gemm'
'Add', 'Identity', 'Gemm'
]

# Add an Op in this list if the Op needs result type deduction which is required
Expand Down
15 changes: 14 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_library(compiler
dialect/onnx/onnxop.inc
pass/onnx_combine.cpp
pass/onnx_rewrite.cpp
pass/onnx_decompose.cpp
pass/passes.hpp)

# Include root src directory.
Expand All @@ -25,6 +26,11 @@ target_link_libraries(compiler
${MLIRLibs}
curses)

set(LLVM_TARGET_DEFINITIONS pass/onnx_decompose.td)
onnf_tablegen(onnx_decompose.inc -gen-rewriters)
add_public_tablegen_target(gen_onnx_decompose)
add_dependencies(compiler gen_onnx_decompose)

set(LLVM_TARGET_DEFINITIONS pass/shape_inference_interface.td)
onnf_tablegen(shape_inference.hpp.inc -gen-op-interface-decls)
onnf_tablegen(shape_inference.cpp.inc -gen-op-interface-defs)
Expand Down Expand Up @@ -55,6 +61,13 @@ onnf_tablegen(krnl.cpp.inc -gen-op-defs)
add_public_tablegen_target(gen_krnl_ops)
add_dependencies(compiler gen_krnl_ops)

add_library(onnf_onnx_decompose pass/onnx_decompose.cpp)
target_include_directories(onnf_onnx_decompose
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
${ONNF_SRC_ROOT})
target_link_libraries(onnf_onnx_decompose ${MLIRLibs})
add_dependencies(onnf_onnx_decompose gen_krnl_ops)

add_library(onnf_shape_inference pass/shape_inference_pass.cpp)
target_include_directories(onnf_shape_inference
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
Expand Down Expand Up @@ -90,7 +103,7 @@ add_subdirectory(runtime)

add_executable(onnf main.cpp)

target_link_libraries(onnf builder ${MLIRLibs} onnf_transform onnf_shape_inference onnf_lower_frontend)
target_link_libraries(onnf builder ${MLIRLibs} onnf_transform onnf_onnx_decompose onnf_shape_inference onnf_lower_frontend)
whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs})
find_package(ZLIB REQUIRED)
target_link_libraries(onnf ${ZLIB_LIBRARIES})
Expand Down
5 changes: 0 additions & 5 deletions src/dialect/onnx/onnxop.inc
Original file line number Diff line number Diff line change
Expand Up @@ -2296,7 +2296,6 @@ def ONNXReciprocalOp:ONNX_Op<"Reciprocal",

def ONNXReduceL1Op:ONNX_Op<"ReduceL1",
[NoSideEffect]> {
let hasCanonicalizer = 1;
let summary = "ONNX ReduceL1 operation";
let description = [{
"Computes the L1 norm of the input tensor's element along the provided axes. The resulted"
Expand All @@ -2314,7 +2313,6 @@ def ONNXReduceL1Op:ONNX_Op<"ReduceL1",

def ONNXReduceL2Op:ONNX_Op<"ReduceL2",
[NoSideEffect]> {
let hasCanonicalizer = 1;
let summary = "ONNX ReduceL2 operation";
let description = [{
"Computes the L2 norm of the input tensor's element along the provided axes. The resulted"
Expand All @@ -2332,7 +2330,6 @@ def ONNXReduceL2Op:ONNX_Op<"ReduceL2",

def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum",
[NoSideEffect]> {
let hasCanonicalizer = 1;
let summary = "ONNX ReduceLogSum operation";
let description = [{
"Computes the log sum of the input tensor's element along the provided axes. The resulted"
Expand All @@ -2350,7 +2347,6 @@ def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum",

def ONNXReduceLogSumExpOp:ONNX_Op<"ReduceLogSumExp",
[NoSideEffect]> {
let hasCanonicalizer = 1;
let summary = "ONNX ReduceLogSumExp operation";
let description = [{
"Computes the log sum exponent of the input tensor's element along the provided axes. The resulted"
Expand Down Expand Up @@ -2465,7 +2461,6 @@ def ONNXReduceSumOp:ONNX_Op<"ReduceSum",

def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare",
[NoSideEffect]> {
let hasCanonicalizer = 1;
let summary = "ONNX ReduceSumSquare operation";
let description = [{
"Computes the sum square of the input tensor's element along the provided axes. The resulted"
Expand Down
3 changes: 2 additions & 1 deletion src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,9 @@ int main(int argc, char *argv[]) {
}

mlir::PassManager pm(&context);
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createDecomposeONNXToONNXPass());
pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass());

if (emissionTarget >= EmitMLIR) {
pm.addPass(mlir::createLowerToKrnlPass());
Expand Down
65 changes: 65 additions & 0 deletions src/pass/onnx_decompose.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
//===- onnx_decompose.cpp - ONNX High Level Rewriting ---------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file implements a set of rewriters to decompose an ONNX operation into
// composition of other ONNX operations.
//
// This pass is applied before any other pass so that there is no need to
// implement shape inference for the decomposed operation. Hence, it is expected
// that there is no knowledge about tensor shape at this point
//
//===----------------------------------------------------------------------===//

#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

#include "src/dialect/onnx/onnx_ops.hpp"
#include "src/pass/passes.hpp"

using namespace mlir;

namespace {
/// Include the patterns defined in the Declarative Rewrite framework.
#include "src/onnx_decompose.inc"

struct DecomposeONNXToONNXPass : public FunctionPass<DecomposeONNXToONNXPass> {
void runOnFunction() final;
};
} // end anonymous namespace.

void DecomposeONNXToONNXPass::runOnFunction() {
auto function = getFunction();
MLIRContext *context = &getContext();

ConversionTarget target(getContext());
target.addLegalDialect<ONNXOpsDialect>();

// These ops will be decomposed into other ONNX ops. Hence, they will not be
// available after this pass.
target.addIllegalOp<ONNXReduceL1Op>();
target.addIllegalOp<ONNXReduceL2Op>();
target.addIllegalOp<ONNXReduceLogSumOp>();
target.addIllegalOp<ONNXReduceLogSumExpOp>();
target.addIllegalOp<ONNXReduceSumSquareOp>();

OwningRewritePatternList patterns;
populateWithGenerated(context, &patterns);

if (failed(applyPartialConversion(function, target, patterns)))
signalPassFailure();
} // end anonymous namespace

/*!
* Create a DecomposeONNX pass.
*/
std::unique_ptr<mlir::Pass> mlir::createDecomposeONNXToONNXPass() {
return std::make_unique<DecomposeONNXToONNXPass>();
}

static PassRegistration<DecomposeONNXToONNXPass> pass("decompose-onnx",
"Decompose ONNX operations into composition of other ONNX operations.");
57 changes: 57 additions & 0 deletions src/pass/onnx_decompose.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
//===----------------------------------------------------------------------===//
//=- onnx_decompose.td - Rewriting for decomposing ONNX Ops -*- tablegen -*===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// Defines language-specific pattern match rewritings for ONNX using
// Declarative Rewrite Rules (DRR) specified using TableGen records.
//

#ifndef ONNX_DECOMPOSE
#define ONNX_DECOMPOSE

#ifndef OP_BASE
include "dialect/onnx/onnx.td"
#endif // OP_BASE

/// Note: The DRR definition used for defining patterns is shown below:
///
/// class Pattern<
/// dag sourcePattern, list<dag> resultPatterns,
/// list<dag> additionalConstraints = [],
/// dag benefitsAdded = (addBenefit 0)
/// >;

//===----------------------------------------------------------------------===//
// ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X)
//===----------------------------------------------------------------------===//
def ReduceL1OpPattern: Pat<(ONNXReduceL1Op $oprd, $axes, $keepdims),
(ONNXReduceSumOp (ONNXAbsOp $oprd), $axes, $keepdims)>;

//===----------------------------------------------------------------------===//
// ONNXReduceL2Op %X = ONNXSqrtOp (ONNXReduceSumSquareOp (%X))
//===----------------------------------------------------------------------===//
def ReduceL2OpPattern: Pat<(ONNXReduceL2Op $oprd, $axes, $keepdims),
(ONNXSqrtOp (ONNXReduceSumSquareOp $oprd, $axes, $keepdims))>;

//===----------------------------------------------------------------------===//
// ONNXReduceLogSumOp %X = ONNXLogOp (ONNXReduceSumOp (%X))
//===----------------------------------------------------------------------===//
def ReduceLogSumOpPattern: Pat<(ONNXReduceLogSumOp $oprd, $axes, $keepdims),
(ONNXLogOp (ONNXReduceSumOp $oprd, $axes, $keepdims))>;

//===----------------------------------------------------------------------===//
// ONNXReduceLogSumExpOp %X = ONNXReduceLogSumOp (ONNXExpOp %X)
//===----------------------------------------------------------------------===//
def ReduceLogSumExpOpPattern: Pat<(ONNXReduceLogSumExpOp $oprd, $axes, $keepdims),
(ONNXReduceLogSumOp (ONNXExpOp $oprd), $axes, $keepdims)>;

//===----------------------------------------------------------------------===//
// ONNXReduceSumSquareOp %X = ONNXReduceSumOp (ONNXMulOp %X, %X)
//===----------------------------------------------------------------------===//
def ReduceSumSquareOpPattern: Pat<(ONNXReduceSumSquareOp $oprd, $axes, $keepdims),
(ONNXReduceSumOp (ONNXMulOp $oprd, $oprd), $axes, $keepdims)>;

#endif // ONNX_DECOMPOSE
29 changes: 0 additions & 29 deletions src/pass/onnx_rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,35 +118,6 @@ struct SplitConvOpPattern : public RewritePattern {
};
} // end anonymous namespace

/// on the ONNXReduceL1Op.
void ONNXReduceL1Op::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ReduceL1OpPattern>(context);
}
/// on the ONNXReduceL2Op.
void ONNXReduceL2Op::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ReduceL2OpPattern>(context);
}

/// on the ONNXReduceLogSumOp.
void ONNXReduceLogSumOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ReduceLogSumOpPattern>(context);
}

/// on the ONNXReduceLogSumExpOp.
void ONNXReduceLogSumExpOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ReduceLogSumExpOpPattern>(context);
}

/// on the ONNXReduceSumSquareOp.
void ONNXReduceSumSquareOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ReduceSumSquareOpPattern>(context);
}

/// on the ONNXReduceSumSquareOp.
void ONNXConvNoBiasOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
Expand Down
30 changes: 0 additions & 30 deletions src/pass/onnx_rewrite.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,34 +24,4 @@ include "dialect/onnx/onnx.td"
/// dag benefitsAdded = (addBenefit 0)
/// >;

//===----------------------------------------------------------------------===//
// ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X)
//===----------------------------------------------------------------------===//
def ReduceL1OpPattern: Pat<(ONNXReduceL1Op $oprd, $axes, $keepdims),
(ONNXReduceSumOp (ONNXAbsOp $oprd), $axes, $keepdims)>;

//===----------------------------------------------------------------------===//
// ONNXReduceL2Op %X = ONNXSqrtOp (ONNXReduceSumSquareOp (%X))
//===----------------------------------------------------------------------===//
def ReduceL2OpPattern: Pat<(ONNXReduceL2Op $oprd, $axes, $keepdims),
(ONNXSqrtOp (ONNXReduceSumSquareOp $oprd, $axes, $keepdims))>;

//===----------------------------------------------------------------------===//
// ONNXReduceLogSumOp %X = ONNXLogOp (ONNXReduceSumOp (%X))
//===----------------------------------------------------------------------===//
def ReduceLogSumOpPattern: Pat<(ONNXReduceLogSumOp $oprd, $axes, $keepdims),
(ONNXLogOp (ONNXReduceSumOp $oprd, $axes, $keepdims))>;

//===----------------------------------------------------------------------===//
// ONNXReduceLogSumExpOp %X = ONNXReduceLogSumOp (ONNXExpOp %X)
//===----------------------------------------------------------------------===//
def ReduceLogSumExpOpPattern: Pat<(ONNXReduceLogSumExpOp $oprd, $axes, $keepdims),
(ONNXReduceLogSumOp (ONNXExpOp $oprd), $axes, $keepdims)>;

//===----------------------------------------------------------------------===//
// ONNXReduceSumSquareOp %X = ONNXReduceSumOp (ONNXMulOp %X, %X)
//===----------------------------------------------------------------------===//
def ReduceSumSquareOpPattern: Pat<(ONNXReduceSumSquareOp $oprd, $axes, $keepdims),
(ONNXReduceSumOp (ONNXMulOp $oprd, $oprd), $axes, $keepdims)>;

#endif // ONNX_REWRITE
3 changes: 3 additions & 0 deletions src/pass/passes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
namespace mlir {
class Pass;

/// Pass for rewriting inside frontend dialect.
std::unique_ptr<Pass> createDecomposeONNXToONNXPass();

std::unique_ptr<Pass> createShapeInferencePass();

/// Add pass for lowering to Krnl IR.
Expand Down
47 changes: 0 additions & 47 deletions test/mlir/onnx/onnx_canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -38,53 +38,6 @@ func @test_identity_identity(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>) ->
"std.return"(%2) : (tensor<10x10xf32>) -> ()
}

// CHECK-LABEL: @test_reducel1(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducel1(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceL1"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()

// CHECK-NEXT: [[ABS:%.+]] = "onnx.Abs"(%arg0) : (tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[ABS]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
}

// CHECK-LABEL: @test_reducel2(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducel2(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceL2"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()

// CHECK-NEXT: [[MUL:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[MUL]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[SQRT:%.+]] = "onnx.Sqrt"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
}

// CHECK-LABEL: @test_reducelogsum(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducelogsum(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceLogSum"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()

// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"(%arg0) {axes = [1], keepdims = 0 : i64} : (tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
}

// CHECK-LABEL: @test_reducelogsumexp(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducelogsumexp(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceLogSumExp"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()

// CHECK-NEXT: [[EXP:%.+]] = "onnx.Exp"(%arg0) : (tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[EXP]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32>
}

// CHECK-LABEL: @test_reducesumsquare(%{{.*}}: tensor<?x?x?xf32>) -> tensor<*xf32>
func @test_reducesumsquare(%arg0 : tensor<?x?x?xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceSumSquare"(%arg0) {axes=[1], keepdims = 0 : i64} : (tensor<?x?x?xf32>)-> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()

// CHECK-NEXT: [[SQUARE:%.+]] = "onnx.Mul"(%arg0, %arg0) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<*xf32>
// CHECK-NEXT: %{{[0-9]+}} = "onnx.ReduceSum"([[SQUARE]]) {axes = [1], keepdims = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32>
}

// CHECK-LABEL: @test_constant_pad(%{{.*}}: tensor<?x?xf32>) -> tensor<*xf32> {
func @test_constant_pad(%arg0 : tensor<?x?xf32>) -> tensor<*xf32> {
// CHECK-NEXT: [[SQUARE:%.+]] = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 2, 0, 0]} : (tensor<?x?xf32>) -> tensor<*xf32>
Expand Down
Loading

0 comments on commit e97df0b

Please sign in to comment.