Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile compatibility with generate + static cache #29114

Merged
merged 16 commits into from
Feb 21, 2024

Conversation

fxmarty
Copy link
Contributor

@fxmarty fxmarty commented Feb 19, 2024

This PR adds the support of torch.compile when using generate + static cache.

@fxmarty fxmarty marked this pull request as draft February 19, 2024 17:02
Comment on lines 1261 to 1266
# The `contiguous()` here is necessary to have a static stride during (non-speculative) decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard.
# TODO: We don't really need to handle the input_ids here, and this contiguous() call could be removed if we were
# simply using GenerationMixin.greedy_search `next_tokens` variable directly (which is already contiguous), instead of
# doing a torch.cat + then slice.
model_inputs = {"input_ids": input_ids.contiguous()}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per the comment. If input_ids is not contiguous, its stride is different at each forward call in generate, which triggers a recompilation at every step in the loop. This is very slow.

Ideally we would want to just use next_tokens instead of having this contiguous() call, but let's do a proper refactor in an other PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, next token would be less "costly" but I think we need all of them for speculative? Anyways noticed that as well when I was compiling, just let it be at the time!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The int vs tensor « bug » is expected, that is precisely why we use « cache_positions » for generation.
‘I think generate needs to handle that because I’m not sure compile will like the in place tensor modification

@fxmarty
Copy link
Contributor Author

fxmarty commented Feb 20, 2024

@ArthurZucker @gante @LysandreJik This PR fixes many issues with the current torch.compile + static cache + generate implementation, as follow.

1. Always keep the same stride for inputs in the decode phase

generate apparently does not use directly the next_tokens variable as the next input_ids. Instead, the next_tokens are concatenated with previous tokens, and then sliced, which results in the input tensors having different stride while having the same shape:

------------- loop forward in generate 0
model_inputs input ids shape torch.Size([2, 7])
model_inputs input ids stride (7, 1)
model_inputs position_ids shape torch.Size([2, 7])
model_inputs position_ids stride (7, 1)
------------- loop forward in generate 1
model_inputs input ids shape torch.Size([2, 1])
model_inputs input ids stride (8, 1)
model_inputs position_ids shape torch.Size([2, 1])
model_inputs position_ids stride (8, 1)
------------- loop forward in generate 2
model_inputs input ids shape torch.Size([2, 1])
model_inputs input ids stride (9, 1)
model_inputs position_ids shape torch.Size([2, 1])
model_inputs position_ids stride (9, 1)
------------- loop forward in generate 3
model_inputs input ids shape torch.Size([2, 1])
model_inputs input ids stride (10, 1)
model_inputs position_ids shape torch.Size([2, 1])
model_inputs position_ids stride (10, 1)
------------- loop forward in generate 4
etc.

This is bad because with torch.compile there are guards on the stride of the inputs, and thus recompilation is triggered in the decode phase while this is really not necessary.

V0220 17:27:08.283680 140705660285312 torch/_dynamo/guards.py:1381] Recompiling function forward in /home/felix/transformers/src/transformers/models/llama/modeling_llama.py:1127
V0220 17:27:08.283680 140705660285312 torch/_dynamo/guards.py:1381]     triggered by the following guard failure(s):
V0220 17:27:08.283680 140705660285312 torch/_dynamo/guards.py:1381]     - tensor 'L['input_ids']' stride mismatch at index 0. expected 7, actual 8

&

V0220 17:27:12.636874 140705660285312 torch/_dynamo/guards.py:1381] Recompiling function torch_dynamo_resume_in_forward_at_989 in /home/felix/transformers/src/transformers/models/llama/modeling_llama.py:989
V0220 17:27:12.636874 140705660285312 torch/_dynamo/guards.py:1381]     triggered by the following guard failure(s):
V0220 17:27:12.636874 140705660285312 torch/_dynamo/guards.py:1381]     - tensor 'L['position_ids']' stride mismatch at index 0. expected 7, actual 8

2. Do not compile _update_causal_mask

