From b6f04fa32bb536a2cae657e233ace368b596b191 Mon Sep 17 00:00:00 2001 From: Justin Ngo Date: Thu, 7 Nov 2024 14:09:43 -0800 Subject: [PATCH] [TOSA] Fix rsub; add clamp.Tensor, avg_pool1d, max_pool1d, prims.collapse (#3855) - Fix aten.rsub.Scalar legalization with appropriate type casting - Add legalization for aten.clamp.Tensor - Resolve some unexpected test failures from PyTorch update by adding legalization for the following ops: + aten.avg_pool1d + aten.max_pool1d + torch.prims.collapse - Update xfail_sets with new e2e results - Add new LIT tests to basic.mlir Change-Id: I9762c7d36ca0b0f75ca68d0c71d7f5d5309a96ad --------- Signed-off-by: Justin Ngo --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 344 ++++++++++++++++++++- projects/pt1/e2e_testing/xfail_sets.py | 61 ++-- test/Conversion/TorchToTosa/basic.mlir | 131 +++++++- 3 files changed, 481 insertions(+), 55 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 10f6ecb357fe..df5ed5fa88c1 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -2072,26 +2072,30 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Only ranked tensor types supported in TOSA Rsub"); + auto resultTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); + auto resultElemTy = resultTy.getElementType(); + + self = tosa::promoteType(rewriter, self, resultTy); + Value otherTensor, alphaTensor; if (failed(torchScalarToTosaTensor(rewriter, op, otherScalar, otherTensor, - selfTy.getElementType(), {}))) + resultElemTy, {}))) return rewriter.notifyMatchFailure( op, "Currently only scalar constants are supported for " "conversion in TOSA Rsub operation"); if (failed(torchAlphaToTosaTensor(rewriter, op.getOperation(), alphaScalar, - alphaTensor, selfTy.getElementType(), + alphaTensor, resultElemTy, /*checkForUnity=*/true))) return failure(); - auto multTensor = rewriter.create( - op->getLoc(), getTypeConverter()->convertType(op.getType()), self, - alphaTensor, /*shift=*/0); + auto multTensor = rewriter.create(op->getLoc(), resultTy, self, + alphaTensor, /*shift=*/0); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), otherTensor, - multTensor); + rewriter.replaceOpWithNewOp(op, resultTy, otherTensor, + multTensor); return success(); } @@ -4730,6 +4734,108 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// Legalization for aten.clamp.Tensor +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenClampTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // We are not using tosa.clamp to lower aten.clamp.Tensor, as + // aten.clamp.Tensor's min and max attributes are tensors that can have size + // greater than 1, which is not compatible with tosa.clamp. + // + // Instead, we use the following formula: + // yi = min(max(xi, min_valuei), max_valuei) + auto self = adaptor.getSelf(); + + // Not a tensor type + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + auto selfElemTy = selfType.getElementType(); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + + // Get min tensor. If None, there is no lower bound. + Value min; + if (succeeded(checkNotNone(rewriter, op, adaptor.getMin()))) { + min = adaptor.getMin(); + } else { + min = + TypeSwitch(selfElemTy) + .Case([&](auto) { + return tosa::getConstTensor( + rewriter, op, std::numeric_limits::lowest(), {}, + selfElemTy) + .value(); + }) + .Case([&](auto intType) { + switch (intType.getWidth()) { + case 8: + return tosa::getConstTensor( + rewriter, op, std::numeric_limits::min(), {}) + .value(); + case 32: + return tosa::getConstTensor( + rewriter, op, std::numeric_limits::min(), + {}) + .value(); + case 64: + return tosa::getConstTensor( + rewriter, op, std::numeric_limits::min(), + {}) + .value(); + } + llvm_unreachable("Invalid integer width"); + }); + } + + // Get max tensor. If None, there is no upper bound. + Value max; + if (succeeded(checkNotNone(rewriter, op, adaptor.getMax()))) { + max = adaptor.getMax(); + } else { + max = + TypeSwitch(selfElemTy) + .Case([&](auto) { + return tosa::getConstTensor( + rewriter, op, std::numeric_limits::max(), {}, + selfElemTy) + .value(); + }) + .Case([&](auto intType) { + switch (intType.getWidth()) { + case 8: + return tosa::getConstTensor( + rewriter, op, std::numeric_limits::max(), {}) + .value(); + case 32: + return tosa::getConstTensor( + rewriter, op, std::numeric_limits::max(), + {}) + .value(); + case 64: + return tosa::getConstTensor( + rewriter, op, std::numeric_limits::max(), + {}) + .value(); + } + llvm_unreachable("Invalid integer width"); + }); + } + + // max(xi, min_valuei) + auto minThresholdCheck = tosa::createBinaryOpAndCast( + rewriter, op, resultType, self, min); + + // yi = min(max(xi, min_valuei), max_valuei) + auto result = tosa::createBinaryOpAndCast( + rewriter, op, resultType, minThresholdCheck, max); + + rewriter.replaceOp(op, result); + return success(); +} + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenArangeStartStepOp op, OpAdaptor adaptor, @@ -5236,11 +5342,29 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { ConvertAtenPoolingBaseOp::transposePoolingOutputToChw( op, rewriter, pooledOutput); - rewriter.replaceOpWithNewOp( - op, + Value result = transposedOutput; + auto resultTy = dyn_cast( OpConversionPattern::getTypeConverter()->convertType( - op.getType()), - transposedOutput); + op.getType())); + + if constexpr (std::is_same() || + std::is_same()) { + auto resultShape = resultTy.getShape(); + auto resultElemTy = resultTy.getElementType(); + + result = rewriter.create( + op->getLoc(), + RankedTensorType::get(makeShapeTorchCompatible(resultShape), + resultElemTy), + transposedOutput, + rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape))); + } + + rewriter.replaceOpWithNewOp( + op, resultTy, + // OpConversionPattern::getTypeConverter()->convertType( + // op.getType()), + result); return success(); } @@ -5387,6 +5511,12 @@ static LogicalResult getOutputTypeAndPoolingParameters( return rewriter.notifyMatchFailure( op, "Non-const kernel_size for pooling op unsupported"); + // Expand kernel size parameter to size 2 to be compatible with + // tosa::MaxPool2dOp or tosa::AvgPool2dOp + if constexpr (std::is_same() || + std::is_same()) + kernelSizeInts.push_back(1); + if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts))) return rewriter.notifyMatchFailure( op, "Non-const stride for pooling op unsupported"); @@ -5394,13 +5524,26 @@ static LogicalResult getOutputTypeAndPoolingParameters( // list during import. For such a case, the stride value is the kernel size. // See: // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d - if (strideInts.empty()) + if (strideInts.empty()) { strideInts.assign(kernelSizeInts); + } else { + // Expand stride parameter to size 2 to be compatible with + // tosa::MaxPool2dOp or tosa::AvgPool2dOp + if constexpr (std::is_same() || + std::is_same()) + strideInts.push_back(1); + } if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInts))) return rewriter.notifyMatchFailure( op, "Non-const padding factor for pooling op unsupported"); + // Expand padding parameter to size 2 to be compatible with + // tosa::MaxPool2dOp or tosa::AvgPool2dOp + if constexpr (std::is_same() || + std::is_same()) + paddingInts.push_back(0); + SmallVector padArr = {paddingInts[0], paddingInts[0], paddingInts[1], paddingInts[1]}; kernel = rewriter.getDenseI64ArrayAttr(kernelSizeInts); @@ -5456,6 +5599,68 @@ class ConvertAtenMaxPool2dOp } }; +// Legalization for aten.max_pool1d +class ConvertAtenMaxPool1dOp + : public ConvertAtenPoolingBaseOp { +public: + using ConvertAtenPoolingBaseOp::ConvertAtenPoolingBaseOp; + LogicalResult processInputs(AtenMaxPool1dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Value &input, + DenseI64ArrayAttr &kernel, + DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad, + Type &outputTy) const override { + auto self = adaptor.getSelf(); + + // Not a RankedTensorType + auto selfTy = dyn_cast(self.getType()); + if (!selfTy) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor type inputs are supported"); + auto selfShape = selfTy.getShape(); + + // Expected a rank 3 input tensor + if (selfTy.getRank() != 3) + return rewriter.notifyMatchFailure( + op, "Input tensor for MaxPool1d should have rank 3"); + + // Unsqueeze input tensor to rank 4 to be compatible with tosa::MaxPool2dOp + SmallVector rank4Shape(selfShape); + rank4Shape.push_back(1); + auto reshapedSelf = rewriter.create( + op->getLoc(), + RankedTensorType::get(makeShapeTorchCompatible(rank4Shape), + selfTy.getElementType()), + self, rewriter.getDenseI64ArrayAttr(rank4Shape)); + + SmallVector dilationArray; + if (!matchPattern(op.getDilation(), + m_TorchListOfConstantInts(dilationArray))) + return rewriter.notifyMatchFailure( + op, "Non-const dilation for pooling op unsupported."); + // TOSA pooling only supports unit dilation. + if (dilationArray[0] > 1) + return rewriter.notifyMatchFailure( + op, "Cannot process non-unit pooling dilation."); + + // Expand dilation to size 2 to be compatible with tosa::MaxPool2dOp + dilationArray.push_back(1); + + if (failed(getOutputTypeAndPoolingParameters( + op, rewriter, reshapedSelf.getResult(), dilationArray, outputTy, + kernel, stride, pad))) + return rewriter.notifyMatchFailure( + op, "invalid pooling parameters or input type"); + + // Transpose to xHWC + input = ConvertAtenPoolingBaseOp:: + transposePoolingInputToHwc(op, rewriter, reshapedSelf.getResult()); + + return success(); + } +}; + class ConvertAtenAvgPool2dOp : public ConvertAtenPoolingBaseOp { public: @@ -5504,6 +5709,68 @@ class ConvertAtenAvgPool2dOp } }; +// Legalization for aten.avg_pool1d +class ConvertAtenAvgPool1dOp + : public ConvertAtenPoolingBaseOp { +public: + using ConvertAtenPoolingBaseOp::ConvertAtenPoolingBaseOp; + LogicalResult processInputs(AtenAvgPool1dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Value &input, + DenseI64ArrayAttr &kernel, + DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad, + Type &outputTy) const override { + auto self = adaptor.getSelf(); + + // Not a RankedTensorType + auto selfTy = dyn_cast(self.getType()); + if (!selfTy) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor type inputs are supported"); + auto selfShape = selfTy.getShape(); + + // Expected a rank 3 input tensor + if (selfTy.getRank() != 3) + return rewriter.notifyMatchFailure( + op, "Input tensor for MaxPool1d should have rank 3"); + + // Unsqueeze input tensor to rank 4 to be compatible with tosa::AvgPool2dOp + SmallVector rank4Shape(selfShape); + rank4Shape.push_back(1); + auto reshapedSelf = rewriter.create( + op->getLoc(), + RankedTensorType::get(makeShapeTorchCompatible(rank4Shape), + selfTy.getElementType()), + self, rewriter.getDenseI64ArrayAttr(rank4Shape)); + + // Currently, we can not represent `count_include_pad` with the existing + // TOSA AvgPool2d specification. Without the below check, we produce silent + // wrong answers (SWA) when the `count_include_pad` value is `true.` + bool countIncludePad; + if (!matchPattern(op.getCountIncludePad(), + m_TorchConstantBool(&countIncludePad)) || + countIncludePad) { + return rewriter.notifyMatchFailure( + op, "Unsupported `count_include_pad` value, for tosa AvgPool2dOp " + "`count_include_pad` value should be `False`."); + } + + SmallVector dilationArray{1, 1}; + if (failed(getOutputTypeAndPoolingParameters( + op, rewriter, reshapedSelf.getResult(), dilationArray, outputTy, + kernel, stride, pad))) + return rewriter.notifyMatchFailure( + op, "invalid pooling parameters or input type"); + + // Transpose to xHWC + input = ConvertAtenPoolingBaseOp:: + transposePoolingInputToHwc(op, rewriter, reshapedSelf.getResult()); + + return success(); + } +}; + // Ref: Error checking based on the Torch to LinAlg lowering template class ConvertAtenConstPatternOp : public OpConversionPattern { @@ -6880,6 +7147,49 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// Legalization for torch.prims.collapse +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + PrimsCollapseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getA(); + + // Not a tensor type + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + auto resultShape = resultType.getShape(); + + int64_t start, end; + if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) + return rewriter.notifyMatchFailure( + op, "Only constant int start value is supported"); + + if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) + return rewriter.notifyMatchFailure( + op, "Only constant int end value is supported"); + + // Identity case + if (start == end) { + rewriter.replaceOp(op, self); + return success(); + } + + // Technically, I should calculate the output shape based on the input shape, + // start value, and end value. However, that would just give the same result + // as me taking the result shape straight from resultType and applying + // tosa::ReshapeOp to the input. Therefore, I'm opting for the latter approach + // here, which is more simple and quicker. + rewriter.replaceOpWithNewOp( + op, resultType, self, + rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape))); + + return success(); +} + } // namespace // ----------------------------------------------------------------------------- @@ -7101,9 +7411,15 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); + target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); + #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ target.addIllegalOp(); \ patterns.add>(typeConverter, \ @@ -7199,6 +7515,8 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(AtenUniformOp); INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp); INSERT_ATENOP_PATTERN(AtenAsStridedOp); + INSERT_ATENOP_PATTERN(AtenClampTensorOp); + INSERT_ATENOP_PATTERN(PrimsCollapseOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index ce6700127867..8d7aa88ad425 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1744,6 +1744,23 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "AdaptiveMaxPool1dDimOneStatic_basic", + "CollapseAllDimensionsModule_basic", + "CollapseRank1DynamicModule_basic", + "CollapseStaticModule_basic", + "ElementwiseClampMinTensorFloatModule_basic", + "ElementwiseClampMinTensorIntModule_basic", + "ElementwiseClampTensorFloatModule_basic", + "ElementwiseClampTensorIntModule_basic", + "ElementwiseFracModule_basic", + "ElementwiseLdexpModule_basic", + "ElementwiseSignbitIntModule_basic", + "Exp2StaticIntModule_basic", + "MaxPool1dEmptyStrideStaticModule_basic", + "MaxPool1dStaticCeilModeTrueModule_basic", + "MaxPool1dStaticModule_basic", + "RepeatInterleaveSelfIntModule_basic", + "RsubIntModule_noalpha_basic", "Aten_TrilinearModuleSumAllDims_basic", "Aten_TrilinearModuleSumdims_basic", "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", @@ -3373,9 +3390,10 @@ } FX_IMPORTER_TOSA_XFAIL_SET = { + "ElementwiseCopysignModule_basic", + "ElementwiseSignbitModule_basic", "Aten_TrilinearModuleVaryingRanks_basic", "Aten_TrilinearModuleZerodDimBug_basic", - "AdaptiveMaxPool1dDimOneStatic_basic", "ElementwiseRreluWithNoiseTrainModule_basic", "ElementwiseRreluWithNoiseTrainStaticModule_basic", "MaxPool3dEmptyStrideStaticModule_basic", @@ -3519,11 +3537,6 @@ "BoolIntTrueModule_basic", "BroadcastDynamicDimModule_basic", "CeilFloatModule_basic", - "CollapseAllDimensionsModule_basic", - "CollapseFullDynamicModule_basic", - "CollapsePartialDynamicModule_basic", - "CollapseRank1DynamicModule_basic", - "CollapseStaticModule_basic", "ConstantBoolParameterModule_basic", "ContainsIntList_False", "ContainsIntList_True", @@ -3585,10 +3598,6 @@ "ElementwiseAtanhIntModule_basic", "ElementwiseAtanhModule_basic", "ElementwiseAtenLogicalNotOpPromoteModule_basic", - "ElementwiseClampMinTensorFloatModule_basic", - "ElementwiseClampMinTensorIntModule_basic", - "ElementwiseClampTensorFloatModule_basic", - "ElementwiseClampTensorIntModule_basic", "ElementwiseCosIntModule_basic", "ElementwiseCoshIntModule_basic", "ElementwiseCoshModule_basic", @@ -3784,7 +3793,6 @@ "ReplicationPad2dModule_right0", "ReplicationPad2dModule_top0", "RollModule_basic", - "RsubIntModule_noalpha_basic", "ScalarConstantTupleModule_basic", "ScalarImplicitFloatModule_basic", "ScalarImplicitIntModule_basic", @@ -3897,16 +3905,12 @@ "IndexPutImpl2DFloatNonAccumulateModule_basic", "IndexPutImpl3DFloatNonAccumulateModule_basic", "IouOfModule_basic", - "MaxPool1dEmptyStrideStaticModule_basic", - "MaxPool1dStaticCeilModeTrueModule_basic", - "MaxPool1dStaticModule_basic", "MeshgridIndexingIJ_basic", "MeshgridIndexingXY_basic", "Meshgrid_basic", "OneHotModule_basic", "ReduceFrobeniusNormKeepDimModule_basic", "ReduceFrobeniusNormModule_basic", - "RepeatInterleaveSelfIntModule_basic", "ScaledDotProductAttentionBoolMaskModule_basic", "ScaledDotProductAttentionDifferentCausalModule_basic", "ScaledDotProductAttentionDifferentDynamicCausalModule_basic", @@ -3927,6 +3931,16 @@ } ONNX_TOSA_XFAIL_SET = { + "ElementwiseCopysignModule_basic", + "ElementwiseFracModule_basic", + "ElementwiseLdexpModule_basic", + "ElementwiseSignbitIntModule_basic", + "ElementwiseSignbitModule_basic", + "Exp2StaticIntModule_basic", + "NllLossStaticModule_basic", + "NllLossStaticModule_mean_basic", + "NllLossStaticModule_sum_basic", + "NllLossStaticModule_weight_basic", "Exp2StaticModule_basic", "ElementwiseRreluWithNoiseEvalModule_basic", "ElementwiseRreluWithNoiseEvalStaticModule_basic", @@ -3950,7 +3964,6 @@ "TriuIndicesAllZerosModule_basic", "ElementwiseCreateComplexModule_basic", "ReduceAllDimFloatModule_basic", - "AdaptiveMaxPool1dDimOneStatic_basic", "ScaledDotProductAttentionDifferentCausalModule_basic", "HstackBasicComplexModule_basic", "HstackBasicFloatModule_basic", @@ -4029,7 +4042,6 @@ "AdaptiveAvgPool1dStaticEvenMultiple_basic", "AdaptiveAvgPool1dStaticLargerOutput_basic", "AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic", - "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dDynamicNoBatch_basic", "AdaptiveAvgPool2dDynamic_basic", "AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic", @@ -4285,10 +4297,6 @@ "ElementwiseBitwiseRightShiftInt8Module_basic", "ElementwiseBitwiseXorModule_basic", "ElementwiseBitwiseXorStaticShapeModule_basic", - "ElementwiseClampMaxModule_basic", - "ElementwiseClampMinModule_basic", - "ElementwiseClampModule_basic", - "ElementwiseClampTensorInt8Module_basic", "ElementwiseCosIntModule_basic", "ElementwiseCoshIntModule_basic", "ElementwiseCoshModule_basic", @@ -4335,7 +4343,6 @@ "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseReciprocalIntModule_basic", - "ElementwiseRelu6Module_basic", "ElementwiseRemainderScalarModule_Bool_basic", "ElementwiseRemainderScalarModule_Int_Float_basic", "ElementwiseRemainderTensorModule_Int_Float_basic", @@ -4414,8 +4421,6 @@ "GtFloatIntModule_basic", "GtIntModule_basic", "HBC_basic", - "HardTanhIntModule_basic", - "HardTanhModule_basic", "HardtanhBackward_basic", "IndexPut1DFloatAccumulateModule_basic", "IndexPut1DFloatNonAccumulateModule_basic", @@ -4463,7 +4468,6 @@ "IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic", "IndexTensorModule3dInput_basic", "IndexTensorModule_basic", - "IndexTensorMultiIndexStaticModule_basic", "IndexTensorMultiInputContiguousCenter_basic", "IndexTensorMultiInputContiguousOneDimDynamic_basic", "IndexTensorMultiInputNonContiguousDynamic_basic", @@ -4474,8 +4478,6 @@ "IndexTensorMultiInputThreeIndexers_basic", "IndexTensorMultiInput_basic", "IndexTensorSelectDimModule_basic", - "IndexTensorStaticContiguousWithNoneModule_basic", - "IndexTensorStaticNonContiguousWithNoneModule_basic", "InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_nearest", "InterpolateDynamicModule_scales_recompute_bilinear", @@ -4503,10 +4505,7 @@ "Matmul_matvec", "Matmul_vecmat", "MaxPool1dCeilModeTrueModule_basic", - "MaxPool1dEmptyStrideStaticModule_basic", "MaxPool1dModule_basic", - "MaxPool1dStaticCeilModeTrueModule_basic", - "MaxPool1dStaticModule_basic", "MaxPool2dCeilModeTrueModule_basic", "MaxPool2dModule_basic", "MaxPool2dWithIndicesAllNegativeValuesModule_basic", @@ -4607,7 +4606,6 @@ "NormScalarOptDimKeepDimModule_basic", "NormScalarOptDimModule_basic", "NormalFunctionalModule_basic", - "NormalizeModule_basic", "NumToTensorFloatModule_basic", "NumToTensorIntModule_basic", "NumelModule_basic", @@ -4730,7 +4728,6 @@ "ReplicationPad2dModule_right0", "ReplicationPad2dModule_top0", "ResNet18Module_basic", - "ResNet18StaticModule_basic", "ReshapeAliasCollapseModule_basic", "ReshapeAliasExpandModule_basic", "ReshapeCollapseModule_basic", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index ed679e852e53..548c0b4baf06 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -2258,16 +2258,8 @@ func.func @torch.aten.logical_and$basic(%arg0: !torch.vtensor<[4,5],i1>, %arg1: // ----- -// CHECK-LABEL: func.func @torch.aten.uniform$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f64>) -> (!torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>) { -// CHECK: %[[VAL_1:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[VAL_2:.*]] = torch.constant.float 1.000000e+01 -// CHECK: %[[VAL_3:.*]] = torch.constant.none -// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1.00007045, 2.18384027, 7.80044794, 5.12785149], [5.79490519, 2.97063255, 1.42340159, 7.10978221], [7.11366796, 9.41223621, 4.45151854, 5.67474747]]> : tensor<3x4xf32>}> : () -> tensor<3x4xf32> -// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<3x4xf32>) -> tensor<3x4xf64> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf64> -> !torch.vtensor<[3,4],f64> -// CHECK: return %[[VAL_6]], %[[VAL_6]] : !torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64> -// CHECK: } +// CHECK-LABEL: torch.aten.uniform$basic +// CHECK: tosa.const func.func @torch.aten.uniform$basic(%arg0: !torch.vtensor<[3,4],f64>) -> (!torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>) { %float1.000000e00 = torch.constant.float 1.000000e+00 %float1.000000e01 = torch.constant.float 1.000000e+01 @@ -2313,3 +2305,122 @@ func.func @torch.aten.as_strided$basic(%arg0: !torch.vtensor<[5,5],f32>) -> !tor %2 = torch.aten.as_strided %arg0, %0, %1, %none : !torch.vtensor<[5,5],f32>, !torch.list, !torch.list, !torch.none -> !torch.vtensor<[3,3],f32> return %2 : !torch.vtensor<[3,3],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.max_pool1d$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,64,112],f32>) -> !torch.vtensor<[1,64,56],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,64,112],f32> -> tensor<1x64x112xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.bool false +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_5]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_4]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<1x64x112xf32>) -> tensor<1x64x112x1xf32> +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_10]], %[[VAL_11]] : (tensor<1x64x112x1xf32>, tensor<4xi32>) -> tensor<1x112x1x64xf32> +// CHECK: %[[VAL_13:.*]] = tosa.max_pool2d %[[VAL_12]] {kernel = array, pad = array, stride = array} : (tensor<1x112x1x64xf32>) -> tensor<1x56x1x64xf32> +// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_15:.*]] = tosa.transpose %[[VAL_13]], %[[VAL_14]] : (tensor<1x56x1x64xf32>, tensor<4xi32>) -> tensor<1x64x56x1xf32> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<1x64x56x1xf32>) -> tensor<1x64x56xf32> +// CHECK: %[[VAL_17:.*]] = tensor.cast %[[VAL_16]] : tensor<1x64x56xf32> to tensor<1x64x56xf32> +// CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<1x64x56xf32> -> !torch.vtensor<[1,64,56],f32> +// CHECK: return %[[VAL_18]] : !torch.vtensor<[1,64,56],f32> +// CHECK: } +func.func @torch.aten.max_pool1d$basic(%arg0: !torch.vtensor<[1,64,112],f32>) -> !torch.vtensor<[1,64,56],f32> { + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %4 = torch.aten.max_pool1d %arg0, %0, %1, %2, %3, %false : !torch.vtensor<[1,64,112],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,64,56],f32> + return %4 : !torch.vtensor<[1,64,56],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.avg_pool1d$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,10],f32> -> tensor<1x512x10xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_4:.*]] = torch.constant.bool false +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<1x512x10xf32>) -> tensor<1x512x10x1xf32> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_10:.*]] = tosa.transpose %[[VAL_8]], %[[VAL_9]] : (tensor<1x512x10x1xf32>, tensor<4xi32>) -> tensor<1x10x1x512xf32> +// CHECK: %[[VAL_11:.*]] = tosa.avg_pool2d %[[VAL_10]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x10x1x512xf32>) -> tensor<1x10x1x512xf32> +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_11]], %[[VAL_12]] : (tensor<1x10x1x512xf32>, tensor<4xi32>) -> tensor<1x512x10x1xf32> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_13]] {new_shape = array} : (tensor<1x512x10x1xf32>) -> tensor<1x512x10xf32> +// CHECK: %[[VAL_15:.*]] = tensor.cast %[[VAL_14]] : tensor<1x512x10xf32> to tensor<1x512x10xf32> +// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<1x512x10xf32> -> !torch.vtensor<[1,512,10],f32> +// CHECK: return %[[VAL_16]] : !torch.vtensor<[1,512,10],f32> +// CHECK: } +func.func @torch.aten.avg_pool1d$basic(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> { + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list + %3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %false : !torch.vtensor<[1,512,10],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32> + return %3 : !torch.vtensor<[1,512,10],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.clamp.Tensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,5],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1],f32>) -> (!torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>) { +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[1],f32> -> tensor<1xf32> +// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1],f32> -> tensor<1xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,5],f32> -> tensor<3x5xf32> +// CHECK: %[[VAL_6:.*]] = torch.constant.none +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<3.40282347E+38> : tensor}> : () -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_4]] : (tensor<3x5xf32>, tensor<1xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_9:.*]] = tosa.minimum %[[VAL_8]], %[[VAL_7]] : (tensor<3x5xf32>, tensor) -> tensor<3x5xf32> +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32> +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<-3.40282347E+38> : tensor}> : () -> tensor +// CHECK: %[[VAL_12:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_11]] : (tensor<3x5xf32>, tensor) -> tensor<3x5xf32> +// CHECK: %[[VAL_13:.*]] = tosa.minimum %[[VAL_12]], %[[VAL_3]] : (tensor<3x5xf32>, tensor<1xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_13]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32> +// CHECK: %[[VAL_15:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_4]] : (tensor<3x5xf32>, tensor<1xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_16:.*]] = tosa.minimum %[[VAL_15]], %[[VAL_3]] : (tensor<3x5xf32>, tensor<1xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_17:.*]] = torch_c.from_builtin_tensor %[[VAL_16]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32> +// CHECK: return %[[VAL_10]], %[[VAL_14]], %[[VAL_17]] : !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32> +// CHECK: } +func.func @torch.aten.clamp.Tensor$basic(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],f32>) -> (!torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>) { + %none = torch.constant.none + %0 = torch.aten.clamp.Tensor %arg0, %arg1, %none : !torch.vtensor<[3,5],f32>, !torch.vtensor<[1],f32>, !torch.none -> !torch.vtensor<[3,5],f32> + %1 = torch.aten.clamp.Tensor %arg0, %none, %arg2 : !torch.vtensor<[3,5],f32>, !torch.none, !torch.vtensor<[1],f32> -> !torch.vtensor<[3,5],f32> + %2 = torch.aten.clamp.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[3,5],f32> + return %0, %1, %2 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.prims.collapse$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,12],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<2x3x4xf32>) -> tensor<2x12xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<2x12xf32> -> !torch.vtensor<[2,12],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[2,12],f32> +// CHECK: } +func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,12],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prims.collapse %arg0, %int1, %int2 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,12],f32> + return %0 : !torch.vtensor<[2,12],f32> +}