Skip to content

Commit

Permalink
[microNPU] Mean legalization support (apache#9576)
Browse files Browse the repository at this point in the history
Supports legalizing a Relay mean operation to an equivalent series of
NPU operations. Mean can be legalized given one of three cases:
    - Case 1 (axis == [1, 2] and keepsdims == True):
        depthwise_conv2d + binary_elementwise
    - Case 2 (ifm qparams == ofm qparams):
        pooling
    - Case 3 (else):
        depthwise_conv2d

Co-authored-by: Rishabh Jain <rishabh.jain2@arm.com>
  • Loading branch information
2 people authored and yangulei committed Jan 11, 2022
1 parent e6ff52c commit f97236c
Show file tree
Hide file tree
Showing 9 changed files with 607 additions and 11 deletions.
166 changes: 166 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def callback(
upscale="NONE",
ifm_layout=str(params.ifm.layout),
ofm_layout=str(params.ofm.layout),
ofm_dtype=str(params.ofm.dtype),
)
return ethosu_depthwise_conv2d

Expand Down Expand Up @@ -961,6 +962,170 @@ def __call__(self, *args, **kwargs):
pass


class MeanRewriter(DFPatternCallback):
"""Convert ethosu.mean composite functions to to an equivalent legalization:
- Case 1 (axis == [1, 2] and keepsdims == True):
ethosu_depthwise_conv2d + ethosu_binary_elementwise
- Case 2 (ifm qparams == ofm qparams): ethosu_pooling
- Case 3 (else): ethosu_depthwise_conv2d
"""

def __init__(self):
super().__init__(require_type=True)
self.pattern = (
wildcard().has_attr({"Composite": ethosu_patterns.MeanParams.composite_name})
)(wildcard())

def callback(
self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
) -> tvm.relay.Expr:
params = ethosu_patterns.MeanParams(post.op.body)
params.ifm.tensor = post.args[0]

ifm_shape = params.ifm.shape
ofm_shape = params.ofm.shape
lut = relay.const([], "int8")
axis = params.axis
reduced_op = params.ifm.tensor

# Enforce 4d input
if len(ifm_shape) < 4:
axis = [x + 1 for x in axis]
if len(ifm_shape) == 3:
ifm_shape = [1, params.height, params.width, ifm_shape[2]]
else:
ifm_shape = [1, params.height, params.width, 1]
reduced_op = relay.reshape(reduced_op, ifm_shape)

filter_height = ifm_shape[1] if 1 in axis else 1
filter_width = ifm_shape[2] if 2 in axis else 1
in_channels = out_channels = ifm_shape[-1]

# If the height is greater than max kernel height, reshape the input
# from [filter_height, filter_width] to [1, (filter_height*filter_width)]
# only in the case the axis is [1, 2].
if axis == [1, 2] and filter_height > 64:
ifm_shape = (ifm_shape[0], 1, filter_height * filter_width, in_channels)
filter_width = filter_height * filter_width
filter_height = 1
reduced_op = relay.reshape(reduced_op, ifm_shape)

if axis == [1, 2] and params.keepdims:
weight_scale = 1
weight_values = np.ones([out_channels, filter_height, filter_width, in_channels])
scale_bias = vela_api.pack_biases(
biases=np.zeros(ifm_shape[-1]),
ifm_scale=params.ifm.q_params.scale_f32,
ifm_dtype=np.dtype(params.ifm.dtype),
weight_scales=np.array([weight_scale], dtype=np.float),
ofm_scale=params.ofm.q_params.scale_f32,
is_activation_tanh_or_sigmoid=False,
)

reduced_op = ethosu_ops.ethosu_depthwise_conv2d(
ifm=reduced_op,
weight=relay.const(weight_values, params.ifm.dtype),
scale_bias=relay.const(scale_bias, "uint8"),
lut=lut,
ifm_scale=float(params.ifm.q_params.scale_f32),
ifm_zero_point=int(params.ifm.q_params.zero_point),
weight_zero_point=0,
ofm_scale=float(params.ofm.q_params.scale_f32),
ofm_zero_point=int(params.ofm.q_params.zero_point),
kernel_shape=(filter_height, filter_width),
ofm_channels=out_channels,
ofm_dtype="int16",
)

n = int(filter_height * filter_width)
eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0

scalar_tensor = relay.const(np.ones([1, 1, 1, 1], dtype="uint8"), dtype="uint8")

reduced_op = ethosu_ops.ethosu_binary_elementwise(
ifm=reduced_op,
ifm2=scalar_tensor,
lut=lut,
operator_type="MUL",
ifm_scale=float(params.ofm.q_params.scale_f32),
ifm_zero_point=int(params.ofm.q_params.zero_point),
ifm2_scale=1 / (n - eps),
ifm2_zero_point=0,
ofm_scale=float(params.ofm.q_params.scale_f32),
ofm_zero_point=int(params.ofm.q_params.zero_point),
ifm_channels=out_channels,
ifm2_channels=out_channels,
reversed_operands=False,
ofm_dtype="int8",
rounding_mode="NATURAL",
)
elif (
params.ifm.q_params.scale_f32 == params.ofm.q_params.scale_f32
and params.ifm.q_params.zero_point == params.ofm.q_params.zero_point
):
reduced_op = ethosu_ops.ethosu_pooling(
ifm=reduced_op,
lut=lut,
pooling_type="AVG",
ifm_scale=float(params.ifm.q_params.scale_f32),
ifm_zero_point=0,
ofm_scale=float(params.ofm.q_params.scale_f32),
ofm_zero_point=0,
pool_shape=(filter_height, filter_width),
ofm_channels=out_channels,
rounding_mode="TRUNCATE",
)
else:
weight_scale = 1 / (filter_height * filter_width)
weight_values = np.ones([out_channels, filter_height, filter_width, in_channels])
bias = -1 * int(params.ifm.q_params.zero_point) * filter_height * filter_width

scale_bias = vela_api.pack_biases(
biases=np.ones([ifm_shape[-1]]) * bias,
ifm_scale=params.ifm.q_params.scale_f32,
ifm_dtype=np.dtype(params.ifm.dtype),
weight_scales=np.array([weight_scale], dtype=np.float),
ofm_scale=params.ofm.q_params.scale_f32,
is_activation_tanh_or_sigmoid=False,
)
reduced_op = ethosu_ops.ethosu_depthwise_conv2d(
ifm=reduced_op,
weight=relay.const(weight_values, params.ifm.dtype),
scale_bias=relay.const(scale_bias, "uint8"),
lut=lut,
ifm_scale=float(params.ifm.q_params.scale_f32),
ifm_zero_point=0,
weight_zero_point=0,
ofm_scale=float(params.ofm.q_params.scale_f32),
ofm_zero_point=int(params.ofm.q_params.zero_point),
kernel_shape=(filter_height, filter_width),
ofm_channels=out_channels,
rounding_mode="NATURAL",
)

# Reshape to original ofm shape
if len(ofm_shape) < 4:
reduced_op = relay.reshape(reduced_op, ofm_shape)

return reduced_op


@ir.transform.module_pass(opt_level=1)
class LegalizeMean:
"""This is the pass that wraps the MeanRewriter"""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(MeanRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


@ir.transform.module_pass(opt_level=1)
class LegalizeEthosU:
"""This is the pass to call graph-rewrites to perform graph transformation
Expand All @@ -987,6 +1152,7 @@ def transform_module(
mod = LegalizeShl()(mod)
mod = LegalizeAbs()(mod)
mod = LegalizeTanh()(mod)
mod = LegalizeMean()(mod)
mod = LegalizeReshape()(mod)
mod = LegalizeStridedSlice()(mod)
mod = LegalizeNoOps()(mod)
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/op/depthwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def _extract_ethosu_depthwise_conv2d_params(attrs, args):
upscale = attrs.upscale
ifm_layout = attrs.ifm_layout
ofm_layout = attrs.ofm_layout
ofm_dtype = attrs.ofm_dtype

return (
ifm,
Expand All @@ -71,6 +72,7 @@ def _extract_ethosu_depthwise_conv2d_params(attrs, args):
upscale,
ifm_layout,
ofm_layout,
ofm_dtype,
)


Expand Down Expand Up @@ -115,6 +117,7 @@ def ethosu_depthwise_conv2d(
upscale: str = "NONE",
ifm_layout: str = "NHWC",
ofm_layout: str = "NHWC",
ofm_dtype: str = "int8",
) -> tvm.relay.Call:
"""This is a quantized 2D depthwise convolution operation as supported by
the NPU. It accepts either NHWC or NHCWB16 format
Expand Down Expand Up @@ -183,6 +186,8 @@ def ethosu_depthwise_conv2d(
The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16".
ofm_layout : str, optional
The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16".
ofm_dtype : str, optional
The Output Feature Map tensor data type. Can be 'int8', 'uint8' or 'int16'.
Returns
-------
Expand Down Expand Up @@ -212,4 +217,5 @@ def ethosu_depthwise_conv2d(
upscale,
ifm_layout,
ofm_layout,
ofm_dtype,
)
17 changes: 11 additions & 6 deletions python/tvm/relay/backend/contrib/ethosu/te/depthwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def depthwise_conv2d_compute(
upscale: str,
ifm_layout: str,
ofm_layout: str,
ofm_dtype: str,
) -> te.Tensor:
"""A compute operator representing the capabilities of 2D convolution for the NPU.
Expand Down Expand Up @@ -96,6 +97,8 @@ def depthwise_conv2d_compute(
The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16".
ofm_layout : str
The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16".
ofm_dtype : str, optional
The Output Feature Map tensor data type. Can be 'int8', 'uint8' or 'int16'.
Returns
-------
Expand Down Expand Up @@ -146,12 +149,14 @@ def depthwise_conv2d_compute(
depthwise = te.compute(
(1, ofm_height, ofm_width, channels),
lambda nn, hh, ww, cc: te.sum(
dmaed_ifm(
nn, hh * stride_h + rh * dilation_h, ww * stride_w + rw * dilation_w, cc
).astype(ifm.dtype)
* weight[cc, rh, rw, 0].astype(ifm.dtype)
# This is a trick to load 10 elements of the scale_bias at once, not accurate maths
+ (scale_bias[cc, 0] * scale_bias[cc, 9] + lut_expr).astype(ifm.dtype),
(
dmaed_ifm(
nn, hh * stride_h + rh * dilation_h, ww * stride_w + rw * dilation_w, cc
).astype(ifm.dtype)
* weight[cc, rh, rw, 0].astype(ifm.dtype)
# This is a trick to load 10 elements of the scale_bias at once, not accurate maths
+ (scale_bias[cc, 0] * scale_bias[cc, 9] + lut_expr).astype(ifm.dtype)
).astype(ofm_dtype),
axis=[rh, rw],
),
name="ethosu_depthwise_conv2d",
Expand Down
25 changes: 25 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,31 @@ def is_composite_func(func: relay.Function, name: str) -> bool:
return composite_name == name


def is_named_ethosu_op(expr: tvm.relay.Expr, name: str) -> bool:
"""Checks whether a relay expression matches that of the
named operator.
Parameters
----------
expr : tvm.relay.Expr
The expression to check.
name : str
The name of the expected operator
(without NPU prefix "contrib.ethosu").
Returns
-------
bool
True if expression matches name, false if not.
"""
prefix = "contrib.ethosu."
return (
isinstance(expr, tvm.relay.expr.Call)
and isinstance(expr.op, tvm.ir.op.Op)
and expr.op.name == prefix + name
)


def get_range_for_dtype_str(dtype: str) -> Tuple[int, int]:
"""
Produce the min,max for a give data type.
Expand Down
Loading

0 comments on commit f97236c

Please sign in to comment.