Skip to content

Commit

Permalink
fix: Centralize FX conv impl, add feature
Browse files Browse the repository at this point in the history
- Centralize convolution implementation in FX, similar across all source
IRs, including aten, acc, nn
- Enable pass-through of build errors in e2e tests to ensure errors are
not being hidden
- Allow conv layers to take bias inputs in FX, per new functionality
from TRT
  • Loading branch information
gs-olive committed Jun 5, 2023
1 parent dd31c9a commit 5348ac2
Show file tree
Hide file tree
Showing 6 changed files with 310 additions and 314 deletions.
5 changes: 5 additions & 0 deletions py/torch_tensorrt/dynamo/test/test_dynamo_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_resnet18(ir):
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"ir": ir,
"pass_through_build_failures": True,
}

trt_mod = torchtrt.compile(model, **compile_spec)
Expand Down Expand Up @@ -57,6 +58,7 @@ def test_mobilenet_v2(ir):
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"ir": ir,
"pass_through_build_failures": True,
}

trt_mod = torchtrt.compile(model, **compile_spec)
Expand Down Expand Up @@ -87,6 +89,7 @@ def test_efficientnet_b0(ir):
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"ir": ir,
"pass_through_build_failures": True,
}

trt_mod = torchtrt.compile(model, **compile_spec)
Expand Down Expand Up @@ -126,6 +129,7 @@ def test_bert_base_uncased(ir):
"enabled_precisions": {torch.float},
"truncate_long_and_double": True,
"ir": ir,
"pass_through_build_failures": True,
}
trt_mod = torchtrt.compile(model, **compile_spec)

Expand Down Expand Up @@ -160,6 +164,7 @@ def test_resnet18_half(ir):
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.half},
"ir": ir,
"pass_through_build_failures": True,
}

trt_mod = torchtrt.compile(model, **compile_spec)
Expand Down
193 changes: 44 additions & 149 deletions py/torch_tensorrt/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
trt_transposed_matmul,
)
from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous
from torch_tensorrt.fx.converters.impl import activation
from torch_tensorrt.fx.converters.impl import activation, convolution

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -96,86 +96,20 @@ def acc_ops_conv1d(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
raise RuntimeError(
f"Conv received input {input_val} that is not part "
"of the TensorRT region!"
)

# Process 1d input with unsqueeze -> conv2d -> squeeze to calculated conv1d
unsqueeze_layer = network.add_shuffle(input=input_val)
unsqueeze_layer.reshape_dims = tuple([*input_val.shape, 1])
set_layer_name(unsqueeze_layer, target, name + "_unsqueeze")
input_val = unsqueeze_layer.get_output(0)

if has_dynamic_shape(input_val.shape):
assert input_val.shape[1] != -1, "Channel dim can't be dynamic for convolution."

# for now we'll assume bias is constant Tensor or None,
# and bias being ITensor is not supported in TensorRT api
# right now
if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
raise RuntimeError(
f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]"
)
bias = to_numpy(kwargs["bias"]) # type: ignore[arg-type]
if bias is not None:
bias = bias[None]
weight = kwargs["weight"]

if network.has_explicit_precision or isinstance(weight, TRTTensor):
weight = get_trt_tensor(network, weight, f"{name}_weight")
# Expand 1d weight with unsqueeze for calculation
unsqueeze_weight_layer = network.add_shuffle(input=weight)
unsqueeze_weight_layer.reshape_dims = tuple([*weight.shape, 1])
set_layer_name(unsqueeze_layer, target, name + "_unsqueeze_weight")
weight = unsqueeze_weight_layer.get_output(0)
weight_shape = tuple(kwargs["weight"].shape) # type: ignore[union-attr]
# will need to use uninitialized weight and set it later to support
# ITensor weights
dummy_weight = trt.Weights()
layer = network.add_convolution_nd(
input=input_val,
num_output_maps=weight.shape[0],
kernel_shape=weight.shape[2:],
kernel=dummy_weight,
bias=bias,
)

