Skip to content

Commit

Permalink
Changes done internally at Facebook (#1390)
Browse files Browse the repository at this point in the history
7ae5990ba20126da1e0a93ad0887cb1892ff48cd Janet Yang <qxy11@fb.com> Pass to remove for _validate_and_get_n_vectors
d269be2fc7d84738a642d1d53eb44e6886a28d0c Alex Beloi <alexbeloi@fb.com> [fx] add deferred weights (xl_weight) and tracing for xl_embedding_bag
6f233bc9c72d90a908db0548c9d2dbe853895137 Alex Beloi <alexbeloi@fb.com> [fx] fix out of bounds indices/offsets for embedding_bag ops with xl_weight
3ca3b21c6a85ab9a6e9de503d0f13ee713a7b67c Janet Yang <qxy11@fb.com> Support div, torch.norm
52955d93d25e857510ed1b765220e8e5b0b0bb08 Janet Yang <qxy11@fb.com> Pass to replace sum(elmtwise(X))/numel(X) w/ mean(elmtwise(X))
89c56ef76a7a329f244a013ac5ccb099cb00c3c0 Janet Yang <qxy11@fb.com> Support scalar clamp, fixes for nan_to_num and benchmark
48071d8da1dc66fffceb0b42ea386079f1fb9709 Wei Wei <wwei6@fb.com> [ads] bug fix in push_down_parrallel_split_ops
afdc533da031a64e162bb08c8629ff38739e24f8 Wei Wei <wwei6@fb.com> [fx2trt] disable dispatch trace leaf node test
9905612fd8e6e2e79dc2f2bd1fa5b5d7fd5c98c3 Shirong Wu <shirong@fb.com> Add number constrain for fuse group ln
d160a7a5e554d37c142e13f100bf4d8739ced232 Wei Wei <wwei6@fb.com> add option to remove passes
c22f691e6eae1b06ecd301eb6285b32d5dc9717c Mike Iovine <mikeiovine@fb.com> [fx2trt] Support dict inputs in acc tracer
8c05a3c57b1f5c63108b979ef8c61411525d0b1f Mike Iovine <mikeiovine@fb.com> [fx2trt] Support namedtuple access in acc tracer getattr
ff2000594e3f3ff75e0074edf9c38b5609128bbd Janet Yang <qxy11@fb.com> Generalize remove split ops more
1580805d827eb40c941e769b0b99e7c6a3ed6f89 Wei Wei <wwei6@fb.com> [fx2trt] add reshape unit test
d6a975462071a3747d18edcbe87a3b143b3ece88 Archie Sravankumar <archishmans@fb.com> Added FX tracing for `log_softmax`
6943ac0e322077b36a03c50c4c9065de6cd32837 Sungmin Cho <sungmincho@fb.com> Add replace_mutable_op lower pass
baab27b81b1275de92fdaf760a158ce951564d33 Donglin Xia <doxia@fb.com> Register avg_pool3d for acc_op in acc_op.py
ae4c4e2c3c18d78542140fcc30e1c24f7c647ef3 Wei Wei <wwei6@fb.com> [aten2trt] init check-in
fc94c5e110d5552349b2634662eae41f9f0b8933 Wei Wei <wwei6@fb.com> [ads] fix a bug in fuse_parallel_linear
87ef03338c9a25c5a610a2eb590345e8935f8d75 Wei Wei <wwei6@fb.com> [aten2trt] add binary ops
fca64a5b09749284fc6028b510078257fd4717b1 Shirong Wu <shirong@meta.com> Fix dper pass
2bb168517ace7e638cffc7a241b1cbf528790b92 Mike Iovine <mikeiovine@fb.com> [fx2trt] Add acc normalization blocklist
8c912e085cf8722d572698286020ae1ce055023d Zhijing Li (Accelerator Enablement) <tissue030@fb.com> Skip unstable test_conv_add_standalone_module
137a3977ffeb03d0387e8a95ff2f32f3d15b3de8 Wei Wei <wwei6@meta.com> [aten2trt] resnet support
f06174dbb190df4ea488ca99a81d4884b5ed3aa2 wwei6 <wwei6@fb.com> [fx2trt] compile
817c1f0b6278ce0ad04dd88d43d21e7390e3baea wwei6 <wwei6@fb.com> [aten2trt] init check-in
92ce42c16f34804584a7e553eddf897c9fa4f65e wwei6 <wwei6@fb.com> [aten2trt] binary op
  • Loading branch information
Wei authored Oct 5, 2022
1 parent 75fdbf0 commit 0daf301
Show file tree
Hide file tree
Showing 34 changed files with 2,640 additions and 215 deletions.
17 changes: 17 additions & 0 deletions py/torch_tensorrt/fx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,20 @@ FX2TRT is merged as FX module in Torch-TensorRT

- The user guide is in [link](../../../docsrc/tutorials/getting_started_with_fx_path.rst#installation)
- The examples are moved to [link](../../../examples/fx)

* Method 1. Follow the instrucions for Torch-TensorRT
* Method 2. To install FX path only (Python path) and avoid the C++ build for torchscript path
`
$ conda create --name python_env python=3.8
$ conda activate python_env
# Recommend to install PyTorch 1.12 and later
$ conda install pytorch torchvision torchtext cudatoolkit=11.3 -c pytorch-nightly
# Install TensorRT python package
$ pip3 install nvidia-pyindex
$ pip3 install nvidia-tensorrt==8.2.4.2
$ git clone https://github.com/pytorch/TensorRT.git
$ cd TensorRT/py && python setup.py install --fx-only && cd ..
$ pyton -c "import torch_tensorrt.fx"
# Test an example by
$ python py/torch_tensorrt/fx/example/lower_example.py
`
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .transformation import * # noqa: F401 F403
from .quantization import * # noqa: F401 F403
from .acc_ops_converters import * # noqa: F401 F403
from .aten_ops_converters import * # noqa: F401 F403

TRT_LOGGER = trt.Logger()
trt.init_libnvinfer_plugins(TRT_LOGGER, "")
104 changes: 87 additions & 17 deletions py/torch_tensorrt/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,63 @@
from ..utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt

from .converter_utils import * # noqa: F403

from torch_tensorrt.fx.passes.lower_basic_pass import (
trt_transposed_linear,
trt_transposed_matmul,
)

_LOGGER: logging.Logger = logging.getLogger(__name__)


@tensorrt_converter(trt_transposed_matmul)
def trt_transposed_matmul_converter(network, target, args, kwargs, name):
lhs, rhs, lhs_transposed, rhs_transposed = args

if isinstance(lhs, torch.nn.Parameter):
lhs = get_trt_tensor(network, lhs, f"{name}_lhs")
if isinstance(rhs, torch.nn.Parameter):
rhs = get_trt_tensor(network, rhs, f"{name}_rhs")
layer = network.add_matrix_multiply(
lhs,
trt.MatrixOperation.TRANSPOSE if lhs_transposed else trt.MatrixOperation.NONE,
rhs,
trt.MatrixOperation.TRANSPOSE if rhs_transposed else trt.MatrixOperation.NONE,
)
set_layer_name(layer, target, name)
return layer.get_output(0)


@tensorrt_converter(trt_transposed_linear)
def trt_transposed_linear_converter(network, target, args, kwargs, name):
input, weight, bias = args

weight = get_trt_tensor(network, weight.t(), f"{name}_weight")
bias = get_trt_tensor(network, bias.reshape(1, -1), f"{name}_bias")

input, weight = broadcast(
network,
input,
weight,
f"{input.name}_broadcast",
f"{weight.name}_broadcast",
)
layer = network.add_matrix_multiply(
input,
trt.MatrixOperation.TRANSPOSE,
weight,
trt.MatrixOperation.NONE,
)
set_layer_name(layer, target, f"{name}_mm")
return add_binary_elementwise_layer(
network,
layer.get_output(0),
bias,
trt.ElementWiseOperation.SUM,
target,
f"{name}_add",
)


@tensorrt_converter(acc_ops.conv1d)
def acc_ops_conv1d(
network: TRTNetwork,
Expand Down Expand Up @@ -1975,7 +2027,10 @@ def acc_ops_max_poolnd(
f"MaxPool2d received input {input_val} that is not part "
"of the TensorRT region!"
)
extend_len = 2 if target == acc_ops.max_pool2d else 3
if target not in (acc_ops.max_pool2d, acc_ops.max_pool3d):
extend_len = 2 if len(kwargs["kernel_size"]) == 2 else 3
else:
extend_len = 2 if target == acc_ops.max_pool2d else 3
kernel_size = extend_attr_to_tuple(kwargs["kernel_size"], extend_len)
stride = extend_attr_to_tuple(kwargs["stride"], extend_len)
padding = extend_attr_to_tuple(kwargs["padding"], extend_len)
Expand Down Expand Up @@ -2259,8 +2314,11 @@ def acc_ops_adaptive_avg_poolnd(
f"AdaptiveAvgPool2d received input {input_val} that is not part "
"of the TensorRT region!"
)
if target not in (acc_ops.adaptive_avg_pool3d, acc_ops.adaptive_avg_pool2d):
extend_len = 2 if len(kwargs["output_size"]) == 2 else 3
else:
extend_len = 2 if target == acc_ops.adaptive_avg_pool2d else 3

extend_len = 2 if target == acc_ops.adaptive_avg_pool2d else 3
assert all(
input_val.shape[-(i + 1)] != -1 for i in range(extend_len)
), "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims."
Expand Down Expand Up @@ -2747,7 +2805,10 @@ def acc_ops_linear(

if isinstance(kwargs["weight"], torch.Tensor):
weight = get_trt_tensor(network, kwargs["weight"].t(), f"{name}_weight")
weight_op = trt.MatrixOperation.NONE
if target is not acc_ops.linear:
weight_op = trt.MatrixOperation.TRANSPOSE
else:
weight_op = trt.MatrixOperation.NONE
else:
assert isinstance(
kwargs["weight"], TRTTensor
Expand Down Expand Up @@ -2782,17 +2843,26 @@ def acc_ops_linear(
return res


def add_clamp(network, input, val, op):
acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions
acc_ops_clamp_tensor = (
val
* torch.ones(acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype))
.cpu()
.numpy()
)
acc_ops_clamp_trt = network.add_constant(acc_ops_clamp_shape, acc_ops_clamp_tensor)
layer = network.add_elementwise(input, acc_ops_clamp_trt.get_output(0), op)

def add_clamp(network, input, val, op, name):
if not len(input.shape):
# clamping scalar
acc_ops_clamp_trt = get_trt_tensor(
network,
squeeze_left(torch.tensor([val], dtype=torch_dtype_from_trt(input.dtype))),
f"{name}_clamp_{val}",
)
else:
acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions
acc_ops_clamp_tensor = (
val
* torch.ones(acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype))
.cpu()
.numpy()
)
acc_ops_clamp_trt = network.add_constant(
acc_ops_clamp_shape, acc_ops_clamp_tensor
).get_output(0)
layer = network.add_elementwise(input, acc_ops_clamp_trt, op)
return layer


Expand All @@ -2816,13 +2886,13 @@ def acc_ops_clamp(

if min_val is not None:
clamp_min_layer = add_clamp(
network, input_val, min_val, trt.ElementWiseOperation.MAX
network, input_val, min_val, trt.ElementWiseOperation.MAX, name
)
set_layer_name(clamp_min_layer, target, f"{name}_clamp_min")
input_val = clamp_min_layer.get_output(0)
if max_val is not None:
clamp_max_layer = add_clamp(
network, input_val, max_val, trt.ElementWiseOperation.MIN
network, input_val, max_val, trt.ElementWiseOperation.MIN, name
)
set_layer_name(clamp_max_layer, target, f"{name}_clamp_max")
input_val = clamp_max_layer.get_output(0)
Expand Down
Loading

0 comments on commit 0daf301

Please sign in to comment.