diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 1ab5ea527e49..bf55ae3e2b06 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -119,6 +119,10 @@ def __init__( value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value) ) + # Remove potential default "num_logits_to_keep" key + if "num_logits_to_keep" in assistant_kwargs.keys() and not assistant_model._supports_num_logits_to_keep(): + del assistant_kwargs["num_logits_to_keep"] + if "assistant_encoder_outputs" in model_kwargs: assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"] elif assistant_model.config.is_encoder_decoder: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index a9ebdcdd4775..319873e3f7e8 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1601,6 +1601,13 @@ def _prepare_cache_for_generation( else EncoderDecoderCache(DynamicCache(), DynamicCache()) ) + def _supports_num_logits_to_keep(self) -> bool: + """ + Return True if the current model supports the keyword argument `num_logits_to_keep` in forward() + to save memory. Checking it in this way allows to avoid using a new model attribute. + """ + return "num_logits_to_keep" in set(inspect.signature(self.forward).parameters.keys()) + def _prepare_special_tokens( self, generation_config: GenerationConfig, @@ -1876,6 +1883,13 @@ def generate( inputs_tensor=inputs_tensor, input_ids_length=input_ids_length, ) + + # If the model supports `num_logits_to_keep` in forward(), set it to 1 to avoid computing the whole + # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding + # dynamically overrides this value as it can need more than the last token logits + if self._supports_num_logits_to_keep() and "num_logits_to_keep" not in model_kwargs: + model_kwargs["num_logits_to_keep"] = 1 + self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) # 7. Prepare the cache. @@ -2412,8 +2426,9 @@ def _dola_decoding( output_hidden_states=True, ) - final_layer_next_token_logits = outputs.logits[:, -1, :].detach().clone() - final_logits = outputs.logits[:, -1, :] + # .float() is needed to retain precision for later logits manipulations + final_layer_next_token_logits = outputs.logits[:, -1, :].detach().clone().float() + final_logits = outputs.logits[:, -1, :].float() candidate_premature_logits = {} for candidate_premature_layer in candidate_premature_layers: candidate_premature_logits[candidate_premature_layer] = lm_head( @@ -2590,7 +2605,8 @@ def _contrastive_search( # next logit for contrastive search to select top-k candidate tokens # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration # (the clone itself is always small) - logit_for_next_step = outputs.logits[:, -1, :].clone() + # .float() is needed to retain precision for later logits manipulations + logit_for_next_step = outputs.logits[:, -1, :].clone().float() model_kwargs = self._update_model_kwargs_for_generation( outputs, @@ -2720,7 +2736,8 @@ def _contrastive_search( next_hidden = outputs.hidden_states[-1] full_hidden_states = outputs.hidden_states - logits = outputs.logits[:, -1, :] + # .float() is needed to retain precision for later logits manipulations + logits = outputs.logits[:, -1, :].float() context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the @@ -2966,7 +2983,8 @@ def _sample( # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) - next_token_logits = outputs.logits[:, -1, :].clone() + # .float() is needed to retain precision for later logits manipulations + next_token_logits = outputs.logits[:, -1, :].clone().float() # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) @@ -3210,7 +3228,8 @@ def _beam_search( # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) - next_token_logits = outputs.logits[:, -1, :].clone() + # .float() is needed to retain precision for later logits manipulations + next_token_logits = outputs.logits[:, -1, :].clone().float() next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) @@ -3484,7 +3503,8 @@ def _group_beam_search( # select outputs of beams of current group only # No need to clone() the logits here as they will not retain outputs.logits at the end of the loop - next_token_logits = outputs.logits[batch_group_indices, -1, :] + # .float() is needed to retain precision for later logits manipulations + next_token_logits = outputs.logits[batch_group_indices, -1, :].float() next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 @@ -3739,7 +3759,8 @@ def _constrained_beam_search( # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) - next_token_logits = outputs.logits[:, -1, :].clone() + # .float() is needed to retain precision for later logits manipulations + next_token_logits = outputs.logits[:, -1, :].clone().float() next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) @@ -3998,7 +4019,8 @@ def _assisted_decoding( outputs = self(**model_inputs) # 2.3. Process the new logits - new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present + # .float() is needed to retain precision for later logits manipulations + new_logits = outputs.logits[:, -candidate_length - 1 :].float() # excludes the input prompt if present next_token_logits = new_logits.clone() if len(logits_processor) > 0: for i in range(candidate_length + 1): diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index afcea137b58b..ea5fd6749e6b 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -44,6 +44,7 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1024,6 +1025,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1032,6 +1034,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: Example: @@ -1071,12 +1078,19 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove the float() operation in v4.46 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() logits = logits * self.logit_scale - logits = logits.float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1109,6 +1123,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1166,6 +1181,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 3486d5ed3ab0..d2b76b8ef835 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1275,6 +1275,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r"""Forward function for causal language modeling. @@ -1284,6 +1285,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: Example: @@ -1328,7 +1334,8 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) + # No upscaling to float was ever done for Dbrx + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) loss = None if labels is not None: @@ -1380,6 +1387,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1437,6 +1445,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index a05d2c059e21..ea3448e03313 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -43,6 +43,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1038,6 +1039,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1046,6 +1048,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: Example: @@ -1085,10 +1092,18 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove the float() operation in v4.46 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1121,6 +1136,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1177,6 +1193,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 398ba4abefe1..bf6ff76189d4 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -41,6 +41,7 @@ is_flash_attn_2_available, is_flash_attn_greater_or_equal, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -987,6 +988,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -995,6 +997,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: Example: @@ -1038,15 +1045,23 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) if self.config.final_logit_softcapping is not None: logits = logits / self.config.final_logit_softcapping logits = torch.tanh(logits) logits = logits * self.config.final_logit_softcapping + # TODO: remove the float() operation in v4.46 logits = logits.float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1079,6 +1094,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1139,6 +1155,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index cdc7e9ba4e77..4d3db3208d77 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -33,6 +33,7 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1520,6 +1521,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, Idefics2CausalLMOutputWithPast]: r""" Args: @@ -1528,6 +1530,12 @@ def forward( config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics2ForConditionalGeneration`). Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: Example: @@ -1591,11 +1599,18 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove the float() operation in v4.46 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() labels = labels.to(logits.device) # Shift so that tokens < n predict n if attention_mask is not None: @@ -1623,7 +1638,13 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + num_logits_to_keep=0, + **kwargs, ): past_length = 0 # Omit tokens covered by past_key_values @@ -1682,6 +1703,7 @@ def prepare_inputs_for_generation( "pixel_values": pixel_values, "pixel_attention_mask": pixel_attention_mask, "image_hidden_states": image_hidden_states, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 230536a83a14..2722a5e06909 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -50,6 +50,7 @@ is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, is_mamba_ssm_available, + is_torchdynamo_compiling, ) from .configuration_jamba import JambaConfig @@ -1497,10 +1498,17 @@ def forward( logits = self.lm_head(hidden_states) else: logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :]) + if labels is None and not is_torchdynamo_compiling: + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # TODO: remove the float() operations in v4.46 logits = logits.float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index a3fc645e5aea..244f74717990 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -37,6 +37,7 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1253,6 +1254,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" Args: @@ -1261,6 +1263,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: """ @@ -1285,11 +1292,18 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove the float() operation in v4.46 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1339,6 +1353,7 @@ def prepare_inputs_for_generation( output_router_logits=False, position_ids=None, use_cache=True, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1371,6 +1386,7 @@ def prepare_inputs_for_generation( "use_cache": use_cache, "attention_mask": attention_mask, "output_router_logits": output_router_logits, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 8716d27f5481..7ce9d2b028b7 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -44,6 +44,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1146,6 +1147,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1154,6 +1156,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: Example: @@ -1198,11 +1205,18 @@ def forward( logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] logits = torch.cat(logits, dim=-1) else: - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove the float() operation in v4.46 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1235,6 +1249,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1292,6 +1307,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 1a2b732e85e4..438d8e8a568a 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -42,6 +42,7 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -996,6 +997,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1004,6 +1006,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: Example: @@ -1044,11 +1051,18 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove the float() operation in v4.46 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1081,6 +1095,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1115,6 +1130,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 522b6db7bcc7..247b5e10f762 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -43,6 +43,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1233,6 +1234,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" Args: @@ -1241,6 +1243,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: Example: @@ -1286,11 +1293,18 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove the float() operation in v4.46 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1339,6 +1353,7 @@ def prepare_inputs_for_generation( output_router_logits=False, position_ids=None, use_cache=True, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1371,6 +1386,7 @@ def prepare_inputs_for_generation( "use_cache": use_cache, "attention_mask": attention_mask, "output_router_logits": output_router_logits, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index db4bce273ca1..76edefcf3025 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -42,6 +42,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1029,6 +1030,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1037,6 +1039,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: Example: @@ -1076,11 +1083,19 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + # TODO: remove the float() operation in v4.46 logits = logits.float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1113,6 +1128,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1170,6 +1186,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 1940660f61b5..edf3eb0ab3fe 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -42,6 +42,7 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1068,6 +1069,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1076,6 +1078,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: Example: @@ -1116,11 +1123,18 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove the float() operation in v4.46 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1153,6 +1167,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1210,6 +1225,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 1e4f56c0674d..7d43ff552c37 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -885,6 +885,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -893,6 +894,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: Example: @@ -933,7 +939,8 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) + # No upscaling to float was ever done for Persimmon + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) loss = None if labels is not None: @@ -970,6 +977,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1027,6 +1035,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 6d63c0ea7e8e..fbc8720c89c8 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -41,6 +41,7 @@ get_torch_version, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1169,6 +1170,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1177,6 +1179,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: Example: @@ -1217,11 +1224,18 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove the float() operation in v4.46 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1255,6 +1269,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1312,6 +1327,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 08417fcabfaa..fea30dc191d2 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -40,6 +40,7 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1210,6 +1211,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1218,6 +1220,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: Example: @@ -1257,11 +1264,18 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove the float() operation in v4.46 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1295,6 +1309,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1352,6 +1367,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 28b414b1901b..1db1da30f96c 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -42,6 +42,7 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1067,6 +1068,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1075,6 +1077,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: Example: @@ -1115,11 +1122,18 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove the float() operation in v4.46 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1153,6 +1167,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1210,6 +1225,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 12ebe26e058d..31ea55644611 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -43,6 +43,7 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1244,6 +1245,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" Args: @@ -1252,6 +1254,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: Example: @@ -1296,11 +1303,18 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove the float() operation in v4.46 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1349,6 +1363,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1406,6 +1421,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 988948a9a827..14e7d5cc65fc 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -1164,6 +1164,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1172,6 +1173,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: Example: @@ -1211,7 +1217,8 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) + # No upscaling to float was ever done for StableLm + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) loss = None if labels is not None: @@ -1248,6 +1255,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1305,6 +1313,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index d51077b04254..ea3f3be9d861 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -42,6 +42,7 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1043,6 +1044,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1051,6 +1053,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: Example: @@ -1091,11 +1098,18 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() + if labels is None and not is_torchdynamo_compiling(): + logger.warning_once( + "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" + ) + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO: remove the float() operation in v4.46 + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1129,6 +1143,7 @@ def prepare_inputs_for_generation( cache_position=None, position_ids=None, use_cache=True, + num_logits_to_keep=0, **kwargs, ): # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens @@ -1186,6 +1201,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "num_logits_to_keep": num_logits_to_keep, } ) return model_inputs diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index ae52f6c67404..b0cf08d0530c 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1828,6 +1828,62 @@ def test_generate_compile_fullgraph(self): output_compiled = compiled_generate(model_inputs, generation_config=generation_config) self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist()) + def test_generate_methods_with_num_logits_to_keep(self): + for model_class in self.all_generative_model_classes: + if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): + self.skipTest(reason="This model does not support `num_logits_to_keep` argument.") + + config, input_ids, attention_mask = self._get_input_ids_and_config() + config.use_cache = True + config.is_decoder = True + + model = model_class(config).to(torch_device).eval() + # All generation methods (except assisted decoding) rely on always extracting the last token logits of the + # full logits matrix, so testing out only greedy search and assisted decoding is enough (if it works, + # other methods will work as well) + generation_kwargs = { + "max_new_tokens": 10, + "do_sample": False, + } + + # Setting num_logits_to_keep at 0 keeps all logits (old behavior) + with_all_logits = model.generate( + input_ids, attention_mask=attention_mask, **generation_kwargs, num_logits_to_keep=0 + ) + # By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior) + without_all_logits = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) + self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist()) + + def test_assisted_decoding_with_num_logits_to_keep(self): + for model_class in self.all_generative_model_classes: + if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): + self.skipTest(reason="This model does not support `num_logits_to_keep` argument.") + if model_class._is_stateful: + self.skipTest(reason="Stateful models don't support assisted generation") + + config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1) + config.use_cache = True + config.is_decoder = True + + model = model_class(config).to(torch_device).eval() + assistant_model = model + # All generation methods (except assisted decoding) rely on always extracting the last token logits of the + # full logits matrix, so testing out only greedy search and assisted decoding is enough (if it works, + # other methods will work as well) + generation_kwargs = { + "max_new_tokens": 10, + "do_sample": False, + "assistant_model": assistant_model, + } + + # Setting num_logits_to_keep at 0 keeps all logits (old behavior) + with_all_logits = model.generate( + input_ids, attention_mask=attention_mask, **generation_kwargs, num_logits_to_keep=0 + ) + # By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior) + without_all_logits = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) + self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist()) + def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): batch_size, seq_length = input_ids.shape num_sequences_in_output = batch_size * num_return_sequences diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 7cbc2f3e2813..7617c15efabf 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4824,6 +4824,27 @@ def test_compile_cuda_graph_time(self): self.assertTrue(record_time < 0.15 * graph_warmup_time) self.assertTrue(opt_time < record_time) + def test_forward_with_num_logits_to_keep(self): + for model_class in self.all_generative_model_classes: + if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): + self.skipTest(reason="This model does not support `num_logits_to_keep` argument.") + + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + batch_size, sequence_length = inputs["input_ids"].shape + vocab_size = config.vocab_size + model = model_class(config).to(device=torch_device).eval() + + # num_logits_to_keep=0 is a special case meaning "keep all logits" + all_logits = model(**inputs, num_logits_to_keep=0).logits + last_token_logits = model(**inputs, num_logits_to_keep=1).logits + + # Assert all shapes are correct + self.assertEqual(tuple(all_logits.shape), (batch_size, sequence_length, vocab_size)) + self.assertEqual(tuple(last_token_logits.shape), (batch_size, 1, vocab_size)) + + # Assert the last tokens are actually the same + self.assertTrue(torch.allclose(all_logits[:, -1:, :], last_token_logits)) + global_rng = random.Random()