Skip to content

Commit

Permalink
Avoid dummy token in PLD to optimize performance (#29445)
Browse files Browse the repository at this point in the history
  • Loading branch information
ofirzaf authored Mar 6, 2024
1 parent 700d48f commit 0a5b051
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,8 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
break

if chosen_ids is None or len(chosen_ids) == 0:
# Need to make a dummy tensor to avoid errors
chosen_ids = torch.zeros((1), dtype=torch.long, device=input_ids.device)
# In case we didn't find a match return the input sequence unchanged, reverts back to autoregressive decoding
return input_ids, None

# Now need extend input_ids with chosen_ids
chosen_ids = chosen_ids.unsqueeze(0)
Expand Down

0 comments on commit 0a5b051

Please sign in to comment.