Skip to content

Commit

Permalink
Revert "Decompose AtenNonzeroOp" (llvm#3289)
Browse files Browse the repository at this point in the history
Reverts llvm#3281
  • Loading branch information
vivekkhandelwal1 authored May 6, 2024
1 parent 17c3c15 commit e60160d
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 25 deletions.
18 changes: 0 additions & 18 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1201,21 +1201,6 @@ class DecomposeAtenIsposinfOp : public OpRewritePattern<AtenIsposinfOp> {
};
} // namespace

namespace {
class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
public:
using OpRewritePattern<AtenNonzeroOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNonzeroOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value zeroScalar =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
rewriter.replaceOpWithNewOp<AtenNeScalarOp>(op, op.getType(), op.getSelf(),
zeroScalar);
return success();
}
};
} // namespace
namespace {
class DecomposeAtenReshapeOp : public OpRewritePattern<AtenReshapeOp> {
public:
Expand Down Expand Up @@ -7755,13 +7740,10 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenZeroOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeMOp>(patterns);
// is-xxx ops
addPatternIfTargetOpIsIllegal<DecomposeAtenIsnanOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIsinfOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIsneginfOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIsposinfOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNonzeroOp>(patterns);

addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
Expand Down
7 changes: 0 additions & 7 deletions test/Dialect/Torch/decompose-complex-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,3 @@ func.func @torch.aten.type_as$fold(%arg0: !torch.tensor<[?], f16>, %arg1: !torch
%0 = torch.aten.type_as %arg0, %arg1 : !torch.tensor<[?], f16>, !torch.tensor<[?,?],f16> -> !torch.tensor<[?], f16>
return %0 : !torch.tensor<[?], f16>
}

// -----
// CHECK-LABEL: func.func @torch.aten.nonzero
func.func @torch.aten.nonzero(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> {
%0 = torch.aten.nonzero %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],si64>
return %0 : !torch.vtensor<[3,4,5],si64>
}

0 comments on commit e60160d

Please sign in to comment.