diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c09900ce8ecc..81ee9844af30 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -9101,6 +9101,7 @@ def Torch_AtenCloneOp : Torch_Op<"aten.clone", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasFolder = 1; } def Torch_AtenLiftFreshCopyOp : Torch_Op<"aten.lift_fresh_copy", [ diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 00c9fcd7b88f..33db9ac9ee54 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -1763,7 +1763,6 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( #define INSERT_UNARY_PATTERN(AtenOp, StablehloOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context) - INSERT_UNARY_PATTERN(AtenCloneOp, stablehlo::ConvertOp); INSERT_UNARY_PATTERN(AtenNegOp, stablehlo::NegOp); INSERT_UNARY_PATTERN(AtenLogicalNotOp, stablehlo::NotOp); INSERT_UNARY_PATTERN(AtenBitwiseNotOp, stablehlo::NotOp); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 4aacd8d7693e..5877a35495e0 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1662,6 +1662,19 @@ void AtenMaskedFillTensorOp::getCanonicalizationPatterns( }); } +//===----------------------------------------------------------------------===// +// AtenCloneOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenCloneOp::fold(FoldAdaptor adaptor) { + // note: memory_format would be ignored + if (llvm::dyn_cast(getSelf().getType())) { + // self should have value semantics + return getSelf(); + } + return {}; +} + //===----------------------------------------------------------------------===// // AtenSortIntOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index d6fe18809b5f..7c1cc16261d0 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1021,6 +1021,7 @@ "BroadcastZeroRankInputStaticModule_basic", "BucketizeTensorStaticFloatModule_basic", "BucketizeTensorStaticModule_basic", + "CloneModule_basic", "ChunkListUnpackUneven_Module_basic", "ChunkListUnpack_Module_basic", "ConstantBoolParameterModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 3b930c20e79d..a329c1ae01a4 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -592,7 +592,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::one_hot : (Tensor, int) -> (Tensor)") emit("aten::einsum : (str, Tensor[], int[]?) -> (Tensor)") emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)") - emit("aten::clone : (Tensor, int?) -> (Tensor)") + emit("aten::clone : (Tensor, int?) -> (Tensor)", has_folder=True) emit("aten::lift_fresh_copy : (Tensor) -> (Tensor)") emit("aten::contiguous : (Tensor, int) -> (Tensor)") emit_with_mutating_variants("aten::copy : (Tensor, Tensor, bool) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 51deffb6175a..91c3112135d0 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -4994,3 +4994,23 @@ def forward(self, x): @register_test_case(module_factory=lambda: IscloseStaticModuleTrue()) def IscloseStaticModuleTrue_basic(module, tu: TestUtils): module.forward(torch.ones(5, 5)) + + +# ============================================================================== + +class CloneModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([5, 5], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.clone(x) + +@register_test_case(module_factory=lambda: CloneModule()) +def CloneModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 5)) diff --git a/test/Conversion/TorchToStablehlo/basic.mlir b/test/Conversion/TorchToStablehlo/basic.mlir index e0ab6bf1502b..b502d3ffcce9 100644 --- a/test/Conversion/TorchToStablehlo/basic.mlir +++ b/test/Conversion/TorchToStablehlo/basic.mlir @@ -1,21 +1,6 @@ // RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s -// ----- - -// CHECK-LABEL: func.func @torch.aten.clone$basic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[T1:.*]] = stablehlo.convert %[[T0]] : tensor -// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %none = torch.constant.none - %0 = torch.aten.clone %arg0, %none : !torch.vtensor<[?,?],f32>, !torch.none -> !torch.vtensor<[?,?],f32> - return %0 : !torch.vtensor<[?,?],f32> -} - // ----- // CHECK-LABEL: func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> {