From d4b5e05ac19a6cea0202bb6cfe6ab6676270dddf Mon Sep 17 00:00:00 2001 From: justin-ngo-arm Date: Thu, 5 Sep 2024 11:27:29 -0700 Subject: [PATCH] [TOSA] Add Torch to Tosa Legalization for torch.tril (#3678) Change-Id: Ie5ba31a27394c3adcea00266a9d562862dbd8b08 Signed-off-by: Justin Ngo --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 110 ++++++++++ projects/pt1/e2e_testing/main.py | 6 +- projects/pt1/e2e_testing/xfail_sets.py | 241 +++++++++++++-------- test/Conversion/TorchToTosa/basic.mlir | 17 ++ 4 files changed, 277 insertions(+), 97 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 5449495d63b0..2bbacaf0015a 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -22,6 +22,7 @@ #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include "llvm/ADT/TypeSwitch.h" #include #include @@ -5385,6 +5386,114 @@ ConvertAtenOp::matchAndRewrite( return success(); } +// Template to create support tril mask tensor for aten.tril +// legalization +template +Value createTrilMask(PatternRewriter &rewriter, Operation *op, + ArrayRef shape, int64_t h, int64_t w, + int64_t diagonal) { + SmallVector vec; + + for (int64_t i = 0; i < h; i++) { + for (int64_t j = 0; j < w; j++) { + // Positive diagonal value includes as many diagonals above the main + // diagonal, while negative diagonal value excludes as many diagonals + // below the main diagonal. + if (i >= j - diagonal) { + vec.push_back(static_cast(1)); + } else { + vec.push_back(static_cast(0)); + } + } + } + + return tosa::getConstTensor(rewriter, op, vec, shape).value(); +} + +// Function to get tril mask tensor based on input type +// for aten.tril legalization +Value getTrilMask(PatternRewriter &rewriter, Operation *op, + ArrayRef shape, int64_t h, int64_t w, + int64_t diagonal, Type type) { + return TypeSwitch(type) + .Case([&](auto) { + return createTrilMask(rewriter, op, shape, h, w, diagonal); + }) + .Case([&](auto intType) { + switch (intType.getWidth()) { + case 1: + return createTrilMask(rewriter, op, shape, h, w, diagonal); + case 32: + return createTrilMask(rewriter, op, shape, h, w, diagonal); + case 64: + return createTrilMask(rewriter, op, shape, h, w, diagonal); + } + llvm_unreachable("Invalid integer width"); + }); +} + +// Legalization for aten.tril +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenTrilOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getSelf(); + + // Not a ranked tensor type + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types are supported"); + + // Rank below 2 not accepted + auto selfRank = selfType.getRank(); + if (selfRank <= 1) + return rewriter.notifyMatchFailure( + op, "Rank 0 and 1 are not accepted as they cause underflow"); + + if (!selfType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "Currently only static shapes are supported"); + + const TypeConverter *typeConverter = this->getTypeConverter(); + RankedTensorType resultType = cast( + typeConverter->convertType(op->getResult(0).getType())); + if (!resultType) + return rewriter.notifyMatchFailure(op, "Result type cannot be empty"); + + // Get height, width of input tensor, and diagonal arg to create + // a const mask tensor to multiply with input. + // This mask tensor has the same height and width of input tensor + // and consists of 1's for the lower triangle part and 0's for the rest. + // For example, with h=4, w=6, diagonal=1: + // tensor([[1, 1, 0, 0, 0, 0], + // [1, 1, 1, 0, 0, 0], + // [1, 1, 1, 1, 0, 0], + // [1, 1, 1, 1, 1, 0]]) + auto selfShape = selfType.getShape(); + int64_t h = selfShape[selfRank - 2]; + int64_t w = selfShape[selfRank - 1]; + int64_t diagonal; + + if (!matchPattern(op.getDiagonal(), m_TorchConstantInt(&diagonal))) + return rewriter.notifyMatchFailure(op, "Diagonal value is not an integer"); + + // Define shape for mask tensor based on rank + SmallVector constShape; + for (auto i = 0; i < selfRank - 2; i++) + constShape.push_back(1); + constShape.push_back(h); + constShape.push_back(w); + + Value trilMask = getTrilMask(rewriter, op, constShape, h, w, diagonal, + resultType.getElementType()); + + rewriter.replaceOpWithNewOp(op, resultType, self, trilMask, + /*shift=*/0); + + return success(); +} + } // namespace // ----------------------------------------------------------------------------- @@ -5638,6 +5747,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(AtenSqrtOp); INSERT_ATENOP_PATTERN(AtenIscloseOp); INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp); + INSERT_ATENOP_PATTERN(AtenTrilOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ diff --git a/projects/pt1/e2e_testing/main.py b/projects/pt1/e2e_testing/main.py index ce767c567501..d99098d40f96 100644 --- a/projects/pt1/e2e_testing/main.py +++ b/projects/pt1/e2e_testing/main.py @@ -58,8 +58,10 @@ FX_IMPORTER_CRASHING_SET, FX_IMPORTER_STABLEHLO_XFAIL_SET, FX_IMPORTER_STABLEHLO_CRASHING_SET, + FX_IMPORTER_TOSA_CRASHING_SET, FX_IMPORTER_TOSA_XFAIL_SET, ONNX_TOSA_XFAIL_SET, + ONNX_TOSA_CRASHING_SET, ) # Import tests to register them in the global registry. @@ -191,7 +193,7 @@ def main(): elif args.config == "fx_importer_tosa": config = FxImporterTestConfig(LinalgOnTensorsTosaBackend(), "tosa") xfail_set = FX_IMPORTER_TOSA_XFAIL_SET - crashing_set = set() + crashing_set = FX_IMPORTER_TOSA_CRASHING_SET elif args.config == "torchdynamo": # TODO: Enanble runtime verification and extend crashing set. config = TorchDynamoTestConfig( @@ -206,7 +208,7 @@ def main(): elif args.config == "onnx_tosa": config = OnnxBackendTestConfig(LinalgOnTensorsTosaBackend(), output_type="tosa") xfail_set = ONNX_TOSA_XFAIL_SET - crashing_set = set() + crashing_set = ONNX_TOSA_CRASHING_SET do_not_attempt = set( args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed or [] diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 828c7a24e26f..7ca15cbdd09d 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1571,9 +1571,25 @@ "IndexTensorNegativeIndexModule_basic", } +FX_IMPORTER_TOSA_CRASHING_SET = { + "IndexTensorNegativeIndexModule_basic", + "InterpolateDynamicModule_scales_recompute_bilinear", + "InterpolateDynamicModule_sizes_bilinear", + "InterpolateDynamicModule_sizes_nearest", + "InterpolateStaticModule_scales_bilinear_align_corners", + "UpSampleNearest2d_basic", + "UpSampleNearest2dStaticSize_basic", + "UpSampleNearest2dDynamicSize_basic", + "UpSampleNearest2dDynamicFactor_basic", + "UpSampleNearest2dStaticFactor_basic", +} + # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "AtenTrilStaticModule_basic", + "AtenTrilWithNegDiagonalStaticModule_basic", + "AtenTrilWithPosDiagonalStaticModule_basic", "ArgmaxKeepdimModule_basic", "MeshgridIndexingIJ_basic", "MeshgridIndexingXY_basic", @@ -2938,6 +2954,64 @@ } FX_IMPORTER_TOSA_XFAIL_SET = { + "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", + "AtenIntMM_basic", + "AtenKthvalueDynamicDimsModule_basic", + "AtenKthvalueFloat64DynamicDimsModule_basic", + "AtenKthvalueFloat64Module_basic", + "AtenKthvalueKeepDimModule_basic", + "AtenKthvalueModule_basic", + "AvgPool3dStaticModule_basic", + "Conv_Transpose1dModule_basic", + "Conv_Transpose1dStaticModule_basic", + "Conv_Transpose2dStaticModule_basic", + "Conv_Transpose3dModule_basic", + "Conv_Transpose3dStaticModule_basic", + "EinsumStaticDiagonalDimensionModule_basic", + "ElementwiseFloatTensorGtIntTensorModule_basic", + "ElementwiseIntTensorLtFloatTensorModule_basic", + "ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic", + "ElementwiseRemainderScalarModule_Int_NegativeDividend_basic", + "ElementwiseRemainderScalarModule_Int_NegativeDivisor_basic", + "ElementwiseRemainderTensorModule_Float_NegativeDividend_basic", + "ElementwiseRemainderTensorModule_Float_NegativeDivisor_basic", + "ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic", + "ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic", + "ElementwiseRemainderTensorModule_Int_NegativeDividend_basic", + "ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic", + "ElementwiseRreluEvalModule_basic", + "ElementwiseRreluEvalStaticModule_basic", + "ElementwiseRreluTrainModule_basic", + "ElementwiseRreluTrainStaticModule_basic", + "FakeQuantizePerTensorAffineCachemaskModule_basic", + "IndexPutWithNoneAndBroadcastModule_basic", + "MaskedScatterStaticBasic_basic", + "MaxUnpool3dModulePad0_basic", + "MaxUnpool3dModule_basic", + "MultinomialModule2D_F32", + "MultinomialModule2D_basic", + "MultinomialModule_basic", + "ReduceAminSingleDim_basic", + "ReduceAminmaxAllDims_basic", + "ReduceAminmaxSingleDim_basic", + "ReduceAnyDimFloatModule_basic", + "RenormModuleFloat16_basic", + "ScaledDotProductAttentionDifferentModule_basic", + "ScaledDotProductAttentionSameModule_basic", + "ScatterAddStaticModule_basic", + "TensorsConcatComplex128FloatModule_basic", + "TensorsConcatComplex128IntModule_basic", + "TensorsConcatComplex64FloatModule_basic", + "TimeOutModule_basic", + "TrilIndicesAllZerosModule_basic", + "TrilIndicesModule_basic", + "TrilIndicesNegativeOffsetModule_basic", + "TrilIndicesOfssetGreaterThanRowModule_basic", + "TriuIndicesAllZerosModule_basic", + "TriuIndicesModule_basic", + "TriuIndicesNegativeOffsetModule_basic", + "TypeConversionUint8ToF32Module_basic", + "WeightNormInterfaceModule_basic", "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", "AdaptiveAvgPool1dGeneralDynamic_basic", "AdaptiveAvgPool1dStaticLargerOutput_basic", @@ -2960,7 +3034,6 @@ "AdaptiveMaxPool3dStatic_basic", "AddIntModule_basic", "AddFloatIntModule_basic", - "Add_MixPModule_basic", "AllBoolFalseModule_basic", "AllBoolTrueModule_basic", "AnyBoolFalseModule_basic", @@ -2987,7 +3060,6 @@ "AtenFloatScalarModule_basic", "AtenHannWindowPeriodicTrueModule_basic", "AtenHannWindowPeriodicFalseModule_basic", - "AtenInstanceNormModule_basic", "AtenIntBoolOpConstFalseModule_basic", "AtenIntBoolOpConstTrueModule_basic", "AtenIntBoolOpModule_basic", @@ -3018,9 +3090,6 @@ "AtenSubFloatModule_basic", "AtenTopKModule_basic", "AtenTopKSmallestModule_basic", - "AtenTrilModule_basic", - "AtenTrilWithNegDiagonalModule_basic", - "AtenTrilWithPosDiagonalModule_basic", "Aten_CastLongModule_basic", "Aten_EmbeddingBagExample_basic", "AvgPool1dFloatModule_basic", @@ -3163,7 +3232,6 @@ "ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic", "ElementwiseDivScalarRoundingModeTruncModule_basic", "ElementwiseDivScalarRoundingModeTruncStaticModule_basic", - "ElementwiseDivTensorFloatModule_basic", "ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic", "ElementwiseDivTensorRoundingModeFloorModule_basic", "ElementwiseDivTensorRoundingModeFloorStaticModule_basic", @@ -3199,7 +3267,6 @@ "ElementwiseMishModule_basic", "ElementwiseMulTensorComplexDiffModule_basic", "ElementwiseMulTensorComplexModule_basic", - "ElementwiseMulTensorFloatModule_basic", "ElementwisePowScalarModule_basic", "ElementwisePowTensorBroadcastModule_basic", "ElementwisePowTensorBroadcastStaticModule_basic", @@ -3220,14 +3287,10 @@ "ElementwiseSinhModule_basic", "ElementwiseTanIntModule_basic", "ElementwiseTanModule_basic", - "ElementwiseTernaryModule_basic", "ElementwiseToDtypeF32ToI64Module_basic", "ElementwiseToDtypeI64ToUI8Module_basic", "ElementwiseUnaryIntModule_basic", - "ElementwiseWhereScalarOtherModule_basic", "ElementwiseWhereScalarOtherStaticModule_basic", - "ElementwiseWhereScalarSelfModule_basic", - "ElementwiseWhereScalarSelfStaticModule_basic", "EmptyLikeMemoryFormatModule_basic", "EmptyLikeModule_defaultDtype", "EmptyLikeModule_falsePinMemory", @@ -3274,8 +3337,6 @@ "GridSamplerBasic2_basic", "GridSamplerBasic3_basic", "GridSamplerBasic4_basic", - "GroupNormModule_basic", - "GroupNormNoWeightAndBiasModule_basic", "GtFloatIntModule_basic", "GtIntModule_basic", "HBC_basic", @@ -3324,21 +3385,7 @@ "IndexSelectTwoIdxModule_basic", "IndexSelectWholeDimensionModule_basic", "IndexSelectWholeTensorModule_basic", - "IndexTensorDyanmicInputContiguousWithNoneModule_basic", - "IndexTensorDyanmicInputNonContiguousWithNoneModule_basic", - "IndexTensorMultiInputContiguousCenter_basic", - "IndexTensorMultiInputContiguousOneDimDynamic_basic", - "IndexTensorMultiInputNonContiguousDynamic_basic", - "IndexTensorMultiInputNonContiguousMultipleStaticDims_basic", - "IndexTensorMultiInputNonContiguousOneDimDynamic_basic", - "IndexTensorMultiInputNonContiguous_basic", - "IndexTensorMultiInputOneDim_basic", - "IndexTensorMultiInputThreeIndexers_basic", - "IndexTensorMultiInput_basic", "IndexTensorNegativeIndexModule_basic", - "IndexTensorSelectDimModule_basic", - "IndexTensorStaticContiguousWithNoneModule_basic", - "IndexTensorStaticNonContiguousWithNoneModule_basic", "InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_nearest", "InterpolateStaticModule_scales_bilinear_align_corners", @@ -3347,9 +3394,6 @@ "IntImplicitModule_basic", "IsFloatingPointFloat_True", "IsFloatingPointInt_False", - "LayerNormLastDimModule_basic", - "LayerNormModule_basic", - "LayerNormNormalizeOverAllDimsModule_basic", "LenStrModule_basic", "LinalgNormKeepDimComplexModule_basic", "LinalgVectorNormComplexModule_basic", @@ -3358,7 +3402,6 @@ "LinspaceModule_basic", "LinspaceOneSizeModule_basic", "LinspaceTwoSizeModule_basic", - "LogSoftmaxIntModule_basic", "MaskedFillTensorFloatValueModule_basic", "MatmulBroadcastBatchDim_basic", "MatmulStaticBroadcast_basic", @@ -3412,10 +3455,6 @@ "NativeDropoutTrainModule_basic", "NativeDropoutTrainStaticShapeModule_basic", "NativeGroupNormBackwardModule_basic", - "NativeGroupNormModule_basic", - "NativeLayerNormDynamicModule_basic", - "NativeLayerNormModule4D_basic", - "NativeLayerNormModule_basic", "NeFloatIntModule_basic", "NeIntModule_basic", "NewEmptyModuleDefaultDtype_basic", @@ -3506,11 +3545,8 @@ "ReduceL3NormKeepDimComplexModule_basic", "ReduceL3NormKeepDimModule_basic", "ReduceMaxAllDims_basic", - "ReduceMaxAlongDimNegative_basic", "ReduceMaxAlongDimUnsignedInt_basic", - "ReduceMaxAlongDim_basic", "ReduceMaxFloatModule_basic", - "ReduceMaxKeepDim_basic", "ReduceMaxSignedIntModule_basic", "ReduceMaxUnsignedIntModule_basic", "ReduceMinAlongDimNegative_basic", @@ -3601,8 +3637,6 @@ "SliceScatterStepVariationModule_basic", "SliceScatterZeroDimModule_basic", "SliceSizeTwoStepModule_basic", - "SoftmaxIntArgTypeF64Module_basic", - "SoftmaxIntNonNoneDtypeModule_basic", "SoftplusModule_basic", "SortIntListReverse_basic", "SortIntList_basic", @@ -3615,20 +3649,6 @@ "SplitDimStaticModule_basic", "SqrtIntConstantModule_basic", "SqrtIntModule_basic", - "StdBiasedModule_basic", - "StdCorrectionAllDimReduceModule_basic", - "StdCorrectionEmptyDimModule_basic", - "StdCorrectionKeepDimModule_basic", - "StdCorrectionLargeInputModule_basic", - "StdCorrectionModule_basic", - "StdCorrectionNoneModule_basic", - "StdCorrectionSingleDimReduceModule_basic", - "StdDimBiasedModule_basic", - "StdDimEmptyDimModule_basic", - "StdDimKeepDimFalseModule_basic", - "StdDimKeepDimTrueModule_basic", - "StdDimNoneDimModule_basic", - "StdUnbiasedModule_basic", "SubFloatModule_basic", "SubIntModule_basic", "TModuleRank0_basic", @@ -3665,8 +3685,6 @@ "TraceUnsignedIntModule_empty", "TypeConversionI1ToF64Module_basic", "TypeConversionI1ToI32Module_basic", - "UnbindIntGetItem_Module_basic", - "UnbindIntListUnpack_Module_basic", "UniformModule_basic", "UniformNoCorrelationModule_basic", "UniformStaticShapeModule_basic", @@ -3679,30 +3697,9 @@ "UpSampleNearest2dStaticFactor_basic", "UpSampleNearest2dStaticSize_basic", "UpSampleNearest2d_basic", - "VarBiasedModule_basic", - "VarCorrectionAllDimReduceModule_basic", - "VarCorrectionEmptyDimModule_basic", - "VarCorrectionKeepDimModule_basic", - "VarCorrectionLargeInputModule_basic", - "VarCorrectionModule_basic", - "VarCorrectionNoneModule_basic", - "VarCorrectionSingleDimReduceModule_basic", - "VarDimAllDimReduceModule_basic", - "VarDimBiasedModule_basic", - "VarDimEmptyDimModule_basic", - "VarDimModule_basic", - "VarDimMultiDimModule_basic", - "VarDimNegativeModule_basic", - "VarDimNoneDimModule_basic", - "VarDimSingleDimModule_basic", - "VarDimUnbiasedModule_basic", "VarMeanBiasedModule_basic", - "VarMeanCorrectionModule_basic", "VarMeanCorrectionNoneModule_basic", - "VarMeanDimBiasedModule_basic", - "VarMeanDimModule_basic", "VarMeanUnbiasedModule_basic", - "VarUnbiasedModule_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewSizeFromOtherTensor_basic", "ZeroFloat32Module_basic", @@ -3711,7 +3708,79 @@ "ZerosLikeModule_falsePinMemory", } +ONNX_TOSA_CRASHING_SET = { + "StdCorrectionEmptyDimModule_basic", + "StdDimEmptyDimModule_basic", + "VarCorrectionEmptyDimModule_basic", + "VarDimEmptyDimModule_basic", + "ViewSizeFromOtherTensor_basic", +} + ONNX_TOSA_XFAIL_SET = { + "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", + "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic", + "ArgmaxKeepdimModule_basic", + "AtenIntMM_basic", + "AtenKthvalueDynamicDimsModule_basic", + "AtenKthvalueFloat64DynamicDimsModule_basic", + "AtenKthvalueFloat64Module_basic", + "AtenKthvalueKeepDimModule_basic", + "AtenKthvalueModule_basic", + "AvgPool2dCountIncludePadFalseStaticModule_basic", + "AvgPool3dStaticModule_basic", + "Conv_Transpose1dModule_basic", + "Conv_Transpose1dStaticModule_basic", + "Conv_Transpose2dStaticModule_basic", + "Conv_Transpose3dModule_basic", + "Conv_Transpose3dStaticModule_basic", + "EinsumStaticDiagonalDimensionModule_basic", + "EinsumStaticModule_basic", + "ElementwiseFmaxModule_basic", + "ElementwiseFminModule_basic", + "ElementwiseGeluApproximateTanhModule_basic", + "ElementwiseIntTensorLtFloatTensorModule_basic", + "ElementwiseNanToNumWithNoneModule_Basic", + "ElementwiseRad2DegIntModule_basic", + "ElementwiseRad2DegModule_basic", + "ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic", + "ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic", + "ElementwiseRemainderScalarModule_Int_Float_NegativeDivisor_basic", + "ElementwiseRemainderScalarModule_Int_NegativeDividend_basic", + "ElementwiseRemainderScalarModule_Int_NegativeDivisor_basic", + "ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic", + "ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic", + "ElementwiseRemainderTensorModule_Int_NegativeDividend_basic", + "ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic", + "ElementwiseRreluTrainModule_basic", + "ElementwiseRreluTrainStaticModule_basic", + "FakeQuantizePerTensorAffineCachemaskModule_basic", + "IndexPutWithNoneAndBroadcastModule_basic", + "MaskedScatterStaticBasic_basic", + "MaxUnpool3dModulePad0_basic", + "MaxUnpool3dModule_basic", + "MultinomialModule2D_F32", + "MultinomialModule2D_basic", + "MultinomialModule_basic", + "ReduceAmaxEmptyDim_basic", + "ReduceAminSingleDim_basic", + "ReduceAminmaxAllDims_basic", + "ReduceAminmaxSingleDim_basic", + "ReduceAnyDimFloatModule_basic", + "RenormModuleFloat16_basic", + "RenormModuleFloat32DynamicDims_basic", + "RenormModuleFloat32NegativeDim_basic", + "RenormModuleFloat32_basic", + "ScatterAddStaticModule_basic", + "TensorSplitSections_GetItemModule_basic", + "TensorSplitSections_ListUnpackModule_basic", + "TensorsConcatComplex128FloatModule_basic", + "TensorsConcatComplex128IntModule_basic", + "TensorsConcatComplex64FloatModule_basic", + "TimeOutModule_basic", + "TypeConversionUint8ToF32Module_basic", + "UnfoldModule_basic", + "WeightNormInterfaceModule_basic", "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", "AdaptiveAvgPool1dGeneralDynamic_basic", "AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic", @@ -3929,8 +3998,6 @@ "ElementwiseAcoshModule_basic", "ElementwiseAddScalarInt64Module_basic", "ElementwiseAddScalarIntModule_basic", - "ElementwiseAndScalarModule_basic", - "ElementwiseAndScalarStaticShapeModule_basic", "ElementwiseAsinIntModule_basic", "ElementwiseAsinModule_basic", "ElementwiseAsinhIntModule_basic", @@ -3951,7 +4018,6 @@ "ElementwiseAtenFloorDivideScalarNegativeModule_basic", "ElementwiseAtenFloorDivideTensorNegativeModule_basic", "ElementwiseAtenFloorDivideTensorPositiveModule_basic", - "ElementwiseAtenIsinfOpModule_basic", "ElementwiseAtenIsneginfOpModule_basic", "ElementwiseAtenIsposinfOpModule_basic", "ElementwiseAtenLogicalAndOpModule_basic", @@ -3969,10 +4035,6 @@ "ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic", "ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic", "ElementwiseBitwiseAndModule_basic", - "ElementwiseBitwiseAndScalarInt32Module_basic", - "ElementwiseBitwiseAndScalarInt64Module_basic", - "ElementwiseBitwiseAndScalarInt8Module_basic", - "ElementwiseBitwiseAndStaticShapeModule_basic", "ElementwiseBitwiseLeftShiftInt32Module_basic", "ElementwiseBitwiseLeftShiftInt64Module_basic", "ElementwiseBitwiseLeftShiftInt8Module_basic", @@ -3987,12 +4049,8 @@ "ElementwiseBitwiseXorStaticShapeModule_basic", "ElementwiseClampMaxModule_basic", "ElementwiseClampMinModule_basic", - "ElementwiseClampMinTensorFloatModule_basic", - "ElementwiseClampMinTensorIntModule_basic", "ElementwiseClampModule_basic", - "ElementwiseClampTensorFloatModule_basic", "ElementwiseClampTensorInt8Module_basic", - "ElementwiseClampTensorIntModule_basic", "ElementwiseCosIntModule_basic", "ElementwiseCosModule_basic", "ElementwiseCoshIntModule_basic", @@ -4006,7 +4064,6 @@ "ElementwiseDivTensorIntegerModule_basic", "ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic", "ElementwiseDivTensorRoundingModeFloorModule_basic", - "ElementwiseDivTensorRoundingModeFloorStaticModule_basic", "ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic", "ElementwiseDivTensorRoundingModeTruncModule_basic", "ElementwiseDivTensorRoundingModeTruncStaticModule_basic", @@ -4030,7 +4087,6 @@ "ElementwiseGeIntScalarModule_basic", "ElementwiseGeIntTensorModule_basic", "ElementwiseGeMixedIntScalarModule_basic", - "ElementwiseGeluModule_basic", "ElementwiseGtMixed2ScalarModule_basic", "ElementwiseIntTensorLtFloatScalarModule_basic", "ElementwiseIsinfModule_basic", @@ -4084,9 +4140,7 @@ "ElementwiseUnaryIntModule_basic", "ElementwiseUnsqueezeNegDimsModule_basic", "ElementwiseWhereScalarOtherModule_basic", - "ElementwiseWhereScalarOtherStaticModule_basic", "ElementwiseWhereScalarSelfModule_basic", - "ElementwiseWhereScalarSelfStaticModule_basic", "ElementwiseWhereSelfModule_basic", "EmbeddingModule1DIndices_basic", "EmbeddingModuleF16_basic", @@ -4144,8 +4198,6 @@ "HBC_basic", "HardTanhIntModule_basic", "HardTanhModule_basic", - "HardsigmoidModule_basic", - "HardsigmoidRandomModule_basic", "HardtanhBackward_basic", "IndexPut1DFloatAccumulateModule_basic", "IndexPut1DFloatNonAccumulateModule_basic", @@ -4216,7 +4268,6 @@ "IndexTensorStaticNonContiguousWithNoneModule_basic", "InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_nearest", - "InterpolateStaticModule_scales_bilinear_align_corners", "InterpolateDynamicModule_scales_recompute_bilinear", "IntFloatModule_basic", "IntImplicitModule_basic", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 3972e2fd44a6..57bbac296241 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1356,3 +1356,20 @@ func.func @torch.aten.__interpolate.size_list_scale_list.nearest(%arg0: !torch.v %1 = torch.aten.__interpolate.size_list_scale_list %arg0, %none, %0, %str, %false, %none, %false : !torch.vtensor<[1,16,135,240],f32>, !torch.none, !torch.list, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,16,270,480],f32> return %1 : !torch.vtensor<[1,16,270,480],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.tril$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],si32>) -> !torch.vtensor<[2,4],si32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],si32> -> tensor<2x4xi32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1, 1, 0, 0], [1, 1, 1, 0]]> : tensor<2x4xi32>}> : () -> tensor<2x4xi32> +// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]] {shift = 0 : i8} : (tensor<2x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<2x4xi32> -> !torch.vtensor<[2,4],si32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[2,4],si32> +// CHECK: } +func.func @torch.aten.tril$basic(%arg0: !torch.vtensor<[2,4], si32>) -> !torch.vtensor<[2,4], si32> { + %int0 = torch.constant.int 1 + %0 = torch.aten.tril %arg0, %int0 : !torch.vtensor<[2,4],si32>, !torch.int -> !torch.vtensor<[2,4],si32> + return %0 : !torch.vtensor<[2,4],si32> +}