Skip to content

Commit

Permalink
fix: Remove input aliasing of builtin ops (#2276)
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive authored Sep 26, 2023
1 parent ecdc040 commit 7daa112
Show file tree
Hide file tree
Showing 10 changed files with 238 additions and 75 deletions.
67 changes: 11 additions & 56 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@

import logging
import unittest
from typing import Any, Callable, Dict, Optional, Sequence
from typing import Any, Callable, Sequence

import torch
import torch._dynamo as td
import torch.utils._pytree as pytree
from torch._dynamo.utils import detect_fake_mode
from torch._functorch.aot_autograd import _aot_export_function
from torch._ops import OpOverload
from torch._functorch.aot_autograd import aot_export_joint_simple
from torch_tensorrt.dynamo import CompilationSettings
from torch_tensorrt.dynamo.compile import compile_module
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
from torch_tensorrt.dynamo.lowering import (
apply_lowering_passes,
get_decompositions,
repair_input_aliasing,
)
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs, set_log_level

Expand Down Expand Up @@ -71,10 +73,13 @@ def _pretraced_backend(
with unittest.mock.patch.object(
fake_mode, "allow_non_fake_inputs", True
), fake_mode:
repair_input_aliasing(gm)

# Invoke AOTAutograd to translate operators to aten
gm = aot_export_for_compile(
gm = aot_export_joint_simple(
gm,
sample_inputs,
trace_joint=False,
decompositions=get_decompositions(
settings.enable_experimental_decompositions
),
Expand Down Expand Up @@ -107,53 +112,3 @@ def _pretraced_backend(
+ "specify pass_through_build_failures=False."
)
raise


def aot_export_for_compile(
func: torch.fx.GraphModule,
args: Sequence[torch.Tensor],
*,
decompositions: Optional[Dict[OpOverload, Callable[[Any], Any]]] = None,
) -> torch.fx.GraphModule:
"""Adapted from:
https://github.com/pytorch/pytorch/blob/1a5fdc2458b98697c75c32eb6f4b8b34d76429cf/torch/_functorch/aot_autograd.py#L4084-L4158
Removed check for input aliasing in resultant subgraph - TRT is functional-only
Exports the function to ATen for torch compile
"""
# Trace function with input arguments and decompositions
with torch.no_grad():
fx_g, metadata, in_spec, out_spec = _aot_export_function(
func,
args,
decompositions=decompositions,
)

# No input mutations
if (
len([x for x in metadata.input_info if x.mutates_data or x.mutates_metadata])
!= 0
):
raise RuntimeError(
f"aot_export_joint_simple does not support input mutations. {str(metadata)}"
)
# No pytrees
if type(in_spec) == pytree.LeafSpec:
raise RuntimeError(
f"aot_export_for_compile requires inputs to be a single list/tuple. in_spec={str(in_spec)}"
)
if len([x for x in in_spec.children_specs if type(x) != pytree.LeafSpec]) != 0:
raise RuntimeError(
f"aot_export_for_compile requires individual inputs not to be pytrees. in_spec={str(in_spec)}"
)
if type(out_spec) == pytree.LeafSpec:
raise RuntimeError(
f"aot_export_for_compile requires outputs to be a single list/tuple. out_spec={str(out_spec)}"
)
if len([x for x in out_spec.children_specs if type(x) != pytree.LeafSpec]) != 0:
raise RuntimeError(
f"aot_export_for_compile requires individual outputs not to be pytrees. out_spec={str(out_spec)}"
)

return fx_g
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/lowering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from ._fusers import * # noqa: F401
from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401
from ._pre_aot_lowering import register_substitution # noqa: F401
from ._repair_input_aliasing import repair_input_aliasing
from .passes import apply_lowering_passes
from .substitutions import * # noqa: F401
38 changes: 38 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_repair_input_aliasing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import logging

import torch
from torch_tensorrt.dynamo.lowering.passes.pass_utils import get_tensor_placeholders

logger = logging.getLogger(__name__)


def repair_input_aliasing(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""Inserts clone operators temporarily ahead of every placeholder
See: https://github.com/pytorch/pytorch/issues/108079
Undone by `remove_input_alias_fixing_clones` after tracing
"""
# Extract graph placeholder Tensors
placeholders = get_tensor_placeholders(gm)

for node in placeholders:
# Insert clones for placeholder nodes to avoid
# input aliasing or mutation
with gm.graph.inserting_after(placeholders[-1]):
cloned_input = gm.graph.call_function(
torch.ops.aten.clone.default,
args=(node,),
)

# Replace all uses of the placeholder except the cloned node
# with the cloned placeholder
node.replace_all_uses_with(
cloned_input,
delete_user_cb=lambda node: node != cloned_input,
)

gm.graph.lint()
gm.recompile()
logger.debug(f"Inserted auxiliary clone nodes for placeholders:\n{gm.graph}")

return gm
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

from .constant_folding import constant_fold
from .pass_manager import DynamoPassManager
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
from .repair_input_as_output import repair_input_as_output

ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
[
remove_input_alias_fixing_clones,
constant_fold,
repair_input_as_output,
]
Expand Down
7 changes: 4 additions & 3 deletions py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

import torch
from torch_tensorrt._utils import sanitized_torch_version
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)

from packaging import version

Expand Down Expand Up @@ -47,9 +50,7 @@ def constant_fold(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in erased_params:
gm.graph.erase_node(node)

gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()
gm = clean_up_graph_after_modifications(gm)

logger.debug(f"Graph after constant folding:\n{gm.graph}")

Expand Down
31 changes: 31 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import List

import torch


def clean_up_graph_after_modifications(
gm: torch.fx.GraphModule,
) -> torch.fx.GraphModule:
"""Runs dead-code elimination, linting, and recompilation for graph, in-place"""
gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()
return gm


def get_tensor_placeholders(
gm: torch.fx.GraphModule,
) -> List[torch.fx.Node]:
"""Returns placeholder nodes of GraphModule which are torch.Tensor types"""
# Tensor placeholders must be subclasses of torch.Tensor
placeholders = [
node
for node in gm.graph.nodes
if (
node.op == "placeholder"
and isinstance(node.type, type)
and issubclass(node.type, torch.Tensor)
)
]

return placeholders
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import logging

import torch
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)

logger = logging.getLogger(__name__)


# TODO: Delete this lowering pass once aot_export_joint_simple is patched
def remove_input_alias_fixing_clones(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""Remove the auxiliary clone nodes inserted to fix input aliasing
See: https://github.com/pytorch/pytorch/issues/108079
"""
modified_graph = False

for node in gm.graph.nodes:
# If the node is a placeholder and its only user is a clone node
# it was modified by the input alias-fixing pass, and the change
# needs to be undone
if (
node.op == "placeholder"
and len(node.users) == 1
and list(node.users)[0].target == torch.ops.aten.clone.default
):
modified_graph = True

# Replace all uses of the clone with the placholder, delete the clone
clone_node = list(node.users)[0]
logger.debug(
f"Removing node {clone_node} from graph, since it is a clone node which "
f"is the only user of placeholder {node} and was inserted by the compiler."
)
clone_node.replace_all_uses_with(node)
gm.graph.erase_node(clone_node)

if modified_graph:
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Removed auxiliary clone nodes for placeholders:\n{gm.graph}")

return gm
20 changes: 7 additions & 13 deletions py/torch_tensorrt/dynamo/lowering/passes/repair_input_as_output.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import logging

import torch
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
get_tensor_placeholders,
)

logger = logging.getLogger(__name__)

Expand All @@ -13,15 +17,7 @@ def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
modified_graph = False

# Extract graph placeholder Tensors
placeholders = [
node
for node in gm.graph.nodes
if (
node.op == "placeholder"
and isinstance(node.type, type)
and issubclass(node.type, torch.Tensor)
)
]
placeholders = get_tensor_placeholders(gm)

for placeholder in placeholders:
# If any placeholder has any users which are direct graph outputs
Expand All @@ -34,7 +30,7 @@ def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
direct_outputs = [user for user in placeholder.users if user.op == "output"]

# Insert clone node for placeholder to ensure placeholder is not a direct output
with gm.graph.inserting_after(placeholder):
with gm.graph.inserting_after(placeholders[-1]):
cloned_placeholder = gm.graph.call_function(
torch.ops.aten.clone.default,
args=(placeholder,),
Expand All @@ -45,9 +41,7 @@ def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
output.replace_input_with(placeholder, cloned_placeholder)

if modified_graph:
gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after repair_input_as_output:\n{gm.graph}")

return gm
Loading

0 comments on commit 7daa112

Please sign in to comment.