From 9a167e2d319641a175b22b10984c36b81f7ba267 Mon Sep 17 00:00:00 2001 From: Justin Ngo Date: Fri, 10 Jan 2025 09:53:34 -0800 Subject: [PATCH] [TOSA] Update tosa.cast check according to TOSA v1.0 spec (#3948) * Update checkValidityOfCast function for tosa.cast according to the latest TOSA v1.0 spec: https://www.mlplatform.org/tosa/tosa_spec.html#_cast * Clean up some dead code in TorchToTosa Change-Id: I41209c698a694bca57ebf49ed3608cf89a0d8ba8 Signed-off-by: Justin Ngo --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 12 +-- .../TorchToTosa/TosaLegalizeUtils.cpp | 78 +++++++++++++------ projects/pt1/e2e_testing/xfail_sets.py | 36 ++++----- 3 files changed, 75 insertions(+), 51 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 6f3e14b1cde1..066126fb0906 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -5513,11 +5513,7 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape))); } - rewriter.replaceOpWithNewOp( - op, resultTy, - // OpConversionPattern::getTypeConverter()->convertType( - // op.getType()), - result); + rewriter.replaceOpWithNewOp(op, resultTy, result); return success(); } @@ -6451,11 +6447,7 @@ ConvertAtenOp::matchAndRewrite( tosa::getConstTensor(rewriter, op, /*vec=*/{0, 3, 1, 2}, /*shape=*/{static_cast(4)}); - // SmallVector transposedOutputShape( - // {transposedResizedOpShape[0], transposedResizedOpShape[3], - // transposedResizedOpShape[1], transposedResizedOpShape[2]}); - // auto transposedOutputType = RankedTensorType::get( - // makeShapeLLVMCompatible(transposedOutputShape), inputElemTy); + rewriter .replaceOpWithNewOp( op, getTypeConverter()->convertType(resultType), resizeOpResult, diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index bf7086a77f66..3d97b695f1ab 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -264,42 +264,68 @@ std::optional getConstTensor(PatternRewriter &rewriter, return const_op.getResult(); } -static LogicalResult checkValidityOfCast(Type src, Type dest) { +// Valid TOSA casting pairs according to TOSA spec: +// https://www.mlplatform.org/tosa/tosa_spec.html#_cast +// Note: currently TOSA doesn't support casting to and from I64 and F64 +[[maybe_unused]] static LogicalResult checkValidityOfCast(Type src, Type dest) { // clang-format off if ((src == dest) || - // int64 -> * - (src.isInteger(64) && dest.isInteger(32)) || - (src.isInteger(64) && dest.isInteger(8)) || - (src.isInteger(64) && dest.isInteger(1)) || - (src.isInteger(64) && dest.isF32()) || // int32 -> * - (src.isInteger(32) && dest.isInteger(64)) || + (src.isInteger(32) && dest.isInteger(16)) || + (src.isInteger(32) && dest.isInteger(8)) || (src.isInteger(32) && dest.isInteger(1)) || (src.isInteger(32) && dest.isF32()) || + (src.isInteger(32) && dest.isF16()) || (src.isInteger(32) && dest.isBF16()) || // int16 -> * + (src.isInteger(16) && dest.isInteger(32)) || + (src.isInteger(16) && dest.isInteger(8)) || + (src.isInteger(16) && dest.isInteger(1)) || (src.isInteger(16) && dest.isBF16()) || + (src.isInteger(16) && dest.isF32()) || + (src.isInteger(16) && dest.isF16()) || // int8 -> * + (src.isInteger(8) && dest.isInteger(32)) || + (src.isInteger(8) && dest.isInteger(16)) || (src.isInteger(8) && dest.isInteger(1)) || (src.isInteger(8) && dest.isBF16()) || + (src.isInteger(8) && dest.isF32()) || + (src.isInteger(8) && dest.isF16()) || // int1 -> * - (src.isInteger(1) && dest.isInteger(64)) || - (src.isInteger(1) && dest.isF32()) || - // f64 -> * - (src.isF64() && dest.isF32()) || - (src.isF64() && dest.isBF16()) || + (src.isInteger(1) && dest.isInteger(32)) || + (src.isInteger(1) && dest.isInteger(16)) || + (src.isInteger(1) && dest.isInteger(8)) || // f32 -> * - (src.isF32() && dest.isF64()) || + (src.isF32() && dest.isInteger(32)) || + (src.isF32() && dest.isInteger(16)) || + (src.isF32() && dest.isInteger(8)) || (src.isF32() && dest.isBF16()) || (src.isF32() && dest.isF16()) || - (src.isF32() && dest.isInteger(8)) || - (src.isF32() && dest.isInteger(64)) || - (src.isF32() && dest.isInteger(1)) || + (src.isF32() && dest.isFloat8E4M3()) || + (src.isF32() && dest.isFloat8E5M2()) || + // f16 -> * + (src.isF16() && dest.isInteger(32)) || + (src.isF16() && dest.isInteger(16)) || + (src.isF16() && dest.isInteger(8)) || + (src.isF16() && dest.isBF16()) || + (src.isF16() && dest.isF32()) || + (src.isF16() && dest.isFloat8E4M3()) || + (src.isF16() && dest.isFloat8E5M2()) || // bf16 -> * - (src.isBF16() && dest.isInteger(8)) || - (src.isBF16() && dest.isInteger(16)) || (src.isBF16() && dest.isInteger(32)) || - (src.isBF16() && dest.isF32())) { + (src.isBF16() && dest.isInteger(16)) || + (src.isBF16() && dest.isInteger(8)) || + (src.isBF16() && dest.isF32()) || + (src.isBF16() && dest.isFloat8E4M3()) || + (src.isBF16() && dest.isFloat8E5M2()) || + // fp8e4m3 -> * + (src.isFloat8E4M3() && dest.isBF16()) || + (src.isFloat8E4M3() && dest.isF32()) || + (src.isFloat8E4M3() && dest.isF16()) || + // fp8e5m2 -> * + (src.isFloat8E5M2() && dest.isBF16()) || + (src.isFloat8E5M2() && dest.isF32()) || + (src.isFloat8E5M2() && dest.isF16())) { return success(); } // clang-format on @@ -313,9 +339,17 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op, Type srcElemTy = dyn_cast(src.getType()).getElementType(); Type destElemTy = dyn_cast(destType).getElementType(); - if (failed(checkValidityOfCast(srcElemTy, destElemTy))) - return rewriter.notifyMatchFailure( - op, "casting to result dtype is invalid or unsupported"); + // Temporarily disable checkValidityOfCast as it's currently strictly + // following TOSA spec and might cause many e2e tests to fail. This is because + // even though there are some casting pairs that are not congruent to TOSA + // spec, they are still permissible. TOSA validation should flag these illegal + // constructs in a per-profile manner. This strict validity check will be + // enabled later in a potential `--strict` mode which checks for strict + // casting only when needed (the default value of `--strict` mode will be + // off). + // if (failed(checkValidityOfCast(srcElemTy, destElemTy))) + // return rewriter.notifyMatchFailure( + // op, "casting to result dtype is invalid or unsupported"); if (destElemTy.isInteger(1)) { auto srcType = dyn_cast(src.getType()); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 7e2bae685c85..7bfbcc07d2a6 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1705,6 +1705,21 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "AtenEyeMModuleInt2D_basic", + "AtenEyeModuleInt2D_basic", + "ElementwiseWhereScalarOtherStaticModule_basic", + "FullModuleFalsePinMemory_basic", + "FullModuleInt2D_basic", + "MaskedFillScalarFloatValueModule_basic", + "MaskedFillScalarFloatValueStaticModule_basic", + "NewFullModuleInt2D_basic", + "NewFullModuleInt3D_basic", + "Threshold3dIntModule_basic", + "TrilIndicesModule_basic", + "TrilIndicesOfssetGreaterThanRowModule_basic", + "TriuIndicesNegativeOffsetModule_basic", + "BmmFloat16Module_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", "Unfold_Module_Rank_4", "Unfold_Module_Rank_Zero_basic", "Unfold_Module_basic", @@ -2546,6 +2561,8 @@ } ) - { ### Test failing in make_fx_tosa but not in tosa + "ElementwiseRreluEvalStaticModule_basic", + "ElementwiseRreluTrainStaticModule_basic", "AdaptiveMaxPool1dDimOneStatic_basic", "FloatPowerTensorTensorStaticModule_basic", # Dynamic shape, has extra unsupported broadcast ops @@ -3466,7 +3483,6 @@ "LayerNormFwAndBwModule_basic", "LayerNormManualFwAndBwModule_basic", "SelfAttentionFwAndBwModule_basic", - "Threshold3dIntModule_basic", "ElementwiseCopysignModule_basic", "ElementwiseSignbitModule_basic", "Aten_TrilinearModuleVaryingRanks_basic", @@ -3515,12 +3531,9 @@ "TensorsConcatComplex64FloatModule_basic", "TimeOutModule_basic", "TrilIndicesAllZerosModule_basic", - "TrilIndicesModule_basic", "TrilIndicesNegativeOffsetModule_basic", - "TrilIndicesOfssetGreaterThanRowModule_basic", "TriuIndicesAllZerosModule_basic", "TriuIndicesModule_basic", - "TriuIndicesNegativeOffsetModule_basic", "TypeConversionUint8ToF32Module_basic", "WeightNormInterfaceModule_basic", "AdaptiveAvgPool3dDynamicNoBatch_basic", @@ -3550,8 +3563,6 @@ "AtenComplexViewModule_basic", "AtenEmbeddingBagStaticModule_basic", "AtenEmbeddingBagSumExample_basic", - "AtenEyeMModuleInt2D_basic", - "AtenEyeModuleInt2D_basic", "AtenFloatScalarModule_basic", "AtenIntBoolOpConstFalseModule_basic", "AtenIntBoolOpConstTrueModule_basic", @@ -3586,11 +3597,8 @@ "AvgPool2dIntModule_basic", "AvgPool2dStaticModule_basic", "BernoulliFloatModule_basic", - "BernoulliModule_basic", - "BernoulliOnesModule_basic", "BernoulliPModule_basic", "BernoulliTensorModule_basic", - "BernoulliZerosModule_basic", "BincountMinlengthModule_basic", "BincountModule_basic", "BincountStaticSizeModule_basic", @@ -3680,11 +3688,8 @@ "ElementwiseSinhModule_basic", "ElementwiseToDtypeF32ToI64Module_basic", "ElementwiseToDtypeI64ToUI8Module_basic", - "ElementwiseWhereScalarOtherStaticModule_basic", "EqIntModule_basic", "FloatImplicitModule_basic", - "FullLikeModuleInt2D_basic", - "FullLikeModuleInt3D_basic", "GeFloatIntModule_basic", "GeFloatModule_basic", "GeIntModule_basic", @@ -3770,8 +3775,6 @@ "NativeGroupNormBackwardModule_basic", "NeFloatIntModule_basic", "NeIntModule_basic", - "NewFullModuleInt2D_basic", - "NewFullModuleInt3D_basic", "NllLossModuleBackward1DMeanWeight_basic", "NllLossModuleBackward1DMean_basic", "NllLossModuleBackward1DSumWeight_basic", @@ -3784,7 +3787,6 @@ "NormalFunctionalModule_basic", "NumelModule_basic", "NumelZeroRankModule_basic", - "OnesLikeModule_falsePinMemory", "PowIntIntModule_basic", "PrimMaxIntModule_basic", "PrimMinIntDynamicModule_basic", @@ -3880,15 +3882,12 @@ "TorchPrimLoopWhileLikeModule_basic", "TraceModule_empty", "TraceUnsignedIntModule_empty", - "TypeConversionI1ToF64Module_basic", - "TypeConversionI1ToI32Module_basic", "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", "UpSampleNearest2dBackwardScalesNone_basic", "UpSampleNearest2dBackward_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewSizeFromOtherTensor_basic", "VisionTransformerModule_basic", - "ZerosLikeModule_falsePinMemory", # Unexpected failures due to new PyTorch version update "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", "AdaptiveAvgPool1dGeneralDynamic_basic", @@ -4651,7 +4650,6 @@ "QuantizedReluUint8_basic", "QuantizedSingleLayer_basic", "RandIntDtypeModule_basic", - "RandIntLowDtypeModule_basic", "RandIntModule_basic", "RandIntPinMemoryModule_basic", "RandLikeDtypeModule_basic",