diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index e5ad57c6b87a..c709e2b4e7bd 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -34,6 +34,7 @@ from .. import expr as _expr from .. import function as _function from .. import op as _op +from .. import qnn from ..ty import TupleType, TensorType, Any from ..loops import while_loop from .. import transform @@ -805,14 +806,35 @@ def log_sigmoid(self, inputs, input_types): data = inputs[0] return _op.log(_op.tensor.sigmoid(data)) - def hard_swish(self, inputs, input_types): - data = inputs[0] - dtype = input_types[0] + def hard_sigmoid(self, inputs, input_types): + def _relu6(x): + return _op.tensor.clip(x, 0.0, 6.0) - def _relu6(input_tensor): - return _op.tensor.clip(input_tensor, 0.0, 6.0) + def func(x): + return _relu6(x + _expr.const(3.0)) / _expr.const(6.0) + + if self.is_quantized_tensor(inputs[0]): + input_scale = _expr.const(inputs[1]) + input_zero_point = _expr.const(inputs[2]) + # PyTorch seems to use the following output qparams, but accuracy + # is broken if we use this. + # TODO(masahi): Revisit this parameter choice + # + # Taken from src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp + # output_scale = _expr.const(0.00390625) # 1.0 / 2^8 + # output_zero_point = _expr.const(-128) + output_scale = input_scale + output_zero_point = input_zero_point + + data = qnn.op.dequantize(inputs[0], input_scale, input_zero_point, axis=1) + out = func(data) + return qnn.op.quantize(out, output_scale, output_zero_point, out_dtype="uint8") + + return func(inputs[0]) - return data * _relu6(data + _expr.const(3.0, dtype=dtype)) / _expr.const(6.0, dtype=dtype) + def hard_swish(self, inputs, input_types): + data = inputs[0] + return data * self.hard_sigmoid(inputs, input_types) def adaptive_avg_pool_2d(self, inputs, input_types): data = inputs[0] @@ -2418,6 +2440,8 @@ def create_convert_map(self): "aten::__not__": self.logical_not, "aten::hardswish_": self.hard_swish, "aten::hardswish": self.hard_swish, + "aten::hardsigmoid_": self.hard_sigmoid, + "aten::hardsigmoid": self.hard_sigmoid, "aten::cumsum": self.cumsum, "aten::masked_fill": self.masked_fill, "aten::masked_select": self.masked_select, diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index e3431043bc86..2b85a1f3a1be 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -191,6 +191,7 @@ def _get_quant_param_for_input(input_value): "quantized::cat": (2, 3), "quantized::mul_scalar": (2, 3), "quantized::add_scalar": (2, 3), + "quantized::hardswish": (1, 2), } def dfs(current_node): @@ -358,6 +359,8 @@ def add_input_quant_params_to_op_inputs(graph): "quantized::add_scalar": 1, "quantized::mul_scalar": 1, "quantized::relu6": 1, + "quantized::hardswish": 1, + "aten::hardsigmoid": 1, } need_input_quant_param = set(num_quantized_inputs.keys()) @@ -765,6 +768,7 @@ def _impl(inputs, _): out_zp = _expr.const(inputs[3]) if q_min > z - c_q or q_max < z - c_q: + # TODO(masahi): Replace this with integer only compute dequant = relay.qnn.op.dequantize(inputs[0], _expr.const(s), _expr.const(z)) dequantized_add = _op.tensor.add(dequant, _expr.const(c_q * s)) return relay.qnn.op.quantize( @@ -820,6 +824,35 @@ def _impl(inputs, _): return _impl +def _hswish(): + # refer to src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp + # They fallback to fp32 + def _impl(inputs, _): + assert len(inputs) == 5, "Input quant params not found in op inputs" + # TODO(masahi): Replace this with integer only compute. + # We do not have to strictly follow how PyTorch does it. + + def relu6(x): + return _op.tensor.clip(x, 0.0, 6.0) + + def hardsigmoid(x): + dtype = "float32" + return relu6(x + _expr.const(3.0, dtype=dtype)) / _expr.const(6.0, dtype=dtype) + + output_scale = _expr.const(inputs[1]) + output_zero_point = _expr.const(inputs[2]) + input_scale = _expr.const(inputs[3]) + input_zero_point = _expr.const(inputs[4]) + + dequant = relay.qnn.op.dequantize(inputs[0], input_scale, input_zero_point, axis=1) + dequantized_hswish = dequant * hardsigmoid(dequant) + return relay.qnn.op.quantize( + dequantized_hswish, output_scale, output_zero_point, out_dtype="uint8" + ) + + return _impl + + def _linear_dynamic(): def _calculate_qparam(inp): # reference ATen/native/quantized/cpu/qlinear_dynamic.cpp @@ -906,4 +939,5 @@ def _impl(inputs, _): "quantized::mul_scalar": _mul_scalar(), "quantized::relu6": _relu6(), "quantized::linear_dynamic": _linear_dynamic(), + "quantized::hardswish": _hswish(), } diff --git a/src/relay/transforms/fold_explicit_padding.cc b/src/relay/transforms/fold_explicit_padding.cc index d606eb445a79..bab8b814df05 100644 --- a/src/relay/transforms/fold_explicit_padding.cc +++ b/src/relay/transforms/fold_explicit_padding.cc @@ -182,7 +182,7 @@ class SimplifyExplicitPadding { }; /*! - * \brief ImplicitPadding finds explict padding before an op that can + * \brief FoldExplicitPadding finds explict padding before an op that can * support implicit padding and fuses them. */ Expr FoldExplicitPadding(const Expr& expr, const IRModule& mod) { diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 07e52b7079e8..29c69abba542 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -41,7 +41,6 @@ def torch_version_check(): def get_tvm_runtime(script_module, input_name, ishape): - input_shapes = [(input_name, ishape)] mod, params = relay.frontend.from_pytorch(script_module, input_shapes) @@ -125,43 +124,40 @@ def fuse_model(self): # Mobilenet V3 related modules class Hsigmoid(nn.Module): - def __init__(self, inplace=True, add_stub=False): + def __init__(self, add_stub=False): super().__init__() - self.float_op = nn.quantized.FloatFunctional() - self.relu6 = nn.ReLU6(inplace=inplace) self.quant = QuantStub() self.dequant = DeQuantStub() self.add_stub = add_stub + self.hsigmoid = nn.Hardsigmoid() def forward(self, x): if self.add_stub: x = self.quant(x) - relu6 = self.relu6(self.float_op.add_scalar(x, 3.0)) - mul = self.float_op.mul_scalar(relu6, 1 / 6.0) + x = self.hsigmoid(x) if self.add_stub: - mul = self.dequant(mul) - return mul + x = self.dequant(x) + return x def fuse_model(self): pass class Hswish(nn.Module): - def __init__(self, inplace=True, add_stub=False): - super(Hswish, self).__init__() - self.float_op = nn.quantized.FloatFunctional() - self.hsigmoid = Hsigmoid(inplace, add_stub=False) + def __init__(self, add_stub=False): + super().__init__() self.quant = QuantStub() self.dequant = DeQuantStub() self.add_stub = add_stub + self.hswish = nn.Hardswish() def forward(self, x): if self.add_stub: x = self.quant(x) - mul = self.float_op.mul(x, self.hsigmoid(x)) + x = self.hswish(x) if self.add_stub: - mul = self.dequant(mul) - return mul + x = self.dequant(x) + return x def fuse_model(self): pass @@ -274,18 +270,12 @@ def test_quantized_modules(): ("conv_bn_relu" + postfix, imagenet_ishape, ConvBn(with_relu=True), per_channel), ("linear" + postfix, (16, 16), Linear(), per_channel), ("linear_relu" + postfix, (16, 16), Linear(with_relu=True), per_channel), - ] - - if torch_version_check(): - qmodules += [ ("hsigmoid", imagenet_ishape, Hsigmoid(add_stub=True), False), ("hswish", imagenet_ishape, Hswish(add_stub=True), False), ("semodule", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), False), ("semodule, per_channel", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), True), ("mul_scalar negative", imagenet_ishape, MulScalarNegative(), False), ] - else: - print("Skipping tests that require torch > 1.4") for (module_name, ishape, raw_module, per_channel) in qmodules: raw_module.eval() @@ -372,6 +362,13 @@ def get_imagenet_input(): # ("googlenet", qgooglenet(pretrained=True), per_channel), ] + if is_version_greater_than("1.7.1"): + from torchvision.models.quantization import mobilenet_v3_large as qmobilenet_v3_large + + qmodels.append( + ("mobilenet_v3_large", qmobilenet_v3_large(pretrained=True, quantize=True).eval(), True) + ) + results = [] for (model_name, raw_model, per_channel) in qmodels: @@ -385,7 +382,10 @@ def get_imagenet_input(): inp = get_imagenet_input() pt_inp = torch.from_numpy(inp) - quantize_model(raw_model, pt_inp, per_channel=per_channel) + if "mobilenet_v3_large" not in model_name: + # mv3 was qat-ed, quantize=True option above makes it already quantized + quantize_model(raw_model, pt_inp, per_channel=per_channel) + script_module = torch.jit.trace(raw_model, pt_inp).eval() with torch.no_grad(): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 24f8edab7d98..83c1698799c7 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3651,6 +3651,13 @@ def test_hard_swish(): verify_model(torch.nn.Hardswish(inplace=True).eval(), input_data=input) +def test_hard_sigmoid(): + examples = [torch.rand(8).float(), torch.rand(8, 10).float(), torch.rand(1, 1, 10).float()] + for input in examples: + verify_model(torch.nn.Hardsigmoid().eval(), input_data=input) + verify_model(torch.nn.Hardsigmoid(inplace=True).eval(), input_data=input) + + def test_cumsum(): def test_fn(dim, dtype=None): return lambda x: torch.cumsum(x, dim=dim, dtype=dtype) @@ -3893,6 +3900,8 @@ def test_fn(is_sorted, return_inverse, return_counts): test_logical_and() test_masked_select() test_unique() + test_hard_swish() + test_hard_sigmoid() # Model tests test_resnet18() @@ -3931,4 +3940,3 @@ def test_fn(is_sorted, return_inverse, return_counts): # Test convert torch script(jit) with specific inputs' types test_convert_torch_script_with_input_types() - test_hard_swish()