Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Onnx llama7b Folders #2860

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -8582,6 +8582,7 @@ def Torch_AtenTensorOp : Torch_Op<"aten.tensor", [
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenTensorBoolOp : Torch_Op<"aten.tensor.bool", [
Expand Down Expand Up @@ -8979,6 +8980,7 @@ def Torch_AtenArangeStartStepOp : Torch_Op<"aten.arange.start_step", [
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
let hasCanonicalizer = 1;
}

def Torch_AtenArangeStartOutOp : Torch_Op<"aten.arange.start_out", [
Expand Down Expand Up @@ -9784,6 +9786,7 @@ def Torch_AtenIndexSelectOp : Torch_Op<"aten.index_select", [
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasFolder = 1;
}

def Torch_Aten_IndexPutImplOp : Torch_Op<"aten._index_put_impl", [
Expand Down Expand Up @@ -12558,6 +12561,7 @@ def Torch_AtenSortOp : Torch_Op<"aten.sort", [
printDefaultTorchOp(printer, *this, 3, 2);
}
}];
let hasFolder = 1;
}

def Torch_AtenSplitTensorOp : Torch_Op<"aten.split.Tensor", [
Expand Down Expand Up @@ -14824,6 +14828,7 @@ def Torch_PrimNumToTensorScalarOp : Torch_Op<"prim.NumToTensor.Scalar", [
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasFolder = 1;
}

def Torch_PrimMinSelfIntOp : Torch_Op<"prim.min.self_int", [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ namespace TorchConversion {
/// linalg-on-tensors backend contract.
void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm);

/// Creates a pipeline that lowers from the onnx backend contract to the
/// linalg-on-tensors backend contract.
void createOnnxBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm);

/// Creates a pipeline that lowers from the torch backend contract to the
/// TOSA backend contract.
void createTorchBackendToTosaBackendPipeline(OpPassManager &pm);
Expand Down
19 changes: 19 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1561,6 +1561,25 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(allowzero, "allowzero", 0))
return failure();
// If the result shape is static then we can create a result shape list
// directly using the result shape values (integers).
if (resultType.hasSizes()) {
if (resultType.areAllSizesKnown()) {
SmallVector<Value> resultShape;
for (int64_t size : resultType.getSizes()) {
resultShape.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(size)));
}
Value resultShapeList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(
Torch::IntType::get(binder.op->getContext())),
resultShape);
rewriter.replaceOpWithNewOp<Torch::AtenReshapeOp>(
binder.op, resultType, data, resultShapeList);
return success();
}
}
Torch::BaseTensorType shapeType =
shape.getType().cast<Torch::BaseTensorType>();
SmallVector<Value> dimList;
Expand Down
191 changes: 179 additions & 12 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LogicalResult.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include <cstdint>
#define DEBUG_TYPE "torch-mlir-torch-dialect"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
Expand Down Expand Up @@ -737,8 +743,7 @@ OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) {
if (getOperand(0).getType() != getResult().getType())
return nullptr;
if (auto tensorType = getOperand(0).getType().dyn_cast<BaseTensorType>()) {
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
return getOperand(0);
return getOperand(0);
}
return nullptr;
}
Expand Down Expand Up @@ -1710,6 +1715,40 @@ void AtenSortIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}

//===----------------------------------------------------------------------===//
// AtenSortIntOp
//===----------------------------------------------------------------------===//

LogicalResult AtenSortOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
auto operand = getSelf();
auto operandTy = dyn_cast<ValueTensorType>(operand.getType());
auto iTTy = cast<ValueTensorType>(getResult(1).getType());
auto indicesTy = iTTy.toBuiltinTensor().clone(iTTy.getDtype());

if (!operandTy.hasSizes())
return failure();
if (!indicesTy.hasStaticShape())
return failure();

bool unaryDim = false;
IntegerAttr dimAttr = dyn_cast_or_null<IntegerAttr>(adaptor.getDim());
if (dimAttr) {
unaryDim = operandTy.getSizes()[dimAttr.getInt()] == 1;
}

OpBuilder b(getContext());
if (unaryDim || llvm::all_of(operandTy.getSizes(),
[](int64_t dim) { return dim == 1; })) {
results.push_back(operand);
results.push_back(DenseElementsAttr::get(
indicesTy, b.getZeroAttr(indicesTy.getElementType())));
return success();
}

return failure();
}

//===----------------------------------------------------------------------===//
// NonValueTensorLiteralOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2506,29 +2545,125 @@ OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//

OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
int64_t start, end, step;
if (matchPattern(getStart(), m_TorchConstantInt(&start)) &&
matchPattern(getEnd(), m_TorchConstantInt(&end)) &&
matchPattern(getStep(), m_TorchConstantInt(&step)) && step == 1 &&
start == 0 && end == std::numeric_limits<int64_t>::max())
DenseElementsAttr input =
dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
IntegerAttr start = dyn_cast_or_null<IntegerAttr>(adaptor.getStart());
IntegerAttr end = dyn_cast_or_null<IntegerAttr>(adaptor.getEnd());
IntegerAttr step = dyn_cast_or_null<IntegerAttr>(adaptor.getStep());
IntegerAttr dim = dyn_cast_or_null<IntegerAttr>(adaptor.getDim());

if (start && end && step && step.getInt() == 1 && start.getInt() == 0 &&
end.getInt() == std::numeric_limits<int64_t>::max())
return getOperand(0);

auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>();
auto outType = getResult().getType().dyn_cast<BaseTensorType>();
if (inType != outType)
return nullptr;
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
auto inType = getOperand(0).getType().dyn_cast<ValueTensorType>();
auto outType = getResult().getType().dyn_cast<ValueTensorType>();
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() ||
!inType.hasDtype() || !outType.hasDtype() ||
inType.getDtype() != outType.getDtype())
return nullptr;

if (inType.getSizes().size() != outType.getSizes().size() ||
!inType.areAllSizesKnown() || !outType.areAllSizesKnown())
return nullptr;

if (input && input.isSplat())
return DenseElementsAttr::get(
outType.toBuiltinTensor().clone(inType.getDtype()),
input.getSplatValue<Attribute>());

// If the output is a single value we can index into a constant input and grab
// that single value:
if (input && start && dim &&
llvm::all_of(outType.getSizes(), [](int64_t dim) { return dim == 1; })) {
bool unaryNonDim = true;
int64_t dimInt = dim.getInt();
for (int i = 0, s = inType.getSizes().size(); i < s; ++i) {
unaryNonDim &= inType.getSizes()[i] == 1 || i == dimInt;
}
if (unaryNonDim) {
Attribute value = input.getValues<Attribute>()[start.getInt()];
return DenseElementsAttr::get(
outType.toBuiltinTensor().clone(inType.getDtype()), value);
}
}

// If the input and output shapes are the same we can just fold:
for (size_t i = 0; i < inType.getSizes().size(); ++i) {
if (inType.getSizes()[i] != outType.getSizes()[i])
return nullptr;
}
return getOperand(0);
}

//===----------------------------------------------------------------------===//
// AtenIndexSelectOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) {
DenseElementsAttr attr;
if (matchPattern(getOperand(0), m_Constant(&attr))) {
// If the operand of the index_select op is a constant value tensor and the
// rank of the input and result operand is 1, and the shape dim is also 1,
// then it means that the index select op is not doing anything and is
// returning the same tensor. In this case, we can just fold the
// index_select op with its operand.
Value input = getOperand(0);
ValueTensorType inputType = input.getType().cast<ValueTensorType>();
std::optional<unsigned> inputRank = getTensorRank(input);
if (!inputRank || *inputRank != 1 || !inputType.hasSizes())
return nullptr;
SmallVector<int64_t> inputShape(inputType.getSizes());
// The input should be of rank 1 and the only shape dim should also be 1.
if (inputShape.size() != 1 && inputShape[0] != 1)
return nullptr;
return input;
}
return nullptr;
}

//===----------------------------------------------------------------------===//
// AtenArangeStartStepOp
//===----------------------------------------------------------------------===//

void AtenArangeStartStepOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](AtenArangeStartStepOp op, PatternRewriter &rewriter) {
// If the result of aten.arange.start_step is a rank 1 tensor with the shape
// value of that dim equal to 1, then we can just replace that with a
// constant tensor whose value would be equal to `start`.
Value result = op.getResult();
ValueTensorType resultType = result.getType().cast<ValueTensorType>();
std::optional<unsigned> resultRank = getTensorRank(result);
if (!resultRank || *resultRank != 1 || !resultType.hasSizes())
return failure();
SmallVector<int64_t> resultShape(resultType.getSizes());
// The result should be of rank 1 and the only shape dim should also be 1.
if (resultShape.size() != 1 && resultShape[0] != 1)
return failure();

if (!resultType.hasDtype() && !isa<Torch::IntType>(resultType.getDtype()))
return failure();

int64_t start, step, end;
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
return failure();
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step)))
return failure();
if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end)))
return failure();

auto elementsAttr = DenseElementsAttr::get(
RankedTensorType::get(
{1}, IntegerType::get(op->getContext(), 64, IntegerType::Signed)),
{start});

rewriter.replaceOpWithNewOp<ValueTensorLiteralOp>(op, resultType,
elementsAttr);
return success();
});
}

