From 6f3d62ab04e91bbe67d51b3c0b467f12fc3ed870 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 28 Feb 2024 12:04:52 -0800 Subject: [PATCH] [torch] Fix folders and `cat` and `view` torch lowerings (#2963) A bunch of small fixes are interlinked and trigger crashes if not addressed as a group. This includes: - aten view when expand from a rank-0 tensor - slice folder with negative indices - `aten._shape_as_tensor` folder on a rank-0 tensor - `aten.cat` of a tensor with a length-0 tensor --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 1 + lib/Conversion/TorchToLinalg/DataMovement.cpp | 38 ++++++---- lib/Dialect/Torch/IR/TorchOps.cpp | 76 ++++++++++++------- projects/pt1/e2e_testing/xfail_sets.py | 11 --- .../build_tools/torch_ods_gen.py | 2 +- test/Dialect/Torch/canonicalize.mlir | 16 +++- 6 files changed, 89 insertions(+), 55 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 57b15ed18f4e..9d245723fd84 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -8925,6 +8925,7 @@ def Torch_Aten_ShapeAsTensorOp : Torch_Op<"aten._shape_as_tensor", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasFolder = 1; } def Torch_AtenIsnanOp : Torch_Op<"aten.isnan", [ diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index d9132317e32f..42aacceab0b4 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -799,10 +799,15 @@ class ConvertAtenViewOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "result shape of rank 0 is invalid"); - // TODO: add support for case inputRank 0 expanded to size 1 - if (inputRank == 0) - return rewriter.notifyMatchFailure( - op, "unimplemented: input rank 0 is not supported"); + if (inputRank == 0) { + Value expanded = + rewriter + .create(loc, resultType, input, + ArrayRef()) + .getResult(); + rewriter.replaceOp(op, expanded); + return success(); + } // Extract the desired output size as a list of integers. This list should // have been created using the operation `torch.prim.ListConstruct`. @@ -1500,6 +1505,14 @@ class ConvertAtenCatOp : public OpConversionPattern { RankedTensorType newResultType = typeConverter->convertType(op.getType()).cast(); + int rank = newResultType.getRank(); + Value dimValue = op.getDim(); + int64_t dim; + if (!matchPattern(dimValue, m_TorchConstantInt(&dim))) + return op.emitError("unimplemented: dim is not constant"); + dim = toPositiveDim(dim, rank); + if (!isValidDim(dim, rank)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); auto outElemType = newResultType.getElementType(); for (size_t i = 0; i < tensors.size(); ++i) { @@ -1510,17 +1523,16 @@ class ConvertAtenCatOp : public OpConversionPattern { } } - int rank = newResultType.getRank(); - Value dimValue = op.getDim(); - int64_t dim; - if (!matchPattern(dimValue, m_TorchConstantInt(&dim))) - return op.emitError("unimplemented: dim is not constant"); - dim = toPositiveDim(dim, rank); - if (!isValidDim(dim, rank)) - return rewriter.notifyMatchFailure(op, "dim is statically invalid"); + llvm::SmallVector filteredTensors; + for (auto tensor : tensors) { + auto inputType = cast(tensor.getType()); + if (inputType.getDimSize(dim) != 0) { + filteredTensors.push_back(tensor); + } + } rewriter.replaceOpWithNewOp(op, newResultType, dim, - tensors); + filteredTensors); return success(); } }; diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 2f0884b1344e..6120fd6f0e32 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2972,8 +2972,10 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { unaryNonDim &= inType.getSizes()[i] == 1 || i == dimInt; } if (unaryNonDim) { - Attribute value = - input.getValues()[start.getValue().getSExtValue()]; + int64_t idx = start.getValue().getSExtValue(); + if (idx < 0) + idx += input.getNumElements(); + Attribute value = input.getValues()[idx]; return DenseElementsAttr::get( outType.toBuiltinTensor().clone(inType.getDtype()), value); } @@ -3237,6 +3239,34 @@ OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenTensorOp +//===----------------------------------------------------------------------===// + +OpFoldResult Aten_ShapeAsTensorOp::fold(FoldAdaptor adaptor) { + auto selfTy = dyn_cast(getSelf().getType()); + auto resultTy = dyn_cast(getType()); + if (!selfTy || !resultTy || !selfTy.hasSizes() || !resultTy.hasDtype() || + !resultTy.hasSizes()) + return {}; + + llvm::SmallVector values(selfTy.getSizes()); + if (llvm::any_of(values, [](int64_t d) { return d == Torch::kUnknownSize; })) + return {}; + + auto dty = dyn_cast(resultTy.getDtype()); + if (!dty) + return {}; + + llvm::SmallVector attrs; + for (auto val : values) { + attrs.push_back(IntegerAttr::get(dty, val)); + } + + auto attrty = RankedTensorType::get(resultTy.getSizes(), dty); + return DenseElementsAttr::get(attrty, attrs); +} + //===----------------------------------------------------------------------===// // AtenIntTensorOp //===----------------------------------------------------------------------===// @@ -3409,25 +3439,25 @@ OpFoldResult AtenItemOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) { SmallVector sizes; if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) { - LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenOnesOp: size operand is " - "not a list of constant integers.\n"); return nullptr; } Type resultType = getResult().getType(); BaseTensorType resultTensorType = resultType.dyn_cast(); if (!resultTensorType || !resultTensorType.hasDtype()) { - LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenOnesOp: result type is not " - "a tensor type or does not have a dtype.\n"); return nullptr; } + int64_t ct = sizes.size(); + if (resultTensorType.getSizes().size() != 1) + return nullptr; + if (resultTensorType.getSizes()[0] != ct) + return nullptr; + ShapedType shapedty = mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType sizes, resultTensorType.getDtype()); if (!shapedty) { - LLVM_DEBUG(llvm::dbgs() - << "Failing to fold AtenOnesOp: ShapedType cast failed.\n"); return nullptr; } auto elementType = shapedty.getElementType(); @@ -3439,33 +3469,31 @@ OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) { Attribute attribute = FloatAttr::get(elementType, 1.0); return DenseElementsAttr::get(shapedty, attribute); } - LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenOnesOp: element type is " - "not integer or float.\n"); return nullptr; } OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) { SmallVector sizes; if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) { - LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenZerosOp: size operand is " - "not a list of constant integers.\n"); return nullptr; } Type resultType = getResult().getType(); BaseTensorType resultTensorType = resultType.dyn_cast(); if (!resultTensorType || !resultTensorType.hasDtype()) { - LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenZerosOp: result type is " - "not a tensor type or does not have a dtype.\n"); return nullptr; } + int64_t ct = sizes.size(); + if (resultTensorType.getSizes().size() != 1) + return nullptr; + if (resultTensorType.getSizes()[0] != ct) + return nullptr; + ShapedType shapedty = mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType sizes, resultTensorType.getDtype()); if (!shapedty) { - LLVM_DEBUG(llvm::dbgs() - << "Failing to fold AtenZerosOp: ShapedType cast failed.\n"); return nullptr; } @@ -3479,33 +3507,31 @@ OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) { return DenseElementsAttr::get(shapedty, attribute); } - LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenZerosOp: element type is " - "not integer or float.\n"); return nullptr; } OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) { SmallVector sizes; if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) { - LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenFullOp: size operand is " - "not a list of constant integers.\n"); return nullptr; } Type resultType = getResult().getType(); BaseTensorType resultTensorType = resultType.dyn_cast(); if (!resultTensorType || !resultTensorType.hasDtype()) { - LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenFullOp: result type is not " - "a tensor type or does not have a dtype.\n"); return nullptr; } + int64_t ct = sizes.size(); + if (resultTensorType.getSizes().size() != 1) + return nullptr; + if (resultTensorType.getSizes()[0] != ct) + return nullptr; + ShapedType shapedty = mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType sizes, resultTensorType.getDtype()); if (!shapedty) { - LLVM_DEBUG(llvm::dbgs() - << "Failing to fold AtenFullOp: ShapedType cast failed.\n"); return nullptr; } auto elementType = shapedty.getElementType(); @@ -3523,8 +3549,6 @@ OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) { return DenseElementsAttr::get(shapedty, attribute); } } - LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenFullOp: element type is " - "not integer or float.\n"); return nullptr; } //===----------------------------------------------------------------------===// diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 67a4f175ddb6..60b08c02539e 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -64,14 +64,6 @@ # See also: https://github.com/pytorch/torchdynamo/issues/327 "AtenEmbeddingBagSumExample_basic", - # error: failed to legalize operation 'torch.valsem.aten.bernoulli.float' that was explicitly marked illegal - "BernoulliFloatModule_basic", - "BernoulliPModule_basic", - # error: failed to legalize operation 'torch.aten.view' that was explicitly marked illegal - "ElementwiseFlattenBroadcastModule_basic", - "FlattenRank0Module_basic", - "UniformModule_basic", - "UniformStaticShapeModule_basic", # error: unsupported by backend contract: tensor with unknown rank # note: see current operation: %1 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[5,4,3,2,1],f32>) -> !torch.vtensor<*,f32> "ElementwisePreluModule_basic", @@ -2150,7 +2142,6 @@ # Failure - torch.aten.view lower "AddSizeIntModule_basic", "ElementwiseFlattenBroadcastModule_basic", - "FlattenRank0Module_basic", "IndexTensorDyanmicInputContiguousWithNoneModule_basic", "IndexTensorDyanmicInputNonContiguousWithNoneModule_basic", "IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic", @@ -2163,7 +2154,6 @@ "IndexTensorStaticContiguousWithNoneModule_basic", "RepeatModule_basic", "SelectIntModule_basic", - "SelectIntNegativeDimAndIndexStaticModule_basic", "SliceSingleIdxModule_basic", "ViewFlattenAndExpandModule_basic", "ViewSizeDimFollowedByCollapsedOnesModule_basic", @@ -2205,7 +2195,6 @@ "FlattenDynamicModule_basic", "GluStaticModule_basic", "GroupNormModule_basic", - "GroupNormNoWeightAndBiasModule_basic", "IndexSelectDynamicIndexSizeModule_basic", "IndexSelectDynamicModulebasic", "IndexTensorHackedTwinModule3dInput_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index c81f543b5dc9..fed048a64340 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -582,7 +582,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)") emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)") emit("aten::scalar_tensor : (Scalar, int?, int?, Device?, bool?) -> (Tensor)") - emit("aten::_shape_as_tensor : (Tensor) -> (Tensor)") + emit("aten::_shape_as_tensor : (Tensor) -> (Tensor)", has_folder=True) emit("aten::isnan : (Tensor) -> (Tensor)") emit("aten::isinf : (Tensor) -> (Tensor)") emit("aten::isneginf : (Tensor) -> (Tensor)") diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 85b95eb1cdba..b3dd4c6f0641 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2081,20 +2081,18 @@ func.func @torch.aten.slice.tensor$fold_dim_1() -> (!torch.vtensor<[1, 1],si64>, // CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>) { // CHECK-NOT: torch.aten.slice.Tensor // CHECK: %[[RET_0:.*]] = torch.vtensor.literal(dense<1.600000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> -// CHECK-NOT: torch.aten.slice.Tensor // CHECK: %[[RET_1:.*]] = torch.vtensor.literal(dense<6.400000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> -// CHECK-NOT: torch.aten.slice.Tensor // CHECK: return %[[RET_0]], %[[RET_1]] : !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32> func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1],f32>) { %tensor = torch.vtensor.literal(dense<[[2.0],[4.0],[8.0],[16.0],[32.0],[64.0],[128.0],[256.0],[512.0],[1024.0]]> : tensor<10x1xf32>) : !torch.vtensor<[10, 1],f32> %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 - %int3 = torch.constant.int 3 + %intn7 = torch.constant.int -7 %int4 = torch.constant.int 4 %int5 = torch.constant.int 5 %int6 = torch.constant.int 6 %dim = torch.constant.int 0 - %0 = torch.aten.slice.Tensor %tensor, %dim, %int3, %int4, %int1 : !torch.vtensor<[10, 1], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], f32> + %0 = torch.aten.slice.Tensor %tensor, %dim, %intn7, %int4, %int1 : !torch.vtensor<[10, 1], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], f32> %1 = torch.aten.slice.Tensor %tensor, %dim, %int5, %int6, %int1 : !torch.vtensor<[10, 1], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], f32> return %0, %1 : !torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1], f32> } @@ -2655,3 +2653,13 @@ func.func @aten_eq_tensor_dense_int() -> !torch.vtensor<[4],i1> { return %0 : !torch.vtensor<[4],i1> } +// ----- + +// CHECK-LABEL: @aten_shape_to_tensor +func.func @aten_shape_to_tensor(%arg0 : !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[3],si32> { + // CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<[4, 5, 6]> : tensor<3xsi32>) : !torch.vtensor<[3],si32> + %0 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[4,5,6],f32> -> !torch.vtensor<[3],si32> + // CHECK: return %[[CST]] + return %0 : !torch.vtensor<[3],si32> +} +