-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathBeadyRing_fullObs_train.py
128 lines (106 loc) · 4.17 KB
/
BeadyRing_fullObs_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import socket
import struct
import pickle
import numpy as np
import gym
from stable_baselines3 import A2C, PPO, DQN
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
from stable_baselines3.common.env_checker import check_env
import wandb
from wandb.integration.sb3 import WandbCallback
class Connection:
def __init__(self, s):
self._socket = s
self._buffer = bytearray()
def receive_object(self):
while len(self._buffer) < 4 or len(self._buffer) < struct.unpack("<L", self._buffer[:4])[0] + 4:
new_bytes = self._socket.recv(16)
if len(new_bytes) == 0:
return None
self._buffer += new_bytes
length = struct.unpack("<L", self._buffer[:4])[0]
header, body = self._buffer[:4], self._buffer[4:length + 4]
obj = pickle.loads(body)
self._buffer = self._buffer[length + 4:]
return obj
def send_object(self, d):
body = pickle.dumps(d, protocol=2)
header = struct.pack("<L", len(body))
msg = header + body
self._socket.send(msg)
class Env(gym.Env):
metadata = {'render.modes': ['rgb_array']}
def __init__(self):
super(Env, self).__init__()
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
addr = ("127.0.0.1", 50710)
s.bind(addr)
s.listen(1)
clientsocket, address = s.accept()
self._socket = clientsocket
self._conn = Connection(clientsocket)
self.grid_len = 41 # Make sure grid size matches max_row_len in the Gh env
self.action_space = gym.spaces.Discrete(2)
self.observation_space = gym.spaces.Box(low=0, high=255,
shape=(1, self.grid_len, self.grid_len),
dtype=np.uint8)
def reset(self):
self._conn.send_object("reset")
msg = self._conn.receive_object()
return np.asarray(msg["state"]).reshape(1, self.grid_len, self.grid_len)
def step(self, action):
self._conn.send_object(action.item())
msg = self._conn.receive_object()
obs = np.asarray(msg["state"]).reshape(1, self.grid_len, self.grid_len)
rwd = msg["reward"]
done = msg["done"]
info = msg["info"]
return obs, rwd, done, info
def render(self, mode='rgb_array'):
msg = self._conn.receive_object()
if mode == 'rgb_array':
img = np.asarray(msg["state"]).reshape(1, self.grid_len, self.grid_len)
return img
def close(self):
self._conn.send_object("close")
self._socket.close()
# Log in to W&B account
print('Wandb login ...')
wandb.login(key='') # place wandb key here!
config = {
"policy_type": "CnnPolicy",
"total_timesteps": 1000000
}
run = wandb.init(
entity='', #Replace with your wandb entity & project
project="BeadyRing_DRL",
config=config,
sync_tensorboard=True # auto-upload sb3's tensorboard metrics
)
print('\n Reset and Loop HoopSnake Gh component ... \n')
def make_env():
env = Env()
# debug
# check_env(env) # check if the env follows the gym interface
env.reset()
env = Monitor(env) # record stats such as returns
return env
env = DummyVecEnv([make_env])
# env = VecVideoRecorder(env, f"videos/{run.id}", record_video_trigger=lambda x: x % 4410 == 0,
# video_length=441) #21*21*10 = 4410 | 21 is self._max_row_len in Gh env
model = DQN(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}", device='cuda')
model.learn(total_timesteps=config["total_timesteps"], log_interval=10,
callback=WandbCallback(model_save_path=f'models/{run.id}', model_save_freq=100))
cum_rwd = 0
obs = env.reset()
for i in range(300):
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
cum_rwd += reward
if done:
obs = env.reset()
print("Return = ", cum_rwd)
cum_rwd = 0
env.close()
run.finish()