Skip to content

Commit

Permalink
tanOp added to dqd pass to be reused by quant-to-math
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Jul 26, 2024
1 parent 90bb15a commit 8fe560e
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 76 deletions.
29 changes: 6 additions & 23 deletions docs/generated/stablehlo_passes.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,29 +144,12 @@ func.func @add(%arg0: tensor<!quant.uniform<i8:f32,1.0:0>>, %arg1: tensor<!quant
Will become:

```mlir
func.func @add(%arg0: tensor<i8>, %arg1: tensor<i8>) -> tensor<i8> {
%0 = stablehlo.convert %arg0 : (tensor<i8>) -> tensor<f32>
%cst = stablehlo.constant dense<0.333333343> : tensor<f32>
%1 = chlo.broadcast_multiply %0, %cst : (tensor<f32>, tensor<f32>) -> tensor<f32>
%cst_0 = stablehlo.constant dense<2.000000e+00> : tensor<f32>
%2 = chlo.broadcast_add %1, %cst_0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
%3 = stablehlo.round_nearest_even %2 : tensor<f32>
%4 = stablehlo.convert %3 : (tensor<f32>) -> tensor<i32>
%5 = stablehlo.convert %arg1 : (tensor<i8>) -> tensor<f32>
%cst_1 = stablehlo.constant dense<0.666666686> : tensor<f32>
%6 = chlo.broadcast_multiply %5, %cst_1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
%cst_2 = stablehlo.constant dense<1.33333337> : tensor<f32>
%7 = chlo.broadcast_add %6, %cst_2 : (tensor<f32>, tensor<f32>) -> tensor<f32>
%8 = stablehlo.round_nearest_even %7 : tensor<f32>
%9 = stablehlo.convert %8 : (tensor<f32>) -> tensor<i32>
%c = stablehlo.constant dense<2> : tensor<i32>
%10 = chlo.broadcast_add %4, %9 : (tensor<i32>, tensor<i32>) -> tensor<i32>
%11 = chlo.broadcast_subtract %10, %c : (tensor<i32>, tensor<i32>) -> tensor<i32>
%c_3 = stablehlo.constant dense<-128> : tensor<i32>
%c_4 = stablehlo.constant dense<127> : tensor<i32>
%12 = stablehlo.clamp %c_3, %11, %c_4 : tensor<i32>
%13 = stablehlo.convert %12 : (tensor<i32>) -> tensor<i8>
return %13 : tensor<i8>
func.func @add(%arg0: tensor<!quant.uniform<i8:f32, 1.000000e+00>>, %arg1: tensor<!quant.uniform<i8:f32, 2.000000e+00:1>>) -> tensor<!quant.uniform<i8:f32, 3.000000e+00:2>> {
%0 = stablehlo.uniform_dequantize %arg0 : (tensor<!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<f32>
%1 = stablehlo.uniform_dequantize %arg1 : (tensor<!quant.uniform<i8:f32, 2.000000e+00:1>>) -> tensor<f32>
%2 = stablehlo.add %0, %1 : tensor<f32>
%3 = stablehlo.uniform_quantize %2 : (tensor<f32>) -> tensor<!quant.uniform<i8:f32, 3.000000e+00:2>>
return %3 : tensor<!quant.uniform<i8:f32, 3.000000e+00:2>>
}
```
### `-stablehlo-legalize-to-vhlo`
Expand Down
52 changes: 0 additions & 52 deletions stablehlo/transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -185,58 +185,6 @@ def StablehloLegalizeQuantToMathPass : Pass<"stablehlo-legalize-quant-to-math",
let description = [{
Convert StableHLO programs using UniformQuantized types to semantically
equivalent integer math operations.
<<<<<<< HEAD

```mlir
func.func @add(%arg0: tensor<!quant.uniform<i8:f32,1.0:0>>, %arg1: tensor<!quant.uniform<i8:f32,2.0:1>>) -> tensor<!quant.uniform<i8:f32,3.0:2>> {
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32,1.0:0>>, tensor<!quant.uniform<i8:f32,2.0:1>>) -> tensor<!quant.uniform<i8:f32,3.0:2>>
func.return %0 : tensor<!quant.uniform<i8:f32,3.0:2>>
}
```

Will become:

```mlir
func.func @add(%arg0: tensor<i8>, %arg1: tensor<i8>) -> tensor<i8> {
%0 = stablehlo.convert %arg0 : (tensor<i8>) -> tensor<f32>
%cst = stablehlo.constant dense<0.333333343> : tensor<f32>
%1 = chlo.broadcast_multiply %0, %cst : (tensor<f32>, tensor<f32>) -> tensor<f32>
%cst_0 = stablehlo.constant dense<2.000000e+00> : tensor<f32>
%2 = chlo.broadcast_add %1, %cst_0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
%3 = stablehlo.round_nearest_even %2 : tensor<f32>
%4 = stablehlo.convert %3 : (tensor<f32>) -> tensor<i32>
%5 = stablehlo.convert %arg1 : (tensor<i8>) -> tensor<f32>
%cst_1 = stablehlo.constant dense<0.666666686> : tensor<f32>
%6 = chlo.broadcast_multiply %5, %cst_1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
%cst_2 = stablehlo.constant dense<1.33333337> : tensor<f32>
%7 = chlo.broadcast_add %6, %cst_2 : (tensor<f32>, tensor<f32>) -> tensor<f32>
%8 = stablehlo.round_nearest_even %7 : tensor<f32>
%9 = stablehlo.convert %8 : (tensor<f32>) -> tensor<i32>
%c = stablehlo.constant dense<2> : tensor<i32>
%10 = chlo.broadcast_add %4, %9 : (tensor<i32>, tensor<i32>) -> tensor<i32>
%11 = chlo.broadcast_subtract %10, %c : (tensor<i32>, tensor<i32>) -> tensor<i32>
%c_3 = stablehlo.constant dense<-128> : tensor<i32>
%c_4 = stablehlo.constant dense<127> : tensor<i32>
%12 = stablehlo.clamp %c_3, %11, %c_4 : tensor<i32>
%13 = stablehlo.convert %12 : (tensor<i32>) -> tensor<i8>
return %13 : tensor<i8>
}
```
}];
let dependentDialects = [
"mlir::chlo::ChloDialect",
"mlir::stablehlo::StablehloDialect",
];
}

