From dd673cfa8de6f215e27eedca54e53fb2f114b65d Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 28 Feb 2024 09:47:06 -0800 Subject: [PATCH] [torch] Add edgecase for aten.shape_to_tensor for rank-0 input (#2962) Currently lowering uses `tensor.from_elements` which does not allow zero inputs. In this case we return a `tensor.empty` operation. --- lib/Conversion/TorchToTensor/TorchToTensor.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/lib/Conversion/TorchToTensor/TorchToTensor.cpp b/lib/Conversion/TorchToTensor/TorchToTensor.cpp index 1b5341028c6d..8b934ccb0484 100644 --- a/lib/Conversion/TorchToTensor/TorchToTensor.cpp +++ b/lib/Conversion/TorchToTensor/TorchToTensor.cpp @@ -84,6 +84,12 @@ class ConvertAtenShapeToTensorPatternOp getTypeConverter()->convertType(op.getType()).cast(); int64_t rank = operandTy.getRank(); + if (rank == 0) { + rewriter.replaceOpWithNewOp(op, resultTy.getShape(), + resultTy.getElementType()); + return success(); + } + SmallVector dims; for (int i = 0; i < rank; ++i) { Value dim = rewriter.createOrFold(loc, operand, i);