Skip to content

Commit

Permalink
[stablehlo] Support PrimsCollapseOp and PrimsSplitDimOp in stablehlo (l…
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyu Yang authored and archana-ramalingam committed May 8, 2024
1 parent bd440fe commit 1288f00
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,17 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, ArrayRef<int64_t> inputUnsqzDims,
size_t dimSizeIndexBits);

// Get a tensor that collapse the specified dimensions of the input tensor
FailureOr<Value> collapseTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, int64_t collapseStartDim,
int64_t collapseEndDim,
size_t dimSizeIndexBits);

// Get a tensor that splits the specified dimensions of the input tensor
FailureOr<Value> splitTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, int64_t splitDim,
int64_t outerLength, size_t dimSizeIndexBits);

Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
const APFloat &constant, Value shape,
TensorType outType);
Expand Down
131 changes: 131 additions & 0 deletions lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
Expand Down Expand Up @@ -306,6 +307,136 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
.getResult();
}

FailureOr<Value> collapseTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, int64_t collapseStartDim,
int64_t collapseEndDim,
size_t dimSizeIndexBits) {

auto dimSizesInfo =
getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits);

if (failed(dimSizesInfo))
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");

auto dimSizes = *dimSizesInfo;
int64_t rank = dimSizes.size();

collapseStartDim = toPositiveDim(collapseStartDim, rank);
collapseEndDim = toPositiveDim(collapseEndDim, rank);

int64_t newRank = rank - (collapseEndDim - collapseStartDim + 1);

auto loc = op->getLoc();
auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
auto oldShape = rankTy.getShape();
Type intType = rewriter.getIntegerType(dimSizeIndexBits);

std::vector<Value> newDimSizes;
std::vector<int64_t> newShape;
newDimSizes.reserve(newRank);
newShape.reserve(newRank);

Value collapseDimSize = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 1));
int64_t collapseShape = 1;

for (int64_t k = collapseStartDim; k <= collapseEndDim; ++k) {
if (k < 0 || k >= rank) {
return rewriter.notifyMatchFailure(
op, "collapse dimensions must be within the rank of the tensor");
}
if (collapseShape == ShapedType::kDynamic ||
oldShape[k] == ShapedType::kDynamic) {
collapseShape = ShapedType::kDynamic;
} else {
collapseShape *= oldShape[k];
}
collapseDimSize =
rewriter.create<arith::MulIOp>(loc, collapseDimSize, dimSizes[k]);
}

for (int64_t k = 0; k < collapseStartDim; ++k) {
newDimSizes.push_back(dimSizes[k]);
newShape.push_back(oldShape[k]);
}
newDimSizes.push_back(collapseDimSize);
newShape.push_back(collapseShape);
for (int64_t k = collapseEndDim + 1; k < rank; ++k) {
newDimSizes.push_back(dimSizes[k]);
newShape.push_back(oldShape[k]);
}

auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
auto shape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
return rewriter.create<stablehlo::DynamicReshapeOp>(loc, outTy, tensor, shape)
.getResult();
}

// TODO: support splitDim & outerLength to be Value
FailureOr<Value> splitTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, int64_t splitDim,
int64_t outerLength, size_t dimSizeIndexBits) {
auto dimSizesInfo =
getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits);

if (failed(dimSizesInfo))
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");

auto dimSizes = *dimSizesInfo;
int64_t rank = dimSizes.size();
splitDim = toPositiveDim(splitDim, rank);

auto loc = op->getLoc();
auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
auto oldShape = rankTy.getShape();
Type intType = rewriter.getIntegerType(dimSizeIndexBits);

if (splitDim < 0 || splitDim >= rank) {
return rewriter.notifyMatchFailure(
op, "split dimensions must be within the rank of the tensor");
}

int64_t newRank = rank + 1;
auto outerLengthValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, outerLength));

auto innerLengthValue = rewriter.create<arith::DivSIOp>(
loc, dimSizes[splitDim], outerLengthValue);

int64_t originShape = oldShape[splitDim];
int64_t outerShape = outerLength;
int64_t innerShape = originShape == ShapedType::kDynamic
? ShapedType::kDynamic
: originShape / outerLength;

std::vector<Value> newDimSizes;
std::vector<int64_t> newShape;

newDimSizes.reserve(newRank);
newShape.reserve(newRank);

for (int64_t k = 0; k < splitDim; ++k) {
newDimSizes.push_back(dimSizes[k]);
newShape.push_back(oldShape[k]);
}
newDimSizes.push_back(outerLengthValue);
newShape.push_back(outerShape);
newDimSizes.push_back(innerLengthValue);
newShape.push_back(innerShape);

for (int64_t k = splitDim + 1; k < rank; ++k) {
newDimSizes.push_back(dimSizes[k]);
newShape.push_back(oldShape[k]);
}

auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
auto shape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
return rewriter.create<stablehlo::DynamicReshapeOp>(loc, outTy, tensor, shape)
.getResult();
}

Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
const APFloat &constant, Value shape,
TensorType outType) {
Expand Down
61 changes: 36 additions & 25 deletions lib/Conversion/TorchToStablehlo/ViewLike.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,34 +414,44 @@ LogicalResult ConvertAtenOp<PrimsCollapseOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "only constant end is currently supported");

start = toPositiveDim(start, rank);
end = toPositiveDim(end, rank);
SmallVector<int64_t, 4> dims;
dims.reserve(rank);
for (int r = 0; r < start; ++r)
dims.push_back(r);
int64_t collapsedDimSize = 1;
for (int r = start; r <= end; ++r) {
if (selfType.getShape()[r] == ShapedType::kDynamic)
return rewriter.notifyMatchFailure(
op, "the size of the dimension being collapsed is can't be unknown");
collapsedDimSize *= selfType.getShape()[r];
auto collapseTensorInfo = hlo::collapseTensor(
rewriter, op, adaptor.getA(), start, end, options.dimSizeIndexBits);
if (failed(collapseTensorInfo))
return rewriter.notifyMatchFailure(op, "failed to create collapsed tensor");

rewriter.replaceOp(op, *collapseTensorInfo);
return success();
}

template <>
LogicalResult ConvertAtenOp<PrimsSplitDimOp>::matchAndRewrite(
PrimsSplitDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto selfType = adaptor.getA().getType().dyn_cast<TensorType>();
if (!selfType) {
return op.emitError("only tensor types are currently supported");
}
dims.push_back(collapsedDimSize);
for (int r = end + 1; r < rank; ++r)
dims.push_back(r);

auto newDimSizesInfo = hlo::getDimSizesOfTensor(
rewriter, op, adaptor.getA(), dims, options.dimSizeIndexBits);
if (failed(newDimSizesInfo))
auto rank = selfType.getRank();
if (rank == 0)
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
auto newDimSizes = *newDimSizesInfo;
auto stablehloShape =
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.getA(),
stablehloShape);
op, "the rank of tensor must be greater than 0");

int64_t dim, outerLength;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "only constant dim is currently supported");
if (!matchPattern(op.getOuterLength(), m_TorchConstantInt(&outerLength)))
return rewriter.notifyMatchFailure(
op, "only constant outerLength is currently supported");

auto splitTensorInfo = hlo::splitTensor(
rewriter, op, adaptor.getA(), dim, outerLength, options.dimSizeIndexBits);

if (failed(splitTensorInfo))
return rewriter.notifyMatchFailure(op, "failed to create split tensor");

rewriter.replaceOp(op, *splitTensorInfo);
return success();
}

Expand All @@ -458,6 +468,7 @@ void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenSqueezeDimOp);
INSERT_ATENOP_PATTERN(AtenUnsqueezeOp);
INSERT_ATENOP_PATTERN(PrimsCollapseOp);
INSERT_ATENOP_PATTERN(PrimsSplitDimOp);
#undef INSERT_ATENOP_PATTERN

#define INSERT_VIEW_OP_PATTERN(AtenOp) \
Expand Down
8 changes: 3 additions & 5 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,11 +678,6 @@
"NumToTensorIntModule_basic",
"NumelModule_basic",
"NumelZeroRankModule_basic",
"PixelShuffleModuleFullDynamic_basic",
"PixelShuffleModuleSpatiallyDynamic_basic",
"PixelShuffleModuleSpatiallyStatic_basic",
"PixelShuffleModuleStaticRank3Int64_basic",
"PixelShuffleModuleStaticRank4Float32_basic",
"PowIntFloatModule_basic",
"PrimMaxIntModule_basic",
"PrimMinIntDynamicModule_basic",
Expand Down Expand Up @@ -1157,6 +1152,8 @@
"Permute0RankModule_basic",
"PermuteModule_basic",
"PermuteNegativeIndexModule_basic",
"PixelShuffleModuleStaticRank3Int64_basic",
"PixelShuffleModuleStaticRank4Float32_basic",
"PowIntFloatModule_basic",
"PrimListUnpackNumMismatchModule_basic",
"PrimMaxIntModule_basic",
Expand Down Expand Up @@ -1240,6 +1237,7 @@
"SliceWholeTensorModule_basic",
"SortIntListReverse_basic",
"SortIntList_basic",
"SplitDimStaticModule_basic",
"SplitTensorGetItem_Module_basic",
"SplitTensorLastSmallerModule_basic",
"SplitTensorListUnpackModule_basic",
Expand Down

0 comments on commit 1288f00

Please sign in to comment.