Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reducing memory usage: removing useless logits computation in generate() #31292

Merged
merged 35 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
1748ff1
Add .float() in all generation methods logit outputs
Cyrilvallez Jun 5, 2024
3f4f4e8
Switch float-casting of logits to training only for main models
Cyrilvallez Jun 5, 2024
727c7e4
Add `num_logits_to_keep` in Llama and add it by default in generate
Cyrilvallez Jun 6, 2024
222017d
Apply style
Cyrilvallez Jun 6, 2024
dc709c6
Add num_logits_to_keep as arg in prepare_input_for_generation
Cyrilvallez Jun 6, 2024
d2f1566
Add support for Mistral
Cyrilvallez Jun 6, 2024
f2ef90c
Revert models except llama and mistral
Cyrilvallez Jun 6, 2024
ce7b980
Fix default None value in _supports_num_logits_to_keep()
Cyrilvallez Jun 6, 2024
d4201f4
Fix dimension of dummy input
Cyrilvallez Jun 6, 2024
b15b5de
Add exception for prophetnet in _supports_num_logits_to_keep()
Cyrilvallez Jun 6, 2024
95e0807
Update _supports_num_logits_to_keep() to use inspect.signature()
Cyrilvallez Jun 6, 2024
12db045
Add deprecation cycle + remove modification with pretraining_tp
Cyrilvallez Jun 20, 2024
b224e24
Apply style
Cyrilvallez Jun 20, 2024
f0e1034
Add most used models
Cyrilvallez Jun 21, 2024
9ac57db
Apply style
Cyrilvallez Jun 21, 2024
f7421b6
Make `num_logits_to_keep` an int in all cases to remove if-else clause
Cyrilvallez Jul 12, 2024
c8f9177
Add compile check for the warning
Cyrilvallez Jul 17, 2024
5e1589e
Fix torch versions
Cyrilvallez Jul 17, 2024
7998b65
style
Cyrilvallez Jul 17, 2024
8fa8018
Add gemma2
Cyrilvallez Aug 20, 2024
b49fe76
Update warning version
Cyrilvallez Aug 20, 2024
cf9378a
Add comment about .float operations in generation utils
Cyrilvallez Aug 20, 2024
66e3e9d
Add tests in GenerationTesterMixin and ModelTesterMixin
Cyrilvallez Aug 21, 2024
e4c5a71
Fix batch size for assisted decoding in tests
Cyrilvallez Aug 21, 2024
b68ee16
fix small issues in test
Cyrilvallez Aug 21, 2024
e837425
refacor test
Cyrilvallez Aug 21, 2024
26863ca
fix slicing removing dim issue
Cyrilvallez Aug 21, 2024
3c3eeaa
Add nemotron support (should fix check-copy issue in CIs)
Cyrilvallez Aug 21, 2024
c400865
Trigger new CIs
Cyrilvallez Aug 21, 2024
802eca8
Trigger new CIs
Cyrilvallez Aug 22, 2024
4d6fae6
Bump version
Cyrilvallez Aug 22, 2024
f12f172
Bump version in TODO
Cyrilvallez Aug 22, 2024
7b1a26c
Trigger CIs
Cyrilvallez Aug 22, 2024
b11b048
remove blank space
Cyrilvallez Aug 23, 2024
f03adfb
Trigger CIs
Cyrilvallez Aug 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ def __init__(
value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value)
)

# Remove potential default "num_logits_to_keep" key
if "num_logits_to_keep" in assistant_kwargs.keys() and not assistant_model._supports_num_logits_to_keep():
del assistant_kwargs["num_logits_to_keep"]

if "assistant_encoder_outputs" in model_kwargs:
assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
elif assistant_model.config.is_encoder_decoder:
Expand Down
40 changes: 31 additions & 9 deletions src/transformers/generation/utils.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe let's comment that we need .float() for full precision soft

Original file line number Diff line number Diff line change
Expand Up @@ -1601,6 +1601,13 @@ def _prepare_cache_for_generation(
else EncoderDecoderCache(DynamicCache(), DynamicCache())
)

def _supports_num_logits_to_keep(self) -> bool:
"""
Return True if the current model supports the keyword argument `num_logits_to_keep` in forward()
to save memory. Checking it in this way allows to avoid using a new model attribute.
"""
return "num_logits_to_keep" in set(inspect.signature(self.forward).parameters.keys())

def _prepare_special_tokens(
self,
generation_config: GenerationConfig,
Expand Down Expand Up @@ -1876,6 +1883,13 @@ def generate(
inputs_tensor=inputs_tensor,
input_ids_length=input_ids_length,
)

# If the model supports `num_logits_to_keep` in forward(), set it to 1 to avoid computing the whole
# logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding
# dynamically overrides this value as it can need more than the last token logits
if self._supports_num_logits_to_keep() and "num_logits_to_keep" not in model_kwargs:
model_kwargs["num_logits_to_keep"] = 1

self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)

# 7. Prepare the cache.
Expand Down Expand Up @@ -2412,8 +2426,9 @@ def _dola_decoding(
output_hidden_states=True,
)

