Skip to content

Commit

Permalink
[microNPU] Refactor type inference data type checks (#10060)
Browse files Browse the repository at this point in the history
* [microNPU] Refactor type inference data type checks

Aims to improve readability, extendibility and error message
unification for data type checks across NPU operators.

A follow up for the comments in #9576.

Change-Id: I83fb89a56677003f7abebb7985ad60d92cfa8df1

* unordered_set -> initializer_list and use new format for upscale check

Change-Id: Icf3d68d5cc7d5e1d5af42b1af193db89faea155e

* remove unused header and use auto for initializer type

Change-Id: I10311b718c3abd0ed75dd88b5ec9de6e0742f047
  • Loading branch information
lhutton1 authored Feb 9, 2022
1 parent 9282367 commit 86e1e56
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 198 deletions.
102 changes: 19 additions & 83 deletions src/relay/op/contrib/ethosu/binary_elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,98 +143,34 @@ bool EthosuBinaryElementwiseRel(const Array<Type>& types, int num_inputs, const
const auto* param = attrs.as<EthosuBinaryElementwiseAttrs>();
ICHECK(param != nullptr) << "EthosuBinaryElementwiseAttrs cannot be nullptr.";

String operator_type = param->operator_type;
auto ifm_dtype = ifm->dtype;
auto ifm2_dtype = ifm2->dtype;
DataType ofm_dtype;
const String operator_name = "ethosu_binary_elementwise";
const String operator_type = param->operator_type;
const DataType ifm_dtype = ifm->dtype;
const DataType ifm2_dtype = ifm2->dtype;
const DataType ofm_dtype = DataTypeFromString(param->ofm_dtype);

if (param->ofm_dtype == "int8") {
ofm_dtype = DataType::Int(8);
} else if (param->ofm_dtype == "uint8") {
ofm_dtype = DataType::UInt(8);
} else if (param->ofm_dtype == "int16") {
ofm_dtype = DataType::Int(16);
} else if (param->ofm_dtype == "int32") {
ofm_dtype = DataType::Int(32);
}

if (ifm_dtype != ifm2_dtype) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise "
<< "type for ifm2 be the same of ifm but was " << ifm2_dtype
<< " instead of " << ifm_dtype);
return false;
}
CheckDataTypeMatch(reporter, ifm_dtype, ifm2_dtype, operator_name, "ifm", "ifm2", operator_type);

if (operator_type == "ADD" || operator_type == "SUB" || operator_type == "MUL") {
if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8) &&
ifm_dtype != DataType::Int(16) && ifm_dtype != DataType::Int(32)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise " << operator_type
<< " type(uint8), type(int8), type(int16) or type(int32) for ifm but was " << ifm_dtype);
return false;
}
if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) &&
ofm_dtype != DataType::Int(16) && ofm_dtype != DataType::Int(32)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise " << operator_type
<< " type(uint8), type(int8), type(int16) or type(int32) for ofm but was " << ofm_dtype);
return false;
}
auto allowed_types = {DataType::Int(8), DataType::UInt(8), DataType::Int(16),
DataType::Int(32)};
CheckDataType(reporter, ifm_dtype, allowed_types, operator_name, "ifm", operator_type);
CheckDataType(reporter, ofm_dtype, allowed_types, operator_name, "ofm", operator_type);
} else if (operator_type == "MIN" || operator_type == "MAX") {
if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise " << operator_type
<< " type(uint8) or type(int8) for ifm but was " << ifm_dtype);
return false;
}
if (ifm_dtype != ofm_dtype) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise "
<< operator_type
<< " type for ofm be the same of ifm but was " << ofm_dtype
<< " instead of " << ifm_dtype);
return false;
}
auto allowed_types = {DataType::Int(8), DataType::UInt(8)};
CheckDataType(reporter, ifm_dtype, allowed_types, operator_name, "ifm", operator_type);
CheckDataTypeMatch(reporter, ifm_dtype, ofm_dtype, operator_name, "ifm", "ofm", operator_type);
} else if (operator_type == "SHR") {
if (ifm_dtype != DataType::Int(32)) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise "
<< operator_type << " type(int32) for ifm but was "
<< ifm_dtype);
return false;
}
if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) &&
ofm_dtype != DataType::Int(32)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise " << operator_type
<< " type(uint8) or type(int8) or type(int32) for ofm but was " << ofm_dtype);
return false;
}
CheckDataType(reporter, ifm_dtype, {DataType::Int(32)}, operator_name, "ifm", operator_type);
CheckDataType(reporter, ofm_dtype, {DataType::UInt(8), DataType::Int(8), DataType::Int(32)},
operator_name, "ofm", operator_type);
} else if (operator_type == "SHL") {
if (ifm_dtype != DataType::Int(32)) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise "
<< operator_type << " type(int32) for ifm but was "
<< ifm_dtype);

return false;
}
if (ofm_dtype != DataType::Int(32)) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise "
<< operator_type << " type(int32) for ofm but was "
<< ofm_dtype);
return false;
}
CheckDataType(reporter, ifm_dtype, {DataType::Int(32)}, operator_name, "ifm", operator_type);
CheckDataType(reporter, ofm_dtype, {DataType::Int(32)}, operator_name, "ofm", operator_type);
} else {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise 'ADD' or 'SUB' or 'MUL' or "
<< "Invalid operator: expected " << operator_name << " 'ADD' or 'SUB' or 'MUL' or "
<< "'MIN' or 'MAX' or 'SHR' or 'SHL' for operator_type but was " << param->operator_type);
return false;
}
Expand Down
81 changes: 81 additions & 0 deletions src/relay/op/contrib/ethosu/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

#include "common.h"

#include <sstream>

#include "../../op_common.h"

namespace tvm {
Expand Down Expand Up @@ -92,6 +94,85 @@ Array<IndexExpr> EthosuInferUpscaledInput(Array<IndexExpr> ifm_shape, String ifm
return new_ifm_shape;
}

DataType DataTypeFromString(const String& dtype) {
DLDataType dl_dtype = tvm::runtime::String2DLDataType(dtype);
return DataType(dl_dtype);
}

void CheckDataType(const TypeReporter& reporter, const DataType& data_type,
const std::initializer_list<DataType>& allowed_data_types,
const String& operator_name, const String& tensor_name,
const String& operator_type) {
for (const auto& i : allowed_data_types) {
if (data_type == i) {
return;
}
}

std::ostringstream message;
message << "Invalid operator: expected " << operator_name << " ";
if (operator_type != "") {
message << operator_type << " ";
}
message << "to have type in {";
for (auto it = allowed_data_types.begin(); it != allowed_data_types.end(); ++it) {
message << *it;
if (std::next(it) != allowed_data_types.end()) {
message << ", ";
}
}
message << "}";
message << " for " << tensor_name << " but was " << data_type << ".";

reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << message.str());
}

void CheckUpscaleMethod(const TypeReporter& reporter, const String& upscale_method,
const std::initializer_list<String>& allowed_upscale_methods,
const String& operator_name, const String& operator_type) {
for (const auto& i : allowed_upscale_methods) {
if (upscale_method == i) {
return;
}
}

std::ostringstream message;
message << "Invalid operator: expected " << operator_name << " ";
if (operator_type != "") {
message << operator_type << " ";
}
message << "to have upscale method in {";
for (auto it = allowed_upscale_methods.begin(); it != allowed_upscale_methods.end(); ++it) {
message << *it;
if (std::next(it) != allowed_upscale_methods.end()) {
message << ", ";
}
}
message << "}";
message << " but was " << upscale_method << ".";

reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << message.str());
}

