From a18021e6754afca2b45b2b254b43a5f9c87685ab Mon Sep 17 00:00:00 2001 From: Ze Zhang Date: Fri, 13 Oct 2023 12:11:46 -0700 Subject: [PATCH] update e2e test --- e2e_testing/xfail_sets.py | 2 ++ .../torch_mlir_e2e_test/test_suite/basic.py | 22 +++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 42dfaae47ef25..f0ad16af46761 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -18,6 +18,7 @@ # 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8 "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "IscloseStaticModule_basic", + "IscloseStaticModuleTrue_basic", } TORCHDYNAMO_XFAIL_SET = { @@ -929,6 +930,7 @@ # and very few tests work yet. TOSA_PASS_SET = { "IscloseStaticModule_basic", + "IscloseStaticModuleTrue_basic", "TileBigDimsSizeModule_basic", "TileSmallDimsSizeModule_basic", "IndexPutImpl2DNoneIndexStaticModule_basic", diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 8590eb6b0a268..c5df42a230e64 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -4581,3 +4581,25 @@ def forward(self, x, y): @register_test_case(module_factory=lambda: IscloseStaticModule()) def IscloseStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 5), tu.rand(5, 5)) + + +# ============================================================================== + + +class IscloseStaticModuleTrue(torch.nn.Module): + + def __init__(self): + super().__init__() + self.register_buffer('tensor', torch.ones(1)) + + @export + @annotate_args([ + None, + ([5, 5], torch.float32, True), + ]) + def forward(self, x): + return torch.isclose(x, self.tensor) + +@register_test_case(module_factory=lambda: IscloseStaticModuleTrue()) +def IscloseStaticModuleTrue_basic(module, tu: TestUtils): + module.forward(torch.ones(5, 5))