forked from llvm/torch-mlir
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a pass to decompose ONNX operations (llvm#9)
- Loading branch information
Showing
11 changed files
with
191 additions
and
115 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
//===- onnx_decompose.cpp - ONNX High Level Rewriting ---------------------===// | ||
// | ||
// Copyright 2019 The IBM Research Authors. | ||
// | ||
// ============================================================================= | ||
// | ||
// This file implements a set of rewriters to decompose an ONNX operation into | ||
// composition of other ONNX operations. | ||
// | ||
// This pass is applied before any other pass so that there is no need to | ||
// implement shape inference for the decomposed operation. Hence, it is expected | ||
// that there is no knowledge about tensor shape at this point | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/IR/Matchers.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
|
||
#include "src/dialect/onnx/onnx_ops.hpp" | ||
#include "src/pass/passes.hpp" | ||
|
||
using namespace mlir; | ||
|
||
namespace { | ||
/// Include the patterns defined in the Declarative Rewrite framework. | ||
#include "src/onnx_decompose.inc" | ||
|
||
struct DecomposeONNXToONNXPass : public FunctionPass<DecomposeONNXToONNXPass> { | ||
void runOnFunction() final; | ||
}; | ||
} // end anonymous namespace. | ||
|
||
void DecomposeONNXToONNXPass::runOnFunction() { | ||
auto function = getFunction(); | ||
MLIRContext *context = &getContext(); | ||
|
||
ConversionTarget target(getContext()); | ||
target.addLegalDialect<ONNXOpsDialect>(); | ||
|
||
// These ops will be decomposed into other ONNX ops. Hence, they will not be | ||
// available after this pass. | ||
target.addIllegalOp<ONNXReduceL1Op>(); | ||
target.addIllegalOp<ONNXReduceL2Op>(); | ||
target.addIllegalOp<ONNXReduceLogSumOp>(); | ||
target.addIllegalOp<ONNXReduceLogSumExpOp>(); | ||
target.addIllegalOp<ONNXReduceSumSquareOp>(); | ||
|
||
OwningRewritePatternList patterns; | ||
populateWithGenerated(context, &patterns); | ||
|
||
if (failed(applyPartialConversion(function, target, patterns))) | ||
signalPassFailure(); | ||
} // end anonymous namespace | ||
|
||
/*! | ||
* Create a DecomposeONNX pass. | ||
*/ | ||
std::unique_ptr<mlir::Pass> mlir::createDecomposeONNXToONNXPass() { | ||
return std::make_unique<DecomposeONNXToONNXPass>(); | ||
} | ||
|
||
static PassRegistration<DecomposeONNXToONNXPass> pass("decompose-onnx", | ||
"Decompose ONNX operations into composition of other ONNX operations."); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
//===----------------------------------------------------------------------===// | ||
//=- onnx_decompose.td - Rewriting for decomposing ONNX Ops -*- tablegen -*===// | ||
// | ||
// Copyright 2019 The IBM Research Authors. | ||
// | ||
// ============================================================================= | ||
// | ||
// Defines language-specific pattern match rewritings for ONNX using | ||
// Declarative Rewrite Rules (DRR) specified using TableGen records. | ||
// | ||
|
||
#ifndef ONNX_DECOMPOSE | ||
#define ONNX_DECOMPOSE | ||
|
||
#ifndef OP_BASE | ||
include "dialect/onnx/onnx.td" | ||
#endif // OP_BASE | ||
|
||
/// Note: The DRR definition used for defining patterns is shown below: | ||
/// | ||
/// class Pattern< | ||
/// dag sourcePattern, list<dag> resultPatterns, | ||
/// list<dag> additionalConstraints = [], | ||
/// dag benefitsAdded = (addBenefit 0) | ||
/// >; | ||
|
||
//===----------------------------------------------------------------------===// | ||
// ONNXReduceL1Op %X = ONNXReduceSumOp (ONNXAbsOp %X) | ||
//===----------------------------------------------------------------------===// | ||
def ReduceL1OpPattern: Pat<(ONNXReduceL1Op $oprd, $axes, $keepdims), | ||
(ONNXReduceSumOp (ONNXAbsOp $oprd), $axes, $keepdims)>; | ||
|
||
//===----------------------------------------------------------------------===// | ||
// ONNXReduceL2Op %X = ONNXSqrtOp (ONNXReduceSumSquareOp (%X)) | ||
//===----------------------------------------------------------------------===// | ||
def ReduceL2OpPattern: Pat<(ONNXReduceL2Op $oprd, $axes, $keepdims), | ||
(ONNXSqrtOp (ONNXReduceSumSquareOp $oprd, $axes, $keepdims))>; | ||
|
||
//===----------------------------------------------------------------------===// | ||
// ONNXReduceLogSumOp %X = ONNXLogOp (ONNXReduceSumOp (%X)) | ||
//===----------------------------------------------------------------------===// | ||
def ReduceLogSumOpPattern: Pat<(ONNXReduceLogSumOp $oprd, $axes, $keepdims), | ||
(ONNXLogOp (ONNXReduceSumOp $oprd, $axes, $keepdims))>; | ||
|
||
//===----------------------------------------------------------------------===// | ||
// ONNXReduceLogSumExpOp %X = ONNXReduceLogSumOp (ONNXExpOp %X) | ||
//===----------------------------------------------------------------------===// | ||
def ReduceLogSumExpOpPattern: Pat<(ONNXReduceLogSumExpOp $oprd, $axes, $keepdims), | ||
(ONNXReduceLogSumOp (ONNXExpOp $oprd), $axes, $keepdims)>; | ||
|
||
//===----------------------------------------------------------------------===// | ||
// ONNXReduceSumSquareOp %X = ONNXReduceSumOp (ONNXMulOp %X, %X) | ||
//===----------------------------------------------------------------------===// | ||
def ReduceSumSquareOpPattern: Pat<(ONNXReduceSumSquareOp $oprd, $axes, $keepdims), | ||
(ONNXReduceSumOp (ONNXMulOp $oprd, $oprd), $axes, $keepdims)>; | ||
|
||
#endif // ONNX_DECOMPOSE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.