From 7daa1120dc1bc72d6f92f1e7aa2b357a65b6ea08 Mon Sep 17 00:00:00 2001 From: George S <113141689+gs-olive@users.noreply.github.com> Date: Tue, 26 Sep 2023 14:45:57 -0700 Subject: [PATCH] fix: Remove input aliasing of builtin ops (#2276) --- py/torch_tensorrt/dynamo/backend/backends.py | 67 +++----------- py/torch_tensorrt/dynamo/lowering/__init__.py | 1 + .../dynamo/lowering/_repair_input_aliasing.py | 38 ++++++++ .../lowering/passes/_aten_lowering_pass.py | 2 + .../lowering/passes/constant_folding.py | 7 +- .../dynamo/lowering/passes/pass_utils.py | 31 +++++++ .../remove_input_alias_fixing_clones.py | 43 +++++++++ .../lowering/passes/repair_input_as_output.py | 20 ++-- .../dynamo/backend/test_specialized_models.py | 91 +++++++++++++++++++ tests/py/dynamo/testing_utilities.py | 13 ++- 10 files changed, 238 insertions(+), 75 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/lowering/_repair_input_aliasing.py create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/remove_input_alias_fixing_clones.py diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 022f3b193d..f8508d752e 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -2,17 +2,19 @@ import logging import unittest -from typing import Any, Callable, Dict, Optional, Sequence +from typing import Any, Callable, Sequence import torch import torch._dynamo as td -import torch.utils._pytree as pytree from torch._dynamo.utils import detect_fake_mode -from torch._functorch.aot_autograd import _aot_export_function -from torch._ops import OpOverload +from torch._functorch.aot_autograd import aot_export_joint_simple from torch_tensorrt.dynamo import CompilationSettings from torch_tensorrt.dynamo.compile import compile_module -from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions +from torch_tensorrt.dynamo.lowering import ( + apply_lowering_passes, + get_decompositions, + repair_input_aliasing, +) from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs, set_log_level @@ -71,10 +73,13 @@ def _pretraced_backend( with unittest.mock.patch.object( fake_mode, "allow_non_fake_inputs", True ), fake_mode: + repair_input_aliasing(gm) + # Invoke AOTAutograd to translate operators to aten - gm = aot_export_for_compile( + gm = aot_export_joint_simple( gm, sample_inputs, + trace_joint=False, decompositions=get_decompositions( settings.enable_experimental_decompositions ), @@ -107,53 +112,3 @@ def _pretraced_backend( + "specify pass_through_build_failures=False." ) raise - - -def aot_export_for_compile( - func: torch.fx.GraphModule, - args: Sequence[torch.Tensor], - *, - decompositions: Optional[Dict[OpOverload, Callable[[Any], Any]]] = None, -) -> torch.fx.GraphModule: - """Adapted from: - https://github.com/pytorch/pytorch/blob/1a5fdc2458b98697c75c32eb6f4b8b34d76429cf/torch/_functorch/aot_autograd.py#L4084-L4158 - - Removed check for input aliasing in resultant subgraph - TRT is functional-only - - Exports the function to ATen for torch compile - """ - # Trace function with input arguments and decompositions - with torch.no_grad(): - fx_g, metadata, in_spec, out_spec = _aot_export_function( - func, - args, - decompositions=decompositions, - ) - - # No input mutations - if ( - len([x for x in metadata.input_info if x.mutates_data or x.mutates_metadata]) - != 0 - ): - raise RuntimeError( - f"aot_export_joint_simple does not support input mutations. {str(metadata)}" - ) - # No pytrees - if type(in_spec) == pytree.LeafSpec: - raise RuntimeError( - f"aot_export_for_compile requires inputs to be a single list/tuple. in_spec={str(in_spec)}" - ) - if len([x for x in in_spec.children_specs if type(x) != pytree.LeafSpec]) != 0: - raise RuntimeError( - f"aot_export_for_compile requires individual inputs not to be pytrees. in_spec={str(in_spec)}" - ) - if type(out_spec) == pytree.LeafSpec: - raise RuntimeError( - f"aot_export_for_compile requires outputs to be a single list/tuple. out_spec={str(out_spec)}" - ) - if len([x for x in out_spec.children_specs if type(x) != pytree.LeafSpec]) != 0: - raise RuntimeError( - f"aot_export_for_compile requires individual outputs not to be pytrees. out_spec={str(out_spec)}" - ) - - return fx_g diff --git a/py/torch_tensorrt/dynamo/lowering/__init__.py b/py/torch_tensorrt/dynamo/lowering/__init__.py index 34faa1d11b..2b67ef0c91 100644 --- a/py/torch_tensorrt/dynamo/lowering/__init__.py +++ b/py/torch_tensorrt/dynamo/lowering/__init__.py @@ -2,5 +2,6 @@ from ._fusers import * # noqa: F401 from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401 from ._pre_aot_lowering import register_substitution # noqa: F401 +from ._repair_input_aliasing import repair_input_aliasing from .passes import apply_lowering_passes from .substitutions import * # noqa: F401 diff --git a/py/torch_tensorrt/dynamo/lowering/_repair_input_aliasing.py b/py/torch_tensorrt/dynamo/lowering/_repair_input_aliasing.py new file mode 100644 index 0000000000..04098e99ca --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/_repair_input_aliasing.py @@ -0,0 +1,38 @@ +import logging + +import torch +from torch_tensorrt.dynamo.lowering.passes.pass_utils import get_tensor_placeholders + +logger = logging.getLogger(__name__) + + +def repair_input_aliasing(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """Inserts clone operators temporarily ahead of every placeholder + + See: https://github.com/pytorch/pytorch/issues/108079 + Undone by `remove_input_alias_fixing_clones` after tracing + """ + # Extract graph placeholder Tensors + placeholders = get_tensor_placeholders(gm) + + for node in placeholders: + # Insert clones for placeholder nodes to avoid + # input aliasing or mutation + with gm.graph.inserting_after(placeholders[-1]): + cloned_input = gm.graph.call_function( + torch.ops.aten.clone.default, + args=(node,), + ) + + # Replace all uses of the placeholder except the cloned node + # with the cloned placeholder + node.replace_all_uses_with( + cloned_input, + delete_user_cb=lambda node: node != cloned_input, + ) + + gm.graph.lint() + gm.recompile() + logger.debug(f"Inserted auxiliary clone nodes for placeholders:\n{gm.graph}") + + return gm diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index a4c7fad607..43d70a4cac 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -5,10 +5,12 @@ from .constant_folding import constant_fold from .pass_manager import DynamoPassManager +from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones from .repair_input_as_output import repair_input_as_output ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist( [ + remove_input_alias_fixing_clones, constant_fold, repair_input_as_output, ] diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index d17d0a2528..ea2547f6bf 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -2,6 +2,9 @@ import torch from torch_tensorrt._utils import sanitized_torch_version +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) from packaging import version @@ -47,9 +50,7 @@ def constant_fold(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: for node in erased_params: gm.graph.erase_node(node) - gm.graph.eliminate_dead_code() - gm.graph.lint() - gm.recompile() + gm = clean_up_graph_after_modifications(gm) logger.debug(f"Graph after constant folding:\n{gm.graph}") diff --git a/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py b/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py new file mode 100644 index 0000000000..31a55099c2 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py @@ -0,0 +1,31 @@ +from typing import List + +import torch + + +def clean_up_graph_after_modifications( + gm: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + """Runs dead-code elimination, linting, and recompilation for graph, in-place""" + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + return gm + + +def get_tensor_placeholders( + gm: torch.fx.GraphModule, +) -> List[torch.fx.Node]: + """Returns placeholder nodes of GraphModule which are torch.Tensor types""" + # Tensor placeholders must be subclasses of torch.Tensor + placeholders = [ + node + for node in gm.graph.nodes + if ( + node.op == "placeholder" + and isinstance(node.type, type) + and issubclass(node.type, torch.Tensor) + ) + ] + + return placeholders diff --git a/py/torch_tensorrt/dynamo/lowering/passes/remove_input_alias_fixing_clones.py b/py/torch_tensorrt/dynamo/lowering/passes/remove_input_alias_fixing_clones.py new file mode 100644 index 0000000000..dce88ad109 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/remove_input_alias_fixing_clones.py @@ -0,0 +1,43 @@ +import logging + +import torch +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) + + +# TODO: Delete this lowering pass once aot_export_joint_simple is patched +def remove_input_alias_fixing_clones(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """Remove the auxiliary clone nodes inserted to fix input aliasing + + See: https://github.com/pytorch/pytorch/issues/108079 + """ + modified_graph = False + + for node in gm.graph.nodes: + # If the node is a placeholder and its only user is a clone node + # it was modified by the input alias-fixing pass, and the change + # needs to be undone + if ( + node.op == "placeholder" + and len(node.users) == 1 + and list(node.users)[0].target == torch.ops.aten.clone.default + ): + modified_graph = True + + # Replace all uses of the clone with the placholder, delete the clone + clone_node = list(node.users)[0] + logger.debug( + f"Removing node {clone_node} from graph, since it is a clone node which " + f"is the only user of placeholder {node} and was inserted by the compiler." + ) + clone_node.replace_all_uses_with(node) + gm.graph.erase_node(clone_node) + + if modified_graph: + gm = clean_up_graph_after_modifications(gm) + logger.debug(f"Removed auxiliary clone nodes for placeholders:\n{gm.graph}") + + return gm diff --git a/py/torch_tensorrt/dynamo/lowering/passes/repair_input_as_output.py b/py/torch_tensorrt/dynamo/lowering/passes/repair_input_as_output.py index 6ce846637d..ec2f5b0ae0 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/repair_input_as_output.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/repair_input_as_output.py @@ -1,6 +1,10 @@ import logging import torch +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, + get_tensor_placeholders, +) logger = logging.getLogger(__name__) @@ -13,15 +17,7 @@ def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: modified_graph = False # Extract graph placeholder Tensors - placeholders = [ - node - for node in gm.graph.nodes - if ( - node.op == "placeholder" - and isinstance(node.type, type) - and issubclass(node.type, torch.Tensor) - ) - ] + placeholders = get_tensor_placeholders(gm) for placeholder in placeholders: # If any placeholder has any users which are direct graph outputs @@ -34,7 +30,7 @@ def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: direct_outputs = [user for user in placeholder.users if user.op == "output"] # Insert clone node for placeholder to ensure placeholder is not a direct output - with gm.graph.inserting_after(placeholder): + with gm.graph.inserting_after(placeholders[-1]): cloned_placeholder = gm.graph.call_function( torch.ops.aten.clone.default, args=(placeholder,), @@ -45,9 +41,7 @@ def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: output.replace_input_with(placeholder, cloned_placeholder) if modified_graph: - gm.graph.eliminate_dead_code() - gm.graph.lint() - gm.recompile() + gm = clean_up_graph_after_modifications(gm) logger.debug(f"Graph after repair_input_as_output:\n{gm.graph}") return gm diff --git a/tests/py/dynamo/backend/test_specialized_models.py b/tests/py/dynamo/backend/test_specialized_models.py index af90fd0b3a..d32171e48b 100644 --- a/tests/py/dynamo/backend/test_specialized_models.py +++ b/tests/py/dynamo/backend/test_specialized_models.py @@ -57,6 +57,7 @@ def forward(self, x): self.assertAlmostEqual( max_diff, 0, + DECIMALS_OF_AGREEMENT, msg=f"MulInt TRT outputs don't match with the original model.", ) torch._dynamo.reset() @@ -113,6 +114,7 @@ def forward(self, x): self.assertAlmostEqual( max_diff, 0, + DECIMALS_OF_AGREEMENT, msg=f"AddFloat TRT outputs don't match with the original model.", ) @@ -273,5 +275,94 @@ def forward(self, x): torch._dynamo.reset() +class TestInputModifications(TestCase): + def test_input_modifications_add(self): + class InplaceAdd(torch.nn.Module): + def forward(self, x): + x += 3 + y = x + 1 + return y + + inputs = [ + torch.rand( + 3, + 5, + 7, + ).cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(InplaceAdd()) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"InplaceAdd TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + def test_input_modifications_mul(self): + class InplaceMul(torch.nn.Module): + def forward(self, x, y): + x *= 5.0 + x *= 1.9 + z = x + y + z /= 1.3 + return z + + inputs = [ + torch.rand( + 1, + 3, + 5, + 7, + ).cuda(), + torch.rand( + 1, + 3, + 5, + 7, + ).cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(InplaceMul()) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"InplaceMul TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/testing_utilities.py b/tests/py/dynamo/testing_utilities.py index af5336813f..344cd6bc1d 100644 --- a/tests/py/dynamo/testing_utilities.py +++ b/tests/py/dynamo/testing_utilities.py @@ -5,9 +5,13 @@ import torch from torch._dynamo.utils import detect_fake_mode +from torch._functorch.aot_autograd import aot_export_joint_simple from torch_tensorrt.dynamo import partitioning -from torch_tensorrt.dynamo.backend.backends import aot_export_for_compile -from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions +from torch_tensorrt.dynamo.lowering import ( + apply_lowering_passes, + get_decompositions, + repair_input_aliasing, +) from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions DECIMALS_OF_AGREEMENT = 4 @@ -39,10 +43,13 @@ def fx_dynamo_testing_backend( with unittest.mock.patch.object( fake_mode, "allow_non_fake_inputs", True ), fake_mode: + repair_input_aliasing(gm) + # Invoke AOTAutograd to translate operators to aten - gm = aot_export_for_compile( + gm = aot_export_joint_simple( gm, sample_inputs, + trace_joint=False, decompositions=get_decompositions(), )