Skip to content

Commit 70f76fb

Browse files
committed
hopefully working memory management of experience replay
1 parent 2cd2cb1 commit 70f76fb

File tree

4 files changed

+58
-33
lines changed

4 files changed

+58
-33
lines changed

.idea/workspace.xml

+33-25
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

atari.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,21 @@ def _main_loop(self, game_model, env, render, total_step_limit):
5858

5959
def _preprocess_observation(self, obs):
6060
image = Image.fromarray(obs, "RGB").convert("L").resize((FRAME_SIZE, FRAME_SIZE))
61-
return np.asarray(image.getdata(), dtype=np.uint8).reshape(image.size[1], image.size[0]) #TODO: possibly memory heavy
61+
return np.asarray(image.getdata(), dtype=np.uint8).reshape(image.size[1], image.size[0]) #TODO: possibly memory heavy, we should pass regular lists here
62+
63+
# class WarpFrame(gym.ObservationWrapper):
64+
# def __init__(self, env):
65+
# """Warp frames to 84x84 as done in the Nature paper and later work."""
66+
# gym.ObservationWrapper.__init__(self, env)
67+
# self.width = 84
68+
# self.height = 84
69+
# self.observation_space = spaces.Box(low=0, high=255,
70+
# shape=(self.height, self.width, 1), dtype=np.uint8)
71+
#
72+
# def observation(self, frame):
73+
# frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
74+
# frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
75+
# return frame[:, :, None]
6276

6377
def _args(self):
6478
parser = argparse.ArgumentParser()

convolutional_neural_network.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def __init__(self, input_shape, action_space):
3232
self.model.add(Dense(512, activation="relu"))
3333
self.model.add(Dense(action_space))
3434
self.model.compile(loss="mean_squared_error",
35-
optimizer=RMSprop(lr=0.00025, rho=0.95, epsilon=0.01),
35+
optimizer=RMSprop(lr=0.00025,
36+
rho=0.95,
37+
epsilon=0.01),
3638
metrics=["accuracy"])
3739
self.model.summary()

game_models/ddqn_game_model.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
class DDQNGameModel(BaseGameModel):
2626

2727
def __init__(self, game_name, mode_name, input_shape, action_space, logger_path, model_path):
28-
BaseGameModel.__init__(self, game_name,
28+
BaseGameModel.__init__(self,
29+
game_name,
2930
mode_name,
3031
logger_path,
3132
input_shape,
@@ -47,7 +48,7 @@ class DDQNSolver(DDQNGameModel):
4748

4849
def __init__(self, game_name, input_shape, action_space):
4950
testing_model_path = "./output/neural_nets/" + game_name + "/ddqn/testing/model.h5"
50-
assert os.path.exists(os.path.dirname(testing_model_path)), "No testing model in: " + str(testing_model_path)
51+
assert os.path.exists(os.path.dirname(testing_model_path)), "No model to test in: " + str(testing_model_path)
5152
DDQNGameModel.__init__(self,
5253
game_name,
5354
"DDQN testing",
@@ -89,10 +90,10 @@ def move(self, state):
8990
return np.argmax(q_values[0])
9091

9192
def remember(self, current_state, action, reward, next_state, terminal):
92-
self.memory.append({"current_state": current_state, #np.asarray([current_state])
93+
self.memory.append({"current_state": np.asarray([current_state]),
9394
"action": action,
9495
"reward": reward,
95-
"next_state": next_state,
96+
"next_state": np.asarray([next_state]),
9697
"terminal": terminal})
9798
if len(self.memory) > MEMORY_SIZE:
9899
self.memory.pop(0)
@@ -127,9 +128,9 @@ def _train(self):
127128
max_q_values = []
128129

129130
for entry in batch:
130-
current_state = np.expand_dims(entry["current_state"].astype(np.float64), axis=0)
131+
current_state = entry["current_state"].astype(np.float64)
131132
current_states.append(current_state)
132-
next_state = np.expand_dims(entry["next_state"].astype(np.float64), axis=0)
133+
next_state = entry["next_state"].astype(np.float64)
133134
next_state_prediction = self.ddqn_target.predict(next_state).ravel()
134135
next_q_value = np.max(next_state_prediction)
135136
q = list(self.ddqn.predict(current_state)[0])

0 commit comments

Comments
 (0)