diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index 2e7a28624e26..2fe075c7e64b 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -96,8 +96,8 @@ Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale, expanded_input_zero_point = ExpandBiasToMatchAxis(input_zero_point, n_dim, {axis}); } - auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), input_zero_point); - auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), input_scale); + auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), expanded_input_zero_point); + auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), expanded_input_scale); return scaled_output; } diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 0808d237fc28..f8f4d0ef5414 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -128,37 +128,68 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.q_multiply_shift") PrimExpr q = call->args[2]; PrimExpr s = call->args[3]; - // Only int32 types are supported (any number of lanes is allowed) - ICHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32); - ICHECK(s.dtype().code() == DLDataTypeCode::kDLInt && s.dtype().bits() == 32); - - DataType hp_dtype = DataType::Int(64, x.dtype().lanes()); - DataType lp_dtype = DataType::Int(32, x.dtype().lanes()); - - // 1) Calculating the integer multiplier and integer shift - PrimExpr zero = make_const(s.dtype(), 0); - PrimExpr left_shift = tir::Select(s > zero, s, zero); - PrimExpr right_shift = tir::Select(s > zero, zero, -s); - - // 2) Cast and Multiply the integer multiplier - PrimExpr one = make_const(hp_dtype, 1); - x = cast(hp_dtype, x); - y = cast(hp_dtype, y); - x = tir::Select(left_shift != zero, x << left_shift, x); - - // 3) Perform the multiplication in higher precision. - x = x * y; - - // 4) Find the rounding scalar - PrimExpr total_right_shift = right_shift + q; - PrimExpr pos_rounding_value = (one << (total_right_shift - 1)); - x = x + pos_rounding_value; - - // 5) Simply right shift the result to get the final output. - x = x >> total_right_shift; - - // 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32. - *rv = cast(lp_dtype, x); + // Lambda function to extract the int value from PrimExpr + auto get_int_value = [](const PrimExpr node) { + if (auto int_node = node.as()) { + return int_node->value; + } + auto broadcast_node = node.as(); + CHECK(broadcast_node != nullptr); + auto int_node = broadcast_node->value.as(); + CHECK(int_node != nullptr); + return int_node->value; + }; + // Power of 2 is determined by the fixed_point_multiplier == 1 << 30. In case of power of 2, + // fixed point multiplier will represent a float value of 0.5. In fixed point, this is + // represented by 1 << 30. + if (get_int_value(y) == (1 << 30)) { + PrimExpr exp = s - 1; + int exp_val = get_int_value(s) - 1; + if (exp_val > 0) { + // power of 2 is greater than 0, apply left shift. + *rv = x << exp; + } else { + // power of 2 is less than 0, round and then apply right shift. + DataType lp_dtype = DataType::Int(32, x.dtype().lanes()); + PrimExpr one = make_const(lp_dtype, 1); + exp = -exp; + PrimExpr rounding_factor = one << (exp - 1); + PrimExpr rounded_t = x + rounding_factor; + *rv = rounded_t >> exp; + } + } else { + // Only int32 types are supported (any number of lanes is allowed) + ICHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32); + ICHECK(s.dtype().code() == DLDataTypeCode::kDLInt && s.dtype().bits() == 32); + + DataType hp_dtype = DataType::Int(64, x.dtype().lanes()); + DataType lp_dtype = DataType::Int(32, x.dtype().lanes()); + + // 1) Calculating the integer multiplier and integer shift + PrimExpr zero = make_const(s.dtype(), 0); + PrimExpr left_shift = tir::Select(s > zero, s, zero); + PrimExpr right_shift = tir::Select(s > zero, zero, -s); + + // 2) Cast and Multiply the integer multiplier + PrimExpr one = make_const(hp_dtype, 1); + x = cast(hp_dtype, x); + y = cast(hp_dtype, y); + x = tir::Select(left_shift != zero, x << left_shift, x); + + // 3) Perform the multiplication in higher precision. + x = x * y; + + // 4) Find the rounding scalar + PrimExpr total_right_shift = right_shift + q; + PrimExpr pos_rounding_value = (one << (total_right_shift - 1)); + x = x + pos_rounding_value; + + // 5) Simply right shift the result to get the final output. + x = x >> total_right_shift; + + // 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32. + *rv = cast(lp_dtype, x); + } }); } // namespace intrin diff --git a/tests/python/relay/test_op_qnn_dequantize.py b/tests/python/relay/test_op_qnn_dequantize.py index 6598e2bb2062..e1416622c236 100644 --- a/tests/python/relay/test_op_qnn_dequantize.py +++ b/tests/python/relay/test_op_qnn_dequantize.py @@ -101,8 +101,26 @@ def test_channelwise_axis_1(): ) +def test_channelwise_axis_0(): + data = np.array([0, 1, 2, 3, 4, 243, 247, 249, 250, 251]).astype("uint8").reshape((2, 5)) + output = ( + np.array([-63.5, -63, -62.5, -62, -61.5, 30, 31, 31.5, 31.75, 32]) + .astype("float32") + .reshape((2, 5)) + ) + quant_args = { + "in_zero_point": np.array([127, 123]).astype("int32"), + "in_scale": np.array([0.5, 0.25]).astype("float32"), + } + + dequantize_test_driver( + in_dtype="uint8", quant_args=quant_args, in_data=data, verify_output_data=output, axis=0 + ) + + if __name__ == "__main__": test_uint8_to_float32() test_int8_to_float32() test_int32_to_float32() test_channelwise_axis_1() + test_channelwise_axis_0() diff --git a/tests/python/relay/test_op_qnn_requantize.py b/tests/python/relay/test_op_qnn_requantize.py index f152a4ebf840..f40a08711451 100644 --- a/tests/python/relay/test_op_qnn_requantize.py +++ b/tests/python/relay/test_op_qnn_requantize.py @@ -204,6 +204,48 @@ def test_upscale(): verify(mod, (golden_data, golden_output)) +def test_non_power_of_two(): + for rounding in roundings: + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="int8", + input_scale=1, + output_scale=3, + rounding=rounding, + ) + + # Try positive values + golden_data = np.multiply(np.arange(0, 32, 1).astype("int32"), 3) + golden_output = np.arange(0, 32, 1) + verify(mod, (golden_data, golden_output)) + + # Try negative values + golden_data = np.multiply(np.arange(0, -32, -1).astype("int32"), 3) + golden_output = np.arange(0, -32, -1) + verify(mod, (golden_data, golden_output)) + + # Try a different scale + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="int8", + input_scale=3, + output_scale=1, + rounding=rounding, + ) + + # Try positive values + golden_data = np.arange(0, 32, 1).astype("int32") + golden_output = np.multiply(golden_data, 3) + verify(mod, (golden_data, golden_output)) + + # Try negative values + golden_data = np.arange(0, -32, -1).astype("int32") + golden_output = np.multiply(golden_data, 3) + verify(mod, (golden_data, golden_output)) + + def test_saturation(): for rounding in roundings: mod = get_mod( @@ -397,6 +439,7 @@ def test_per_channel_different_scale(): test_same_scale() test_downscale() test_upscale() + test_non_power_of_two() test_saturation() test_zero_point() test_per_channel_same_scale()