Skip to content

Commit

Permalink
Revert "Support default padding case for tosa::AvgPool in the presenc…
Browse files Browse the repository at this point in the history
…e of count_include_pad (#3868)"

This reverts commit 30c5193.
  • Loading branch information
rahuls-cerebras committed Jan 3, 2025
1 parent ca77a5a commit c1c0524
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 43 deletions.
46 changes: 25 additions & 21 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5549,26 +5549,6 @@ static LogicalResult getOutputTypeAndPoolingParameters(
std::is_same<AtenOpT, AtenAvgPool1dOp>())
paddingInts.push_back(0);

if constexpr (std::is_same<AtenOpT, AtenAvgPool1dOp>() ||
std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
// Currently, we can not represent `count_include_pad` with the existing
// TOSA AvgPool2d specification. Without the below check, we produce silent
// wrong answer (SWA) when the `count_include_pad` value is `true.`
//
// Note: We need to check for `count_include_pad` only when the `padding`
// value is non-zero.
bool countIncludePad;
if ((paddingInts[0] != 0 || paddingInts[1] != 0) &&
(!matchPattern(op.getCountIncludePad(),
m_TorchConstantBool(&countIncludePad)) ||

countIncludePad)) {
return rewriter.notifyMatchFailure(
op, "Unsupported `count_include_pad` value, for tosa AvgPool "
"`count_include_pad` value should be `False`.");
}
}

SmallVector<int64_t, 4> padArr = {paddingInts[0], paddingInts[0],
paddingInts[1], paddingInts[1]};
kernel = rewriter.getDenseI64ArrayAttr(kernelSizeInts);
Expand Down Expand Up @@ -5697,6 +5677,18 @@ class ConvertAtenAvgPool2dOp
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
Type &outputTy) const override {

// Currently, we can not represent `count_include_pad` with the existing
// TOSA AvgPool2d specification. Without the below check, we produce silent
// wrong answers (SWA) when the `count_include_pad` value is `true.`
bool countIncludePad;
if (!matchPattern(op.getCountIncludePad(),
m_TorchConstantBool(&countIncludePad)) ||
countIncludePad) {
return rewriter.notifyMatchFailure(
op, "Unsupported `count_include_pad` value, for tosa AvgPool2dOp "
"`count_include_pad` value should be `False`.");
}

// Currently, we can not represent `divisor_override` with the existing TOSA
// AvgPool2d specification. Without the below check, we produce silent wrong
// answers (SWA) when the `divisor_override` value is other than `None.`
Expand Down Expand Up @@ -5745,7 +5737,7 @@ class ConvertAtenAvgPool1dOp
// Expected a rank 3 input tensor
if (selfTy.getRank() != 3)
return rewriter.notifyMatchFailure(
op, "Input tensor for AvgPool1d should have rank 3");
op, "Input tensor for MaxPool1d should have rank 3");

// Unsqueeze input tensor to rank 4 to be compatible with tosa::AvgPool2dOp
SmallVector<int64_t> rank4Shape(selfShape);
Expand All @@ -5756,6 +5748,18 @@ class ConvertAtenAvgPool1dOp
selfTy.getElementType()),
self, rewriter.getDenseI64ArrayAttr(rank4Shape));

// Currently, we can not represent `count_include_pad` with the existing
// TOSA AvgPool2d specification. Without the below check, we produce silent
// wrong answers (SWA) when the `count_include_pad` value is `true.`
bool countIncludePad;
if (!matchPattern(op.getCountIncludePad(),
m_TorchConstantBool(&countIncludePad)) ||
countIncludePad) {
return rewriter.notifyMatchFailure(
op, "Unsupported `count_include_pad` value, for tosa AvgPool2dOp "
"`count_include_pad` value should be `False`.");
}

SmallVector<int64_t, 2> dilationArray{1, 1};
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool1dOp,
tosa::AvgPool2dOp>(
Expand Down
22 changes: 15 additions & 7 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1736,12 +1736,6 @@
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
"ElementwiseCosIntModule_basic",
"ElementwiseReciprocalIntModule_basic",
Expand Down Expand Up @@ -2322,7 +2316,6 @@
"ReshapeExpandModule_basic",
"ReturnThreeTensorFloat32_basic",
"ReturnTwoTensorF32I64_basic",
"ResNet18StaticModule_basic",
"RsubFloatModule_basic",
"RsubFloatModule_noalpha_basic",
"RsubInt0d_NumToTensor_Module_basic",
Expand Down Expand Up @@ -3876,11 +3869,26 @@
"ViewSizeFromOtherTensor_basic",
"VisionTransformerModule_basic",
"ZerosLikeModule_falsePinMemory",
# count_include_pad and divisor_override check in TOSA AvgPool2d
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic",
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"ResNet18Module_basic",
"ResNet18StaticModule_basic",
"MobilenetV3Module_basic",
# Unexpected failures due to new PyTorch version update
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
"AdaptiveAvgPool1dGeneralDynamic_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
"AdaptiveAvgPool1dStaticLargerOutput_basic",
"AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dDynamicNoBatch_basic",
"AdaptiveAvgPool2dDynamic_basic",
"CrossEntropyLossModule_basic",
Expand Down
15 changes: 0 additions & 15 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2424,18 +2424,3 @@ func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !to
%0 = torch.prims.collapse %arg0, %int1, %int2 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,12],f32>
return %0 : !torch.vtensor<[2,12],f32>
}

// -----

func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%false = torch.constant.bool false
%count_include_pad = torch.constant.bool true
%0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
// expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool1d' that was explicitly marked illegal}}
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32>
return %3 : !torch.vtensor<[1,512,10],f32>
}

0 comments on commit c1c0524

Please sign in to comment.