Skip to content

Commit

Permalink
Merge pull request #1104 from pytorch/fb-sync-2-wwei6
Browse files Browse the repository at this point in the history
Refactor the internal codebase from fx2trt_oss to torch_tensorrt
  • Loading branch information
Wei authored Jun 8, 2022
2 parents 916c3de + 7618ac5 commit fa33de2
Show file tree
Hide file tree
Showing 118 changed files with 245 additions and 203 deletions.
2 changes: 1 addition & 1 deletion docs/_modules/torch_tensorrt/_compile.html
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ <h1>Source code for torch_tensorrt._compile</h1><div class="highlight"><pre>
<span class="c1"># profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile</span>
<span class="p">)</span>
<span class="c1"># For profile</span>
<span class="c1"># from fx2trt_oss.fx.tools.trt_profiler_sorted import profile_trt_module</span>
<span class="c1"># from torch_tensorrt.fx.tools.trt_profiler_sorted import profile_trt_module</span>
<span class="c1"># profile_trt_module(&quot;&quot;, trt_mod, acc_inputs)</span>
<span class="n">trt_mod</span> <span class="o">=</span> <span class="n">TRTModule</span><span class="p">(</span><span class="o">*</span><span class="n">r</span><span class="p">)</span>

Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def get_input(self, inputs):
# profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile
)
# For profile
# from fx2trt_oss.fx.tools.trt_profiler_sorted import profile_trt_module
# from torch_tensorrt.fx.tools.trt_profiler_sorted import profile_trt_module
# profile_trt_module("", trt_mod, acc_inputs)
trt_mod = TRTModule(*r)

Expand Down
11 changes: 5 additions & 6 deletions py/torch_tensorrt/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,21 @@
import warnings
from typing import cast, Dict, Optional, Sequence, Tuple, Union

from ..tracer.acc_tracer import acc_ops
import numpy as np

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch

from ..converter_registry import tensorrt_converter

from ..tracer.acc_tracer import acc_ops
from ..types import * # noqa: F403
from ..utils import (
get_dynamic_dims,
torch_dtype_from_trt,
torch_dtype_to_trt,
)
from torch.fx.immutable_collections import immutable_list
from torch.fx.node import Argument, Target

from ..utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt

from .converter_utils import * # noqa: F403


Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/converters/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch

from ..converter_registry import tensorrt_converter

from .converter_utils import mark_as_int8_layer
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/converters/adaptive_avgpool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch

from ..converter_registry import tensorrt_converter

from .converter_utils import extend_mod_attr_to_tuple, mark_as_int8_layer
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/converters/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch

from ..converter_registry import tensorrt_converter

from .converter_utils import get_dyn_range, mark_as_int8_layer
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/converters/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch

from ..converter_registry import tensorrt_converter

from .converter_utils import get_dyn_range, mark_as_int8_layer, to_numpy
Expand Down
3 changes: 2 additions & 1 deletion py/torch_tensorrt/fx/converters/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch
from torch.fx.node import Argument, Target

from ..types import (
Shape,
TRTDataType,
Expand All @@ -18,7 +20,6 @@
TRTTensor,
)
from ..utils import torch_dtype_from_trt
from torch.fx.node import Argument, Target


