Skip to content

Commit

Permalink
[Torch Dialect] add fold pattern for aten.clone
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyunqu committed Jan 30, 2024
1 parent 1d6aca3 commit cf6d35c
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 1 deletion.
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -9101,6 +9101,7 @@ def Torch_AtenCloneOp : Torch_Op<"aten.clone", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenLiftFreshCopyOp : Torch_Op<"aten.lift_fresh_copy", [
Expand Down
12 changes: 12 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1658,6 +1658,18 @@ void AtenMaskedFillTensorOp::getCanonicalizationPatterns(
});
}

//===----------------------------------------------------------------------===//
// AtenCloneOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenCloneOp::fold(FoldAdaptor adaptor) {
// self should have value semantics
if (llvm::dyn_cast<ValueTensorType>(getSelf().getType())) {
return getSelf();
}
return {};
}

//===----------------------------------------------------------------------===//
// AtenSortIntOp
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,7 @@
"BroadcastZeroRankInputStaticModule_basic",
"BucketizeTensorStaticFloatModule_basic",
"BucketizeTensorStaticModule_basic",
"CloneModule_basic",
"ChunkListUnpackUneven_Module_basic",
"ChunkListUnpack_Module_basic",
"ConstantBoolParameterModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::one_hot : (Tensor, int) -> (Tensor)")
emit("aten::einsum : (str, Tensor[], int[]?) -> (Tensor)")
emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)")
emit("aten::clone : (Tensor, int?) -> (Tensor)")
emit("aten::clone : (Tensor, int?) -> (Tensor)", has_folder=True)
emit("aten::lift_fresh_copy : (Tensor) -> (Tensor)")
emit("aten::contiguous : (Tensor, int) -> (Tensor)")
emit_with_mutating_variants("aten::copy : (Tensor, Tensor, bool) -> (Tensor)")
Expand Down
20 changes: 20 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4994,3 +4994,23 @@ def forward(self, x):
@register_test_case(module_factory=lambda: IscloseStaticModuleTrue())
def IscloseStaticModuleTrue_basic(module, tu: TestUtils):
module.forward(torch.ones(5, 5))


# ==============================================================================

class CloneModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([5, 5], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.clone(x)

@register_test_case(module_factory=lambda: CloneModule())
def CloneModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 5))

0 comments on commit cf6d35c

Please sign in to comment.