From 1e33d3720b83c549b30317e2ed297b15539504b9 Mon Sep 17 00:00:00 2001 From: jmamou Date: Wed, 26 Jun 2024 06:03:53 -0700 Subject: [PATCH 01/18] optimal Speculation Lookahead based on probability --- .../generation/configuration_utils.py | 1 + src/transformers/generation/utils.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 8bb5e091d6db..707b11b4d247 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -396,6 +396,7 @@ def __init__(self, **kwargs): # Assistant generation self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5) self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic") + self.assistant_confidence_threshold = kwargs.pop("assistant_confidence_threshold", 0) # Cache implementation self.cache_implementation = kwargs.pop("cache_implementation", None) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index dd1719294e8f..a634830e0f0c 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2682,6 +2682,14 @@ def _sample( else (outputs.hidden_states,) ) + if ( + hasattr(generation_config, "assistant_confidence_threshold") + and generation_config.assistant_confidence_threshold > 0 + ): + p = next_token_scores.softmax(-1).max(-1).values + if p < generation_config.assistant_confidence_threshold: + this_peer_finished = True + # token selection if do_sample: probs = nn.functional.softmax(next_token_scores, dim=-1) @@ -3865,6 +3873,13 @@ def _assisted_decoding( if streamer is not None: streamer.end() + if ( + hasattr(candidate_generator, "assistant_threshold") + ): + candidate_generator.assistant_model.generation_config.assistant_threshold = ( + candidate_generator.assistant_threshold + ) + if ( hasattr(candidate_generator, "assistant_model") and candidate_generator.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic" @@ -3872,6 +3887,7 @@ def _assisted_decoding( candidate_generator.assistant_model.generation_config.num_assistant_tokens = ( candidate_generator.num_assistant_tokens ) + if return_dict_in_generate: if self.config.is_encoder_decoder: return GenerateEncoderDecoderOutput( From f1d92b194fa6d7e6efc9350f276a3e9d96ea3b0e Mon Sep 17 00:00:00 2001 From: jmamou Date: Tue, 2 Jul 2024 03:09:30 -0700 Subject: [PATCH 02/18] update peer finished condition --- src/transformers/generation/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index a634830e0f0c..9eec578dda3b 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2710,9 +2710,9 @@ def _sample( model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, ) - - unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) - this_peer_finished = unfinished_sequences.max() == 0 + if not this_peer_finished: + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + this_peer_finished = unfinished_sequences.max() == 0 # This is needed to properly delete outputs.logits which may be very large for first iteration # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration From 21ab024771e30d61bcf3462282393ecd96381dd4 Mon Sep 17 00:00:00 2001 From: jmamou Date: Tue, 2 Jul 2024 05:49:13 -0700 Subject: [PATCH 03/18] add support to do_sample True --- src/transformers/generation/utils.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 8d2ae878b565..62628fef5201 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2686,14 +2686,6 @@ def _sample( if self.config.is_encoder_decoder else (outputs.hidden_states,) ) - - if ( - hasattr(generation_config, "assistant_confidence_threshold") - and generation_config.assistant_confidence_threshold > 0 - ): - p = next_token_scores.softmax(-1).max(-1).values - if p < generation_config.assistant_confidence_threshold: - this_peer_finished = True # token selection if do_sample: @@ -2701,6 +2693,17 @@ def _sample( next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) else: next_tokens = torch.argmax(next_token_scores, dim=-1) + + if ( + hasattr(generation_config, "assistant_confidence_threshold") + and generation_config.assistant_confidence_threshold > 0 + ): + if do_sample: + p = probs[torch.arange(probs.size(0)), next_tokens] + else: + p = next_token_scores.softmax(-1).max(-1).values + if p < generation_config.assistant_confidence_threshold: + this_peer_finished = True # finished sentences should have their next token be a padding token if has_eos_stopping_criteria: From e7610f8996b433586f1d3d96dfa648207e59efa3 Mon Sep 17 00:00:00 2001 From: jmamou Date: Mon, 15 Jul 2024 03:15:55 -0700 Subject: [PATCH 04/18] add stopping criteria --- .../generation/candidate_generator.py | 2 ++ .../generation/stopping_criteria.py | 16 +++++++++++ src/transformers/generation/utils.py | 27 +++++-------------- 3 files changed, 25 insertions(+), 20 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index e735d0a2ca7f..07c3e2d31a97 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -107,6 +107,7 @@ def __init__( # Prepare the assistant and the starting number of candidate tokens self.assistant_model = assistant_model self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens + self.assistant_confidence_threshold = assistant_model.generation_config.assistant_confidence_threshold # Prepare the kwargs for the assistant model assistant_kwargs = {} @@ -149,6 +150,7 @@ def __init__( self.generation_config = copy.deepcopy(generation_config) self.generation_config.return_dict_in_generate = True self.generation_config.output_scores = True + self.generation_config.assistant_confidence_threshold = self.assistant_confidence_threshold # Disable sampling -- this implementation of assisted generation/speculative decoding uses the assistant # greedily to maximize matches. Disables sampling-related flags to prevent warnings diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index b1bf3dee9ae1..c3854e201343 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -73,6 +73,7 @@ def __init__(self, max_length: int, max_position_embeddings: Optional[int] = Non @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: cur_len = input_ids.shape[-1] + print(f"{cur_len=}\t{self.max_length=}") is_done = cur_len >= self.max_length if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings: logger.warning_once( @@ -499,13 +500,28 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa is_done = torch.isin(input_ids[:, -1], self.eos_token_id) return is_done +class ConfidenceCriteria(StoppingCriteria): + + def __init__(self, assistant_confidence_threshold): + self.assistant_confidence_threshold = assistant_confidence_threshold + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: + probs = scores[-1].softmax(-1) + p = probs[0, input_ids[0,-1]].item() + print(f"{p=}") + if p < self.assistant_confidence_threshold: + return True + return False + class StoppingCriteriaList(list): @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device, dtype=torch.bool) + print(f"BEFORE {is_done=}") for criteria in self: is_done = is_done | criteria(input_ids, scores, **kwargs) + print(f"{criteria=},{is_done=}") return is_done @property diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 62628fef5201..e3c19aef0100 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -94,6 +94,7 @@ WatermarkLogitsProcessor, ) from .stopping_criteria import ( + ConfidenceCriteria, EosTokenCriteria, MaxLengthCriteria, MaxTimeCriteria, @@ -989,6 +990,8 @@ def _get_stopping_criteria( criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer)) if generation_config.eos_token_id is not None: criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_id)) + if generation_config.assistant_confidence_threshold is not None and generation_config.assistant_confidence_threshold > 0: + criteria.append(ConfidenceCriteria(assistant_confidence_threshold=generation_config.assistant_confidence_threshold)) criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) return criteria @@ -1305,6 +1308,7 @@ def _prepare_generated_length( "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" ) generation_config.max_length = generation_config.max_new_tokens + input_ids_length + print(f"{generation_config.max_new_tokens=}\t{input_ids_length=}") # if both `inputs_embeds` and `input_ids` are passed, we do not correct the length # otherwise we need total length [inputs-embeds-len + new-tokens-len] to not go beyond indicated `max_length`` @@ -2693,17 +2697,6 @@ def _sample( next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) else: next_tokens = torch.argmax(next_token_scores, dim=-1) - - if ( - hasattr(generation_config, "assistant_confidence_threshold") - and generation_config.assistant_confidence_threshold > 0 - ): - if do_sample: - p = probs[torch.arange(probs.size(0)), next_tokens] - else: - p = next_token_scores.softmax(-1).max(-1).values - if p < generation_config.assistant_confidence_threshold: - this_peer_finished = True # finished sentences should have their next token be a padding token if has_eos_stopping_criteria: @@ -2721,6 +2714,7 @@ def _sample( if not this_peer_finished: unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) this_peer_finished = unfinished_sequences.max() == 0 + print(f"{this_peer_finished.item()=}") # This is needed to properly delete outputs.logits which may be very large for first iteration # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration @@ -2728,7 +2722,6 @@ def _sample( if streamer is not None: streamer.end() - if return_dict_in_generate: if self.config.is_encoder_decoder: return GenerateEncoderDecoderOutput( @@ -3729,7 +3722,7 @@ def _assisted_decoding( candidate_logits = candidate_logits.to(self.device) candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] - is_done_candidate = stopping_criteria(candidate_input_ids, None) + is_done_candidate = stopping_criteria(candidate_input_ids, candidate_logits) # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct, @@ -3876,17 +3869,11 @@ def _assisted_decoding( unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) this_peer_finished = unfinished_sequences.max() == 0 + print(f"{this_peer_finished=}") if streamer is not None: streamer.end() - if ( - hasattr(candidate_generator, "assistant_threshold") - ): - candidate_generator.assistant_model.generation_config.assistant_threshold = ( - candidate_generator.assistant_threshold - ) - if ( hasattr(candidate_generator, "assistant_model") and candidate_generator.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic" From a0b107d9bb45c77a4a7055aad6952902fe1e5f8e Mon Sep 17 00:00:00 2001 From: jmamou Date: Mon, 15 Jul 2024 03:45:50 -0700 Subject: [PATCH 05/18] gitignore --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 337f2ef2c735..c47bf45cb4f7 100644 --- a/.gitignore +++ b/.gitignore @@ -167,3 +167,6 @@ tags # ruff .ruff_cache +slurm-*.out +run*.py +run*.sh From adf359849123eb5c9efffce5fd1786cfefd96797 Mon Sep 17 00:00:00 2001 From: jmamou Date: Mon, 15 Jul 2024 05:23:05 -0700 Subject: [PATCH 06/18] add print --- src/transformers/generation/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 8eb6b17645b7..e2d33798678a 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -3934,6 +3934,7 @@ def _assisted_decoding( this_peer_finished = False while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + print("****** START DRAFT") cur_len = input_ids.shape[-1] # 1. Fetch candidate sequences from a `CandidateGenerator` @@ -3984,7 +3985,7 @@ def _assisted_decoding( if do_sample and len(logits_warper) > 0: for i in range(candidate_length + 1): new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) - + print("****** END DRAFT") # 3. Select the accepted tokens. There are two possible cases: # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) # 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf). From 39b9f63e3c4d8935503a5660920c75417e643e28 Mon Sep 17 00:00:00 2001 From: jmamou Date: Mon, 15 Jul 2024 06:03:46 -0700 Subject: [PATCH 07/18] remove prints --- src/transformers/generation/stopping_criteria.py | 4 ---- src/transformers/generation/utils.py | 10 ++-------- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index c3854e201343..85cf4ce06bdd 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -73,7 +73,6 @@ def __init__(self, max_length: int, max_position_embeddings: Optional[int] = Non @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: cur_len = input_ids.shape[-1] - print(f"{cur_len=}\t{self.max_length=}") is_done = cur_len >= self.max_length if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings: logger.warning_once( @@ -508,7 +507,6 @@ def __init__(self, assistant_confidence_threshold): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: probs = scores[-1].softmax(-1) p = probs[0, input_ids[0,-1]].item() - print(f"{p=}") if p < self.assistant_confidence_threshold: return True return False @@ -518,10 +516,8 @@ class StoppingCriteriaList(list): @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device, dtype=torch.bool) - print(f"BEFORE {is_done=}") for criteria in self: is_done = is_done | criteria(input_ids, scores, **kwargs) - print(f"{criteria=},{is_done=}") return is_done @property diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e2d33798678a..8e7e7818ed64 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1312,7 +1312,6 @@ def _prepare_generated_length( "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" ) generation_config.max_length = generation_config.max_new_tokens + input_ids_length - print(f"{generation_config.max_new_tokens=}\t{input_ids_length=}") # if both `inputs_embeds` and `input_ids` are passed, we do not correct the length # otherwise we need total length [inputs-embeds-len + new-tokens-len] to not go beyond indicated `max_length`` @@ -2967,10 +2966,8 @@ def _sample( model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, ) - if not this_peer_finished: - unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) - this_peer_finished = unfinished_sequences.max() == 0 - print(f"{this_peer_finished.item()=}") + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + this_peer_finished = unfinished_sequences.max() == 0 # This is needed to properly delete outputs.logits which may be very large for first iteration # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration @@ -3934,7 +3931,6 @@ def _assisted_decoding( this_peer_finished = False while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): - print("****** START DRAFT") cur_len = input_ids.shape[-1] # 1. Fetch candidate sequences from a `CandidateGenerator` @@ -3985,7 +3981,6 @@ def _assisted_decoding( if do_sample and len(logits_warper) > 0: for i in range(candidate_length + 1): new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) - print("****** END DRAFT") # 3. Select the accepted tokens. There are two possible cases: # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) # 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf). @@ -4091,7 +4086,6 @@ def _assisted_decoding( unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) this_peer_finished = unfinished_sequences.max() == 0 - print(f"{this_peer_finished=}") if streamer is not None: streamer.end() From bdda459ca689f1635a88861228b65f1e74e4d027 Mon Sep 17 00:00:00 2001 From: jmamou Date: Mon, 15 Jul 2024 06:11:08 -0700 Subject: [PATCH 08/18] minor --- src/transformers/generation/utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 8e7e7818ed64..7c19daf88e8a 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2944,8 +2944,7 @@ def _sample( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) - ) - + ) # token selection if do_sample: probs = nn.functional.softmax(next_token_scores, dim=-1) @@ -3940,7 +3939,7 @@ def _assisted_decoding( candidate_logits = candidate_logits.to(self.device) candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] - is_done_candidate = stopping_criteria(candidate_input_ids, candidate_logits) + is_done_candidate = stopping_criteria(candidate_input_ids, None) # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct, @@ -3981,6 +3980,7 @@ def _assisted_decoding( if do_sample and len(logits_warper) > 0: for i in range(candidate_length + 1): new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) + # 3. Select the accepted tokens. There are two possible cases: # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) # 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf). @@ -4097,7 +4097,6 @@ def _assisted_decoding( candidate_generator.assistant_model.generation_config.num_assistant_tokens = ( candidate_generator.num_assistant_tokens ) - if return_dict_in_generate: if self.config.is_encoder_decoder: return GenerateEncoderDecoderOutput( From 1916bcd651b7f6100878ff13808c00cf9ed1036c Mon Sep 17 00:00:00 2001 From: jmamou Date: Mon, 15 Jul 2024 06:15:28 -0700 Subject: [PATCH 09/18] minor --- src/transformers/generation/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 7c19daf88e8a..fb4908a5b475 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2944,7 +2944,8 @@ def _sample( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) - ) + ) + # token selection if do_sample: probs = nn.functional.softmax(next_token_scores, dim=-1) @@ -2965,6 +2966,7 @@ def _sample( model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, ) + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) this_peer_finished = unfinished_sequences.max() == 0 @@ -2974,6 +2976,7 @@ def _sample( if streamer is not None: streamer.end() + if return_dict_in_generate: if self.config.is_encoder_decoder: return GenerateEncoderDecoderOutput( @@ -3980,7 +3983,7 @@ def _assisted_decoding( if do_sample and len(logits_warper) > 0: for i in range(candidate_length + 1): new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) - + # 3. Select the accepted tokens. There are two possible cases: # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) # 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf). From 6fea2b87b8ae65cfbd81481f30464536227a8b1b Mon Sep 17 00:00:00 2001 From: jmamou Date: Wed, 17 Jul 2024 03:36:40 -0700 Subject: [PATCH 10/18] git ignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index c47bf45cb4f7..a5707744af8f 100644 --- a/.gitignore +++ b/.gitignore @@ -170,3 +170,5 @@ tags slurm-*.out run*.py run*.sh +dataset.py +prompts.py From 7b0103d60b645b8b1853d5dd6d7fce8f6d7d2173 Mon Sep 17 00:00:00 2001 From: jmamou Date: Mon, 2 Sep 2024 05:04:44 -0700 Subject: [PATCH 11/18] adding test to stopping ConfidenceCriteria --- src/transformers/generation/__init__.py | 2 ++ tests/generation/test_stopping_criteria.py | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index faf5266b84ae..2bea00261951 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -83,6 +83,7 @@ "MaxNewTokensCriteria", "MaxLengthCriteria", "MaxTimeCriteria", + "ConfidenceCriteria", "EosTokenCriteria", "StoppingCriteria", "StoppingCriteriaList", @@ -225,6 +226,7 @@ WhisperTimeStampLogitsProcessor, ) from .stopping_criteria import ( + ConfidenceCriteria, EosTokenCriteria, MaxLengthCriteria, MaxNewTokensCriteria, diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index a04dac96169e..a04dd55f7ff1 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -26,6 +26,7 @@ import torch from transformers.generation import ( + ConfidenceCriteria, EosTokenCriteria, MaxLengthCriteria, MaxTimeCriteria, @@ -100,6 +101,23 @@ def test_eos_token_criteria(self): input_ids[:, -1] = 1 self.assertListEqual(criteria(input_ids, scores).tolist(), [False, False, False]) + def test_confidence_criteria(self): + criteria = ConfidenceCriteria(assistant_confidence_threshold=0.5) + + vocab_size = 250 + length = 5 + + input_ids = ids_tensor((1, length), vocab_size) + scores = (torch.randn((1,vocab_size)),) + + # Simulate high confidence by setting the probability of the last token to be high + scores[0][0, input_ids[0, -1]] = 10.0 # Logits before softmax + self.assertFalse(criteria(input_ids, scores)) + + # Simulate low confidence by setting the probability of the last token to be low + scores[0][0, input_ids[0, -1]] = -10.0 # Logits before softmax + self.assertTrue(criteria(input_ids, scores)) + def test_validate_stopping_criteria(self): validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10) From 7d4a0959d7b3141312bb376b1879c6a977ee2b08 Mon Sep 17 00:00:00 2001 From: jmamou Date: Mon, 2 Sep 2024 05:40:02 -0700 Subject: [PATCH 12/18] doc + format --- src/transformers/generation/configuration_utils.py | 2 ++ src/transformers/generation/stopping_criteria.py | 4 ++-- src/transformers/generation/utils.py | 9 +++++++-- tests/generation/test_stopping_criteria.py | 4 ++-- 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index c532c877bdcf..7cfdea715ea8 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -324,6 +324,8 @@ class GenerationConfig(PushToHubMixin): reduce by 1. `num_assistant_tokens` value is persistent over multiple generation calls with the same assistant model. - `"heuristic_transient"`: Same as `"heuristic"` but `num_assistant_tokens` is reset to its initial value after each generation call. - `"constant"`: `num_assistant_tokens` stays unchanged during generation + assistant_confidence_threshold (`float`, *optional*, defaults to 0): + The confidence threshold for the assistant model. If the assistant model's confidence in its prediction for the current token is lower than this threshold, the assistant model stops generating tokens, even if the number of speculative tokens has not been reached. prompt_lookup_num_tokens (`int`, *optional*, default to `None`): The number of tokens to be output as candidate tokens. max_matching_ngram_size (`int`, *optional*, default to `None`): diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index ca4775ea555c..255721350963 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -466,14 +466,14 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa is_done = isin_mps_friendly(input_ids[:, -1], self.eos_token_id) return is_done -class ConfidenceCriteria(StoppingCriteria): +class ConfidenceCriteria(StoppingCriteria): def __init__(self, assistant_confidence_threshold): self.assistant_confidence_threshold = assistant_confidence_threshold def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: probs = scores[-1].softmax(-1) - p = probs[0, input_ids[0,-1]].item() + p = probs[0, input_ids[0, -1]].item() if p < self.assistant_confidence_threshold: return True return False diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 5424f03d10fa..045025ee45e0 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -970,8 +970,13 @@ def _get_stopping_criteria( criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer)) if generation_config._eos_token_tensor is not None: criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor)) - if generation_config.assistant_confidence_threshold is not None and generation_config.assistant_confidence_threshold > 0: - criteria.append(ConfidenceCriteria(assistant_confidence_threshold=generation_config.assistant_confidence_threshold)) + if ( + generation_config.assistant_confidence_threshold is not None + and generation_config.assistant_confidence_threshold > 0 + ): + criteria.append( + ConfidenceCriteria(assistant_confidence_threshold=generation_config.assistant_confidence_threshold) + ) criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) return criteria diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index a04dd55f7ff1..e8594dcdb07e 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -108,8 +108,8 @@ def test_confidence_criteria(self): length = 5 input_ids = ids_tensor((1, length), vocab_size) - scores = (torch.randn((1,vocab_size)),) - + scores = (torch.randn((1, vocab_size)),) + # Simulate high confidence by setting the probability of the last token to be high scores[0][0, input_ids[0, -1]] = 10.0 # Logits before softmax self.assertFalse(criteria(input_ids, scores)) From 1e6a0e0be3b037a7d8b77d6e08d99114c5b94915 Mon Sep 17 00:00:00 2001 From: jmamou Date: Mon, 2 Sep 2024 05:56:41 -0700 Subject: [PATCH 13/18] add doc --- src/transformers/generation/configuration_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 7cfdea715ea8..60c88292155b 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -325,7 +325,9 @@ class GenerationConfig(PushToHubMixin): - `"heuristic_transient"`: Same as `"heuristic"` but `num_assistant_tokens` is reset to its initial value after each generation call. - `"constant"`: `num_assistant_tokens` stays unchanged during generation assistant_confidence_threshold (`float`, *optional*, defaults to 0): - The confidence threshold for the assistant model. If the assistant model's confidence in its prediction for the current token is lower than this threshold, the assistant model stops generating tokens, even if the number of speculative tokens has not been reached. + The confidence threshold for the assistant model. If the assistant model's confidence in its prediction for the current token is lower + than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_ + (defined by `num_assistant_tokens`) is not yet reached. prompt_lookup_num_tokens (`int`, *optional*, default to `None`): The number of tokens to be output as candidate tokens. max_matching_ngram_size (`int`, *optional*, default to `None`): From 7a005d21da0e6518a9231c8001f0303e071ed244 Mon Sep 17 00:00:00 2001 From: Jonathan Mamou Date: Sun, 8 Sep 2024 14:24:23 +0300 Subject: [PATCH 14/18] Update .gitignore --- .gitignore | 5 ----- 1 file changed, 5 deletions(-) diff --git a/.gitignore b/.gitignore index a5707744af8f..337f2ef2c735 100644 --- a/.gitignore +++ b/.gitignore @@ -167,8 +167,3 @@ tags # ruff .ruff_cache -slurm-*.out -run*.py -run*.sh -dataset.py -prompts.py From 201741bb4e5e16217bd048f1aba61f7bfa789980 Mon Sep 17 00:00:00 2001 From: jmamou Date: Sun, 8 Sep 2024 05:19:08 -0700 Subject: [PATCH 15/18] update docstring and default value of assistant_confidence_threshold --- src/transformers/generation/configuration_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 60c88292155b..37bee501becd 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -324,10 +324,11 @@ class GenerationConfig(PushToHubMixin): reduce by 1. `num_assistant_tokens` value is persistent over multiple generation calls with the same assistant model. - `"heuristic_transient"`: Same as `"heuristic"` but `num_assistant_tokens` is reset to its initial value after each generation call. - `"constant"`: `num_assistant_tokens` stays unchanged during generation - assistant_confidence_threshold (`float`, *optional*, defaults to 0): + assistant_confidence_threshold (`float`, *optional*, defaults to None): The confidence threshold for the assistant model. If the assistant model's confidence in its prediction for the current token is lower than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_ - (defined by `num_assistant_tokens`) is not yet reached. + (defined by `num_assistant_tokens`) is not yet reached. It is an unsupervised version of the dynamic speculation lookahead + from Dynamic Speculation Lookahead Accelerates Speculative Decoding of Large Language Models . prompt_lookup_num_tokens (`int`, *optional*, default to `None`): The number of tokens to be output as candidate tokens. max_matching_ngram_size (`int`, *optional*, default to `None`): @@ -427,7 +428,7 @@ def __init__(self, **kwargs): # Assistant generation self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5) self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic") - self.assistant_confidence_threshold = kwargs.pop("assistant_confidence_threshold", 0) + self.assistant_confidence_threshold = kwargs.pop("assistant_confidence_threshold", None) # Prompt lookup decoding self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None) From 7c90a8a5891bf567c54c3957792a7d8df2cd8343 Mon Sep 17 00:00:00 2001 From: jmamou Date: Sun, 8 Sep 2024 05:19:33 -0700 Subject: [PATCH 16/18] add docstring --- src/transformers/generation/stopping_criteria.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 255721350963..069af00eb1bf 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -468,6 +468,15 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa class ConfidenceCriteria(StoppingCriteria): + """ + This class can be used to stop generation whenever assistant model's confidence in its prediction for the current token is lower than the threshold + `model.generation_config.assistant_confidence_threshold` even if the number of speculative tokens (defined by `num_assistant_tokens`) is not yet reached. + + Args: + assistant_confidence_threshold (`float`): + The value of the threshold. + """ + def __init__(self, assistant_confidence_threshold): self.assistant_confidence_threshold = assistant_confidence_threshold From f457553fb5f76d71f594b115a367c08cd5c2fd83 Mon Sep 17 00:00:00 2001 From: Jonathan Mamou Date: Wed, 11 Sep 2024 12:02:08 +0300 Subject: [PATCH 17/18] Update src/transformers/generation/configuration_utils.py implicit default value (None) Co-authored-by: Joao Gante --- src/transformers/generation/configuration_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 37bee501becd..659b239197d3 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -324,7 +324,7 @@ class GenerationConfig(PushToHubMixin): reduce by 1. `num_assistant_tokens` value is persistent over multiple generation calls with the same assistant model. - `"heuristic_transient"`: Same as `"heuristic"` but `num_assistant_tokens` is reset to its initial value after each generation call. - `"constant"`: `num_assistant_tokens` stays unchanged during generation - assistant_confidence_threshold (`float`, *optional*, defaults to None): + assistant_confidence_threshold (`float`, *optional*): The confidence threshold for the assistant model. If the assistant model's confidence in its prediction for the current token is lower than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_ (defined by `num_assistant_tokens`) is not yet reached. It is an unsupervised version of the dynamic speculation lookahead From cd71a924e7d8d0ace8e90206e75a37b0f7ffcb78 Mon Sep 17 00:00:00 2001 From: jmamou Date: Wed, 11 Sep 2024 02:31:14 -0700 Subject: [PATCH 18/18] style fix --- src/transformers/generation/configuration_utils.py | 2 +- src/transformers/generation/stopping_criteria.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 659b239197d3..01ffc6aa4a7a 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -327,7 +327,7 @@ class GenerationConfig(PushToHubMixin): assistant_confidence_threshold (`float`, *optional*): The confidence threshold for the assistant model. If the assistant model's confidence in its prediction for the current token is lower than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_ - (defined by `num_assistant_tokens`) is not yet reached. It is an unsupervised version of the dynamic speculation lookahead + (defined by `num_assistant_tokens`) is not yet reached. It is an unsupervised version of the dynamic speculation lookahead from Dynamic Speculation Lookahead Accelerates Speculative Decoding of Large Language Models . prompt_lookup_num_tokens (`int`, *optional*, default to `None`): The number of tokens to be output as candidate tokens. diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 069af00eb1bf..b950a69f8b64 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -469,14 +469,14 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa class ConfidenceCriteria(StoppingCriteria): """ - This class can be used to stop generation whenever assistant model's confidence in its prediction for the current token is lower than the threshold - `model.generation_config.assistant_confidence_threshold` even if the number of speculative tokens (defined by `num_assistant_tokens`) is not yet reached. + This class can be used to stop generation whenever assistant model's confidence in its prediction for the current token is lower than the threshold + `model.generation_config.assistant_confidence_threshold` even if the number of speculative tokens (defined by `num_assistant_tokens`) is not yet reached. Args: assistant_confidence_threshold (`float`): The value of the threshold. """ - + def __init__(self, assistant_confidence_threshold): self.assistant_confidence_threshold = assistant_confidence_threshold