Skip to content

Commit

Permalink
[microNPU] Add support for pack and unpack (#9960)
Browse files Browse the repository at this point in the history
* [microNPU] Add support for pack and unpack

Pack is represented by a series of `expand_dims` operations
followed by a `concatenate` in Relay. Unpack is represented
by a `split` followed by a series of `squeeze` operations in
Relay. This commit legalizes `expand_dims` and `squeeze` to
reshape operations while making use of existing legalization
techniques for `split` and `concatenate` so that pack and
unpack can be offloaded to the NPU.

Change-Id: I3fbebb4ece5ca04598f8e587b9e6c0ddf280266d

* rebase and add tests for expand dims and squeeze

Change-Id: Ic6a9fd77b61368720328bfe82032490bcc66152c
  • Loading branch information
lhutton1 authored Feb 9, 2022
1 parent 345dc37 commit 9282367
Show file tree
Hide file tree
Showing 4 changed files with 557 additions and 6 deletions.
66 changes: 66 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,6 +1445,70 @@ def __call__(self, *args, **kwargs):
pass


class ExpandDimsRewriter(DFPatternCallback):
"""Legalize expand dims to a reshape operator."""

def __init__(self):
super().__init__(require_type=True, rewrite_once=True)
self.pattern = (
wildcard().has_attr({"Composite": ethosu_patterns.ExpandDimsParams.composite_name})
)(None)

def callback(
self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
) -> tvm.relay.Expr:
params = ethosu_patterns.ExpandDimsParams(post.op.body)
return relay.op.reshape(post.args[0], newshape=params.output.shape)


@ir.transform.module_pass(opt_level=1)
class LegalizeExpandDims:
"""This is the pass that wraps ExpandDimsRewriter."""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(ExpandDimsRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


class SqueezeRewriter(DFPatternCallback):
"""Legalize squeeze to a reshape operator."""

def __init__(self):
super().__init__(require_type=True, rewrite_once=True)
self.pattern = (
wildcard().has_attr({"Composite": ethosu_patterns.SqueezeParams.composite_name})
)(None)

def callback(
self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
) -> tvm.relay.Expr:
params = ethosu_patterns.SqueezeParams(post.op.body)
return relay.op.reshape(post.args[0], newshape=params.output.shape)


@ir.transform.module_pass(opt_level=1)
class LegalizeSqueeze:
"""This is the pass that wraps SqueezeRewriter."""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(SqueezeRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


@ir.transform.module_pass(opt_level=1)
class LegalizeEthosU:
"""This is the pass to call graph-rewrites to perform graph transformation
Expand Down Expand Up @@ -1477,6 +1541,8 @@ def transform_module(
mod = LegalizeSigmoid()(mod)
mod = LegalizeRequantize()(mod)
mod = LegalizeResize2d()(mod)
mod = LegalizeExpandDims()(mod)
mod = LegalizeSqueeze()(mod)
mod = LegalizeReshape()(mod)
mod = LegalizeStridedSlice()(mod)
mod = LegalizeNoOps()(mod)
Expand Down
84 changes: 78 additions & 6 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,19 +1214,22 @@ class ConcatParams:

def __init__(self, func_body):
self.concat = func_body
self.is_qnn_variant = self.concat.op.name == "qnn.concatenate"
self.input_tensors = [TensorParams(tensor) for tensor in list(func_body.args[0])]
self.input_scales = [s.data.asnumpy() for s in list(func_body.args[1])]
self.input_zero_points = [zp.data.asnumpy() for zp in list(func_body.args[2])]
self.axis = func_body.attrs.axis

if self.is_qnn_variant:
self.input_scales = [s.data.asnumpy() for s in list(func_body.args[1])]
self.input_zero_points = [zp.data.asnumpy() for zp in list(func_body.args[2])]

def is_valid(self):
"""Checks whether Concatenate has compatible attributes with the hardware"""
if not check_valid_dtypes(self.input_tensors, supported_dtypes=[np.int8]):
return False
# Check that the scales and zero points of input tensors are the same
if not all(self.input_scales == self.input_scales[0]):
if self.is_qnn_variant and not all(self.input_scales == self.input_scales[0]):
return False
if not all(self.input_zero_points == self.input_zero_points[0]):
if self.is_qnn_variant and not all(self.input_zero_points == self.input_zero_points[0]):
return False

input_dim = len(self.input_tensors[0].shape)
Expand All @@ -1244,6 +1247,8 @@ def is_valid(self):
output_shape = self.concat.checked_type.shape
if len(output_shape) != input_dim:
return False
if len(output_shape) > 3 and output_shape[0] != 1:
return False
return True


Expand All @@ -1252,8 +1257,11 @@ def concat_pattern():
tensors = is_tuple(None)
scales = is_tuple(None)
zero_points = is_tuple(None)
concat = is_op("qnn.concatenate")(tensors, scales, zero_points, is_constant(), is_constant())
return concat
qnn_concat = is_op("qnn.concatenate")(
tensors, scales, zero_points, is_constant(), is_constant()
)
concat = is_op("concatenate")(tensors)
return concat | qnn_concat


class SplitParams:
Expand Down Expand Up @@ -1433,6 +1441,60 @@ def resize2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
return quant | is_op("image.resize2d")(wildcard()).has_attr({"method": "nearest_neighbor"})


class ExpandDimsParams:
"""
This class will parse a call to a ethos-u.expand_dims composite function
and extract the parameter information.
"""

composite_name = "ethos-u.expand_dims"

def __init__(self, func_body):
self.expand_dims = func_body
self.input = TensorParams(func_body.args[0])
self.output = TensorParams(func_body)

def is_valid(self):
"""Checks whether expand_dims has compatible attributes with the hardware."""
if not check_dimensions(self.input) or not check_dimensions(self.output):
return False
if not check_valid_dtypes([self.input, self.output], supported_dtypes=[np.int8]):
return False
return True


def expand_dims_pattern():
"""Create the pattern for expand_dims."""
return is_op("expand_dims")(wildcard())


class SqueezeParams:
"""
This class will parse a call to a ethos-u.squeeze composite function
and extract the parameter information.
"""

composite_name = "ethos-u.squeeze"

def __init__(self, func_body):
self.squeeze = func_body
self.input = TensorParams(func_body.args[0])
self.output = TensorParams(func_body)

def is_valid(self):
"""Checks whether squeeze has compatible attributes with the hardware."""
if not check_dimensions(self.output):
return False
if not check_valid_dtypes([self.input, self.output], supported_dtypes=[np.int8]):
return False
return True


def squeeze_pattern():
"""Create the pattern for squeeze."""
return is_op("squeeze")(wildcard())


@register_pattern_table("ethos-u")
def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]:
return [
Expand Down Expand Up @@ -1533,6 +1595,16 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
resize2d_pattern(),
lambda pat: Resize2dParams(pat).is_valid(),
),
(
ExpandDimsParams.composite_name,
expand_dims_pattern(),
lambda pat: ExpandDimsParams(pat).is_valid(),
),
(
SqueezeParams.composite_name,
squeeze_pattern(),
lambda pat: SqueezeParams(pat).is_valid(),
),
]


Expand Down
56 changes: 56 additions & 0 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,28 @@ def create_model():
_compare_ethosu_with_reference(ethosu_mod, input_data, output_data, accel_type)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize("ifm_shape,axis", [((2,), 0), ((1, 3, 3), 2)])
def test_tflite_expand_dims(accel_type, ifm_shape, axis):
@tf.function
def expand_dims_func(x):
return tf.expand_dims(x, axis=axis)

_compare_tvm_with_tflite(expand_dims_func, [ifm_shape], accel_type)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize(
"ifm_shape,axis", [((1, 1, 2, 1), 0), ((1, 3, 3, 1), 3), ((1, 1, 2, 1), None)]
)
def test_tflite_squeeze(accel_type, ifm_shape, axis):
@tf.function
def squeeze_func(x):
return tf.squeeze(x, axis=axis)

_compare_tvm_with_tflite(squeeze_func, [ifm_shape], accel_type)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize(
"ifm_shape,size",
Expand Down Expand Up @@ -1100,5 +1122,39 @@ def conv2d_transpose(x):
_compare_tvm_with_tflite(conv2d_transpose, [ifm_shape], accel_type=accel_type)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize(
"ifm_shapes,axis",
[
([(1, 2, 2), (1, 2, 2), (1, 2, 2)], 2),
([(5, 4), (5, 4)], 1),
([(1,), (1,)], 0),
([(3, 1), (3, 1), (3, 1), (3, 1)], 0),
],
)
def test_tflite_pack(accel_type, ifm_shapes, axis):
@tf.function
def pack_func(*inputs):
return tf.stack(inputs, axis=axis)

# TODO(lhutton1) For now output is not bit exact with TFLite.
# This is because TFLite reference kernels are not being used.
# For this, TFLite will need upgrading to 2.6.
_compare_tvm_with_tflite(pack_func, ifm_shapes, accel_type, output_tolerance=1)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize(
"ifm_shape,axis",
[[(1, 2, 3, 4), 1], [(2, 3), 1], [(5, 6, 7), 2]],
)
def test_tflite_unpack(accel_type, ifm_shape, axis):
@tf.function
def unpack_func(x):
return tf.unstack(x, axis=axis)

_compare_tvm_with_tflite(unpack_func, [ifm_shape], accel_type)


if __name__ == "__main__":
pytest.main([__file__])
Loading

0 comments on commit 9282367

Please sign in to comment.