From cb055ae51ac94e7bf726dc0ea95990451f7e2800 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Fri, 13 Dec 2024 22:10:21 +0000 Subject: [PATCH] [TOSA] Don't run validation pass on non TOSA operations This commit ensures the validation pass is not run on operations from other dialects. In doing so, operations from other dialects that, for example, use types not supported by TOSA don't result in an error. Change-Id: If1efde2036f2d3e13b8c8588fea6344922453c2b Signed-off-by: Luke Hutton --- mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp | 4 ++++ mlir/test/Dialect/Tosa/invalid.mlir | 12 ++++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 893cedefc1ebd..6fd671051362c 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -543,6 +543,10 @@ bool TosaValidation::isValidElementType(Type type) { void TosaValidation::runOnOperation() { configLevelAndProfile(); getOperation().walk([&](Operation *op) { + if (!op->getDialect() || + op->getDialect()->getNamespace() != TosaDialect::getDialectNamespace()) + return; + for (Value operand : op->getOperands()) { auto elementTy = getElementTypeOrSelf(operand); if (!isValidElementType(elementTy)) { diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 79bb7fce5755e..cca50b25d14d6 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -625,7 +625,6 @@ func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1 func.func @test_unsupported_int64_data_type(%arg0: tensor<1x13x13x5xf32>) -> tensor<1x13x13xi64> { // expected-error@+1 {{'tosa.argmax' op is not profile-aligned: element type 'i64' is not legal}} %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x13x13x5xf32>) -> tensor<1x13x13xi64> - // expected-error@+1 {{'func.return' op is not profile-aligned: element type 'i64' is not legal}} return %0 : tensor<1x13x13xi64> } @@ -879,4 +878,13 @@ func.func @test_mismatch_in_out_shape_logical_not(%arg0: tensor<1x21x3xi1>) -> t // expected-error@+1 {{'tosa.logical_not' op requires the same shape for all operands and results}} %0 = tosa.logical_not %arg0 : (tensor<1x21x3xi1>) -> tensor<13x21x3xi1> return %0 : tensor<13x21x3xi1> -} \ No newline at end of file +} + +// ----- + +// Check validate pass doesn't run on non TOSA ops +func.func @test_non_tosa_ops() { + %0 = arith.constant 6 : index + %2 = tensor.empty(%0) : tensor + return +}