Skip to content

Commit

Permalink
🚨🚨 Fix beam score calculation issue for decoder-only models (#27351)
Browse files Browse the repository at this point in the history
* Fix beam score calculation issue for decoder-only models

* Update beam search test and fix code quality issue

* Fix beam_sample, group_beam_search and constrained_beam_search

* Split test for pytorch and TF, add documentation

---------

Co-authored-by: Xin Qiu <xin.qiu@sentient.ai>
  • Loading branch information
VsonicV and VsonicV authored Nov 15, 2023
1 parent 3d1a7bf commit 453079c
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 9 deletions.
42 changes: 34 additions & 8 deletions src/transformers/generation/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,10 @@ def process(
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
group_index: Optional[int] = 0,
decoder_prompt_len: Optional[int] = 0,
) -> Dict[str, torch.Tensor]:
cur_len = input_ids.shape[-1] + 1 # add up to the length which the next_scores is calculated on
# add up to the length which the next_scores is calculated on
cur_len = input_ids.shape[-1] - decoder_prompt_len + 1
batch_size = len(self._beam_hyps) // self.num_beam_groups

if not (batch_size == (input_ids.shape[0] // self.group_size)):
Expand Down Expand Up @@ -277,10 +279,15 @@ def process(
else:
beam_index = None

# skip the corner case where the very first generated token is eos_token
if decoder_prompt_len == input_ids.shape[-1]:
continue

self._beam_hyps[batch_group_idx].add(
input_ids[batch_beam_idx].clone(),
next_score.item(),
beam_indices=beam_index,
decoder_prompt_len=decoder_prompt_len,
)
else:
# add next predicted token since it is not eos_token
Expand Down Expand Up @@ -322,6 +329,7 @@ def finalize(
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps) // self.num_beam_groups

Expand All @@ -340,7 +348,7 @@ def finalize(
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index)
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len)

# select the best hypotheses
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
Expand Down Expand Up @@ -511,6 +519,7 @@ def process(
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
) -> Tuple[torch.Tensor]:
r"""
Args:
Expand All @@ -535,7 +544,8 @@ def process(
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
beam_indices (`torch.LongTensor`, *optional*):
Beam indices indicating to which beam hypothesis each token correspond.
decoder_prompt_len (`int`, *optional*):
The length of prompt that is included in the input to decoder.
Return:
`UserDict`: A dictionary composed of the fields as defined above:
Expand All @@ -550,7 +560,8 @@ def process(
indicating to which beam the next tokens shall be added.
"""

cur_len = input_ids.shape[-1] + 1 # add up to the length which the next_scores is calculated on
# add up to the length which the next_scores is calculated on
cur_len = input_ids.shape[-1] - decoder_prompt_len + 1
batch_size = len(self._beam_hyps)
if not (batch_size == (input_ids.shape[0] // self.group_size)):
if self.num_beam_groups > 1:
Expand Down Expand Up @@ -606,10 +617,16 @@ def process(
else:
beam_index = None

# skip the corner case where the only constraint token is
# eos_token and the very first generated token is eos_token
if decoder_prompt_len == input_ids.shape[-1]:
continue

beam_hyp.add(
input_ids[batch_beam_idx].clone(),
next_score.item(),
beam_indices=beam_index,
decoder_prompt_len=decoder_prompt_len,
)
else:
# add next predicted token since it is not eos_token
Expand Down Expand Up @@ -805,6 +822,7 @@ def finalize(
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps)

Expand All @@ -828,7 +846,9 @@ def finalize(
completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist())
if completes_constraint:
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index)
beam_hyp.add(
final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len
)
ids_collect.append(beam_id)

# due to overly complex constraints or other factors, sometimes we can't gaurantee a successful
Expand All @@ -839,7 +859,7 @@ def finalize(
batch_beam_idx = batch_idx * self.num_beams + beam_id
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
beam_hyp.add(final_tokens, final_score)
beam_hyp.add(final_tokens, final_score, decoder_prompt_len=decoder_prompt_len)
if len(ids_collect) >= self.num_beam_hyps_to_keep:
break

Expand Down Expand Up @@ -931,11 +951,17 @@ def __len__(self):
"""
return len(self.beams)

def add(self, hyp: torch.LongTensor, sum_logprobs: float, beam_indices: Optional[torch.LongTensor] = None):
def add(
self,
hyp: torch.LongTensor,
sum_logprobs: float,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
score = sum_logprobs / ((hyp.shape[-1] - decoder_prompt_len) ** self.length_penalty)
if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp, beam_indices))
if len(self) > self.num_beams:
Expand Down
16 changes: 16 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3172,6 +3172,8 @@ def beam_search(
beam_scores = beam_scores.view((batch_size * num_beams,))

this_peer_finished = False # used by synced_gpus only

decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
Expand Down Expand Up @@ -3246,6 +3248,7 @@ def beam_search(
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=beam_indices,
decoder_prompt_len=decoder_prompt_len,
)

beam_scores = beam_outputs["next_beam_scores"]
Expand Down Expand Up @@ -3281,6 +3284,7 @@ def beam_search(
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=beam_indices,
decoder_prompt_len=decoder_prompt_len,
)

if return_dict_in_generate:
Expand Down Expand Up @@ -3500,6 +3504,8 @@ def beam_sample(
beam_scores = beam_scores.view((batch_size * num_beams,))

this_peer_finished = False # used by synced_gpus only

decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
Expand Down Expand Up @@ -3578,6 +3584,7 @@ def beam_sample(
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=beam_indices,
decoder_prompt_len=decoder_prompt_len,
)
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
Expand Down Expand Up @@ -3612,6 +3619,7 @@ def beam_sample(
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=beam_indices,
decoder_prompt_len=decoder_prompt_len,
)

if return_dict_in_generate:
Expand Down Expand Up @@ -3837,6 +3845,8 @@ def group_beam_search(
beam_scores = beam_scores.view((batch_size * num_beams,))

this_peer_finished = False # used by synced_gpus only

decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
Expand Down Expand Up @@ -3924,6 +3934,7 @@ def group_beam_search(
eos_token_id=eos_token_id,
beam_indices=process_beam_indices,
group_index=beam_group_idx,
decoder_prompt_len=decoder_prompt_len,
)
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
Expand Down Expand Up @@ -3993,6 +4004,7 @@ def group_beam_search(
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=final_beam_indices,
decoder_prompt_len=decoder_prompt_len,
)

if return_dict_in_generate:
Expand Down Expand Up @@ -4220,6 +4232,8 @@ def constrained_beam_search(
beam_scores = beam_scores.view((batch_size * num_beams,))

this_peer_finished = False # used by synced_gpus only

decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
Expand Down Expand Up @@ -4298,6 +4312,7 @@ def constrained_beam_search(
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=beam_indices,
decoder_prompt_len=decoder_prompt_len,
)
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
Expand Down Expand Up @@ -4331,6 +4346,7 @@ def constrained_beam_search(
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=beam_indices,
decoder_prompt_len=decoder_prompt_len,
)

if return_dict_in_generate:
Expand Down
6 changes: 5 additions & 1 deletion tests/generation/test_framework_agnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,11 @@ def test_eos_token_id_int_and_list_beam_search(self):
"do_sample": False,
"num_beams": 3,
}
expectation = 13
if is_pt:
expectation = 20
else:
# TODO (joao): fix me
expectation = 13

tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = """Hello, my dog is cute and"""
Expand Down

0 comments on commit 453079c

Please sign in to comment.