Skip to content

Commit

Permalink
Update --opt-report=NNPAUnsupportedOps for NNPA (#2858)
Browse files Browse the repository at this point in the history
* Update --opt-report=NNPAUnsupportedOps for NNPA

Signed-off-by: Mike Essenmacher <essen@us.ibm.com>

* Clang check update

Signed-off-by: Mike Essenmacher <essen@us.ibm.com>

* Use inputNNPALevel directly in the message instead of create a new onnxMlirNnpaLevel

Signed-off-by: Mike Essenmacher <essen@us.ibm.com>

---------

Signed-off-by: Mike Essenmacher <essen@us.ibm.com>
  • Loading branch information
mikeessen authored Jul 11, 2024
1 parent c5c9239 commit 8050865
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 32 deletions.
50 changes: 25 additions & 25 deletions src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ bool onnxToZHighUnsupportedReport(Operation *op, const std::string &message) {
}

/// Report incompatibility with NNPA Level.
bool onnxToZHighInCompatibilityReport(Operation *op) {
std::string onnxMlirNnpaLevel(NNPA_Z16);
bool onnxToZHighInCompatibilityReport(
Operation *op, std::string inputNNPALevel) {
std::string message =
"onnx-mlir NNPA level (" + onnxMlirNnpaLevel +
"onnx-mlir NNPA level (" + inputNNPALevel +
") is not compatible with NNPA level specified by '-mcpu'(" + mcpu +
").";
return onnxToZHighUnsupportedReport(op, message);
Expand Down Expand Up @@ -358,7 +358,7 @@ bool isSuitableForZDNN<ONNXAddOp>(
ONNXAddOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16)) {
return onnxToZHighInCompatibilityReport(op.getOperation());
return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
}
if (!isValidElementTypeAndRank(op.getOperation(), op.getA()))
return false;
Expand All @@ -377,7 +377,7 @@ bool isSuitableForZDNN<ONNXSubOp>(
ONNXSubOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return onnxToZHighInCompatibilityReport(op.getOperation());
return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
if (!isValidElementTypeAndRank(op.getOperation(), op.getA()))
return false;
if (!isValidElementTypeAndRank(op.getOperation(), op.getB()))
Expand All @@ -395,7 +395,7 @@ bool isSuitableForZDNN<ONNXMulOp>(
ONNXMulOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return onnxToZHighInCompatibilityReport(op.getOperation());
return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
if (!isValidElementTypeAndRank(op.getOperation(), op.getA()))
return false;
if (!isValidElementTypeAndRank(op.getOperation(), op.getB()))
Expand All @@ -415,7 +415,7 @@ bool isSuitableForZDNN<ONNXDivOp>(
Value B = op.getB();
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return onnxToZHighInCompatibilityReport(op.getOperation());
return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
// Broadcast with a scalar operand.
if (isEnableScalarBcastBinary()) {
if (isF32ScalarConstantTensor(A) &&
Expand Down Expand Up @@ -443,7 +443,7 @@ bool isSuitableForZDNN<ONNXSumOp>(
ONNXSumOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return onnxToZHighInCompatibilityReport(op.getOperation());
return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
// Do not support a single input.
if (op.getData_0().size() < 2)
return onnxToZHighUnsupportedReport(op.getOperation(),
Expand Down Expand Up @@ -474,7 +474,7 @@ bool isSuitableForZDNN<ONNXMinOp>(
ONNXMinOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return onnxToZHighInCompatibilityReport(op.getOperation());
return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
int64_t opnum = op.getNumOperands();
if (opnum != 2)
return onnxToZHighUnsupportedReport(op.getOperation(),
Expand All @@ -497,7 +497,7 @@ bool isSuitableForZDNN<ONNXMaxOp>(
ONNXMaxOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return onnxToZHighInCompatibilityReport(op.getOperation());
return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
int64_t opnum = op.getNumOperands();
if (opnum != 2)
return onnxToZHighUnsupportedReport(op.getOperation(),
Expand All @@ -521,7 +521,7 @@ bool isSuitableForZDNN<ONNXSoftmaxOp>(
ONNXSoftmaxOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return onnxToZHighInCompatibilityReport(op.getOperation());
return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
if (!isValidElementTypeAndRank(op.getOperation(), op.getInput()))
return false;
ShapedType inputType = mlir::cast<ShapedType>(op.getType());
Expand All @@ -547,7 +547,7 @@ bool isSuitableForZDNN<ONNXReluOp>(
ONNXReluOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return onnxToZHighInCompatibilityReport(op.getOperation());
return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
if (!isValidElementTypeAndRank(op.getOperation(), op.getX()))
return false;
return true;
Expand All @@ -559,7 +559,7 @@ bool isSuitableForZDNN<ONNXTanhOp>(
ONNXTanhOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return onnxToZHighInCompatibilityReport(op.getOperation());
return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
if (!isValidElementTypeAndRank(op.getOperation(), op.getInput()))
return false;
return true;
Expand All @@ -571,7 +571,7 @@ bool isSuitableForZDNN<ONNXSigmoidOp>(
ONNXSigmoidOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return onnxToZHighInCompatibilityReport(op.getOperation());
return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
if (!isValidElementTypeAndRank(op.getOperation(), op.getX()))
return false;
return true;
Expand All @@ -583,7 +583,7 @@ bool isSuitableForZDNN<ONNXLogOp>(
ONNXLogOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return onnxToZHighInCompatibilityReport(op.getOperation());
return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
if (!isValidElementTypeAndRank(op.getOperation(), op.getInput()))
return false;
return true;
Expand All @@ -595,7 +595,7 @@ bool isSuitableForZDNN<ONNXExpOp>(
ONNXExpOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return onnxToZHighInCompatibilityReport(op.getOperation());
return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
if (!isValidElementTypeAndRank(op.getOperation(), op.getInput()))
return false;
return true;
Expand All @@ -607,7 +607,7 @@ bool isSuitableForZDNN<ONNXMatMulOp>(
ONNXMatMulOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return onnxToZHighInCompatibilityReport(op.getOperation());
return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
int64_t opnum = op.getNumOperands();
if (opnum != 2)
return onnxToZHighUnsupportedReport(op.getOperation(),
Expand Down Expand Up @@ -691,7 +691,7 @@ bool isSuitableForZDNN<ONNXGemmOp>(

// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return onnxToZHighInCompatibilityReport(op.getOperation());
return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);

// Check data type.
if (!isValidElementTypeAndRank(op.getOperation(), A))
Expand Down Expand Up @@ -765,7 +765,7 @@ bool isSuitableForZDNN<ONNXReduceMeanV13Op>(
ONNXReduceMeanV13Op op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return onnxToZHighInCompatibilityReport(op.getOperation());
return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);

// Check data type.
if (!isValidElementTypeAndRank(op.getOperation(), op.getData()))
Expand Down Expand Up @@ -845,7 +845,7 @@ bool isSuitableForZDNN<ONNXLSTMOp>(

// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return onnxToZHighInCompatibilityReport(op.getOperation());
return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);

// Check direction.
if ((direction != FORWARD) && (direction != REVERSE) &&
Expand Down Expand Up @@ -958,7 +958,7 @@ bool isSuitableForZDNN<ONNXGRUOp>(

// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return onnxToZHighInCompatibilityReport(op.getOperation());
return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);

// Check direction.
if ((direction != FORWARD) && (direction != REVERSE) &&
Expand Down Expand Up @@ -1063,7 +1063,7 @@ bool isSuitableForZDNN<ONNXMaxPoolSingleOutOp>(
ONNXMaxPoolSingleOutOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return onnxToZHighInCompatibilityReport(op.getOperation());
return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);

// Check data type.
if (!isValidElementTypeAndRank(op.getOperation(), op.getX()))
Expand Down Expand Up @@ -1095,7 +1095,7 @@ bool isSuitableForZDNN<ONNXAveragePoolOp>(
ONNXAveragePoolOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return onnxToZHighInCompatibilityReport(op.getOperation());
return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);

// Check data type.
if (!isValidElementTypeAndRank(op.getOperation(), op.getX()))
Expand Down Expand Up @@ -1218,7 +1218,7 @@ bool isSuitableForZDNN<ONNXConvOp>(
ONNXConvOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return onnxToZHighInCompatibilityReport(op.getOperation());
return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);

// Check data type.
if (!isValidElementTypeAndRank(op.getOperation(), op.getX()))
Expand Down Expand Up @@ -1324,7 +1324,7 @@ bool isSuitableForZDNN<ONNXBatchNormalizationInferenceModeOp>(

// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return onnxToZHighInCompatibilityReport(op.getOperation());
return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);

// 4D tensors(N x C x H x W) are supported as input and output.
if (shapeInput.size() != 4 || shapeOutput.size() != 4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,5 @@ bool meetPoolParamRestrictions(mlir::Operation *op, int64_t inputShape,
bool onnxToZHighUnsupportedReport(
mlir::Operation *op, const std::string &message);

bool onnxToZHighInCompatibilityReport(mlir::Operation *op);
bool onnxToZHighInCompatibilityReport(
mlir::Operation *op, std::string inputNNPALevel);
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ void getRewriteONNXForZHighDynamicallyLegal(
target, dimAnalysis, [](ONNXAddOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return !onnxToZHighInCompatibilityReport(op.getOperation());
return !onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
// Check element type.
if (!isValidElementTypeAndRank(op.getOperation(), op.getA(), true))
return true;
Expand All @@ -547,7 +547,7 @@ void getRewriteONNXForZHighDynamicallyLegal(
target, dimAnalysis, [](ONNXDivOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return !onnxToZHighInCompatibilityReport(op.getOperation());
return !onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
// Check element type.
if (!isValidElementTypeAndRank(op.getOperation(), op.getA(), true))
return true;
Expand All @@ -560,7 +560,7 @@ void getRewriteONNXForZHighDynamicallyLegal(
target, dimAnalysis, [](ONNXMulOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return !onnxToZHighInCompatibilityReport(op.getOperation());
return !onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
// Check element type.
if (!isValidElementTypeAndRank(op.getOperation(), op.getA(), true))
return true;
Expand All @@ -573,7 +573,7 @@ void getRewriteONNXForZHighDynamicallyLegal(
target, dimAnalysis, [](ONNXSubOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return !onnxToZHighInCompatibilityReport(op.getOperation());
return !onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
// Check element type.
if (!isValidElementTypeAndRank(op.getOperation(), op.getA(), true))
return true;
Expand All @@ -597,7 +597,7 @@ void getRewriteONNXForZHighDynamicallyLegal(
target, dimAnalysis, [](ONNXMatMulOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return !onnxToZHighInCompatibilityReport(op.getOperation());
return !onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);

Value A = op.getA();
Value B = op.getB();
Expand Down Expand Up @@ -667,7 +667,7 @@ void getRewriteONNXForZHighDynamicallyLegal(
[](ONNXSoftmaxOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return !onnxToZHighInCompatibilityReport(op.getOperation());
return !onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);

Value input = op.getInput();
// std::string message = "The `input` is not reshaped to 3D because it
Expand Down

0 comments on commit 8050865

Please sign in to comment.