From f60103c5c0b54cf9a294e27bda95b0aa43e49d35 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Mon, 20 Jan 2025 13:09:45 +0100 Subject: [PATCH] don't pass input_pos_maxp1 to ThunderModules --- litgpt/generate/base.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/litgpt/generate/base.py b/litgpt/generate/base.py index 273fe7497d..572c700a0c 100644 --- a/litgpt/generate/base.py +++ b/litgpt/generate/base.py @@ -174,7 +174,10 @@ def generate_fn( token = prompt prefill_token = True input_pos = torch.arange(0, prompt_size, device=device, dtype=torch.int64) - input_pos_maxp1 = torch.tensor(prompt_size, device=device) + if model.__class__.__name__ != 'ThunderModule': + input_pos_maxp1 = torch.tensor(prompt_size, device=device) + else: + input_pos_maxp1 = None for current_idx in range(max_returned_tokens - prompt_size): # Generate the token @@ -222,7 +225,8 @@ def generate_fn( input_pos = torch.tensor([prompt_size], device=device, dtype=torch.int64) else: input_pos.add_(1) - input_pos_maxp1.add_(1) + if input_pos_maxp1 is not None: + input_pos_maxp1.add_(1) # Yield any remaining tokens if yielded_idx < len(tokens):