Skip to content

Commit

Permalink
[Torch Dialect] add fold pattern for aten.clone (#2804)
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyunqu authored Jan 31, 2024
1 parent 25a5a22 commit d778950
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 17 deletions.
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -9101,6 +9101,7 @@ def Torch_AtenCloneOp : Torch_Op<"aten.clone", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenLiftFreshCopyOp : Torch_Op<"aten.lift_fresh_copy", [
Expand Down
1 change: 0 additions & 1 deletion lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1763,7 +1763,6 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
#define INSERT_UNARY_PATTERN(AtenOp, StablehloOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenUnaryOp<AtenOp, StablehloOp>>(typeConverter, context)
INSERT_UNARY_PATTERN(AtenCloneOp, stablehlo::ConvertOp);
INSERT_UNARY_PATTERN(AtenNegOp, stablehlo::NegOp);
INSERT_UNARY_PATTERN(AtenLogicalNotOp, stablehlo::NotOp);
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, stablehlo::NotOp);
Expand Down
13 changes: 13 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1662,6 +1662,19 @@ void AtenMaskedFillTensorOp::getCanonicalizationPatterns(
});
}

//===----------------------------------------------------------------------===//
// AtenCloneOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenCloneOp::fold(FoldAdaptor adaptor) {
// note: memory_format would be ignored
if (llvm::dyn_cast<ValueTensorType>(getSelf().getType())) {
// self should have value semantics
return getSelf();
}
return {};
}

//===----------------------------------------------------------------------===//
// AtenSortIntOp
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,7 @@
"BroadcastZeroRankInputStaticModule_basic",
"BucketizeTensorStaticFloatModule_basic",
"BucketizeTensorStaticModule_basic",
"CloneModule_basic",
"ChunkListUnpackUneven_Module_basic",
"ChunkListUnpack_Module_basic",
"ConstantBoolParameterModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::one_hot : (Tensor, int) -> (Tensor)")
emit("aten::einsum : (str, Tensor[], int[]?) -> (Tensor)")
emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)")
emit("aten::clone : (Tensor, int?) -> (Tensor)")
emit("aten::clone : (Tensor, int?) -> (Tensor)", has_folder=True)
emit("aten::lift_fresh_copy : (Tensor) -> (Tensor)")
emit("aten::contiguous : (Tensor, int) -> (Tensor)")
emit_with_mutating_variants("aten::copy : (Tensor, Tensor, bool) -> (Tensor)")
Expand Down
20 changes: 20 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4994,3 +4994,23 @@ def forward(self, x):
@register_test_case(module_factory=lambda: IscloseStaticModuleTrue())
def IscloseStaticModuleTrue_basic(module, tu: TestUtils):
module.forward(torch.ones(5, 5))


# ==============================================================================

class CloneModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([5, 5], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.clone(x)

@register_test_case(module_factory=lambda: CloneModule())
def CloneModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 5))
15 changes: 0 additions & 15 deletions test/Conversion/TorchToStablehlo/basic.mlir
Original file line number Diff line number Diff line change
@@ -1,21 +1,6 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s


// -----

// CHECK-LABEL: func.func @torch.aten.clone$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[T1:.*]] = stablehlo.convert %[[T0]] : tensor<?x?xf32>
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%none = torch.constant.none
%0 = torch.aten.clone %arg0, %none : !torch.vtensor<[?,?],f32>, !torch.none -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}

// -----

// CHECK-LABEL: func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> {
Expand Down

0 comments on commit d778950

Please sign in to comment.