From 030b0140d45559743dff85573ca00ba10cce7a5a Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Fri, 15 Dec 2023 15:45:32 -0500 Subject: [PATCH] [TorchToLinalg] Lower aten.cat to tensor.concat (#2650) This replaces the lowering of aten.cat with tensor.concat, allowing more efficient handling of concatenations in downstream flows. The refbackend populates concat decomposition patterns that can be used to recover the previous lowering. --- include/torch-mlir/RefBackend/Passes.h | 2 + include/torch-mlir/RefBackend/Passes.td | 5 ++ lib/Conversion/TorchToLinalg/DataMovement.cpp | 51 +++---------------- lib/RefBackend/RefBackend.cpp | 27 +++++++++- .../linalg_on_tensors_backends/refbackend.py | 1 + test/Conversion/TorchToLinalg/basic.mlir | 38 ++++++++++++++ 6 files changed, 79 insertions(+), 45 deletions(-) diff --git a/include/torch-mlir/RefBackend/Passes.h b/include/torch-mlir/RefBackend/Passes.h index 8f1b2b525a22..be5e43a1e63c 100644 --- a/include/torch-mlir/RefBackend/Passes.h +++ b/include/torch-mlir/RefBackend/Passes.h @@ -31,6 +31,8 @@ std::unique_ptr> createMLProgramBufferizePass(); std::unique_ptr> createMungeMemrefCopyPass(); +std::unique_ptr> createGeneralizeTensorConcatPass(); + std::unique_ptr> createGeneralizeTensorPadPass(); } // namespace RefBackend } // namespace torch diff --git a/include/torch-mlir/RefBackend/Passes.td b/include/torch-mlir/RefBackend/Passes.td index 12d182e49e3a..3d8b7fd41b1b 100644 --- a/include/torch-mlir/RefBackend/Passes.td +++ b/include/torch-mlir/RefBackend/Passes.td @@ -35,6 +35,11 @@ def MungeMemrefCopy : Pass<"refback-munge-memref-copy", "func::FuncOp"> { let dependentDialects = ["memref::MemRefDialect"]; } +def GeneralizeTensorConcat : Pass<"refback-generalize-tensor-concat", "func::FuncOp"> { + let summary = "Convert tensor.concat to other tensor ops"; + let constructor = "mlir::torch::RefBackend::createGeneralizeTensorConcatPass()"; +} + def GeneralizeTensorPad : Pass<"refback-generalize-tensor-pad", "func::FuncOp"> { let summary = "Convert tensor.pad to linalg ops"; let constructor = "mlir::torch::RefBackend::createGeneralizeTensorPadPass()"; diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 4eb02215a8bf..dae387422b52 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -1033,8 +1033,11 @@ class ConvertAtenCatOp : public OpConversionPattern { auto outElemType = newResultType.getElementType(); for (size_t i = 0; i < tensors.size(); ++i) { - tensors[i] = torch_to_linalg::convertTensorToElementType( - rewriter, loc, tensors[i], outElemType); + auto inputType = cast(tensors[i].getType()); + if (inputType.getElementType() != outElemType) { + tensors[i] = torch_to_linalg::convertTensorToElementType( + rewriter, loc, tensors[i], outElemType); + } } int rank = newResultType.getRank(); @@ -1046,48 +1049,8 @@ class ConvertAtenCatOp : public OpConversionPattern { if (!isValidDim(dim, rank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - SmallVector offsets, sizes, strides; - sizes.reserve(rank); - strides.resize(rank, rewriter.create(loc, 1)); - offsets.resize(rank, rewriter.create(loc, 0)); - - for (int i = 0; i < rank; ++i) - sizes.push_back(rewriter.createOrFold(loc, tensors[0], i)); - - // Calculate the size of the `dim` result dimension by adding the dim size - // of each tensor together. - Value resultDimSize = sizes[dim]; - - Value dimIndex = rewriter.createOrFold( - loc, rewriter.getIndexAttr(dim)); - for (auto tensor : ArrayRef(tensors).drop_front()) { - auto size = rewriter.createOrFold(loc, tensor, dimIndex); - resultDimSize = - rewriter.createOrFold(loc, resultDimSize, size); - } - sizes[dim] = resultDimSize; - - auto toOpFoldResult = [](Value v) -> OpFoldResult { - auto op = v.getDefiningOp(); - if (!op) - return v; - return op.getValue(); - }; - - Value result = rewriter.create( - loc, getAsOpFoldResult(sizes), newResultType.getElementType()); - for (auto tensor : tensors) { - SmallVector sizes = getTensorSizes(rewriter, loc, tensor); - result = rewriter.createOrFold( - loc, tensor, result, - llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)), - llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)), - llvm::to_vector(llvm::map_range(strides, toOpFoldResult))); - offsets[dim] = - rewriter.createOrFold(loc, offsets[dim], sizes[dim]); - } - - rewriter.replaceOpWithNewOp(op, newResultType, result); + rewriter.replaceOpWithNewOp(op, newResultType, dim, + tensors); return success(); } }; diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index 481bdf3426d8..4ada196e944c 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -20,10 +20,12 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Approximation.h" #include "mlir/Dialect/Math/Transforms/Passes.h" -#include "mlir/Dialect/MLProgram/IR/MLProgram.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" @@ -436,6 +438,29 @@ mlir::torch::RefBackend::createMungeMemrefCopyPass() { return std::make_unique(); } +namespace { +class GeneralizeTensorConcat + : public GeneralizeTensorConcatBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + tensor::populateDecomposeTensorConcatPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> +mlir::torch::RefBackend::createGeneralizeTensorConcatPass() { + return std::make_unique(); +} + namespace { class GeneralizeTensorPad : public GeneralizeTensorPadBase { diff --git a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index 1b9dbb0d2c51..266459e00b0c 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -123,6 +123,7 @@ def invoke(*args): LOWERING_PIPELINE = "builtin.module(" + ",".join([ "func.func(refback-generalize-tensor-pad)", + "func.func(refback-generalize-tensor-concat)", # Apply some optimizations. It would be great if MLIR had more useful # optimizations that worked out of the box here. # Note: When measured, this doesn't seem to actually help that much diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index eba7546655e9..0aaca941b0d9 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -287,3 +287,41 @@ func.func @torch.aten.neg.f16(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtenso %0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f16> -> !torch.vtensor<[?,?],f16> return %0 : !torch.vtensor<[?,?],f16> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.cat$convert( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[ARG0]], %[[ARG1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32>) -> !torch.list +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T2:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[T3:.*]] = linalg.generic {{.*}} ins(%[[T2]] : tensor) outs(%{{.*}}: tensor) +// CHECK: %[[T4:.*]] = tensor.concat dim(0) %[[T1]], %[[T3]] : (tensor, tensor) -> tensor +// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.cat$convert(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> { + %int0 = torch.constant.int 0 + %0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32>) -> !torch.list + %1 = torch.aten.cat %0, %int0 : !torch.list, !torch.int -> !torch.vtensor<[?,?],f32> + return %1 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.cat( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %int0 = torch.constant.int 0 +// CHECK: %[[VAL_0:.*]] = torch.prim.ListConstruct %[[ARG_0]], %[[ARG_1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = tensor.concat dim(0) %[[VAL_1]], %[[VAL_2]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.cat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int0 = torch.constant.int 0 + %0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list + %1 = torch.aten.cat %0, %int0 : !torch.list, !torch.int -> !torch.vtensor<[?,?],f32> + return %1 : !torch.vtensor<[?,?],f32> +}