-
Notifications
You must be signed in to change notification settings - Fork 360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Linter + config fix #2636
fix: Linter + config fix #2636
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py 2024-02-02 22:58:37.176921+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py 2024-02-02 23:00:23.939132+00:00
@@ -30,16 +30,18 @@
gpu_id (int): Device ID for target GPU
dla_core (int): Core ID for target DLA core
allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
"""
- device_type: Optional[
- trt.DeviceType
- ] = None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+ device_type: Optional[trt.DeviceType] = (
+ None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+ )
gpu_id: int = -1 #: Device ID for target GPU
dla_core: int = -1 #: Core ID for target DLA core
- allow_gpu_fallback: bool = False #: Whether falling back to GPU if DLA cannot support an op should be allowed
+ allow_gpu_fallback: bool = (
+ False #: Whether falling back to GPU if DLA cannot support an op should be allowed
+ )
def __init__(self, *args: Any, **kwargs: Any):
"""__init__ Method for torch_tensorrt.Device
Device accepts one of a few construction patterns
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py 2024-02-02 22:58:37.176921+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py 2024-02-02 23:00:24.141670+00:00
@@ -26,16 +26,16 @@
class _ShapeMode(Enum):
STATIC = 0
DYNAMIC = 1
- shape_mode: Optional[
- _ShapeMode
- ] = None #: Is input statically or dynamically shaped
- shape: Optional[
- Tuple[int, ...] | Dict[str, Tuple[int, ...]]
- ] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+ shape_mode: Optional[_ShapeMode] = (
+ None #: Is input statically or dynamically shaped
+ )
+ shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = (
+ None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+ )
dtype: _enums.dtype = (
_enums.dtype.unknown
) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
_explicit_set_dtype: bool = False
format: _enums.TensorFormat = (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2024-02-02 22:58:37.180921+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2024-02-02 23:00:24.398349+00:00
@@ -26,13 +26,13 @@
from packaging import version
_LOGGER: logging.Logger = logging.getLogger(__name__)
-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
- Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+ Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)
class UnsupportedOperatorException(RuntimeError):
pass
@@ -90,13 +90,13 @@
self.input_specs_iter = 0
self._cur_node_name: Optional[str] = None
self._cur_node: Optional[torch.fx.Node] = None
self._input_names: List[str] = []
self._output_names: List[str] = []
- self._itensor_to_tensor_meta: Dict[
- trt.tensorrt.ITensor, TensorMetadata
- ] = dict()
+ self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+ dict()
+ )
self.compilation_settings = compilation_settings
# Data types for TRT Module output Tensors
self.output_dtypes = output_dtypes
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2024-02-02 22:58:37.180921+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2024-02-02 23:00:24.452976+00:00
@@ -322,17 +322,15 @@
else:
raise AssertionError(f"Cannot convert {input_val} to TRT constant")
@overload
-def get_positive_dim(dim: int, dim_size: int) -> int:
- ...
+def get_positive_dim(dim: int, dim_size: int) -> int: ...
@overload
-def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]:
- ...
+def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]: ...
def get_positive_dim(
dim: Union[int, Sequence[int]], dim_size: int
) -> Union[int, Tuple[int, ...]]:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py 2024-02-02 22:58:37.184921+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py 2024-02-02 23:00:24.827064+00:00
@@ -5,13 +5,13 @@
from torch._decomp import get_decompositions as get_torch_decompositions
from torch._ops import OpOverload, OpOverloadPacket
aten = torch.ops.aten
-_core_aten_decompositions: Dict[
- OpOverload, Callable[[Any], Any]
-] = core_aten_decompositions()
+_core_aten_decompositions: Dict[OpOverload, Callable[[Any], Any]] = (
+ core_aten_decompositions()
+)
torch_enabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
aten._adaptive_avg_pool2d_backward,
aten.addcdiv,
aten.addcdiv_,
aten.addcmul,
@@ -178,13 +178,13 @@
torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
aten._softmax.default,
}
-ENABLED_TORCH_DECOMPOSITIONS: Dict[
- OpOverload, Callable[[Any], Any]
-] = get_torch_decompositions(torch_enabled_decompositions)
+ENABLED_TORCH_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = (
+ get_torch_decompositions(torch_enabled_decompositions)
+)
TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {}
def check_decomp_set_invariants() -> None:
"""Validates no overlap between enabled and disabled decomposition sets"""
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py 2024-02-02 22:58:37.184921+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py 2024-02-02 23:00:24.833470+00:00
@@ -20,16 +20,14 @@
logger.debug(f"Graph after lowering linear:\n{gm.graph}")
return gm
-def linear_replacement() -> (
- Tuple[
- torch.fx.GraphModule,
- Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
- ]
-):
+def linear_replacement() -> Tuple[
+ torch.fx.GraphModule,
+ Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
"""Constructs the original and replacement functions for linear"""
# Original graph
def orig(
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py 2024-02-02 22:58:37.184921+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py 2024-02-02 23:00:24.872952+00:00
@@ -20,16 +20,14 @@
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")
return gm
-def view_replacement() -> (
- Tuple[
- torch.fx.GraphModule,
- Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
- ]
-):
+def view_replacement() -> Tuple[
+ torch.fx.GraphModule,
+ Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
+]:
"""Constructs the original and replacement functions for view"""
# Original graph
def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
return torch.ops.aten.view.default(input, shape)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py 2024-02-02 22:58:37.184921+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py 2024-02-02 23:00:24.886494+00:00
@@ -58,16 +58,14 @@
logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}")
return gm
-def scaled_dot_product_attention_replacement() -> (
- Tuple[
- Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]],
- Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
- ]
-):
+def scaled_dot_product_attention_replacement() -> Tuple[
+ Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]],
+ Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
"""Constructs the original and replacement functions for efficient attention"""
# Efficient Attention original graph
def efficient(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py 2024-02-02 22:58:37.184921+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py 2024-02-02 23:00:25.480659+00:00
@@ -19,13 +19,13 @@
from .observer import Observer
from .utils import Frameworks, LowerPrecision, get_dynamic_dims, unified_dtype_converter
_LOGGER: logging.Logger = logging.getLogger(__name__)
-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
- Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+ Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)
class TRTInterpreterResult(NamedTuple):
engine: Any
input_names: Sequence[str]
@@ -73,13 +73,13 @@
self.input_specs_iter = 0
self.validate_input_specs()
self._cur_node_name: Optional[str] = None
self._input_names: List[str] = []
self._output_names: List[str] = []
- self._itensor_to_tensor_meta: Dict[
- trt.tensorrt.ITensor, TensorMetadata
- ] = dict()
+ self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+ dict()
+ )
def validate_input_specs(self):
for shape, _, _, shape_ranges, has_batch_dim in self.input_specs:
if not self.network.has_implicit_batch_dimension:
assert (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py 2024-02-02 22:58:37.184921+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py 2024-02-02 23:00:25.517674+00:00
@@ -316,13 +316,11 @@
"input": args[0],
"kernel_size": args[1],
"stride": (
args[2]
if len(args) > 2
- else (None, None)
- if len(args[1]) == 2
- else (None, None, None)
+ else (None, None) if len(args[1]) == 2 else (None, None, None)
),
"padding": (
args[3] if len(args) > 3 else (0, 0) if len(args[1]) == 2 else (0, 0, 0)
),
"dilation": (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py 2024-02-02 22:58:37.188921+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py 2024-02-02 23:00:26.048695+00:00
@@ -193,13 +193,11 @@
kwargs2 = {"equal_nan": True}
if rtol:
kwargs2["rtol"] = rtol
if atol:
kwargs2["atol"] = atol
- kwargs2[
- "msg"
- ] = (
+ kwargs2["msg"] = (
lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}"
)
# If tensors are on different devices, make sure to compare
# their copies that are on the same device.
if x.get_device() != y.get_device():
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py 2024-02-02 22:58:37.184921+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py 2024-02-02 23:00:26.052416+00:00
@@ -535,13 +535,13 @@
reshape_batch_size: Optional[fx.Node] = get_reshape_batch_size_as_node(
maybe_reshape
)
if not reshape_batch_size:
continue
- reshape_batch_size_inferred_source: Optional[
- fx.Node
- ] = get_reshape_batch_size_inferred_source(reshape_batch_size)
+ reshape_batch_size_inferred_source: Optional[fx.Node] = (
+ get_reshape_batch_size_inferred_source(reshape_batch_size)
+ )
if not reshape_batch_size_inferred_source:
continue
reshape_input: fx.Node = maybe_reshape.kwargs["input"]
if reshape_input == reshape_batch_size_inferred_source:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py 2024-02-02 22:58:38.968690+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py 2024-02-02 23:00:28.215562+00:00
@@ -30,16 +30,18 @@
gpu_id (int): Device ID for target GPU
dla_core (int): Core ID for target DLA core
allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
"""
- device_type: Optional[
- trt.DeviceType
- ] = None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+ device_type: Optional[trt.DeviceType] = (
+ None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+ )
gpu_id: int = -1 #: Device ID for target GPU
dla_core: int = -1 #: Core ID for target DLA core
- allow_gpu_fallback: bool = False #: Whether falling back to GPU if DLA cannot support an op should be allowed
+ allow_gpu_fallback: bool = (
+ False #: Whether falling back to GPU if DLA cannot support an op should be allowed
+ )
def __init__(self, *args: Any, **kwargs: Any):
"""__init__ Method for torch_tensorrt.Device
Device accepts one of a few construction patterns
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py 2024-02-02 22:58:38.968690+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py 2024-02-02 23:00:28.428113+00:00
@@ -26,16 +26,16 @@
class _ShapeMode(Enum):
STATIC = 0
DYNAMIC = 1
- shape_mode: Optional[
- _ShapeMode
- ] = None #: Is input statically or dynamically shaped
- shape: Optional[
- Tuple[int, ...] | Dict[str, Tuple[int, ...]]
- ] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+ shape_mode: Optional[_ShapeMode] = (
+ None #: Is input statically or dynamically shaped
+ )
+ shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = (
+ None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+ )
dtype: _enums.dtype = (
_enums.dtype.unknown
) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
_explicit_set_dtype: bool = False
format: _enums.TensorFormat = (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2024-02-02 22:58:38.972690+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2024-02-02 23:00:28.672335+00:00
@@ -26,13 +26,13 @@
from packaging import version
_LOGGER: logging.Logger = logging.getLogger(__name__)
-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
- Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+ Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)
class UnsupportedOperatorException(RuntimeError):
pass
@@ -90,13 +90,13 @@
self.input_specs_iter = 0
self._cur_node_name: Optional[str] = None
self._cur_node: Optional[torch.fx.Node] = None
self._input_names: List[str] = []
self._output_names: List[str] = []
- self._itensor_to_tensor_meta: Dict[
- trt.tensorrt.ITensor, TensorMetadata
- ] = dict()
+ self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+ dict()
+ )
self.compilation_settings = compilation_settings
# Data types for TRT Module output Tensors
self.output_dtypes = output_dtypes
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2024-02-02 22:58:38.972690+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2024-02-02 23:00:28.737105+00:00
@@ -322,17 +322,15 @@
else:
raise AssertionError(f"Cannot convert {input_val} to TRT constant")
@overload
-def get_positive_dim(dim: int, dim_size: int) -> int:
- ...
+def get_positive_dim(dim: int, dim_size: int) -> int: ...
@overload
-def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]:
- ...
+def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]: ...
def get_positive_dim(
dim: Union[int, Sequence[int]], dim_size: int
) -> Union[int, Tuple[int, ...]]:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py 2024-02-02 22:58:38.972690+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py 2024-02-02 23:00:29.084024+00:00
@@ -5,13 +5,13 @@
from torch._decomp import get_decompositions as get_torch_decompositions
from torch._ops import OpOverload, OpOverloadPacket
aten = torch.ops.aten
-_core_aten_decompositions: Dict[
- OpOverload, Callable[[Any], Any]
-] = core_aten_decompositions()
+_core_aten_decompositions: Dict[OpOverload, Callable[[Any], Any]] = (
+ core_aten_decompositions()
+)
torch_enabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
aten._adaptive_avg_pool2d_backward,
aten.addcdiv,
aten.addcdiv_,
aten.addcmul,
@@ -178,13 +178,13 @@
torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
aten._softmax.default,
}
-ENABLED_TORCH_DECOMPOSITIONS: Dict[
- OpOverload, Callable[[Any], Any]
-] = get_torch_decompositions(torch_enabled_decompositions)
+ENABLED_TORCH_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = (
+ get_torch_decompositions(torch_enabled_decompositions)
+)
TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {}
def check_decomp_set_invariants() -> None:
"""Validates no overlap between enabled and disabled decomposition sets"""
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py 2024-02-02 22:58:38.972690+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py 2024-02-02 23:00:29.090021+00:00
@@ -20,16 +20,14 @@
logger.debug(f"Graph after lowering linear:\n{gm.graph}")
return gm
-def linear_replacement() -> (
- Tuple[
- torch.fx.GraphModule,
- Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
- ]
-):
+def linear_replacement() -> Tuple[
+ torch.fx.GraphModule,
+ Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
"""Constructs the original and replacement functions for linear"""
# Original graph
def orig(
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py 2024-02-02 22:58:38.972690+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py 2024-02-02 23:00:29.131410+00:00
@@ -20,16 +20,14 @@
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")
return gm
-def view_replacement() -> (
- Tuple[
- torch.fx.GraphModule,
- Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
- ]
-):
+def view_replacement() -> Tuple[
+ torch.fx.GraphModule,
+ Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
+]:
"""Constructs the original and replacement functions for view"""
# Original graph
def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
return torch.ops.aten.view.default(input, shape)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py 2024-02-02 22:58:38.972690+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py 2024-02-02 23:00:29.148990+00:00
@@ -58,16 +58,14 @@
logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}")
return gm
-def scaled_dot_product_attention_replacement() -> (
- Tuple[
- Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]],
- Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
- ]
-):
+def scaled_dot_product_attention_replacement() -> Tuple[
+ Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]],
+ Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
"""Constructs the original and replacement functions for efficient attention"""
# Efficient Attention original graph
def efficient(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py 2024-02-02 22:58:38.976690+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py 2024-02-02 23:00:29.729130+00:00
@@ -19,13 +19,13 @@
from .observer import Observer
from .utils import Frameworks, LowerPrecision, get_dynamic_dims, unified_dtype_converter
_LOGGER: logging.Logger = logging.getLogger(__name__)
-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
- Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+ Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)
class TRTInterpreterResult(NamedTuple):
engine: Any
input_names: Sequence[str]
@@ -73,13 +73,13 @@
self.input_specs_iter = 0
self.validate_input_specs()
self._cur_node_name: Optional[str] = None
self._input_names: List[str] = []
self._output_names: List[str] = []
- self._itensor_to_tensor_meta: Dict[
- trt.tensorrt.ITensor, TensorMetadata
- ] = dict()
+ self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+ dict()
+ )
def validate_input_specs(self):
for shape, _, _, shape_ranges, has_batch_dim in self.input_specs:
if not self.network.has_implicit_batch_dimension:
assert (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py 2024-02-02 22:58:38.976690+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py 2024-02-02 23:00:29.786451+00:00
@@ -316,13 +316,11 @@
"input": args[0],
"kernel_size": args[1],
"stride": (
args[2]
if len(args) > 2
- else (None, None)
- if len(args[1]) == 2
- else (None, None, None)
+ else (None, None) if len(args[1]) == 2 else (None, None, None)
),
"padding": (
args[3] if len(args) > 3 else (0, 0) if len(args[1]) == 2 else (0, 0, 0)
),
"dilation": (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py 2024-02-02 22:58:38.976690+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py 2024-02-02 23:00:30.238090+00:00
@@ -193,13 +193,11 @@
kwargs2 = {"equal_nan": True}
if rtol:
kwargs2["rtol"] = rtol
if atol:
kwargs2["atol"] = atol
- kwargs2[
- "msg"
- ] = (
+ kwargs2["msg"] = (
lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}"
)
# If tensors are on different devices, make sure to compare
# their copies that are on the same device.
if x.get_device() != y.get_device():
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py 2024-02-02 22:58:38.976690+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py 2024-02-02 23:00:30.290959+00:00
@@ -535,13 +535,13 @@
reshape_batch_size: Optional[fx.Node] = get_reshape_batch_size_as_node(
maybe_reshape
)
if not reshape_batch_size:
continue
- reshape_batch_size_inferred_source: Optional[
- fx.Node
- ] = get_reshape_batch_size_inferred_source(reshape_batch_size)
+ reshape_batch_size_inferred_source: Optional[fx.Node] = (
+ get_reshape_batch_size_inferred_source(reshape_batch_size)
+ )
if not reshape_batch_size_inferred_source:
continue
reshape_input: fx.Node = maybe_reshape.kwargs["input"]
if reshape_input == reshape_batch_size_inferred_source:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py 2024-02-02 22:58:39.858138+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py 2024-02-02 23:00:30.579683+00:00
@@ -30,16 +30,18 @@
gpu_id (int): Device ID for target GPU
dla_core (int): Core ID for target DLA core
allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
"""
- device_type: Optional[
- trt.DeviceType
- ] = None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+ device_type: Optional[trt.DeviceType] = (
+ None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+ )
gpu_id: int = -1 #: Device ID for target GPU
dla_core: int = -1 #: Core ID for target DLA core
- allow_gpu_fallback: bool = False #: Whether falling back to GPU if DLA cannot support an op should be allowed
+ allow_gpu_fallback: bool = (
+ False #: Whether falling back to GPU if DLA cannot support an op should be allowed
+ )
def __init__(self, *args: Any, **kwargs: Any):
"""__init__ Method for torch_tensorrt.Device
Device accepts one of a few construction patterns
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py 2024-02-02 22:58:39.858138+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py 2024-02-02 23:00:30.809057+00:00
@@ -26,16 +26,16 @@
class _ShapeMode(Enum):
STATIC = 0
DYNAMIC = 1
- shape_mode: Optional[
- _ShapeMode
- ] = None #: Is input statically or dynamically shaped
- shape: Optional[
- Tuple[int, ...] | Dict[str, Tuple[int, ...]]
- ] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+ shape_mode: Optional[_ShapeMode] = (
+ None #: Is input statically or dynamically shaped
+ )
+ shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = (
+ None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+ )
dtype: _enums.dtype = (
_enums.dtype.unknown
) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
_explicit_set_dtype: bool = False
format: _enums.TensorFormat = (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2024-02-02 22:58:39.858138+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2024-02-02 23:00:31.088273+00:00
@@ -26,13 +26,13 @@
from packaging import version
_LOGGER: logging.Logger = logging.getLogger(__name__)
-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
- Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+ Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)
class UnsupportedOperatorException(RuntimeError):
pass
@@ -90,13 +90,13 @@
self.input_specs_iter = 0
self._cur_node_name: Optional[str] = None
self._cur_node: Optional[torch.fx.Node] = None
self._input_names: List[str] = []
self._output_names: List[str] = []
- self._itensor_to_tensor_meta: Dict[
- trt.tensorrt.ITensor, TensorMetadata
- ] = dict()
+ self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+ dict()
+ )
self.compilation_settings = compilation_settings
# Data types for TRT Module output Tensors
self.output_dtypes = output_dtypes
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2024-02-02 22:58:39.858138+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2024-02-02 23:00:31.139260+00:00
@@ -322,17 +322,15 @@
else:
raise AssertionError(f"Cannot convert {input_val} to TRT constant")
@overload
-def get_positive_dim(dim: int, dim_size: int) -> int:
- ...
+def get_positive_dim(dim: int, dim_size: int) -> int: ...
@overload
-def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]:
- ...
+def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]: ...
def get_positive_dim(
dim: Union[int, Sequence[int]], dim_size: int
) -> Union[int, Tuple[int, ...]]:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py 2024-02-02 22:58:39.862138+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py 2024-02-02 23:00:31.535195+00:00
@@ -5,13 +5,13 @@
from torch._decomp import get_decompositions as get_torch_decompositions
from torch._ops import OpOverload, OpOverloadPacket
aten = torch.ops.aten
-_core_aten_decompositions: Dict[
- OpOverload, Callable[[Any], Any]
-] = core_aten_decompositions()
+_core_aten_decompositions: Dict[OpOverload, Callable[[Any], Any]] = (
+ core_aten_decompositions()
+)
torch_enabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
aten._adaptive_avg_pool2d_backward,
aten.addcdiv,
aten.addcdiv_,
aten.addcmul,
@@ -178,13 +178,13 @@
torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
aten._softmax.default,
}
-ENABLED_TORCH_DECOMPOSITIONS: Dict[
- OpOverload, Callable[[Any], Any]
-] = get_torch_decompositions(torch_enabled_decompositions)
+ENABLED_TORCH_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = (
+ get_torch_decompositions(torch_enabled_decompositions)
+)
TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {}
def check_decomp_set_invariants() -> None:
"""Validates no overlap between enabled and disabled decomposition sets"""
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py 2024-02-02 22:58:39.862138+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py 2024-02-02 23:00:31.573453+00:00
@@ -20,16 +20,14 @@
logger.debug(f"Graph after lowering linear:\n{gm.graph}")
return gm
-def linear_replacement() -> (
- Tuple[
- torch.fx.GraphModule,
- Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
- ]
-):
+def linear_replacement() -> Tuple[
+ torch.fx.GraphModule,
+ Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
"""Constructs the original and replacement functions for linear"""
# Original graph
def orig(
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py 2024-02-02 22:58:39.862138+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py 2024-02-02 23:00:31.626820+00:00
@@ -20,16 +20,14 @@
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")
return gm
-def view_replacement() -> (
- Tuple[
- torch.fx.GraphModule,
- Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
- ]
-):
+def view_replacement() -> Tuple[
+ torch.fx.GraphModule,
+ Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
+]:
"""Constructs the original and replacement functions for view"""
# Original graph
def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
return torch.ops.aten.view.default(input, shape)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py 2024-02-02 22:58:39.862138+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py 2024-02-02 23:00:31.636087+00:00
@@ -58,16 +58,14 @@
logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}")
return gm
-def scaled_dot_product_attention_replacement() -> (
- Tuple[
- Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]],
- Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
- ]
-):
+def scaled_dot_product_attention_replacement() -> Tuple[
+ Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]],
+ Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
"""Constructs the original and replacement functions for efficient attention"""
# Efficient Attention original graph
def efficient(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py 2024-02-02 22:58:39.862138+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py 2024-02-02 23:00:32.282502+00:00
@@ -316,13 +316,11 @@
"input": args[0],
"kernel_size": args[1],
"stride": (
args[2]
if len(args) > 2
- else (None, None)
- if len(args[1]) == 2
- else (None, None, None)
+ else (None, None) if len(args[1]) == 2 else (None, None, None)
),
"padding": (
args[3] if len(args) > 3 else (0, 0) if len(args[1]) == 2 else (0, 0, 0)
),
"dilation": (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py 2024-02-02 22:58:39.866138+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py 2024-02-02 23:00:32.309571+00:00
@@ -19,13 +19,13 @@
from .observer import Observer
from .utils import Frameworks, LowerPrecision, get_dynamic_dims, unified_dtype_converter
_LOGGER: logging.Logger = logging.getLogger(__name__)
-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
- Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+ Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)
class TRTInterpreterResult(NamedTuple):
engine: Any
input_names: Sequence[str]
@@ -73,13 +73,13 @@
self.input_specs_iter = 0
self.validate_input_specs()
self._cur_node_name: Optional[str] = None
self._input_names: List[str] = []
self._output_names: List[str] = []
- self._itensor_to_tensor_meta: Dict[
- trt.tensorrt.ITensor, TensorMetadata
- ] = dict()
+ self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+ dict()
+ )
def validate_input_specs(self):
for shape, _, _, shape_ranges, has_batch_dim in self.input_specs:
if not self.network.has_implicit_batch_dimension:
assert (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py 2024-02-02 22:58:39.866138+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py 2024-02-02 23:00:32.814525+00:00
@@ -193,13 +193,11 @@
kwargs2 = {"equal_nan": True}
if rtol:
kwargs2["rtol"] = rtol
if atol:
kwargs2["atol"] = atol
- kwargs2[
- "msg"
- ] = (
+ kwargs2["msg"] = (
lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}"
)
# If tensors are on different devices, make sure to compare
# their copies that are on the same device.
if x.get_device() != y.get_device():
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py 2024-02-02 22:58:39.866138+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py 2024-02-02 23:00:32.920329+00:00
@@ -535,13 +535,13 @@
reshape_batch_size: Optional[fx.Node] = get_reshape_batch_size_as_node(
maybe_reshape
)
if not reshape_batch_size:
continue
- reshape_batch_size_inferred_source: Optional[
- fx.Node
- ] = get_reshape_batch_size_inferred_source(reshape_batch_size)
+ reshape_batch_size_inferred_source: Optional[fx.Node] = (
+ get_reshape_batch_size_inferred_source(reshape_batch_size)
+ )
if not reshape_batch_size_inferred_source:
continue
reshape_input: fx.Node = maybe_reshape.kwargs["input"]
if reshape_input == reshape_batch_size_inferred_source:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
Type of change
Checklist: