diff --git a/include/torch-mlir/Conversion/TorchToLinalg/Utils.h b/include/torch-mlir/Conversion/TorchToLinalg/Utils.h index 14e9202222c6..b59d183b4084 100644 --- a/include/torch-mlir/Conversion/TorchToLinalg/Utils.h +++ b/include/torch-mlir/Conversion/TorchToLinalg/Utils.h @@ -101,6 +101,10 @@ LogicalResult permuteTensor(Operation *op, PatternRewriter &rewriter, Location loc, SmallVector dimensions, Value input, Value &result); +// Flips an input tensor based on the values of axis list. +Value flipTensor(PatternRewriter &rewriter, Location loc, Value input, + SmallVector axis); + } // namespace torch_to_linalg } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 5542e0fc642f..ac1707ec23a6 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -40,6 +40,7 @@ static int64_t productReduce(ArrayRef a) { template LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, + int64_t &dim, SmallVector &resultShape, SmallVector &offsets, SmallVector &strides) { @@ -51,7 +52,6 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, Value one = rewriter.create(loc, 1); Value negone = rewriter.create(loc, -1); - int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return op->emitError("unimplemented: dim is not constant"); @@ -1857,14 +1857,46 @@ class ConvertAtenSliceTensorOp : public OpConversionPattern { RankedTensorType resultType = cast( typeConverter->convertType(op->getResult(0).getType())); - SmallVector resultShape; - SmallVector offsets; - SmallVector strides; + SmallVector resultShape, offsets, strides; + int64_t dim; if (failed(prepareArgumentsForSlicingOp( - op, adaptor, rewriter, resultShape, offsets, strides))) { + op, adaptor, rewriter, dim, resultShape, offsets, strides))) { return failure(); } + + // If stride is negative, then flip the input tensor corresponding to that + // dim, update the stride for flipped tensor by multiplying it by -1, and + // update the offset as follows: + // flipped_offset = input_shape[dim] - (result_shape[dim] * flipped_stride) + // + // For example: + // Input = [0, 1, 2, 3, 4, 5] + // stride = [-2], result_shape = [2], offset = [3] + // Result = [3, 1] + // After flipping: + // Input = [5, 4, 3, 2, 1, 0] + // stride = [2], result_shape = [2], offset = [6 - (2 * 2)] = [2] + // Result = [3, 1] + + Value flippedInput = torch_to_linalg::flipTensor(rewriter, loc, input, + SmallVector{dim}); + Value cstDim = rewriter.create(loc, dim); + Value zero = rewriter.create(loc, 0); + Value isNegativeStride = rewriter.create( + loc, arith::CmpIPredicate::slt, strides[dim], zero); + strides[dim] = rewriter.create(loc, strides[dim]); + Value resShapeMulStride = + rewriter.create(loc, resultShape[dim], strides[dim]); + Value inputDim = rewriter.create(loc, input, cstDim); + Value flippedOffset = + rewriter.create(loc, inputDim, resShapeMulStride); + offsets[dim] = rewriter.create( + loc, isNegativeStride, flippedOffset, offsets[dim]); + + input = rewriter.create(loc, isNegativeStride, + flippedInput, input); + SmallVector dynShape(resultType.getRank(), ShapedType::kDynamic); auto sliceType = RankedTensorType::get( dynShape, resultType.getElementType(), resultType.getEncoding()); @@ -2095,12 +2127,11 @@ class ConvertAtenSliceScatterOp RankedTensorType resultType = cast( typeConverter->convertType(op->getResult(0).getType())); - SmallVector resultShape; - SmallVector offsets; - SmallVector strides; + SmallVector resultShape, offsets, strides; + int64_t dim; if (failed(prepareArgumentsForSlicingOp( - op, adaptor, rewriter, resultShape, offsets, strides))) { + op, adaptor, rewriter, dim, resultShape, offsets, strides))) { return failure(); } diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 52765411bd73..fc910fa9d3f2 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -222,14 +222,9 @@ class ConvertAtenFlipOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - MLIRContext *context = op.getContext(); Value self = adaptor.getSelf(); auto selfRank = cast(adaptor.getSelf().getType()).getRank(); - Type elementType = - cast(adaptor.getSelf().getType()).getElementType(); - Value c1 = - rewriter.create(loc, rewriter.getIndexAttr(1)); SmallVector axis; if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(axis))) @@ -242,40 +237,8 @@ class ConvertAtenFlipOp : public OpConversionPattern { } } - // Only used to calculate flipped values, i.e. those on the flip axes. Other - // dims won't be used. - SmallVector dims = getTensorSizes(rewriter, loc, self); - for (auto flipDim : axis) - dims[flipDim] = rewriter.create(loc, dims[flipDim], c1); - - Value initTensor = createZeroInitTensor( - rewriter, loc, getTensorSizes(rewriter, loc, self), elementType); - - SmallVector iteratorTypes( - selfRank, utils::IteratorType::parallel); - SmallVector indexingMaps( - 2, AffineMap::getMultiDimIdentityMap(selfRank, context)); - Value flipped = - rewriter - .create( - loc, self.getType(), self, initTensor, indexingMaps, - iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - SmallVector indices; - for (auto i = 0; i < selfRank; i++) - indices.push_back(b.create(loc, i)); - for (auto flipDim : axis) { - indices[flipDim] = b.create( - loc, dims[flipDim], indices[flipDim]); - } - Value res = b.create(loc, self, indices) - .getResult(); - b.create(loc, res); - }) - .getResult(0); - + Value flipped = torch_to_linalg::flipTensor(rewriter, loc, self, axis); rewriter.replaceOpWithNewOp(op, self.getType(), flipped); - return success(); } }; diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 6ef947d890cd..18e8fb449ef5 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -620,3 +620,44 @@ LogicalResult torch_to_linalg::permuteTensor(Operation *op, .getResult(0); return success(); } + +// Flips an input tensor based on the values of axis list. +Value torch_to_linalg::flipTensor(PatternRewriter &rewriter, Location loc, + Value input, SmallVector axis) { + Value c1 = rewriter.create(loc, rewriter.getIndexAttr(1)); + Type elementType = cast(input.getType()).getElementType(); + auto selfRank = cast(input.getType()).getRank(); + + // Only used to calculate flipped values, i.e. those on the flip axes. Other + // dims won't be used. + SmallVector dims = getTensorSizes(rewriter, loc, input); + for (auto flipDim : axis) + dims[flipDim] = rewriter.create(loc, dims[flipDim], c1); + + Value initTensor = createZeroInitTensor( + rewriter, loc, getTensorSizes(rewriter, loc, input), elementType); + + SmallVector iteratorTypes(selfRank, + utils::IteratorType::parallel); + SmallVector indexingMaps( + 2, AffineMap::getMultiDimIdentityMap(selfRank, rewriter.getContext())); + Value flipped = + rewriter + .create( + loc, input.getType(), input, initTensor, indexingMaps, + iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + SmallVector indices; + for (auto i = 0; i < selfRank; i++) + indices.push_back(b.create(loc, i)); + for (auto flipDim : axis) { + indices[flipDim] = b.create(loc, dims[flipDim], + indices[flipDim]); + } + Value res = b.create(loc, input, indices) + .getResult(); + b.create(loc, res); + }) + .getResult(0); + return flipped; +}