Skip to content

Commit

Permalink
[TorchToLinalg] Lower aten.cat to tensor.concat (#2650)
Browse files Browse the repository at this point in the history
This replaces the lowering of aten.cat with tensor.concat, allowing more
efficient handling of concatenations in downstream flows. The refbackend
populates concat decomposition patterns that can be used to recover the
previous lowering.
  • Loading branch information
qedawkins authored Dec 15, 2023
1 parent 061af69 commit 030b014
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 45 deletions.
2 changes: 2 additions & 0 deletions include/torch-mlir/RefBackend/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createMLProgramBufferizePass();

std::unique_ptr<OperationPass<func::FuncOp>> createMungeMemrefCopyPass();

std::unique_ptr<OperationPass<func::FuncOp>> createGeneralizeTensorConcatPass();

std::unique_ptr<OperationPass<func::FuncOp>> createGeneralizeTensorPadPass();
} // namespace RefBackend
} // namespace torch
Expand Down
5 changes: 5 additions & 0 deletions include/torch-mlir/RefBackend/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ def MungeMemrefCopy : Pass<"refback-munge-memref-copy", "func::FuncOp"> {
let dependentDialects = ["memref::MemRefDialect"];
}

def GeneralizeTensorConcat : Pass<"refback-generalize-tensor-concat", "func::FuncOp"> {
let summary = "Convert tensor.concat to other tensor ops";
let constructor = "mlir::torch::RefBackend::createGeneralizeTensorConcatPass()";
}

def GeneralizeTensorPad : Pass<"refback-generalize-tensor-pad", "func::FuncOp"> {
let summary = "Convert tensor.pad to linalg ops";
let constructor = "mlir::torch::RefBackend::createGeneralizeTensorPadPass()";
Expand Down
51 changes: 7 additions & 44 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1033,8 +1033,11 @@ class ConvertAtenCatOp : public OpConversionPattern<AtenCatOp> {

auto outElemType = newResultType.getElementType();
for (size_t i = 0; i < tensors.size(); ++i) {
tensors[i] = torch_to_linalg::convertTensorToElementType(
rewriter, loc, tensors[i], outElemType);
auto inputType = cast<RankedTensorType>(tensors[i].getType());
if (inputType.getElementType() != outElemType) {
tensors[i] = torch_to_linalg::convertTensorToElementType(
rewriter, loc, tensors[i], outElemType);
}
}

int rank = newResultType.getRank();
Expand All @@ -1046,48 +1049,8 @@ class ConvertAtenCatOp : public OpConversionPattern<AtenCatOp> {
if (!isValidDim(dim, rank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");

SmallVector<Value> offsets, sizes, strides;
sizes.reserve(rank);
strides.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 1));
offsets.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 0));

for (int i = 0; i < rank; ++i)
sizes.push_back(rewriter.createOrFold<tensor::DimOp>(loc, tensors[0], i));

// Calculate the size of the `dim` result dimension by adding the dim size
// of each tensor together.
Value resultDimSize = sizes[dim];

Value dimIndex = rewriter.createOrFold<arith::ConstantOp>(
loc, rewriter.getIndexAttr(dim));
for (auto tensor : ArrayRef(tensors).drop_front()) {
auto size = rewriter.createOrFold<tensor::DimOp>(loc, tensor, dimIndex);
resultDimSize =
rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size);
}
sizes[dim] = resultDimSize;

auto toOpFoldResult = [](Value v) -> OpFoldResult {
auto op = v.getDefiningOp<arith::ConstantIndexOp>();
if (!op)
return v;
return op.getValue();
};

Value result = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(sizes), newResultType.getElementType());
for (auto tensor : tensors) {
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, tensor);
result = rewriter.createOrFold<tensor::InsertSliceOp>(
loc, tensor, result,
llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)),
llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)),
llvm::to_vector(llvm::map_range(strides, toOpFoldResult)));
offsets[dim] =
rewriter.createOrFold<arith::AddIOp>(loc, offsets[dim], sizes[dim]);
}

rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, result);
rewriter.replaceOpWithNewOp<tensor::ConcatOp>(op, newResultType, dim,
tensors);
return success();
}
};
Expand Down
27 changes: 26 additions & 1 deletion lib/RefBackend/RefBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Approximation.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
Expand Down Expand Up @@ -436,6 +438,29 @@ mlir::torch::RefBackend::createMungeMemrefCopyPass() {
return std::make_unique<MungeMemrefCopy>();
}

namespace {
class GeneralizeTensorConcat
: public GeneralizeTensorConcatBase<GeneralizeTensorConcat> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tensor::TensorDialect>();
}

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
tensor::populateDecomposeTensorConcatPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::RefBackend::createGeneralizeTensorConcatPass() {
return std::make_unique<GeneralizeTensorConcat>();
}

namespace {
class GeneralizeTensorPad
: public GeneralizeTensorPadBase<GeneralizeTensorPad> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def invoke(*args):

LOWERING_PIPELINE = "builtin.module(" + ",".join([
"func.func(refback-generalize-tensor-pad)",
"func.func(refback-generalize-tensor-concat)",
# Apply some optimizations. It would be great if MLIR had more useful
# optimizations that worked out of the box here.
# Note: When measured, this doesn't seem to actually help that much
Expand Down
38 changes: 38 additions & 0 deletions test/Conversion/TorchToLinalg/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,41 @@ func.func @torch.aten.neg.f16(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtenso
%0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f16> -> !torch.vtensor<[?,?],f16>
return %0 : !torch.vtensor<[?,?],f16>
}

// -----

// CHECK-LABEL: func.func @torch.aten.cat$convert(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[ARG0]], %[[ARG1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32>) -> !torch.list<vtensor>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T2:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
// CHECK: %[[T3:.*]] = linalg.generic {{.*}} ins(%[[T2]] : tensor<?x?xi32>) outs(%{{.*}}: tensor<?x?xf32>)
// CHECK: %[[T4:.*]] = tensor.concat dim(0) %[[T1]], %[[T3]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.cat$convert(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> {
%int0 = torch.constant.int 0
%0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32>) -> !torch.list<vtensor>
%1 = torch.aten.cat %0, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?],f32>
return %1 : !torch.vtensor<[?,?],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.cat(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %int0 = torch.constant.int 0
// CHECK: %[[VAL_0:.*]] = torch.prim.ListConstruct %[[ARG_0]], %[[ARG_1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list<vtensor>
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = tensor.concat dim(0) %[[VAL_1]], %[[VAL_2]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.cat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%int0 = torch.constant.int 0
%0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list<vtensor>
%1 = torch.aten.cat %0, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?],f32>
return %1 : !torch.vtensor<[?,?],f32>
}

0 comments on commit 030b014

Please sign in to comment.