Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support qdq decomposition of TanOp #2459

Merged
merged 6 commits into from
Jul 26, 2024

Conversation

sdasgup3
Copy link
Member

@sdasgup3 sdasgup3 commented Jul 24, 2024

parent PR

The PR adds patterns for qdq decomposition patterns for newly added TanOp

Note to the reviewers: To may just focus on the very last commit of the chain. The rest is coming from parent PR.

childPR

@sdasgup3 sdasgup3 force-pushed the support-qdq-decomp-tanop branch from ce20302 to 6591cbd Compare July 25, 2024 00:02
sdasgup3 added a commit that referenced this pull request Jul 26, 2024
This PR does the followings:
1. Merges `stablehlo-legalize-quantized-op-to-qdq` into
`stablehlo-legalize-quant-to-int`.
1. Rename `stablehlo-legalize-quant-to-int` to
`stablehlo-legalize-quant-to-math`. This is to clarify for scenario when
the fallback `qdq` is used and sull integer quantized program cannot be
generated.
1. Removes `stablehlo-legalize-quantized-op-to-qdq` pass and replace its
uses with `stablehlo-legalize-quant-to-math`.
1. Remove QDQ lit checks from
`stablehlo/tests/ops_stablehlo_quantized.mlir` and merges the tests
added for qdq pass in
`stablehlo/tests/stablehlo_legalize_quant_to_int.mlir`
1. Updates the tests in
`stablehlo/tests/stablehlo_legalize_quant_to_int.mlir` updating __only__
negatives tests, which are previously unhanded by
`stablehlo-legalize-quant-to-int`. The current
`stablehlo-legalize-quant-to-math` uses the fallback to handle these
cases.
1. About the pass    `stablehlo-legalize-quant-to-math`
- It uses `Patternbenefit` to assign highest priority (`benefit=10`) to
pattern which has specialized handling in
`stablehlo-legalize-quant-to-int`. Next in priority (`benefit=0`) are
the QDQ patterns.
 
With that the following program, which `stablehlo-legalize-quant-to-int`
has specialized handling, will avoid the fallback path.
```
func.func @max_per_tensor_same_quant_parameters(
    %arg0: tensor<1x2x!quant.uniform<i8:f32, 2.000000e+00:3>>
  ) -> tensor<1x2x!quant.uniform<i8:f32, 2.000000e+00:3>> {
  %0 = "stablehlo.maximum"(%arg0, %arg0) : (
    tensor<1x2x!quant.uniform<i8:f32, 2.000000e+00:3>>,
    tensor<1x2x!quant.uniform<i8:f32, 2.000000e+00:3>>
  ) -> tensor<1x2x!quant.uniform<i8:f32, 2.000000e+00:3>>
  return %0 : tensor<1x2x!quant.uniform<i8:f32, 2.000000e+00:3>>
}
```

whereas the following, which is not supported in
`stablehlo-legalize-quant-to-int` will choose the fallback path.
```
func.func @max_per_tensor_diff_quant_parameters(%arg0: tensor<!quant.uniform<i8:f32,1.0:0>>, %arg1: tensor<!quant.uniform<i8:f32,2.0:1>>) ->  tensor<!quant.uniform<i8:f32,3.0:2>> {
  %0 = "stablehlo.maximum"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32,1.0:0>>, tensor<!quant.uniform<i8:f32,2.0:1>>) -> tensor<!quant.uniform<i8:f32,3.0:2>>
  func.return %0 : tensor<!quant.uniform<i8:f32,3.0:2>>
}
```


- Currently handles qdq fallback for AddOp and a bunch of `GenericOps`
op
[cs](https://github.com/openxla/stablehlo/blob/eba821aa1c54a21d70331d7926dfc8b929f988f3/stablehlo/transforms/StablehloLegalizeQuantToInt.cpp#L1239).
qdq fallback for `dot_general` and `convolution` will be handled in a
follow up PR. What this means is we will still see quantized
dot_gneral/convolution program, which are currently unsupported by
`stablehlo-legalize-quant-to-int`, error out.
 

[childPR](#2459)
@sdasgup3 sdasgup3 force-pushed the support-qdq-decomp-tanop branch from 6591cbd to 8fe560e Compare July 26, 2024 01:42
@sdasgup3 sdasgup3 merged commit 4286b80 into openxla:main Jul 26, 2024
10 checks passed
sdasgup3 added a commit that referenced this pull request Jul 26, 2024
[ParentPR](#2459)

Previously while creating the QDQ pattern we use `create` API without
using the result type and hence reply on the type inference to derive
the result type. That works for most of the element-wise operations,
however, for dot_general and convolution the result type can might be
infeasible to infer in the presense of input quantize types.

The PR fixes that.

Note to the reviewers: To may just focus on the very last commit of the
chain. The rest is coming from parent PR.

[ChildPR](#2461)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants