Skip to content

Commit

Permalink
[microNPU] Refactor type inference data type checks
Browse files Browse the repository at this point in the history
Aims to improve readability, extendibility and error message
unification for data type checks across NPU operators.

A follow up for the comments in apache#9576.

Change-Id: I83fb89a56677003f7abebb7985ad60d92cfa8df1
  • Loading branch information
lhutton1 committed Feb 7, 2022
1 parent 22c488e commit d731ac8
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 182 deletions.
102 changes: 19 additions & 83 deletions src/relay/op/contrib/ethosu/binary_elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,98 +143,34 @@ bool EthosuBinaryElementwiseRel(const Array<Type>& types, int num_inputs, const
const auto* param = attrs.as<EthosuBinaryElementwiseAttrs>();
ICHECK(param != nullptr) << "EthosuBinaryElementwiseAttrs cannot be nullptr.";

String operator_type = param->operator_type;
auto ifm_dtype = ifm->dtype;
auto ifm2_dtype = ifm2->dtype;
DataType ofm_dtype;
const String operator_name = "ethosu_binary_elementwise";
const String operator_type = param->operator_type;
const DataType ifm_dtype = ifm->dtype;
const DataType ifm2_dtype = ifm2->dtype;
const DataType ofm_dtype = DataTypeFromString(param->ofm_dtype);

if (param->ofm_dtype == "int8") {
ofm_dtype = DataType::Int(8);
} else if (param->ofm_dtype == "uint8") {
ofm_dtype = DataType::UInt(8);
} else if (param->ofm_dtype == "int16") {
ofm_dtype = DataType::Int(16);
} else if (param->ofm_dtype == "int32") {
ofm_dtype = DataType::Int(32);
}

if (ifm_dtype != ifm2_dtype) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise "
<< "type for ifm2 be the same of ifm but was " << ifm2_dtype
<< " instead of " << ifm_dtype);
return false;
}
CheckDataTypeMatch(reporter, ifm_dtype, ifm2_dtype, operator_name, "ifm", "ifm2", operator_type);

if (operator_type == "ADD" || operator_type == "SUB" || operator_type == "MUL") {
if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8) &&
ifm_dtype != DataType::Int(16) && ifm_dtype != DataType::Int(32)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise " << operator_type
<< " type(uint8), type(int8), type(int16) or type(int32) for ifm but was " << ifm_dtype);
return false;
}
if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) &&
ofm_dtype != DataType::Int(16) && ofm_dtype != DataType::Int(32)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise " << operator_type
<< " type(uint8), type(int8), type(int16) or type(int32) for ofm but was " << ofm_dtype);
return false;
}
std::unordered_set<DataType> allowed_types = {DataType::Int(8), DataType::UInt(8),
DataType::Int(16), DataType::Int(32)};
CheckDataType(reporter, ifm_dtype, allowed_types, operator_name, "ifm", operator_type);
CheckDataType(reporter, ofm_dtype, allowed_types, operator_name, "ofm", operator_type);
} else if (operator_type == "MIN" || operator_type == "MAX") {
if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise " << operator_type
<< " type(uint8) or type(int8) for ifm but was " << ifm_dtype);
return false;
}
if (ifm_dtype != ofm_dtype) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise "
<< operator_type
<< " type for ofm be the same of ifm but was " << ofm_dtype
<< " instead of " << ifm_dtype);
return false;
}
std::unordered_set<DataType> allowed_types = {DataType::Int(8), DataType::UInt(8)};
CheckDataType(reporter, ifm_dtype, allowed_types, operator_name, "ifm", operator_type);
CheckDataTypeMatch(reporter, ifm_dtype, ofm_dtype, operator_name, "ifm", "ofm", operator_type);
} else if (operator_type == "SHR") {
if (ifm_dtype != DataType::Int(32)) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise "
<< operator_type << " type(int32) for ifm but was "
<< ifm_dtype);
return false;
}
if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) &&
ofm_dtype != DataType::Int(32)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise " << operator_type
<< " type(uint8) or type(int8) or type(int32) for ofm but was " << ofm_dtype);
return false;
}
CheckDataType(reporter, ifm_dtype, {DataType::Int(32)}, operator_name, "ifm", operator_type);
CheckDataType(reporter, ofm_dtype, {DataType::UInt(8), DataType::Int(8), DataType::Int(32)},
operator_name, "ofm", operator_type);
} else if (operator_type == "SHL") {
if (ifm_dtype != DataType::Int(32)) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise "
<< operator_type << " type(int32) for ifm but was "
<< ifm_dtype);

