From f665aac3ba8510661c215709d74b0b221e7ee3ad Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Fri, 16 Aug 2019 15:15:22 -0700 Subject: [PATCH 1/4] Init --- python/tvm/relay/quantize/quantize.py | 14 ++- tutorials/quantize_model.py | 134 ++++++++++++++++++++++++++ 2 files changed, 140 insertions(+), 8 deletions(-) create mode 100644 tutorials/quantize_model.py diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index adde2058267a..2ab8b79484ff 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -387,7 +387,7 @@ def _bind_params(func, params): return _expr.bind(func, bind_dict) -def prerequisite_optimize(graph, params=None): +def prerequisite_optimize(mod, params=None): """ Prerequisite optimization passes for quantization. Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and "CanonicalizeOps" optimization before quantization. """ @@ -398,15 +398,14 @@ def prerequisite_optimize(graph, params=None): _transform.FoldConstant()]) if params: - graph = _bind_params(graph, params) + mod['main'] = _bind_params(mod['main'], params) - mod = _module.Module.from_expr(graph) with _transform.PassContext(opt_level=3): mod = optimize(mod) - return mod["main"] + return mod -def quantize(graph, params=None, dataset=None): +def quantize(mod, params=None, dataset=None): """ The quantization procedure. Before running the three main procedure of quantization, "annotate", "calibrate" and "realize" , we need to do "SimplifyInference", "FoldScaleAxis", "FoldConstant" @@ -429,9 +428,8 @@ def quantize(graph, params=None, dataset=None): ret: Function The graph after quantization """ - graph = prerequisite_optimize(graph, params) + mod = prerequisite_optimize(mod, params) - mod = _module.Module.from_expr(graph) calibrate_pass = _transform.function_pass(calibrate, opt_level=1, name="QuantizeCalibrate") quant_passes = [partition(), @@ -448,4 +446,4 @@ def quantize(graph, params=None, dataset=None): with quantize_context(): mod = quantize_seq(mod) - return mod["main"] + return mod diff --git a/tutorials/quantize_model.py b/tutorials/quantize_model.py new file mode 100644 index 000000000000..e9a89e0a22c0 --- /dev/null +++ b/tutorials/quantize_model.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +.. _tutorial-quantize-model: + +Speed Up Inference and Compress Model with Quantization +=========================================== +**Author**: `Ziheng Jiang `_ + +This is an example to speed up and compress +a ResNet model with quantization. +""" + +import tvm +import tvm.relay as relay +from tvm import rpc +from tvm.contrib import util, graph_runtime as runtime +from tvm.contrib.download import download_testdata + +from mxnet.gluon.model_zoo.vision import get_model +from PIL import Image +import numpy as np +# get model + +# one line to get the model +block = get_model('resnet18_v1', pretrained=True) + +img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true' +img_name = 'cat.png' +img_path = download_testdata(img_url, img_name, module='data') +image = Image.open(img_path).resize((224, 224)) + +def transform_image(image): + image = np.array(image) - np.array([123., 117., 104.]) + image /= np.array([58.395, 57.12, 57.375]) + image = image.transpose((2, 0, 1)) + image = image[np.newaxis, :] + return image + +x = transform_image(image) + + +shape_dict = {'data': x.shape} +mod, params = relay.frontend.from_mxnet(block, shape_dict) + + +local_demo = True + +target = tvm.target.create('llvm') + +with relay.build_config(opt_level=3): + graph, lib, params = relay.build(mod, target, params=params) + +def evaluate_inference_speed(graph, lib, params): + tmp = util.tempdir() + lib_fname = tmp.relpath('net.tar') + lib.export_library(lib_fname) + + if local_demo: + remote = rpc.LocalSession() + else: + # The following is my environment, change this to the IP address of your target device + host = '10.77.1.162' + port = 9090 + remote = rpc.connect(host, port) + + # upload the library to remote device and load it + remote.upload(lib_fname) + rlib = remote.load_module('net.tar') + + # create the remote runtime module + ctx = remote.cpu(0) + module = runtime.create(graph, rlib, ctx) + # set parameter (upload params to the remote device. This may take a while) + module.set_input(**params) + # set input data + module.set_input('data', tvm.nd.array(x.astype('float32'))) + # run + + ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=10) + prof_res = np.array(ftimer().results) * 1000 # convert to millisecond + print("Mean inference time (std dev): %.2f ms (%.2f ms)" % + (np.mean(prof_res), np.std(prof_res))) + + +evaluate_inference_speed(graph, lib, params) + +import tvm.relay.quantize as qtz + +qconfig_i8_i32 = qtz.qconfig(skip_conv_layers=[0], + nbit_input=8, + nbit_weight=8, + global_scale=4.0, + dtype_input="int8", + dtype_weight="int8", + dtype_activation="int32", + do_simulation=False) + +# explain configures + + +with qconfig_i8_i32: + mod = qtz.quantize(mod, params) + + +# compare origin size and quantized size + +# compare origin speed and quantized speed + +def profile_speed_and_size(): + pass + +# compare origin speed and i16 speed + +qconfig_i8_i16 + + +# How do we get those model + +# search configure on Machine From 5427c86ba44ff3abbec53b5f3f1ed159fb665a69 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Fri, 23 Aug 2019 15:04:16 -0700 Subject: [PATCH 2/4] Update --- python/tvm/relay/op/_tensor.py | 1 + python/tvm/relay/op/annotation/annotation.py | 16 + python/tvm/relay/quantize/__init__.py | 2 +- python/tvm/relay/quantize/_annotate.py | 213 +++++++--- python/tvm/relay/quantize/_quantized_ops.py | 104 +++++ python/tvm/relay/quantize/calibrate.py | 297 ++++++++++++++ python/tvm/relay/quantize/kl_divergence.py | 124 ------ python/tvm/relay/quantize/quantize.py | 138 ++----- src/relay/op/annotation/annotation.cc | 7 + src/relay/pass/forward_rewrite.cc | 4 +- src/relay/pass/quantize/annotate.cc | 16 +- src/relay/pass/quantize/calibrate.cc | 8 +- src/relay/pass/quantize/partition.cc | 2 + src/relay/pass/quantize/quantize.cc | 3 +- src/relay/pass/quantize/quantize.h | 7 +- src/relay/pass/quantize/realize.cc | 396 ++++++++++--------- tutorials/quantize_model.py | 134 ------- 17 files changed, 827 insertions(+), 645 deletions(-) create mode 100644 python/tvm/relay/quantize/_quantized_ops.py create mode 100644 python/tvm/relay/quantize/calibrate.py delete mode 100644 python/tvm/relay/quantize/kl_divergence.py delete mode 100644 tutorials/quantize_model.py diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 176def347042..c5a6c413ca70 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -24,6 +24,7 @@ schedule_broadcast = schedule_injective schedule_elemwise = schedule_injective +register_schedule("annotation.cast_hint", schedule_broadcast) register_schedule("log", schedule_broadcast) register_schedule("log1p", schedule_broadcast) register_schedule("cos", schedule_broadcast) diff --git a/python/tvm/relay/op/annotation/annotation.py b/python/tvm/relay/op/annotation/annotation.py index 10c898538596..4d0f9c453acb 100644 --- a/python/tvm/relay/op/annotation/annotation.py +++ b/python/tvm/relay/op/annotation/annotation.py @@ -61,3 +61,19 @@ def stop_fusion(data): The annotated expression. """ return _make.stop_fusion(data) + + +def cast_hint(data, dtype): + """Annotate an expression to prevent it being fused with previous expressions. + + Parameters + ---------- + data : tvm.relay.Expr + The expression to be annotated. + + Returns + ------- + result : tvm.relay.Expr + The annotated expression. + """ + return _make.cast_hint(data, dtype) diff --git a/python/tvm/relay/quantize/__init__.py b/python/tvm/relay/quantize/__init__.py index 29b68950fa42..aaacd1d918f5 100644 --- a/python/tvm/relay/quantize/__init__.py +++ b/python/tvm/relay/quantize/__init__.py @@ -19,6 +19,6 @@ from __future__ import absolute_import as _abs from .quantize import * +from . import _quantized_ops from ._partition import register_partition_function from ._annotate import register_annotate_function -from .kl_divergence import kl_divergence_scale diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index 55f3597881e7..f08922469df9 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -19,46 +19,18 @@ from __future__ import absolute_import import warnings -import topi from ..._ffi.function import register_func from .. import expr as _expr from .. import analysis as _analysis from .. import op as _op from ..op import op as _reg from ..base import register_relay_node +from ..op.annotation import cast_hint from . import _quantize from .quantize import QAnnotateKind, current_qconfig, quantize_context from .quantize import _forward_op -@_reg.register_compute("relay.op.annotation.simulated_quantize") -def simulated_quantize_compute(attrs, inputs, out_type, target): - """Compiler for simulated_quantize.""" - assert len(inputs) == 4 - assert attrs.sign - assert attrs.rounding == "round" - - data, scale, clip_min, clip_max = inputs - - if attrs.kind == QAnnotateKind.IDENTITY: - return [topi.identity(data)] - - # simulate rounding error - scaled_data = topi.divide(data, scale) - clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min) - round_data = topi.round(clipped_data) - - # recover data - rdata = topi.multiply(round_data, scale) - return [rdata] - - -_reg.register_schedule("relay.op.annotation.simulated_quantize", - _reg.schedule_injective) -_reg.register_pattern("relay.op.annotation.simulated_quantize", - _reg.OpPattern.ELEMWISE) - - @register_relay_node class QAnnotateExpr(_expr.TempExpr): """A special kind of Expr for Annotating. @@ -146,6 +118,23 @@ def attach_simulated_quantize(data, kind, sign=True, rounding="round"): register_func("relay.quantize.attach_simulated_quantize", attach_simulated_quantize) +@register_annotate_function("annotation.cast_hint") +def cast_hint_rewrite(ref_call, new_args, ctx): + """Rewrite function to force cast""" + expr, x_kind = _get_expr_kind(new_args[0]) + + if quantize_context().check_to_skip(ref_call): + return expr + + if x_kind is None: + return new_args[0] + if x_kind == QAnnotateKind.ACTIVATION: + expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT) + + expr = _forward_op(ref_call, [expr]) + return QAnnotateExpr(expr, QAnnotateKind.INPUT) + + @register_annotate_function("nn.contrib_conv2d_NCHWc") def conv2d_nchwc_rewrite(ref_call, new_args, ctx): warnings.warn("NCHWc layout Conv2D detected, please use a lower " @@ -155,20 +144,24 @@ def conv2d_nchwc_rewrite(ref_call, new_args, ctx): @register_annotate_function("nn.conv2d") def conv2d_rewrite(ref_call, new_args, ctx): - """Rewrite function for conv2d. Lhs of conv will be quantized to - input field, and rhs of conv will be quantized to weight field. - Output would be in activation field""" + """Rewrite function for conv2d. Allowed combination: + - lhs[nbit_input, dtype_input], rhs[nbit_weight, dtype_weight] -> + out[x, dtype_activation] + """ if quantize_context().check_to_skip(ref_call): return None lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) + qcfg = current_qconfig() if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION: lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) + lhs_expr = cast_hint(lhs_expr, qcfg.dtype_input) assert rhs_kind is None rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) + rhs_expr = cast_hint(rhs_expr, qcfg.dtype_weight) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) @@ -186,11 +179,14 @@ def dense_rewrite(ref_call, new_args, ctx): lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) + qcfg = current_qconfig() if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION: lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) + lhs_expr = cast_hint(lhs_expr, qcfg.dtype_input) assert rhs_kind is None rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) + rhs_expr = cast_hint(rhs_expr, qcfg.dtype_weight) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) @@ -199,7 +195,48 @@ def dense_rewrite(ref_call, new_args, ctx): @register_annotate_function("multiply") def multiply_rewrite(ref_call, new_args, ctx): - """Rewrite function for multiply.""" + """Rewrite function for multiply. + Allowed combination: + - lhs[nbit_input, dtype_activation] * rhs[nbit_weight/nbit_input, dtype_activation] + -> out[x, dtype_activation] + """ + if quantize_context().check_to_skip(ref_call): + return None + + lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) + rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) + + if lhs_kind is None and rhs_kind is None: + return None + + qcfg = current_qconfig() + # for now, only support multiply bias transformed by batch_norm + assert rhs_kind is None + rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) + rhs_expr = cast_hint(rhs_expr, qcfg.dtype_activation) + + # print('multiply lhs: {0}'.format(lhs_kind)) + # print('multiply lhs: \n{0}'.format(lhs_expr)) + + if lhs_kind is QAnnotateKind.ACTIVATION: + lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) + lhs_expr = cast_hint(lhs_expr, qcfg.dtype_activation) + + expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) + # print('multiply out: \n{0}'.format(expr.astext(show_meta_data=False))) + return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) + + +@register_annotate_function("add") +def new_add_rewrite(ref_call, new_args, ctx): + """Rewrite function for add. Allowed combinations: + - lhs[*, dtype_activation], rhs[nbit_weight, dtype_activation] -> + out[*, dtype_activation] + - lhs[nbit_input, dtype_activation], rhs[nbit_weight, dtype_activation] -> + out[*, dtype_activation] + - lhs[nbit_input, dtype_input], rhs[nbit_input, dtype_input] -> + out[*, dtype_input] + """ if quantize_context().check_to_skip(ref_call): return None @@ -207,23 +244,84 @@ def multiply_rewrite(ref_call, new_args, ctx): rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) if lhs_kind is None and rhs_kind is None: + # trivial case return None - if lhs_kind in [QAnnotateKind.ACTIVATION, QAnnotateKind.INPUT] and rhs_kind is None: - # quantize lhs to INPUT field - if lhs_kind == QAnnotateKind.ACTIVATION: - lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) - # quantize rhs to WEIGHT field - rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) + qcfg = current_qconfig() + + # unify lhs and rhs to the same dom_scale + dom_scale = _expr.var("dom_scale") + clip_min = _expr.var("clip_min") + clip_max = _expr.var("clip_max") + lhs_expr = _quantize.simulated_quantize(lhs_expr, + dom_scale, clip_min, clip_max, QAnnotateKind.INPUT, True, 'round') + rhs_expr = _quantize.simulated_quantize(rhs_expr, + dom_scale, clip_min, clip_max, QAnnotateKind.INPUT, True, 'round') + + if lhs_kind is QAnnotateKind.ACTIVATION and rhs_kind is None: + # introduced by bias_add from batch_norm (resnet18_v1) + rhs_expr = cast_hint(rhs_expr, qcfg.dtype_activation) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) + if lhs_kind is None and rhs_kind is QAnnotateKind.INPUT: + # introduced by residual addition, lhs is a skipped layer (resnet18_v1) + lhs_expr = cast_hint(lhs_expr, qcfg.dtype_input) + expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) + return QAnnotateExpr(expr, QAnnotateKind.INPUT) + + if lhs_kind is QAnnotateKind.INPUT and rhs_kind is None: + # introduced by residual addition, rhs is a skipped layer (resnet18_v2) + rhs_expr = cast_hint(rhs_expr, qcfg.dtype_input) + expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) + return QAnnotateExpr(expr, QAnnotateKind.INPUT) + + if lhs_kind is QAnnotateKind.INPUT and rhs_kind is QAnnotateKind.INPUT: + # introduced by residual addition (resnet18_v1) + expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) + return QAnnotateExpr(expr, QAnnotateKind.INPUT) + + print('lhs: {0}'.format(lhs_expr)) + print('rhs: {0}'.format(rhs_expr)) + print('lhs: {0}'.format(lhs_kind)) + print('rhs: {0}'.format(rhs_kind)) + raise ValueError +# dom_scale = scale / 2^(valid_bit) + +# simulation +# lhs = sq(lhs, dom_scale, clip_min, clip_max) +# lhs = cast_hint(lhs, dtype) +# rhs = sq(rhs, dom_scale, clip_min, clip_max) +# rhs = cast_hint(rhs, dtype) +# out = lhs + rhs + +# realization +# lhs = lhs * ldom_scale / odom_scale +# lshift(lhs, log2(ldom_scale / odom_scale)) +# overflow risk +# lhs = lhs * dom_lscale / odom_scale +# rhs = adjust(rhs, dom_rscale, oscale, nbit) + + +# quantized_add(lhs, rhs, odom_scale, clip_min, clip_max) +# during simulation +# out = lhs + rhs +# scaled_out = out / odom_scale +# truncate(scaled_out, clip_min, clip_max) + +# during realization +# lhs = lhs * ldom_scale / odom_scale +# lshift(lhs, log2(ldom_scale / odom_scale)) +# lhs = lhs * dom_lscale / odom_scale +# rhs = adjust(rhs, dom_rscale, oscale, nbit) + +# dom_scale + + -@register_annotate_function("add") def add_rewrite(ref_call, new_args, ctx): - """Rewrite function for add.""" if quantize_context().check_to_skip(ref_call): return None @@ -243,12 +341,15 @@ def add_rewrite(ref_call, new_args, ctx): if lhs_kind is not None and rhs_kind is None: if _analysis.check_constant(rhs_expr): - # - introduced by batch_norm: add(out, const) + # introduced by batch_norm: add(out, const) rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) else: + # happens in residual addition when the rhs is a skipped layer rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) - return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) + # print('add lhs_kind: {0}'.format(lhs_kind)) + # print('add out:\n{0}'.format(expr.astext(show_meta_data=False))) + return QAnnotateExpr(expr, lhs_kind) if lhs_kind is not None and rhs_kind is not None: if lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.INPUT: @@ -281,11 +382,11 @@ def identity_rewrite(ref_call, new_args, ctx): register_annotate_function("clip", identity_rewrite) register_annotate_function("nn.relu", identity_rewrite) register_annotate_function("strided_slice", identity_rewrite) -register_annotate_function("nn.avg_pool2d", identity_rewrite) register_annotate_function("annotation.stop_fusion", identity_rewrite) -def pool2d_rewrite(ref_call, new_args, ctx): +@register_annotate_function("nn.max_pool2d") +def max_pool2d_rewrite(ref_call, new_args, ctx): """Rewrite function for max pool2d""" if quantize_context().check_to_skip(ref_call): return None @@ -296,29 +397,27 @@ def pool2d_rewrite(ref_call, new_args, ctx): return None if x_kind == QAnnotateKind.ACTIVATION: expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT) + expr = cast_hint(expr, current_qconfig().dtype_input) expr = _forward_op(ref_call, [expr]) return QAnnotateExpr(expr, QAnnotateKind.INPUT) -register_annotate_function("nn.max_pool2d", pool2d_rewrite) - +@register_annotate_function("nn.avg_pool2d") +def avg_pool2d_rewrite(ref_call, new_args, ctx): + """Rewrite function for avg_pool2d""" + if quantize_context().check_to_skip(ref_call): + return None -@register_annotate_function("annotation.cast_hint") -def cast_hint_rewrite(ref_call, new_args, ctx): - """Rewrite function to force cast""" expr, x_kind = _get_expr_kind(new_args[0]) - if quantize_context().check_to_skip(ref_call): - return expr - if x_kind is None: - return new_args[0] - if x_kind == QAnnotateKind.ACTIVATION: - expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT) + return None + if x_kind == QAnnotateKind.INPUT: + expr = cast_hint(expr, current_qconfig().dtype_activation) expr = _forward_op(ref_call, [expr]) - return QAnnotateExpr(expr, QAnnotateKind.INPUT) + return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) @register_annotate_function("concatenate") diff --git a/python/tvm/relay/quantize/_quantized_ops.py b/python/tvm/relay/quantize/_quantized_ops.py new file mode 100644 index 000000000000..a705d97098c2 --- /dev/null +++ b/python/tvm/relay/quantize/_quantized_ops.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +#pylint: disable=unused-argument +"""Internal module for quantization.""" +from __future__ import absolute_import +import math + +import topi +from .. import expr as _expr +from ..op import op as _reg +from .quantize import QAnnotateKind + +@_reg.register_compute("relay.op.annotation.simulated_quantize") +def simulated_quantize_compute(attrs, inputs, out_type, target): + """Compiler for simulated_quantize.""" + assert len(inputs) == 4 + assert attrs.sign + assert attrs.rounding == "round" + + data, scale, clip_min, clip_max = inputs + + if attrs.kind == QAnnotateKind.IDENTITY: + return [topi.identity(data)] + + # simulate rounding error + scaled_data = topi.divide(data, scale) + clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min) + round_data = topi.round(clipped_data) + + # recover data + rdata = topi.multiply(round_data, scale) + return [rdata] + + +_reg.register_schedule("relay.op.annotation.simulated_quantize", + _reg.schedule_injective) +_reg.register_pattern("relay.op.annotation.simulated_quantize", + _reg.OpPattern.ELEMWISE) + + +# dom_scale = scale / valid_range +# qdata * dom_scale = fdata + +def adjust_scale(data, from_scale, to_scale): + if from_scale == to_scale: + return data + + factor = from_scale / to_scale + shift_factor = math.log2(factor) + assert shift_factor > 0 + if isinstance(shift_factor, int): + out = topi.left_shift(data, shift_factor) + elif isinstance(factor, int): + out = topi.mulitply(data, factor) + else: + dtype = data.dtype + out = topi.cast(data, "float32") + out = topi.mulitply(data, factor) + out = topi.cast(out, dtype) + return out + + +def extract_scalar(tensor): + assert isinstance(tensor, _expr.Constant) + arr = tensor.value + assert arr.size == 1 + return arr[0] + + +# @_reg.register_compute("relay.op.quantize.quantized_add") +def quantized_add_compute(attrs, inputs, out_type, target): + """Compiler for simulated_quantize.""" + + assert len(inputs) == 5 + + lhs, rhs, dom_lscale, dom_rscale, dom_oscale = inputs + dom_lscale = extract_scalar(dom_lscale) + dom_rscale = extract_scalar(dom_rscale) + dom_oscale = extract_scalar(dom_oscale) + + lhs = adjust_scale(lhs, dom_lscale, dom_oscale) + rhs = adjust_scale(rhs, dom_rscale, dom_oscale) + out = lhs + rhs + return out + + +# _reg.register_schedule("relay.op.quantize.quantized_add", +# _reg.schedule_injective) +# _reg.register_pattern("relay.op.quantize.quantized_add", +# _reg.OpPattern.ELEMWISE) diff --git a/python/tvm/relay/quantize/calibrate.py b/python/tvm/relay/quantize/calibrate.py new file mode 100644 index 000000000000..5596eedab6a8 --- /dev/null +++ b/python/tvm/relay/quantize/calibrate.py @@ -0,0 +1,297 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Find optimal scale for quantization by minimizing KL-divergence""" +from __future__ import absolute_import +import numpy as np +import multiprocessing as mp +import logging +try: + import scipy +except ImportError: + scipy = None + +import tvm + +from . import _quantize +from . import quantize +from .. import op as _op +from .. import expr as _expr +from .. import module as _module +from .. import analysis as _analysis +from .. import transform as _transform +from .. import build_module as _build_module +from ...contrib import graph_runtime + + +def _smooth_distribution(p, eps=0.0001): + """Given a discrete distribution (may have not been normalized to 1), + smooth it by replacing zeros with eps multiplied by a scaling factor and taking the + corresponding amount off the non-zero values. + Ref: http://hanj.cs.illinois.edu/cs412/bk3/KL-divergence.pdf + """ + is_zeros = (p == 0).astype(np.float32) + is_nonzeros = (p != 0).astype(np.float32) + n_zeros = is_zeros.sum() + n_nonzeros = p.size - n_zeros + if not n_nonzeros: + raise ValueError('The discrete probability distribution is malformed. All entries are 0.') + eps1 = eps * float(n_zeros) / float(n_nonzeros) + assert eps1 < 1.0, 'n_zeros=%d, n_nonzeros=%d, eps1=%f' % (n_zeros, n_nonzeros, eps1) + hist = p.astype(np.float32) + hist += eps * is_zeros + (-eps1) * is_nonzeros + assert (hist <= 0).sum() == 0 + return hist + + +def _find_scale_by_kl(arr, + quantized_dtype='int8', + num_bins=8001, + num_quantized_bins=255): + """Given a tensor, find the optimal threshold for quantizing it. + The reference distribution is `q`, and the candidate distribution is `p`. + `q` is a truncated version of the original distribution. + + Ref: + http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf + """ + assert isinstance(arr, np.ndarray) + + min_val = np.min(arr) + max_val = np.max(arr) + th = max(abs(min_val), abs(max_val)) + + if min_val >= 0 and quantized_dtype in ['uint8']: + # We need to move negative bins to positive bins to fit uint8 range. + num_quantized_bins = num_quantized_bins * 2 + 1 + + hist, hist_edges = np.histogram(arr, bins=num_bins, range=(-th, th)) + zero_bin_idx = num_bins // 2 + num_half_quantized_bins = num_quantized_bins // 2 + + thresholds = np.zeros(num_bins // 2 + 1 - num_quantized_bins // 2) + divergence = np.zeros_like(thresholds) + quantized_bins = np.zeros(num_quantized_bins, dtype=np.int32) + # i means the number of bins on half axis excluding the zero bin. + for i in range(num_quantized_bins // 2, + num_bins // 2 + 1): + p_bin_idx_start = zero_bin_idx - i + p_bin_idx_stop = zero_bin_idx + i + 1 + thresholds[i - num_half_quantized_bins] = hist_edges[p_bin_idx_stop] + sliced_nd_hist = hist[p_bin_idx_start:p_bin_idx_stop] + + # generate reference distribution p + p = sliced_nd_hist.copy() + assert p.size % 2 == 1 + assert p.size >= num_quantized_bins + # put left outlier count in p[0] + left_outlier_count = np.sum(hist[0:p_bin_idx_start]) + p[0] += left_outlier_count + # put right outlier count in p[-1] + right_outlier_count = np.sum(hist[p_bin_idx_stop:]) + p[-1] += right_outlier_count + # is_nonzeros[k] indicates whether hist[k] is nonzero + is_nonzeros = (p != 0).astype(np.int32) + + # calculate how many bins should be merged to generate quantized distribution q + num_merged_bins = sliced_nd_hist.size // num_quantized_bins + # merge hist into num_quantized_bins bins + for j in range(num_quantized_bins): + start = j * num_merged_bins + stop = start + num_merged_bins + quantized_bins[j] = sliced_nd_hist[start:stop].sum() + quantized_bins[-1] += sliced_nd_hist[num_quantized_bins * num_merged_bins:].sum() + # expand quantized_bins into p.size bins + q = np.zeros(sliced_nd_hist.size, dtype=np.float32) + for j in range(num_quantized_bins): + start = j * num_merged_bins + if j == num_quantized_bins - 1: + stop = len(is_nonzeros) + else: + stop = start + num_merged_bins + norm = is_nonzeros[start:stop].sum() + if norm != 0: + q[start:stop] = float(quantized_bins[j]) / float(norm) + q[p == 0] = 0 + p = _smooth_distribution(p) + # There is a chance that q is an invalid probability distribution. + try: + q = _smooth_distribution(q) + except ValueError: + divergence[i - num_half_quantized_bins] = float("inf") + divergence[i - num_half_quantized_bins] = scipy.stats.entropy(p, q) + + min_divergence_idx = np.argmin(divergence) + opt_th = thresholds[min_divergence_idx] + return opt_th + + +def collect_stats(mod, dataset): + """Given an annotated graph, create a profile graph to collect profile data from the + calibration dataset. This pass collects simulated_quantize op input into a tuple. + Simulated_quantize ops are rewritten to identity mode. The tuple is the output of the profile + graph. + + Parameters + ---------- + graph: Function + The simulation graph after annotation. + + Returns + ------- + ret: Function + The profile graph which outputs a tuple of profile data. + """ + logging.info("collecting statistics for calibration...") + func = mod['main'] + func = _quantize.CreateStatsCollector(func) + with _transform.build_config(opt_level=3): + graph, lib, params = _build_module.build(func, target="llvm") + outputs = [] + runtime = graph_runtime.create(graph, lib, tvm.cpu()) + runtime.set_input(**params) + + num_outputs = runtime.get_num_outputs() + outputs = [[] for i in range(num_outputs)] + + for batch_id, batch in enumerate(dataset): + runtime.set_input(**batch) + runtime.run() + for i in range(num_outputs): + output = runtime.get_output(i).asnumpy() + outputs[i].append(output) + for i in range(num_outputs): + outputs[i] = np.concatenate(outputs[i]).reshape(-1) + return outputs + + +def _kl_scale(stats): + assert scipy is not None, "scipy need to be installed for \ + utilizing kl calibration during quantization" + with mp.Pool() as pool: + logging.info("finding threshold with kl for calibration...") + scales = list(pool.map(_find_scale_by_kl, stats)) + + def func(sq_call): + scale = scales[func.scale_idx] + func.scale_idx += 1 + return scale + func.scale_idx = 0 + + return func + + +def set_params(mod, input_scale_func, weight_scale_func): + quantize_op = _op.get("relay.op.annotation.simulated_quantize") + cfg = quantize.current_qconfig() + const_params = {} + + def visit_func(expr): + # TODO(ziheng) memorize, e.g. two sq share the same scales + if isinstance(expr, _expr.Call) and expr.op == quantize_op: + sq = expr + _, ndom_scale, nclip_min, nclip_max = sq.args + attrs = sq.attrs + kind = attrs.kind + nbit = cfg.get_nbit_by_kind(kind) + valid_bit = nbit - attrs.sign + + # set scale + if kind == quantize.QAnnotateKind.WEIGHT: + assert isinstance(sq.args[0], _expr.Constant) + scale = weight_scale_func(sq) + else: + scale = input_scale_func(sq) + + def _make_const(val): + return _expr.const(val, 'float32') + + valid_range = 2**valid_bit + const_params[ndom_scale] = _make_const(scale / valid_range) + const_params[nclip_min] = _make_const(- (valid_range - 1)) + const_params[nclip_max] = _make_const((valid_range - 1)) + + func = mod['main'] + _analysis.post_order_visit(func, visit_func) + func = _expr.bind(func, const_params) + return _module.Module.from_expr(func) + + +# weight scale functions +def _power2_scale(sq_call): + """calculate weight scale with nearest mode-2 scale""" + var = sq_call.args[0] + assert isinstance(var, _expr.Constant) + val = np.amax(np.abs(var.data.asnumpy())) + return 2**np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0 + +def _max_scale(sq_call): + """calculate weight scale with maximum absolute value""" + var = sq_call.args[0] + assert isinstance(var, _expr.Constant) + val = np.amax(np.abs(var.data.asnumpy())) + return val + + +# input scale functions +def _global_scale(sq_call): + cfg = quantize.current_qconfig() + return cfg.global_scale + + +def calibrate(dataset=None): + """The calibrate procedure will try to calculate the content of + dom_scale, nbit, clip_min, clip_max for every `simulated_quantize` + operator. + + Parameters + --------- + graph: Function + The simulation graph after annotation. + + mod: tvm.relay.Module + The module where calibration happens on. + + ctx: tvm.relay.PassContext + The pass context used for calibration. + + weight_scales: 'power2' or 'max'. + The way to calculate scales for weights (annotated with QAnnotateKind.WEIGHT). + power2: Find the maximum of the absolute value of the tensor, and then round up to power + of two. + max: Find the maximum of the absolute value of the tensor. + + scales: List[float] + Pre-calculated scales for input and activations. Length and the order of elements of the + scales list should match the output tuple of the profile graph created by collect_stats. + + Returns + ------- + ret: Function + The graph after calibration + """ + def wrapped_func(mod, ctx): + """make transform.module pass happy""" + cfg = quantize.current_qconfig() + + if cfg.calibrate_mode == 'kl': + stats = collect_stats(mod, dataset) + input_scale_func = _kl_scale(stats) + elif cfg.calibrate_mode == 'global_scale': + input_scale_func = _global_scale + + return set_params(mod, input_scale_func, _power2_scale) + return wrapped_func diff --git a/python/tvm/relay/quantize/kl_divergence.py b/python/tvm/relay/quantize/kl_divergence.py deleted file mode 100644 index bce45dca6f1c..000000000000 --- a/python/tvm/relay/quantize/kl_divergence.py +++ /dev/null @@ -1,124 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Find optimal scale for quantization by minimizing KL-divergence""" - -try: - from scipy import stats -except ImportError: - stats = None - -import numpy as np - - -def _smooth_distribution(p, eps=0.0001): - """Given a discrete distribution (may have not been normalized to 1), - smooth it by replacing zeros with eps multiplied by a scaling factor and taking the - corresponding amount off the non-zero values. - Ref: http://hanj.cs.illinois.edu/cs412/bk3/KL-divergence.pdf - """ - is_zeros = (p == 0).astype(np.float32) - is_nonzeros = (p != 0).astype(np.float32) - n_zeros = is_zeros.sum() - n_nonzeros = p.size - n_zeros - if not n_nonzeros: - raise ValueError('The discrete probability distribution is malformed. All entries are 0.') - eps1 = eps * float(n_zeros) / float(n_nonzeros) - assert eps1 < 1.0, 'n_zeros=%d, n_nonzeros=%d, eps1=%f' % (n_zeros, n_nonzeros, eps1) - hist = p.astype(np.float32) - hist += eps * is_zeros + (-eps1) * is_nonzeros - assert (hist <= 0).sum() == 0 - return hist - - -# pylint: disable=invalid-name -def kl_divergence_scale(arr, quantized_dtype='int8', num_bins=8001, num_quantized_bins=255): - """Given a tensor, find the optimal threshold for quantizing it. - The reference distribution is `q`, and the candidate distribution is `p`. - `q` is a truncated version of the original distribution. - - Ref: - http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf - """ - assert isinstance(arr, np.ndarray) - - min_val = np.min(arr) - max_val = np.max(arr) - th = max(abs(min_val), abs(max_val)) - - if min_val >= 0 and quantized_dtype in ['uint8']: - # We need to move negative bins to positive bins to fit uint8 range. - num_quantized_bins = num_quantized_bins * 2 + 1 - - hist, hist_edges = np.histogram(arr, bins=num_bins, range=(-th, th)) - zero_bin_idx = num_bins // 2 - num_half_quantized_bins = num_quantized_bins // 2 - - thresholds = np.zeros(num_bins // 2 + 1 - num_quantized_bins // 2) - divergence = np.zeros_like(thresholds) - quantized_bins = np.zeros(num_quantized_bins, dtype=np.int32) - # i means the number of bins on half axis excluding the zero bin. - for i in range(num_quantized_bins // 2, - num_bins // 2 + 1): - p_bin_idx_start = zero_bin_idx - i - p_bin_idx_stop = zero_bin_idx + i + 1 - thresholds[i - num_half_quantized_bins] = hist_edges[p_bin_idx_stop] - sliced_nd_hist = hist[p_bin_idx_start:p_bin_idx_stop] - - # generate reference distribution p - p = sliced_nd_hist.copy() - assert p.size % 2 == 1 - assert p.size >= num_quantized_bins - # put left outlier count in p[0] - left_outlier_count = np.sum(hist[0:p_bin_idx_start]) - p[0] += left_outlier_count - # put right outlier count in p[-1] - right_outlier_count = np.sum(hist[p_bin_idx_stop:]) - p[-1] += right_outlier_count - # is_nonzeros[k] indicates whether hist[k] is nonzero - is_nonzeros = (p != 0).astype(np.int32) - - # calculate how many bins should be merged to generate quantized distribution q - num_merged_bins = sliced_nd_hist.size // num_quantized_bins - # merge hist into num_quantized_bins bins - for j in range(num_quantized_bins): - start = j * num_merged_bins - stop = start + num_merged_bins - quantized_bins[j] = sliced_nd_hist[start:stop].sum() - quantized_bins[-1] += sliced_nd_hist[num_quantized_bins * num_merged_bins:].sum() - # expand quantized_bins into p.size bins - q = np.zeros(sliced_nd_hist.size, dtype=np.float32) - for j in range(num_quantized_bins): - start = j * num_merged_bins - if j == num_quantized_bins - 1: - stop = len(is_nonzeros) - else: - stop = start + num_merged_bins - norm = is_nonzeros[start:stop].sum() - if norm != 0: - q[start:stop] = float(quantized_bins[j]) / float(norm) - q[p == 0] = 0 - p = _smooth_distribution(p) - # There is a chance that q is an invalid probability distribution. - try: - q = _smooth_distribution(q) - except ValueError: - divergence[i - num_half_quantized_bins] = float("inf") - divergence[i - num_half_quantized_bins] = stats.entropy(p, q) - - min_divergence_idx = np.argmin(divergence) - opt_th = thresholds[min_divergence_idx] - return opt_th diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 2ab8b79484ff..cc51e34fdfd9 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -17,17 +17,16 @@ #pylint: disable=unused-argument """Automatic quantization toolkit.""" from __future__ import absolute_import -import numpy as np from . import _quantize +from . import calibrate as _calibrate from .. import expr as _expr -from .. import module as _module -from .. import analysis as _analysis from .. import transform as _transform -from .. import op as _op from ... import make as _make from ..base import NodeBase, register_relay_node +# TODO(contributor): remove kind in sq +# TODO(ziheng): refactor the infra to modulized pass class QAnnotateKind(object): """Denote the kind of annotation field, corresponding @@ -78,8 +77,9 @@ class QConfig(NodeBase): "dtype_input": "int8", "dtype_weight": "int8", "dtype_activation": "int32", - "global_scale": 8.0, "skip_conv_layers": [0], + "calibrate_mode": "global_scale", + "global_scale": 8.0, "do_simulation": False, "round_for_shift": True, "debug_enabled_ops": None, @@ -220,6 +220,22 @@ def quantize_context(): return QuantizeContext.Current +# NOTE: +# behavior of cast_hint +# partition part, will insert cast_hint to denote, which will be +# inserted simulated quantize during annotate + +# in some condition we need to defer the cast operartion +# will be transformed to real cast in realize +# but work as identity before realize + + +# behavior of quantized_add +# add wiil be transformed to quantized_add during annotate, +# odom_scale will be tuned by calibrate +# if simulated, only do addition, which is used before realize +# during realize, lhs and rhs's scale will be unified as odom_scale +# then do add def partition(): """Partition graph into small low-precision sections by `cast_hint` and `stop_fusion`. @@ -245,113 +261,6 @@ def annotate(): return _quantize.QuantizeAnnotate() -def collect_stats(graph): - """Given an annotated graph, create a profile graph to collect profile data from the - calibration dataset. This pass collects simulated_quantize op input into a tuple. - Simulated_quantize ops are rewritten to identity mode. The tuple is the output of the profile - graph. - - Parameters - ---------- - graph: Function - The simulation graph after annotation. - - Returns - ------- - ret: Function - The profile graph which outputs a tuple of profile data. - """ - return _quantize.CollectStats(graph) - - -def calibrate(graph, mod=None, ctx=None, weight_scales='power2', scales=None): - """The calibrate procedure will try to calculate the content of - dom_scale, nbit, clip_min, clip_max for every `simulated_quantize` - operator. - - Parameters - --------- - graph: Function - The simulation graph after annotation. - - mod: tvm.relay.Module - The module where calibration happens on. - - ctx: tvm.relay.PassContext - The pass context used for calibration. - - weight_scales: 'power2' or 'max'. - The way to calculate scales for weights (annotated with QAnnotateKind.WEIGHT). - power2: Find the maximum of the absolute value of the tensor, and then round up to power - of two. - max: Find the maximum of the absolute value of the tensor. - - scales: List[float] - Pre-calculated scales for input and activations. Length and the order of elements of the - scales list should match the output tuple of the profile graph created by collect_stats. - - Returns - ------- - ret: Function - The graph after calibration - """ - def power2_scale(arr): - """calculate weight scale with nearest mode-2 scale""" - val = np.amax(np.abs(arr.asnumpy())) - return 2**np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0 - - def max_scale(arr): - """calculate weight scale with maximum absolute value""" - val = np.amax(np.abs(arr.asnumpy())) - return val - - scale_idx = 0 - - cfg = current_qconfig() - const_params = {} - quantize_op = _op.get("relay.op.annotation.simulated_quantize") - - def visit_func(expr): - """Internal visit function""" - nonlocal scale_idx - if isinstance(expr, _expr.Call) and expr.op == quantize_op: - _, ndom_scale, nclip_min, nclip_max = expr.args - attrs = expr.attrs - kind = attrs.kind - nbit = cfg.get_nbit_by_kind(kind) - - valid_bit = nbit - attrs.sign - if kind in [QAnnotateKind.WEIGHT]: - if all([isinstance(arg, _expr.Constant) - for arg in [ndom_scale, nclip_min, nclip_max]]): - return - var = expr.args[0] - assert isinstance(var, _expr.Constant) - if weight_scales == 'max': - scale = max_scale(var.data) - elif weight_scales == 'power2': - scale = power2_scale(var.data) - else: - raise ValueError('{} not supported'.format(weight_scales)) - elif scales is not None: - scale = scales[scale_idx] - scale_idx += 1 - else: - scale = cfg.global_scale - - def _make_const(val): - return _expr.const(val, 'float32') - - valid_range = 2**valid_bit - const_params[ndom_scale] = _make_const(scale / valid_range) - const_params[nclip_min] = _make_const(- (valid_range - 1)) - const_params[nclip_max] = _make_const((valid_range - 1)) - - _analysis.post_order_visit(graph, visit_func) - ret = _expr.bind(graph, const_params) - return ret - - def realize(): """The realize pass will transform the simulated quantized graph, which actually computes with float32, to a real low-bit integer graph. It will @@ -430,8 +339,9 @@ def quantize(mod, params=None, dataset=None): """ mod = prerequisite_optimize(mod, params) - calibrate_pass = _transform.function_pass(calibrate, opt_level=1, - name="QuantizeCalibrate") + calibrate_pass = _transform.module_pass(_calibrate.calibrate(dataset), + opt_level=1, + name="QuantizeCalibrate") quant_passes = [partition(), annotate(), calibrate_pass] diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index eeacc6cbf999..bc1e358f8e6f 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -93,11 +93,18 @@ Expr CastHint(Expr data, DataType dtype) { return CallNode::make(op, {data}, Attrs{attrs}, {}); } +TVM_REGISTER_API("relay.op.annotation._make.cast_hint") +.set_body_typed([](Expr data, DataType dtype) { + return CastHint(data, dtype); +}); + + RELAY_REGISTER_OP("annotation.cast_hint") .describe(R"code(Annotate an expression to be cast into specific data type.)code" TVM_ADD_FILELINE) .set_num_inputs(1) .add_argument("data", "Tensor", "The input data.") +.set_attrs_type_key("relay.attrs.CastHintAttrs") .add_type_rel("Identity", IdentityRel) .set_support_level(10) .set_attr("TOpPattern", kOpaque) diff --git a/src/relay/pass/forward_rewrite.cc b/src/relay/pass/forward_rewrite.cc index 6c66d6e982a7..ed84f33d00e2 100644 --- a/src/relay/pass/forward_rewrite.cc +++ b/src/relay/pass/forward_rewrite.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/src/relay/pass/quantize/annotate.cc b/src/relay/pass/quantize/annotate.cc index d8a7a0f24818..53ce2c53d9c4 100644 --- a/src/relay/pass/quantize/annotate.cc +++ b/src/relay/pass/quantize/annotate.cc @@ -76,23 +76,9 @@ TVM_REGISTER_API("relay._quantize.make_annotate_expr") Pass QuantizeAnnotate() { - // TODO(tvm-teams): since partition has added cast_hint in different - // branches, try to remove this in the future. - std::function fmulti_ref = [](const Expr& e) { - if (e->derived_from()) { - const auto* n = e.as(); - CHECK(n); - const PackedFunc* f = - runtime::Registry::Get("relay.quantize.attach_simulated_quantize"); - Expr ret = (*f)(n->expr, static_cast(kQInput)); - return static_cast(QAnnotateExprNode::make(ret, kQInput)); - } - return e; - }; - runtime::TypedPackedFunc pass_func = [=](Function f, Module m, PassContext pc) { - auto func = Downcast(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref)); + auto func = Downcast(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, nullptr)); auto new_params = func->params; for (const auto& x : FreeVars(func)) { new_params.push_back(x); diff --git a/src/relay/pass/quantize/calibrate.cc b/src/relay/pass/quantize/calibrate.cc index 30b47ba69a6e..6faa8d9fbb32 100644 --- a/src/relay/pass/quantize/calibrate.cc +++ b/src/relay/pass/quantize/calibrate.cc @@ -66,7 +66,7 @@ class StatsCollector : private ExprMutator { // add non-const expressions to profile data if (attrs->kind != QAnnotateKind::kQWeight) { - CHECK(!quantize_input.as()); + // CHECK(!quantize_input.as()); profile_data_.push_back(identity_quantize); } return identity_quantize; @@ -87,12 +87,12 @@ class StatsCollector : private ExprMutator { * \param expr The simulation graph after annotation. * \return The profile graph. */ -Expr CollectStats(const Expr& expr) { +Expr CreateStatsCollector(const Expr& expr) { return StatsCollector().Collect(expr); } -TVM_REGISTER_API("relay._quantize.CollectStats") -.set_body_typed(CollectStats); +TVM_REGISTER_API("relay._quantize.CreateStatsCollector") +.set_body_typed(CreateStatsCollector); } // namespace quantize } // namespace relay diff --git a/src/relay/pass/quantize/partition.cc b/src/relay/pass/quantize/partition.cc index 3f46cf2f227e..5d56f0c18413 100644 --- a/src/relay/pass/quantize/partition.cc +++ b/src/relay/pass/quantize/partition.cc @@ -79,6 +79,8 @@ Pass QuantizePartition() { [=](Function f, Module m, PassContext pc) { auto ret = Downcast( ForwardRewrite(f, "FQPartitionRewrite", nullptr, nullptr)); + LOG(INFO) << "after partition"; + LOG(INFO) << AsText(ret, false); return ret; }; return CreateFunctionPass(pass_func, 1, "QuantizePartition", {}); diff --git a/src/relay/pass/quantize/quantize.cc b/src/relay/pass/quantize/quantize.cc index c6d71ba0ed32..88191ce29394 100644 --- a/src/relay/pass/quantize/quantize.cc +++ b/src/relay/pass/quantize/quantize.cc @@ -124,8 +124,9 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << "nbit_input=" << op->nbit_input << ", "; p->stream << "nbit_weight=" << op->nbit_weight << ", "; p->stream << "nbit_activation=" << op->nbit_activation << ", "; - p->stream << "global_scale=" << op->global_scale << ", "; p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", "; + p->stream << "calibrate_mode=" << op->calibrate_mode << ", "; + p->stream << "global_scale=" << op->global_scale << ", "; p->stream << "do_simulation==" << op->do_simulation << ", "; p->stream << "round_for_shift==" << op->round_for_shift << ", "; p->stream << "debug_enabled_ops==" << op->debug_enabled_ops; diff --git a/src/relay/pass/quantize/quantize.h b/src/relay/pass/quantize/quantize.h index 4c153d522d69..3d86281ff267 100644 --- a/src/relay/pass/quantize/quantize.h +++ b/src/relay/pass/quantize/quantize.h @@ -59,7 +59,6 @@ struct SimulatedQuantizeAttrs : public tvm::AttrsNode { } }; - class QConfig; /*! * \brief Container for build configuration options @@ -72,8 +71,9 @@ class QConfigNode : public Node { DataType dtype_input = Int(8); DataType dtype_weight = Int(8); DataType dtype_activation = Int(32); - double global_scale = 8.0; Array skip_conv_layers = Array(NodePtr(nullptr)); + std::string calibrate_mode = "global_scale"; + double global_scale = 8.0; bool do_simulation = false; bool round_for_shift = true; Array debug_enabled_ops = Array(NodePtr(nullptr)); @@ -85,8 +85,9 @@ class QConfigNode : public Node { v->Visit("dtype_input", &dtype_input); v->Visit("dtype_weight", &dtype_weight); v->Visit("dtype_activation", &dtype_activation); - v->Visit("global_scale", &global_scale); v->Visit("skip_conv_layers", &skip_conv_layers); + v->Visit("calibrate_mode", &calibrate_mode); + v->Visit("global_scale", &global_scale); v->Visit("do_simulation", &do_simulation); v->Visit("round_for_shift", &round_for_shift); v->Visit("debug_enabled_ops", &debug_enabled_ops); diff --git a/src/relay/pass/quantize/realize.cc b/src/relay/pass/quantize/realize.cc index 7eae9992c9e4..3bb997e89dd5 100644 --- a/src/relay/pass/quantize/realize.cc +++ b/src/relay/pass/quantize/realize.cc @@ -54,17 +54,17 @@ RELAY_DEFINE_NODE_REF(QRealizeExpr, QRealizeExprNode, TempExpr); class QRealizeIntExprNode : public QRealizeExprNode { public: Expr dom_scale; - DataType dtype; + // DataType dtype; void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("data", &data); v->Visit("dom_scale", &dom_scale); - v->Visit("dtype", &dtype); + // v->Visit("dtype", &dtype); } Expr Realize() const final; - TVM_DLL static QRealizeIntExpr make(Expr data, Expr dom_scale, DataType dtype); + TVM_DLL static QRealizeIntExpr make(Expr data, Expr dom_scale); static constexpr const char * _type_key = "relay.quantize.QRealizeIntExpr"; TVM_DECLARE_NODE_TYPE_INFO(QRealizeIntExprNode, QRealizeExprNode); @@ -81,11 +81,11 @@ Expr QRealizeIntExprNode::Realize() const { return data; } -QRealizeIntExpr QRealizeIntExprNode::make(Expr data, Expr dom_scale, DataType dtype) { +QRealizeIntExpr QRealizeIntExprNode::make(Expr data, Expr dom_scale) { NodePtr n = make_node(); n->data = std::move(data); n->dom_scale = std::move(dom_scale); - n->dtype = std::move(dtype); + // n->dtype = std::move(dtype); return QRealizeIntExpr(n); } @@ -96,29 +96,32 @@ inline Expr ForwardOp(const Call& ref_call, const Array& args) { } -/* calculate `data * s1 / s2`, use shift if possible */ -inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) { - // here we assume the dtype of data is dtype activation - if (s1 == s2) return data; - - float factor = s1 / s2; - float shift_factor = std::log2(factor); - CHECK_GT(shift_factor, 0); - if (static_cast(shift_factor) == shift_factor) { - return LeftShift(data, MakeConstantScalar(dtype, - static_cast(shift_factor))); - } else if (static_cast(factor) == factor) { - return Multiply(data, MakeConstantScalar(dtype, factor)); - } else { - data = Cast(data, Float(32)); - data = Multiply(data, MakeConstantScalar(Float(32), factor)); - return Cast(Round(data), dtype); - } -} +// /* calculate `data * s1 / s2`, use shift if possible */ +// inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) { +// // here we assume the dtype of data is dtype activation +// if (s1 == s2) return data; +// +// float factor = s1 / s2; +// float shift_factor = std::log2(factor); +// CHECK_GT(shift_factor, 0); +// if (static_cast(shift_factor) == shift_factor) { +// // TODO: saturation +// // clip(data, pow(2, dtype - shift_factor)) +// return LeftShift(data, MakeConstantScalar(dtype, +// static_cast(shift_factor))); +// } else if (static_cast(factor) == factor) { +// return Multiply(data, MakeConstantScalar(dtype, factor)); +// } else { +// data = Cast(data, Float(32)); +// data = Multiply(data, MakeConstantScalar(Float(32), factor)); +// return Cast(Round(data), dtype); +// } +// } Expr QuantizeRealize(const Call& ref_call, const Array& new_args, const NodeRef& ctx) { + // TODO: consider not dtype_activation const QConfig& cfg = QConfig::Current(); // do not handle data type cast const auto param = ref_call->attrs.as(); @@ -142,7 +145,7 @@ Expr QuantizeRealize(const Call& ref_call, if (idom_scale_imm == odom_scale_imm) { // same domain scale, only clip data = Clip(data, clip_min_imm, clip_max_imm); - return QRealizeIntExprNode::make(data, dom_scale, n->dtype); + return QRealizeIntExprNode::make(data, dom_scale); } float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm); @@ -157,18 +160,27 @@ Expr QuantizeRealize(const Call& ref_call, } data = RightShift(data, MakeConstantScalar(cfg->dtype_activation, static_cast(shift_nbit))); - } else { + data = Clip(data, clip_min_imm, clip_max_imm); + return QRealizeIntExprNode::make(data, dom_scale); + } else { // shift_nbit < 0, use left shift + shift_nbit = - shift_nbit; + // saturation before to avoid overflow during left shifting + clip_min_imm = int(clip_min_imm) >> int(shift_nbit); + clip_max_imm = int(clip_max_imm) >> int(shift_nbit); + LOG(INFO) << "left shift happens"; + LOG(INFO) << "clip_min: " << clip_min_imm; + LOG(INFO) << "clip_max: " << clip_max_imm; + data = Clip(data, clip_min_imm, clip_max_imm); data = LeftShift(data, MakeConstantScalar(cfg->dtype_activation, static_cast(shift_nbit))); + return QRealizeIntExprNode::make(data, dom_scale); } - data = Clip(data, clip_min_imm, clip_max_imm); - return QRealizeIntExprNode::make(data, dom_scale, n->dtype); } else { // float computation data = Cast(data, Float(32)); Expr scaled_data = Multiply(data, Divide(n->dom_scale, dom_scale)); Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm); - return QRealizeIntExprNode::make(round_data, dom_scale, Float(32)); + return QRealizeIntExprNode::make(round_data, dom_scale); } } @@ -177,9 +189,13 @@ Expr QuantizeRealize(const Call& ref_call, Expr data = new_args[0]; Expr scaled_data = Multiply(data, MakeConstantScalar(Float(32), 1 / dom_scale_imm)); Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm); - return QRealizeIntExprNode::make(round_data, dom_scale, Float(32)); + return QRealizeIntExprNode::make(round_data, dom_scale); } +RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize") +.set_attr("FQRealizeRewrite", QuantizeRealize); + + Expr FoldConstantOpt(const Expr& expr) { auto mod = ModuleNode::FromExpr(expr); mod = transform::FoldConstant()(mod); @@ -187,10 +203,6 @@ Expr FoldConstantOpt(const Expr& expr) { return expr.as() == nullptr ? entry_func->body : entry_func; } -RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize") -.set_attr("FQRealizeRewrite", QuantizeRealize); - - Expr Conv2dRealize(const Call& ref_call, const Array& new_args, const NodeRef& ctx) { @@ -205,10 +217,11 @@ Expr Conv2dRealize(const Call& ref_call, CHECK(rhs); Expr ldata = lhs->data; - if (lhs->dtype != cfg->dtype_input) { - ldata = Cast(ldata, cfg->dtype_input); - } - Expr rdata = Cast(rhs->data, cfg->dtype_weight); + // if (lhs->dtype != cfg->dtype_input) { + // ldata = Cast(ldata, cfg->dtype_input); + // } + // Expr rdata = Cast(rhs->data, cfg->dtype_weight); + Expr rdata = rhs->data; const auto ref_attrs = ref_call->attrs.as(); auto attrs = make_node(); @@ -220,7 +233,7 @@ Expr Conv2dRealize(const Call& ref_call, {ldata, rdata}, Attrs(attrs), ref_call->type_args); Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); Expr dom_scale = FoldConstantOpt(mul); - return QRealizeIntExprNode::make(ret, dom_scale, out_dtype); + return QRealizeIntExprNode::make(ret, dom_scale); } RELAY_REGISTER_OP("nn.conv2d") @@ -239,10 +252,11 @@ Expr DenseRealize(const Call& ref_call, const auto* rhs = new_args[1].as(); Expr ldata = lhs->data; - if (lhs->dtype != cfg->dtype_input) { - ldata = Cast(ldata, cfg->dtype_input); - } - Expr rdata = Cast(rhs->data, cfg->dtype_weight); + // if (lhs->dtype != cfg->dtype_input) { + // ldata = Cast(ldata, cfg->dtype_input); + // } + // Expr rdata = Cast(rhs->data, cfg->dtype_weight); + Expr rdata = rhs->data; const auto ref_attrs = ref_call->attrs.as(); auto attrs = make_node(); @@ -254,7 +268,7 @@ Expr DenseRealize(const Call& ref_call, {ldata, rdata}, Attrs(attrs), ref_call->type_args); Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); Expr dom_scale = FoldConstantOpt(mul); - return QRealizeIntExprNode::make(ret, dom_scale, out_dtype); + return QRealizeIntExprNode::make(ret, dom_scale); } RELAY_REGISTER_OP("nn.dense") @@ -273,22 +287,21 @@ Expr MulRealize(const Call& ref_call, Expr ldata = lhs->data; Expr rdata = rhs->data; - DataType dtype = cfg->dtype_activation; - if (lhs->dtype != dtype) { - ldata = Cast(ldata, dtype); - } else { - CHECK_EQ(lhs->dtype, dtype); - } - if (rhs->dtype != dtype) { - rdata = Cast(rdata, dtype); - } else { - CHECK_EQ(rhs->dtype, dtype); - } + // if (lhs->dtype != dtype) { + // ldata = Cast(ldata, dtype); + // } else { + // CHECK_EQ(lhs->dtype, dtype); + // } + // if (rhs->dtype != dtype) { + // rdata = Cast(rdata, dtype); + // } else { + // CHECK_EQ(rhs->dtype, dtype); + // } Expr ret = ForwardOp(ref_call, {ldata, rdata}); Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); Expr dom_scale = FoldConstantOpt(mul); - return QRealizeIntExprNode::make(ret, dom_scale, dtype); + return QRealizeIntExprNode::make(ret, dom_scale); } CHECK(!new_args[0]->derived_from() && !new_args[1]->derived_from()); return Expr(nullptr); @@ -298,81 +311,84 @@ RELAY_REGISTER_OP("multiply") .set_attr("FQRealizeRewrite", MulRealize); -float ChooseDomScale(const std::vector& nptrs) { - if (nptrs.size() == 2) { - // x = a * s1, y = b * s2 - // x + y = (a * s1 / s2 + b) * s2, if s1 > s2 - // = (a + b * s2 / s1) * s1, if s2 > s1 - float s1 = GetScalarFromConstant(nptrs[0]->dom_scale); - float s2 = GetScalarFromConstant(nptrs[1]->dom_scale); - return s1 > s2 ? s2 : s1; - } else { - const QConfig& cfg = QConfig::Current(); - float scale = cfg->global_scale; - return scale / std::pow(2.0, cfg->nbit_activation - 1); - } -} - - -/* \brief Unify the dom scale of arguments */ -Array UnifyDTypeScale(const Array& ref_args, const Array& args, - DataType* dtype_ptr, Expr* scale_ptr) { - static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize"); - const QConfig& cfg = QConfig::Current(); - - std::vector nptrs; - Array ret; - for (auto arg : args) { - const auto* nptr = arg.as(); - CHECK(nptr); - nptrs.push_back(nptr); - ret.push_back(nptr->data); - } - - // unify the data type - CHECK_EQ(ref_args.size(), args.size()); - DataType dtype; - - if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) { - dtype = cfg->dtype_input; - } else { - dtype = cfg->dtype_activation; - } - for (size_t i = 0; i < ret.size(); ++i) { - auto ref_arg = ref_args[i].as(); - if (nptrs[i]->dtype != dtype) { - ret.Set(i, Cast(ret[i], dtype)); - } else if (ref_arg && ref_arg->op.same_as(simulated_quantize) && - ref_arg->attrs.as()->kind == kQInput) { - auto new_arg = Cast(ret[i], cfg->dtype_input); - new_arg = StopFusion(new_arg); - ret.Set(i, Cast(new_arg, dtype)); - } - } - - // unify the dom_scale - float s = ChooseDomScale(nptrs); - Expr dom_scale = MakeConstantScalar(Float(32), s); - for (size_t i = 0; i < ret.size(); ++i) { - float cur_s = GetScalarFromConstant(nptrs[i]->dom_scale); - ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype)); - } - - *dtype_ptr = dtype; - *scale_ptr = dom_scale; - return ret; -} +// float ChooseDomScale(const std::vector& nptrs) { +// if (nptrs.size() == 2) { +// // x = a * s1, y = b * s2 +// // x + y = (a * s1 / s2 + b) * s2, if s1 > s2 +// // = (a + b * s2 / s1) * s1, if s2 > s1 +// float s1 = GetScalarFromConstant(nptrs[0]->dom_scale); +// float s2 = GetScalarFromConstant(nptrs[1]->dom_scale); +// return s1 > s2 ? s2 : s1; +// } else { +// const QConfig& cfg = QConfig::Current(); +// float scale = cfg->global_scale; +// return scale / std::pow(2.0, cfg->nbit_activation - 1); +// } +// } +// +// +// /* \brief Unify the dom scale of arguments */ +// Array UnifyDTypeScale(const Array& ref_args, const Array& args, +// DataType* dtype_ptr, Expr* scale_ptr) { +// static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize"); +// const QConfig& cfg = QConfig::Current(); +// +// std::vector nptrs; +// Array ret; +// for (auto arg : args) { +// const auto* nptr = arg.as(); +// CHECK(nptr); +// nptrs.push_back(nptr); +// ret.push_back(nptr->data); +// } +// +// // unify the data type +// CHECK_EQ(ref_args.size(), args.size()); +// DataType dtype; +// +// // if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) { +// // dtype = cfg->dtype_input; +// // } else { +// // dtype = cfg->dtype_activation; +// // } +// // for (size_t i = 0; i < ret.size(); ++i) { +// // auto ref_arg = ref_args[i].as(); +// // if (nptrs[i]->dtype != dtype) { +// // ret.Set(i, Cast(ret[i], dtype)); +// // } else if (ref_arg && ref_arg->op.same_as(simulated_quantize) && +// // ref_arg->attrs.as()->kind == kQInput) { +// // auto new_arg = Cast(ret[i], cfg->dtype_input); +// // new_arg = StopFusion(new_arg); +// // ret.Set(i, Cast(new_arg, dtype)); +// // } +// // } +// +// // unify the dom_scale +// float s = ChooseDomScale(nptrs); +// Expr dom_scale = MakeConstantScalar(Float(32), s); +// for (size_t i = 0; i < ret.size(); ++i) { +// float cur_s = GetScalarFromConstant(nptrs[i]->dom_scale); +// ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype)); +// } +// +// *dtype_ptr = dtype; +// *scale_ptr = dom_scale; +// return ret; +// } Expr AddRealize(const Call& ref_call, const Array& new_args, const NodeRef& ctx) { CHECK_EQ(new_args.size(), 2); - if (new_args[0].as() && new_args[1].as()) { - DataType dtype; - Expr dom_scale; - Array ret_args = UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale); - Expr ret = ForwardOp(ref_call, ret_args); - return QRealizeIntExprNode::make(ret, dom_scale, dtype); + + const auto* lhs = new_args[0].as(); + const auto* rhs = new_args[1].as(); + if (lhs && rhs) { + float ldom_scale = GetScalarFromConstant(lhs->dom_scale); + float rdom_scale = GetScalarFromConstant(rhs->dom_scale); + CHECK_EQ(ldom_scale, rdom_scale); + Expr ret = ForwardOp(ref_call, {lhs->data, rhs->data}); + return QRealizeIntExprNode::make(ret, lhs->dom_scale); } CHECK(!new_args[0]->derived_from() && !new_args[1]->derived_from()); @@ -395,7 +411,7 @@ Expr ClipRealize(const Call& ref_call, Expr ret = CallNode::make(ref_call->op, {n->data}, Attrs(attrs), ref_call->type_args); - return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype); + return QRealizeIntExprNode::make(ret, n->dom_scale); } CHECK(!new_args[0]->derived_from()); return Expr(nullptr); @@ -405,35 +421,35 @@ RELAY_REGISTER_OP("clip") .set_attr("FQRealizeRewrite", ClipRealize); -Expr ConcatenateRealize(const Call& ref_call, - const Array& new_args, - const NodeRef& ctx) { - CHECK_EQ(new_args.size(), 1); - CHECK_EQ(ref_call->args.size(), 1); - - const auto* tuple = new_args[0].as(); - const auto* ref_tuple = ref_call->args[0].as(); - CHECK(tuple); - CHECK(ref_tuple); - const Array& arr = tuple->fields; - const Array& ref_arr = ref_tuple->fields; - - if (arr[0].as()) { - DataType dtype; - Expr dom_scale; - Array ret_args = UnifyDTypeScale(ref_arr, arr, &dtype, &dom_scale); - Expr ret = ForwardOp(ref_call, {TupleNode::make(ret_args)}); - return QRealizeIntExprNode::make(ret, dom_scale, dtype); - } else { - for (auto arg : new_args) { - CHECK(!arg->derived_from()); - } - return Expr(nullptr); - } -} - -RELAY_REGISTER_OP("concatenate") -.set_attr("FQRealizeRewrite", ConcatenateRealize); +// Expr ConcatenateRealize(const Call& ref_call, +// const Array& new_args, +// const NodeRef& ctx) { +// CHECK_EQ(new_args.size(), 1); +// CHECK_EQ(ref_call->args.size(), 1); +// +// const auto* tuple = new_args[0].as(); +// const auto* ref_tuple = ref_call->args[0].as(); +// CHECK(tuple); +// CHECK(ref_tuple); +// const Array& arr = tuple->fields; +// const Array& ref_arr = ref_tuple->fields; +// +// if (arr[0].as()) { +// DataType dtype; +// Expr dom_scale; +// Array ret_args = UnifyDTypeScale(ref_arr, arr, &dtype, &dom_scale); +// Expr ret = ForwardOp(ref_call, {TupleNode::make(ret_args)}); +// return QRealizeIntExprNode::make(ret, dom_scale); +// } else { +// for (auto arg : new_args) { +// CHECK(!arg->derived_from()); +// } +// return Expr(nullptr); +// } +// } +// +// RELAY_REGISTER_OP("concatenate") +// .set_attr("FQRealizeRewrite", ConcatenateRealize); /* \brief forward the original operator */ @@ -443,7 +459,7 @@ Expr IdentityRealize(const Call& ref_call, CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { Expr ret = ForwardOp(ref_call, {n->data}); - return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype); + return QRealizeIntExprNode::make(ret, n->dom_scale); } CHECK(!new_args[0]->derived_from()); return Expr(nullptr); @@ -458,44 +474,44 @@ RELAY_REGISTER_OP("strided_slice") RELAY_REGISTER_OP("annotation.stop_fusion") .set_attr("FQRealizeRewrite", IdentityRealize); -/* \brief for unary operators which requantize its input to dtype_nbit */ -Expr CastDtypeInputRealize(const Call& ref_call, - const Array& new_args, - const NodeRef& ctx) { - const QConfig& cfg = QConfig::Current(); - CHECK_EQ(new_args.size(), 1); - if (const auto* n = new_args[0].as()) { - Expr data = Cast(n->data, cfg->dtype_input); - Expr ret = ForwardOp(ref_call, {data}); - return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_input); - } - CHECK(!new_args[0]->derived_from()); - return Expr(nullptr); -} +// /* \brief for unary operators which requantize its input to dtype_nbit */ +// Expr CastDtypeInputRealize(const Call& ref_call, +// const Array& new_args, +// const NodeRef& ctx) { +// const QConfig& cfg = QConfig::Current(); +// CHECK_EQ(new_args.size(), 1); +// if (const auto* n = new_args[0].as()) { +// Expr data = Cast(n->data, cfg->dtype_input); +// Expr ret = ForwardOp(ref_call, {data}); +// return QRealizeIntExprNode::make(ret, n->dom_scale); +// } +// CHECK(!new_args[0]->derived_from()); +// return Expr(nullptr); +// } RELAY_REGISTER_OP("nn.max_pool2d") -.set_attr("FQRealizeRewrite", CastDtypeInputRealize); +.set_attr("FQRealizeRewrite", IdentityRealize); -Expr AvgPoolRealize(const Call& ref_call, - const Array& new_args, - const NodeRef& ctx) { - const QConfig& cfg = QConfig::Current(); - CHECK_EQ(new_args.size(), 1); - if (const auto* n = new_args[0].as()) { - Expr data = n->data; - if (n->dtype != cfg->dtype_activation) { - data = Cast(n->data, cfg->dtype_activation); - } - Expr ret = ForwardOp(ref_call, {data}); - return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_activation); - } - CHECK(!new_args[0]->derived_from()); - return Expr(nullptr); -} +// Expr AvgPoolRealize(const Call& ref_call, +// const Array& new_args, +// const NodeRef& ctx) { +// const QConfig& cfg = QConfig::Current(); +// CHECK_EQ(new_args.size(), 1); +// if (const auto* n = new_args[0].as()) { +// Expr data = n->data; +// // if (n->dtype != cfg->dtype_activation) { +// // data = Cast(n->data, cfg->dtype_activation); +// // } +// Expr ret = ForwardOp(ref_call, {data}); +// return QRealizeIntExprNode::make(ret, n->dom_scale); +// } +// CHECK(!new_args[0]->derived_from()); +// return Expr(nullptr); +// } RELAY_REGISTER_OP("nn.avg_pool2d") -.set_attr("FQRealizeRewrite", AvgPoolRealize); +.set_attr("FQRealizeRewrite", IdentityRealize); Expr CastHintRealize(const Call& ref_call, const Array& new_args, @@ -504,7 +520,7 @@ Expr CastHintRealize(const Call& ref_call, CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { Expr ret = Cast(n->data, param->dtype); - return QRealizeIntExprNode::make(ret, n->dom_scale, param->dtype); + return QRealizeIntExprNode::make(ret, n->dom_scale); } CHECK(!new_args[0]->derived_from()); return Expr(nullptr); diff --git a/tutorials/quantize_model.py b/tutorials/quantize_model.py deleted file mode 100644 index e9a89e0a22c0..000000000000 --- a/tutorials/quantize_model.py +++ /dev/null @@ -1,134 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -.. _tutorial-quantize-model: - -Speed Up Inference and Compress Model with Quantization -=========================================== -**Author**: `Ziheng Jiang `_ - -This is an example to speed up and compress -a ResNet model with quantization. -""" - -import tvm -import tvm.relay as relay -from tvm import rpc -from tvm.contrib import util, graph_runtime as runtime -from tvm.contrib.download import download_testdata - -from mxnet.gluon.model_zoo.vision import get_model -from PIL import Image -import numpy as np -# get model - -# one line to get the model -block = get_model('resnet18_v1', pretrained=True) - -img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true' -img_name = 'cat.png' -img_path = download_testdata(img_url, img_name, module='data') -image = Image.open(img_path).resize((224, 224)) - -def transform_image(image): - image = np.array(image) - np.array([123., 117., 104.]) - image /= np.array([58.395, 57.12, 57.375]) - image = image.transpose((2, 0, 1)) - image = image[np.newaxis, :] - return image - -x = transform_image(image) - - -shape_dict = {'data': x.shape} -mod, params = relay.frontend.from_mxnet(block, shape_dict) - - -local_demo = True - -target = tvm.target.create('llvm') - -with relay.build_config(opt_level=3): - graph, lib, params = relay.build(mod, target, params=params) - -def evaluate_inference_speed(graph, lib, params): - tmp = util.tempdir() - lib_fname = tmp.relpath('net.tar') - lib.export_library(lib_fname) - - if local_demo: - remote = rpc.LocalSession() - else: - # The following is my environment, change this to the IP address of your target device - host = '10.77.1.162' - port = 9090 - remote = rpc.connect(host, port) - - # upload the library to remote device and load it - remote.upload(lib_fname) - rlib = remote.load_module('net.tar') - - # create the remote runtime module - ctx = remote.cpu(0) - module = runtime.create(graph, rlib, ctx) - # set parameter (upload params to the remote device. This may take a while) - module.set_input(**params) - # set input data - module.set_input('data', tvm.nd.array(x.astype('float32'))) - # run - - ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=10) - prof_res = np.array(ftimer().results) * 1000 # convert to millisecond - print("Mean inference time (std dev): %.2f ms (%.2f ms)" % - (np.mean(prof_res), np.std(prof_res))) - - -evaluate_inference_speed(graph, lib, params) - -import tvm.relay.quantize as qtz - -qconfig_i8_i32 = qtz.qconfig(skip_conv_layers=[0], - nbit_input=8, - nbit_weight=8, - global_scale=4.0, - dtype_input="int8", - dtype_weight="int8", - dtype_activation="int32", - do_simulation=False) - -# explain configures - - -with qconfig_i8_i32: - mod = qtz.quantize(mod, params) - - -# compare origin size and quantized size - -# compare origin speed and quantized speed - -def profile_speed_and_size(): - pass - -# compare origin speed and i16 speed - -qconfig_i8_i16 - - -# How do we get those model - -# search configure on Machine From 9af9da38261a2784530d5036cecb6aa0b75692c6 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 27 Aug 2019 01:22:08 -0700 Subject: [PATCH 3/4] Update test --- .../test_quantization_accuracy.py | 113 ++++++++++++++---- 1 file changed, 89 insertions(+), 24 deletions(-) diff --git a/tests/python/nightly/quantization/test_quantization_accuracy.py b/tests/python/nightly/quantization/test_quantization_accuracy.py index f047952f3e6b..e655b30afe83 100644 --- a/tests/python/nightly/quantization/test_quantization_accuracy.py +++ b/tests/python/nightly/quantization/test_quantization_accuracy.py @@ -19,11 +19,12 @@ from tvm import relay from tvm.relay import quantize as qtz import mxnet as mx +import numpy as np from mxnet import gluon import logging import os -logging.basicConfig(level=logging.INFO) +logging.basicConfig(level=logging.DEBUG) Config = namedtuple('Config', ['model', 'nbit_input', 'dtype_input', 'nbit_output', 'dtype_output', 'global_scale', 'expected_acc']) @@ -58,32 +59,39 @@ def batch_fn(batch, ctx): return val_data, batch_fn -def get_model(model_name, batch_size, qconfig, target=None, original=False, simulated=False): +def get_model(model_name, batch_size, qconfig, target=None, original=False, simulated=False, calib_set=None): gluon_model = gluon.model_zoo.vision.get_model(model_name, pretrained=True) img_size = 299 if model_name == 'inceptionv3' else 224 data_shape = (batch_size, 3, img_size, img_size) mod, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape}) - net = mod['main'] - with relay.build_config(opt_level=3): - qfunc = relay.quantize.prerequisite_optimize(net, params=params) + qmod = relay.quantize.prerequisite_optimize(mod, params=params) logging.debug('original') - logging.debug(qfunc.astext(show_meta_data=False)) + logging.debug(qmod['main'].astext(show_meta_data=False)) + + def visit(e): + if isinstance(e, tvm.relay.Call): + print(e.op.name) + for var in e.args: + if isinstance(var, tvm.relay.Constant): + print(np.max(var.data.asnumpy())) + relay.analysis.post_order_visit(qmod['main'], visit) + if original: - return qfunc + return qmod with qconfig: logging.debug('current quantize config') logging.debug(qtz.current_qconfig()) - qfunc = qtz.quantize(qfunc) + qmod = qtz.quantize(qmod, dataset=calib_set) logging.debug('after quantize') - logging.debug(qfunc.astext(show_meta_data=False)) - return qfunc + logging.debug(qmod['main'].astext(show_meta_data=False)) + return qmod -def eval_acc(model, dataset, batch_fn, target=tvm.target.cuda(), ctx=tvm.gpu(), log_interval=100): +def eval_acc(mod, dataset, batch_fn, target=tvm.target.cuda(), ctx=tvm.gpu(), log_interval=100): with relay.build_config(opt_level=3): - graph, lib, params = relay.build(model, target) + graph, lib, params = relay.build(mod, target) # create runtime module m = tvm.contrib.graph_runtime.create(graph, lib, ctx) m.set_input(**params) @@ -111,20 +119,33 @@ def eval_acc(model, dataset, batch_fn, target=tvm.target.cuda(), ctx=tvm.gpu(), logging.info('[final] validation: acc-top1=%f acc-top5=%f', top1, top5) return top1 + +def get_calibration_dataset(dataset, batch_fn, num_samples=100): + dataset.reset() + for i, batch in enumerate(dataset): + if i * dataset.batch_size > num_samples: + break + data, label = batch_fn(batch, [mx.cpu(0)]) + yield {'data': data[0].asnumpy()} + + def test_quantize_acc(cfg, rec_val): qconfig = qtz.qconfig(skip_conv_layers=[0], nbit_input=cfg.nbit_input, nbit_weight=cfg.nbit_input, - global_scale=cfg.global_scale, dtype_input=cfg.dtype_input, dtype_weight=cfg.dtype_input, dtype_activation=cfg.dtype_output, + global_scale=cfg.global_scale, + do_simulation=False, debug_enabled_ops=None) - model = get_model(cfg.model, 32, qconfig, tvm.target.cuda()) val_data, batch_fn = get_val_data(cfg.model, rec_val=rec_val, batch_size=32) + calib_set = get_calibration_dataset(val_data, batch_fn) + + mod = get_model(cfg.model, 32, qconfig, tvm.target.cuda(), calib_set=calib_set) - acc = eval_acc(model, val_data, batch_fn) + acc = eval_acc(mod, val_data, batch_fn) assert acc > cfg.expected_acc return acc @@ -135,19 +156,63 @@ def test_quantize_acc(cfg, rec_val): results = [] configs = [ - Config('mobilenetv2_1.0', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=4.0, expected_acc=0.666), - - Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=8.0, expected_acc=0.692), - Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.692), - Config('resnet34_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.733), - Config('resnet50_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.747), - Config('resnet101_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.756), - # TODO: need to fix accuracy - # Config('mobilenetv2_1.0', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=4.0), + # resnet18_v1 best configuration + Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=8.0, expected_acc=0.675), + Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.674), + Config('resnet34_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.714), + Config('resnet50_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=4.0, expected_acc=0.743), + Config('resnet101_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=2.0, expected_acc=0.751), + + # resnet18_v2 best configuration + # Config('resnet18_v2', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=4.0, expected_acc=0.611), + # Config('resnet18_v2', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=4.0, expected_acc=0.612), + # Config('resnet34_v2', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=4.0, expected_acc=0.726), + # Config('resnet50_v2', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=2.0, expected_acc=0.752), + # Config('resnet101_v2', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=2.0, expected_acc=0.765), + + # resnet18_v1 history + Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=2.0, expected_acc=0.000), + Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=2.0, expected_acc=0.401), + Config('resnet34_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=2.0, expected_acc=0.259), + Config('resnet50_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=2.0, expected_acc=0.738), + Config('resnet101_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=2.0, expected_acc=0.751), + + # Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=4.0, expected_acc=0.367), + # Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=4.0, expected_acc=0.672), + # Config('resnet34_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=4.0, expected_acc=0.699), + # Config('resnet50_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=4.0, expected_acc=0.743), + # Config('resnet101_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=4.0, expected_acc=0.759), + + # Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=8.0, expected_acc=0.675), + # Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.674), + # Config('resnet34_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.714), + # Config('resnet50_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.696), + # Config('resnet101_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.713), + + # resnet18_v2 history + # Config('resnet18_v2', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=2.0, expected_acc=0.250) + # Config('resnet18_v2', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=2.0, expected_acc=0.454), + # Config('resnet34_v2', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=2.0, expected_acc=0.459), + # Config('resnet50_v2', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=2.0, expected_acc=0.752), + # Config('resnet101_v2', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=2.0, expected_acc=0.765), + + # Config('resnet18_v2', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=4.0, expected_acc=0.611), + # Config('resnet18_v2', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=4.0, expected_acc=0.612), + # Config('resnet34_v2', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=4.0, expected_acc=0.726), + # Config('resnet50_v2', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=4.0, expected_acc=0.750), + # Config('resnet101_v2', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=4.0, expected_acc=0.752), + + # Config('resnet18_v2', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=8.0, expected_acc=0.500), + # Config('resnet18_v2', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.500), + # Config('resnet34_v2', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.705), + # Config('resnet50_v2', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.661), + # Config('resnet101_v2', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.526), ] + # global scales for config in configs: acc = test_quantize_acc(config, rec_val) results.append((config, acc)) for res in results: print(res) + From 8dab80c86e26d093bc1d10b9e56d9ef9925295c3 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 27 Aug 2019 01:23:33 -0700 Subject: [PATCH 4/4] Update test --- .../test_quantization_accuracy.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/python/nightly/quantization/test_quantization_accuracy.py b/tests/python/nightly/quantization/test_quantization_accuracy.py index e655b30afe83..5786b008bcfc 100644 --- a/tests/python/nightly/quantization/test_quantization_accuracy.py +++ b/tests/python/nightly/quantization/test_quantization_accuracy.py @@ -164,18 +164,18 @@ def test_quantize_acc(cfg, rec_val): Config('resnet101_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=2.0, expected_acc=0.751), # resnet18_v2 best configuration - # Config('resnet18_v2', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=4.0, expected_acc=0.611), - # Config('resnet18_v2', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=4.0, expected_acc=0.612), - # Config('resnet34_v2', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=4.0, expected_acc=0.726), - # Config('resnet50_v2', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=2.0, expected_acc=0.752), - # Config('resnet101_v2', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=2.0, expected_acc=0.765), + Config('resnet18_v2', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=4.0, expected_acc=0.611), + Config('resnet18_v2', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=4.0, expected_acc=0.612), + Config('resnet34_v2', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=4.0, expected_acc=0.726), + Config('resnet50_v2', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=2.0, expected_acc=0.752), + Config('resnet101_v2', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=2.0, expected_acc=0.765), # resnet18_v1 history - Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=2.0, expected_acc=0.000), - Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=2.0, expected_acc=0.401), - Config('resnet34_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=2.0, expected_acc=0.259), - Config('resnet50_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=2.0, expected_acc=0.738), - Config('resnet101_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=2.0, expected_acc=0.751), + # Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=2.0, expected_acc=0.000), + # Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=2.0, expected_acc=0.401), + # Config('resnet34_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=2.0, expected_acc=0.259), + # Config('resnet50_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=2.0, expected_acc=0.738), + # Config('resnet101_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=2.0, expected_acc=0.751), # Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=4.0, expected_acc=0.367), # Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=4.0, expected_acc=0.672),