Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QUANTIZE] Improve explicitness of rules during annotation/realization #3828

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
schedule_broadcast = schedule_injective
schedule_elemwise = schedule_injective

register_schedule("annotation.cast_hint", schedule_broadcast)
register_schedule("log", schedule_broadcast)
register_schedule("log1p", schedule_broadcast)
register_schedule("cos", schedule_broadcast)
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/relay/op/annotation/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,19 @@ def stop_fusion(data):
The annotated expression.
"""
return _make.stop_fusion(data)


def cast_hint(data, dtype):
"""Annotate an expression to prevent it being fused with previous expressions.

Parameters
----------
data : tvm.relay.Expr
The expression to be annotated.

Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
return _make.cast_hint(data, dtype)
2 changes: 1 addition & 1 deletion python/tvm/relay/quantize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@
from __future__ import absolute_import as _abs

from .quantize import *
from . import _quantized_ops
from ._partition import register_partition_function
from ._annotate import register_annotate_function
from .kl_divergence import kl_divergence_scale
213 changes: 156 additions & 57 deletions python/tvm/relay/quantize/_annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,46 +19,18 @@
from __future__ import absolute_import
import warnings

import topi
from ..._ffi.function import register_func
from .. import expr as _expr
from .. import analysis as _analysis
from .. import op as _op
from ..op import op as _reg
from ..base import register_relay_node
from ..op.annotation import cast_hint
from . import _quantize
from .quantize import QAnnotateKind, current_qconfig, quantize_context
from .quantize import _forward_op


@_reg.register_compute("relay.op.annotation.simulated_quantize")
def simulated_quantize_compute(attrs, inputs, out_type, target):
"""Compiler for simulated_quantize."""
assert len(inputs) == 4
assert attrs.sign
assert attrs.rounding == "round"

data, scale, clip_min, clip_max = inputs

if attrs.kind == QAnnotateKind.IDENTITY:
return [topi.identity(data)]

# simulate rounding error
scaled_data = topi.divide(data, scale)
clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min)
round_data = topi.round(clipped_data)

# recover data
rdata = topi.multiply(round_data, scale)
return [rdata]


_reg.register_schedule("relay.op.annotation.simulated_quantize",
_reg.schedule_injective)
_reg.register_pattern("relay.op.annotation.simulated_quantize",
_reg.OpPattern.ELEMWISE)


