-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathexample_ray_ma_teams.py
62 lines (54 loc) · 1.87 KB
/
example_ray_ma_teams.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import ray
from ray import tune
from soccer_twos import EnvType
from utils import create_rllib_env
NUM_ENVS_PER_WORKER = 3
if __name__ == "__main__":
ray.init()
tune.registry.register_env("Soccer", create_rllib_env)
temp_env = create_rllib_env({"variation": EnvType.multiagent_team})
obs_space = temp_env.observation_space
act_space = temp_env.action_space
temp_env.close()
analysis = tune.run(
"PPO",
name="PPO_teams_1",
config={
# system settings
"num_gpus": 1,
"num_workers": 6,
"num_envs_per_worker": NUM_ENVS_PER_WORKER,
"log_level": "INFO",
"framework": "torch",
# RL setup
"multiagent": {
"policies": {
"default": (None, obs_space, act_space, {}),
},
"policy_mapping_fn": tune.function(lambda _: "default"),
"policies_to_train": ["default"],
},
"env": "Soccer",
"env_config": {
"num_envs_per_worker": NUM_ENVS_PER_WORKER,
"variation": EnvType.multiagent_team,
},
},
stop={
"timesteps_total": 15000000, # 15M
# "time_total_s": 14400, # 4h
},
checkpoint_freq=100,
checkpoint_at_end=True,
local_dir="./ray_results",
# restore="./ray_results/PPO_teams_1/PPO_Soccer_ID/checkpoint_00X/checkpoint-X",
)
# Gets best trial based on max accuracy across all training iterations.
best_trial = analysis.get_best_trial("episode_reward_mean", mode="max")
print(best_trial)
# Gets best checkpoint for trial based on accuracy.
best_checkpoint = analysis.get_best_checkpoint(
trial=best_trial, metric="episode_reward_mean", mode="max"
)
print(best_checkpoint)
print("Done training")