-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgpt2_inference.py
84 lines (67 loc) · 2.52 KB
/
gpt2_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
from gpt2 import GPT, GPTConfig
import tiktoken
import time
def load_model(checkpoint_path, device='cuda'):
# Load the checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device)
print("Loaded!")
# Create model with the same config
model = GPT(checkpoint['config'])
# Load the trained weights
model.load_state_dict(checkpoint['model'])
# Move model to device and set to eval mode
model.to(device)
model.eval()
return model
def generate(model, prompt, max_tokens=100, temperature=1.0, top_k=50, device='cuda'):
# Encode the prompt
enc = tiktoken.encoding_for_model('gpt2')
tokens = enc.encode(prompt)
tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(0) # Add batch dimension
tokens = tokens.to(device)
# Generate tokens
with torch.no_grad():
while tokens.size(1) < len(tokens[0]) + max_tokens:
# Get predictions
logits, _ = model(tokens)
logits = logits[:, -1, :] / temperature
# Optional: top-k sampling
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float('-inf')
# Sample from the distribution
probs = torch.nn.functional.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
print(enc.decode([next_token.item()]), end='', flush=True)
# Append to sequence
tokens = torch.cat((tokens, next_token), dim=1)
# Optional: stop if we generate an end token
if next_token.item() == enc.eot_token:
break
# Decode the generated tokens
generated_text = enc.decode(tokens[0].tolist())
return generated_text
# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")
# Load model
# model = load_model('gpt2_log/model_19072.pt', device)
model = load_model('gpt2_log/model_19072.pt', device)
# Generate text
t0 = time.time()
# prompt = "Hello, I'm a language model,"
prompt = input("Enter a prompt: ")
print(prompt, end='')
generated = generate(
model,
prompt,
max_tokens=50,
temperature=0.8, # Lower for more focused/conservative text
top_k=50, # Helps avoid rare/unwanted tokens
device=device
)
t1 = time.time()
print(f"Prompt: {prompt}")
print(f"Generated: {generated}")
print(f"Time taken: {((t1 - t0)*1000):.4f}")