Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

examples: Add example usage scripts for torch_tensorrt.dynamo.compile path [1.1 / x] #1966

Merged
merged 2 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ docsrc/_build
docsrc/_notebooks
docsrc/_cpp_api
docsrc/_tmp
docsrc/tutorials/_rendered_examples
*.so
__pycache__
*.egg-info
Expand Down Expand Up @@ -67,4 +68,4 @@ bazel-tensorrt
*cifar-10-batches-py*
bazel-project
build/
wheelhouse/
wheelhouse/
1 change: 1 addition & 0 deletions docsrc/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ endif
rm -rf $(SOURCEDIR)/_py_api
rm -rf $(SOURCEDIR)/_build
rm -rf $(SOURCEDIR)/_tmp
rm -rf $(SOURCEDIR)/tutorials/_rendered_examples

html:
# mkdir -p $(SOURCEDIR)/_notebooks
Expand Down
8 changes: 8 additions & 0 deletions docsrc/_static/css/custom.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
/* sphinx-design styles for cards/tabs
*/

.sphx-glr-thumbcontainer {
padding: 50%;
display: flex;
align-content: center;
}
127 changes: 127 additions & 0 deletions docsrc/_static/css/pytorch_theme.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
body {
font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;
}

/* Default header fonts are ugly */
h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption {
font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;
}

/* Use white for docs background */
.wy-side-nav-search {
background-color: #fff;
}

.wy-nav-content-wrap, .wy-menu li.current > a {
background-color: #fff;
}

@media screen and (min-width: 1400px) {
.wy-nav-content-wrap {
background-color: rgba(0, 0, 0, 0.0470588);
}

.wy-nav-content {
background-color: #fff;
}
}

/* Fixes for mobile */
.wy-nav-top {
background-color: #fff;
background-image: url('../img/pytorch-logo-dark.svg');
background-repeat: no-repeat;
background-position: center;
padding: 0;
margin: 0.4045em 0.809em;
color: #333;
}

.wy-nav-top > a {
display: none;
}

@media screen and (max-width: 768px) {
.wy-side-nav-search>a img.logo {
height: 60px;
}
}

/* This is needed to ensure that logo above search scales properly */
.wy-side-nav-search a {
display: block;
}

/* This ensures that multiple constructors will remain in separate lines. */
.rst-content dl:not(.docutils) dt {
display: table;
}

/* Use our red for literals (it's very similar to the original color) */
.rst-content tt.literal, .rst-content tt.literal, .rst-content code.literal {
color: #F05732;
}

.rst-content tt.xref, a .rst-content tt, .rst-content tt.xref,
.rst-content code.xref, a .rst-content tt, a .rst-content code {
color: #404040;
}

/* Change link colors (except for the menu) */

a {
color: #F05732;
}

a:hover {
color: #F05732;
}


a:visited {
color: #D44D2C;
}

.wy-menu a {
color: #b3b3b3;
}

.wy-menu a:hover {
color: #b3b3b3;
}

a.icon.icon-home {
color: #D44D2C;
}

.version{
color: #D44D2C !important;
}

/* Default footer text is quite big */
footer {
font-size: 80%;
}

footer .rst-footer-buttons {
font-size: 125%; /* revert footer settings - 1/80% = 125% */
}

footer p {
font-size: 100%;
}

/* For hidden headers that appear in TOC tree */
/* see https://stackoverflow.com/a/32363545/3343043 */
.rst-content .hidden-section {
display: none;
}

nav .hidden-section {
display: inherit;
}

/* Make code blocks have a background */
.codeblock,pre.literal-block,.rst-content .literal-block,.rst-content pre.literal-block,div[class^='highlight'] {
background: rgba(0, 0, 0, 0.0470588);
}
16 changes: 16 additions & 0 deletions docsrc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import torch
import pytorch_sphinx_theme
import torch_tensorrt
from docutils.parsers.rst import Directive, directives
from docutils.statemachine import StringList
from docutils import nodes

# -- Project information -----------------------------------------------------

Expand Down Expand Up @@ -47,6 +50,7 @@
"sphinx.ext.coverage",
"sphinx.ext.mathjax",
"sphinx.ext.viewcode",
"sphinx_gallery.gen_gallery",
]

napoleon_use_ivar = True
Expand Down Expand Up @@ -78,6 +82,18 @@
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_static"]
# Custom CSS paths should either relative to html_static_path
# or fully qualified paths (eg. https://...)
html_css_files = [
"https://cdn.jsdelivr.net/npm/katex@0.10.0-beta/dist/katex.min.css",
"css/custom.css",
]

# sphinx-gallery configuration
sphinx_gallery_conf = {
"examples_dirs": "../examples",
"gallery_dirs": "tutorials/_rendered_examples/",
}

# Setup the breathe extension
breathe_projects = {"Torch-TensorRT": "./_tmp/xml"}
Expand Down
31 changes: 22 additions & 9 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,30 +36,43 @@ Getting Started
getting_started/getting_started_with_windows


