Skip to content

Commit

Permalink
address comments and correct depthwise type inference
Browse files Browse the repository at this point in the history
Change-Id: I3c82f9b492a58082c98c200d42d5413451740504
  • Loading branch information
lhutton1 committed Dec 1, 2021
1 parent 647d65e commit 5fa39af
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,7 @@ def callback(

@ir.transform.module_pass(opt_level=1)
class LegalizeMean:
"""This is the pass that wraps the AddRewriter"""
"""This is the pass that wraps the MeanRewriter"""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
Expand Down
4 changes: 2 additions & 2 deletions src/relay/op/contrib/ethosu/binary_elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,11 @@ bool EthosuBinaryElementwiseRel(const Array<Type>& types, int num_inputs, const

if (operator_type == "ADD" || operator_type == "SUB" || operator_type == "MUL") {
if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8) &&
ifm_dtype != DataType::Int(32)) {
ifm_dtype != DataType::Int(16) && ifm_dtype != DataType::Int(32)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_binary_elementwise " << operator_type
<< " type(uint8) or type(int8) or type(int32) for ifm but was " << ifm_dtype);
<< " type(uint8), type(int8), type(int16) or type(int32) for ifm but was " << ifm_dtype);
return false;
}
if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) &&
Expand Down
23 changes: 22 additions & 1 deletion src/relay/op/contrib/ethosu/depthwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,18 @@ bool EthosuDepthwiseConv2DRel(const Array<Type>& types, int num_inputs, const At
const auto* param = attrs.as<EthosuDepthwiseConv2DAttrs>();
ICHECK(param != nullptr) << "EthosuDepthwiseConv2DAttrs cannot be nullptr.";

DataType ofm_dtype;

if (param->ofm_dtype == "int8") {
ofm_dtype = DataType::Int(8);
} else if (param->ofm_dtype == "uint8") {
ofm_dtype = DataType::UInt(8);
} else if (param->ofm_dtype == "int16") {
ofm_dtype = DataType::Int(16);
} else if (param->ofm_dtype == "int32") {
ofm_dtype = DataType::Int(32);
}

if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
Expand All @@ -160,6 +172,15 @@ bool EthosuDepthwiseConv2DRel(const Array<Type>& types, int num_inputs, const At
return false;
}

if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) &&
ofm_dtype != DataType::Int(16) && ofm_dtype != DataType::Int(32)) {
reporter->GetDiagCtx().EmitFatal(
Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected ethosu_depthwise_conv2d output data type "
<< " type(uint8), type(int8), type(int16) or type(int32) for ofm but was " << ofm_dtype);
return false;
}

// Collect the ifm, weight and ofm tensors for using in the inference function
Array<Type> tensor_types = {types[0], types[1], types[4]};

Expand All @@ -173,7 +194,7 @@ bool EthosuDepthwiseConv2DRel(const Array<Type>& types, int num_inputs, const At
EthosuInferKernelOutput(ifm->shape, param->ifm_layout, param->ofm_layout, param->kernel_shape,
param->ofm_channels, param->dilation, param->strides, param->padding);

reporter->Assign(types[4], TensorType(ofm_shape, ifm->dtype));
reporter->Assign(types[4], TensorType(ofm_shape, ofm_dtype));

return true;
}
Expand Down

0 comments on commit 5fa39af

Please sign in to comment.