From 834064e97180ab5d233858879a7e59a650c0389d Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Fri, 2 Jun 2023 09:15:50 -0700 Subject: [PATCH] fix: Centralize FX conv impl, add feature - Centralize convolution implementation in FX, similar across all source IRs, including aten, acc, nn - Enable pass-through of build errors in e2e tests to ensure errors are not being hidden - Allow conv layers to take bias inputs in FX, per new functionality from TRT - Remove separate `convolution.py` file and centralize `nn` converters to a single file --- .../dynamo/test/test_dynamo_backend.py | 5 + py/torch_tensorrt/fx/converters/__init__.py | 2 +- .../fx/converters/acc_ops_converters.py | 167 +++----------- .../fx/converters/aten_ops_converters.py | 33 ++- .../fx/converters/converter_utils.py | 7 +- .../fx/converters/convolution.py | 212 ------------------ .../fx/converters/impl/convolution.py | 145 ++++++++++++ .../fx/converters/nn_ops_converters.py | 118 +++++++++- 8 files changed, 327 insertions(+), 362 deletions(-) delete mode 100644 py/torch_tensorrt/fx/converters/convolution.py create mode 100644 py/torch_tensorrt/fx/converters/impl/convolution.py diff --git a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py index 3c2ec01419..462fe04e70 100644 --- a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py +++ b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py @@ -30,6 +30,7 @@ def test_resnet18(ir): "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.float}, "ir": ir, + "pass_through_build_failures": True, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -60,6 +61,7 @@ def test_mobilenet_v2(ir): "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.float}, "ir": ir, + "pass_through_build_failures": True, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -90,6 +92,7 @@ def test_efficientnet_b0(ir): "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.float}, "ir": ir, + "pass_through_build_failures": True, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -129,6 +132,7 @@ def test_bert_base_uncased(ir): "enabled_precisions": {torch.float}, "truncate_long_and_double": True, "ir": ir, + "pass_through_build_failures": True, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -163,6 +167,7 @@ def test_resnet18_half(ir): "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.half}, "ir": ir, + "pass_through_build_failures": True, } trt_mod = torchtrt.compile(model, **compile_spec) diff --git a/py/torch_tensorrt/fx/converters/__init__.py b/py/torch_tensorrt/fx/converters/__init__.py index b6f5a35be8..f037d54ce7 100644 --- a/py/torch_tensorrt/fx/converters/__init__.py +++ b/py/torch_tensorrt/fx/converters/__init__.py @@ -5,7 +5,6 @@ from .adaptive_avgpool import * # noqa: F401 F403 from .add import * # noqa: F401 F403 from .batchnorm import * # noqa: F401 F403 - from .convolution import * # noqa: F401 F403 from .linear import * # noqa: F401 F403 from .maxpool import * # noqa: F401 F403 from .mul import * # noqa: F401 F403 @@ -13,6 +12,7 @@ from .quantization import * # noqa: F401 F403 from .acc_ops_converters import * # noqa: F401 F403 from .aten_ops_converters import * # noqa: F401 F403 + from .nn_ops_converters import * # noqa: F401 F403 TRT_LOGGER = trt.Logger() trt.init_libnvinfer_plugins(TRT_LOGGER, "") diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 57c720ffba..e7dcf49e90 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -26,7 +26,7 @@ trt_transposed_matmul, ) from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous -from torch_tensorrt.fx.converters.impl import activation +from torch_tensorrt.fx.converters.impl import activation, convolution _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -96,86 +96,20 @@ def acc_ops_conv1d( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"Conv received input {input_val} that is not part " - "of the TensorRT region!" - ) - - # Process 1d input with unsqueeze -> conv2d -> squeeze to calculated conv1d - unsqueeze_layer = network.add_shuffle(input=input_val) - unsqueeze_layer.reshape_dims = tuple([*input_val.shape, 1]) - set_layer_name(unsqueeze_layer, target, name + "_unsqueeze") - input_val = unsqueeze_layer.get_output(0) - - if has_dynamic_shape(input_val.shape): - assert input_val.shape[1] != -1, "Channel dim can't be dynamic for convolution." - - # for now we'll assume bias is constant Tensor or None, - # and bias being ITensor is not supported in TensorRT api - # right now - if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor): - raise RuntimeError( - f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]" - ) - bias = to_numpy(kwargs["bias"]) # type: ignore[arg-type] - if bias is not None: - bias = bias[None] - weight = kwargs["weight"] - - if network.has_explicit_precision or isinstance(weight, TRTTensor): - weight = get_trt_tensor(network, weight, f"{name}_weight") - # Expand 1d weight with unsqueeze for calculation - unsqueeze_weight_layer = network.add_shuffle(input=weight) - unsqueeze_weight_layer.reshape_dims = tuple([*weight.shape, 1]) - set_layer_name(unsqueeze_layer, target, name + "_unsqueeze_weight") - weight = unsqueeze_weight_layer.get_output(0) - weight_shape = tuple(kwargs["weight"].shape) # type: ignore[union-attr] - # will need to use uninitialized weight and set it later to support - # ITensor weights - dummy_weight = trt.Weights() - layer = network.add_convolution_nd( - input=input_val, - num_output_maps=weight.shape[0], - kernel_shape=weight.shape[2:], - kernel=dummy_weight, - bias=bias, - ) - - layer.set_input(1, weight) - else: - if not isinstance(kwargs["weight"], torch.Tensor): - raise RuntimeError( - f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tensor]" - ) - weight = to_numpy(weight) - weight = np.expand_dims(weight, -1) - layer = network.add_convolution_nd( - input=input_val, - num_output_maps=weight.shape[0], - kernel_shape=weight.shape[2:], - kernel=weight, - bias=bias, - ) - # expand params to 2d for computation - padding = list(kwargs["padding"]) - padding.append(0) - stride = extend_attr_to_tuple(kwargs["stride"], 2) - dilation = extend_attr_to_tuple(kwargs["dilation"], 2) - - set_layer_name(layer, target, name) - layer.stride_nd = stride - layer.padding_nd = padding - layer.dilation_nd = dilation - if kwargs["groups"] is not None: - layer.num_groups = kwargs["groups"] - - result = layer.get_output(0) - squeeze_layer = network.add_shuffle(input=result) - squeeze_layer.reshape_dims = tuple(result.shape[:-1]) - set_layer_name(squeeze_layer, target, name + "_squeeze") - return squeeze_layer.get_output(0) + return convolution.convNd( + network, + target, + source_ir=SourceIR.ACC, + name=name, + is_conv1d=True, + input_val=kwargs["input"], + weight=kwargs["weight"], + bias=kwargs["bias"], + stride=kwargs["stride"], + padding=kwargs["padding"], + dilation=kwargs["dilation"], + groups=kwargs["groups"], + ) @tensorrt_converter(acc_ops.conv3d) @@ -187,63 +121,20 @@ def acc_ops_convnd( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"Conv received input {input_val} that is not part " - "of the TensorRT region!" - ) - - if has_dynamic_shape(input_val.shape): - assert input_val.shape[1] != -1, "Channel dim can't be dynamic for convolution." - - # for now we'll assume bias is constant Tensor or None, - # and bias being ITensor is not supported in TensorRT api - # right now - if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor): - raise RuntimeError( - f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]" - ) - bias = to_numpy(kwargs["bias"]) # type: ignore[arg-type] - - if network.has_explicit_precision or isinstance(kwargs["weight"], TRTTensor): - weight = get_trt_tensor(network, kwargs["weight"], f"{name}_weight") - weight_shape = tuple(kwargs["weight"].shape) # type: ignore[union-attr] - # will need to use uninitialized weight and set it later to support - # ITensor weights - dummy_weight = trt.Weights() - layer = network.add_convolution_nd( - input=input_val, - num_output_maps=weight.shape[0], - kernel_shape=weight.shape[2:], - kernel=dummy_weight, - bias=bias, - ) - - layer.set_input(1, weight) - else: - if not isinstance(kwargs["weight"], torch.Tensor): - raise RuntimeError( - f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tensor]" - ) - weight = to_numpy(kwargs["weight"]) - layer = network.add_convolution_nd( - input=input_val, - num_output_maps=weight.shape[0], - kernel_shape=weight.shape[2:], - kernel=weight, - bias=bias, - ) - - set_layer_name(layer, target, name) - layer.stride_nd = kwargs["stride"] - layer.padding_nd = kwargs["padding"] - layer.dilation_nd = kwargs["dilation"] - if kwargs["groups"] is not None: - layer.num_groups = kwargs["groups"] - - return layer.get_output(0) + return convolution.convNd( + network, + target, + source_ir=SourceIR.ACC, + name=name, + is_conv1d=False, + input_val=kwargs["input"], + weight=kwargs["weight"], + bias=kwargs["bias"], + stride=kwargs["stride"], + padding=kwargs["padding"], + dilation=kwargs["dilation"], + groups=kwargs["groups"], + ) @tensorrt_converter(acc_ops.conv_transpose2d) diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index 82847cc760..5a989a0ed7 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -22,7 +22,7 @@ from .converter_utils import * # noqa: F403 import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils -from torch_tensorrt.fx.converters.impl import activation +from torch_tensorrt.fx.converters.impl import activation, convolution _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -129,13 +129,36 @@ def aten_ops_convolution( # we do not handle output_padding. if args[7] not in ([0], [0, 0], [0, 0, 0]): raise RuntimeError(f"Target {target} has non-0 output_padding") + if len(kwargs_new["stride"]) == 1: - return acc_ops_converters.acc_ops_conv1d( - network, target, None, kwargs_new, name + return convolution.convNd( + network, + target, + source_ir=SourceIR.ATEN, + name=name, + is_conv1d=True, + input_val=kwargs_new["input"], + weight=kwargs_new["weight"], + bias=kwargs_new["bias"], + stride=kwargs_new["stride"], + padding=kwargs_new["padding"], + dilation=kwargs_new["dilation"], + groups=kwargs_new["groups"], ) else: - return acc_ops_converters.acc_ops_convnd( - network, target, None, kwargs_new, name + return convolution.convNd( + network, + target, + source_ir=SourceIR.ATEN, + name=name, + is_conv1d=False, + input_val=kwargs_new["input"], + weight=kwargs_new["weight"], + bias=kwargs_new["bias"], + stride=kwargs_new["stride"], + padding=kwargs_new["padding"], + dilation=kwargs_new["dilation"], + groups=kwargs_new["groups"], ) diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index d13be41d05..a7015589a8 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -99,14 +99,17 @@ def get_positive_dim(dim: int, dim_size: int) -> int: def set_layer_name( - layer: TRTLayer, target: Target, name: str, source_ir: Optional[SourceIR] = None + layer: TRTLayer, + target: Union[Target, torch.nn.Module, str], + name: str, + source_ir: Optional[SourceIR] = None, ) -> None: """ Set the TensorRT layer name to "[TensorRT Layer Type]_[Original Op Name]_[FX Node Name with Suffix]" Args: layer (TRTLayer): A TensorRT layer of which we want to set the name. - target (Target): A fx node.target. For call_function node, it's the function that + target (Target): A fx node.target or submodule. For call_function node, it's the function that the node represents. name (str): Consists of fx node.name with optional suffix. source_ir: (Optional[SourceIR]): The IR producing the op. diff --git a/py/torch_tensorrt/fx/converters/convolution.py b/py/torch_tensorrt/fx/converters/convolution.py deleted file mode 100644 index 6af940200a..0000000000 --- a/py/torch_tensorrt/fx/converters/convolution.py +++ /dev/null @@ -1,212 +0,0 @@ -# @manual=//deeplearning/trt/python:py_tensorrt -import logging - -import numpy as np -import tensorrt as trt -import torch - -from ..converter_registry import tensorrt_converter - -from .converter_utils import ( - extend_mod_attr_to_tuple, - get_dyn_range, - mark_as_int8_layer, - to_numpy, -) - -logger = logging.getLogger(__name__) - - -def common_conv(network, mod, dimension, input_val, layer_name, is_quantized): - if mod.padding_mode != "zeros": - raise RuntimeError(f"Only support padding mode: zeros, got {mod.padding_mode}.") - - kernel_size = extend_mod_attr_to_tuple(mod, "kernel_size", dimension) - stride = extend_mod_attr_to_tuple(mod, "stride", dimension) - padding = extend_mod_attr_to_tuple(mod, "padding", dimension) - dilation = extend_mod_attr_to_tuple(mod, "dilation", dimension) - - kernel = to_numpy(mod.weight() if is_quantized else mod.weight) - bias = to_numpy(mod.bias() if is_quantized else mod.bias) - - if dimension == 1: - # Append unsqueeze before conv2d to calculate conv1d - unsqueeze_layer = network.add_shuffle(input=input_val) - unsqueeze_layer.reshape_dims = (*input_val.shape, 1) - unsqueeze_layer.name = f"{layer_name}_unsqueeze" - input_val = unsqueeze_layer.get_output(0) - - kernel = np.expand_dims(kernel, -1) - kernel_size = kernel.shape[2:] - if bias is not None: - bias = bias[None] - stride = (stride[0], 1) - padding = (padding[0], 0) - dilation = (dilation[0], 1) - layer = network.add_convolution_nd( - input=input_val, - num_output_maps=mod.out_channels, - kernel_shape=kernel_size, - kernel=kernel, - bias=bias, - ) - layer.name = layer_name - layer.stride_nd = stride - layer.padding_nd = padding - layer.dilation_nd = dilation - layer.num_groups = mod.groups - - if is_quantized: - # Assume the dtype of activation is torch.quint8 - mark_as_int8_layer( - layer, get_dyn_range(mod.scale, mod.zero_point, torch.quint8) - ) - - result = layer.get_output(0) - if dimension == 1: - # Append squeeze after conv2d to calculate conv1d - squeeze_layer = network.add_shuffle(input=result) - squeeze_layer.reshape_dims = tuple(result.shape[:-1]) - squeeze_layer.name = f"{layer_name}_squeeze" - result = squeeze_layer.get_output(0) - - return result - - -def common_conv_relu(network, mod, dimension, input_val, layer_name, is_quantized): - conv_output = common_conv( - network, - mod, - dimension=2, - input_val=input_val, - layer_name=f"{layer_name}_conv", - is_quantized=is_quantized, - ) - - layer = network.add_activation(input=conv_output, type=trt.ActivationType.RELU) - layer.name = f"{layer_name}_relu" - - if is_quantized: - mark_as_int8_layer(layer, conv_output.dynamic_range) - - return layer.get_output(0) - - -@tensorrt_converter(torch.nn.modules.conv.Conv1d) -def conv1d(network, submod, args, kwargs, layer_name): - # args/kwargs should have already been normalized to kwargs - assert len(args) == 0 - input_val = kwargs["input"] - - if not isinstance(input_val, trt.tensorrt.ITensor): - raise RuntimeError( - f"Conv1d received input {input_val} that is not part " - "of the TensorRT region!" - ) - - if layer_name is None: - raise RuntimeError("layer name is none") - return common_conv( - network, - submod, - dimension=1, - input_val=input_val, - layer_name=layer_name, - is_quantized=False, - ) - - -@tensorrt_converter(torch.nn.modules.conv.Conv2d) -def conv2d(network, submod, args, kwargs, layer_name): - # args/kwargs should have already been normalized to kwargs - assert len(args) == 0 - input_val = kwargs["input"] - - if not isinstance(input_val, trt.tensorrt.ITensor): - raise RuntimeError( - f"Conv2d received input {input_val} that is not part " - "of the TensorRT region!" - ) - - return common_conv( - network, - submod, - dimension=2, - input_val=input_val, - layer_name=layer_name, - is_quantized=False, - ) - - -@tensorrt_converter(torch.nn.modules.conv.Conv3d) -def conv3d(network, submod, args, kwargs, layer_name): - # args/kwargs should have already been normalized to kwargs - assert len(args) == 0 - input_val = kwargs["input"] - # TODO: Remove this warning when https://github.com/pytorch/TensorRT/issues/1445 is fixed - kernel = to_numpy(submod.weight) - kernel_size_one = True - if len(kernel.shape) == 5: - for filter_size in kernel.shape[2:]: - if filter_size != 1: - kernel_size_one = False - if kernel_size_one: - logger.warn( - "Conv3d layer with kernel size = 1 configuration incurs a failure with TensorRT tactic optimizer in some cases. \ - Github issue: https://github.com/pytorch/TensorRT/issues/1445. Other conv variants do not have this issue." - ) - - if not isinstance(input_val, trt.tensorrt.ITensor): - raise RuntimeError( - f"Conv3d received input {input_val} that is not part " - "of the TensorRT region!" - ) - - return common_conv( - network, - submod, - dimension=3, - input_val=input_val, - layer_name=layer_name, - is_quantized=False, - ) - - -@tensorrt_converter(torch.nn.quantized.modules.conv.Conv2d) -def quantized_conv2d(network, submod, args, kwargs, layer_name): - input_val = args[0] - - if not isinstance(input_val, trt.tensorrt.ITensor): - raise RuntimeError( - f"Quantized Conv2d received input {input_val} that is not part " - "of the TensorRT region!" - ) - - return common_conv( - network, - submod, - dimension=2, - input_val=input_val, - layer_name=layer_name, - is_quantized=True, - ) - - -@tensorrt_converter(torch.nn.intrinsic.quantized.modules.ConvReLU2d) -def quantized_conv_relu2d(network, submod, args, kwargs, layer_name): - input_val = args[0] - - if not isinstance(input_val, trt.tensorrt.ITensor): - raise RuntimeError( - f"Quantized ConvReLU2d received input {input_val} that is not part " - "of the TensorRT region!" - ) - - return common_conv_relu( - network, - submod, - dimension=2, - input_val=input_val, - layer_name=f"{layer_name}_conv", - is_quantized=True, - ) diff --git a/py/torch_tensorrt/fx/converters/impl/convolution.py b/py/torch_tensorrt/fx/converters/impl/convolution.py new file mode 100644 index 0000000000..a0e7537fde --- /dev/null +++ b/py/torch_tensorrt/fx/converters/impl/convolution.py @@ -0,0 +1,145 @@ +import numpy as np +from typing import Any, Optional, Sequence, Union + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +import torch +from torch.fx.node import Target + +from torch_tensorrt.fx.converters.converter_utils import ( + SourceIR, + extend_attr_to_tuple, + get_dyn_range, + mark_as_int8_layer, + set_layer_name, + has_dynamic_shape, + to_numpy, + get_trt_tensor, +) +from torch_tensorrt.fx.converters import acc_ops_converters + +from torch_tensorrt.fx.types import ( + TRTNetwork, + TRTTensor, +) + + +def convNd( + network: TRTNetwork, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + is_conv1d: bool, + input_val: TRTTensor, + weight: Union[TRTTensor, torch.Tensor], + bias: Optional[Union[TRTTensor, torch.Tensor]], + stride: Optional[Union[int, Sequence[int]]], + padding: Optional[Union[int, Sequence[int]]], + dilation: Optional[Union[int, Sequence[int]]], + groups: Optional[int], + scale: Optional[Union[torch.Tensor, float]] = None, + zero_point: Optional[Union[torch.Tensor, float]] = None, +) -> TRTTensor: + + if has_dynamic_shape(input_val.shape): + assert input_val.shape[1] != -1, "Channel dim can't be dynamic for convolution." + + if is_conv1d: + # Apply an unsqueeze operation to transform the conv1d problem into conv2d + kwargs = { + "input": input_val, + "dim": -1, + } + input_val = acc_ops_converters.acc_ops_unsqueeze( + network, target, tuple(), kwargs, name + "_unsqueeze" + ) + + # Process bias terms + if isinstance(bias, torch.Tensor): + # Transform the bias constant into a Numpy array + bias = to_numpy(bias) + + elif isinstance(bias, TRTTensor): + bias = get_trt_tensor(network, bias, f"{name}_bias") + + elif bias is not None: + raise RuntimeError( + f"Convolution {name} has bias of type {type(bias)}, Expected Torch Tensor or TRT Tensor" + ) + + # Process weight terms + if network.has_explicit_precision or isinstance(weight, TRTTensor): + weight = get_trt_tensor(network, weight, f"{name}_weight") + # Append new dimension (unsqueeze) if the convolution is 1d + if is_conv1d: + kwargs = { + "input": weight, + "dim": -1, + } + weight = acc_ops_converters.acc_ops_unsqueeze( + network, target, tuple(), kwargs, name + "_unsqueeze_weight" + ) + + elif isinstance(weight, torch.Tensor): + # Transform the weight constant into a Numpy array + weight = to_numpy(weight) + + # Append new dimension (unsqueeze) if the convolution is 1d + if is_conv1d: + weight = np.expand_dims(weight, -1) + + else: + raise RuntimeError( + f"Convolution {name} has weight of type {type(weight)}, Expect Optional[Tensor]" + ) + + conv_layer = network.add_convolution_nd( + input=input_val, + num_output_maps=weight.shape[0], + kernel_shape=weight.shape[2:], + kernel=trt.Weights() if isinstance(weight, TRTTensor) else weight, + bias=trt.Weights() if isinstance(bias, TRTTensor) else bias, + ) + + # If the weight is a TRTTensor, set it as an input of the layer + if isinstance(weight, TRTTensor): + conv_layer.set_input(1, weight) + + # If the bias is a TRTTensor, set it as an input of the layer + if isinstance(bias, TRTTensor): + conv_layer.set_input(2, bias) + + # Expand parameters manually for Conv1D computations + if is_conv1d: + padding = tuple(padding) + (0,) + stride = extend_attr_to_tuple(stride, 2) + dilation = extend_attr_to_tuple(dilation, 2) + + set_layer_name(conv_layer, target, name, source_ir) + + # Set relevant attributes of convolution layer + conv_layer.padding_nd = padding + conv_layer.stride_nd = stride + conv_layer.dilation_nd = dilation + + if groups is not None: + conv_layer.num_groups = groups + + # Handle quantization cases + if scale is not None and zero_point is not None: + # Assume the dtype of activation is torch.quint8 + mark_as_int8_layer(conv_layer, get_dyn_range(scale, zero_point, torch.quint8)) + + result = conv_layer.get_output(0) + + if is_conv1d: + # Apply a squeeze operation to transform the conv2d problem back into conv1d + kwargs = { + "input": result, + "dim": -1, + } + result = acc_ops_converters.acc_ops_squeeze( + network, target, tuple(), kwargs, name + "_squeeze" + ) + + return result diff --git a/py/torch_tensorrt/fx/converters/nn_ops_converters.py b/py/torch_tensorrt/fx/converters/nn_ops_converters.py index 3be5e9ae98..2aacaa9a68 100644 --- a/py/torch_tensorrt/fx/converters/nn_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/nn_ops_converters.py @@ -1,13 +1,14 @@ -import numpy as np +import torch # @manual=//deeplearning/trt/python:py_tensorrt -import tensorrt as trt -import torch +import logging from torch_tensorrt.fx.converter_registry import tensorrt_converter -from torch_tensorrt.fx.converters.impl import activation +from torch_tensorrt.fx.converters.impl import activation, convolution from torch_tensorrt.fx.converters.converter_utils import SourceIR +logger = logging.getLogger(__name__) + @tensorrt_converter(torch.nn.functional.relu) @tensorrt_converter(torch.nn.modules.activation.ReLU) @@ -113,3 +114,112 @@ def selu(network, submod, args, kwargs, layer_name): input_val=kwargs["input"], alpha=kwargs["alpha"], ) + + +@tensorrt_converter(torch.nn.modules.conv.Conv1d) +def conv1d(network, submod, args, kwargs, layer_name): + # args/kwargs should have already been normalized to kwargs + assert len(args) == 0 + + if layer_name is None: + raise RuntimeError("layer name is none") + return convolution.convNd( + network, + submod._get_name(), + source_ir=SourceIR.NN, + name=layer_name, + is_conv1d=True, + input_val=kwargs["input"], + weight=submod.weight, + bias=submod.bias, + stride=getattr(submod, "stride"), + padding=getattr(submod, "padding"), + dilation=getattr(submod, "dilation"), + groups=submod.groups, + ) + + +@tensorrt_converter(torch.nn.modules.conv.Conv2d) +def conv2d(network, submod, args, kwargs, layer_name): + # args/kwargs should have already been normalized to kwargs + assert len(args) == 0 + return convolution.convNd( + network, + submod._get_name(), + source_ir=SourceIR.NN, + name=layer_name, + is_conv1d=False, + input_val=kwargs["input"], + weight=submod.weight, + bias=submod.bias, + stride=getattr(submod, "stride"), + padding=getattr(submod, "padding"), + dilation=getattr(submod, "dilation"), + groups=submod.groups, + ) + + +@tensorrt_converter(torch.nn.modules.conv.Conv3d) +def conv3d(network, submod, args, kwargs, layer_name): + # args/kwargs should have already been normalized to kwargs + assert len(args) == 0 + return convolution.convNd( + network, + submod._get_name(), + source_ir=SourceIR.NN, + name=layer_name, + is_conv1d=False, + input_val=kwargs["input"], + weight=submod.weight, + bias=submod.bias, + stride=getattr(submod, "stride"), + padding=getattr(submod, "padding"), + dilation=getattr(submod, "dilation"), + groups=submod.groups, + ) + + +@tensorrt_converter(torch.nn.quantized.modules.conv.Conv2d) +def quantized_conv2d(network, submod, args, kwargs, layer_name): + input_val = args[0] + return convolution.convNd( + network, + submod._get_name(), + source_ir=SourceIR.NN, + name=layer_name, + is_conv1d=False, + input_val=input_val, + weight=submod.weight(), + bias=submod.bias(), + stride=getattr(submod, "stride"), + padding=getattr(submod, "padding"), + dilation=getattr(submod, "dilation"), + groups=submod.groups, + scale=submod.scale, + zero_point=submod.zero_point, + ) + + +@tensorrt_converter(torch.nn.intrinsic.quantized.modules.ConvReLU2d) +def quantized_conv_relu2d(network, submod, args, kwargs, layer_name): + input_val = args[0] + conv_out = convolution.convNd( + network, + submod._get_name(), + source_ir=SourceIR.NN, + name=layer_name, + is_conv1d=False, + input_val=input_val, + weight=submod.weight(), + bias=submod.bias(), + stride=getattr(submod, "stride"), + padding=getattr(submod, "padding"), + dilation=getattr(submod, "dilation"), + groups=submod.groups, + scale=submod.scale, + zero_point=submod.zero_point, + ) + + return activation.relu( + network, submod._get_name(), SourceIR.NN, layer_name + "_relu", conv_out + )