-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix remaining issues in beam score calculation #27808
Changes from all commits
6d4a9f9
240d1fe
35887f9
96bacb3
b9b97e9
890ba95
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -224,8 +224,8 @@ def process( | |||||
group_index: Optional[int] = 0, | ||||||
decoder_prompt_len: Optional[int] = 0, | ||||||
) -> Dict[str, torch.Tensor]: | ||||||
# add up to the length which the next_scores is calculated on | ||||||
cur_len = input_ids.shape[-1] - decoder_prompt_len + 1 | ||||||
# add up to the length which the next_scores is calculated on (including decoder prompt) | ||||||
cur_len = input_ids.shape[-1] + 1 | ||||||
batch_size = len(self._beam_hyps) // self.num_beam_groups | ||||||
|
||||||
if not (batch_size == (input_ids.shape[0] // self.group_size)): | ||||||
|
@@ -279,15 +279,11 @@ def process( | |||||
else: | ||||||
beam_index = None | ||||||
|
||||||
# skip the corner case where the very first generated token is eos_token | ||||||
if decoder_prompt_len == input_ids.shape[-1]: | ||||||
continue | ||||||
|
||||||
self._beam_hyps[batch_group_idx].add( | ||||||
input_ids[batch_beam_idx].clone(), | ||||||
next_score.item(), | ||||||
beam_indices=beam_index, | ||||||
decoder_prompt_len=decoder_prompt_len, | ||||||
generated_len=cur_len - decoder_prompt_len, | ||||||
) | ||||||
else: | ||||||
# add next predicted token since it is not eos_token | ||||||
|
@@ -308,7 +304,7 @@ def process( | |||||
|
||||||
# Check if we are done so that we can save a pad step if all(done) | ||||||
self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done( | ||||||
next_scores[batch_idx].max().item(), cur_len | ||||||
next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len | ||||||
) | ||||||
|
||||||
return UserDict( | ||||||
|
@@ -348,7 +344,8 @@ def finalize( | |||||
final_score = final_beam_scores[batch_beam_idx].item() | ||||||
final_tokens = input_ids[batch_beam_idx] | ||||||
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None | ||||||
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len) | ||||||
generated_len = final_tokens.shape[-1] - decoder_prompt_len | ||||||
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len) | ||||||
|
||||||
# select the best hypotheses | ||||||
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) | ||||||
|
@@ -560,8 +557,8 @@ def process( | |||||
indicating to which beam the next tokens shall be added. | ||||||
""" | ||||||
|
||||||
# add up to the length which the next_scores is calculated on | ||||||
cur_len = input_ids.shape[-1] - decoder_prompt_len + 1 | ||||||
# add up to the length which the next_scores is calculated on (including decoder prompt) | ||||||
cur_len = input_ids.shape[-1] + 1 | ||||||
batch_size = len(self._beam_hyps) | ||||||
if not (batch_size == (input_ids.shape[0] // self.group_size)): | ||||||
if self.num_beam_groups > 1: | ||||||
|
@@ -617,16 +614,11 @@ def process( | |||||
else: | ||||||
beam_index = None | ||||||
|
||||||
# skip the corner case where the only constraint token is | ||||||
# eos_token and the very first generated token is eos_token | ||||||
if decoder_prompt_len == input_ids.shape[-1]: | ||||||
continue | ||||||
|
||||||
beam_hyp.add( | ||||||
input_ids[batch_beam_idx].clone(), | ||||||
next_score.item(), | ||||||
beam_indices=beam_index, | ||||||
decoder_prompt_len=decoder_prompt_len, | ||||||
generated_len=cur_len - decoder_prompt_len, | ||||||
) | ||||||
else: | ||||||
# add next predicted token since it is not eos_token | ||||||
|
@@ -660,7 +652,7 @@ def process( | |||||
|
||||||
# Check if we are done so that we can save a pad step if all(done) | ||||||
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( | ||||||
next_scores[batch_idx].max().item(), cur_len | ||||||
next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len | ||||||
) | ||||||
|
||||||
return UserDict( | ||||||
|
@@ -846,9 +838,8 @@ def finalize( | |||||
completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist()) | ||||||
if completes_constraint: | ||||||
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None | ||||||
beam_hyp.add( | ||||||
final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len | ||||||
) | ||||||
generated_len = final_tokens.shape[-1] - decoder_prompt_len | ||||||
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len) | ||||||
ids_collect.append(beam_id) | ||||||
|
||||||
# due to overly complex constraints or other factors, sometimes we can't gaurantee a successful | ||||||
|
@@ -859,7 +850,8 @@ def finalize( | |||||
batch_beam_idx = batch_idx * self.num_beams + beam_id | ||||||
final_score = final_beam_scores[batch_beam_idx].item() | ||||||
final_tokens = input_ids[batch_beam_idx] | ||||||
beam_hyp.add(final_tokens, final_score, decoder_prompt_len=decoder_prompt_len) | ||||||
generated_len = final_tokens.shape[-1] - decoder_prompt_len | ||||||
beam_hyp.add(final_tokens, final_score, generated_len=generated_len) | ||||||
if len(ids_collect) >= self.num_beam_hyps_to_keep: | ||||||
break | ||||||
|
||||||
|
@@ -956,12 +948,17 @@ def add( | |||||
hyp: torch.LongTensor, | ||||||
sum_logprobs: float, | ||||||
beam_indices: Optional[torch.LongTensor] = None, | ||||||
decoder_prompt_len: Optional[int] = 0, | ||||||
generated_len: Optional[int] = None, | ||||||
): | ||||||
""" | ||||||
Add a new hypothesis to the list. | ||||||
""" | ||||||
score = sum_logprobs / ((hyp.shape[-1] - decoder_prompt_len) ** self.length_penalty) | ||||||
if generated_len is not None: | ||||||
score = sum_logprobs / (generated_len**self.length_penalty) | ||||||
# This 'else' case exists for retrocompatibility | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oops, PR already merged, maybe let's stay with it for now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. of course no worries |
||||||
else: | ||||||
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty) | ||||||
|
||||||
if len(self) < self.num_beams or score > self.worst_score: | ||||||
self.beams.append((score, hyp, beam_indices)) | ||||||
if len(self) > self.num_beams: | ||||||
|
@@ -971,7 +968,7 @@ def add( | |||||
else: | ||||||
self.worst_score = min(score, self.worst_score) | ||||||
|
||||||
def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: | ||||||
def is_done(self, best_sum_logprobs: float, cur_len: int, decoder_prompt_len: Optional[int] = 0) -> bool: | ||||||
""" | ||||||
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst | ||||||
one in the heap, then we are done with this sentence. | ||||||
|
@@ -987,7 +984,7 @@ def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: | |||||
# when `length_penalty` is positive. See the discussion below for more details. | ||||||
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565 | ||||||
elif self.early_stopping is False: | ||||||
highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty | ||||||
highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty | ||||||
ret = self.worst_score >= highest_attainable_score | ||||||
return ret | ||||||
# `"never"`: compute the best possible score, depending on the signal of `length_penalty` | ||||||
|
@@ -996,9 +993,13 @@ def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: | |||||
# abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain | ||||||
# its max this way | ||||||
if self.length_penalty > 0.0: | ||||||
highest_attainable_score = best_sum_logprobs / self.max_length**self.length_penalty | ||||||
if self.max_length <= decoder_prompt_len: | ||||||
raise ValueError("max_length is not larger than decoder prompt length") | ||||||
highest_attainable_score = ( | ||||||
best_sum_logprobs / (self.max_length - decoder_prompt_len) ** self.length_penalty | ||||||
) | ||||||
# the opposite logic applies here (max `highest_attainable_score` from `cur_len`) | ||||||
else: | ||||||
highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty | ||||||
highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty | ||||||
ret = self.worst_score >= highest_attainable_score | ||||||
return ret |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -633,11 +633,7 @@ def test_eos_token_id_int_and_list_beam_search(self): | |
"do_sample": False, | ||
"num_beams": 3, | ||
} | ||
if is_pt: | ||
expectation = 20 | ||
else: | ||
# TODO (joao): fix me | ||
expectation = 13 | ||
Comment on lines
-636
to
-640
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice! 🔥 |
||
expectation = 13 | ||
|
||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") | ||
text = """Hello, my dog is cute and""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd add a note that the
else
case here exists for retrocompatibility reasons :)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the reminder. Added!