-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.compile
compatibility with generate
+ static cache
#29114
Conversation
# The `contiguous()` here is necessary to have a static stride during (non-speculative) decoding. torchdynamo otherwise | ||
# recompiles graphs as the stride of the inputs is a guard. | ||
# TODO: We don't really need to handle the input_ids here, and this contiguous() call could be removed if we were | ||
# simply using GenerationMixin.greedy_search `next_tokens` variable directly (which is already contiguous), instead of | ||
# doing a torch.cat + then slice. | ||
model_inputs = {"input_ids": input_ids.contiguous()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As per the comment. If input_ids
is not contiguous, its stride is different at each forward call in generate, which triggers a recompilation at every step in the loop. This is very slow.
Ideally we would want to just use next_tokens
instead of having this contiguous()
call, but let's do a proper refactor in an other PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, next token would be less "costly" but I think we need all of them for speculative? Anyways noticed that as well when I was compiling, just let it be at the time!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The int vs tensor « bug » is expected, that is precisely why we use « cache_positions » for generation.
‘I think generate needs to handle that because I’m not sure compile will like the in place tensor modification
@ArthurZucker @gante @LysandreJik This PR fixes many issues with the current 1. Always keep the same stride for inputs in the decode phase
This is bad because with torch.compile there are guards on the stride of the inputs, and thus recompilation is triggered in the decode phase while this is really not necessary.
&
2. Do not compile
|
generate
compatibility with torch.compilegenerate
@@ -1975,6 +1990,7 @@ def contrastive_search( | |||
model_kwargs, | |||
is_encoder_decoder=self.config.is_encoder_decoder, | |||
standardize_cache_format=True, | |||
model_inputs=model_inputs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need to update model_kwargs
with model_inputs["cache_position"]
, hence the additional argument
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
Good to have faster compile, and remove the seen token api to rely on cache positions which is IMO less brittle!
⚡ compile is good. I just have to check my benchmarks for FA2 and compiled static cache. Your snippet goes from 11.2 to 11.8 seconds, acceptable IMO (as mentions this is probably the update causal mask not being compiled, would be nice to compile / just compile the part that post process it after the slicing!)
return k_out, v_out | ||
|
||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | ||
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" | ||
return self.seen_tokens | ||
# TODO: Fix once the stateful `int` bug in PyTorch is fixed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure it will ever be fixed!
If fixed, this will only be used in prepare inputs for generation, but my plan forward is to rely on the cache_postiions
entirely, not the state of the cache to know the iteration we are at in the generate function! cc @gante
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use of model_inputs
is indeed cleaner for me, as we only keep track of the cached_positions and increment them = no need for seen tokens!
I'll let @gante validate it all, but LGTM
if use_cache: # kept for BC (cache positions) | ||
if not isinstance(past_key_values, StaticCache): | ||
past_key_values = DynamicCache.from_legacy_cache(past_key_values) | ||
past_seen_tokens = past_key_values.get_seq_length() | ||
past_seen_tokens = past_key_values.get_seq_length() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we assume use_cache
is used with cache_positions
(generate) then we probably don't need that anymore do we?
Bit breaky but still
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker Overall we should not assume cache_positions
is an input to the model, e.g. for ONNX this would break things
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not? If we decide this to be the new format is it not better for ONNX as well?
cache_position = torch.arange( | ||
past_length, past_length + position_ids.shape[-1], device=position_ids.device | ||
) | ||
cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the cache positions are given as input in the kwargs, and the length is 1, we can just increment it no ? (no arange that way)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker I left it that way because I don't properly understand the relationship between position_ids
, cache_position
, especially for speculative decoding. Maybe it can be improved later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
position_ids != cache_position if padding basically
# The `contiguous()` here is necessary to have a static stride during (non-speculative) decoding. torchdynamo otherwise | ||
# recompiles graphs as the stride of the inputs is a guard. | ||
# TODO: We don't really need to handle the input_ids here, and this contiguous() call could be removed if we were | ||
# simply using GenerationMixin.greedy_search `next_tokens` variable directly (which is already contiguous), instead of | ||
# doing a torch.cat + then slice. | ||
model_inputs = {"input_ids": input_ids.contiguous()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, next token would be less "costly" but I think we need all of them for speculative? Anyways noticed that as well when I was compiling, just let it be at the time!
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
torch._dynamo.exc.Unsupported: 'inline in skipfiles: LlamaModel._update_causal_mask | _fn /home/arthur/miniconda3/envs/py39/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py, skipped according skipfiles.SKIP_DIRS' failing on torch2.2 let's wait a tad bit (that was full graph!) |
Leaving this open for now, as we would like to avoid There is likely a bug in PyTorch where CUDA graphs are rerecorded while they should not, so we can't simply remove |
generate
torch.compile
compatibility with generate
+ static cache
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Looking forward, I think we can get rid of cache_position
, there are many places where this information is present and one of them has to be compatible with torch.compile
😅
@fxmarty btw, I am working on PR that has some of the changes that you added here (such as resetting the cache after generate), we might get merge conflicts :) |
Thank you @gante, awesome! Yes, I think there needs to be an alignment at some point between all different archs, it's getting a bit complex with all the different approaches. At the end of the day after discussing with @ArthurZucker, merging but not cherry picking in the release. I removed the I believe there is a bug in PyTorch where cuda graphs are somehow rerecorded in the second pass.
|
^ reference for this pytorch/pytorch#120309 |
This PR adds the support of
torch.compile
when usinggenerate
+ static cache.