return false;
}
if (ofm_dtype != DataType::Int(32)) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise "
<< operator_type << " type(int32) for ofm but was "
<< ofm_dtype);
return false;
}
CheckDataType(reporter, ifm_dtype, {DataType::Int(32)}, operator_name, "ifm", operator_type);
CheckDataType(reporter, ofm_dtype, {DataType::Int(32)}, operator_name, "ofm", operator_type);
} else {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise 'ADD' or 'SUB' or 'MUL' or "
<< "Invalid operator: expected " << operator_name << " 'ADD' or 'SUB' or 'MUL' or "
<< "'MIN' or 'MAX' or 'SHR' or 'SHL' for operator_type but was " << param->operator_type);
return false;
}
Expand Down
53 changes: 53 additions & 0 deletions src/relay/op/contrib/ethosu/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@

#include "common.h"

#include <sstream>
#include <unordered_set>

#include "../../op_common.h"

namespace tvm {
Expand Down Expand Up @@ -92,6 +95,56 @@ Array<IndexExpr> EthosuInferUpscaledInput(Array<IndexExpr> ifm_shape, String ifm
return new_ifm_shape;
}

DataType DataTypeFromString(const String& dtype) {
DLDataType dl_dtype = tvm::runtime::String2DLDataType(dtype);
return DataType(dl_dtype);
}

void CheckDataType(const TypeReporter& reporter, const DataType& data_type,
const std::unordered_set<DataType>& allowed_data_types,
const String& operator_name, const String& tensor_name,
const String& operator_type) {
if (allowed_data_types.find(data_type) != allowed_data_types.end()) {
return;
}

std::ostringstream message;
message << "Invalid operator: expected " << operator_name << " ";
if (operator_type != "") {
message << operator_type << " ";
}
message << "to have type in {";
for (auto it = allowed_data_types.begin(); it != allowed_data_types.end(); ++it) {
message << *it;
if (std::next(it) != allowed_data_types.end()) {
message << ", ";
}
}
message << "}";
message << " for " << tensor_name << " but was " << data_type << ".";

reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << message.str());
}

void CheckDataTypeMatch(const TypeReporter& reporter, const DataType& data_type,
const DataType& data_type2, const String& operator_name,
const String& tensor_name, const String& tensor_name2,
const String& operator_type) {
if (data_type == data_type2) {
return;
}

std::ostringstream message;
message << "Invalid operator: expected " << operator_name << " ";
if (operator_type != " ") {
message << operator_type << " ";
}
message << "data types for " << tensor_name << " and " << tensor_name2 << " to match, but was "
<< data_type << " and " << data_type2;

reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << message.str());
}

} // namespace ethosu
} // namespace contrib
} // namespace op
Expand Down
36 changes: 36 additions & 0 deletions src/relay/op/contrib/ethosu/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

#include <tvm/relay/expr.h>

#include <unordered_set>

namespace tvm {
namespace relay {
namespace op {
Expand Down Expand Up @@ -65,6 +67,40 @@ Array<IndexExpr> EthosuInferKernelOutput(Array<IndexExpr> ifm_shape, String ifm_
*/
Array<IndexExpr> EthosuInferUpscaledInput(Array<IndexExpr> ifm_shape, String ifm_layout);

/*! \brief Get data type from string representation.
* \param dtype Data type in lower case format followed by number of bits e.g. "int8".
*/
DataType DataTypeFromString(const String& dtype);

/*! \brief Check the data type for a given input matches one given in allowed_data_types. Raise a
* type inference error if not.
* \param reporter The infer type reporter.
* \param data_type The data ntype to check.
* \param allowed_data_types An unordered set of allowed data types.
* \param operator_name The name of the operator to report.
* \param tensor_name The name of the tensor to report e.g. "ifm", "ofm".
* \param operator_type The type of the operator to report e.g. "ADD" for binary_elementwise.
*/
void CheckDataType(const TypeReporter& reporter, const DataType& data_type,
const std::unordered_set<DataType>& allowed_data_types,
const String& operator_name, const String& tensor_name,
const String& operator_type = "");

/*! \brief Check the data type matches that of the second data type provided. Raise a type inference
* error if not.
* \param reporter The infer type reporter.
* \param data_type The data type to check.
* \param data_type2 The second data type to check.
* \param operator_name The name of the operator to report.
* \param tensor_name The name of the tensor to report e.g. "ifm", "ofm".
* \param tensor_name2 The name of the second tensor to report e.g. "ifm2".
* \param operator_type The type of the operator to report e.g. "ADD" for binary_elementwise.
*/
void CheckDataTypeMatch(const TypeReporter& reporter, const DataType& data_type,
const DataType& data_type2, const String& operator_name,
const String& tensor_name, const String& tensor_name2,
const String& operator_type = "");

} // namespace ethosu
} // namespace contrib
} // namespace op
Expand Down
26 changes: 5 additions & 21 deletions src/relay/op/contrib/ethosu/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,28 +131,12 @@ bool EthosuConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attr
if (ifm == nullptr || weight == nullptr) return false;
const auto* param = attrs.as<EthosuConv2DAttrs>();
CHECK(param != nullptr) << "EthosuConv2DAttrs cannot be nullptr.";
const String operator_name = "ethosu_conv2d";

if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_conv2d input data type "
<< "of type(uint8) or type(int8) but was " << ifm->dtype);
return false;
}

