Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Dynamo refactor #2104

Merged
merged 14 commits into from
Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 22 additions & 25 deletions .circleci/config.yml
peri044 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ commands:
command: |
set -e
mkdir -p /tmp/artifacts/test_results
cd tests/py
cd tests/py/ts/
pytest --junitxml=/tmp/artifacts/test_results/api/api_test_results.xml api/
pytest --junitxml=/tmp/artifacts/test_results/models/models_test_results.xml models/
pytest --junitxml=/tmp/artifacts/test_results/integrations/integrations_test_results.xml integrations/
Expand Down Expand Up @@ -733,50 +733,47 @@ commands:
# =================== FX tests end ======================== #

# =================== Dynamo tests start ======================== #
test-dynamo-fx_ts:
description: "Test the Dynamo fx_ts_compat path"

test-dynamo-torch_compile:
description: "Test Dynamo torch_compile tests"
steps:
- run:
name: Run Dynamo fx_ts_compat core tests
name: Run Dynamo torch_compile tests
command: |
cd py/torch_tensorrt/dynamo/fx_ts_compat/test
pushd core/
pytest --junitxml=/tmp/artifacts/test_results/dynamo/fx_ts_compat/test_results.xml
popd
cd tests/py/dynamo/backend/
pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml

- store_test_results:
path: /tmp/artifacts
- store_artifacts:
path: /tmp/testlogs

test-dynamo-compile-core:
description: "Test the Dynamo compile path"
test-dynamo-models_torch_compile:
description: "Test the Dynamo models via torch_compile path"
steps:
- run:
name: Run Dynamo compile core tests
name: Run Dynamo models via torch_compile path
command: |
cd py/torch_tensorrt/dynamo/backend
pushd test/
pytest --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml
popd
cd tests/py/dynamo/models
pip3 install timm
pip3 install transformers
pytest test_models.py --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml --ir torch_compile

- store_test_results:
path: /tmp/artifacts
- store_artifacts:
path: /tmp/testlogs

test-dynamo-compile:
description: "Test the Dynamo compile path"
test-dynamo-models_torch_export:
description: "Test the Dynamo models via torch_export path"
steps:
- run:
name: Run Dynamo compile E2E tests
name: Run Dynamo models via torch_export path
command: |
cd py/torch_tensorrt/dynamo/
pushd test/
cd tests/py/dynamo/models
pip3 install timm
pip3 install transformers
pytest --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml --ir dynamo_compile
popd
pytest test_models_export.py --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml --ir dynamo

