diff --git a/examples/dynamo/README.rst b/examples/dynamo/README.rst index a9eab6d698..c4d2baf0e4 100644 --- a/examples/dynamo/README.rst +++ b/examples/dynamo/README.rst @@ -21,4 +21,4 @@ Model Zoo * :ref:`_torch_export_gpt2`: Compiling a GPT2 model using AOT workflow (`ir=dynamo`) * :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`) * :ref:`_torch_export_sam2`: Compiling SAM2 model using AOT workflow (`ir=dynamo`) -* :ref:`_torch_export_flux_dev`: Compiling FLUX.1-dev model using AOT workflow (`ir=dynamo`) \ No newline at end of file +* :ref:`_torch_export_flux_dev`: Compiling FLUX.1-dev model using AOT workflow (`ir=dynamo`) diff --git a/examples/dynamo/mutable_torchtrt_module_example.py b/examples/dynamo/mutable_torchtrt_module_example.py index 8b62855c32..766fd029ae 100644 --- a/examples/dynamo/mutable_torchtrt_module_example.py +++ b/examples/dynamo/mutable_torchtrt_module_example.py @@ -14,6 +14,7 @@ 1. Sample workflow of Mutable Torch TensorRT Module with ResNet 18 2. Save a Mutable Torch TensorRT Module 3. Integration with Huggingface pipeline in LoRA use case +4. Usage of dynamic shape with Mutable Torch TensorRT Module """ import numpy as np @@ -63,7 +64,7 @@ # Saving Mutable Torch TensorRT Module # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# Currently, saving is only enabled for C++ runtime, not python runtime. +# Currently, saving is only when "use_python" = False in settings torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl") reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl") @@ -71,8 +72,6 @@ # Stable Diffusion with Huggingface # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# The LoRA checkpoint is from https://civitai.com/models/12597/moxin - from diffusers import DiffusionPipeline with torch.no_grad(): @@ -83,15 +82,13 @@ "immutable_weights": False, } - model_id = "runwayml/stable-diffusion-v1-5" + model_id = "stabilityai/stable-diffusion-xl-base-1.0" device = "cuda:0" - prompt = "house in forest, shuimobysim, wuchangshuo, best quality" - negative = "(worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, out of focus, cloudy, (watermark:2)," + prompt = "cinematic photo elsa, police uniform , . 35mm photograph, film, bokeh, professional, 4k, highly detailed" + negative = "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly, nude" - pipe = DiffusionPipeline.from_pretrained( - model_id, revision="fp16", torch_dtype=torch.float16 - ) + pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) pipe.to(device) # The only extra line you need @@ -103,7 +100,7 @@ # Standard Huggingface LoRA loading procedure pipe.load_lora_weights( "stablediffusionapi/load_lora_embeddings", - weight_name="moxin.safetensors", + weight_name="all-disney-princess-xl-lo.safetensors", adapter_name="lora1", ) pipe.set_adapters(["lora1"], adapter_weights=[1]) @@ -113,3 +110,45 @@ # Refit triggered image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0] image.save("./with_LoRA_mutable.jpg") + + +# %% +# Use Mutable Torch TensorRT module with dynamic shape +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b, c={}): + x = torch.matmul(a, b) + x = torch.matmul(c["a"], c["b"].T) + print(c["b"][0]) + x = 2 * c["b"] + return x + + +device = "cuda:0" +model = Model().eval().to(device) +inputs = (torch.rand(10, 3).to(device), torch.rand(3, 30).to(device)) +kwargs = { + "c": {"a": torch.rand(10, 30).to(device), "b": torch.rand(10, 30).to(device)}, +} +dim_0 = torch.export.Dim("dim", min=1, max=50) +dim_1 = torch.export.Dim("dim", min=1, max=50) +dim_2 = torch.export.Dim("dim2", min=1, max=50) +args_dynamic_shapes = ({1: dim_1}, {0: dim_0}) +kwarg_dynamic_shapes = { + "c": {"a": {}, "b": {0: dim_2}}, +} +# Export the model first with custom dynamic shape constraints +model = torch_trt.MutableTorchTensorRTModule(model, debug=True, min_block_size=1) +model.set_expected_dynamic_shape_range(args_dynamic_shapes, kwarg_dynamic_shapes) +# Compile +model(*inputs, **kwargs) +# Change input shape +inputs_2 = (torch.rand(10, 5).to(device), torch.rand(10, 30).to(device)) +kwargs_2 = { + "c": {"a": torch.rand(10, 30).to(device), "b": torch.rand(5, 30).to(device)}, +} +# Run without recompiling +model(*inputs_2, **kwargs_2) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 64c2382582..96fc6daad2 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -395,9 +395,12 @@ def refit_module_weights( try: weight_name_map = compiled_submodule.weight_name_map except AttributeError: - logger.warning( - "The module was compiled with an old version of Torch-TensorRT. Rebuilding the weight map." - ) + if not isinstance( + compiled_submodule, torch.fx.graph_module.GraphModule + ): + logger.warning( + "The module was compiled with an old version of Torch-TensorRT. Rebuilding the weight map." + ) if not weight_name_map: use_weight_map_cache = False logger.warning( diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 489c64190f..2f35a6d124 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -375,7 +375,10 @@ def _construct_trt_network_def(self) -> None: @staticmethod def find_weight( - weight_name: str, np_map: dict[str, Any], state_dict: dict[str, Any] + weight_name: str, + np_map: dict[str, Any], + state_dict: dict[str, Any], + device: torch.device, ) -> str: """ We need to build map from engine weight name to state_dict weight name. @@ -385,19 +388,21 @@ def find_weight( np_map: the map from weight name to np values in INetworkDefinition state_dict: state of the graph module """ - network_weight = torch.from_numpy(np_map[weight_name]).cuda() + network_weight = torch.from_numpy(np_map[weight_name]).to(device) for sd_w_name, sd_weight in state_dict.items(): - if TRTInterpreter.check_weight_equal(sd_weight, network_weight): + if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device): del state_dict[sd_w_name] return sd_w_name return "" @staticmethod def check_weight_equal( - sd_weight: torch.tensor, network_weight: Union[torch.Tensor, np.ndarray] + sd_weight: torch.tensor, + network_weight: Union[torch.Tensor, np.ndarray], + device: torch.device, ) -> Any: if not isinstance(network_weight, torch.Tensor): - network_weight = torch.from_numpy(network_weight).cuda() + network_weight = torch.from_numpy(network_weight).to(device) try: return sd_weight.shape == network_weight.shape and torch.all( torch.abs(sd_weight - network_weight) < 0.01 @@ -530,10 +535,10 @@ def _save_weight_mapping(self) -> None: # There is no direct connection in batch_norm layer. So skip it pass elif sd_weight_name not in sd or not TRTInterpreter.check_weight_equal( - sd[sd_weight_name], np_map[engine_weight_name] + sd[sd_weight_name], np_map[engine_weight_name], torch_device ): weight_name_map[engine_weight_name] = TRTInterpreter.find_weight( - engine_weight_name, np_map, sd + engine_weight_name, np_map, sd, torch_device ) if ( weight_name_map[engine_weight_name] != "" diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index a0e570e992..ff39746671 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -1,3 +1,4 @@ +import inspect import logging from copy import deepcopy from enum import Enum, auto @@ -41,6 +42,10 @@ def get_state(self) -> RefitFlag: return self._state +class DynamicShapeOutOfRangeException(Exception): + pass + + class MutableTorchTensorRTModule(object): """ Initialize a MutableTorchTensorRTModule to seamlessly manipulate it like a regular PyTorch module. @@ -65,7 +70,7 @@ def __init__( Union[torch.dtype, dtype] ] = _defaults.ENABLED_PRECISIONS, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS, + immutable_weights: bool = False, debug: bool = _defaults.DEBUG, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, workspace_size: int = _defaults.WORKSPACE_SIZE, @@ -189,11 +194,13 @@ def __init__( "hardware_compatible": hardware_compatible, "timing_cache_path": timing_cache_path, } + self.arg_dynamic_shapes: Optional[tuple[Any]] = None + self.kwarg_dynamic_shapes: Optional[dict[Any, Any]] = None self.settings = CompilationSettings(**compilation_options) self.run_info: Optional[tuple[Any, ...]] = None self.state_dict_metadata: dict[str, torch.Size] = {} - self.store_state_dict_metadata() + self._store_state_dict_metadata() cls = self.__class__ self.__class__ = type( @@ -203,7 +210,64 @@ def __init__( ) self.init_finished = True - def store_state_dict_metadata(self) -> None: + def set_expected_dynamic_shape_range( + self, + args_dynamic_shape: tuple[dict[Any, Any]], + kwargs_dynamic_shape: dict[str, Any], + ) -> None: + """ + Set the dynamic shape range. The shape hint should EXACTLY follow arg_inputs and kwarg_inputs passed to the forward function + and should not omit any entries. If the dynamic shape is not required for the input, an empty dictionary should be given + as the shape hint for that input. + + Example: + def forward(a, b, c=0, d=0): + pass + + seq_len = torch.export.Dim("seq_len", min=1, max=10) + args_dynamic_shape = ({0: seq_len}, {}) # b does not have a dynamic shape + kwargs_dynamic_shape = {'c': {0, seq_len}, 'd': {}} # d does not have a dynamic shape + # Later when you call the function + forward(*(a, b), **{c:..., d:...}) + + + Arguments: + args_dynamic_shape (tuple[dict[Any, Any]]): Dynamic shape hint for the arg_inputs, + kwargs_dynamic_shape: (dict[str, Any]): Dynamic shape hint for the kwarg_inputs + """ + assert isinstance( + args_dynamic_shape, tuple + ), "args dynamic shape has to be a tuple" + assert isinstance( + kwargs_dynamic_shape, dict + ), "args dynamic shape has to be a dictionary" + self.kwarg_dynamic_shapes = kwargs_dynamic_shape + self.arg_dynamic_shapes = args_dynamic_shape + + # Clear cached inputs + self.arg_inputs = tuple() + self.kwarg_inputs = {} + + self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE) + + def _get_total_dynamic_shapes(self) -> dict[str, Any] | None: + if not self.arg_dynamic_shapes and not self.kwarg_dynamic_shapes: + return None + total_dynamic_shape = {} + if self.arg_dynamic_shapes: + signature = list( + inspect.signature(self.original_model.forward).parameters.keys() + ) + for i, arg in enumerate(self.arg_dynamic_shapes): + total_dynamic_shape[signature[i]] = arg + + if self.kwarg_dynamic_shapes: + for kwargs, kwargs_dynamic_shape in self.kwarg_dynamic_shapes.items(): + total_dynamic_shape[kwargs] = kwargs_dynamic_shape + + return total_dynamic_shape + + def _store_state_dict_metadata(self) -> None: for k, v in self.original_model.state_dict().items(): self.state_dict_metadata[k] = v.shape @@ -295,6 +359,7 @@ def compile(self) -> None: self.original_model, self.arg_inputs, kwargs=self.kwarg_inputs, + dynamic_shapes=self._get_total_dynamic_shapes(), ) self.gm = dynamo_compile( self.exp_program, @@ -306,28 +371,75 @@ def compile(self) -> None: torch.cuda.empty_cache() def _validate_inputs(self, *args: Any, **kwargs: Any) -> None: - if ( - not self.arg_inputs - or not MutableTorchTensorRTModule.check_inputs_equal(self.arg_inputs, args) - or not MutableTorchTensorRTModule.check_inputs_equal( - self.kwarg_inputs, kwargs - ) - ): + + if not self.arg_inputs: + logger.info("First time compilation initiated. This may take some time.") + self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE) + self._store_inputs(args, kwargs) + if self.arg_dynamic_shapes or self.kwarg_dynamic_shapes: + if not self._validates_dynamic_hints(): + logger.warning( + "Invalid dynamic shape hint. Compiling module for the provided input shapes (static)" + ) + self.arg_dynamic_shapes = None + self.kwarg_dynamic_shapes = None + return + + # If input does not equal or does not fall into dynamic shape range, recompile the engine + try: + if not MutableTorchTensorRTModule._check_inputs_shape( + self.arg_inputs, args, dynamic_shapes=self.arg_dynamic_shapes + ) or not MutableTorchTensorRTModule._check_inputs_shape( + self.kwarg_inputs, kwargs, dynamic_shapes=self.kwarg_dynamic_shapes + ): + logger.info("Input change detected.") + self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE) + self._store_inputs(args, kwargs) + except DynamicShapeOutOfRangeException as e: logger.info("Input change detected.") + logger.warning(e) + logger.warning( + "Provided inputs are outside the set expected shape range, recompiling module for the provided input shapes (static)" + ) + self.arg_dynamic_shapes = None + self.kwarg_dynamic_shapes = None self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE) - self.store_inputs(args, kwargs) + self._store_inputs(args, kwargs) + + def _validates_dynamic_hints(self) -> bool: + if self.arg_dynamic_shapes is None: + if self.arg_inputs: + logger.warning("arg_dynamic_shape is not provided!") + else: + if len(self.arg_dynamic_shapes) != len(self.arg_inputs): + logger.warning( + f"Warning: The length of arg_inputs is {len(self.arg_inputs)} but the length of arg_dynamic_shape is {len(self.arg_dynamic_shapes)}!" + ) + return False + + if self.kwarg_dynamic_shapes is None: + if self.kwarg_inputs: + logger.warning("kwarg_dynamic_shape is not provided!") + else: + if self.kwarg_dynamic_shapes.keys() != self.kwarg_inputs.keys(): + logger.warning( + f"kwarg_inputs has {list(self.kwarg_inputs.keys())} but kwarg_dynamic_shape has {list(self.kwarg_dynamic_shapes.keys())}!" + ) + return False + + return True - def store_inputs(self, arg_inputs: Any, kwarg_inputs: Any) -> None: + def _store_inputs(self, arg_inputs: Any, kwarg_inputs: Any) -> None: self.arg_inputs = arg_inputs self.kwarg_inputs = kwarg_inputs @staticmethod - def process_kwarg_inputs(inputs: Any) -> Any: + def _process_kwarg_inputs(inputs: Any) -> Any: # Process kwarg inputs to be acceptable for Torch-TensorRT if isinstance(inputs, dict): # None should be excluded. AOT compile also does not allow dynamic control flow, bool is also excluded. return { - k: MutableTorchTensorRTModule.process_kwarg_inputs(v) + k: MutableTorchTensorRTModule._process_kwarg_inputs(v) for k, v in inputs.items() if (v is not None and not isinstance(v, bool)) } @@ -338,7 +450,10 @@ def process_kwarg_inputs(inputs: Any) -> Any: elif isinstance(inputs, (list, tuple)): if None not in inputs: return type(inputs)( - [MutableTorchTensorRTModule.process_kwarg_inputs(v) for v in inputs] + [ + MutableTorchTensorRTModule._process_kwarg_inputs(v) + for v in inputs + ] ) raise ValueError( @@ -348,7 +463,7 @@ def process_kwarg_inputs(inputs: Any) -> Any: def forward(self, *args: Any, **kwargs: Any) -> Any: # Step 1: Check whether the input shape has changed - kwargs = MutableTorchTensorRTModule.process_kwarg_inputs(kwargs) + kwargs = MutableTorchTensorRTModule._process_kwarg_inputs(kwargs) self._validate_inputs(*args, **kwargs) # Step 2: If the flag is unknown, it could be a recompile or refit. @@ -360,7 +475,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: if self.refit_state.get_state() == RefitFlag.NEEDS_RECOMPILE: logger.info("(Re)Compiling the engine...") self.compile() - self.store_state_dict_metadata() + self._store_state_dict_metadata() self.refit_state.set_state(RefitFlag.LIVE) elif self.refit_state.get_state() == RefitFlag.NEEDS_REFIT: @@ -371,7 +486,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: logger.error(e) logger.error("Model refit failed. Recompiling the graph module.") self.compile() - self.store_state_dict_metadata() + self._store_state_dict_metadata() self.refit_state.set_state(RefitFlag.LIVE) result = self.gm(*args, **kwargs) @@ -381,7 +496,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: def to(self, device: str) -> None: logger.warning("Original PyTorch model is moved. CPU offload may failed.") - self.orignial_model.to(device) + self.original_model.to(device) def __deepcopy__(self, memo: Any) -> Any: cls = self.__class__ @@ -433,36 +548,78 @@ def __setattr__(self, name: str, value: Any) -> None: object.__setattr__(self, name, value) @staticmethod - def check_inputs_equal( + def _check_inputs_shape( input1: Any, input2: Any, + dynamic_shapes: Any = None, ) -> bool: - # TODO: Add support for dynamic shape + if isinstance(input1, (tuple, list)): if len(input1) != len(input2): return False - for a, b in zip(input1, input2): + for (i, a), b in zip(enumerate(input1), input2): if type(a) != type(b): return False - if isinstance(a, torch.Tensor) and a.shape != b.shape: - return False - elif isinstance(a, bool) and a != b: + if isinstance(a, bool) and a != b: return False + elif isinstance(a, torch.Tensor) and a.shape != b.shape: + if dynamic_shapes is None: + logger.warning( + "Dynamic shape is not properly set but the input shape is changed!" + ) + return False + else: + tensor_dynamic_shape = dynamic_shapes[i] + if not MutableTorchTensorRTModule._check_tensor_shapes_with_dynamic_shapes( + a, b, tensor_dynamic_shape + ): + return False elif isinstance(input1, dict): if input1.keys() != input2.keys(): return False - for a, b in zip(input1.values(), input2.values()): - if type(a) != type(b): + for (ka, va), vb in zip(input1.items(), input2.values()): + if type(va) != type(vb): return False - if isinstance(a, torch.Tensor) and a.shape != b.shape: - return False - elif isinstance(a, bool) and a != b: + if isinstance(va, bool) and va != vb: return False + elif isinstance(va, torch.Tensor) and va.shape != vb.shape: + if dynamic_shapes is None: + logger.warning( + "Dynamic shape is not properly set but the input shape is changed!" + ) + return False + else: + tensor_dynamic_shape = dynamic_shapes[ka] + if not MutableTorchTensorRTModule._check_tensor_shapes_with_dynamic_shapes( + va, vb, tensor_dynamic_shape + ): + return False elif isinstance( - a, (list, tuple, dict) - ) and not MutableTorchTensorRTModule.check_inputs_equal(a, b): + va, (list, tuple, dict) + ) and not MutableTorchTensorRTModule._check_inputs_shape( + va, vb, dynamic_shapes[ka] if dynamic_shapes else None + ): + return False + return True + + @staticmethod + def _check_tensor_shapes_with_dynamic_shapes( + t1: torch.tensor, t2: torch.tensor, dynamic_shape: dict[int, Any] + ) -> bool: + for (i, axis_0), axis_1 in zip(enumerate(t1.shape), t2.shape): + if axis_0 != axis_1: + if i not in dynamic_shape: + logger.warning( + "Dynamic shape does not include the axis on which input changes!" + ) return False + dyn = dynamic_shape[i] + if axis_1 > dyn.max or axis_1 < dyn.min: + raise DynamicShapeOutOfRangeException( + f"The input size ({axis_1}) of dimension ({i}) is not in dynamic shape range [{dyn.max}, {dyn.max}]!" + ) + return True @staticmethod diff --git a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py index ab1137e2b3..d3ef4ee245 100644 --- a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py +++ b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py @@ -35,6 +35,160 @@ def test_check_output_equal(): msg=f"test_check_output_equal is not correct.", ) + torch.manual_seed(1) + c = { + "a": torch.rand(10, 30), + "b": [torch.rand(10, 30), torch.rand(5, 5)], + "c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 5)]}, + } + assertions.assertFalse( + check_output_equal(a, c), + msg=f"test_check_output_equal is not correct.", + ) + + +@pytest.mark.unit +def test_check_input_shape_dynamic(): + torch.manual_seed(0) + a = { + "a": torch.rand(10, 3), + "b": [torch.rand(10, 30), torch.rand(5, 5)], + "c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 5)]}, + } + torch.manual_seed(0) + b = { + "a": torch.rand(10, 30), + "b": [torch.rand(10, 30), torch.rand(5, 5)], + "c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 5)]}, + } + + dim = torch.export.Dim("dim", min=1, max=50) + dynamic_shape = {"a": {1: dim}, "b": [{}, {}], "c": {"a": {}, "b": [{}, {}]}} + assertions.assertFalse( + torch_trt.MutableTorchTensorRTModule._check_inputs_shape(a, b), + msg=f"test_check_output_equal is not correct.", + ) + assertions.assertTrue( + torch_trt.MutableTorchTensorRTModule._check_inputs_shape(a, b, dynamic_shape), + msg=f"test_check_output_equal is not correct.", + ) + + +@pytest.mark.unit +def test_model_complex_dynamic_shape(): + device = "cuda:0" + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b, c=None): + x = torch.matmul(a, b) + x = torch.matmul(c["a"], c["b"][0].T) + x = 2 * c["b"][1] + return x + + model = Model().eval().to(device) + inputs = [torch.rand(10, 3).to(device)] + kwargs = { + "b": torch.rand(3, 30).to(device), + "c": { + "a": torch.rand(10, 30).to(device), + "b": [torch.rand(10, 30).to(device), torch.rand(5, 3).to(device)], + }, + } + + dim = torch.export.Dim("dim", min=1, max=50) + dim2 = torch.export.Dim("dim2", min=1, max=50) + args_dynamic_shapes = ({1: dim},) + kwarg_dynamic_shapes = { + "b": {0: dim}, + "c": {"a": {}, "b": [{}, {1: dim2}]}, + } + # Export the model first with custom dynamic shape constraints + trt_gm = torch_trt.MutableTorchTensorRTModule(model, debug=True, min_block_size=1) + trt_gm.set_expected_dynamic_shape_range(args_dynamic_shapes, kwarg_dynamic_shapes) + # Run inference + trt_gm(*inputs, **kwargs) + + inputs_2 = [torch.rand(10, 9).to(device)] + kwargs_2 = { + "b": torch.rand(9, 30).to(device), + "c": { + "a": torch.rand(10, 30).to(device), + "b": [torch.rand(10, 30).to(device), torch.rand(5, 20).to(device)], + }, + } + + kwargs = torch_trt.MutableTorchTensorRTModule._process_kwarg_inputs(kwargs_2) + trt_gm._validate_inputs(*inputs_2, **kwargs_2) + assertions.assertTrue( + trt_gm.refit_state.get_state() == RefitFlag.LIVE, + msg=f"Dynamic shape support of inputs_2 is not correct.", + ) + trt_gm(*inputs_2, **kwargs_2) + + # Change does not align with Dynamic Shape Hint + inputs_3 = [torch.rand(7, 9).to(device)] + kwargs_3 = { + "b": torch.rand(9, 30).to(device), + "c": { + "a": torch.rand(10, 30).to(device), + "b": [torch.rand(10, 30).to(device), torch.rand(5, 20).to(device)], + }, + } + + kwargs = torch_trt.MutableTorchTensorRTModule._process_kwarg_inputs(kwargs_3) + trt_gm._validate_inputs(*inputs_3, **kwargs_3) + assertions.assertTrue( + trt_gm.refit_state.get_state() == RefitFlag.NEEDS_RECOMPILE, + msg=f"Dynamic shape support of inputs_3 is not correct.", + ) + trt_gm(*inputs_3, **kwargs_3) + + # # Stored input is changed (inputs first dimension is 7) + inputs_4 = [torch.rand(7, 20).to(device)] + kwargs_4 = { + "b": torch.rand(20, 30).to(device), + "c": { + "a": torch.rand(10, 30).to(device), + "b": [torch.rand(10, 30).to(device), torch.rand(5, 20).to(device)], + }, + } + + kwargs = torch_trt.MutableTorchTensorRTModule._process_kwarg_inputs(kwargs_4) + trt_gm._validate_inputs(*inputs_4, **kwargs_4) + assertions.assertTrue( + trt_gm.refit_state.get_state() == RefitFlag.LIVE, + msg=f"Dynamic shape support of inputs_4 is not correct.", + ) + trt_gm(*inputs_4, **kwargs_4) + + # # Change outside of the dynamic range limit + inputs_5 = [torch.rand(7, 900).to(device)] + kwargs_5 = { + "b": torch.rand(900, 30).to(device), + "c": { + "a": torch.rand(10, 30).to(device), + "b": [torch.rand(10, 30).to(device), torch.rand(5, 20).to(device)], + }, + } + + kwargs = torch_trt.MutableTorchTensorRTModule._process_kwarg_inputs(kwargs_5) + trt_gm._validate_inputs(*inputs_5, **kwargs_5) + assertions.assertTrue( + trt_gm.refit_state.get_state() == RefitFlag.NEEDS_RECOMPILE, + msg=f"Dynamic shape support of inputs_5 is not correct.", + ) + assertions.assertTrue( + trt_gm.arg_dynamic_shapes == None, + msg=f"Dynamic shape support of inputs_5 is not correct.", + ) + assertions.assertTrue( + trt_gm.kwarg_dynamic_shapes == None, + msg=f"Dynamic shape support of inputs_5 is not correct.", + ) + @unittest.skipIf( not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, @@ -188,7 +342,7 @@ def test_resnet18_modify_attribute_no_refit(): for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): assertions.assertTrue( torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), - msg=f"The output of refitted Mutable Module is not correct.", + msg=f"The output of original and refitted Mutable Module is not the same.", ) # # Clean up model env @@ -255,7 +409,7 @@ def forward(self, x, b=5, c=None, d=None): ) assertions.assertTrue( check_output_equal(expected_outputs, refitted_outputs), - msg=f"The output of saved and reloaded Mutable Module is not correct.", + msg=f"The output of original and refitted Mutable Module is not the same.", ) # Clean up model env @@ -318,7 +472,7 @@ def set_weights(self): expected_outputs, refitted_outputs = model(*args), mutable_module(*args) assertions.assertTrue( check_output_equal(expected_outputs, refitted_outputs), - msg=f"The output of saved and reloaded Mutable Module is not correct.", + msg=f"The output of original and refitted Mutable Module is not the same.", ) # Clean up model env @@ -381,7 +535,7 @@ def set_layer(self): expected_outputs, refitted_outputs = model(*args), mutable_module(*args) assertions.assertTrue( check_output_equal(expected_outputs, refitted_outputs), - msg=f"The output of saved and reloaded Mutable Module is not correct.", + msg=f"The output of original and refitted Mutable Module is not the same.", ) # Clean up model env @@ -451,7 +605,7 @@ def forward(self, x, b=5, c=None, d=None): ) assertions.assertTrue( check_output_equal(expected_outputs, refitted_outputs), - msg=f"The output of saved and reloaded Mutable Module is not correct.", + msg=f"The output of original and refitted Mutable Module is not the same.", ) # Clean up model env