Skip to content

Commit

Permalink
[torch] Add edgecase for aten.shape_to_tensor for rank-0 input (llvm#…
Browse files Browse the repository at this point in the history
…2962)

Currently lowering uses `tensor.from_elements` which does not allow zero
inputs. In this case we return a `tensor.empty` operation.
  • Loading branch information
rsuderman authored Feb 28, 2024
1 parent 08bc013 commit dd673cf
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions lib/Conversion/TorchToTensor/TorchToTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ class ConvertAtenShapeToTensorPatternOp
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();

int64_t rank = operandTy.getRank();
if (rank == 0) {
rewriter.replaceOpWithNewOp<tensor::EmptyOp>(op, resultTy.getShape(),
resultTy.getElementType());
return success();
}

SmallVector<Value> dims;
for (int i = 0; i < rank; ++i) {
Value dim = rewriter.createOrFold<tensor::DimOp>(loc, operand, i);
Expand Down

0 comments on commit dd673cf

Please sign in to comment.