forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Relay][QNN] Simulated Quantize and Dequantize (apache#7613)
* Add initial implementation of flexible simulated qnn ops. * Added proper topi testing and fixed qnn axis bug. * Add injective schedule wrapping. * Stuck on typerel problem. * Relay integration fully working. * Simulated quantize totally finished. * Change dtype to be a scalar rather than tensor. * Undo change to quantize. * formatting. * Fix attritubes. * Fix negative axis dequantize bug. * Add topi simulated dequantize. * Add simulated_dequantize op to topi and relay. * Formatting. * Test negative axis perchannel dequantization. * Lint formatting. * Change import order to make lint happy. * Fix pytest. * Directly return make call. * Clarify disable mode for simulated qnn ops and fix typos. * Line too long oops. Co-authored-by: Ubuntu <jwfromm@jwfromm-cpu-dev.itxhlkosmouevgkdrmwxfbs5qh.xx.internal.cloudapp.net>
- Loading branch information
Showing
15 changed files
with
1,019 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
# pylint: disable=invalid-name, unused-argument, len-as-condition | ||
"""QNN operator feature registration""" | ||
|
||
from tvm import topi | ||
|
||
from ...op.op import register_compute | ||
from ...op.op import register_injective_schedule | ||
from ...op.op import register_pattern, OpPattern | ||
|
||
|
||
@register_compute("qnn.simulated_quantize") | ||
def simulated_quantize_compute(attrs, inputs, output_type): | ||
assert len(inputs) == 4 | ||
return [ | ||
topi.nn.simulated_quantize( | ||
inputs[0], inputs[1], inputs[2], inputs[3], axis=attrs.get_int("axis") | ||
) | ||
] | ||
|
||
|
||
register_injective_schedule("qnn.simulated_quantize") | ||
register_pattern("qnn.simulated_quantize", OpPattern.ELEMWISE) | ||
|
||
|
||
@register_compute("qnn.simulated_dequantize") | ||
def simulated_dequantize_compute(attrs, inputs, output_type): | ||
assert len(inputs) == 4 | ||
return [ | ||
topi.nn.simulated_dequantize( | ||
inputs[0], inputs[1], inputs[2], inputs[3], axis=attrs.get_int("axis") | ||
) | ||
] | ||
|
||
|
||
register_injective_schedule("qnn.simulated_dequantize") | ||
register_pattern("qnn.simulated_dequantize", OpPattern.ELEMWISE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Quantized Neural Network (QNN) Operators""" | ||
import tvm | ||
from tvm import te, tir, topi | ||
|
||
SQNN_DISABLE = 0 | ||
SQNN_INT8 = 1 | ||
SQNN_UINT8 = 2 | ||
SQNN_INT32 = 3 | ||
|
||
SQNN_DTYPE_TO_CODE = { | ||
"disable": SQNN_DISABLE, | ||
"int8": SQNN_INT8, | ||
"uint8": SQNN_UINT8, | ||
"int32": SQNN_INT32, | ||
} | ||
|
||
SQNN_CODE_TO_DTYPE = {v: k for k, v in SQNN_DTYPE_TO_CODE.items()} | ||
|
||
|
||
@tvm.te.tag_scope(tag=topi.tag.ELEMWISE) | ||
def simulated_quantize(data, out_dtype, output_scale=None, output_zero_point=None, axis=-1): | ||
"""Simulated QNN quantize operator that mimics QNN outputs without changing datatype. | ||
The benefit of this operator over true QNN quantize is that this operator allows dynamic | ||
datatype selection and can operate on both per-channel and scalar scales and zero points while | ||
QNN quantize requires both of these to be fixed at compile time. | ||
Parameters | ||
---------- | ||
data: tvm.te.Tensor | ||
An N-D input tensor to the operator. | ||
out_dtype: tvm.te.Tensor | ||
A scalar variable that indicates which datatype to simulate quantization with. Use | ||
SQNN_DTYPE_TO_CODE to convert a dtype string into the corresponding variable | ||
value. | ||
output_scale: tvm.te.Tensor, optional | ||
A scalar tensor representing the scale to use when quantizing to integer datatypes. | ||
When it contains more than a single value, N must match the number of channels in data. | ||
output_zero_point: tvm.te.Tensor, optional | ||
A 1-D tensor representing the zero point to use when quantizing to integer datatypes. | ||
When it contains more than a single value, N must match the number of channels in data. | ||
axis: int, optional | ||
The channel axis for quantization. Default value is -1 which corresponds to the last axis. | ||
""" | ||
# When disabled, just pass through the input values. | ||
def _compute_pass_through(value, *indices): | ||
return value[indices] | ||
|
||
# Simulate quantization for arbitrary integer datatypes. The computation for all datatypes is: | ||
# Q_output = clip((round(input_tensor/output_scale) + output_zero_point), | ||
# out_dtype::min, | ||
# out_dtype::max) | ||
def _compute_intn(dtype, value, *indices): | ||
assert output_scale is not None and output_zero_point is not None | ||
const_min = tvm.tir.min_value(dtype) | ||
const_max = tvm.tir.max_value(dtype) | ||
# Use indexmod to handle both scalar and per-channel QNN parameters. | ||
scale_idx = tir.indexmod(indices[axis], topi.shape(output_scale)[0]) | ||
zp_idx = tir.indexmod(indices[axis], topi.shape(output_zero_point)[0]) | ||
return te.max( | ||
te.min( | ||
te.round(value[indices] / output_scale[scale_idx]) + output_zero_point[zp_idx], | ||
const_max, | ||
), | ||
const_min, | ||
) | ||
|
||
# Use an if chain to dynamically return the proper quantization based on the input datatype. | ||
# This allows the op to compile once but apply different quantization approaches | ||
# using a variable datatype input. | ||
def _dispatch_sim_quantize(value): | ||
pass_through_value = te.compute( | ||
data.shape, lambda *indices: _compute_pass_through(value, *indices) | ||
) | ||
int8_value = te.compute( | ||
data.shape, | ||
lambda *indices: tir.if_then_else( | ||
out_dtype.equal(SQNN_DTYPE_TO_CODE["int8"]), | ||
_compute_intn("int8", value, *indices), | ||
pass_through_value[indices], | ||
), | ||
) | ||
uint8_value = te.compute( | ||
data.shape, | ||
lambda *indices: tir.if_then_else( | ||
out_dtype.equal(SQNN_DTYPE_TO_CODE["uint8"]), | ||
_compute_intn("uint8", value, *indices), | ||
int8_value[indices], | ||
), | ||
) | ||
int32_value = te.compute( | ||
data.shape, | ||
lambda *indices: tir.if_then_else( | ||
out_dtype.equal(SQNN_DTYPE_TO_CODE["int32"]), | ||
_compute_intn("int32", value, *indices), | ||
uint8_value[indices], | ||
), | ||
) | ||
|
||
return int32_value | ||
|
||
return te.compute(data.shape, lambda *indices: _dispatch_sim_quantize(data)[indices]) | ||
|
||
|
||
@tvm.te.tag_scope(tag=topi.tag.ELEMWISE) | ||
def simulated_dequantize(data, in_dtype, input_scale=None, input_zero_point=None, axis=-1): | ||
"""Simulated QNN dequantize operator that mimics QNN outputs without changing datatype. | ||
The benefit of this operator over true QNN dequantize is that this operator allows dynamic | ||
datatype selection and can operate on both per-channel and scalar scales and zero points while | ||
QNN dequantize requires both of these to be fixed at compile time. | ||
Parameters | ||
---------- | ||
data: tvm.te.Tensor | ||
An N-D input tensor to the operator. | ||
in_dtype: tvm.te.Tensor | ||
A scalar variable that indicates which datatype to simulate dequantization with. Use | ||
SQNN_DTYPE_TO_CODE to convert a dtype string into the corresponding variable | ||
value. | ||
input_scale: tvm.te.Tensor, optional | ||
A scalar tensor representing the scale to use when dequantizing from integer datatypes. | ||
When it contains more than a single value, N must match the number of channels in data. | ||
input_zero_point: tvm.te.Tensor, optional | ||
A 1-D tensor representing the zero point to use when dequantizing from integer datatypes. | ||
When it contains more than a single value, N must match the number of channels in data. | ||
axis: int, optional | ||
The channel axis for quantization. Default value is -1 which corresponds to the last axis. | ||
""" | ||
# When disabled simply return the input tensor. | ||
def _compute_pass_through(value, *indices): | ||
return value[indices] | ||
|
||
# Simulate dequantization for arbitrary integer datatypes. The computation for all datatypes is: | ||
# DQ_output = (input - zero_point) * scale | ||
def _compute_intn(value, *indices): | ||
assert input_scale is not None and input_zero_point is not None | ||
# Use indexmod to handle both scalar and per-channel QNN parameters. | ||
scale_idx = tir.indexmod(indices[axis], topi.shape(input_scale)[0]) | ||
zp_idx = tir.indexmod(indices[axis], topi.shape(input_zero_point)[0]) | ||
return (value[indices] - input_zero_point[zp_idx]) * input_scale[scale_idx] | ||
|
||
# Use an if chain to dynamically return the proper dequantization based on the input datatype. | ||
# This allows the op to compile once but apply different quantization approaches | ||
# using a variable datatype input. | ||
def _dispatch_sim_dequantize(value): | ||
pass_through_value = te.compute( | ||
data.shape, lambda *indices: _compute_pass_through(value, *indices) | ||
) | ||
intn_condition = tvm.te.any( | ||
in_dtype.equal(SQNN_DTYPE_TO_CODE["int8"]), | ||
in_dtype.equal(SQNN_DTYPE_TO_CODE["uint8"]), | ||
in_dtype.equal(SQNN_DTYPE_TO_CODE["int32"]), | ||
) | ||
intn_value = te.compute( | ||
data.shape, | ||
lambda *indices: tir.if_then_else( | ||
intn_condition, | ||
_compute_intn(value, *indices), | ||
pass_through_value[indices], | ||
), | ||
) | ||
|
||
return intn_value | ||
|
||
return te.compute(data.shape, lambda *indices: _dispatch_sim_dequantize(data)[indices]) |
Oops, something went wrong.