Skip to content

Commit

Permalink
Migrate passes in TorchConversion to use FunctionOpInterface. (llvm#2935
Browse files Browse the repository at this point in the history
)

This enables better re-use in downstreams which use different func
implementations and should have no impact on those that don't except in
opt pipelines if using the old form. With interfaces, explicit pipelines
via `--pass-pipeline=` must be used.
  • Loading branch information
stellaraccident authored Feb 20, 2024
1 parent 135c81a commit 4446fa0
Show file tree
Hide file tree
Showing 10 changed files with 18 additions and 15 deletions.
10 changes: 6 additions & 4 deletions include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H
#define TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"

Expand Down Expand Up @@ -54,16 +54,18 @@ createVerifyStablehloBackendContractPass();

std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();

std::unique_ptr<OperationPass<func::FuncOp>>
std::unique_ptr<InterfacePass<FunctionOpInterface>>
createFinalizingBackendTypeConversionPass();

// These passes do a one-off conversion of a specific kind of quantized group
// matmul as a prototype. Generalized quantized operation handling will likely
// obviate them but that are being carried for now in order to unblock progress
// on full integrations. See https://github.com/llvm/torch-mlir/issues/2417 for
// the plan to support a more generalized lowering for these graphs.
std::unique_ptr<OperationPass<func::FuncOp>> createUnpackQuantTensorPass();
std::unique_ptr<OperationPass<func::FuncOp>> createConvertCustomQuantOpPass();
std::unique_ptr<InterfacePass<FunctionOpInterface>>
createUnpackQuantTensorPass();
std::unique_ptr<InterfacePass<FunctionOpInterface>>
createConvertCustomQuantOpPass();

std::unique_ptr<OperationPass<ModuleOp>>
createVerifyLinalgOnTensorsBackendContractPass();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def FuncBackendTypeConversion : Pass<"torch-func-backend-type-conversion", "Modu
}

def FinalizingBackendTypeConversion
: Pass<"torch-finalizing-backend-type-conversion", "func::FuncOp"> {
: InterfacePass<"torch-finalizing-backend-type-conversion", "mlir::FunctionOpInterface"> {
let summary = "Finalizes a partial conversion to builtin tensors";
let constructor =
"mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass()";
Expand Down Expand Up @@ -51,12 +51,12 @@ def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contra

// The following passes are for a one-off conversion of a specific kind of quantized group matmul.
// They should not be included in default lowering flows until further along.
def UnpackQuantTensor : Pass<"torch-unpack-quant-tensor", "func::FuncOp"> {
def UnpackQuantTensor : InterfacePass<"torch-unpack-quant-tensor", "mlir::FunctionOpInterface"> {
let summary = "Unpack quantized int4 tensor from int8 containter";
let constructor = "mlir::torch::TorchConversion::createUnpackQuantTensorPass()";
}

def ConvertCustomQuantOp : Pass<"torch-convert-custom-quant-op", "func::FuncOp"> {
def ConvertCustomQuantOp : InterfacePass<"torch-convert-custom-quant-op", "mlir::FunctionOpInterface"> {
let summary = "Convert torch custom quant op to linalg";
let constructor = "mlir::torch::TorchConversion::createConvertCustomQuantOpPass()";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ static void setupFinalization(ConversionTarget &target,
setupFinalization<OpTy2, OpTys...>(target, patterns, typeConverter);
}

static void stripTorchAttrs(func::FuncOp func) {
static void stripTorchAttrs(FunctionOpInterface func) {
bool modified = false;
SmallVector<NamedAttribute> newAttrs;
for (auto attr : func->getDialectAttrs()) {
Expand Down Expand Up @@ -173,7 +173,7 @@ struct FinalizingBackendTypeConversionPass
};
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
std::unique_ptr<InterfacePass<FunctionOpInterface>>
mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() {
return std::make_unique<FinalizingBackendTypeConversionPass>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ class ConvertCustomQuantOpPass
};
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
std::unique_ptr<InterfacePass<FunctionOpInterface>>
mlir::torch::TorchConversion::createConvertCustomQuantOpPass() {
return std::make_unique<ConvertCustomQuantOpPass>();
}
2 changes: 1 addition & 1 deletion lib/Dialect/TorchConversion/Transforms/PassDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H
#define TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"

namespace mlir {
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TorchConversion/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class UnpackQuantTensorPass
};
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
std::unique_ptr<InterfacePass<FunctionOpInterface>>
mlir::torch::TorchConversion::createUnpackQuantTensorPass() {
return std::make_unique<UnpackQuantTensorPass>();
}
2 changes: 1 addition & 1 deletion test/Dialect/TorchConversion/convert-custom-quant-op.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: torch-mlir-opt %s -torch-convert-custom-quant-op -split-input-file -verify-diagnostics | FileCheck %s
// RUN: torch-mlir-opt %s '-pass-pipeline=builtin.module(func.func(torch-convert-custom-quant-op))' -split-input-file -verify-diagnostics | FileCheck %s

// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK: #map1 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: torch-mlir-opt %s -torch-finalizing-backend-type-conversion -split-input-file -verify-diagnostics -allow-unregistered-dialect | FileCheck %s
// RUN: torch-mlir-opt %s '-pass-pipeline=builtin.module(func.func(torch-finalizing-backend-type-conversion))' -split-input-file -verify-diagnostics -allow-unregistered-dialect | FileCheck %s

// This test is largely copied from `finalizing-bufferize` upstream, as it
// covers the same scope.
Expand Down
2 changes: 1 addition & 1 deletion test/Dialect/TorchConversion/unpack-quant-tensor.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: torch-mlir-opt %s -torch-unpack-quant-tensor -split-input-file -verify-diagnostics | FileCheck %s
// RUN: torch-mlir-opt %s '-pass-pipeline=builtin.module(func.func(torch-unpack-quant-tensor))' -split-input-file -verify-diagnostics | FileCheck %s

// CHECK-LABEL: func @forward
func.func @forward(%arg0: !torch.vtensor<[1,1,8],f16>) -> !torch.vtensor<[1,1,8],f16> {
Expand Down

0 comments on commit 4446fa0

Please sign in to comment.