Skip to content

Commit

Permalink
DialectBuilder for QLinearMatMul
Browse files Browse the repository at this point in the history
Signed-off-by: Tung D. Le <tung@jp.ibm.com>
  • Loading branch information
tungld committed Jul 12, 2024
1 parent edc18a7 commit aa29016
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 20 deletions.
9 changes: 9 additions & 0 deletions src/Dialect/ONNX/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,15 @@ Value OnnxBuilder::layerNorm(Type outputType, Value input, Value scale,
return layerNormOp.getY();
}

Value OnnxBuilder::qlinearMatMul(Type outputType, Value a, Value aScale,
Value aZeroPoint, Value b, Value bScale, Value bZeroPoint, Value yScale,
Value yZeroPoint) const {
return createOpAndInferShapes<ONNXQLinearMatMulOp>(toTensor(outputType),
toTensor(a), toTensor(aScale), toTensor(aZeroPoint), toTensor(b),
toTensor(bScale), toTensor(bZeroPoint), toTensor(yScale),
toTensor(yZeroPoint));
}

Value OnnxBuilder::RMSLayerNorm(Type outputType, Value input, Value scale,
Value bias, int64_t axis, FloatAttr epsilon) const {
IntegerAttr axisAttr = getSignedInt64Attr(axis);
Expand Down
6 changes: 6 additions & 0 deletions src/Dialect/ONNX/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ struct OnnxBuilder : DialectBuilder {
mlir::Value scale, mlir::Value bias, int64_t axis,
mlir::FloatAttr epsilon) const;

// ONNXQLinearMatMulOp
mlir::Value qlinearMatMul(mlir::Type outputType, mlir::Value a,
mlir::Value aScale, mlir::Value aZeroPoint, mlir::Value b,
mlir::Value bScale, mlir::Value bZeroPoint, mlir::Value yScale,
mlir::Value yZeroPoint) const;

// ONNXRMSLayerNormalizationOp, version with one output only (Y).
mlir::Value RMSLayerNorm(mlir::Type outputType, mlir::Value input,
mlir::Value scale, mlir::Value bias, int64_t axis,
Expand Down
36 changes: 17 additions & 19 deletions src/Dialect/ONNX/Transforms/Recompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,17 +349,15 @@ struct RecomposeQLinearMatMulFromQuantizeLinearPattern
using namespace onnx_mlir;
Location loc = qlOp.getLoc();
// Match
Value a, a_scale, a_zeropoint, b, b_scale, b_zeropoint, out_scale,
out_zeropoint;
if (!matchQLinearMatMulPattern(qlOp, a, a_scale, a_zeropoint, b, b_scale,
b_zeropoint, out_scale, out_zeropoint))
Value a, aScale, aZeroPoint, b, bScale, bZeroPoint, outScale, outZeroPoint;
if (!matchQLinearMatMulPattern(qlOp, a, aScale, aZeroPoint, b, bScale,
bZeroPoint, outScale, outZeroPoint))
return failure();

// Replace
MultiDialectBuilder<OnnxBuilder> create(rewriter, loc);
Value res = rewriter.create<ONNXQLinearMatMulOp>(loc, qlOp.getY().getType(),
a, a_scale, a_zeropoint, b, b_scale, b_zeropoint, out_scale,
out_zeropoint);
Value res = create.onnx.qlinearMatMul(qlOp.getY().getType(), a, aScale,
aZeroPoint, b, bScale, bZeroPoint, outScale, outZeroPoint);

rewriter.replaceOp(qlOp, res);
return success();
Expand All @@ -368,11 +366,11 @@ struct RecomposeQLinearMatMulFromQuantizeLinearPattern
// Recompose QLinearMatMul, starting from QuantizeLinear.
// Pattern: DequanizeLinear + MatMul + QuantizeLinear.
static bool matchQLinearMatMulPattern(ONNXQuantizeLinearOp op, Value &a,
Value &a_scale, Value &a_zeropoint, Value &b, Value &b_scale,
Value &b_zeropoint, Value &out_scale, Value &out_zeropoint) {
Value &aScale, Value &aZeroPoint, Value &b, Value &bScale,
Value &bZeroPoint, Value &outScale, Value &outZeroPoint) {
Operation *quantizeOp = op.getOperation();
out_scale = op.getYScale();
out_zeropoint = op.getYZeroPoint();
outScale = op.getYScale();
outZeroPoint = op.getYZeroPoint();
// Matching MatMul.
Value qlX, matA, matB;
Operation *matmulOp;
Expand All @@ -387,15 +385,15 @@ struct RecomposeQLinearMatMulFromQuantizeLinearPattern
if (!dlOpA)
return false;
a = dlOpA.getX();
a_scale = dlOpA.getXScale();
a_zeropoint = dlOpA.getXZeroPoint();
aScale = dlOpA.getXScale();
aZeroPoint = dlOpA.getXZeroPoint();
// Matching input B of MatMul.
auto dlOpB = matB.getDefiningOp<ONNXDequantizeLinearOp>();
if (!dlOpB)
return false;
b = dlOpB.getX();
b_scale = dlOpB.getXScale();
b_zeropoint = dlOpB.getXZeroPoint();
bScale = dlOpB.getXScale();
bZeroPoint = dlOpB.getXZeroPoint();
// Matched the pattern.
return true;
}
Expand Down Expand Up @@ -452,11 +450,11 @@ void RecomposeONNXToONNXPass::runOnOperation() {
// Pattern: DequanizeLinear + MatMul + QuantizeLinear.
target.addDynamicallyLegalOp<ONNXQuantizeLinearOp>(
[](ONNXQuantizeLinearOp op) {
Value a, a_scale, a_zeropoint, b, b_scale, b_zeropoint, out_scale,
out_zeropoint;
Value a, aScale, aZeroPoint, b, bScale, bZeroPoint, outScale,
outZeroPoint;
return !RecomposeQLinearMatMulFromQuantizeLinearPattern::
matchQLinearMatMulPattern(op, a, a_scale, a_zeropoint, b, b_scale,
b_zeropoint, out_scale, out_zeropoint);
matchQLinearMatMulPattern(op, a, aScale, aZeroPoint, b, bScale,
bZeroPoint, outScale, outZeroPoint);
});

RewritePatternSet patterns(context);
Expand Down
2 changes: 1 addition & 1 deletion test/mlir/driver/compile_phases.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: onnx-mlir %s | FileCheck %s
// RUN: onnx-mlir %s -o %t| FileCheck %s && rm %t.so

// CHECK: [1/5] {{.*}} Importing ONNX Model to MLIR Module
// CHECK: [2/5] {{.*}} Compiling and Optimizing MLIR Module
Expand Down
24 changes: 24 additions & 0 deletions test/mlir/driver/static_quantization.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: onnx-mlir --printIR --EmitONNXIR %s | FileCheck %s

// COM: Check that Dequantize-MatMul-Quantize is always recomposed to QLinearMatMul before the removal of Quantize-Dequantize is applied.
// COM: Otherwise, the recomposition of QLinearMatMul failed due to pattern mismatched (lack of DequantizeLinear).
module {
func.func @qlinear_matmul(%arg0: tensor<?x?x768xf32>, %arg1: tensor<f32>, %arg2: tensor<i8>, %arg3: tensor<768x768xi8>, %arg4: tensor<f32>, %arg5: tensor<i8>, %arg6: tensor<f32>, %arg7: tensor<i8>) -> (tensor<?x?x768xi8>) {
%0 = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<?x?x768xf32>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xi8>
%1 = "onnx.DequantizeLinear"(%0, %arg1, %arg2) {axis = 1 : si64} : (tensor<?x?x768xi8>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xf32>
%2 = "onnx.DequantizeLinear"(%arg3, %arg4, %arg5) {axis = 1 : si64} : (tensor<768x768xi8>, tensor<f32>, tensor<i8>) -> tensor<768x768xf32>
%3 = "onnx.MatMul"(%1, %2) : (tensor<?x?x768xf32>, tensor<768x768xf32>) -> tensor<?x?x768xf32>
%4 = "onnx.QuantizeLinear"(%3, %arg6, %arg7) {axis = 1 : si64} : (tensor<?x?x768xf32>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xi8>
return %4: tensor<?x?x768xi8>

}
"onnx.EntryPoint"() {func = @main_graph} : () -> ()

// CHECK-LABEL: func.func @qlinear_matmul
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x768xf32>, [[PARAM_1_:%.+]]: tensor<f32>, [[PARAM_2_:%.+]]: tensor<i8>, [[PARAM_3_:%.+]]: tensor<768x768xi8>, [[PARAM_4_:%.+]]: tensor<f32>, [[PARAM_5_:%.+]]: tensor<i8>, [[PARAM_6_:%.+]]: tensor<f32>, [[PARAM_7_:%.+]]: tensor<i8>) -> tensor<?x?x768xi8> {
// CHECK: [[VAR_0_:%.+]] = "onnx.QuantizeLinear"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 1 : si64, onnx_node_name = "onnx.QuantizeLinear_0", saturate = 1 : si64} : (tensor<?x?x768xf32>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xi8>
// CHECK: [[VAR_1_:%.+]] = "onnx.QLinearMatMul"([[VAR_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[PARAM_4_]], [[PARAM_5_]], [[PARAM_6_]], [[PARAM_7_]]) {onnx_node_name = "onnx.QLinearMatMul_1"} : (tensor<?x?x768xi8>, tensor<f32>, tensor<i8>, tensor<768x768xi8>, tensor<f32>, tensor<i8>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xi8>
// CHECK: return [[VAR_1_]] : tensor<?x?x768xi8>
// CHECK: }
// CHECK: "onnx.EntryPoint"() {func = @main_graph} : () -> ()
}

0 comments on commit aa29016

Please sign in to comment.