Skip to content

Commit

Permalink
Added dynamic shape support to MutableTorchTensorRTModule
Browse files Browse the repository at this point in the history
  • Loading branch information
cehongwang committed Feb 13, 2025
1 parent 1e356e4 commit 9580f92
Show file tree
Hide file tree
Showing 2 changed files with 219 additions and 20 deletions.
119 changes: 99 additions & 20 deletions py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import logging
from copy import deepcopy
from enum import Enum, auto
Expand Down Expand Up @@ -41,6 +42,10 @@ def get_state(self) -> RefitFlag:
return self._state


class DynamicShapeOutOfRangeException(Exception):
pass


class MutableTorchTensorRTModule(object):
"""
Initialize a MutableTorchTensorRTModule to seamlessly manipulate it like a regular PyTorch module.
Expand All @@ -65,7 +70,7 @@ def __init__(
Union[torch.dtype, dtype]
] = _defaults.ENABLED_PRECISIONS,
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
immutable_weights: bool = False,
debug: bool = _defaults.DEBUG,
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
workspace_size: int = _defaults.WORKSPACE_SIZE,
Expand Down Expand Up @@ -189,6 +194,9 @@ def __init__(
"hardware_compatible": hardware_compatible,
"timing_cache_path": timing_cache_path,
}
self.arg_dynamic_shapes: Optional[tuple[Any]] = None
self.kwarg_dynamic_shapes: Optional[dict[Any, Any]] = None
self.total_dynamic_shape: Optional[dict[Any, Any]] = None

self.settings = CompilationSettings(**compilation_options)
self.run_info: Optional[tuple[Any, ...]] = None
Expand All @@ -203,6 +211,31 @@ def __init__(
)
self.init_finished = True

def set_dynamic_shape_hint(
self,
args_dynamic_shape: tuple[dict[Any, Any]],
kwargs_dynamic_shape: dict[str, Any],
) -> None:
assert isinstance(
args_dynamic_shape, tuple
), "args dynamic shape has to be a tuple"
assert isinstance(
kwargs_dynamic_shape, dict
), "args dynamic shape has to be a dictionary"
self.kwarg_dynamic_shapes = kwargs_dynamic_shape
self.arg_dynamic_shapes = args_dynamic_shape
self.total_dynamic_shape = self.kwarg_dynamic_shapes.copy()
signature = list(
inspect.signature(self.original_model.forward).parameters.keys()
)
for i, arg in enumerate(self.arg_dynamic_shapes):
self.total_dynamic_shape[signature[i]] = arg
self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE)

# Clear cached inputs
self.arg_inputs = tuple()
self.kwarg_inputs = {}

def store_state_dict_metadata(self) -> None:
for k, v in self.original_model.state_dict().items():
self.state_dict_metadata[k] = v.shape
Expand Down Expand Up @@ -295,6 +328,7 @@ def compile(self) -> None:
self.original_model,
self.arg_inputs,
kwargs=self.kwarg_inputs,
dynamic_shapes=self.total_dynamic_shape,
)
self.gm = dynamo_compile(
self.exp_program,
Expand All @@ -306,14 +340,26 @@ def compile(self) -> None:
torch.cuda.empty_cache()

def _validate_inputs(self, *args: Any, **kwargs: Any) -> None:
if (
not self.arg_inputs
or not MutableTorchTensorRTModule.check_inputs_equal(self.arg_inputs, args)
or not MutableTorchTensorRTModule.check_inputs_equal(
self.kwarg_inputs, kwargs
)
):
try:
if (
not self.arg_inputs
or not MutableTorchTensorRTModule.check_inputs_equal(
self.arg_inputs, args, dynamic_shapes=self.arg_dynamic_shapes
)
or not MutableTorchTensorRTModule.check_inputs_equal(
self.kwarg_inputs, kwargs, dynamic_shapes=self.kwarg_dynamic_shapes
)
):
logger.info("Input change detected.")
self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE)
self.store_inputs(args, kwargs)
except DynamicShapeOutOfRangeException as e:
logger.info("Input change detected.")
logger.warning(e)
logger.warning("Recompiling the engine with static shape")
self.arg_dynamic_shapes = None
self.kwarg_dynamic_shapes = None
self.total_dynamic_shape = None
self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE)
self.store_inputs(args, kwargs)

Expand Down Expand Up @@ -436,35 +482,68 @@ def __setattr__(self, name: str, value: Any) -> None:
def check_inputs_equal(
input1: Any,
input2: Any,
dynamic_shapes: Any = None,
) -> bool:
# TODO: Add support for dynamic shape

if isinstance(input1, (tuple, list)):
if len(input1) != len(input2):
return False
for a, b in zip(input1, input2):
for (i, a), b in zip(enumerate(input1), input2):
if type(a) != type(b):
return False
if isinstance(a, torch.Tensor) and a.shape != b.shape:
return False
elif isinstance(a, bool) and a != b:
if isinstance(a, bool) and a != b:
return False
elif isinstance(a, torch.Tensor) and a.shape != b.shape:
if dynamic_shapes is None:
return False
else:
tensor_dynamic_shape = dynamic_shapes[i]
if not MutableTorchTensorRTModule.check_tensor_shapes_with_dynamic_shapes(
a, b, tensor_dynamic_shape
):
return False

elif isinstance(input1, dict):
if input1.keys() != input2.keys():
return False
for a, b in zip(input1.values(), input2.values()):
if type(a) != type(b):
return False
if isinstance(a, torch.Tensor) and a.shape != b.shape:
for (ka, va), vb in zip(input1.items(), input2.values()):
if type(va) != type(vb):
return False
elif isinstance(a, bool) and a != b:
if isinstance(va, bool) and va != vb:
return False
elif isinstance(va, torch.Tensor) and va.shape != vb.shape:
if dynamic_shapes is None:
return False
else:
tensor_dynamic_shape = dynamic_shapes[ka]
if not MutableTorchTensorRTModule.check_tensor_shapes_with_dynamic_shapes(
va, vb, tensor_dynamic_shape
):
return False
elif isinstance(
a, (list, tuple, dict)
) and not MutableTorchTensorRTModule.check_inputs_equal(a, b):
va, (list, tuple, dict)
) and not MutableTorchTensorRTModule.check_inputs_equal(
va, vb, dynamic_shapes[ka] if dynamic_shapes else None
):
return False
return True

@staticmethod
def check_tensor_shapes_with_dynamic_shapes(
t1: torch.tensor, t2: torch.tensor, dynamic_shape: dict[int, Any]
) -> bool:
for (i, axis_0), axis_1 in zip(enumerate(t1.shape), t2.shape):
if axis_0 != axis_1:
if i not in dynamic_shape:
return False
dyn = dynamic_shape[i]
if axis_1 > dyn.max or axis_1 < dyn.min:
raise DynamicShapeOutOfRangeException(
f"The input size ({axis_1}) of dimension ({i}) is not in dynamic shape range [{dyn.max}, {dyn.max}]!"
)

return True

@staticmethod
def save(module: Any, path: str) -> None:
# Cast the object back to MutableTorchTensorRTModule to save
Expand Down
120 changes: 120 additions & 0 deletions tests/py/dynamo/runtime/test_mutable_torchtrt_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,126 @@ def test_check_output_equal():
)


@pytest.mark.unit
def test_check_input_shape_dynamic():
torch.manual_seed(0)
a = {
"a": torch.rand(10, 3),
"b": [torch.rand(10, 30), torch.rand(5, 5)],
"c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 5)]},
}
torch.manual_seed(0)
b = {
"a": torch.rand(10, 30),
"b": [torch.rand(10, 30), torch.rand(5, 5)],
"c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 5)]},
}

dim = torch.export.Dim("dim", min=1, max=50)
dynamic_shape = {"a": {1: dim}, "b": [{}, {}], "c": {"a": {}, "b": [{}, {}]}}
assertions.assertFalse(
torch_trt.MutableTorchTensorRTModule.check_inputs_equal(a, b),
msg=f"test_check_output_equal is not correct.",
)
assertions.assertTrue(
torch_trt.MutableTorchTensorRTModule.check_inputs_equal(a, b, dynamic_shape),
msg=f"test_check_output_equal is not correct.",
)


@pytest.mark.unit
def test_model_complex_dynamic_shape():
class Model(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, a, b, c=None):
x = torch.matmul(a, b)
x = torch.matmul(c["a"], c["b"][0].T)
x = 2 * c["b"][1]
return x

model = Model().eval().cuda()
inputs = [torch.rand(10, 3)]
kwargs = {
"b": torch.rand(3, 30),
"c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 3)]},
}

dim = torch.export.Dim("dim", min=1, max=50)
dim2 = torch.export.Dim("dim2", min=1, max=50)
args_dynamic_shapes = ({1: dim},)
kwarg_dynamic_shapes = {
"b": {0: dim},
"c": {"a": {}, "b": [{}, {1: dim2}]},
}
# Export the model first with custom dynamic shape constraints
# exp_program = torch.export.export(model, tuple(inputs), kwargs=k
trt_gm = torch_trt.MutableTorchTensorRTModule(model, debug=True)
trt_gm.set_dynamic_shape_hint(args_dynamic_shapes, kwarg_dynamic_shapes)
# Run inference
trt_gm(*inputs, **kwargs)

inputs_2 = [torch.rand(10, 9)]
kwargs_2 = {
"b": torch.rand(9, 30),
"c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 20)]},
}

kwargs = torch_trt.MutableTorchTensorRTModule.process_kwarg_inputs(kwargs_2)
trt_gm._validate_inputs(*inputs_2, **kwargs_2)
assertions.assertTrue(
trt_gm.refit_state.get_state() == RefitFlag.LIVE,
msg=f"Dynamic shape support is not correct.",
)
trt_gm(*inputs_2, **kwargs_2)

# Change does not align with Dynamic Shape Hint
inputs_3 = [torch.rand(7, 9)]
kwargs_3 = {
"b": torch.rand(9, 30),
"c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 20)]},
}

kwargs = torch_trt.MutableTorchTensorRTModule.process_kwarg_inputs(kwargs_3)
trt_gm._validate_inputs(*inputs_3, **kwargs_3)
assertions.assertTrue(
trt_gm.refit_state.get_state() == RefitFlag.NEEDS_RECOMPILE,
msg=f"Dynamic shape support is not correct.",
)
trt_gm(*inputs_3, **kwargs_3)

# # Stored input is changed (inputs first dimension is 7)
inputs_4 = [torch.rand(7, 20)]
kwargs_4 = {
"b": torch.rand(20, 30),
"c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 20)]},
}

kwargs = torch_trt.MutableTorchTensorRTModule.process_kwarg_inputs(kwargs_4)
trt_gm._validate_inputs(*inputs_4, **kwargs_4)
assertions.assertTrue(
trt_gm.refit_state.get_state() == RefitFlag.LIVE,
msg=f"Dynamic shape support is not correct.",
)
trt_gm(*inputs_4, **kwargs_4)

# # Change outside of the dynamic range limit
inputs_5 = [torch.rand(7, 900)]
kwargs_5 = {
"b": torch.rand(900, 30),
"c": {"a": torch.rand(10, 30), "b": [torch.rand(10, 30), torch.rand(5, 20)]},
}

kwargs = torch_trt.MutableTorchTensorRTModule.process_kwarg_inputs(kwargs_5)
trt_gm._validate_inputs(*inputs_5, **kwargs_5)
assertions.assertTrue(
trt_gm.refit_state.get_state() == RefitFlag.NEEDS_RECOMPILE,
msg=f"Dynamic shape support is not correct.",
)
trt_gm(*inputs_5, **kwargs_5)


@unittest.skipIf(
not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime,
"TorchScript Frontend is not available",
Expand Down

0 comments on commit 9580f92

Please sign in to comment.