Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Wrap ExportedPrograms transformations with an API, allow dynamo.compile to accept graphmodules. #2388

Merged
merged 6 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions docsrc/user_guide/saving_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ The following code illustrates this approach.
model = MyModel().eval().cuda()
inputs = torch.randn((1, 3, 224, 224)).cuda()
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
trt_script_model = torch.jit.trace(trt_gm, inputs)
torch.jit.save(trt_script_model, "trt_model.ts")
trt_traced_model = torchtrt.dynamo.serialize(trt_gm, inputs)
torch.jit.save(trt_traced_model, "trt_model.ts")

# Later, you can load it and run inference
model = torch.jit.load("trt_model.ts").cuda()
Expand All @@ -37,21 +37,19 @@ b) ExportedProgram

import torch
import torch_tensorrt
from torch_tensorrt.dynamo.export import transform, create_exported_program

model = MyModel().eval().cuda()
inputs = torch.randn((1, 3, 224, 224)).cuda()
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
# Transform and create an exported program
trt_gm = transform(trt_gm, inputs)
trt_exp_program = create_exported_program(trt_gm, call_spec, trt_gm.state_dict())
torch._export.save(trt_exp_program, "trt_model.ep")
trt_exp_program = torch_tensorrt.dynamo.serialize(trt_gm, inputs, call_spec, ir="exported_program")
peri044 marked this conversation as resolved.
Show resolved Hide resolved
torch.export.save(trt_exp_program, "trt_model.ep")

# Later, you can load it and run inference
model = torch._export.load("trt_model.ep")
model = torch.export.load("trt_model.ep")
model(inputs)

`torch_tensorrt.dynamo.export.transform` inlines the submodules within a GraphModule to their corresponding nodes and stiches all the nodes together.
`torch_tensorrt.dynamo.transform` inlines the submodules within a GraphModule to their corresponding nodes, stiches all the nodes together and creates an ExportedProgram.
This is needed as `torch._export` serialization cannot handle serializing and deserializing of submodules (`call_module` nodes).

NOTE: This way of saving the models using `ExportedProgram` is experimental. Here is a known issue : https://github.com/pytorch/TensorRT/issues/2341
Expand Down
5 changes: 2 additions & 3 deletions py/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
numpy
packaging
pybind11==2.6.2
--extra-index-url https://download.pytorch.org/whl/nightly/cu121
torch>=2.1.0,<2.2.0
torchvision>=0.16.0,<0.17.0
torch==2.1.0
torchvision==0.16.0
--extra-index-url https://pypi.ngc.nvidia.com
tensorrt==8.6.1
pyyaml
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
DYNAMO_CONVERTERS,
dynamo_tensorrt_converter,
)
from .export import serialize
12 changes: 10 additions & 2 deletions py/torch_tensorrt/dynamo/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@


