-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcocontrol.py
131 lines (92 loc) · 3.83 KB
/
cocontrol.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import numpy as np
import torch
from collections import namedtuple
from cocontrol.environment import CoControlEnv, CoControlAgent
from cocontrol.training import PPOLearner
from cocontrol.util import print_progress, plot, start_plot, save_plot
PARAMETER_PATH="parameters.pt"
def run_learner(learner, env, result_path, epochs=15, checkpoint_window=1):
performances = ()
episode_cnt = 0
episode_step = 0
max_avg_score = 0.0
for cnt, data in enumerate(learner.train(epochs)):
episode_step += 1
performance, score, terminal = data
if terminal:
episode_cnt += 1
episode_step = 0
if terminal and np.mean(env.get_score_history()[-checkpoint_window:]) > max_avg_score:
learner.save(result_path)
print("\nSaved checkpoint in epoch " + str(cnt) \
+ " with avg score: " + str(np.mean(env.get_score_history()[-checkpoint_window:])) + "\n")
max_avg_score = np.mean(env.get_score_history()[-checkpoint_window:])
print_progress(episode_cnt, episode_step, performance.item(), env.get_score_history(), total=8)
if terminal:
print("")
return env.get_score_history(), performances
def run_dummy(env):
class DummyPolicy:
def sample(self, states):
return torch.rand(env.get_agent_size(), env.get_action_size()) * 2.0 - 1.0
agent = CoControlAgent(DummyPolicy())
episode = enumerate(env.generate_episode(agent))
for count, step_data in episode:
# Consume the generated steps
pass
def replay(env, parameter_path):
print("Replay from " + parameter_path)
learner = PPOLearner(env=env)
learner.load(parameter_path)
replay_agent = learner.get_agent(0.5)
for _ in range(100):
episode = env.generate_episode(replay_agent, train_mode=True)
for _ in episode:
# Consume the generated steps
pass
print("Average score on 100 episodes: " + str(np.mean(env.get_score_history())))
def run_demo(env, parameter_path):
print("Replay from " + parameter_path)
learner = PPOLearner(env=env)
learner.load(parameter_path)
replay_agent = learner.get_agent(0.2)
episode = env.generate_episode(replay_agent)
for count, step_data in enumerate(episode):
# Consume the generated steps
pass
print("Score: " + str(env.get_score()))
def learn(env, epochs, parameter_path):
print("\nStart learning\n")
learner = PPOLearner(env=env)
scores, losses = run_learner(learner, env, parameter_path, epochs=epochs)
print("\nStore results\n")
plot(scores, path="scores.png", windows=[1, 100],
colors=['r', 'g'], labels=["Agent avg", "100 episode avg"])
torch.save(scores, "scores.pt")
return scores, losses
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-r','--replay',
help='Replay 100 epsisodes from the stored parameters',
action='store_true', required=False)
parser.add_argument('-s','--show',
help='Demonstrate from the stored parameters',
action='store_true', required=False)
parser.add_argument('-d','--dummy',
help='Demonstrate the environment with a dummy policy',
action='store_true', required=False)
parser.add_argument('-n','--epochs',
help='Number of epochs used for training',
type=int, default=16, required=False)
args = parser.parse_args()
env=CoControlEnv()
if args.replay:
replay(env, PARAMETER_PATH)
elif args.show:
run_demo(env, PARAMETER_PATH)
elif args.dummy:
run_dummy(env)
else:
scores, losses = learn(env, args.epochs, PARAMETER_PATH)
plot(scores, windows=[1, 100], colors=['b', 'g'], labels=["Score", "Avg"], path=None)