if (weight->dtype != DataType::UInt(8) && weight->dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_conv2d weight data type "
<< "of type(uint8) or type(int8) but was " << weight->dtype);
return false;
}

if (scale_bias->dtype != DataType::UInt(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_conv2d scale bias data type "
<< "of type(uint8) but was " << scale_bias->dtype);
return false;
}
CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, "ifm");
CheckDataType(reporter, weight->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name,
"weight");
CheckDataType(reporter, scale_bias->dtype, {DataType::UInt(8)}, operator_name, "scale bias");

const std::unordered_set<std::string> upscale_methods = {"NONE", "ZEROS", "NEAREST"};
if (upscale_methods.find(param->upscale) == upscale_methods.end()) {
Expand Down
51 changes: 9 additions & 42 deletions src/relay/op/contrib/ethosu/depthwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,50 +136,17 @@ bool EthosuDepthwiseConv2DRel(const Array<Type>& types, int num_inputs, const At
const auto* param = attrs.as<EthosuDepthwiseConv2DAttrs>();
ICHECK(param != nullptr) << "EthosuDepthwiseConv2DAttrs cannot be nullptr.";

DataType ofm_dtype;
const String operator_name = "ethosu_depthwise_conv2d";

if (param->ofm_dtype == "int8") {
ofm_dtype = DataType::Int(8);
} else if (param->ofm_dtype == "uint8") {
ofm_dtype = DataType::UInt(8);
} else if (param->ofm_dtype == "int16") {
ofm_dtype = DataType::Int(16);
} else if (param->ofm_dtype == "int32") {
ofm_dtype = DataType::Int(32);
}

if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_depthwise_conv2d input data type "
<< "of type(uint8) or type(int8) but was " << ifm->dtype);
return false;
}

if (weight->dtype != DataType::UInt(8) && weight->dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_depthwise_conv2d weight data type "
<< "of type(uint8) or type(int8) but was " << weight->dtype);
return false;
}

if (scale_bias->dtype != DataType::UInt(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_depthwise_conv2d scale bias data type "
<< "of type(uint8) but was " << scale_bias->dtype);
return false;
}
CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, "ifm");
CheckDataType(reporter, weight->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name,
"weight");
CheckDataType(reporter, scale_bias->dtype, {DataType::UInt(8)}, operator_name, "scale bias");

if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) &&
ofm_dtype != DataType::Int(16) && ofm_dtype != DataType::Int(32)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_depthwise_conv2d output data type "
<< " type(uint8), type(int8), type(int16) or type(int32) for ofm but was " << ofm_dtype);
return false;
}
DataType ofm_dtype = DataTypeFromString(param->ofm_dtype);
std::unordered_set<DataType> ofm_dtypes = {DataType::UInt(8), DataType::Int(8), DataType::Int(16),
DataType::Int(32)};
CheckDataType(reporter, ofm_dtype, ofm_dtypes, operator_name, "ofm");

// Collect the ifm, weight and ofm tensors for using in the inference function
Array<Type> tensor_types = {types[0], types[1], types[4]};
Expand Down
10 changes: 3 additions & 7 deletions src/relay/op/contrib/ethosu/identity.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,11 @@ bool EthosuIdentityRel(const Array<Type>& types, int num_inputs, const Attrs& at
if (ifm == nullptr) return false;

const auto* param = attrs.as<EthosuIdentityAttrs>();

ICHECK(param != nullptr) << "EthosuIdentityAttrs cannot be nullptr.";

if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: Expected type(uint8) or type(int8) for ifm but was " << ifm->dtype);
return false;
}
const String operator_name = "ethosu_identity";

CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, "ifm");

if (ifm->shape.size() > 4) {
reporter->GetDiagCtx().EmitFatal(
Expand Down
18 changes: 7 additions & 11 deletions src/relay/op/contrib/ethosu/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,21 +123,17 @@ bool EthosuPoolingRel(const Array<Type>& types, int num_inputs, const Attrs& att
const auto* param = attrs.as<EthosuPoolingAttrs>();
ICHECK(param != nullptr) << "EthosuPoolingAttrs cannot be nullptr.";

const String operator_name = "ethosu_pooling";

if (param->pooling_type != "AVG" && param->pooling_type != "MAX") {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected pooling_type 'AVG' or 'MAX' but was "
<< param->pooling_type);
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected " << operator_name
<< " type 'AVG' or 'MAX' but was " << param->pooling_type);
return false;
}

if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: Expected pool type(uint8) or type(int8) for ifm but was "
<< ifm->dtype);
return false;
}
CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, "ifm",
param->pooling_type);

const std::unordered_set<std::string> upscale_methods = {"NONE", "ZEROS", "NEAREST"};
if (upscale_methods.find(param->upscale) == upscale_methods.end()) {
Expand Down
Loading

0 comments on commit d731ac8

Please sign in to comment.