- store_test_results:
path: /tmp/artifacts
Expand Down Expand Up @@ -1039,9 +1036,9 @@ jobs:
command: pip3 install --pre /tmp/dist/x86_64-linux/*cp39-cp39*.whl
# We install torch after torch-trt because pip automatically enforces the version constraint otherwise
- dump-test-env
- test-dynamo-compile
- test-dynamo-compile-core
- test-dynamo-fx_ts
- test-dynamo-torch_compile
- test-dynamo-models_torch_compile
- test-dynamo-models_torch_export

package-x86_64-linux:
parameters:
Expand Down
29 changes: 20 additions & 9 deletions py/torch_tensorrt/_Input.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,47 +302,58 @@ def _parse_tensor_domain(domain: Optional[Tuple[float, float]]) -> Tuple:
return result_domain

@classmethod
def from_tensor(cls, t: torch.Tensor) -> "Input":
def from_tensor(
cls, t: torch.Tensor, disable_memory_format_check: bool = False
) -> "Input":
"""
Produce a Input which contains the information of the given PyTorch tensor.

Args:
tensor (torch.Tensor): A PyTorch tensor.
disable_memory_format_check (bool): Whether to validate the memory formats of input tensors

Returns:
A Input object.
"""
if not any(
[
t.is_contiguous(memory_format=torch.contiguous_format),
t.is_contiguous(memory_format=torch.channels_last),
]
if not (
t.is_contiguous(memory_format=torch.contiguous_format)
or t.is_contiguous(memory_format=torch.channels_last)
or disable_memory_format_check
):
raise ValueError(
"Tensor does not have a supported memory format, supported formats are contiguous or channel_last"
)
frmt = (
torch.contiguous_format
if t.is_contiguous(memory_format=torch.contiguous_format)
if (
t.is_contiguous(memory_format=torch.contiguous_format)
or disable_memory_format_check
)
else torch.channels_last
)
return cls(shape=t.shape, dtype=t.dtype, format=frmt)

@classmethod
def from_tensors(cls, ts: torch.Tensor) -> List["Input"]:
def from_tensors(
cls, ts: torch.Tensor, disable_memory_format_check: bool = False
) -> List["Input"]:
"""
Produce a list of Inputs which contain
the information of all the given PyTorch tensors.

Args:
tensors (Iterable[torch.Tensor]): A list of PyTorch tensors.
disable_memory_format_check (bool): Whether to validate the memory formats of input tensors

Returns:
A list of Inputs.
"""

assert isinstance(ts, (list, tuple))
return [cls.from_tensor(t) for t in ts]
return [
cls.from_tensor(t, disable_memory_format_check=disable_memory_format_check)
for t in ts
]

def example_tensor(self, optimization_profile_field: str = None) -> torch.Tensor:
"""
Expand Down
75 changes: 54 additions & 21 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class _IRType(Enum):

ts = 0
fx = 1
fx_ts_compat = 2
dynamo_compile = 3
dynamo = 2
torch_compile = 3


class _ModuleType(Enum):
Expand Down Expand Up @@ -47,33 +47,33 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:

ir_targets_torchscript = any([ir == opt for opt in ["torchscript", "ts"]])
ir_targets_fx = ir == "fx"
ir_targets_dynamo_compile = ir == "dynamo_compile"
ir_targets_fx_ts_compat = ir == "fx_ts_compat"
ir_targets_dynamo = ir == "dynamo"
ir_targets_torch_compile = ir == "torch_compile"

if module_is_tsable and ir_targets_torchscript:
return _IRType.ts
elif module_is_fxable and ir_targets_fx:
return _IRType.fx
elif module_is_fxable and ir_targets_fx_ts_compat:
return _IRType.fx_ts_compat
elif module_is_fxable and ir_targets_dynamo_compile:
return _IRType.dynamo_compile
elif module_is_fxable and ir_targets_dynamo:
return _IRType.dynamo
elif module_is_fxable and ir_targets_torch_compile:
return _IRType.torch_compile
else:
if ir == "default":
# Options are listed in order of preference
if module_is_tsable:
logging.log(
logging.Level.Info, "ir was set to default, using TorchScript as ir"
logging.Level.Warning,
"Input graph is a Torchscript module but the ir provided is default (dynamo). Please set ir=torchscript to suppress the warning. Compiling the module with ir=ts",
)
return _IRType.ts
peri044 marked this conversation as resolved.
Show resolved Hide resolved
elif module_is_fxable:
raise ValueError(
"Was given a torch.fx.GraphModule, fx is not currently supported by Torch-TensorRT"
logging.log(
logging.Level.Info, "ir was set to default, using dynamo as ir"
)
# logging.log(logging.Level.Info, "ir was set to default, using TorchScript as fx")
# return _IRType.fx
return _IRType.dynamo
else:
raise ValueError("Module was provided with in an unsupported format")
raise ValueError("Module was provided in an unsupported format")
else:
raise ValueError("Unknown ir was requested")

Expand Down Expand Up @@ -156,18 +156,41 @@ def compile(
dynamic_batch=False,
**kwargs,
)
elif target_ir == _IRType.dynamo_compile:
elif target_ir == _IRType.dynamo:
from torch_tensorrt import Device
from torch_tensorrt.dynamo.utils import prepare_inputs, prepare_device
import collections.abc

if not isinstance(inputs, collections.abc.Sequence):
inputs = [inputs]
device = kwargs.get("device", Device._current_device())
torchtrt_inputs, torch_inputs = prepare_inputs(inputs, prepare_device(device))
module = torch_tensorrt.dynamo.trace(module, torch_inputs, **kwargs)
return torch_tensorrt.dynamo.compile(
peri044 marked this conversation as resolved.
Show resolved Hide resolved
module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs
)
elif target_ir == _IRType.fx_ts_compat:
return torch_tensorrt.dynamo.fx_ts_compat.compile(
module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs
module,
inputs=inputs,
enabled_precisions=enabled_precisions,
**kwargs,
)
elif target_ir == _IRType.torch_compile:
return torch_compile(module, enabled_precisions=enabled_precisions, **kwargs)
else:
raise RuntimeError("Module is an unknown format or the ir requested is unknown")


def torch_compile(module, **kwargs):
"""
Returns a boxed model which is the output of torch.compile.
This does not compile the model to TRT. Execute this model on
sample inputs to compile the model to TRT.
"""
from torch_tensorrt.dynamo.backend import torch_tensorrt_backend

boxed_fn = torch.compile(module, backend=torch_tensorrt_backend, options={**kwargs})

return boxed_fn


def convert_method_to_trt_engine(
peri044 marked this conversation as resolved.
Show resolved Hide resolved
module: Any,
method_name: str,
Expand Down Expand Up @@ -224,6 +247,16 @@ def convert_method_to_trt_engine(
**kwargs,
)
elif target_ir == _IRType.fx:
raise RuntimeError("fx is currently not supported")
raise RuntimeError(
"convert_method_to_trt_engine call is not supported for ir=fx"
)
elif target_ir == _IRType.dynamo:
raise RuntimeError(
"convert_method_to_trt_engine call is not supported for ir=dynamo."
)
elif target_ir == _IRType.torch_compile:
raise RuntimeError(
"convert_method_to_trt_engine call is not supported for ir=torch_compile"
)
else:
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
5 changes: 3 additions & 2 deletions py/torch_tensorrt/dynamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from torch_tensorrt._util import sanitized_torch_version

if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
from torch_tensorrt.dynamo import fx_ts_compat
from .backend import compile
from ._settings import *
from .compile import compile
from .aten_tracer import trace
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from torch_tensorrt.fx.utils import LowerPrecision
import torch


PRECISION = LowerPrecision.FP32
PRECISION = torch.float32
DEBUG = False
WORKSPACE_SIZE = 0
MIN_BLOCK_SIZE = 5
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from dataclasses import dataclass, field
from typing import Optional, Sequence

from torch_tensorrt.fx.utils import LowerPrecision
from torch_tensorrt.dynamo.backend._defaults import (
import torch
from torch_tensorrt.dynamo._defaults import (
PRECISION,
DEBUG,
WORKSPACE_SIZE,
Expand All @@ -17,7 +16,7 @@

@dataclass
class CompilationSettings:
precision: LowerPrecision = PRECISION
precision: torch.dtype = PRECISION
debug: bool = DEBUG
workspace_size: int = WORKSPACE_SIZE
min_block_size: int = MIN_BLOCK_SIZE
Expand Down
Loading