def compile(
exported_program: ExportedProgram,
exported_program: Union[torch.fx.GraphModule, ExportedProgram],
inputs: Any,
*,
device: Optional[Union[Device, torch.device, str]] = DEVICE,
Expand Down Expand Up @@ -86,7 +86,15 @@ def compile(
inputs = prepare_inputs(inputs)
device = to_torch_tensorrt_device(device)

gm = exported_program.module()
if isinstance(exported_program, torch.fx.GraphModule):
gm = exported_program
elif isinstance(exported_program, ExportedProgram):
gm = exported_program.module()
else:
raise AssertionError(
f"Input graph should either be an ExportedProgram or a GraphModule but got type {type(exported_program)}"
)

logger.debug("Input graph: " + str(gm.graph))

# Apply lowering on the graph module
Expand Down
67 changes: 39 additions & 28 deletions py/torch_tensorrt/dynamo/export.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
import operator
from typing import Any, Dict, Sequence, Tuple, Union, cast
from typing import Any, Dict, Sequence, Tuple, cast

import torch
from torch._export.exported_program import CallSpec
Expand All @@ -10,28 +10,42 @@
from torch_tensorrt.dynamo import partitioning


def transform(
gm: torch.fx.GraphModule, inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
# Run shape analysis
_, outputs_map = partitioning.run_shape_analysis(gm, inputs)

# Inline TensorRT submodules
inline_trt_modules(gm, outputs_map)

# Inline pytorch submodules
inline_torch_modules(gm)

# Lift constant buffers and parameters in the graph
# torch.export serialization expects them to be lifted
lift_constant_pass(gm)

# Clean the graph
gm.delete_all_unused_submodules()
gm.graph.eliminate_dead_code()
gm.graph.lint()

return gm
def serialize(
gm: torch.fx.GraphModule,
inputs: Sequence[torch.Tensor],
call_spec: CallSpec = None,
ir: str = "torchscript",
) -> ExportedProgram:
if ir == "torchscript":
return torch.jit.trace(gm, inputs)
elif ir == "exported_program":
assert call_spec
# Run shape analysis
_, outputs_map = partitioning.run_shape_analysis(gm, inputs)

# Inline TensorRT submodules
inline_trt_modules(gm, outputs_map)

# Inline pytorch submodules
inline_torch_modules(gm)

# Lift constant buffers and parameters in the graph
# torch.export serialization expects them to be lifted
lift_constant_pass(gm)

# Clean the graph
gm.delete_all_unused_submodules()
gm.graph.eliminate_dead_code()
gm.graph.lint()

# Create an exported program with the TRT GraphModule
exp_program = create_trt_exp_program(gm, call_spec)

return exp_program
else:
raise ValueError(
"Invalid ir : {ir} provided for serialization. Options include torchscript | exported_program"
)


def lift_constant_pass(trt_gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
Expand Down Expand Up @@ -115,7 +129,6 @@ def inline_torch_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:

# Copy all nodes in the submodule into gm and
# store the output node of this submodule which is now present in gm

submodule_output = gm.graph.graph_copy(submodule.graph, val_map)

# Get their references (since we copied) in the parent graph (gm)
Expand Down Expand Up @@ -174,9 +187,7 @@ def copy_submodule_attributes(


def create_trt_exp_program(
gm: torch.fx.GraphModule,
call_spec: CallSpec,
state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]],
gm: torch.fx.GraphModule, call_spec: CallSpec
) -> ExportedProgram:
"""Creates a new Exported Program. This function takes an torch.fx.GraphModule which has TRT engines
and constructs an Exported Program object with the new IO node names, call_spec and state_dict
Expand Down Expand Up @@ -208,7 +219,7 @@ def create_trt_exp_program(
)

trt_exp_program = ExportedProgram(
gm, gm.graph, trt_graph_signature, call_spec, state_dict, {}, [], []
gm, gm.graph, trt_graph_signature, call_spec, gm.state_dict(), {}, [], []
)

return trt_exp_program
Expand Down
85 changes: 10 additions & 75 deletions tests/py/dynamo/models/test_export_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch_tensorrt as torchtrt
import torchvision.models as models
from torch._export.serde.serialize import deserialize, serialize
from torch_tensorrt.dynamo.export import create_trt_exp_program, transform
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity

assertions = unittest.TestCase()
Expand Down Expand Up @@ -45,9 +44,8 @@ def forward(self, x):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
trt_gm = transform(trt_gm, [input])
trt_exp_program = create_trt_exp_program(
trt_gm, exp_program.call_spec, trt_gm.state_dict()
trt_exp_program = torchtrt.dynamo.serialize(
trt_gm, [input], call_spec=exp_program.call_spec, ir="exported_program"
)
serialized_prog = serialize(trt_exp_program)
deserialized_prog = deserialize(*serialized_prog)
Expand Down Expand Up @@ -100,11 +98,9 @@ def forward(self, x):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
trt_gm = transform(trt_gm, [input])
trt_exp_program = create_trt_exp_program(
trt_gm, exp_program.call_spec, trt_gm.state_dict()
trt_exp_program = torchtrt.dynamo.serialize(
trt_gm, [input], call_spec=exp_program.call_spec, ir="exported_program"
)

serialized_prog = serialize(trt_exp_program)
deserialized_prog = deserialize(*serialized_prog)
# Check Pyt and TRT exported program outputs
Expand Down Expand Up @@ -161,11 +157,9 @@ def forward(self, x):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
trt_gm = transform(trt_gm, [input])
trt_exp_program = create_trt_exp_program(
trt_gm, exp_program.call_spec, trt_gm.state_dict()
trt_exp_program = torchtrt.dynamo.serialize(
trt_gm, [input], call_spec=exp_program.call_spec, ir="exported_program"
)

torch._export.save(trt_exp_program, "/tmp/trt.ep")
deser_trt_exp_program = torch._export.load("/tmp/trt.ep")

Expand Down Expand Up @@ -224,11 +218,9 @@ def forward(self, x):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
trt_gm = transform(trt_gm, [input])
trt_exp_program = create_trt_exp_program(
trt_gm, exp_program.call_spec, trt_gm.state_dict()
trt_exp_program = torchtrt.dynamo.serialize(
trt_gm, [input], call_spec=exp_program.call_spec, ir="exported_program"
)

torch._export.save(trt_exp_program, "/tmp/trt.ep")
deser_trt_exp_program = torch._export.load("/tmp/trt.ep")

Expand Down Expand Up @@ -270,9 +262,8 @@ def test_resnet18_save_load(ir):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
trt_gm = transform(trt_gm, [input])
trt_exp_program = create_trt_exp_program(
trt_gm, exp_program.call_spec, trt_gm.state_dict()
trt_exp_program = torchtrt.dynamo.serialize(
trt_gm, [input], call_spec=exp_program.call_spec, ir="exported_program"
)
torch._export.save(trt_exp_program, "/tmp/trt.ep")
deser_trt_exp_program = torch._export.load("/tmp/trt.ep")
Expand All @@ -291,59 +282,3 @@ def test_resnet18_save_load(ir):
cos_sim > COSINE_THRESHOLD,
msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


# Enable this test once this issue is resolved https://github.com/pytorch/TensorRT/issues/2341
# @pytest.mark.unit
# def test_hybrid_conv_fallback(ir):
# """
# This tests export save and load functionality on a hybrid
# model where a conv (a weighted layer) has been forced to fallback to Pytorch.
# """

# class MyModule(torch.nn.Module):
# def __init__(self):
# super().__init__()
# self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
# self.relu = torch.nn.ReLU()

# def forward(self, x):
# conv = self.conv(x)
# relu = self.relu(conv)
# mul = relu * 0.5
# return mul

# model = MyModule().eval().cuda()
# input = torch.randn((1, 3, 224, 224)).to("cuda")

# compile_spec = {
# "inputs": [
# torchtrt.Input(
# input.shape, dtype=torch.float, format=torch.contiguous_format
# )
# ],
# "ir": ir,
# "min_block_size": 1,
# "torch_executed_ops": "torch.ops.aten.convolution.default",
# }

# trt_exp_program = torchtrt.compile(model, **compile_spec)
# torch._export.save(trt_exp_program, "/tmp/trt.ep")
# deser_trt_exp_program = torch._export.load("/tmp/trt.ep")

# outputs_pyt = model(input)
# outputs_trt = trt_exp_program(input)
# for idx in range(len(outputs_pyt)):
# cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx])
# assertions.assertTrue(
# cos_sim > COSINE_THRESHOLD,
# msg=f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
# )

# outputs_trt_deser = deser_trt_exp_program(input)
# for idx in range(len(outputs_pyt)):
# cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
# assertions.assertTrue(
# cos_sim > COSINE_THRESHOLD,
# msg=f"test_base_full_compile_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
# )