From c7452af4fa7b4139dbd8b78b388b84a08b8c1b7a Mon Sep 17 00:00:00 2001 From: Chi_Liu <22491986+AmosLewis@users.noreply.github.com> Date: Fri, 12 Jan 2024 14:54:38 -0800 Subject: [PATCH] [MLIR][ONNX] Add OnnxToTorch support for Maxpool Op (#2695) Add Maxpool ONNX op support. Add Utils.h/cpp files to create a constant int list for ONNX. --- .../Conversion/TorchOnnxToTorch/Utils.h | 23 +++++ .../TorchOnnxToTorch/CMakeLists.txt | 1 + .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 79 +++++++++++++++++ lib/Conversion/TorchOnnxToTorch/Utils.cpp | 28 ++++++ .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 85 ++++++++++++++++++- 5 files changed, 215 insertions(+), 1 deletion(-) create mode 100644 include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h create mode 100644 lib/Conversion/TorchOnnxToTorch/Utils.cpp diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h new file mode 100644 index 000000000000..058fee4da4a2 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h @@ -0,0 +1,23 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H +#define TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H + +#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" + +namespace mlir::torch::onnx_c { + +Value createConstantIntList(OpBinder binder, + ConversionPatternRewriter &rewriter, + SmallVector cstInput); + +} // namespace mlir::torch::onnx_c + +#endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H diff --git a/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt b/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt index 807db64eac64..4a5015816609 100644 --- a/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt +++ b/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_conversion_library(TorchMLIRTorchOnnxToTorch Passes.cpp Patterns.cpp TorchOnnxToTorch.cpp + Utils.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchOnnxToTorch diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 0102366fe01c..c0a7473e4601 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" +#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; @@ -148,6 +149,84 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, lhs, rhs); return success(); }); + patterns.onOp( + "MaxPool", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + std::string autoPad; + if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) + return rewriter.notifyMatchFailure(binder.op, + "auto_pad bind failure"); + if (autoPad != "NOTSET") + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: auto_pad != NOTSET"); + + Torch::ValueTensorType resultType; + Value operand; + bool ceilMode; + int64_t storageOrder; + // TODO: Add support for indices output and storage_order + if (binder.tensorOperand(operand) || + binder.s64BoolAttr(ceilMode, "ceil_mode", false) || + binder.s64IntegerAttr(storageOrder, "storage_order", 0) || + binder.tensorResultType(resultType)) + return rewriter.notifyMatchFailure( + binder.op, + "operand/ceil_mode/storage_order/resultType bind failure"); + if (storageOrder != 0) + return rewriter.notifyMatchFailure( + binder.op, "storage_order setting is not supported."); + // Determine the rank of input tensor. + std::optional maybeRank = Torch::getTensorRank(operand); + if (!maybeRank) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: unranked tensor"); + unsigned rank = *maybeRank; + + SmallVector kernel, padding, strides, dilations; + if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {})) + return rewriter.notifyMatchFailure(binder.op, + "kernel_shape bind failure"); + if (kernel.size() != rank - 2) + return rewriter.notifyMatchFailure( + binder.op, "kernel list size does not match the number of axes"); + if (binder.s64IntegerArrayAttr(padding, "pads", {0})) + return rewriter.notifyMatchFailure(binder.op, "pads bind failure"); + if (padding.size() != 1 && padding.size() != rank - 2) + return rewriter.notifyMatchFailure( + binder.op, "padding list size does not match the number of axes"); + if (binder.s64IntegerArrayAttr(strides, "strides", {1})) + return rewriter.notifyMatchFailure(binder.op, "strides bind failure"); + if (strides.size() != 1 && strides.size() != rank - 2) + return rewriter.notifyMatchFailure( + binder.op, "strides list size does not match the number of axes"); + if (binder.s64IntegerArrayAttr(dilations, "dilations", {})) + return rewriter.notifyMatchFailure(binder.op, + "dilations bind failure"); + + Value kernelSizeList = createConstantIntList(binder, rewriter, kernel); + Value paddingList = createConstantIntList(binder, rewriter, padding); + Value stridesList = createConstantIntList(binder, rewriter, strides); + Value dilationsList = + createConstantIntList(binder, rewriter, dilations); + Value cstCeilMode = + rewriter.create(binder.getLoc(), ceilMode); + + if (rank == 3) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: AtenMaxPool1dOp"); + if (rank == 4) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, kernelSizeList, stridesList, + paddingList, dilationsList, cstCeilMode); + return success(); + } + if (rank == 5) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, kernelSizeList, stridesList, + paddingList, dilationsList, cstCeilMode); + return success(); + } + return rewriter.notifyMatchFailure(binder.op, "No rank is matched."); + }); patterns.onOp("Greater", 16, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/lib/Conversion/TorchOnnxToTorch/Utils.cpp b/lib/Conversion/TorchOnnxToTorch/Utils.cpp new file mode 100644 index 000000000000..8f5a2e67c0cb --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/Utils.cpp @@ -0,0 +1,28 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::onnx_c; + +Value mlir::torch::onnx_c::createConstantIntList( + OpBinder binder, ConversionPatternRewriter &rewriter, + SmallVector cstInput) { + SmallVector cstValue; + for (int64_t i : cstInput) { + cstValue.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + return rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstValue); +} diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 07ddf3e594ea..c85659c25aa8 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -13,6 +13,8 @@ func.func @test_greater(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtenso return %0 : !torch.vtensor<[3,4,5],i1> } +// ----- + // CHECK-LABEL: func.func @test_greater_or_equal func.func @test_greater_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> @@ -22,6 +24,8 @@ func.func @test_greater_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !tor return %0 : !torch.vtensor<[3,4,5],i1> } +// ----- + // CHECK-LABEL: func.func @test_less func.func @test_less(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> @@ -31,6 +35,8 @@ func.func @test_less(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[ return %0 : !torch.vtensor<[3,4,5],i1> } +// ----- + // CHECK-LABEL: func.func @test_gather_elements func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5], si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} { // CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0 @@ -99,7 +105,7 @@ func.func @test_gemm_beta(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtenso return %0 : !torch.vtensor<[3,4],f32> } - // ----- +// ----- // CHECK-LABEL: func.func @test_gemm_alpha_beta func.func @test_gemm_alpha_beta(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { @@ -137,6 +143,8 @@ func.func @test_leaky_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_matmul_2d func.func @test_matmul_2d(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[3,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[3,3],f32> @@ -173,6 +181,62 @@ func.func @test_matmul_4d(%arg0: !torch.vtensor<[1,2,3,4],f32>, %arg1: !torch.vt // ----- +// CHECK-LABEL: func.func @test_maxpool_2d_default +func.func @test_maxpool_2d_default(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[I2_1:.*]] = torch.constant.int 2 + // CHECK: %[[LIST22:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0]] : (!torch.int) -> !torch.list + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.max_pool2d %arg0, %[[LIST22]], %[[LIST1]], %[[LIST0]], %[[LIST]], %[[FALSE]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,3,31,31],f32> + %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32> + return %0 : !torch.vtensor<[1,3,31,31],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_maxpool_2d_ceil +func.func @test_maxpool_2d_ceil(%arg0: !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,2,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { + // CHECK: %[[I3:.*]] = torch.constant.int 3 + // CHECK: %[[I3_1:.*]] = torch.constant.int 3 + // CHECK: %[[LIST33:.*]] = torch.prim.ListConstruct %[[I3]], %[[I3_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0]] : (!torch.int) -> !torch.list + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[I2_1:.*]] = torch.constant.int 2 + // CHECK: %[[LIST22:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[TRUE:.*]] = torch.constant.bool true + // CHECK: torch.aten.max_pool2d %arg0, %[[LIST33]], %[[LIST22]], %[[LIST0]], %[[LIST]], %[[TRUE]] : !torch.vtensor<[1,1,4,4],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,1,2,2],f32> + %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.ceil_mode = 1 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,2,2],f32> + return %0 : !torch.vtensor<[1,1,2,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_maxpool_3d_default +func.func @test_maxpool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[I2_1:.*]] = torch.constant.int 2 + // CHECK: %[[I2_2:.*]] = torch.constant.int 2 + // CHECK: %[[LIST222:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]], %[[I2_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0]] : (!torch.int) -> !torch.list + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.max_pool3d %arg0, %[[LIST222]], %[[LIST1]], %[[LIST0]], %[[LIST]], %[[FALSE]] : !torch.vtensor<[1,3,32,32,32],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,3,31,31,31],f32> + %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32> + return %0 : !torch.vtensor<[1,3,31,31,31],f32> +} + +// ----- + // CHECK-LABEL: @test_gelu_default_1 func.func @test_gelu_default_1(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[STR1:.*]] = torch.constant.str "none" @@ -222,6 +286,8 @@ func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch. return %0 : !torch.vtensor<[3,4,5],i1> } +// ----- + // CHECK-LABEL: func.func @test_pow func.func @test_pow(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> @@ -229,6 +295,8 @@ func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch. return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_hardsigmoid_example func.func @test_hardsigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 5.000000e-01 @@ -252,6 +320,8 @@ func.func @test_hardsigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vt return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: @test_hardsigmoid func.func @test_hardsigmoid(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 5.000000e-01 @@ -274,6 +344,8 @@ func.func @test_hardsigmoid(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_hardsigmoid_default func.func @test_hardsigmoid_default(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 0.20000000298023224 @@ -331,6 +403,8 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3 return %0 : !torch.vtensor<[1,1,1,1],f32> } +// ----- + // CHECK-LABEL: func.func @test_max_example func.func @test_max_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.maximum %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> @@ -338,6 +412,8 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3 return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: func.func @test_min_example func.func @test_min_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.minimum %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> @@ -345,6 +421,7 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3 return %0 : !torch.vtensor<[3],f32> } +// ----- // CHECK-LABEL: func.func @test_log func.func @test_log(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { @@ -353,6 +430,8 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3 return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_neg func.func @test_neg(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.neg %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> @@ -360,6 +439,8 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3 return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_not_2d func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1> @@ -367,6 +448,8 @@ func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4], return %0 : !torch.vtensor<[3,4],i1> } +// ----- + // CHECK-LABEL: func.func @test_or2d func.func @test_or2d(%arg0: !torch.vtensor<[3,4],i1>, %arg1: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_or.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],i1>, !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1>