Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Compile] Only test compiling model forward pass #35658

Merged
merged 2 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 4 additions & 21 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2025,16 +2025,10 @@ def test_generate_with_quant_cache(self):
with self.assertRaises(ValueError):
model.generate(**generation_kwargs, **inputs_dict)

@parameterized.expand(
[
("forward_only", False), # TODO (@joao): a few models failing. After fixed, this should not be "@slow"
("end_to_end", True), # TODO (@joao): end-to-end compilation is broken with torch 2.5+, explore and fix
]
)
@pytest.mark.generate
@require_torch_gpu
@slow
def test_generate_compile(self, _, end_to_end):
def test_generate_compile_model_forward(self):
"""
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. Tests
end-to-end compilation and forward pass compilation only.
Expand All @@ -2044,14 +2038,7 @@ def test_generate_compile(self, _, end_to_end):
if not model_class._supports_static_cache:
self.skipTest("This model doesn't support static cache")

# TODO (joao) -- fix and enable me :)
if end_to_end and any(model_name in model_class.__name__.lower() for model_name in ["whisper"]):
self.skipTest("whisper model end-to-end generate compile not yet supported")

config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# TODO (joao) -- fix and enable me :)
if end_to_end and config.is_encoder_decoder:
self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported")

model = model_class(config).to(torch_device)
model.eval() # otherwise `self.training` is `True` -- this flag is used at attn mask creation time
Expand All @@ -2067,10 +2054,8 @@ def test_generate_compile(self, _, end_to_end):
"max_new_tokens": 10,
"return_dict_in_generate": True,
"output_scores": True,
"cache_implementation": "static",
}
# end-to-end works best with dynamic cache, forward compilation works best with static cache
if not end_to_end:
generation_kwargs["cache_implementation"] = "static"

# get eager + dynamic cache results for future comparison
dynamic_outputs = []
Expand All @@ -2081,10 +2066,8 @@ def test_generate_compile(self, _, end_to_end):
generation_config = copy.deepcopy(model.generation_config)
generation_config.update(**generation_kwargs)
torch.compiler.reset()
if end_to_end:
model.generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
else:
model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")

model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")

compiled_outputs = []
for model_inputs in input_ids_sets:
Expand Down
2 changes: 1 addition & 1 deletion tests/models/chameleon/test_modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def test_batching_equivalence(self):

# TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow
@unittest.skip("Chameleon is not compatible with end-to-end generation compilation")
def test_generate_compile_fullgraph(self):
def test_generate_compile_model_forward(self):
pass


Expand Down
2 changes: 1 addition & 1 deletion tests/models/dbrx/test_modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def test_disk_offload_bin(self):
pass

@unittest.skip("Dbrx does not support `torch.compile` with `fullgraph=True`.")
def test_generate_compile_fullgraph(self):
def test_generate_compile_model_forward(self):
pass


Expand Down
8 changes: 0 additions & 8 deletions tests/models/emu3/test_modeling_emu3.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,6 @@ def test_model_rope_scaling(self, scaling_type):
def test_custom_4d_attention_mask(self):
pass

@unittest.skip("Fails with unknown error only on end-to-end compile") # TODO raushan fixme
def test_generate_compile_1_end_to_end(self):
pass


class Emu3Vision2TextModelTester:
def __init__(
Expand Down Expand Up @@ -398,10 +394,6 @@ def test_custom_4d_attention_mask(self):
def test_initialization(self):
pass

@unittest.skip("End-to-end compilation is not supported due to dynamic control in `prepare_inputs_for_generation`")
def test_generate_compile_1_end_to_end(self):
pass


@require_torch
class Emu3IntegrationTest(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion tests/models/idefics/test_modeling_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,7 @@ def test_custom_4d_attention_mask(self):
pass

@unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs")
def test_generate_compile_fullgraph(self):
def test_generate_compile_model_forward(self):
pass

@unittest.skip(reason="We only test the model that takes in multiple images")
Expand Down
2 changes: 1 addition & 1 deletion tests/models/paligemma/test_modeling_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):

# TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow
@unittest.skip("PaliGemma is not compatible with end-to-end generation compilation")
def test_generate_compile_fullgraph(self):
def test_generate_compile_model_forward(self):
pass


Expand Down
2 changes: 1 addition & 1 deletion tests/models/qwen2_vl/test_modeling_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def test_generate_from_inputs_embeds_with_static_cache(self):
pass

@unittest.skip(reason="Can't compile fullgraph due to dynamic control flow in `prepare_inputs_for_generate`")
def test_generate_compile_fullgraph(self):
def test_generate_compile_model_forward(self):
pass


Expand Down
Loading