Skip to content

Commit

Permalink
torch.compile compatibility with generate + static cache (#29114)
Browse files Browse the repository at this point in the history
* fix compatibility

* working version

* cleanup

* sanity checks

* more sanity

* working version WITH refactor

* working without API change

* cleanup & tests pass

* more cleaning

* fix test

* fix tests

* Update src/transformers/generation/utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* smaller comment

* update comment

* update comment

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
  • Loading branch information
fxmarty and ArthurZucker authored Feb 21, 2024
1 parent 3994fa5 commit cc4a664
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 51 deletions.
12 changes: 8 additions & 4 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,6 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len:
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.seen_tokens = 0

def update(
self,
Expand Down Expand Up @@ -391,15 +390,20 @@ def update(
k_out[:, :, new_cache_positions] = key_states
v_out[:, :, new_cache_positions] = value_states

self.seen_tokens += key_states.shape[2]
return k_out, v_out

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
return self.seen_tokens
# TODO: Fix once the stateful `int` bug in PyTorch is fixed.
raise ValueError(
"get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114."
)

def get_usable_length(self, new_sequence_length=None, layer_idx: Optional[int] = 0) -> int:
return self.seen_tokens
# TODO: Fix once the stateful `int` bug in PyTorch is fixed.
raise ValueError(
"get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114."
)

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
Expand Down
73 changes: 46 additions & 27 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,7 @@ def _update_model_kwargs_for_generation(
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
model_inputs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
# update past_key_values
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
Expand Down Expand Up @@ -677,6 +678,8 @@ def _update_model_kwargs_for_generation(
dim=-1,
)

model_kwargs["cache_position"] = model_inputs.get("cache_position", None)

return model_kwargs

def _reorder_cache(self, past_key_values, beam_idx):
Expand Down Expand Up @@ -1451,17 +1454,19 @@ def generate(
):
generation_config.max_length -= inputs_tensor.shape[1]

# if we don't pass `past_key_values` and a cache_implementation is specified
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING and not model_kwargs.get(
"past_key_values", False
):
cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING[generation_config.cache_implementation]
if not callable(getattr(self, "_setup_cache", None)):
raise ValueError(
"The `generation_config` defines a `cache_implementation` that is not compatible with this model."
" Make sure it has a `_setup_cache` function."
)
self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length)
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if generation_config.cache_implementation == "static":
if model_kwargs.get("past_key_values", False) is not False:
raise ValueError(
"Using `past_key_values` argument with `generate()` when using a static KV cache is not supported. Please open an issue in Transformers GitHub repository."
)
cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING["static"]
if not callable(getattr(self, "_setup_cache", None)):
raise ValueError(
"The `generation_config` defines a `cache_implementation` that is not compatible with this model."
" Make sure it has a `_setup_cache` function."
)
self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length)

self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)

Expand Down Expand Up @@ -1523,7 +1528,7 @@ def generate(
)

# 12. run assisted generate
return self.assisted_decoding(
result = self.assisted_decoding(
input_ids,
candidate_generator=candidate_generator,
do_sample=generation_config.do_sample,
Expand All @@ -1541,7 +1546,7 @@ def generate(
)
if generation_mode == GenerationMode.GREEDY_SEARCH:
# 11. run greedy search
return self.greedy_search(
result = self.greedy_search(
input_ids,
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
Expand All @@ -1559,7 +1564,7 @@ def generate(
if not model_kwargs["use_cache"]:
raise ValueError("Contrastive search requires `use_cache=True`")

return self.contrastive_search(
result = self.contrastive_search(
input_ids,
top_k=generation_config.top_k,
penalty_alpha=generation_config.penalty_alpha,
Expand Down Expand Up @@ -1589,7 +1594,7 @@ def generate(
)

# 13. run sample
return self.sample(
result = self.sample(
input_ids,
logits_processor=prepared_logits_processor,
logits_warper=logits_warper,
Expand Down Expand Up @@ -1623,7 +1628,7 @@ def generate(
**model_kwargs,
)
# 13. run beam search
return self.beam_search(
result = self.beam_search(
input_ids,
beam_scorer,
logits_processor=prepared_logits_processor,
Expand Down Expand Up @@ -1662,7 +1667,7 @@ def generate(
)

# 14. run beam sample
return self.beam_sample(
result = self.beam_sample(
input_ids,
beam_scorer,
logits_processor=prepared_logits_processor,
Expand Down Expand Up @@ -1697,7 +1702,7 @@ def generate(
**model_kwargs,
)
# 13. run beam search
return self.group_beam_search(
result = self.group_beam_search(
input_ids,
beam_scorer,
logits_processor=prepared_logits_processor,
Expand Down Expand Up @@ -1771,7 +1776,7 @@ def typeerror():
**model_kwargs,
)
# 13. run beam search
return self.constrained_beam_search(
result = self.constrained_beam_search(
input_ids,
constrained_beam_scorer=constrained_beam_scorer,
logits_processor=prepared_logits_processor,
Expand All @@ -1785,6 +1790,16 @@ def typeerror():
**model_kwargs,
)

if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if not callable(getattr(self, "_reset_cache", None)):
raise ValueError(
"A `static_cache` was used to generate but there was a failure when trying to release the cache. "
" Make sure this model implements a `_reset_cache` function."
)
self._reset_cache()

return result

@torch.no_grad()
def contrastive_search(
self,
Expand Down Expand Up @@ -1975,6 +1990,7 @@ def contrastive_search(
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
standardize_cache_format=True,
model_inputs=model_inputs,
)
if not sequential:
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
Expand Down Expand Up @@ -2169,7 +2185,7 @@ def contrastive_search(
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
)

# if eos_token was found in one sentence, set sentence to finished
Expand Down Expand Up @@ -2450,7 +2466,10 @@ def greedy_search(
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
model_inputs=model_inputs,
)

# if eos_token was found in one sentence, set sentence to finished
Expand Down Expand Up @@ -2744,7 +2763,7 @@ def sample(
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
)

# if eos_token was found in one sentence, set sentence to finished
Expand Down Expand Up @@ -3137,7 +3156,7 @@ def beam_search(
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
Expand Down Expand Up @@ -3484,7 +3503,7 @@ def beam_sample(
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
Expand Down Expand Up @@ -3883,7 +3902,7 @@ def group_beam_search(
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)

model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
Expand Down Expand Up @@ -4235,7 +4254,7 @@ def constrained_beam_search(

input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
Expand Down Expand Up @@ -4642,7 +4661,7 @@ def assisted_decoding(
)

model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
)

# if eos_token was found in one sentence, set sentence to finished
Expand Down
42 changes: 22 additions & 20 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,7 @@ def forward(
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

# In case static cache is used, it is an instance attribute.
past_key_value = getattr(self, "past_key_value", past_key_value)

if past_key_value is not None:
Expand Down Expand Up @@ -969,9 +970,11 @@ def forward(
if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
past_seen_tokens = past_key_values.get_seq_length()

if cache_position is None:
if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.")
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
Expand Down Expand Up @@ -1043,6 +1046,10 @@ def forward(
attentions=all_self_attns,
)

# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
Expand All @@ -1058,16 +1065,8 @@ def _update_causal_mask(self, attention_mask, input_tensor):
causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)

if hasattr(self, "causal_mask"): # we use the current dtype to avoid any overflows
causal_mask = (
self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min
)
else:
mask = torch.full(
(self.config.max_position_embeddings, self.config.max_position_embeddings),
fill_value=torch.finfo(dtype).min,
)
causal_mask = torch.triu(mask, diagonal=1)
# We use the current dtype to avoid any overflows
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min

causal_mask = causal_mask.to(dtype=dtype, device=device)
if attention_mask is not None and attention_mask.dim() == 2:
Expand Down Expand Up @@ -1253,29 +1252,32 @@ def prepare_inputs_for_generation(
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
if getattr(self.model.layers[0].self_attn, "past_key_value", None) is not None:
# generation with static cache
past_length = past_key_value.get_seq_length()
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
past_length = 0
else:
past_length = cache_position[-1] + 1
input_ids = input_ids[:, past_length:]
position_ids = position_ids[:, past_length:]

# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
cache_position = torch.arange(
past_length, past_length + position_ids.shape[-1], device=position_ids.device
)
cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
# TODO: use `next_tokens` directly instead.
model_inputs = {"input_ids": input_ids.contiguous()}

model_inputs.update(
{
"position_ids": position_ids,
"position_ids": position_ids.contiguous(),
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
Expand Down

0 comments on commit cc4a664

Please sign in to comment.