layer.set_input(1, weight)
else:
if not isinstance(kwargs["weight"], torch.Tensor):
raise RuntimeError(
f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tensor]"
)
weight = to_numpy(weight)
weight = np.expand_dims(weight, -1)
layer = network.add_convolution_nd(
input=input_val,
num_output_maps=weight.shape[0],
kernel_shape=weight.shape[2:],
kernel=weight,
bias=bias,
)
# expand params to 2d for computation
padding = list(kwargs["padding"])
padding.append(0)
stride = extend_attr_to_tuple(kwargs["stride"], 2)
dilation = extend_attr_to_tuple(kwargs["dilation"], 2)

set_layer_name(layer, target, name)
layer.stride_nd = stride
layer.padding_nd = padding
layer.dilation_nd = dilation
if kwargs["groups"] is not None:
layer.num_groups = kwargs["groups"]

result = layer.get_output(0)
squeeze_layer = network.add_shuffle(input=result)
squeeze_layer.reshape_dims = tuple(result.shape[:-1])
set_layer_name(squeeze_layer, target, name + "_squeeze")
return squeeze_layer.get_output(0)
return convolution.convNd(
network,
target,
source_ir=SourceIR.ACC,
name=name,
is_conv1d=True,
input_val=kwargs["input"],
weight=kwargs["weight"],
bias=kwargs["bias"],
stride=kwargs["stride"],
padding=kwargs["padding"],
dilation=kwargs["dilation"],
groups=kwargs["groups"],
)


@tensorrt_converter(acc_ops.conv3d)
Expand All @@ -187,63 +121,20 @@ def acc_ops_convnd(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]

if not isinstance(input_val, TRTTensor):
raise RuntimeError(
f"Conv received input {input_val} that is not part "
"of the TensorRT region!"
)

if has_dynamic_shape(input_val.shape):
assert input_val.shape[1] != -1, "Channel dim can't be dynamic for convolution."

# for now we'll assume bias is constant Tensor or None,
# and bias being ITensor is not supported in TensorRT api
# right now
if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
raise RuntimeError(
f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]"
)
bias = to_numpy(kwargs["bias"]) # type: ignore[arg-type]

if network.has_explicit_precision or isinstance(kwargs["weight"], TRTTensor):
weight = get_trt_tensor(network, kwargs["weight"], f"{name}_weight")
weight_shape = tuple(kwargs["weight"].shape) # type: ignore[union-attr]
# will need to use uninitialized weight and set it later to support
# ITensor weights
dummy_weight = trt.Weights()
layer = network.add_convolution_nd(
input=input_val,
num_output_maps=weight.shape[0],
kernel_shape=weight.shape[2:],
kernel=dummy_weight,
bias=bias,
)

layer.set_input(1, weight)
else:
if not isinstance(kwargs["weight"], torch.Tensor):
raise RuntimeError(
f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tensor]"
)
weight = to_numpy(kwargs["weight"])
layer = network.add_convolution_nd(
input=input_val,
num_output_maps=weight.shape[0],
kernel_shape=weight.shape[2:],
kernel=weight,
bias=bias,
)

set_layer_name(layer, target, name)
layer.stride_nd = kwargs["stride"]
layer.padding_nd = kwargs["padding"]
layer.dilation_nd = kwargs["dilation"]
if kwargs["groups"] is not None:
layer.num_groups = kwargs["groups"]

return layer.get_output(0)
return convolution.convNd(
network,
target,
source_ir=SourceIR.ACC,
name=name,
is_conv1d=False,
input_val=kwargs["input"],
weight=kwargs["weight"],
bias=kwargs["bias"],
stride=kwargs["stride"],
padding=kwargs["padding"],
dilation=kwargs["dilation"],
groups=kwargs["groups"],
)


@tensorrt_converter(acc_ops.conv_transpose2d)
Expand All @@ -268,32 +159,36 @@ def acc_ops_conv_transposend(
input_val.shape[1] != -1
), "Channel dim can't be dynamic for transpose convolution."

