Skip to content

Commit

Permalink
fix: Add support for truncate_long_and_double in Dynamo [8 / x] (#1983
Browse files Browse the repository at this point in the history
)
  • Loading branch information
gs-olive authored Jul 24, 2023
1 parent a2d61a2 commit a765b72
Show file tree
Hide file tree
Showing 7 changed files with 337 additions and 2 deletions.
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
VERSION_COMPATIBLE = False
OPTIMIZATION_LEVEL = None
USE_PYTHON_RUNTIME = None
TRUNCATE_LONG_AND_DOUBLE = False
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
VERSION_COMPATIBLE,
OPTIMIZATION_LEVEL,
USE_PYTHON_RUNTIME,
TRUNCATE_LONG_AND_DOUBLE,
)


Expand All @@ -26,3 +27,4 @@ class CompilationSettings:
version_compatible: bool = VERSION_COMPATIBLE
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
11 changes: 10 additions & 1 deletion py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
get_submod_inputs,
)
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
from torch_tensorrt.dynamo.conversion import convert_module
from torch_tensorrt.dynamo.conversion import (
convert_module,
repair_long_or_double_inputs,
)

from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler

Expand Down Expand Up @@ -135,6 +138,12 @@ def _compile_module(
partitioned_module, submodule, sample_inputs
)

# Handle long/double inputs if requested by the user
if settings.truncate_long_and_double:
submodule_inputs = repair_long_or_double_inputs(
partitioned_module, submodule, submodule_inputs, name
)

# Create TRT Module from submodule
trt_mod = convert_module(
submodule,
Expand Down
4 changes: 3 additions & 1 deletion py/torch_tensorrt/dynamo/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
VERSION_COMPATIBLE,
OPTIMIZATION_LEVEL,
USE_PYTHON_RUNTIME,
TRUNCATE_LONG_AND_DOUBLE,
)


Expand All @@ -53,7 +54,7 @@ def compile(
dla_local_dram_size=1073741824,
dla_global_dram_size=536870912,
calibrator=None,
truncate_long_and_double=False,
truncate_long_and_double=TRUNCATE_LONG_AND_DOUBLE,
require_full_compilation=False,
min_block_size=MIN_BLOCK_SIZE,
torch_executed_ops=[],
Expand Down Expand Up @@ -109,6 +110,7 @@ def compile(
"version_compatible": version_compatible,
"optimization_level": optimization_level,
"use_python_runtime": use_python_runtime,
"truncate_long_and_double": truncate_long_and_double,
}

settings = CompilationSettings(**compilation_options)
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .trt_interpreter import *
from .conversion import *
from .truncate_long_and_double import repair_long_or_double_inputs
207 changes: 207 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
import torch
from torch.fx.node import _get_qualified_name
from typing import Optional, Sequence, Union


def _extract_downstream_get_nodes(
module_node: torch.fx.Node, output_indices: Sequence[int]
) -> Sequence[torch.fx.Node]:
"""Extracts downstream users of a node which get the item at a particular index
Certain module-type nodes have multiple outputs (tuple of outputs). This function
returns downstream nodes which call the _operator.getitem function, which extracts
the element at a particular index in the tuple
Args:
module_node: FX module-type node to analyze
output_index: Indices in the module node output to search for
Returns:
List of nodes which get the item at the specified index in the module node output
"""
get_nodes = []

# Iterate over all downstream users of the node object
for user in module_node.users:
# If the user is a "get" node accessing the specified index, store it
if _get_qualified_name(user.target) == "_operator.getitem" and (
user.args[1] in output_indices
):
get_nodes.append(user)

return get_nodes


def _repair_64bit_input(
gm: torch.fx.GraphModule,
position: int,
submodule_name: str,
submodule_outputs: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]],
dtype: torch.dtype,
):
"""Fixes a single Long/Double input to a TRT-accelerated subgraph
In-Place modifies the provided graph
Inserts a cast to the 32-bit equivalent type for TRT, then if necessary,
inserts an upcast back to the 64-bit type for subsequent Torch operations
Args:
gm: FX GraphModule enclosing the TRT subgraph
position: Index in the submodule inputs at which the long or double input is found
submodule_name: Name of TRT-accelerated subgraph module in FX graph
submodule_outputs: Output tensor(s) of TRT-accelerated subgraph (used for dtypes/structure)
dtype: Data type of tensor at position in submodule (double/long)
"""
assert dtype in (
torch.int64,
torch.float64,
), f"dtype argument must be torch.int64 or torch.float64, got {dtype}"

