Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix remaining issues in beam score calculation #27808

Merged
merged 6 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 29 additions & 28 deletions src/transformers/generation/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ def process(
group_index: Optional[int] = 0,
decoder_prompt_len: Optional[int] = 0,
) -> Dict[str, torch.Tensor]:
# add up to the length which the next_scores is calculated on
cur_len = input_ids.shape[-1] - decoder_prompt_len + 1
# add up to the length which the next_scores is calculated on (including decoder prompt)
cur_len = input_ids.shape[-1] + 1
batch_size = len(self._beam_hyps) // self.num_beam_groups

if not (batch_size == (input_ids.shape[0] // self.group_size)):
Expand Down Expand Up @@ -279,15 +279,11 @@ def process(
else:
beam_index = None

# skip the corner case where the very first generated token is eos_token
if decoder_prompt_len == input_ids.shape[-1]:
continue

self._beam_hyps[batch_group_idx].add(
input_ids[batch_beam_idx].clone(),
next_score.item(),
beam_indices=beam_index,
decoder_prompt_len=decoder_prompt_len,
generated_len=cur_len - decoder_prompt_len,
)
else:
# add next predicted token since it is not eos_token
Expand All @@ -308,7 +304,7 @@ def process(

# Check if we are done so that we can save a pad step if all(done)
self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done(
next_scores[batch_idx].max().item(), cur_len
next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
)

return UserDict(
Expand Down Expand Up @@ -348,7 +344,8 @@ def finalize(
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len)
generated_len = final_tokens.shape[-1] - decoder_prompt_len
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)

# select the best hypotheses
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
Expand Down Expand Up @@ -560,8 +557,8 @@ def process(
indicating to which beam the next tokens shall be added.
"""

# add up to the length which the next_scores is calculated on
cur_len = input_ids.shape[-1] - decoder_prompt_len + 1
# add up to the length which the next_scores is calculated on (including decoder prompt)
cur_len = input_ids.shape[-1] + 1
batch_size = len(self._beam_hyps)
if not (batch_size == (input_ids.shape[0] // self.group_size)):
if self.num_beam_groups > 1:
Expand Down Expand Up @@ -617,16 +614,11 @@ def process(
else:
beam_index = None

# skip the corner case where the only constraint token is
# eos_token and the very first generated token is eos_token
if decoder_prompt_len == input_ids.shape[-1]:
continue

beam_hyp.add(
input_ids[batch_beam_idx].clone(),
next_score.item(),
beam_indices=beam_index,
decoder_prompt_len=decoder_prompt_len,
generated_len=cur_len - decoder_prompt_len,
)
else:
# add next predicted token since it is not eos_token
Expand Down Expand Up @@ -660,7 +652,7 @@ def process(

# Check if we are done so that we can save a pad step if all(done)
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
next_scores[batch_idx].max().item(), cur_len
next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
)

return UserDict(
Expand Down Expand Up @@ -846,9 +838,8 @@ def finalize(
completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist())
if completes_constraint:
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
beam_hyp.add(
final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len
)
generated_len = final_tokens.shape[-1] - decoder_prompt_len
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
ids_collect.append(beam_id)

# due to overly complex constraints or other factors, sometimes we can't gaurantee a successful
Expand All @@ -859,7 +850,8 @@ def finalize(
batch_beam_idx = batch_idx * self.num_beams + beam_id
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
beam_hyp.add(final_tokens, final_score, decoder_prompt_len=decoder_prompt_len)
generated_len = final_tokens.shape[-1] - decoder_prompt_len
beam_hyp.add(final_tokens, final_score, generated_len=generated_len)
if len(ids_collect) >= self.num_beam_hyps_to_keep:
break

Expand Down Expand Up @@ -956,12 +948,17 @@ def add(
hyp: torch.LongTensor,
sum_logprobs: float,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
generated_len: Optional[int] = None,
):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / ((hyp.shape[-1] - decoder_prompt_len) ** self.length_penalty)
if generated_len is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd add a note that the else case here exists for retrocompatibility reasons :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reminder. Added!

score = sum_logprobs / (generated_len**self.length_penalty)
# This 'else' case exists for retrocompatibility
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# This 'else' case exists for retrocompatibility
# This 'else' case exists for backward compatibility

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, PR already merged, maybe let's stay with it for now?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

of course no worries

else:
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)

if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp, beam_indices))
if len(self) > self.num_beams:
Expand All @@ -971,7 +968,7 @@ def add(
else:
self.worst_score = min(score, self.worst_score)

def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
def is_done(self, best_sum_logprobs: float, cur_len: int, decoder_prompt_len: Optional[int] = 0) -> bool:
"""
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
one in the heap, then we are done with this sentence.
Expand All @@ -987,7 +984,7 @@ def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
# when `length_penalty` is positive. See the discussion below for more details.
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
elif self.early_stopping is False:
highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty
highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
ret = self.worst_score >= highest_attainable_score
return ret
# `"never"`: compute the best possible score, depending on the signal of `length_penalty`
Expand All @@ -996,9 +993,13 @@ def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
# abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain
# its max this way
if self.length_penalty > 0.0:
highest_attainable_score = best_sum_logprobs / self.max_length**self.length_penalty
if self.max_length <= decoder_prompt_len:
raise ValueError("max_length is not larger than decoder prompt length")
highest_attainable_score = (
best_sum_logprobs / (self.max_length - decoder_prompt_len) ** self.length_penalty
)
# the opposite logic applies here (max `highest_attainable_score` from `cur_len`)
else:
highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty
highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
ret = self.worst_score >= highest_attainable_score
return ret
6 changes: 1 addition & 5 deletions tests/generation/test_framework_agnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,11 +633,7 @@ def test_eos_token_id_int_and_list_beam_search(self):
"do_sample": False,
"num_beams": 3,
}
if is_pt:
expectation = 20
else:
# TODO (joao): fix me
expectation = 13
Comment on lines -636 to -640
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! 🔥

expectation = 13

tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = """Hello, my dog is cute and"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ def generate_step(pixel_values):

preds, scores = generate_step(pixel_values)

EXPECTED_SCORES = np.array([-0.64145195])
EXPECTED_SCORES = np.array([-0.5956343])
max_diff = np.amax(np.abs(scores - EXPECTED_SCORES))
self.assertLessEqual(max_diff, 1e-4)

Expand Down