Skip to content

Commit

Permalink
Rewrite shape and size OP (llvm#285)
Browse files Browse the repository at this point in the history
* add shape inference

* Revert "add shape inference"

This reverts commit f9d42f39e68e14b5648abccfc8617fff00244d16.

* add rewrite rules

* test cases

* format

* add constraint

* response to review

* response to review
  • Loading branch information
chentong319 authored Sep 10, 2020
1 parent 5e11429 commit ac67900
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/Dialect/ONNX/ONNXOps.td.inc
Original file line number Diff line number Diff line change
Expand Up @@ -4729,6 +4729,7 @@ def ONNXSequenceLengthOp:ONNX_Op<"SequenceLength",

def ONNXShapeOp:ONNX_Op<"Shape",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let hasCanonicalizer = 1;
let summary = "ONNX Shape operation";
let description = [{
"Takes a tensor as input and outputs an 1D int64 tensor containing the shape of the input tensor."
Expand Down Expand Up @@ -4863,6 +4864,7 @@ def ONNXSinhOp:ONNX_Op<"Sinh",

def ONNXSizeOp:ONNX_Op<"Size",
[NoSideEffect]> {
let hasCanonicalizer = 1;
let summary = "ONNX Size operation";
let description = [{
"Takes a tensor as input and outputs a int64 scalar that equals to the total number of elements of the input tensor."
Expand Down
35 changes: 35 additions & 0 deletions src/Transform/ONNX/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,29 @@ DenseElementsAttr createDenseElementsAttrFromFloatAttr(
return mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values));
}

// Create a DenseElementsAttr based on the shape of type.
DenseElementsAttr createDenseElementsAttrFromShape(
PatternRewriter &rewriter, Value value) {
auto inType = value.getType().cast<ShapedType>();
auto shape = inType.getShape();
SmallVector<int64_t, 1> dims = {inType.getRank()};
SmallVector<int64_t, 4> values(shape.begin(), shape.end());
auto tensorType =
mlir::RankedTensorType::get(dims, rewriter.getIntegerType(64));
return mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values));
}

// Create a DenseElementsAttr based on the size of type.
DenseElementsAttr createDenseElementsAttrFromSize(
PatternRewriter &rewriter, Value value) {
auto inType = value.getType().cast<ShapedType>();
SmallVector<int64_t, 1> dims(1, 1);
SmallVector<int64_t, 1> values = {inType.getNumElements()};
auto tensorType =
mlir::RankedTensorType::get(dims, rewriter.getIntegerType(64));
return mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values));
}

// If 'lhs' is not NoneType, return 'lhs - rhs'.
// Otherwise, return '-rhs'.
Value subtractOrNeg(
Expand Down Expand Up @@ -128,3 +151,15 @@ void ONNXBatchNormalizationTestModeOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<FuseBatchNormTestModeConvPattern>(context);
}

/// on the ONNXShapeOp.
void ONNXShapeOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ShapeToConstantPattern>(context);
}

/// on the ONNXSizeOp.
void ONNXSizeOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<SizeToConstantPattern>(context);
}
29 changes: 29 additions & 0 deletions src/Transform/ONNX/Rewrite.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ include "src/Dialect/ONNX/ONNXOps.td"
def createDenseElementsAttrFromFloatAttr : NativeCodeCall<
"createDenseElementsAttrFromFloatAttr($_builder, $0.getType().cast<ShapedType>().getElementType(), $1)">;

// Create a DenseElementsAttr from the shape of the type of a value.
def createDenseElementsAttrFromShape : NativeCodeCall<
"createDenseElementsAttrFromShape($_builder, $0)">;

// Create a DenseElementsAttr from the size of the type of a value.
def createDenseElementsAttrFromSize : NativeCodeCall<
"createDenseElementsAttrFromSize($_builder, $0)">;

// If '$1' is not NoneType, do subtraction '$1 - $2'.
// Otherwise, take the negative of '$2'.
def subtractOrNeg: NativeCodeCall<
Expand Down Expand Up @@ -172,4 +180,25 @@ def FuseBatchNormTestModeConvPattern: Pat<
$auto_pad, $dilation, $group, $kernel_shape, $pads, $strides)
>;

def IsStaticShapeTensor:
Constraint<
CPred<
"$_self.getType().cast<::mlir::ShapedType>().hasStaticShape()">,
"hasStaticShape">;

def ShapeToConstantPattern: Pat<
(ONNXShapeOp $A),
(ONNXConstantOp
(GetNullAttr),
(createDenseElementsAttrFromShape $A)),
[(IsStaticShapeTensor:$A)]
>;

def SizeToConstantPattern: Pat<
(ONNXSizeOp $A),
(ONNXConstantOp
(GetNullAttr),
(createDenseElementsAttrFromSize $A)),
[(IsStaticShapeTensor:$A)]
>;
#endif // ONNX_REWRITE
46 changes: 46 additions & 0 deletions test/mlir/onnx/onnx_canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,49 @@ func @test_transpose_fusion_removal(%arg0: tensor<10x11x12x13xf32>) -> tensor<10
// CHECK-NEXT: return %arg0 : tensor<10x11x12x13xf32>
"std.return"(%1) : (tensor<10x11x12x13xf32>) -> ()
}

// -----

func @test_shape1(%arg0 : tensor<2x4x8x16xf32>) -> tensor<*xi64> {
%0 = "onnx.Shape"(%arg0) : (tensor<2x4x8x16xf32>) -> tensor<*xi64>
return %0 : tensor<*xi64>

// CHECK-LABEL: @test_shape1
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[2, 4, 8, 16]> : tensor<4xi64>} : () -> tensor<*xi64>
// CHECK-NEXT: %0 : tensor<*xi64>
}

// -----

func @test_shape2(%arg0 : tensor<?x4x8x16xf32>) -> tensor<*xi64> {
%0 = "onnx.Shape"(%arg0) : (tensor<?x4x8x16xf32>) -> tensor<*xi64>
return %0 : tensor<*xi64>

// CHECK-LABEL: @test_shape2
// CHECK-NEXT: %0 = "onnx.Shape"(%arg0) : (tensor<?x4x8x16xf32>) -> tensor<*xi64>
// CHECK-NEXT: return %0 : tensor<*xi64>
}


// -----

func @test_size1(%arg0 : tensor<2x4x8x16xf32>) -> tensor<*xi64> {
%0 = "onnx.Size"(%arg0) : (tensor<2x4x8x16xf32>) -> tensor<*xi64>
return %0 : tensor<*xi64>

// CHECK-LABEL: @test_size1
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<1024> : tensor<1xi64>} : () -> tensor<*xi64>
// CHECK-NEXT: %0 : tensor<*xi64>
}

// -----

func @test_size2(%arg0 : tensor<*xf32>) -> tensor<*xi64> {
%0 = "onnx.Size"(%arg0) : (tensor<*xf32>) -> tensor<*xi64>
return %0 : tensor<*xi64>

// CHECK-LABEL: @test_size2
// CHECK-NEXT: %0 = "onnx.Size"(%arg0) : (tensor<*xf32>) -> tensor<*xi64>
// CHECK-NEXT: return %0 : tensor<*xi64>
}

2 changes: 1 addition & 1 deletion utils/gen_onnx_mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@
]

# Operations supporting canonicalization.
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Cast', 'Transpose', 'Dropout']
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Cast', 'Transpose', 'Dropout', 'Shape', 'Size']

# Operations who have operands that, if produced by constant operations, should
# be promoted to become an attribute (via attribute promotion).
Expand Down

0 comments on commit ac67900

Please sign in to comment.