Skip to content

Commit

Permalink
Support default padding case for tosa::AvgPool in the presence of cou…
Browse files Browse the repository at this point in the history
…nt_include_pad (#3868)

Essentially, as part of my earlier
[change](7f9f99c)
, I didn't consider the `padding` value while erroring out for
unsupported `count_include_pad` during `torch-to-tosa` lowering for
AvgPool2d. The fix captured in this change addresses this. Please see
[issue](#3862) for more details
on this.

Co-authored-by: Hanumanth Hanumantharayappa <hhanuman@ah-hhanuman-l.dhcp.mathworks.com>
  • Loading branch information
Hanumanth04 and Hanumanth Hanumantharayappa authored Nov 12, 2024
1 parent cd38ecf commit 30c5193
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 40 deletions.
46 changes: 21 additions & 25 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5549,6 +5549,26 @@ static LogicalResult getOutputTypeAndPoolingParameters(
std::is_same<AtenOpT, AtenAvgPool1dOp>())
paddingInts.push_back(0);

if constexpr (std::is_same<AtenOpT, AtenAvgPool1dOp>() ||
std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
// Currently, we can not represent `count_include_pad` with the existing
// TOSA AvgPool2d specification. Without the below check, we produce silent
// wrong answer (SWA) when the `count_include_pad` value is `true.`
//
// Note: We need to check for `count_include_pad` only when the `padding`
// value is non-zero.
bool countIncludePad;
if ((paddingInts[0] != 0 || paddingInts[1] != 0) &&
(!matchPattern(op.getCountIncludePad(),
m_TorchConstantBool(&countIncludePad)) ||

countIncludePad)) {
return rewriter.notifyMatchFailure(
op, "Unsupported `count_include_pad` value, for tosa AvgPool "
"`count_include_pad` value should be `False`.");
}
}

SmallVector<int64_t, 4> padArr = {paddingInts[0], paddingInts[0],
paddingInts[1], paddingInts[1]};
kernel = rewriter.getDenseI64ArrayAttr(kernelSizeInts);
Expand Down Expand Up @@ -5677,18 +5697,6 @@ class ConvertAtenAvgPool2dOp
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
Type &outputTy) const override {

// Currently, we can not represent `count_include_pad` with the existing
// TOSA AvgPool2d specification. Without the below check, we produce silent
// wrong answers (SWA) when the `count_include_pad` value is `true.`
bool countIncludePad;
if (!matchPattern(op.getCountIncludePad(),
m_TorchConstantBool(&countIncludePad)) ||
countIncludePad) {
return rewriter.notifyMatchFailure(
op, "Unsupported `count_include_pad` value, for tosa AvgPool2dOp "
"`count_include_pad` value should be `False`.");
}

// Currently, we can not represent `divisor_override` with the existing TOSA
// AvgPool2d specification. Without the below check, we produce silent wrong
// answers (SWA) when the `divisor_override` value is other than `None.`
Expand Down Expand Up @@ -5737,7 +5745,7 @@ class ConvertAtenAvgPool1dOp
// Expected a rank 3 input tensor
if (selfTy.getRank() != 3)
return rewriter.notifyMatchFailure(
op, "Input tensor for MaxPool1d should have rank 3");
op, "Input tensor for AvgPool1d should have rank 3");

// Unsqueeze input tensor to rank 4 to be compatible with tosa::AvgPool2dOp
SmallVector<int64_t> rank4Shape(selfShape);
Expand All @@ -5748,18 +5756,6 @@ class ConvertAtenAvgPool1dOp
selfTy.getElementType()),
self, rewriter.getDenseI64ArrayAttr(rank4Shape));

// Currently, we can not represent `count_include_pad` with the existing
// TOSA AvgPool2d specification. Without the below check, we produce silent
// wrong answers (SWA) when the `count_include_pad` value is `true.`
bool countIncludePad;
if (!matchPattern(op.getCountIncludePad(),
m_TorchConstantBool(&countIncludePad)) ||
countIncludePad) {
return rewriter.notifyMatchFailure(
op, "Unsupported `count_include_pad` value, for tosa AvgPool2dOp "
"`count_include_pad` value should be `False`.");
}

SmallVector<int64_t, 2> dilationArray{1, 1};
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool1dOp,
tosa::AvgPool2dOp>(
Expand Down
22 changes: 7 additions & 15 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1736,6 +1736,12 @@
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
"ElementwiseCosIntModule_basic",
"ElementwiseReciprocalIntModule_basic",
Expand Down Expand Up @@ -2316,6 +2322,7 @@
"ReshapeExpandModule_basic",
"ReturnThreeTensorFloat32_basic",
"ReturnTwoTensorF32I64_basic",
"ResNet18StaticModule_basic",
"RsubFloatModule_basic",
"RsubFloatModule_noalpha_basic",
"RsubInt0d_NumToTensor_Module_basic",
Expand Down Expand Up @@ -3869,26 +3876,11 @@
"ViewSizeFromOtherTensor_basic",
"VisionTransformerModule_basic",
"ZerosLikeModule_falsePinMemory",
# count_include_pad and divisor_override check in TOSA AvgPool2d
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic",
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"ResNet18Module_basic",
"ResNet18StaticModule_basic",
"MobilenetV3Module_basic",
# Unexpected failures due to new PyTorch version update
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
"AdaptiveAvgPool1dGeneralDynamic_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
"AdaptiveAvgPool1dStaticLargerOutput_basic",
"AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dDynamicNoBatch_basic",
"AdaptiveAvgPool2dDynamic_basic",
"CrossEntropyLossModule_basic",
Expand Down
15 changes: 15 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2424,3 +2424,18 @@ func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !to
%0 = torch.prims.collapse %arg0, %int1, %int2 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,12],f32>
return %0 : !torch.vtensor<[2,12],f32>
}

// -----

func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%false = torch.constant.bool false
%count_include_pad = torch.constant.bool true
%0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
// expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool1d' that was explicitly marked illegal}}
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32>
return %3 : !torch.vtensor<[1,512,10],f32>
}

0 comments on commit 30c5193

Please sign in to comment.