diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index c033dad1bbb4..91dcaea73378 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -5549,6 +5549,26 @@ static LogicalResult getOutputTypeAndPoolingParameters( std::is_same()) paddingInts.push_back(0); + if constexpr (std::is_same() || + std::is_same()) { + // Currently, we can not represent `count_include_pad` with the existing + // TOSA AvgPool2d specification. Without the below check, we produce silent + // wrong answer (SWA) when the `count_include_pad` value is `true.` + // + // Note: We need to check for `count_include_pad` only when the `padding` + // value is non-zero. + bool countIncludePad; + if ((paddingInts[0] != 0 || paddingInts[1] != 0) && + (!matchPattern(op.getCountIncludePad(), + m_TorchConstantBool(&countIncludePad)) || + + countIncludePad)) { + return rewriter.notifyMatchFailure( + op, "Unsupported `count_include_pad` value, for tosa AvgPool " + "`count_include_pad` value should be `False`."); + } + } + SmallVector padArr = {paddingInts[0], paddingInts[0], paddingInts[1], paddingInts[1]}; kernel = rewriter.getDenseI64ArrayAttr(kernelSizeInts); @@ -5677,18 +5697,6 @@ class ConvertAtenAvgPool2dOp DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad, Type &outputTy) const override { - // Currently, we can not represent `count_include_pad` with the existing - // TOSA AvgPool2d specification. Without the below check, we produce silent - // wrong answers (SWA) when the `count_include_pad` value is `true.` - bool countIncludePad; - if (!matchPattern(op.getCountIncludePad(), - m_TorchConstantBool(&countIncludePad)) || - countIncludePad) { - return rewriter.notifyMatchFailure( - op, "Unsupported `count_include_pad` value, for tosa AvgPool2dOp " - "`count_include_pad` value should be `False`."); - } - // Currently, we can not represent `divisor_override` with the existing TOSA // AvgPool2d specification. Without the below check, we produce silent wrong // answers (SWA) when the `divisor_override` value is other than `None.` @@ -5737,7 +5745,7 @@ class ConvertAtenAvgPool1dOp // Expected a rank 3 input tensor if (selfTy.getRank() != 3) return rewriter.notifyMatchFailure( - op, "Input tensor for MaxPool1d should have rank 3"); + op, "Input tensor for AvgPool1d should have rank 3"); // Unsqueeze input tensor to rank 4 to be compatible with tosa::AvgPool2dOp SmallVector rank4Shape(selfShape); @@ -5748,18 +5756,6 @@ class ConvertAtenAvgPool1dOp selfTy.getElementType()), self, rewriter.getDenseI64ArrayAttr(rank4Shape)); - // Currently, we can not represent `count_include_pad` with the existing - // TOSA AvgPool2d specification. Without the below check, we produce silent - // wrong answers (SWA) when the `count_include_pad` value is `true.` - bool countIncludePad; - if (!matchPattern(op.getCountIncludePad(), - m_TorchConstantBool(&countIncludePad)) || - countIncludePad) { - return rewriter.notifyMatchFailure( - op, "Unsupported `count_include_pad` value, for tosa AvgPool2dOp " - "`count_include_pad` value should be `False`."); - } - SmallVector dilationArray{1, 1}; if (failed(getOutputTypeAndPoolingParameters( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 8c38d0112f6c..b5d02034c1b2 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1736,6 +1736,12 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", + "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic", + "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", "ElementwiseAtenLogicalNotOpPromoteModule_basic", "ElementwiseCosIntModule_basic", "ElementwiseReciprocalIntModule_basic", @@ -2316,6 +2322,7 @@ "ReshapeExpandModule_basic", "ReturnThreeTensorFloat32_basic", "ReturnTwoTensorF32I64_basic", + "ResNet18StaticModule_basic", "RsubFloatModule_basic", "RsubFloatModule_noalpha_basic", "RsubInt0d_NumToTensor_Module_basic", @@ -3869,26 +3876,11 @@ "ViewSizeFromOtherTensor_basic", "VisionTransformerModule_basic", "ZerosLikeModule_falsePinMemory", - # count_include_pad and divisor_override check in TOSA AvgPool2d - "AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic", - "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic", - "AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic", - "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", - "AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic", - "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", - "ResNet18Module_basic", - "ResNet18StaticModule_basic", - "MobilenetV3Module_basic", # Unexpected failures due to new PyTorch version update "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", "AdaptiveAvgPool1dGeneralDynamic_basic", - "AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic", - "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool1dStaticEvenMultiple_basic", "AdaptiveAvgPool1dStaticLargerOutput_basic", - "AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic", - "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dDynamicNoBatch_basic", "AdaptiveAvgPool2dDynamic_basic", "CrossEntropyLossModule_basic", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 548c0b4baf06..23b5f6b06f1d 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -2424,3 +2424,18 @@ func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !to %0 = torch.prims.collapse %arg0, %int1, %int2 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,12],f32> return %0 : !torch.vtensor<[2,12],f32> } + +// ----- + +func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> { + %int1 = torch.constant.int 1 + %int3 = torch.constant.int 3 + %false = torch.constant.bool false + %count_include_pad = torch.constant.bool true + %0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + // expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool1d' that was explicitly marked illegal}} + %3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32> + return %3 : !torch.vtensor<[1,512,10],f32> +}