diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 5613d613f9847..e35fe1543fa22 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -354,6 +354,7 @@ def callback( upscale="NONE", ifm_layout=str(params.ifm.layout), ofm_layout=str(params.ofm.layout), + ofm_dtype=str(params.ofm.dtype), ) return ethosu_depthwise_conv2d @@ -961,6 +962,170 @@ def __call__(self, *args, **kwargs): pass +class MeanRewriter(DFPatternCallback): + """Convert ethosu.mean composite functions to to an equivalent legalization: + - Case 1 (axis == [1, 2] and keepsdims == True): + ethosu_depthwise_conv2d + ethosu_binary_elementwise + - Case 2 (ifm qparams == ofm qparams): ethosu_pooling + - Case 3 (else): ethosu_depthwise_conv2d + """ + + def __init__(self): + super().__init__(require_type=True) + self.pattern = ( + wildcard().has_attr({"Composite": ethosu_patterns.MeanParams.composite_name}) + )(wildcard()) + + def callback( + self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map + ) -> tvm.relay.Expr: + params = ethosu_patterns.MeanParams(post.op.body) + params.ifm.tensor = post.args[0] + + ifm_shape = params.ifm.shape + ofm_shape = params.ofm.shape + lut = relay.const([], "int8") + axis = params.axis + reduced_op = params.ifm.tensor + + # Enforce 4d input + if len(ifm_shape) < 4: + axis = [x + 1 for x in axis] + if len(ifm_shape) == 3: + ifm_shape = [1, params.height, params.width, ifm_shape[2]] + else: + ifm_shape = [1, params.height, params.width, 1] + reduced_op = relay.reshape(reduced_op, ifm_shape) + + filter_height = ifm_shape[1] if 1 in axis else 1 + filter_width = ifm_shape[2] if 2 in axis else 1 + in_channels = out_channels = ifm_shape[-1] + + # If the height is greater than max kernel height, reshape the input + # from [filter_height, filter_width] to [1, (filter_height*filter_width)] + # only in the case the axis is [1, 2]. + if axis == [1, 2] and filter_height > 64: + ifm_shape = (ifm_shape[0], 1, filter_height * filter_width, in_channels) + filter_width = filter_height * filter_width + filter_height = 1 + reduced_op = relay.reshape(reduced_op, ifm_shape) + + if axis == [1, 2] and params.keepdims: + weight_scale = 1 + weight_values = np.ones([out_channels, filter_height, filter_width, in_channels]) + scale_bias = vela_api.pack_biases( + biases=np.zeros(ifm_shape[-1]), + ifm_scale=params.ifm.q_params.scale_f32, + ifm_dtype=np.dtype(params.ifm.dtype), + weight_scales=np.array([weight_scale], dtype=np.float), + ofm_scale=params.ofm.q_params.scale_f32, + is_activation_tanh_or_sigmoid=False, + ) + + reduced_op = ethosu_ops.ethosu_depthwise_conv2d( + ifm=reduced_op, + weight=relay.const(weight_values, params.ifm.dtype), + scale_bias=relay.const(scale_bias, "uint8"), + lut=lut, + ifm_scale=float(params.ifm.q_params.scale_f32), + ifm_zero_point=int(params.ifm.q_params.zero_point), + weight_zero_point=0, + ofm_scale=float(params.ofm.q_params.scale_f32), + ofm_zero_point=int(params.ofm.q_params.zero_point), + kernel_shape=(filter_height, filter_width), + ofm_channels=out_channels, + ofm_dtype="int16", + ) + + n = int(filter_height * filter_width) + eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0 + + scalar_tensor = relay.const(np.ones([1, 1, 1, 1], dtype="uint8"), dtype="uint8") + + reduced_op = ethosu_ops.ethosu_binary_elementwise( + ifm=reduced_op, + ifm2=scalar_tensor, + lut=lut, + operator_type="MUL", + ifm_scale=float(params.ofm.q_params.scale_f32), + ifm_zero_point=int(params.ofm.q_params.zero_point), + ifm2_scale=1 / (n - eps), + ifm2_zero_point=0, + ofm_scale=float(params.ofm.q_params.scale_f32), + ofm_zero_point=int(params.ofm.q_params.zero_point), + ifm_channels=out_channels, + ifm2_channels=out_channels, + reversed_operands=False, + ofm_dtype="int8", + rounding_mode="NATURAL", + ) + elif ( + params.ifm.q_params.scale_f32 == params.ofm.q_params.scale_f32 + and params.ifm.q_params.zero_point == params.ofm.q_params.zero_point + ): + reduced_op = ethosu_ops.ethosu_pooling( + ifm=reduced_op, + lut=lut, + pooling_type="AVG", + ifm_scale=float(params.ifm.q_params.scale_f32), + ifm_zero_point=0, + ofm_scale=float(params.ofm.q_params.scale_f32), + ofm_zero_point=0, + pool_shape=(filter_height, filter_width), + ofm_channels=out_channels, + rounding_mode="TRUNCATE", + ) + else: + weight_scale = 1 / (filter_height * filter_width) + weight_values = np.ones([out_channels, filter_height, filter_width, in_channels]) + bias = -1 * int(params.ifm.q_params.zero_point) * filter_height * filter_width + + scale_bias = vela_api.pack_biases( + biases=np.ones([ifm_shape[-1]]) * bias, + ifm_scale=params.ifm.q_params.scale_f32, + ifm_dtype=np.dtype(params.ifm.dtype), + weight_scales=np.array([weight_scale], dtype=np.float), + ofm_scale=params.ofm.q_params.scale_f32, + is_activation_tanh_or_sigmoid=False, + ) + reduced_op = ethosu_ops.ethosu_depthwise_conv2d( + ifm=reduced_op, + weight=relay.const(weight_values, params.ifm.dtype), + scale_bias=relay.const(scale_bias, "uint8"), + lut=lut, + ifm_scale=float(params.ifm.q_params.scale_f32), + ifm_zero_point=0, + weight_zero_point=0, + ofm_scale=float(params.ofm.q_params.scale_f32), + ofm_zero_point=int(params.ofm.q_params.zero_point), + kernel_shape=(filter_height, filter_width), + ofm_channels=out_channels, + rounding_mode="NATURAL", + ) + + # Reshape to original ofm shape + if len(ofm_shape) < 4: + reduced_op = relay.reshape(reduced_op, ofm_shape) + + return reduced_op + + +@ir.transform.module_pass(opt_level=1) +class LegalizeMean: + """This is the pass that wraps the MeanRewriter""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(MeanRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + @ir.transform.module_pass(opt_level=1) class LegalizeEthosU: """This is the pass to call graph-rewrites to perform graph transformation @@ -987,6 +1152,7 @@ def transform_module( mod = LegalizeShl()(mod) mod = LegalizeAbs()(mod) mod = LegalizeTanh()(mod) + mod = LegalizeMean()(mod) mod = LegalizeReshape()(mod) mod = LegalizeStridedSlice()(mod) mod = LegalizeNoOps()(mod) diff --git a/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py index 6d96f4465d17a..3df3e2d81303a 100644 --- a/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py @@ -50,6 +50,7 @@ def _extract_ethosu_depthwise_conv2d_params(attrs, args): upscale = attrs.upscale ifm_layout = attrs.ifm_layout ofm_layout = attrs.ofm_layout + ofm_dtype = attrs.ofm_dtype return ( ifm, @@ -71,6 +72,7 @@ def _extract_ethosu_depthwise_conv2d_params(attrs, args): upscale, ifm_layout, ofm_layout, + ofm_dtype, ) @@ -115,6 +117,7 @@ def ethosu_depthwise_conv2d( upscale: str = "NONE", ifm_layout: str = "NHWC", ofm_layout: str = "NHWC", + ofm_dtype: str = "int8", ) -> tvm.relay.Call: """This is a quantized 2D depthwise convolution operation as supported by the NPU. It accepts either NHWC or NHCWB16 format @@ -183,6 +186,8 @@ def ethosu_depthwise_conv2d( The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16". ofm_layout : str, optional The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16". + ofm_dtype : str, optional + The Output Feature Map tensor data type. Can be 'int8', 'uint8' or 'int16'. Returns ------- @@ -212,4 +217,5 @@ def ethosu_depthwise_conv2d( upscale, ifm_layout, ofm_layout, + ofm_dtype, ) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py index 05b2993f58571..c9a88e803c3da 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py @@ -42,6 +42,7 @@ def depthwise_conv2d_compute( upscale: str, ifm_layout: str, ofm_layout: str, + ofm_dtype: str, ) -> te.Tensor: """A compute operator representing the capabilities of 2D convolution for the NPU. @@ -96,6 +97,8 @@ def depthwise_conv2d_compute( The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16". ofm_layout : str The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16". + ofm_dtype : str, optional + The Output Feature Map tensor data type. Can be 'int8', 'uint8' or 'int16'. Returns ------- @@ -146,12 +149,14 @@ def depthwise_conv2d_compute( depthwise = te.compute( (1, ofm_height, ofm_width, channels), lambda nn, hh, ww, cc: te.sum( - dmaed_ifm( - nn, hh * stride_h + rh * dilation_h, ww * stride_w + rw * dilation_w, cc - ).astype(ifm.dtype) - * weight[cc, rh, rw, 0].astype(ifm.dtype) - # This is a trick to load 10 elements of the scale_bias at once, not accurate maths - + (scale_bias[cc, 0] * scale_bias[cc, 9] + lut_expr).astype(ifm.dtype), + ( + dmaed_ifm( + nn, hh * stride_h + rh * dilation_h, ww * stride_w + rw * dilation_w, cc + ).astype(ifm.dtype) + * weight[cc, rh, rw, 0].astype(ifm.dtype) + # This is a trick to load 10 elements of the scale_bias at once, not accurate maths + + (scale_bias[cc, 0] * scale_bias[cc, 9] + lut_expr).astype(ifm.dtype) + ).astype(ofm_dtype), axis=[rh, rw], ), name="ethosu_depthwise_conv2d", diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index 45a82d5932d63..21b0ecf789d2a 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -142,6 +142,31 @@ def is_composite_func(func: relay.Function, name: str) -> bool: return composite_name == name +def is_named_ethosu_op(expr: tvm.relay.Expr, name: str) -> bool: + """Checks whether a relay expression matches that of the + named operator. + + Parameters + ---------- + expr : tvm.relay.Expr + The expression to check. + name : str + The name of the expected operator + (without NPU prefix "contrib.ethosu"). + + Returns + ------- + bool + True if expression matches name, false if not. + """ + prefix = "contrib.ethosu." + return ( + isinstance(expr, tvm.relay.expr.Call) + and isinstance(expr.op, tvm.ir.op.Op) + and expr.op.name == prefix + name + ) + + def get_range_for_dtype_str(dtype: str) -> Tuple[int, int]: """ Produce the min,max for a give data type. diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 54185b2fca41d..e53be02b8dd2a 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -947,6 +947,92 @@ def tanh_pattern(): return quant +class MeanParams: + """ + This class will parse a call to ethosu.mean composite function + and extract the parameter information. + """ + + composite_name = "ethos-u.mean" + + def __init__(self, func_body: Call): + requantize = func_body + mean_op = requantize.args[0] + attrs = mean_op.attrs + cast = mean_op.args[0] + + layout = "NHWC" + self.ifm = TensorParams( + cast.args[0], + layout, + requantize.args[RequantArgs.IFM_SCALE.value], + requantize.args[RequantArgs.IFM_ZERO_POINT.value], + ) + self.ofm = TensorParams( + requantize, + layout, + requantize.args[RequantArgs.OFM_SCALE.value], + requantize.args[RequantArgs.OFM_ZERO_POINT.value], + ) + + ifm_shape = self.ifm.shape + self.height = ifm_shape[0] if len(ifm_shape) in (2, 3) else ifm_shape[1] + self.width = ifm_shape[1] if len(ifm_shape) in (2, 3) else ifm_shape[2] + self.keepdims = attrs.keepdims + + self.axis = list(sorted(attrs.axis)) + if attrs.exclude: + self.axis = [i for i in range(len(self.ifm.shape)) if i not in self.axis] + + def is_valid(self) -> bool: + """ + Checks whether Mean has compatible attributes with HW. + """ + + def check_axis(num_dims, axis): + if num_dims in (2, 3): + return axis in ([0], [1], [0, 1]) + return axis in ([1], [2], [1, 2]) + + tensor_params = [self.ifm, self.ofm] + if not check_valid_dtypes(tensor_params, supported_dtypes=[np.int8]): + return False + if self.ifm.dtype != self.ofm.dtype: + return False + if not len(self.ifm.shape) in [2, 3, 4]: + return False + if not check_axis(len(self.ifm.shape), self.axis): + return False + + # MEAN has further restrictions on the input size, depending on legalization method. + input_size = self.height * self.width + if input_size > 65536: + return False + if ( + self.ifm.q_params.scale_f32 != self.ofm.q_params.scale_f32 + or self.ifm.q_params.zero_point != self.ofm.q_params.zero_point + ) and input_size > 4096: + return False + if self.axis == [1, 2] and self.keepdims and self.ifm.dtype == "int8" and input_size > 256: + return False + # Large kernel height reshape only when axis is [1, 2] + if self.axis != [1, 2] and self.height > 64: + return False + return True + + +def mean_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """ + This function creates the pattern for mean. + """ + pattern = is_op("cast")(wildcard()) + pattern = is_op("mean")(pattern) + pattern = is_op("qnn.requantize")( + pattern, is_constant(), is_constant(), is_constant(), is_constant() + ) + return pattern + + @register_pattern_table("ethos-u") def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]: return [ @@ -1016,6 +1102,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal lambda pat: AbsParams(pat).is_valid(), ), (TanhParams.composite_name, tanh_pattern(), lambda pat: TanhParams(pat).is_valid()), + ( + MeanParams.composite_name, + mean_pattern(), + lambda pat: MeanParams(pat).is_valid(), + ), ] diff --git a/src/relay/op/contrib/ethosu/binary_elementwise.cc b/src/relay/op/contrib/ethosu/binary_elementwise.cc index 48b085a2b6f24..4e0d086e66b86 100644 --- a/src/relay/op/contrib/ethosu/binary_elementwise.cc +++ b/src/relay/op/contrib/ethosu/binary_elementwise.cc @@ -128,6 +128,21 @@ struct EthosuBinaryElementwiseAttrs : public tvm::AttrsNode& ifm_shape, const DataType& ifm_dtype) { + if (ifm_dtype != DataType::UInt(8)) { + return false; + } + + for (const auto& expr : ifm_shape) { + const auto& dim_int_node = expr.as(); + CHECK(dim_int_node) << "Expected IntImmNode for shape dimensions."; + int dim = dim_int_node->value; + if (dim != 1) return false; + } + + return true; +} + bool EthosuBinaryElementwiseRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { const int ifm_index = 0; @@ -156,7 +171,7 @@ bool EthosuBinaryElementwiseRel(const Array& types, int num_inputs, const ofm_dtype = DataType::Int(32); } - if (ifm_dtype != ifm2_dtype) { + if (ifm_dtype != ifm2_dtype && !IsScalarTensor(ifm2->shape, ifm2_dtype)) { reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << "Invalid operator: expected ethosu_binary_elementwise " << "type for ifm2 be the same of ifm but was " << ifm2_dtype @@ -166,11 +181,11 @@ bool EthosuBinaryElementwiseRel(const Array& types, int num_inputs, const if (operator_type == "ADD" || operator_type == "SUB" || operator_type == "MUL") { if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8) && - ifm_dtype != DataType::Int(32)) { + ifm_dtype != DataType::Int(16) && ifm_dtype != DataType::Int(32)) { reporter->GetDiagCtx().EmitFatal( Diagnostic::Error(reporter->GetSpan()) << "Invalid operator: expected ethosu_binary_elementwise " << operator_type - << " type(uint8) or type(int8) or type(int32) for ifm but was " << ifm_dtype); + << " type(uint8), type(int8), type(int16) or type(int32) for ifm but was " << ifm_dtype); return false; } if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) && diff --git a/src/relay/op/contrib/ethosu/depthwise.cc b/src/relay/op/contrib/ethosu/depthwise.cc index 5a8997a148e07..c95385ad95d83 100644 --- a/src/relay/op/contrib/ethosu/depthwise.cc +++ b/src/relay/op/contrib/ethosu/depthwise.cc @@ -56,6 +56,7 @@ struct EthosuDepthwiseConv2DAttrs : public tvm::AttrsNode& types, int num_inputs, const At const auto* param = attrs.as(); ICHECK(param != nullptr) << "EthosuDepthwiseConv2DAttrs cannot be nullptr."; + DataType ofm_dtype; + + if (param->ofm_dtype == "int8") { + ofm_dtype = DataType::Int(8); + } else if (param->ofm_dtype == "uint8") { + ofm_dtype = DataType::UInt(8); + } else if (param->ofm_dtype == "int16") { + ofm_dtype = DataType::Int(16); + } else if (param->ofm_dtype == "int32") { + ofm_dtype = DataType::Int(32); + } + if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) { reporter->GetDiagCtx().EmitFatal( Diagnostic::Error(reporter->GetSpan()) @@ -156,6 +172,15 @@ bool EthosuDepthwiseConv2DRel(const Array& types, int num_inputs, const At return false; } + if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) && + ofm_dtype != DataType::Int(16) && ofm_dtype != DataType::Int(32)) { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_depthwise_conv2d output data type " + << " type(uint8), type(int8), type(int16) or type(int32) for ofm but was " << ofm_dtype); + return false; + } + // Collect the ifm, weight and ofm tensors for using in the inference function Array tensor_types = {types[0], types[1], types[4]}; @@ -169,7 +194,7 @@ bool EthosuDepthwiseConv2DRel(const Array& types, int num_inputs, const At EthosuInferKernelOutput(ifm->shape, param->ifm_layout, param->ofm_layout, param->kernel_shape, param->ofm_channels, param->dilation, param->strides, param->padding); - reporter->Assign(types[4], TensorType(ofm_shape, ifm->dtype)); + reporter->Assign(types[4], TensorType(ofm_shape, ofm_dtype)); return true; } @@ -180,7 +205,8 @@ Expr MakeEthosuDepthwiseConv2D(Expr ifm, Expr weight, Expr scale_bias, Expr lut, IndexExpr ofm_channels, Array strides, Array padding, Array dilation, String activation, int clip_min, int clip_max, String rounding_mode, - String upscale, String ifm_layout, String ofm_layout) { + String upscale, String ifm_layout, String ofm_layout, + String ofm_dtype) { auto attrs = make_object(); attrs->ifm_scale = ifm_scale; attrs->ifm_zero_point = ifm_zero_point; @@ -199,6 +225,7 @@ Expr MakeEthosuDepthwiseConv2D(Expr ifm, Expr weight, Expr scale_bias, Expr lut, attrs->upscale = std::move(upscale); attrs->ifm_layout = std::move(ifm_layout); attrs->ofm_layout = std::move(ofm_layout); + attrs->ofm_dtype = std::move(ofm_dtype); static const Op& op = Op::Get("contrib.ethosu.depthwise_conv2d"); return Call(op, {ifm, weight, scale_bias, lut}, Attrs(attrs), {}); } diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index f4393d409d048..afd635d96cc66 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -502,6 +502,107 @@ def representative_dataset(): infra.verify_source(compiled_models, accel_type) +@pytest.mark.parametrize( + "accel_type", + ACCEL_TYPES, +) +@pytest.mark.parametrize( + "ifm_shape, axis, keep_dims, use_same_quantization", + [ + # mean to depthwise + multiply + [(1, 8, 16, 16), (1, 2), True, False], + [(1, 3, 4), (0, 1), True, False], + [(1, 65, 2, 1), (1, 2), True, False], # special case when h > 64 + # mean to average pool + [(1, 8, 16, 16), (2,), False, True], + [(3, 3, 4), (0,), True, True], + [(8, 5), (0,), False, True], + # mean to depthwise + [(1, 8, 16, 16), (2,), True, False], + [(1, 8, 16, 16), (2, 1), False, False], + [(8, 4), (0,), False, False], + ], +) +def test_mean(accel_type, ifm_shape, axis, keep_dims, use_same_quantization): + dtype = "int8" + + def create_mod_from_tflite(): + class Model(tf.Module): + @tf.function + def tf_function(self, x): + op = tf.math.reduce_mean(x, axis=axis, keepdims=keep_dims) + return op + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_graph = converter.convert() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, _ = relay.frontend.from_tflite( + tflite_model, + shape_dict={"ifm": ifm_shape}, + dtype_dict={"ifm": dtype}, + ) + input_data, output_data = infra.generate_ref_data_tflite(tflite_graph) + return mod, input_data, output_data + + def create_mod_from_relay(): + ifm = relay.var("input", shape=ifm_shape, dtype=dtype) + cast = relay.cast(ifm, dtype="int32") + mean = relay.mean(cast, axis=axis, keepdims=keep_dims) + requantize = relay.qnn.op.requantize( + mean, + input_scale=relay.const(1.0, dtype="float32"), + input_zero_point=relay.const(0, dtype="int32"), + output_scale=relay.const(1.0, dtype="float32"), + output_zero_point=relay.const(0, dtype="int32"), + ) + + func = relay.Function(relay.analysis.free_vars(requantize), requantize) + mod = tvm.IRModule.from_expr(func) + + input_data = {"input": np.random.randint(low=-127, high=128, size=ifm_shape, dtype=dtype)} + output_data = generate_ref_data(mod, input_data) + return mod, input_data, output_data + + mod, input_data, output_data = ( + create_mod_from_relay() if use_same_quantization else create_mod_from_tflite() + ) + mod = partition_for_ethosu(mod) + + # TODO(lhutton1) For now output is not bit exact with TFLite. + # This is because TFLite reference kernels are not being used. + # For this, TFLite will need upgrading to 2.6. + compiled_models = infra.build_source( + mod, input_data, output_data, accel_type, output_tolerance=1 + ) + + # Assumes only two runtime.Modules are created -- i.e. single offload module + ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0] + + # Verify generated C source + get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts") + compilation_artifacts = get_artifacts(ethosu_module) + cmms = bytes.fromhex(compilation_artifacts[0].command_stream) + infra.print_payload(cmms) + infra.verify_source(compiled_models, accel_type) + + @pytest.mark.parametrize("accel_type", ACCEL_TYPES) def test_binary_add_from_constant_scalar(accel_type): dtype = "uint8" diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 64bdae5c1b8ba..00589a98e9f9a 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -28,6 +28,7 @@ from tvm.relay.backend.contrib.ethosu import legalize, preprocess from tvm.relay import dataflow_pattern from tvm.relay.op.contrib import ethosu +from tvm.relay.backend.contrib.ethosu import util from tvm.relay.build_module import bind_params_by_name from . import relay_ir_builder @@ -1059,5 +1060,164 @@ def representative_dataset(): assert tuple(func_body.args[1].checked_type.shape) == (256,) +@pytest.mark.parametrize( + "ifm_shape, axis, keep_dims, use_same_quantization", + [ + # mean to depthwise + multiply + [(1, 8, 16, 16), (1, 2), True, False], + [(1, 8, 16, 16), (2, 1), True, False], + [(1, 3, 4), (0, 1), True, False], + [(8, 5), (1, 0), True, False], + [(1, 65, 2, 1), (1, 2), True, False], # special case when h > 64 + # mean to average pool + [(1, 8, 16, 16), (1,), True, True], + [(1, 8, 16, 16), (2,), False, True], + [(1, 8, 16, 16), (1, 2), False, True], + [(3, 3, 4), (0,), True, True], + [(3, 3, 4), (1,), False, True], + [(8, 5), (0,), False, True], + [(8, 5), (1,), True, True], + # mean to depthwise + [(1, 8, 16, 16), (1,), True, False], + [(1, 8, 16, 16), (2,), True, False], + [(1, 8, 16, 16), (1, 2), False, False], + [(8, 4), (0,), False, False], + ], +) +def test_mean(ifm_shape, axis, keep_dims, use_same_quantization): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x): + op = tf.math.reduce_mean(x, axis=axis, keepdims=keep_dims) + return op + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model, 0) + + mod, _ = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + return mod + + def create_relay_graph_with_same_quantization(): + ifm = relay.var("input", shape=ifm_shape, dtype=dtype) + cast = relay.cast(ifm, dtype="int32") + mean = relay.mean(cast, axis=axis, keepdims=keep_dims) + requantize = relay.qnn.op.requantize( + mean, + input_scale=relay.const(1.0, dtype="float32"), + input_zero_point=relay.const(0, dtype="int32"), + output_scale=relay.const(1.0, dtype="float32"), + output_zero_point=relay.const(0, dtype="int32"), + ) + + func = relay.Function(relay.analysis.free_vars(requantize), requantize) + mod = tvm.IRModule.from_expr(func) + return mod + + def verify(ext_func): + out_var = ext_func.body + + next_op = out_var + mul_op = None + pooling_op = None + depthwise_op = None + if ( + isinstance(next_op, relay.expr.Call) + and isinstance(next_op.op, tvm.ir.op.Op) + and next_op.op.name == "reshape" + ): + next_op = next_op.args[0] + if util.is_named_ethosu_op(next_op, "binary_elementwise"): + mul_op = next_op + next_op = next_op.args[0] + if util.is_named_ethosu_op(next_op, "pooling"): + pooling_op = next_op + next_op = next_op.args[0] + if util.is_named_ethosu_op(next_op, "depthwise_conv2d"): + depthwise_op = next_op + next_op = next_op.args[0] + while ( + isinstance(next_op, relay.expr.Call) + and isinstance(next_op.op, tvm.ir.op.Op) + and next_op.op.name == "reshape" + ): + next_op = next_op.args[0] + in_var = next_op + + def calculate_expected_output_shape(): + for i in range(len(ifm_shape)): + if i in axis: + if keep_dims: + yield 1 + else: + yield ifm_shape[i] + + out_shape = tuple(calculate_expected_output_shape()) + + # check IFM + assert tuple(in_var.checked_type.shape) == ifm_shape + assert in_var.checked_type.dtype == dtype + + # check OFM + assert tuple(out_var.checked_type.shape) == out_shape + assert out_var.checked_type.dtype == dtype + + # check expected legalization case + if axis in [(1, 2), (2, 1), (0, 1), (1, 0)] and keep_dims and dtype == "int8": + assert depthwise_op and mul_op + assert mul_op.attrs.operator_type == "MUL" + elif pooling_op: + attrs = pooling_op.attrs + assert ( + attrs.ifm_scale == attrs.ofm_scale and attrs.ifm_zero_point == attrs.ofm_zero_point + ) + else: + assert depthwise_op + assert not mul_op + + rewriter = legalize.MeanRewriter() + pattern_table = [ + ( + ethosu.MeanParams.composite_name, + ethosu.mean_pattern(), + lambda pat: ethosu.MeanParams(pat).is_valid(), + ), + ] + + mod = ( + create_relay_graph_with_same_quantization() + if use_same_quantization + else create_tflite_graph() + ) + mod = partition_ethosu_by_table(mod, pattern_table) + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + rewriter, mod["tvmgen_default_ethos_u_main_0"] + ) + verify(mod["tvmgen_default_ethos_u_main_0"]) + + if __name__ == "__main__": pytest.main([__file__])