Skip to content

Commit

Permalink
[Torch Dialect] decompose all index_put-like op to aten.index_put.hac…
Browse files Browse the repository at this point in the history
…ked_twin for stricter semantics (llvm#3071)

This PR decomposes all index_put-like op to aten.index_put.hacked_twin for stricter semantics, i.e., no None index in indices argument.
  • Loading branch information
Vremold authored May 8, 2024
1 parent abef114 commit 346a536
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 201 deletions.
21 changes: 5 additions & 16 deletions lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -675,12 +675,12 @@ static Value collapseAndMoveBatchDims(Location loc, Value values, int64_t batch,
return b.create<AtenViewOp>(loc, valuesTy, values, outDimsList);
}

class ConvertAten_IndexPutImplOp
: public OpConversionPattern<Aten_IndexPutImplOp> {
class ConvertAtenIndexPutHackedTwinOp
: public OpConversionPattern<AtenIndexPutHackedTwinOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(Aten_IndexPutImplOp op, OpAdaptor adaptor,
matchAndRewrite(AtenIndexPutHackedTwinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Expand All @@ -699,17 +699,6 @@ class ConvertAten_IndexPutImplOp
return rewriter.notifyMatchFailure(
op, "unimplemented: the values tensor type must have sizes.");

// The unsafe should be either `False` or `none`.
if (!op.getUnsafe().getType().isa<Torch::NoneType>()) {
bool unsafe;
if (!matchPattern(op.getUnsafe(), m_TorchConstantBool(&unsafe)))
return rewriter.notifyMatchFailure(
op, "unimplemented: unsafe must be a constant");
else if (unsafe)
return rewriter.notifyMatchFailure(
op, "unimplemented: unsafe is expected to be false");
}

// The accumulate should be a torch constant of boolean type.
bool accumulate;
if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate)))
Expand Down Expand Up @@ -1621,8 +1610,8 @@ class ConvertTorchToTMTensor
RewritePatternSet patterns(context);
target.addIllegalOp<AtenBincountOp>();
patterns.add<ConvertAtenBincountOp>(typeConverter, context);
target.addIllegalOp<Aten_IndexPutImplOp>();
patterns.add<ConvertAten_IndexPutImplOp>(typeConverter, context);
target.addIllegalOp<AtenIndexPutHackedTwinOp>();
patterns.add<ConvertAtenIndexPutHackedTwinOp>(typeConverter, context);
target.addIllegalOp<AtenMaxPool2dWithIndicesBackwardOp>();
patterns.add<ConvertAtenMaxPool2dWithIndicesBackwardOp>(typeConverter,
context);
Expand Down
6 changes: 3 additions & 3 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3575,8 +3575,8 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
}

template <>
LogicalResult ConvertAtenOp<Aten_IndexPutImplOp>::matchAndRewrite(
Aten_IndexPutImplOp op, OpAdaptor adaptor,
LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
AtenIndexPutHackedTwinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// a = torch.tensor([[0, 1, 2, 3]])
// a[..., 1:] = torch.tensor([4, 5, 6])
Expand Down Expand Up @@ -5331,7 +5331,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
INSERT_ATENOP_PATTERN(AtenGatherOp);
INSERT_ATENOP_PATTERN(Aten_IndexPutImplOp);
INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp);
INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp);
INSERT_ATENOP_PATTERN(AtenAbsOp);
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
Expand Down
Loading

0 comments on commit 346a536

Please sign in to comment.