def get_trt_plugin(
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/converters/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import tensorrt as trt
import torch

from ..converter_registry import tensorrt_converter

from .converter_utils import (
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/converters/linear.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch

from ..converter_registry import tensorrt_converter

from .converter_utils import get_dyn_range, mark_as_int8_layer, to_numpy
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/converters/maxpool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch

from ..converter_registry import tensorrt_converter

from .converter_utils import extend_mod_attr_to_tuple, mark_as_int8_layer
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/converters/mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch

from ..converter_registry import tensorrt_converter

from .converter_utils import get_dyn_range, mark_as_int8_layer
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/converters/quantization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch

from ..converter_registry import tensorrt_converter

from .converter_utils import get_dyn_range, get_inputs_from_args_and_kwargs
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/converters/transformation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch

from ..converter_registry import tensorrt_converter

from .converter_utils import mark_as_int8_layer
Expand Down
12 changes: 6 additions & 6 deletions py/torch_tensorrt/fx/example/fx2trt_example.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# type: ignore[]

import fx2trt_oss.tracer.acc_tracer.acc_tracer as acc_tracer
import torch
import torch.fx
import torch.nn as nn
from fx2trt_oss.fx import InputTensorSpec, TRTInterpreter, TRTModule
from fx2trt_oss.fx.tools.trt_splitter import TRTSplitter
import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter


# The purpose of this example is to demonstrate the overall flow of lowering a PyTorch
Expand Down Expand Up @@ -83,12 +83,12 @@ def forward(self, x):
%x : [#users=1] = placeholder[target=x]
%linear_weight : [#users=1] = get_attr[target=linear.weight]
%linear_bias : [#users=1] = get_attr[target=linear.bias]
%linear_1 : [#users=1] = call_function[target=fx2trt_oss.tracer.acc_tracer.acc_ops.linear](args = (), ...
%relu_1 : [#users=1] = call_function[target=fx2trt_oss.tracer.acc_tracer.acc_ops.relu](args = (), ...
%linear_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.linear](args = (), ...
%relu_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.relu](args = (), ...
return relu_1
graph():
%relu_1 : [#users=1] = placeholder[target=relu_1]
%linalg_norm_1 : [#users=1] = call_function[target=fx2trt_oss.tracer.acc_tracer.acc_ops.linalg_norm](args = (), ...
%linalg_norm_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.linalg_norm](args = (), ...
return linalg_norm_1
"""

Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/fx/example/lower_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import torch
import torchvision
from fx2trt_oss.fx.lower import lower_to_trt
from fx2trt_oss.fx.utils import LowerPrecision
from torch_tensorrt.fx.lower import lower_to_trt
from torch_tensorrt.fx.utils import LowerPrecision


"""
Expand Down
8 changes: 4 additions & 4 deletions py/torch_tensorrt/fx/example/quantized_resnet_test.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import copy

import fx2trt_oss.tracer.acc_tracer.acc_tracer as acc_tracer

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch.fx

import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
import torchvision.models as models
from fx2trt_oss.fx import InputTensorSpec, TRTInterpreter, TRTModule
from fx2trt_oss.fx.utils import LowerPrecision
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
from torch.fx.experimental.normalize import NormalizeArgs
from torch.fx.passes import shape_prop
from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule
from torch_tensorrt.fx.utils import LowerPrecision

rn18 = models.resnet18().eval()

Expand Down
38 changes: 23 additions & 15 deletions py/torch_tensorrt/fx/example/test_fx2trt.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,35 @@
import torch_tensorrt
import torch
import torch_tensorrt


class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5,3)
self.linear = torch.nn.Linear(5, 3)
self.relu = torch.nn.functional.relu
def forward(self,x):

def forward(self, x):
x = self.linear(x)
x = self.relu(x)
return x


model = MyModel().eval() # torch module needs to be in eval (not training) mode
model = MyModel().eval() # torch module needs to be in eval (not training) mode

# torch tensorrt
inputs = [torch_tensorrt.Input(
(2,5),
dtype=torch.half,
)]
enabled_precisions = {torch.float, torch.half} # Run with fp16

trt_ts_module = torch_tensorrt.compile(model, inputs=inputs, enabled_precisions=enabled_precisions)

inputs_ts = [torch.ones(2,5)]
inputs = [
torch_tensorrt.Input(
(2, 5),
dtype=torch.half,
)
]
enabled_precisions = {torch.float, torch.half} # Run with fp16

trt_ts_module = torch_tensorrt.compile(
model, inputs=inputs, enabled_precisions=enabled_precisions
)

inputs_ts = [torch.ones(2, 5)]
inputs_ts = [i.cuda().half() for i in inputs_ts]
result = trt_ts_module(*inputs_ts)
print(result)
Expand All @@ -33,12 +39,14 @@ def forward(self,x):
print(ref)

# fx2trt
inputs_fx = [torch.ones((2,5))]
inputs_fx = [torch.ones((2, 5))]

model.cuda().half()
inputs_fx = [i.cuda().half() for i in inputs_fx]

trt_fx_module = torch_tensorrt.compile(model, ir="fx", inputs=inputs_fx, enabled_precisions={torch.half})
trt_fx_module = torch_tensorrt.compile(
model, ir="fx", inputs=inputs_fx, enabled_precisions={torch.half}
)
result = trt_fx_module(*inputs_fx)
print(result)

Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/fx/example/torchdynamo_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import torch
import torchdynamo
import torchvision
from fx2trt_oss.fx.lower import lower_to_trt
from fx2trt_oss.fx.utils import LowerPrecision
from torch_tensorrt.fx.lower import lower_to_trt
from torch_tensorrt.fx.utils import LowerPrecision
from torchdynamo.optimizations import backends

"""
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/fx/fx2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
import tensorrt as trt
import torch
import torch.fx
from .observer import Observer
from torch.fx.node import _get_qualified_name
from torch.fx.passes.shape_prop import TensorMetadata

from .converter_registry import CONVERTERS
from .input_tensor_spec import InputTensorSpec
from .observer import Observer
from .utils import get_dynamic_dims, LowerPrecision, torch_dtype_to_trt


Expand Down
16 changes: 10 additions & 6 deletions py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,27 @@
import logging
from typing import Any, Callable, Sequence

from .tracer.acc_tracer import acc_tracer

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch
import torch.fx as fx
import torch.nn as nn
from .lower_setting import LowerSetting
from .passes.pass_utils import decorate_method, validate_inference
from .passes.splitter_base import SplitResult
from torch.fx.passes.splitter_base import SplitResult

from .fx2trt import TRTInterpreter, TRTInterpreterResult
from .input_tensor_spec import InputTensorSpec
from .lower_setting import LowerSetting
from .passes.lower_pass_manager_builder import LowerPassManagerBuilder
from .passes.pass_utils import chain_passes, PassFunc
from .passes.pass_utils import (
chain_passes,
decorate_method,
PassFunc,
validate_inference,
)
from .tools.timing_cache_utils import TimingCacheManager
from .tools.trt_splitter import TRTSplitter, TRTSplitterSetting

from .tracer.acc_tracer import acc_tracer
from .trt_module import TRTModule
from .utils import LowerPrecision

Expand Down
10 changes: 4 additions & 6 deletions py/torch_tensorrt/fx/lower_setting.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import dataclasses as dc
from typing import List, Optional, Sequence, Set, Type

from .input_tensor_spec import InputTensorSpec
from .passes.lower_basic_pass import (
fuse_permute_linear,
fuse_permute_matmul,
)
from .utils import LowerPrecision
from torch import nn
from torch.fx.passes.pass_manager import PassManager

from .input_tensor_spec import InputTensorSpec
from .passes.lower_basic_pass import fuse_permute_linear, fuse_permute_matmul
from .utils import LowerPrecision


@dc.dataclass
class LowerSetting:
Expand Down
18 changes: 10 additions & 8 deletions py/torch_tensorrt/fx/passes/lower_basic_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import warnings
from typing import Any

from ..tracer.acc_tracer import acc_ops
import torch
import torch.fx
from torch.fx.experimental.const_fold import split_const_subgraphs

from ..observer import observable
from .pass_utils import log_before_after, validate_inference

from ..tracer.acc_tracer import acc_ops
from ..tracer.acc_tracer.acc_utils import get_attr
from torch.fx.experimental.const_fold import split_const_subgraphs
from .pass_utils import log_before_after, validate_inference

# Create an alias for module input type to avoid littering pyre-ignore for Any
# throughout the file.
Expand Down Expand Up @@ -46,15 +48,15 @@ def fuse_sparse_matmul_add(gm: torch.fx.GraphModule, input: Input):
def forward(self, x):
a = self.a
b = self.b
addmm_mm = fx2trt_oss.tracer.acc_tracer.acc_ops.matmul(input = a, other = b); a = b = None
addmm_add = fx2trt_oss.tracer.acc_tracer.acc_ops.add(input = addmm_mm, other = x); addmm_mm = x = None
addmm_mm = torch_tensorrt.fx.tracer.acc_tracer.acc_ops.matmul(input = a, other = b); a = b = None
addmm_add = torch_tensorrt.fx.tracer.acc_tracer.acc_ops.add(input = addmm_mm, other = x); addmm_mm = x = None
return addmm_add
After:
def forward(self, x):
a = self.a
b = self.b
linear_1 = fx2trt_oss.tracer.acc_tracer.acc_ops.linear(input = a, weight = b, bias = x); a = b = x = None
linear_1 = torch_tensorrt.fx.tracer.acc_tracer.acc_ops.linear(input = a, weight = b, bias = x); a = b = x = None
return linear_1
"""
counter = 0
Expand Down Expand Up @@ -198,8 +200,8 @@ def fuse_permute_matmul(gm: torch.fx.GraphModule, input: Input):
try:
# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
from fx2trt_oss.fx.converter_registry import tensorrt_converter
from fx2trt_oss.fx.converters.converter_utils import (
from torch_tensorrt.fx.converter_registry import tensorrt_converter
from torch_tensorrt.fx.converters.converter_utils import (
add_binary_elementwise_layer,
broadcast,
get_trt_tensor,
Expand Down
Loading

0 comments on commit fa33de2

Please sign in to comment.