-
Notifications
You must be signed in to change notification settings - Fork 360
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add example usage scripts for dynamo path
- Add sample scripts covering resnet18, transformers, and custom examples showcasing the `torch_tensorrt.dynamo.torch_compile` path, which can compile models with data-dependent control flow and other such restrictions which can make other compilation methods more difficult - Cover different customizeable features allowed in the new backend
- Loading branch information
Showing
3 changed files
with
147 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import torch | ||
from torch_tensorrt.dynamo.torch_compile import create_backend | ||
from torch_tensorrt.fx.lower_setting import LowerPrecision | ||
|
||
|
||
##### Overview | ||
# This script is intended as an overview of the process by which | ||
# torch_tensorrt.dynamo.torch_compile works, and how it integrates | ||
# with the new torch.compile API. | ||
|
||
# We begin by defining a model | ||
class Model(torch.nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self.relu = torch.nn.ReLU() | ||
|
||
def forward(self, x: torch.Tensor, y: torch.Tensor): | ||
x_out = self.relu(x) | ||
y_out = self.relu(y) | ||
x_y_out = x_out + y_out | ||
return torch.mean(x_y_out) | ||
|
||
|
||
##### Compilation using default settings | ||
|
||
sample_inputs = [torch.rand((5, 7)).cuda(), torch.rand((5, 7)).cuda()] | ||
model = Model().eval().cuda() | ||
|
||
# Next, we compile the model using torch.compile | ||
# For the default settings, we can simply call torch.compile | ||
# with the backend "tensorrt", and run the model on an | ||
# input to cause compilation, as so: | ||
optimized_model = torch.compile(model, backend="tensorrt") | ||
optimized_model(*sample_inputs) | ||
|
||
|
||
##### Compilation using custom settings | ||
|
||
sample_inputs_half = [ | ||
torch.rand((5, 7)).half().cuda(), | ||
torch.rand((5, 7)).half().cuda(), | ||
] | ||
model_half = Model().half().eval().cuda() | ||
|
||
# Alternatively, if we want to customize certain options in the backend, | ||
# but still use the torch.compile call directly, we can call the | ||
# convenience/helper function create_backend to create a custom backend | ||
# which has been pre-populated with certain key | ||
custom_backend = create_backend( | ||
lower_precision=LowerPrecision.FP16, debug=True, max_num_trt_engines=2 | ||
) | ||
optimized_model_custom = torch.compile(model_half, backend=custom_backend) | ||
optimized_model_custom(*sample_inputs_half) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import torch | ||
from torch_tensorrt.dynamo import torch_compile | ||
import torchvision.models as models | ||
|
||
##### Overview | ||
# This script is intended as a sample of the torch_tensorrt.dynamo.torch_compile | ||
# workflow on the resnet18 model | ||
|
||
|
||
# Initialize model and sample inputs | ||
model = models.resnet18(pretrained=True).half().eval().to("cuda") | ||
inputs = [torch.randn((1, 3, 224, 224)).to("cuda").half()] | ||
|
||
##### Optional Input Arguments | ||
|
||
# Enabled precision for TensorRT optimization | ||
enabled_precisions = {torch.half} | ||
# Whether to print verbose logs | ||
debug = True | ||
# Workspace size for TensorRT | ||
workspace_size = 20 << 30 | ||
# Maximum number of TRT Engines | ||
# (Higher value allows more graph segmentation) | ||
max_num_trt_engines = 100 | ||
|
||
|
||
# Build and compile the model with torch.compile, using tensorrt backend | ||
optimized_model = torch_compile( | ||
model, | ||
inputs, | ||
enabled_precisions=enabled_precisions, | ||
debug=debug, | ||
workspace_size=workspace_size, | ||
max_num_trt_engines=max_num_trt_engines, | ||
) | ||
|
||
|
||
# Does not cause recompilation (same batch size as input) | ||
new_inputs = [torch.randn((1, 3, 224, 224)).half().to("cuda")] | ||
new_outputs = optimized_model(*new_inputs) | ||
|
||
|
||
# Does cause recompilation (new batch size) | ||
new_batch_size_inputs = [torch.randn((8, 3, 224, 224)).half().to("cuda")] | ||
new_batch_size_outputs = optimized_model(*new_batch_size_inputs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import torch | ||
from torch_tensorrt.dynamo import torch_compile | ||
from transformers import BertModel | ||
|
||
##### Overview | ||
# This script is intended as a sample of the torch_tensorrt.dynamo.torch_compile | ||
# workflow on the BERT base uncased model | ||
|
||
|
||
model = BertModel.from_pretrained("bert-base-uncased").cuda().eval() | ||
inputs = [ | ||
torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda"), | ||
torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda"), | ||
] | ||
|
||
##### Optional Input Arguments | ||
|
||
# Enabled precision for TensorRT optimization | ||
enabled_precisions = {torch.float} | ||
# Whether to print verbose logs | ||
debug = True | ||
# Workspace size for TensorRT | ||
workspace_size = 20 << 30 | ||
# Maximum number of TRT Engines | ||
# (Higher value allows more graph segmentation) | ||
max_num_trt_engines = 200 | ||
|
||
|
||
# Build and compile the model with torch.compile, using tensorrt backend | ||
optimized_model = torch_compile( | ||
model, | ||
inputs, | ||
enabled_precisions=enabled_precisions, | ||
debug=debug, | ||
workspace_size=workspace_size, | ||
max_num_trt_engines=max_num_trt_engines, | ||
) | ||
|
||
# Does not cause recompilation (same batch size as input) | ||
new_inputs = [ | ||
torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda"), | ||
torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda"), | ||
] | ||
new_outputs = optimized_model(*new_inputs) | ||
|
||
|
||
# Does cause recompilation (new batch size) | ||
new_batch_size_inputs = [torch.randn((8, 3, 224, 224)).half().to("cuda")] | ||
new_batch_size_outputs = optimized_model(*new_batch_size_inputs) |