Skip to content

Commit

Permalink
[FxImporter] Add fx importer to stablehlo e2e test config (llvm#3183)
Browse files Browse the repository at this point in the history
  • Loading branch information
penguin-wwy authored Apr 19, 2024
1 parent 6c4f7de commit 0a60734
Show file tree
Hide file tree
Showing 3 changed files with 421 additions and 9 deletions.
15 changes: 12 additions & 3 deletions projects/pt1/e2e_testing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,19 @@
TORCHDYNAMO_CRASHING_SET,
ONNX_CRASHING_SET,
ONNX_XFAIL_SET,
FX_IMPORT_XFAIL_SET,
FX_IMPORTER_XFAIL_SET,
FX_IMPORTER_CRASHING_SET,
FX_IMPORTER_STABLEHLO_XFAIL_SET,
FX_IMPORTER_STABLEHLO_CRASHING_SET,
)

# Import tests to register them in the global registry.
from torch_mlir_e2e_test.test_suite import register_all_tests
register_all_tests()

def _get_argparse():
config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo", "onnx", "fx_importer"]
config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "make_fx_tosa", "tosa", "lazy_tensor_core",
"torchdynamo", "onnx", "fx_importer", "fx_importer_stablehlo"]
parser = argparse.ArgumentParser(description="Run torchscript e2e tests.")
parser.add_argument("-c", "--config",
choices=config_choices,
Expand All @@ -67,6 +70,8 @@ def _get_argparse():
"lazy_tensor_core": run the model through the Lazy Tensor Core frontend and execute the traced graph.
"torchdynamo": run the model through the TorchDynamo frontend and execute the graph using Linalg-on-Tensors.
"onnx": export to the model via onnx and reimport using the torch-onnx-to-torch path.
"fx_importer": run the model through the fx importer frontend and execute the graph using Linalg-on-Tensors.
"fx_importer_stablehlo": run the model through the fx importer frontend and execute the graph using Stablehlo backend.
""")
parser.add_argument("-f", "--filter", default=".*", help="""
Regular expression specifying which tests to include in this run.
Expand Down Expand Up @@ -127,8 +132,12 @@ def main():
crashing_set = LTC_CRASHING_SET
elif args.config == "fx_importer":
config = FxImporterTestConfig(RefBackendLinalgOnTensorsBackend())
xfail_set = FX_IMPORT_XFAIL_SET
xfail_set = FX_IMPORTER_XFAIL_SET
crashing_set = FX_IMPORTER_CRASHING_SET
elif args.config == "fx_importer_stablehlo":
config = FxImporterTestConfig(LinalgOnTensorsStablehloBackend(), "stablehlo")
xfail_set = FX_IMPORTER_STABLEHLO_XFAIL_SET
crashing_set = FX_IMPORTER_STABLEHLO_CRASHING_SET
elif args.config == "torchdynamo":
config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend())
xfail_set = TORCHDYNAMO_XFAIL_SET
Expand Down
Loading

0 comments on commit 0a60734

Please sign in to comment.