Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile compatibility with generate + static cache #29114

Merged
merged 16 commits into from
Feb 21, 2024
12 changes: 8 additions & 4 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,6 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len:
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.seen_tokens = 0

def update(
self,
Expand Down Expand Up @@ -391,15 +390,20 @@ def update(
k_out[:, :, new_cache_positions] = key_states
v_out[:, :, new_cache_positions] = value_states

self.seen_tokens += key_states.shape[2]
return k_out, v_out

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
return self.seen_tokens
# TODO: Fix once the stateful `int` bug in PyTorch is fixed.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure it will ever be fixed!
If fixed, this will only be used in prepare inputs for generation, but my plan forward is to rely on the cache_postiions entirely, not the state of the cache to know the iteration we are at in the generate function! cc @gante

raise ValueError(
"get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114."
)

def get_usable_length(self, new_sequence_length=None, layer_idx: Optional[int] = 0) -> int:
return self.seen_tokens
# TODO: Fix once the stateful `int` bug in PyTorch is fixed.
raise ValueError(
"get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114."
)

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
Expand Down
73 changes: 46 additions & 27 deletions src/transformers/generation/utils.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use of model_inputs is indeed cleaner for me, as we only keep track of the cached_positions and increment them = no need for seen tokens!
I'll let @gante validate it all, but LGTM

Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,7 @@ def _update_model_kwargs_for_generation(
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
model_inputs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
# update past_key_values
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
Expand Down Expand Up @@ -677,6 +678,8 @@ def _update_model_kwargs_for_generation(
dim=-1,
)

model_kwargs["cache_position"] = model_inputs.get("cache_position", None)

return model_kwargs

def _reorder_cache(self, past_key_values, beam_idx):
Expand Down Expand Up @@ -1451,17 +1454,19 @@ def generate(
):
generation_config.max_length -= inputs_tensor.shape[1]

# if we don't pass `past_key_values` and a cache_implementation is specified
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING and not model_kwargs.get(
"past_key_values", False
):
cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING[generation_config.cache_implementation]
if not callable(getattr(self, "_setup_cache", None)):
raise ValueError(
"The `generation_config` defines a `cache_implementation` that is not compatible with this model."
" Make sure it has a `_setup_cache` function."
)
self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length)
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if generation_config.cache_implementation == "static":
if model_kwargs.get("past_key_values", False) is not False:
raise ValueError(
"Using `past_key_values` argument with `generate()` when using a static KV cache is not supported. Please open an issue in Transformers GitHub repository."
)
cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING["static"]
if not callable(getattr(self, "_setup_cache", None)):
raise ValueError(
"The `generation_config` defines a `cache_implementation` that is not compatible with this model."
" Make sure it has a `_setup_cache` function."
)
self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length)

self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)

Expand Down Expand Up @@ -1523,7 +1528,7 @@ def generate(
)

