Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-enable SDPA's FA2 path #30070

Merged
merged 22 commits into from
Apr 17, 2024
Merged

Re-enable SDPA's FA2 path #30070

merged 22 commits into from
Apr 17, 2024

Conversation

fxmarty
Copy link
Contributor

@fxmarty fxmarty commented Apr 5, 2024

As per title, SDPA can't dispatch to FA2 in case the attn_mask argument is not None, which was always the case in recent releases following llama/gemma refactor, which led to memory issues #30010.

In case we compile and use fullgraph=True, the graph is recaptured for different q_len, so having is_causal=causal_mask is None and q_len > 1, is fine

Reference pytorch/pytorch#108108

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@fxmarty
Copy link
Contributor Author

fxmarty commented Apr 5, 2024

FAILED tests/test_pipeline_mixin.py::TextGenerationPipelineTests::test_stop_sequence_stopping_criteria - AssertionError: Lists differ: [{'generated_text': 'Hello I believe in in in����������'}] != [{'generated_text': 'Hello I believe in fe fe fe fe fe fe fe fe fe fe fe fe'}]

is not failing locally, I don't get it. Besides, it is labeled as tests_pipelines_tf, not sure why

FAILED tests/models/whisper/test_modeling_whisper.py::WhisperModelTest::test_generate_with_prompt_ids_max_length - IndexError: index -1 is out of bounds for dimension 1 with size 0

is unrelated and also failing on main... Not sure why it appears here?

Other than that, Cohere/Llama/Gemma slow tests pass.

Copy link
Contributor

@warner-benjamin warner-benjamin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some comments and suggested fixes to enable torch.compile with fullgraph=True and prevent applying bidirectional instead of causal attention. The suggestions also apply to gemma and cohere models.

@fxmarty fxmarty requested review from warner-benjamin and ArthurZucker and removed request for warner-benjamin April 8, 2024 10:51
Comment on lines 262 to 268
if attention_mask is None:
# TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or
# or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
# Thus, we currently can NOT set `ignore_causal_mask = True` here. We would need a `torch._dynamo.is_exporting()` flag.
#
# Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` (`TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor`).
ignore_causal_mask = not is_tracing
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@warner-benjamin I'm quite sure there is no way around this currently. We may revisit this with pytorch/pytorch#114823 if that helps.

But I get it that we should have a path where users don't use attention_mask input, and use torch.compile WITHOUT torch.export/torch.onnx.dynamo_export.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With my proposed modification to SdpaAttention.forward, I think it would be possible to check here for self.training to dispatch to FA2 and Efficient if in eval? Otherwise, a config option to force SDPA FA2 would work.

Copy link
Contributor Author

@fxmarty fxmarty Apr 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@warner-benjamin I think you misunderstood my message here.

When using torch.export, torch.onnx.dynamo_export, a single GraphModule is captured, no matter what, based on the sample input that is passed.

On the contrary, simply when using an OptimizedModule from torch.compile (even with fullgraph=True), there is still possible full graph invalidation based on the input. You can check torchdynamo logs (torch._logging.set_logs(dynamo=logging.INFO, aot=logging.INFO, inductor=logging.INFO, graph_breaks=True, guards=True, recompiles=True, output_code=True, graph_code=True, graph=True)), e.g. we have the following even with fullgraph=True:

[2024-04-08 11:58:14,174] torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function forward in /home/felix/transformers/src/transformers/models/llama/modeling_llama.py:1163
[2024-04-08 11:58:14,174] torch._dynamo.guards.__recompiles: [DEBUG]     triggered by the following guard failure(s):
[2024-04-08 11:58:14,174] torch._dynamo.guards.__recompiles: [DEBUG]     - tensor 'L['input_ids']' size mismatch at index 1. expected 4, actual 1

& you can see the related FX ops, with is_causal hard-coded:

[2024-04-08 11:58:22,052] [0/1] torch._dynamo.output_graph.__graph: [DEBUG] call_function attn_output <built-in function scaled_dot_product_attention> (query_states_2, key_states_36, value_states_35) {'attn_mask': None, 'dropout_p': 0.0, 'is_causal': False}

&

[2024-04-08 11:57:34,825] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] call_function attn_output <built-in function scaled_dot_product_attention> (query_states_2, key_states_36, value_states_35) {'attn_mask': None, 'dropout_p': 0.0, 'is_causal': True}

