Skip to content

Commit

Permalink
Using these changes and
Browse files Browse the repository at this point in the history
```shell
TORCH_MLIR_ENABLE_LTC=0 TORCH_MLIR_ENABLE_JIT_IR_IMPORTER=0 CMAKE_GENERATOR=Ninja python setup.py bdist_wheel --dist-dir wheelhouse
```

I can build a wheel using which PI will successfully compile/lower e2e model examples.
  • Loading branch information
makslevental committed Feb 6, 2023
1 parent 089018b commit be042d0
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
23 changes: 15 additions & 8 deletions python/torch_mlir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,26 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

import warnings
from typing import Optional, Sequence, Union, List, Dict, Tuple
from enum import Enum

import sys
from io import StringIO

from torch._functorch.compile_utils import strip_overloads
import torch
try:
from torch._functorch.compile_utils import strip_overloads
import torch
except ImportError:
warnings.warn("PyTorch not installed")

from torch_mlir.passmanager import PassManager
from .compiler_utils import run_pipeline_with_repro_report
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder
from torch_mlir.dialects import torch as torch_dialect
try:
from torch_dialect.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder
except ImportError:
warnings.warn("torch-mlir JIT IR Importer not installed")


class OutputType(Enum):
Expand Down Expand Up @@ -88,7 +95,7 @@ class TensorPlaceholder:
```
"""

def __init__(self, shape: List[int], dtype: torch.dtype):
def __init__(self, shape: List[int], dtype: "torch.dtype"):
"""Create a tensor with shape `shape` and dtype `dtype`.
Args:
Expand All @@ -100,7 +107,7 @@ def __init__(self, shape: List[int], dtype: torch.dtype):
self.dtype = dtype

@staticmethod
def like(tensor: torch.Tensor, dynamic_axes: List[int] = None):
def like(tensor: "torch.Tensor", dynamic_axes: List[int] = None):
"""Create a tensor placeholder that is like the given tensor.
Args:
Expand All @@ -119,7 +126,7 @@ def like(tensor: torch.Tensor, dynamic_axes: List[int] = None):
return TensorPlaceholder(shape, tensor.dtype)


_example_arg = Union[TensorPlaceholder, torch.Tensor]
_example_arg = Union[TensorPlaceholder, "torch.Tensor"]
_example_args_for_one_method = Union[_example_arg, Sequence[_example_arg]]
_example_args = Union[_example_args_for_one_method, "ExampleArgs"]

Expand Down Expand Up @@ -246,7 +253,7 @@ def _get_for_tracing(
}


def compile(model: torch.nn.Module,
def compile(model: "torch.nn.Module",
example_args: _example_args,
output_type: Union[str, "OutputType"] = OutputType.TORCH,
use_tracing: bool = False,
Expand Down
10 changes: 5 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@

PACKAGE_VERSION = os.environ.get("TORCH_MLIR_PYTHON_PACKAGE_VERSION") or "0.0.1"

# If true, enable LTC build by default
# If true, enable LTC and JIT IR importer build by default
TORCH_MLIR_ENABLE_LTC_DEFAULT = True
TORCH_MLIR_ENABLE_JIT_IR_IMPORTER = int(os.environ.get('TORCH_MLIR_ENABLE_JIT_IR_IMPORTER', True))

# Build phase discovery is unreliable. Just tell it what phases to run.
class CustomBuild(_build):
Expand Down Expand Up @@ -90,6 +91,7 @@ def run(self):
f"-DCMAKE_C_VISIBILITY_PRESET=hidden",
f"-DCMAKE_CXX_VISIBILITY_PRESET=hidden",
f"-DTORCH_MLIR_ENABLE_LTC={'ON' if enable_ltc else 'OFF'}",
f"-DTORCH_MLIR_ENABLE_JIT_IR_IMPORTER={'ON' if TORCH_MLIR_ENABLE_JIT_IR_IMPORTER else 'OFF'}",
]

os.makedirs(cmake_build_dir, exist_ok=True)
Expand Down Expand Up @@ -159,9 +161,7 @@ def build_extension(self, ext):
ext_modules=[
CMakeExtension("torch_mlir._mlir_libs._jit_ir_importer"),
],
install_requires=[
"numpy",
f"torch=={torch.__version__}".split("+", 1)[0],
],
install_requires=["numpy", ] + [
f"torch=={torch.__version__}".split("+", 1)[0], ] if TORCH_MLIR_ENABLE_JIT_IR_IMPORTER else [],
zip_safe=False,
)

0 comments on commit be042d0

Please sign in to comment.