//===----------------------------------------------------------------------===//
// AtenMulIntOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2832,6 +2967,38 @@ OpFoldResult AtenItemOp::fold(FoldAdaptor adaptor) {
return nullptr;
}

if (auto tensorOp = dyn_cast<AtenTensorOp>(getOperand().getDefiningOp())) {
auto dataList = tensorOp.getData().getDefiningOp<PrimListConstructOp>();
if (!dataList)
return nullptr;
if (dataList.getNumOperands() != 1)
return nullptr;

int64_t dim;
if (!matchPattern(dataList->getOperands()[0], m_TorchConstantInt(&dim)))
return nullptr;
return getI64IntegerAttr(getContext(), dim);
}
return nullptr;
}

//===----------------------------------------------------------------------===//
// AtenTensorOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) {
if (getOperation()->use_empty())
getOperation()->erase();
return nullptr;
}

//===----------------------------------------------------------------------===//
// PrimNumToTensorScalarOp
//===----------------------------------------------------------------------===//

OpFoldResult PrimNumToTensorScalarOp::fold(FoldAdaptor adaptor) {
if (getOperation()->use_empty())
getOperation()->erase();
return nullptr;
}

Expand Down
15 changes: 14 additions & 1 deletion lib/Dialect/TorchConversion/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h"
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
Expand Down Expand Up @@ -48,7 +49,11 @@ void mlir::torch::registerTorchConversionPasses() {
"Pipeline lowering torch backend contract to linalg-on-tensors backend "
"contract.",
TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline);

mlir::PassPipelineRegistration<>(
"onnx-backend-to-linalg-on-tensors-backend-pipeline",
"Pipeline lowering onnx backend contract to linalg-on-tensors backend "
"contract.",
TorchConversion::createOnnxBackendToLinalgOnTensorsBackendPipeline);
mlir::PassPipelineRegistration<>(
"torch-backend-to-tosa-backend-pipeline",
"Pipeline lowering torch backend contract to TOSA backend "
Expand Down Expand Up @@ -103,6 +108,14 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
pm.addPass(TorchConversion::createVerifyLinalgOnTensorsBackendContractPass());
}

void TorchConversion::createOnnxBackendToLinalgOnTensorsBackendPipeline(
OpPassManager &pm) {
pm.addNestedPass<func::FuncOp>(onnx_c::createTorchOnnxToTorchPass());
Torch::TorchLoweringPipelineOptions options;
Torch::createTorchFunctionToTorchBackendPipeline(pm, options);
createTorchBackendToLinalgOnTensorsBackendPipeline(pm);
}

void TorchConversion::createTorchBackendToTosaBackendPipeline(
OpPassManager &pm) {
pm.addNestedPass<func::FuncOp>(createConvertTorchToTosaPass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::eye : (int, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::eye.m : (int, int, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)")
emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)", has_folder=True)
emit("aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)")
emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)")
emit("aten::scalar_tensor : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
Expand All @@ -586,7 +586,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::any.dim : (Tensor, int, bool) -> (Tensor)")
emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)", has_canonicalizer=True)
emit("aten::arange.start_out : (Scalar, Scalar, Scalar, Tensor) -> (Tensor)")
emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)")
emit("aten::argmin : (Tensor, int?, bool) -> (Tensor)")
Expand Down Expand Up @@ -616,7 +616,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)", has_folder=True)
emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)")
emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)")
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)")
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)", has_folder=True)
emit_with_mutating_variants("aten::_index_put_impl : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)")
emit("aten::item : (Tensor) -> (Scalar)", has_folder=True)
emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)")
Expand Down Expand Up @@ -728,7 +728,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::ne.int_list : (int[], int[]) -> (bool)")
emit("aten::any.bool : (bool[]) -> (bool)", has_folder=True)
emit("aten::sort.int : (int[], bool) -> ()", has_canonicalizer=True)
emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)")
emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)", has_folder=True)
emit("aten::split.Tensor : (Tensor, int, int) -> (Tensor[])")
emit("aten::split_with_sizes : (Tensor, int[], int) -> (Tensor[])")
emit("aten::unbind.int : (Tensor, int) -> (Tensor[])")
Expand Down Expand Up @@ -838,7 +838,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("prim::device : (Tensor) -> (Device)", has_canonicalizer=True)
emit("prim::dtype : (Tensor) -> (int)", has_folder=True)
emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True)
emit("prim::NumToTensor.Scalar : (Scalar) -> (Tensor)")
emit("prim::NumToTensor.Scalar : (Scalar) -> (Tensor)", has_folder=True)
emit("prim::min.self_int : (int[]) -> (int)", has_folder=True)
emit("prim::min.int : (int, int) -> (int)", has_folder=True)
emit("prim::max.self_int : (int[]) -> (int)")
Expand Down
Loading
Loading