Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch, QNN] Support quantized mobilenet v3 from torch 1.8 #7606

Merged
merged 5 commits into from
Mar 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .. import expr as _expr
from .. import function as _function
from .. import op as _op
from .. import qnn
from ..ty import TupleType, TensorType, Any
from ..loops import while_loop
from .. import transform
Expand Down Expand Up @@ -805,14 +806,35 @@ def log_sigmoid(self, inputs, input_types):
data = inputs[0]
return _op.log(_op.tensor.sigmoid(data))

def hard_swish(self, inputs, input_types):
data = inputs[0]
dtype = input_types[0]
def hard_sigmoid(self, inputs, input_types):
def _relu6(x):
return _op.tensor.clip(x, 0.0, 6.0)

def _relu6(input_tensor):
return _op.tensor.clip(input_tensor, 0.0, 6.0)
def func(x):
return _relu6(x + _expr.const(3.0)) / _expr.const(6.0)

if self.is_quantized_tensor(inputs[0]):
input_scale = _expr.const(inputs[1])
input_zero_point = _expr.const(inputs[2])
# PyTorch seems to use the following output qparams, but accuracy
# is broken if we use this.
# TODO(masahi): Revisit this parameter choice
#
# Taken from src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
# output_scale = _expr.const(0.00390625) # 1.0 / 2^8
# output_zero_point = _expr.const(-128)
output_scale = input_scale
output_zero_point = input_zero_point

data = qnn.op.dequantize(inputs[0], input_scale, input_zero_point, axis=1)
out = func(data)
return qnn.op.quantize(out, output_scale, output_zero_point, out_dtype="uint8")

return func(inputs[0])

return data * _relu6(data + _expr.const(3.0, dtype=dtype)) / _expr.const(6.0, dtype=dtype)
def hard_swish(self, inputs, input_types):
data = inputs[0]
return data * self.hard_sigmoid(inputs, input_types)

def adaptive_avg_pool_2d(self, inputs, input_types):
data = inputs[0]
Expand Down Expand Up @@ -2418,6 +2440,8 @@ def create_convert_map(self):
"aten::__not__": self.logical_not,
"aten::hardswish_": self.hard_swish,
"aten::hardswish": self.hard_swish,
"aten::hardsigmoid_": self.hard_sigmoid,
"aten::hardsigmoid": self.hard_sigmoid,
"aten::cumsum": self.cumsum,
"aten::masked_fill": self.masked_fill,
"aten::masked_select": self.masked_select,
Expand Down
34 changes: 34 additions & 0 deletions python/tvm/relay/frontend/qnn_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def _get_quant_param_for_input(input_value):
"quantized::cat": (2, 3),
"quantized::mul_scalar": (2, 3),
"quantized::add_scalar": (2, 3),
"quantized::hardswish": (1, 2),
}

def dfs(current_node):
Expand Down Expand Up @@ -358,6 +359,8 @@ def add_input_quant_params_to_op_inputs(graph):
"quantized::add_scalar": 1,
"quantized::mul_scalar": 1,
"quantized::relu6": 1,
"quantized::hardswish": 1,
"aten::hardsigmoid": 1,
}

need_input_quant_param = set(num_quantized_inputs.keys())
Expand Down Expand Up @@ -765,6 +768,7 @@ def _impl(inputs, _):
out_zp = _expr.const(inputs[3])

if q_min > z - c_q or q_max < z - c_q:
# TODO(masahi): Replace this with integer only compute
dequant = relay.qnn.op.dequantize(inputs[0], _expr.const(s), _expr.const(z))
dequantized_add = _op.tensor.add(dequant, _expr.const(c_q * s))
return relay.qnn.op.quantize(
Expand Down Expand Up @@ -820,6 +824,35 @@ def _impl(inputs, _):
return _impl


def _hswish():
# refer to src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
# They fallback to fp32
def _impl(inputs, _):
assert len(inputs) == 5, "Input quant params not found in op inputs"
# TODO(masahi): Replace this with integer only compute.
# We do not have to strictly follow how PyTorch does it.

def relu6(x):
return _op.tensor.clip(x, 0.0, 6.0)

def hardsigmoid(x):
dtype = "float32"
return relu6(x + _expr.const(3.0, dtype=dtype)) / _expr.const(6.0, dtype=dtype)

output_scale = _expr.const(inputs[1])
output_zero_point = _expr.const(inputs[2])
input_scale = _expr.const(inputs[3])
input_zero_point = _expr.const(inputs[4])

dequant = relay.qnn.op.dequantize(inputs[0], input_scale, input_zero_point, axis=1)
dequantized_hswish = dequant * hardsigmoid(dequant)
return relay.qnn.op.quantize(
dequantized_hswish, output_scale, output_zero_point, out_dtype="uint8"
)

return _impl


def _linear_dynamic():
def _calculate_qparam(inp):
# reference ATen/native/quantized/cpu/qlinear_dynamic.cpp
Expand Down Expand Up @@ -906,4 +939,5 @@ def _impl(inputs, _):
"quantized::mul_scalar": _mul_scalar(),
"quantized::relu6": _relu6(),
"quantized::linear_dynamic": _linear_dynamic(),
"quantized::hardswish": _hswish(),
}
2 changes: 1 addition & 1 deletion src/relay/transforms/fold_explicit_padding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ class SimplifyExplicitPadding {
};

