Skip to content

Commit

Permalink
feat: Add example usage scripts for dynamo path
Browse files Browse the repository at this point in the history
- Add sample scripts covering resnet18, transformers, and custom
examples showcasing the `torch_tensorrt.dynamo.torch_compile` path,
which can compile models with data-dependent control flow and other such
restrictions which can make other compilation methods more difficult
- Cover different customizeable features allowed in the new backend
  • Loading branch information
gs-olive committed May 5, 2023
1 parent 25db257 commit 9cbd31b
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 0 deletions.
53 changes: 53 additions & 0 deletions examples/dynamo/torch_compile_advanced_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
from torch_tensorrt.dynamo.torch_compile import create_backend
from torch_tensorrt.fx.lower_setting import LowerPrecision


##### Overview
# This script is intended as an overview of the process by which
# torch_tensorrt.dynamo.torch_compile works, and how it integrates
# with the new torch.compile API.

# We begin by defining a model
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.relu = torch.nn.ReLU()

def forward(self, x: torch.Tensor, y: torch.Tensor):
x_out = self.relu(x)
y_out = self.relu(y)
x_y_out = x_out + y_out
return torch.mean(x_y_out)


##### Compilation using default settings

sample_inputs = [torch.rand((5, 7)).cuda(), torch.rand((5, 7)).cuda()]
model = Model().eval().cuda()

# Next, we compile the model using torch.compile
# For the default settings, we can simply call torch.compile
# with the backend "tensorrt", and run the model on an
# input to cause compilation, as so:
optimized_model = torch.compile(model, backend="tensorrt")
optimized_model(*sample_inputs)


##### Compilation using custom settings

sample_inputs_half = [
torch.rand((5, 7)).half().cuda(),
torch.rand((5, 7)).half().cuda(),
]
model_half = Model().half().eval().cuda()

# Alternatively, if we want to customize certain options in the backend,
# but still use the torch.compile call directly, we can call the
# convenience/helper function create_backend to create a custom backend
# which has been pre-populated with certain key
custom_backend = create_backend(
lower_precision=LowerPrecision.FP16, debug=True, max_num_trt_engines=2
)
optimized_model_custom = torch.compile(model_half, backend=custom_backend)
optimized_model_custom(*sample_inputs_half)
45 changes: 45 additions & 0 deletions examples/dynamo/torch_compile_resnet_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
from torch_tensorrt.dynamo import torch_compile
import torchvision.models as models

##### Overview
# This script is intended as a sample of the torch_tensorrt.dynamo.torch_compile
# workflow on the resnet18 model


# Initialize model and sample inputs
model = models.resnet18(pretrained=True).half().eval().to("cuda")
inputs = [torch.randn((1, 3, 224, 224)).to("cuda").half()]

##### Optional Input Arguments

# Enabled precision for TensorRT optimization
enabled_precisions = {torch.half}
# Whether to print verbose logs
debug = True
# Workspace size for TensorRT
workspace_size = 20 << 30
# Maximum number of TRT Engines
# (Higher value allows more graph segmentation)
max_num_trt_engines = 100


# Build and compile the model with torch.compile, using tensorrt backend
optimized_model = torch_compile(
model,
inputs,
enabled_precisions=enabled_precisions,
debug=debug,
workspace_size=workspace_size,
max_num_trt_engines=max_num_trt_engines,
)


# Does not cause recompilation (same batch size as input)
new_inputs = [torch.randn((1, 3, 224, 224)).half().to("cuda")]
new_outputs = optimized_model(*new_inputs)


# Does cause recompilation (new batch size)
new_batch_size_inputs = [torch.randn((8, 3, 224, 224)).half().to("cuda")]
new_batch_size_outputs = optimized_model(*new_batch_size_inputs)
49 changes: 49 additions & 0 deletions examples/dynamo/torch_compile_transformers_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
from torch_tensorrt.dynamo import torch_compile
from transformers import BertModel

##### Overview
# This script is intended as a sample of the torch_tensorrt.dynamo.torch_compile
# workflow on the BERT base uncased model


model = BertModel.from_pretrained("bert-base-uncased").cuda().eval()
inputs = [
torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda"),
torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda"),
]

##### Optional Input Arguments

# Enabled precision for TensorRT optimization
enabled_precisions = {torch.float}
# Whether to print verbose logs
debug = True
# Workspace size for TensorRT
workspace_size = 20 << 30
# Maximum number of TRT Engines
# (Higher value allows more graph segmentation)
max_num_trt_engines = 200


# Build and compile the model with torch.compile, using tensorrt backend
optimized_model = torch_compile(
model,
inputs,
enabled_precisions=enabled_precisions,
debug=debug,
workspace_size=workspace_size,
max_num_trt_engines=max_num_trt_engines,
)

# Does not cause recompilation (same batch size as input)
new_inputs = [
torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda"),
torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda"),
]
new_outputs = optimized_model(*new_inputs)


# Does cause recompilation (new batch size)
new_batch_size_inputs = [torch.randn((8, 3, 224, 224)).half().to("cuda")]
new_batch_size_outputs = optimized_model(*new_batch_size_inputs)

0 comments on commit 9cbd31b

Please sign in to comment.