diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 31bb0eca5c09..3b1bef6f0400 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -3038,7 +3038,9 @@ def beam_search( ) # (batch_size * num_beams, vocab_size) next_token_scores_processed = logits_processor(input_ids, next_token_scores) - next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) + next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( + next_token_scores_processed + ) # Store scores, attentions and hidden_states when required if return_dict_in_generate: @@ -3363,7 +3365,9 @@ def beam_sample( ) # (batch_size * num_beams, vocab_size) next_token_scores_processed = logits_processor(input_ids, next_token_scores) - next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) + next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( + next_token_scores_processed + ) # Note: logits warpers are intentionally applied after adding running beam scores. On some logits warpers # (like top_p) this is indiferent, but on others (like temperature) it is not. For reference, see # https://github.com/huggingface/transformers/pull/5420#discussion_r449779867 @@ -4080,7 +4084,9 @@ def constrained_beam_search( next_token_scores_processed = logits_processor(input_ids, next_token_scores) - next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) + next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( + next_token_scores_processed + ) scores_for_all_vocab = next_token_scores.clone()