Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Add new TRT 8.6 features to Dynamo compile [3 / x] #1973

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion py/torch_tensorrt/dynamo/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch_tensorrt
from functools import partial

from typing import Any, Sequence
from typing import Any, Optional, Sequence
from torch_tensorrt import EngineCapability, Device
from torch_tensorrt.fx.utils import LowerPrecision

Expand All @@ -17,6 +17,9 @@
WORKSPACE_SIZE,
MIN_BLOCK_SIZE,
PASS_THROUGH_BUILD_FAILURES,
MAX_AUX_STREAMS,
VERSION_COMPATIBLE,
OPTIMIZATION_LEVEL,
USE_EXPERIMENTAL_RT,
)

Expand Down Expand Up @@ -46,6 +49,9 @@ def compile(
min_block_size=MIN_BLOCK_SIZE,
torch_executed_ops=[],
torch_executed_modules=[],
max_aux_streams=MAX_AUX_STREAMS,
version_compatible=VERSION_COMPATIBLE,
optimization_level=OPTIMIZATION_LEVEL,
use_experimental_rt=USE_EXPERIMENTAL_RT,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are adding the experimental runtime, we should move those changes into a new PR which also moves the RT into dynamo and have that PR depend on this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR has now been refactored into the new PR

**kwargs,
):
Expand Down Expand Up @@ -95,6 +101,9 @@ def compile(
workspace_size=workspace_size,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
max_aux_streams=max_aux_streams,
version_compatible=version_compatible,
optimization_level=optimization_level,
use_experimental_rt=use_experimental_rt,
**kwargs,
)
Expand All @@ -119,6 +128,9 @@ def create_backend(
min_block_size: int = MIN_BLOCK_SIZE,
torch_executed_ops: Sequence[str] = set(),
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
version_compatible: bool = VERSION_COMPATIBLE,
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
use_experimental_rt: bool = USE_EXPERIMENTAL_RT,
**kwargs,
):
Expand All @@ -131,6 +143,10 @@ def create_backend(
min_block_size: Minimum number of operators per TRT-Engine Block
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False)
max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
version_compatible: Provide version forward-compatibility for engine plan files
optimization_level: Builder optimization 0-5, higher levels imply longer build time,
searching for more optimization options. TRT defaults to 3
use_experimental_rt: Whether to use the new experimental TRTModuleNext for TRT engines
Returns:
Backend for torch.compile
Expand All @@ -145,6 +161,9 @@ def create_backend(
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
pass_through_build_failures=pass_through_build_failures,
max_aux_streams=max_aux_streams,
version_compatible=version_compatible,
optimization_level=optimization_level,
use_experimental_rt=use_experimental_rt,
)

Expand Down
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/backend/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@
WORKSPACE_SIZE = 0
MIN_BLOCK_SIZE = 5
PASS_THROUGH_BUILD_FAILURES = False
MAX_AUX_STREAMS = None
VERSION_COMPATIBLE = False
OPTIMIZATION_LEVEL = None
USE_EXPERIMENTAL_RT = False
8 changes: 7 additions & 1 deletion py/torch_tensorrt/dynamo/backend/_settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Sequence
from typing import Optional, Sequence

from torch_tensorrt.fx.utils import LowerPrecision
from torch_tensorrt.dynamo.backend._defaults import (
Expand All @@ -8,6 +8,9 @@
WORKSPACE_SIZE,
MIN_BLOCK_SIZE,
PASS_THROUGH_BUILD_FAILURES,
MAX_AUX_STREAMS,
VERSION_COMPATIBLE,
OPTIMIZATION_LEVEL,
USE_EXPERIMENTAL_RT,
)

Expand All @@ -20,4 +23,7 @@ class CompilationSettings:
min_block_size: int = MIN_BLOCK_SIZE
torch_executed_ops: Sequence[str] = field(default_factory=set)
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES
max_aux_streams: Optional[int] = MAX_AUX_STREAMS
version_compatible: bool = VERSION_COMPATIBLE
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
use_experimental_rt: bool = USE_EXPERIMENTAL_RT
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/backend/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def convert_module(
if settings.debug
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
),
max_aux_streams=settings.max_aux_streams,
version_compatible=settings.version_compatible,
optimization_level=settings.optimization_level,
)

if settings.use_experimental_rt:
Expand Down