From bfcf93ea2191a760ed14bbefc76f3cb34806806a Mon Sep 17 00:00:00 2001 From: saienduri <77521230+saienduri@users.noreply.github.com> Date: Tue, 6 Feb 2024 19:07:59 -0800 Subject: [PATCH] Rename torch_mlir.compile APIs and introduce FX based analogs (#2842) Link to related RFC: https://discourse.llvm.org/t/rfc-rename-torch-mlir-compile-apis-and-introduce-fx-based-analogs/76646 This commit updates the documentation, tests, CMake files, and API for the proposed changes in the RFC. There is a new torch_mlir/fx.py for user level APIs related to importing modules and a corresponding test for this path can be found at test/python/fx_importer/basic_test.py. --------- Co-authored-by: MaheshRavishankar --- docs/architecture.md | 2 +- docs/development.md | 25 ++++++++++++++----- docs/{long_term_roadmap.md => roadmap.md} | 16 ++++++++++++ projects/pt1/examples/torchdynamo_resnet18.py | 4 +-- projects/pt1/examples/torchscript_resnet18.py | 4 +-- .../torchscript_resnet18_all_output_types.py | 8 +++--- .../torchscript_resnet_inference.ipynb | 4 +-- .../torchscript_stablehlo_backend_resnet.py | 4 +-- .../torchscript_stablehlo_backend_tinybert.py | 4 +-- projects/pt1/python/CMakeLists.txt | 2 +- .../test/compile_api/already_scripted.py | 8 +++--- .../python/test/compile_api/already_traced.py | 8 +++--- .../test/compile_api/backend_legal_ops.py | 6 ++--- projects/pt1/python/test/compile_api/basic.py | 20 +++++++-------- .../pt1/python/test/compile_api/make_fx.py | 4 +-- .../test/compile_api/multiple_methods.py | 10 ++++---- .../test/compile_api/output_type_spec.py | 6 ++--- .../pt1/python/test/compile_api/tracing.py | 20 +++++++-------- projects/pt1/python/torch_mlir/dynamo.py | 2 +- .../{__init__.py => torchscript.py} | 6 ++--- .../configs/linalg_on_tensors_backend.py | 4 +-- .../configs/stablehlo_backend.py | 4 +-- .../configs/torchdynamo.py | 2 +- .../configs/tosa_backend.py | 4 +-- .../pt1/python/torch_mlir_e2e_test/utils.py | 4 +-- .../test/python/custom_op_shape_dtype_fn.py | 4 +-- .../jit_ir/node_import/unimplemented.py | 6 ++--- python/CMakeLists.txt | 7 ++++++ python/torch_mlir/fx.py | 25 +++++++++++++++++++ test/python/compile.py | 4 +-- test/python/fx_importer/basic_test.py | 25 ++----------------- 31 files changed, 146 insertions(+), 106 deletions(-) rename docs/{long_term_roadmap.md => roadmap.md} (94%) rename projects/pt1/python/torch_mlir/{__init__.py => torchscript.py} (99%) create mode 100644 python/torch_mlir/fx.py diff --git a/docs/architecture.md b/docs/architecture.md index 4c102e140d7a..e2ef378bd99c 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -184,7 +184,7 @@ semantics. And often users want to erase the shapes in the trace to allow dynamic shapes for the trace. Additionally, the Python-level data structures and APIs are very parallel between `torch.jit.script` and `torch.jit.trace`, so we consider both of those as the same from the perspective of the responsibilities -of the compiler. Both are accessed via the `torch_mlir.compile` Python API. +of the compiler. Both are accessed via the `torch_mlir.torchscript.compile` Python API. ### Modeling the `torch.nn.Module` object (`IValue`) hierarchy for TorchScript diff --git a/docs/development.md b/docs/development.md index 782058a63ea7..3e9192f5fa8e 100644 --- a/docs/development.md +++ b/docs/development.md @@ -120,37 +120,50 @@ cmake --build build ### Linux and macOS ```shell -export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/projects/pt1/examples +export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/test/python/fx_importer ``` ### Windows PowerShell ```shell -$env:PYTHONPATH = "$PWD/build/tools/torch-mlir/python_packages/torch_mlir;$PWD/projects/pt1/examples" +$env:PYTHONPATH = "$PWD/build/tools/torch-mlir/python_packages/torch_mlir;$PWD/test/python/fx_importer" ``` ## Testing MLIR output in various dialects -To test the compiler's output to the different MLIR dialects, you can use the example `projects/pt1/examples/torchscript_resnet18_all_output_types.py`. +To test the MLIR output to torch dialect, you can use `test/python/fx_importer/basic_test.py`. Make sure you have activated the virtualenv and set the `PYTHONPATH` above (if running on Windows, modify the environment variable as shown above): ```shell source mlir_venv/bin/activate +export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/test/python/fx_importer +python test/python/fx_importer/basic_test.py +``` + +This will display the basic example in TORCH dialect. + +To test the compiler's output to the different MLIR dialects, you can also use the deprecated path +using torchscript with the example `projects/pt1/examples/torchscript_resnet18_all_output_types.py`. +This path doesn't give access to the current generation work that is being driven via the fx_importer +and may lead to errors. + +Same as above, but with different python path and example: +```shell export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/projects/pt1/examples python projects/pt1/examples/torchscript_resnet18_all_output_types.py ``` This will display the Resnet18 network example in three dialects: TORCH, LINALG on TENSORS and TOSA. -The main functionality is on `torch_mlir.compile()`'s `output_type`. +The main functionality is on `torch_mlir.torchscript.compile()`'s `output_type`. Ex: ```python -module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch") +module = torch_mlir.torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch") ``` -Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `STABLEHLO`. +`output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `STABLEHLO`. ## Jupyter diff --git a/docs/long_term_roadmap.md b/docs/roadmap.md similarity index 94% rename from docs/long_term_roadmap.md rename to docs/roadmap.md index 0f0940efc32d..f60502a52423 100644 --- a/docs/long_term_roadmap.md +++ b/docs/roadmap.md @@ -51,6 +51,22 @@ the ecosystem are: Most of this document describes long-term ecosystem changes that will address these, drastically improving Torch-MLIR's ability to meet its goals. +## Current API Paths + +Currently, there are two main API paths for the torch-mlir project: + +- The first path is part of the legacy project pt1 code + (torch_mlir.torchscript.compile). This allows users to test the compiler's + output to the different MLIR dialects (`TORCH`, `TOSA`, `LINALG_ON_TENSORS`, + `RAW` and `STABLEHLO`). This path is deprecated and doesn’t give access to + the current generation work that is being driven via the fx_importer. It is + tied to the old Torchscript path. +- The second path (torch_mlir.fx.export_and_import) allows users to import a + consolidated torch.export.ExportedProgram instance of an arbitrary Python + callable (an nn.Module, a function or a method) and output to torch dialect + mlir module. This path is aligned with PyTorch's roadmap, but the path is + not fully functional yet. + ## Roadmap ### Refactoring the frontend diff --git a/projects/pt1/examples/torchdynamo_resnet18.py b/projects/pt1/examples/torchdynamo_resnet18.py index d7abd80da665..377c632da36f 100644 --- a/projects/pt1/examples/torchdynamo_resnet18.py +++ b/projects/pt1/examples/torchdynamo_resnet18.py @@ -14,7 +14,7 @@ import torchvision.models as models from torchvision import transforms -import torch_mlir +from torch_mlir import torchscript from torch_mlir.dynamo import make_simple_dynamo_backend from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend @@ -71,7 +71,7 @@ def predictions(torch_func, jit_func, img, labels): @make_simple_dynamo_backend def refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): - mlir_module = torch_mlir.compile( + mlir_module = torchscript.compile( fx_graph, example_inputs, output_type="linalg-on-tensors") backend = refbackend.RefBackendLinalgOnTensorsBackend() compiled = backend.compile(mlir_module) diff --git a/projects/pt1/examples/torchscript_resnet18.py b/projects/pt1/examples/torchscript_resnet18.py index ac46e6f4523b..62e5eda6cc83 100644 --- a/projects/pt1/examples/torchscript_resnet18.py +++ b/projects/pt1/examples/torchscript_resnet18.py @@ -12,7 +12,7 @@ import torchvision.models as models from torchvision import transforms -import torch_mlir +from torch_mlir import torchscript from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend @@ -67,7 +67,7 @@ def predictions(torch_func, jit_func, img, labels): resnet18 = models.resnet18(pretrained=True) resnet18.train(False) -module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors") +module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors") backend = refbackend.RefBackendLinalgOnTensorsBackend() compiled = backend.compile(module) jit_module = backend.load(compiled) diff --git a/projects/pt1/examples/torchscript_resnet18_all_output_types.py b/projects/pt1/examples/torchscript_resnet18_all_output_types.py index a17fa40521d3..70a920550b2d 100644 --- a/projects/pt1/examples/torchscript_resnet18_all_output_types.py +++ b/projects/pt1/examples/torchscript_resnet18_all_output_types.py @@ -6,15 +6,15 @@ import torch import torchvision -import torch_mlir +from torch_mlir import torchscript resnet18 = torchvision.models.resnet18(pretrained=True) resnet18.eval() -module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch") +module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch") print("TORCH OutputType\n", module.operation.get_asm(large_elements_limit=10)) -module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors") +module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors") print("LINALG_ON_TENSORS OutputType\n", module.operation.get_asm(large_elements_limit=10)) # TODO: Debug why this is so slow. -module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="tosa") +module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="tosa") print("TOSA OutputType\n", module.operation.get_asm(large_elements_limit=10)) diff --git a/projects/pt1/examples/torchscript_resnet_inference.ipynb b/projects/pt1/examples/torchscript_resnet_inference.ipynb index 3ab7cc64dadb..9970f90b8bb2 100644 --- a/projects/pt1/examples/torchscript_resnet_inference.ipynb +++ b/projects/pt1/examples/torchscript_resnet_inference.ipynb @@ -184,7 +184,7 @@ "\n", "# Compile the model with an example input.\n", "# We lower to the linalg-on-tensors form that the reference backend supports.\n", - "compiled = torch_mlir.compile(TanhModule(), torch.ones(3), output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)\n", + "compiled = torch_mlir.torchscript.compile(TanhModule(), torch.ones(3), output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)\n", "# Load it on the reference backend.\n", "jit_module = compile_and_load_on_refbackend(compiled)\n", "# Run it!\n", @@ -326,7 +326,7 @@ "source": [ "resnet18 = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)\n", "resnet18.eval()\n", - "compiled = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type=\"linalg-on-tensors\")\n", + "compiled = torch_mlir.torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type=\"linalg-on-tensors\")\n", "jit_module = compile_and_load_on_refbackend(compiled)" ] }, diff --git a/projects/pt1/examples/torchscript_stablehlo_backend_resnet.py b/projects/pt1/examples/torchscript_stablehlo_backend_resnet.py index 7a97359cff62..e42828ed776e 100644 --- a/projects/pt1/examples/torchscript_stablehlo_backend_resnet.py +++ b/projects/pt1/examples/torchscript_stablehlo_backend_resnet.py @@ -1,13 +1,13 @@ import torch import torchvision.models as models -import torch_mlir +from torch_mlir import torchscript model = models.resnet18(pretrained=True) model.eval() data = torch.randn(2,3,200,200) out_stablehlo_mlir_path = "./resnet18_stablehlo.mlir" -module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=False) +module = torchscript.compile(model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=False) with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf: outf.write(str(module)) diff --git a/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py b/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py index c035be3a54fe..c68daf12dd86 100644 --- a/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py +++ b/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py @@ -1,5 +1,5 @@ import torch -import torch_mlir +from torch_mlir import torchscript from transformers import BertForMaskedLM @@ -17,7 +17,7 @@ def forward(self, data): data = torch.randint(30522, (2, 128)) out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir" -module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=True) +module = torchscript.compile(model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=True) with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf: outf.write(str(module)) diff --git a/projects/pt1/python/CMakeLists.txt b/projects/pt1/python/CMakeLists.txt index ce40426988a7..642b86b50490 100644 --- a/projects/pt1/python/CMakeLists.txt +++ b/projects/pt1/python/CMakeLists.txt @@ -18,7 +18,7 @@ declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" ADD_TO_PARENT TorchMLIRPythonTorchExtensionsSources SOURCES - __init__.py + torchscript.py _dynamo_fx_importer.py compiler_utils.py dynamo.py diff --git a/projects/pt1/python/test/compile_api/already_scripted.py b/projects/pt1/python/test/compile_api/already_scripted.py index 367170081228..7d9720727a38 100644 --- a/projects/pt1/python/test/compile_api/already_scripted.py +++ b/projects/pt1/python/test/compile_api/already_scripted.py @@ -6,7 +6,7 @@ # RUN: %PYTHON %s | FileCheck %s import torch -import torch_mlir +from torch_mlir import torchscript class BasicModule(torch.nn.Module): @@ -15,17 +15,17 @@ def sin(self, x): return torch.ops.aten.sin(x) -example_args = torch_mlir.ExampleArgs() +example_args = torchscript.ExampleArgs() example_args.add_method("sin", torch.ones(2, 3)) scripted = torch.jit.script(BasicModule()) -print(torch_mlir.compile(scripted, example_args)) +print(torchscript.compile(scripted, example_args)) # CHECK: module # CHECK-DAG: func.func @sin scripted = torch.jit.script(BasicModule()) try: # CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition. - torch_mlir.compile(scripted, torch_mlir.ExampleArgs().add_method("nonexistent", torch.ones(2, 3))) + torchscript.compile(scripted, torchscript.ExampleArgs().add_method("nonexistent", torch.ones(2, 3))) except Exception as e: print(e) diff --git a/projects/pt1/python/test/compile_api/already_traced.py b/projects/pt1/python/test/compile_api/already_traced.py index a719eb743c73..32f7b5653fca 100644 --- a/projects/pt1/python/test/compile_api/already_traced.py +++ b/projects/pt1/python/test/compile_api/already_traced.py @@ -6,23 +6,23 @@ # RUN: %PYTHON %s | FileCheck %s import torch -import torch_mlir +from torch_mlir import torchscript class BasicModule(torch.nn.Module): def forward(self, x): return torch.ops.aten.sin(x) example_arg = torch.ones(2, 3) -example_args = torch_mlir.ExampleArgs.get(example_arg) +example_args = torchscript.ExampleArgs.get(example_arg) traced = torch.jit.trace(BasicModule(), example_arg) -print(torch_mlir.compile(traced, example_args)) +print(torchscript.compile(traced, example_args)) # CHECK: module # CHECK-DAG: func.func @forward traced = torch.jit.trace(BasicModule(), example_arg) try: # CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition. - torch_mlir.compile(traced, torch_mlir.ExampleArgs().add_method("nonexistent", example_arg)) + torchscript.compile(traced, torchscript.ExampleArgs().add_method("nonexistent", example_arg)) except Exception as e: print(e) diff --git a/projects/pt1/python/test/compile_api/backend_legal_ops.py b/projects/pt1/python/test/compile_api/backend_legal_ops.py index 98c034930243..64ebf7a522fa 100644 --- a/projects/pt1/python/test/compile_api/backend_legal_ops.py +++ b/projects/pt1/python/test/compile_api/backend_legal_ops.py @@ -7,7 +7,7 @@ import torch -import torch_mlir +from torch_mlir import torchscript class AddmmModule(torch.nn.Module): def __init__(self): @@ -15,9 +15,9 @@ def __init__(self): def forward(self, x, y, z): return torch.ops.aten.addmm(x, y, z) -example_args = 3 * [torch_mlir.TensorPlaceholder([-1, -1], torch.float32)] +example_args = 3 * [torchscript.TensorPlaceholder([-1, -1], torch.float32)] -print(torch_mlir.compile(AddmmModule(), example_args, +print(torchscript.compile(AddmmModule(), example_args, output_type="torch", backend_legal_ops=["aten.addmm"])) # CHECK-LABEL: @forward # CHECK: torch.aten.addmm diff --git a/projects/pt1/python/test/compile_api/basic.py b/projects/pt1/python/test/compile_api/basic.py index 999d2fe4a820..0c516b620863 100644 --- a/projects/pt1/python/test/compile_api/basic.py +++ b/projects/pt1/python/test/compile_api/basic.py @@ -7,7 +7,7 @@ import torch -import torch_mlir +from torch_mlir import torchscript class TanhModule(torch.nn.Module): def __init__(self): @@ -18,24 +18,24 @@ def forward(self, x): tanh_example_input = torch.ones(2, 3) # Simplest case: One example argument. -print(torch_mlir.compile(TanhModule(), tanh_example_input)) +print(torchscript.compile(TanhModule(), tanh_example_input)) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> # Use a TensorPlaceholder to represent dynamic axes. -placeholder = torch_mlir.TensorPlaceholder.like(tanh_example_input, dynamic_axes=[1]) -print(torch_mlir.compile(TanhModule(), placeholder)) +placeholder = torchscript.TensorPlaceholder.like(tanh_example_input, dynamic_axes=[1]) +print(torchscript.compile(TanhModule(), placeholder)) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,?],f32> -> !torch.vtensor<[2,?],f32> # Explicitly construct a TensorPlaceholder. -placeholder = torch_mlir.TensorPlaceholder([-1, 2], torch.float32) -print(torch_mlir.compile(TanhModule(), placeholder)) +placeholder = torchscript.TensorPlaceholder([-1, 2], torch.float32) +print(torchscript.compile(TanhModule(), placeholder)) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[?,2],f32> -> !torch.vtensor<[?,2],f32> # Basic smoke test for the raw output type. -print(torch_mlir.compile(TanhModule(), tanh_example_input, output_type=torch_mlir.OutputType.RAW)) +print(torchscript.compile(TanhModule(), tanh_example_input, output_type=torchscript.OutputType.RAW)) # CHECK: torch.nn_module { # CHECK: } : !torch.nn.Module<"{{.*}}.TanhModule"> @@ -47,12 +47,12 @@ def forward(self, lhs, rhs ): # N > 1 inputs. mm_example_inputs = [torch.ones(2, 3), torch.ones(3, 4)] -print(torch_mlir.compile(MmModule(), mm_example_inputs)) +print(torchscript.compile(MmModule(), mm_example_inputs)) # CHECK-LABEL: @forward # CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[2,4],f32> # Mixes Tensor's and TensorPlaceholder's. -mm_dynamic_inputs = [mm_example_inputs[0], torch_mlir.TensorPlaceholder.like(mm_example_inputs[1], dynamic_axes=[1])] -print(torch_mlir.compile(MmModule(), mm_dynamic_inputs)) +mm_dynamic_inputs = [mm_example_inputs[0], torchscript.TensorPlaceholder.like(mm_example_inputs[1], dynamic_axes=[1])] +print(torchscript.compile(MmModule(), mm_dynamic_inputs)) # CHECK-LABEL: @forward # CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,?],f32> -> !torch.vtensor<[2,?],f32> diff --git a/projects/pt1/python/test/compile_api/make_fx.py b/projects/pt1/python/test/compile_api/make_fx.py index 62add20a576b..ec859d86e369 100644 --- a/projects/pt1/python/test/compile_api/make_fx.py +++ b/projects/pt1/python/test/compile_api/make_fx.py @@ -8,7 +8,7 @@ import functorch import torch -import torch_mlir +from torch_mlir import torchscript def simple(x): return x * x @@ -17,6 +17,6 @@ def simple(x): graph = functorch.make_fx(simple)(torch.randn(1,)) # Simplest case: One example argument. -print(torch_mlir.compile(graph, example_input)) +print(torchscript.compile(graph, example_input)) # CHECK-LABEL: @forward # CHECK: torch.aten.mul.Tensor %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[1],f32> \ No newline at end of file diff --git a/projects/pt1/python/test/compile_api/multiple_methods.py b/projects/pt1/python/test/compile_api/multiple_methods.py index f70b14ab68ab..067e775bfc71 100644 --- a/projects/pt1/python/test/compile_api/multiple_methods.py +++ b/projects/pt1/python/test/compile_api/multiple_methods.py @@ -6,7 +6,7 @@ # RUN: %PYTHON %s | FileCheck %s import torch -import torch_mlir +from torch_mlir import torchscript class TwoMethodsModule(torch.nn.Module): @@ -17,14 +17,14 @@ def cos(self, x): return torch.ops.aten.cos(x) -example_args = torch_mlir.ExampleArgs() +example_args = torchscript.ExampleArgs() example_args.add_method("sin", torch.ones(2, 3)) example_args.add_method("cos", torch.ones(2, 4)) # Note: Due to https://github.com/pytorch/pytorch/issues/88735 we need to # check the `use_tracing` case first. -print(torch_mlir.compile(TwoMethodsModule(), example_args, use_tracing=True)) +print(torchscript.compile(TwoMethodsModule(), example_args, use_tracing=True)) # CHECK: module # CHECK-DAG: func.func @sin # CHECK-DAG: func.func @cos @@ -34,8 +34,8 @@ def cos(self, x): # Otherwise the user would have to do this manually, which is tedious. This # technically mutates the user input model which is not great but probably okay # for this kind of API sugar. Users can always take full control of the process -# by scripting the model themselves before passing it to `torch_mlir.compile`. -print(torch_mlir.compile(TwoMethodsModule(), example_args)) +# by scripting the model themselves before passing it to `torchscript.compile`. +print(torchscript.compile(TwoMethodsModule(), example_args)) # CHECK: module # CHECK-DAG: func.func @sin # CHECK-DAG: func.func @cos diff --git a/projects/pt1/python/test/compile_api/output_type_spec.py b/projects/pt1/python/test/compile_api/output_type_spec.py index b975c2b5c0ae..92ed1e425d8d 100644 --- a/projects/pt1/python/test/compile_api/output_type_spec.py +++ b/projects/pt1/python/test/compile_api/output_type_spec.py @@ -7,7 +7,7 @@ import torch -import torch_mlir +from torch_mlir import torchscript class TanhModule(torch.nn.Module): def __init__(self): @@ -17,9 +17,9 @@ def forward(self, x): tanh_example_input = torch.ones(2, 3) -print(torch_mlir.compile(TanhModule(), tanh_example_input, output_type=torch_mlir.OutputType.TORCH)) +print(torchscript.compile(TanhModule(), tanh_example_input, output_type=torchscript.OutputType.TORCH)) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> -print(torch_mlir.compile(TanhModule(), tanh_example_input, output_type="torch")) +print(torchscript.compile(TanhModule(), tanh_example_input, output_type="torch")) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> diff --git a/projects/pt1/python/test/compile_api/tracing.py b/projects/pt1/python/test/compile_api/tracing.py index ea74fea12ab4..bbf652f07a28 100644 --- a/projects/pt1/python/test/compile_api/tracing.py +++ b/projects/pt1/python/test/compile_api/tracing.py @@ -7,7 +7,7 @@ import torch -import torch_mlir +from torch_mlir import torchscript class TanhModule(torch.nn.Module): @@ -17,38 +17,38 @@ def forward(self, x): tanh_example_input = torch.ones(2, 3) # Simplest case: One example argument. -print(torch_mlir.compile(TanhModule(), tanh_example_input, use_tracing=True)) +print(torchscript.compile(TanhModule(), tanh_example_input, use_tracing=True)) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> # Simplest case: Passed as a tuple. -print(torch_mlir.compile(TanhModule(), (tanh_example_input,), use_tracing=True)) +print(torchscript.compile(TanhModule(), (tanh_example_input,), use_tracing=True)) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> # Simplest case: Passed as a list. -print(torch_mlir.compile(TanhModule(), [tanh_example_input], use_tracing=True)) +print(torchscript.compile(TanhModule(), [tanh_example_input], use_tracing=True)) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> # TensorPlaceholder support. -placeholder = torch_mlir.TensorPlaceholder.like( +placeholder = torchscript.TensorPlaceholder.like( tanh_example_input, dynamic_axes=[1]) -print(torch_mlir.compile(TanhModule(), [placeholder], +print(torchscript.compile(TanhModule(), [placeholder], use_tracing=True, ignore_traced_shapes=True)) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,?],f32> -> !torch.vtensor<[2,?],f32> try: # CHECK: `ignore_traced_shapes` requires `use_tracing` - torch_mlir.compile(TanhModule(), [placeholder], ignore_traced_shapes=True) + torchscript.compile(TanhModule(), [placeholder], ignore_traced_shapes=True) except Exception as e: print(e) try: # CHECK: TensorPlaceholder can only be used with tracing when `ignore_traced_shapes=True` - torch_mlir.compile(TanhModule(), [placeholder], use_tracing=True) + torchscript.compile(TanhModule(), [placeholder], use_tracing=True) except Exception as e: print(e) @@ -60,13 +60,13 @@ def forward(self, x): try: # CHECK: Only Tensor's, TensorPlaceholder's, or sequences of Tensor's and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}' - torch_mlir.compile(DictModule(), {'a': torch.tensor(3.0)}, use_tracing=True) + torchscript.compile(DictModule(), {'a': torch.tensor(3.0)}, use_tracing=True) except Exception as e: print(e) try: # CHECK: Only Tensor's, TensorPlaceholder's, or sequences of Tensor's and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}' - torch_mlir.compile(DictModule(), [{'a': torch.tensor(3.0)}], use_tracing=True) + torchscript.compile(DictModule(), [{'a': torch.tensor(3.0)}], use_tracing=True) except Exception as e: print(e) diff --git a/projects/pt1/python/torch_mlir/dynamo.py b/projects/pt1/python/torch_mlir/dynamo.py index d3d7978bbfee..fa00bb9a847f 100644 --- a/projects/pt1/python/torch_mlir/dynamo.py +++ b/projects/pt1/python/torch_mlir/dynamo.py @@ -125,7 +125,7 @@ def make_simple_dynamo_backend(user_backend): Args: user_backend: A function with the signature used by ordinary TorchDynamo backends. But the torch.fx.GraphModule passed to it - will be normalized for consumption by `torch_mlir.compile`. + will be normalized for consumption by `torchscript.compile`. Returns: A function with the signature used by TorchDynamo backends. """ diff --git a/projects/pt1/python/torch_mlir/__init__.py b/projects/pt1/python/torch_mlir/torchscript.py similarity index 99% rename from projects/pt1/python/torch_mlir/__init__.py rename to projects/pt1/python/torch_mlir/torchscript.py index c916043c2cdd..f3412b83addb 100644 --- a/projects/pt1/python/torch_mlir/__init__.py +++ b/projects/pt1/python/torch_mlir/torchscript.py @@ -22,7 +22,7 @@ class OutputType(Enum): - """The kind of output that `torch_mlir.compile` can produce. + """The kind of output that `torchscript.compile` can produce. In MLIR terminology, this describes the mix of dialects that will be produced by the conversion process. @@ -392,13 +392,13 @@ def compile(model: torch.nn.Module, strip_overloads(model) # Get the model as JIT IR (TorchScript) for import. - # TODO: Longer-term, we probably need to split `torch_mlir.compile`. + # TODO: Longer-term, we probably need to split `torchscript.compile`. # There should be an "acquisition" step that does # tracing/scripting/importing from FX/using torchdynamo.export/etc. # + any lowering to the backend contract. Then there should be a # "backend lowering" step that does the actual lowering to each # backend. This separation should be visible at the Python API level, and - # we can implement a deliberately simplified API like `torch_mlir.compile` + # we can implement a deliberately simplified API like `torchscript.compile` # on top of those building blocks. if isinstance(model, torch.jit.ScriptModule): # If the user already converted the model to JIT IR themselves, just diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/linalg_on_tensors_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/linalg_on_tensors_backend.py index 6ad41dd6dccb..8c99278b0ec3 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/linalg_on_tensors_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/linalg_on_tensors_backend.py @@ -6,7 +6,7 @@ from typing import Any import torch -import torch_mlir +from torch_mlir import torchscript from torch_mlir_e2e_test.linalg_on_tensors_backends.abc import LinalgOnTensorsBackend from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem @@ -30,7 +30,7 @@ def __init__(self, backend: LinalgOnTensorsBackend): def compile(self, program: torch.nn.Module) -> Any: example_args = convert_annotations_to_placeholders(program.forward) - module = torch_mlir.compile( + module = torchscript.compile( program, example_args, output_type="linalg-on-tensors") return self.backend.compile(module) diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/stablehlo_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/stablehlo_backend.py index 8a244b756e6c..1ab8a8d22b4f 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/stablehlo_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/stablehlo_backend.py @@ -6,7 +6,7 @@ from typing import Any import torch -import torch_mlir +from torch_mlir import torchscript from torch_mlir_e2e_test.stablehlo_backends.abc import StablehloBackend from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem @@ -30,7 +30,7 @@ def __init__(self, backend: StablehloBackend): def compile(self, program: torch.nn.Module) -> Any: example_args = convert_annotations_to_placeholders(program.forward) - module = torch_mlir.compile(program, example_args, output_type="stablehlo") + module = torchscript.compile(program, example_args, output_type="stablehlo") return self.backend.compile(module) diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py b/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py index c53227acf36a..e5c2475c7669 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py @@ -17,7 +17,7 @@ from torch_mlir._dynamo_fx_importer import import_fx_graph_as_func from torch_mlir.dynamo import _get_decomposition_table -from torch_mlir import ( +from torch_mlir.torchscript import ( _example_args, OutputType, BACKEND_LEGAL_OPS, diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/tosa_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/tosa_backend.py index 8efab87a2bfe..8aa2d0e63eb6 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/tosa_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/tosa_backend.py @@ -6,7 +6,7 @@ from typing import Any import torch -import torch_mlir +from torch_mlir import torchscript from torch_mlir_e2e_test.tosa_backends.abc import TosaBackend from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem @@ -30,7 +30,7 @@ def __init__(self, backend: TosaBackend, use_make_fx: bool = False): def compile(self, program: torch.nn.Module) -> Any: example_args = convert_annotations_to_placeholders(program.forward) - module = torch_mlir.compile( + module = torchscript.compile( program, example_args, output_type="tosa", use_make_fx=self.use_make_fx) return self.backend.compile(module) diff --git a/projects/pt1/python/torch_mlir_e2e_test/utils.py b/projects/pt1/python/torch_mlir_e2e_test/utils.py index 403c455cba64..e3a76581f668 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/utils.py +++ b/projects/pt1/python/torch_mlir_e2e_test/utils.py @@ -3,13 +3,13 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -from torch_mlir import TensorPlaceholder +from torch_mlir.torchscript import TensorPlaceholder from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME def convert_annotations_to_placeholders(forward_method): """Converts the annotations on a forward method into tensor placeholders. - These placeholders are suitable for being passed to `torch_mlir.compile`. + These placeholders are suitable for being passed to `torchscript.compile`. """ annotations = getattr(forward_method, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME) placeholders = [] diff --git a/projects/pt1/test/python/custom_op_shape_dtype_fn.py b/projects/pt1/test/python/custom_op_shape_dtype_fn.py index a46f1c594031..a3a2b965d655 100644 --- a/projects/pt1/test/python/custom_op_shape_dtype_fn.py +++ b/projects/pt1/test/python/custom_op_shape_dtype_fn.py @@ -5,7 +5,7 @@ import torch import torch.multiprocessing as mp import torch.utils.cpp_extension -import torch_mlir +from torch_mlir import torchscript from torch_mlir_e2e_test.annotations import export, annotate_args @@ -56,7 +56,7 @@ def run(): mod = CustomOpExampleModule() mod.eval() - module = torch_mlir.compile( + module = torchscript.compile( mod, torch.ones(3, 4), output_type="torch", diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/unimplemented.py b/projects/pt1/test/python/importer/jit_ir/node_import/unimplemented.py index eb6bb2f09ff3..533ef7586748 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/unimplemented.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/unimplemented.py @@ -1,5 +1,5 @@ import torch -import torch_mlir +from torch_mlir import torchscript # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s @@ -39,6 +39,6 @@ def forward(self, data): with torch.no_grad(): return data -output_type = torch_mlir.OutputType.RAW -mod = torch_mlir.compile(Model(), [torch.tensor([0, 1, 2, 3])], output_type) +output_type = torchscript.OutputType.RAW +mod = torchscript.compile(Model(), [torch.tensor([0, 1, 2, 3])], output_type) print(mod) diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index d725aae6c584..6300df01e4ec 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -39,6 +39,13 @@ declare_mlir_python_sources(TorchMLIRPythonSources.Importers extras/onnx_importer.py ) +declare_mlir_python_sources(TorchMLIRPythonSources.PublicAPI + ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" + ADD_TO_PARENT TorchMLIRPythonSources + SOURCES + fx.py +) + declare_mlir_python_sources(TorchMLIRPythonSources.Tools ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" ADD_TO_PARENT TorchMLIRPythonSources diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py new file mode 100644 index 000000000000..78b46cc3ea29 --- /dev/null +++ b/python/torch_mlir/fx.py @@ -0,0 +1,25 @@ +from typing import Optional + +import torch +import torch.export +import torch.nn as nn + +from torch_mlir.extras.fx_importer import FxImporter +from torch_mlir import ir +from torch_mlir.dialects import torch as torch_d + +def export_and_import( + f, + *args, + fx_importer: Optional[FxImporter] = None, + constraints: Optional[torch.export.Constraint] = None, + **kwargs, +): + context = ir.Context() + torch_d.register_dialect(context) + + if fx_importer is None: + fx_importer = FxImporter(context=context) + prog = torch.export.export(f, args, kwargs, constraints=constraints) + fx_importer.import_frozen_exported_program(prog) + return fx_importer.module_op diff --git a/test/python/compile.py b/test/python/compile.py index fc2917e9c76a..990738085020 100644 --- a/test/python/compile.py +++ b/test/python/compile.py @@ -3,7 +3,7 @@ import gc import sys import torch -import torch_mlir +from torch_mlir import torchscript def run_test(f): @@ -26,7 +26,7 @@ def forward(self, x): # CHECK-LABEL: TEST: test_enable_ir_printing @run_test def test_enable_ir_printing(): - torch_mlir.compile(TinyModel(), + torchscript.compile(TinyModel(), torch.ones(1, 3, 20, 20), output_type="linalg-on-tensors", enable_ir_printing=True) diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index acd2a559fa52..36c554862506 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -1,5 +1,3 @@ -# Copyright 2023 Advanced Micro Devices, Inc -# # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception @@ -13,26 +11,7 @@ import torch.export import torch.nn as nn -from torch_mlir.extras.fx_importer import FxImporter -from torch_mlir import ir -from torch_mlir.dialects import torch as torch_d - - -def export_and_import( - f, - *args, - fx_importer: Optional[FxImporter] = None, - constraints: Optional[torch.export.Constraint] = None, - **kwargs, -): - context = ir.Context() - torch_d.register_dialect(context) - - if fx_importer is None: - fx_importer = FxImporter(context=context) - prog = torch.export.export(f, args, kwargs, constraints=constraints) - fx_importer.import_frozen_exported_program(prog) - return fx_importer.module_op +from torch_mlir import fx def run(f): @@ -75,5 +54,5 @@ def __init__(self): def forward(self, x): return torch.tanh(x) * get_a() * self.b * self.p - m = export_and_import(Basic(), torch.randn(3, 4)) + m = fx.export_and_import(Basic(), torch.randn(3, 4)) print(m)