# Determine target data type in 32 and 64 bit forms
dtype_64bit = dtype
dtype_32bit = torch.int32 if (dtype == torch.int64) else torch.float32

# Find the node representing the submodule in the graph
module_node = None

# Iterate over all nodes in the graph, seeking target module name match
for n in gm.graph.nodes:
if n.op == "call_module" and str(n.target) == submodule_name:
module_node = n
break

if module_node is None:
raise AssertionError(
f"Sought module node {submodule_name}, could not find in graph:\n{gm.graph}"
)

# Extract the 64-bit node of the input
node_64bit = module_node.all_input_nodes[position]

# Prior to the module, insert a cast to the 32-bit equivalent node
with gm.graph.inserting_before(module_node):
node_32bit = gm.graph.call_function(
torch.ops.aten._to_copy.default,
args=(node_64bit,),
kwargs={"dtype": dtype_32bit},
)

# Replace 64-bit input to TRT module with new 32-bit cast node
module_node.replace_input_with(node_64bit, node_32bit)

output_positions_64bit = set()
outputs_list = (
[submodule_outputs]
if isinstance(submodule_outputs, torch.Tensor)
else submodule_outputs
)

# Determine if any outputs of the model are 64-bit type and store their indices
if submodule_outputs is not None:
for output_position, output in enumerate(outputs_list):
if output.dtype == dtype_64bit:
output_positions_64bit.add(output_position)

# Only enter this code block if there exists a 64-bit output
# This implies a cast is needed, since TRT cannot output 64-bit tensors
if output_positions_64bit:
# Determine whther the outputs of the module are tuple-type or not
is_collection_output = False
if isinstance(submodule_outputs, tuple):
is_collection_output = True

if not is_collection_output:
# If the output is a single tensor, insert a cast back to int64
with gm.graph.inserting_after(module_node):
cast_node_64bit = gm.graph.call_function(
torch.ops.aten._to_copy.default,
args=(module_node,),
kwargs={"dtype": dtype_64bit},
)

# Replace all uses of the TRT module (except the cast node) with the 64-bit equivalent
module_node.replace_all_uses_with(
cast_node_64bit, delete_user_cb=lambda user: (user != cast_node_64bit)
)

else:
# If the output is a tuple of tensors, extract downstream users for each 64-bit output
get_nodes = _extract_downstream_get_nodes(
module_node, output_positions_64bit
)

# For each downstream user, append a cast node back to the 64-bit precision
for get_node in get_nodes:
with gm.graph.inserting_after(get_node):
cast_node_64bit = gm.graph.call_function(
torch.ops.aten._to_copy.default,
args=(get_node,),
kwargs={"dtype": torch.int64},
)

get_node.replace_all_uses_with(
cast_node_64bit,
delete_user_cb=lambda user: (user != cast_node_64bit),
)

# Clean up graph and ensure invariants are preserved
gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()


