Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Nov 9, 2023
1 parent 6ef3b21 commit 1b077ca
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 64 deletions.
139 changes: 94 additions & 45 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,12 +754,12 @@ def aten_ops_cumsum(
)


@dynamo_tensorrt_converter(torch.ops.aten.tile.default) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.tile.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
) # type: ignore[misc]
)
def aten_ops_tile(
ctx: ConversionContext,
target: Target,
Expand All @@ -777,7 +777,7 @@ def aten_ops_tile(
)


@dynamo_tensorrt_converter(torch.ops.aten.permute.default) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.permute.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down Expand Up @@ -1702,29 +1702,63 @@ def aten_ops_logical_xor(


def bitwise_type_validator(node: Node) -> bool:
targets = [
supported_type = [torch.bool, bool]

tensor_targets = [
torch.ops.aten.bitwise_and.Tensor,
torch.ops.aten.bitwise_or.Tensor,
torch.ops.aten.bitwise_xor.Tensor,
]
if node.target not in targets:
return False
scalar_targets = [
torch.ops.aten.bitwise_and.Scalar,
torch.ops.aten.bitwise_or.Scalar,
torch.ops.aten.bitwise_xor.Scalar,
]
scalar_tensor_targets = [
torch.ops.aten.bitwise_and.Scalar_Tensor,
torch.ops.aten.bitwise_or.Scalar_Tensor,
torch.ops.aten.bitwise_xor.Scalar_Tensor,
]

lhs_val = node.args[0]
rhs_val = node.args[1]
lhs_meta = lhs_val.meta.get("tensor_meta")
rhs_meta = rhs_val.meta.get("tensor_meta")
if node.target in tensor_targets:
lhs_val = node.args[0]
rhs_val = node.args[1]
lhs_meta = lhs_val.meta.get("tensor_meta")
rhs_meta = rhs_val.meta.get("tensor_meta")
if lhs_meta is None or rhs_meta is None:
return False
return lhs_meta.dtype in supported_type and rhs_meta.dtype in supported_type

if lhs_meta is None or rhs_meta is None:
return False
elif node.target in scalar_targets:
lhs_val = node.args[0]
rhs_val = node.args[1]
lhs_meta = lhs_val.meta.get("tensor_meta")
if lhs_meta is None:
return False
return lhs_meta.dtype in supported_type and isinstance(rhs_val, bool)

supported_type = [torch.bool, bool]
return lhs_meta.dtype in supported_type and rhs_meta.dtype in supported_type
elif node.target in scalar_tensor_targets:
lhs_val = node.args[0]
rhs_val = node.args[1]
rhs_meta = rhs_val.meta.get("tensor_meta")
if rhs_meta is None:
return False
return isinstance(lhs_val, bool) and rhs_meta.dtype in supported_type

else:
return False


@dynamo_tensorrt_converter(torch.ops.aten.bitwise_and.Tensor, capability_validator=bitwise_type_validator) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_and.Scalar) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_and.Scalar_Tensor) # type: ignore[misc]
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_and.Tensor, capability_validator=bitwise_type_validator
)
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_and.Scalar, capability_validator=bitwise_type_validator
)
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_and.Scalar_Tensor,
capability_validator=bitwise_type_validator,
)
def aten_ops_bitwise_and(
ctx: ConversionContext,
target: Target,
Expand All @@ -1742,9 +1776,15 @@ def aten_ops_bitwise_and(
)


@dynamo_tensorrt_converter(torch.ops.aten.bitwise_or.Tensor, capability_validator=bitwise_type_validator) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_or.Scalar) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_or.Scalar_Tensor) # type: ignore[misc]
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_or.Tensor, capability_validator=bitwise_type_validator
)
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_or.Scalar, capability_validator=bitwise_type_validator
)
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_or.Scalar_Tensor, capability_validator=bitwise_type_validator
)
def aten_ops_bitwise_or(
ctx: ConversionContext,
target: Target,
Expand All @@ -1762,9 +1802,16 @@ def aten_ops_bitwise_or(
)


@dynamo_tensorrt_converter(torch.ops.aten.bitwise_xor.Tensor, capability_validator=bitwise_type_validator) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_xor.Scalar) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.bitwise_xor.Scalar_Tensor) # type: ignore[misc]
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_xor.Tensor, capability_validator=bitwise_type_validator
)
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_xor.Scalar, capability_validator=bitwise_type_validator
)
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_xor.Scalar_Tensor,
capability_validator=bitwise_type_validator,
)
def aten_ops_bitwise_xor(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -1793,12 +1840,14 @@ def bitwise_not_type_validator(node: Node) -> bool:
return val_meta.dtype in supported_type


@dynamo_tensorrt_converter(torch.ops.aten.bitwise_not.default, capability_validator=bitwise_not_type_validator) # type: ignore[misc]
@dynamo_tensorrt_converter(
torch.ops.aten.bitwise_not.default, capability_validator=bitwise_not_type_validator
)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
) # type: ignore[misc]
)
def aten_ops_bitwise_not(
ctx: ConversionContext,
target: Target,
Expand All @@ -1815,13 +1864,13 @@ def aten_ops_bitwise_not(
)


@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
) # type: ignore[misc]
)
def aten_ops_eq(
ctx: ConversionContext,
target: Target,
Expand All @@ -1839,13 +1888,13 @@ def aten_ops_eq(
)


@dynamo_tensorrt_converter(torch.ops.aten.ne.Tensor) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.ne.Scalar) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.ne.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.ne.Scalar)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
) # type: ignore[misc]
)
def aten_ops_ne(
ctx: ConversionContext,
target: Target,
Expand All @@ -1863,13 +1912,13 @@ def aten_ops_ne(
)


@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
) # type: ignore[misc]
)
def aten_ops_gt(
ctx: ConversionContext,
target: Target,
Expand All @@ -1887,13 +1936,13 @@ def aten_ops_gt(
)


@dynamo_tensorrt_converter(torch.ops.aten.ge.Tensor) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.ge.Scalar) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.ge.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.ge.Scalar)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
) # type: ignore[misc]
)
def aten_ops_ge(
ctx: ConversionContext,
target: Target,
Expand All @@ -1911,13 +1960,13 @@ def aten_ops_ge(
)


@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
) # type: ignore[misc]
)
def aten_ops_lt(
ctx: ConversionContext,
target: Target,
Expand All @@ -1935,13 +1984,13 @@ def aten_ops_lt(
)


@dynamo_tensorrt_converter(torch.ops.aten.le.Tensor) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.le.Scalar) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.le.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.le.Scalar)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
) # type: ignore[misc]
)
def aten_ops_le(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -2191,14 +2240,14 @@ def aten_ops_argmax(
)


@dynamo_tensorrt_converter(torch.ops.aten.addmm.default) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.addmm.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
1: (np.ndarray, torch.Tensor, TRTTensor),
2: (np.ndarray, torch.Tensor, TRTTensor),
}
) # type: ignore[misc]
)
def aten_ops_addmm(
ctx: ConversionContext,
target: Target,
Expand Down
Loading

0 comments on commit 1b077ca

Please sign in to comment.