Skip to content

Commit

Permalink
[torch] Fix folders and cat and view torch lowerings (llvm#2963)
Browse files Browse the repository at this point in the history
A bunch of small fixes are interlinked and trigger crashes if not
addressed as a group. This includes:

- aten view when expand from a rank-0 tensor
- slice folder with negative indices
- `aten._shape_as_tensor` folder on a rank-0 tensor
- `aten.cat` of a tensor with a length-0 tensor
  • Loading branch information
rsuderman authored Feb 28, 2024
1 parent 73b6df9 commit 6f3d62a
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 55 deletions.
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -8925,6 +8925,7 @@ def Torch_Aten_ShapeAsTensorOp : Torch_Op<"aten._shape_as_tensor", [
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenIsnanOp : Torch_Op<"aten.isnan", [
Expand Down
38 changes: 25 additions & 13 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -799,10 +799,15 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
return rewriter.notifyMatchFailure(op,
"result shape of rank 0 is invalid");

// TODO: add support for case inputRank 0 expanded to size 1
if (inputRank == 0)
return rewriter.notifyMatchFailure(
op, "unimplemented: input rank 0 is not supported");
if (inputRank == 0) {
Value expanded =
rewriter
.create<tensor::ExpandShapeOp>(loc, resultType, input,
ArrayRef<ReassociationIndices>())
.getResult();
rewriter.replaceOp(op, expanded);
return success();
}

// Extract the desired output size as a list of integers. This list should
// have been created using the operation `torch.prim.ListConstruct`.
Expand Down Expand Up @@ -1500,6 +1505,14 @@ class ConvertAtenCatOp : public OpConversionPattern<AtenCatOp> {

RankedTensorType newResultType =
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
int rank = newResultType.getRank();
Value dimValue = op.getDim();
int64_t dim;
if (!matchPattern(dimValue, m_TorchConstantInt(&dim)))
return op.emitError("unimplemented: dim is not constant");
dim = toPositiveDim(dim, rank);
if (!isValidDim(dim, rank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");

auto outElemType = newResultType.getElementType();
for (size_t i = 0; i < tensors.size(); ++i) {
Expand All @@ -1510,17 +1523,16 @@ class ConvertAtenCatOp : public OpConversionPattern<AtenCatOp> {
}
}

int rank = newResultType.getRank();
Value dimValue = op.getDim();
int64_t dim;
if (!matchPattern(dimValue, m_TorchConstantInt(&dim)))
return op.emitError("unimplemented: dim is not constant");
dim = toPositiveDim(dim, rank);
if (!isValidDim(dim, rank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
llvm::SmallVector<Value> filteredTensors;
for (auto tensor : tensors) {
auto inputType = cast<RankedTensorType>(tensor.getType());
if (inputType.getDimSize(dim) != 0) {
filteredTensors.push_back(tensor);
}
}

rewriter.replaceOpWithNewOp<tensor::ConcatOp>(op, newResultType, dim,
tensors);
filteredTensors);
return success();
}
};
Expand Down
76 changes: 50 additions & 26 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2972,8 +2972,10 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
unaryNonDim &= inType.getSizes()[i] == 1 || i == dimInt;
}
if (unaryNonDim) {
Attribute value =
input.getValues<Attribute>()[start.getValue().getSExtValue()];
int64_t idx = start.getValue().getSExtValue();
if (idx < 0)
idx += input.getNumElements();
Attribute value = input.getValues<Attribute>()[idx];
return DenseElementsAttr::get(
outType.toBuiltinTensor().clone(inType.getDtype()), value);
}
Expand Down Expand Up @@ -3237,6 +3239,34 @@ OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) {
return nullptr;
}

//===----------------------------------------------------------------------===//
// AtenTensorOp
//===----------------------------------------------------------------------===//

OpFoldResult Aten_ShapeAsTensorOp::fold(FoldAdaptor adaptor) {
auto selfTy = dyn_cast<BaseTensorType>(getSelf().getType());
auto resultTy = dyn_cast<BaseTensorType>(getType());
if (!selfTy || !resultTy || !selfTy.hasSizes() || !resultTy.hasDtype() ||
!resultTy.hasSizes())
return {};

llvm::SmallVector<int64_t> values(selfTy.getSizes());
if (llvm::any_of(values, [](int64_t d) { return d == Torch::kUnknownSize; }))
return {};

auto dty = dyn_cast<IntegerType>(resultTy.getDtype());
if (!dty)
return {};

llvm::SmallVector<Attribute> attrs;
for (auto val : values) {
attrs.push_back(IntegerAttr::get(dty, val));
}

auto attrty = RankedTensorType::get(resultTy.getSizes(), dty);
return DenseElementsAttr::get(attrty, attrs);
}

//===----------------------------------------------------------------------===//
// AtenIntTensorOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -3409,25 +3439,25 @@ OpFoldResult AtenItemOp::fold(FoldAdaptor adaptor) {
OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) {
SmallVector<int64_t> sizes;
if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) {
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenOnesOp: size operand is "
"not a list of constant integers.\n");
return nullptr;
}

Type resultType = getResult().getType();
BaseTensorType resultTensorType = resultType.dyn_cast<BaseTensorType>();
if (!resultTensorType || !resultTensorType.hasDtype()) {
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenOnesOp: result type is not "
"a tensor type or does not have a dtype.\n");
return nullptr;
}

