From e60160d79362e0a6d7680d7fd7e14edb960a5d34 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Mon, 6 May 2024 22:22:04 +0530 Subject: [PATCH] Revert "Decompose AtenNonzeroOp" (#3289) Reverts llvm/torch-mlir#3281 --- .../Torch/Transforms/DecomposeComplexOps.cpp | 18 ------------------ test/Dialect/Torch/decompose-complex-ops.mlir | 7 ------- 2 files changed, 25 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 5354ca2339db..cc21f2155e46 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1201,21 +1201,6 @@ class DecomposeAtenIsposinfOp : public OpRewritePattern { }; } // namespace -namespace { -class DecomposeAtenNonzeroOp : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenNonzeroOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value zeroScalar = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); - rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), - zeroScalar); - return success(); - } -}; -} // namespace namespace { class DecomposeAtenReshapeOp : public OpRewritePattern { public: @@ -7755,13 +7740,10 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - // is-xxx ops addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 39b150339392..530160f990ae 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -78,10 +78,3 @@ func.func @torch.aten.type_as$fold(%arg0: !torch.tensor<[?], f16>, %arg1: !torch %0 = torch.aten.type_as %arg0, %arg1 : !torch.tensor<[?], f16>, !torch.tensor<[?,?],f16> -> !torch.tensor<[?], f16> return %0 : !torch.tensor<[?], f16> } - -// ----- -// CHECK-LABEL: func.func @torch.aten.nonzero -func.func @torch.aten.nonzero(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> { - %0 = torch.aten.nonzero %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],si64> - return %0 : !torch.vtensor<[3,4,5],si64> -}