Skip to content

Commit

Permalink
update e2e test
Browse files Browse the repository at this point in the history
  • Loading branch information
Ze Zhang committed Oct 15, 2023
1 parent 70e8dff commit a8fb81f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
2 changes: 2 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"UnflattenStaticModule_basic",
"IscloseStaticModule_basic",
"IscloseStaticModuleTrue_basic",
}

TORCHDYNAMO_XFAIL_SET = {
Expand Down Expand Up @@ -930,6 +931,7 @@
# and very few tests work yet.
TOSA_PASS_SET = {
"IscloseStaticModule_basic",
"IscloseStaticModuleTrue_basic",
"TileBigDimsSizeModule_basic",
"TileSmallDimsSizeModule_basic",
"IndexPutImpl2DNoneIndexStaticModule_basic",
Expand Down
22 changes: 22 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4603,3 +4603,25 @@ def forward(self, x, y):
@register_test_case(module_factory=lambda: IscloseStaticModule())
def IscloseStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 5), tu.rand(5, 5))


# ==============================================================================


class IscloseStaticModuleTrue(torch.nn.Module):

def __init__(self):
super().__init__()
self.register_buffer('tensor', torch.ones(1))

@export
@annotate_args([
None,
([5, 5], torch.float32, True),
])
def forward(self, x):
return torch.isclose(x, self.tensor)

@register_test_case(module_factory=lambda: IscloseStaticModuleTrue())
def IscloseStaticModuleTrue_basic(module, tu: TestUtils):
module.forward(torch.ones(5, 5))

0 comments on commit a8fb81f

Please sign in to comment.