Skip to content

Commit

Permalink
Rename torch_mlir.compile APIs and introduce FX based analogs (llvm#2842
Browse files Browse the repository at this point in the history
)

Link to related RFC:
https://discourse.llvm.org/t/rfc-rename-torch-mlir-compile-apis-and-introduce-fx-based-analogs/76646
This commit updates the documentation, tests, CMake files, and API for
the proposed changes in the RFC. There is a new torch_mlir/fx.py for
user level APIs related to importing modules and a corresponding test
for this path can be found at test/python/fx_importer/basic_test.py.

---------

Co-authored-by: MaheshRavishankar <mravisha@amd.com>
  • Loading branch information
saienduri and MaheshRavishankar authored Feb 7, 2024
1 parent cc06391 commit bfcf93e
Show file tree
Hide file tree
Showing 31 changed files with 146 additions and 106 deletions.
2 changes: 1 addition & 1 deletion docs/architecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ semantics. And often users want to erase the shapes in the trace to allow
dynamic shapes for the trace. Additionally, the Python-level data structures and
APIs are very parallel between `torch.jit.script` and `torch.jit.trace`, so we
consider both of those as the same from the perspective of the responsibilities
of the compiler. Both are accessed via the `torch_mlir.compile` Python API.
of the compiler. Both are accessed via the `torch_mlir.torchscript.compile` Python API.

### Modeling the `torch.nn.Module` object (`IValue`) hierarchy for TorchScript

Expand Down
25 changes: 19 additions & 6 deletions docs/development.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,37 +120,50 @@ cmake --build build
### Linux and macOS

```shell
export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/projects/pt1/examples
export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/test/python/fx_importer
```

### Windows PowerShell

```shell
$env:PYTHONPATH = "$PWD/build/tools/torch-mlir/python_packages/torch_mlir;$PWD/projects/pt1/examples"
$env:PYTHONPATH = "$PWD/build/tools/torch-mlir/python_packages/torch_mlir;$PWD/test/python/fx_importer"
```

## Testing MLIR output in various dialects

To test the compiler's output to the different MLIR dialects, you can use the example `projects/pt1/examples/torchscript_resnet18_all_output_types.py`.
To test the MLIR output to torch dialect, you can use `test/python/fx_importer/basic_test.py`.

Make sure you have activated the virtualenv and set the `PYTHONPATH` above
(if running on Windows, modify the environment variable as shown above):
```shell
source mlir_venv/bin/activate
export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/test/python/fx_importer
python test/python/fx_importer/basic_test.py
```

This will display the basic example in TORCH dialect.

To test the compiler's output to the different MLIR dialects, you can also use the deprecated path
using torchscript with the example `projects/pt1/examples/torchscript_resnet18_all_output_types.py`.
This path doesn't give access to the current generation work that is being driven via the fx_importer
and may lead to errors.

Same as above, but with different python path and example:
```shell
export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/projects/pt1/examples
python projects/pt1/examples/torchscript_resnet18_all_output_types.py
```

This will display the Resnet18 network example in three dialects: TORCH, LINALG on TENSORS and TOSA.

The main functionality is on `torch_mlir.compile()`'s `output_type`.
The main functionality is on `torch_mlir.torchscript.compile()`'s `output_type`.

