Skip to content

Commit

Permalink
[QNN] Optimize requantize for power of 2 and bug in dequantize
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Oct 13, 2020
1 parent f73a1f6 commit def4e78
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/relay/qnn/op/dequantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale,
expanded_input_zero_point = ExpandBiasToMatchAxis(input_zero_point, n_dim, {axis});
}

auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), input_zero_point);
auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), input_scale);
auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), expanded_input_zero_point);
auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), expanded_input_scale);
return scaled_output;
}

Expand Down
19 changes: 11 additions & 8 deletions src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,17 +155,20 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
if (!IsEqualScalar(input_scale, output_scale)) {
int32_t fixed_point_multiplier, shift;
std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier);

const bool is_upward_rounding = (param->rounding == "UPWARD");

// When using upward rounding (i.e., x.5 rounded to x+1), leverage
// the FixedPointMultiply operator
scaled_int32_t =
(is_upward_rounding
? FixedPointMultiply(scaled_int32_t, fixed_point_multiplier, shift)
: FixedPointMultiplyToNearest(scaled_int32_t, double_multiplier, input_shape));
if (is_upward_rounding && fixed_point_multiplier == (1 << 30)) {
// Power of 2
scaled_int32_t = PowerOfTwoMultiply(scaled_int32_t, shift - 1);
} else {
// When using upward rounding (i.e., x.5 rounded to x+1), leverage
// the FixedPointMultiply operator
scaled_int32_t =
(is_upward_rounding
? FixedPointMultiply(scaled_int32_t, fixed_point_multiplier, shift)
: FixedPointMultiplyToNearest(scaled_int32_t, double_multiplier, input_shape));
}
}

} else {
// This is per-channel (per=axis) quantization.
std::vector<double> double_multipliers;
Expand Down
15 changes: 15 additions & 0 deletions src/relay/qnn/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,21 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplie
return std::make_pair(significand, exponent);
}

Expr PowerOfTwoMultiply(Expr tensor, int32_t exp) {
Expr out;
if (exp > 0) {
// power of 2 is greater than 0, apply left shift.
out = LeftShift(tensor, MakeConstantScalar(DataType::Int(32), exp));
} else {
// power of 2 is less than 0, round and then apply right shift.
exp = -exp;
auto rounding_factor = 1 << (exp - 1);
auto rounded_t = Add(tensor, MakeConstantScalar(DataType::Int(32),rounding_factor));
out = RightShift(rounded_t, MakeConstantScalar(DataType::Int(32), exp));
}
return out;
}

Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier,
const Array<IndexExpr>& input_shape) {
// Choose high precision datatype to be int64. This is for avoiding overflow
Expand Down
8 changes: 8 additions & 0 deletions src/relay/qnn/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ static inline int64_t get_const_int(const tvm::PrimExpr& x) {
*/
Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier,
const Array<IndexExpr>& input_shape);
/*
* \brief Mutiply an integer datatype tensor by a power of two.
* \param tensor The quantized input tensor of dtype int32.
* \param exp The exp or the power of 2 representing the number to be multiplied.
* \return The sequence of Relay ops for power of two multiplication.
*/
Expr PowerOfTwoMultiply(Expr tensor, int32_t exp);

/*
* \brief Fixed point multiplication between integer tensor with floating point
Expand Down
14 changes: 14 additions & 0 deletions tests/python/relay/test_op_qnn_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,22 @@ def test_channelwise_axis_1():
)


def test_channelwise_axis_0():
data = np.array([0, 1, 2, 3, 4, 243, 247, 249, 250, 251]).astype("uint8").reshape((2, 5))
output = np.array([-63.5, -63, -62.5, -62, -61.5, 30, 31, 31.5, 31.75, 32]).astype("float32").reshape((2, 5))
quant_args = {
"in_zero_point": np.array([127, 123]).astype("int32"),
"in_scale": np.array([0.5, 0.25]).astype("float32"),
}

dequantize_test_driver(
in_dtype="uint8", quant_args=quant_args, in_data=data, verify_output_data=output, axis=0
)


if __name__ == "__main__":
test_uint8_to_float32()
test_int8_to_float32()
test_int32_to_float32()
test_channelwise_axis_1()
test_channelwise_axis_0()
43 changes: 43 additions & 0 deletions tests/python/relay/test_op_qnn_requantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,48 @@ def test_upscale():
verify(mod, (golden_data, golden_output))


def test_non_power_of_two():
for rounding in roundings:
mod = get_mod(
data_shape=(32,),
data_dtype="int32",
out_dtype="int8",
input_scale=1,
output_scale=3,
rounding=rounding,
)

# Try positive values
golden_data = np.multiply(np.arange(0, 32, 1).astype("int32"), 3)
golden_output = np.arange(0, 32, 1)
verify(mod, (golden_data, golden_output))

# Try negative values
golden_data = np.multiply(np.arange(0, -32, -1).astype("int32"), 3)
golden_output = np.arange(0, -32, -1)
verify(mod, (golden_data, golden_output))

# Try a different scale
mod = get_mod(
data_shape=(32,),
data_dtype="int32",
out_dtype="int8",
input_scale=3,
output_scale=1,
rounding=rounding,
)

# Try positive values
golden_data = np.arange(0, 32, 1).astype("int32")
golden_output = np.multiply(golden_data, 3)
verify(mod, (golden_data, golden_output))

# Try negative values
golden_data = np.arange(0, -32, -1).astype("int32")
golden_output = np.multiply(golden_data, 3)
verify(mod, (golden_data, golden_output))


def test_saturation():
for rounding in roundings:
mod = get_mod(
Expand Down Expand Up @@ -397,6 +439,7 @@ def test_per_channel_different_scale():
test_same_scale()
test_downscale()
test_upscale()
test_non_power_of_two()
test_saturation()
test_zero_point()
test_per_channel_same_scale()
Expand Down

0 comments on commit def4e78

Please sign in to comment.