Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix/feat: Move convolution core to impl + add feature (FX converter refactor) #1972

Merged
merged 1 commit into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions py/torch_tensorrt/dynamo/test/test_dynamo_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def test_resnet18(ir):
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"ir": ir,
"pass_through_build_failures": True,
}

trt_mod = torchtrt.compile(model, **compile_spec)
Expand Down Expand Up @@ -60,6 +61,7 @@ def test_mobilenet_v2(ir):
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"ir": ir,
"pass_through_build_failures": True,
}

trt_mod = torchtrt.compile(model, **compile_spec)
Expand Down Expand Up @@ -90,6 +92,7 @@ def test_efficientnet_b0(ir):
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"ir": ir,
"pass_through_build_failures": True,
}

trt_mod = torchtrt.compile(model, **compile_spec)
Expand Down Expand Up @@ -129,6 +132,7 @@ def test_bert_base_uncased(ir):
"enabled_precisions": {torch.float},
"truncate_long_and_double": True,
"ir": ir,
"pass_through_build_failures": True,
}
trt_mod = torchtrt.compile(model, **compile_spec)

Expand Down Expand Up @@ -163,6 +167,7 @@ def test_resnet18_half(ir):
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.half},
"ir": ir,
"pass_through_build_failures": True,
}

trt_mod = torchtrt.compile(model, **compile_spec)
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/fx/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from .adaptive_avgpool import * # noqa: F401 F403
from .add import * # noqa: F401 F403
from .batchnorm import * # noqa: F401 F403
from .convolution import * # noqa: F401 F403
from .linear import * # noqa: F401 F403
from .maxpool import * # noqa: F401 F403
from .mul import * # noqa: F401 F403
from .transformation import * # noqa: F401 F403
from .quantization import * # noqa: F401 F403
from .acc_ops_converters import * # noqa: F401 F403
from .aten_ops_converters import * # noqa: F401 F403
from .nn_ops_converters import * # noqa: F401 F403

TRT_LOGGER = trt.Logger()
trt.init_libnvinfer_plugins(TRT_LOGGER, "")
167 changes: 29 additions & 138 deletions py/torch_tensorrt/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
trt_transposed_matmul,
)
from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous
from torch_tensorrt.fx.converters.impl import activation
from torch_tensorrt.fx.converters.impl import activation, convolution

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -96,86 +96,20 @@ def acc_ops_conv1d(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]
if not isinstance(input_val, TRTTensor):
raise RuntimeError(
f"Conv received input {input_val} that is not part "
"of the TensorRT region!"
)

# Process 1d input with unsqueeze -> conv2d -> squeeze to calculated conv1d
unsqueeze_layer = network.add_shuffle(input=input_val)
unsqueeze_layer.reshape_dims = tuple([*input_val.shape, 1])
set_layer_name(unsqueeze_layer, target, name + "_unsqueeze")
input_val = unsqueeze_layer.get_output(0)

if has_dynamic_shape(input_val.shape):
assert input_val.shape[1] != -1, "Channel dim can't be dynamic for convolution."

# for now we'll assume bias is constant Tensor or None,
# and bias being ITensor is not supported in TensorRT api
# right now
if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
raise RuntimeError(
f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]"
)
bias = to_numpy(kwargs["bias"]) # type: ignore[arg-type]
if bias is not None:
bias = bias[None]
weight = kwargs["weight"]

if network.has_explicit_precision or isinstance(weight, TRTTensor):
weight = get_trt_tensor(network, weight, f"{name}_weight")
# Expand 1d weight with unsqueeze for calculation
unsqueeze_weight_layer = network.add_shuffle(input=weight)
unsqueeze_weight_layer.reshape_dims = tuple([*weight.shape, 1])
set_layer_name(unsqueeze_layer, target, name + "_unsqueeze_weight")
weight = unsqueeze_weight_layer.get_output(0)
weight_shape = tuple(kwargs["weight"].shape) # type: ignore[union-attr]
# will need to use uninitialized weight and set it later to support
# ITensor weights
dummy_weight = trt.Weights()
layer = network.add_convolution_nd(
input=input_val,
num_output_maps=weight.shape[0],
kernel_shape=weight.shape[2:],
kernel=dummy_weight,
bias=bias,
)

layer.set_input(1, weight)
else:
if not isinstance(kwargs["weight"], torch.Tensor):
raise RuntimeError(
f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tensor]"
)
weight = to_numpy(weight)
weight = np.expand_dims(weight, -1)
layer = network.add_convolution_nd(
input=input_val,
num_output_maps=weight.shape[0],
kernel_shape=weight.shape[2:],
kernel=weight,
bias=bias,
)
# expand params to 2d for computation
padding = list(kwargs["padding"])
padding.append(0)
stride = extend_attr_to_tuple(kwargs["stride"], 2)
dilation = extend_attr_to_tuple(kwargs["dilation"], 2)

set_layer_name(layer, target, name)
layer.stride_nd = stride
layer.padding_nd = padding
layer.dilation_nd = dilation
if kwargs["groups"] is not None:
layer.num_groups = kwargs["groups"]

