From f2269ced80a680e48c2cddff243c5de9a80be8b3 Mon Sep 17 00:00:00 2001 From: Liam Fitzpatrick Date: Tue, 29 Mar 2022 22:21:47 +0200 Subject: [PATCH] Improve list index normalization SimplifyShapeCalculations. (#710) The reified code to compute the shape of torch.aten.constant_pad_nd uses negative indices when setting list elements. This was not converted to a positive offset in one place in SimplifyShapeCalculations which prevented computation of the static shape. --- .../torch-mlir/Dialect/Torch/Utils/Utils.h | 5 ++++ lib/Dialect/Torch/IR/TorchOps.cpp | 24 +++++++------------ .../Transforms/SimplifyShapeCalculations.cpp | 21 ++++++++-------- lib/Dialect/Torch/Utils/Utils.cpp | 11 +++++++++ .../Torch/simplify-shape-calculations.mlir | 19 +++++++++++++++ 5 files changed, 54 insertions(+), 26 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 123d5ce6f493..aa7268814c0c 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -21,6 +21,11 @@ namespace Torch { int64_t toPositiveDim(int64_t dim, int64_t inputRank); bool isValidDim(int64_t dim, int64_t inputRank); bool getListConstructElements(Value v, SmallVectorImpl &elems); +/// Returns the index indicated by `v` for a list of given `length`. +/// If the index is negative, it is adjusted to `length` + `v`. +/// `None` is returned the index is not an integer in the range [0,`length). +llvm::Optional matchLegalConstantIndexIntoListOfSize(Value v, + int64_t length); torch_upstream::ScalarType getScalarTypeForType(Type type); // Helper to convert a tensor to a specific scalar type. Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 717df7694f2b..0834be386612 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -725,18 +725,14 @@ OpFoldResult AtenSizeIntOp::fold(ArrayRef operands) { if (!type || !type.hasSizes()) return nullptr; - int64_t inputRank = type.getSizes().size(); - int64_t dim; - if (!matchPattern(this->dim(), m_TorchConstantInt(&dim))) + llvm::Optional dimOpt = matchLegalConstantIndexIntoListOfSize( + this->dim(), type.getSizes().size()); + if (!dimOpt) return nullptr; - dim = toPositiveDim(dim, inputRank); - if (!isValidDim(dim, inputRank)) - return nullptr; - - if (type.getSizes()[dim] == kUnknownSize) + if (type.getSizes()[*dimOpt] == kUnknownSize) return nullptr; return IntegerAttr::get(IntegerType::get(getContext(), 64), - type.getSizes()[dim]); + type.getSizes()[*dimOpt]); } //===----------------------------------------------------------------------===// @@ -1227,14 +1223,12 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns( return failure(); // Get the index, but be careful because it might be statically invalid. - int64_t index; - if (!matchPattern(op.getOperand(1), m_TorchConstantInt(&index))) - return failure(); - int64_t positiveDim = toPositiveDim(index, listConstruct.getNumOperands()); - if (!isValidDim(positiveDim, listConstruct.getNumOperands())) + llvm::Optional indexOpt = matchLegalConstantIndexIntoListOfSize( + op.getOperand(1), listConstruct.getNumOperands()); + if (!indexOpt) return rewriter.notifyMatchFailure(op, "statically invalid index"); - rewriter.replaceOp(op, {listConstruct.getOperand(positiveDim)}); + rewriter.replaceOp(op, {listConstruct.getOperand(*indexOpt)}); return success(); }); patterns.add(+[](Aten__Getitem__TOp op, PatternRewriter &rewriter) { diff --git a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp index 34afe1d83f55..91ed9faeebdb 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp @@ -176,14 +176,14 @@ class AbstractlyInterpretListOpsWithinABlock if (auto setItem = dyn_cast(user)) { if (!setItem.use_empty()) return failure(); - int64_t index; - if (!matchPattern(setItem.idx(), m_TorchConstantInt(&index))) - return failure(); + llvm::Optional indexOpt = + matchLegalConstantIndexIntoListOfSize(setItem.idx(), + runningList.size()); // The index might be statically out of bounds. - if (index < 0 || index >= static_cast(runningList.size())) + if (!indexOpt) return failure(); if (setItem.l() == op) { - runningList[index] = setItem.el(); + runningList[*indexOpt] = setItem.el(); generatedNewLiteral = true; } listLiterals.push_back(runningList); @@ -293,15 +293,14 @@ static void refineShapeCalculateResult(ShapeCalculateOp op, int resultNum, // change the size of the list. It might clobber some elements, which then // become dimensions with unknown size. if (auto setItem = dyn_cast(user)) { - int64_t index; // If the index is statically known, we can clobber only a single index. // Otherwise, we conservatively clobber all of them. - if (matchPattern(setItem.idx(), m_TorchConstantInt(&index)) && - isValidDim(index, listConstruct->getNumOperands())) { - clobberedElements.set(index); - } else { + llvm::Optional indexOpt = matchLegalConstantIndexIntoListOfSize( + setItem.idx(), listConstruct->getNumOperands()); + if (indexOpt) + clobberedElements.set(*indexOpt); + else clobberedElements.set(); - } continue; } // An unhandled op! We can't make any assumptions about the shape. diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 57e16dc18bb8..2a778285a80d 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -22,6 +22,17 @@ bool Torch::isValidDim(int64_t dim, int64_t inputRank) { return dim >= 0 && dim < inputRank; } +llvm::Optional +Torch::matchLegalConstantIndexIntoListOfSize(Value v, int64_t length) { + int64_t dim; + if (!matchPattern(v, m_TorchConstantInt(&dim))) + return llvm::None; + dim = toPositiveDim(dim, length); + if (!isValidDim(dim, length)) + return llvm::None; + return dim; +} + bool Torch::getListConstructElements(Value v, SmallVectorImpl &elems) { auto listConstruct = v.getDefiningOp(); if (!listConstruct) diff --git a/test/Dialect/Torch/simplify-shape-calculations.mlir b/test/Dialect/Torch/simplify-shape-calculations.mlir index ece9e87bac08..e18330ca8759 100644 --- a/test/Dialect/Torch/simplify-shape-calculations.mlir +++ b/test/Dialect/Torch/simplify-shape-calculations.mlir @@ -191,6 +191,25 @@ func @abstractly_interpret_list_ops$mutation_ops(%arg0: !torch.vtensor, %arg1: ! return %0 : !torch.vtensor } +// Test negative indexes with set_item op. +// CHECK-LABEL: func @abstractly_interpret_list_ops$neg_index_set_item( +// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %arg1, %arg2 : (!torch.int, !torch.int) -> !torch.list +// CHECK: torch.shape.calculate.yield.shapes %[[SHAPE]] : !torch.list +func @abstractly_interpret_list_ops$neg_index_set_item(%arg0: !torch.vtensor, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.vtensor { + %int1 = torch.constant.int 1 + %int-1 = torch.constant.int -1 + %int-2 = torch.constant.int -2 + %0 = torch.shape.calculate { + torch.shape.calculate.yield %arg0 : !torch.vtensor + } shapes { + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.aten._set_item.t %1, %int-1, %arg2 : !torch.list, !torch.int, !torch.int -> !torch.list + %3 = torch.aten._set_item.t %1, %int-2, %arg1 : !torch.list, !torch.int, !torch.int -> !torch.list + torch.shape.calculate.yield.shapes %1 : !torch.list + } : !torch.vtensor + return %0 : !torch.vtensor +} + // Test interspersed mutation and evaluation ops. // CHECK-LABEL: func @abstractly_interpret_list_ops$mix_mutation_and_evaluation_ops( // CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %int0, %int1, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list