From 6d669eea44603b37915d9b1ae51ec2620f9c938e Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Mon, 19 Feb 2024 18:01:36 +0100 Subject: [PATCH 01/15] fix compatibility --- src/transformers/cache_utils.py | 9 ++++++--- src/transformers/generation/utils.py | 7 +++++++ src/transformers/models/llama/modeling_llama.py | 16 ++++++++++++++-- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index abdc3c7c0707..70505694fe99 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -357,7 +357,9 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) - self.seen_tokens = 0 + + # NOTE: self.seen_tokens being in an int results in bugs with torch.compile, where it is somehow not updated. + self.seen_tokens = torch.tensor(0, dtype=torch.int64, device=device) def update( self, @@ -390,8 +392,9 @@ def update( k_out[:, :, new_cache_positions] = key_states v_out[:, :, new_cache_positions] = value_states - - self.seen_tokens += key_states.shape[2] + + # # This NEEDS to be in-place as in the modeling we are not calling directly `self.past_key_value.update()`, but are rather using getattr. + self.seen_tokens.add_(key_states.shape[2]) return k_out, v_out def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 0c6740b32388..7b99383e6517 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2343,6 +2343,7 @@ def greedy_search( unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) this_peer_finished = False # used by synced_gpus only + count = 0 while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. @@ -2354,9 +2355,13 @@ def greedy_search( if this_peer_finished_flag.item() == 0.0: break + print("------------ call forward", count) # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + print("model_inputs intut_ids shape", model_inputs["input_ids"].shape) + print("model_inputs intut_ids stride", model_inputs["input_ids"].stride()) + count += 1 # forward pass to get next token outputs = self( **model_inputs, @@ -2394,6 +2399,8 @@ def greedy_search( # argmax next_tokens = torch.argmax(next_tokens_scores, dim=-1) + print("next_tokens", next_tokens.shape) + # finished sentences should have their next token be a padding token if eos_token_id is not None: if pad_token_id is None: diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index c30be2a2da4f..812f218dd923 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -623,6 +623,7 @@ def forward( cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) + # In case static cache is used, it is an instance attribute. past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: @@ -961,6 +962,10 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) + # print("attention_mask", attention_mask.shape) + # print("attention_mask", attention_mask.stride()) + # print("inputs_embeds", inputs_embeds.shape) + # print("inputs_embeds", inputs_embeds.stride()) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) # embed positions @@ -1224,7 +1229,7 @@ def prepare_inputs_for_generation( and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] - + position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation @@ -1238,6 +1243,8 @@ def prepare_inputs_for_generation( past_length = past_key_value.get_seq_length() input_ids = input_ids[:, past_length:] position_ids = position_ids[:, past_length:] + + print("input_ids after update", input_ids.shape) # TODO @gante we should only keep a `cache_position` in generate, and do +=1. # same goes for position ids. Could also help with continued generation. @@ -1251,7 +1258,12 @@ def prepare_inputs_for_generation( if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} + # The `contiguous()` here is necessary to have a static stride during (non-speculative) decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # TODO: We don't really need to handle the input_ids here, and this contiguous() call could be removed if we were + # simply using GenerationMixin.greedy_search `next_tokens` variable directly (which is already contiguous), instead of + # doing a torch.cat + then slice. + model_inputs = {"input_ids": input_ids.contiguous()} model_inputs.update( { From 489105094982e87d8227e9119cf839f2a1481949 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 20 Feb 2024 14:28:24 +0100 Subject: [PATCH 02/15] working version --- src/transformers/cache_utils.py | 1 + src/transformers/generation/utils.py | 15 +++++---- .../models/llama/modeling_llama.py | 33 ++++++++----------- 3 files changed, 24 insertions(+), 25 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 70505694fe99..db7f3328933c 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -394,6 +394,7 @@ def update( v_out[:, :, new_cache_positions] = value_states # # This NEEDS to be in-place as in the modeling we are not calling directly `self.past_key_value.update()`, but are rather using getattr. + # print("update seen_tokens with", key_states.shape[2]) self.seen_tokens.add_(key_states.shape[2]) return k_out, v_out diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 7b99383e6517..5a80fe54aacc 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1432,6 +1432,7 @@ def generate( ): generation_config.max_length -= inputs_tensor.shape[1] + """ # if we don't pass `past_key_values` and a cache_implementation is specified if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING and not model_kwargs.get( "past_key_values", False @@ -1445,7 +1446,8 @@ def generate( self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) - + """ + # 7. determine generation mode generation_mode = self._get_generation_mode(generation_config, assistant_model) @@ -2345,6 +2347,7 @@ def greedy_search( this_peer_finished = False # used by synced_gpus only count = 0 while True: + print("----- call forward") if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence @@ -2355,12 +2358,14 @@ def greedy_search( if this_peer_finished_flag.item() == 0.0: break - print("------------ call forward", count) # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - print("model_inputs intut_ids shape", model_inputs["input_ids"].shape) - print("model_inputs intut_ids stride", model_inputs["input_ids"].stride()) + print("model_inputs input ids shape", model_inputs["input_ids"].shape) + print("model_inputs input ids stride", model_inputs["input_ids"].stride()) + print("model_inputs attention_mask shape", model_inputs["attention_mask"].shape) + print("model_inputs attention_mask stride", model_inputs["attention_mask"].stride()) + count += 1 # forward pass to get next token outputs = self( @@ -2399,8 +2404,6 @@ def greedy_search( # argmax next_tokens = torch.argmax(next_tokens_scores, dim=-1) - print("next_tokens", next_tokens.shape) - # finished sentences should have their next token be a padding token if eos_token_id is not None: if pad_token_id is None: diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 812f218dd923..ceaab7cb6a11 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -961,11 +961,7 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - - # print("attention_mask", attention_mask.shape) - # print("attention_mask", attention_mask.stride()) - # print("inputs_embeds", inputs_embeds.shape) - # print("inputs_embeds", inputs_embeds.stride()) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) # embed positions @@ -1029,7 +1025,11 @@ def forward( hidden_states=all_hidden_states, attentions=all_self_attns, ) - + + # TODO: As of 20/02/2024, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.). Ideally, we would want to pass a statically-shaped 4D mask as input to the model. + @torch.compiler.disable def _update_causal_mask(self, attention_mask, input_tensor): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: @@ -1044,16 +1044,10 @@ def _update_causal_mask(self, attention_mask, input_tensor): causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1) self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) - if hasattr(self, "causal_mask"): # we use the current dtype to avoid any overflows - causal_mask = ( - self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min - ) - else: - mask = torch.full( - (self.config.max_position_embeddings, self.config.max_position_embeddings), - fill_value=torch.finfo(dtype).min, - ) - causal_mask = torch.triu(mask, diagonal=1).to(dtype) + # We use the current dtype to avoid any overflows + causal_mask = ( + self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min + ) if attention_mask is not None and attention_mask.dim() == 2: mask_length = attention_mask.shape[-1] @@ -1241,11 +1235,11 @@ def prepare_inputs_for_generation( if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None): # generation with static cache past_length = past_key_value.get_seq_length() + print("input_ids shape here", input_ids.shape) + print("past_length here", past_length) input_ids = input_ids[:, past_length:] position_ids = position_ids[:, past_length:] - print("input_ids after update", input_ids.shape) - # TODO @gante we should only keep a `cache_position` in generate, and do +=1. # same goes for position ids. Could also help with continued generation. cache_position = kwargs.get("cache_position", None) @@ -1253,6 +1247,7 @@ def prepare_inputs_for_generation( cache_position = torch.arange( past_length, past_length + position_ids.shape[-1], device=position_ids.device ) + print("cache_position", cache_position) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: @@ -1267,7 +1262,7 @@ def prepare_inputs_for_generation( model_inputs.update( { - "position_ids": position_ids, + "position_ids": position_ids.contiguous(), "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), From 0a00d6bba703037d89207fe701920c4e2612af2a Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 20 Feb 2024 15:48:02 +0100 Subject: [PATCH 03/15] cleanup --- src/transformers/cache_utils.py | 7 ++- src/transformers/generation/utils.py | 44 ++++++++++++++----- .../models/llama/modeling_llama.py | 12 ++--- 3 files changed, 40 insertions(+), 23 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index db7f3328933c..cd17201c483f 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -387,14 +387,13 @@ def update( A tuple containing the updated key and value states. """ new_cache_positions = cache_kwargs.get("cache_position") - k_out = self.key_cache - v_out = self.value_cache + k_out = self.key_cache[: key_states.shape[0]] + v_out = self.value_cache[: value_states.shape[0]] k_out[:, :, new_cache_positions] = key_states v_out[:, :, new_cache_positions] = value_states - + # # This NEEDS to be in-place as in the modeling we are not calling directly `self.past_key_value.update()`, but are rather using getattr. - # print("update seen_tokens with", key_states.shape[2]) self.seen_tokens.add_(key_states.shape[2]) return k_out, v_out diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 957d0b8e23a0..5e728ee4ac78 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1214,6 +1214,17 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de UserWarning, ) + @torch.no_grad() + def prepare_static_cache(self, max_batch_size: int, max_total_tokens: int): + # if we don't pass `past_key_values` and a cache_implementation is specified + cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING["static"] + if not callable(getattr(self, "_setup_cache", None)): + raise ValueError( + "The `generation_config` defines a `cache_implementation` that is not compatible with this model." + " Make sure it has a `_setup_cache` function." + ) + self._setup_cache(cache_cls, max_batch_size=max_batch_size, max_cache_len=max_total_tokens) + @torch.no_grad() def generate( self, @@ -1463,10 +1474,11 @@ def generate( " Make sure it has a `_setup_cache` function." ) self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length) + """ + # TODO: sanity check on batch_size self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) - """ - + # 7. determine generation mode generation_mode = self._get_generation_mode(generation_config, assistant_model) @@ -1525,7 +1537,7 @@ def generate( ) # 12. run assisted generate - return self.assisted_decoding( + result = self.assisted_decoding( input_ids, candidate_generator=candidate_generator, do_sample=generation_config.do_sample, @@ -1543,7 +1555,7 @@ def generate( ) if generation_mode == GenerationMode.GREEDY_SEARCH: # 11. run greedy search - return self.greedy_search( + result = self.greedy_search( input_ids, logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, @@ -1561,7 +1573,7 @@ def generate( if not model_kwargs["use_cache"]: raise ValueError("Contrastive search requires `use_cache=True`") - return self.contrastive_search( + result = self.contrastive_search( input_ids, top_k=generation_config.top_k, penalty_alpha=generation_config.penalty_alpha, @@ -1591,7 +1603,7 @@ def generate( ) # 13. run sample - return self.sample( + result = self.sample( input_ids, logits_processor=prepared_logits_processor, logits_warper=logits_warper, @@ -1625,7 +1637,7 @@ def generate( **model_kwargs, ) # 13. run beam search - return self.beam_search( + result = self.beam_search( input_ids, beam_scorer, logits_processor=prepared_logits_processor, @@ -1664,7 +1676,7 @@ def generate( ) # 14. run beam sample - return self.beam_sample( + result = self.beam_sample( input_ids, beam_scorer, logits_processor=prepared_logits_processor, @@ -1699,7 +1711,7 @@ def generate( **model_kwargs, ) # 13. run beam search - return self.group_beam_search( + result = self.group_beam_search( input_ids, beam_scorer, logits_processor=prepared_logits_processor, @@ -1773,7 +1785,7 @@ def typeerror(): **model_kwargs, ) # 13. run beam search - return self.constrained_beam_search( + result = self.constrained_beam_search( input_ids, constrained_beam_scorer=constrained_beam_scorer, logits_processor=prepared_logits_processor, @@ -1787,6 +1799,16 @@ def typeerror(): **model_kwargs, ) + if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: + if not callable(getattr(self, "_reset_cache", None)): + raise ValueError( + "The `generation_config` defines a `cache_implementation` that is not compatible with this model." + " Make sure it has a `_reset_cache` function." + ) + self._reset_cache() + + return result + @torch.no_grad() def contrastive_search( self, @@ -2409,7 +2431,7 @@ def greedy_search( print("model_inputs attention_mask shape", model_inputs["attention_mask"].shape) print("model_inputs attention_mask stride", model_inputs["attention_mask"].stride()) - count += 1 + count += 1 # forward pass to get next token outputs = self( **model_inputs, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index c6fcd718be3e..2e715ddd3e3f 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -829,7 +829,7 @@ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = def _reset_cache(self): for layer in self.model.layers: - layer.self_attn.past_key_value = None + layer.self_attn.past_key_value.seen_tokens.sub_(layer.self_attn.past_key_value.seen_tokens) LLAMA_INPUTS_DOCSTRING = r""" @@ -1050,7 +1050,7 @@ def forward( hidden_states=all_hidden_states, attentions=all_self_attns, ) - + # TODO: As of 20/02/2024, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.). Ideally, we would want to pass a statically-shaped 4D mask as input to the model. @@ -1071,9 +1071,7 @@ def _update_causal_mask(self, attention_mask, input_tensor): self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) # We use the current dtype to avoid any overflows - causal_mask = ( - self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min - ) + causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min causal_mask = causal_mask.to(dtype=dtype, device=device) if attention_mask is not None and attention_mask.dim() == 2: @@ -1262,8 +1260,7 @@ def prepare_inputs_for_generation( if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None): # generation with static cache past_length = past_key_value.get_seq_length() - print("input_ids shape here", input_ids.shape) - print("past_length here", past_length) + print("past_length", past_length) input_ids = input_ids[:, past_length:] position_ids = position_ids[:, past_length:] @@ -1274,7 +1271,6 @@ def prepare_inputs_for_generation( cache_position = torch.arange( past_length, past_length + position_ids.shape[-1], device=position_ids.device ) - print("cache_position", cache_position) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: From 7472549870ceef89d680055adcda4f394c7123f8 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 20 Feb 2024 16:14:02 +0100 Subject: [PATCH 04/15] sanity checks --- src/transformers/cache_utils.py | 3 ++ src/transformers/generation/utils.py | 33 +++++++++++-------- .../models/llama/modeling_llama.py | 16 +++++++++ 3 files changed, 38 insertions(+), 14 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index cd17201c483f..382517ca5ec7 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -359,6 +359,7 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) # NOTE: self.seen_tokens being in an int results in bugs with torch.compile, where it is somehow not updated. + # TODO: We may want to remove `self.seen_tokens` altogether from the modeling, and leave it in the non-compiled `GenerationMixin.generate()`. self.seen_tokens = torch.tensor(0, dtype=torch.int64, device=device) def update( @@ -387,6 +388,8 @@ def update( A tuple containing the updated key and value states. """ new_cache_positions = cache_kwargs.get("cache_position") + + # `self.max_batch_size` may be larger than the current batch size. k_out = self.key_cache[: key_states.shape[0]] v_out = self.value_cache[: value_states.shape[0]] diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 5e728ee4ac78..7951a5bccfbb 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1462,20 +1462,25 @@ def generate( ): generation_config.max_length -= inputs_tensor.shape[1] - """ - # if we don't pass `past_key_values` and a cache_implementation is specified - if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING and not model_kwargs.get( - "past_key_values", False - ): - cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING[generation_config.cache_implementation] - if not callable(getattr(self, "_setup_cache", None)): - raise ValueError( - "The `generation_config` defines a `cache_implementation` that is not compatible with this model." - " Make sure it has a `_setup_cache` function." - ) - self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length) - """ - # TODO: sanity check on batch_size + if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: + if generation_config.cache_implementation == "static": + if not self.is_stateful_cache_initalized: + if hasattr(self.forward, "get_compiler_config"): + raise ValueError( + "Using `generate` with a compile model through torch.compile with a static cache implementation, but the static cache has not been initialized prior to calling torch.compile. This may lead to unexpected behavior. Please call `model.prepare_static_cache(max_batch_size=maximum_batch_size, max_total_tokens=max_total_tokens)` with the expected maximum batch size and maximum sequence length to be later used." + ) + else: + cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING["static"] + if not callable(getattr(self, "_setup_cache", None)): + raise ValueError( + "The `generation_config` defines a `cache_implementation` that is not compatible with this model." + " Make sure it has a `_setup_cache` function." + ) + self._setup_cache( + cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length + ) + else: + self._validate_cache_for_shapes(batch_size=batch_size, total_tokens=generation_config.max_length) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 2e715ddd3e3f..77bfdde782e2 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -809,6 +809,7 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + self.is_stateful_cache_initalized = False def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None): if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache: @@ -827,6 +828,21 @@ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = self.config, max_batch_size, max_cache_len, device=weights.device, dtype=weights.dtype ) + self.is_stateful_cache_initalized = True + + def _validate_cache_for_shapes(self, batch_size: int, total_tokens: int): + max_batch_size = self.model.layers[0].self_attn.past_key_value.max_batch_size + max_cache_len = self.model.layers[0].self_attn.past_key_value.max_cache_len + + if batch_size > max_batch_size: + raise ValueError( + f"Trying to run inference with a static KV cache for a batch size of {batch_size}, while the initialized static KV cache has a maximum batch size of {max_batch_size}." + ) + if total_tokens > max_cache_len: + raise ValueError( + f"Trying to run inference with a static KV cache for a maximum sequence lengths of {total_tokens}, while the initialized static KV cache can only handle up to {max_cache_len} tokens." + ) + def _reset_cache(self): for layer in self.model.layers: layer.self_attn.past_key_value.seen_tokens.sub_(layer.self_attn.past_key_value.seen_tokens) From b214766730bf881642186340cb6b9a4021ef8998 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 20 Feb 2024 16:16:41 +0100 Subject: [PATCH 05/15] more sanity --- src/transformers/generation/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 7951a5bccfbb..d3791736b3af 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1216,7 +1216,6 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de @torch.no_grad() def prepare_static_cache(self, max_batch_size: int, max_total_tokens: int): - # if we don't pass `past_key_values` and a cache_implementation is specified cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING["static"] if not callable(getattr(self, "_setup_cache", None)): raise ValueError( @@ -1464,6 +1463,10 @@ def generate( if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: if generation_config.cache_implementation == "static": + if model_kwargs.get("past_key_values", False) is not False: + raise ValueError( + "Using `past_key_values` argument with `generate()` when using a static KV cache is not supported. Please open an issue in Transformers GitHub repository." + ) if not self.is_stateful_cache_initalized: if hasattr(self.forward, "get_compiler_config"): raise ValueError( From b9b627c6f003f9326d2bcc9b18c46e0939a595fc Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 20 Feb 2024 17:52:02 +0100 Subject: [PATCH 06/15] working version WITH refactor From 0c03b7d45d7948363a6ece0e921b21f8c4e78286 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 20 Feb 2024 19:10:47 +0100 Subject: [PATCH 07/15] working without API change --- src/transformers/cache_utils.py | 21 ++++--- src/transformers/generation/utils.py | 55 ++++++++----------- .../models/llama/modeling_llama.py | 43 +++++---------- 3 files changed, 49 insertions(+), 70 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 382517ca5ec7..55f0c49db3b9 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -358,10 +358,6 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) - # NOTE: self.seen_tokens being in an int results in bugs with torch.compile, where it is somehow not updated. - # TODO: We may want to remove `self.seen_tokens` altogether from the modeling, and leave it in the non-compiled `GenerationMixin.generate()`. - self.seen_tokens = torch.tensor(0, dtype=torch.int64, device=device) - def update( self, key_states: torch.Tensor, @@ -389,23 +385,26 @@ def update( """ new_cache_positions = cache_kwargs.get("cache_position") - # `self.max_batch_size` may be larger than the current batch size. - k_out = self.key_cache[: key_states.shape[0]] - v_out = self.value_cache[: value_states.shape[0]] + k_out = self.key_cache + v_out = self.value_cache k_out[:, :, new_cache_positions] = key_states v_out[:, :, new_cache_positions] = value_states - # # This NEEDS to be in-place as in the modeling we are not calling directly `self.past_key_value.update()`, but are rather using getattr. - self.seen_tokens.add_(key_states.shape[2]) return k_out, v_out def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" - return self.seen_tokens + # TODO: Fix once the stateful `int` bug in PyTorch is fixed. + raise ValueError( + "get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114." + ) def get_usable_length(self, new_sequence_length=None, layer_idx: Optional[int] = 0) -> int: - return self.seen_tokens + # TODO: Fix once the stateful `int` bug in PyTorch is fixed. + raise ValueError( + "get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114." + ) def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d3791736b3af..9ccb4809271c 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -648,6 +648,7 @@ def _update_model_kwargs_for_generation( model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False, standardize_cache_format: bool = False, + model_inputs: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: # update past_key_values model_kwargs["past_key_values"] = self._extract_past_from_model_output( @@ -677,6 +678,8 @@ def _update_model_kwargs_for_generation( dim=-1, ) + model_kwargs["cache_position"] = model_inputs.get("cache_position", None) + return model_kwargs def _reorder_cache(self, past_key_values, beam_idx): @@ -1214,6 +1217,7 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de UserWarning, ) + """ @torch.no_grad() def prepare_static_cache(self, max_batch_size: int, max_total_tokens: int): cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING["static"] @@ -1223,6 +1227,7 @@ def prepare_static_cache(self, max_batch_size: int, max_total_tokens: int): " Make sure it has a `_setup_cache` function." ) self._setup_cache(cache_cls, max_batch_size=max_batch_size, max_cache_len=max_total_tokens) + """ @torch.no_grad() def generate( @@ -1467,23 +1472,13 @@ def generate( raise ValueError( "Using `past_key_values` argument with `generate()` when using a static KV cache is not supported. Please open an issue in Transformers GitHub repository." ) - if not self.is_stateful_cache_initalized: - if hasattr(self.forward, "get_compiler_config"): - raise ValueError( - "Using `generate` with a compile model through torch.compile with a static cache implementation, but the static cache has not been initialized prior to calling torch.compile. This may lead to unexpected behavior. Please call `model.prepare_static_cache(max_batch_size=maximum_batch_size, max_total_tokens=max_total_tokens)` with the expected maximum batch size and maximum sequence length to be later used." - ) - else: - cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING["static"] - if not callable(getattr(self, "_setup_cache", None)): - raise ValueError( - "The `generation_config` defines a `cache_implementation` that is not compatible with this model." - " Make sure it has a `_setup_cache` function." - ) - self._setup_cache( - cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length - ) - else: - self._validate_cache_for_shapes(batch_size=batch_size, total_tokens=generation_config.max_length) + cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING["static"] + if not callable(getattr(self, "_setup_cache", None)): + raise ValueError( + "The `generation_config` defines a `cache_implementation` that is not compatible with this model." + " Make sure it has a `_setup_cache` function." + ) + self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) @@ -2007,6 +2002,7 @@ def contrastive_search( model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, standardize_cache_format=True, + model_inputs=model_inputs, ) if not sequential: # Expands model inputs top_k times, for batched forward passes (akin to beam search). @@ -2201,7 +2197,7 @@ def contrastive_search( if streamer is not None: streamer.put(next_tokens.cpu()) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) # if eos_token was found in one sentence, set sentence to finished @@ -2420,7 +2416,6 @@ def greedy_search( this_peer_finished = False # used by synced_gpus only count = 0 while True: - print("----- call forward") if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence @@ -2434,11 +2429,6 @@ def greedy_search( # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - print("model_inputs input ids shape", model_inputs["input_ids"].shape) - print("model_inputs input ids stride", model_inputs["input_ids"].stride()) - print("model_inputs attention_mask shape", model_inputs["attention_mask"].shape) - print("model_inputs attention_mask stride", model_inputs["attention_mask"].stride()) - count += 1 # forward pass to get next token outputs = self( @@ -2490,7 +2480,10 @@ def greedy_search( if streamer is not None: streamer.put(next_tokens.cpu()) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + model_inputs=model_inputs, ) # if eos_token was found in one sentence, set sentence to finished @@ -2784,7 +2777,7 @@ def sample( if streamer is not None: streamer.put(next_tokens.cpu()) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) # if eos_token was found in one sentence, set sentence to finished @@ -3177,7 +3170,7 @@ def beam_search( input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) if model_kwargs["past_key_values"] is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( @@ -3524,7 +3517,7 @@ def beam_sample( input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) if model_kwargs["past_key_values"] is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( @@ -3923,7 +3916,7 @@ def group_beam_search( input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) if model_kwargs["past_key_values"] is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( @@ -4275,7 +4268,7 @@ def constrained_beam_search( input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) if model_kwargs["past_key_values"] is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( @@ -4682,7 +4675,7 @@ def assisted_decoding( ) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) # if eos_token was found in one sentence, set sentence to finished diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 77bfdde782e2..9d8de1080647 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -830,22 +830,9 @@ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = self.is_stateful_cache_initalized = True - def _validate_cache_for_shapes(self, batch_size: int, total_tokens: int): - max_batch_size = self.model.layers[0].self_attn.past_key_value.max_batch_size - max_cache_len = self.model.layers[0].self_attn.past_key_value.max_cache_len - - if batch_size > max_batch_size: - raise ValueError( - f"Trying to run inference with a static KV cache for a batch size of {batch_size}, while the initialized static KV cache has a maximum batch size of {max_batch_size}." - ) - if total_tokens > max_cache_len: - raise ValueError( - f"Trying to run inference with a static KV cache for a maximum sequence lengths of {total_tokens}, while the initialized static KV cache can only handle up to {max_cache_len} tokens." - ) - def _reset_cache(self): for layer in self.model.layers: - layer.self_attn.past_key_value.seen_tokens.sub_(layer.self_attn.past_key_value.seen_tokens) + layer.self_attn.past_key_value = None LLAMA_INPUTS_DOCSTRING = r""" @@ -989,13 +976,13 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() - if cache_position is None: + past_seen_tokens = 0 + if use_cache: # kept for BC (cache positions) + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1273,20 +1260,20 @@ def prepare_inputs_for_generation( if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] - if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None): + if getattr(self.model.layers[0].self_attn, "past_key_value", None) is not None: # generation with static cache - past_length = past_key_value.get_seq_length() - print("past_length", past_length) + cache_position = kwargs.get("cache_position", None) + if cache_position is None: + past_length = 0 + else: + past_length = cache_position[-1] + 1 input_ids = input_ids[:, past_length:] position_ids = position_ids[:, past_length:] # TODO @gante we should only keep a `cache_position` in generate, and do +=1. # same goes for position ids. Could also help with continued generation. - cache_position = kwargs.get("cache_position", None) - if cache_position is None: - cache_position = torch.arange( - past_length, past_length + position_ids.shape[-1], device=position_ids.device - ) + # cache_position = kwargs.get("cache_position", None) + cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: From 28cdee0fa4b1565d913a57b086bdf5fd9e994900 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 20 Feb 2024 19:47:25 +0100 Subject: [PATCH 08/15] cleanup & tests pass --- src/transformers/models/llama/modeling_llama.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 9d8de1080647..84a1c8d0ff48 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -809,7 +809,6 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - self.is_stateful_cache_initalized = False def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None): if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache: @@ -828,8 +827,6 @@ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = self.config, max_batch_size, max_cache_len, device=weights.device, dtype=weights.dtype ) - self.is_stateful_cache_initalized = True - def _reset_cache(self): for layer in self.model.layers: layer.self_attn.past_key_value = None @@ -976,13 +973,13 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if cache_position is None: - past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() + past_seen_tokens = 0 + if use_cache: # kept for BC (cache positions) + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) From 190e0cf2be7c8d3c729e726642a0ff4f0b464a44 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 20 Feb 2024 19:49:58 +0100 Subject: [PATCH 09/15] more cleaning --- src/transformers/cache_utils.py | 1 - src/transformers/generation/utils.py | 14 -------------- src/transformers/models/llama/modeling_llama.py | 1 - 3 files changed, 16 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 55f0c49db3b9..1cb7c429ae19 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -384,7 +384,6 @@ def update( A tuple containing the updated key and value states. """ new_cache_positions = cache_kwargs.get("cache_position") - k_out = self.key_cache v_out = self.value_cache diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 9ccb4809271c..464dda742258 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1217,18 +1217,6 @@ def _validate_generated_length(self, generation_config, input_ids_length, has_de UserWarning, ) - """ - @torch.no_grad() - def prepare_static_cache(self, max_batch_size: int, max_total_tokens: int): - cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING["static"] - if not callable(getattr(self, "_setup_cache", None)): - raise ValueError( - "The `generation_config` defines a `cache_implementation` that is not compatible with this model." - " Make sure it has a `_setup_cache` function." - ) - self._setup_cache(cache_cls, max_batch_size=max_batch_size, max_cache_len=max_total_tokens) - """ - @torch.no_grad() def generate( self, @@ -2414,7 +2402,6 @@ def greedy_search( unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) this_peer_finished = False # used by synced_gpus only - count = 0 while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. @@ -2429,7 +2416,6 @@ def greedy_search( # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - count += 1 # forward pass to get next token outputs = self( **model_inputs, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 84a1c8d0ff48..ff9d36db8975 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1269,7 +1269,6 @@ def prepare_inputs_for_generation( # TODO @gante we should only keep a `cache_position` in generate, and do +=1. # same goes for position ids. Could also help with continued generation. - # cache_position = kwargs.get("cache_position", None) cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step From 5019e81b80f6ad274a6ca48eb54c1aa4aab8a2b3 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 20 Feb 2024 20:00:11 +0100 Subject: [PATCH 10/15] fix test --- src/transformers/models/llama/modeling_llama.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index ff9d36db8975..35d0252dfee9 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -978,11 +978,11 @@ def forward( if not isinstance(past_key_values, StaticCache): past_key_values = DynamicCache.from_legacy_cache(past_key_values) - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) if position_ids is None: position_ids = cache_position.unsqueeze(0) From 80b9072c4e3fe0a39383f2f62bfdbbdf60ddde4b Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 20 Feb 2024 20:08:51 +0100 Subject: [PATCH 11/15] fix tests --- src/transformers/models/llama/modeling_llama.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 35d0252dfee9..9d01b83c95d3 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -977,12 +977,14 @@ def forward( if use_cache: # kept for BC (cache positions) if not isinstance(past_key_values, StaticCache): past_key_values = DynamicCache.from_legacy_cache(past_key_values) - - if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + + if cache_position is None: + if isinstance(past_key_values, StaticCache): + raise ValueError("cache_position is a required argument when using StaticCache.") + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) if position_ids is None: position_ids = cache_position.unsqueeze(0) From a660486ee6bb9fc2ccbedf5d51a5fb6dcb6cf2f0 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 21 Feb 2024 09:34:02 +0100 Subject: [PATCH 12/15] Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/generation/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 464dda742258..d337e5593440 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1793,8 +1793,8 @@ def typeerror(): if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: if not callable(getattr(self, "_reset_cache", None)): raise ValueError( - "The `generation_config` defines a `cache_implementation` that is not compatible with this model." - " Make sure it has a `_reset_cache` function." + "A `static_cache` was used to generate but there was a failure when trying to release the cache. " + " Make sure this model implements a `_reset_cache` function." ) self._reset_cache() From 293546296f8a5557b6f1a654e2ab9b6eacd7e176 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 21 Feb 2024 09:41:10 +0100 Subject: [PATCH 13/15] smaller comment --- src/transformers/models/llama/modeling_llama.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 9d01b83c95d3..03176aac3462 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1277,11 +1277,9 @@ def prepare_inputs_for_generation( if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - # The `contiguous()` here is necessary to have a static stride during (non-speculative) decoding. torchdynamo otherwise - # recompiles graphs as the stride of the inputs is a guard. - # TODO: We don't really need to handle the input_ids here, and this contiguous() call could be removed if we were - # simply using GenerationMixin.greedy_search `next_tokens` variable directly (which is already contiguous), instead of - # doing a torch.cat + then slice. + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} model_inputs.update( From 5dbcef434777827faffcf928ef9941b6981ed3cd Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 21 Feb 2024 10:14:56 +0100 Subject: [PATCH 14/15] update comment --- src/transformers/models/llama/modeling_llama.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 03176aac3462..55f397ce04be 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1053,9 +1053,10 @@ def forward( attentions=all_self_attns, ) - # TODO: As of 20/02/2024, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # As of 20/02/2024, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.). Ideally, we would want to pass a statically-shaped 4D mask as input to the model. + # (`recording cudagraph tree for symint key 13`, etc.), which is very slow. + # TODO: pass a statically-shaped 4D mask as input to the model, enabling torch.compile with `fullgraph=True`. @torch.compiler.disable def _update_causal_mask(self, attention_mask, input_tensor): if self.config._attn_implementation == "flash_attention_2": From a8c4e1036ac6f0f78e512235cf42c17c7d3cc762 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 21 Feb 2024 11:48:25 +0100 Subject: [PATCH 15/15] update comment --- src/transformers/models/llama/modeling_llama.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 55f397ce04be..aec9f6ee7660 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1053,11 +1053,10 @@ def forward( attentions=all_self_attns, ) - # As of 20/02/2024, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is very slow. - # TODO: pass a statically-shaped 4D mask as input to the model, enabling torch.compile with `fullgraph=True`. - @torch.compiler.disable + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 def _update_causal_mask(self, attention_mask, input_tensor): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: