diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index a0e570e992..cabed5b601 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -1,3 +1,4 @@ +import inspect import logging from copy import deepcopy from enum import Enum, auto @@ -41,6 +42,10 @@ def get_state(self) -> RefitFlag: return self._state +class DynamicShapeOutOfRangeException(Exception): + pass + + class MutableTorchTensorRTModule(object): """ Initialize a MutableTorchTensorRTModule to seamlessly manipulate it like a regular PyTorch module. @@ -65,7 +70,7 @@ def __init__( Union[torch.dtype, dtype] ] = _defaults.ENABLED_PRECISIONS, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS, + immutable_weights: bool = False, debug: bool = _defaults.DEBUG, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, workspace_size: int = _defaults.WORKSPACE_SIZE, @@ -189,6 +194,9 @@ def __init__( "hardware_compatible": hardware_compatible, "timing_cache_path": timing_cache_path, } + self.arg_dynamic_shapes: Optional[tuple[Any]] = None + self.kwarg_dynamic_shapes: Optional[dict[Any, Any]] = None + self.total_dynamic_shape: Optional[dict[Any, Any]] = None self.settings = CompilationSettings(**compilation_options) self.run_info: Optional[tuple[Any, ...]] = None @@ -203,6 +211,31 @@ def __init__( ) self.init_finished = True + def set_dynamic_shape_hint( + self, + args_dynamic_shape: tuple[dict[Any, Any]], + kwargs_dynamic_shape: dict[str, Any], + ) -> None: + assert isinstance( + args_dynamic_shape, tuple + ), "args dynamic shape has to be a tuple" + assert isinstance( + kwargs_dynamic_shape, dict + ), "args dynamic shape has to be a dictionary" + self.kwarg_dynamic_shapes = kwargs_dynamic_shape + self.arg_dynamic_shapes = args_dynamic_shape + self.total_dynamic_shape = self.kwarg_dynamic_shapes.copy() + signature = list( + inspect.signature(self.original_model.forward).parameters.keys() + ) + for i, arg in enumerate(self.arg_dynamic_shapes): + self.total_dynamic_shape[signature[i]] = arg + self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE) + + # Clear cached inputs + self.arg_inputs = tuple() + self.kwarg_inputs = {} + def store_state_dict_metadata(self) -> None: for k, v in self.original_model.state_dict().items(): self.state_dict_metadata[k] = v.shape @@ -295,6 +328,7 @@ def compile(self) -> None: self.original_model, self.arg_inputs, kwargs=self.kwarg_inputs, + dynamic_shapes=self.total_dynamic_shape, ) self.gm = dynamo_compile( self.exp_program, @@ -306,14 +340,26 @@ def compile(self) -> None: torch.cuda.empty_cache() def _validate_inputs(self, *args: Any, **kwargs: Any) -> None: - if ( - not self.arg_inputs - or not MutableTorchTensorRTModule.check_inputs_equal(self.arg_inputs, args) - or not MutableTorchTensorRTModule.check_inputs_equal( - self.kwarg_inputs, kwargs - ) - ): + try: + if ( + not self.arg_inputs + or not MutableTorchTensorRTModule.check_inputs_equal( + self.arg_inputs, args, dynamic_shapes=self.arg_dynamic_shapes + ) + or not MutableTorchTensorRTModule.check_inputs_equal( + self.kwarg_inputs, kwargs, dynamic_shapes=self.kwarg_dynamic_shapes + ) + ): + logger.info("Input change detected.") + self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE) + self.store_inputs(args, kwargs) + except DynamicShapeOutOfRangeException as e: logger.info("Input change detected.") + logger.warning(e) + logger.warning("Recompiling the engine with static shape") + self.arg_dynamic_shapes = None + self.kwarg_dynamic_shapes = None + self.total_dynamic_shape = None self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE) self.store_inputs(args, kwargs) @@ -436,35 +482,68 @@ def __setattr__(self, name: str, value: Any) -> None: def check_inputs_equal( input1: Any, input2: Any, + dynamic_shapes: Any = None, ) -> bool: - # TODO: Add support for dynamic shape + if isinstance(input1, (tuple, list)): if len(input1) != len(input2): return False - for a, b in zip(input1, input2): + for (i, a), b in zip(enumerate(input1), input2): if type(a) != type(b): return False - if isinstance(a, torch.Tensor) and a.shape != b.shape: - return False - elif isinstance(a, bool) and a != b: + if isinstance(a, bool) and a != b: return False + elif isinstance(a, torch.Tensor) and a.shape != b.shape: + if dynamic_shapes is None: + return False + else: + tensor_dynamic_shape = dynamic_shapes[i] + if not MutableTorchTensorRTModule.check_tensor_shapes_with_dynamic_shapes( + a, b, tensor_dynamic_shape + ): + return False elif isinstance(input1, dict): if input1.keys() != input2.keys(): return False - for a, b in zip(input1.values(), input2.values()): - if type(a) != type(b): - return False - if isinstance(a, torch.Tensor) and a.shape != b.shape: + for (ka, va), vb in zip(input1.items(), input2.values()): + if type(va) != type(vb): return False - elif isinstance(a, bool) and a != b: + if isinstance(va, bool) and va != vb: return False + elif isinstance(va, torch.Tensor) and va.shape != vb.shape: + if dynamic_shapes is None: + return False + else: + tensor_dynamic_shape = dynamic_shapes[ka] + if not MutableTorchTensorRTModule.check_tensor_shapes_with_dynamic_shapes( + va, vb, tensor_dynamic_shape + ): + return False elif isinstance( - a, (list, tuple, dict) - ) and not MutableTorchTensorRTModule.check_inputs_equal(a, b): + va, (list, tuple, dict) + ) and not MutableTorchTensorRTModule.check_inputs_equal( + va, vb, dynamic_shapes[ka] if dynamic_shapes else None + ): return False return True + @staticmethod + def check_tensor_shapes_with_dynamic_shapes( + t1: torch.tensor, t2: torch.tensor, dynamic_shape: dict[int, Any] + ) -> bool: + for (i, axis_0), axis_1 in zip(enumerate(t1.shape), t2.shape): + if axis_0 != axis_1: + if i not in dynamic_shape: + return False + dyn = dynamic_shape[i] + if axis_1 > dyn.max or axis_1 < dyn.min: + raise DynamicShapeOutOfRangeException( + f"The input size ({axis_1}) of dimension ({i}) is not in dynamic shape range [{dyn.max}, {dyn.max}]!" + ) + + return True + @staticmethod def save(module: Any, path: str) -> None: # Cast the object back to MutableTorchTensorRTModule to save diff --git a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py index ab1137e2b3..8a4b2a376e 100644 --- a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py +++ b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py @@ -36,6 +36,126 @@ def test_check_output_equal(): ) +@pytest.mark.unit +def test_check_input_shape_dynamic(): + torch.manual_seed(0) + a = { + "a": torch.rand(10, 3), + "b": [torch.rand(10, 30), torch.rand(5, 5)], + "c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 5)]}, + } + torch.manual_seed(0) + b = { + "a": torch.rand(10, 30), + "b": [torch.rand(10, 30), torch.rand(5, 5)], + "c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 5)]}, + } + + dim = torch.export.Dim("dim", min=1, max=50) + dynamic_shape = {"a": {1: dim}, "b": [{}, {}], "c": {"a": {}, "b": [{}, {}]}} + assertions.assertFalse( + torch_trt.MutableTorchTensorRTModule.check_inputs_equal(a, b), + msg=f"test_check_output_equal is not correct.", + ) + assertions.assertTrue( + torch_trt.MutableTorchTensorRTModule.check_inputs_equal(a, b, dynamic_shape), + msg=f"test_check_output_equal is not correct.", + ) + + +@pytest.mark.unit +def test_model_complex_dynamic_shape(): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b, c=None): + x = torch.matmul(a, b) + x = torch.matmul(c["a"], c["b"][0].T) + x = 2 * c["b"][1] + return x + + model = Model().eval().cuda() + inputs = [torch.rand(10, 3)] + kwargs = { + "b": torch.rand(3, 30), + "c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 3)]}, + } + + dim = torch.export.Dim("dim", min=1, max=50) + dim2 = torch.export.Dim("dim2", min=1, max=50) + args_dynamic_shapes = ({1: dim},) + kwarg_dynamic_shapes = { + "b": {0: dim}, + "c": {"a": {}, "b": [{}, {1: dim2}]}, + } + # Export the model first with custom dynamic shape constraints + # exp_program = torch.export.export(model, tuple(inputs), kwargs=k + trt_gm = torch_trt.MutableTorchTensorRTModule(model, debug=True) + trt_gm.set_dynamic_shape_hint(args_dynamic_shapes, kwarg_dynamic_shapes) + # Run inference + trt_gm(*inputs, **kwargs) + + inputs_2 = [torch.rand(10, 9)] + kwargs_2 = { + "b": torch.rand(9, 30), + "c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 20)]}, + } + + kwargs = torch_trt.MutableTorchTensorRTModule.process_kwarg_inputs(kwargs_2) + trt_gm._validate_inputs(*inputs_2, **kwargs_2) + assertions.assertTrue( + trt_gm.refit_state.get_state() == RefitFlag.LIVE, + msg=f"Dynamic shape support is not correct.", + ) + trt_gm(*inputs_2, **kwargs_2) + + # Change does not align with Dynamic Shape Hint + inputs_3 = [torch.rand(7, 9)] + kwargs_3 = { + "b": torch.rand(9, 30), + "c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 20)]}, + } + + kwargs = torch_trt.MutableTorchTensorRTModule.process_kwarg_inputs(kwargs_3) + trt_gm._validate_inputs(*inputs_3, **kwargs_3) + assertions.assertTrue( + trt_gm.refit_state.get_state() == RefitFlag.NEEDS_RECOMPILE, + msg=f"Dynamic shape support is not correct.", + ) + trt_gm(*inputs_3, **kwargs_3) + + # # Stored input is changed (inputs first dimension is 7) + inputs_4 = [torch.rand(7, 20)] + kwargs_4 = { + "b": torch.rand(20, 30), + "c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 20)]}, + } + + kwargs = torch_trt.MutableTorchTensorRTModule.process_kwarg_inputs(kwargs_4) + trt_gm._validate_inputs(*inputs_4, **kwargs_4) + assertions.assertTrue( + trt_gm.refit_state.get_state() == RefitFlag.LIVE, + msg=f"Dynamic shape support is not correct.", + ) + trt_gm(*inputs_4, **kwargs_4) + + # # Change outside of the dynamic range limit + inputs_5 = [torch.rand(7, 900)] + kwargs_5 = { + "b": torch.rand(900, 30), + "c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 20)]}, + } + + kwargs = torch_trt.MutableTorchTensorRTModule.process_kwarg_inputs(kwargs_5) + trt_gm._validate_inputs(*inputs_5, **kwargs_5) + assertions.assertTrue( + trt_gm.refit_state.get_state() == RefitFlag.NEEDS_RECOMPILE, + msg=f"Dynamic shape support is not correct.", + ) + trt_gm(*inputs_5, **kwargs_5) + + @unittest.skipIf( not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available",