diff --git a/lib/Conversion/TorchToTensor/TorchToTensor.cpp b/lib/Conversion/TorchToTensor/TorchToTensor.cpp index 417fd17fcb86..1b5341028c6d 100644 --- a/lib/Conversion/TorchToTensor/TorchToTensor.cpp +++ b/lib/Conversion/TorchToTensor/TorchToTensor.cpp @@ -28,6 +28,47 @@ using namespace mlir::torch::Torch; namespace { +class ConvertAtenItemOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenItemOp::Adaptor; + LogicalResult + matchAndRewrite(AtenItemOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto operand = adaptor.getOperands()[0]; + auto operandTy = cast(operand.getType()); + auto torchDTy = cast(op.getOperand().getType()).getDtype(); + + if (operandTy.getNumElements() != 1) + return rewriter.notifyMatchFailure(op, "expected only one item"); + + auto zeroIdx = rewriter.create(op.getLoc(), 0); + auto rank = operandTy.getRank(); + llvm::SmallVector indices(rank, zeroIdx); + + Value extract = rewriter.create( + op.getLoc(), operandTy.getElementType(), operand, indices); + auto extractTy = extract.getType(); + if (isa(extractTy) && !extractTy.isInteger(64)) { + if (torchDTy.isSignlessInteger()) { + extract = rewriter.create( + op.getLoc(), rewriter.getIntegerType(64), extract); + } else { + extract = rewriter.create( + op.getLoc(), rewriter.getIntegerType(64), extract); + } + } + + if (isa(extractTy) && !extractTy.isF64()) { + extract = rewriter.create(op.getLoc(), + rewriter.getF64Type(), extract); + } + + rewriter.replaceOp(op, extract); + return success(); + } +}; + class ConvertAtenShapeToTensorPatternOp : public OpConversionPattern { public: @@ -70,6 +111,7 @@ class ConvertTorchToTensor ConversionTarget target(*context); target.addLegalDialect(); target.addLegalDialect(); + target.addIllegalOp(); target.addIllegalOp(); TypeConverter typeConverter; @@ -77,7 +119,8 @@ class ConvertTorchToTensor TorchConversion::setupBackendTypeConversion(target, typeConverter); RewritePatternSet patterns(context); - patterns.add(typeConverter, context); + patterns.add( + typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 91d468a6941f..673d7083f585 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -20,6 +20,7 @@ #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" +#include "torch-mlir/Conversion/TorchToTensor/TorchToTensor.h" #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #ifdef TORCH_MLIR_ENABLE_STABLEHLO @@ -76,6 +77,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( pm.addNestedPass(createConvertTorchToLinalgPass()); pm.addNestedPass(createConvertTorchToSCFPass()); pm.addNestedPass(createConvertTorchToArithPass()); + pm.addNestedPass(createConvertTorchToTensorPass()); pm.addPass(createConvertTorchConversionToMLProgramPass()); pm.addNestedPass(memref::createExpandOpsPass()); diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py index 74717d99fb4e..303c3f0a801a 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py @@ -428,3 +428,41 @@ def forward(self, val): @register_test_case(module_factory=lambda: AtenIntTensorCharDtypeModule()) def AtenIntTensorCharDtypeModule_basic(module, tu: TestUtils): module.forward(tu.randint(low=-100, high=100).to(dtype=torch.int8)) + +# ============================================================================== + +class AtenItemIntOpModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.int8, True), + ]) + + def forward(self, val): + return int(val) + +@register_test_case(module_factory=lambda: AtenItemIntOpModule()) +def AtenItemIntOpModule_basic(module, tu: TestUtils): + module.forward(tu.randint(low=-100, high=100).to(dtype=torch.int8)) + +# ============================================================================== + +class AtenItemFpOpModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.float, True), + ]) + + def forward(self, val): + return float(val) + +@register_test_case(module_factory=lambda: AtenItemFpOpModule()) +def AtenItemFpOpModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1))