def StablehloLegalizeQuantizedOpToQDQPass : Pass<"stablehlo-legalize-quantized-op-to-qdq", "mlir::func::FuncOp"> {
let summary = "Decompose StableHLO quantized ops using uniform quantize/dequantize ops.";

let description = [{
Decompose StableHLO quantized programs using uniform quantize/dequantize
operations. For example, the following program
=======
>>>>>>> e4fee2a3 (fix tests)

```mlir
func.func @add(%arg0: tensor<!quant.uniform<i8:f32,1.0:0>>, %arg1: tensor<!quant.uniform<i8:f32,2.0:1>>) -> tensor<!quant.uniform<i8:f32,3.0:2>> {
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/transforms/StablehloLegalizeQuantizedOpToQDQ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ void populateStablehloLegalizeQuantizedOpToQDQPatterns(
CosineOp, DivOp, Expm1Op, ExpOp, FloorOp, Log1pOp, LogisticOp, LogOp,
MaxOp, MinOp, MulOp, NegOp, PowOp, ReducePrecisionOp, RemOp, RoundOp,
RoundNearestEvenOp, RsqrtOp, SelectOp, SignOp, SineOp, SqrtOp, SubtractOp,
TanhOp, TriangularSolveOp>(patterns, context, benefit);
TanhOp, TanOp, TriangularSolveOp>(patterns, context, benefit);
}

} // namespace stablehlo
Expand Down

0 comments on commit 8fe560e

Please sign in to comment.