Skip to content

Commit

Permalink
[TOSA] Update tosa.cast check according to TOSA v1.0 spec (#3948)
Browse files Browse the repository at this point in the history
* Update checkValidityOfCast function for tosa.cast according to the
latest TOSA v1.0 spec:
https://www.mlplatform.org/tosa/tosa_spec.html#_cast
* Clean up some dead code in TorchToTosa


Change-Id: I41209c698a694bca57ebf49ed3608cf89a0d8ba8

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
  • Loading branch information
justin-ngo-arm authored Jan 10, 2025
1 parent 98e4eb2 commit 9a167e2
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 51 deletions.
12 changes: 2 additions & 10 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5513,11 +5513,7 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape)));
}

rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, resultTy,
// OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
// op.getType()),
result);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, result);

return success();
}
Expand Down Expand Up @@ -6451,11 +6447,7 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewrite(
tosa::getConstTensor<int32_t>(rewriter, op,
/*vec=*/{0, 3, 1, 2},
/*shape=*/{static_cast<int32_t>(4)});
// SmallVector<int64_t> transposedOutputShape(
// {transposedResizedOpShape[0], transposedResizedOpShape[3],
// transposedResizedOpShape[1], transposedResizedOpShape[2]});
// auto transposedOutputType = RankedTensorType::get(
// makeShapeLLVMCompatible(transposedOutputShape), inputElemTy);

rewriter
.replaceOpWithNewOp<tosa::TransposeOp>(
op, getTypeConverter()->convertType(resultType), resizeOpResult,
Expand Down
78 changes: 56 additions & 22 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,42 +264,68 @@ std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
return const_op.getResult();
}

static LogicalResult checkValidityOfCast(Type src, Type dest) {
// Valid TOSA casting pairs according to TOSA spec:
// https://www.mlplatform.org/tosa/tosa_spec.html#_cast
// Note: currently TOSA doesn't support casting to and from I64 and F64
[[maybe_unused]] static LogicalResult checkValidityOfCast(Type src, Type dest) {
// clang-format off
if ((src == dest) ||
// int64 -> *
(src.isInteger(64) && dest.isInteger(32)) ||
(src.isInteger(64) && dest.isInteger(8)) ||
(src.isInteger(64) && dest.isInteger(1)) ||
(src.isInteger(64) && dest.isF32()) ||
// int32 -> *
(src.isInteger(32) && dest.isInteger(64)) ||
(src.isInteger(32) && dest.isInteger(16)) ||
(src.isInteger(32) && dest.isInteger(8)) ||
(src.isInteger(32) && dest.isInteger(1)) ||
(src.isInteger(32) && dest.isF32()) ||
(src.isInteger(32) && dest.isF16()) ||
(src.isInteger(32) && dest.isBF16()) ||
// int16 -> *
(src.isInteger(16) && dest.isInteger(32)) ||
(src.isInteger(16) && dest.isInteger(8)) ||
(src.isInteger(16) && dest.isInteger(1)) ||
(src.isInteger(16) && dest.isBF16()) ||
(src.isInteger(16) && dest.isF32()) ||
(src.isInteger(16) && dest.isF16()) ||
// int8 -> *
(src.isInteger(8) && dest.isInteger(32)) ||
(src.isInteger(8) && dest.isInteger(16)) ||
(src.isInteger(8) && dest.isInteger(1)) ||
(src.isInteger(8) && dest.isBF16()) ||
(src.isInteger(8) && dest.isF32()) ||
(src.isInteger(8) && dest.isF16()) ||
// int1 -> *
(src.isInteger(1) && dest.isInteger(64)) ||
(src.isInteger(1) && dest.isF32()) ||
// f64 -> *
(src.isF64() && dest.isF32()) ||
(src.isF64() && dest.isBF16()) ||
(src.isInteger(1) && dest.isInteger(32)) ||
(src.isInteger(1) && dest.isInteger(16)) ||
(src.isInteger(1) && dest.isInteger(8)) ||
// f32 -> *
(src.isF32() && dest.isF64()) ||
(src.isF32() && dest.isInteger(32)) ||
(src.isF32() && dest.isInteger(16)) ||
(src.isF32() && dest.isInteger(8)) ||
(src.isF32() && dest.isBF16()) ||
(src.isF32() && dest.isF16()) ||
(src.isF32() && dest.isInteger(8)) ||
(src.isF32() && dest.isInteger(64)) ||
(src.isF32() && dest.isInteger(1)) ||
(src.isF32() && dest.isFloat8E4M3()) ||
(src.isF32() && dest.isFloat8E5M2()) ||
// f16 -> *
(src.isF16() && dest.isInteger(32)) ||
(src.isF16() && dest.isInteger(16)) ||
(src.isF16() && dest.isInteger(8)) ||
(src.isF16() && dest.isBF16()) ||
(src.isF16() && dest.isF32()) ||
(src.isF16() && dest.isFloat8E4M3()) ||
(src.isF16() && dest.isFloat8E5M2()) ||
// bf16 -> *
(src.isBF16() && dest.isInteger(8)) ||
(src.isBF16() && dest.isInteger(16)) ||
(src.isBF16() && dest.isInteger(32)) ||
(src.isBF16() && dest.isF32())) {
(src.isBF16() && dest.isInteger(16)) ||
(src.isBF16() && dest.isInteger(8)) ||
(src.isBF16() && dest.isF32()) ||
(src.isBF16() && dest.isFloat8E4M3()) ||
(src.isBF16() && dest.isFloat8E5M2()) ||
// fp8e4m3 -> *
(src.isFloat8E4M3() && dest.isBF16()) ||
(src.isFloat8E4M3() && dest.isF32()) ||
(src.isFloat8E4M3() && dest.isF16()) ||
// fp8e5m2 -> *
(src.isFloat8E5M2() && dest.isBF16()) ||
(src.isFloat8E5M2() && dest.isF32()) ||
(src.isFloat8E5M2() && dest.isF16())) {
return success();
}
// clang-format on
Expand All @@ -313,9 +339,17 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
Type srcElemTy = dyn_cast<TensorType>(src.getType()).getElementType();
Type destElemTy = dyn_cast<TensorType>(destType).getElementType();

if (failed(checkValidityOfCast(srcElemTy, destElemTy)))
return rewriter.notifyMatchFailure(
op, "casting to result dtype is invalid or unsupported");
// Temporarily disable checkValidityOfCast as it's currently strictly
// following TOSA spec and might cause many e2e tests to fail. This is because
// even though there are some casting pairs that are not congruent to TOSA
// spec, they are still permissible. TOSA validation should flag these illegal
// constructs in a per-profile manner. This strict validity check will be
// enabled later in a potential `--strict` mode which checks for strict
// casting only when needed (the default value of `--strict` mode will be
// off).
// if (failed(checkValidityOfCast(srcElemTy, destElemTy)))
// return rewriter.notifyMatchFailure(
// op, "casting to result dtype is invalid or unsupported");

if (destElemTy.isInteger(1)) {
auto srcType = dyn_cast<TensorType>(src.getType());
Expand Down
36 changes: 17 additions & 19 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1705,6 +1705,21 @@
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"AtenEyeMModuleInt2D_basic",
"AtenEyeModuleInt2D_basic",
"ElementwiseWhereScalarOtherStaticModule_basic",
"FullModuleFalsePinMemory_basic",
"FullModuleInt2D_basic",
"MaskedFillScalarFloatValueModule_basic",
"MaskedFillScalarFloatValueStaticModule_basic",
"NewFullModuleInt2D_basic",
"NewFullModuleInt3D_basic",
"Threshold3dIntModule_basic",
"TrilIndicesModule_basic",
"TrilIndicesOfssetGreaterThanRowModule_basic",
"TriuIndicesNegativeOffsetModule_basic",
"BmmFloat16Module_basic",
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
"Unfold_Module_Rank_4",
"Unfold_Module_Rank_Zero_basic",
"Unfold_Module_basic",
Expand Down Expand Up @@ -2546,6 +2561,8 @@
}
) - {
### Test failing in make_fx_tosa but not in tosa
"ElementwiseRreluEvalStaticModule_basic",
"ElementwiseRreluTrainStaticModule_basic",
"AdaptiveMaxPool1dDimOneStatic_basic",
"FloatPowerTensorTensorStaticModule_basic",
# Dynamic shape, has extra unsupported broadcast ops
Expand Down Expand Up @@ -3466,7 +3483,6 @@
"LayerNormFwAndBwModule_basic",
"LayerNormManualFwAndBwModule_basic",
"SelfAttentionFwAndBwModule_basic",
"Threshold3dIntModule_basic",
"ElementwiseCopysignModule_basic",
"ElementwiseSignbitModule_basic",
"Aten_TrilinearModuleVaryingRanks_basic",
Expand Down Expand Up @@ -3515,12 +3531,9 @@
"TensorsConcatComplex64FloatModule_basic",
"TimeOutModule_basic",
"TrilIndicesAllZerosModule_basic",
"TrilIndicesModule_basic",
"TrilIndicesNegativeOffsetModule_basic",
"TrilIndicesOfssetGreaterThanRowModule_basic",
"TriuIndicesAllZerosModule_basic",
"TriuIndicesModule_basic",
"TriuIndicesNegativeOffsetModule_basic",
"TypeConversionUint8ToF32Module_basic",
"WeightNormInterfaceModule_basic",
"AdaptiveAvgPool3dDynamicNoBatch_basic",
Expand Down Expand Up @@ -3550,8 +3563,6 @@
"AtenComplexViewModule_basic",
"AtenEmbeddingBagStaticModule_basic",
"AtenEmbeddingBagSumExample_basic",
"AtenEyeMModuleInt2D_basic",
"AtenEyeModuleInt2D_basic",
"AtenFloatScalarModule_basic",
"AtenIntBoolOpConstFalseModule_basic",
"AtenIntBoolOpConstTrueModule_basic",
Expand Down Expand Up @@ -3586,11 +3597,8 @@
"AvgPool2dIntModule_basic",
"AvgPool2dStaticModule_basic",
"BernoulliFloatModule_basic",
"BernoulliModule_basic",
"BernoulliOnesModule_basic",
"BernoulliPModule_basic",
"BernoulliTensorModule_basic",
"BernoulliZerosModule_basic",
"BincountMinlengthModule_basic",
"BincountModule_basic",
"BincountStaticSizeModule_basic",
Expand Down Expand Up @@ -3680,11 +3688,8 @@
"ElementwiseSinhModule_basic",
"ElementwiseToDtypeF32ToI64Module_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",
"ElementwiseWhereScalarOtherStaticModule_basic",
"EqIntModule_basic",
"FloatImplicitModule_basic",
"FullLikeModuleInt2D_basic",
"FullLikeModuleInt3D_basic",
"GeFloatIntModule_basic",
"GeFloatModule_basic",
"GeIntModule_basic",
Expand Down Expand Up @@ -3770,8 +3775,6 @@
"NativeGroupNormBackwardModule_basic",
"NeFloatIntModule_basic",
"NeIntModule_basic",
"NewFullModuleInt2D_basic",
"NewFullModuleInt3D_basic",
"NllLossModuleBackward1DMeanWeight_basic",
"NllLossModuleBackward1DMean_basic",
"NllLossModuleBackward1DSumWeight_basic",
Expand All @@ -3784,7 +3787,6 @@
"NormalFunctionalModule_basic",
"NumelModule_basic",
"NumelZeroRankModule_basic",
"OnesLikeModule_falsePinMemory",
"PowIntIntModule_basic",
"PrimMaxIntModule_basic",
"PrimMinIntDynamicModule_basic",
Expand Down Expand Up @@ -3880,15 +3882,12 @@
"TorchPrimLoopWhileLikeModule_basic",
"TraceModule_empty",
"TraceUnsignedIntModule_empty",
"TypeConversionI1ToF64Module_basic",
"TypeConversionI1ToI32Module_basic",
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
"UpSampleNearest2dBackwardScalesNone_basic",
"UpSampleNearest2dBackward_basic",
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
"ViewSizeFromOtherTensor_basic",
"VisionTransformerModule_basic",
"ZerosLikeModule_falsePinMemory",
# Unexpected failures due to new PyTorch version update
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
"AdaptiveAvgPool1dGeneralDynamic_basic",
Expand Down Expand Up @@ -4651,7 +4650,6 @@
"QuantizedReluUint8_basic",
"QuantizedSingleLayer_basic",
"RandIntDtypeModule_basic",
"RandIntLowDtypeModule_basic",
"RandIntModule_basic",
"RandIntPinMemoryModule_basic",
"RandLikeDtypeModule_basic",
Expand Down

0 comments on commit 9a167e2

Please sign in to comment.