diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h index d762bd840f7f..2f70cf990219 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h @@ -10,7 +10,7 @@ #ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H #define TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H -#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -54,7 +54,7 @@ createVerifyStablehloBackendContractPass(); std::unique_ptr> createFuncBackendTypeConversionPass(); -std::unique_ptr> +std::unique_ptr> createFinalizingBackendTypeConversionPass(); // These passes do a one-off conversion of a specific kind of quantized group @@ -62,8 +62,10 @@ createFinalizingBackendTypeConversionPass(); // obviate them but that are being carried for now in order to unblock progress // on full integrations. See https://github.com/llvm/torch-mlir/issues/2417 for // the plan to support a more generalized lowering for these graphs. -std::unique_ptr> createUnpackQuantTensorPass(); -std::unique_ptr> createConvertCustomQuantOpPass(); +std::unique_ptr> +createUnpackQuantTensorPass(); +std::unique_ptr> +createConvertCustomQuantOpPass(); std::unique_ptr> createVerifyLinalgOnTensorsBackendContractPass(); diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td index 4d3e16a81c5c..73654c6f8034 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td @@ -22,7 +22,7 @@ def FuncBackendTypeConversion : Pass<"torch-func-backend-type-conversion", "Modu } def FinalizingBackendTypeConversion - : Pass<"torch-finalizing-backend-type-conversion", "func::FuncOp"> { + : InterfacePass<"torch-finalizing-backend-type-conversion", "mlir::FunctionOpInterface"> { let summary = "Finalizes a partial conversion to builtin tensors"; let constructor = "mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass()"; @@ -51,12 +51,12 @@ def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contra // The following passes are for a one-off conversion of a specific kind of quantized group matmul. // They should not be included in default lowering flows until further along. -def UnpackQuantTensor : Pass<"torch-unpack-quant-tensor", "func::FuncOp"> { +def UnpackQuantTensor : InterfacePass<"torch-unpack-quant-tensor", "mlir::FunctionOpInterface"> { let summary = "Unpack quantized int4 tensor from int8 containter"; let constructor = "mlir::torch::TorchConversion::createUnpackQuantTensorPass()"; } -def ConvertCustomQuantOp : Pass<"torch-convert-custom-quant-op", "func::FuncOp"> { +def ConvertCustomQuantOp : InterfacePass<"torch-convert-custom-quant-op", "mlir::FunctionOpInterface"> { let summary = "Convert torch custom quant op to linalg"; let constructor = "mlir::torch::TorchConversion::createConvertCustomQuantOpPass()"; } diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp index 5dd3d778f8f4..896dd9577617 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp @@ -115,7 +115,7 @@ static void setupFinalization(ConversionTarget &target, setupFinalization(target, patterns, typeConverter); } -static void stripTorchAttrs(func::FuncOp func) { +static void stripTorchAttrs(FunctionOpInterface func) { bool modified = false; SmallVector newAttrs; for (auto attr : func->getDialectAttrs()) { @@ -173,7 +173,7 @@ struct FinalizingBackendTypeConversionPass }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() { return std::make_unique(); } diff --git a/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp index 514d05234486..7bcb67b17c61 100644 --- a/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp +++ b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp @@ -229,7 +229,7 @@ class ConvertCustomQuantOpPass }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::TorchConversion::createConvertCustomQuantOpPass() { return std::make_unique(); } diff --git a/lib/Dialect/TorchConversion/Transforms/PassDetail.h b/lib/Dialect/TorchConversion/Transforms/PassDetail.h index 224ad8e2d89a..cb80ebd89a3c 100644 --- a/lib/Dialect/TorchConversion/Transforms/PassDetail.h +++ b/lib/Dialect/TorchConversion/Transforms/PassDetail.h @@ -10,7 +10,7 @@ #ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H #define TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H -#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" namespace mlir { diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 673d7083f585..9ff447371a76 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -9,6 +9,7 @@ #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" #include "mlir/Conversion/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" diff --git a/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp b/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp index 25f325399f12..064c87f6e6a8 100644 --- a/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp +++ b/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp @@ -137,7 +137,7 @@ class UnpackQuantTensorPass }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::TorchConversion::createUnpackQuantTensorPass() { return std::make_unique(); } diff --git a/test/Dialect/TorchConversion/convert-custom-quant-op.mlir b/test/Dialect/TorchConversion/convert-custom-quant-op.mlir index 4f72f24e8868..7aca3551cfc2 100644 --- a/test/Dialect/TorchConversion/convert-custom-quant-op.mlir +++ b/test/Dialect/TorchConversion/convert-custom-quant-op.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt %s -torch-convert-custom-quant-op -split-input-file -verify-diagnostics | FileCheck %s +// RUN: torch-mlir-opt %s '-pass-pipeline=builtin.module(func.func(torch-convert-custom-quant-op))' -split-input-file -verify-diagnostics | FileCheck %s // CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK: #map1 = affine_map<(d0, d1, d2) -> (d0, d1, 0)> diff --git a/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir b/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir index 46f80c06b4ce..57077a723ada 100644 --- a/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir +++ b/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt %s -torch-finalizing-backend-type-conversion -split-input-file -verify-diagnostics -allow-unregistered-dialect | FileCheck %s +// RUN: torch-mlir-opt %s '-pass-pipeline=builtin.module(func.func(torch-finalizing-backend-type-conversion))' -split-input-file -verify-diagnostics -allow-unregistered-dialect | FileCheck %s // This test is largely copied from `finalizing-bufferize` upstream, as it // covers the same scope. diff --git a/test/Dialect/TorchConversion/unpack-quant-tensor.mlir b/test/Dialect/TorchConversion/unpack-quant-tensor.mlir index 0ca64ae09397..8fa1a775b66d 100644 --- a/test/Dialect/TorchConversion/unpack-quant-tensor.mlir +++ b/test/Dialect/TorchConversion/unpack-quant-tensor.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt %s -torch-unpack-quant-tensor -split-input-file -verify-diagnostics | FileCheck %s +// RUN: torch-mlir-opt %s '-pass-pipeline=builtin.module(func.func(torch-unpack-quant-tensor))' -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: func @forward func.func @forward(%arg0: !torch.vtensor<[1,1,8],f16>) -> !torch.vtensor<[1,1,8],f16> {