Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add attention_mask and position_ids in assisted model #26892

Merged
merged 14 commits into from
Nov 10, 2023
67 changes: 33 additions & 34 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4410,15 +4410,6 @@ def assisted_decoding(

# other auxiliary variables
max_len = stopping_criteria[0].max_length
assistant_kv_indexing = (
1
if "bloom" in assistant_model.__class__.__name__.lower()
or (
assistant_model.config.architectures is not None
and "bloom" in assistant_model.config.architectures[0].lower()
)
else 0
)

this_peer_finished = False # used by synced_gpus only
while True:
Expand All @@ -4439,42 +4430,43 @@ def assisted_decoding(
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
# need access to the assistant cache to secure strong speedups.
candidate_input_ids = input_ids
assistant_attention_mask = model_kwargs.get("attention_mask", None)
assistant_decoder_attention_mask = model_kwargs.get("decoder_attention_mask", None)
assistant_encoder_outputs = (model_kwargs.get("assistant_encoder_outputs", None),)
for _ in range(int(num_assistant_tokens)):
# 1.1. use the assistant model to obtain the next candidate logits
if "assistant_past_key_values" in model_kwargs:
prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2]
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
assist_inputs = candidate_input_ids[:, -new_token_len:]
# TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2
if assistant_model.config.is_encoder_decoder:
assistant_model_outputs = assistant_model(
decoder_input_ids=assist_inputs,
past_key_values=model_kwargs["assistant_past_key_values"],
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
)
else:
assistant_model_outputs = assistant_model(
assist_inputs,
past_key_values=model_kwargs["assistant_past_key_values"],
)
else:
if assistant_model.config.is_encoder_decoder:
assistant_model_outputs = assistant_model(
decoder_input_ids=candidate_input_ids,
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
)
else:
assistant_model_outputs = assistant_model(candidate_input_ids)
assistant_inputs = assistant_model.prepare_inputs_for_generation(
candidate_input_ids,
attention_mask=assistant_attention_mask,
decoder_attention_mask=assistant_decoder_attention_mask,
encoder_outputs=assistant_encoder_outputs,
past_key_values=model_kwargs.get("assistant_past_key_values", None),
)
gante marked this conversation as resolved.
Show resolved Hide resolved
assistant_model_outputs = assistant_model(
**assistant_inputs,
)
gante marked this conversation as resolved.
Show resolved Hide resolved

# 1.2. greedily select the next candidate token
model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values
if len(logits_processor) > 0:
assistant_model_outputs.logits[:, -1, :] = logits_processor(
candidate_input_ids, assistant_model_outputs.logits[:, -1, :]
)

new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1)
candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1)
if self.config.is_encoder_decoder and assistant_decoder_attention_mask is not None:
assistant_decoder_attention_mask = torch.cat(
(
assistant_decoder_attention_mask,
torch.ones([1, 1], dtype=assistant_decoder_attention_mask.dtype),
),
dim=-1,
)
elif not self.config.is_encoder_decoder and assistant_attention_mask is not None:
assistant_attention_mask = torch.cat(
(assistant_attention_mask, torch.ones([1, 1], dtype=assistant_attention_mask.dtype)), dim=-1
)

# 1.3. stop assistant generation on EOS
if eos_token_id_tensor is not None:
Expand Down Expand Up @@ -4610,6 +4602,13 @@ def assisted_decoding(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)

# Update attention_mask
gante marked this conversation as resolved.
Show resolved Hide resolved
if n_matches > 0 and model_kwargs.get("attention_mask", None) is not None:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], n_matches))], dim=-1
)

# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
Expand Down
Loading