diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index e55c592f4b..88b2a1e8b2 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -9,3 +9,4 @@ VERSION_COMPATIBLE = False OPTIMIZATION_LEVEL = None USE_PYTHON_RUNTIME = None +TRUNCATE_LONG_AND_DOUBLE = False diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 85a2693606..99ec34ec27 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -11,6 +11,7 @@ VERSION_COMPATIBLE, OPTIMIZATION_LEVEL, USE_PYTHON_RUNTIME, + TRUNCATE_LONG_AND_DOUBLE, ) @@ -26,3 +27,4 @@ class CompilationSettings: version_compatible: bool = VERSION_COMPATIBLE optimization_level: Optional[int] = OPTIMIZATION_LEVEL use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME + truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 01827af13a..a89999b930 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -16,7 +16,10 @@ get_submod_inputs, ) from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs -from torch_tensorrt.dynamo.conversion import convert_module +from torch_tensorrt.dynamo.conversion import ( + convert_module, + repair_long_or_double_inputs, +) from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler @@ -135,6 +138,12 @@ def _compile_module( partitioned_module, submodule, sample_inputs ) + # Handle long/double inputs if requested by the user + if settings.truncate_long_and_double: + submodule_inputs = repair_long_or_double_inputs( + partitioned_module, submodule, submodule_inputs, name + ) + # Create TRT Module from submodule trt_mod = convert_module( submodule, diff --git a/py/torch_tensorrt/dynamo/compile.py b/py/torch_tensorrt/dynamo/compile.py index 5528fe88b1..b27dcb45ee 100644 --- a/py/torch_tensorrt/dynamo/compile.py +++ b/py/torch_tensorrt/dynamo/compile.py @@ -30,6 +30,7 @@ VERSION_COMPATIBLE, OPTIMIZATION_LEVEL, USE_PYTHON_RUNTIME, + TRUNCATE_LONG_AND_DOUBLE, ) @@ -53,7 +54,7 @@ def compile( dla_local_dram_size=1073741824, dla_global_dram_size=536870912, calibrator=None, - truncate_long_and_double=False, + truncate_long_and_double=TRUNCATE_LONG_AND_DOUBLE, require_full_compilation=False, min_block_size=MIN_BLOCK_SIZE, torch_executed_ops=[], @@ -109,6 +110,7 @@ def compile( "version_compatible": version_compatible, "optimization_level": optimization_level, "use_python_runtime": use_python_runtime, + "truncate_long_and_double": truncate_long_and_double, } settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/conversion/__init__.py b/py/torch_tensorrt/dynamo/conversion/__init__.py index 56c2361f13..f50b22f27d 100644 --- a/py/torch_tensorrt/dynamo/conversion/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/__init__.py @@ -1,2 +1,3 @@ from .trt_interpreter import * from .conversion import * +from .truncate_long_and_double import repair_long_or_double_inputs diff --git a/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py b/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py new file mode 100644 index 0000000000..fc3263de57 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py @@ -0,0 +1,207 @@ +import torch +from torch.fx.node import _get_qualified_name +from typing import Optional, Sequence, Union + + +def _extract_downstream_get_nodes( + module_node: torch.fx.Node, output_indices: Sequence[int] +) -> Sequence[torch.fx.Node]: + """Extracts downstream users of a node which get the item at a particular index + + Certain module-type nodes have multiple outputs (tuple of outputs). This function + returns downstream nodes which call the _operator.getitem function, which extracts + the element at a particular index in the tuple + + Args: + module_node: FX module-type node to analyze + output_index: Indices in the module node output to search for + Returns: + List of nodes which get the item at the specified index in the module node output + """ + get_nodes = [] + + # Iterate over all downstream users of the node object + for user in module_node.users: + # If the user is a "get" node accessing the specified index, store it + if _get_qualified_name(user.target) == "_operator.getitem" and ( + user.args[1] in output_indices + ): + get_nodes.append(user) + + return get_nodes + + +def _repair_64bit_input( + gm: torch.fx.GraphModule, + position: int, + submodule_name: str, + submodule_outputs: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]], + dtype: torch.dtype, +): + """Fixes a single Long/Double input to a TRT-accelerated subgraph + + In-Place modifies the provided graph + + Inserts a cast to the 32-bit equivalent type for TRT, then if necessary, + inserts an upcast back to the 64-bit type for subsequent Torch operations + + Args: + gm: FX GraphModule enclosing the TRT subgraph + position: Index in the submodule inputs at which the long or double input is found + submodule_name: Name of TRT-accelerated subgraph module in FX graph + submodule_outputs: Output tensor(s) of TRT-accelerated subgraph (used for dtypes/structure) + dtype: Data type of tensor at position in submodule (double/long) + """ + assert dtype in ( + torch.int64, + torch.float64, + ), f"dtype argument must be torch.int64 or torch.float64, got {dtype}" + + # Determine target data type in 32 and 64 bit forms + dtype_64bit = dtype + dtype_32bit = torch.int32 if (dtype == torch.int64) else torch.float32 + + # Find the node representing the submodule in the graph + module_node = None + + # Iterate over all nodes in the graph, seeking target module name match + for n in gm.graph.nodes: + if n.op == "call_module" and str(n.target) == submodule_name: + module_node = n + break + + if module_node is None: + raise AssertionError( + f"Sought module node {submodule_name}, could not find in graph:\n{gm.graph}" + ) + + # Extract the 64-bit node of the input + node_64bit = module_node.all_input_nodes[position] + + # Prior to the module, insert a cast to the 32-bit equivalent node + with gm.graph.inserting_before(module_node): + node_32bit = gm.graph.call_function( + torch.ops.aten._to_copy.default, + args=(node_64bit,), + kwargs={"dtype": dtype_32bit}, + ) + + # Replace 64-bit input to TRT module with new 32-bit cast node + module_node.replace_input_with(node_64bit, node_32bit) + + output_positions_64bit = set() + outputs_list = ( + [submodule_outputs] + if isinstance(submodule_outputs, torch.Tensor) + else submodule_outputs + ) + + # Determine if any outputs of the model are 64-bit type and store their indices + if submodule_outputs is not None: + for output_position, output in enumerate(outputs_list): + if output.dtype == dtype_64bit: + output_positions_64bit.add(output_position) + + # Only enter this code block if there exists a 64-bit output + # This implies a cast is needed, since TRT cannot output 64-bit tensors + if output_positions_64bit: + # Determine whther the outputs of the module are tuple-type or not + is_collection_output = False + if isinstance(submodule_outputs, tuple): + is_collection_output = True + + if not is_collection_output: + # If the output is a single tensor, insert a cast back to int64 + with gm.graph.inserting_after(module_node): + cast_node_64bit = gm.graph.call_function( + torch.ops.aten._to_copy.default, + args=(module_node,), + kwargs={"dtype": dtype_64bit}, + ) + + # Replace all uses of the TRT module (except the cast node) with the 64-bit equivalent + module_node.replace_all_uses_with( + cast_node_64bit, delete_user_cb=lambda user: (user != cast_node_64bit) + ) + + else: + # If the output is a tuple of tensors, extract downstream users for each 64-bit output + get_nodes = _extract_downstream_get_nodes( + module_node, output_positions_64bit + ) + + # For each downstream user, append a cast node back to the 64-bit precision + for get_node in get_nodes: + with gm.graph.inserting_after(get_node): + cast_node_64bit = gm.graph.call_function( + torch.ops.aten._to_copy.default, + args=(get_node,), + kwargs={"dtype": torch.int64}, + ) + + get_node.replace_all_uses_with( + cast_node_64bit, + delete_user_cb=lambda user: (user != cast_node_64bit), + ) + + # Clean up graph and ensure invariants are preserved + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + +def repair_long_or_double_inputs( + parent_graph: torch.fx.GraphModule, + submodule: torch.fx.GraphModule, + submodule_inputs: Sequence[torch.Tensor], + submodule_name: Optional[str] = None, +) -> Sequence[torch.Tensor]: + """Fixes all Long/Double type inputs to a TRT-accelerated subgraph + + In-Place modifies the provided graph + + Inserts a cast to the 32-bit equivalent type for TRT, then if necessary, + inserts an upcast back to the 64-bit type for subsequent Torch operations + + Args: + parent_graph: FX GraphModule enclosing the TRT subgraph + submodule: Child submodule to repair inputs on + submodule_inputs: Input tensor(s) of TRT-accelerated subgraph (used for dtypes/structure) + submodule_name: Optionally specify the name of the submodule target in the parent graph + Returns: + New submodule inputs, updated accordingly with long/double truncation + """ + num_submodule_inputs = len(submodule_inputs) + repaired_outputs_once = False + + # For each input to the TRT subgraph, check if its type is long/double + for position in range(num_submodule_inputs): + param = submodule_inputs[position] + + # If the data type of the input is long/double, insert necessary + # casts to replace the operation + if param.dtype in (torch.int64, torch.float64): + # Ensure outputs are only repaired once per submodule to avoid + # unnecessary ops showing up in the graph + if not repaired_outputs_once: + submodule_outputs = submodule(*submodule_inputs) + + _repair_64bit_input( + parent_graph, + position, + submodule_name if submodule_name is not None else submodule._get_name(), + None if repaired_outputs_once else submodule_outputs, + param.dtype, + ) + + repaired_outputs_once = True + + # Repair submodule inputs in accordance with inserted casts + dtype_32bit = torch.int32 if (param.dtype == torch.int64) else torch.float32 + submodule_inputs = ( + submodule_inputs[:position] + + (param.to(dtype_32bit),) + + submodule_inputs[position + 1 :] + ) + + return submodule_inputs diff --git a/tests/py/dynamo/backend/test_backend_compiler.py b/tests/py/dynamo/backend/test_backend_compiler.py index 4bc1c8c18c..282fcdbfd2 100644 --- a/tests/py/dynamo/backend/test_backend_compiler.py +++ b/tests/py/dynamo/backend/test_backend_compiler.py @@ -171,5 +171,118 @@ def forward(self, x, y): ) +class Test64BitInput(TestCase): + def test_float64_input_full_support(self): + class FullySupportedMultiOp(torch.nn.Module): + def forward(self, x, y): + return torch.ops.aten.mean.dim( + torch.ops.aten.mul.Tensor(torch.ops.aten.add.Tensor(x, y), 2), [0] + ) + + fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) + partitioned_graph = partition(deepcopy(fx_graph), min_block_size=3) + + self.assertEquals( + len(list(partitioned_graph.named_children())), + 1, + "All operators are supported, there should be one segment", + ) + + inputs = [ + torch.randint(-5, 5, (16, 7), dtype=torch.double).cuda(), + torch.randint(-5, 5, (16, 7), dtype=torch.double).cuda(), + ] + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + truncate_long_and_double=True, + debug=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"TRT outputs don't match with the original model.", + ) + + def test_int64_input_partial_support(self): + class PartiallySupportedMultiOp(torch.nn.Module): + def forward(self, x, y): + return torch.ops.aten.div.Tensor_mode( + x, torch.ops.aten.add.Tensor(y, y), rounding_mode="floor" + ) + + fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) + unexpected_ops = {torch.ops.aten.add.Tensor} + + inputs = [ + torch.randint(-40, 40, (16, 7, 5), dtype=torch.long).cuda(), + torch.randint(1, 40, (16, 7, 5), dtype=torch.long).cuda(), + ] + + (unexpected_ops_seen, _, partitioned_graphs,) = lower_graph_testing( + fx_graph, + inputs, + unexpected_ops=unexpected_ops, + min_block_size=1, + torch_executed_ops={"torch.ops.aten.add.Tensor"}, + testing_partitioning=True, + ) + + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + self.assertEquals( + len(partitioned_graphs), + 1, + "Without control flow breaks, there should only be a single graph", + ) + self.assertEquals( + len(list(partitioned_graphs[0].named_children())), + 1, + "Certain operators are set to run in Torch, expected 1 segment", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + truncate_long_and_double=True, + debug=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"TRT outputs don't match with the original model.", + ) + + if __name__ == "__main__": run_tests()