Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Split addmm nodes to not cast bias for FP32 accumulation and flux example fixes. #3395

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions examples/dynamo/torch_export_flux_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@

**FLUX.1 [dev]** is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. It is an open-weight, guidance-distilled model for non-commercial applications.

Install the following dependencies before compilation
To run this demo, you need to have access to Flux model (request for access if you do not have it already on the `FLUX.1-dev <https://huggingface.co/black-forest-labs/FLUX.1-dev>`_ page) and install the following dependencies

.. code-block:: python

pip install sentencepiece=="0.2.0" transformers=="4.48.2" accelerate=="1.3.0" diffusers=="0.32.2"
pip install sentencepiece=="0.2.0" transformers=="4.48.2" accelerate=="1.3.0" diffusers=="0.32.2" protobuf=="5.29.3"

There are different components of the ``FLUX.1-dev`` pipeline such as ``transformer``, ``vae``, ``text_encoder``, ``tokenizer`` and ``scheduler``. In this example,
we demonstrate optimizing the ``transformer`` component of the model (which typically consumes >95% of the e2e diffusion latency)
Expand Down Expand Up @@ -63,6 +63,8 @@
"txt_ids": {0: SEQ_LEN},
"img_ids": {0: IMG_ID},
"guidance": {0: BATCH},
"joint_attention_kwargs": {},
"return_dict": None,
}
# The guidance factor is of type torch.float32
dummy_inputs = {
Expand All @@ -79,6 +81,8 @@
"txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE),
"img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE),
"guidance": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE),
"joint_attention_kwargs": {},
"return_dict": False,
}
# This will create an exported program which is going to be compiled with Torch-TensorRT
ep = _export(
Expand Down
28 changes: 0 additions & 28 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2891,34 +2891,6 @@ def aten_ops_argmin(
)


@dynamo_tensorrt_converter(torch.ops.aten.addmm.default, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
0: (TRTTensor,),
1: (np.ndarray, torch.Tensor, TRTTensor),
2: (np.ndarray, torch.Tensor, TRTTensor),
}
)
def aten_ops_addmm(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.addmm.addmm(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
args[2],
beta=kwargs.get("beta", 1),
alpha=kwargs.get("alpha", 1),
)


@dynamo_tensorrt_converter(
torch.ops.aten.constant_pad_nd.default, supports_dynamic_shapes=True
)
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from torch_tensorrt.dynamo.conversion.impl import (
activation,
addmm,
arange,
attention,
cast,
Expand Down
34 changes: 0 additions & 34 deletions py/torch_tensorrt/dynamo/conversion/impl/addmm.py

This file was deleted.

1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
aten.addcmul,
aten.addcmul_,
aten.addr,
aten.addmm,
aten.aminmax,
aten.arange.default,
aten.arange.start,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .fuse_prims_broadcast import fuse_prims_broadcast
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention
from .pass_manager import DynamoPassManager
from .remove_assert_scalar import remove_assert_scalar
from .remove_assert_nodes import remove_assert_nodes
from .remove_detach import remove_detach
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
from .repair_input_as_output import repair_input_as_output
Expand All @@ -27,7 +27,7 @@
replace_max_pool_with_indices,
lower_scaled_dot_product_attention,
view_to_reshape,
remove_assert_scalar,
remove_assert_nodes,
accumulate_fp32_matmul,
]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
def accumulate_fp32_matmul(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Replace a matmul layer with fp32 accumulation nodes"""
"""Add cast to FP32/16 nodes around a matmul layer. This pattern is detected by TensorRT and will enable FP32 accumulation during execution."""
if settings.use_fp32_acc:
matmul_targets = [
torch.ops.aten.mm.default,
torch.ops.aten.bmm.default,
torch.ops.aten.addmm.default,
]

matmul_nodes = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
logger = logging.getLogger(__name__)


def remove_assert_scalar(
def remove_assert_nodes(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Remove assert_scalar ops in the graph"""
count = 0
for node in gm.graph.nodes:
if (
node.target == torch.ops.aten._assert_scalar.default
or node == torch.ops.aten._assert_tensor_metadata.default
or node.target == torch.ops.aten._assert_tensor_metadata.default
):
gm.graph.erase_node(node)
count += 1
Expand Down
65 changes: 0 additions & 65 deletions tests/py/dynamo/conversion/test_addmm_aten.py

This file was deleted.

42 changes: 42 additions & 0 deletions tests/py/dynamo/lowering/test_aten_lowering_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,48 @@ def forward(self, input, weight):
)
torch._dynamo.reset()

def test_fp32_acc_for_addmm(self):
class FP32Acc(torch.nn.Module):
def forward(self, input, mat1, mat2):
out = torch.ops.aten.addmm.default(input, mat1, mat2)
return out

inputs = [
torch.rand((3, 5)).cuda(),
torch.rand((3, 4)).cuda(),
torch.rand((4, 5)).cuda(),
]

fx_graph = torch.fx.symbolic_trace(FP32Acc())
expected_ops = {
torch.ops.aten._to_copy.default,
torch.ops.aten.mm.default,
torch.ops.aten.add.Tensor,
}
unexpected_ops = {}

unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
use_fp32_acc=True,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)
torch._dynamo.reset()


class TestLowerEfficientAttention(TestCase):
def test_lower_efficient_attention(self):
Expand Down
Loading