-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
91 lines (78 loc) · 2.58 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
from carla_env.env import CarlaEnv
import carla_env.scenarios as scenarios
import carla_env.rewards as rewards
from example_collector import EpisodeCollector
from train import TRAIN_CONFIG, TRAIN_MODEL
import deepq_learner
MAX_EPISODES = 1000
class DoneError(BaseException):
pass
# Update the environment for testing components
TEST_ENV = TRAIN_CONFIG.copy()
TEST_ENV.update({
"server_map": "/Game/Maps/Town01",
"reward_function": rewards.REWARD_LANE_KEEP,
"scenarios": scenarios.TOWN1_LANE_KEEP,
"log_images": False,
"quality": "Low",
})
def main():
collector = EpisodeCollector()
def on_step(py_measurements):
collector.step(py_measurements)
def on_next():
collector.next()
if collector.valid_episodes >= MAX_EPISODES:
raise DoneError()
env = CarlaEnv(TEST_ENV)
env.on_step = on_step
env.on_next = on_next
carla_out_path = "/media/grant/FastData/carla"
if not os.path.exists(carla_out_path):
os.mkdir(carla_out_path)
checkpoint_path = os.path.join(carla_out_path, "checkpoints")
if not os.path.exists(checkpoint_path):
os.mkdir(checkpoint_path)
# Learn
learn_config = deepq_learner.DEEPQ_CONFIG.copy()
learn_config.update({
"gpu_memory_fraction": 0.7,
"lr": 1e-90,
"max_timesteps": int(1e8),
"buffer_size": int(1e3),
"exploration_fraction": 0.000001,
"exploration_final_eps": 0.000001,
"train_freq": 4000000,
"learning_starts": 1000000,
"target_network_update_freq": 10000000,
"gamma": 0.99,
"prioritized_replay": True,
"prioritized_replay_alpha": 0.6,
"checkpoint_freq": 100000000,
"checkpoint_path": checkpoint_path,
"print_freq": 1
})
learn = deepq_learner.DeepqLearner(env=env, q_func=TRAIN_MODEL, config=learn_config)
print("Running training....")
try:
learn.run()
except DoneError:
pass
except Exception as e:
print("Training Failed!")
raise e
finally:
print("Closing environment.")
env.close()
# Determine results
results = collector.results()
print(",".join(str(x) for x in results))
with open(carla_out_path + '/results.csv', 'w') as file:
collector.save_metrics_file(file)
with open(carla_out_path + '/crashes.csv', 'w') as file:
collector.save_crashes_file(file)
with open(carla_out_path + '/out-of-lanes.csv', 'w') as file:
collector.save_out_of_lane_file(file)
if __name__ == '__main__':
main()