Skip to content

Commit

Permalink
[torch] Support !countIncludePad when unpadded for average pool (ll…
Browse files Browse the repository at this point in the history
…vm#2836)

We do not support average pool when `countIncludePad is set to false.
However if the input is unpadded then the setting of the boolean is
unneeded. Extended use by checking if padding is zero before rejecting
the lowering.
  • Loading branch information
rsuderman authored Jan 31, 2024
1 parent 0114a57 commit 34f6948
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
5 changes: 4 additions & 1 deletion lib/Conversion/TorchToLinalg/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,10 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
m_TorchConstantBool(&countIncludePad)))
return rewriter.notifyMatchFailure(
op, "count_include_pad must be a constant");
if (!countIncludePad) {

// If the padding is zero then there is no padding to include.
if (!countIncludePad &&
!llvm::all_of(paddingInts, [](int64_t p) { return p == 0; })) {
return rewriter.notifyMatchFailure(
op, "unimplemented: count_include_pad is expected to be true");
}
Expand Down
24 changes: 23 additions & 1 deletion projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,28 @@ def forward(self, x):
def AvgPool2dCeilModeTrueModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0))

class AvgPool2dWithoutPadModule(torch.nn.Module):

def __init__(self):
super().__init__()
self.ap2d = torch.nn.AvgPool2d(kernel_size=[6, 8],
stride=[2, 2],
padding=[0, 0],
ceil_mode=False,
count_include_pad=False,
divisor_override=None)

@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return self.ap2d(x)

@register_test_case(module_factory=lambda: AvgPool2dWithoutPadModule())
def AvgPool2dWithoutPadModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0))

# ==============================================================================

Expand Down Expand Up @@ -1141,4 +1163,4 @@ def forward(self,x):
module_factory=lambda: AdaptiveMaxPool2dStaticWithIndices())
def AdaptiveMaxPool2dStaticWithIndices_basic(
module, tu: TestUtils):
module.forward(tu.rand(1, 512, 10, 16))
module.forward(tu.rand(1, 512, 10, 16))

0 comments on commit 34f6948

Please sign in to comment.