Tutorials
User Guide
------------
* :ref:`creating_a_ts_mod`
* :ref:`getting_started_with_fx`
* :ref:`ptq`
* :ref:`runtime`
* :ref:`serving_torch_tensorrt_with_triton`
* :ref:`use_from_pytorch`
* :ref:`using_dla`

.. toctree::
:caption: User Guide
:maxdepth: 1
:hidden:

user_guide/creating_torchscript_module_in_python
user_guide/getting_started_with_fx_path
user_guide/ptq
user_guide/runtime
user_guide/use_from_pytorch
user_guide/using_dla

Tutorials
------------
* :ref:`torch_tensorrt_tutorials`
* :ref:`serving_torch_tensorrt_with_triton`
* :ref:`notebooks`

.. toctree::
:caption: Tutorials
:maxdepth: 1
:maxdepth: 3
:hidden:

tutorials/creating_torchscript_module_in_python
tutorials/getting_started_with_fx_path
tutorials/ptq
tutorials/runtime
tutorials/serving_torch_tensorrt_with_triton
tutorials/use_from_pytorch
tutorials/using_dla
tutorials/notebooks
tutorials/_rendered_examples/dynamo/torch_compile_resnet_example
tutorials/_rendered_examples/dynamo/torch_compile_transformers_example
tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage

Python API Documenation
------------------------
Expand Down
1 change: 1 addition & 0 deletions docsrc/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
sphinx==4.5.0
sphinx-gallery==0.13.0
breathe==4.33.1
exhale==0.3.1
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
7 changes: 7 additions & 0 deletions examples/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
.. _torch_tensorrt_tutorials:

Torch-TensorRT Tutorials
===========================

The user guide covers the basic concepts and usage of Torch-TensorRT.
We also provide a number of tutorials to explore specific usecases and advanced concepts
11 changes: 11 additions & 0 deletions examples/dynamo/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
.. _torch_compile:

Dynamo / ``torch.compile``
----------------------------

Torch-TensorRT provides a backend for the new ``torch.compile`` API released in PyTorch 2.0. In the following examples we describe
a number of ways you can leverage this backend to accelerate inference.

* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile``
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile``
* :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API
103 changes: 103 additions & 0 deletions examples/dynamo/torch_compile_advanced_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""
.. _torch_compile_advanced_usage:

Torch Compile Advanced Usage
======================================================

This interactive script is intended as an overview of the process by which `torch_tensorrt.compile(..., ir="torch_compile", ...)` works, and how it integrates with the `torch.compile` API."""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

import torch
import torch_tensorrt

# %%

# We begin by defining a model
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.relu = torch.nn.ReLU()

def forward(self, x: torch.Tensor, y: torch.Tensor):
x_out = self.relu(x)
y_out = self.relu(y)
x_y_out = x_out + y_out
return torch.mean(x_y_out)


# %%
# Compilation with `torch.compile` Using Default Settings
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Define sample float inputs and initialize model
sample_inputs = [torch.rand((5, 7)).cuda(), torch.rand((5, 7)).cuda()]
model = Model().eval().cuda()

# %%

# Next, we compile the model using torch.compile
# For the default settings, we can simply call torch.compile
# with the backend "torch_tensorrt", and run the model on an
# input to cause compilation, as so:
optimized_model = torch.compile(model, backend="torch_tensorrt")
optimized_model(*sample_inputs)

# %%
# Compilation with `torch.compile` Using Custom Settings
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# First, we use Torch utilities to clean up the workspace
# after the previous compile invocation
torch._dynamo.reset()

# Define sample half inputs and initialize model
sample_inputs_half = [
torch.rand((5, 7)).half().cuda(),
torch.rand((5, 7)).half().cuda(),
]
model_half = Model().eval().cuda()

# %%

# If we want to customize certain options in the backend,
# but still use the torch.compile call directly, we can provide
# custom options to the backend via the "options" keyword
# which takes in a dictionary mapping options to values.
#
# For accepted backend options, see the CompilationSettings dataclass:
# py/torch_tensorrt/dynamo/_settings.py
backend_kwargs = {
"enabled_precisions": {torch.half},
"debug": True,
"min_block_size": 2,
"torch_executed_ops": {"torch.ops.aten.sub.Tensor"},
"optimization_level": 4,
"use_python_runtime": False,
}

# Run the model on an input to cause compilation, as so:
optimized_model_custom = torch.compile(
model_half, backend="torch_tensorrt", options=backend_kwargs
)
optimized_model_custom(*sample_inputs_half)

# %%
# Cleanup
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Finally, we use Torch utilities to clean up the workspace
torch._dynamo.reset()

# %%
# Cuda Driver Error Note
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# Occasionally, upon exiting the Python runtime after Dynamo compilation with `torch_tensorrt`,
# one may encounter a Cuda Driver Error. This issue is related to https://github.com/NVIDIA/TensorRT/issues/2052
# and can be resolved by wrapping the compilation/inference in a function and using a scoped call, as in::
#
# if __name__ == '__main__':
# compile_engine_and_infer()
Loading