-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
31 lines (24 loc) · 1023 Bytes
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import argparse
from model import get_model
from data import *
parser = argparse.ArgumentParser()
parser.add_argument('--emb_size', default=512, type=int, help='embedding size')
parser.add_argument('--hidden_size', default=512, type=int, help='hidden size')
parser.add_argument("--name", default="hinton", help="enter the name to predict country of origin")
parser.add_argument('--dropout', default=0, type=float, help='dropout value')
args = parser.parse_args()
model = get_model(n_letters, args.emb_size, args.hidden_size, n_categories, dropout=args.dropout)
model.load_state_dict(torch.load('./checkpoints/model.pt'))
def predict(name):
seq = list(name)
for i, j in enumerate(seq):
seq[i] = all_letters.find(j)
seq_len = len(seq)
inp = torch.tensor(seq).unsqueeze(0)
model.eval()
out = model.infer(inp)
_, idx = torch.max(out, dim=1)
return all_categories[idx.item()]
if __name__ == '__main__':
print(f'prediction for {args.name} is {predict(args.name)}')