diff --git a/litgpt/generate/base.py b/litgpt/generate/base.py index 273fe7497d..572c700a0c 100644 --- a/litgpt/generate/base.py +++ b/litgpt/generate/base.py @@ -174,7 +174,10 @@ def generate_fn( token = prompt prefill_token = True input_pos = torch.arange(0, prompt_size, device=device, dtype=torch.int64) - input_pos_maxp1 = torch.tensor(prompt_size, device=device) + if model.__class__.__name__ != 'ThunderModule': + input_pos_maxp1 = torch.tensor(prompt_size, device=device) + else: + input_pos_maxp1 = None for current_idx in range(max_returned_tokens - prompt_size): # Generate the token @@ -222,7 +225,8 @@ def generate_fn( input_pos = torch.tensor([prompt_size], device=device, dtype=torch.int64) else: input_pos.add_(1) - input_pos_maxp1.add_(1) + if input_pos_maxp1 is not None: + input_pos_maxp1.add_(1) # Yield any remaining tokens if yielded_idx < len(tokens):