-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathtrain.py
94 lines (80 loc) · 3.34 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gym
import torch
import psutil
from torch.backends import cudnn
import hkenv
import models
import trainer
import buffer
DEVICE = 'cuda'
cudnn.benchmark = True
def get_model(env: gym.Env, n_frames: int):
c, *shape = env.observation_space.shape
m = models.SimpleExtractor(shape, n_frames * c,
activation='relu', sn=False)
m = models.DuelingMLP(m, env.action_space.n,
activation='relu', noisy=True, sn=False)
return m.to(DEVICE)
def train(dqn):
print('training started')
dqn.save_explorations(75)
dqn.load_explorations()
# dqn.load_explorations('saved/1673839254HornetTweaks/transitions')
# raise ValueError
dqn.learn() # warmup
saved_rew = float('-inf')
saved_train_rew = float('-inf')
for i in range(1, 550):
print('episode', i)
rew, loss, lr = dqn.run_episode()
if rew > saved_train_rew and dqn.eps < 0.11:
print('new best train model found')
saved_train_rew = rew
dqn.save_models('besttrain', online_only=True)
if i % 10 == 0:
dqn.run_episode(random_action=True)
if i >= 100:
eval_rew = dqn.evaluate()
if eval_rew > saved_rew:
print('new best eval model found')
saved_rew = eval_rew
dqn.save_models('best', online_only=True)
dqn.save_models('latest', online_only=True)
dqn.log({'reward': rew, 'loss': loss, 'total steps': dqn.steps}, i)
print(f'episode {i} finished, total step {dqn.steps}, learned {dqn.learn_steps}, epsilon {dqn.eps}',
f'total rewards {round(rew, 3)}, loss {round(loss, 3)}, current lr {round(lr, 8)}',
f'total memory usage {psutil.virtual_memory().percent}%', sep='\n')
print()
dqn.save_models('latest', online_only=False)
def main():
n_frames = 4
# env = hkenv.HKEnv((160, 160), rgb=False, gap=0.165, w1=0.8, w2=0.8, w3=-0.0001)
env = hkenv.HKEnvV2((192, 192), rgb=False, gap=0.17, w1=0.8, w2=0.5, w3=-8e-5)
m = get_model(env, n_frames)
replay_buffer = buffer.MultistepBuffer(180000, n=10, gamma=0.99, prioritized=None)
# prioritized={
# 'alpha': 0.5,
# 'beta': 0.4,
# 'beta_anneal': 0.6 / 550.
# })
dqn = trainer.Trainer(env=env, replay_buffer=replay_buffer,
n_frames=n_frames, gamma=0.99, eps=0.,
eps_func=(lambda val, step: 0.),
target_steps=8000,
learn_freq=4,
model=m,
lr=8e-5,
lr_decay=False,
criterion=torch.nn.MSELoss(),
batch_size=32,
device=DEVICE,
is_double=True,
drq=True,
svea=True,
reset=0, # no reset
n_targets=1,
save_suffix='HornetV2',
no_save=False)
train(dqn)
if __name__ == '__main__':
main()