def repair_long_or_double_inputs(
parent_graph: torch.fx.GraphModule,
submodule: torch.fx.GraphModule,
submodule_inputs: Sequence[torch.Tensor],
submodule_name: Optional[str] = None,
) -> Sequence[torch.Tensor]:
"""Fixes all Long/Double type inputs to a TRT-accelerated subgraph
In-Place modifies the provided graph
Inserts a cast to the 32-bit equivalent type for TRT, then if necessary,
inserts an upcast back to the 64-bit type for subsequent Torch operations
Args:
parent_graph: FX GraphModule enclosing the TRT subgraph
submodule: Child submodule to repair inputs on
submodule_inputs: Input tensor(s) of TRT-accelerated subgraph (used for dtypes/structure)
submodule_name: Optionally specify the name of the submodule target in the parent graph
Returns:
New submodule inputs, updated accordingly with long/double truncation
"""
num_submodule_inputs = len(submodule_inputs)
repaired_outputs_once = False

# For each input to the TRT subgraph, check if its type is long/double
for position in range(num_submodule_inputs):
param = submodule_inputs[position]

# If the data type of the input is long/double, insert necessary
# casts to replace the operation
if param.dtype in (torch.int64, torch.float64):
# Ensure outputs are only repaired once per submodule to avoid
# unnecessary ops showing up in the graph
if not repaired_outputs_once:
submodule_outputs = submodule(*submodule_inputs)

_repair_64bit_input(
parent_graph,
position,
submodule_name if submodule_name is not None else submodule._get_name(),
None if repaired_outputs_once else submodule_outputs,
param.dtype,
)

repaired_outputs_once = True

# Repair submodule inputs in accordance with inserted casts
dtype_32bit = torch.int32 if (param.dtype == torch.int64) else torch.float32
submodule_inputs = (
submodule_inputs[:position]
+ (param.to(dtype_32bit),)
+ submodule_inputs[position + 1 :]
)

return submodule_inputs
113 changes: 113 additions & 0 deletions tests/py/dynamo/backend/test_backend_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,5 +171,118 @@ def forward(self, x, y):
)


class Test64BitInput(TestCase):
def test_float64_input_full_support(self):
class FullySupportedMultiOp(torch.nn.Module):
def forward(self, x, y):
return torch.ops.aten.mean.dim(
torch.ops.aten.mul.Tensor(torch.ops.aten.add.Tensor(x, y), 2), [0]
)

fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp())
partitioned_graph = partition(deepcopy(fx_graph), min_block_size=3)

self.assertEquals(
len(list(partitioned_graph.named_children())),
1,
"All operators are supported, there should be one segment",
)

inputs = [
torch.randint(-5, 5, (16, 7), dtype=torch.double).cuda(),
torch.randint(-5, 5, (16, 7), dtype=torch.double).cuda(),
]

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
truncate_long_and_double=True,
debug=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"TRT outputs don't match with the original model.",
)

def test_int64_input_partial_support(self):
class PartiallySupportedMultiOp(torch.nn.Module):
def forward(self, x, y):
return torch.ops.aten.div.Tensor_mode(
x, torch.ops.aten.add.Tensor(y, y), rounding_mode="floor"
)

fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp())
unexpected_ops = {torch.ops.aten.add.Tensor}

inputs = [
torch.randint(-40, 40, (16, 7, 5), dtype=torch.long).cuda(),
torch.randint(1, 40, (16, 7, 5), dtype=torch.long).cuda(),
]

(unexpected_ops_seen, _, partitioned_graphs,) = lower_graph_testing(
fx_graph,
inputs,
unexpected_ops=unexpected_ops,
min_block_size=1,
torch_executed_ops={"torch.ops.aten.add.Tensor"},
testing_partitioning=True,
)

self.assertEquals(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)
self.assertEquals(
len(partitioned_graphs),
1,
"Without control flow breaks, there should only be a single graph",
)
self.assertEquals(
len(list(partitioned_graphs[0].named_children())),
1,
"Certain operators are set to run in Torch, expected 1 segment",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
truncate_long_and_double=True,
debug=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"TRT outputs don't match with the original model.",
)


if __name__ == "__main__":
run_tests()

0 comments on commit a765b72

Please sign in to comment.