Skip to content

Commit

Permalink
[QNN] Optimize requantize for power of 2 and fix dequantize for per-c…
Browse files Browse the repository at this point in the history
…hannel quantized input (#6675)

* [QNN] Optimize requantize for power of 2 and bug in dequantize

* Comments

* Docs

* Comments

* Ethos
  • Loading branch information
anijain2305 authored Oct 29, 2020
1 parent b0afc74 commit 380e2e9
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 33 deletions.
4 changes: 2 additions & 2 deletions src/relay/qnn/op/dequantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale,
expanded_input_zero_point = ExpandBiasToMatchAxis(input_zero_point, n_dim, {axis});
}

auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), input_zero_point);
auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), input_scale);
auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), expanded_input_zero_point);
auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), expanded_input_scale);
return scaled_output;
}

Expand Down
93 changes: 62 additions & 31 deletions src/target/intrin_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,37 +128,68 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.q_multiply_shift")
PrimExpr q = call->args[2];
PrimExpr s = call->args[3];

// Only int32 types are supported (any number of lanes is allowed)
ICHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32);
ICHECK(s.dtype().code() == DLDataTypeCode::kDLInt && s.dtype().bits() == 32);

DataType hp_dtype = DataType::Int(64, x.dtype().lanes());
DataType lp_dtype = DataType::Int(32, x.dtype().lanes());

// 1) Calculating the integer multiplier and integer shift
PrimExpr zero = make_const(s.dtype(), 0);
PrimExpr left_shift = tir::Select(s > zero, s, zero);
PrimExpr right_shift = tir::Select(s > zero, zero, -s);

// 2) Cast and Multiply the integer multiplier
PrimExpr one = make_const(hp_dtype, 1);
x = cast(hp_dtype, x);
y = cast(hp_dtype, y);
x = tir::Select(left_shift != zero, x << left_shift, x);

// 3) Perform the multiplication in higher precision.
x = x * y;

// 4) Find the rounding scalar
PrimExpr total_right_shift = right_shift + q;
PrimExpr pos_rounding_value = (one << (total_right_shift - 1));
x = x + pos_rounding_value;

// 5) Simply right shift the result to get the final output.
x = x >> total_right_shift;

// 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32.
*rv = cast(lp_dtype, x);
// Lambda function to extract the int value from PrimExpr
auto get_int_value = [](const PrimExpr node) {
if (auto int_node = node.as<IntImmNode>()) {
return int_node->value;
}
auto broadcast_node = node.as<BroadcastNode>();
CHECK(broadcast_node != nullptr);
auto int_node = broadcast_node->value.as<IntImmNode>();
CHECK(int_node != nullptr);
return int_node->value;
};
// Power of 2 is determined by the fixed_point_multiplier == 1 << 30. In case of power of 2,
// fixed point multiplier will represent a float value of 0.5. In fixed point, this is
// represented by 1 << 30.
if (get_int_value(y) == (1 << 30)) {
PrimExpr exp = s - 1;
int exp_val = get_int_value(s) - 1;
if (exp_val > 0) {
// power of 2 is greater than 0, apply left shift.
*rv = x << exp;
} else {
// power of 2 is less than 0, round and then apply right shift.
DataType lp_dtype = DataType::Int(32, x.dtype().lanes());
PrimExpr one = make_const(lp_dtype, 1);
exp = -exp;
PrimExpr rounding_factor = one << (exp - 1);
PrimExpr rounded_t = x + rounding_factor;
*rv = rounded_t >> exp;
}
} else {
// Only int32 types are supported (any number of lanes is allowed)
ICHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32);
ICHECK(s.dtype().code() == DLDataTypeCode::kDLInt && s.dtype().bits() == 32);

DataType hp_dtype = DataType::Int(64, x.dtype().lanes());
DataType lp_dtype = DataType::Int(32, x.dtype().lanes());

// 1) Calculating the integer multiplier and integer shift
PrimExpr zero = make_const(s.dtype(), 0);
PrimExpr left_shift = tir::Select(s > zero, s, zero);
PrimExpr right_shift = tir::Select(s > zero, zero, -s);

// 2) Cast and Multiply the integer multiplier
PrimExpr one = make_const(hp_dtype, 1);
x = cast(hp_dtype, x);
y = cast(hp_dtype, y);
x = tir::Select(left_shift != zero, x << left_shift, x);

// 3) Perform the multiplication in higher precision.
x = x * y;

// 4) Find the rounding scalar
PrimExpr total_right_shift = right_shift + q;
PrimExpr pos_rounding_value = (one << (total_right_shift - 1));
x = x + pos_rounding_value;

// 5) Simply right shift the result to get the final output.
x = x >> total_right_shift;

// 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32.
*rv = cast(lp_dtype, x);
}
});

} // namespace intrin
Expand Down
18 changes: 18 additions & 0 deletions tests/python/relay/test_op_qnn_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,26 @@ def test_channelwise_axis_1():
)


def test_channelwise_axis_0():
data = np.array([0, 1, 2, 3, 4, 243, 247, 249, 250, 251]).astype("uint8").reshape((2, 5))
output = (
np.array([-63.5, -63, -62.5, -62, -61.5, 30, 31, 31.5, 31.75, 32])
.astype("float32")
.reshape((2, 5))
)
quant_args = {
"in_zero_point": np.array([127, 123]).astype("int32"),
"in_scale": np.array([0.5, 0.25]).astype("float32"),
}

dequantize_test_driver(
in_dtype="uint8", quant_args=quant_args, in_data=data, verify_output_data=output, axis=0
)


if __name__ == "__main__":
test_uint8_to_float32()
test_int8_to_float32()
test_int32_to_float32()
test_channelwise_axis_1()
test_channelwise_axis_0()
43 changes: 43 additions & 0 deletions tests/python/relay/test_op_qnn_requantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,48 @@ def test_upscale():
verify(mod, (golden_data, golden_output))


def test_non_power_of_two():
for rounding in roundings:
mod = get_mod(
data_shape=(32,),
data_dtype="int32",
out_dtype="int8",
input_scale=1,
output_scale=3,
rounding=rounding,
)

# Try positive values
golden_data = np.multiply(np.arange(0, 32, 1).astype("int32"), 3)
golden_output = np.arange(0, 32, 1)
verify(mod, (golden_data, golden_output))

# Try negative values
golden_data = np.multiply(np.arange(0, -32, -1).astype("int32"), 3)
golden_output = np.arange(0, -32, -1)
verify(mod, (golden_data, golden_output))

# Try a different scale
mod = get_mod(
data_shape=(32,),
data_dtype="int32",
out_dtype="int8",
input_scale=3,
output_scale=1,
rounding=rounding,
)

# Try positive values
golden_data = np.arange(0, 32, 1).astype("int32")
golden_output = np.multiply(golden_data, 3)
verify(mod, (golden_data, golden_output))

# Try negative values
golden_data = np.arange(0, -32, -1).astype("int32")
golden_output = np.multiply(golden_data, 3)
verify(mod, (golden_data, golden_output))


def test_saturation():
for rounding in roundings:
mod = get_mod(
Expand Down Expand Up @@ -397,6 +439,7 @@ def test_per_channel_different_scale():
test_same_scale()
test_downscale()
test_upscale()
test_non_power_of_two()
test_saturation()
test_zero_point()
test_per_channel_same_scale()
Expand Down

0 comments on commit 380e2e9

Please sign in to comment.