diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 91dcaea73378..c033dad1bbb4 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -5549,26 +5549,6 @@ static LogicalResult getOutputTypeAndPoolingParameters( std::is_same()) paddingInts.push_back(0); - if constexpr (std::is_same() || - std::is_same()) { - // Currently, we can not represent `count_include_pad` with the existing - // TOSA AvgPool2d specification. Without the below check, we produce silent - // wrong answer (SWA) when the `count_include_pad` value is `true.` - // - // Note: We need to check for `count_include_pad` only when the `padding` - // value is non-zero. - bool countIncludePad; - if ((paddingInts[0] != 0 || paddingInts[1] != 0) && - (!matchPattern(op.getCountIncludePad(), - m_TorchConstantBool(&countIncludePad)) || - - countIncludePad)) { - return rewriter.notifyMatchFailure( - op, "Unsupported `count_include_pad` value, for tosa AvgPool " - "`count_include_pad` value should be `False`."); - } - } - SmallVector padArr = {paddingInts[0], paddingInts[0], paddingInts[1], paddingInts[1]}; kernel = rewriter.getDenseI64ArrayAttr(kernelSizeInts); @@ -5697,6 +5677,18 @@ class ConvertAtenAvgPool2dOp DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad, Type &outputTy) const override { + // Currently, we can not represent `count_include_pad` with the existing + // TOSA AvgPool2d specification. Without the below check, we produce silent + // wrong answers (SWA) when the `count_include_pad` value is `true.` + bool countIncludePad; + if (!matchPattern(op.getCountIncludePad(), + m_TorchConstantBool(&countIncludePad)) || + countIncludePad) { + return rewriter.notifyMatchFailure( + op, "Unsupported `count_include_pad` value, for tosa AvgPool2dOp " + "`count_include_pad` value should be `False`."); + } + // Currently, we can not represent `divisor_override` with the existing TOSA // AvgPool2d specification. Without the below check, we produce silent wrong // answers (SWA) when the `divisor_override` value is other than `None.` @@ -5745,7 +5737,7 @@ class ConvertAtenAvgPool1dOp // Expected a rank 3 input tensor if (selfTy.getRank() != 3) return rewriter.notifyMatchFailure( - op, "Input tensor for AvgPool1d should have rank 3"); + op, "Input tensor for MaxPool1d should have rank 3"); // Unsqueeze input tensor to rank 4 to be compatible with tosa::AvgPool2dOp SmallVector rank4Shape(selfShape); @@ -5756,6 +5748,18 @@ class ConvertAtenAvgPool1dOp selfTy.getElementType()), self, rewriter.getDenseI64ArrayAttr(rank4Shape)); + // Currently, we can not represent `count_include_pad` with the existing + // TOSA AvgPool2d specification. Without the below check, we produce silent + // wrong answers (SWA) when the `count_include_pad` value is `true.` + bool countIncludePad; + if (!matchPattern(op.getCountIncludePad(), + m_TorchConstantBool(&countIncludePad)) || + countIncludePad) { + return rewriter.notifyMatchFailure( + op, "Unsupported `count_include_pad` value, for tosa AvgPool2dOp " + "`count_include_pad` value should be `False`."); + } + SmallVector dilationArray{1, 1}; if (failed(getOutputTypeAndPoolingParameters( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index b5d02034c1b2..8c38d0112f6c 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1736,12 +1736,6 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { - "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", - "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic", - "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", "ElementwiseAtenLogicalNotOpPromoteModule_basic", "ElementwiseCosIntModule_basic", "ElementwiseReciprocalIntModule_basic", @@ -2322,7 +2316,6 @@ "ReshapeExpandModule_basic", "ReturnThreeTensorFloat32_basic", "ReturnTwoTensorF32I64_basic", - "ResNet18StaticModule_basic", "RsubFloatModule_basic", "RsubFloatModule_noalpha_basic", "RsubInt0d_NumToTensor_Module_basic", @@ -3876,11 +3869,26 @@ "ViewSizeFromOtherTensor_basic", "VisionTransformerModule_basic", "ZerosLikeModule_falsePinMemory", + # count_include_pad and divisor_override check in TOSA AvgPool2d + "AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic", + "AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic", + "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", + "AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", + "ResNet18Module_basic", + "ResNet18StaticModule_basic", + "MobilenetV3Module_basic", # Unexpected failures due to new PyTorch version update "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", "AdaptiveAvgPool1dGeneralDynamic_basic", + "AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool1dStaticEvenMultiple_basic", "AdaptiveAvgPool1dStaticLargerOutput_basic", + "AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dDynamicNoBatch_basic", "AdaptiveAvgPool2dDynamic_basic", "CrossEntropyLossModule_basic", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 23b5f6b06f1d..548c0b4baf06 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -2424,18 +2424,3 @@ func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !to %0 = torch.prims.collapse %arg0, %int1, %int2 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,12],f32> return %0 : !torch.vtensor<[2,12],f32> } - -// ----- - -func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> { - %int1 = torch.constant.int 1 - %int3 = torch.constant.int 3 - %false = torch.constant.bool false - %count_include_pad = torch.constant.bool true - %0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list - %1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list - %2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list - // expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool1d' that was explicitly marked illegal}} - %3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32> - return %3 : !torch.vtensor<[1,512,10],f32> -}