# 12. run assisted generate
return self.assisted_decoding(
result = self.assisted_decoding(
input_ids,
candidate_generator=candidate_generator,
do_sample=generation_config.do_sample,
Expand All @@ -1541,7 +1546,7 @@ def generate(
)
if generation_mode == GenerationMode.GREEDY_SEARCH:
# 11. run greedy search
return self.greedy_search(
result = self.greedy_search(
input_ids,
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
Expand All @@ -1559,7 +1564,7 @@ def generate(
if not model_kwargs["use_cache"]:
raise ValueError("Contrastive search requires `use_cache=True`")

return self.contrastive_search(
result = self.contrastive_search(
input_ids,
top_k=generation_config.top_k,
penalty_alpha=generation_config.penalty_alpha,
Expand Down Expand Up @@ -1589,7 +1594,7 @@ def generate(
)

# 13. run sample
return self.sample(
result = self.sample(
input_ids,
logits_processor=prepared_logits_processor,
logits_warper=logits_warper,
Expand Down Expand Up @@ -1623,7 +1628,7 @@ def generate(
**model_kwargs,
)
# 13. run beam search
return self.beam_search(
result = self.beam_search(
input_ids,
beam_scorer,
logits_processor=prepared_logits_processor,
Expand Down Expand Up @@ -1662,7 +1667,7 @@ def generate(
)

# 14. run beam sample
return self.beam_sample(
result = self.beam_sample(
input_ids,
beam_scorer,
logits_processor=prepared_logits_processor,
Expand Down Expand Up @@ -1697,7 +1702,7 @@ def generate(
**model_kwargs,
)
# 13. run beam search
return self.group_beam_search(
result = self.group_beam_search(
input_ids,
beam_scorer,
logits_processor=prepared_logits_processor,
Expand Down Expand Up @@ -1771,7 +1776,7 @@ def typeerror():
**model_kwargs,
)
# 13. run beam search
return self.constrained_beam_search(
result = self.constrained_beam_search(
input_ids,
constrained_beam_scorer=constrained_beam_scorer,
logits_processor=prepared_logits_processor,
Expand All @@ -1785,6 +1790,16 @@ def typeerror():
**model_kwargs,
)

if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if not callable(getattr(self, "_reset_cache", None)):
raise ValueError(
"A `static_cache` was used to generate but there was a failure when trying to release the cache. "
" Make sure this model implements a `_reset_cache` function."
)
self._reset_cache()

return result

@torch.no_grad()
def contrastive_search(
self,
Expand Down Expand Up @@ -1975,6 +1990,7 @@ def contrastive_search(
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
standardize_cache_format=True,
model_inputs=model_inputs,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to update model_kwargs with model_inputs["cache_position"], hence the additional argument

)
if not sequential:
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
Expand Down Expand Up @@ -2169,7 +2185,7 @@ def contrastive_search(
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
)

# if eos_token was found in one sentence, set sentence to finished
Expand Down Expand Up @@ -2450,7 +2466,10 @@ def greedy_search(
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
model_inputs=model_inputs,
)

# if eos_token was found in one sentence, set sentence to finished
Expand Down Expand Up @@ -2744,7 +2763,7 @@ def sample(
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
)

# if eos_token was found in one sentence, set sentence to finished
Expand Down Expand Up @@ -3137,7 +3156,7 @@ def beam_search(
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
Expand Down Expand Up @@ -3484,7 +3503,7 @@ def beam_sample(
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
Expand Down Expand Up @@ -3883,7 +3902,7 @@ def group_beam_search(
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)

model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
Expand Down Expand Up @@ -4235,7 +4254,7 @@ def constrained_beam_search(

input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
Expand Down Expand Up @@ -4642,7 +4661,7 @@ def assisted_decoding(
)

model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
)

# if eos_token was found in one sentence, set sentence to finished
Expand Down
42 changes: 22 additions & 20 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,7 @@ def forward(
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)

# In case static cache is used, it is an instance attribute.
past_key_value = getattr(self, "past_key_value", past_key_value)

if past_key_value is not None:
Expand Down Expand Up @@ -976,9 +977,11 @@ def forward(
if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
past_seen_tokens = past_key_values.get_seq_length()
Comment on lines 977 to +980
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we assume use_cache is used with cache_positions (generate) then we probably don't need that anymore do we?
Bit breaky but still

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker Overall we should not assume cache_positions is an input to the model, e.g. for ONNX this would break things

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not? If we decide this to be the new format is it not better for ONNX as well?


if cache_position is None:
if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.")
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
Expand Down Expand Up @@ -1050,6 +1053,10 @@ def forward(
attentions=all_self_attns,
)

# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
Expand All @@ -1065,16 +1072,8 @@ def _update_causal_mask(self, attention_mask, input_tensor):
causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)

if hasattr(self, "causal_mask"): # we use the current dtype to avoid any overflows
causal_mask = (
self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min
)
else:
mask = torch.full(
(self.config.max_position_embeddings, self.config.max_position_embeddings),
fill_value=torch.finfo(dtype).min,
)
causal_mask = torch.triu(mask, diagonal=1)
# We use the current dtype to avoid any overflows
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min

causal_mask = causal_mask.to(dtype=dtype, device=device)
if attention_mask is not None and attention_mask.dim() == 2:
Expand Down Expand Up @@ -1260,29 +1259,32 @@ def prepare_inputs_for_generation(
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
if getattr(self.model.layers[0].self_attn, "past_key_value", None) is not None:
# generation with static cache
past_length = past_key_value.get_seq_length()
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
past_length = 0
else:
past_length = cache_position[-1] + 1
input_ids = input_ids[:, past_length:]
position_ids = position_ids[:, past_length:]

# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
cache_position = torch.arange(
past_length, past_length + position_ids.shape[-1], device=position_ids.device
)
cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the cache positions are given as input in the kwargs, and the length is 1, we can just increment it no ? (no arange that way)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker I left it that way because I don't properly understand the relationship between position_ids, cache_position, especially for speculative decoding. Maybe it can be improved later.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

position_ids != cache_position if padding basically


# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
# TODO: use `next_tokens` directly instead.
model_inputs = {"input_ids": input_ids.contiguous()}

model_inputs.update(
{
"position_ids": position_ids,
"position_ids": position_ids.contiguous(),
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
Expand Down
Loading