Skip to content

Commit

Permalink
Improve list index normalization SimplifyShapeCalculations. (llvm#710)
Browse files Browse the repository at this point in the history
The reified code to compute the shape of torch.aten.constant_pad_nd
uses negative indices when setting list elements. This was not
converted to a positive offset in one place in SimplifyShapeCalculations
which prevented computation of the static shape.
  • Loading branch information
ljfitz authored Mar 29, 2022
1 parent 25ba51b commit f2269ce
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 26 deletions.
5 changes: 5 additions & 0 deletions include/torch-mlir/Dialect/Torch/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ namespace Torch {
int64_t toPositiveDim(int64_t dim, int64_t inputRank);
bool isValidDim(int64_t dim, int64_t inputRank);
bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems);
/// Returns the index indicated by `v` for a list of given `length`.
/// If the index is negative, it is adjusted to `length` + `v`.
/// `None` is returned the index is not an integer in the range [0,`length).
llvm::Optional<int64_t> matchLegalConstantIndexIntoListOfSize(Value v,
int64_t length);
torch_upstream::ScalarType getScalarTypeForType(Type type);
// Helper to convert a tensor to a specific scalar type.
Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input,
Expand Down
24 changes: 9 additions & 15 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -725,18 +725,14 @@ OpFoldResult AtenSizeIntOp::fold(ArrayRef<Attribute> operands) {
if (!type || !type.hasSizes())
return nullptr;

int64_t inputRank = type.getSizes().size();
int64_t dim;
if (!matchPattern(this->dim(), m_TorchConstantInt(&dim)))
llvm::Optional<int64_t> dimOpt = matchLegalConstantIndexIntoListOfSize(
this->dim(), type.getSizes().size());
if (!dimOpt)
return nullptr;
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank))
return nullptr;

if (type.getSizes()[dim] == kUnknownSize)
if (type.getSizes()[*dimOpt] == kUnknownSize)
return nullptr;
return IntegerAttr::get(IntegerType::get(getContext(), 64),
type.getSizes()[dim]);
type.getSizes()[*dimOpt]);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1227,14 +1223,12 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
return failure();

// Get the index, but be careful because it might be statically invalid.
int64_t index;
if (!matchPattern(op.getOperand(1), m_TorchConstantInt(&index)))
return failure();
int64_t positiveDim = toPositiveDim(index, listConstruct.getNumOperands());
if (!isValidDim(positiveDim, listConstruct.getNumOperands()))
llvm::Optional<int64_t> indexOpt = matchLegalConstantIndexIntoListOfSize(
op.getOperand(1), listConstruct.getNumOperands());
if (!indexOpt)
return rewriter.notifyMatchFailure(op, "statically invalid index");

rewriter.replaceOp(op, {listConstruct.getOperand(positiveDim)});
rewriter.replaceOp(op, {listConstruct.getOperand(*indexOpt)});
return success();
});
patterns.add(+[](Aten__Getitem__TOp op, PatternRewriter &rewriter) {
Expand Down
21 changes: 10 additions & 11 deletions lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,14 @@ class AbstractlyInterpretListOpsWithinABlock
if (auto setItem = dyn_cast<Aten_SetItemTOp>(user)) {
if (!setItem.use_empty())
return failure();
int64_t index;
if (!matchPattern(setItem.idx(), m_TorchConstantInt(&index)))
return failure();
llvm::Optional<int64_t> indexOpt =
matchLegalConstantIndexIntoListOfSize(setItem.idx(),
runningList.size());
// The index might be statically out of bounds.
if (index < 0 || index >= static_cast<int64_t>(runningList.size()))
if (!indexOpt)
return failure();
if (setItem.l() == op) {
runningList[index] = setItem.el();
runningList[*indexOpt] = setItem.el();
generatedNewLiteral = true;
}
listLiterals.push_back(runningList);
Expand Down Expand Up @@ -293,15 +293,14 @@ static void refineShapeCalculateResult(ShapeCalculateOp op, int resultNum,
// change the size of the list. It might clobber some elements, which then
// become dimensions with unknown size.
if (auto setItem = dyn_cast<Aten_SetItemTOp>(user)) {
int64_t index;
// If the index is statically known, we can clobber only a single index.
// Otherwise, we conservatively clobber all of them.
if (matchPattern(setItem.idx(), m_TorchConstantInt(&index)) &&
isValidDim(index, listConstruct->getNumOperands())) {
clobberedElements.set(index);
} else {
llvm::Optional<int64_t> indexOpt = matchLegalConstantIndexIntoListOfSize(
setItem.idx(), listConstruct->getNumOperands());
if (indexOpt)
clobberedElements.set(*indexOpt);
else
clobberedElements.set();
}
continue;
}
// An unhandled op! We can't make any assumptions about the shape.
Expand Down
11 changes: 11 additions & 0 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@ bool Torch::isValidDim(int64_t dim, int64_t inputRank) {
return dim >= 0 && dim < inputRank;
}

llvm::Optional<int64_t>
Torch::matchLegalConstantIndexIntoListOfSize(Value v, int64_t length) {
int64_t dim;
if (!matchPattern(v, m_TorchConstantInt(&dim)))
return llvm::None;
dim = toPositiveDim(dim, length);
if (!isValidDim(dim, length))
return llvm::None;
return dim;
}

bool Torch::getListConstructElements(Value v, SmallVectorImpl<Value> &elems) {
auto listConstruct = v.getDefiningOp<PrimListConstructOp>();
if (!listConstruct)
Expand Down
19 changes: 19 additions & 0 deletions test/Dialect/Torch/simplify-shape-calculations.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,25 @@ func @abstractly_interpret_list_ops$mutation_ops(%arg0: !torch.vtensor, %arg1: !
return %0 : !torch.vtensor
}

// Test negative indexes with set_item op.
// CHECK-LABEL: func @abstractly_interpret_list_ops$neg_index_set_item(
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %arg1, %arg2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list<int>
func @abstractly_interpret_list_ops$neg_index_set_item(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.vtensor {
%int1 = torch.constant.int 1
%int-1 = torch.constant.int -1
%int-2 = torch.constant.int -2
%0 = torch.shape.calculate {
torch.shape.calculate.yield %arg0 : !torch.vtensor
} shapes {
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.aten._set_item.t %1, %int-1, %arg2 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
%3 = torch.aten._set_item.t %1, %int-2, %arg1 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>
torch.shape.calculate.yield.shapes %1 : !torch.list<int>
} : !torch.vtensor
return %0 : !torch.vtensor
}

// Test interspersed mutation and evaluation ops.
// CHECK-LABEL: func @abstractly_interpret_list_ops$mix_mutation_and_evaluation_ops(
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %int0, %int1, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
Expand Down

0 comments on commit f2269ce

Please sign in to comment.