What I am saying is that we want to avoid hard-coding is_causal. As no torch._dynamo.is_exporting() is not available (I wish it was, as well as torch._dynamo.is_fullgraph_compiling() pytorch/pytorch#120400), my safe bet is to always use attn_mask in case torch.compile is used.

Copy link
Contributor

@warner-benjamin warner-benjamin Apr 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fxmarty I've been unable to trigger the full graph invalidation as you've shown. Could you share code to reproduce it?

I've tried my proposed changes with torch.export and torch.compile with fullgraph=True, dynamic=True, and both fullgraph=True, dynamic=True with PyTorch 2.2.2 and PyTorch 2.1.2 (minus export since it's a 2.2 feature). I have been unsuccessful in triggering any recompilations (outside of different input shapes without dynamic=True, but that's too be expected).

What I am saying is that we want to avoid hard-coding is_causal. As no torch._dynamo.is_exporting() is not available (I wish it was, as well as torch._dynamo.is_fullgraph_compiling() pytorch/pytorch#120400), my safe bet is to always use attn_mask in case torch.compile is used.

I'm not following why we cannot use the FA2 kernel during compiled training. We could dispatch via self.training to FA2/attn_mask, add a config option to force FA2/attn_mask, or both.

As it currently stands, someone testing model training in eager mode would use the memory efficient FA2 kernel, and then when switching to torch.compile for additional memory savings and speed, their training would instead use more memory due to silently switching the kernel from FA2 to efficient.

Copy link
Contributor Author

@fxmarty fxmarty Apr 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fxmarty I've been unable to trigger the full graph invalidation as you've shown. Could you share code to reproduce it?

Sure, let me do next week.

I'm not following why we cannot use the FA2 kernel during compiled training. We could dispatch via self.training to FA2/attn_mask, add a config option to force FA2/attn_mask, or both.

Sure, self.training could be a decent option

Copy link
Contributor Author

@fxmarty fxmarty Apr 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@warner-benjamin using torch==2.2.2 and setting

        if attention_mask is None:
            ignore_causal_mask = True

in _ignore_causal_mask_sdpa, and tracing with fullgraph=True, dynamic=True, I do get scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool when using is_causal=q_len > 1.

When tracing simply with fullgraph=True, I am getting the above (graph recompilation), which is not friendly with torch.export / ONNX dynamo export.

You can see the following script:

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import copy

import logging
torch._logging.set_logs(dynamo=logging.INFO, aot=logging.INFO, inductor=logging.INFO, graph_breaks=True, guards=True, recompiles=True, output_code=True, graph_code=True, graph=True)

tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
)

with torch.device("cuda"):
    model = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Llama-2-7b-chat-hf",
        torch_dtype=torch.float16,
        attn_implementation="sdpa",
    )

model = model.eval()
inps = {
    "input_ids": torch.tensor([[5, 6, 7, 9]], dtype=torch.int64).to("cuda"),
    "use_cache": True,
}

with torch.no_grad():
    out = model(**inps)

    pkv = out.past_key_values
    pkv_0 = copy.deepcopy(pkv)
    pkv_1 = copy.deepcopy(pkv)
    pkv_2 = copy.deepcopy(pkv)
    pkv_3 = copy.deepcopy(pkv)

with torch.no_grad():

    print("--------- WITHOUT TORCH.COMPILE + input_ids > 1")
    inps_here = copy.deepcopy(inps)
    print("inps_here", inps_here.keys())

    res_several = model(**inps_here, past_key_values=pkv_0)

    print("--------- WITHOUT TORCH.COMPILE + input_ids == 1")
    inps_here = copy.deepcopy(inps)
    inps_here["input_ids"] = inps_here["input_ids"][:, :1]
    print("inps_here", inps_here.keys())

    res_one = model(**inps_here, past_key_values=pkv_1)

    print("compiling...")

    torch.cuda.synchronize()

    model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

    print("--------- WITH TORCH.COMPILE + input_ids > 1")
    inps_here = copy.deepcopy(inps)
    print("inps_here", inps_here.keys())

    res_several_comp = model(**inps_here, past_key_values=pkv_2)

    print("--------- WITH TORCH.COMPILE + input_ids == 1")
    inps_here = copy.deepcopy(inps)
    inps_here["input_ids"] = inps_here["input_ids"][:, :1]
    print("inps_here", inps_here.keys())

    res_one_comp = model(**inps_here, past_key_values=pkv_3)

assert torch.allclose(res_several, res_several_comp)
assert torch.allclose(res_one, res_one_comp)

Thus in my opinion dropping the attention_mask by default in case we are using torch.compile is not the right choice for now.

We could have a force_causal_mask flag to inform users who want to make sure SDPA FA2 backend is usable with torch.compile/torchscript/etc, that would disable attn_mask altogether and force is_causal=True. Would that be fine to you? I would suggest to do it in an other PR though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Will fix compile+FA2 in a new PR once this one is merged.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your thorough review!

@fxmarty fxmarty requested a review from warner-benjamin April 8, 2024 11:39
fxmarty and others added 4 commits April 17, 2024 18:11
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@fxmarty
Copy link
Contributor Author

fxmarty commented Apr 17, 2024

@ArthurZucker done, also running slow tests for llama/gemma/cohere gives

