Skip to content

Commit

Permalink
fix: Address review comments and add upgrades
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive committed Sep 21, 2023
1 parent 1a03fc5 commit 7fa0a0c
Show file tree
Hide file tree
Showing 9 changed files with 199 additions and 184 deletions.
109 changes: 109 additions & 0 deletions docsrc/contributors/writing_dynamo_aten_lowering_passes.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
.. _writing_dynamo_aten_lowering_passes:

Writing Dynamo ATen Lowering Passes
===================

Basics of a Lowering Pass
------------

ATen lowering passes are Python functions which take as input a graph of ATen operators, apply some desired modification such as operator coalescing/fusion, operator replacement, subgraph rewriting, custom operator insertion, or other operation on a `torch.fx.GraphModule`, then return the modified graph to the caller. These lowering passes generally modify the graph in-place and return the same input object.

Lowering Pass Requirements
------------

An ATen lowering pass function in Torch-TRT must satisfy two requirements:
- The function must take as input a single `torch.fx.GraphModule` and return the lowered `torch.fx.GraphModule`
- The function must leave the graph in a valid and invoke-able state, including performing any necessary linting and recompilation

See this link for information on `Graph Manipulations <https://pytorch.org/docs/stable/fx.html#graph-manipulation>`_ in FX. See below for an example of a lowering pass which repairs graphs that have inputs which are also outputs, a disallowed configuration for TRT Engines.

Example Lowering Pass
------------

