From 90911099754ae4ef269b00d258e849e945bbb55f Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 13 Jun 2024 12:46:02 -0400 Subject: [PATCH 1/9] fix assisted decoding --- src/transformers/generation/logits_process.py | 7 ++++--- src/transformers/generation/utils.py | 16 ++++++++-------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index b226a059d106..295f8a19dec7 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -154,7 +154,7 @@ def __init__(self, min_length: int, eos_token_id: Union[int, List[int], torch.Te @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) - eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id) + eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id.to(vocab_tensor.device)) scores_processed = scores.clone() if input_ids.shape[-1] < self.min_length: scores_processed = torch.where(eos_token_mask, -math.inf, scores) @@ -226,7 +226,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip scores_processed = scores.clone() vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) - eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id) + eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id.to(vocab_tensor.device)) if new_tokens_length < self.min_new_tokens: scores_processed = torch.where(eos_token_mask, -math.inf, scores) @@ -1582,7 +1582,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to scores_processed = scores if cur_len == self.max_length - 1: scores_processed = torch.full_like(scores, -math.inf) - scores_processed[:, self.eos_token_id] = 0 + scores_processed[:, self.eos_token_id.to(scores_processed.device)] = 0 return scores_processed @@ -2321,6 +2321,7 @@ def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: scores_processed = scores if self.min_eos_p: + self.eos_token_id = self.eos_token_id.to(scores.device) probs = torch.nn.functional.softmax(scores.float(), dim=-1) # create scores full of -inf except for the eos_token_id early_stop_scores = torch.ones_like(scores) * -float("inf") diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c68190908925..e56060ec76f9 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2183,7 +2183,7 @@ def _contrastive_search( has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) top_k = generation_config.top_k penalty_alpha = generation_config.penalty_alpha - pad_token_id = generation_config.pad_token_id + pad_token_id = generation_config.pad_token_id.to(self.device) output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores @@ -2591,7 +2591,7 @@ def _sample( `model.config.is_encoder_decoder=True`. """ # init values - pad_token_id = generation_config.pad_token_id + pad_token_id = generation_config.pad_token_id.to(self.device) output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores @@ -2800,8 +2800,8 @@ def _beam_search( `model.config.is_encoder_decoder=True`. """ # init values - pad_token_id = generation_config.pad_token_id - eos_token_id = generation_config.eos_token_id + pad_token_id = generation_config.pad_token_id.to(self.device) + eos_token_id = generation_config.eos_token_id.to(self.device) output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores @@ -3119,8 +3119,8 @@ def _group_beam_search( `model.config.is_encoder_decoder=True`. """ # init values - pad_token_id = generation_config.pad_token_id - eos_token_id = generation_config.eos_token_id + pad_token_id = generation_config.pad_token_id.to(self.device) + eos_token_id = generation_config.eos_token_id.to(self.device) output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores @@ -3411,8 +3411,8 @@ def _constrained_beam_search( `model.config.is_encoder_decoder=True`. """ # init values - pad_token_id = generation_config.pad_token_id - eos_token_id = generation_config.eos_token_id + pad_token_id = generation_config.pad_token_id.to(self.device) + eos_token_id = generation_config.eos_token_id.to(self.device) output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores From e2700d7c3345a01b04eb1444b64621e62d7a881c Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 14 Jun 2024 09:02:55 -0400 Subject: [PATCH 2/9] check None --- src/transformers/generation/utils.py | 32 +++++++++++++++++++++------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e56060ec76f9..da539995d6f1 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2183,7 +2183,9 @@ def _contrastive_search( has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) top_k = generation_config.top_k penalty_alpha = generation_config.penalty_alpha - pad_token_id = generation_config.pad_token_id.to(self.device) + pad_token_id = ( + generation_config.pad_token_id.to(self.device) if generation_config.pad_token_id is not None else None + ) output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores @@ -2591,7 +2593,9 @@ def _sample( `model.config.is_encoder_decoder=True`. """ # init values - pad_token_id = generation_config.pad_token_id.to(self.device) + pad_token_id = ( + generation_config.pad_token_id.to(self.device) if generation_config.pad_token_id is not None else None + ) output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores @@ -2800,8 +2804,12 @@ def _beam_search( `model.config.is_encoder_decoder=True`. """ # init values - pad_token_id = generation_config.pad_token_id.to(self.device) - eos_token_id = generation_config.eos_token_id.to(self.device) + pad_token_id = ( + generation_config.pad_token_id.to(self.device) if generation_config.pad_token_id is not None else None + ) + eos_token_id = ( + generation_config.eos_token_id.to(self.device) if generation_config.pad_token_id is not None else None + ) output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores @@ -3119,8 +3127,12 @@ def _group_beam_search( `model.config.is_encoder_decoder=True`. """ # init values - pad_token_id = generation_config.pad_token_id.to(self.device) - eos_token_id = generation_config.eos_token_id.to(self.device) + pad_token_id = ( + generation_config.pad_token_id.to(self.device) if generation_config.pad_token_id is not None else None + ) + eos_token_id = ( + generation_config.eos_token_id.to(self.device) if generation_config.pad_token_id is not None else None + ) output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores @@ -3411,8 +3423,12 @@ def _constrained_beam_search( `model.config.is_encoder_decoder=True`. """ # init values - pad_token_id = generation_config.pad_token_id.to(self.device) - eos_token_id = generation_config.eos_token_id.to(self.device) + pad_token_id = ( + generation_config.pad_token_id.to(self.device) if generation_config.pad_token_id is not None else None + ) + eos_token_id = ( + generation_config.eos_token_id.to(self.device) if generation_config.pad_token_id is not None else None + ) output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores From 4836a200ce927bd9242c5a089b8c07999e68ca45 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 14 Jun 2024 09:08:51 -0400 Subject: [PATCH 3/9] fix typo --- src/transformers/generation/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index da539995d6f1..dbd8a45c091b 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2808,7 +2808,7 @@ def _beam_search( generation_config.pad_token_id.to(self.device) if generation_config.pad_token_id is not None else None ) eos_token_id = ( - generation_config.eos_token_id.to(self.device) if generation_config.pad_token_id is not None else None + generation_config.eos_token_id.to(self.device) if generation_config.eos_token_id is not None else None ) output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states @@ -3131,7 +3131,7 @@ def _group_beam_search( generation_config.pad_token_id.to(self.device) if generation_config.pad_token_id is not None else None ) eos_token_id = ( - generation_config.eos_token_id.to(self.device) if generation_config.pad_token_id is not None else None + generation_config.eos_token_id.to(self.device) if generation_config.eos_token_id is not None else None ) output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states @@ -3427,7 +3427,7 @@ def _constrained_beam_search( generation_config.pad_token_id.to(self.device) if generation_config.pad_token_id is not None else None ) eos_token_id = ( - generation_config.eos_token_id.to(self.device) if generation_config.pad_token_id is not None else None + generation_config.eos_token_id.to(self.device) if generation_config.eos_token_id is not None else None ) output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states From ff49a2917ad098d37f95d80737fdc066218059d0 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 17 Jun 2024 07:04:13 -0400 Subject: [PATCH 4/9] fix _prepare_special_tokens --- src/transformers/generation/logits_process.py | 8 ++-- src/transformers/generation/utils.py | 37 ++++++------------- 2 files changed, 15 insertions(+), 30 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 295f8a19dec7..b7948600210f 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -154,7 +154,7 @@ def __init__(self, min_length: int, eos_token_id: Union[int, List[int], torch.Te @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) - eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id.to(vocab_tensor.device)) + eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id) scores_processed = scores.clone() if input_ids.shape[-1] < self.min_length: scores_processed = torch.where(eos_token_mask, -math.inf, scores) @@ -226,7 +226,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip scores_processed = scores.clone() vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) - eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id.to(vocab_tensor.device)) + eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id) if new_tokens_length < self.min_new_tokens: scores_processed = torch.where(eos_token_mask, -math.inf, scores) @@ -1582,7 +1582,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to scores_processed = scores if cur_len == self.max_length - 1: scores_processed = torch.full_like(scores, -math.inf) - scores_processed[:, self.eos_token_id.to(scores_processed.device)] = 0 + scores_processed[:, self.eos_token_id] = 0 return scores_processed @@ -1696,7 +1696,6 @@ def __init__( @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: cur_len = input_ids.shape[-1] - self.eos_token_id = self.eos_token_id.to(scores.device) penalties = torch.zeros_like(scores) scores_processed = scores if cur_len > self.regulation_start: @@ -2321,7 +2320,6 @@ def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: scores_processed = scores if self.min_eos_p: - self.eos_token_id = self.eos_token_id.to(scores.device) probs = torch.nn.functional.softmax(scores.float(), dim=-1) # create scores full of -inf except for the eos_token_id early_stop_scores = torch.ones_like(scores) * -float("inf") diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index dbd8a45c091b..8e913ce22b6e 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1467,8 +1467,11 @@ def _tensor_or_none(token_kwargs, token_self, device=None): device = self.device token = token_kwargs if token_kwargs is not None else token_self - if token is None or isinstance(token, torch.Tensor): + if token is None: return token + elif isinstance(token, torch.Tensor): + return token.to(device) + return torch.tensor(token, device=device, dtype=torch.long) bos_token_id = _tensor_or_none( @@ -2183,9 +2186,7 @@ def _contrastive_search( has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) top_k = generation_config.top_k penalty_alpha = generation_config.penalty_alpha - pad_token_id = ( - generation_config.pad_token_id.to(self.device) if generation_config.pad_token_id is not None else None - ) + pad_token_id = generation_config.pad_token_id output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores @@ -2593,9 +2594,7 @@ def _sample( `model.config.is_encoder_decoder=True`. """ # init values - pad_token_id = ( - generation_config.pad_token_id.to(self.device) if generation_config.pad_token_id is not None else None - ) + pad_token_id = generation_config.pad_token_id output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores @@ -2804,12 +2803,8 @@ def _beam_search( `model.config.is_encoder_decoder=True`. """ # init values - pad_token_id = ( - generation_config.pad_token_id.to(self.device) if generation_config.pad_token_id is not None else None - ) - eos_token_id = ( - generation_config.eos_token_id.to(self.device) if generation_config.eos_token_id is not None else None - ) + pad_token_id = generation_config.pad_token_id + eos_token_id = generation_config.eos_token_id output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores @@ -3127,12 +3122,8 @@ def _group_beam_search( `model.config.is_encoder_decoder=True`. """ # init values - pad_token_id = ( - generation_config.pad_token_id.to(self.device) if generation_config.pad_token_id is not None else None - ) - eos_token_id = ( - generation_config.eos_token_id.to(self.device) if generation_config.eos_token_id is not None else None - ) + pad_token_id = generation_config.pad_token_id + eos_token_id = generation_config.eos_token_id output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores @@ -3423,12 +3414,8 @@ def _constrained_beam_search( `model.config.is_encoder_decoder=True`. """ # init values - pad_token_id = ( - generation_config.pad_token_id.to(self.device) if generation_config.pad_token_id is not None else None - ) - eos_token_id = ( - generation_config.eos_token_id.to(self.device) if generation_config.eos_token_id is not None else None - ) + pad_token_id = generation_config.pad_token_id + eos_token_id = generation_config.eos_token_id output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores From c556ecbc7f9e56afde707319fe25067c43ab5e40 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 17 Jun 2024 07:06:05 -0400 Subject: [PATCH 5/9] fix style --- src/transformers/generation/logits_process.py | 1 + src/transformers/generation/utils.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index b7948600210f..b226a059d106 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1696,6 +1696,7 @@ def __init__( @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: cur_len = input_ids.shape[-1] + self.eos_token_id = self.eos_token_id.to(scores.device) penalties = torch.zeros_like(scores) scores_processed = scores if cur_len > self.regulation_start: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 8e913ce22b6e..213806a0c649 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2804,7 +2804,7 @@ def _beam_search( """ # init values pad_token_id = generation_config.pad_token_id - eos_token_id = generation_config.eos_token_id + eos_token_id = generation_config.eos_token_id output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores @@ -3123,7 +3123,7 @@ def _group_beam_search( """ # init values pad_token_id = generation_config.pad_token_id - eos_token_id = generation_config.eos_token_id + eos_token_id = generation_config.eos_token_id output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores @@ -3415,7 +3415,7 @@ def _constrained_beam_search( """ # init values pad_token_id = generation_config.pad_token_id - eos_token_id = generation_config.eos_token_id + eos_token_id = generation_config.eos_token_id output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores From 23ddb4b3b89bd4048704fa795c041a00efec8922 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 19 Jun 2024 15:32:06 -0400 Subject: [PATCH 6/9] add tests for assisted decoding --- tests/generation/test_utils.py | 44 ++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 1981f5a63919..1a323c1636fc 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3067,6 +3067,50 @@ def test_return_unprocessed_logit_scores(self): self.assertTrue(y_prob > 0.001 and n_prob > 0.001) self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0) + @slow + @require_torch_multi_gpu + def test_assisted_decoding_in_different_gpu(self): + # PT-only test: TF doesn't support assisted decoding yet. + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda:0") + assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda:1") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") + model.config.pad_token_id = tokenizer.eos_token_id + assistant.config.pad_token_id = tokenizer.eos_token_id + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + input_length = input_ids.shape[-1] + + out = model.generate( + input_ids, + assistant_model=assistant, + max_new_tokens=20, + ) + self.assertTrue(input_length <= out.shape[-1] <= 20) + + @slow + @require_torch_gpu + def test_assisted_decoding_in_different_gpu(self): + # PT-only test: TF doesn't support assisted decoding yet. + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda") + assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cpu") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") + model.config.pad_token_id = tokenizer.eos_token_id + assistant.config.pad_token_id = tokenizer.eos_token_id + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + input_length = input_ids.shape[-1] + + out = model.generate( + input_ids, + assistant_model=assistant, + max_new_tokens=20, + ) + self.assertTrue(input_length <= out.shape[-1] <= 20) + @require_torch class TokenHealingTestCase(unittest.TestCase): From 63360e7e7c7b4a5e600853bc384fd0fd5368eb78 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 19 Jun 2024 14:29:44 -0400 Subject: [PATCH 7/9] fix style --- tests/generation/test_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 1a323c1636fc..e27f7f38ef77 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -30,7 +30,9 @@ require_auto_gptq, require_quanto, require_torch, + require_torch_gpu, require_torch_multi_accelerator, + require_torch_multi_gpu, slow, torch_device, ) @@ -3091,7 +3093,7 @@ def test_assisted_decoding_in_different_gpu(self): @slow @require_torch_gpu - def test_assisted_decoding_in_different_gpu(self): + def test_assisted_decoding_in_gpu_cpu(self): # PT-only test: TF doesn't support assisted decoding yet. model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda") assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cpu") From 120ace0a4b3809b84a38e92abb7831298e98a656 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 19 Jun 2024 14:31:53 -0400 Subject: [PATCH 8/9] fix lint --- tests/generation/test_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index e27f7f38ef77..45dc054a43c0 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3074,7 +3074,9 @@ def test_return_unprocessed_logit_scores(self): def test_assisted_decoding_in_different_gpu(self): # PT-only test: TF doesn't support assisted decoding yet. model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda:0") - assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda:1") + assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to( + "cuda:1" + ) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") model.config.pad_token_id = tokenizer.eos_token_id assistant.config.pad_token_id = tokenizer.eos_token_id @@ -3096,7 +3098,9 @@ def test_assisted_decoding_in_different_gpu(self): def test_assisted_decoding_in_gpu_cpu(self): # PT-only test: TF doesn't support assisted decoding yet. model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda") - assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cpu") + assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to( + "cpu" + ) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") model.config.pad_token_id = tokenizer.eos_token_id assistant.config.pad_token_id = tokenizer.eos_token_id From 07eee580134105a6c12579f46c7037b3abbb39ed Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 1 Jul 2024 06:07:11 -0400 Subject: [PATCH 9/9] fix tests check --- tests/generation/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index acf86d837058..93622026f91e 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3121,7 +3121,7 @@ def test_assisted_decoding_in_different_gpu(self): assistant_model=assistant, max_new_tokens=20, ) - self.assertTrue(input_length <= out.shape[-1] <= 20) + self.assertTrue(input_length <= out.shape[-1] <= input_length + 20) @slow @require_torch_gpu @@ -3145,7 +3145,7 @@ def test_assisted_decoding_in_gpu_cpu(self): assistant_model=assistant, max_new_tokens=20, ) - self.assertTrue(input_length <= out.shape[-1] <= 20) + self.assertTrue(input_length <= out.shape[-1] <= input_length + 20) @require_torch