Skip to content

Commit

Permalink
[mlir][ONNX] Implement TileOp canonicalizer (#2765)
Browse files Browse the repository at this point in the history
Signed-off-by: Sam <srcarroll314@gmail.com>
  • Loading branch information
srcarroll authored Mar 22, 2024
1 parent 9dcf0a9 commit a1b4aa8
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/Dialect/ONNX/ONNXOps.td.inc
Original file line number Diff line number Diff line change
Expand Up @@ -9523,6 +9523,7 @@ def ONNXThresholdedReluOp:ONNX_Op<"ThresholdedRelu",

def ONNXTileOp:ONNX_Op<"Tile",
[Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, DeclareOpInterfaceMethods<ShapeHelperOpInterface>]> {
let hasCanonicalizer = 1;
let summary = "ONNX Tile operation";
let description = [{
Constructs a tensor by tiling a given tensor.
Expand Down
6 changes: 6 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Canonicalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1726,6 +1726,12 @@ void ONNXSqueezeV11Op::getCanonicalizationPatterns(
result.insert<RemoveSqueezeV11CastUnsqueezeV11Pattern>(context);
}

/// on the ONNXTileOp.
void ONNXTileOp::getCanonicalizationPatterns(
RewritePatternSet &result, MLIRContext *context) {
result.insert<RemoveIdentityTilePattern>(context);
}

/// on the ONNXTransposeOp.
void ONNXTransposeOp::getCanonicalizationPatterns(
RewritePatternSet &result, MLIRContext *context) {
Expand Down
30 changes: 28 additions & 2 deletions src/Dialect/ONNX/ONNXOps/Canonicalize.td
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ class HaveSameDim<int dim>: Constraint<
"$1.getType().cast<RankedTensorType>().getShape()[" # dim # "])">,
"Two tensors have the same specified dimension">;

def HaveSameShapedType: Constraint<
CPred<"(isa<ShapedType>($0.getType()) &&"
"dyn_cast<ShapedType>($0.getType()) == "
"dyn_cast<ShapedType>($1.getType()))">,
"has same shaped type">;

def HaveSameStaticShape: Constraint<
CPred<"onnx_mlir::haveSameStaticShape($0, $1)">,
"Two tensors have the same static shape">;
Expand Down Expand Up @@ -195,8 +201,8 @@ def IsNotFromONNXConstantOp: Constraint<
"Is a value not from ONNXConstantOp">;

def IsFromONNXConstantOpWithDenseElementsAttr: Constraint<
And<[CPred<" isa<ONNXConstantOp>($_self.getDefiningOp()) ">,
CPred<" onnx_mlir::getONNXConstantOp($_self).getValueAttr().isa<DenseElementsAttr>() ">
And<[CPred<" $_self.getDefiningOp<ONNXConstantOp>() ">,
CPred<" isa<DenseElementsAttr>(onnx_mlir::getONNXConstantOp($_self).getValueAttr()) ">
]>, "Value is not a ONNXConstantOp with a DenseElementsAttr">;

def AreTheSameAxesConstant: Constraint<
Expand Down Expand Up @@ -450,6 +456,26 @@ def SwapCastSlicePattern: Pat<
(ONNXSliceOp (ONNXCastOp $data, $saturate, $to), $starts, $ends, $axes, $steps)
>;

//===----------------------------------------------------------------------===//
// Canonicalization for ONNXTileOp
//===----------------------------------------------------------------------===//

def IsFromONNXConstantOpWithOnesDenseElementsAttr: Constraint<
And<[IsFromONNXConstantOpWithDenseElementsAttr.predicate,
CPred<"::llvm::all_of("
" onnx_mlir::getONNXConstantOp($_self).getValueAttr()"
" .dyn_cast<DenseElementsAttr>().getValues<int64_t>(), "
"[](int64_t repeat) { return repeat == 1;})">
]>, "Value is not a ONNXConstantOp with a DenseElementsAttr of ones">;

def RemoveIdentityTilePattern: Pat<
// Tile with `repeats` of all constant 1's
(ONNXTileOp:$result $val, $r),
// Remove the tile.
(replaceWithValue $val),
// Check that we have indeed a identity tile pattern.
[(IsFromONNXConstantOpWithOnesDenseElementsAttr:$r), (HaveSameShapedType $val,$result)]>;

//===----------------------------------------------------------------------===//
// Canonicalization for ONNXLayoutTransformOp
//===----------------------------------------------------------------------===//
Expand Down
12 changes: 12 additions & 0 deletions test/mlir/onnx/onnx_canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,18 @@ func.func @test_transpose_concat_reversed(%arg0: tensor<?x5x5x1xf32>, %arg1: ten

// -----

// CHECK-LABEL: func @identity_tile
func.func @identity_tile(%arg0: tensor<32x64xf32>) -> tensor<32x64xf32> {
%0 = onnx.Constant dense<1> : tensor<2xi64>
%1 = "onnx.Tile"(%arg0, %0) : (tensor<32x64xf32>, tensor<2xi64>) -> tensor<32x64xf32>
onnx.Return %1 : tensor<32x64xf32>

// CHECK-NEXT: onnx.Return %arg0
// CHECK-NOT: "onnx.Tile"
}

// -----

// Check the removal of identity reshapes.
// CHECK-LABEL: func @test_reshape_removal_1(%arg0: tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> {
func.func @test_reshape_removal_1(%arg0: tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> {
Expand Down
1 change: 1 addition & 0 deletions utils/gen_onnx_mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@
"Squeeze",
"SqueezeV11",
"Sub",
"Tile",
"Transpose",
"Unsqueeze",
"UnsqueezeV11",
Expand Down

0 comments on commit a1b4aa8

Please sign in to comment.