Skip to content

Commit

Permalink
Merge qdq and quant-to-int passes (#2458)
Browse files Browse the repository at this point in the history
This PR does the followings:
1. Merges `stablehlo-legalize-quantized-op-to-qdq` into
`stablehlo-legalize-quant-to-int`.
1. Rename `stablehlo-legalize-quant-to-int` to
`stablehlo-legalize-quant-to-math`. This is to clarify for scenario when
the fallback `qdq` is used and sull integer quantized program cannot be
generated.
1. Removes `stablehlo-legalize-quantized-op-to-qdq` pass and replace its
uses with `stablehlo-legalize-quant-to-math`.
1. Remove QDQ lit checks from
`stablehlo/tests/ops_stablehlo_quantized.mlir` and merges the tests
added for qdq pass in
`stablehlo/tests/stablehlo_legalize_quant_to_int.mlir`
1. Updates the tests in
`stablehlo/tests/stablehlo_legalize_quant_to_int.mlir` updating __only__
negatives tests, which are previously unhanded by
`stablehlo-legalize-quant-to-int`. The current
`stablehlo-legalize-quant-to-math` uses the fallback to handle these
cases.
1. About the pass    `stablehlo-legalize-quant-to-math`
- It uses `Patternbenefit` to assign highest priority (`benefit=10`) to
pattern which has specialized handling in
`stablehlo-legalize-quant-to-int`. Next in priority (`benefit=0`) are
the QDQ patterns.
 
With that the following program, which `stablehlo-legalize-quant-to-int`
has specialized handling, will avoid the fallback path.
```
func.func @max_per_tensor_same_quant_parameters(
    %arg0: tensor<1x2x!quant.uniform<i8:f32, 2.000000e+00:3>>
  ) -> tensor<1x2x!quant.uniform<i8:f32, 2.000000e+00:3>> {
  %0 = "stablehlo.maximum"(%arg0, %arg0) : (
    tensor<1x2x!quant.uniform<i8:f32, 2.000000e+00:3>>,
    tensor<1x2x!quant.uniform<i8:f32, 2.000000e+00:3>>
  ) -> tensor<1x2x!quant.uniform<i8:f32, 2.000000e+00:3>>
  return %0 : tensor<1x2x!quant.uniform<i8:f32, 2.000000e+00:3>>
}
```

whereas the following, which is not supported in
`stablehlo-legalize-quant-to-int` will choose the fallback path.
```
func.func @max_per_tensor_diff_quant_parameters(%arg0: tensor<!quant.uniform<i8:f32,1.0:0>>, %arg1: tensor<!quant.uniform<i8:f32,2.0:1>>) ->  tensor<!quant.uniform<i8:f32,3.0:2>> {
  %0 = "stablehlo.maximum"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32,1.0:0>>, tensor<!quant.uniform<i8:f32,2.0:1>>) -> tensor<!quant.uniform<i8:f32,3.0:2>>
  func.return %0 : tensor<!quant.uniform<i8:f32,3.0:2>>
}
```


- Currently handles qdq fallback for AddOp and a bunch of `GenericOps`
op
[cs](https://github.com/openxla/stablehlo/blob/eba821aa1c54a21d70331d7926dfc8b929f988f3/stablehlo/transforms/StablehloLegalizeQuantToInt.cpp#L1239).
qdq fallback for `dot_general` and `convolution` will be handled in a
follow up PR. What this means is we will still see quantized
dot_gneral/convolution program, which are currently unsupported by
`stablehlo-legalize-quant-to-int`, error out.
 

[childPR](#2459)
  • Loading branch information
sdasgup3 authored Jul 26, 2024
1 parent 0bfc536 commit 2980259
Show file tree
Hide file tree
Showing 11 changed files with 1,189 additions and 292 deletions.
2 changes: 1 addition & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1034,7 +1034,7 @@ cc_library(
"stablehlo/transforms/StablehloConvertToSignless.cpp",
"stablehlo/transforms/StablehloLegalizeCompositeToCall.cpp",
"stablehlo/transforms/StablehloLegalizeDeprecatedOps.cpp",
"stablehlo/transforms/StablehloLegalizeQuantToInt.cpp",
"stablehlo/transforms/StablehloLegalizeQuantToMath.cpp",
"stablehlo/transforms/StablehloLegalizeQuantizedOpToQDQ.cpp",
"stablehlo/transforms/StablehloLegalizeToVhlo.cpp",
"stablehlo/transforms/StablehloRefineArguments.cpp",
Expand Down
42 changes: 39 additions & 3 deletions docs/generated/stablehlo_passes.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,48 @@ long-term supported counterparts.
```
-fail-on-unused : Fail on (mostly) unused ops that are deprecated without any fallback.
```
### `-stablehlo-legalize-quant-to-int`
### `-stablehlo-legalize-quant-to-math`

_Convert from StableHLO quantized ops to StableHLO primitive ops._
_Convert from StableHLO quantized ops to StableHLO primitive math ops._

Convert StableHLO programs using UniformQuantized types to semantically
equivalent integer math.
equivalent integer math operations.

```mlir
func.func @add(%arg0: tensor<!quant.uniform<i8:f32,1.0:0>>, %arg1: tensor<!quant.uniform<i8:f32,2.0:1>>) -> tensor<!quant.uniform<i8:f32,3.0:2>> {
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32,1.0:0>>, tensor<!quant.uniform<i8:f32,2.0:1>>) -> tensor<!quant.uniform<i8:f32,3.0:2>>
func.return %0 : tensor<!quant.uniform<i8:f32,3.0:2>>
}
```

Will become:

```mlir
func.func @add(%arg0: tensor<i8>, %arg1: tensor<i8>) -> tensor<i8> {
%0 = stablehlo.convert %arg0 : (tensor<i8>) -> tensor<f32>
%cst = stablehlo.constant dense<0.333333343> : tensor<f32>
%1 = chlo.broadcast_multiply %0, %cst : (tensor<f32>, tensor<f32>) -> tensor<f32>
%cst_0 = stablehlo.constant dense<2.000000e+00> : tensor<f32>
%2 = chlo.broadcast_add %1, %cst_0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
%3 = stablehlo.round_nearest_even %2 : tensor<f32>
%4 = stablehlo.convert %3 : (tensor<f32>) -> tensor<i32>
%5 = stablehlo.convert %arg1 : (tensor<i8>) -> tensor<f32>
%cst_1 = stablehlo.constant dense<0.666666686> : tensor<f32>
%6 = chlo.broadcast_multiply %5, %cst_1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
%cst_2 = stablehlo.constant dense<1.33333337> : tensor<f32>
%7 = chlo.broadcast_add %6, %cst_2 : (tensor<f32>, tensor<f32>) -> tensor<f32>
%8 = stablehlo.round_nearest_even %7 : tensor<f32>
%9 = stablehlo.convert %8 : (tensor<f32>) -> tensor<i32>
%c = stablehlo.constant dense<2> : tensor<i32>
%10 = chlo.broadcast_add %4, %9 : (tensor<i32>, tensor<i32>) -> tensor<i32>
%11 = chlo.broadcast_subtract %10, %c : (tensor<i32>, tensor<i32>) -> tensor<i32>
%c_3 = stablehlo.constant dense<-128> : tensor<i32>
%c_4 = stablehlo.constant dense<127> : tensor<i32>
%12 = stablehlo.clamp %c_3, %11, %c_4 : tensor<i32>
%13 = stablehlo.convert %12 : (tensor<i32>) -> tensor<i8>
return %13 : tensor<i8>
}
```
### `-stablehlo-legalize-quantized-op-to-qdq`

_Decompose StableHLO quantized ops using uniform quantize/dequantize ops._
Expand Down
269 changes: 30 additions & 239 deletions stablehlo/tests/ops_stablehlo_quantized.mlir

Large diffs are not rendered by default.

463 changes: 452 additions & 11 deletions stablehlo/tests/stablehlo_legalize_quant_to_int.mlir

Large diffs are not rendered by default.

592 changes: 592 additions & 0 deletions stablehlo/tests/stablehlo_legalize_quantized_op_to_qdq.mlir

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion stablehlo/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ add_mlir_dialect_library(StablehloPasses
StablehloConvertToSignless.cpp
StablehloLegalizeCompositeToCall.cpp
StablehloLegalizeDeprecatedOps.cpp
StablehloLegalizeQuantToInt.cpp
StablehloLegalizeQuantToMath.cpp
StablehloLegalizeQuantizedOpToQDQ.cpp
StablehloLegalizeToVhlo.cpp
StablehloRefineArguments.cpp
Expand Down
4 changes: 1 addition & 3 deletions stablehlo/transforms/PassPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ void createStablehloRemoveDynamismPipeline(OpPassManager &pm,

void createStablehloLowerQuantPipeline(OpPassManager &pm) {
pm.addNestedPass<mlir::func::FuncOp>(
stablehlo::createStablehloLegalizeQuantizedOpToQDQPass());
pm.addNestedPass<mlir::func::FuncOp>(
stablehlo::createStablehloLegalizeQuantToIntPass());
stablehlo::createStablehloLegalizeQuantToMathPass());
pm.addNestedPass<mlir::func::FuncOp>(
stablehlo::createChloLegalizeToStablehloPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
Expand Down
3 changes: 2 additions & 1 deletion stablehlo/transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ void populateStablehloAggressiveFolderPatterns(RewritePatternSet *patterns,
/// Collection of rewrite patterns for lowering quantized StableHLO operations
/// using uniform dequantize/quantize operations.
void populateStablehloLegalizeQuantizedOpToQDQPatterns(
RewritePatternSet *patterns, MLIRContext *context);
RewritePatternSet *patterns, MLIRContext *context,
PatternBenefit benefit = 1);

/// A subset of folding patterns for StableHLO that is necessary for shape
/// refinement.
Expand Down
42 changes: 39 additions & 3 deletions stablehlo/transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,48 @@ def StablehloLegalizeCompositeToCallPass :
];
}

def StablehloLegalizeQuantToIntPass : Pass<"stablehlo-legalize-quant-to-int", "mlir::func::FuncOp"> {
let summary = "Convert from StableHLO quantized ops to StableHLO primitive ops.";
def StablehloLegalizeQuantToMathPass : Pass<"stablehlo-legalize-quant-to-math", "mlir::func::FuncOp"> {
let summary = "Convert from StableHLO quantized ops to StableHLO primitive math ops.";

let description = [{
Convert StableHLO programs using UniformQuantized types to semantically
equivalent integer math.
equivalent integer math operations.

```mlir
func.func @add(%arg0: tensor<!quant.uniform<i8:f32,1.0:0>>, %arg1: tensor<!quant.uniform<i8:f32,2.0:1>>) -> tensor<!quant.uniform<i8:f32,3.0:2>> {
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32,1.0:0>>, tensor<!quant.uniform<i8:f32,2.0:1>>) -> tensor<!quant.uniform<i8:f32,3.0:2>>
func.return %0 : tensor<!quant.uniform<i8:f32,3.0:2>>
}
```

Will become:

```mlir
func.func @add(%arg0: tensor<i8>, %arg1: tensor<i8>) -> tensor<i8> {
%0 = stablehlo.convert %arg0 : (tensor<i8>) -> tensor<f32>
%cst = stablehlo.constant dense<0.333333343> : tensor<f32>
%1 = chlo.broadcast_multiply %0, %cst : (tensor<f32>, tensor<f32>) -> tensor<f32>
%cst_0 = stablehlo.constant dense<2.000000e+00> : tensor<f32>
%2 = chlo.broadcast_add %1, %cst_0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
%3 = stablehlo.round_nearest_even %2 : tensor<f32>
%4 = stablehlo.convert %3 : (tensor<f32>) -> tensor<i32>
%5 = stablehlo.convert %arg1 : (tensor<i8>) -> tensor<f32>
%cst_1 = stablehlo.constant dense<0.666666686> : tensor<f32>
%6 = chlo.broadcast_multiply %5, %cst_1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
%cst_2 = stablehlo.constant dense<1.33333337> : tensor<f32>
%7 = chlo.broadcast_add %6, %cst_2 : (tensor<f32>, tensor<f32>) -> tensor<f32>
%8 = stablehlo.round_nearest_even %7 : tensor<f32>
%9 = stablehlo.convert %8 : (tensor<f32>) -> tensor<i32>
%c = stablehlo.constant dense<2> : tensor<i32>
%10 = chlo.broadcast_add %4, %9 : (tensor<i32>, tensor<i32>) -> tensor<i32>
%11 = chlo.broadcast_subtract %10, %c : (tensor<i32>, tensor<i32>) -> tensor<i32>
%c_3 = stablehlo.constant dense<-128> : tensor<i32>
%c_4 = stablehlo.constant dense<127> : tensor<i32>
%12 = stablehlo.clamp %c_3, %11, %c_4 : tensor<i32>
%13 = stablehlo.convert %12 : (tensor<i32>) -> tensor<i8>
return %13 : tensor<i8>
}
```
}];
let dependentDialects = [
"mlir::chlo::ChloDialect",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ limitations under the License.
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo/transforms/Passes.h"

namespace mlir::stablehlo {
namespace {
Expand Down Expand Up @@ -426,10 +427,10 @@ class ConvertUniformQuantizedAddOp
// We only handle cases where lhs, rhs and results all have quantized
// element type.
if (failed(lhsQuantType) || failed(rhsQuantType) || failed(resQuantType)) {
op->emitError(
return rewriter.notifyMatchFailure(
op,
"AddOp requires the quantized element type for all operands and "
"results");
return failure();
}

if (isPerAxisType(*lhsQuantType) || isPerAxisType(*rhsQuantType) ||
Expand All @@ -440,16 +441,16 @@ class ConvertUniformQuantizedAddOp
!isPerAxisType(*resQuantType) ||
getPerAxisType(*lhsQuantType) != getPerAxisType(*rhsQuantType) ||
getPerAxisType(*lhsQuantType) != getPerAxisType(*resQuantType)) {
op->emitError(
return rewriter.notifyMatchFailure(
op,
"Per-axis quantized AddOp requires the same quantized element "
"type for all operands and results");
return failure();
}
if (!getPerAxisType(*lhsQuantType).getStorageType().isInteger(32)) {
// For server-side StableHLO Quantization, add is quantized only when
// fused with conv/dot ops, whose output must be i32.
op->emitError("Per-axis quantized AddOp requires i32 storage type");
return failure();
return rewriter.notifyMatchFailure(
op, "Per-axis quantized AddOp requires i32 storage type");
}
return matchAndRewritePerAxis(op, adaptor, rewriter,
getPerAxisType(*lhsQuantType));
Expand Down Expand Up @@ -1229,8 +1230,9 @@ class ConvertUniformQuantizedConvolutionOp
// TODO: b/310685906 - Add operand/result type validations.
class ConvertGenericOp : public ConversionPattern {
public:
explicit ConvertGenericOp(MLIRContext *ctx, TypeConverter &converter)
: ConversionPattern(converter, MatchAnyOpTypeTag(), 1, ctx) {}
explicit ConvertGenericOp(MLIRContext *ctx, TypeConverter &converter,
PatternBenefit benefit)
: ConversionPattern(converter, MatchAnyOpTypeTag(), benefit, ctx) {}

LogicalResult matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
Expand All @@ -1244,7 +1246,8 @@ class ConvertGenericOp : public ConversionPattern {
stablehlo::ReturnOp, stablehlo::SelectOp, stablehlo::SliceOp,
stablehlo::TransposeOp, stablehlo::GetDimensionSizeOp,
stablehlo::DynamicBroadcastInDimOp>(op)) {
return failure();
return rewriter.notifyMatchFailure(
op, "Unsupported op for performing type change");
}

if (isa<stablehlo::MinOp, stablehlo::MaxOp>(op)) {
Expand All @@ -1253,10 +1256,10 @@ class ConvertGenericOp : public ConversionPattern {
auto rhsType = getPerTensorType(op->getOperandTypes()[1]);
auto resultType = getPerTensorType(op->getResultTypes()[0]);
if (lhsType != rhsType || lhsType != resultType) {
return op->emitError(
op->getName().getStringRef() +
" with different quantization parameters for operands and"
" results is not supported.");
return rewriter.notifyMatchFailure(
op, op->getName().getStringRef() +
" with different quantization parameters for operands and"
" results is not supported.");
}
}

Expand Down Expand Up @@ -1294,12 +1297,12 @@ class UniformQuantizedToIntTypeConverter : public TypeConverter {

} // namespace

#define GEN_PASS_DEF_STABLEHLOLEGALIZEQUANTTOINTPASS
#define GEN_PASS_DEF_STABLEHLOLEGALIZEQUANTTOMATHPASS
#include "stablehlo/transforms/Passes.h.inc"

class StablehloLegalizeQuantToIntPass
: public impl::StablehloLegalizeQuantToIntPassBase<
StablehloLegalizeQuantToIntPass> {
class StablehloLegalizeQuantToMathPass
: public impl::StablehloLegalizeQuantToMathPassBase<
StablehloLegalizeQuantToMathPass> {
public:
// Performs conversion of stablehlo quant ops to primitive ops.
void runOnOperation() override {
Expand All @@ -1311,11 +1314,15 @@ class StablehloLegalizeQuantToIntPass
patterns.add<ConvertUniformQuantizeOp, ConvertUniformDequantizeOp,
ConvertUniformQuantizedAddOp, ConvertUniformQuantizedDotOp,
ConvertUniformQuantizedDotGeneralOp,
ConvertUniformQuantizedConvolutionOp>(context);
ConvertUniformQuantizedConvolutionOp>(context, /*benefit=*/10);

// Populate stablehlo quant-op to dq-op-q patterns as fallback.
populateStablehloLegalizeQuantizedOpToQDQPatterns(&patterns, context,
/*benefit=*/1);

// uq->int convert patterns for func.func, func.return and generic ops.
UniformQuantizedToIntTypeConverter converter;
patterns.add<ConvertGenericOp>(context, converter);
patterns.add<ConvertGenericOp>(context, converter, /*benefit=*/10);
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
converter);
populateReturnOpTypeConversionPattern(patterns, converter);
Expand Down
17 changes: 6 additions & 11 deletions stablehlo/transforms/StablehloLegalizeQuantizedOpToQDQ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,27 +109,22 @@ class StablehloLegalizeQuantizedOpToQDQPass

template <typename... StablehloOpTypes>
void populateStablehloLegalizeQuantizedOpToQDQPatterns(
RewritePatternSet* patterns, MLIRContext* context) {
patterns->add<QuantizedStablehloOpConversion<StablehloOpTypes>...>(context);
RewritePatternSet* patterns, MLIRContext* context, PatternBenefit benefit) {
patterns->add<QuantizedStablehloOpConversion<StablehloOpTypes>...>(context,
benefit);
}

} // namespace

void populateStablehloLegalizeQuantizedOpToQDQPatterns(
RewritePatternSet* patterns, MLIRContext* context) {
// The following list covers most of the operations which, according to the
// stablehlo spoecification document, interprets the quantized
// operation using dequant-op-quant strategy. The ones excluded are
// AddOP, ConvolutionOp, DotGeneralOp, and DynamicConvOp, which are current
// using `stablehlo-legalize-quant-to-int` pass for decomposituion to
// primitive math operations.
RewritePatternSet* patterns, MLIRContext* context, PatternBenefit benefit) {
populateStablehloLegalizeQuantizedOpToQDQPatterns<
AbsOp, Atan2Op, BatchNormGradOp, BatchNormInferenceOp,
AbsOp, AddOp, Atan2Op, BatchNormGradOp, BatchNormInferenceOp,
BatchNormTrainingOp, CbrtOp, CeilOp, CholeskyOp, ClampOp, CompareOp,
CosineOp, DivOp, Expm1Op, ExpOp, FloorOp, Log1pOp, LogisticOp, LogOp,
MaxOp, MinOp, MulOp, NegOp, PowOp, ReducePrecisionOp, RemOp, RoundOp,
RoundNearestEvenOp, RsqrtOp, SelectOp, SignOp, SineOp, SqrtOp, SubtractOp,
TanhOp, TriangularSolveOp>(patterns, context);
TanhOp, TriangularSolveOp>(patterns, context, benefit);
}

} // namespace stablehlo
Expand Down

0 comments on commit 2980259

Please sign in to comment.