Skip to content

Commit

Permalink
Revert "[TOSA] Make validation pass isValidElementType check more str…
Browse files Browse the repository at this point in the history
…ict (llvm#119671)"

This reverts commit 9472c5f.
  • Loading branch information
MaheshRavishankar authored and raikonenfnu committed Dec 18, 2024
1 parent b07e7b7 commit 10073b7
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
19 changes: 16 additions & 3 deletions mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,8 +524,18 @@ bool TosaValidation::isValidElementType(Type type) {
if (!isEnabledProfile(TosaProfileEnum::MainInference))
return false;
return type.isF32() || type.isF16() || type.isBF16();
} else if (auto intTy = dyn_cast<IntegerType>(type)) {
if (intTy.isSignless()) {
}
if (auto intTy = dyn_cast<IntegerType>(type)) {
if (intTy.isUnsigned()) {
switch (intTy.getWidth()) {
case 8:
case 16:
return true;
default:
return false;
}
} else {
// Signless - treated as signed.
switch (intTy.getWidth()) {
case 1:
case 4:
Expand All @@ -534,10 +544,13 @@ bool TosaValidation::isValidElementType(Type type) {
case 32:
case 48:
return true;
default:
return false;
}
}
return false;
}
return false;
return true;
}

void TosaValidation::runOnOperation() {
Expand Down
8 changes: 0 additions & 8 deletions mlir/test/Dialect/Tosa/level_check.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,6 @@ func.func @test_const_f64(%arg0 : tensor<1xf64>) {

// -----

func.func @test_const_ui8(%arg0 : tensor<1xui8>) {
// expected-error@+1 {{'tosa.const' op is not profile-aligned: element type 'ui8' is not legal}}
%0 = "tosa.const"() {value = dense<0> : tensor<1xui8>} : () -> tensor<1xui8>
return
}

// -----

func.func @test_avgpool2d_kernel_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
// expected-error@+1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}}
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 8193, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :
Expand Down

0 comments on commit 10073b7

Please sign in to comment.