FAILED tests/models/code_llama/test_tokenization_code_llama.py::LlamaIntegrationTest::test_conversion - AssertionError: '{\n [226 chars]ed": true,\n      "special": true\n    },\n   [1796697 chars]}\n}' != '{\n [226 chars]ed": false,\n      "special": true...
FAILED tests/models/cohere/test_tokenization_cohere.py::CohereTokenizationTest::test_saving_tokenizer_trainer - pynvml.nvml.NVMLError_NotSupported: Not Supported
FAILED tests/models/llama/test_tokenization_llama.py::LlamaIntegrationTest::test_conversion - AssertionError: '{\n [964 chars]or": {\n    "type": "TemplateProcessing",\n   [1795198 chars]}\n}' != '{\n [964 chars]or": null,\n  "decoder": {\n    "t...
FAILED tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py::RecurrentGemmaModelTest::test_cpu_offload - ValueError: model.normalizer doesn't have any device set.
FAILED tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py::RecurrentGemmaModelTest::test_disk_offload_bin - ValueError: model.normalizer doesn't have any device set.
FAILED tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py::RecurrentGemmaModelTest::test_disk_offload_safetensors - KeyError: 'model.normalizer'
FAILED tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py::RecurrentGemmaIntegrationTest::test_2b_sample - AssertionError: Lists differ: ['Where is Paris ?\n\nChoose the word or phrase that is closest[463 chars]}}$'] != ['Where is Paris ?\n\nAnswer this quest...
FAILED tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py::RecurrentGemmaIntegrationTest::test_model_2b_8bit - AssertionError: Lists differ: ['Hello I am doing a project on the topic o[162 chars]em>'] != ['<bos>Hello I am doing a project on the to[166 chars]<u>"]

which also fail on main on A100

@fxmarty
Copy link
Contributor Author

fxmarty commented Apr 17, 2024

@ArthurZucker

from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import torch
from transformers.cache_utils import StaticCache

tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
)

with torch.device("cuda"):
    model = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Llama-2-7b-chat-hf",
        torch_dtype=torch.float16,
        attn_implementation="sdpa",
    )

inputs = tokenizer(
    ["I would", "Today I am in Paris and", "I am"], padding=True, return_tensors="pt"
).to(model.device)

new_tokens = 10
gen_config = GenerationConfig(
    max_new_tokens=new_tokens,
    min_new_tokens=new_tokens,
    use_cache=True,
    pad_token_id=tokenizer.pad_token_id,
    num_beams=1,
    do_sample=False,
    eos_token_id=None,  # This is required for min_new_tokens to actually have an effect.
)
model.generation_config.eos_token_id = None  # greedy_search falls back on this eos_token_id that we need to set to None as well for min_new_tokens to have an effect.

print("----- GENERATE WITHOUT COMPILE")
gen_out = model.generate(**inputs, generation_config=gen_config, cache_implementation="static")

decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)

print("decoded", decoded)

print("compiling...")

model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
print("Finished compile call")

print("----- GENERATE WITH COMPILE")
gen_out = model.generate(**inputs, generation_config=gen_config, cache_implementation="static")
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
print("decoded", decoded)

gen_out = model.generate(**inputs, generation_config=gen_config, cache_implementation="static")
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
print("decoded", decoded)

works as expected as well (on 2.3 RC, see pytorch/pytorch#121943 for 2.2.2, happening on main)

@fxmarty fxmarty requested a review from ArthurZucker April 17, 2024 15:58
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like that the black magic is hidden from use with a if magic : return None.

@fxmarty fxmarty merged commit 05bdef1 into huggingface:main Apr 17, 2024
19 checks passed
ArthurZucker added a commit that referenced this pull request Apr 18, 2024
ArthurZucker added a commit that referenced this pull request Apr 18, 2024
* Revert "Re-enable SDPA's FA2 path (#30070)"

This reverts commit 05bdef1.

* Revert "Fix quality Olmo + SDPA (#30302)"

This reverts commit ec92f98.
LysandreJik pushed a commit that referenced this pull request Apr 18, 2024
… + revert #30070 at the same time (#30317)

* Update awq.py

* style

* revert felix PR

* fix

* add felix comments
ydshieh pushed a commit that referenced this pull request Apr 23, 2024
* tentatively re-enable FA2 + SDPA

* better comment

* _ignore_causal_mask_sdpa as staticmethod

* type hints

* use past_seen_tokens instead

* enable copied from for sdpa

* ruff

* llama simplifications on review

* remove unnecessary self.is_causal check

* fix copies

* cleaning

* precise message

* better doc

* add test

* simplify

* Update src/transformers/models/llama/modeling_llama.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_llama.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_llama.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* style

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
ydshieh pushed a commit that referenced this pull request Apr 23, 2024
* Revert "Re-enable SDPA's FA2 path (#30070)"

This reverts commit 05bdef1.

* Revert "Fix quality Olmo + SDPA (#30302)"

This reverts commit ec92f98.
ydshieh pushed a commit that referenced this pull request Apr 23, 2024
… + revert #30070 at the same time (#30317)

* Update awq.py

* style

* revert felix PR

* fix

* add felix comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants