Skip to content

Commit

Permalink
[Torch Dialect] Decompose AtenTriuOp (#2561)
Browse files Browse the repository at this point in the history
decompose like:
```
import torch

def my_triu(x, diag):
    rows = torch.ops.aten.size(x, -2)
    cols = torch.ops.aten.size(x, -1)

    row_indices = torch.ops.aten.arange(rows).unsqueeze(1)
    col_indices = torch.ops.aten.arange(cols).unsqueeze(0)

    cond = torch.ops.aten.ge(
        col_indices, torch.ops.aten.add(row_indices, diag))
    return torch.ops.aten.where(cond, x, 0)

x = torch.rand(5, 7)
assert torch.allclose(my_triu(x, 0), torch.triu(x, 0))
assert torch.allclose(my_triu(x, 1), torch.triu(x, 1))
assert torch.allclose(my_triu(x, 2), torch.triu(x, 2))
assert torch.allclose(my_triu(x, -1), torch.triu(x, -1))
```

---------

Co-authored-by: LiuYuanqiang <liuyuanqiang.yqliu@bytedance.com>
  • Loading branch information
Mi-Jiazhi and qingyunqu authored Nov 29, 2023
1 parent 49fdc1a commit f7a92d3
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 0 deletions.
57 changes: 57 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,62 @@ class DecomposeAtenAmaxOp : public OpRewritePattern<AtenAmaxOp> {
};
} // end namespace

namespace {
class DecomposeAtenTriuOp : public OpRewritePattern<AtenTriuOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenTriuOp op,
PatternRewriter &rewriter) const override {
MLIRContext *context = op.getContext();
Location loc = op.getLoc();
Value input = op.getSelf();
auto inputType = input.getType().cast<BaseTensorType>();
if (!inputType.hasSizes() || !inputType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "should have shape and dtype");
}
if (inputType.getSizes().size() < 2) {
return rewriter.notifyMatchFailure(op, "the rank of tensor should >= 2");
}

auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
Value cstZero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value cstOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value none = rewriter.create<ConstantNoneOp>(loc);

Value rowDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(-2));
Value colDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(-1));
Value rowSize = rewriter.create<AtenSizeIntOp>(loc, input, rowDim);
Value colSize = rewriter.create<AtenSizeIntOp>(loc, input, colDim);

Value rowArange = rewriter.create<AtenArangeOp>(
loc, baseType, rowSize, /*dtype=*/none, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none);
Value colArange = rewriter.create<AtenArangeOp>(
loc, baseType, colSize, /*dtype=*/none, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none);

Value unsqueezeRowArange =
rewriter.create<AtenUnsqueezeOp>(loc, baseType, rowArange, cstOne);
Value unsqueezeColArange =
rewriter.create<AtenUnsqueezeOp>(loc, baseType, colArange, cstZero);

Value unsqueezeRowArangePlusDiagonal = rewriter.create<AtenAddScalarOp>(
loc, baseType, unsqueezeRowArange, op.getDiagonal(), cstOne);

Value condTensor = rewriter.create<AtenGeTensorOp>(
loc, baseType, unsqueezeColArange, unsqueezeRowArangePlusDiagonal);

rewriter.replaceOpWithNewOp<AtenWhereScalarOtherOp>(
op, op.getResult().getType(), condTensor, input, cstZero);
return success();
}
};
} // namespace

namespace {
class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
public:
Expand Down Expand Up @@ -5817,6 +5873,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenTileOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeAsOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTriuOp>(patterns);

GreedyRewriteConfig config;
config.useTopDownTraversal = true;
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenTypeAsOp>();
target.addIllegalOp<AtenTileOp>();
target.addIllegalOp<AtenReshapeAsOp>();
target.addIllegalOp<AtenTriuOp>();
for (auto &opName : backendLegalOpsSet) {
target.addLegalOp(
OperationName(kTorchOpPrefix + opName.first().str(), context));
Expand Down
46 changes: 46 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3251,6 +3251,52 @@ def AtenTriuWithPosDiagonalModule_basic(module, tu: TestUtils):
# ==============================================================================


class TriuModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([4,5], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.triu(x, 1)


@register_test_case(module_factory=lambda: TriuModule())
def TriuModule_basic(module, tu: TestUtils):
x=torch.tensor([[ 0.5876, -0.0794, -1.8373, 0.6654, 0.2],
[-0.2447, 0.9556, -1.2919, 1.3378, 0.3],
[ 0.4333, 0.3146, 0.6576, -1.0432, 0.4],
[-0.9888, torch.nan, torch.inf, -torch.inf, 0.5]])
module.forward(x)


# ==============================================================================


class TriuBroadcastModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3,4,5,6], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.triu(x, 2)


@register_test_case(module_factory=lambda: TriuBroadcastModule())
def TriuBroadcastModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3,4,5,6))


# ==============================================================================


class AtenTriuWithNegDiagonalModule(torch.nn.Module):

def __init__(self):
Expand Down

0 comments on commit f7a92d3

Please sign in to comment.