Skip to content

Commit

Permalink
Revert the rule that rewrites a quantize-dequantize pair to identity (#…
Browse files Browse the repository at this point in the history
…2952)

Signed-off-by: Tung D. Le <tung@jp.ibm.com>
  • Loading branch information
tungld authored Sep 25, 2024
1 parent 5c53b7e commit bb179d7
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 28 deletions.
4 changes: 1 addition & 3 deletions src/Dialect/ONNX/ONNXOps/Canonicalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1863,6 +1863,4 @@ void ONNXWhereOp::getCanonicalizationPatterns(

// on the ONNXDequantizeLinearOp.
void ONNXDequantizeLinearOp::getCanonicalizationPatterns(
RewritePatternSet &result, MLIRContext *context) {
result.insert<QuantizeDequantizePattern>(context);
}
RewritePatternSet &result, MLIRContext *context) {}
11 changes: 0 additions & 11 deletions src/Dialect/ONNX/ONNXOps/Canonicalize.td
Original file line number Diff line number Diff line change
Expand Up @@ -1055,15 +1055,4 @@ def AlwaysFalseWherePattern : Pat<
[(IsNegativeSplatConstant:$negative_constant), (AreAllDimSizes:$dims)]
>;

//===----------------------------------------------------------------------===//
// Canonicalization for ONNXDequantizeLinear
//===----------------------------------------------------------------------===//

// Convert QuantizeLinear+DequantizeLinear to Identity.
def QuantizeDequantizePattern: Pat<
(ONNXDequantizeLinearOp (ONNXQuantizeLinearOp $x, $x_scale, $x_zeropoint, $x_axis, $x_saturate),
$y_scale, $y_zeropoint, $y_axis),
(replaceWithValue $x)
>;

#endif // ONNX_REWRITE
14 changes: 0 additions & 14 deletions test/mlir/onnx/onnx_canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1826,17 +1826,3 @@ func.func @test_where_with_always_false_3(%arg0: tensor<?x?xi64>) -> tensor<2xi6
// CHECK: }
}

// -----

func.func @test_dequantize_linear(%arg0: tensor<?x?x768xf32>, %arg1: tensor<f32>, %arg2: tensor<i8>) -> (tensor<?x?x768xf32>) {
%0 = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<?x?x768xf32>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xi8>
%1 = "onnx.DequantizeLinear"(%0, %arg1, %arg2) {axis = 1 : si64} : (tensor<?x?x768xi8>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xf32>
return %1: tensor<?x?x768xf32>

// CHECK-LABEL: func.func @test_dequantize_linear
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x768xf32>, [[PARAM_1_:%.+]]: tensor<f32>, [[PARAM_2_:%.+]]: tensor<i8>) -> tensor<?x?x768xf32> {
// CHECK-NOT: "onnx.QuantizeLinear"
// CHECK-NOT: "onnx.DequantizeLinear"
// CHECK: return [[PARAM_0_]] : tensor<?x?x768xf32>
// CHECK: }
}

0 comments on commit bb179d7

Please sign in to comment.