final_layer_next_token_logits = outputs.logits[:, -1, :].detach().clone()
final_logits = outputs.logits[:, -1, :]
# .float() is needed to retain precision for later logits manipulations
final_layer_next_token_logits = outputs.logits[:, -1, :].detach().clone().float()
final_logits = outputs.logits[:, -1, :].float()
candidate_premature_logits = {}
for candidate_premature_layer in candidate_premature_layers:
candidate_premature_logits[candidate_premature_layer] = lm_head(
Expand Down Expand Up @@ -2590,7 +2605,8 @@ def _contrastive_search(
# next logit for contrastive search to select top-k candidate tokens
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration
# (the clone itself is always small)
logit_for_next_step = outputs.logits[:, -1, :].clone()
# .float() is needed to retain precision for later logits manipulations
logit_for_next_step = outputs.logits[:, -1, :].clone().float()

model_kwargs = self._update_model_kwargs_for_generation(
outputs,
Expand Down Expand Up @@ -2720,7 +2736,8 @@ def _contrastive_search(
next_hidden = outputs.hidden_states[-1]
full_hidden_states = outputs.hidden_states

logits = outputs.logits[:, -1, :]
# .float() is needed to retain precision for later logits manipulations
logits = outputs.logits[:, -1, :].float()
context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)

# compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
Expand Down Expand Up @@ -2966,7 +2983,8 @@ def _sample(

# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
next_token_logits = outputs.logits[:, -1, :].clone()
# .float() is needed to retain precision for later logits manipulations
next_token_logits = outputs.logits[:, -1, :].clone().float()

# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
Expand Down Expand Up @@ -3210,7 +3228,8 @@ def _beam_search(

# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
next_token_logits = outputs.logits[:, -1, :].clone()
# .float() is needed to retain precision for later logits manipulations
next_token_logits = outputs.logits[:, -1, :].clone().float()
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
Expand Down Expand Up @@ -3484,7 +3503,8 @@ def _group_beam_search(

# select outputs of beams of current group only
# No need to clone() the logits here as they will not retain outputs.logits at the end of the loop
next_token_logits = outputs.logits[batch_group_indices, -1, :]
# .float() is needed to retain precision for later logits manipulations
next_token_logits = outputs.logits[batch_group_indices, -1, :].float()

next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
Expand Down Expand Up @@ -3739,7 +3759,8 @@ def _constrained_beam_search(

# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
next_token_logits = outputs.logits[:, -1, :].clone()
# .float() is needed to retain precision for later logits manipulations
next_token_logits = outputs.logits[:, -1, :].clone().float()
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
Expand Down Expand Up @@ -3998,7 +4019,8 @@ def _assisted_decoding(
outputs = self(**model_inputs)

# 2.3. Process the new logits
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
# .float() is needed to retain precision for later logits manipulations
new_logits = outputs.logits[:, -candidate_length - 1 :].float() # excludes the input prompt if present
next_token_logits = new_logits.clone()
if len(logits_processor) > 0:
for i in range(candidate_length + 1):
Expand Down
20 changes: 18 additions & 2 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
Expand Down Expand Up @@ -1024,6 +1025,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand All @@ -1032,6 +1034,11 @@ def forward(
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.

Returns:

Example:
Expand Down Expand Up @@ -1071,12 +1078,19 @@ def forward(
)

hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
if labels is None and not is_torchdynamo_compiling():
logger.warning_once(
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO: remove the float() operation in v4.46
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
logits = logits * self.logit_scale
logits = logits.float()

loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
Expand Down Expand Up @@ -1109,6 +1123,7 @@ def prepare_inputs_for_generation(
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
Expand Down Expand Up @@ -1166,6 +1181,7 @@ def prepare_inputs_for_generation(
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs
11 changes: 10 additions & 1 deletion src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,6 +1275,7 @@ def forward(
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""Forward function for causal language modeling.

Expand All @@ -1284,6 +1285,11 @@ def forward(
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.

Returns:

Example:
Expand Down Expand Up @@ -1328,7 +1334,8 @@ def forward(
)

hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
# No upscaling to float was ever done for Dbrx
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])

loss = None
if labels is not None:
Expand Down Expand Up @@ -1380,6 +1387,7 @@ def prepare_inputs_for_generation(
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
Expand Down Expand Up @@ -1437,6 +1445,7 @@ def prepare_inputs_for_generation(
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs
21 changes: 19 additions & 2 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
Expand Down Expand Up @@ -1038,6 +1039,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand All @@ -1046,6 +1048,11 @@ def forward(
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.

Returns:

Example:
Expand Down Expand Up @@ -1085,10 +1092,18 @@ def forward(
)

hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
if labels is None and not is_torchdynamo_compiling():
logger.warning_once(
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO: remove the float() operation in v4.46
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()

loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
Expand Down Expand Up @@ -1121,6 +1136,7 @@ def prepare_inputs_for_generation(
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
Expand Down Expand Up @@ -1177,6 +1193,7 @@ def prepare_inputs_for_generation(
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs
Expand Down
19 changes: 18 additions & 1 deletion src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
is_flash_attn_2_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
Expand Down Expand Up @@ -987,6 +988,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand All @@ -995,6 +997,11 @@ def forward(
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.

Returns:

Example:
Expand Down Expand Up @@ -1038,15 +1045,23 @@ def forward(
)

hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
if labels is None and not is_torchdynamo_compiling():
logger.warning_once(
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
if self.config.final_logit_softcapping is not None:
logits = logits / self.config.final_logit_softcapping
logits = torch.tanh(logits)
logits = logits * self.config.final_logit_softcapping

# TODO: remove the float() operation in v4.46
logits = logits.float()
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
Expand Down Expand Up @@ -1079,6 +1094,7 @@ def prepare_inputs_for_generation(
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
Expand Down Expand Up @@ -1139,6 +1155,7 @@ def prepare_inputs_for_generation(
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs
Expand Down
Loading
Loading