void CheckDataTypeMatch(const TypeReporter& reporter, const DataType& data_type,
const DataType& data_type2, const String& operator_name,
const String& tensor_name, const String& tensor_name2,
const String& operator_type) {
if (data_type == data_type2) {
return;
}

std::ostringstream message;
message << "Invalid operator: expected " << operator_name << " ";
if (operator_type != "") {
message << operator_type << " ";
}
message << "data types for " << tensor_name << " and " << tensor_name2 << " to match, but was "
<< data_type << " and " << data_type2;

reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << message.str());
}

} // namespace ethosu
} // namespace contrib
} // namespace op
Expand Down
46 changes: 46 additions & 0 deletions src/relay/op/contrib/ethosu/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,52 @@ Array<IndexExpr> EthosuInferKernelOutput(Array<IndexExpr> ifm_shape, String ifm_
*/
Array<IndexExpr> EthosuInferUpscaledInput(Array<IndexExpr> ifm_shape, String ifm_layout);

/*! \brief Get data type from string representation.
* \param dtype Data type in lower case format followed by number of bits e.g. "int8".
*/
DataType DataTypeFromString(const String& dtype);

/*! \brief Check the data type for a given input matches one given in allowed_data_types. Raise a
* type inference error if not.
* \param reporter The infer type reporter.
* \param data_type The data type to check.
* \param allowed_data_types An initializer list of allowed data types.
* \param operator_name The name of the operator to report.
* \param tensor_name The name of the tensor to report e.g. "ifm", "ofm".
* \param operator_type The type of the operator to report e.g. "ADD" for binary_elementwise.
*/
void CheckDataType(const TypeReporter& reporter, const DataType& data_type,
const std::initializer_list<DataType>& allowed_data_types,
const String& operator_name, const String& tensor_name,
const String& operator_type = "");

/*! \brief Check the upscale method matches one given in allowed_upscale_methods. Raise a type
* inference error if not.
* \param reporter The infer type reporter.
* \param upscale_method The upscale method string to check.
* \param allowed_upscale_methods An initializer list of allowed upscale methods.
* \param operator_name The name of the operator to report.
* \param operator_type The type of the operator to report e.g. "ADD" for binary_elementwise.
*/
void CheckUpscaleMethod(const TypeReporter& reporter, const String& upscale_method,
const std::initializer_list<String>& allowed_upscale_methods,
const String& operator_name, const String& operator_type = "");

/*! \brief Check the data type matches that of the second data type provided. Raise a type inference
* error if not.
* \param reporter The infer type reporter.
* \param data_type The data type to check.
* \param data_type2 The second data type to check.
* \param operator_name The name of the operator to report.
* \param tensor_name The name of the tensor to report e.g. "ifm", "ofm".
* \param tensor_name2 The name of the second tensor to report e.g. "ifm2".
* \param operator_type The type of the operator to report e.g. "ADD" for binary_elementwise.
*/
void CheckDataTypeMatch(const TypeReporter& reporter, const DataType& data_type,
const DataType& data_type2, const String& operator_name,
const String& tensor_name, const String& tensor_name2,
const String& operator_type = "");

} // namespace ethosu
} // namespace contrib
} // namespace op
Expand Down
35 changes: 6 additions & 29 deletions src/relay/op/contrib/ethosu/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,37 +131,14 @@ bool EthosuConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attr
if (ifm == nullptr || weight == nullptr) return false;
const auto* param = attrs.as<EthosuConv2DAttrs>();
CHECK(param != nullptr) << "EthosuConv2DAttrs cannot be nullptr.";
const String operator_name = "ethosu_conv2d";

if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_conv2d input data type "
<< "of type(uint8) or type(int8) but was " << ifm->dtype);
return false;
}

if (weight->dtype != DataType::UInt(8) && weight->dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_conv2d weight data type "
<< "of type(uint8) or type(int8) but was " << weight->dtype);
return false;
}
CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, "ifm");
CheckDataType(reporter, weight->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name,
"weight");
CheckDataType(reporter, scale_bias->dtype, {DataType::UInt(8)}, operator_name, "scale bias");

