diff --git a/include/torch-mlir/RefBackend/Passes.h b/include/torch-mlir/RefBackend/Passes.h index 8f1b2b525a22..be5e43a1e63c 100644 --- a/include/torch-mlir/RefBackend/Passes.h +++ b/include/torch-mlir/RefBackend/Passes.h @@ -31,6 +31,8 @@ std::unique_ptr> createMLProgramBufferizePass(); std::unique_ptr> createMungeMemrefCopyPass(); +std::unique_ptr> createGeneralizeTensorConcatPass(); + std::unique_ptr> createGeneralizeTensorPadPass(); } // namespace RefBackend } // namespace torch diff --git a/include/torch-mlir/RefBackend/Passes.td b/include/torch-mlir/RefBackend/Passes.td index 12d182e49e3a..3d8b7fd41b1b 100644 --- a/include/torch-mlir/RefBackend/Passes.td +++ b/include/torch-mlir/RefBackend/Passes.td @@ -35,6 +35,11 @@ def MungeMemrefCopy : Pass<"refback-munge-memref-copy", "func::FuncOp"> { let dependentDialects = ["memref::MemRefDialect"]; } +def GeneralizeTensorConcat : Pass<"refback-generalize-tensor-concat", "func::FuncOp"> { + let summary = "Convert tensor.concat to other tensor ops"; + let constructor = "mlir::torch::RefBackend::createGeneralizeTensorConcatPass()"; +} + def GeneralizeTensorPad : Pass<"refback-generalize-tensor-pad", "func::FuncOp"> { let summary = "Convert tensor.pad to linalg ops"; let constructor = "mlir::torch::RefBackend::createGeneralizeTensorPadPass()"; diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 4eb02215a8bf..dae387422b52 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -1033,8 +1033,11 @@ class ConvertAtenCatOp : public OpConversionPattern { auto outElemType = newResultType.getElementType(); for (size_t i = 0; i < tensors.size(); ++i) { - tensors[i] = torch_to_linalg::convertTensorToElementType( - rewriter, loc, tensors[i], outElemType); + auto inputType = cast(tensors[i].getType()); + if (inputType.getElementType() != outElemType) { + tensors[i] = torch_to_linalg::convertTensorToElementType( + rewriter, loc, tensors[i], outElemType); + } } int rank = newResultType.getRank(); @@ -1046,48 +1049,8 @@ class ConvertAtenCatOp : public OpConversionPattern { if (!isValidDim(dim, rank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - SmallVector offsets, sizes, strides; - sizes.reserve(rank); - strides.resize(rank, rewriter.create(loc, 1)); - offsets.resize(rank, rewriter.create(loc, 0)); - - for (int i = 0; i < rank; ++i) - sizes.push_back(rewriter.createOrFold(loc, tensors[0], i)); - - // Calculate the size of the `dim` result dimension by adding the dim size - // of each tensor together. - Value resultDimSize = sizes[dim]; - - Value dimIndex = rewriter.createOrFold( - loc, rewriter.getIndexAttr(dim)); - for (auto tensor : ArrayRef(tensors).drop_front()) { - auto size = rewriter.createOrFold(loc, tensor, dimIndex); - resultDimSize = - rewriter.createOrFold(loc, resultDimSize, size); - } - sizes[dim] = resultDimSize; - - auto toOpFoldResult = [](Value v) -> OpFoldResult { - auto op = v.getDefiningOp(); - if (!op) - return v; - return op.getValue(); - }; - - Value result = rewriter.create( - loc, getAsOpFoldResult(sizes), newResultType.getElementType()); - for (auto tensor : tensors) { - SmallVector sizes = getTensorSizes(rewriter, loc, tensor); - result = rewriter.createOrFold( - loc, tensor, result, - llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)), - llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)), - llvm::to_vector(llvm::map_range(strides, toOpFoldResult))); - offsets[dim] = - rewriter.createOrFold(loc, offsets[dim], sizes[dim]); - } - - rewriter.replaceOpWithNewOp(op, newResultType, result); + rewriter.replaceOpWithNewOp(op, newResultType, dim, + tensors); return success(); } }; diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index 481bdf3426d8..4ada196e944c 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -20,10 +20,12 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Approximation.h" #include "mlir/Dialect/Math/Transforms/Passes.h" -#include "mlir/Dialect/MLProgram/IR/MLProgram.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" @@ -436,6 +438,29 @@ mlir::torch::RefBackend::createMungeMemrefCopyPass() { return std::make_unique(); } +namespace { +class GeneralizeTensorConcat + : public GeneralizeTensorConcatBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + tensor::populateDecomposeTensorConcatPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> +mlir::torch::RefBackend::createGeneralizeTensorConcatPass() { + return std::make_unique(); +} + namespace { class GeneralizeTensorPad : public GeneralizeTensorPadBase { diff --git a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index 1b9dbb0d2c51..266459e00b0c 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -123,6 +123,7 @@ def invoke(*args): LOWERING_PIPELINE = "builtin.module(" + ",".join([ "func.func(refback-generalize-tensor-pad)", + "func.func(refback-generalize-tensor-concat)", # Apply some optimizations. It would be great if MLIR had more useful # optimizations that worked out of the box here. # Note: When measured, this doesn't seem to actually help that much diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index eba7546655e9..0aaca941b0d9 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -287,3 +287,41 @@ func.func @torch.aten.neg.f16(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtenso %0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f16> -> !torch.vtensor<[?,?],f16> return %0 : !torch.vtensor<[?,?],f16> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.cat$convert( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[ARG0]], %[[ARG1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32>) -> !torch.list +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T2:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[T3:.*]] = linalg.generic {{.*}} ins(%[[T2]] : tensor) outs(%{{.*}}: tensor) +// CHECK: %[[T4:.*]] = tensor.concat dim(0) %[[T1]], %[[T3]] : (tensor, tensor) -> tensor +// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.cat$convert(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> { + %int0 = torch.constant.int 0 + %0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32>) -> !torch.list + %1 = torch.aten.cat %0, %int0 : !torch.list, !torch.int -> !torch.vtensor<[?,?],f32> + return %1 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.cat( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %int0 = torch.constant.int 0 +// CHECK: %[[VAL_0:.*]] = torch.prim.ListConstruct %[[ARG_0]], %[[ARG_1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = tensor.concat dim(0) %[[VAL_1]], %[[VAL_2]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.cat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int0 = torch.constant.int 0 + %0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list + %1 = torch.aten.cat %0, %int0 : !torch.list, !torch.int -> !torch.vtensor<[?,?],f32> + return %1 : !torch.vtensor<[?,?],f32> +}