-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathconfig.py
127 lines (96 loc) · 4.18 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
Configuration for Simultaneous Neural Machine Translation
"""
from collections import OrderedDict
# data_home = '/home/thoma/scratch/un16/'
# model_home = '/home/thoma/scratch/simul/'
# data_home = '/mnt/scratch/un16/'
# model_home = '/mnt/scratch/simul/'
data_home = '/cs/natlang-expts/aalineja/dl4mt-simul-trans/data/'
model_home = './model/'
def pretrain_config():
"""Configuration for pretraining underlining NMT model."""
config = dict()
# training set (source, target)
config['datasets'] = [data_home + 'all_de-en.de.tok.bpe.shuf',
data_home + 'all_de-en.en.tok.bpe.shuf']
# validation set (source, target)
# config['valid_datasets'] = [data_home + 'newstest2011.en.tok.bpe',
# data_home + 'newstest2011.de.tok.bpe']
config['valid_datasets'] = [data_home + 'newstest2011.de.tok.bpe',
data_home + 'newstest2011.en.tok.bpe']
# vocabulary (source, target)
config['dictionaries'] = [data_home + 'all_de-en.de.tok.bpe.pkl',
data_home + 'all_de-en.en.tok.bpe.pkl']
# save the model to
config['saveto'] = model_home + 'pretrained/wmt15_bpe2k_uni_de-en.npz'
config['reload_'] = True
# model details
config['dim_word'] = 1028
config['dim'] = 1028
config['n_words'] = 20000
config['n_words_src'] = 20000
# learning details
config['decay_c'] = 0
config['clip_c'] = 1.
config['use_dropout'] = False
config['lrate'] = 0.0001
config['optimizer'] = 'adadelta'
config['patience'] = 1000
config['maxlen'] = 50
config['batch_size'] = 32
config['valid_batch_size'] = 32
config['validFreq'] = 5000
config['dispFreq'] = 20
config['saveFreq'] = 5000
config['sampleFreq'] = 500
config['birnn'] = False
return config
def rl_config():
"""Configuration for training the agent using REINFORCE algorithm."""
config = OrderedDict() # general configuration
# work-space
config['workspace'] = model_home
# training set (source, target); or leave it None, agent will use the same corpus saved in the model
config['datasets'] = [data_home + 'all_de-en.de.tok.bpe.shuf',
data_home + 'all_de-en.en.tok.bpe.shuf']
# validation set (source, target); or leave it None, agent will use the same corpus saved in the model
config['valid_datasets'] = [data_home + 'newstest2011.de.tok.bpe',
data_home + 'newstest2011.en.tok.bpe']
# vocabulary (source, target); or leave it None, agent will use the same dictionary saved in the model
config['dictionaries'] = [data_home + 'all_de-en.de.tok.bpe.pkl',
data_home + 'all_de-en.en.tok.bpe.pkl']
# pretrained model
config['model'] = model_home + 'pretrained/wmt15_bpe2k_uni_de-en.npz'
config['option'] = model_home + 'pretrained/wmt15_bpe2k_uni_de-en.npz.pkl'
# critical training parameters.
config['sample'] = 5
config['batchsize'] = 10
config['rl_maxlen'] = 40
config['target_ap'] = 1 # 1 means no AP. # 0.75 # target delay if using AP as reward.
config['target_cw'] = 8 # if cw > 0 use cw mode
# under-construction
config['predict'] = True
# learning rate
config['lr_policy'] = 0.000002
config['lr_model'] = 0.00002
# policy parameters
config['prop'] = 0.5 # leave it default
config['recurrent'] = True # use a recurrent agent
config['layernorm'] = False # layer normalalization for the GRU agent.
config['updater'] = 'REINFORCE' # 'TRPO' not work well.
config['act_mask'] = True # leave it default
# old model parameters (maybe useless, leave them default)
config['step'] = 1
config['peek'] = 1
config['s0'] = 1
config['gamma'] = 1
config['Rtype'] = 10
config['maxsrc'] = 10
config['pre'] = False
config['coverage'] = False
config['upper'] = False
config['finetune'] = True
config['train_gt'] = False # when training with GT, fix the random agent??
config['full_att'] = True
return config