# for now we'll assume bias is constant Tensor or None,
# and bias being ITensor is not supported in TensorRT api
# right now
if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
raise RuntimeError(
f"ConvTranspose {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]"
)
bias = to_numpy(kwargs["bias"]) # type: ignore[arg-type]
if not isinstance(kwargs["bias"], TRTTensor):
if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
raise RuntimeError(
f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]"
)
bias = to_numpy(kwargs["bias"]) # type: ignore[arg-type]
else:
bias = kwargs["bias"]

if network.has_explicit_precision or isinstance(kwargs["weight"], TRTTensor):
weight = get_trt_tensor(network, kwargs["weight"], f"{name}_weight")
weight_shape = tuple(kwargs["weight"].shape) # type: ignore[union-attr]
# will need to use uninitialized weight and set it later to support
# ITensor weights
dummy_weight = trt.Weights()

# nn.ConvTranspose2d/3d weight size is (in_channels, out_channels/groups, kernel_0, kernel_1, [kernel_2])
layer = network.add_deconvolution_nd(
input=input_val,
num_output_maps=weight.shape[1] * kwargs["groups"],
kernel_shape=weight.shape[2:],
kernel=dummy_weight,
bias=bias,
kernel=trt.Weights(),
bias=trt.Weights() if isinstance(bias, TRTTensor) else bias,
)

layer.set_input(1, weight)

# If the bias is a TRTTensor, set it as an input of the layer
if isinstance(bias, TRTTensor):
bias = get_trt_tensor(network, bias, f"{name}_bias")
layer.set_input(2, bias)
else:
if not isinstance(kwargs["weight"], torch.Tensor):
raise RuntimeError(
Expand Down
33 changes: 28 additions & 5 deletions py/torch_tensorrt/fx/converters/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from .converter_utils import * # noqa: F403
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
from torch_tensorrt.fx.converters.impl import activation
from torch_tensorrt.fx.converters.impl import activation, convolution

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -129,13 +129,36 @@ def aten_ops_convolution(
# we do not handle output_padding.
if args[7] not in ([0], [0, 0], [0, 0, 0]):
raise RuntimeError(f"Target {target} has non-0 output_padding")

if len(kwargs_new["stride"]) == 1:
return acc_ops_converters.acc_ops_conv1d(
network, target, None, kwargs_new, name
return convolution.convNd(
network,
target,
source_ir=SourceIR.ATEN,
name=name,
is_conv1d=True,
input_val=kwargs_new["input"],
weight=kwargs_new["weight"],
bias=kwargs_new["bias"],
stride=kwargs_new["stride"],
padding=kwargs_new["padding"],
dilation=kwargs_new["dilation"],
groups=kwargs_new["groups"],
)
else:
return acc_ops_converters.acc_ops_convnd(
network, target, None, kwargs_new, name
return convolution.convNd(
network,
target,
source_ir=SourceIR.ATEN,
name=name,
is_conv1d=False,
input_val=kwargs_new["input"],
weight=kwargs_new["weight"],
bias=kwargs_new["bias"],
stride=kwargs_new["stride"],
padding=kwargs_new["padding"],
dilation=kwargs_new["dilation"],
groups=kwargs_new["groups"],
)


Expand Down
7 changes: 5 additions & 2 deletions py/torch_tensorrt/fx/converters/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,17 @@ def get_positive_dim(dim: int, dim_size: int) -> int:


def set_layer_name(
layer: TRTLayer, target: Target, name: str, source_ir: Optional[SourceIR] = None
layer: TRTLayer,
target: Union[Target, torch.nn.Module, str],
name: str,
source_ir: Optional[SourceIR] = None,
) -> None:
"""
Set the TensorRT layer name to "[TensorRT Layer Type]_[Original Op Name]_[FX Node Name with Suffix]"
Args:
layer (TRTLayer): A TensorRT layer of which we want to set the name.
target (Target): A fx node.target. For call_function node, it's the function that
target (Target): A fx node.target or submodule. For call_function node, it's the function that
the node represents.
name (str): Consists of fx node.name with optional suffix.
source_ir: (Optional[SourceIR]): The IR producing the op.
Expand Down
Loading

0 comments on commit 5348ac2

Please sign in to comment.