Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support more elementwise and unary dynamo converters #2429

Merged
merged 6 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 258 additions & 8 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,12 +754,12 @@ def aten_ops_cumsum(
)


@dynamo_tensorrt_converter(torch.ops.aten.tile.default) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.tile.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
) # type: ignore[misc]
)
def aten_ops_tile(
ctx: ConversionContext,
target: Target,
Expand All @@ -777,7 +777,7 @@ def aten_ops_tile(
)


@dynamo_tensorrt_converter(torch.ops.aten.permute.default) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.permute.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down Expand Up @@ -1701,9 +1701,177 @@ def aten_ops_logical_xor(
)


def bitwise_type_validator(node: Node) -> bool:
supported_type = [torch.bool, bool]

tensor_targets = [
torch.ops.aten.bitwise_and.Tensor,
torch.ops.aten.bitwise_or.Tensor,
torch.ops.aten.bitwise_xor.Tensor,
]
scalar_targets = [
torch.ops.aten.bitwise_and.Scalar,
torch.ops.aten.bitwise_or.Scalar,
torch.ops.aten.bitwise_xor.Scalar,
]
scalar_tensor_targets = [
torch.ops.aten.bitwise_and.Scalar_Tensor,
torch.ops.aten.bitwise_or.Scalar_Tensor,
torch.ops.aten.bitwise_xor.Scalar_Tensor,
]

if node.target in tensor_targets:
lhs_val = node.args[0]
rhs_val = node.args[1]
lhs_meta = lhs_val.meta.get("tensor_meta")
rhs_meta = rhs_val.meta.get("tensor_meta")
if lhs_meta is None or rhs_meta is None:
return False
return lhs_meta.dtype in supported_type and rhs_meta.dtype in supported_type

elif node.target in scalar_targets:
lhs_val = node.args[0]
rhs_val = node.args[1]
lhs_meta = lhs_val.meta.get("tensor_meta")
if lhs_meta is None:
return False
return lhs_meta.dtype in supported_type and isinstance(rhs_val, bool)

elif node.target in scalar_tensor_targets:
lhs_val = node.args[0]
rhs_val = node.args[1]
rhs_meta = rhs_val.meta.get("tensor_meta")
if rhs_meta is None:
return False
return isinstance(lhs_val, bool) and rhs_meta.dtype in supported_type

else:
return False


@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_and.Tensor, capability_validator=bitwise_type_validator
)
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_and.Scalar, capability_validator=bitwise_type_validator
)
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_and.Scalar_Tensor,
capability_validator=bitwise_type_validator,
)
def aten_ops_bitwise_and(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.bitwise_and(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_or.Tensor, capability_validator=bitwise_type_validator
)
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_or.Scalar, capability_validator=bitwise_type_validator
)
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_or.Scalar_Tensor, capability_validator=bitwise_type_validator
)
def aten_ops_bitwise_or(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.bitwise_or(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_xor.Tensor, capability_validator=bitwise_type_validator
)
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_xor.Scalar, capability_validator=bitwise_type_validator
)
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_xor.Scalar_Tensor,
capability_validator=bitwise_type_validator,
)
def aten_ops_bitwise_xor(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.bitwise_xor(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


def bitwise_not_type_validator(node: Node) -> bool:
val = node.args[0]
val_meta = val.meta.get("tensor_meta")

if val_meta is None:
return False

supported_type = [torch.bool, bool]
return val_meta.dtype in supported_type


@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_not.default, capability_validator=bitwise_not_type_validator
)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_bitwise_not(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.unary.bitwise_not(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
)


@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar)
def aten_ops_equal(
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_eq(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
Expand All @@ -1720,9 +1888,38 @@ def aten_ops_equal(
)


@dynamo_tensorrt_converter(torch.ops.aten.ne.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.ne.Scalar)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_ne(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.ne(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar)
def aten_ops_greater(
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_gt(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
Expand All @@ -1739,9 +1936,38 @@ def aten_ops_greater(
)


@dynamo_tensorrt_converter(torch.ops.aten.ge.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.ge.Scalar)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_ge(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.ge(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar)
def aten_ops_less(
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_lt(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
Expand All @@ -1758,6 +1984,30 @@ def aten_ops_less(
)


@dynamo_tensorrt_converter(torch.ops.aten.le.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.le.Scalar)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_le(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.elementwise.le(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


def conv_param_validator(conv_node: Node) -> bool:
return conv_node.args[7] in ([0], [0, 0], [0, 0, 0])

Expand Down Expand Up @@ -1990,14 +2240,14 @@ def aten_ops_argmax(
)


@dynamo_tensorrt_converter(torch.ops.aten.addmm.default) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.addmm.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
1: (np.ndarray, torch.Tensor, TRTTensor),
2: (np.ndarray, torch.Tensor, TRTTensor),
}
) # type: ignore[misc]
)
def aten_ops_addmm(
ctx: ConversionContext,
target: Target,
Expand Down
8 changes: 4 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def convert_binary_elementwise(
source_ir: Optional[SourceIR],
name: str,
op_type: trt.ElementWiseOperation,
lhs_val: Union[int, float, TRTTensor, torch.Tensor],
rhs_val: Union[int, float, TRTTensor, torch.Tensor],
lhs_val: Union[int, float, bool, TRTTensor, torch.Tensor],
rhs_val: Union[int, float, bool, TRTTensor, torch.Tensor],
) -> TRTTensor:
"""
This function adds a TensorRT elementwise layer. We allow both operands to be
Expand Down Expand Up @@ -120,11 +120,11 @@ def convert_binary_elementwise(
# Note that the dtype here is supposed to be the same as the scalar
# dtype but we don't have a way to detect whether it makes sense for the
# scalar to be float or half. Hence we go with the lhs dtype.
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)):
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int, bool)):
rhs_val = np.array(
[rhs_val], dtype=unified_dtype_converter(lhs_dtype, Frameworks.NUMPY)
)
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)):
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int, bool)):
lhs_val = np.array(
[lhs_val], dtype=unified_dtype_converter(rhs_dtype, Frameworks.NUMPY)
)
Expand Down
Loading