-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
65 lines (52 loc) · 2.57 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# coding=utf-8
import argparse
import visdom
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from dataset import PoetryDataSet
from model import PoetryNet
parser = argparse.ArgumentParser(description='Pytorch 学习念诗')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--epochs', type=int, default=20, metavar='N',
help='number of epochs to train (default: 20)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--data-path', type=str, default='tang.npz', metavar='S',
help='the path of data (default: \'tang.npz\'')
parser.add_argument('--model-path', type=str, default='checkpoint/', metavar='S',
help='the path of models (default: \'checkpoint/\'')
parser.add_argument('--embedding-dim', type=int, default=128, metavar='N',
help='input embedding dim vocabulary for model (default: 128)')
parser.add_argument('--hidden-dim', type=int, default=256, metavar='N',
help='input hidden dim for model (default: 256)')
parser.add_argument('--print-per-batch', type=int, default=2, metavar='N')
args = parser.parse_args()
dataset = PoetryDataSet(args.data_path)
word2ix = dataset.word2ix
data_loader = DataLoader(dataset,
batch_size=args.batch_size,
shuffle=True)
vis = visdom.Visdom(env=u'poem')
model = PoetryNet(len(word2ix), args.embedding_dim, args.hidden_dim)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
criterion = nn.CrossEntropyLoss()
for epoch in range(args.epochs):
for i, data in enumerate(data_loader):
data = data.long().transpose(1, 0).contiguous()
optimizer.zero_grad()
input, target = Variable(data[:-1, :]), Variable(data[1:, :])
output, _ = model(input)
loss = criterion(output, target.view(-1))
loss.backward()
optimizer.step()
if i % args.print_per_batch == 0:
print 'epoch {}, iteration {}, loss = {}'.format(epoch + 1, i, loss.data[0])
x = torch.Tensor([i])
y = loss.data
vis.line(X=x, Y=y, win='epoch{}'.format(epoch), update='append' if i > 0 else None,
opts={'title': 'epoch{}'.format(epoch), 'xlabel': 'batch', 'ylabel': 'loss'})
torch.save(model.state_dict(), args.model_path + "{}.pth".format(epoch))