From 1748ff17cc90e38d17c16e4c19957199fb815ed1 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 5 Jun 2024 18:17:27 +0200 Subject: [PATCH 01/35] Add .float() in all generation methods logit outputs --- src/transformers/generation/utils.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index a9ebdcdd4775..3984eba644d5 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2590,7 +2590,7 @@ def _contrastive_search( # next logit for contrastive search to select top-k candidate tokens # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration # (the clone itself is always small) - logit_for_next_step = outputs.logits[:, -1, :].clone() + logit_for_next_step = outputs.logits[:, -1, :].clone().float() model_kwargs = self._update_model_kwargs_for_generation( outputs, @@ -2778,7 +2778,7 @@ def _contrastive_search( next_past_key_values = tuple(new_key_values) - logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :] + logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :].float() # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration if self.config.is_encoder_decoder: @@ -2966,7 +2966,7 @@ def _sample( # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) - next_token_logits = outputs.logits[:, -1, :].clone() + next_token_logits = outputs.logits[:, -1, :].clone().float() # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) @@ -3210,7 +3210,7 @@ def _beam_search( # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) - next_token_logits = outputs.logits[:, -1, :].clone() + next_token_logits = outputs.logits[:, -1, :].clone().float() next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) @@ -3484,7 +3484,7 @@ def _group_beam_search( # select outputs of beams of current group only # No need to clone() the logits here as they will not retain outputs.logits at the end of the loop - next_token_logits = outputs.logits[batch_group_indices, -1, :] + next_token_logits = outputs.logits[batch_group_indices, -1, :].float() next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 @@ -3739,7 +3739,7 @@ def _constrained_beam_search( # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) - next_token_logits = outputs.logits[:, -1, :].clone() + next_token_logits = outputs.logits[:, -1, :].clone().float() next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) @@ -3999,7 +3999,7 @@ def _assisted_decoding( # 2.3. Process the new logits new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present - next_token_logits = new_logits.clone() + next_token_logits = new_logits.clone().float() if len(logits_processor) > 0: for i in range(candidate_length + 1): new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) From 3f4f4e8d7c73e7f240d1a58d947e0a6363442750 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 5 Jun 2024 19:12:44 +0200 Subject: [PATCH 02/35] Switch float-casting of logits to training only for main models --- src/transformers/models/cohere/modeling_cohere.py | 3 ++- src/transformers/models/gemma/modeling_gemma.py | 4 +++- src/transformers/models/idefics2/modeling_idefics2.py | 3 ++- src/transformers/models/jamba/modeling_jamba.py | 3 ++- src/transformers/models/llama/modeling_llama.py | 3 ++- src/transformers/models/mistral/modeling_mistral.py | 3 ++- src/transformers/models/mixtral/modeling_mixtral.py | 3 ++- src/transformers/models/olmo/modeling_olmo.py | 3 ++- src/transformers/models/phi/modeling_phi.py | 3 ++- src/transformers/models/phi3/modeling_phi3.py | 3 ++- src/transformers/models/qwen2/modeling_qwen2.py | 3 ++- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 3 ++- src/transformers/models/starcoder2/modeling_starcoder2.py | 3 ++- 13 files changed, 27 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index afcea137b58b..66ccc134a368 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -1071,12 +1071,13 @@ def forward( ) hidden_states = outputs[0] + # Casting of the logits to float will happen in generate() in inference mode to save memory logits = self.lm_head(hidden_states) logits = logits * self.logit_scale - logits = logits.float() loss = None if labels is not None: + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index a05d2c059e21..402e81c17a06 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1085,10 +1085,12 @@ def forward( ) hidden_states = outputs[0] + # Casting of the logits to float will happen in generate() in inference mode to save memory logits = self.lm_head(hidden_states) - logits = logits.float() + loss = None if labels is not None: + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index cdc7e9ba4e77..b8c29a67615f 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1591,11 +1591,12 @@ def forward( ) hidden_states = outputs[0] + # Casting of the logits to float will happen in generate() in inference mode to save memory logits = self.lm_head(hidden_states) - logits = logits.float() loss = None if labels is not None: + logits = logits.float() labels = labels.to(logits.device) # Shift so that tokens < n predict n if attention_mask is not None: diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 230536a83a14..eff55c7964c1 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -1493,14 +1493,15 @@ def forward( ) hidden_states = outputs[0] + # Casting of the logits to float will happen in generate() in inference mode to save memory if num_logits_to_keep is None: logits = self.lm_head(hidden_states) else: logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :]) - logits = logits.float() loss = None if labels is not None: + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 8716d27f5481..879a1bae7597 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1193,16 +1193,17 @@ def forward( ) hidden_states = outputs[0] + # Casting of the logits to float will happen in generate() in inference mode to save memory if self.config.pretraining_tp > 1: lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] logits = torch.cat(logits, dim=-1) else: logits = self.lm_head(hidden_states) - logits = logits.float() loss = None if labels is not None: + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 1a2b732e85e4..fdd292edbe58 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1044,11 +1044,12 @@ def forward( ) hidden_states = outputs[0] + # Casting of the logits to float will happen in generate() in inference mode to save memory logits = self.lm_head(hidden_states) - logits = logits.float() loss = None if labels is not None: + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 522b6db7bcc7..1e2786ea4905 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1286,11 +1286,12 @@ def forward( ) hidden_states = outputs[0] + # Casting of the logits to float will happen in generate() in inference mode to save memory logits = self.lm_head(hidden_states) - logits = logits.float() loss = None if labels is not None: + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 1940660f61b5..dc6dd5e39bad 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -1116,11 +1116,12 @@ def forward( ) hidden_states = outputs[0] + # Casting of the logits to float will happen in generate() in inference mode to save memory logits = self.lm_head(hidden_states) - logits = logits.float() loss = None if labels is not None: + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 6d63c0ea7e8e..598b31721e31 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -1217,11 +1217,12 @@ def forward( ) hidden_states = outputs[0] + # Casting of the logits to float will happen in generate() in inference mode to save memory logits = self.lm_head(hidden_states) - logits = logits.float() loss = None if labels is not None: + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 08417fcabfaa..c6f25e558bbc 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1257,11 +1257,12 @@ def forward( ) hidden_states = outputs[0] + # Casting of the logits to float will happen in generate() in inference mode to save memory logits = self.lm_head(hidden_states) - logits = logits.float() loss = None if labels is not None: + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 28b414b1901b..ac903b83a7f6 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -1115,11 +1115,12 @@ def forward( ) hidden_states = outputs[0] + # Casting of the logits to float will happen in generate() in inference mode to save memory logits = self.lm_head(hidden_states) - logits = logits.float() loss = None if labels is not None: + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 12ebe26e058d..3bbcaa6b3617 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1296,11 +1296,12 @@ def forward( ) hidden_states = outputs[0] + # Casting of the logits to float will happen in generate() in inference mode to save memory logits = self.lm_head(hidden_states) - logits = logits.float() loss = None if labels is not None: + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index d51077b04254..c403b2f98ca5 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -1091,11 +1091,12 @@ def forward( ) hidden_states = outputs[0] + # Casting of the logits to float will happen in generate() in inference mode to save memory logits = self.lm_head(hidden_states) - logits = logits.float() loss = None if labels is not None: + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() From 727c7e46c03399fa176e2ea3560da60019905f4b Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 6 Jun 2024 14:13:27 +0200 Subject: [PATCH 03/35] Add `num_logits_to_keep` in Llama and add it by default in generate --- .../generation/candidate_generator.py | 4 ++++ src/transformers/generation/utils.py | 16 ++++++++++++++ .../models/llama/modeling_llama.py | 21 ++++++++++++++++--- 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 1ab5ea527e49..bf55ae3e2b06 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -119,6 +119,10 @@ def __init__( value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value) ) + # Remove potential default "num_logits_to_keep" key + if "num_logits_to_keep" in assistant_kwargs.keys() and not assistant_model._supports_num_logits_to_keep(): + del assistant_kwargs["num_logits_to_keep"] + if "assistant_encoder_outputs" in model_kwargs: assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"] elif assistant_model.config.is_encoder_decoder: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 3984eba644d5..51978988e2d4 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1601,6 +1601,15 @@ def _prepare_cache_for_generation( else EncoderDecoderCache(DynamicCache(), DynamicCache()) ) + def _supports_num_logits_to_keep(self) -> bool: + """ + Return True if the current model supports the keyword argument `num_logits_to_keep` in forward() + to save memory. Checking it in this way allows to avoid using a new model attribute. + """ + # Dummy call to check if `num_logits_to_keep` is present in output dict + dummy = self.prepare_inputs_for_generation(input_ids=torch.ones(1)) + return "num_logits_to_keep" in dummy + def _prepare_special_tokens( self, generation_config: GenerationConfig, @@ -1876,6 +1885,13 @@ def generate( inputs_tensor=inputs_tensor, input_ids_length=input_ids_length, ) + + # If the model supports `num_logits_to_keep` in forward(), set it to 1 to avoid computing the whole + # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding + # dynamically overrides this value as it can need more than the last token logits + if self._supports_num_logits_to_keep(): + model_kwargs["num_logits_to_keep"] = 1 + self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) # 7. Prepare the cache. diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 879a1bae7597..0457f809f3e3 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1146,6 +1146,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: Optional[Union[int, None]] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1154,6 +1155,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + Returns: Example: @@ -1193,16 +1199,24 @@ def forward( ) hidden_states = outputs[0] - # Casting of the logits to float will happen in generate() in inference mode to save memory if self.config.pretraining_tp > 1: lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) - logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + if num_logits_to_keep is None: + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + else: + logits = [F.linear(hidden_states[:, -num_logits_to_keep:, :], lm_head_slices[i]) for i in range(self.config.pretraining_tp)] logits = torch.cat(logits, dim=-1) else: - logits = self.lm_head(hidden_states) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + if num_logits_to_keep is None: + logits = self.lm_head(hidden_states) + else: + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() @@ -1293,6 +1307,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": kwargs["num_logits_to_keep"] if "num_logits_to_keep" in kwargs else None, } ) return model_inputs From 222017d4d0d7442132424124cf495adbb5253ebe Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 6 Jun 2024 14:14:12 +0200 Subject: [PATCH 04/35] Apply style --- src/transformers/models/llama/modeling_llama.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 0457f809f3e3..be64cfa63d33 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1205,7 +1205,10 @@ def forward( if num_logits_to_keep is None: logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] else: - logits = [F.linear(hidden_states[:, -num_logits_to_keep:, :], lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = [ + F.linear(hidden_states[:, -num_logits_to_keep:, :], lm_head_slices[i]) + for i in range(self.config.pretraining_tp) + ] logits = torch.cat(logits, dim=-1) else: # Only compute necessary logits, and do not upcast them to float if we are not computing the loss From dc709c687ed2639ad1307fd06abc63225ed8c44b Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 6 Jun 2024 14:19:08 +0200 Subject: [PATCH 05/35] Add num_logits_to_keep as arg in prepare_input_for_generation --- src/transformers/models/llama/modeling_llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index be64cfa63d33..08841ce7d8d6 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1253,6 +1253,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1310,7 +1311,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "num_logits_to_keep": kwargs["num_logits_to_keep"] if "num_logits_to_keep" in kwargs else None, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs From d2f1566fccff68bbeced8c5b74c99acd636d3fa2 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 6 Jun 2024 15:05:27 +0200 Subject: [PATCH 06/35] Add support for Mistral --- .../models/mistral/modeling_mistral.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index fdd292edbe58..d5a18dccfa4b 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -996,6 +996,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: Optional[Union[int, None]] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1004,6 +1005,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + Returns: Example: @@ -1044,11 +1050,15 @@ def forward( ) hidden_states = outputs[0] - # Casting of the logits to float will happen in generate() in inference mode to save memory - logits = self.lm_head(hidden_states) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + if num_logits_to_keep is None: + logits = self.lm_head(hidden_states) + else: + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() @@ -1082,6 +1092,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1116,6 +1127,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs From f2ef90cd62e96f05efe1f7240f3a1bc55a29ddce Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 6 Jun 2024 16:57:15 +0200 Subject: [PATCH 07/35] Revert models except llama and mistral --- src/transformers/models/cohere/modeling_cohere.py | 3 +-- src/transformers/models/gemma/modeling_gemma.py | 4 +--- src/transformers/models/idefics2/modeling_idefics2.py | 3 +-- src/transformers/models/jamba/modeling_jamba.py | 3 +-- src/transformers/models/mixtral/modeling_mixtral.py | 3 +-- src/transformers/models/olmo/modeling_olmo.py | 3 +-- src/transformers/models/phi/modeling_phi.py | 3 +-- src/transformers/models/phi3/modeling_phi3.py | 3 +-- src/transformers/models/qwen2/modeling_qwen2.py | 3 +-- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 3 +-- src/transformers/models/starcoder2/modeling_starcoder2.py | 3 +-- 11 files changed, 11 insertions(+), 23 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 66ccc134a368..afcea137b58b 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -1071,13 +1071,12 @@ def forward( ) hidden_states = outputs[0] - # Casting of the logits to float will happen in generate() in inference mode to save memory logits = self.lm_head(hidden_states) logits = logits * self.logit_scale + logits = logits.float() loss = None if labels is not None: - logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 402e81c17a06..a05d2c059e21 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1085,12 +1085,10 @@ def forward( ) hidden_states = outputs[0] - # Casting of the logits to float will happen in generate() in inference mode to save memory logits = self.lm_head(hidden_states) - + logits = logits.float() loss = None if labels is not None: - logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index b8c29a67615f..cdc7e9ba4e77 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1591,12 +1591,11 @@ def forward( ) hidden_states = outputs[0] - # Casting of the logits to float will happen in generate() in inference mode to save memory logits = self.lm_head(hidden_states) + logits = logits.float() loss = None if labels is not None: - logits = logits.float() labels = labels.to(logits.device) # Shift so that tokens < n predict n if attention_mask is not None: diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index eff55c7964c1..230536a83a14 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -1493,15 +1493,14 @@ def forward( ) hidden_states = outputs[0] - # Casting of the logits to float will happen in generate() in inference mode to save memory if num_logits_to_keep is None: logits = self.lm_head(hidden_states) else: logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :]) + logits = logits.float() loss = None if labels is not None: - logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 1e2786ea4905..522b6db7bcc7 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1286,12 +1286,11 @@ def forward( ) hidden_states = outputs[0] - # Casting of the logits to float will happen in generate() in inference mode to save memory logits = self.lm_head(hidden_states) + logits = logits.float() loss = None if labels is not None: - logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index dc6dd5e39bad..1940660f61b5 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -1116,12 +1116,11 @@ def forward( ) hidden_states = outputs[0] - # Casting of the logits to float will happen in generate() in inference mode to save memory logits = self.lm_head(hidden_states) + logits = logits.float() loss = None if labels is not None: - logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 598b31721e31..6d63c0ea7e8e 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -1217,12 +1217,11 @@ def forward( ) hidden_states = outputs[0] - # Casting of the logits to float will happen in generate() in inference mode to save memory logits = self.lm_head(hidden_states) + logits = logits.float() loss = None if labels is not None: - logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index c6f25e558bbc..08417fcabfaa 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1257,12 +1257,11 @@ def forward( ) hidden_states = outputs[0] - # Casting of the logits to float will happen in generate() in inference mode to save memory logits = self.lm_head(hidden_states) + logits = logits.float() loss = None if labels is not None: - logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index ac903b83a7f6..28b414b1901b 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -1115,12 +1115,11 @@ def forward( ) hidden_states = outputs[0] - # Casting of the logits to float will happen in generate() in inference mode to save memory logits = self.lm_head(hidden_states) + logits = logits.float() loss = None if labels is not None: - logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 3bbcaa6b3617..12ebe26e058d 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1296,12 +1296,11 @@ def forward( ) hidden_states = outputs[0] - # Casting of the logits to float will happen in generate() in inference mode to save memory logits = self.lm_head(hidden_states) + logits = logits.float() loss = None if labels is not None: - logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index c403b2f98ca5..d51077b04254 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -1091,12 +1091,11 @@ def forward( ) hidden_states = outputs[0] - # Casting of the logits to float will happen in generate() in inference mode to save memory logits = self.lm_head(hidden_states) + logits = logits.float() loss = None if labels is not None: - logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() From ce7b980c24ae889678b7b4cbd046c69c6943bb2b Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 6 Jun 2024 18:07:35 +0200 Subject: [PATCH 08/35] Fix default None value in _supports_num_logits_to_keep() --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 51978988e2d4..7011c41d1965 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1607,7 +1607,7 @@ def _supports_num_logits_to_keep(self) -> bool: to save memory. Checking it in this way allows to avoid using a new model attribute. """ # Dummy call to check if `num_logits_to_keep` is present in output dict - dummy = self.prepare_inputs_for_generation(input_ids=torch.ones(1)) + dummy = self.prepare_inputs_for_generation(torch.ones(1), attention_mask=None) return "num_logits_to_keep" in dummy def _prepare_special_tokens( From d4201f42fb4f0dbe2ae48c9c9ba29d1ccef61cac Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 7 Jun 2024 00:01:03 +0200 Subject: [PATCH 09/35] Fix dimension of dummy input --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 7011c41d1965..638a264d7e47 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1607,7 +1607,7 @@ def _supports_num_logits_to_keep(self) -> bool: to save memory. Checking it in this way allows to avoid using a new model attribute. """ # Dummy call to check if `num_logits_to_keep` is present in output dict - dummy = self.prepare_inputs_for_generation(torch.ones(1), attention_mask=None) + dummy = self.prepare_inputs_for_generation(torch.ones(1, 1), attention_mask=None) return "num_logits_to_keep" in dummy def _prepare_special_tokens( From b15b5dece25c58c55bf2002492434fa8c3971cf8 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 7 Jun 2024 01:00:39 +0200 Subject: [PATCH 10/35] Add exception for prophetnet in _supports_num_logits_to_keep() --- src/transformers/generation/utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 638a264d7e47..7f8920d1362d 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1606,8 +1606,12 @@ def _supports_num_logits_to_keep(self) -> bool: Return True if the current model supports the keyword argument `num_logits_to_keep` in forward() to save memory. Checking it in this way allows to avoid using a new model attribute. """ - # Dummy call to check if `num_logits_to_keep` is present in output dict - dummy = self.prepare_inputs_for_generation(torch.ones(1, 1), attention_mask=None) + # Dummy call to check if `num_logits_to_keep` is present in output dict (encoder_outputs needs to be passed + # for prophetnet model to avoid AssertionError but can be anything except None) + if "prophetnet" in self.__class__.__name__.lower(): + dummy = self.prepare_inputs_for_generation(torch.ones(1, 1), attention_mask=None, encoder_outputs=0) + else: + dummy = self.prepare_inputs_for_generation(torch.ones(1, 1), attention_mask=None) return "num_logits_to_keep" in dummy def _prepare_special_tokens( From 95e0807aa4211f4cb21086159cf570eef6efc329 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 7 Jun 2024 01:20:20 +0200 Subject: [PATCH 11/35] Update _supports_num_logits_to_keep() to use inspect.signature() --- src/transformers/generation/utils.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 7f8920d1362d..7d3a4809faa9 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1606,13 +1606,7 @@ def _supports_num_logits_to_keep(self) -> bool: Return True if the current model supports the keyword argument `num_logits_to_keep` in forward() to save memory. Checking it in this way allows to avoid using a new model attribute. """ - # Dummy call to check if `num_logits_to_keep` is present in output dict (encoder_outputs needs to be passed - # for prophetnet model to avoid AssertionError but can be anything except None) - if "prophetnet" in self.__class__.__name__.lower(): - dummy = self.prepare_inputs_for_generation(torch.ones(1, 1), attention_mask=None, encoder_outputs=0) - else: - dummy = self.prepare_inputs_for_generation(torch.ones(1, 1), attention_mask=None) - return "num_logits_to_keep" in dummy + return "num_logits_to_keep" in set(inspect.signature(self.forward).parameters.keys()) def _prepare_special_tokens( self, From 12db045777530cb8961c223c00a947e0e44fcecd Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 20 Jun 2024 09:06:39 +0200 Subject: [PATCH 12/35] Add deprecation cycle + remove modification with pretraining_tp --- .../models/llama/modeling_llama.py | 18 ++++++++---------- .../models/mistral/modeling_mistral.py | 9 +++++++-- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 08841ce7d8d6..42d59c5ed3df 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1201,21 +1201,19 @@ def forward( hidden_states = outputs[0] if self.config.pretraining_tp > 1: lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - if num_logits_to_keep is None: - logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] - else: - logits = [ - F.linear(hidden_states[:, -num_logits_to_keep:, :], lm_head_slices[i]) - for i in range(self.config.pretraining_tp) - ] + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] logits = torch.cat(logits, dim=-1) else: + if labels is None: + logger.warning_once( + 'Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)' + ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove those 2 float() operations in v4.44 if num_logits_to_keep is None: - logits = self.lm_head(hidden_states) + logits = self.lm_head(hidden_states).float() else: - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index d5a18dccfa4b..f645bca867eb 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1050,11 +1050,16 @@ def forward( ) hidden_states = outputs[0] + if labels is None: + logger.warning_once( + 'Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)' + ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove those 2 float() operations in v4.44 if num_logits_to_keep is None: - logits = self.lm_head(hidden_states) + logits = self.lm_head(hidden_states).float() else: - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: From b224e24ce5a2ab5a6014feb1265beaa0a0475e1f Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 20 Jun 2024 09:08:19 +0200 Subject: [PATCH 13/35] Apply style --- src/transformers/models/llama/modeling_llama.py | 2 +- src/transformers/models/mistral/modeling_mistral.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 42d59c5ed3df..e03bd588c9ea 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1206,7 +1206,7 @@ def forward( else: if labels is None: logger.warning_once( - 'Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)' + "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss # TODO: remove those 2 float() operations in v4.44 diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index f645bca867eb..d3399dec65de 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1052,7 +1052,7 @@ def forward( hidden_states = outputs[0] if labels is None: logger.warning_once( - 'Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)' + "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss # TODO: remove those 2 float() operations in v4.44 From f0e1034b096f2c2d8878fafff571cf054b9de9a9 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 21 Jun 2024 09:24:04 +0200 Subject: [PATCH 14/35] Add most used models --- .../models/cohere/modeling_cohere.py | 22 +++++++++++++++-- src/transformers/models/dbrx/modeling_dbrx.py | 14 ++++++++++- .../models/gemma/modeling_gemma.py | 23 ++++++++++++++++-- .../models/idefics2/modeling_idefics2.py | 24 ++++++++++++++++--- .../models/jamba/modeling_jamba.py | 7 ++++++ .../models/jetmoe/modeling_jetmoe.py | 22 +++++++++++++++-- .../models/llama/modeling_llama.py | 2 +- .../models/mistral/modeling_mistral.py | 2 +- .../models/mixtral/modeling_mixtral.py | 22 +++++++++++++++-- src/transformers/models/olmo/modeling_olmo.py | 22 +++++++++++++++-- .../models/persimmon/modeling_persimmon.py | 14 ++++++++++- src/transformers/models/phi/modeling_phi.py | 22 +++++++++++++++-- src/transformers/models/phi3/modeling_phi3.py | 22 +++++++++++++++-- .../models/qwen2/modeling_qwen2.py | 22 +++++++++++++++-- .../models/qwen2_moe/modeling_qwen2_moe.py | 22 +++++++++++++++-- .../models/stablelm/modeling_stablelm.py | 14 ++++++++++- .../models/starcoder2/modeling_starcoder2.py | 22 +++++++++++++++-- 17 files changed, 270 insertions(+), 28 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index afcea137b58b..0d3bd0939e9b 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -1024,6 +1024,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: Optional[int] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1032,6 +1033,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + Returns: Example: @@ -1071,12 +1077,22 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) + if labels is None: + logger.warning_once( + "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove those 2 float() operations in v4.44 + if num_logits_to_keep is None: + logits = self.lm_head(hidden_states).float() + else: + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() logits = logits * self.logit_scale - logits = logits.float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1109,6 +1125,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1166,6 +1183,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 3486d5ed3ab0..e66df692f42d 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1275,6 +1275,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: Optional[int] = None, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r"""Forward function for causal language modeling. @@ -1284,6 +1285,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + Returns: Example: @@ -1328,7 +1334,11 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) + # No upscaling to float was ever done for Dbrx + if num_logits_to_keep is None: + logits = self.lm_head(hidden_states) + else: + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) loss = None if labels is not None: @@ -1380,6 +1390,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1437,6 +1448,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index a05d2c059e21..566d29a12f65 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1038,6 +1038,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: Optional[int] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1046,6 +1047,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + Returns: Example: @@ -1085,10 +1091,21 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None: + logger.warning_once( + "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove those 2 float() operations in v4.44 + if num_logits_to_keep is None: + logits = self.lm_head(hidden_states).float() + else: + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1121,6 +1138,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1177,6 +1195,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index cdc7e9ba4e77..40a7f06c404c 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1520,6 +1520,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + num_logits_to_keep: Optional[int] = None, ) -> Union[Tuple, Idefics2CausalLMOutputWithPast]: r""" Args: @@ -1528,6 +1529,12 @@ def forward( config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics2ForConditionalGeneration`). Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + Returns: Example: @@ -1591,11 +1598,21 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None: + logger.warning_once( + "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove those 2 float() operations in v4.44 + if num_logits_to_keep is None: + logits = self.lm_head(hidden_states).float() + else: + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() labels = labels.to(logits.device) # Shift so that tokens < n predict n if attention_mask is not None: @@ -1623,7 +1640,7 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, num_logits_to_keep=None, **kwargs ): past_length = 0 # Omit tokens covered by past_key_values @@ -1682,6 +1699,7 @@ def prepare_inputs_for_generation( "pixel_values": pixel_values, "pixel_attention_mask": pixel_attention_mask, "image_hidden_states": image_hidden_states, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 230536a83a14..426b23253a43 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -1497,10 +1497,17 @@ def forward( logits = self.lm_head(hidden_states) else: logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :]) + if labels is None: + logger.warning_once( + "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # TODO: remove this float() operations in v4.44 logits = logits.float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index a3fc645e5aea..b3a68e231f75 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -1253,6 +1253,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: Optional[int] = None, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" Args: @@ -1261,6 +1262,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + Returns: """ @@ -1285,11 +1291,21 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None: + logger.warning_once( + "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove those 2 float() operations in v4.44 + if num_logits_to_keep is None: + logits = self.lm_head(hidden_states).float() + else: + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1339,6 +1355,7 @@ def prepare_inputs_for_generation( output_router_logits=False, position_ids=None, use_cache=True, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1371,6 +1388,7 @@ def prepare_inputs_for_generation( "use_cache": use_cache, "attention_mask": attention_mask, "output_router_logits": output_router_logits, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index e03bd588c9ea..79fb600175d9 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1146,7 +1146,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: Optional[Union[int, None]] = None, + num_logits_to_keep: Optional[int] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index d3399dec65de..e3ef86597781 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -996,7 +996,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: Optional[Union[int, None]] = None, + num_logits_to_keep: Optional[int] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 522b6db7bcc7..801e52812049 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1233,6 +1233,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: Optional[int] = None, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" Args: @@ -1241,6 +1242,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + Returns: Example: @@ -1286,11 +1292,21 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None: + logger.warning_once( + "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove those 2 float() operations in v4.44 + if num_logits_to_keep is None: + logits = self.lm_head(hidden_states).float() + else: + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1339,6 +1355,7 @@ def prepare_inputs_for_generation( output_router_logits=False, position_ids=None, use_cache=True, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1371,6 +1388,7 @@ def prepare_inputs_for_generation( "use_cache": use_cache, "attention_mask": attention_mask, "output_router_logits": output_router_logits, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 1940660f61b5..89da0e940d71 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -1068,6 +1068,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: Optional[int] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1076,6 +1077,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + Returns: Example: @@ -1116,11 +1122,21 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None: + logger.warning_once( + "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove those 2 float() operations in v4.44 + if num_logits_to_keep is None: + logits = self.lm_head(hidden_states).float() + else: + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1153,6 +1169,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1210,6 +1227,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 1e4f56c0674d..85bce2ab8930 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -885,6 +885,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: Optional[int] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -893,6 +894,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + Returns: Example: @@ -933,7 +939,11 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) + # No upscaling to float was ever done for Persimmon + if num_logits_to_keep is None: + logits = self.lm_head(hidden_states) + else: + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) loss = None if labels is not None: @@ -970,6 +980,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1027,6 +1038,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 6d63c0ea7e8e..953cd6cf0e92 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -1169,6 +1169,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: Optional[int] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1177,6 +1178,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + Returns: Example: @@ -1217,11 +1223,21 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None: + logger.warning_once( + "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove those 2 float() operations in v4.44 + if num_logits_to_keep is None: + logits = self.lm_head(hidden_states).float() + else: + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1255,6 +1271,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1312,6 +1329,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 08417fcabfaa..11ac91b88943 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1210,6 +1210,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: Optional[int] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1218,6 +1219,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + Returns: Example: @@ -1257,11 +1263,21 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None: + logger.warning_once( + "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove those 2 float() operations in v4.44 + if num_logits_to_keep is None: + logits = self.lm_head(hidden_states).float() + else: + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1295,6 +1311,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1352,6 +1369,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 28b414b1901b..476593ca3241 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -1067,6 +1067,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: Optional[int] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1075,6 +1076,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + Returns: Example: @@ -1115,11 +1121,21 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None: + logger.warning_once( + "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove those 2 float() operations in v4.44 + if num_logits_to_keep is None: + logits = self.lm_head(hidden_states).float() + else: + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1153,6 +1169,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1210,6 +1227,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 12ebe26e058d..5e9e88004c3c 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1244,6 +1244,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: Optional[int] = None, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" Args: @@ -1252,6 +1253,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + Returns: Example: @@ -1296,11 +1302,21 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None: + logger.warning_once( + "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove those 2 float() operations in v4.44 + if num_logits_to_keep is None: + logits = self.lm_head(hidden_states).float() + else: + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1349,6 +1365,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1406,6 +1423,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 988948a9a827..63832dd9f592 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -1164,6 +1164,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: Optional[int] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1172,6 +1173,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + Returns: Example: @@ -1211,7 +1217,11 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) + # No upscaling to float was ever done for StableLm + if num_logits_to_keep is None: + logits = self.lm_head(hidden_states) + else: + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) loss = None if labels is not None: @@ -1248,6 +1258,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1305,6 +1316,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index d51077b04254..d6ce8c223c83 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -1043,6 +1043,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: Optional[int] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1051,6 +1052,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + Returns: Example: @@ -1091,11 +1097,21 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None: + logger.warning_once( + "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove those 2 float() operations in v4.44 + if num_logits_to_keep is None: + logits = self.lm_head(hidden_states).float() + else: + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1129,6 +1145,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=None, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1186,6 +1203,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs From 9ac57db6e368ab68cb28ab0b3c8e5262839ee0d3 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 21 Jun 2024 09:58:38 +0200 Subject: [PATCH 15/35] Apply style --- src/transformers/models/idefics2/modeling_idefics2.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 40a7f06c404c..c05c65e86af1 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1640,7 +1640,13 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, num_logits_to_keep=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + num_logits_to_keep=None, + **kwargs, ): past_length = 0 # Omit tokens covered by past_key_values From f7421b6951f5678b4b8463745e3743bce1f8db6e Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 12 Jul 2024 15:21:16 +0200 Subject: [PATCH 16/35] Make `num_logits_to_keep` an int in all cases to remove if-else clause --- .../models/cohere/modeling_cohere.py | 19 ++++++++----------- src/transformers/models/dbrx/modeling_dbrx.py | 17 +++++++---------- .../models/gemma/modeling_gemma.py | 19 ++++++++----------- .../models/idefics2/modeling_idefics2.py | 19 ++++++++----------- .../models/jetmoe/modeling_jetmoe.py | 19 ++++++++----------- .../models/llama/modeling_llama.py | 19 ++++++++----------- .../models/mistral/modeling_mistral.py | 19 ++++++++----------- .../models/mixtral/modeling_mixtral.py | 19 ++++++++----------- src/transformers/models/olmo/modeling_olmo.py | 19 ++++++++----------- .../models/persimmon/modeling_persimmon.py | 17 +++++++---------- src/transformers/models/phi/modeling_phi.py | 19 ++++++++----------- src/transformers/models/phi3/modeling_phi3.py | 19 ++++++++----------- .../models/qwen2/modeling_qwen2.py | 19 ++++++++----------- .../models/qwen2_moe/modeling_qwen2_moe.py | 19 ++++++++----------- .../models/stablelm/modeling_stablelm.py | 17 +++++++---------- .../models/starcoder2/modeling_starcoder2.py | 19 ++++++++----------- 16 files changed, 125 insertions(+), 173 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 0d3bd0939e9b..01fa18dc08c6 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -1024,7 +1024,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: Optional[int] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1033,10 +1033,10 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int` or `None`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all - `input_ids`. Only last token logits are needed for generation, and calculating them only for that token - can save memory, which becomes pretty significant for long sequences. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. Returns: @@ -1082,11 +1082,8 @@ def forward( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove those 2 float() operations in v4.44 - if num_logits_to_keep is None: - logits = self.lm_head(hidden_states).float() - else: - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + # TODO: remove the float() operation in v4.44 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() logits = logits * self.logit_scale loss = None @@ -1125,7 +1122,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index e66df692f42d..d2b76b8ef835 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1275,7 +1275,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: Optional[int] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r"""Forward function for causal language modeling. @@ -1285,10 +1285,10 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int` or `None`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all - `input_ids`. Only last token logits are needed for generation, and calculating them only for that token - can save memory, which becomes pretty significant for long sequences. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. Returns: @@ -1335,10 +1335,7 @@ def forward( hidden_states = outputs[0] # No upscaling to float was ever done for Dbrx - if num_logits_to_keep is None: - logits = self.lm_head(hidden_states) - else: - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) loss = None if labels is not None: @@ -1390,7 +1387,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 566d29a12f65..bea6a901bc36 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1038,7 +1038,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: Optional[int] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1047,10 +1047,10 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int` or `None`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all - `input_ids`. Only last token logits are needed for generation, and calculating them only for that token - can save memory, which becomes pretty significant for long sequences. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. Returns: @@ -1096,11 +1096,8 @@ def forward( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove those 2 float() operations in v4.44 - if num_logits_to_keep is None: - logits = self.lm_head(hidden_states).float() - else: - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + # TODO: remove the float() operation in v4.44 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: @@ -1138,7 +1135,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index c05c65e86af1..7eb4d63cab2b 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1520,7 +1520,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - num_logits_to_keep: Optional[int] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, Idefics2CausalLMOutputWithPast]: r""" Args: @@ -1530,10 +1530,10 @@ def forward( Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int` or `None`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all - `input_ids`. Only last token logits are needed for generation, and calculating them only for that token - can save memory, which becomes pretty significant for long sequences. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. Returns: @@ -1603,11 +1603,8 @@ def forward( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove those 2 float() operations in v4.44 - if num_logits_to_keep is None: - logits = self.lm_head(hidden_states).float() - else: - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + # TODO: remove the float() operation in v4.44 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: @@ -1645,7 +1642,7 @@ def prepare_inputs_for_generation( past_key_values=None, attention_mask=None, inputs_embeds=None, - num_logits_to_keep=None, + num_logits_to_keep=0, **kwargs, ): past_length = 0 diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index b3a68e231f75..d8fd5f4d10bd 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -1253,7 +1253,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: Optional[int] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" Args: @@ -1262,10 +1262,10 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int` or `None`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all - `input_ids`. Only last token logits are needed for generation, and calculating them only for that token - can save memory, which becomes pretty significant for long sequences. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. Returns: """ @@ -1296,11 +1296,8 @@ def forward( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove those 2 float() operations in v4.44 - if num_logits_to_keep is None: - logits = self.lm_head(hidden_states).float() - else: - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + # TODO: remove the float() operation in v4.44 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: @@ -1355,7 +1352,7 @@ def prepare_inputs_for_generation( output_router_logits=False, position_ids=None, use_cache=True, - num_logits_to_keep=None, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 79fb600175d9..67b71ae46a9f 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1146,7 +1146,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: Optional[int] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1155,10 +1155,10 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int` or `None`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all - `input_ids`. Only last token logits are needed for generation, and calculating them only for that token - can save memory, which becomes pretty significant for long sequences. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. Returns: @@ -1209,11 +1209,8 @@ def forward( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove those 2 float() operations in v4.44 - if num_logits_to_keep is None: - logits = self.lm_head(hidden_states).float() - else: - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + # TODO: remove the float() operation in v4.44 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: @@ -1251,7 +1248,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index e3ef86597781..e88757284a45 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -996,7 +996,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: Optional[int] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1005,10 +1005,10 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int` or `None`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all - `input_ids`. Only last token logits are needed for generation, and calculating them only for that token - can save memory, which becomes pretty significant for long sequences. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. Returns: @@ -1055,11 +1055,8 @@ def forward( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove those 2 float() operations in v4.44 - if num_logits_to_keep is None: - logits = self.lm_head(hidden_states).float() - else: - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + # TODO: remove the float() operation in v4.44 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: @@ -1097,7 +1094,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 801e52812049..83b6447a6457 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1233,7 +1233,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: Optional[int] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" Args: @@ -1242,10 +1242,10 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int` or `None`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all - `input_ids`. Only last token logits are needed for generation, and calculating them only for that token - can save memory, which becomes pretty significant for long sequences. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. Returns: @@ -1297,11 +1297,8 @@ def forward( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove those 2 float() operations in v4.44 - if num_logits_to_keep is None: - logits = self.lm_head(hidden_states).float() - else: - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + # TODO: remove the float() operation in v4.44 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: @@ -1355,7 +1352,7 @@ def prepare_inputs_for_generation( output_router_logits=False, position_ids=None, use_cache=True, - num_logits_to_keep=None, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 89da0e940d71..521370900b5d 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -1068,7 +1068,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: Optional[int] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1077,10 +1077,10 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int` or `None`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all - `input_ids`. Only last token logits are needed for generation, and calculating them only for that token - can save memory, which becomes pretty significant for long sequences. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. Returns: @@ -1127,11 +1127,8 @@ def forward( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove those 2 float() operations in v4.44 - if num_logits_to_keep is None: - logits = self.lm_head(hidden_states).float() - else: - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + # TODO: remove the float() operation in v4.44 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: @@ -1169,7 +1166,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 85bce2ab8930..7d43ff552c37 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -885,7 +885,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: Optional[int] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -894,10 +894,10 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int` or `None`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all - `input_ids`. Only last token logits are needed for generation, and calculating them only for that token - can save memory, which becomes pretty significant for long sequences. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. Returns: @@ -940,10 +940,7 @@ def forward( hidden_states = outputs[0] # No upscaling to float was ever done for Persimmon - if num_logits_to_keep is None: - logits = self.lm_head(hidden_states) - else: - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) loss = None if labels is not None: @@ -980,7 +977,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 953cd6cf0e92..afc9317bdae1 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -1169,7 +1169,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: Optional[int] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1178,10 +1178,10 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int` or `None`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all - `input_ids`. Only last token logits are needed for generation, and calculating them only for that token - can save memory, which becomes pretty significant for long sequences. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. Returns: @@ -1228,11 +1228,8 @@ def forward( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove those 2 float() operations in v4.44 - if num_logits_to_keep is None: - logits = self.lm_head(hidden_states).float() - else: - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + # TODO: remove the float() operation in v4.44 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: @@ -1271,7 +1268,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 11ac91b88943..0deb8faed939 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1210,7 +1210,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: Optional[int] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1219,10 +1219,10 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int` or `None`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all - `input_ids`. Only last token logits are needed for generation, and calculating them only for that token - can save memory, which becomes pretty significant for long sequences. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. Returns: @@ -1268,11 +1268,8 @@ def forward( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove those 2 float() operations in v4.44 - if num_logits_to_keep is None: - logits = self.lm_head(hidden_states).float() - else: - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + # TODO: remove the float() operation in v4.44 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: @@ -1311,7 +1308,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 476593ca3241..87b8269247ac 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -1067,7 +1067,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: Optional[int] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1076,10 +1076,10 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int` or `None`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all - `input_ids`. Only last token logits are needed for generation, and calculating them only for that token - can save memory, which becomes pretty significant for long sequences. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. Returns: @@ -1126,11 +1126,8 @@ def forward( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove those 2 float() operations in v4.44 - if num_logits_to_keep is None: - logits = self.lm_head(hidden_states).float() - else: - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + # TODO: remove the float() operation in v4.44 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: @@ -1169,7 +1166,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 5e9e88004c3c..6b648240e314 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1244,7 +1244,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: Optional[int] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" Args: @@ -1253,10 +1253,10 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int` or `None`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all - `input_ids`. Only last token logits are needed for generation, and calculating them only for that token - can save memory, which becomes pretty significant for long sequences. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. Returns: @@ -1307,11 +1307,8 @@ def forward( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove those 2 float() operations in v4.44 - if num_logits_to_keep is None: - logits = self.lm_head(hidden_states).float() - else: - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + # TODO: remove the float() operation in v4.44 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: @@ -1365,7 +1362,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 63832dd9f592..14e7d5cc65fc 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -1164,7 +1164,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: Optional[int] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1173,10 +1173,10 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int` or `None`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all - `input_ids`. Only last token logits are needed for generation, and calculating them only for that token - can save memory, which becomes pretty significant for long sequences. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. Returns: @@ -1218,10 +1218,7 @@ def forward( hidden_states = outputs[0] # No upscaling to float was ever done for StableLm - if num_logits_to_keep is None: - logits = self.lm_head(hidden_states) - else: - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) loss = None if labels is not None: @@ -1258,7 +1255,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index d6ce8c223c83..75e0c40c6ce7 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -1043,7 +1043,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: Optional[int] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1052,10 +1052,10 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int` or `None`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all - `input_ids`. Only last token logits are needed for generation, and calculating them only for that token - can save memory, which becomes pretty significant for long sequences. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. Returns: @@ -1102,11 +1102,8 @@ def forward( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove those 2 float() operations in v4.44 - if num_logits_to_keep is None: - logits = self.lm_head(hidden_states).float() - else: - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + # TODO: remove the float() operation in v4.44 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: @@ -1145,7 +1142,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, - num_logits_to_keep=None, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens From c8f917766a18ef9d23b5c3fe221f292c4988c6c7 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 17 Jul 2024 13:46:32 +0200 Subject: [PATCH 17/35] Add compile check for the warning --- src/transformers/models/cohere/modeling_cohere.py | 2 +- src/transformers/models/gemma/modeling_gemma.py | 2 +- src/transformers/models/idefics2/modeling_idefics2.py | 2 +- src/transformers/models/jamba/modeling_jamba.py | 2 +- src/transformers/models/jetmoe/modeling_jetmoe.py | 2 +- src/transformers/models/llama/modeling_llama.py | 2 +- src/transformers/models/mistral/modeling_mistral.py | 2 +- src/transformers/models/mixtral/modeling_mixtral.py | 2 +- src/transformers/models/olmo/modeling_olmo.py | 2 +- src/transformers/models/phi/modeling_phi.py | 2 +- src/transformers/models/phi3/modeling_phi3.py | 2 +- src/transformers/models/qwen2/modeling_qwen2.py | 2 +- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 2 +- src/transformers/models/starcoder2/modeling_starcoder2.py | 2 +- 14 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 01fa18dc08c6..a938fd6045e6 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -1077,7 +1077,7 @@ def forward( ) hidden_states = outputs[0] - if labels is None: + if labels is None and not torch.compiler.is_compiling(): logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index bea6a901bc36..c3c9cd4e6f78 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1091,7 +1091,7 @@ def forward( ) hidden_states = outputs[0] - if labels is None: + if labels is None and not torch.compiler.is_compiling(): logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 7eb4d63cab2b..7e31fd792122 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1598,7 +1598,7 @@ def forward( ) hidden_states = outputs[0] - if labels is None: + if labels is None and not torch.compiler.is_compiling(): logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 426b23253a43..00aee65bcd45 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -1497,7 +1497,7 @@ def forward( logits = self.lm_head(hidden_states) else: logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :]) - if labels is None: + if labels is None and not torch.compiler.is_compiling(): logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index d8fd5f4d10bd..8fd69d8ef2c7 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -1291,7 +1291,7 @@ def forward( ) hidden_states = outputs[0] - if labels is None: + if labels is None and not torch.compiler.is_compiling(): logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 67b71ae46a9f..940d2d4e2382 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1204,7 +1204,7 @@ def forward( logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] logits = torch.cat(logits, dim=-1) else: - if labels is None: + if labels is None and not torch.compiler.is_compiling(): logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index e88757284a45..1c5e63c00221 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1050,7 +1050,7 @@ def forward( ) hidden_states = outputs[0] - if labels is None: + if labels is None and not torch.compiler.is_compiling(): logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 83b6447a6457..44f7c09f5977 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1292,7 +1292,7 @@ def forward( ) hidden_states = outputs[0] - if labels is None: + if labels is None and not torch.compiler.is_compiling(): logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 521370900b5d..98009c989d27 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -1122,7 +1122,7 @@ def forward( ) hidden_states = outputs[0] - if labels is None: + if labels is None and not torch.compiler.is_compiling(): logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index afc9317bdae1..41c27a97b30b 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -1223,7 +1223,7 @@ def forward( ) hidden_states = outputs[0] - if labels is None: + if labels is None and not torch.compiler.is_compiling(): logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 0deb8faed939..fd0a62b42a20 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1263,7 +1263,7 @@ def forward( ) hidden_states = outputs[0] - if labels is None: + if labels is None and not torch.compiler.is_compiling(): logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 87b8269247ac..0cf1b96e038c 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -1121,7 +1121,7 @@ def forward( ) hidden_states = outputs[0] - if labels is None: + if labels is None and not torch.compiler.is_compiling(): logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 6b648240e314..b37ec376f308 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1302,7 +1302,7 @@ def forward( ) hidden_states = outputs[0] - if labels is None: + if labels is None and not torch.compiler.is_compiling(): logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 75e0c40c6ce7..a4faa48063d5 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -1097,7 +1097,7 @@ def forward( ) hidden_states = outputs[0] - if labels is None: + if labels is None and not torch.compiler.is_compiling(): logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) From 5e1589e13217483cfee443f9285b48d168a52f37 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 17 Jul 2024 15:50:17 +0200 Subject: [PATCH 18/35] Fix torch versions --- src/transformers/models/cohere/modeling_cohere.py | 3 ++- src/transformers/models/gemma/modeling_gemma.py | 3 ++- src/transformers/models/idefics2/modeling_idefics2.py | 3 ++- src/transformers/models/jamba/modeling_jamba.py | 3 ++- src/transformers/models/jetmoe/modeling_jetmoe.py | 3 ++- src/transformers/models/llama/modeling_llama.py | 3 ++- src/transformers/models/mistral/modeling_mistral.py | 3 ++- src/transformers/models/mixtral/modeling_mixtral.py | 3 ++- src/transformers/models/olmo/modeling_olmo.py | 3 ++- src/transformers/models/phi/modeling_phi.py | 3 ++- src/transformers/models/phi3/modeling_phi3.py | 3 ++- src/transformers/models/qwen2/modeling_qwen2.py | 3 ++- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 3 ++- src/transformers/models/starcoder2/modeling_starcoder2.py | 3 ++- 14 files changed, 28 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index a938fd6045e6..550de1de7aed 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -44,6 +44,7 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1077,7 +1078,7 @@ def forward( ) hidden_states = outputs[0] - if labels is None and not torch.compiler.is_compiling(): + if labels is None and not is_torchdynamo_compiling(): logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index c3c9cd4e6f78..dfa646641902 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -43,6 +43,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1091,7 +1092,7 @@ def forward( ) hidden_states = outputs[0] - if labels is None and not torch.compiler.is_compiling(): + if labels is None and not is_torchdynamo_compiling(): logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 7e31fd792122..80980038fe8b 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -33,6 +33,7 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1598,7 +1599,7 @@ def forward( ) hidden_states = outputs[0] - if labels is None and not torch.compiler.is_compiling(): + if labels is None and not is_torchdynamo_compiling(): logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 00aee65bcd45..aa2196522f79 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -49,6 +49,7 @@ is_causal_conv1d_available, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, is_mamba_ssm_available, ) from .configuration_jamba import JambaConfig @@ -1497,7 +1498,7 @@ def forward( logits = self.lm_head(hidden_states) else: logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :]) - if labels is None and not torch.compiler.is_compiling(): + if labels is None and not is_torchdynamo_compiling: logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 8fd69d8ef2c7..f837d688d0e9 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -37,6 +37,7 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1291,7 +1292,7 @@ def forward( ) hidden_states = outputs[0] - if labels is None and not torch.compiler.is_compiling(): + if labels is None and not is_torchdynamo_compiling(): logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 940d2d4e2382..429735bedbb7 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -44,6 +44,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1204,7 +1205,7 @@ def forward( logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] logits = torch.cat(logits, dim=-1) else: - if labels is None and not torch.compiler.is_compiling(): + if labels is None and not is_torchdynamo_compiling(): logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 1c5e63c00221..54b0221f47c8 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -42,6 +42,7 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1050,7 +1051,7 @@ def forward( ) hidden_states = outputs[0] - if labels is None and not torch.compiler.is_compiling(): + if labels is None and not is_torchdynamo_compiling(): logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 44f7c09f5977..c4654f607a6d 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -43,6 +43,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1292,7 +1293,7 @@ def forward( ) hidden_states = outputs[0] - if labels is None and not torch.compiler.is_compiling(): + if labels is None and not is_torchdynamo_compiling(): logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 98009c989d27..194e1f9e8da2 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -42,6 +42,7 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1122,7 +1123,7 @@ def forward( ) hidden_states = outputs[0] - if labels is None and not torch.compiler.is_compiling(): + if labels is None and not is_torchdynamo_compiling(): logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 41c27a97b30b..9b15d88a9465 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -41,6 +41,7 @@ get_torch_version, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1223,7 +1224,7 @@ def forward( ) hidden_states = outputs[0] - if labels is None and not torch.compiler.is_compiling(): + if labels is None and not is_torchdynamo_compiling(): logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index fd0a62b42a20..d8be113fcc77 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -40,6 +40,7 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1263,7 +1264,7 @@ def forward( ) hidden_states = outputs[0] - if labels is None and not torch.compiler.is_compiling(): + if labels is None and not is_torchdynamo_compiling(): logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 0cf1b96e038c..52ff3accaa9f 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -42,6 +42,7 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1121,7 +1122,7 @@ def forward( ) hidden_states = outputs[0] - if labels is None and not torch.compiler.is_compiling(): + if labels is None and not is_torchdynamo_compiling(): logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index b37ec376f308..f35ffbad4f2b 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -43,6 +43,7 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1302,7 +1303,7 @@ def forward( ) hidden_states = outputs[0] - if labels is None and not torch.compiler.is_compiling(): + if labels is None and not is_torchdynamo_compiling(): logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index a4faa48063d5..9ff03995d090 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -42,6 +42,7 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1097,7 +1098,7 @@ def forward( ) hidden_states = outputs[0] - if labels is None and not torch.compiler.is_compiling(): + if labels is None and not is_torchdynamo_compiling(): logger.warning_once( "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) From 7998b65071fe2552fb6869dd06e33133c38b1b42 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 17 Jul 2024 16:31:11 +0200 Subject: [PATCH 19/35] style --- src/transformers/models/jamba/modeling_jamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index aa2196522f79..b5ab8f3cf2d8 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -49,8 +49,8 @@ is_causal_conv1d_available, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, - is_torchdynamo_compiling, is_mamba_ssm_available, + is_torchdynamo_compiling, ) from .configuration_jamba import JambaConfig From 8fa801811ffe9720b9003e79742624d629ff6df9 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 20 Aug 2024 10:14:48 +0200 Subject: [PATCH 20/35] Add gemma2 --- src/transformers/models/gemma2/modeling_gemma2.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 398ba4abefe1..25eb6228f936 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -987,6 +987,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -995,6 +996,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: Example: @@ -1038,15 +1044,19 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) if self.config.final_logit_softcapping is not None: logits = logits / self.config.final_logit_softcapping logits = torch.tanh(logits) logits = logits * self.config.final_logit_softcapping + # TODO: remove the float() operation in v4.45 logits = logits.float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1079,6 +1089,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1139,6 +1150,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs From b49fe767e6b6e43856b61800ff671a331cc205a5 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 20 Aug 2024 10:22:51 +0200 Subject: [PATCH 21/35] Update warning version --- src/transformers/models/cohere/modeling_cohere.py | 4 ++-- src/transformers/models/gemma/modeling_gemma.py | 4 ++-- src/transformers/models/gemma2/modeling_gemma2.py | 5 +++++ src/transformers/models/idefics2/modeling_idefics2.py | 4 ++-- src/transformers/models/jamba/modeling_jamba.py | 2 +- src/transformers/models/jetmoe/modeling_jetmoe.py | 4 ++-- src/transformers/models/llama/modeling_llama.py | 4 ++-- src/transformers/models/mistral/modeling_mistral.py | 4 ++-- src/transformers/models/mixtral/modeling_mixtral.py | 4 ++-- src/transformers/models/olmo/modeling_olmo.py | 4 ++-- src/transformers/models/phi/modeling_phi.py | 4 ++-- src/transformers/models/phi3/modeling_phi3.py | 4 ++-- src/transformers/models/qwen2/modeling_qwen2.py | 4 ++-- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 4 ++-- src/transformers/models/starcoder2/modeling_starcoder2.py | 4 ++-- 15 files changed, 32 insertions(+), 27 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 550de1de7aed..71cc27bc6cd6 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -1080,10 +1080,10 @@ def forward( hidden_states = outputs[0] if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.44 + # TODO: remove the float() operation in v4.45 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() logits = logits * self.logit_scale diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index dfa646641902..f55b9acd3776 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1094,10 +1094,10 @@ def forward( hidden_states = outputs[0] if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.44 + # TODO: remove the float() operation in v4.45 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 25eb6228f936..0292011e88f2 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -41,6 +41,7 @@ is_flash_attn_2_available, is_flash_attn_greater_or_equal, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1044,6 +1045,10 @@ def forward( ) hidden_states = outputs[0] + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) if self.config.final_logit_softcapping is not None: diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 80980038fe8b..cacf1b762d53 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1601,10 +1601,10 @@ def forward( hidden_states = outputs[0] if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.44 + # TODO: remove the float() operation in v4.45 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index b5ab8f3cf2d8..ba87d8374baa 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -1500,7 +1500,7 @@ def forward( logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :]) if labels is None and not is_torchdynamo_compiling: logger.warning_once( - "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # TODO: remove this float() operations in v4.44 logits = logits.float() diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index f837d688d0e9..905c20eb65ab 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -1294,10 +1294,10 @@ def forward( hidden_states = outputs[0] if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.44 + # TODO: remove the float() operation in v4.45 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 429735bedbb7..31446432cff7 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1207,10 +1207,10 @@ def forward( else: if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.44 + # TODO: remove the float() operation in v4.45 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 54b0221f47c8..dc00c2c5196a 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1053,10 +1053,10 @@ def forward( hidden_states = outputs[0] if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.44 + # TODO: remove the float() operation in v4.45 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index c4654f607a6d..3994321390f7 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1295,10 +1295,10 @@ def forward( hidden_states = outputs[0] if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.44 + # TODO: remove the float() operation in v4.45 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 194e1f9e8da2..114902659ff5 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -1125,10 +1125,10 @@ def forward( hidden_states = outputs[0] if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.44 + # TODO: remove the float() operation in v4.45 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 9b15d88a9465..d4b515ee12b5 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -1226,10 +1226,10 @@ def forward( hidden_states = outputs[0] if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.44 + # TODO: remove the float() operation in v4.45 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index d8be113fcc77..261f66596b4e 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1266,10 +1266,10 @@ def forward( hidden_states = outputs[0] if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.44 + # TODO: remove the float() operation in v4.45 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 52ff3accaa9f..8e69126d8eee 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -1124,10 +1124,10 @@ def forward( hidden_states = outputs[0] if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.44 + # TODO: remove the float() operation in v4.45 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index f35ffbad4f2b..3451cbaebbe3 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1305,10 +1305,10 @@ def forward( hidden_states = outputs[0] if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.44 + # TODO: remove the float() operation in v4.45 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 9ff03995d090..6543d98a708e 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -1100,10 +1100,10 @@ def forward( hidden_states = outputs[0] if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.44, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.44 + # TODO: remove the float() operation in v4.45 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None From cf9378a437875c6330eae8c58ab855b11d216297 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 20 Aug 2024 11:04:27 +0200 Subject: [PATCH 22/35] Add comment about .float operations in generation utils --- src/transformers/generation/utils.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 7d3a4809faa9..ff98524a8ea4 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2426,8 +2426,9 @@ def _dola_decoding( output_hidden_states=True, ) - final_layer_next_token_logits = outputs.logits[:, -1, :].detach().clone() - final_logits = outputs.logits[:, -1, :] + # .float() is needed to retain precision for later logits manipulations + final_layer_next_token_logits = outputs.logits[:, -1, :].detach().clone().float() + final_logits = outputs.logits[:, -1, :].float() candidate_premature_logits = {} for candidate_premature_layer in candidate_premature_layers: candidate_premature_logits[candidate_premature_layer] = lm_head( @@ -2604,6 +2605,7 @@ def _contrastive_search( # next logit for contrastive search to select top-k candidate tokens # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration # (the clone itself is always small) + # .float() is needed to retain precision for later logits manipulations logit_for_next_step = outputs.logits[:, -1, :].clone().float() model_kwargs = self._update_model_kwargs_for_generation( @@ -2734,7 +2736,8 @@ def _contrastive_search( next_hidden = outputs.hidden_states[-1] full_hidden_states = outputs.hidden_states - logits = outputs.logits[:, -1, :] + # .float() is needed to retain precision for later logits manipulations + logits = outputs.logits[:, -1, :].float() context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the @@ -2792,7 +2795,7 @@ def _contrastive_search( next_past_key_values = tuple(new_key_values) - logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :].float() + logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :] # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration if self.config.is_encoder_decoder: @@ -2980,6 +2983,7 @@ def _sample( # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) + # .float() is needed to retain precision for later logits manipulations next_token_logits = outputs.logits[:, -1, :].clone().float() # pre-process distribution @@ -3224,6 +3228,7 @@ def _beam_search( # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) + # .float() is needed to retain precision for later logits manipulations next_token_logits = outputs.logits[:, -1, :].clone().float() next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 @@ -3498,6 +3503,7 @@ def _group_beam_search( # select outputs of beams of current group only # No need to clone() the logits here as they will not retain outputs.logits at the end of the loop + # .float() is needed to retain precision for later logits manipulations next_token_logits = outputs.logits[batch_group_indices, -1, :].float() next_token_scores = nn.functional.log_softmax( @@ -3753,6 +3759,7 @@ def _constrained_beam_search( # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) + # .float() is needed to retain precision for later logits manipulations next_token_logits = outputs.logits[:, -1, :].clone().float() next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 @@ -4012,8 +4019,9 @@ def _assisted_decoding( outputs = self(**model_inputs) # 2.3. Process the new logits - new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present - next_token_logits = new_logits.clone().float() + # .float() is needed to retain precision for later logits manipulations + new_logits = outputs.logits[:, -candidate_length - 1 :].float() # excludes the input prompt if present + next_token_logits = new_logits.clone() if len(logits_processor) > 0: for i in range(candidate_length + 1): new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) From 66e3e9d831b63bb804e05716b392f9982957df5e Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 21 Aug 2024 11:19:07 +0200 Subject: [PATCH 23/35] Add tests in GenerationTesterMixin and ModelTesterMixin --- src/transformers/generation/utils.py | 2 +- tests/generation/test_utils.py | 35 ++++++++++++++++++++++++++++ tests/test_modeling_common.py | 21 +++++++++++++++++ 3 files changed, 57 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ff98524a8ea4..616dbe6d5d3d 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1887,7 +1887,7 @@ def generate( # If the model supports `num_logits_to_keep` in forward(), set it to 1 to avoid computing the whole # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding # dynamically overrides this value as it can need more than the last token logits - if self._supports_num_logits_to_keep(): + if self._supports_num_logits_to_keep() and "num_logits_to_keep" not in model_kwargs: model_kwargs["num_logits_to_keep"] = 1 self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index ae52f6c67404..18299db36be4 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1828,6 +1828,41 @@ def test_generate_compile_fullgraph(self): output_compiled = compiled_generate(model_inputs, generation_config=generation_config) self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist()) + def test_generate_with_num_logits_to_keep(self): + for model_class in self.all_generative_model_classes: + if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): + self.skipTest(reason="This model does not support `num_logits_to_keep` argument.") + + config, input_ids, attention_mask = self._get_input_ids_and_config() + config.use_cache = True + config.is_decoder = True + + model = model_class(config).to(torch_device).eval() + assistant_model = model + # All generation methods (except assisted decoding) rely on always extracting the last token logits of the + # full logits matrix, so testing out only greedy search and assisted decoding is enough (if it works, + # other methods will work as well) + strategies = [ + { + "max_new_tokens": 10, + "do_sample": False, + }, + { + "max_new_tokens": 10, + "do_sample": False, + "assistant_model": assistant_model, + }, + ] + + for generation_kwargs in strategies: + # Setting num_logits_to_keep at 0 keeps all logits (old behavior) + with_all_logits = model.generate( + input_ids, attention_mask=attention_mask, **generation_kwargs, num_logits_to_keep=0 + ) + # By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior) + without_all_logits = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) + self.assertListEqual(with_all_logits.tolist(), without_all_logits.tolist()) + def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): batch_size, seq_length = input_ids.shape num_sequences_in_output = batch_size * num_return_sequences diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 7cbc2f3e2813..126696459c04 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4824,6 +4824,27 @@ def test_compile_cuda_graph_time(self): self.assertTrue(record_time < 0.15 * graph_warmup_time) self.assertTrue(opt_time < record_time) + def test_forward_with_num_logits_to_keep(self): + for model_class in self.all_generative_model_classes: + if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): + self.skipTest(reason="This model does not support `num_logits_to_keep` argument.") + + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + batch_size, sequence_length = inputs["input_ids"].shape + vocab_size = config.vocab_size + model = model_class(config).to(device=torch_device).eval() + + # num_logits_to_keep=0 is a special case meaning "keep all logits" + all_logits = model(**inputs, num_logits_to_keep=0).logits + last_token_logits = model(**inputs, num_logits_to_keep=1).logits + + # Assert all shapes are correct + self.assertEqual(tuple(all_logits.shape), (batch_size, sequence_length, vocab_size)) + self.assertEqual(tuple(last_token_logits.shape), (batch_size, 1, vocab_size)) + + # Assert the last tokens are actually the same + self.assertEqual(all_logits[:, -1, :].tolist(), last_token_logits.tolist()) + global_rng = random.Random() From e4c5a71bb1d78ea96ca38c51bf9a722d962fb345 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 21 Aug 2024 11:29:55 +0200 Subject: [PATCH 24/35] Fix batch size for assisted decoding in tests --- tests/generation/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 18299db36be4..0333c5fbada3 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1833,7 +1833,7 @@ def test_generate_with_num_logits_to_keep(self): if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): self.skipTest(reason="This model does not support `num_logits_to_keep` argument.") - config, input_ids, attention_mask = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1) config.use_cache = True config.is_decoder = True From b68ee1661482b291427cb4ec16385e127c666080 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 21 Aug 2024 11:53:50 +0200 Subject: [PATCH 25/35] fix small issues in test --- tests/generation/test_utils.py | 4 +++- tests/test_modeling_common.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 0333c5fbada3..9f7314fe0e5f 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1854,7 +1854,9 @@ def test_generate_with_num_logits_to_keep(self): }, ] - for generation_kwargs in strategies: + for i, generation_kwargs in enumerate(strategies): + if i == 1 and model_class._is_stateful: + self.skipTest(reason="Stateful models don't support assisted generation") # Setting num_logits_to_keep at 0 keeps all logits (old behavior) with_all_logits = model.generate( input_ids, attention_mask=attention_mask, **generation_kwargs, num_logits_to_keep=0 diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 126696459c04..e8622f73242f 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4843,7 +4843,7 @@ def test_forward_with_num_logits_to_keep(self): self.assertEqual(tuple(last_token_logits.shape), (batch_size, 1, vocab_size)) # Assert the last tokens are actually the same - self.assertEqual(all_logits[:, -1, :].tolist(), last_token_logits.tolist()) + self.assertTrue(torch.allclose(all_logits[:, -1, :], last_token_logits)) global_rng = random.Random() From e8374252e514600a41b246c59fcef9a5a3b142aa Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 21 Aug 2024 12:13:38 +0200 Subject: [PATCH 26/35] refacor test --- tests/generation/test_utils.py | 63 ++++++++++++++++++++++------------ 1 file changed, 41 insertions(+), 22 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 9f7314fe0e5f..b0cf08d0530c 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1828,11 +1828,39 @@ def test_generate_compile_fullgraph(self): output_compiled = compiled_generate(model_inputs, generation_config=generation_config) self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist()) - def test_generate_with_num_logits_to_keep(self): + def test_generate_methods_with_num_logits_to_keep(self): for model_class in self.all_generative_model_classes: if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): self.skipTest(reason="This model does not support `num_logits_to_keep` argument.") + config, input_ids, attention_mask = self._get_input_ids_and_config() + config.use_cache = True + config.is_decoder = True + + model = model_class(config).to(torch_device).eval() + # All generation methods (except assisted decoding) rely on always extracting the last token logits of the + # full logits matrix, so testing out only greedy search and assisted decoding is enough (if it works, + # other methods will work as well) + generation_kwargs = { + "max_new_tokens": 10, + "do_sample": False, + } + + # Setting num_logits_to_keep at 0 keeps all logits (old behavior) + with_all_logits = model.generate( + input_ids, attention_mask=attention_mask, **generation_kwargs, num_logits_to_keep=0 + ) + # By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior) + without_all_logits = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) + self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist()) + + def test_assisted_decoding_with_num_logits_to_keep(self): + for model_class in self.all_generative_model_classes: + if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): + self.skipTest(reason="This model does not support `num_logits_to_keep` argument.") + if model_class._is_stateful: + self.skipTest(reason="Stateful models don't support assisted generation") + config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1) config.use_cache = True config.is_decoder = True @@ -1842,28 +1870,19 @@ def test_generate_with_num_logits_to_keep(self): # All generation methods (except assisted decoding) rely on always extracting the last token logits of the # full logits matrix, so testing out only greedy search and assisted decoding is enough (if it works, # other methods will work as well) - strategies = [ - { - "max_new_tokens": 10, - "do_sample": False, - }, - { - "max_new_tokens": 10, - "do_sample": False, - "assistant_model": assistant_model, - }, - ] + generation_kwargs = { + "max_new_tokens": 10, + "do_sample": False, + "assistant_model": assistant_model, + } - for i, generation_kwargs in enumerate(strategies): - if i == 1 and model_class._is_stateful: - self.skipTest(reason="Stateful models don't support assisted generation") - # Setting num_logits_to_keep at 0 keeps all logits (old behavior) - with_all_logits = model.generate( - input_ids, attention_mask=attention_mask, **generation_kwargs, num_logits_to_keep=0 - ) - # By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior) - without_all_logits = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) - self.assertListEqual(with_all_logits.tolist(), without_all_logits.tolist()) + # Setting num_logits_to_keep at 0 keeps all logits (old behavior) + with_all_logits = model.generate( + input_ids, attention_mask=attention_mask, **generation_kwargs, num_logits_to_keep=0 + ) + # By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior) + without_all_logits = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) + self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist()) def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): batch_size, seq_length = input_ids.shape From 26863ca36046aad4ea7a9308a047a28cc114d481 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 21 Aug 2024 12:21:56 +0200 Subject: [PATCH 27/35] fix slicing removing dim issue --- tests/test_modeling_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index e8622f73242f..7617c15efabf 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4843,7 +4843,7 @@ def test_forward_with_num_logits_to_keep(self): self.assertEqual(tuple(last_token_logits.shape), (batch_size, 1, vocab_size)) # Assert the last tokens are actually the same - self.assertTrue(torch.allclose(all_logits[:, -1, :], last_token_logits)) + self.assertTrue(torch.allclose(all_logits[:, -1:, :], last_token_logits)) global_rng = random.Random() From 3c3eeaa9df9c3c38ea44b7078b9dee8cfb01e824 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 21 Aug 2024 15:23:46 +0200 Subject: [PATCH 28/35] Add nemotron support (should fix check-copy issue in CIs) --- .../models/nemotron/modeling_nemotron.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index db4bce273ca1..d2f03f79250b 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -42,6 +42,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1029,6 +1030,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1037,6 +1039,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: Example: @@ -1076,11 +1083,19 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + # TODO: remove the float() operation in v4.45 logits = logits.float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1113,6 +1128,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1170,6 +1186,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs From c4008655bc0ea922783d23a59af3a63a173e30b4 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 21 Aug 2024 15:35:32 +0200 Subject: [PATCH 29/35] Trigger new CIs From 802eca83b29e8d467328f1d82853dd8157110947 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 22 Aug 2024 08:14:54 +0200 Subject: [PATCH 30/35] Trigger new CIs From 4d6fae6543912c6da1c7175970ad77ba84b8f8ae Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 22 Aug 2024 17:10:10 +0200 Subject: [PATCH 31/35] Bump version --- src/transformers/models/cohere/modeling_cohere.py | 2 +- src/transformers/models/gemma/modeling_gemma.py | 2 +- src/transformers/models/gemma2/modeling_gemma2.py | 2 +- src/transformers/models/idefics2/modeling_idefics2.py | 2 +- src/transformers/models/jamba/modeling_jamba.py | 2 +- src/transformers/models/jetmoe/modeling_jetmoe.py | 2 +- src/transformers/models/llama/modeling_llama.py | 2 +- src/transformers/models/mistral/modeling_mistral.py | 2 +- src/transformers/models/mixtral/modeling_mixtral.py | 2 +- src/transformers/models/nemotron/modeling_nemotron.py | 2 +- src/transformers/models/olmo/modeling_olmo.py | 2 +- src/transformers/models/phi/modeling_phi.py | 2 +- src/transformers/models/phi3/modeling_phi3.py | 2 +- src/transformers/models/qwen2/modeling_qwen2.py | 2 +- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 2 +- src/transformers/models/starcoder2/modeling_starcoder2.py | 2 +- 16 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 71cc27bc6cd6..f516915fc8c7 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -1080,7 +1080,7 @@ def forward( hidden_states = outputs[0] if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss # TODO: remove the float() operation in v4.45 diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index f55b9acd3776..3253b43d9263 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1094,7 +1094,7 @@ def forward( hidden_states = outputs[0] if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss # TODO: remove the float() operation in v4.45 diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 0292011e88f2..719ae707e6df 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -1047,7 +1047,7 @@ def forward( hidden_states = outputs[0] if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index cacf1b762d53..e30ab158ad2a 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1601,7 +1601,7 @@ def forward( hidden_states = outputs[0] if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss # TODO: remove the float() operation in v4.45 diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index ba87d8374baa..f2039489b743 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -1500,7 +1500,7 @@ def forward( logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :]) if labels is None and not is_torchdynamo_compiling: logger.warning_once( - "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # TODO: remove this float() operations in v4.44 logits = logits.float() diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 905c20eb65ab..f500bab07e6f 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -1294,7 +1294,7 @@ def forward( hidden_states = outputs[0] if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss # TODO: remove the float() operation in v4.45 diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 31446432cff7..a9b8f493a693 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1207,7 +1207,7 @@ def forward( else: if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss # TODO: remove the float() operation in v4.45 diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index dc00c2c5196a..e91ef2eb0e39 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1053,7 +1053,7 @@ def forward( hidden_states = outputs[0] if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss # TODO: remove the float() operation in v4.45 diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 3994321390f7..775bf90613b9 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1295,7 +1295,7 @@ def forward( hidden_states = outputs[0] if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss # TODO: remove the float() operation in v4.45 diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index d2f03f79250b..d28b9fa003d1 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -1085,7 +1085,7 @@ def forward( hidden_states = outputs[0] if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 114902659ff5..9a5c431abdbc 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -1125,7 +1125,7 @@ def forward( hidden_states = outputs[0] if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss # TODO: remove the float() operation in v4.45 diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index d4b515ee12b5..913eb58a2c39 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -1226,7 +1226,7 @@ def forward( hidden_states = outputs[0] if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss # TODO: remove the float() operation in v4.45 diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 261f66596b4e..02f806d8b5cc 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1266,7 +1266,7 @@ def forward( hidden_states = outputs[0] if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss # TODO: remove the float() operation in v4.45 diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 8e69126d8eee..8bef1c400bca 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -1124,7 +1124,7 @@ def forward( hidden_states = outputs[0] if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss # TODO: remove the float() operation in v4.45 diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 3451cbaebbe3..6da9e813cdb5 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1305,7 +1305,7 @@ def forward( hidden_states = outputs[0] if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss # TODO: remove the float() operation in v4.45 diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 6543d98a708e..9c3a41e79e99 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -1100,7 +1100,7 @@ def forward( hidden_states = outputs[0] if labels is None and not is_torchdynamo_compiling(): logger.warning_once( - "Starting from v4.45, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss # TODO: remove the float() operation in v4.45 From f12f172f2df8ce18dfb4d874d2e4822aebf86557 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 22 Aug 2024 17:15:30 +0200 Subject: [PATCH 32/35] Bump version in TODO --- src/transformers/models/cohere/modeling_cohere.py | 2 +- src/transformers/models/gemma/modeling_gemma.py | 2 +- src/transformers/models/gemma2/modeling_gemma2.py | 2 +- src/transformers/models/idefics2/modeling_idefics2.py | 2 +- src/transformers/models/jamba/modeling_jamba.py | 2 +- src/transformers/models/jetmoe/modeling_jetmoe.py | 2 +- src/transformers/models/llama/modeling_llama.py | 2 +- src/transformers/models/mistral/modeling_mistral.py | 2 +- src/transformers/models/mixtral/modeling_mixtral.py | 2 +- src/transformers/models/nemotron/modeling_nemotron.py | 2 +- src/transformers/models/olmo/modeling_olmo.py | 2 +- src/transformers/models/phi/modeling_phi.py | 2 +- src/transformers/models/phi3/modeling_phi3.py | 2 +- src/transformers/models/qwen2/modeling_qwen2.py | 2 +- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 2 +- src/transformers/models/starcoder2/modeling_starcoder2.py | 2 +- 16 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index f516915fc8c7..ea5fd6749e6b 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -1083,7 +1083,7 @@ def forward( "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.45 + # TODO: remove the float() operation in v4.46 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() logits = logits * self.logit_scale diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 3253b43d9263..ea3448e03313 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1097,7 +1097,7 @@ def forward( "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.45 + # TODO: remove the float() operation in v4.46 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 719ae707e6df..bf6ff76189d4 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -1056,7 +1056,7 @@ def forward( logits = torch.tanh(logits) logits = logits * self.config.final_logit_softcapping - # TODO: remove the float() operation in v4.45 + # TODO: remove the float() operation in v4.46 logits = logits.float() loss = None if labels is not None: diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index e30ab158ad2a..4d3db3208d77 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1604,7 +1604,7 @@ def forward( "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.45 + # TODO: remove the float() operation in v4.46 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index f2039489b743..2722a5e06909 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -1502,7 +1502,7 @@ def forward( logger.warning_once( "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) - # TODO: remove this float() operations in v4.44 + # TODO: remove the float() operations in v4.46 logits = logits.float() loss = None diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index f500bab07e6f..244f74717990 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -1297,7 +1297,7 @@ def forward( "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.45 + # TODO: remove the float() operation in v4.46 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index a9b8f493a693..7ce9d2b028b7 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1210,7 +1210,7 @@ def forward( "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.45 + # TODO: remove the float() operation in v4.46 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index e91ef2eb0e39..438d8e8a568a 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1056,7 +1056,7 @@ def forward( "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.45 + # TODO: remove the float() operation in v4.46 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 775bf90613b9..247b5e10f762 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1298,7 +1298,7 @@ def forward( "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.45 + # TODO: remove the float() operation in v4.46 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index d28b9fa003d1..76edefcf3025 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -1089,7 +1089,7 @@ def forward( ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) - # TODO: remove the float() operation in v4.45 + # TODO: remove the float() operation in v4.46 logits = logits.float() loss = None diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 9a5c431abdbc..edf3eb0ab3fe 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -1128,7 +1128,7 @@ def forward( "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.45 + # TODO: remove the float() operation in v4.46 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 913eb58a2c39..fbc8720c89c8 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -1229,7 +1229,7 @@ def forward( "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.45 + # TODO: remove the float() operation in v4.46 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 02f806d8b5cc..fea30dc191d2 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1269,7 +1269,7 @@ def forward( "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.45 + # TODO: remove the float() operation in v4.46 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 8bef1c400bca..1db1da30f96c 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -1127,7 +1127,7 @@ def forward( "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.45 + # TODO: remove the float() operation in v4.46 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 6da9e813cdb5..31ea55644611 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1308,7 +1308,7 @@ def forward( "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.45 + # TODO: remove the float() operation in v4.46 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 9c3a41e79e99..ea3f3be9d861 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -1103,7 +1103,7 @@ def forward( "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" ) # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - # TODO: remove the float() operation in v4.45 + # TODO: remove the float() operation in v4.46 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None From 7b1a26cc26fed41e5a9be0f41052be26742fd66a Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 22 Aug 2024 17:38:39 +0200 Subject: [PATCH 33/35] Trigger CIs From b11b048f734d2dfa26775919919d5a7df0d8e15f Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 23 Aug 2024 11:19:04 +0200 Subject: [PATCH 34/35] remove blank space --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 616dbe6d5d3d..319873e3f7e8 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1889,7 +1889,7 @@ def generate( # dynamically overrides this value as it can need more than the last token logits if self._supports_num_logits_to_keep() and "num_logits_to_keep" not in model_kwargs: model_kwargs["num_logits_to_keep"] = 1 - + self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) # 7. Prepare the cache. From f03adfb17030181d39a698b7c9b0b0ba2df80722 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 23 Aug 2024 11:31:32 +0200 Subject: [PATCH 35/35] Trigger CIs