Skip to content

Commit

Permalink
[Torch, QNN] Support quantized mobilenet v3 from torch 1.8 (apache#7606)
Browse files Browse the repository at this point in the history
* [Torch] support hardsigmoid

* qhswish first impl

* add qhardsigmoid but the result is not correct

* add qmv3 to test

* comment fix
  • Loading branch information
masahi authored and Trevor Morris committed May 6, 2021
1 parent 125e101 commit 2ac41f4
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 30 deletions.
36 changes: 30 additions & 6 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .. import expr as _expr
from .. import function as _function
from .. import op as _op
from .. import qnn
from ..ty import TupleType, TensorType, Any
from ..loops import while_loop
from .. import transform
Expand Down Expand Up @@ -805,14 +806,35 @@ def log_sigmoid(self, inputs, input_types):
data = inputs[0]
return _op.log(_op.tensor.sigmoid(data))

def hard_swish(self, inputs, input_types):
data = inputs[0]
dtype = input_types[0]
def hard_sigmoid(self, inputs, input_types):
def _relu6(x):
return _op.tensor.clip(x, 0.0, 6.0)

def _relu6(input_tensor):
return _op.tensor.clip(input_tensor, 0.0, 6.0)
def func(x):
return _relu6(x + _expr.const(3.0)) / _expr.const(6.0)

if self.is_quantized_tensor(inputs[0]):
input_scale = _expr.const(inputs[1])
input_zero_point = _expr.const(inputs[2])
# PyTorch seems to use the following output qparams, but accuracy
# is broken if we use this.
# TODO(masahi): Revisit this parameter choice
#
# Taken from src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
# output_scale = _expr.const(0.00390625) # 1.0 / 2^8
# output_zero_point = _expr.const(-128)
output_scale = input_scale
output_zero_point = input_zero_point

data = qnn.op.dequantize(inputs[0], input_scale, input_zero_point, axis=1)
out = func(data)
return qnn.op.quantize(out, output_scale, output_zero_point, out_dtype="uint8")

return func(inputs[0])

return data * _relu6(data + _expr.const(3.0, dtype=dtype)) / _expr.const(6.0, dtype=dtype)
def hard_swish(self, inputs, input_types):
data = inputs[0]
return data * self.hard_sigmoid(inputs, input_types)

def adaptive_avg_pool_2d(self, inputs, input_types):
data = inputs[0]
Expand Down Expand Up @@ -2418,6 +2440,8 @@ def create_convert_map(self):
"aten::__not__": self.logical_not,
"aten::hardswish_": self.hard_swish,
"aten::hardswish": self.hard_swish,
"aten::hardsigmoid_": self.hard_sigmoid,
"aten::hardsigmoid": self.hard_sigmoid,
"aten::cumsum": self.cumsum,
"aten::masked_fill": self.masked_fill,
"aten::masked_select": self.masked_select,
Expand Down
34 changes: 34 additions & 0 deletions python/tvm/relay/frontend/qnn_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def _get_quant_param_for_input(input_value):
"quantized::cat": (2, 3),
"quantized::mul_scalar": (2, 3),
"quantized::add_scalar": (2, 3),
"quantized::hardswish": (1, 2),
}

def dfs(current_node):
Expand Down Expand Up @@ -358,6 +359,8 @@ def add_input_quant_params_to_op_inputs(graph):
"quantized::add_scalar": 1,
"quantized::mul_scalar": 1,
"quantized::relu6": 1,
"quantized::hardswish": 1,
"aten::hardsigmoid": 1,
}

need_input_quant_param = set(num_quantized_inputs.keys())
Expand Down Expand Up @@ -765,6 +768,7 @@ def _impl(inputs, _):
out_zp = _expr.const(inputs[3])

if q_min > z - c_q or q_max < z - c_q:
# TODO(masahi): Replace this with integer only compute
dequant = relay.qnn.op.dequantize(inputs[0], _expr.const(s), _expr.const(z))
dequantized_add = _op.tensor.add(dequant, _expr.const(c_q * s))
return relay.qnn.op.quantize(
Expand Down Expand Up @@ -820,6 +824,35 @@ def _impl(inputs, _):
return _impl


def _hswish():
# refer to src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
# They fallback to fp32
def _impl(inputs, _):
assert len(inputs) == 5, "Input quant params not found in op inputs"
# TODO(masahi): Replace this with integer only compute.
# We do not have to strictly follow how PyTorch does it.

def relu6(x):
return _op.tensor.clip(x, 0.0, 6.0)

def hardsigmoid(x):
dtype = "float32"
return relu6(x + _expr.const(3.0, dtype=dtype)) / _expr.const(6.0, dtype=dtype)

output_scale = _expr.const(inputs[1])
output_zero_point = _expr.const(inputs[2])
input_scale = _expr.const(inputs[3])
input_zero_point = _expr.const(inputs[4])

dequant = relay.qnn.op.dequantize(inputs[0], input_scale, input_zero_point, axis=1)
dequantized_hswish = dequant * hardsigmoid(dequant)
return relay.qnn.op.quantize(
dequantized_hswish, output_scale, output_zero_point, out_dtype="uint8"
)

return _impl


def _linear_dynamic():
def _calculate_qparam(inp):
# reference ATen/native/quantized/cpu/qlinear_dynamic.cpp
Expand Down Expand Up @@ -906,4 +939,5 @@ def _impl(inputs, _):
"quantized::mul_scalar": _mul_scalar(),
"quantized::relu6": _relu6(),
"quantized::linear_dynamic": _linear_dynamic(),
"quantized::hardswish": _hswish(),
}
2 changes: 1 addition & 1 deletion src/relay/transforms/fold_explicit_padding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ class SimplifyExplicitPadding {
};

