diff --git a/ch05/01_main-chapter-code/gpt_generate.py b/ch05/01_main-chapter-code/gpt_generate.py index 0a5b8141..92c00102 100644 --- a/ch05/01_main-chapter-code/gpt_generate.py +++ b/ch05/01_main-chapter-code/gpt_generate.py @@ -270,7 +270,7 @@ def main(gpt_config, input_prompt, model_size): token_ids = generate( model=gpt, - idx=text_to_token_ids(input_prompt, tokenizer), + idx=text_to_token_ids(input_prompt, tokenizer).to(device), max_new_tokens=25, context_size=gpt_config["context_length"], top_k=50,