Skip to content

Commit

Permalink
Generate: remove unused attributes in AssistedCandidateGenerator (#…
Browse files Browse the repository at this point in the history
…29787)

remove unused attrs
  • Loading branch information
gante authored Mar 22, 2024
1 parent e85654f commit 34e07f4
Showing 1 changed file with 0 additions and 9 deletions.
9 changes: 0 additions & 9 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,27 +131,18 @@ def __init__(
if assistant_model.config.is_encoder_decoder:
# both are encoder-decoder
self.input_ids_key = "decoder_input_ids"
self.attention_key = "decoder_attention_mask"
elif "encoder_outputs" in assistant_kwargs:
# special case for encoder-decoder with decoder-only assistant (like DistilWhisper)
self.input_ids_key = "input_ids"
self.attention_key = "attention_mask"
self.assistant_kwargs["attention_mask"] = self.assistant_kwargs.get(
"decoder_attention_mask",
torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long),
)
else:
# both are decoder-only
self.input_ids_key = "input_ids"
self.attention_key = "attention_mask"

# Prepare generation-related options.
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id_tensor = (
torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
)
self.logits_processor = logits_processor
self.generation_config = copy.deepcopy(generation_config)
self.generation_config.return_dict_in_generate = True
Expand Down

0 comments on commit 34e07f4

Please sign in to comment.