/*!
* \brief ImplicitPadding finds explict padding before an op that can
* \brief FoldExplicitPadding finds explict padding before an op that can
* support implicit padding and fuses them.
*/
Expr FoldExplicitPadding(const Expr& expr, const IRModule& mod) {
Expand Down
44 changes: 22 additions & 22 deletions tests/python/frontend/pytorch/qnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def torch_version_check():


def get_tvm_runtime(script_module, input_name, ishape):

input_shapes = [(input_name, ishape)]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)

Expand Down Expand Up @@ -125,43 +124,40 @@ def fuse_model(self):

# Mobilenet V3 related modules
class Hsigmoid(nn.Module):
def __init__(self, inplace=True, add_stub=False):
def __init__(self, add_stub=False):
super().__init__()
self.float_op = nn.quantized.FloatFunctional()
self.relu6 = nn.ReLU6(inplace=inplace)
self.quant = QuantStub()
self.dequant = DeQuantStub()
self.add_stub = add_stub
self.hsigmoid = nn.Hardsigmoid()

def forward(self, x):
if self.add_stub:
x = self.quant(x)
relu6 = self.relu6(self.float_op.add_scalar(x, 3.0))
mul = self.float_op.mul_scalar(relu6, 1 / 6.0)
x = self.hsigmoid(x)
if self.add_stub:
mul = self.dequant(mul)
return mul
x = self.dequant(x)
return x

def fuse_model(self):
pass


class Hswish(nn.Module):
def __init__(self, inplace=True, add_stub=False):
super(Hswish, self).__init__()
self.float_op = nn.quantized.FloatFunctional()
self.hsigmoid = Hsigmoid(inplace, add_stub=False)
def __init__(self, add_stub=False):
super().__init__()
self.quant = QuantStub()
self.dequant = DeQuantStub()
self.add_stub = add_stub
self.hswish = nn.Hardswish()

def forward(self, x):
if self.add_stub:
x = self.quant(x)
mul = self.float_op.mul(x, self.hsigmoid(x))
x = self.hswish(x)
if self.add_stub:
mul = self.dequant(mul)
return mul
x = self.dequant(x)
return x

def fuse_model(self):
pass
Expand Down Expand Up @@ -274,18 +270,12 @@ def test_quantized_modules():
("conv_bn_relu" + postfix, imagenet_ishape, ConvBn(with_relu=True), per_channel),
("linear" + postfix, (16, 16), Linear(), per_channel),
("linear_relu" + postfix, (16, 16), Linear(with_relu=True), per_channel),
]

if torch_version_check():
qmodules += [
("hsigmoid", imagenet_ishape, Hsigmoid(add_stub=True), False),
("hswish", imagenet_ishape, Hswish(add_stub=True), False),
("semodule", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), False),
("semodule, per_channel", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), True),
("mul_scalar negative", imagenet_ishape, MulScalarNegative(), False),
]
else:
print("Skipping tests that require torch > 1.4")

for (module_name, ishape, raw_module, per_channel) in qmodules:
raw_module.eval()
Expand Down Expand Up @@ -372,6 +362,13 @@ def get_imagenet_input():
# ("googlenet", qgooglenet(pretrained=True), per_channel),
]

if is_version_greater_than("1.7.1"):
from torchvision.models.quantization import mobilenet_v3_large as qmobilenet_v3_large

qmodels.append(
("mobilenet_v3_large", qmobilenet_v3_large(pretrained=True, quantize=True).eval(), True)
)

results = []

for (model_name, raw_model, per_channel) in qmodels:
Expand All @@ -385,7 +382,10 @@ def get_imagenet_input():
inp = get_imagenet_input()
pt_inp = torch.from_numpy(inp)

quantize_model(raw_model, pt_inp, per_channel=per_channel)
if "mobilenet_v3_large" not in model_name:
# mv3 was qat-ed, quantize=True option above makes it already quantized
quantize_model(raw_model, pt_inp, per_channel=per_channel)

script_module = torch.jit.trace(raw_model, pt_inp).eval()

with torch.no_grad():
Expand Down
10 changes: 9 additions & 1 deletion tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3651,6 +3651,13 @@ def test_hard_swish():
verify_model(torch.nn.Hardswish(inplace=True).eval(), input_data=input)


def test_hard_sigmoid():
examples = [torch.rand(8).float(), torch.rand(8, 10).float(), torch.rand(1, 1, 10).float()]
for input in examples:
verify_model(torch.nn.Hardsigmoid().eval(), input_data=input)
verify_model(torch.nn.Hardsigmoid(inplace=True).eval(), input_data=input)


def test_cumsum():
def test_fn(dim, dtype=None):
return lambda x: torch.cumsum(x, dim=dim, dtype=dtype)
Expand Down Expand Up @@ -3893,6 +3900,8 @@ def test_fn(is_sorted, return_inverse, return_counts):
test_logical_and()
test_masked_select()
test_unique()
test_hard_swish()
test_hard_sigmoid()

# Model tests
test_resnet18()
Expand Down Expand Up @@ -3931,4 +3940,3 @@ def test_fn(is_sorted, return_inverse, return_counts):

# Test convert torch script(jit) with specific inputs' types
test_convert_torch_script_with_input_types()
test_hard_swish()

0 comments on commit 2ac41f4

Please sign in to comment.