if (scale_bias->dtype != DataType::UInt(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_conv2d scale bias data type "
<< "of type(uint8) but was " << scale_bias->dtype);
return false;
}

const std::unordered_set<std::string> upscale_methods = {"NONE", "ZEROS", "NEAREST"};
if (upscale_methods.find(param->upscale) == upscale_methods.end()) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: Expected upsample method to be 'NONE', "
"'ZEROS' or 'NEAREST' but got "
<< param->upscale);
return false;
}
CheckUpscaleMethod(reporter, param->upscale, {"NONE", "ZEROS", "NEAREST"}, operator_name);

// The scale_bias should be provided as a tensor of size {ofm_channels, 10}
reporter->Assign(types[2], TensorType({weight->shape[0], 10}, DataType::UInt(8)));
Expand Down
50 changes: 8 additions & 42 deletions src/relay/op/contrib/ethosu/depthwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,50 +136,16 @@ bool EthosuDepthwiseConv2DRel(const Array<Type>& types, int num_inputs, const At
const auto* param = attrs.as<EthosuDepthwiseConv2DAttrs>();
ICHECK(param != nullptr) << "EthosuDepthwiseConv2DAttrs cannot be nullptr.";

DataType ofm_dtype;
const String operator_name = "ethosu_depthwise_conv2d";

if (param->ofm_dtype == "int8") {
ofm_dtype = DataType::Int(8);
} else if (param->ofm_dtype == "uint8") {
ofm_dtype = DataType::UInt(8);
} else if (param->ofm_dtype == "int16") {
ofm_dtype = DataType::Int(16);
} else if (param->ofm_dtype == "int32") {
ofm_dtype = DataType::Int(32);
}

if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_depthwise_conv2d input data type "
<< "of type(uint8) or type(int8) but was " << ifm->dtype);
return false;
}

if (weight->dtype != DataType::UInt(8) && weight->dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_depthwise_conv2d weight data type "
<< "of type(uint8) or type(int8) but was " << weight->dtype);
return false;
}

if (scale_bias->dtype != DataType::UInt(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_depthwise_conv2d scale bias data type "
<< "of type(uint8) but was " << scale_bias->dtype);
return false;
}
CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, "ifm");
CheckDataType(reporter, weight->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name,
"weight");
CheckDataType(reporter, scale_bias->dtype, {DataType::UInt(8)}, operator_name, "scale bias");

if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) &&
ofm_dtype != DataType::Int(16) && ofm_dtype != DataType::Int(32)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_depthwise_conv2d output data type "
<< " type(uint8), type(int8), type(int16) or type(int32) for ofm but was " << ofm_dtype);
return false;
}
DataType ofm_dtype = DataTypeFromString(param->ofm_dtype);
auto ofm_dtypes = {DataType::UInt(8), DataType::Int(8), DataType::Int(16), DataType::Int(32)};
CheckDataType(reporter, ofm_dtype, ofm_dtypes, operator_name, "ofm");

// Collect the ifm, weight and ofm tensors for using in the inference function
Array<Type> tensor_types = {types[0], types[1], types[4]};
Expand Down
10 changes: 3 additions & 7 deletions src/relay/op/contrib/ethosu/identity.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,11 @@ bool EthosuIdentityRel(const Array<Type>& types, int num_inputs, const Attrs& at
if (ifm == nullptr) return false;

const auto* param = attrs.as<EthosuIdentityAttrs>();

ICHECK(param != nullptr) << "EthosuIdentityAttrs cannot be nullptr.";

if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: Expected type(uint8) or type(int8) for ifm but was " << ifm->dtype);
return false;
}
const String operator_name = "ethosu_identity";

CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, "ifm");

if (ifm->shape.size() > 4) {
reporter->GetDiagCtx().EmitFatal(
Expand Down
Loading

0 comments on commit 86e1e56

Please sign in to comment.