From 6ae9b3218affd19b11d020ed639049e1c3c3c471 Mon Sep 17 00:00:00 2001 From: Abhishek-TyRnT Date: Sat, 13 Jan 2024 08:44:23 +0530 Subject: [PATCH 01/18] ADDED SUPPORT FLOAT VALUE IN ARANGE --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 59 ++++-- projects/pt1/e2e_testing/xfail_sets.py | 214 +++++++++++++++++++++ 2 files changed, 260 insertions(+), 13 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index e123522a4542..bec4054d9a37 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -25,6 +25,7 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include using namespace mlir; using namespace mlir::torch; @@ -4046,28 +4047,60 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "unimplemented: pin_memory must be either None or false"); } - int64_t start, step, end; - if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) + double start, step, end; + int64_t start_int, step_int, end_int; + bool is_all_inp_int; //Flag to check whether all inputs are integer + is_all_inp_int = op.getStart().getType().isa() && op.getEnd().getType().isa() && op.getStep().getType().isa(); + + if (matchPattern(op.getStart(), m_TorchConstantInt(&start_int))) + { + start = (double)(start_int); + } + + else if(!matchPattern(op.getStart(), m_TorchConstantFloat(&start))) return rewriter.notifyMatchFailure( - op, "unimplemented: value `start` should be a torch constant int"); + op, "unimplemented: value `start` should be a torch constant int or float"); - if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) + if (matchPattern(op.getEnd(), m_TorchConstantInt(&end_int))) + { + end = (double)(end_int); + } + else if (!matchPattern(op.getEnd(), m_TorchConstantFloat(&end))) return rewriter.notifyMatchFailure( - op, "unimplemented: value `end` should be a torch constant int"); + op, "unimplemented: value `end` should be a torch constant int or float"); - if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) + if (matchPattern(op.getStep(), m_TorchConstantInt(&step_int))) + { + + step = (double)(step_int); + } + + else if (!matchPattern(op.getStep(), m_TorchConstantFloat(&step))) return rewriter.notifyMatchFailure( - op, "unimplemented: value `step` should be a torch constant int"); + op, "unimplemented: value `step` should be a torch constant int or float"); // The result will always be a 1-d tensor. // The size of the result is calculated as follows: // ceil((end - start)/step) - int64_t resultShape = ceil((float)(end - start) / (float)step); - SmallVector values(resultShape, start); - for (unsigned i = 1; i < resultShape; i++) - values[i] += i * step; - Value result = - tosa::getConstTensor(rewriter, op, values, resultShape).value(); + int64_t resultShape = ceil((end - start) / step); + Value result; + if (is_all_inp_int) + { + SmallVector values(resultShape, start); + for (unsigned i = 1; i < resultShape; i++) + values[i] += i * step; + + result = tosa::getConstTensor(rewriter, op, values, resultShape).value(); + } + + else + { + SmallVector values(resultShape, start); + for (unsigned i = 1; i < resultShape; i++) + values[i] += (i * step); + + result = tosa::getConstTensor(rewriter, op, values, resultShape).value(); + } rewriter.replaceOpWithNewOp(op, resultType, result); return success(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 76f84344bd42..5f710b511d4e 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -949,6 +949,212 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "Convolution2DStridedModule_basic", + "IscloseStaticModule_basic", + "IscloseStaticModuleTrue_basic", + "TileBigDimsSizeModule_basic", + "TileSmallDimsSizeModule_basic", + "IndexPutImpl2DNoneIndexStaticModule_basic", + "AliasModule_basic", + "MaxPool2dEmptyStrideStaticModule_basic", + "ConstantBoolParameterModule_basic", + "ElementwiseCloneContiguousModule_basic", + "ElementwiseCloneChannelsLastMemoryFormatModule_basic", + "ElementwiseCloneModule_basic", + "ElementwiseUnaryModule_basic", + "ElementwiseBinaryModule_basic", + "ElementwiseSigmoidModule_basic", + "ElementwiseExpModule_basic", + "ElementwiseReluModule_basic", + "ElementwiseLeakyReluModule_basic", + "ElementwiseEluModule_basic", + "ElementwiseEluNonDefaultModule_basic", + "ElementwiseFloorModule_basic", + "ElementwiseFloorIntModule_basic", + "ElementwiseLogModule_basic", + "ElementwiseBinaryStaticShapeModule_basic", + "ElementwiseMinimumModule_basic", + "ElementwiseMinimumIntModule_basic", + "ElementwiseMinOtherIntModule_basic", + "ElementwiseMinOtherModule_basic", + "ElementwiseMaximumModule_basic", + "ElementwiseMaximumIntModule_basic", + "ElementwiseMaxOtherIntModule_basic", + "ElementwiseMaxOtherModule_basic", + "GluStaticModule_basic", + "ViewDoubleMergeStaticModule_basic", + "ViewCollapseOnesMiddleModule_basic", + "ViewFiveTestStaticModule_basic", + "ViewOffsetTestStaticModule_basic", + "ViewTwoFiveThreeStaticModule_basic", + "ViewTwoToThreeStaticModule_basic", + "ViewExpandOnesMiddleOppModule_basic", + "ViewOffsetBackwardTestStaticModule_basic", + "TanhBackward_basic", + "HardtanhBackward_basic", + "ElementwiseAddModule_basic", + "ReturnThreeTensorFloat32_basic", + "AddCMulModule_basic", + "AddCDivModule_basic", + "SqueezeModule_broadcast", + "BoolTensorReturnFalseModule_basic", + "BoolTensorReturnTrueModule_basic", + "BoolTensorReturnMixedModule_basic", + "BoolTensorHandleSignless_basic", + "ElementwiseRsqrtModule_basic", + "SelectIntNegativeDimAndIndexStaticModule_basic", + "SqueezeModule_static", + "SqueezeModule_noUnitDim", + "SqueezeModule_allUnitDim", + "TModuleRank1_basic", + "TModuleRank0_basic", + "ElementwiseToDtypeIdentityModule_basic", + "AtenToDeviceModule_basic", + "View1DFoldModule_basic", + "UnsafeView1DFoldModule_basic", + "UnflattenIntStaticModule_basic", + "UnflattenIntNegativeOneDimStaticModule_basic", + "UnflattenIntNegativeOneSizeStaticModule_basic", + "SqueezeDimModule_static", + "SqueezeDimModule_identity", + "SqueezeDimModule_unitDim", + "ReturnTwoTensorF32I64_basic", + "ElementwiseSignModule_basic", + "ElementwisePowModule_basic", + "BmmFloatModule_basic", + "MmDagModule_basic", + "Matmul4dStatic_basic", + "Matmul_dot", + "Matmul_3d", + "RsubFloatModule_basic", + "RsubFloatModule_noalpha_basic", + "RsubInt0d_NumToTensor_Module_basic", + "ElementwiseBitwiseAndModule_basic", + "ElementwiseBitwiseAndStaticShapeModule_basic", + "ElementwiseBitwiseNotInt32Module_basic", + "ElementwiseBitwiseNotInt64Module_basic", + "ElementwiseOrTensorStaticShapeModule_basic", + "ElementwiseOrTensorModule_basic", + "ElementwiseBitwiseOrModule_basic", + "ElementwiseBitwiseOrStaticShapeModule_basic", + "ElementwiseBitwiseXorModule_basic", + "ElementwiseBitwiseXorStaticShapeModule_basic", + "ElementwiseGeFloatIntScalarModule_basic", + "ElementwiseGeFloatScalarModule_basic", + "ElementwiseGeIntScalarModule_basic", + "ElementwiseGeMixedIntScalarModule_basic", + "ElementwiseGtFloatScalarModule_basic", + "ElementwiseGtIntScalarModule_basic", + "ElementwiseGtMixed2ScalarModule_basic", + "ElementwiseGtFloatTensorModule_basic", + "ElementwiseGtIntTensorModule_basic", + "ElementwiseLtFloatScalarModule_basic", + "ElementwiseLtIntScalarModule_basic", + "ElementwiseLtDiffWidthScalarModule_basic", + "ElementwiseLtFloatTensorModule_basic", + "ElementwiseLtIntTensorModule_basic", + "ElementwiseEqFloatScalarModule_basic", + "ElementwiseEqIntScalarModule_basic", + "ElementwiseEqBoolScalarModule_basic", + "ElementwiseEqDiffWidthScalarModule_basic", + "ElementwiseEqFloatTensorModule_basic", + "ElementwiseEqIntTensorModule_basic", + "ElementwiseNeFloatScalarModule_basic", + "ElementwiseNeFloatTensorModule_basic", + "ElementwiseNeFloatTensorStaticModule_basic", + "ElementwiseNeIntTensorModule_basic", + "ElementwiseNeIntTensorStaticModule_basic", + "ElementwiseMulScalarModule_int", + "ElementwiseMulScalarModule_float", + "ElementwiseMulTensorIntModule_basic", + "ElementwiseDivScalarModule_basic", + "ElementwiseAtenDivIntScalarModule_basic", + "ElementwiseSubScalarFloatModule_basic", + "ElementwiseAddScalarFloatModule_basic", + "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", + "ElementwiseMulScalarModule_float", + "ElementwiseCeilModule_basic", + "ElementwiseReciprocalModule_basic", + "ElementwiseIsnanModule_basic", + "ElementwiseIsinfModule_basic", + "TypePromotionAlphaWiderModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", + "BatchNorm1DModule_basic", + "BatchNorm1DWith2DInputModule_basic", + "BatchNorm2DModule_basic", + "BatchNorm3DModule_basic", + "BatchNorm1DStaticShapeModule_basic", + "FlattenStaticModule_basic", + "UnflattenStaticModule_basic", + "FlattenRank0Module_basic", + "ElementwiseFlattenBroadcastModule_basic", + "SquareModule_basic", + "MaxPool2dStaticModule_basic", + "MaxPool2dStaticCeilModeTrueModule_basic", + "ResNet18StaticModule_basic", + "ReduceAmaxKeepDim_basic", + "NativeLayerNormModule4D_basic", + "LayerNormNormalizeOverAllDimsModule_basic", + "PermuteModule_basic", + "PermuteNegativeIndexModule_basic", + "ElementwiseLog2Module_basic", + "Threshold1dIntI32Module_basic", + "Threshold1dFloatModule_basic", + "Threshold2dFloatModule_basic", + "Threshold3dFloatModule_basic", + "ElementwiseSubScalarIntModule_basic", + "ElementwiseAddScalarIntModule_basic", + "ElementwiseMulScalarModule_basic", + "ZerosModuleDefaultDtype_basic", + "ZerosModuleInt2D_basic", + "ZerosModuleInt3D_basic", + "ZerosModuleFloat2D_basic", + "ZerosModuleFloat3D_basic", + "ZerosModuleFalsePinMemory_basic", + "OnesModuleDefaultDtype_basic", + "OnesModuleInt_basic", + "OnesModuleFloat_basic", + "OnesModuleFalsePinMemory_basic", + "OnesModuleCPUDevice_basic", + "NewZerosModuleDefaultDtype_basic", + "NewZerosModuleInt2D_basic", + "NewZerosModuleInt3D_basic", + "NewZerosModuleFloat2D_basic", + "NewZerosModuleFloat3D_basic", + "NewZerosModuleFalsePinMemory_basic", + "NewOnesModuleDefaultDtype_basic", + "NewOnesModuleInt2D_basic", + "NewOnesModuleInt3D_basic", + "NewOnesModuleFloat2D_basic", + "NewOnesModuleFloat3D_basic", + "NewOnesModuleFalsePinMemory_basic", + "SiluModule_basic", + "DropoutEvalIntModule_basic", + "DropoutEvalFloatModule_basic", + "ContiguousModule_basic", + "DropoutModule_basic", + "ViewExpandModule_basic", + "ViewExpandOnesModule_basic", + "ViewExpandOnesBeforeAndAfterModule_basic", + "ViewExpandOnesMiddleModule_basic", + "ViewExpandCollapseModule_basic", + "ViewExpandCollapseWithOnesModule_basic", + "ViewCollapseInferredDimModule_basic", + "ViewExpandInferredDimModule_basic", + "ViewNegativeStaticModule_basic", + "ViewNoChangeStaticModule_basic", + "UnsafeViewExpandModule_basic", + "ReshapeCollapseModule_basic", + "ReshapeAsModule_basic", + "ElementwiseGeluModule_basic", + "GeluBackwardModule_basic", + "ElementwiseNeIntScalarModule_basic", + "Convolution2DStaticModule_basic", + "ElementwiseNegModule_basic", + "TestMultipleTensorReturn_basic", + "TypeAsSameModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", "AddCDivModule_basic", @@ -966,6 +1172,14 @@ "ArangeStartOutViewModule_basic", "ArangeStartStepIntModule_basic", "ArangeZeroElementOutputModule_basic", + "ArangeDtypeIntModule_basic", + "ArangeFalsePinMemoryModule_basic", + "ArangeFloatModule_basic", + "ArangeNegativeStartFloatModule_basic", + "ArangeStartFloatModule_basic", + "ArangeStartNegativeStepFloatModule_basic", + "ArangeStartOutDtypeModule_basic", + "ArangeStartStepFloatModule_basic", "ArgmaxModule_keepDim", "ArgmaxModule_with_dim", "AtenComplex64Module_basic", From 42fac70758805a31cb127b6277c2a1746fa7b1ff Mon Sep 17 00:00:00 2001 From: Abhishek-TyRnT Date: Sat, 13 Jan 2024 08:58:07 +0530 Subject: [PATCH 02/18] got rid of extra tosa tests --- projects/pt1/e2e_testing/xfail_sets.py | 206 ------------------------- 1 file changed, 206 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 753c60d3dac5..2c5a94034c4f 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -955,212 +955,6 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { - "Convolution2DStridedModule_basic", - "IscloseStaticModule_basic", - "IscloseStaticModuleTrue_basic", - "TileBigDimsSizeModule_basic", - "TileSmallDimsSizeModule_basic", - "IndexPutImpl2DNoneIndexStaticModule_basic", - "AliasModule_basic", - "MaxPool2dEmptyStrideStaticModule_basic", - "ConstantBoolParameterModule_basic", - "ElementwiseCloneContiguousModule_basic", - "ElementwiseCloneChannelsLastMemoryFormatModule_basic", - "ElementwiseCloneModule_basic", - "ElementwiseUnaryModule_basic", - "ElementwiseBinaryModule_basic", - "ElementwiseSigmoidModule_basic", - "ElementwiseExpModule_basic", - "ElementwiseReluModule_basic", - "ElementwiseLeakyReluModule_basic", - "ElementwiseEluModule_basic", - "ElementwiseEluNonDefaultModule_basic", - "ElementwiseFloorModule_basic", - "ElementwiseFloorIntModule_basic", - "ElementwiseLogModule_basic", - "ElementwiseBinaryStaticShapeModule_basic", - "ElementwiseMinimumModule_basic", - "ElementwiseMinimumIntModule_basic", - "ElementwiseMinOtherIntModule_basic", - "ElementwiseMinOtherModule_basic", - "ElementwiseMaximumModule_basic", - "ElementwiseMaximumIntModule_basic", - "ElementwiseMaxOtherIntModule_basic", - "ElementwiseMaxOtherModule_basic", - "GluStaticModule_basic", - "ViewDoubleMergeStaticModule_basic", - "ViewCollapseOnesMiddleModule_basic", - "ViewFiveTestStaticModule_basic", - "ViewOffsetTestStaticModule_basic", - "ViewTwoFiveThreeStaticModule_basic", - "ViewTwoToThreeStaticModule_basic", - "ViewExpandOnesMiddleOppModule_basic", - "ViewOffsetBackwardTestStaticModule_basic", - "TanhBackward_basic", - "HardtanhBackward_basic", - "ElementwiseAddModule_basic", - "ReturnThreeTensorFloat32_basic", - "AddCMulModule_basic", - "AddCDivModule_basic", - "SqueezeModule_broadcast", - "BoolTensorReturnFalseModule_basic", - "BoolTensorReturnTrueModule_basic", - "BoolTensorReturnMixedModule_basic", - "BoolTensorHandleSignless_basic", - "ElementwiseRsqrtModule_basic", - "SelectIntNegativeDimAndIndexStaticModule_basic", - "SqueezeModule_static", - "SqueezeModule_noUnitDim", - "SqueezeModule_allUnitDim", - "TModuleRank1_basic", - "TModuleRank0_basic", - "ElementwiseToDtypeIdentityModule_basic", - "AtenToDeviceModule_basic", - "View1DFoldModule_basic", - "UnsafeView1DFoldModule_basic", - "UnflattenIntStaticModule_basic", - "UnflattenIntNegativeOneDimStaticModule_basic", - "UnflattenIntNegativeOneSizeStaticModule_basic", - "SqueezeDimModule_static", - "SqueezeDimModule_identity", - "SqueezeDimModule_unitDim", - "ReturnTwoTensorF32I64_basic", - "ElementwiseSignModule_basic", - "ElementwisePowModule_basic", - "BmmFloatModule_basic", - "MmDagModule_basic", - "Matmul4dStatic_basic", - "Matmul_dot", - "Matmul_3d", - "RsubFloatModule_basic", - "RsubFloatModule_noalpha_basic", - "RsubInt0d_NumToTensor_Module_basic", - "ElementwiseBitwiseAndModule_basic", - "ElementwiseBitwiseAndStaticShapeModule_basic", - "ElementwiseBitwiseNotInt32Module_basic", - "ElementwiseBitwiseNotInt64Module_basic", - "ElementwiseOrTensorStaticShapeModule_basic", - "ElementwiseOrTensorModule_basic", - "ElementwiseBitwiseOrModule_basic", - "ElementwiseBitwiseOrStaticShapeModule_basic", - "ElementwiseBitwiseXorModule_basic", - "ElementwiseBitwiseXorStaticShapeModule_basic", - "ElementwiseGeFloatIntScalarModule_basic", - "ElementwiseGeFloatScalarModule_basic", - "ElementwiseGeIntScalarModule_basic", - "ElementwiseGeMixedIntScalarModule_basic", - "ElementwiseGtFloatScalarModule_basic", - "ElementwiseGtIntScalarModule_basic", - "ElementwiseGtMixed2ScalarModule_basic", - "ElementwiseGtFloatTensorModule_basic", - "ElementwiseGtIntTensorModule_basic", - "ElementwiseLtFloatScalarModule_basic", - "ElementwiseLtIntScalarModule_basic", - "ElementwiseLtDiffWidthScalarModule_basic", - "ElementwiseLtFloatTensorModule_basic", - "ElementwiseLtIntTensorModule_basic", - "ElementwiseEqFloatScalarModule_basic", - "ElementwiseEqIntScalarModule_basic", - "ElementwiseEqBoolScalarModule_basic", - "ElementwiseEqDiffWidthScalarModule_basic", - "ElementwiseEqFloatTensorModule_basic", - "ElementwiseEqIntTensorModule_basic", - "ElementwiseNeFloatScalarModule_basic", - "ElementwiseNeFloatTensorModule_basic", - "ElementwiseNeFloatTensorStaticModule_basic", - "ElementwiseNeIntTensorModule_basic", - "ElementwiseNeIntTensorStaticModule_basic", - "ElementwiseMulScalarModule_int", - "ElementwiseMulScalarModule_float", - "ElementwiseMulTensorIntModule_basic", - "ElementwiseDivScalarModule_basic", - "ElementwiseAtenDivIntScalarModule_basic", - "ElementwiseSubScalarFloatModule_basic", - "ElementwiseAddScalarFloatModule_basic", - "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", - "ElementwiseMulScalarModule_float", - "ElementwiseCeilModule_basic", - "ElementwiseReciprocalModule_basic", - "ElementwiseIsnanModule_basic", - "ElementwiseIsinfModule_basic", - "TypePromotionAlphaWiderModule_basic", - "Conv2dWithPaddingDilationStrideStaticModule_basic", - "Conv2dWithPaddingDilationStrideStaticModule_depthwise", - "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", - "BatchNorm1DModule_basic", - "BatchNorm1DWith2DInputModule_basic", - "BatchNorm2DModule_basic", - "BatchNorm3DModule_basic", - "BatchNorm1DStaticShapeModule_basic", - "FlattenStaticModule_basic", - "UnflattenStaticModule_basic", - "FlattenRank0Module_basic", - "ElementwiseFlattenBroadcastModule_basic", - "SquareModule_basic", - "MaxPool2dStaticModule_basic", - "MaxPool2dStaticCeilModeTrueModule_basic", - "ResNet18StaticModule_basic", - "ReduceAmaxKeepDim_basic", - "NativeLayerNormModule4D_basic", - "LayerNormNormalizeOverAllDimsModule_basic", - "PermuteModule_basic", - "PermuteNegativeIndexModule_basic", - "ElementwiseLog2Module_basic", - "Threshold1dIntI32Module_basic", - "Threshold1dFloatModule_basic", - "Threshold2dFloatModule_basic", - "Threshold3dFloatModule_basic", - "ElementwiseSubScalarIntModule_basic", - "ElementwiseAddScalarIntModule_basic", - "ElementwiseMulScalarModule_basic", - "ZerosModuleDefaultDtype_basic", - "ZerosModuleInt2D_basic", - "ZerosModuleInt3D_basic", - "ZerosModuleFloat2D_basic", - "ZerosModuleFloat3D_basic", - "ZerosModuleFalsePinMemory_basic", - "OnesModuleDefaultDtype_basic", - "OnesModuleInt_basic", - "OnesModuleFloat_basic", - "OnesModuleFalsePinMemory_basic", - "OnesModuleCPUDevice_basic", - "NewZerosModuleDefaultDtype_basic", - "NewZerosModuleInt2D_basic", - "NewZerosModuleInt3D_basic", - "NewZerosModuleFloat2D_basic", - "NewZerosModuleFloat3D_basic", - "NewZerosModuleFalsePinMemory_basic", - "NewOnesModuleDefaultDtype_basic", - "NewOnesModuleInt2D_basic", - "NewOnesModuleInt3D_basic", - "NewOnesModuleFloat2D_basic", - "NewOnesModuleFloat3D_basic", - "NewOnesModuleFalsePinMemory_basic", - "SiluModule_basic", - "DropoutEvalIntModule_basic", - "DropoutEvalFloatModule_basic", - "ContiguousModule_basic", - "DropoutModule_basic", - "ViewExpandModule_basic", - "ViewExpandOnesModule_basic", - "ViewExpandOnesBeforeAndAfterModule_basic", - "ViewExpandOnesMiddleModule_basic", - "ViewExpandCollapseModule_basic", - "ViewExpandCollapseWithOnesModule_basic", - "ViewCollapseInferredDimModule_basic", - "ViewExpandInferredDimModule_basic", - "ViewNegativeStaticModule_basic", - "ViewNoChangeStaticModule_basic", - "UnsafeViewExpandModule_basic", - "ReshapeCollapseModule_basic", - "ReshapeAsModule_basic", - "ElementwiseGeluModule_basic", - "GeluBackwardModule_basic", - "ElementwiseNeIntScalarModule_basic", - "Convolution2DStaticModule_basic", - "ElementwiseNegModule_basic", - "TestMultipleTensorReturn_basic", - "TypeAsSameModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", "AddCDivModule_basic", From b85c84e6f91dc372daeb9f745763e32272a112cb Mon Sep 17 00:00:00 2001 From: Abhishek-TyRnT Date: Tue, 16 Jan 2024 08:55:37 +0530 Subject: [PATCH 03/18] git rid of iostream import --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 044f21134a6c..8d36194cc1c8 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -26,7 +26,6 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" -#include using namespace mlir; using namespace mlir::torch; From 153080288f18dd562e4bcf45861ef59311f34a9c Mon Sep 17 00:00:00 2001 From: Abhishek-TyRnT Date: Sat, 3 Feb 2024 13:50:02 +0530 Subject: [PATCH 04/18] using int in result shape --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index a46fdc5f549c..001b9ad981a7 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -4102,11 +4102,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // The result will always be a 1-d tensor. // The size of the result is calculated as follows: // ceil((end - start)/step) - int64_t resultShape = ceil((end - start) / step); + Value result; if (is_all_inp_int) { - SmallVector values(resultShape, start); + int64_t resultShape = ceil((float)(end_int - start_int) / (float)(step_int)); + SmallVector values(resultShape, start_int); for (unsigned i = 1; i < resultShape; i++) values[i] += i * step; @@ -4115,6 +4116,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( else { + int64_t resultShape = ceil((end - start) / step); SmallVector values(resultShape, start); for (unsigned i = 1; i < resultShape; i++) values[i] += (i * step); From b6e1bcf0a49e5b62ea2d8d433e00dd60703fd54b Mon Sep 17 00:00:00 2001 From: Abhishek-TyRnT Date: Mon, 5 Feb 2024 22:25:27 +0530 Subject: [PATCH 05/18] got rid of resultshape for int case --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 001b9ad981a7..dbab3d40d14c 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -4106,22 +4106,21 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value result; if (is_all_inp_int) { - int64_t resultShape = ceil((float)(end_int - start_int) / (float)(step_int)); - SmallVector values(resultShape, start_int); - for (unsigned i = 1; i < resultShape; i++) - values[i] += i * step; + SmallVector values(start_int); + for (int64_t i = start_int; i < end_int; i += step_int) + values.push_back(i); - result = tosa::getConstTensor(rewriter, op, values, resultShape).value(); + result = tosa::getConstTensor(rewriter, op, values, values.size()).value(); } else { - int64_t resultShape = ceil((end - start) / step); - SmallVector values(resultShape, start); - for (unsigned i = 1; i < resultShape; i++) + int64_t resultSize = ceil((end - start) / step); + SmallVector values(resultSize, start); + for (unsigned i = 1; i < resultSize; i++) values[i] += (i * step); - result = tosa::getConstTensor(rewriter, op, values, resultShape).value(); + result = tosa::getConstTensor(rewriter, op, values, resultSize).value(); } rewriter.replaceOpWithNewOp(op, resultType, result); From 5b59626ab544a918359284400ea7e33e0d1ebbd9 Mon Sep 17 00:00:00 2001 From: Abhishek-TyRnT Date: Tue, 6 Feb 2024 22:35:54 +0530 Subject: [PATCH 06/18] got rid of result shape in all int case --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index dbab3d40d14c..cb2b7fe2b8e5 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -4106,9 +4106,18 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value result; if (is_all_inp_int) { - SmallVector values(start_int); - for (int64_t i = start_int; i < end_int; i += step_int) - values.push_back(i); + SmallVector values; + if (step_int >= 0) + { + for (int64_t i = start_int; i < end_int; i += step_int) + values.push_back(i); + } + + else + { + for (int64_t i = start_int; i > end_int; i += step_int) + values.push_back(i); + } result = tosa::getConstTensor(rewriter, op, values, values.size()).value(); } From cb4ed3e2bd72a30b56bd0ebcd39ec50f7ef82453 Mon Sep 17 00:00:00 2001 From: Abhishek-TyRnT Date: Thu, 15 Feb 2024 22:35:25 +0530 Subject: [PATCH 07/18] using static cast instead of dynamic cast --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index cb2b7fe2b8e5..24ad70235607 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -4069,12 +4069,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( double start, step, end; int64_t start_int, step_int, end_int; - bool is_all_inp_int; //Flag to check whether all inputs are integer - is_all_inp_int = op.getStart().getType().isa() && op.getEnd().getType().isa() && op.getStep().getType().isa(); + auto isInteger = [=](Value v) { return v.getType().isa(); }; + //Flag to check whether all inputs are integer + bool integer_range = isInteger(op.getStart()) && isInteger(op.getEnd()) && isInteger(op.getStep()); if (matchPattern(op.getStart(), m_TorchConstantInt(&start_int))) { - start = (double)(start_int); + start = static_cast(start_int); } else if(!matchPattern(op.getStart(), m_TorchConstantFloat(&start))) @@ -4083,7 +4084,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (matchPattern(op.getEnd(), m_TorchConstantInt(&end_int))) { - end = (double)(end_int); + end = static_cast(end_int); } else if (!matchPattern(op.getEnd(), m_TorchConstantFloat(&end))) return rewriter.notifyMatchFailure( @@ -4092,7 +4093,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (matchPattern(op.getStep(), m_TorchConstantInt(&step_int))) { - step = (double)(step_int); + step = static_cast(step_int); } else if (!matchPattern(op.getStep(), m_TorchConstantFloat(&step))) @@ -4104,7 +4105,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // ceil((end - start)/step) Value result; - if (is_all_inp_int) + if (integer_range) { SmallVector values; if (step_int >= 0) From 5d3194bebd398d69a1c401e4b759e6d284d40b32 Mon Sep 17 00:00:00 2001 From: Abhishek-TyRnT Date: Mon, 19 Feb 2024 20:19:09 +0530 Subject: [PATCH 08/18] typecasting for int64type --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 24ad70235607..5eb22637b12a 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -4070,6 +4070,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( double start, step, end; int64_t start_int, step_int, end_int; auto isInteger = [=](Value v) { return v.getType().isa(); }; + bool isOutputInt64=false; + auto intType = resultType.getElementType().dyn_cast_or_null(); + + if(intType) + { + if(intType.getWidth() == 64) + { + isOutputInt64 = true; + } + } //Flag to check whether all inputs are integer bool integer_range = isInteger(op.getStart()) && isInteger(op.getEnd()) && isInteger(op.getStep()); @@ -4123,6 +4133,18 @@ LogicalResult ConvertAtenOp::matchAndRewrite( result = tosa::getConstTensor(rewriter, op, values, values.size()).value(); } + //Since typecasting from float32 or float64 to int64 results in, seemingly + //garbage values. Therefore typecasting here itself. + else if(isOutputInt64) + { + int64_t resultSize = ceil((end - start) / step); + SmallVector values(resultSize, start); + for (unsigned i = 1; i < resultSize; i++) + values[i] += static_cast(i * step); + + result = tosa::getConstTensor(rewriter, op, values, resultSize).value(); + } + else { int64_t resultSize = ceil((end - start) / step); From 0ee752bc688c49d3b55995bb85db9262c8fdaad0 Mon Sep 17 00:00:00 2001 From: Abhishek-TyRnT Date: Sat, 13 Jan 2024 08:44:23 +0530 Subject: [PATCH 09/18] ADDED SUPPORT FLOAT VALUE IN ARANGE --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 59 ++++-- projects/pt1/e2e_testing/xfail_sets.py | 214 +++++++++++++++++++++ 2 files changed, 260 insertions(+), 13 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index b49c9af8adce..0b5819631f52 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -26,6 +26,7 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include using namespace mlir; using namespace mlir::torch; @@ -4067,28 +4068,60 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "unimplemented: pin_memory must be either None or false"); } - int64_t start, step, end; - if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) + double start, step, end; + int64_t start_int, step_int, end_int; + bool is_all_inp_int; //Flag to check whether all inputs are integer + is_all_inp_int = op.getStart().getType().isa() && op.getEnd().getType().isa() && op.getStep().getType().isa(); + + if (matchPattern(op.getStart(), m_TorchConstantInt(&start_int))) + { + start = (double)(start_int); + } + + else if(!matchPattern(op.getStart(), m_TorchConstantFloat(&start))) return rewriter.notifyMatchFailure( - op, "unimplemented: value `start` should be a torch constant int"); + op, "unimplemented: value `start` should be a torch constant int or float"); - if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) + if (matchPattern(op.getEnd(), m_TorchConstantInt(&end_int))) + { + end = (double)(end_int); + } + else if (!matchPattern(op.getEnd(), m_TorchConstantFloat(&end))) return rewriter.notifyMatchFailure( - op, "unimplemented: value `end` should be a torch constant int"); + op, "unimplemented: value `end` should be a torch constant int or float"); - if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) + if (matchPattern(op.getStep(), m_TorchConstantInt(&step_int))) + { + + step = (double)(step_int); + } + + else if (!matchPattern(op.getStep(), m_TorchConstantFloat(&step))) return rewriter.notifyMatchFailure( - op, "unimplemented: value `step` should be a torch constant int"); + op, "unimplemented: value `step` should be a torch constant int or float"); // The result will always be a 1-d tensor. // The size of the result is calculated as follows: // ceil((end - start)/step) - int64_t resultShape = ceil((float)(end - start) / (float)step); - SmallVector values(resultShape, start); - for (unsigned i = 1; i < resultShape; i++) - values[i] += i * step; - Value result = - tosa::getConstTensor(rewriter, op, values, resultShape).value(); + int64_t resultShape = ceil((end - start) / step); + Value result; + if (is_all_inp_int) + { + SmallVector values(resultShape, start); + for (unsigned i = 1; i < resultShape; i++) + values[i] += i * step; + + result = tosa::getConstTensor(rewriter, op, values, resultShape).value(); + } + + else + { + SmallVector values(resultShape, start); + for (unsigned i = 1; i < resultShape; i++) + values[i] += (i * step); + + result = tosa::getConstTensor(rewriter, op, values, resultShape).value(); + } rewriter.replaceOpWithNewOp(op, resultType, result); return success(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 70f26fe421e0..c9768152da25 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -871,6 +871,212 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "Convolution2DStridedModule_basic", + "IscloseStaticModule_basic", + "IscloseStaticModuleTrue_basic", + "TileBigDimsSizeModule_basic", + "TileSmallDimsSizeModule_basic", + "IndexPutImpl2DNoneIndexStaticModule_basic", + "AliasModule_basic", + "MaxPool2dEmptyStrideStaticModule_basic", + "ConstantBoolParameterModule_basic", + "ElementwiseCloneContiguousModule_basic", + "ElementwiseCloneChannelsLastMemoryFormatModule_basic", + "ElementwiseCloneModule_basic", + "ElementwiseUnaryModule_basic", + "ElementwiseBinaryModule_basic", + "ElementwiseSigmoidModule_basic", + "ElementwiseExpModule_basic", + "ElementwiseReluModule_basic", + "ElementwiseLeakyReluModule_basic", + "ElementwiseEluModule_basic", + "ElementwiseEluNonDefaultModule_basic", + "ElementwiseFloorModule_basic", + "ElementwiseFloorIntModule_basic", + "ElementwiseLogModule_basic", + "ElementwiseBinaryStaticShapeModule_basic", + "ElementwiseMinimumModule_basic", + "ElementwiseMinimumIntModule_basic", + "ElementwiseMinOtherIntModule_basic", + "ElementwiseMinOtherModule_basic", + "ElementwiseMaximumModule_basic", + "ElementwiseMaximumIntModule_basic", + "ElementwiseMaxOtherIntModule_basic", + "ElementwiseMaxOtherModule_basic", + "GluStaticModule_basic", + "ViewDoubleMergeStaticModule_basic", + "ViewCollapseOnesMiddleModule_basic", + "ViewFiveTestStaticModule_basic", + "ViewOffsetTestStaticModule_basic", + "ViewTwoFiveThreeStaticModule_basic", + "ViewTwoToThreeStaticModule_basic", + "ViewExpandOnesMiddleOppModule_basic", + "ViewOffsetBackwardTestStaticModule_basic", + "TanhBackward_basic", + "HardtanhBackward_basic", + "ElementwiseAddModule_basic", + "ReturnThreeTensorFloat32_basic", + "AddCMulModule_basic", + "AddCDivModule_basic", + "SqueezeModule_broadcast", + "BoolTensorReturnFalseModule_basic", + "BoolTensorReturnTrueModule_basic", + "BoolTensorReturnMixedModule_basic", + "BoolTensorHandleSignless_basic", + "ElementwiseRsqrtModule_basic", + "SelectIntNegativeDimAndIndexStaticModule_basic", + "SqueezeModule_static", + "SqueezeModule_noUnitDim", + "SqueezeModule_allUnitDim", + "TModuleRank1_basic", + "TModuleRank0_basic", + "ElementwiseToDtypeIdentityModule_basic", + "AtenToDeviceModule_basic", + "View1DFoldModule_basic", + "UnsafeView1DFoldModule_basic", + "UnflattenIntStaticModule_basic", + "UnflattenIntNegativeOneDimStaticModule_basic", + "UnflattenIntNegativeOneSizeStaticModule_basic", + "SqueezeDimModule_static", + "SqueezeDimModule_identity", + "SqueezeDimModule_unitDim", + "ReturnTwoTensorF32I64_basic", + "ElementwiseSignModule_basic", + "ElementwisePowModule_basic", + "BmmFloatModule_basic", + "MmDagModule_basic", + "Matmul4dStatic_basic", + "Matmul_dot", + "Matmul_3d", + "RsubFloatModule_basic", + "RsubFloatModule_noalpha_basic", + "RsubInt0d_NumToTensor_Module_basic", + "ElementwiseBitwiseAndModule_basic", + "ElementwiseBitwiseAndStaticShapeModule_basic", + "ElementwiseBitwiseNotInt32Module_basic", + "ElementwiseBitwiseNotInt64Module_basic", + "ElementwiseOrTensorStaticShapeModule_basic", + "ElementwiseOrTensorModule_basic", + "ElementwiseBitwiseOrModule_basic", + "ElementwiseBitwiseOrStaticShapeModule_basic", + "ElementwiseBitwiseXorModule_basic", + "ElementwiseBitwiseXorStaticShapeModule_basic", + "ElementwiseGeFloatIntScalarModule_basic", + "ElementwiseGeFloatScalarModule_basic", + "ElementwiseGeIntScalarModule_basic", + "ElementwiseGeMixedIntScalarModule_basic", + "ElementwiseGtFloatScalarModule_basic", + "ElementwiseGtIntScalarModule_basic", + "ElementwiseGtMixed2ScalarModule_basic", + "ElementwiseGtFloatTensorModule_basic", + "ElementwiseGtIntTensorModule_basic", + "ElementwiseLtFloatScalarModule_basic", + "ElementwiseLtIntScalarModule_basic", + "ElementwiseLtDiffWidthScalarModule_basic", + "ElementwiseLtFloatTensorModule_basic", + "ElementwiseLtIntTensorModule_basic", + "ElementwiseEqFloatScalarModule_basic", + "ElementwiseEqIntScalarModule_basic", + "ElementwiseEqBoolScalarModule_basic", + "ElementwiseEqDiffWidthScalarModule_basic", + "ElementwiseEqFloatTensorModule_basic", + "ElementwiseEqIntTensorModule_basic", + "ElementwiseNeFloatScalarModule_basic", + "ElementwiseNeFloatTensorModule_basic", + "ElementwiseNeFloatTensorStaticModule_basic", + "ElementwiseNeIntTensorModule_basic", + "ElementwiseNeIntTensorStaticModule_basic", + "ElementwiseMulScalarModule_int", + "ElementwiseMulScalarModule_float", + "ElementwiseMulTensorIntModule_basic", + "ElementwiseDivScalarModule_basic", + "ElementwiseAtenDivIntScalarModule_basic", + "ElementwiseSubScalarFloatModule_basic", + "ElementwiseAddScalarFloatModule_basic", + "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", + "ElementwiseMulScalarModule_float", + "ElementwiseCeilModule_basic", + "ElementwiseReciprocalModule_basic", + "ElementwiseIsnanModule_basic", + "ElementwiseIsinfModule_basic", + "TypePromotionAlphaWiderModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", + "BatchNorm1DModule_basic", + "BatchNorm1DWith2DInputModule_basic", + "BatchNorm2DModule_basic", + "BatchNorm3DModule_basic", + "BatchNorm1DStaticShapeModule_basic", + "FlattenStaticModule_basic", + "UnflattenStaticModule_basic", + "FlattenRank0Module_basic", + "ElementwiseFlattenBroadcastModule_basic", + "SquareModule_basic", + "MaxPool2dStaticModule_basic", + "MaxPool2dStaticCeilModeTrueModule_basic", + "ResNet18StaticModule_basic", + "ReduceAmaxKeepDim_basic", + "NativeLayerNormModule4D_basic", + "LayerNormNormalizeOverAllDimsModule_basic", + "PermuteModule_basic", + "PermuteNegativeIndexModule_basic", + "ElementwiseLog2Module_basic", + "Threshold1dIntI32Module_basic", + "Threshold1dFloatModule_basic", + "Threshold2dFloatModule_basic", + "Threshold3dFloatModule_basic", + "ElementwiseSubScalarIntModule_basic", + "ElementwiseAddScalarIntModule_basic", + "ElementwiseMulScalarModule_basic", + "ZerosModuleDefaultDtype_basic", + "ZerosModuleInt2D_basic", + "ZerosModuleInt3D_basic", + "ZerosModuleFloat2D_basic", + "ZerosModuleFloat3D_basic", + "ZerosModuleFalsePinMemory_basic", + "OnesModuleDefaultDtype_basic", + "OnesModuleInt_basic", + "OnesModuleFloat_basic", + "OnesModuleFalsePinMemory_basic", + "OnesModuleCPUDevice_basic", + "NewZerosModuleDefaultDtype_basic", + "NewZerosModuleInt2D_basic", + "NewZerosModuleInt3D_basic", + "NewZerosModuleFloat2D_basic", + "NewZerosModuleFloat3D_basic", + "NewZerosModuleFalsePinMemory_basic", + "NewOnesModuleDefaultDtype_basic", + "NewOnesModuleInt2D_basic", + "NewOnesModuleInt3D_basic", + "NewOnesModuleFloat2D_basic", + "NewOnesModuleFloat3D_basic", + "NewOnesModuleFalsePinMemory_basic", + "SiluModule_basic", + "DropoutEvalIntModule_basic", + "DropoutEvalFloatModule_basic", + "ContiguousModule_basic", + "DropoutModule_basic", + "ViewExpandModule_basic", + "ViewExpandOnesModule_basic", + "ViewExpandOnesBeforeAndAfterModule_basic", + "ViewExpandOnesMiddleModule_basic", + "ViewExpandCollapseModule_basic", + "ViewExpandCollapseWithOnesModule_basic", + "ViewCollapseInferredDimModule_basic", + "ViewExpandInferredDimModule_basic", + "ViewNegativeStaticModule_basic", + "ViewNoChangeStaticModule_basic", + "UnsafeViewExpandModule_basic", + "ReshapeCollapseModule_basic", + "ReshapeAsModule_basic", + "ElementwiseGeluModule_basic", + "GeluBackwardModule_basic", + "ElementwiseNeIntScalarModule_basic", + "Convolution2DStaticModule_basic", + "ElementwiseNegModule_basic", + "TestMultipleTensorReturn_basic", + "TypeAsSameModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", "AddCDivModule_basic", @@ -888,6 +1094,14 @@ "ArangeStartOutViewModule_basic", "ArangeStartStepIntModule_basic", "ArangeZeroElementOutputModule_basic", + "ArangeDtypeIntModule_basic", + "ArangeFalsePinMemoryModule_basic", + "ArangeFloatModule_basic", + "ArangeNegativeStartFloatModule_basic", + "ArangeStartFloatModule_basic", + "ArangeStartNegativeStepFloatModule_basic", + "ArangeStartOutDtypeModule_basic", + "ArangeStartStepFloatModule_basic", "ArgmaxModule_keepDim", "ArgmaxModule_with_dim", "AtenComplex64Module_basic", From 6b26100dc6710f71c222298d9956cd4cc6eb8b08 Mon Sep 17 00:00:00 2001 From: Abhishek-TyRnT Date: Sat, 13 Jan 2024 08:58:07 +0530 Subject: [PATCH 10/18] got rid of extra tosa tests --- projects/pt1/e2e_testing/xfail_sets.py | 206 ------------------------- 1 file changed, 206 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index c9768152da25..0e92f61b89c8 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -871,212 +871,6 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { - "Convolution2DStridedModule_basic", - "IscloseStaticModule_basic", - "IscloseStaticModuleTrue_basic", - "TileBigDimsSizeModule_basic", - "TileSmallDimsSizeModule_basic", - "IndexPutImpl2DNoneIndexStaticModule_basic", - "AliasModule_basic", - "MaxPool2dEmptyStrideStaticModule_basic", - "ConstantBoolParameterModule_basic", - "ElementwiseCloneContiguousModule_basic", - "ElementwiseCloneChannelsLastMemoryFormatModule_basic", - "ElementwiseCloneModule_basic", - "ElementwiseUnaryModule_basic", - "ElementwiseBinaryModule_basic", - "ElementwiseSigmoidModule_basic", - "ElementwiseExpModule_basic", - "ElementwiseReluModule_basic", - "ElementwiseLeakyReluModule_basic", - "ElementwiseEluModule_basic", - "ElementwiseEluNonDefaultModule_basic", - "ElementwiseFloorModule_basic", - "ElementwiseFloorIntModule_basic", - "ElementwiseLogModule_basic", - "ElementwiseBinaryStaticShapeModule_basic", - "ElementwiseMinimumModule_basic", - "ElementwiseMinimumIntModule_basic", - "ElementwiseMinOtherIntModule_basic", - "ElementwiseMinOtherModule_basic", - "ElementwiseMaximumModule_basic", - "ElementwiseMaximumIntModule_basic", - "ElementwiseMaxOtherIntModule_basic", - "ElementwiseMaxOtherModule_basic", - "GluStaticModule_basic", - "ViewDoubleMergeStaticModule_basic", - "ViewCollapseOnesMiddleModule_basic", - "ViewFiveTestStaticModule_basic", - "ViewOffsetTestStaticModule_basic", - "ViewTwoFiveThreeStaticModule_basic", - "ViewTwoToThreeStaticModule_basic", - "ViewExpandOnesMiddleOppModule_basic", - "ViewOffsetBackwardTestStaticModule_basic", - "TanhBackward_basic", - "HardtanhBackward_basic", - "ElementwiseAddModule_basic", - "ReturnThreeTensorFloat32_basic", - "AddCMulModule_basic", - "AddCDivModule_basic", - "SqueezeModule_broadcast", - "BoolTensorReturnFalseModule_basic", - "BoolTensorReturnTrueModule_basic", - "BoolTensorReturnMixedModule_basic", - "BoolTensorHandleSignless_basic", - "ElementwiseRsqrtModule_basic", - "SelectIntNegativeDimAndIndexStaticModule_basic", - "SqueezeModule_static", - "SqueezeModule_noUnitDim", - "SqueezeModule_allUnitDim", - "TModuleRank1_basic", - "TModuleRank0_basic", - "ElementwiseToDtypeIdentityModule_basic", - "AtenToDeviceModule_basic", - "View1DFoldModule_basic", - "UnsafeView1DFoldModule_basic", - "UnflattenIntStaticModule_basic", - "UnflattenIntNegativeOneDimStaticModule_basic", - "UnflattenIntNegativeOneSizeStaticModule_basic", - "SqueezeDimModule_static", - "SqueezeDimModule_identity", - "SqueezeDimModule_unitDim", - "ReturnTwoTensorF32I64_basic", - "ElementwiseSignModule_basic", - "ElementwisePowModule_basic", - "BmmFloatModule_basic", - "MmDagModule_basic", - "Matmul4dStatic_basic", - "Matmul_dot", - "Matmul_3d", - "RsubFloatModule_basic", - "RsubFloatModule_noalpha_basic", - "RsubInt0d_NumToTensor_Module_basic", - "ElementwiseBitwiseAndModule_basic", - "ElementwiseBitwiseAndStaticShapeModule_basic", - "ElementwiseBitwiseNotInt32Module_basic", - "ElementwiseBitwiseNotInt64Module_basic", - "ElementwiseOrTensorStaticShapeModule_basic", - "ElementwiseOrTensorModule_basic", - "ElementwiseBitwiseOrModule_basic", - "ElementwiseBitwiseOrStaticShapeModule_basic", - "ElementwiseBitwiseXorModule_basic", - "ElementwiseBitwiseXorStaticShapeModule_basic", - "ElementwiseGeFloatIntScalarModule_basic", - "ElementwiseGeFloatScalarModule_basic", - "ElementwiseGeIntScalarModule_basic", - "ElementwiseGeMixedIntScalarModule_basic", - "ElementwiseGtFloatScalarModule_basic", - "ElementwiseGtIntScalarModule_basic", - "ElementwiseGtMixed2ScalarModule_basic", - "ElementwiseGtFloatTensorModule_basic", - "ElementwiseGtIntTensorModule_basic", - "ElementwiseLtFloatScalarModule_basic", - "ElementwiseLtIntScalarModule_basic", - "ElementwiseLtDiffWidthScalarModule_basic", - "ElementwiseLtFloatTensorModule_basic", - "ElementwiseLtIntTensorModule_basic", - "ElementwiseEqFloatScalarModule_basic", - "ElementwiseEqIntScalarModule_basic", - "ElementwiseEqBoolScalarModule_basic", - "ElementwiseEqDiffWidthScalarModule_basic", - "ElementwiseEqFloatTensorModule_basic", - "ElementwiseEqIntTensorModule_basic", - "ElementwiseNeFloatScalarModule_basic", - "ElementwiseNeFloatTensorModule_basic", - "ElementwiseNeFloatTensorStaticModule_basic", - "ElementwiseNeIntTensorModule_basic", - "ElementwiseNeIntTensorStaticModule_basic", - "ElementwiseMulScalarModule_int", - "ElementwiseMulScalarModule_float", - "ElementwiseMulTensorIntModule_basic", - "ElementwiseDivScalarModule_basic", - "ElementwiseAtenDivIntScalarModule_basic", - "ElementwiseSubScalarFloatModule_basic", - "ElementwiseAddScalarFloatModule_basic", - "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", - "ElementwiseMulScalarModule_float", - "ElementwiseCeilModule_basic", - "ElementwiseReciprocalModule_basic", - "ElementwiseIsnanModule_basic", - "ElementwiseIsinfModule_basic", - "TypePromotionAlphaWiderModule_basic", - "Conv2dWithPaddingDilationStrideStaticModule_basic", - "Conv2dWithPaddingDilationStrideStaticModule_depthwise", - "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", - "BatchNorm1DModule_basic", - "BatchNorm1DWith2DInputModule_basic", - "BatchNorm2DModule_basic", - "BatchNorm3DModule_basic", - "BatchNorm1DStaticShapeModule_basic", - "FlattenStaticModule_basic", - "UnflattenStaticModule_basic", - "FlattenRank0Module_basic", - "ElementwiseFlattenBroadcastModule_basic", - "SquareModule_basic", - "MaxPool2dStaticModule_basic", - "MaxPool2dStaticCeilModeTrueModule_basic", - "ResNet18StaticModule_basic", - "ReduceAmaxKeepDim_basic", - "NativeLayerNormModule4D_basic", - "LayerNormNormalizeOverAllDimsModule_basic", - "PermuteModule_basic", - "PermuteNegativeIndexModule_basic", - "ElementwiseLog2Module_basic", - "Threshold1dIntI32Module_basic", - "Threshold1dFloatModule_basic", - "Threshold2dFloatModule_basic", - "Threshold3dFloatModule_basic", - "ElementwiseSubScalarIntModule_basic", - "ElementwiseAddScalarIntModule_basic", - "ElementwiseMulScalarModule_basic", - "ZerosModuleDefaultDtype_basic", - "ZerosModuleInt2D_basic", - "ZerosModuleInt3D_basic", - "ZerosModuleFloat2D_basic", - "ZerosModuleFloat3D_basic", - "ZerosModuleFalsePinMemory_basic", - "OnesModuleDefaultDtype_basic", - "OnesModuleInt_basic", - "OnesModuleFloat_basic", - "OnesModuleFalsePinMemory_basic", - "OnesModuleCPUDevice_basic", - "NewZerosModuleDefaultDtype_basic", - "NewZerosModuleInt2D_basic", - "NewZerosModuleInt3D_basic", - "NewZerosModuleFloat2D_basic", - "NewZerosModuleFloat3D_basic", - "NewZerosModuleFalsePinMemory_basic", - "NewOnesModuleDefaultDtype_basic", - "NewOnesModuleInt2D_basic", - "NewOnesModuleInt3D_basic", - "NewOnesModuleFloat2D_basic", - "NewOnesModuleFloat3D_basic", - "NewOnesModuleFalsePinMemory_basic", - "SiluModule_basic", - "DropoutEvalIntModule_basic", - "DropoutEvalFloatModule_basic", - "ContiguousModule_basic", - "DropoutModule_basic", - "ViewExpandModule_basic", - "ViewExpandOnesModule_basic", - "ViewExpandOnesBeforeAndAfterModule_basic", - "ViewExpandOnesMiddleModule_basic", - "ViewExpandCollapseModule_basic", - "ViewExpandCollapseWithOnesModule_basic", - "ViewCollapseInferredDimModule_basic", - "ViewExpandInferredDimModule_basic", - "ViewNegativeStaticModule_basic", - "ViewNoChangeStaticModule_basic", - "UnsafeViewExpandModule_basic", - "ReshapeCollapseModule_basic", - "ReshapeAsModule_basic", - "ElementwiseGeluModule_basic", - "GeluBackwardModule_basic", - "ElementwiseNeIntScalarModule_basic", - "Convolution2DStaticModule_basic", - "ElementwiseNegModule_basic", - "TestMultipleTensorReturn_basic", - "TypeAsSameModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", "AddCDivModule_basic", From ef559c5d97864faabde4bccbee2d72fe60378ffb Mon Sep 17 00:00:00 2001 From: Abhishek-TyRnT Date: Tue, 16 Jan 2024 08:55:37 +0530 Subject: [PATCH 11/18] git rid of iostream import --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 0b5819631f52..a46fdc5f549c 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -26,7 +26,6 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" -#include using namespace mlir; using namespace mlir::torch; From 08a289f326bb139914388b1e309cdabde411c4ae Mon Sep 17 00:00:00 2001 From: Abhishek-TyRnT Date: Sat, 3 Feb 2024 13:50:02 +0530 Subject: [PATCH 12/18] using int in result shape --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index a46fdc5f549c..001b9ad981a7 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -4102,11 +4102,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // The result will always be a 1-d tensor. // The size of the result is calculated as follows: // ceil((end - start)/step) - int64_t resultShape = ceil((end - start) / step); + Value result; if (is_all_inp_int) { - SmallVector values(resultShape, start); + int64_t resultShape = ceil((float)(end_int - start_int) / (float)(step_int)); + SmallVector values(resultShape, start_int); for (unsigned i = 1; i < resultShape; i++) values[i] += i * step; @@ -4115,6 +4116,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( else { + int64_t resultShape = ceil((end - start) / step); SmallVector values(resultShape, start); for (unsigned i = 1; i < resultShape; i++) values[i] += (i * step); From 7f3caa87ab5b7047946eadb911ff2af138261d4b Mon Sep 17 00:00:00 2001 From: Abhishek-TyRnT Date: Mon, 5 Feb 2024 22:25:27 +0530 Subject: [PATCH 13/18] got rid of resultshape for int case --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 001b9ad981a7..dbab3d40d14c 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -4106,22 +4106,21 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value result; if (is_all_inp_int) { - int64_t resultShape = ceil((float)(end_int - start_int) / (float)(step_int)); - SmallVector values(resultShape, start_int); - for (unsigned i = 1; i < resultShape; i++) - values[i] += i * step; + SmallVector values(start_int); + for (int64_t i = start_int; i < end_int; i += step_int) + values.push_back(i); - result = tosa::getConstTensor(rewriter, op, values, resultShape).value(); + result = tosa::getConstTensor(rewriter, op, values, values.size()).value(); } else { - int64_t resultShape = ceil((end - start) / step); - SmallVector values(resultShape, start); - for (unsigned i = 1; i < resultShape; i++) + int64_t resultSize = ceil((end - start) / step); + SmallVector values(resultSize, start); + for (unsigned i = 1; i < resultSize; i++) values[i] += (i * step); - result = tosa::getConstTensor(rewriter, op, values, resultShape).value(); + result = tosa::getConstTensor(rewriter, op, values, resultSize).value(); } rewriter.replaceOpWithNewOp(op, resultType, result); From 0f6ef1fc5f9cc91daf8d951d692a67391ac1ba6e Mon Sep 17 00:00:00 2001 From: Abhishek-TyRnT Date: Tue, 6 Feb 2024 22:35:54 +0530 Subject: [PATCH 14/18] got rid of result shape in all int case --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index dbab3d40d14c..cb2b7fe2b8e5 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -4106,9 +4106,18 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value result; if (is_all_inp_int) { - SmallVector values(start_int); - for (int64_t i = start_int; i < end_int; i += step_int) - values.push_back(i); + SmallVector values; + if (step_int >= 0) + { + for (int64_t i = start_int; i < end_int; i += step_int) + values.push_back(i); + } + + else + { + for (int64_t i = start_int; i > end_int; i += step_int) + values.push_back(i); + } result = tosa::getConstTensor(rewriter, op, values, values.size()).value(); } From 8b57a512b597d5e5f01f265e3b93a2b246882cc8 Mon Sep 17 00:00:00 2001 From: Abhishek-TyRnT Date: Thu, 15 Feb 2024 22:35:25 +0530 Subject: [PATCH 15/18] using static cast instead of dynamic cast --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index cb2b7fe2b8e5..24ad70235607 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -4069,12 +4069,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( double start, step, end; int64_t start_int, step_int, end_int; - bool is_all_inp_int; //Flag to check whether all inputs are integer - is_all_inp_int = op.getStart().getType().isa() && op.getEnd().getType().isa() && op.getStep().getType().isa(); + auto isInteger = [=](Value v) { return v.getType().isa(); }; + //Flag to check whether all inputs are integer + bool integer_range = isInteger(op.getStart()) && isInteger(op.getEnd()) && isInteger(op.getStep()); if (matchPattern(op.getStart(), m_TorchConstantInt(&start_int))) { - start = (double)(start_int); + start = static_cast(start_int); } else if(!matchPattern(op.getStart(), m_TorchConstantFloat(&start))) @@ -4083,7 +4084,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (matchPattern(op.getEnd(), m_TorchConstantInt(&end_int))) { - end = (double)(end_int); + end = static_cast(end_int); } else if (!matchPattern(op.getEnd(), m_TorchConstantFloat(&end))) return rewriter.notifyMatchFailure( @@ -4092,7 +4093,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (matchPattern(op.getStep(), m_TorchConstantInt(&step_int))) { - step = (double)(step_int); + step = static_cast(step_int); } else if (!matchPattern(op.getStep(), m_TorchConstantFloat(&step))) @@ -4104,7 +4105,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // ceil((end - start)/step) Value result; - if (is_all_inp_int) + if (integer_range) { SmallVector values; if (step_int >= 0) From 3140ab1a5990507b696c088d597f8c6c06c5fa80 Mon Sep 17 00:00:00 2001 From: Abhishek-TyRnT Date: Mon, 19 Feb 2024 20:19:09 +0530 Subject: [PATCH 16/18] typecasting for int64type --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 24ad70235607..5eb22637b12a 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -4070,6 +4070,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( double start, step, end; int64_t start_int, step_int, end_int; auto isInteger = [=](Value v) { return v.getType().isa(); }; + bool isOutputInt64=false; + auto intType = resultType.getElementType().dyn_cast_or_null(); + + if(intType) + { + if(intType.getWidth() == 64) + { + isOutputInt64 = true; + } + } //Flag to check whether all inputs are integer bool integer_range = isInteger(op.getStart()) && isInteger(op.getEnd()) && isInteger(op.getStep()); @@ -4123,6 +4133,18 @@ LogicalResult ConvertAtenOp::matchAndRewrite( result = tosa::getConstTensor(rewriter, op, values, values.size()).value(); } + //Since typecasting from float32 or float64 to int64 results in, seemingly + //garbage values. Therefore typecasting here itself. + else if(isOutputInt64) + { + int64_t resultSize = ceil((end - start) / step); + SmallVector values(resultSize, start); + for (unsigned i = 1; i < resultSize; i++) + values[i] += static_cast(i * step); + + result = tosa::getConstTensor(rewriter, op, values, resultSize).value(); + } + else { int64_t resultSize = ceil((end - start) / step); From 4c185db1cfef941d458e178e079112db8447e05e Mon Sep 17 00:00:00 2001 From: James Newling Date: Mon, 26 Feb 2024 11:49:43 -0800 Subject: [PATCH 17/18] git format, add some stylistic changes --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 181 ++++++++++++--------- 1 file changed, 103 insertions(+), 78 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 5eb22637b12a..06aaf8965acd 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -8,24 +8,22 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" -#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" -#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" -#include "torch-mlir/Conversion/Utils/Utils.h" - #include "../PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Traits.h" #include "mlir/IR/Matchers.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" +#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" +#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" -#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include using namespace mlir; using namespace mlir::torch; @@ -4067,93 +4065,120 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "unimplemented: pin_memory must be either None or false"); } - double start, step, end; - int64_t start_int, step_int, end_int; - auto isInteger = [=](Value v) { return v.getType().isa(); }; - bool isOutputInt64=false; - auto intType = resultType.getElementType().dyn_cast_or_null(); - - if(intType) - { - if(intType.getWidth() == 64) - { - isOutputInt64 = true; + // Stores a range value (start / end / step) and whether it was initiated with + // a constant integer, an constant float or neither. + class ConstRangeValue { + public: + explicit ConstRangeValue(double v) + : vDouble(v), fromDouble(true), vInt(static_cast(v)), + fromInt(false) {} + + explicit ConstRangeValue(int64_t v) + : vDouble(static_cast(v)), fromDouble(false), vInt(v), + fromInt(true) {} + + ConstRangeValue() + : vDouble(0), fromDouble(false), vInt(0), fromInt(false) {} + + bool hasConstInt() const { return fromInt; } + bool hasConstDouble() const { return fromDouble; } + bool hasConst() const { return fromInt || fromDouble; } + double getDouble() const { return vDouble; } + int64_t getInt() const { return vInt; } + + private: + double vDouble; + bool fromDouble; + int64_t vInt; + bool fromInt; + }; + + auto setConstantIntOrFloat = [](Value v) -> ConstRangeValue { + int64_t intVal{0}; + double floatVal{0.0}; + if (matchPattern(v, m_TorchConstantFloat(&floatVal))) { + return ConstRangeValue(floatVal); + } else if (matchPattern(v, m_TorchConstantInt(&intVal))) { + return ConstRangeValue(intVal); } - } - //Flag to check whether all inputs are integer - bool integer_range = isInteger(op.getStart()) && isInteger(op.getEnd()) && isInteger(op.getStep()); - - if (matchPattern(op.getStart(), m_TorchConstantInt(&start_int))) - { - start = static_cast(start_int); - } + return ConstRangeValue(); + }; - else if(!matchPattern(op.getStart(), m_TorchConstantFloat(&start))) + auto start = setConstantIntOrFloat(op.getStart()); + if (!start.hasConst()) { return rewriter.notifyMatchFailure( - op, "unimplemented: value `start` should be a torch constant int or float"); - - if (matchPattern(op.getEnd(), m_TorchConstantInt(&end_int))) - { - end = static_cast(end_int); + op, "unimplemented: case where `start` is not a constant int or float"); } - else if (!matchPattern(op.getEnd(), m_TorchConstantFloat(&end))) + + auto end = setConstantIntOrFloat(op.getEnd()); + if (!end.hasConst()) { return rewriter.notifyMatchFailure( - op, "unimplemented: value `end` should be a torch constant int or float"); + op, + "unimplemented: case where value `end` is not a constant int or float"); + } - if (matchPattern(op.getStep(), m_TorchConstantInt(&step_int))) - { - - step = static_cast(step_int); + auto step = setConstantIntOrFloat(op.getStep()); + if (!step.hasConst()) { + return rewriter.notifyMatchFailure(op, + "unimplemented: case where value `step` " + "is not a constant int or float"); } - else if (!matchPattern(op.getStep(), m_TorchConstantFloat(&step))) - return rewriter.notifyMatchFailure( - op, "unimplemented: value `step` should be a torch constant int or float"); + auto getRange = [](auto start, auto end, auto step) { + // Initialize a small vector of the same type as start: + using T = decltype(start); + SmallVector values; - // The result will always be a 1-d tensor. - // The size of the result is calculated as follows: - // ceil((end - start)/step) - - Value result; - if (integer_range) - { - SmallVector values; - if (step_int >= 0) - { - for (int64_t i = start_int; i < end_int; i += step_int) - values.push_back(i); + uint64_t counter{0}; + if (start == end) { + return values; } - - else - { - for (int64_t i = start_int; i > end_int; i += step_int) - values.push_back(i); + assert(step != T(0)); + values.reserve( + 1 + static_cast(std::abs((end - start) / std::abs(step)))); + if (step > 0) { + while (start + T(counter) * step < end) { + values.push_back(start + counter * step); + counter++; + } + } else { + while (start + T(counter) * step > end) { + values.push_back(start + counter * step); + counter++; + } } + return values; + }; - result = tosa::getConstTensor(rewriter, op, values, values.size()).value(); - } + const auto intType = + resultType.getElementType().dyn_cast_or_null(); - //Since typecasting from float32 or float64 to int64 results in, seemingly - //garbage values. Therefore typecasting here itself. - else if(isOutputInt64) - { - int64_t resultSize = ceil((end - start) / step); - SmallVector values(resultSize, start); - for (unsigned i = 1; i < resultSize; i++) - values[i] += static_cast(i * step); - - result = tosa::getConstTensor(rewriter, op, values, resultSize).value(); - } + auto maybeResult = [&]() -> std::optional { + if (intType && start.hasConstInt() && end.hasConstInt() && + step.hasConstInt()) { + auto values = getRange(start.getInt(), end.getInt(), step.getInt()); + return tosa::getConstTensor(rewriter, op, values, values.size()); + } - else - { - int64_t resultSize = ceil((end - start) / step); - SmallVector values(resultSize, start); - for (unsigned i = 1; i < resultSize; i++) - values[i] += (i * step); - - result = tosa::getConstTensor(rewriter, op, values, resultSize).value(); + auto values = + getRange(start.getDouble(), end.getDouble(), step.getDouble()); + if (intType) { + SmallVector values_i64; + values_i64.reserve(values.size()); + for (auto v : values) { + values_i64.push_back(static_cast(v)); + } + return tosa::getConstTensor(rewriter, op, values_i64, + values.size()); + } + return tosa::getConstTensor(rewriter, op, values, values.size()); + }(); + + if (!maybeResult.has_value()) { + return rewriter.notifyMatchFailure( + op, "failed to generate constant tensor for arange"); } + auto result = maybeResult.value(); rewriter.replaceOpWithNewOp(op, resultType, result); return success(); From 9b4ae1e06805ce2453c10e4c9b517bea5a9a1476 Mon Sep 17 00:00:00 2001 From: James Newling Date: Mon, 26 Feb 2024 12:21:58 -0800 Subject: [PATCH 18/18] update --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 59 ++++++++++++++-------- 1 file changed, 39 insertions(+), 20 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 06aaf8965acd..ce0a1af2f834 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -22,6 +22,7 @@ #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" #include @@ -4065,8 +4066,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "unimplemented: pin_memory must be either None or false"); } - // Stores a range value (start / end / step) and whether it was initiated with - // a constant integer, an constant float or neither. + // Stores a range value (a start, end, or step value) and whether or not it + // was initiated with a constant integer, an constant float or neither. class ConstRangeValue { public: explicit ConstRangeValue(double v) @@ -4077,9 +4078,21 @@ LogicalResult ConvertAtenOp::matchAndRewrite( : vDouble(static_cast(v)), fromDouble(false), vInt(v), fromInt(true) {} + // Constructor for the case where there is no constant value to use. ConstRangeValue() : vDouble(0), fromDouble(false), vInt(0), fromInt(false) {} + static ConstRangeValue fromValue(Value v) { + int64_t intVal{0}; + double floatVal{0.0}; + if (matchPattern(v, m_TorchConstantFloat(&floatVal))) { + return ConstRangeValue(floatVal); + } else if (matchPattern(v, m_TorchConstantInt(&intVal))) { + return ConstRangeValue(intVal); + } + return ConstRangeValue(); + } + bool hasConstInt() const { return fromInt; } bool hasConstDouble() const { return fromDouble; } bool hasConst() const { return fromInt || fromDouble; } @@ -4093,31 +4106,20 @@ LogicalResult ConvertAtenOp::matchAndRewrite( bool fromInt; }; - auto setConstantIntOrFloat = [](Value v) -> ConstRangeValue { - int64_t intVal{0}; - double floatVal{0.0}; - if (matchPattern(v, m_TorchConstantFloat(&floatVal))) { - return ConstRangeValue(floatVal); - } else if (matchPattern(v, m_TorchConstantInt(&intVal))) { - return ConstRangeValue(intVal); - } - return ConstRangeValue(); - }; - - auto start = setConstantIntOrFloat(op.getStart()); + auto start = ConstRangeValue::fromValue(op.getStart()); if (!start.hasConst()) { return rewriter.notifyMatchFailure( op, "unimplemented: case where `start` is not a constant int or float"); } - auto end = setConstantIntOrFloat(op.getEnd()); + auto end = ConstRangeValue::fromValue(op.getEnd()); if (!end.hasConst()) { return rewriter.notifyMatchFailure( op, "unimplemented: case where value `end` is not a constant int or float"); } - auto step = setConstantIntOrFloat(op.getStep()); + auto step = ConstRangeValue::fromValue(op.getStep()); if (!step.hasConst()) { return rewriter.notifyMatchFailure(op, "unimplemented: case where value `step` " @@ -4150,19 +4152,24 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return values; }; - const auto intType = + const auto isIntType = resultType.getElementType().dyn_cast_or_null(); + const auto isDoubleType = + resultType.getElementType().dyn_cast_or_null(); + auto maybeResult = [&]() -> std::optional { - if (intType && start.hasConstInt() && end.hasConstInt() && + // Integer output type, and start / end / range are all integers. + if (isIntType && start.hasConstInt() && end.hasConstInt() && step.hasConstInt()) { auto values = getRange(start.getInt(), end.getInt(), step.getInt()); return tosa::getConstTensor(rewriter, op, values, values.size()); } + // Get a double range. auto values = getRange(start.getDouble(), end.getDouble(), step.getDouble()); - if (intType) { + if (isIntType) { SmallVector values_i64; values_i64.reserve(values.size()); for (auto v : values) { @@ -4171,7 +4178,19 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return tosa::getConstTensor(rewriter, op, values_i64, values.size()); } - return tosa::getConstTensor(rewriter, op, values, values.size()); + + if (!isDoubleType) { + return {}; + } + + SmallVector values_f32; + values_f32.reserve(values.size()); + for (auto v : values) { + values_f32.push_back(static_cast(v)); + } + auto vs = tosa::getConstTensor(rewriter, op, values_f32, + values_f32.size()); + return vs; }(); if (!maybeResult.has_value()) {