Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QNN] Optimize requantize for power of 2 and fix dequantize for per-channel quantized input #6675

Merged
merged 5 commits into from
Oct 29, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/relay/qnn/op/dequantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale,
expanded_input_zero_point = ExpandBiasToMatchAxis(input_zero_point, n_dim, {axis});
}

auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), input_zero_point);
auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), input_scale);
auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), expanded_input_zero_point);
auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), expanded_input_scale);
return scaled_output;
}

Expand Down
21 changes: 13 additions & 8 deletions src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,17 +155,22 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
if (!IsEqualScalar(input_scale, output_scale)) {
int32_t fixed_point_multiplier, shift;
std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier);

const bool is_upward_rounding = (param->rounding == "UPWARD");

// When using upward rounding (i.e., x.5 rounded to x+1), leverage
// the FixedPointMultiply operator
scaled_int32_t =
(is_upward_rounding
? FixedPointMultiply(scaled_int32_t, fixed_point_multiplier, shift)
: FixedPointMultiplyToNearest(scaled_int32_t, double_multiplier, input_shape));
if (is_upward_rounding && fixed_point_multiplier == (1 << 30)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the fixed_point_multiplier must be (1 << 30)?

Copy link
Contributor Author

@anijain2305 anijain2305 Oct 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, we use frexp to represent a floating point numbers. It gives a float significant which is between [0.5, 1). For power of 2, it is always 0.5. We convert the float significand into a fixed point 32-bit integer, where decimal point is between the first and second bit. 0.5 in this representation = 1 << 30

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anijain2305 , can add a small one line comment regarding (1<<30) ? These days aside from float32 many other types of float floats around.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cbalint13 Added a comment, can you PTAL

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anijain2305 , Thank you !

// Power of 2 is determined by the fixed_point_multiplier == 1 << 30. In case of power of 2,
// fixed point multiplier will represent a float value of 0.5. In fixed point, this is
// represented by 1 << 30.
scaled_int32_t = PowerOfTwoMultiply(scaled_int32_t, shift - 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense for this to go in FixedPointMultiply? This would give the possibility to everybody using FixedPointMultiply to exploit this fix.

} else {
// When using upward rounding (i.e., x.5 rounded to x+1), leverage
// the FixedPointMultiply operator
scaled_int32_t =
(is_upward_rounding
? FixedPointMultiply(scaled_int32_t, fixed_point_multiplier, shift)
: FixedPointMultiplyToNearest(scaled_int32_t, double_multiplier, input_shape));
}
}

} else {
// This is per-channel (per=axis) quantization.
std::vector<double> double_multipliers;
Expand Down
15 changes: 15 additions & 0 deletions src/relay/qnn/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,21 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplie
return std::make_pair(significand, exponent);
}

Expr PowerOfTwoMultiply(Expr tensor, int32_t exp) {
Expr out;
if (exp > 0) {
// power of 2 is greater than 0, apply left shift.
out = LeftShift(tensor, MakeConstantScalar(DataType::Int(32), exp));
} else {
// power of 2 is less than 0, round and then apply right shift.
exp = -exp;
auto rounding_factor = 1 << (exp - 1);
auto rounded_t = Add(tensor, MakeConstantScalar(DataType::Int(32), rounding_factor));
out = RightShift(rounded_t, MakeConstantScalar(DataType::Int(32), exp));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure you don't need to convert to int64 upfront and then cast back to int32?

}
return out;
}

Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier,
const Array<IndexExpr>& input_shape) {
// Choose high precision datatype to be int64. This is for avoiding overflow
Expand Down
7 changes: 7 additions & 0 deletions src/relay/qnn/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ static inline int64_t get_const_int(const tvm::PrimExpr& x) {
*/
Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier,
const Array<IndexExpr>& input_shape);
/*
* \brief Mutiply an integer datatype tensor by a power of two.
* \param tensor The quantized input tensor of dtype int32.
* \param exp The exp or the power of 2 representing the number to be multiplied.
* \return The sequence of Relay ops for power of two multiplication.
*/
Expr PowerOfTwoMultiply(Expr tensor, int32_t exp);

/*
* \brief Fixed point multiplication between integer tensor with floating point
Expand Down
18 changes: 18 additions & 0 deletions tests/python/relay/test_op_qnn_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,26 @@ def test_channelwise_axis_1():
)


def test_channelwise_axis_0():
data = np.array([0, 1, 2, 3, 4, 243, 247, 249, 250, 251]).astype("uint8").reshape((2, 5))
output = (
np.array([-63.5, -63, -62.5, -62, -61.5, 30, 31, 31.5, 31.75, 32])
.astype("float32")
.reshape((2, 5))
)
quant_args = {
"in_zero_point": np.array([127, 123]).astype("int32"),
"in_scale": np.array([0.5, 0.25]).astype("float32"),
}

dequantize_test_driver(
in_dtype="uint8", quant_args=quant_args, in_data=data, verify_output_data=output, axis=0
)


if __name__ == "__main__":
test_uint8_to_float32()
test_int8_to_float32()
test_int32_to_float32()
test_channelwise_axis_1()
test_channelwise_axis_0()
43 changes: 43 additions & 0 deletions tests/python/relay/test_op_qnn_requantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,48 @@ def test_upscale():
verify(mod, (golden_data, golden_output))


def test_non_power_of_two():
for rounding in roundings:
mod = get_mod(
data_shape=(32,),
data_dtype="int32",
out_dtype="int8",
input_scale=1,
output_scale=3,
rounding=rounding,
)

# Try positive values
golden_data = np.multiply(np.arange(0, 32, 1).astype("int32"), 3)
golden_output = np.arange(0, 32, 1)
verify(mod, (golden_data, golden_output))

# Try negative values
golden_data = np.multiply(np.arange(0, -32, -1).astype("int32"), 3)
golden_output = np.arange(0, -32, -1)
verify(mod, (golden_data, golden_output))

# Try a different scale
mod = get_mod(
data_shape=(32,),
data_dtype="int32",
out_dtype="int8",
input_scale=3,
output_scale=1,
rounding=rounding,
)

# Try positive values
golden_data = np.arange(0, 32, 1).astype("int32")
golden_output = np.multiply(golden_data, 3)
verify(mod, (golden_data, golden_output))

# Try negative values
golden_data = np.arange(0, -32, -1).astype("int32")
golden_output = np.multiply(golden_data, 3)
verify(mod, (golden_data, golden_output))


def test_saturation():
for rounding in roundings:
mod = get_mod(
Expand Down Expand Up @@ -397,6 +439,7 @@ def test_per_channel_different_scale():
test_same_scale()
test_downscale()
test_upscale()
test_non_power_of_two()
test_saturation()
test_zero_point()
test_per_channel_same_scale()
Expand Down