Ex:
```python
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
module = torch_mlir.torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
```

Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `STABLEHLO`.
`output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `STABLEHLO`.

## Jupyter

Expand Down
16 changes: 16 additions & 0 deletions docs/long_term_roadmap.md → docs/roadmap.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,22 @@ the ecosystem are:
Most of this document describes long-term ecosystem changes that will address
these, drastically improving Torch-MLIR's ability to meet its goals.

## Current API Paths

Currently, there are two main API paths for the torch-mlir project:

- The first path is part of the legacy project pt1 code
(torch_mlir.torchscript.compile). This allows users to test the compiler's
output to the different MLIR dialects (`TORCH`, `TOSA`, `LINALG_ON_TENSORS`,
`RAW` and `STABLEHLO`). This path is deprecated and doesn’t give access to
the current generation work that is being driven via the fx_importer. It is
tied to the old Torchscript path.
- The second path (torch_mlir.fx.export_and_import) allows users to import a
consolidated torch.export.ExportedProgram instance of an arbitrary Python
callable (an nn.Module, a function or a method) and output to torch dialect
mlir module. This path is aligned with PyTorch's roadmap, but the path is
not fully functional yet.

## Roadmap

### Refactoring the frontend
Expand Down
4 changes: 2 additions & 2 deletions projects/pt1/examples/torchdynamo_resnet18.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torchvision.models as models
from torchvision import transforms

import torch_mlir
from torch_mlir import torchscript
from torch_mlir.dynamo import make_simple_dynamo_backend
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend

Expand Down Expand Up @@ -71,7 +71,7 @@ def predictions(torch_func, jit_func, img, labels):
@make_simple_dynamo_backend
def refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule,
example_inputs: List[torch.Tensor]):
mlir_module = torch_mlir.compile(
mlir_module = torchscript.compile(
fx_graph, example_inputs, output_type="linalg-on-tensors")
backend = refbackend.RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(mlir_module)
Expand Down
4 changes: 2 additions & 2 deletions projects/pt1/examples/torchscript_resnet18.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torchvision.models as models
from torchvision import transforms

import torch_mlir
from torch_mlir import torchscript
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend


Expand Down Expand Up @@ -67,7 +67,7 @@ def predictions(torch_func, jit_func, img, labels):

resnet18 = models.resnet18(pretrained=True)
resnet18.train(False)
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors")
module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors")
backend = refbackend.RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(module)
jit_module = backend.load(compiled)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
import torch
import torchvision

import torch_mlir
from torch_mlir import torchscript

resnet18 = torchvision.models.resnet18(pretrained=True)
resnet18.eval()

module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
print("TORCH OutputType\n", module.operation.get_asm(large_elements_limit=10))
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors")
module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors")
print("LINALG_ON_TENSORS OutputType\n", module.operation.get_asm(large_elements_limit=10))
# TODO: Debug why this is so slow.
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="tosa")
module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="tosa")
print("TOSA OutputType\n", module.operation.get_asm(large_elements_limit=10))
4 changes: 2 additions & 2 deletions projects/pt1/examples/torchscript_resnet_inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@
"\n",
"# Compile the model with an example input.\n",
"# We lower to the linalg-on-tensors form that the reference backend supports.\n",
"compiled = torch_mlir.compile(TanhModule(), torch.ones(3), output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)\n",
"compiled = torch_mlir.torchscript.compile(TanhModule(), torch.ones(3), output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)\n",
"# Load it on the reference backend.\n",
"jit_module = compile_and_load_on_refbackend(compiled)\n",
"# Run it!\n",
Expand Down Expand Up @@ -326,7 +326,7 @@
"source": [
"resnet18 = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)\n",
"resnet18.eval()\n",
"compiled = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type=\"linalg-on-tensors\")\n",
"compiled = torch_mlir.torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type=\"linalg-on-tensors\")\n",
"jit_module = compile_and_load_on_refbackend(compiled)"
]
},
Expand Down
4 changes: 2 additions & 2 deletions projects/pt1/examples/torchscript_stablehlo_backend_resnet.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import torch
import torchvision.models as models
import torch_mlir
from torch_mlir import torchscript

model = models.resnet18(pretrained=True)
model.eval()
data = torch.randn(2,3,200,200)
out_stablehlo_mlir_path = "./resnet18_stablehlo.mlir"

module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=False)
module = torchscript.compile(model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=False)
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
outf.write(str(module))

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
import torch_mlir
from torch_mlir import torchscript

from transformers import BertForMaskedLM

Expand All @@ -17,7 +17,7 @@ def forward(self, data):
data = torch.randint(30522, (2, 128))
out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir"

module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=True)
module = torchscript.compile(model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=True)
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
outf.write(str(module))

Expand Down
2 changes: 1 addition & 1 deletion projects/pt1/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonTorchExtensionsSources
SOURCES
__init__.py
torchscript.py
_dynamo_fx_importer.py
compiler_utils.py
dynamo.py
Expand Down
8 changes: 4 additions & 4 deletions projects/pt1/python/test/compile_api/already_scripted.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# RUN: %PYTHON %s | FileCheck %s

import torch
import torch_mlir
from torch_mlir import torchscript


class BasicModule(torch.nn.Module):
Expand All @@ -15,17 +15,17 @@ def sin(self, x):
return torch.ops.aten.sin(x)


example_args = torch_mlir.ExampleArgs()
example_args = torchscript.ExampleArgs()
example_args.add_method("sin", torch.ones(2, 3))

scripted = torch.jit.script(BasicModule())
print(torch_mlir.compile(scripted, example_args))
print(torchscript.compile(scripted, example_args))
# CHECK: module
# CHECK-DAG: func.func @sin

scripted = torch.jit.script(BasicModule())
try:
# CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition.
torch_mlir.compile(scripted, torch_mlir.ExampleArgs().add_method("nonexistent", torch.ones(2, 3)))
torchscript.compile(scripted, torchscript.ExampleArgs().add_method("nonexistent", torch.ones(2, 3)))
except Exception as e:
print(e)
8 changes: 4 additions & 4 deletions projects/pt1/python/test/compile_api/already_traced.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,23 @@
# RUN: %PYTHON %s | FileCheck %s

import torch
import torch_mlir
from torch_mlir import torchscript

class BasicModule(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.sin(x)

example_arg = torch.ones(2, 3)
example_args = torch_mlir.ExampleArgs.get(example_arg)
example_args = torchscript.ExampleArgs.get(example_arg)

traced = torch.jit.trace(BasicModule(), example_arg)
print(torch_mlir.compile(traced, example_args))
print(torchscript.compile(traced, example_args))
# CHECK: module
# CHECK-DAG: func.func @forward

traced = torch.jit.trace(BasicModule(), example_arg)
try:
# CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition.
torch_mlir.compile(traced, torch_mlir.ExampleArgs().add_method("nonexistent", example_arg))
torchscript.compile(traced, torchscript.ExampleArgs().add_method("nonexistent", example_arg))
except Exception as e:
print(e)
6 changes: 3 additions & 3 deletions projects/pt1/python/test/compile_api/backend_legal_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@

import torch

import torch_mlir
from torch_mlir import torchscript

class AddmmModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y, z):
return torch.ops.aten.addmm(x, y, z)

example_args = 3 * [torch_mlir.TensorPlaceholder([-1, -1], torch.float32)]
example_args = 3 * [torchscript.TensorPlaceholder([-1, -1], torch.float32)]

print(torch_mlir.compile(AddmmModule(), example_args,
print(torchscript.compile(AddmmModule(), example_args,
output_type="torch", backend_legal_ops=["aten.addmm"]))
# CHECK-LABEL: @forward
# CHECK: torch.aten.addmm
20 changes: 10 additions & 10 deletions projects/pt1/python/test/compile_api/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch

import torch_mlir
from torch_mlir import torchscript

class TanhModule(torch.nn.Module):
def __init__(self):
Expand All @@ -18,24 +18,24 @@ def forward(self, x):
tanh_example_input = torch.ones(2, 3)

# Simplest case: One example argument.
print(torch_mlir.compile(TanhModule(), tanh_example_input))
print(torchscript.compile(TanhModule(), tanh_example_input))
# CHECK-LABEL: @forward
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>

# Use a TensorPlaceholder to represent dynamic axes.
placeholder = torch_mlir.TensorPlaceholder.like(tanh_example_input, dynamic_axes=[1])
print(torch_mlir.compile(TanhModule(), placeholder))
placeholder = torchscript.TensorPlaceholder.like(tanh_example_input, dynamic_axes=[1])
print(torchscript.compile(TanhModule(), placeholder))
# CHECK-LABEL: @forward
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,?],f32> -> !torch.vtensor<[2,?],f32>

# Explicitly construct a TensorPlaceholder.
placeholder = torch_mlir.TensorPlaceholder([-1, 2], torch.float32)
print(torch_mlir.compile(TanhModule(), placeholder))
placeholder = torchscript.TensorPlaceholder([-1, 2], torch.float32)
print(torchscript.compile(TanhModule(), placeholder))
# CHECK-LABEL: @forward
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[?,2],f32> -> !torch.vtensor<[?,2],f32>

# Basic smoke test for the raw output type.
print(torch_mlir.compile(TanhModule(), tanh_example_input, output_type=torch_mlir.OutputType.RAW))
print(torchscript.compile(TanhModule(), tanh_example_input, output_type=torchscript.OutputType.RAW))
# CHECK: torch.nn_module {
# CHECK: } : !torch.nn.Module<"{{.*}}.TanhModule">

Expand All @@ -47,12 +47,12 @@ def forward(self, lhs, rhs ):

# N > 1 inputs.
mm_example_inputs = [torch.ones(2, 3), torch.ones(3, 4)]
print(torch_mlir.compile(MmModule(), mm_example_inputs))
print(torchscript.compile(MmModule(), mm_example_inputs))
# CHECK-LABEL: @forward
# CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[2,4],f32>

# Mixes Tensor's and TensorPlaceholder's.
mm_dynamic_inputs = [mm_example_inputs[0], torch_mlir.TensorPlaceholder.like(mm_example_inputs[1], dynamic_axes=[1])]
print(torch_mlir.compile(MmModule(), mm_dynamic_inputs))
mm_dynamic_inputs = [mm_example_inputs[0], torchscript.TensorPlaceholder.like(mm_example_inputs[1], dynamic_axes=[1])]
print(torchscript.compile(MmModule(), mm_dynamic_inputs))
# CHECK-LABEL: @forward
# CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,?],f32> -> !torch.vtensor<[2,?],f32>
4 changes: 2 additions & 2 deletions projects/pt1/python/test/compile_api/make_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import functorch
import torch

import torch_mlir
from torch_mlir import torchscript

def simple(x):
return x * x
Expand All @@ -17,6 +17,6 @@ def simple(x):
graph = functorch.make_fx(simple)(torch.randn(1,))

# Simplest case: One example argument.
print(torch_mlir.compile(graph, example_input))
print(torchscript.compile(graph, example_input))
# CHECK-LABEL: @forward
# CHECK: torch.aten.mul.Tensor %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[1],f32>
10 changes: 5 additions & 5 deletions projects/pt1/python/test/compile_api/multiple_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# RUN: %PYTHON %s | FileCheck %s

import torch
import torch_mlir
from torch_mlir import torchscript


class TwoMethodsModule(torch.nn.Module):
Expand All @@ -17,14 +17,14 @@ def cos(self, x):
return torch.ops.aten.cos(x)


example_args = torch_mlir.ExampleArgs()
example_args = torchscript.ExampleArgs()
example_args.add_method("sin", torch.ones(2, 3))
example_args.add_method("cos", torch.ones(2, 4))

# Note: Due to https://github.com/pytorch/pytorch/issues/88735 we need to
# check the `use_tracing` case first.

print(torch_mlir.compile(TwoMethodsModule(), example_args, use_tracing=True))
print(torchscript.compile(TwoMethodsModule(), example_args, use_tracing=True))
# CHECK: module
# CHECK-DAG: func.func @sin
# CHECK-DAG: func.func @cos
Expand All @@ -34,8 +34,8 @@ def cos(self, x):
# Otherwise the user would have to do this manually, which is tedious. This
# technically mutates the user input model which is not great but probably okay
# for this kind of API sugar. Users can always take full control of the process
# by scripting the model themselves before passing it to `torch_mlir.compile`.
print(torch_mlir.compile(TwoMethodsModule(), example_args))
# by scripting the model themselves before passing it to `torchscript.compile`.
print(torchscript.compile(TwoMethodsModule(), example_args))
# CHECK: module
# CHECK-DAG: func.func @sin
# CHECK-DAG: func.func @cos
Loading

0 comments on commit bfcf93e

Please sign in to comment.