Skip to content

Commit

Permalink
[Torch] support 1d aten tensor shape and dtype infer (#3776)
Browse files Browse the repository at this point in the history
  • Loading branch information
yyp0 authored Oct 12, 2024
1 parent ab62f35 commit b176939
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
57 changes: 57 additions & 0 deletions lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,62 @@ class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
};
} // namespace

namespace {
class InferTensorOp : public OpRewritePattern<AtenTensorOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenTensorOp op,
PatternRewriter &rewriter) const override {
auto context = op.getContext();
auto loc = op.getLoc();
auto result = op.getResult();
auto resultType = cast<BaseTensorType>(result.getType());
if (resultType.hasSizes() && resultType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "The result of aten.tensor is already a BaseTensorType.");
}

auto inputList = op.getOperand(0);
auto listConstruct = inputList.getDefiningOp<PrimListConstructOp>();
if (!listConstruct) {
return rewriter.notifyMatchFailure(
op, "The operand 0 of aten.tensor is not PrimListConstructOp.");
}

// Currently only support the 1d input list.
SmallVector<int64_t> sizes;
sizes.push_back(listConstruct->getOperands().size());
FailureOr<Type> torchType;
auto eleType = listConstruct->getOperands()[0].getType();
if (isa<Torch::IntType>(eleType)) {
torchType = getTypeForScalarType(op->getContext(),
torch_upstream::ScalarType::Long);
} else if (isa<Torch::FloatType>(eleType)) {
torchType = getTypeForScalarType(op->getContext(),
torch_upstream::ScalarType::Float);
} else {
return rewriter.notifyMatchFailure(
op, "Currently only support Int and Float Type.");
}
auto newResultType = ValueTensorType::get(context, sizes, *torchType);

Value originalTypedValue;
for (OpOperand &use : llvm::make_early_inc_range(result.getUses())) {
if (!originalTypedValue) {
rewriter.setInsertionPointAfter(op);
originalTypedValue =
rewriter.create<TensorStaticInfoCastOp>(loc, resultType, result);
}
use.set(originalTypedValue);
}

result.setType(newResultType);

return success();
}
};
} // namespace

static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
int resultNum,
PatternRewriter &rewriter) {
Expand Down Expand Up @@ -135,6 +191,7 @@ class SimplifyShapeCalculationsPass
populateFoldPrimUncheckedCastOpPattern(patterns, context);
patterns.insert<DecomposeAtenSizeOp>(context);
patterns.insert<RefineShapeCalculateOp>(context);
patterns.insert<InferTensorOp>(context);

PrimIfOp::getCanonicalizationPatterns(patterns, context);
Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context);
Expand Down
24 changes: 24 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5621,6 +5621,30 @@ def ConstantBoolParameterModule_basic(module, tu: TestUtils):
# ==============================================================================


class TensorAlloc1dStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([2, 4, 6], torch.int, True),
]
)
def forward(self, x):
res = torch.tensor([x.shape[0]])
return res


@register_test_case(module_factory=lambda: TensorAlloc1dStaticModule())
def TensorAlloc1dStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 6))


# ==============================================================================


class ScalarTensorFloat32Module(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit b176939

Please sign in to comment.