Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lookahead with do_sample=True does not take temperature, top_k, top_p #26

Open
learning-chip opened this issue Apr 23, 2024 · 2 comments

Comments

@learning-chip
Copy link

learning-chip commented Apr 23, 2024

Here lookahead_generation doesn't take logits_warper as input:

elif generation_mode == GenerationMode.LOOKAHEAD_GENERATION:
# 11. run greedy search
return self.lookahead_generation(
input_ids,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)

logits_warper is used in original sample to modify next_tokens_scores:

return self.sample(
input_ids,
logits_processor=logits_processor,
logits_warper=logits_warper,
stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)

and to modifies logits by temperature, top_k, top_p...

        if generation_config.temperature is not None and generation_config.temperature != 1.0:
            warpers.append(TemperatureLogitsWarper(generation_config.temperature))
        if generation_config.top_k is not None and generation_config.top_k != 0:
            warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
        if generation_config.top_p is not None and generation_config.top_p < 1.0:
            warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))

https://github.com/huggingface/transformers/blob/09f9f566de83eef1f13ee83b5a1bbeebde5c80c1/src/transformers/generation/utils.py#L728-L733

This is not applied inside lookahead_generation. So with do_sample=True the temperature is always one

@zheyishine
Copy link
Collaborator

Thank you, we will fix it soon.

@snippetzero
Copy link

@zheyishine Hi, Has there been any progress?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants