From c531f5495bf2046d86bb76285f0d5d23076c71f8 Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Wed, 24 Jan 2024 11:09:56 -0600 Subject: [PATCH] AtenAdaptiveMaxPool2d Conversion to Linalg (#2779) The logic here is very similar to the conversion for AdaptiveAvgPool1d #2661 with a few modifications: 1. buffVal = -inf instead of 0 2. the main linalg generic op accumulates a max, instead of a sum, to the first output tensor 3. avg pooling requires dividing the sum pool by the kernel width, which we stored as an auxilliary tensor (kSizeTensor). Here, the auxiliary tensor will be recording the indices. Strangely enough, the only signature available for this function is to return indices, and it appears that they must be computed whether the user desires them or not. See [pytorch/torch/nn/functional.py](https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py#L1174). Before writing other adaptive pooling conversions, the logic of this decomposition should be rolled into a helper function that will work for both max and avg pooling ops. Even the auxiliary tensor should likely be automated. This code was written in a slightly more tedious way than strictly necessary (often using loops to fill SmallVectors up to rank-2, which is only two in this case), in order to more easily facilitate the transition to a helper function. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 +++ lib/Conversion/TorchToLinalg/Pooling.cpp | 204 ++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 73 +++++++ .../build_tools/abstract_interp_lib_gen.py | 23 ++ .../build_tools/torch_ods_gen.py | 1 + .../torch_mlir_e2e_test/test_suite/pooling.py | 81 +++++++ 6 files changed, 407 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 8ed176a8eae5..a46c79acb941 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6856,6 +6856,31 @@ def Torch_Aten_AdaptiveAvgPool3dBackwardOp : Torch_Op<"aten._adaptive_avg_pool3d }]; } +def Torch_AtenAdaptiveMaxPool2dOp : Torch_Op<"aten.adaptive_max_pool2d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::adaptive_max_pool2d : (Tensor, int[]) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$output_size + ); + let results = (outs + AnyTorchTensorType:$result0, + AnyTorchTensorType:$result1 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAdaptiveMaxPool2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 2); + } + void AtenAdaptiveMaxPool2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 2); + } + }]; +} + def Torch_AtenTopkOp : Torch_Op<"aten.topk", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 14d2c71dbc92..eed79072d0f9 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -793,6 +793,208 @@ class ConvertAtenAdaptiveAvgPool1dOp }; } // namespace +// The logic for this conversion is similar to the AdaptiveAvgPool1dOp +// conversion. Before writing any more adaptive pooling conversions, the logic +// in this should be off-loaded to a helper function, since each of the adaptive +// ops are essentially the same with some minor tweaks. Instead of kSizeTensor, +// we named the additional output of the linalg generic op auxTensor. +// For max pooling, auxTensor holds the indices of max values, and for +// avg pooling, the auxTensor will be kSizeTensor, used to later divide the +// sum pool by the kernel size. +namespace { +class ConvertAtenAdaptiveMaxPool2dOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenAdaptiveMaxPool2dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Location loc = op->getLoc(); + const TypeConverter *typeConverter = getTypeConverter(); + + // get rank of input (same as rank of output) + int64_t rank = + adaptor.getSelf().getType().cast().getRank(); + // input operand should be NCHW (i.e. rank 4) + if (rank != 4) { + return rewriter.notifyMatchFailure(op, "only supports input type NCHW"); + } + + // input tensor and output shape + Value input = adaptor.getSelf(); + Value outputShape = op.getOutputSize(); + SmallVector outShapeVector; + getListConstructElements(outputShape, outShapeVector); + outShapeVector = + getTypeConvertedValues(rewriter, loc, typeConverter, outShapeVector); + SmallVector inputSpatialSizes; + for (unsigned i = 2; i < rank; i++) { + inputSpatialSizes.push_back(getDimOp(rewriter, loc, input, i)); + } + SmallVector outShapeIndexVector; + for (auto v : outShapeVector) { + outShapeIndexVector.push_back(castIntToIndex(rewriter, loc, v)); + } + RankedTensorType inputType = input.getType().cast(); + RankedTensorType outputType = + typeConverter->convertType(op.getResult0().getType()) + .cast(); + + // get elementType of input tensor + Type elementType = inputType.getElementType(); + + // make an iteration space of size kMax = 1 + ceildiv (hIn - 1) , hOut + Type boolType = rewriter.getI1Type(); + SmallVector kIterSizeVector; + Value constantOne = + rewriter.create(loc, rewriter.getIndexAttr(1)); + for (int i = 0; i < rank - 2; i++) { + Value hInPlusOne = rewriter.create( + loc, inputSpatialSizes[i], constantOne); + Value kMaxMinusOne = rewriter.create( + loc, hInPlusOne, outShapeIndexVector[i]); + Value kMax = + rewriter.create(loc, constantOne, kMaxMinusOne); + kIterSizeVector.push_back(kMax); + } + Value kIter = rewriter.create( + loc, getAsOpFoldResult(kIterSizeVector), boolType); + + // need to buffer input, else there will possibly be an out of bounds access + // later buffVal = 0 for avg pooling and -inf for max pooling + auto smallestFPValueAttr = rewriter.getFloatAttr( + elementType, + APFloat::getInf(elementType.cast().getFloatSemantics(), + /*Negative=*/true)); + Value buffVal = rewriter.create(loc, elementType, + smallestFPValueAttr); + SmallVector lowPadding(rank, 0); + SmallVector highPadding(2, 0); + for (int i = 0; i < rank - 2; i++) { + highPadding.push_back(1); + } + Value buffInput = torch_to_linalg::getPaddedTensor( + op, rewriter, input, lowPadding, highPadding, buffVal); + + // make a list of outputSizes + SmallVector outputSizes; + for (unsigned i = 0; i < 2; i++) { + outputSizes.push_back(getDimOp(rewriter, loc, input, i)); + } + for (unsigned i = 2; i < rank; i++) { + outputSizes.push_back(outShapeIndexVector[i - 2]); + } + + // for avg pooling the auxTensor should hold kernel widths (kSizeTensor) + // for max Pooling, it should hold the indices + RankedTensorType outputType1 = + typeConverter->convertType(op.getResult1().getType()) + .cast(); + Type indicesType = outputType1.getElementType(); + Value auxTensor = rewriter.create( + loc, getAsOpFoldResult(outputSizes), indicesType); + + // initialize an output tensor + Value initOutput = + createInitTensor(rewriter, loc, outputSizes, elementType, buffVal); + + // setup indexing maps and iterator types for linalg generic op (outputShape + // (rank),kIter (rank -2)) for kIter (d0,d1,d2,d3,d4,d5) -> (d4,d5) for + // output (d0,d1,d2,d3,d4,d5) -> (d0,d1,d2,d3) for auxTensor + // (d0,d1,d2,d3,d4,d5) -> (d0,d1,d2,d3) (or (d2,d3) for avg pooling) + SmallVector kIterExprs, outputExprs, auxTensorExprs; + // batch + channel + output spatial dims + for (unsigned i = 0; i < rank; i++) { + outputExprs.push_back(rewriter.getAffineDimExpr(i)); + auxTensorExprs.push_back(rewriter.getAffineDimExpr(i)); + } + // kIter covers last rank-2 indices + for (unsigned i = rank; i < 2 * rank - 2; i++) { + kIterExprs.push_back(rewriter.getAffineDimExpr(i)); + } + SmallVector indexingMaps = + AffineMap::inferFromExprList({kIterExprs, outputExprs, auxTensorExprs}); + SmallVector iteratorTypes( + rank, utils::IteratorType::parallel); + for (unsigned i = 0; i < rank - 2; i++) { + iteratorTypes.push_back(utils::IteratorType::reduction); + } + Value indexOne = rewriter.create(loc, 1); + auto maxPool = rewriter.create( + loc, /*resultTensorTypes=*/ + TypeRange({initOutput.getType(), auxTensor.getType()}), + /*inputs=*/ValueRange({kIter}), + /*outputs=*/ValueRange({initOutput, auxTensor}), + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value res = args[1]; + Value maxIndex = args[2]; + SmallVector ind; + for (unsigned i = 0; i < 2 * rank - 2; i++) { + ind.push_back(b.create(loc, i)); + } + // compute start and end indices + // st = s1( s0(ind2 * Hin) // Hout ) + SmallVector starts; + SmallVector ends; + for (unsigned i = 2; i < rank; i++) { + Value s0 = + b.create(loc, ind[i], inputSpatialSizes[i - 2]); + Value s1 = b.create( + loc, s0, outShapeIndexVector[i - 2]); + starts.push_back(s1); + // en = e4( 1 + e3( e2( e1( e0(ind2 + 1) * hIn ) - 1 ) // hOut ) ) + Value e0 = b.create(loc, ind[i], indexOne); + Value e1 = + b.create(loc, e0, inputSpatialSizes[i - 2]); + Value e2 = b.create(loc, e1, indexOne); + Value e3 = b.create( + loc, e2, outShapeIndexVector[i - 2]); + Value e4 = b.create(loc, indexOne, e3); + ends.push_back(e4); + } + SmallVector inputElementIndices; + inputElementIndices.push_back(ind[0]); + inputElementIndices.push_back(ind[1]); + for (unsigned i = 2; i < rank; i++) { + inputElementIndices.push_back( + b.create(loc, starts[i - 2], ind[rank - 2 + i])); + } + Value inElt = b.create(loc, elementType, buffInput, + inputElementIndices); + // check if we extracted at windex < end index + for (unsigned i = 0; i < rank - 2; i++) { + Value cond = + b.create(loc, arith::CmpIPredicate(6), + inputElementIndices[i + 2], ends[i]); + inElt = b.create(loc, cond, inElt, buffVal); + } + Value cond1 = b.create(loc, arith::CmpFPredicate::OGT, + inElt, res); + // index location is (ih * input_width + iw) + Value indexOut0 = b.create(loc, inputElementIndices[2], + inputSpatialSizes[1]); + Value indexOut1 = + b.create(loc, indexOut0, inputElementIndices[3]); + Value indexOut1Int = castIndexToInt64(b, loc, indexOut1); + Value indexOut2 = + b.create(loc, cond1, indexOut1Int, maxIndex); + Value out2 = b.create(loc, cond1, inElt, res); + b.create(loc, ValueRange({out2, indexOut2})); + }); + + Value maxValues = rewriter.create( + loc, outputType, maxPool.getResultTensors()[0]); + Value outputIndices = rewriter.create( + loc, outputType1, maxPool.getResultTensors()[1]); + rewriter.replaceOp(op, {maxValues, outputIndices}); + return success(); + } +}; +} // namespace + void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -813,4 +1015,6 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 0e6313ea8978..590bea8d7176 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7964,6 +7964,73 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.adaptive_avg_pool2d(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.adaptive_max_pool2d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.tuple, list> {\n" +" %0 = call @__torch__.adaptive_max_pool2d(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.tuple, list>\n" +" return %0 : !torch.tuple, list>\n" +" }\n" +" func.func @__torch__.adaptive_max_pool2d(%arg0: !torch.list, %arg1: !torch.list) -> !torch.tuple, list> {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int2 = torch.constant.int 2\n" +" %int3 = torch.constant.int 3\n" +" %int4 = torch.constant.int 4\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %11 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %12 = torch.aten.eq.int %11, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" }\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %5, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %11 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.ne.int %11, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %12 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %6 = torch.prim.ListConstruct : () -> !torch.list\n" +" %7 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %8 = torch.aten.sub.int %7, %int2 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %8, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %11 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.append.t %6, %11 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %9 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" torch.prim.Loop %9, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %11 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.append.t %6, %11 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %10 = torch.prim.TupleConstruct %6, %6 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %10 : !torch.tuple, list>\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.flatten.using_ints\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.flatten(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" @@ -9896,6 +9963,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple\n" " return %1 : !torch.tuple\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.adaptive_max_pool2d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.mish\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index ca5d983f2c8d..28e87cc60990 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -902,6 +902,24 @@ def aten〇avg_pool2d〡shape(self: List[int], kernel_size: List[int], stride: L def aten〇adaptive_avg_pool2d〡shape(self: List[int], output_size: List[int]) -> List[int]: return upstream_shape_functions.adaptive_avg_pool2d(self, output_size) +def adaptive_max_pool2d(self: List[int], out: List[int]): + assert len(out) == 2 + assert len(self) == 3 or len(self) == 4 + + for i in range(len(self)): + assert self[i] != 0 + + shape: List[int] = [] + for i in range(len(self) - 2): + shape.append(self[i]) + for j in range(len(out)): + shape.append(out[j]) + + return shape, shape + +def aten〇adaptive_max_pool2d〡shape(self: List[int], output_size: List[int]) -> Tuple[List[int], List[int]]: + return adaptive_max_pool2d(self, output_size) + def aten〇flatten〇using_ints〡shape(self: List[int], start_dim: int = 0, end_dim: int = -1) -> List[int]: return upstream_shape_functions.flatten(self, start_dim, end_dim) @@ -2334,6 +2352,11 @@ def aten〇max_pool2d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], ker self_rank, self_dtype = self_rank_dtype return self_dtype, torch.int64 +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[2, 2])) +def aten〇adaptive_max_pool2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int]) -> Tuple[int, int]: + self_rank, self_dtype = self_rank_dtype + return self_dtype, torch.int64 + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇mish〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 45f580ba5e13..ae4c608c6de7 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -502,6 +502,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)") emit("aten::_adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)") emit("aten::_adaptive_avg_pool3d_backward : (Tensor, Tensor) -> (Tensor)") + emit("aten::adaptive_max_pool2d : (Tensor, int[]) -> (Tensor, Tensor)") emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)") emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)") emit("aten::pixel_shuffle : (Tensor, int) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index b19596be7031..1d3481196e5f 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -1061,3 +1061,84 @@ def forward(self, x): def AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic( module, tu: TestUtils): module.forward(tu.rand(1, 512, 7)) + +class AdaptiveMaxPool2dDynamic(torch.nn.Module): + + def __init__(self): + super().__init__() + self.amp2d = torch.nn.AdaptiveMaxPool2d(output_size=(7,13), return_indices=False) + + @export + @annotate_args([ + None, + ([-1,-1,-1,-1], torch.float32, True) + ]) + def forward(self,x): + return self.amp2d(x) + +@register_test_case( + module_factory=lambda: AdaptiveMaxPool2dDynamic()) +def AdaptiveMaxPool2dDynamic_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 10, 16)) + +class AdaptiveMaxPool2dDynamicWithIndices(torch.nn.Module): + + def __init__(self): + super().__init__() + self.amp2d = torch.nn.AdaptiveMaxPool2d(output_size=(7,13), return_indices=True) + + @export + @annotate_args([ + None, + ([-1,-1,-1,-1], torch.float32, True) + ]) + def forward(self,x): + return self.amp2d(x) + +@register_test_case( + module_factory=lambda: AdaptiveMaxPool2dDynamicWithIndices()) +def AdaptiveMaxPool2dDynamicWithIndices_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 10, 16)) + + +class AdaptiveMaxPool2dStatic(torch.nn.Module): + + def __init__(self): + super().__init__() + self.amp2d = torch.nn.AdaptiveMaxPool2d(output_size=(7,13), return_indices=False) + + @export + @annotate_args([ + None, + ([1, 512, 10, 9], torch.float32, True) + ]) + def forward(self,x): + return self.amp2d(x) + +@register_test_case( + module_factory=lambda: AdaptiveMaxPool2dStatic()) +def AdaptiveMaxPool2dStatic_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 10, 9)) + +class AdaptiveMaxPool2dStaticWithIndices(torch.nn.Module): + + def __init__(self): + super().__init__() + self.amp2d = torch.nn.AdaptiveMaxPool2d(output_size=(7,13), return_indices=True) + + @export + @annotate_args([ + None, + ([1, 512, 10, 16], torch.float32, True) + ]) + def forward(self,x): + return self.amp2d(x) + +@register_test_case( + module_factory=lambda: AdaptiveMaxPool2dStaticWithIndices()) +def AdaptiveMaxPool2dStaticWithIndices_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 10, 16)) \ No newline at end of file