.. code-block:: python
def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""Repair scenarios where inputs are also outputs of the graph
TRT does not allow such cases, so we insert a clone (identity) layer
"""
modified_graph = False
# Extract graph placeholder Tensors
placeholders = [
node
for node in gm.graph.nodes
if (
node.op == "placeholder"
and isinstance(node.type, type)
and issubclass(node.type, torch.Tensor)
)
]
for placeholder in placeholders:
# If any placeholder has any users which are direct graph outputs
if len(placeholder.users) >= 1 and any(
user.op == "output" for user in placeholder.users
):
modified_graph = True
# Get direct graph outputs which are direct uses of placeholders
direct_outputs = [user for user in placeholder.users if user.op == "output"]
# Insert clone node for placeholder to ensure
# placeholder is not a direct output
with gm.graph.inserting_after(placeholder):
cloned_placeholder = gm.graph.call_function(
torch.ops.aten.clone.default,
args=(placeholder,),
)
# Replace placeholder as output with cloned version
for output in direct_outputs:
output.replace_input_with(placeholder, cloned_placeholder)
# If the graph was modified, clean up the graph and ensure it is up-to-date
if modified_graph:
gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()
logger.debug(f"Graph after repair_input_as_output:\n{gm.graph}")
return gm
Registering Lowering Passes
----------------------

Lowering passes are currently registered in `py/torch_tensorrt/dynamo/lowering/passes/__init__.py`, using the `torch.fx.passes.pass_manager.PassManager` utility to assemble the list of passes in a desired order. New passes added directly to that list will be applied to graphs in the Torch-TensorRT `torch.compile` backend. Currently, we offer an ATen lowering pass registration decorator for convenience, which can be invoked either directly, or with the optional `index` keyword argument which controls where in the pass list the lowering pass will be inserted.

For instance, to insert the pass at the default location (end of the list), the following code can be used:

.. code-block:: python
@_aten_lowering_pass
def my_custom_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
...
Alternatively, to insert the pass at a custom index (such as the front of the list) in the passlist, the following code can be used:

.. code-block:: python
@_aten_lowering_pass(index=0)
def my_custom_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
...
There are also provided utilities in `torch_tensorrt.dynamo.lowering.passes` for displaying the currently-available lowering pass list, applying those passes to an arbitrary `torch.fx.GraphModule`, and removing the lowering pass at a specific index.

.. code-block:: python
# Print all lowering passes in the list
print(dump_lowering_passes())
# Apply lowering passes to a GraphModule
apply_lowering_passes(graph_module)
# Remove the lowering pass at index 1
_remove_lowering_pass(index=1)
**Note:** The above APIs are subject to change, as the lowering pass system evolves.
3 changes: 2 additions & 1 deletion docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ Tutorials
tutorials/_rendered_examples/dynamo/torch_compile_resnet_example
tutorials/_rendered_examples/dynamo/torch_compile_transformers_example
tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage
tutorials/_rendered_examples/dynamo/dynamo_aten_lowering_passes

Python API Documenation
------------------------
Expand Down Expand Up @@ -129,6 +128,7 @@ Contributor Documentation
--------------------------------
* :ref:`system_overview`
* :ref:`writing_converters`
* :ref:`writing_dynamo_aten_lowering_passes`
* :ref:`useful_links`

.. toctree::
Expand All @@ -138,6 +138,7 @@ Contributor Documentation

contributors/system_overview
contributors/writing_converters
contributors/writing_dynamo_aten_lowering_passes
contributors/useful_links

Indices
Expand Down
1 change: 0 additions & 1 deletion examples/dynamo/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,3 @@ a number of ways you can leverage this backend to accelerate inference.
* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile``
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile``
* :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API
:ref:`dynamo_aten_lowering_passes`: Custom modifications of a graph of ATen operators via lowering passes
113 changes: 0 additions & 113 deletions examples/dynamo/dynamo_aten_lowering_passes.py

This file was deleted.

5 changes: 2 additions & 3 deletions py/torch_tensorrt/dynamo/aten_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

import torch
from torch._export import export
from torch_tensorrt.dynamo.backend.backends import constant_fold
from torch_tensorrt.dynamo.lowering import get_decompositions
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
from torch_tensorrt.dynamo.utils import set_log_level

logger = logging.getLogger(__name__)
Expand All @@ -29,6 +28,6 @@ def trace(
"torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions)
):
graph_module = export(model, tuple(inputs)).module()
constant_fold(graph_module)
graph_module = apply_lowering_passes(graph_module)
logger.debug("Post export graph: " + str(graph_module.graph))
return graph_module
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/lowering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
from ._fusers import * # noqa: F401
from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401
from ._pre_aot_lowering import register_substitution # noqa: F401
from .passes import add_lowering_pass, apply_lowering_passes
from .passes import apply_lowering_passes
from .substitutions import * # noqa: F401
56 changes: 1 addition & 55 deletions py/torch_tensorrt/dynamo/lowering/passes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,55 +1 @@
import logging
from typing import Callable, Optional

import torch

# Import and order lowering passes and pass manager
from .constant_folding import constant_fold
from .pass_manager import DynamoPassManager
from .repair_input_as_output import repair_input_as_output

ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
[
constant_fold,
repair_input_as_output,
]
)

logger = logging.getLogger(__name__)


def add_lowering_pass(
lowering_pass: Callable[[torch.fx.GraphModule], torch.fx.GraphModule],
index: Optional[int] = None,
) -> None:
"""Adds a lowering pass to the registry, at a specified index if desired
If no index is specified, the lowering pass is inserted at the end of the list
"""
ATEN_LOWERING_PASSES.add_pass_with_index(lowering_pass, index)
logger.debug(
f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_LOWERING_PASSES}"
)
return


def remove_lowering_pass(index: int) -> None:
"""Removes a lowering pass at a specific index from the registry"""
ATEN_LOWERING_PASSES.remove_pass_with_index(index)
logger.debug(
f"Removed lowering pass at index {index}, current passlist: {ATEN_LOWERING_PASSES}"
)
return


def apply_lowering_passes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""Applies the lowering passes to a graph module, returns the modified GraphModule"""
logging.debug(
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_LOWERING_PASSES}"
)
return ATEN_LOWERING_PASSES(gm)


def dump_lowering_passes() -> str:
"""Returns a string containing the lowering passes"""
return str(ATEN_LOWERING_PASSES)
from ._aten_lowering_pass import *
Loading

0 comments on commit 7fa0a0c

Please sign in to comment.