Skip to content

Commit

Permalink
[torch] Support lowering torch.item to tensor.extract
Browse files Browse the repository at this point in the history
Extracting scalar values from tensors can be implemented via a lowering
to tensor.extract.
  • Loading branch information
rsuderman committed Jan 31, 2024
1 parent 25a5a22 commit 26713dd
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 1 deletion.
45 changes: 44 additions & 1 deletion lib/Conversion/TorchToTensor/TorchToTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,47 @@ using namespace mlir::torch::Torch;

namespace {

class ConvertAtenItemOp : public OpConversionPattern<AtenItemOp> {
public:
using OpConversionPattern<AtenItemOp>::OpConversionPattern;
using OpAdaptor = typename AtenItemOp::Adaptor;
LogicalResult
matchAndRewrite(AtenItemOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto operand = adaptor.getOperands()[0];
auto operandTy = cast<RankedTensorType>(operand.getType());
auto torchDTy = cast<ValueTensorType>(op.getOperand().getType()).getDtype();

if (operandTy.getNumElements() != 1)
return rewriter.notifyMatchFailure(op, "expected only one item");

auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
auto rank = operandTy.getRank();
llvm::SmallVector<Value> indices(rank, zeroIdx);

Value extract = rewriter.create<tensor::ExtractOp>(
op.getLoc(), operandTy.getElementType(), operand, indices);
auto extractTy = extract.getType();
if (isa<mlir::IntegerType>(extractTy) && !extractTy.isInteger(64)) {
if (torchDTy.isSignlessInteger()) {
extract = rewriter.create<arith::ExtUIOp>(
op.getLoc(), rewriter.getIntegerType(64), extract);
} else {
extract = rewriter.create<arith::ExtSIOp>(
op.getLoc(), rewriter.getIntegerType(64), extract);
}
}

if (isa<mlir::FloatType>(extractTy) && !extractTy.isF64()) {
extract = rewriter.create<arith::ExtFOp>(op.getLoc(),
rewriter.getF64Type(), extract);
}

rewriter.replaceOp(op, extract);
return success();
}
};

class ConvertAtenShapeToTensorPatternOp
: public OpConversionPattern<Aten_ShapeAsTensorOp> {
public:
Expand Down Expand Up @@ -70,14 +111,16 @@ class ConvertTorchToTensor
ConversionTarget target(*context);
target.addLegalDialect<arith::ArithDialect>();
target.addLegalDialect<tensor::TensorDialect>();
target.addIllegalOp<Torch::AtenItemOp>();
target.addIllegalOp<Torch::Aten_ShapeAsTensorOp>();

TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);

RewritePatternSet patterns(context);
patterns.add<ConvertAtenShapeToTensorPatternOp>(typeConverter, context);
patterns.add<ConvertAtenShapeToTensorPatternOp, ConvertAtenItemOp>(
typeConverter, context);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/TorchConversion/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "torch-mlir/Conversion/TorchToTensor/TorchToTensor.h"
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
Expand Down Expand Up @@ -76,6 +77,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
pm.addNestedPass<func::FuncOp>(createConvertTorchToLinalgPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToTensorPass());
pm.addPass(createConvertTorchConversionToMLProgramPass());
pm.addNestedPass<func::FuncOp>(memref::createExpandOpsPass());

Expand Down
38 changes: 38 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,41 @@ def forward(self, val):
@register_test_case(module_factory=lambda: AtenIntTensorCharDtypeModule())
def AtenIntTensorCharDtypeModule_basic(module, tu: TestUtils):
module.forward(tu.randint(low=-100, high=100).to(dtype=torch.int8))

# ==============================================================================

class AtenItemIntOpModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([], torch.int8, True),
])

def forward(self, val):
return int(val)

@register_test_case(module_factory=lambda: AtenItemIntOpModule())
def AtenItemIntOpModule_basic(module, tu: TestUtils):
module.forward(tu.randint(low=-100, high=100).to(dtype=torch.int8))

# ==============================================================================

class AtenItemFpOpModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([], torch.float, True),
])

def forward(self, val):
return float(val)

@register_test_case(module_factory=lambda: AtenItemFpOpModule())
def AtenItemFpOpModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1))

0 comments on commit 26713dd

Please sign in to comment.