Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix assisted decoding #31401

Merged
merged 12 commits into from
Jul 3, 2024
5 changes: 4 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,8 +1471,11 @@ def _tensor_or_none(token_kwargs, token_self, device=None):
device = self.device

token = token_kwargs if token_kwargs is not None else token_self
if token is None or isinstance(token, torch.Tensor):
if token is None:
return token
elif isinstance(token, torch.Tensor):
return token.to(device)

return torch.tensor(token, device=device, dtype=torch.long)

bos_token_id = _tensor_or_none(
Expand Down
50 changes: 50 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
require_auto_gptq,
require_quanto,
require_torch,
require_torch_gpu,
require_torch_multi_accelerator,
require_torch_multi_gpu,
slow,
torch_device,
)
Expand Down Expand Up @@ -3097,6 +3099,54 @@ def test_return_unprocessed_logit_scores(self):
self.assertTrue(y_prob > 0.001 and n_prob > 0.001)
self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0)

@slow
@require_torch_multi_gpu
def test_assisted_decoding_in_different_gpu(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda:0")
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
"cuda:1"
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
model.config.pad_token_id = tokenizer.eos_token_id
assistant.config.pad_token_id = tokenizer.eos_token_id

text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)
input_length = input_ids.shape[-1]

out = model.generate(
input_ids,
assistant_model=assistant,
max_new_tokens=20,
)
self.assertTrue(input_length <= out.shape[-1] <= input_length + 20)

@slow
@require_torch_gpu
def test_assisted_decoding_in_gpu_cpu(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda")
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
"cpu"
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
model.config.pad_token_id = tokenizer.eos_token_id
assistant.config.pad_token_id = tokenizer.eos_token_id

text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)
input_length = input_ids.shape[-1]

out = model.generate(
input_ids,
assistant_model=assistant,
max_new_tokens=20,
)
self.assertTrue(input_length <= out.shape[-1] <= input_length + 20)


@require_torch
class TokenHealingTestCase(unittest.TestCase):
Expand Down