From b31905d1f61034bb147466ba7fd281ab668e8333 Mon Sep 17 00:00:00 2001 From: Xin Qiu Date: Fri, 8 Dec 2023 21:14:16 +0800 Subject: [PATCH] Fix remaining issues in beam score calculation (#27808) * Fix issues in add and is_done for BeamHypotheses * make newly added arguments optional for better compatibility * Directly use cur_len as generated_len, add note for retrocompatibility * update test expectation * make cur_len represents the length of the entire sequence including the decoder prompt * remove redundant if/else in testing --- src/transformers/generation/beam_search.py | 57 ++++++++++--------- tests/generation/test_framework_agnostic.py | 6 +- .../test_modeling_vision_encoder_decoder.py | 2 +- 3 files changed, 31 insertions(+), 34 deletions(-) diff --git a/src/transformers/generation/beam_search.py b/src/transformers/generation/beam_search.py index a29d34306f83..5e73862e163d 100644 --- a/src/transformers/generation/beam_search.py +++ b/src/transformers/generation/beam_search.py @@ -224,8 +224,8 @@ def process( group_index: Optional[int] = 0, decoder_prompt_len: Optional[int] = 0, ) -> Dict[str, torch.Tensor]: - # add up to the length which the next_scores is calculated on - cur_len = input_ids.shape[-1] - decoder_prompt_len + 1 + # add up to the length which the next_scores is calculated on (including decoder prompt) + cur_len = input_ids.shape[-1] + 1 batch_size = len(self._beam_hyps) // self.num_beam_groups if not (batch_size == (input_ids.shape[0] // self.group_size)): @@ -279,15 +279,11 @@ def process( else: beam_index = None - # skip the corner case where the very first generated token is eos_token - if decoder_prompt_len == input_ids.shape[-1]: - continue - self._beam_hyps[batch_group_idx].add( input_ids[batch_beam_idx].clone(), next_score.item(), beam_indices=beam_index, - decoder_prompt_len=decoder_prompt_len, + generated_len=cur_len - decoder_prompt_len, ) else: # add next predicted token since it is not eos_token @@ -308,7 +304,7 @@ def process( # Check if we are done so that we can save a pad step if all(done) self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done( - next_scores[batch_idx].max().item(), cur_len + next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len ) return UserDict( @@ -348,7 +344,8 @@ def finalize( final_score = final_beam_scores[batch_beam_idx].item() final_tokens = input_ids[batch_beam_idx] beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None - beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len) + generated_len = final_tokens.shape[-1] - decoder_prompt_len + beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len) # select the best hypotheses sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) @@ -560,8 +557,8 @@ def process( indicating to which beam the next tokens shall be added. """ - # add up to the length which the next_scores is calculated on - cur_len = input_ids.shape[-1] - decoder_prompt_len + 1 + # add up to the length which the next_scores is calculated on (including decoder prompt) + cur_len = input_ids.shape[-1] + 1 batch_size = len(self._beam_hyps) if not (batch_size == (input_ids.shape[0] // self.group_size)): if self.num_beam_groups > 1: @@ -617,16 +614,11 @@ def process( else: beam_index = None - # skip the corner case where the only constraint token is - # eos_token and the very first generated token is eos_token - if decoder_prompt_len == input_ids.shape[-1]: - continue - beam_hyp.add( input_ids[batch_beam_idx].clone(), next_score.item(), beam_indices=beam_index, - decoder_prompt_len=decoder_prompt_len, + generated_len=cur_len - decoder_prompt_len, ) else: # add next predicted token since it is not eos_token @@ -660,7 +652,7 @@ def process( # Check if we are done so that we can save a pad step if all(done) self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( - next_scores[batch_idx].max().item(), cur_len + next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len ) return UserDict( @@ -846,9 +838,8 @@ def finalize( completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist()) if completes_constraint: beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None - beam_hyp.add( - final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len - ) + generated_len = final_tokens.shape[-1] - decoder_prompt_len + beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len) ids_collect.append(beam_id) # due to overly complex constraints or other factors, sometimes we can't gaurantee a successful @@ -859,7 +850,8 @@ def finalize( batch_beam_idx = batch_idx * self.num_beams + beam_id final_score = final_beam_scores[batch_beam_idx].item() final_tokens = input_ids[batch_beam_idx] - beam_hyp.add(final_tokens, final_score, decoder_prompt_len=decoder_prompt_len) + generated_len = final_tokens.shape[-1] - decoder_prompt_len + beam_hyp.add(final_tokens, final_score, generated_len=generated_len) if len(ids_collect) >= self.num_beam_hyps_to_keep: break @@ -956,12 +948,17 @@ def add( hyp: torch.LongTensor, sum_logprobs: float, beam_indices: Optional[torch.LongTensor] = None, - decoder_prompt_len: Optional[int] = 0, + generated_len: Optional[int] = None, ): """ Add a new hypothesis to the list. """ - score = sum_logprobs / ((hyp.shape[-1] - decoder_prompt_len) ** self.length_penalty) + if generated_len is not None: + score = sum_logprobs / (generated_len**self.length_penalty) + # This 'else' case exists for retrocompatibility + else: + score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty) + if len(self) < self.num_beams or score > self.worst_score: self.beams.append((score, hyp, beam_indices)) if len(self) > self.num_beams: @@ -971,7 +968,7 @@ def add( else: self.worst_score = min(score, self.worst_score) - def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: + def is_done(self, best_sum_logprobs: float, cur_len: int, decoder_prompt_len: Optional[int] = 0) -> bool: """ If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst one in the heap, then we are done with this sentence. @@ -987,7 +984,7 @@ def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: # when `length_penalty` is positive. See the discussion below for more details. # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565 elif self.early_stopping is False: - highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty + highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty ret = self.worst_score >= highest_attainable_score return ret # `"never"`: compute the best possible score, depending on the signal of `length_penalty` @@ -996,9 +993,13 @@ def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: # abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain # its max this way if self.length_penalty > 0.0: - highest_attainable_score = best_sum_logprobs / self.max_length**self.length_penalty + if self.max_length <= decoder_prompt_len: + raise ValueError("max_length is not larger than decoder prompt length") + highest_attainable_score = ( + best_sum_logprobs / (self.max_length - decoder_prompt_len) ** self.length_penalty + ) # the opposite logic applies here (max `highest_attainable_score` from `cur_len`) else: - highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty + highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty ret = self.worst_score >= highest_attainable_score return ret diff --git a/tests/generation/test_framework_agnostic.py b/tests/generation/test_framework_agnostic.py index 8a269801640e..306cb15168e5 100644 --- a/tests/generation/test_framework_agnostic.py +++ b/tests/generation/test_framework_agnostic.py @@ -633,11 +633,7 @@ def test_eos_token_id_int_and_list_beam_search(self): "do_sample": False, "num_beams": 3, } - if is_pt: - expectation = 20 - else: - # TODO (joao): fix me - expectation = 13 + expectation = 13 tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") text = """Hello, my dog is cute and""" diff --git a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py index 40862c6234f4..71b1ff710679 100644 --- a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py +++ b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py @@ -800,7 +800,7 @@ def generate_step(pixel_values): preds, scores = generate_step(pixel_values) - EXPECTED_SCORES = np.array([-0.64145195]) + EXPECTED_SCORES = np.array([-0.5956343]) max_diff = np.amax(np.abs(scores - EXPECTED_SCORES)) self.assertLessEqual(max_diff, 1e-4)