Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
doru1004 committed Mar 31, 2020
2 parents f5548fb + b422116 commit fc29506
Show file tree
Hide file tree
Showing 33 changed files with 596 additions and 511 deletions.
2 changes: 1 addition & 1 deletion .buildbot/p9.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ make -j "$(nproc)" install
OMP_NUM_THREADS=20 OMP_THREAD_LIMIT=20 ctest3 -j "$(nproc)"

# Run lit+FileCheck tests:
make check-mlir-lit
make check-onnx-lit
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
command: |
sudo pip install -q -e ./onnx-mlir/third_party/onnx
cd onnx-mlir/build
cmake --build . --target run-onnx-backend-test
cmake --build . --target check-onnx-backend
- run:
name: Run DocCheck
command: cd onnx-mlir/build && cmake --build . --target check-doc
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ cmake --build . --target onnx-mlir
# Run FileCheck tests:
export LIT_OPTS=-v
cmake --build . --target check-mlir-lit
cmake --build . --target check-onnx-lit
```

After the above commands succeed, an `onnx-mlir` executable should appear in the `bin` directory.
Expand Down
29 changes: 0 additions & 29 deletions doc/Dialects/onnx.md
Original file line number Diff line number Diff line change
Expand Up @@ -636,35 +636,6 @@ ONNX ConvInteger operation

1. `y`: memref of any type values or tensor of any type values

### onnx.ConvNoBias (ONNXConvNoBiasOp)
ONNX Conv operation with no Bias operand.

#### Description:


"The convolution operator consumes an input tensor and a filter, and"
"computes the output."

#### Operands:

1. `X`: memref of any type values or tensor of any type values
1. `W`: memref of any type values or tensor of any type values

#### Attributes:

| Attribute | MLIR Type | Description |
| :-------: | :-------: | ----------- |
| `auto_pad` | `StringAttr` | string attribute attribute |
| `dilations` | `ArrayAttr` | 64-bit integer array attribute attribute |
| `group` | `IntegerAttr` | 64-bit integer attribute attribute |
| `kernel_shape` | `ArrayAttr` | 64-bit integer array attribute attribute |
| `pads` | `ArrayAttr` | 64-bit integer array attribute attribute |
| `strides` | `ArrayAttr` | 64-bit integer array attribute attribute |

#### Results:

1. `o_Y`: memref of any type values or tensor of any type values

### onnx.Conv (ONNXConvOp)
ONNX Conv operation

Expand Down
30 changes: 18 additions & 12 deletions doc/ImportONNXDefs.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,31 @@
The specifications of ONNX are defined under onnx/defs directory in ONNX projects.
There is a python script onnx/defs/gen_doc.py that automatically generate documents about operations in ONNX (docs/Operations.md).
ONNX MLIR modified this script to import ONNX specifications into ONNX MLIR. There are two files generated for ONNX MLIR with the modified gen_doc.py:
1. src/dialect/onnx/onnxop.inc: Operation defintion for MLIR tablegen. Will be included in src/dialect/onnx/onnx.td
2. src/builder/op_build_table.inc: c code for ONNX MLIR frontend to import operation nodes from ONNX model. Will be included in src/builder/frontend_dialect_transformer.cpp
1. src/Dialect/ONNX/ONNXOps.td.inc: Operation defintion for MLIR tablegen. Will be included in src/Dialect/ONNX/ONNXOps.td
2. src/Builder/OpBuildTable.inc: c code for ONNX MLIR frontend to import operation nodes from ONNX model. Will be included in src/Builder/FrontendDialectTransformer.cpp

## How to use the script
1. Get ONNX. You can use onnx-mlir/third_party/onnx
2. In your ONNX directory, copy the script docs/gen_doc.py in your ONNX MLIR to onnx/defs in ONNX
3. Run the script: python onnx/defs/gen_doc.py
4. Two files, onnxop.inc and op_buid_table.inc should be generated in current directory
5. copy the two file into your ONNX MLIR: cp onnxop.inc your_onnx-mlir/src/dialect/onnx/onnxop.inc; cp op_build_table.inc your_onnx-mlir/src/builder
6. go to your ONNX MLIR and build
2. Perform the following steps (assume that we use onnx-mlir/third_party/onnx):
```bash
$ cd onnx-mlir
$ cp doc/gen_doc.py third_party/onnx/onnx/defs/
$ cd third_party/onnx
$ python onnx/defs/gen_doc.py
$ cd ../..
$ cp third_party/onnx/onnx/defs/ONNXOps.td.inc src/Dialect/ONNX/
$ cp third_party/onnx/onnx/defs/OpBuildTable.inc src/Builder/
```

## Consistency
The Operators.md generated by gen_doc.py is copied into doc. Please refer to this specification, not the one in onnx github, to make sure operators are consistent in version with onnxop.inc.
The Operators.md generated by gen_doc.py is copied into doc. Please refer to this specification, not the one in onnx github, to make sure operators are consistent in version with ONNXOps.td.inc.

## Customization
In addition to following the ONNX specification, the modified gen_doc.py provides some mechanism for you to customize the output.
Several tables are defined at the beginning of the script:
1. special_attr_defaults: gives attribute special default value.
1. special_attr_defaults: gives attribute special default value.
2. special_op_handler: creates special import function in frontend_dialect_transformer.cpp. Currently special handler is used for operations with oprational arguments
3. ShapeInferenceList: list of operations which has shape inference defined
4. CanonicalList : list of operations which has canonical form
5. manual_code_in_op_def: provides a way to specify any code for an operation in its tablegen
3. OpsWithShapeInference: list of operations which have shape inference defined
4. OpsWithCanonicalizer: list of operations which have canonical form
5. OpsWithPromotableConstOperands: list of operations which have operands that, if produced by constant operations, should be promoted to become an attribute (via attribute promotion)
6. custom_builder_ops_list: list of operations which need custom build methods to deduce result types
5 changes: 2 additions & 3 deletions doc/gen_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

# Special operation importing handlers.
special_op_handler = dict([
("Conv", "ImportNodeConv"),
("MaxPool", "ImportNodeMaxPool"),
("BatchNormalization", "ImportNodeBatchNormalization"),
("Pad", "ImportNodePad"),
Expand All @@ -47,11 +46,11 @@
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
'Sign', 'Constant', 'AveragePool', 'Abs'
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv'
]

# Operations supporting canonicalization.
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm']
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv']

# Operations who have operands that, if produced by constant operations, should
# be promoted to become an attribute (via attribute promotion).
Expand Down
44 changes: 18 additions & 26 deletions src/Builder/FrontendDialectTransformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,15 +248,31 @@ class FrontendGenImpl {
bool variadicIn = expectedNumOperands == -1;
bool variadicOut = expectedNumResults == -1;

// In ONNX, there are two ways to leave an optional input or output
// unspecified: the first, available only for trailing inputs and outputs,
// is to simply not provide that input; the second method is to use an empty
// string in place of an input or output name.
//
// Here, we import optional inputs and outputs as NoneType.

// Trailing optional inputs.
if (!variadicIn)
for (auto i = inputs.size(); i < expectedNumOperands; i++)
inputs.emplace_back(none_);

std::vector<mlir::Type> outputTypes;
for (auto item : node.output()) {
outputTypes.push_back(
mlir::UnrankedTensorType::get(builder_.getF32Type()));
// Optional outputs using empty string.
if (item.empty())
outputTypes.emplace_back(builder_.getNoneType());
else
outputTypes.push_back(
mlir::UnrankedTensorType::get(builder_.getF32Type()));
}
// Trailing optional outputs.
if (!variadicOut)
for (int i = node.output().size(); i < expectedNumResults; ++i)
outputTypes.emplace_back(builder_.getNoneType());

auto attributes = ImportNodeAttributes(node);

Expand Down Expand Up @@ -303,30 +319,6 @@ class FrontendGenImpl {
buildOutputAndOperation<mlir::ONNXReshapeOp>(node, inputs, nIn, nOut);
}

/*!
* Special handle for Conv operations.
* c++ does not allow template specialization inside a class scope
* a specialized function is used
*/
void ImportNodeConv(onnx::NodeProto node, int nIn, int nOut) {
// Conv has attribute dilations, kernel_shape, pads, the default value of
// which is determined by the shape of first argument. However, since the
// shape is unknown now, these attributes can be not generated auto
// dilations_attr = get_attr_ints(node, "dilations",
// std::vector<int>(inputs[0]->getType().cast<RankedTensorType>.getDims()-2,
// 1));
// attributes.push_back(dilations_attr)
// similar situation for pads, strides in AveragePool
// axes of ReduceSum, pads, strides, dilations and kernel_shape of MaxPool
// TODO: fix this after type inference
int nOps = node.input().size();

if (nOps == 2)
buildOperation<mlir::ONNXConvNoBiasOp>(node, nOps, nOut);
else
buildOperation<mlir::ONNXConvOp>(node, nOps, nOut);
}

/*!
* Special handle for MaxPool operations.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/Builder/OpBuildTable.inc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ if (opName == "Constant")
if (opName == "ConstantOfShape")
return buildOperation<mlir::ONNXConstantOfShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1);
if (opName == "Conv")
return ImportNodeConv(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
return buildOperation<mlir::ONNXConvOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1);
if (opName == "ConvInteger")
return buildOperation<mlir::ONNXConvIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1);
if (opName == "ConvTranspose")
Expand Down
7 changes: 4 additions & 3 deletions src/Conversion/ONNXToKrnl/Math/Gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ struct ONNXGemmOpLowering : public ConversionPattern {
bool hasBias = !op->getOperand(2).getType().isa<NoneType>();

Value A, B, C;
A = operands[0];
B = operands[1];
ONNXGemmOpOperandAdaptor operandAdaptor(operands);
A = operandAdaptor.A();
B = operandAdaptor.B();
if (hasBias)
C = operands[2];
C = operandAdaptor.C();

auto memRefType = convertToMemRefType(*op->result_type_begin());

Expand Down
5 changes: 3 additions & 2 deletions src/Conversion/ONNXToKrnl/Math/MatMul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();

Value A = operands[0];
Value B = operands[1];
ONNXMatMulOpOperandAdaptor operandAdaptor(operands);
Value A = operandAdaptor.A();
Value B = operandAdaptor.B();
auto AShape = A.getType().cast<MemRefType>().getShape();
auto BShape = B.getType().cast<MemRefType>().getShape();

Expand Down
27 changes: 13 additions & 14 deletions src/Conversion/ONNXToKrnl/Math/Softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ using namespace mlir;
struct ONNXSoftmaxOpLowering : public ConversionPattern {
ONNXSoftmaxOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// softmax(x) = let max_x = max(x) in
// let exp_x = exp(x - max_x) in
// let sum = sum(exp_x) in
Expand All @@ -29,7 +28,8 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
assert(axis >= -rank && axis <= rank - 1);

auto loc = op->getLoc();

ONNXSoftmaxOpOperandAdaptor operandAdaptor(operands);
Value input = operandAdaptor.input();
// Insert an allocation and deallocation for the result of this operation.
auto elementType = memRefType.getElementType();

Expand All @@ -38,8 +38,8 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
operands[0]);
alloc = insertAllocAndDealloc(
memRefType, loc, rewriter, insertDealloc, input);

// Shape of the result
auto memRefShape = memRefType.getShape();
Expand All @@ -49,15 +49,14 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
Value zero = emitConstantOp(rewriter, loc, elementType, 0);
Value negInfinity = rewriter.create<ConstantOp>(
loc,
Value negInfinity = rewriter.create<ConstantOp>(loc,
FloatAttr::get(elementType, -std::numeric_limits<float>::infinity()));

// Define loops.
std::vector<Value> originalLoops;
std::vector<Value> optimizedLoops;
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
optimizedLoops, rank);
Block *optimizationBlock =
defineLoops(rewriter, loc, originalLoops, optimizedLoops, rank);

// Coerce the input into a 2-D tensor. `axis` will be the coercing point.
// This coercing follows the softmax definition in ONNX:
Expand All @@ -75,7 +74,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
}
KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops);
for (int i = 0; i < axis; ++i)
addDimensionToPack(rewriter, loc, outerPack, operands[0], i);
addDimensionToPack(rewriter, loc, outerPack, input, i);

// Define an inner loop with respect to axis.
std::vector<Value> innerLoops, optimizedInnerLoops;
Expand All @@ -87,7 +86,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
}
KrnlIterateOperandPack innerPack(rewriter, innerLoops, optimizedInnerLoops);
for (int i = axis; i < rank; ++i)
addDimensionToPack(rewriter, loc, innerPack, operands[0], i);
addDimensionToPack(rewriter, loc, innerPack, input, i);

KrnlIterateOp outerIterateOp, maxIterateOp, sumIterateOp, softmaxIterateOp;
SmallVector<Value, 4> outerLoopIVs;
Expand Down Expand Up @@ -144,7 +143,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {

// Compute the max value.
Value max = rewriter.create<LoadOp>(loc, maxOp);
Value nextMax = rewriter.create<LoadOp>(loc, operands[0], maxLoopIVs);
Value nextMax = rewriter.create<LoadOp>(loc, input, maxLoopIVs);
auto maxCond =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, max, nextMax);
max = rewriter.create<SelectOp>(loc, maxCond, max, nextMax);
Expand All @@ -167,7 +166,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {

// Sum up values.
Value sum = rewriter.create<LoadOp>(loc, sumOp);
Value next = rewriter.create<LoadOp>(loc, operands[0], sumLoopIVs);
Value next = rewriter.create<LoadOp>(loc, input, sumLoopIVs);
Value sub = rewriter.create<SubFOp>(loc, next, max);
Value exp = rewriter.create<ExpOp>(loc, sub);
sum = rewriter.create<AddFOp>(loc, sum, exp);
Expand Down
46 changes: 32 additions & 14 deletions src/Conversion/ONNXToKrnl/NN/Conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,35 @@

using namespace mlir;

struct ONNXConvNoBiasOpLowering : public ConversionPattern {
ONNXConvNoBiasOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {}
struct ONNXConvOpLowering : public ConversionPattern {
ONNXConvOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXConvOp::getOperationName(), 1, ctx) {}

PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
ONNXConvOpOperandAdaptor operandAdaptor(operands);
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertToMemRefType(*op->result_type_begin());
Value alloc;
bool insertDealloc = checkInsertDealloc(op);
ONNXConvNoBiasOp convOp = llvm::dyn_cast<ONNXConvNoBiasOp>(op);
ONNXConvOp convOp = llvm::dyn_cast<ONNXConvOp>(op);

auto resultShape = memRefType.getShape();
auto inputOperand = operandAdaptor.X();
auto inputShape = inputOperand.getType().cast<MemRefType>().getShape();
auto kernelOperand = operandAdaptor.W();
auto kernelShape = kernelOperand.getType().cast<MemRefType>().getShape();
auto biasOperand = operandAdaptor.B();
bool hasBias = !biasOperand.getType().isa<NoneType>();

if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else
alloc = insertAllocAndDealloc(
memRefType, loc, rewriter, insertDealloc, {operands[0]});
memRefType, loc, rewriter, insertDealloc, {inputOperand});

auto resultShape = memRefType.getShape();
auto &inputOperand = operands[0];
auto inputShape = inputOperand.getType().cast<MemRefType>().getShape();
auto &kernelOperand = operands[1];
auto kernelShape = kernelOperand.getType().cast<MemRefType>().getShape();

// R = ConvNoBias(D, K)
// R = Conv(D, K)
//
// The input/output shapes will look like this:
//
Expand Down Expand Up @@ -169,8 +172,23 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {

// 3.4 Emit inner loop nest.
innerLoops.createIterateOp();
rewriter.setInsertionPointToStart(innerLoops.getIterateBlock());

// Emit the bias, if needed.
if (hasBias) {
auto loadResult =
rewriter.create<LoadOp>(loc, alloc, resultIndices);
SmallVector<Value, 4> biasIndices;
biasIndices.emplace_back(kernel);
auto loadBias =
rewriter.create<LoadOp>(loc, biasOperand, kernel);
auto resultWithBias = rewriter.create<MulFOp>(
loc, loadResult, loadBias);
// Store initializer value into output location.
rewriter.create<StoreOp>(loc, resultWithBias, alloc, resultIndices);
}

//
rewriter.setInsertionPointToStart(innerLoops.getIterateBlock());
{
// 4. Emit inner loop body
// R[n][kernel][r1][r2] =
Expand Down Expand Up @@ -238,5 +256,5 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {

void populateLoweringONNXConvOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<ONNXConvNoBiasOpLowering>(ctx);
patterns.insert<ONNXConvOpLowering>(ctx);
}
Loading

0 comments on commit fc29506

Please sign in to comment.