diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index 88b8c13d10..a437057d04 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -1,7 +1,7 @@ import logging from copy import deepcopy from enum import Enum, auto -from typing import Any, Collection, Dict, Iterator, List, Optional, Set, Tuple, Union +from typing import Any, Collection, Dict, Iterator, List, Optional, Set, Union import numpy as np import torch @@ -57,9 +57,9 @@ def __init__( disable_tf32: bool = _defaults.DISABLE_TF32, assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT, sparse_weights: bool = _defaults.SPARSE_WEIGHTS, - enabled_precisions: ( - Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] - ) = _defaults.ENABLED_PRECISIONS, + enabled_precisions: Set[ + Union[torch.dtype, dtype] + ] = _defaults.ENABLED_PRECISIONS, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, make_refitable: bool = _defaults.MAKE_REFITABLE, debug: bool = _defaults.DEBUG,