Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix static generation when compiling! #28937

Merged
merged 42 commits into from
Feb 15, 2024
Merged
Changes from 1 commit
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
2187685
wow I was scared!
ArthurZucker Feb 9, 2024
4922c92
fix everything
ArthurZucker Feb 9, 2024
56768a0
nits
ArthurZucker Feb 9, 2024
b565051
make it BC?
ArthurZucker Feb 12, 2024
99afd1a
add todo
ArthurZucker Feb 12, 2024
edc498f
nits
ArthurZucker Feb 12, 2024
651c4bd
is_tracing should still be used to pass tracing tests
ArthurZucker Feb 12, 2024
f69626e
nits
ArthurZucker Feb 12, 2024
96136ac
some nits to make sure genration works with static cache uncompiled
ArthurZucker Feb 12, 2024
d5ebd80
fix sdpa
ArthurZucker Feb 12, 2024
70adcf6
fix FA2 for both static and dynamic in a better way?
ArthurZucker Feb 14, 2024
61ed4cb
style
ArthurZucker Feb 14, 2024
fedc563
fix-copies
ArthurZucker Feb 14, 2024
0195d58
fix fix copies
ArthurZucker Feb 14, 2024
07f3adb
fix sequential beam searcg
ArthurZucker Feb 14, 2024
9402c25
style
ArthurZucker Feb 14, 2024
86303c4
use `keys_to_ignore`
ArthurZucker Feb 14, 2024
fb9e907
nit
ArthurZucker Feb 14, 2024
9aa667e
correct dtype inference when init
ArthurZucker Feb 14, 2024
68a5f29
:( the fix for FA2 is still not optimal to investigate!
ArthurZucker Feb 14, 2024
3b9969b
styling
ArthurZucker Feb 14, 2024
162ab87
Merge branch 'main' of github.com:huggingface/transformers into fix-s…
ArthurZucker Feb 14, 2024
914b0d7
nits
ArthurZucker Feb 14, 2024
e79f79f
nit
ArthurZucker Feb 14, 2024
ee2317d
this might work better
ArthurZucker Feb 14, 2024
93b2691
add comment
ArthurZucker Feb 14, 2024
3619ed3
Update src/transformers/models/llama/modeling_llama.py
ArthurZucker Feb 14, 2024
c23cdc4
"position_ids" -> "cache_position"
ArthurZucker Feb 14, 2024
717a8e7
style
ArthurZucker Feb 14, 2024
7fe0964
Merge branch 'main' of github.com:huggingface/transformers into fix-s…
ArthurZucker Feb 14, 2024
464c463
Merge branch 'main' of github.com:huggingface/transformers into fix-s…
ArthurZucker Feb 15, 2024
80148ab
nit
ArthurZucker Feb 15, 2024
c9f3c82
Remove changes that should no be propagatted just yet
ArthurZucker Feb 15, 2024
5f54d84
Apply suggestions from code review
ArthurZucker Feb 15, 2024
b3fc042
Styling
ArthurZucker Feb 15, 2024
5fdb2da
make sure we raise an errir for static cache with FA2 enabled
ArthurZucker Feb 15, 2024
03edf91
move to the bottom of the signature
ArthurZucker Feb 15, 2024
b762304
style
ArthurZucker Feb 15, 2024
9fbe901
Update src/transformers/models/llama/modeling_llama.py
ArthurZucker Feb 15, 2024
7afe7d9
Update src/transformers/models/llama/modeling_llama.py
ArthurZucker Feb 15, 2024
3772d1c
nit in the name
ArthurZucker Feb 15, 2024
cf0bc32
Merge branches 'fix-static-kv-cache' and 'fix-static-kv-cache' of git…
ArthurZucker Feb 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
is_tracing should still be used to pass tracing tests
ArthurZucker committed Feb 12, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 651c4bd808c5ad2dd488ca4ca37f1b2984c0da56
35 changes: 21 additions & 14 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, SinkCache
from ...cache_utils import Cache, DynamicCache
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@@ -349,7 +349,7 @@ def forward(
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[ :, :, cache_position, : key_states.shape[-2]]
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

# upcast attention to fp32
@@ -423,7 +423,7 @@ def forward(

cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)

past_key_value = getattr(self, "past_key_value", past_key_value)

if past_key_value is not None:
@@ -621,7 +621,7 @@ def forward(

cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)

past_key_value = getattr(self, "past_key_value", past_key_value)

if past_key_value is not None:
@@ -939,14 +939,18 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

past_seen_tokens = 0
past_seen_tokens = 0
if use_cache and not isinstance(past_key_values, Cache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_usable_length(inputs_embeds.shape[1]) # kept for BC (cache positions)

past_seen_tokens = past_key_values.get_usable_length(
inputs_embeds.shape[1]
) # kept for BC (cache positions)

if cache_position is None:
cache_position = torch.arange(past_seen_tokens, past_seen_tokens+inputs_embeds.shape[1], device=inputs_embeds.device)

cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)

@@ -1047,8 +1051,11 @@ def _update_causal_mask(self, attention_mask, input_tensor):
)

if self.config._attn_implementation == "sdpa":
if attention_mask is not None and torch.any(attention_mask != 1):
causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to(dtype)
is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy)
if not is_tracing and attention_mask is not None and torch.any(attention_mask != 1):
causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to(
dtype
)

return causal_mask

@@ -1225,11 +1232,11 @@ def prepare_inputs_for_generation(
input_ids = input_ids[:, past_length:]
position_ids = position_ids[:, past_length:]

# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
cache_position = torch.arange(past_length, past_length+input_ids.shape[1])
cache_position = torch.arange(past_length, past_length + input_ids.shape[1])

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
@@ -1240,7 +1247,7 @@ def prepare_inputs_for_generation(
model_inputs.update(
{
"position_ids": position_ids,
"cache_position":cache_position,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,