/*!
* \brief ImplicitPadding finds explict padding before an op that can
* \brief FoldExplicitPadding finds explict padding before an op that can
* support implicit padding and fuses them.
*/
Expr FoldExplicitPadding(const Expr& expr, const IRModule& mod) {
Expand Down
44 changes: 22 additions & 22 deletions tests/python/frontend/pytorch/qnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def torch_version_check():


def get_tvm_runtime(script_module, input_name, ishape):

input_shapes = [(input_name, ishape)]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)

Expand Down Expand Up @@ -125,43 +124,40 @@ def fuse_model(self):

# Mobilenet V3 related modules
class Hsigmoid(nn.Module):
def __init__(self, inplace=True, add_stub=False):
def __init__(self, add_stub=False):
super().__init__()
self.float_op = nn.quantized.FloatFunctional()
self.relu6 = nn.ReLU6(inplace=inplace)
self.quant = QuantStub()
self.dequant = DeQuantStub()
self.add_stub = add_stub
self.hsigmoid = nn.Hardsigmoid()

def forward(self, x):
if self.add_stub:
x = self.quant(x)
relu6 = self.relu6(self.float_op.add_scalar(x, 3.0))
mul = self.float_op.mul_scalar(relu6, 1 / 6.0)
x = self.hsigmoid(x)
if self.add_stub:
mul = self.dequant(mul)
return mul
x = self.dequant(x)
return x

def fuse_model(self):
pass


class Hswish(nn.Module):
def __init__(self, inplace=True, add_stub=False):
super(Hswish, self).__init__()
self.float_op = nn.quantized.FloatFunctional()
self.hsigmoid = Hsigmoid(inplace, add_stub=False)
def __init__(self, add_stub=False):
super().__init__()
self.quant = QuantStub()
self.dequant = DeQuantStub()
self.add_stub = add_stub
self.hswish = nn.Hardswish()

def forward(self, x):
if self.add_stub:
x = self.quant(x)
mul = self.float_op.mul(x, self.hsigmoid(x))
x = self.hswish(x)
if self.add_stub:
mul = self.dequant(mul)
return mul
x = self.dequant(x)
return x

def fuse_model(self):
pass
Expand Down Expand Up @@ -274,18 +270,12 @@ def test_quantized_modules():
("conv_bn_relu" + postfix, imagenet_ishape, ConvBn(with_relu=True), per_channel),
("linear" + postfix, (16, 16), Linear(), per_channel),
("linear_relu" + postfix, (16, 16), Linear(with_relu=True), per_channel),
]

if torch_version_check():
qmodules += [
("hsigmoid", imagenet_ishape, Hsigmoid(add_stub=True), False),
("hswish", imagenet_ishape, Hswish(add_stub=True), False),
("semodule", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), False),
("semodule, per_channel", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), True),
("mul_scalar negative", imagenet_ishape, MulScalarNegative(), False),
]
else:
print("Skipping tests that require torch > 1.4")

for (module_name, ishape, raw_module, per_channel) in qmodules:
raw_module.eval()
Expand Down Expand Up @@ -372,6 +362,13 @@ def get_imagenet_input():
# ("googlenet", qgooglenet(pretrained=True), per_channel),
]

if is_version_greater_than("1.7.1"):
from torchvision.models.quantization import mobilenet_v3_large as qmobilenet_v3_large

qmodels.append(
("mobilenet_v3_large", qmobilenet_v3_large(pretrained=True, quantize=True).eval(), True)
)

results = []

for (model_name, raw_model, per_channel) in qmodels:
Expand All @@ -385,7 +382,10 @@ def get_imagenet_input():
inp = get_imagenet_input()
pt_inp = torch.from_numpy(inp)

quantize_model(raw_model, pt_inp, per_channel=per_channel)
if "mobilenet_v3_large" not in model_name:
# mv3 was qat-ed, quantize=True option above makes it already quantized
quantize_model(raw_model, pt_inp, per_channel=per_channel)

script_module = torch.jit.trace(raw_model, pt_inp).eval()

with torch.no_grad():
Expand Down
10 changes: 9 additions & 1 deletion tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3651,6 +3651,13 @@ def test_hard_swish():
verify_model(torch.nn.Hardswish(inplace=True).eval(), input_data=input)


def test_hard_sigmoid():
examples = [torch.rand(8).float(), torch.rand(8, 10).float(), torch.rand(1, 1, 10).float()]
for input in examples:
verify_model(torch.nn.Hardsigmoid().eval(), input_data=input)
verify_model(torch.nn.Hardsigmoid(inplace=True).eval(), input_data=input)


def test_cumsum():
def test_fn(dim, dtype=None):
return lambda x: torch.cumsum(x, dim=dim, dtype=dtype)
Expand Down Expand Up @@ -3893,6 +3900,8 @@ def test_fn(is_sorted, return_inverse, return_counts):
test_logical_and()
test_masked_select()
test_unique()
test_hard_swish()
test_hard_sigmoid()

# Model tests
test_resnet18()
Expand Down Expand Up @@ -3931,4 +3940,3 @@ def test_fn(is_sorted, return_inverse, return_counts):

# Test convert torch script(jit) with specific inputs' types
test_convert_torch_script_with_input_types()
test_hard_swish()