Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile compatibility with generate + static cache #29114

Merged
merged 16 commits into from
Feb 21, 2024
9 changes: 6 additions & 3 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,9 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len:
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.seen_tokens = 0

# NOTE: self.seen_tokens being in an int results in bugs with torch.compile, where it is somehow not updated.
self.seen_tokens = torch.tensor(0, dtype=torch.int64, device=device)

def update(
self,
Expand Down Expand Up @@ -390,8 +392,9 @@ def update(

k_out[:, :, new_cache_positions] = key_states
v_out[:, :, new_cache_positions] = value_states

self.seen_tokens += key_states.shape[2]

# # This NEEDS to be in-place as in the modeling we are not calling directly `self.past_key_value.update()`, but are rather using getattr.
self.seen_tokens.add_(key_states.shape[2])
return k_out, v_out

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/generation/utils.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use of model_inputs is indeed cleaner for me, as we only keep track of the cached_positions and increment them = no need for seen tokens!
I'll let @gante validate it all, but LGTM

Original file line number Diff line number Diff line change
Expand Up @@ -2343,6 +2343,7 @@ def greedy_search(
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)

this_peer_finished = False # used by synced_gpus only
count = 0
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
Expand All @@ -2354,9 +2355,13 @@ def greedy_search(
if this_peer_finished_flag.item() == 0.0:
break

print("------------ call forward", count)
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

print("model_inputs intut_ids shape", model_inputs["input_ids"].shape)
print("model_inputs intut_ids stride", model_inputs["input_ids"].stride())
count += 1
# forward pass to get next token
outputs = self(
**model_inputs,
Expand Down Expand Up @@ -2394,6 +2399,8 @@ def greedy_search(
# argmax
next_tokens = torch.argmax(next_tokens_scores, dim=-1)

print("next_tokens", next_tokens.shape)

# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
Expand Down
16 changes: 14 additions & 2 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,7 @@ def forward(
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)

# In case static cache is used, it is an instance attribute.
past_key_value = getattr(self, "past_key_value", past_key_value)

if past_key_value is not None:
Expand Down Expand Up @@ -961,6 +962,10 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

# print("attention_mask", attention_mask.shape)
# print("attention_mask", attention_mask.stride())
# print("inputs_embeds", inputs_embeds.shape)
# print("inputs_embeds", inputs_embeds.stride())
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)

# embed positions
Expand Down Expand Up @@ -1224,7 +1229,7 @@ def prepare_inputs_for_generation(
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
Expand All @@ -1238,6 +1243,8 @@ def prepare_inputs_for_generation(
past_length = past_key_value.get_seq_length()
input_ids = input_ids[:, past_length:]
position_ids = position_ids[:, past_length:]

print("input_ids after update", input_ids.shape)

# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
Expand All @@ -1251,7 +1258,12 @@ def prepare_inputs_for_generation(
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
# The `contiguous()` here is necessary to have a static stride during (non-speculative) decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard.
# TODO: We don't really need to handle the input_ids here, and this contiguous() call could be removed if we were
# simply using GenerationMixin.greedy_search `next_tokens` variable directly (which is already contiguous), instead of
# doing a torch.cat + then slice.
model_inputs = {"input_ids": input_ids.contiguous()}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per the comment. If input_ids is not contiguous, its stride is different at each forward call in generate, which triggers a recompilation at every step in the loop. This is very slow.

Ideally we would want to just use next_tokens instead of having this contiguous() call, but let's do a proper refactor in an other PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, next token would be less "costly" but I think we need all of them for speculative? Anyways noticed that as well when I was compiling, just let it be at the time!


model_inputs.update(
{
Expand Down
Loading