From 4446fa00d8258311867496fc79d0b1dddd22a972 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 20 Feb 2024 08:54:02 -0800 Subject: [PATCH] Migrate passes in TorchConversion to use FunctionOpInterface. (#2935) This enables better re-use in downstreams which use different func implementations and should have no impact on those that don't except in opt pipelines if using the old form. With interfaces, explicit pipelines via `--pass-pipeline=` must be used. --- .../Dialect/TorchConversion/Transforms/Passes.h | 10 ++++++---- .../Dialect/TorchConversion/Transforms/Passes.td | 6 +++--- .../Transforms/BackendTypeConversionPasses.cpp | 4 ++-- .../Transforms/ConvertCustomQuantOp.cpp | 2 +- lib/Dialect/TorchConversion/Transforms/PassDetail.h | 2 +- lib/Dialect/TorchConversion/Transforms/Passes.cpp | 1 + .../TorchConversion/Transforms/UnpackQuantTensor.cpp | 2 +- .../TorchConversion/convert-custom-quant-op.mlir | 2 +- .../finalizing-backend-type-conversion.mlir | 2 +- test/Dialect/TorchConversion/unpack-quant-tensor.mlir | 2 +- 10 files changed, 18 insertions(+), 15 deletions(-) diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h index d762bd840f7f..2f70cf990219 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h @@ -10,7 +10,7 @@ #ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H #define TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H -#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -54,7 +54,7 @@ createVerifyStablehloBackendContractPass(); std::unique_ptr> createFuncBackendTypeConversionPass(); -std::unique_ptr> +std::unique_ptr> createFinalizingBackendTypeConversionPass(); // These passes do a one-off conversion of a specific kind of quantized group @@ -62,8 +62,10 @@ createFinalizingBackendTypeConversionPass(); // obviate them but that are being carried for now in order to unblock progress // on full integrations. See https://github.com/llvm/torch-mlir/issues/2417 for // the plan to support a more generalized lowering for these graphs. -std::unique_ptr> createUnpackQuantTensorPass(); -std::unique_ptr> createConvertCustomQuantOpPass(); +std::unique_ptr> +createUnpackQuantTensorPass(); +std::unique_ptr> +createConvertCustomQuantOpPass(); std::unique_ptr> createVerifyLinalgOnTensorsBackendContractPass(); diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td index 4d3e16a81c5c..73654c6f8034 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td @@ -22,7 +22,7 @@ def FuncBackendTypeConversion : Pass<"torch-func-backend-type-conversion", "Modu } def FinalizingBackendTypeConversion - : Pass<"torch-finalizing-backend-type-conversion", "func::FuncOp"> { + : InterfacePass<"torch-finalizing-backend-type-conversion", "mlir::FunctionOpInterface"> { let summary = "Finalizes a partial conversion to builtin tensors"; let constructor = "mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass()"; @@ -51,12 +51,12 @@ def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contra // The following passes are for a one-off conversion of a specific kind of quantized group matmul. // They should not be included in default lowering flows until further along. -def UnpackQuantTensor : Pass<"torch-unpack-quant-tensor", "func::FuncOp"> { +def UnpackQuantTensor : InterfacePass<"torch-unpack-quant-tensor", "mlir::FunctionOpInterface"> { let summary = "Unpack quantized int4 tensor from int8 containter"; let constructor = "mlir::torch::TorchConversion::createUnpackQuantTensorPass()"; } -def ConvertCustomQuantOp : Pass<"torch-convert-custom-quant-op", "func::FuncOp"> { +def ConvertCustomQuantOp : InterfacePass<"torch-convert-custom-quant-op", "mlir::FunctionOpInterface"> { let summary = "Convert torch custom quant op to linalg"; let constructor = "mlir::torch::TorchConversion::createConvertCustomQuantOpPass()"; } diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp index 5dd3d778f8f4..896dd9577617 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp @@ -115,7 +115,7 @@ static void setupFinalization(ConversionTarget &target, setupFinalization(target, patterns, typeConverter); } -static void stripTorchAttrs(func::FuncOp func) { +static void stripTorchAttrs(FunctionOpInterface func) { bool modified = false; SmallVector newAttrs; for (auto attr : func->getDialectAttrs()) { @@ -173,7 +173,7 @@ struct FinalizingBackendTypeConversionPass }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() { return std::make_unique(); } diff --git a/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp index 514d05234486..7bcb67b17c61 100644 --- a/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp +++ b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp @@ -229,7 +229,7 @@ class ConvertCustomQuantOpPass }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::TorchConversion::createConvertCustomQuantOpPass() { return std::make_unique(); } diff --git a/lib/Dialect/TorchConversion/Transforms/PassDetail.h b/lib/Dialect/TorchConversion/Transforms/PassDetail.h index 224ad8e2d89a..cb80ebd89a3c 100644 --- a/lib/Dialect/TorchConversion/Transforms/PassDetail.h +++ b/lib/Dialect/TorchConversion/Transforms/PassDetail.h @@ -10,7 +10,7 @@ #ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H #define TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H -#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" namespace mlir { diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 673d7083f585..9ff447371a76 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -9,6 +9,7 @@ #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" #include "mlir/Conversion/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" diff --git a/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp b/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp index 25f325399f12..064c87f6e6a8 100644 --- a/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp +++ b/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp @@ -137,7 +137,7 @@ class UnpackQuantTensorPass }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::TorchConversion::createUnpackQuantTensorPass() { return std::make_unique(); } diff --git a/test/Dialect/TorchConversion/convert-custom-quant-op.mlir b/test/Dialect/TorchConversion/convert-custom-quant-op.mlir index 4f72f24e8868..7aca3551cfc2 100644 --- a/test/Dialect/TorchConversion/convert-custom-quant-op.mlir +++ b/test/Dialect/TorchConversion/convert-custom-quant-op.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt %s -torch-convert-custom-quant-op -split-input-file -verify-diagnostics | FileCheck %s +// RUN: torch-mlir-opt %s '-pass-pipeline=builtin.module(func.func(torch-convert-custom-quant-op))' -split-input-file -verify-diagnostics | FileCheck %s // CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK: #map1 = affine_map<(d0, d1, d2) -> (d0, d1, 0)> diff --git a/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir b/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir index 46f80c06b4ce..57077a723ada 100644 --- a/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir +++ b/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt %s -torch-finalizing-backend-type-conversion -split-input-file -verify-diagnostics -allow-unregistered-dialect | FileCheck %s +// RUN: torch-mlir-opt %s '-pass-pipeline=builtin.module(func.func(torch-finalizing-backend-type-conversion))' -split-input-file -verify-diagnostics -allow-unregistered-dialect | FileCheck %s // This test is largely copied from `finalizing-bufferize` upstream, as it // covers the same scope. diff --git a/test/Dialect/TorchConversion/unpack-quant-tensor.mlir b/test/Dialect/TorchConversion/unpack-quant-tensor.mlir index 0ca64ae09397..8fa1a775b66d 100644 --- a/test/Dialect/TorchConversion/unpack-quant-tensor.mlir +++ b/test/Dialect/TorchConversion/unpack-quant-tensor.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt %s -torch-unpack-quant-tensor -split-input-file -verify-diagnostics | FileCheck %s +// RUN: torch-mlir-opt %s '-pass-pipeline=builtin.module(func.func(torch-unpack-quant-tensor))' -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: func @forward func.func @forward(%arg0: !torch.vtensor<[1,1,8],f16>) -> !torch.vtensor<[1,1,8],f16> {