diff --git a/src/relay/op/contrib/ethosu/binary_elementwise.cc b/src/relay/op/contrib/ethosu/binary_elementwise.cc index e7622452166c..8966540e0703 100644 --- a/src/relay/op/contrib/ethosu/binary_elementwise.cc +++ b/src/relay/op/contrib/ethosu/binary_elementwise.cc @@ -143,98 +143,34 @@ bool EthosuBinaryElementwiseRel(const Array& types, int num_inputs, const const auto* param = attrs.as(); ICHECK(param != nullptr) << "EthosuBinaryElementwiseAttrs cannot be nullptr."; - String operator_type = param->operator_type; - auto ifm_dtype = ifm->dtype; - auto ifm2_dtype = ifm2->dtype; - DataType ofm_dtype; + const String operator_name = "ethosu_binary_elementwise"; + const String operator_type = param->operator_type; + const DataType ifm_dtype = ifm->dtype; + const DataType ifm2_dtype = ifm2->dtype; + const DataType ofm_dtype = DataTypeFromString(param->ofm_dtype); - if (param->ofm_dtype == "int8") { - ofm_dtype = DataType::Int(8); - } else if (param->ofm_dtype == "uint8") { - ofm_dtype = DataType::UInt(8); - } else if (param->ofm_dtype == "int16") { - ofm_dtype = DataType::Int(16); - } else if (param->ofm_dtype == "int32") { - ofm_dtype = DataType::Int(32); - } - - if (ifm_dtype != ifm2_dtype) { - reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_binary_elementwise " - << "type for ifm2 be the same of ifm but was " << ifm2_dtype - << " instead of " << ifm_dtype); - return false; - } + CheckDataTypeMatch(reporter, ifm_dtype, ifm2_dtype, operator_name, "ifm", "ifm2", operator_type); if (operator_type == "ADD" || operator_type == "SUB" || operator_type == "MUL") { - if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8) && - ifm_dtype != DataType::Int(16) && ifm_dtype != DataType::Int(32)) { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_binary_elementwise " << operator_type - << " type(uint8), type(int8), type(int16) or type(int32) for ifm but was " << ifm_dtype); - return false; - } - if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) && - ofm_dtype != DataType::Int(16) && ofm_dtype != DataType::Int(32)) { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_binary_elementwise " << operator_type - << " type(uint8), type(int8), type(int16) or type(int32) for ofm but was " << ofm_dtype); - return false; - } + std::unordered_set allowed_types = {DataType::Int(8), DataType::UInt(8), + DataType::Int(16), DataType::Int(32)}; + CheckDataType(reporter, ifm_dtype, allowed_types, operator_name, "ifm", operator_type); + CheckDataType(reporter, ofm_dtype, allowed_types, operator_name, "ofm", operator_type); } else if (operator_type == "MIN" || operator_type == "MAX") { - if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8)) { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_binary_elementwise " << operator_type - << " type(uint8) or type(int8) for ifm but was " << ifm_dtype); - return false; - } - if (ifm_dtype != ofm_dtype) { - reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_binary_elementwise " - << operator_type - << " type for ofm be the same of ifm but was " << ofm_dtype - << " instead of " << ifm_dtype); - return false; - } + std::unordered_set allowed_types = {DataType::Int(8), DataType::UInt(8)}; + CheckDataType(reporter, ifm_dtype, allowed_types, operator_name, "ifm", operator_type); + CheckDataTypeMatch(reporter, ifm_dtype, ofm_dtype, operator_name, "ifm", "ofm", operator_type); } else if (operator_type == "SHR") { - if (ifm_dtype != DataType::Int(32)) { - reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_binary_elementwise " - << operator_type << " type(int32) for ifm but was " - << ifm_dtype); - return false; - } - if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) && - ofm_dtype != DataType::Int(32)) { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_binary_elementwise " << operator_type - << " type(uint8) or type(int8) or type(int32) for ofm but was " << ofm_dtype); - return false; - } + CheckDataType(reporter, ifm_dtype, {DataType::Int(32)}, operator_name, "ifm", operator_type); + CheckDataType(reporter, ofm_dtype, {DataType::UInt(8), DataType::Int(8), DataType::Int(32)}, + operator_name, "ofm", operator_type); } else if (operator_type == "SHL") { - if (ifm_dtype != DataType::Int(32)) { - reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_binary_elementwise " - << operator_type << " type(int32) for ifm but was " - << ifm_dtype); - - return false; - } - if (ofm_dtype != DataType::Int(32)) { - reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_binary_elementwise " - << operator_type << " type(int32) for ofm but was " - << ofm_dtype); - return false; - } + CheckDataType(reporter, ifm_dtype, {DataType::Int(32)}, operator_name, "ifm", operator_type); + CheckDataType(reporter, ofm_dtype, {DataType::Int(32)}, operator_name, "ofm", operator_type); } else { reporter->GetDiagCtx().EmitFatal( Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_binary_elementwise 'ADD' or 'SUB' or 'MUL' or " + << "Invalid operator: expected " << operator_name << " 'ADD' or 'SUB' or 'MUL' or " << "'MIN' or 'MAX' or 'SHR' or 'SHL' for operator_type but was " << param->operator_type); return false; } diff --git a/src/relay/op/contrib/ethosu/common.cc b/src/relay/op/contrib/ethosu/common.cc index eac576257721..8e705b66bcb5 100644 --- a/src/relay/op/contrib/ethosu/common.cc +++ b/src/relay/op/contrib/ethosu/common.cc @@ -24,6 +24,9 @@ #include "common.h" +#include +#include + #include "../../op_common.h" namespace tvm { @@ -92,6 +95,56 @@ Array EthosuInferUpscaledInput(Array ifm_shape, String ifm return new_ifm_shape; } +DataType DataTypeFromString(const String& dtype) { + DLDataType dl_dtype = tvm::runtime::String2DLDataType(dtype); + return DataType(dl_dtype); +} + +void CheckDataType(const TypeReporter& reporter, const DataType& data_type, + const std::unordered_set& allowed_data_types, + const String& operator_name, const String& tensor_name, + const String& operator_type) { + if (allowed_data_types.find(data_type) != allowed_data_types.end()) { + return; + } + + std::ostringstream message; + message << "Invalid operator: expected " << operator_name << " "; + if (operator_type != "") { + message << operator_type << " "; + } + message << "to have type in {"; + for (auto it = allowed_data_types.begin(); it != allowed_data_types.end(); ++it) { + message << *it; + if (std::next(it) != allowed_data_types.end()) { + message << ", "; + } + } + message << "}"; + message << " for " << tensor_name << " but was " << data_type << "."; + + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << message.str()); +} + +void CheckDataTypeMatch(const TypeReporter& reporter, const DataType& data_type, + const DataType& data_type2, const String& operator_name, + const String& tensor_name, const String& tensor_name2, + const String& operator_type) { + if (data_type == data_type2) { + return; + } + + std::ostringstream message; + message << "Invalid operator: expected " << operator_name << " "; + if (operator_type != " ") { + message << operator_type << " "; + } + message << "data types for " << tensor_name << " and " << tensor_name2 << " to match, but was " + << data_type << " and " << data_type2; + + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << message.str()); +} + } // namespace ethosu } // namespace contrib } // namespace op diff --git a/src/relay/op/contrib/ethosu/common.h b/src/relay/op/contrib/ethosu/common.h index 001b596c0949..9238a7db95b1 100644 --- a/src/relay/op/contrib/ethosu/common.h +++ b/src/relay/op/contrib/ethosu/common.h @@ -27,6 +27,8 @@ #include +#include + namespace tvm { namespace relay { namespace op { @@ -65,6 +67,40 @@ Array EthosuInferKernelOutput(Array ifm_shape, String ifm_ */ Array EthosuInferUpscaledInput(Array ifm_shape, String ifm_layout); +/*! \brief Get data type from string representation. + * \param dtype Data type in lower case format followed by number of bits e.g. "int8". + */ +DataType DataTypeFromString(const String& dtype); + +/*! \brief Check the data type for a given input matches one given in allowed_data_types. Raise a + * type inference error if not. + * \param reporter The infer type reporter. + * \param data_type The data ntype to check. + * \param allowed_data_types An unordered set of allowed data types. + * \param operator_name The name of the operator to report. + * \param tensor_name The name of the tensor to report e.g. "ifm", "ofm". + * \param operator_type The type of the operator to report e.g. "ADD" for binary_elementwise. + */ +void CheckDataType(const TypeReporter& reporter, const DataType& data_type, + const std::unordered_set& allowed_data_types, + const String& operator_name, const String& tensor_name, + const String& operator_type = ""); + +/*! \brief Check the data type matches that of the second data type provided. Raise a type inference + * error if not. + * \param reporter The infer type reporter. + * \param data_type The data type to check. + * \param data_type2 The second data type to check. + * \param operator_name The name of the operator to report. + * \param tensor_name The name of the tensor to report e.g. "ifm", "ofm". + * \param tensor_name2 The name of the second tensor to report e.g. "ifm2". + * \param operator_type The type of the operator to report e.g. "ADD" for binary_elementwise. + */ +void CheckDataTypeMatch(const TypeReporter& reporter, const DataType& data_type, + const DataType& data_type2, const String& operator_name, + const String& tensor_name, const String& tensor_name2, + const String& operator_type = ""); + } // namespace ethosu } // namespace contrib } // namespace op diff --git a/src/relay/op/contrib/ethosu/convolution.cc b/src/relay/op/contrib/ethosu/convolution.cc index 7b11f61acc12..4d9541ced816 100644 --- a/src/relay/op/contrib/ethosu/convolution.cc +++ b/src/relay/op/contrib/ethosu/convolution.cc @@ -131,28 +131,12 @@ bool EthosuConv2DRel(const Array& types, int num_inputs, const Attrs& attr if (ifm == nullptr || weight == nullptr) return false; const auto* param = attrs.as(); CHECK(param != nullptr) << "EthosuConv2DAttrs cannot be nullptr."; + const String operator_name = "ethosu_conv2d"; - if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) { - reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_conv2d input data type " - << "of type(uint8) or type(int8) but was " << ifm->dtype); - return false; - } - - if (weight->dtype != DataType::UInt(8) && weight->dtype != DataType::Int(8)) { - reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_conv2d weight data type " - << "of type(uint8) or type(int8) but was " << weight->dtype); - return false; - } - - if (scale_bias->dtype != DataType::UInt(8)) { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_conv2d scale bias data type " - << "of type(uint8) but was " << scale_bias->dtype); - return false; - } + CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, "ifm"); + CheckDataType(reporter, weight->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, + "weight"); + CheckDataType(reporter, scale_bias->dtype, {DataType::UInt(8)}, operator_name, "scale bias"); const std::unordered_set upscale_methods = {"NONE", "ZEROS", "NEAREST"}; if (upscale_methods.find(param->upscale) == upscale_methods.end()) { diff --git a/src/relay/op/contrib/ethosu/depthwise.cc b/src/relay/op/contrib/ethosu/depthwise.cc index c95385ad95d8..abfe0e3856a1 100644 --- a/src/relay/op/contrib/ethosu/depthwise.cc +++ b/src/relay/op/contrib/ethosu/depthwise.cc @@ -136,50 +136,17 @@ bool EthosuDepthwiseConv2DRel(const Array& types, int num_inputs, const At const auto* param = attrs.as(); ICHECK(param != nullptr) << "EthosuDepthwiseConv2DAttrs cannot be nullptr."; - DataType ofm_dtype; + const String operator_name = "ethosu_depthwise_conv2d"; - if (param->ofm_dtype == "int8") { - ofm_dtype = DataType::Int(8); - } else if (param->ofm_dtype == "uint8") { - ofm_dtype = DataType::UInt(8); - } else if (param->ofm_dtype == "int16") { - ofm_dtype = DataType::Int(16); - } else if (param->ofm_dtype == "int32") { - ofm_dtype = DataType::Int(32); - } - - if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_depthwise_conv2d input data type " - << "of type(uint8) or type(int8) but was " << ifm->dtype); - return false; - } - - if (weight->dtype != DataType::UInt(8) && weight->dtype != DataType::Int(8)) { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_depthwise_conv2d weight data type " - << "of type(uint8) or type(int8) but was " << weight->dtype); - return false; - } - - if (scale_bias->dtype != DataType::UInt(8)) { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_depthwise_conv2d scale bias data type " - << "of type(uint8) but was " << scale_bias->dtype); - return false; - } + CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, "ifm"); + CheckDataType(reporter, weight->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, + "weight"); + CheckDataType(reporter, scale_bias->dtype, {DataType::UInt(8)}, operator_name, "scale bias"); - if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) && - ofm_dtype != DataType::Int(16) && ofm_dtype != DataType::Int(32)) { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_depthwise_conv2d output data type " - << " type(uint8), type(int8), type(int16) or type(int32) for ofm but was " << ofm_dtype); - return false; - } + DataType ofm_dtype = DataTypeFromString(param->ofm_dtype); + std::unordered_set ofm_dtypes = {DataType::UInt(8), DataType::Int(8), DataType::Int(16), + DataType::Int(32)}; + CheckDataType(reporter, ofm_dtype, ofm_dtypes, operator_name, "ofm"); // Collect the ifm, weight and ofm tensors for using in the inference function Array tensor_types = {types[0], types[1], types[4]}; diff --git a/src/relay/op/contrib/ethosu/identity.cc b/src/relay/op/contrib/ethosu/identity.cc index c2b67477cfe9..350e8028f201 100644 --- a/src/relay/op/contrib/ethosu/identity.cc +++ b/src/relay/op/contrib/ethosu/identity.cc @@ -69,15 +69,11 @@ bool EthosuIdentityRel(const Array& types, int num_inputs, const Attrs& at if (ifm == nullptr) return false; const auto* param = attrs.as(); - ICHECK(param != nullptr) << "EthosuIdentityAttrs cannot be nullptr."; - if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: Expected type(uint8) or type(int8) for ifm but was " << ifm->dtype); - return false; - } + const String operator_name = "ethosu_identity"; + + CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, "ifm"); if (ifm->shape.size() > 4) { reporter->GetDiagCtx().EmitFatal( diff --git a/src/relay/op/contrib/ethosu/pooling.cc b/src/relay/op/contrib/ethosu/pooling.cc index dc16c072ebe2..d9861954ac98 100644 --- a/src/relay/op/contrib/ethosu/pooling.cc +++ b/src/relay/op/contrib/ethosu/pooling.cc @@ -123,21 +123,17 @@ bool EthosuPoolingRel(const Array& types, int num_inputs, const Attrs& att const auto* param = attrs.as(); ICHECK(param != nullptr) << "EthosuPoolingAttrs cannot be nullptr."; + const String operator_name = "ethosu_pooling"; + if (param->pooling_type != "AVG" && param->pooling_type != "MAX") { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected pooling_type 'AVG' or 'MAX' but was " - << param->pooling_type); + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected " << operator_name + << " type 'AVG' or 'MAX' but was " << param->pooling_type); return false; } - if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: Expected pool type(uint8) or type(int8) for ifm but was " - << ifm->dtype); - return false; - } + CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, "ifm", + param->pooling_type); const std::unordered_set upscale_methods = {"NONE", "ZEROS", "NEAREST"}; if (upscale_methods.find(param->upscale) == upscale_methods.end()) { diff --git a/src/relay/op/contrib/ethosu/unary_elementwise.cc b/src/relay/op/contrib/ethosu/unary_elementwise.cc index 9dc07e031d75..a346f095283c 100644 --- a/src/relay/op/contrib/ethosu/unary_elementwise.cc +++ b/src/relay/op/contrib/ethosu/unary_elementwise.cc @@ -104,30 +104,22 @@ bool EthosuUnaryElementwiseRel(const Array& types, int num_inputs, const A const auto* param = attrs.as(); CHECK(param != nullptr) << "EthosuUnaryElementwiseAttrs cannot be nullptr."; - String operator_type = param->operator_type; + const String operator_name = "ethosu_unary_elementwise"; + const String operator_type = param->operator_type; if (operator_type != "ABS" && operator_type != "CLZ") { reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_unary_elementwise 'ABS' " - "or 'CLZ' for operator_type but was" + << "Invalid operator: expected << " << operator_name + << " 'ABS' or 'CLZ' for operator_type but was" << operator_type); return false; } - auto ifm_dtype = ifm->dtype; - if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8) && operator_type == "ABS") { - reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_unary_elementwise " - << operator_type << "input data type " - << "of type(uint8) or type(int8) but was " << ifm_dtype); - return false; - } - - if (ifm_dtype != DataType::Int(32) && operator_type == "CLZ") { - reporter->GetDiagCtx().EmitFatal( - Diagnostic::Error(reporter->GetSpan()) - << "Invalid operator: expected ethosu_unary_elementwise CLZ input data type " - << "of type(int32) but was " << ifm_dtype); - return false; + const DataType ifm_dtype = ifm->dtype; + if (operator_type == "CLZ") { + CheckDataType(reporter, ifm_dtype, {DataType::Int(32)}, operator_name, "ifm", operator_type); + } else { + CheckDataType(reporter, ifm_dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, "ifm", + operator_type); } // Assign ofm type