@register_relay_node
class QAnnotateExpr(_expr.TempExpr):
"""A special kind of Expr for Annotating.
Expand Down Expand Up @@ -146,6 +118,23 @@ def attach_simulated_quantize(data, kind, sign=True, rounding="round"):
register_func("relay.quantize.attach_simulated_quantize", attach_simulated_quantize)


@register_annotate_function("annotation.cast_hint")
def cast_hint_rewrite(ref_call, new_args, ctx):
"""Rewrite function to force cast"""
expr, x_kind = _get_expr_kind(new_args[0])

if quantize_context().check_to_skip(ref_call):
return expr

if x_kind is None:
return new_args[0]
if x_kind == QAnnotateKind.ACTIVATION:
expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT)

expr = _forward_op(ref_call, [expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)


@register_annotate_function("nn.contrib_conv2d_NCHWc")
def conv2d_nchwc_rewrite(ref_call, new_args, ctx):
warnings.warn("NCHWc layout Conv2D detected, please use a lower "
Expand All @@ -155,20 +144,24 @@ def conv2d_nchwc_rewrite(ref_call, new_args, ctx):

@register_annotate_function("nn.conv2d")
def conv2d_rewrite(ref_call, new_args, ctx):
"""Rewrite function for conv2d. Lhs of conv will be quantized to
input field, and rhs of conv will be quantized to weight field.
Output would be in activation field"""
"""Rewrite function for conv2d. Allowed combination:
- lhs[nbit_input, dtype_input], rhs[nbit_weight, dtype_weight] ->
out[x, dtype_activation]
"""
if quantize_context().check_to_skip(ref_call):
return None

lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])

qcfg = current_qconfig()
if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION:
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
lhs_expr = cast_hint(lhs_expr, qcfg.dtype_input)

assert rhs_kind is None
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
rhs_expr = cast_hint(rhs_expr, qcfg.dtype_weight)

expr = _forward_op(ref_call, [lhs_expr, rhs_expr])

Expand All @@ -186,11 +179,14 @@ def dense_rewrite(ref_call, new_args, ctx):
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])

qcfg = current_qconfig()
if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION:
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
lhs_expr = cast_hint(lhs_expr, qcfg.dtype_input)

assert rhs_kind is None
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
rhs_expr = cast_hint(rhs_expr, qcfg.dtype_weight)

expr = _forward_op(ref_call, [lhs_expr, rhs_expr])

Expand All @@ -199,31 +195,133 @@ def dense_rewrite(ref_call, new_args, ctx):

@register_annotate_function("multiply")
def multiply_rewrite(ref_call, new_args, ctx):
"""Rewrite function for multiply."""
"""Rewrite function for multiply.
Allowed combination:
- lhs[nbit_input, dtype_activation] * rhs[nbit_weight/nbit_input, dtype_activation]
-> out[x, dtype_activation]
"""
if quantize_context().check_to_skip(ref_call):
return None

lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])

if lhs_kind is None and rhs_kind is None:
return None

qcfg = current_qconfig()
# for now, only support multiply bias transformed by batch_norm
assert rhs_kind is None
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
rhs_expr = cast_hint(rhs_expr, qcfg.dtype_activation)

# print('multiply lhs: {0}'.format(lhs_kind))
# print('multiply lhs: \n{0}'.format(lhs_expr))

if lhs_kind is QAnnotateKind.ACTIVATION:
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
lhs_expr = cast_hint(lhs_expr, qcfg.dtype_activation)

expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
# print('multiply out: \n{0}'.format(expr.astext(show_meta_data=False)))
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)


@register_annotate_function("add")
def new_add_rewrite(ref_call, new_args, ctx):
"""Rewrite function for add. Allowed combinations:
- lhs[*, dtype_activation], rhs[nbit_weight, dtype_activation] ->
out[*, dtype_activation]
- lhs[nbit_input, dtype_activation], rhs[nbit_weight, dtype_activation] ->
out[*, dtype_activation]
- lhs[nbit_input, dtype_input], rhs[nbit_input, dtype_input] ->
out[*, dtype_input]
"""
if quantize_context().check_to_skip(ref_call):
return None

lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])

if lhs_kind is None and rhs_kind is None:
# trivial case
return None

if lhs_kind in [QAnnotateKind.ACTIVATION, QAnnotateKind.INPUT] and rhs_kind is None:
# quantize lhs to INPUT field
if lhs_kind == QAnnotateKind.ACTIVATION:
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
# quantize rhs to WEIGHT field
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
qcfg = current_qconfig()

# unify lhs and rhs to the same dom_scale
dom_scale = _expr.var("dom_scale")
clip_min = _expr.var("clip_min")
clip_max = _expr.var("clip_max")
lhs_expr = _quantize.simulated_quantize(lhs_expr,
dom_scale, clip_min, clip_max, QAnnotateKind.INPUT, True, 'round')
rhs_expr = _quantize.simulated_quantize(rhs_expr,
dom_scale, clip_min, clip_max, QAnnotateKind.INPUT, True, 'round')

if lhs_kind is QAnnotateKind.ACTIVATION and rhs_kind is None:
# introduced by bias_add from batch_norm (resnet18_v1)
rhs_expr = cast_hint(rhs_expr, qcfg.dtype_activation)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)

if lhs_kind is None and rhs_kind is QAnnotateKind.INPUT:
# introduced by residual addition, lhs is a skipped layer (resnet18_v1)
lhs_expr = cast_hint(lhs_expr, qcfg.dtype_input)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)

if lhs_kind is QAnnotateKind.INPUT and rhs_kind is None:
# introduced by residual addition, rhs is a skipped layer (resnet18_v2)
rhs_expr = cast_hint(rhs_expr, qcfg.dtype_input)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)

if lhs_kind is QAnnotateKind.INPUT and rhs_kind is QAnnotateKind.INPUT:
# introduced by residual addition (resnet18_v1)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)

print('lhs: {0}'.format(lhs_expr))
print('rhs: {0}'.format(rhs_expr))
print('lhs: {0}'.format(lhs_kind))
print('rhs: {0}'.format(rhs_kind))

raise ValueError

# dom_scale = scale / 2^(valid_bit)

# simulation
# lhs = sq(lhs, dom_scale, clip_min, clip_max)
# lhs = cast_hint(lhs, dtype)
# rhs = sq(rhs, dom_scale, clip_min, clip_max)
# rhs = cast_hint(rhs, dtype)
# out = lhs + rhs

# realization
# lhs = lhs * ldom_scale / odom_scale
# lshift(lhs, log2(ldom_scale / odom_scale))
# overflow risk
# lhs = lhs * dom_lscale / odom_scale
# rhs = adjust(rhs, dom_rscale, oscale, nbit)


# quantized_add(lhs, rhs, odom_scale, clip_min, clip_max)
# during simulation
# out = lhs + rhs
# scaled_out = out / odom_scale
# truncate(scaled_out, clip_min, clip_max)

# during realization
# lhs = lhs * ldom_scale / odom_scale
# lshift(lhs, log2(ldom_scale / odom_scale))
# lhs = lhs * dom_lscale / odom_scale
# rhs = adjust(rhs, dom_rscale, oscale, nbit)

# dom_scale



@register_annotate_function("add")
def add_rewrite(ref_call, new_args, ctx):
"""Rewrite function for add."""
if quantize_context().check_to_skip(ref_call):
return None

Expand All @@ -243,12 +341,15 @@ def add_rewrite(ref_call, new_args, ctx):

if lhs_kind is not None and rhs_kind is None:
if _analysis.check_constant(rhs_expr):
# - introduced by batch_norm: add(out, const)
# introduced by batch_norm: add(out, const)
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
else:
# happens in residual addition when the rhs is a skipped layer
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
# print('add lhs_kind: {0}'.format(lhs_kind))
# print('add out:\n{0}'.format(expr.astext(show_meta_data=False)))
return QAnnotateExpr(expr, lhs_kind)

if lhs_kind is not None and rhs_kind is not None:
if lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.INPUT:
Expand Down Expand Up @@ -281,11 +382,11 @@ def identity_rewrite(ref_call, new_args, ctx):
register_annotate_function("clip", identity_rewrite)
register_annotate_function("nn.relu", identity_rewrite)
register_annotate_function("strided_slice", identity_rewrite)
register_annotate_function("nn.avg_pool2d", identity_rewrite)
register_annotate_function("annotation.stop_fusion", identity_rewrite)


def pool2d_rewrite(ref_call, new_args, ctx):
@register_annotate_function("nn.max_pool2d")
def max_pool2d_rewrite(ref_call, new_args, ctx):
"""Rewrite function for max pool2d"""
if quantize_context().check_to_skip(ref_call):
return None
Expand All @@ -296,29 +397,27 @@ def pool2d_rewrite(ref_call, new_args, ctx):
return None
if x_kind == QAnnotateKind.ACTIVATION:
expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT)
expr = cast_hint(expr, current_qconfig().dtype_input)

expr = _forward_op(ref_call, [expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)


register_annotate_function("nn.max_pool2d", pool2d_rewrite)

@register_annotate_function("nn.avg_pool2d")
def avg_pool2d_rewrite(ref_call, new_args, ctx):
"""Rewrite function for avg_pool2d"""
if quantize_context().check_to_skip(ref_call):
return None

@register_annotate_function("annotation.cast_hint")
def cast_hint_rewrite(ref_call, new_args, ctx):
"""Rewrite function to force cast"""
expr, x_kind = _get_expr_kind(new_args[0])

if quantize_context().check_to_skip(ref_call):
return expr

if x_kind is None:
return new_args[0]
if x_kind == QAnnotateKind.ACTIVATION:
expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT)
return None
if x_kind == QAnnotateKind.INPUT:
expr = cast_hint(expr, current_qconfig().dtype_activation)

expr = _forward_op(ref_call, [expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)


@register_annotate_function("concatenate")
Expand Down
Loading