-
Notifications
You must be signed in to change notification settings - Fork 18
/
train-teaching.py
executable file
·136 lines (117 loc) · 5.43 KB
/
train-teaching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python
import warnings
# This removes the annoying warning from h5py
warnings.simplefilter(action='ignore', category=FutureWarning)
import click
import os
import tpprl.exp_teacher as ET
from tpprl.utils import _now
import tensorflow as tf
import numpy as np
import sys
@click.command()
@click.argument('initial_difficulty_csv', type=click.Path(exists=True))
@click.argument('alpha', type=float)
@click.argument('beta', type=float)
@click.argument('output_dir', type=click.Path())
@click.option('--epochs', 'epochs', help='How many epochs to train for.', default=1000, show_default=True)
@click.option('--num-iters', 'num_iters', help='How many iterations in each epoch.', default=50, show_default=True)
@click.option('--save-every', 'save_every', help='How many epochs to save a copy of the parameters to disk.', default=200, show_default=True)
@click.option('--T', 'T', help='The learning duration (in days).', default=14, show_default=True)
@click.option('--tau', 'tau', help='Delay before the test.', default=2, show_default=True)
@click.option('--with-summaries/--no-with-summaries', 'with_summaries', help='Whether to save summaries.', default=False, show_default=True)
@click.option('--summary-suffix', 'summary_suffix', help='Suffix to add to the summary directory', default='', show_default=True)
@click.option('--only-cpu/--no-only-cpu', 'only_cpu', help='Whether to use only the CPU for training.', default=True, show_default=True)
@click.option('--q', 'q', help='Weight for the intensity regularizer.', default=0.00025, show_default=True)
@click.option('--q-entropy', 'q_entropy', help='Weight for the entropy regularizer.', default=0.002, show_default=True)
@click.option('--restore/--no-restore', 'should_restore', help='Whether to restore from the last save or overwrite the previous progress (if it exists).', default=True, show_default=True)
@click.option('--until', 'until', help='How many steps of iterations to run.', default=20000, show_default=True)
@click.option('--with-mp/--no-with-mp', 'with_MP', help='Whether to use multiprocessing module to run simulations in parallel.', default=True, show_default=True)
@click.option('--with-recall-probs/--no-with-recall-probs', 'with_recall_probs', help='Whether to provide true probability of recall or only binary feedback to the agent.', default=False, show_default=True)
@click.option('--with-zero-wt/--no-with-zero-wt', 'with_zero_wt', help='Force wt to be zero.', default=False, show_default=True)
def cmd(initial_difficulty_csv, alpha, beta, output_dir, should_restore,
T, tau, with_summaries, summary_suffix, only_cpu, q, q_entropy,
epochs, num_iters, save_every, until, with_MP, with_recall_probs,
with_zero_wt):
"""Read initial difficulty of items from INITIAL_DIFFICULTY_CSV, ALPHA and
BETA, train an optimal teacher and save the results to output_dir."""
with open(initial_difficulty_csv, 'r') as f:
n_0s = [float(x.strip()) for x in f.readline().split(',')]
num_items = len(n_0s)
scenario_opts = {
'T': T,
'tau': tau,
'n_0s': n_0s,
'alphas': np.ones(num_items) * alpha,
'betas': np.ones(num_items) * beta,
}
summary_dir = os.path.join(output_dir, 'summary/train-{}'.format(summary_suffix))
save_dir = os.path.join(output_dir, 'save/')
os.makedirs(summary_dir, exist_ok=True)
os.makedirs(save_dir, exist_ok=True)
teacher_opts = ET.mk_def_teacher_opts(
num_items=num_items,
hidden_dims=8,
learning_rate=0.02,
decay_rate=0.02,
summary_dir=summary_dir,
save_dir=save_dir,
batch_size=32,
only_cpu=only_cpu,
T=T,
tau=tau,
q=q,
q_entropy=q_entropy,
learning_bump=1.0,
decay_steps=10,
scenario_opts=scenario_opts,
set_wt_zero=with_zero_wt,
)
config = tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=False
)
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
teacher = ET.ExpRecurrentTeacher(
_opts=teacher_opts,
sess=sess,
num_items=num_items
)
teacher.initialize(finalize=True)
if should_restore and os.path.exists(save_dir):
try:
teacher.restore()
global_steps = teacher.sess.run(teacher.global_step)
print(_now(), "Restored successfully to step {}.".format(global_steps))
except (FileNotFoundError, AttributeError):
warnings.warn('"{}" exists, but no save files were found. Not restoring.'
.format(save_dir))
global_steps = teacher.sess.run(teacher.global_step)
if global_steps > until:
print(
_now(),
'Have already run {} > {} iterations, not going further.'
.format(global_steps, until)
)
for epoch in range(epochs):
sys.stdout.flush()
teacher.train_many(
num_iters=num_iters,
init_seed=42,
with_summaries=with_summaries,
with_MP=with_MP,
with_memorize_loss=False,
save_every=save_every,
with_recall_probs=with_recall_probs,
)
step = teacher.sess.run(teacher.global_step)
if step > until:
print(
_now(),
'Have already run {} > {} iterations, not going further.'
.format(step, until)
)
break
if __name__ == '__main__':
cmd()