_update_causal_mask uses the input attention_mask length in its code. I believe this results in an FX placehoder being a SymInt,

V0220 11:30:30.341809 140023118176640 torch/_dynamo/output_graph.py:1084] [2/1]  ===== __compiled_fn_12 =====
V0220 11:30:30.341809 140023118176640 torch/_dynamo/output_graph.py:1084] [2/1]  /home/felix/miniconda3/envs/fx/lib/python3.9/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0220 11:30:30.341809 140023118176640 torch/_dynamo/output_graph.py:1084] [2/1]     def forward(self, s0 : torch.SymInt, L_attention_mask_ : torch.Tensor):
V0220 11:30:30.341809 140023118176640 torch/_dynamo/output_graph.py:1084] [2/1]         l_attention_mask_ = L_attention_mask_

which retriggers CUDAGraph capture for every decode step:

I0220 11:30:57.881897 140023118176640 torch/_inductor/cudagraph_trees.py:375] recording cudagraph tree for symint key 13
I0220 11:30:57.902060 140023118176640 torch/_inductor/cudagraph_trees.py:375] recording cudagraph tree for symint key 14
I0220 11:30:57.922157 140023118176640 torch/_inductor/cudagraph_trees.py:375] recording cudagraph tree for symint key 15
I0220 11:30:57.942382 140023118176640 torch/_inductor/cudagraph_trees.py:375] recording cudagraph tree for symint key 16
I0220 11:30:57.962995 140023118176640 torch/_inductor/cudagraph_trees.py:375] recording cudagraph tree for symint key 17
I0220 11:30:57.983414 140023118176640 torch/_inductor/cudagraph_trees.py:375] recording cudagraph tree for symint key 18
I0220 11:30:58.004108 140023118176640 torch/_inductor/cudagraph_trees.py:375] recording cudagraph tree for symint key 19
I0220 11:30:58.024602 140023118176640 torch/_inductor/cudagraph_trees.py:375] recording cudagraph tree for symint key 20
I0220 11:30:58.045034 140023118176640 torch/_inductor/cudagraph_trees.py:375] recording cudagraph tree for symint key 21
I0220 11:30:58.065743 140023118176640 torch/_inductor/cudagraph_trees.py:375] recording cudagraph tree for symint key 22
I0220 11:30:58.086160 140023118176640 torch/_inductor/cudagraph_trees.py:375] recording cudagraph tree for symint key 23
I0220 11:30:58.106957 140023118176640 torch/_inductor/cudagraph_trees.py:375] recording cudagraph tree for symint key 24
I0220 11:30:58.127666 140023118176640 torch/_inductor/cudagraph_trees.py:375] recording cudagraph tree for symint key 25

This is very slow. We avoid capturing _update_causal_mask with @torch.compiler.disable, which fixes the issue (no more cuda graph capture after the very first decode step).

3. Avoid using a stateful int seen_tokens (PyTorch bug)

On main, StaticCache's seen_tokens is bugged only when using torch.compile prior. Convince yourself with:

from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import torch
from transformers.cache_utils import StaticCache

tokenizer = AutoTokenizer.from_pretrained(
    "NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
)

with torch.device("cuda"):
    model = AutoModelForCausalLM.from_pretrained(
        "NousResearch/Llama-2-7b-chat-hf",
        torch_dtype=torch.float16,
        attn_implementation="sdpa",
    )

inputs = tokenizer(
    ["I would", "Today I am in Paris and"], padding=True, return_tensors="pt"
).to(model.device)

new_tokens = 10
gen_config = GenerationConfig(
    max_new_tokens=new_tokens,
    min_new_tokens=new_tokens,
    use_cache=True,
    pad_token_id=tokenizer.pad_token_id,
    num_beams=1,
    do_sample=False,
    eos_token_id=None,  # This is required for min_new_tokens to actually have an effect.
)
model.generation_config.eos_token_id = None  # greedy_search falls back on this eos_token_id that we need to set to None as well for min_new_tokens to have an effect.

gen_out = model.generate(**inputs, generation_config=gen_config)

decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)

print("decoded", decoded)

print("compiling...")

model.forward = torch.compile(model.forward, mode="reduce-overhead")
print("Finished compile call")

# warmup
gen_out = model.generate(**inputs, generation_config=gen_config, cache_implementation="static")

print("\n\n\n\n\n\n----- second call")
gen_out = model.generate(**inputs, generation_config=gen_config, cache_implementation="static")

decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)

print("decoded static", decoded)

which yields

image

Using a torch.Tensor updated in-place instead of int fixes the bug, however we then hit what I believe to be a torch.compile bug where subclasses are added after the torch.compile call (in the _setup_cache). Even with the above fix, there is still a bug where seen_tokens is not properly updated. By making sure _setup_cache is called BEFORE torch.compile, this issue disappears. However this required an API change, so disreguarding this approach.

Instead, remove seen_tokens altogether from StaticCache.

Results

On main (ee3af60):

-------------- STATIC CACHE
compiling...
torch.compile call: 703.207 ms
/home/felix/miniconda3/envs/fx/lib/python3.9/site-packages/torch/_inductor/compile_fx.py:148: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(

- 0-th `generate` call latency per token (new_tokens=20): 10591.283 ms

- 1-th `generate` call latency per token (new_tokens=20): 3284.225 ms

- 2-th `generate` call latency per token (new_tokens=20): 132.769 ms

- 3-th `generate` call latency per token (new_tokens=20): 11.211 ms

- 4-th `generate` call latency per token (new_tokens=20): 11.160 ms
decoded static ['I would like to know how to get a copy of my medical records from my primary care physician.\n', 'Today I am in Paris and I am feeling very grateful for this opportunity to explore this beautiful city. I have always wanted to visit']

On this branch (0c03b7d):

-------------- STATIC CACHE
compiling...
torch.compile call: 729.943 ms
/home/felix/miniconda3/envs/fx/lib/python3.9/site-packages/torch/_inductor/compile_fx.py:148: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(

- 0-th `generate` call latency per token (new_tokens=20): 4121.241 ms

- 1-th `generate` call latency per token (new_tokens=20): 239.070 ms

- 2-th `generate` call latency per token (new_tokens=20): 11.592 ms

- 3-th `generate` call latency per token (new_tokens=20): 11.602 ms

- 4-th `generate` call latency per token (new_tokens=20): 11.618 ms
decoded static ['I would like to know how to get a copy of my medical records from my primary care physician.\n', 'Today I am in Paris and I am feeling very grateful for this opportunity to explore this beautiful city. I have always wanted to visit']

@fxmarty fxmarty changed the title WIP: fix generate compatibility with torch.compile Make torch.compile compilation >2x faster when using static cache + generate Feb 20, 2024
@@ -1975,6 +1990,7 @@ def contrastive_search(
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
standardize_cache_format=True,
model_inputs=model_inputs,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to update model_kwargs with model_inputs["cache_position"], hence the additional argument

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!
Good to have faster compile, and remove the seen token api to rely on cache positions which is IMO less brittle!
⚡ compile is good. I just have to check my benchmarks for FA2 and compiled static cache. Your snippet goes from 11.2 to 11.8 seconds, acceptable IMO (as mentions this is probably the update causal mask not being compiled, would be nice to compile / just compile the part that post process it after the slicing!)

return k_out, v_out

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
return self.seen_tokens
# TODO: Fix once the stateful `int` bug in PyTorch is fixed.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure it will ever be fixed!
If fixed, this will only be used in prepare inputs for generation, but my plan forward is to rely on the cache_postiions entirely, not the state of the cache to know the iteration we are at in the generate function! cc @gante

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use of model_inputs is indeed cleaner for me, as we only keep track of the cached_positions and increment them = no need for seen tokens!
I'll let @gante validate it all, but LGTM

Comment on lines 977 to +980
if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
past_seen_tokens = past_key_values.get_seq_length()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we assume use_cache is used with cache_positions (generate) then we probably don't need that anymore do we?
Bit breaky but still

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker Overall we should not assume cache_positions is an input to the model, e.g. for ONNX this would break things

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not? If we decide this to be the new format is it not better for ONNX as well?

cache_position = torch.arange(
past_length, past_length + position_ids.shape[-1], device=position_ids.device
)
cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the cache positions are given as input in the kwargs, and the length is 1, we can just increment it no ? (no arange that way)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker I left it that way because I don't properly understand the relationship between position_ids, cache_position, especially for speculative decoding. Maybe it can be improved later.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

position_ids != cache_position if padding basically

Comment on lines 1261 to 1266
# The `contiguous()` here is necessary to have a static stride during (non-speculative) decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard.
# TODO: We don't really need to handle the input_ids here, and this contiguous() call could be removed if we were
# simply using GenerationMixin.greedy_search `next_tokens` variable directly (which is already contiguous), instead of
# doing a torch.cat + then slice.
model_inputs = {"input_ids": input_ids.contiguous()}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, next token would be less "costly" but I think we need all of them for speculative? Anyways noticed that as well when I was compiling, just let it be at the time!

fxmarty and others added 2 commits February 21, 2024 09:34
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Feb 21, 2024

torch._dynamo.exc.Unsupported: 'inline in skipfiles: LlamaModel._update_causal_mask | _fn /home/arthur/miniconda3/envs/py39/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py, skipped according skipfiles.SKIP_DIRS'

failing on torch2.2 let's wait a tad bit (that was full graph!)

@fxmarty
Copy link
Contributor Author

fxmarty commented Feb 21, 2024

Leaving this open for now, as we would like to avoid @torch.compiler.disable and keep compatibility with fullgraph=True.

There is likely a bug in PyTorch where CUDA graphs are rerecorded while they should not, so we can't simply remove @torch.compiler.disable.

@fxmarty fxmarty changed the title Make torch.compile compilation >2x faster when using static cache + generate torch.compile compatibility with generate + static cache Feb 21, 2024
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Looking forward, I think we can get rid of cache_position, there are many places where this information is present and one of them has to be compatible with torch.compile 😅

@gante
Copy link
Member

gante commented Feb 21, 2024

@fxmarty btw, I am working on PR that has some of the changes that you added here (such as resetting the cache after generate), we might get merge conflicts :)

@fxmarty
Copy link
Contributor Author

fxmarty commented Feb 21, 2024

Thank you @gante, awesome! Yes, I think there needs to be an alignment at some point between all different archs, it's getting a bit complex with all the different approaches.

At the end of the day after discussing with @ArthurZucker, merging but not cherry picking in the release. I removed the @torch.compiler.disable decorator for the reason above.

I believe there is a bug in PyTorch where cuda graphs are somehow rerecorded in the second pass.

- 0-th `generate` call latency per token (new_tokens=100): 744.209 ms

- 1-th `generate` call latency per token (new_tokens=100): 1773.975 ms

- 2-th `generate` call latency per token (new_tokens=100): 11.069 ms

- 3-th `generate` call latency per token (new_tokens=100): 11.042 ms

- 4-th `generate` call latency per token (new_tokens=100): 11.035 ms
decoded static ["I would like to know how to get a copy of my medical records from my primary care physician.\nI would like to know how to get a copy of my medical records from my primary care physician.\nGetting a copy of your medical records from your primary care physician can be a straightforward process, but it's important to follow the proper steps to ensure you receive a complete and accurate copy of your records. Here are the general steps you can take:\n\n1. Contact your", "Today I am in Paris and I am feeling very grateful for this opportunity to explore this beautiful city. I have always wanted to visit Paris and now I am finally here, and it is even more beautiful than I imagined. The Eiffel Tower is stunning, the Louvre is incredible, and the food is delicious. I am soaking up every moment and making the most of my time here. I can't wait to see what the rest of the trip has in store for me. #grateful"]

@fxmarty
Copy link
Contributor Author

fxmarty commented Feb 21, 2024

^ reference for this pytorch/pytorch#120309

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants