From 42234e89d0e6e740d4c3ffe4ad395010a3144cd5 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 12 Feb 2025 11:05:03 -0800 Subject: [PATCH 1/5] fix: do not cast bias nodes for addmm FP32 accumulation --- .../lowering/passes/accumulate_fp32_matmul.py | 28 ++++++++++++- .../lowering/test_aten_lowering_passes.py | 42 +++++++++++++++++++ 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py b/py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py index 5ffdf08b7d..724aca2c04 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py @@ -9,17 +9,41 @@ logger = logging.getLogger(__name__) +def split_addmm_nodes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + target = torch.ops.aten.addmm.default + addmm_nodes = [node for node in gm.graph.nodes if node.target == target] + for addmm_node in addmm_nodes: + bias, mat1, mat2 = addmm_node.all_input_nodes + + with gm.graph.inserting_before(addmm_node): + mm_node = gm.graph.call_function( + torch.ops.aten.mm.default, + args=(mat1, mat2), + ) + add_node = gm.graph.call_function( + torch.ops.aten.add.Tensor, + args=(bias, mm_node), + ) + + addmm_node.replace_all_uses_with(add_node, propagate_meta=True) + gm.graph.erase_node(addmm_node) + + return gm + + def accumulate_fp32_matmul( gm: torch.fx.GraphModule, settings: CompilationSettings ) -> torch.fx.GraphModule: - """Replace a matmul layer with fp32 accumulation nodes""" + """Add cast to FP32/16 nodes around a matmul layer. This pattern is detected by TensorRT and will enable FP32 accumulation during execution.""" if settings.use_fp32_acc: matmul_targets = [ torch.ops.aten.mm.default, torch.ops.aten.bmm.default, - torch.ops.aten.addmm.default, ] + # Split torch.addmm nodes into add + mm and only add cast nodes around mm nodes + split_addmm_nodes(gm) + matmul_nodes = [ node for node in gm.graph.nodes if node.target in matmul_targets ] diff --git a/tests/py/dynamo/lowering/test_aten_lowering_passes.py b/tests/py/dynamo/lowering/test_aten_lowering_passes.py index 76d47d24bd..dded14aff2 100644 --- a/tests/py/dynamo/lowering/test_aten_lowering_passes.py +++ b/tests/py/dynamo/lowering/test_aten_lowering_passes.py @@ -269,6 +269,48 @@ def forward(self, input, weight): ) torch._dynamo.reset() + def test_fp32_acc_for_addmm(self): + class FP32Acc(torch.nn.Module): + def forward(self, input, mat1, mat2): + out = torch.ops.aten.addmm.default(input, mat1, mat2) + return out + + inputs = [ + torch.rand((3, 5)).cuda(), + torch.rand((3, 4)).cuda(), + torch.rand((4, 5)).cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(FP32Acc()) + expected_ops = { + torch.ops.aten._to_copy.default, + torch.ops.aten.mm.default, + torch.ops.aten.add.Tensor, + } + unexpected_ops = {} + + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + use_fp32_acc=True, + ) + + self.assertEqual( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEqual( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + torch._dynamo.reset() + class TestLowerEfficientAttention(TestCase): def test_lower_efficient_attention(self): From 15503d0dbe20bf828025bdfdc775e55378aaf911 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 13 Feb 2025 15:32:04 -0800 Subject: [PATCH 2/5] fix: Fix assert pass --- examples/dynamo/torch_export_flux_dev.py | 4 ++++ .../dynamo/lowering/passes/remove_assert_scalar.py | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/dynamo/torch_export_flux_dev.py b/examples/dynamo/torch_export_flux_dev.py index 25b2fc6d2e..7ebe118196 100644 --- a/examples/dynamo/torch_export_flux_dev.py +++ b/examples/dynamo/torch_export_flux_dev.py @@ -63,6 +63,8 @@ "txt_ids": {0: SEQ_LEN}, "img_ids": {0: IMG_ID}, "guidance": {0: BATCH}, + "joint_attention_kwargs": {}, + "return_dict": None } # The guidance factor is of type torch.float32 dummy_inputs = { @@ -79,6 +81,8 @@ "txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE), "img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE), "guidance": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE), + "joint_attention_kwargs": {}, + "return_dict": False } # This will create an exported program which is going to be compiled with Torch-TensorRT ep = _export( diff --git a/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_scalar.py b/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_scalar.py index 67d2ba6690..7db3a3f9ba 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_scalar.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_scalar.py @@ -17,14 +17,14 @@ def remove_assert_scalar( for node in gm.graph.nodes: if ( node.target == torch.ops.aten._assert_scalar.default - or node == torch.ops.aten._assert_tensor_metadata.default + or node.target == torch.ops.aten._assert_tensor_metadata.default ): gm.graph.erase_node(node) count += 1 if count > 0: gm = clean_up_graph_after_modifications(gm) - + logger.debug(f"Removed {count} assert_scalar nodes:\n{gm.graph}") return gm From 684f424ad8abe9b5f9c3b2145d9ab00720c1f7ca Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 20 Feb 2025 16:15:22 -0800 Subject: [PATCH 3/5] chore: updates --- .../dynamo/lowering/_decomposition_groups.py | 1 + .../lowering/passes/accumulate_fp32_matmul.py | 25 ------------------- 2 files changed, 1 insertion(+), 25 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index 825be75076..f3a510343a 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -17,6 +17,7 @@ aten.addcmul, aten.addcmul_, aten.addr, + aten.addmm, aten.aminmax, aten.arange.default, aten.arange.start, diff --git a/py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py b/py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py index 724aca2c04..50a79bad34 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py @@ -9,28 +9,6 @@ logger = logging.getLogger(__name__) -def split_addmm_nodes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: - target = torch.ops.aten.addmm.default - addmm_nodes = [node for node in gm.graph.nodes if node.target == target] - for addmm_node in addmm_nodes: - bias, mat1, mat2 = addmm_node.all_input_nodes - - with gm.graph.inserting_before(addmm_node): - mm_node = gm.graph.call_function( - torch.ops.aten.mm.default, - args=(mat1, mat2), - ) - add_node = gm.graph.call_function( - torch.ops.aten.add.Tensor, - args=(bias, mm_node), - ) - - addmm_node.replace_all_uses_with(add_node, propagate_meta=True) - gm.graph.erase_node(addmm_node) - - return gm - - def accumulate_fp32_matmul( gm: torch.fx.GraphModule, settings: CompilationSettings ) -> torch.fx.GraphModule: @@ -41,9 +19,6 @@ def accumulate_fp32_matmul( torch.ops.aten.bmm.default, ] - # Split torch.addmm nodes into add + mm and only add cast nodes around mm nodes - split_addmm_nodes(gm) - matmul_nodes = [ node for node in gm.graph.nodes if node.target in matmul_targets ] From 4d68dd0c8b2d0836423fed9df244853a7cb6db8c Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 20 Feb 2025 17:12:55 -0800 Subject: [PATCH 4/5] chore: updates --- examples/dynamo/torch_export_flux_dev.py | 8 ++++---- .../{remove_assert_scalar.py => remove_assert_nodes.py} | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) rename py/torch_tensorrt/dynamo/lowering/passes/{remove_assert_scalar.py => remove_assert_nodes.py} (99%) diff --git a/examples/dynamo/torch_export_flux_dev.py b/examples/dynamo/torch_export_flux_dev.py index 7ebe118196..421ca9874a 100644 --- a/examples/dynamo/torch_export_flux_dev.py +++ b/examples/dynamo/torch_export_flux_dev.py @@ -9,11 +9,11 @@ **FLUX.1 [dev]** is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. It is an open-weight, guidance-distilled model for non-commercial applications. -Install the following dependencies before compilation +To run this demo, you need to have access to Flux model (request for access if you do not have it already on the `FLUX.1-dev `_ page) and install the following dependencies .. code-block:: python - pip install sentencepiece=="0.2.0" transformers=="4.48.2" accelerate=="1.3.0" diffusers=="0.32.2" + pip install sentencepiece=="0.2.0" transformers=="4.48.2" accelerate=="1.3.0" diffusers=="0.32.2" protobuf=="5.29.3" There are different components of the ``FLUX.1-dev`` pipeline such as ``transformer``, ``vae``, ``text_encoder``, ``tokenizer`` and ``scheduler``. In this example, we demonstrate optimizing the ``transformer`` component of the model (which typically consumes >95% of the e2e diffusion latency) @@ -64,7 +64,7 @@ "img_ids": {0: IMG_ID}, "guidance": {0: BATCH}, "joint_attention_kwargs": {}, - "return_dict": None + "return_dict": None, } # The guidance factor is of type torch.float32 dummy_inputs = { @@ -82,7 +82,7 @@ "img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE), "guidance": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE), "joint_attention_kwargs": {}, - "return_dict": False + "return_dict": False, } # This will create an exported program which is going to be compiled with Torch-TensorRT ep = _export( diff --git a/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_scalar.py b/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_nodes.py similarity index 99% rename from py/torch_tensorrt/dynamo/lowering/passes/remove_assert_scalar.py rename to py/torch_tensorrt/dynamo/lowering/passes/remove_assert_nodes.py index 7db3a3f9ba..c7f93ffb5f 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_scalar.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_nodes.py @@ -24,7 +24,7 @@ def remove_assert_scalar( if count > 0: gm = clean_up_graph_after_modifications(gm) - + logger.debug(f"Removed {count} assert_scalar nodes:\n{gm.graph}") return gm From 91cd13c5ab7c1a8f5fa50d145dfd1b8901e0866f Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 21 Feb 2025 15:32:44 -0800 Subject: [PATCH 5/5] chore: remove addmm converter, tests --- .../dynamo/conversion/aten_ops_converters.py | 28 -------- .../dynamo/conversion/impl/__init__.py | 1 - .../dynamo/conversion/impl/addmm.py | 34 ---------- .../lowering/passes/_aten_lowering_pass.py | 4 +- .../lowering/passes/remove_assert_nodes.py | 2 +- tests/py/dynamo/conversion/test_addmm_aten.py | 65 ------------------- 6 files changed, 3 insertions(+), 131 deletions(-) delete mode 100644 py/torch_tensorrt/dynamo/conversion/impl/addmm.py delete mode 100644 tests/py/dynamo/conversion/test_addmm_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 4f2d168d29..545e665543 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2891,34 +2891,6 @@ def aten_ops_argmin( ) -@dynamo_tensorrt_converter(torch.ops.aten.addmm.default, supports_dynamic_shapes=True) -@enforce_tensor_types( - { - 0: (TRTTensor,), - 1: (np.ndarray, torch.Tensor, TRTTensor), - 2: (np.ndarray, torch.Tensor, TRTTensor), - } -) -def aten_ops_addmm( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.addmm.addmm( - ctx, - target, - SourceIR.ATEN, - name, - args[0], - args[1], - args[2], - beta=kwargs.get("beta", 1), - alpha=kwargs.get("alpha", 1), - ) - - @dynamo_tensorrt_converter( torch.ops.aten.constant_pad_nd.default, supports_dynamic_shapes=True ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index 75f7492591..f80de4d6a9 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -1,6 +1,5 @@ from torch_tensorrt.dynamo.conversion.impl import ( activation, - addmm, arange, attention, cast, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/addmm.py b/py/torch_tensorrt/dynamo/conversion/impl/addmm.py deleted file mode 100644 index 1a0690852a..0000000000 --- a/py/torch_tensorrt/dynamo/conversion/impl/addmm.py +++ /dev/null @@ -1,34 +0,0 @@ -from typing import Optional, Union - -import numpy as np -import torch -from torch.fx.node import Target -from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.dynamo.conversion import impl -from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.fx.types import TRTTensor - - -def addmm( - ctx: ConversionContext, - target: Target, - source_ir: Optional[SourceIR], - name: str, - input: TRTTensor, - mat1: Union[TRTTensor, torch.Tensor, np.ndarray], - mat2: Union[TRTTensor, torch.Tensor, np.ndarray], - *, - beta: Union[float, int], - alpha: Union[float, int], -) -> TRTTensor: - mm = impl.matmul.matrix_multiply(ctx, target, source_ir, f"{name}_mm", mat1, mat2) - if alpha != 1: - mm = impl.elementwise.mul( - ctx, target, SourceIR.ATEN, f"{name}_mul_alpha", mm, alpha - ) - if beta != 1: - input = impl.elementwise.mul( - ctx, target, SourceIR.ATEN, f"{name}_mul_beta", input, beta - ) - - return impl.elementwise.add(ctx, target, source_ir, f"{name}_add", input, mm) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index f24fd8ec21..f611c90f51 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -10,7 +10,7 @@ from .fuse_prims_broadcast import fuse_prims_broadcast from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention from .pass_manager import DynamoPassManager -from .remove_assert_scalar import remove_assert_scalar +from .remove_assert_nodes import remove_assert_nodes from .remove_detach import remove_detach from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones from .repair_input_as_output import repair_input_as_output @@ -27,7 +27,7 @@ replace_max_pool_with_indices, lower_scaled_dot_product_attention, view_to_reshape, - remove_assert_scalar, + remove_assert_nodes, accumulate_fp32_matmul, ] ) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_nodes.py b/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_nodes.py index c7f93ffb5f..890391e280 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_nodes.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_nodes.py @@ -9,7 +9,7 @@ logger = logging.getLogger(__name__) -def remove_assert_scalar( +def remove_assert_nodes( gm: torch.fx.GraphModule, settings: CompilationSettings ) -> torch.fx.GraphModule: """Remove assert_scalar ops in the graph""" diff --git a/tests/py/dynamo/conversion/test_addmm_aten.py b/tests/py/dynamo/conversion/test_addmm_aten.py deleted file mode 100644 index 6108d3ea6d..0000000000 --- a/tests/py/dynamo/conversion/test_addmm_aten.py +++ /dev/null @@ -1,65 +0,0 @@ -import torch -import torch.nn as nn -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests - -from .harness import DispatchTestCase - - -class TestAddmmConverter(DispatchTestCase): - @parameterized.expand( - [ - ((2, 2), (2, 3), (3, 2)), - ((4, 6), (4, 5), (5, 6)), - ((2, 1), (2, 3), (3, 1)), - ((4, 1), (4, 1), (1, 1)), - ((1, 2), (1, 3), (3, 2)), - ] - ) - def test_addmm(self, input_shape, mat1_shape, mat2_shape): - class Addmm(nn.Module): - def forward(self, input, mat1, mat2): - return torch.ops.aten.addmm.default(input, mat1, mat2) - - inputs = [ - torch.randn(input_shape), - torch.randn(mat1_shape), - torch.randn(mat2_shape), - ] - - self.run_test( - Addmm(), - inputs, - ) - - @parameterized.expand( - [ - ((2, 2), (2, 3), (3, 2), 1.0, 1.0), - ((4, 6), (4, 5), (5, 6), 1.2, 0.8), - ((2, 1), (2, 3), (3, 1), 3, 2), - ((4, 1), (4, 1), (1, 1), 1, 1), - ((1, 2), (1, 3), (3, 2), 2, 1.0), - ((1, 2), (1, 3), (3, 2), 1, 2.0), - ] - ) - def test_addmm_scale(self, input_shape, mat1_shape, mat2_shape, beta, alpha): - class Addmm(nn.Module): - def forward(self, input, mat1, mat2): - return torch.ops.aten.addmm.default( - input, mat1, mat2, beta=beta, alpha=alpha - ) - - inputs = [ - torch.randn(input_shape), - torch.randn(mat1_shape), - torch.randn(mat2_shape), - ] - - self.run_test( - Addmm(), - inputs, - ) - - -if __name__ == "__main__": - run_tests()