result = layer.get_output(0)
squeeze_layer = network.add_shuffle(input=result)
squeeze_layer.reshape_dims = tuple(result.shape[:-1])
set_layer_name(squeeze_layer, target, name + "_squeeze")
return squeeze_layer.get_output(0)
return convolution.convNd(
network,
target,
source_ir=SourceIR.ACC,
name=name,
is_conv1d=True,
input_val=kwargs["input"],
weight=kwargs["weight"],
bias=kwargs["bias"],
stride=kwargs["stride"],
padding=kwargs["padding"],
dilation=kwargs["dilation"],
groups=kwargs["groups"],
)


@tensorrt_converter(acc_ops.conv3d)
Expand All @@ -187,63 +121,20 @@ def acc_ops_convnd(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = kwargs["input"]

if not isinstance(input_val, TRTTensor):
raise RuntimeError(
f"Conv received input {input_val} that is not part "
"of the TensorRT region!"
)

if has_dynamic_shape(input_val.shape):
assert input_val.shape[1] != -1, "Channel dim can't be dynamic for convolution."

# for now we'll assume bias is constant Tensor or None,
# and bias being ITensor is not supported in TensorRT api
# right now
if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
raise RuntimeError(
f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]"
)
bias = to_numpy(kwargs["bias"]) # type: ignore[arg-type]

if network.has_explicit_precision or isinstance(kwargs["weight"], TRTTensor):
weight = get_trt_tensor(network, kwargs["weight"], f"{name}_weight")
weight_shape = tuple(kwargs["weight"].shape) # type: ignore[union-attr]
# will need to use uninitialized weight and set it later to support
# ITensor weights
dummy_weight = trt.Weights()
layer = network.add_convolution_nd(
input=input_val,
num_output_maps=weight.shape[0],
kernel_shape=weight.shape[2:],
kernel=dummy_weight,
bias=bias,
)

layer.set_input(1, weight)
else:
if not isinstance(kwargs["weight"], torch.Tensor):
raise RuntimeError(
f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tensor]"
)
weight = to_numpy(kwargs["weight"])
layer = network.add_convolution_nd(
input=input_val,
num_output_maps=weight.shape[0],
kernel_shape=weight.shape[2:],
kernel=weight,
bias=bias,
)

set_layer_name(layer, target, name)
layer.stride_nd = kwargs["stride"]
layer.padding_nd = kwargs["padding"]
layer.dilation_nd = kwargs["dilation"]
if kwargs["groups"] is not None:
layer.num_groups = kwargs["groups"]

return layer.get_output(0)
return convolution.convNd(
network,
target,
source_ir=SourceIR.ACC,
name=name,
is_conv1d=False,
input_val=kwargs["input"],
weight=kwargs["weight"],
bias=kwargs["bias"],
stride=kwargs["stride"],
padding=kwargs["padding"],
dilation=kwargs["dilation"],
groups=kwargs["groups"],
)


@tensorrt_converter(acc_ops.conv_transpose2d)
Expand Down
33 changes: 28 additions & 5 deletions py/torch_tensorrt/fx/converters/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from .converter_utils import * # noqa: F403
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
from torch_tensorrt.fx.converters.impl import activation
from torch_tensorrt.fx.converters.impl import activation, convolution

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -129,13 +129,36 @@ def aten_ops_convolution(
# we do not handle output_padding.
if args[7] not in ([0], [0, 0], [0, 0, 0]):
raise RuntimeError(f"Target {target} has non-0 output_padding")

if len(kwargs_new["stride"]) == 1:
return acc_ops_converters.acc_ops_conv1d(
network, target, None, kwargs_new, name
return convolution.convNd(
network,
target,
source_ir=SourceIR.ATEN,
name=name,
is_conv1d=True,
input_val=kwargs_new["input"],
weight=kwargs_new["weight"],
bias=kwargs_new["bias"],
stride=kwargs_new["stride"],
padding=kwargs_new["padding"],
dilation=kwargs_new["dilation"],
groups=kwargs_new["groups"],
)
else:
return acc_ops_converters.acc_ops_convnd(
network, target, None, kwargs_new, name
return convolution.convNd(
network,
target,
source_ir=SourceIR.ATEN,
name=name,
is_conv1d=False,
input_val=kwargs_new["input"],
weight=kwargs_new["weight"],
bias=kwargs_new["bias"],
stride=kwargs_new["stride"],
padding=kwargs_new["padding"],
dilation=kwargs_new["dilation"],
groups=kwargs_new["groups"],
)


Expand Down
7 changes: 5 additions & 2 deletions py/torch_tensorrt/fx/converters/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,17 @@ def get_positive_dim(dim: int, dim_size: int) -> int:


def set_layer_name(
layer: TRTLayer, target: Target, name: str, source_ir: Optional[SourceIR] = None
layer: TRTLayer,
target: Union[Target, torch.nn.Module, str],
name: str,
source_ir: Optional[SourceIR] = None,
) -> None:
"""
Set the TensorRT layer name to "[TensorRT Layer Type]_[Original Op Name]_[FX Node Name with Suffix]"

Args:
layer (TRTLayer): A TensorRT layer of which we want to set the name.
target (Target): A fx node.target. For call_function node, it's the function that
target (Target): A fx node.target or submodule. For call_function node, it's the function that
the node represents.
name (str): Consists of fx node.name with optional suffix.
source_ir: (Optional[SourceIR]): The IR producing the op.
Expand Down
Loading