From e9b33b37bea19b4d33b8cd1cd99336a685adcab4 Mon Sep 17 00:00:00 2001 From: "Feng, Jiqing" Date: Tue, 17 Oct 2023 23:21:13 -0700 Subject: [PATCH 1/8] add attention_mask and position_ids in assisted model --- src/transformers/generation/utils.py | 48 +++++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 606fbbe7060..00f6f885a3b 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2498,6 +2498,9 @@ def greedy_search( break # prepare model inputs + import pdb + + pdb.set_trace() model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token @@ -4437,10 +4440,32 @@ def assisted_decoding( # `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we # need access to the assistant cache to secure strong speedups. candidate_input_ids = input_ids + # import pdb; pdb.set_trace() + assistant_attention_mask = model_kwargs.get("attention_mask", None) + if assistant_attention_mask is not None: + assistant_attention_mask = torch.cat( + [ + assistant_attention_mask, + torch.ones( + [ + assistant_attention_mask.shape[0], + input_ids.shape[-1] - assistant_attention_mask.shape[-1], + ], + dtype=assistant_attention_mask.dtype, + ), + ], + dim=-1, + ) for _ in range(int(num_assistant_tokens)): # 1.1. use the assistant model to obtain the next candidate logits - if "assistant_past_key_values" in model_kwargs: - prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2] + assistant_past_key_values = model_kwargs.get("assistant_past_key_values", None) + assistant_position_ids = assistant_model.prepare_inputs_for_generation( + candidate_input_ids, + attention_mask=assistant_attention_mask, + past_key_values=assistant_past_key_values, + ).get("position_ids", None) + if assistant_past_key_values is not None: + prev_seq_len = assistant_past_key_values[0][assistant_kv_indexing].shape[-2] # `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model) new_token_len = candidate_input_ids.shape[1] - prev_seq_len assist_inputs = candidate_input_ids[:, -new_token_len:] @@ -4448,22 +4473,32 @@ def assisted_decoding( if assistant_model.config.is_encoder_decoder: assistant_model_outputs = assistant_model( decoder_input_ids=assist_inputs, - past_key_values=model_kwargs["assistant_past_key_values"], + attention_mask=assistant_attention_mask, + position_ids=assistant_position_ids, + past_key_values=assistant_past_key_values, encoder_outputs=model_kwargs["assistant_encoder_outputs"], ) else: assistant_model_outputs = assistant_model( assist_inputs, - past_key_values=model_kwargs["assistant_past_key_values"], + attention_mask=assistant_attention_mask, + position_ids=assistant_position_ids, + past_key_values=assistant_past_key_values, ) else: if assistant_model.config.is_encoder_decoder: assistant_model_outputs = assistant_model( decoder_input_ids=candidate_input_ids, + attention_mask=assistant_attention_mask, + position_ids=assistant_position_ids, encoder_outputs=model_kwargs["assistant_encoder_outputs"], ) else: - assistant_model_outputs = assistant_model(candidate_input_ids) + assistant_model_outputs = assistant_model( + candidate_input_ids, + attention_mask=assistant_attention_mask, + position_ids=assistant_position_ids, + ) # 1.2. greedily select the next candidate token model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values @@ -4473,6 +4508,9 @@ def assisted_decoding( ) new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1) candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1) + assistant_attention_mask = torch.cat( + (assistant_attention_mask, torch.ones([1, 1], dtype=assistant_attention_mask.dtype)), dim=-1 + ) # 1.3. stop assistant generation on EOS if eos_token_id_tensor is not None: From 75d47fd15929fc9482f6fd62065eaffeb6d41a0d Mon Sep 17 00:00:00 2001 From: "Feng, Jiqing" Date: Tue, 17 Oct 2023 23:35:45 -0700 Subject: [PATCH 2/8] fix bug --- src/transformers/generation/utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 00f6f885a3b..e2163564357 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2498,9 +2498,6 @@ def greedy_search( break # prepare model inputs - import pdb - - pdb.set_trace() model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token @@ -4440,7 +4437,6 @@ def assisted_decoding( # `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we # need access to the assistant cache to secure strong speedups. candidate_input_ids = input_ids - # import pdb; pdb.set_trace() assistant_attention_mask = model_kwargs.get("attention_mask", None) if assistant_attention_mask is not None: assistant_attention_mask = torch.cat( From 0f3faf00b2152c42bcaf32925e8291e599fb04d4 Mon Sep 17 00:00:00 2001 From: "Feng, Jiqing" Date: Sun, 29 Oct 2023 04:33:43 -0700 Subject: [PATCH 3/8] fix attention mask --- src/transformers/generation/utils.py | 79 ++++++++-------------------- 1 file changed, 23 insertions(+), 56 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e95875cf024..87e62353be5 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4440,63 +4440,20 @@ def assisted_decoding( # need access to the assistant cache to secure strong speedups. candidate_input_ids = input_ids assistant_attention_mask = model_kwargs.get("attention_mask", None) - if assistant_attention_mask is not None: - assistant_attention_mask = torch.cat( - [ - assistant_attention_mask, - torch.ones( - [ - assistant_attention_mask.shape[0], - input_ids.shape[-1] - assistant_attention_mask.shape[-1], - ], - dtype=assistant_attention_mask.dtype, - ), - ], - dim=-1, - ) + assistant_decoder_attention_mask = model_kwargs.get("decoder_attention_mask", None) + assistant_encoder_outputs = (model_kwargs.get("assistant_encoder_outputs", None),) for _ in range(int(num_assistant_tokens)): # 1.1. use the assistant model to obtain the next candidate logits - assistant_past_key_values = model_kwargs.get("assistant_past_key_values", None) - assistant_position_ids = assistant_model.prepare_inputs_for_generation( + assistant_inputs = assistant_model.prepare_inputs_for_generation( candidate_input_ids, attention_mask=assistant_attention_mask, - past_key_values=assistant_past_key_values, - ).get("position_ids", None) - if assistant_past_key_values is not None: - prev_seq_len = assistant_past_key_values[0][assistant_kv_indexing].shape[-2] - # `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model) - new_token_len = candidate_input_ids.shape[1] - prev_seq_len - assist_inputs = candidate_input_ids[:, -new_token_len:] - # TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2 - if assistant_model.config.is_encoder_decoder: - assistant_model_outputs = assistant_model( - decoder_input_ids=assist_inputs, - attention_mask=assistant_attention_mask, - position_ids=assistant_position_ids, - past_key_values=assistant_past_key_values, - encoder_outputs=model_kwargs["assistant_encoder_outputs"], - ) - else: - assistant_model_outputs = assistant_model( - assist_inputs, - attention_mask=assistant_attention_mask, - position_ids=assistant_position_ids, - past_key_values=assistant_past_key_values, - ) - else: - if assistant_model.config.is_encoder_decoder: - assistant_model_outputs = assistant_model( - decoder_input_ids=candidate_input_ids, - attention_mask=assistant_attention_mask, - position_ids=assistant_position_ids, - encoder_outputs=model_kwargs["assistant_encoder_outputs"], - ) - else: - assistant_model_outputs = assistant_model( - candidate_input_ids, - attention_mask=assistant_attention_mask, - position_ids=assistant_position_ids, - ) + decoder_attention_mask=assistant_decoder_attention_mask, + encoder_outputs=assistant_encoder_outputs, + past_key_values=model_kwargs.get("assistant_past_key_values", None), + ) + assistant_model_outputs = assistant_model( + **assistant_inputs, + ) # 1.2. greedily select the next candidate token model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values @@ -4504,11 +4461,21 @@ def assisted_decoding( assistant_model_outputs.logits[:, -1, :] = logits_processor( candidate_input_ids, assistant_model_outputs.logits[:, -1, :] ) + new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1) candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1) - assistant_attention_mask = torch.cat( - (assistant_attention_mask, torch.ones([1, 1], dtype=assistant_attention_mask.dtype)), dim=-1 - ) + if self.config.is_encoder_decoder and assistant_decoder_attention_mask is not None: + assistant_decoder_attention_mask = torch.cat( + ( + assistant_decoder_attention_mask, + torch.ones([1, 1], dtype=assistant_decoder_attention_mask.dtype), + ), + dim=-1, + ) + elif not self.config.is_encoder_decoder and assistant_attention_mask is not None: + assistant_attention_mask = torch.cat( + (assistant_attention_mask, torch.ones([1, 1], dtype=assistant_attention_mask.dtype)), dim=-1 + ) # 1.3. stop assistant generation on EOS if eos_token_id_tensor is not None: From 4fd18fff7217ac48c1f0a2aca3e8cdd788ee2a48 Mon Sep 17 00:00:00 2001 From: "Feng, Jiqing" Date: Sun, 29 Oct 2023 07:18:54 -0700 Subject: [PATCH 4/8] fix attention_mask --- src/transformers/generation/utils.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 87e62353be5..1460d6cd89c 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4410,15 +4410,6 @@ def assisted_decoding( # other auxiliary variables max_len = stopping_criteria[0].max_length - assistant_kv_indexing = ( - 1 - if "bloom" in assistant_model.__class__.__name__.lower() - or ( - assistant_model.config.architectures is not None - and "bloom" in assistant_model.config.architectures[0].lower() - ) - else 0 - ) this_peer_finished = False # used by synced_gpus only while True: @@ -4611,6 +4602,13 @@ def assisted_decoding( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) + # Update attention_mask + if n_matches > 0 and model_kwargs.get("attention_mask", None) is not None: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], n_matches))], dim=-1 + ) + # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( From 55feaa4a40a451cb6110dff6b02902d7d969f589 Mon Sep 17 00:00:00 2001 From: "Feng, Jiqing" Date: Tue, 31 Oct 2023 18:56:14 -0700 Subject: [PATCH 5/8] check assist inputs --- src/transformers/generation/utils.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1460d6cd89c..e7465d72a49 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4442,9 +4442,16 @@ def assisted_decoding( encoder_outputs=assistant_encoder_outputs, past_key_values=model_kwargs.get("assistant_past_key_values", None), ) - assistant_model_outputs = assistant_model( - **assistant_inputs, - ) + if assistant_inputs.get("past_key_values", None) is not None: + if self.config.is_encoder_decoder: + input_ids_len = assistant_inputs["decoder_input_ids"].shape[-1] + else: + input_ids_len = assistant_inputs["input_ids"].shape[-1] + + if input_ids_len not in (0, 1): + raise ValueError("The length of the input ids in assistant inputs should be 1 or 2") + + assistant_model_outputs = assistant_model(**assistant_inputs) # 1.2. greedily select the next candidate token model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values @@ -4602,7 +4609,7 @@ def assisted_decoding( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) - # Update attention_mask + # Update attention_mask for the assistant's next round of generations if n_matches > 0 and model_kwargs.get("attention_mask", None) is not None: attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat( From d121258d2224b36804182671aee6f78d5a135786 Mon Sep 17 00:00:00 2001 From: "Feng, Jiqing" Date: Tue, 31 Oct 2023 18:56:52 -0700 Subject: [PATCH 6/8] check assist input ids length --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e7465d72a49..339daa48af5 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4448,7 +4448,7 @@ def assisted_decoding( else: input_ids_len = assistant_inputs["input_ids"].shape[-1] - if input_ids_len not in (0, 1): + if input_ids_len not in (1, 2): raise ValueError("The length of the input ids in assistant inputs should be 1 or 2") assistant_model_outputs = assistant_model(**assistant_inputs) From a1e3b65d97d1776f000884166d37d65c37eca9bc Mon Sep 17 00:00:00 2001 From: "Feng, Jiqing" Date: Thu, 2 Nov 2023 19:36:46 -0700 Subject: [PATCH 7/8] fix assist model type --- src/transformers/generation/utils.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 2453a8e627f..85c4aa6fc54 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4488,11 +4488,6 @@ def assisted_decoding( else: num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens - # check if assistant model accepts encoder_outputs - assistant_accepts_encoder_outputs = "encoder_outputs" in set( - inspect.signature(assistant_model.forward).parameters.keys() - ) - # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() @@ -4568,7 +4563,7 @@ def assisted_decoding( past_key_values=model_kwargs.get("assistant_past_key_values", None), ) if assistant_inputs.get("past_key_values", None) is not None: - if self.config.is_encoder_decoder: + if assistant_model.config.is_encoder_decoder: input_ids_len = assistant_inputs["decoder_input_ids"].shape[-1] else: input_ids_len = assistant_inputs["input_ids"].shape[-1] @@ -4587,7 +4582,7 @@ def assisted_decoding( new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1) candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1) - if self.config.is_encoder_decoder and assistant_decoder_attention_mask is not None: + if assistant_model.config.is_encoder_decoder and assistant_decoder_attention_mask is not None: assistant_decoder_attention_mask = torch.cat( ( assistant_decoder_attention_mask, @@ -4595,7 +4590,7 @@ def assisted_decoding( ), dim=-1, ) - elif not self.config.is_encoder_decoder and assistant_attention_mask is not None: + elif not assistant_model.config.is_encoder_decoder and assistant_attention_mask is not None: assistant_attention_mask = torch.cat( (assistant_attention_mask, torch.ones([1, 1], dtype=assistant_attention_mask.dtype)), dim=-1 ) From 54a94c20af0e7a5ae769329840f7555fc53c1fee Mon Sep 17 00:00:00 2001 From: "Feng, Jiqing" Date: Tue, 7 Nov 2023 17:21:51 -0800 Subject: [PATCH 8/8] set assist attention mask device --- src/transformers/generation/utils.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 85c4aa6fc54..4dbfc367064 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4586,13 +4586,23 @@ def assisted_decoding( assistant_decoder_attention_mask = torch.cat( ( assistant_decoder_attention_mask, - torch.ones([1, 1], dtype=assistant_decoder_attention_mask.dtype), + torch.ones( + [1, 1], + dtype=assistant_decoder_attention_mask.dtype, + device=assistant_decoder_attention_mask.device, + ), ), dim=-1, ) elif not assistant_model.config.is_encoder_decoder and assistant_attention_mask is not None: assistant_attention_mask = torch.cat( - (assistant_attention_mask, torch.ones([1, 1], dtype=assistant_attention_mask.dtype)), dim=-1 + ( + assistant_attention_mask, + torch.ones( + [1, 1], dtype=assistant_attention_mask.dtype, device=assistant_attention_mask.device + ), + ), + dim=-1, ) # 1.3. stop assistant generation on EOS