Skip to content

Commit

Permalink
chore: refactor/complete export function
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
  • Loading branch information
peri044 committed Oct 30, 2023
1 parent 14dbbfd commit 8f4ff5d
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 93 deletions.
70 changes: 36 additions & 34 deletions py/torch_tensorrt/dynamo/_exporter.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
import copy
import operator
from typing import Any, Dict, Sequence, Tuple, Union, cast
from typing import Any, Dict, Sequence, Tuple, cast

import torch
from torch._export.exported_program import CallSpec
from torch._guards import detect_fake_mode
from torch._subclasses.fake_tensor import FakeTensor
from torch.export import ExportedProgram, ExportGraphSignature
from torch.export.exported_program import (
InputKind,
InputSpec,
OutputKind,
OutputSpec,
TensorArgument,
)
from torch_tensorrt.dynamo import partitioning


# TODO: @peri044: Correct this implementation
def export(
src_gm: torch.fx.GraphModule,
trt_gm: torch.fx.GraphModule,
gm: torch.fx.GraphModule,
inputs: Sequence[torch.Tensor],
*,
ir: str = "torchscript",
) -> ExportedProgram:
"""Export a program (``torch.fx.GraphModule``) for serialization with the TensorRT engines embedded.
Expand All @@ -39,12 +45,20 @@ def export(
format=torch.channel_last
), # Dynamic input shape for input #2
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
call_spec (CallSpec): CallSpec object. If ir=torchscript, set it to None. For ir=exported_program, the
ir (str): torchscript | exported_program. Based on the provided ir, the output type would be a torchscript or exported program.
"""

patched_module = transform(torch.fx.GraphModule, inputs)

return create_trt_exp_program(patched_module, src_gm.call_spec, src_gm.state_dict)
if ir == "torchscript":
return torch.jit.trace(gm, inputs)
elif ir == "exported_program":
patched_module = transform(gm, inputs)
exp_program = create_trt_exp_program(patched_module)

return exp_program
else:
raise ValueError(
"Invalid ir : {ir} provided for serialization. Options include torchscript | exported_program"
)


def transform(
Expand Down Expand Up @@ -212,40 +226,28 @@ def copy_submodule_attributes(

def create_trt_exp_program(
gm: torch.fx.GraphModule,
call_spec: CallSpec,
state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]],
) -> ExportedProgram:
"""Creates a new Exported Program. This function takes an torch.fx.GraphModule which has TRT engines
and constructs an Exported Program object with the new IO node names, call_spec and state_dict
"""
input_node_names = [
node.name for node in gm.graph.nodes if node.op == "placeholder"
input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
output_nodes = [node for node in gm.graph.nodes if node.op == "output"]

input_specs = [
InputSpec(InputKind.USER_INPUT, TensorArgument(name=node.name), node.target)
for node in input_nodes
]
output_specs = [
OutputSpec(OutputKind.USER_OUTPUT, TensorArgument(name=node.name), node.target)
for node in output_nodes
]
output_node_names = [node.name for node in gm.graph.nodes if node.op == "output"]
param_names = [param[0] for param in gm.named_parameters()]
buffer_names = [buffer[0] for buffer in gm.named_buffers()]
inputs_to_parameters = {}
inputs_to_buffers = {}
for node in gm.graph.nodes:
if node.target in param_names:
inputs_to_parameters[node.name] = node.target
if node.target in buffer_names:
inputs_to_buffers[node.name] = node.target

trt_graph_signature = ExportGraphSignature(
parameters=param_names,
buffers=buffer_names,
user_inputs=input_node_names,
user_outputs=output_node_names,
inputs_to_parameters=inputs_to_parameters,
inputs_to_buffers=inputs_to_buffers,
buffers_to_mutate={},
backward_signature=None,
assertion_dep_token=None,
input_specs=input_specs, output_specs=output_specs
)

trt_exp_program = ExportedProgram(
gm, gm.graph, trt_graph_signature, call_spec, state_dict, {}, [], []
gm, gm.graph, trt_graph_signature, gm.state_dict(), {}, [], [], []
)

return trt_exp_program
Expand Down
100 changes: 41 additions & 59 deletions tests/py/dynamo/models/test_export_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch_tensorrt as torchtrt
import torchvision.models as models
from torch._export.serde.serialize import deserialize, serialize
from torch_tensorrt.dynamo.export import create_trt_exp_program, transform
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity

assertions = unittest.TestCase()
Expand Down Expand Up @@ -45,21 +44,18 @@ def forward(self, x):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
trt_gm = transform(trt_gm, [input])
trt_exp_program = create_trt_exp_program(
trt_gm, exp_program.call_spec, trt_gm.state_dict()
)
trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program")
serialized_prog = serialize(trt_exp_program)
deserialized_prog = deserialize(*serialized_prog)

# Check Pyt and TRT exported program outputs
cos_sim = cosine_similarity(model(input), trt_exp_program(input))
cos_sim = cosine_similarity(model(input), trt_exp_program(input)[0])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)
# Check Pyt and deserialized TRT exported program outputs
cos_sim = cosine_similarity(model(input), deserialized_prog(input))
cos_sim = cosine_similarity(model(input), deserialized_prog(input)[0])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
Expand Down Expand Up @@ -100,11 +96,7 @@ def forward(self, x):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
trt_gm = transform(trt_gm, [input])
trt_exp_program = create_trt_exp_program(
trt_gm, exp_program.call_spec, trt_gm.state_dict()
)

trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program")
serialized_prog = serialize(trt_exp_program)
deserialized_prog = deserialize(*serialized_prog)
# Check Pyt and TRT exported program outputs
Expand Down Expand Up @@ -161,11 +153,7 @@ def forward(self, x):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
trt_gm = transform(trt_gm, [input])
trt_exp_program = create_trt_exp_program(
trt_gm, exp_program.call_spec, trt_gm.state_dict()
)

trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program")
torch._export.save(trt_exp_program, "/tmp/trt.ep")
deser_trt_exp_program = torch._export.load("/tmp/trt.ep")

Expand Down Expand Up @@ -224,11 +212,7 @@ def forward(self, x):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
trt_gm = transform(trt_gm, [input])
trt_exp_program = create_trt_exp_program(
trt_gm, exp_program.call_spec, trt_gm.state_dict()
)

trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program")
torch._export.save(trt_exp_program, "/tmp/trt.ep")
deser_trt_exp_program = torch._export.load("/tmp/trt.ep")

Expand All @@ -250,47 +234,45 @@ def forward(self, x):
)


@pytest.mark.unit
def test_resnet18_save_load(ir):
"""
This tests export save and load functionality on Resnet18 model
"""
model = models.resnet18().eval().cuda()
input = torch.randn((1, 3, 224, 224)).to("cuda")
# TODO (peri044) : Enable this test once the _frozen_param0 attribute resulting in sym_int ops issue is fixed.
# @pytest.mark.unit
# def test_resnet18_save_load(ir):
# """
# This tests export save and load functionality on Resnet18 model
# """
# model = models.resnet18().eval().cuda()
# input = torch.randn((1, 3, 224, 224)).to("cuda")

compile_spec = {
"inputs": [
torchtrt.Input(
input.shape, dtype=torch.float, format=torch.contiguous_format
)
],
"ir": ir,
"min_block_size": 1,
}
# compile_spec = {
# "inputs": [
# torchtrt.Input(
# input.shape, dtype=torch.float, format=torch.contiguous_format
# )
# ],
# "ir": ir,
# "min_block_size": 1,
# }

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
trt_gm = transform(trt_gm, [input])
trt_exp_program = create_trt_exp_program(
trt_gm, exp_program.call_spec, trt_gm.state_dict()
)
torch._export.save(trt_exp_program, "/tmp/trt.ep")
deser_trt_exp_program = torch._export.load("/tmp/trt.ep")
# exp_program = torchtrt.dynamo.trace(model, **compile_spec)
# trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
# trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program")
# torch._export.save(trt_exp_program, "/tmp/trt.ep")
# deser_trt_exp_program = torch._export.load("/tmp/trt.ep")

outputs_pyt = model(input)
outputs_trt = trt_exp_program(input)
cos_sim = cosine_similarity(outputs_pyt, outputs_trt)
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)
# outputs_pyt = model(input)
# outputs_trt = trt_exp_program(input)
# cos_sim = cosine_similarity(outputs_pyt, outputs_trt)
# assertions.assertTrue(
# cos_sim > COSINE_THRESHOLD,
# msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
# )

outputs_trt_deser = deser_trt_exp_program(input)
cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser)
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)
# outputs_trt_deser = deser_trt_exp_program(input)
# cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser)
# assertions.assertTrue(
# cos_sim > COSINE_THRESHOLD,
# msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
# )


# Enable this test once this issue is resolved https://github.com/pytorch/TensorRT/issues/2341
Expand Down

0 comments on commit 8f4ff5d

Please sign in to comment.