Skip to content

Commit

Permalink
Revert "[MLIR][TORCH] Only unroll prim loop-like ops within a `torch.…
Browse files Browse the repository at this point in the history
…shape.calculate` region (llvm#3812)"

This reverts commit 55ff110.
  • Loading branch information
Max191 committed Oct 31, 2024
1 parent 8b0bf2e commit 1570c15
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,16 @@ class FoldPrimUncheckedCastOp : public OpRewritePattern<PrimUncheckedCastOp> {
} // namespace

namespace {
// TODO: Only unroll inside the shape calculation region.
// Maybe do this by only applying patterns and folding greedily on the ops
// inside the region + the shape.calculate op itself?
class FullyUnrollPrimLoopOp : public OpRewritePattern<PrimLoopOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PrimLoopOp op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
MLIRContext *context = op->getContext();
// Only unroll loops if they are contained in a shape calculate region.
Region *region = op->getParentRegion();
Operation *parentOp = region->getParentOp();
if (!parentOp || !isa<Torch::ShapeCalculateOp>(parentOp))
return rewriter.notifyMatchFailure(
op, "Loop is not contained in a shape calculation region.");
if (!op.isForLike())
return rewriter.notifyMatchFailure(op, "Loop is not for-like");
int64_t maxTripCount;
Expand Down
17 changes: 0 additions & 17 deletions test/Dialect/Torch/simplify-shape-calculations.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -152,23 +152,6 @@ func.func @fully_unroll_prim_loop$no_unroll(%arg0: !torch.vtensor, %arg1: !torch
return %0 : !torch.vtensor
}

// CHECK-LABEL: func.func @fully_unroll_prim_loop$outside_region(
// CHECK: %[[LOOP:.*]] = torch.prim.Loop
func.func @fully_unroll_prim_loop$outside_region(%arg0: !torch.vtensor, %arg1: !torch.list<int>, %arg2: !torch.int) -> !torch.vtensor {
%true = torch.constant.bool true
%0 = torch.prim.Loop %arg2, %true, init(%arg0) {
^bb0(%arg3: !torch.int, %arg4: !torch.vtensor):
%1 = torch.shape.calculate {
torch.shape.calculate.yield %arg4 : !torch.vtensor
} shapes {
torch.prim.Print(%arg3) : !torch.int
torch.shape.calculate.yield.shapes %arg1 : !torch.list<int>
} : !torch.vtensor
torch.prim.Loop.condition %true, iter(%1 : !torch.vtensor)
} : (!torch.int, !torch.bool, !torch.vtensor) -> !torch.vtensor
return %0 : !torch.vtensor
}

// CHECK-LABEL: func.func @abstractly_interpret_list_ops$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.int,
Expand Down

0 comments on commit 1570c15

Please sign in to comment.