int64_t ct = sizes.size();
if (resultTensorType.getSizes().size() != 1)
return nullptr;
if (resultTensorType.getSizes()[0] != ct)
return nullptr;

ShapedType shapedty =
mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType
sizes, resultTensorType.getDtype());
if (!shapedty) {
LLVM_DEBUG(llvm::dbgs()
<< "Failing to fold AtenOnesOp: ShapedType cast failed.\n");
return nullptr;
}
auto elementType = shapedty.getElementType();
Expand All @@ -3439,33 +3469,31 @@ OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) {
Attribute attribute = FloatAttr::get(elementType, 1.0);
return DenseElementsAttr::get(shapedty, attribute);
}
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenOnesOp: element type is "
"not integer or float.\n");
return nullptr;
}

OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) {
SmallVector<int64_t> sizes;
if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) {
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenZerosOp: size operand is "
"not a list of constant integers.\n");
return nullptr;
}

Type resultType = getResult().getType();
BaseTensorType resultTensorType = resultType.dyn_cast<BaseTensorType>();
if (!resultTensorType || !resultTensorType.hasDtype()) {
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenZerosOp: result type is "
"not a tensor type or does not have a dtype.\n");
return nullptr;
}

int64_t ct = sizes.size();
if (resultTensorType.getSizes().size() != 1)
return nullptr;
if (resultTensorType.getSizes()[0] != ct)
return nullptr;

ShapedType shapedty =
mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType
sizes, resultTensorType.getDtype());
if (!shapedty) {
LLVM_DEBUG(llvm::dbgs()
<< "Failing to fold AtenZerosOp: ShapedType cast failed.\n");
return nullptr;
}

Expand All @@ -3479,33 +3507,31 @@ OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) {
return DenseElementsAttr::get(shapedty, attribute);
}

LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenZerosOp: element type is "
"not integer or float.\n");
return nullptr;
}

OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) {
SmallVector<int64_t> sizes;
if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) {
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenFullOp: size operand is "
"not a list of constant integers.\n");
return nullptr;
}

Type resultType = getResult().getType();
BaseTensorType resultTensorType = resultType.dyn_cast<BaseTensorType>();
if (!resultTensorType || !resultTensorType.hasDtype()) {
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenFullOp: result type is not "
"a tensor type or does not have a dtype.\n");
return nullptr;
}

int64_t ct = sizes.size();
if (resultTensorType.getSizes().size() != 1)
return nullptr;
if (resultTensorType.getSizes()[0] != ct)
return nullptr;

ShapedType shapedty =
mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType
sizes, resultTensorType.getDtype());
if (!shapedty) {
LLVM_DEBUG(llvm::dbgs()
<< "Failing to fold AtenFullOp: ShapedType cast failed.\n");
return nullptr;
}
auto elementType = shapedty.getElementType();
Expand All @@ -3523,8 +3549,6 @@ OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) {
return DenseElementsAttr::get(shapedty, attribute);
}
}
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenFullOp: element type is "
"not integer or float.\n");
return nullptr;
}
//===----------------------------------------------------------------------===//
Expand Down
11 changes: 0 additions & 11 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,6 @@
# See also: https://github.com/pytorch/torchdynamo/issues/327
"AtenEmbeddingBagSumExample_basic",

