diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index d154edb1ab750..1dda2e81ddc38 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -14,6 +14,20 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::onnx_c; +static Value createConstantIntList(OpBinder binder, + ConversionPatternRewriter &rewriter, + SmallVector cstInput) { + SmallVector cstValue; + for (int64_t i : cstInput) { + cstValue.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + return rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstValue); +} + // Simple rewrites for the default domain. // See: https://onnx.ai/onnx/operators/ // For operators that are effectively version invariant, we register with @@ -148,6 +162,86 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, lhs, rhs); return success(); }); + patterns.onOp( + "MaxPool", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + std::string autoPad; + if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) + return rewriter.notifyMatchFailure(binder.op, + "auto_pad bind failure"); + if (autoPad != "NOTSET") { + // TODO: Add support for `auto_pad` != "NOTSET" + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: auto_pad != NOTSET"); + } + + Torch::ValueTensorType resultType; + Value operand; + bool ceilMode; + int64_t storageOrder; + // TODO: Add support for indices output and storage_order + if (binder.tensorOperand(operand) || + binder.s64BoolAttr(ceilMode, "ceil_mode", false) || + binder.s64IntegerAttr(storageOrder, "storage_order", 0) || + binder.tensorResultType(resultType)) + return rewriter.notifyMatchFailure( + binder.op, + "operand/ceil_mode/storage_order/resultType bind failure"); + if (storageOrder != 0) + return rewriter.notifyMatchFailure( + binder.op, "storage_order setting is not supported."); + // Determine the rank of input tensor. + std::optional maybeRank = Torch::getTensorRank(operand); + if (!maybeRank) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: unranked tensor"); + unsigned rank = *maybeRank; + + SmallVector kernel, padding, strides, dilations; + if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {})) + return rewriter.notifyMatchFailure(binder.op, + "kernel_shape bind failure"); + if (kernel.size() != rank - 2) + return rewriter.notifyMatchFailure( + binder.op, "kernel list size does not match the number of axes"); + if (binder.s64IntegerArrayAttr(padding, "pads", {0})) + return rewriter.notifyMatchFailure(binder.op, "pads bind failure"); + if (padding.size() != 1 && padding.size() != rank - 2) + return rewriter.notifyMatchFailure( + binder.op, "padding list size does not match the number of axes"); + if (binder.s64IntegerArrayAttr(strides, "strides", {1})) + return rewriter.notifyMatchFailure(binder.op, "strides bind failure"); + if (strides.size() != 1 && strides.size() != rank - 2) + return rewriter.notifyMatchFailure( + binder.op, "strides list size does not match the number of axes"); + if (binder.s64IntegerArrayAttr(dilations, "dilations", {})) + return rewriter.notifyMatchFailure(binder.op, + "dilations bind failure"); + + Value kernelSizeList = createConstantIntList(binder, rewriter, kernel); + Value paddingList = createConstantIntList(binder, rewriter, padding); + Value stridesList = createConstantIntList(binder, rewriter, strides); + Value dilationsList = + createConstantIntList(binder, rewriter, dilations); + Value cstCeilMode = + rewriter.create(binder.getLoc(), ceilMode); + + if (rank == 3) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: AtenMaxPool1dOp"); + if (rank == 4) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, kernelSizeList, stridesList, + paddingList, dilationsList, cstCeilMode); + return success(); + } + if (rank == 5) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, kernelSizeList, stridesList, + paddingList, dilationsList, cstCeilMode); + return success(); + } + return failure(); + }); patterns.onOp("Greater", 16, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index e224ddfa2944c..2f7a51a783380 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -124,6 +124,8 @@ func.func @test_leaky_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_matmul_2d func.func @test_matmul_2d(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[3,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[3,3],f32> @@ -160,6 +162,62 @@ func.func @test_matmul_4d(%arg0: !torch.vtensor<[1,2,3,4],f32>, %arg1: !torch.vt // ----- +// CHECK-LABEL: func.func @test_maxpool_2d_default +func.func @test_maxpool_2d_default(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[I2_1:.*]] = torch.constant.int 2 + // CHECK: %[[LIST22:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0]] : (!torch.int) -> !torch.list + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.max_pool2d %arg0, %[[LIST22]], %[[LIST1]], %[[LIST0]], %[[LIST]], %[[FALSE]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,3,31,31],f32> + %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32> + return %0 : !torch.vtensor<[1,3,31,31],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_maxpool_2d_ceil +func.func @test_maxpool_2d_ceil(%arg0: !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,2,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { + // CHECK: %[[I3:.*]] = torch.constant.int 3 + // CHECK: %[[I3_1:.*]] = torch.constant.int 3 + // CHECK: %[[LIST33:.*]] = torch.prim.ListConstruct %[[I3]], %[[I3_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0]] : (!torch.int) -> !torch.list + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[I2_1:.*]] = torch.constant.int 2 + // CHECK: %[[LIST22:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[TRUE:.*]] = torch.constant.bool true + // CHECK: torch.aten.max_pool2d %arg0, %[[LIST33]], %[[LIST22]], %[[LIST0]], %[[LIST]], %[[TRUE]] : !torch.vtensor<[1,1,4,4],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,1,2,2],f32> + %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.ceil_mode = 1 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,2,2],f32> + return %0 : !torch.vtensor<[1,1,2,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_maxpool_3d_default +func.func @test_maxpool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[I2_1:.*]] = torch.constant.int 2 + // CHECK: %[[I2_2:.*]] = torch.constant.int 2 + // CHECK: %[[LIST222:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]], %[[I2_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0]] : (!torch.int) -> !torch.list + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.max_pool3d %arg0, %[[LIST222]], %[[LIST1]], %[[LIST0]], %[[LIST]], %[[FALSE]] : !torch.vtensor<[1,3,32,32,32],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,3,31,31,31],f32> + %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32> + return %0 : !torch.vtensor<[1,3,31,31,31],f32> +} + +// ----- + // CHECK-LABEL: @test_gelu_default_1 func.func @test_gelu_default_1(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[STR1:.*]] = torch.constant.str "none"