diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index e825938ee65f..13d3a8de9463 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -84,6 +84,11 @@ void createTorchDynamoExportToTorchBackendPipeline( void createTorchFunctionToTorchBackendPipeline( OpPassManager &pm, const TorchLoweringPipelineOptions &options); +/// Creates a pipeline that lowers the torch Onnx IR that is produced by +/// Onnx import into the form expected by torch-verify-backend-contract. +void createTorchOnnxToTorchBackendPipeline( + OpPassManager &pm, const TorchLoweringPipelineOptions &options); + /// Creates a pipeline that simplifies the computations in the program. /// This pass does not do any global program restructuring -- it works entirely /// within a single semantic model of a `builtin.module` with diff --git a/lib/Dialect/Torch/Transforms/Passes.cpp b/lib/Dialect/Torch/Transforms/Passes.cpp index 3ed8dc324578..846470202c15 100644 --- a/lib/Dialect/Torch/Transforms/Passes.cpp +++ b/lib/Dialect/Torch/Transforms/Passes.cpp @@ -10,6 +10,7 @@ #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" +#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h" void mlir::torch::registerTorchPasses() { mlir::torch::registerPasses(); @@ -25,6 +26,10 @@ void mlir::torch::registerTorchPasses() { "torch-function-to-torch-backend-pipeline", "Pipeline lowering a Torch function to Torch backend form.", mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline); + mlir::PassPipelineRegistration( + "torch-onnx-to-torch-backend-pipeline", + "Pipeline lowering Torch Onnx IR to Torch backend form.", + mlir::torch::Torch::createTorchOnnxToTorchBackendPipeline); mlir::PassPipelineRegistration( "torch-simplification-pipeline", "Pipeline simplifying computations in the program.", @@ -86,6 +91,37 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline( options.backendLegalOps, options.extraLibrary)); } +void mlir::torch::Torch::createTorchOnnxToTorchBackendPipeline( + OpPassManager &pm, const TorchLoweringPipelineOptions &options) { + pm.addNestedPass(onnx_c::createTorchOnnxToTorchPass()); + // The above pass just converts the torch onnx IR to torch, hence the given + // pipeline will make sure that the IR is transformed such that it satisfies + // the backend contract. + if (options.decompose) { + pm.addNestedPass( + Torch::createDecomposeComplexOpsPass(options.backendLegalOps)); + pm.addNestedPass(createCanonicalizerPass()); + } + // TODO: Move the combination of two passes i.e., ScalarizeShapes and + // TorchShapeRefinementPipeline out of here and create an onnx shape + // refinement pipeline which runs iteratively over the IR. + createTorchShapeRefinementPipeline(pm, options); + // This pass scalarizes the tensor shape computations. + pm.addNestedPass( + mlir::torch::Torch::createScalarizeShapesPass()); + createTorchShapeRefinementPipeline(pm, options); + pm.addPass(Torch::createRefinePublicReturnPass()); + pm.addNestedPass(createCanonicalizerPass()); + // The decompose pass is run again here since the scalarize shapes pass and + // shape refinement pipeline might create some ops for which decomposition + // exists. + if (options.decompose) { + pm.addNestedPass( + Torch::createDecomposeComplexOpsPass(options.backendLegalOps)); + pm.addNestedPass(createCanonicalizerPass()); + } +} + // A simplification pipeline to establish the invariants of the backend // contract (see `satisfiedBackendContract` in `LowerToBackendContract`). // diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 40d7b629a275..bdb46d636681 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -70,7 +70,6 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( // We want to fuse quantized operations together before lowering to linalg. pm.addNestedPass(Torch::createFuseQuantizedOpsPass()); - pm.addNestedPass(Torch::createScalarizeShapesPass()); // Lower to linalg + guards which is the input to codegen backends. // We do this first as it tends to involve pattern-matching against constants, diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py index a6e42e278757..79404b1d0d80 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py @@ -100,33 +100,25 @@ def _module_lowering( print("ONNX RAW IR") print(torch_mod) - # Lower from ONNX to Torch - run_pipeline_with_repro_report( - torch_mod, - # The importer may produce additional MLIR functions corresponding to - # ONNX operators that are functions. In some cases they need to be - # inlined to avoid the backend choking on them. - f"builtin.module(inline, func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))", - "Lowering Onnx backend contract to Linalg-on-Tensors backend contract", - ) - - if verbose: - print("\n====================") - print("TorchFX IR") - print(torch_mod) - backend_legal_ops = [ "aten.flatten.using_ints", "aten.adaptive_avg_pool1d", "aten.unflatten.int", ] option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}" + + # Lower from ONNX to Torch run_pipeline_with_repro_report( torch_mod, - f"builtin.module(torch-lower-to-backend-contract{option_string})", - "Lowering TorchFX IR -> Torch Backend IR", + f"builtin.module(torch-onnx-to-torch-backend-pipeline{option_string})", + "Lowering Onnx Raw IR -> Torch Backend IR", ) + if verbose: + print("\n====================") + print("Torch IR") + print(torch_mod) + return lower_mlir_module(verbose, output_type, torch_mod) diff --git a/test/Dialect/Torch/torch-onnx-to-torch-backend-pipeline.mlir b/test/Dialect/Torch/torch-onnx-to-torch-backend-pipeline.mlir new file mode 100644 index 000000000000..038f5686d6a4 --- /dev/null +++ b/test/Dialect/Torch/torch-onnx-to-torch-backend-pipeline.mlir @@ -0,0 +1,67 @@ +// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-onnx-to-torch-backend-pipeline{backend-legal-ops=aten.flatten.using_ints,aten.unflatten.int})' -split-input-file %s | FileCheck %s + +// CHECK-LABEL: func.func @test_reshape_negative_dim_decompose +func.func @test_reshape_negative_dim_decompose(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,6,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[INT6:.+]] = torch.constant.int 6 + // CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT6]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.view %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,6,2],f32> + %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,6,2],f32> + return %0 : !torch.vtensor<[2,6,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_triu_decompose +func.func @test_triu_decompose(%arg0: !torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[ZERO_TENSOR:.+]] = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT4:.+]] = torch.constant.int 4 + // CHECK: %[[INT5:.+]] = torch.constant.int 5 + // CHECK: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[INT0]], %[[INT4]], %[[INT1]], %[[INT4]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[4],si64> + // CHECK: %[[ARANGE_0:.+]] = torch.aten.arange.start_step %[[INT0]], %[[INT5]], %[[INT1]], %[[INT4]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[5],si64> + // CHECK: %[[UNSQUEEZE:.+]] = torch.aten.unsqueeze %[[ARANGE]], %[[INT1]] : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + // CHECK: %[[UNSQUEEZE_0:.+]] = torch.aten.unsqueeze %[[ARANGE_0]], %[[INT0]] : !torch.vtensor<[5],si64>, !torch.int -> !torch.vtensor<[1,5],si64> + // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %[[UNSQUEEZE]], %[[INT0]], %[[INT1]] : !torch.vtensor<[4,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1],si64> + // CHECK: %[[COND:.+]] = torch.aten.ge.Tensor %[[UNSQUEEZE_0]], %[[ADD]] : !torch.vtensor<[1,5],si64>, !torch.vtensor<[4,1],si64> -> !torch.vtensor<[4,5],i1> + // CHECK: %[[RESULT:.+]] = torch.aten.where.self %[[COND]], %arg0, %[[ZERO_TENSOR]] : !torch.vtensor<[4,5],i1>, !torch.vtensor<[4,5],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[4,5],si64> + %0 = torch.operator "onnx.Trilu"(%arg0) : (!torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64> + return %0 : !torch.vtensor<[4,5],si64> +} + +// ----- + +module { +// CHECK-LABEL: func.func @test_scalarize + func.func @test_scalarize(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.11.0"} { + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[ADD:.+]] = torch.aten.flatten.using_ints %arg0, %[[INT2]], %[[INT3]] : !torch.vtensor<[?,?,16,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,1024],f32> + %0 = torch.operator "onnx.Shape"(%arg0) : (!torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[4],si64> + %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__21> : tensor} : () -> !torch.vtensor<[],si64> + %2 = torch.operator "onnx.Gather"(%0, %1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> + %3 = torch.operator "onnx.Shape"(%arg0) : (!torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[4],si64> + %4 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__22> : tensor} : () -> !torch.vtensor<[],si64> + %5 = torch.operator "onnx.Gather"(%3, %4) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> + %6 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> + %7 = torch.operator "onnx.Unsqueeze"(%2, %6) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> + %8 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> + %9 = torch.operator "onnx.Unsqueeze"(%5, %8) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> + %10 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_onnx__Concat_3209> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> + %11 = torch.operator "onnx.Concat"(%7, %9, %10) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3],si64> + %12 = torch.operator "onnx.Reshape"(%arg0, %11) : (!torch.vtensor<[?,?,16,64],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],f32> + return %12 : !torch.vtensor<[?,?,?],f32> + } +} + +{-# + dialect_resources: { + builtin: { + __21: "0x080000000000000000000000", + __22: "0x080000000100000000000000", + _onnx__Concat_3209: "0x080000000004000000000000" + } + } +#-}