From 47daacc28d7e11a96558a536fd5b57dd81718c5d Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Fri, 26 Nov 2021 22:15:03 +0000 Subject: [PATCH] address comments and correct depthwise type inference Change-Id: I3c82f9b492a58082c98c200d42d5413451740504 --- .../relay/backend/contrib/ethosu/legalize.py | 2 +- .../op/contrib/ethosu/binary_elementwise.cc | 4 ++-- src/relay/op/contrib/ethosu/depthwise.cc | 23 ++++++++++++++++++- 3 files changed, 25 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index e01a532ef0d9b..969766b289303 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -1041,7 +1041,7 @@ def callback( @ir.transform.module_pass(opt_level=1) class LegalizeMean: - """This is the pass that wraps the AddRewriter""" + """This is the pass that wraps the MeanRewriter""" def transform_module( self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext diff --git a/src/relay/op/contrib/ethosu/binary_elementwise.cc b/src/relay/op/contrib/ethosu/binary_elementwise.cc index 0ccbb59c728d4..4e0d086e66b86 100644 --- a/src/relay/op/contrib/ethosu/binary_elementwise.cc +++ b/src/relay/op/contrib/ethosu/binary_elementwise.cc @@ -181,11 +181,11 @@ bool EthosuBinaryElementwiseRel(const Array& types, int num_inputs, const if (operator_type == "ADD" || operator_type == "SUB" || operator_type == "MUL") { if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8) && - ifm_dtype != DataType::Int(32)) { + ifm_dtype != DataType::Int(16) && ifm_dtype != DataType::Int(32)) { reporter->GetDiagCtx().EmitFatal( Diagnostic::Error(reporter->GetSpan()) << "Invalid operator: expected ethosu_binary_elementwise " << operator_type - << " type(uint8) or type(int8) or type(int32) for ifm but was " << ifm_dtype); + << " type(uint8), type(int8), type(int16) or type(int32) for ifm but was " << ifm_dtype); return false; } if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) && diff --git a/src/relay/op/contrib/ethosu/depthwise.cc b/src/relay/op/contrib/ethosu/depthwise.cc index 9ab7682aa6db0..c95385ad95d83 100644 --- a/src/relay/op/contrib/ethosu/depthwise.cc +++ b/src/relay/op/contrib/ethosu/depthwise.cc @@ -136,6 +136,18 @@ bool EthosuDepthwiseConv2DRel(const Array& types, int num_inputs, const At const auto* param = attrs.as(); ICHECK(param != nullptr) << "EthosuDepthwiseConv2DAttrs cannot be nullptr."; + DataType ofm_dtype; + + if (param->ofm_dtype == "int8") { + ofm_dtype = DataType::Int(8); + } else if (param->ofm_dtype == "uint8") { + ofm_dtype = DataType::UInt(8); + } else if (param->ofm_dtype == "int16") { + ofm_dtype = DataType::Int(16); + } else if (param->ofm_dtype == "int32") { + ofm_dtype = DataType::Int(32); + } + if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) { reporter->GetDiagCtx().EmitFatal( Diagnostic::Error(reporter->GetSpan()) @@ -160,6 +172,15 @@ bool EthosuDepthwiseConv2DRel(const Array& types, int num_inputs, const At return false; } + if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) && + ofm_dtype != DataType::Int(16) && ofm_dtype != DataType::Int(32)) { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_depthwise_conv2d output data type " + << " type(uint8), type(int8), type(int16) or type(int32) for ofm but was " << ofm_dtype); + return false; + } + // Collect the ifm, weight and ofm tensors for using in the inference function Array tensor_types = {types[0], types[1], types[4]}; @@ -173,7 +194,7 @@ bool EthosuDepthwiseConv2DRel(const Array& types, int num_inputs, const At EthosuInferKernelOutput(ifm->shape, param->ifm_layout, param->ofm_layout, param->kernel_shape, param->ofm_channels, param->dilation, param->strides, param->padding); - reporter->Assign(types[4], TensorType(ofm_shape, ifm->dtype)); + reporter->Assign(types[4], TensorType(ofm_shape, ofm_dtype)); return true; }