# error: failed to legalize operation 'torch.valsem.aten.bernoulli.float' that was explicitly marked illegal
"BernoulliFloatModule_basic",
"BernoulliPModule_basic",
# error: failed to legalize operation 'torch.aten.view' that was explicitly marked illegal
"ElementwiseFlattenBroadcastModule_basic",
"FlattenRank0Module_basic",
"UniformModule_basic",
"UniformStaticShapeModule_basic",
# error: unsupported by backend contract: tensor with unknown rank
# note: see current operation: %1 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[5,4,3,2,1],f32>) -> !torch.vtensor<*,f32>
"ElementwisePreluModule_basic",
Expand Down Expand Up @@ -2150,7 +2142,6 @@
# Failure - torch.aten.view lower
"AddSizeIntModule_basic",
"ElementwiseFlattenBroadcastModule_basic",
"FlattenRank0Module_basic",
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
Expand All @@ -2163,7 +2154,6 @@
"IndexTensorStaticContiguousWithNoneModule_basic",
"RepeatModule_basic",
"SelectIntModule_basic",
"SelectIntNegativeDimAndIndexStaticModule_basic",
"SliceSingleIdxModule_basic",
"ViewFlattenAndExpandModule_basic",
"ViewSizeDimFollowedByCollapsedOnesModule_basic",
Expand Down Expand Up @@ -2205,7 +2195,6 @@
"FlattenDynamicModule_basic",
"GluStaticModule_basic",
"GroupNormModule_basic",
"GroupNormNoWeightAndBiasModule_basic",
"IndexSelectDynamicIndexSizeModule_basic",
"IndexSelectDynamicModulebasic",
"IndexTensorHackedTwinModule3dInput_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)")
emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)")
emit("aten::scalar_tensor : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::_shape_as_tensor : (Tensor) -> (Tensor)")
emit("aten::_shape_as_tensor : (Tensor) -> (Tensor)", has_folder=True)
emit("aten::isnan : (Tensor) -> (Tensor)")
emit("aten::isinf : (Tensor) -> (Tensor)")
emit("aten::isneginf : (Tensor) -> (Tensor)")
Expand Down
16 changes: 12 additions & 4 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2081,20 +2081,18 @@ func.func @torch.aten.slice.tensor$fold_dim_1() -> (!torch.vtensor<[1, 1],si64>,
// CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>) {
// CHECK-NOT: torch.aten.slice.Tensor
// CHECK: %[[RET_0:.*]] = torch.vtensor.literal(dense<1.600000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
// CHECK-NOT: torch.aten.slice.Tensor
// CHECK: %[[RET_1:.*]] = torch.vtensor.literal(dense<6.400000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
// CHECK-NOT: torch.aten.slice.Tensor
// CHECK: return %[[RET_0]], %[[RET_1]] : !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>
func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1],f32>) {
%tensor = torch.vtensor.literal(dense<[[2.0],[4.0],[8.0],[16.0],[32.0],[64.0],[128.0],[256.0],[512.0],[1024.0]]> : tensor<10x1xf32>) : !torch.vtensor<[10, 1],f32>
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%intn7 = torch.constant.int -7
%int4 = torch.constant.int 4
%int5 = torch.constant.int 5
%int6 = torch.constant.int 6
%dim = torch.constant.int 0
%0 = torch.aten.slice.Tensor %tensor, %dim, %int3, %int4, %int1 : !torch.vtensor<[10, 1], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], f32>
%0 = torch.aten.slice.Tensor %tensor, %dim, %intn7, %int4, %int1 : !torch.vtensor<[10, 1], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], f32>
%1 = torch.aten.slice.Tensor %tensor, %dim, %int5, %int6, %int1 : !torch.vtensor<[10, 1], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], f32>
return %0, %1 : !torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1], f32>
}
Expand Down Expand Up @@ -2655,3 +2653,13 @@ func.func @aten_eq_tensor_dense_int() -> !torch.vtensor<[4],i1> {
return %0 : !torch.vtensor<[4],i1>
}

// -----

// CHECK-LABEL: @aten_shape_to_tensor
func.func @aten_shape_to_tensor(%arg0 : !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[3],si32> {
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<[4, 5, 6]> : tensor<3xsi32>) : !torch.vtensor<[3],si32>
%0 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[4,5,6],f32> -> !torch.vtensor<[3],si32>
// CHECK: return %[[CST]]
return %0 : !torch.vtensor<[3],si32>
}

0 comments on commit 6f3d62a

Please sign in to comment.