-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
317 lines (265 loc) · 10.9 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
import argparse
import logging
import multiprocessing as mp
import os
import time
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from environments import SimulatedSpe_edEnv, WebsocketEnv
from environments.logging import CloudUploader, Spe_edLogger
from environments.spe_ed import SavedGame
from heuristics import PathLengthHeuristic
from policies import HeuristicPolicy, load_named_policy
from tournament.tournament import run_tournament
# Set up logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
default_window_size = (720, 720)
def play(env, pol, show=False, render_file=None, fps=10, logger=None, silent=True, window_size=default_window_size):
obs = env.reset()
if show and not env.render(screen_width=window_size[0], screen_height=window_size[1]):
return
if render_file is not None: # Initialize video writer
from imageio_ffmpeg import write_frames
writer = write_frames(render_file, window_size, fps=fps, codec="libx264", quality=8)
writer.send(None) # seed the generator
writer.send(
env.render(mode="rgb_array", screen_width=window_size[0], screen_height=window_size[1]).copy(order="C")
)
if logger is not None: # Log initial state
states = [env.game_state()]
time_limits = []
done = False
with tqdm(disable=silent) as pbar:
while not done:
action = pol.act(*obs)
obs, reward, done, _ = env.step(action)
if show and not env.render(screen_width=window_size[0], screen_height=window_size[1]):
return
if render_file is not None:
writer.send(
env.render(mode="rgb_array", screen_width=window_size[0], screen_height=window_size[1]).copy(
order="C"
)
)
if logger is not None:
states.append(env.game_state())
if isinstance(env, WebsocketEnv):
time_limits.append(env.time_limit)
pbar.update()
if logger is not None:
logger.log(states, time_limits)
if render_file is not None:
writer.close()
if show:
# Show final state
while True:
if not env.render(screen_width=window_size[0], screen_height=window_size[1]):
return
plt.pause(0.01) # Sleep
def show_logfile(log_file, window_size=default_window_size):
"""Render logfile to mp4"""
from matplotlib.widgets import Slider
from visualization import Spe_edAx
def format_state(t):
s = "Players:\n"
s += "\n".join(str(p) for p in game.player_states[t]) + "\n"
s += "\nActions:\n"
if t + 1 < len(game.data):
s += "\n".join(str(a) for a in game.infer_actions(t)) + "\n"
else:
s += "\n".join("win" if p.active else "inactive" for p in game.player_states[t]) + "\n"
return s
game = SavedGame.load(log_file)
if game.you is not None:
game.move_controlled_player_to_front()
fig = plt.figure(figsize=(window_size[0] / 100, window_size[1] / 100), dpi=100)
ax1 = plt.subplot(1, 1, 1)
viewer = Spe_edAx(fig, ax1, game.cell_states[0], game.player_states[0])
plt.tight_layout()
plt.subplots_adjust(bottom=0.1, right=0.6)
slider = Slider(plt.axes([0.1, 0.025, 0.8, 0.03]), "t", 0, len(game.data) - 1, valinit=0, valstep=1, valfmt="%d")
text_box = fig.text(0.61, 0.975, format_state(0), ha="left", va="top")
def change_t(val):
t = int(slider.val)
viewer.update(game.cell_states[t], game.player_states[t])
text_box.set_text(format_state(t))
slider.on_changed(change_t)
plt.show()
def render_logfile(log_file, fps=10, silent=False, window_size=default_window_size):
"""Render logfile to mp4.
Resulting .mp4 is placed alongside the .json file.
Args:
log_file: Log file to render.
fps: FPS of generated video.
silent: Show no progress bar.
"""
import subprocess
import tempfile
from imageio_ffmpeg import get_ffmpeg_exe
from visualization import Spe_edAx, render_video
def temp_file_name(suffix):
"""Create the name of a temp file with given suffix without opening it."""
return Path(tempfile.gettempdir()) / (next(tempfile._get_candidate_names()) + suffix)
game = SavedGame.load(log_file)
if game.you:
game.move_controlled_player_to_front()
fig = plt.figure(
figsize=(window_size[0] / 100, window_size[1] / 100),
dpi=100,
tight_layout=True,
)
ax = plt.subplot(1, 1, 1)
viewer = Spe_edAx(fig, ax, game.cell_states[0], game.player_states[0])
def frames():
"""Draw all game states"""
for i in tqdm(range(len(game.cell_states)), desc=f"Rendering {log_file.name}", disable=silent):
viewer.update(game.cell_states[i], game.player_states[i])
fig.canvas.draw()
frame = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(window_size[0], window_size[1], 3)
yield frame
# Render video to temp file
tmp_video = temp_file_name(".mp4")
width, height = fig.canvas.get_width_height()
render_video(tmp_video, frames(), width, height, fps=fps)
# Create thumbnail in temp file
tmp_thumbnail = temp_file_name(".jpg")
plt.savefig(tmp_thumbnail)
# Join both in log dir
subprocess.run(
[
get_ffmpeg_exe(),
"-i",
str(tmp_video),
"-i",
str(tmp_thumbnail),
"-y",
"-map",
"0",
"-map",
"1",
"-c",
"copy",
"-disposition:v:1",
"attached_pic",
"-v",
"warning",
str(log_file.parent / (log_file.name[:-5] + ".mp4")),
]
)
# Cleanup
plt.close(fig)
tmp_video.unlink()
tmp_thumbnail.unlink()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="spe_ed")
parser.add_argument(
"mode",
nargs="?",
choices=["play", "replay", "render_logdir", "plot", "tournament", "tournament-plot"],
default="play",
)
parser.add_argument("--show", action="store_true", help="Display games using an updating matplotlib plot.")
parser.add_argument("--render-file", type=str, default=None, help="File to render to. Should end with .mp4")
parser.add_argument(
"--sim",
action="store_true",
help="The simulator environment runs a local simulation of Spe_ed instead of using the webserver.",
)
parser.add_argument("--log-file", type=str, default=None, help="Path to a log file, used to load and replay games.")
parser.add_argument("--log-dir", type=str, default=None, help="Directory for storing or retrieving logs.")
parser.add_argument(
"--t-config",
type=str,
default="./tournament/tournament_config.py",
help="Path of the tournament config file containing which settings to run.",
)
parser.add_argument("--upload", action="store_true", help="Upload generated log to cloud server.")
parser.add_argument("--fps", type=int, default=10, help="FPS for rendering.")
parser.add_argument(
"--cores", type=int, default=None, help="Number of cores for multiprocessing, default uses all."
)
parser.add_argument("--repeat", type=bool, default=False, help="Play endlessly.")
args = parser.parse_args()
if args.mode == "render_logdir":
log_dir = Path(args.log_dir)
if not log_dir.is_dir():
logging.error(f"{log_dir} is not a directory")
quit(1)
log_files = []
for log_file in log_dir.iterdir():
if not log_file.name.endswith(".json"):
continue
if (log_dir / (log_file.name[:-5] + ".mp4")).exists():
continue
log_files.append(log_file)
with mp.Pool(args.cores) as pool, tqdm(desc="Rendering games", total=len(log_files)) as pbar:
for log_file in log_files:
pool.apply_async(render_logfile, (log_file, args.fps, True), callback=lambda _: pbar.update())
pool.close()
pool.join()
elif args.mode == "replay":
show_logfile(args.log_file)
elif args.mode == "tournament":
from statistics import create_tournament_plots
log_dir = Path(args.log_dir)
run_tournament(args.show, log_dir, args.t_config, args.cores)
create_tournament_plots(log_dir, log_dir.parent)
elif args.mode == "tournament-plot":
from statistics import create_tournament_plots
log_dir = Path(args.log_dir)
if not log_dir.is_dir():
logging.error(f"{log_dir} is not a directory")
quit(1)
create_tournament_plots(log_dir, log_dir.parent)
elif args.mode == "plot":
from statistics import create_plots
log_dir = Path(args.log_dir)
if not log_dir.is_dir():
logging.error(f"{log_dir} is not a directory")
quit(1)
create_plots(log_dir, log_dir.parent / "statistics.csv")
else:
# Create logger
if args.log_dir is not None:
logger_callbacks = []
if args.upload:
logger_callbacks.append(
CloudUploader(
os.environ["CLOUD_URL"],
os.environ["CLOUD_USER"],
os.environ["CLOUD_PASSWORD"],
remote_dir="logs/",
).upload
)
logger = Spe_edLogger(args.log_dir, logger_callbacks)
else:
logger = None
# Create environment
if args.sim:
env = SimulatedSpe_edEnv(40, 40, [HeuristicPolicy(PathLengthHeuristic(10)) for _ in range(5)])
else:
env = WebsocketEnv(os.environ["URL"], os.environ["KEY"], os.environ["TIME_URL"])
# Create policy
pol = load_named_policy("GarrukV3")
while True:
try:
play(
env,
pol,
show=args.show,
render_file=args.render_file,
fps=args.fps,
logger=logger,
silent=args.repeat,
)
except Exception:
logging.exception("Exception during play")
time.sleep(60) # Sleep for a bit and try again
if not args.repeat:
break