diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 82b955f4407a..d268a31ddf9e 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -19,6 +19,7 @@ "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "UnflattenStaticModule_basic", "IscloseStaticModule_basic", + "IscloseStaticModuleTrue_basic", } TORCHDYNAMO_XFAIL_SET = { @@ -930,6 +931,7 @@ # and very few tests work yet. TOSA_PASS_SET = { "IscloseStaticModule_basic", + "IscloseStaticModuleTrue_basic", "TileBigDimsSizeModule_basic", "TileSmallDimsSizeModule_basic", "IndexPutImpl2DNoneIndexStaticModule_basic", diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index a91fbdb64a53..d78253a58fc3 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -4603,3 +4603,25 @@ def forward(self, x, y): @register_test_case(module_factory=lambda: IscloseStaticModule()) def IscloseStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 5), tu.rand(5, 5)) + + +# ============================================================================== + + +class IscloseStaticModuleTrue(torch.nn.Module): + + def __init__(self): + super().__init__() + self.register_buffer('tensor', torch.ones(1)) + + @export + @annotate_args([ + None, + ([5, 5], torch.float32, True), + ]) + def forward(self, x): + return torch.isclose(x, self.tensor) + +@register_test_case(module_factory=lambda: IscloseStaticModuleTrue()) +def IscloseStaticModuleTrue_basic(module, tu: TestUtils): + module.forward(torch.ones(5, 5))