Skip to content

Commit ee42b74

Browse files
committed
paramtrized reward cliping
1 parent 887f598 commit ee42b74

File tree

5 files changed

+83
-69
lines changed

5 files changed

+83
-69
lines changed

.idea/workspace.xml

+69-61
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

atari.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import gym
22
import argparse
3+
import numpy as np
34
import atari_py
45
from game_models.ddqn_game_model import DDQNTrainer, DDQNSolver
56
from game_models.ge_game_model import GETrainer, GESolver
@@ -13,12 +14,12 @@
1314
class Atari:
1415

1516
def __init__(self):
16-
game_name, game_mode, render, total_step_limit, total_run_limit = self._args()
17+
game_name, game_mode, render, total_step_limit, total_run_limit, clip = self._args()
1718
env_name = game_name + "Deterministic-v4" # Handles frame skipping (4) at every iteration
1819
env = MainGymWrapper.wrap(gym.make(env_name))
19-
self._main_loop(self._game_model(game_mode, game_name, env.action_space.n), env, render, total_step_limit, total_run_limit)
20+
self._main_loop(self._game_model(game_mode, game_name, env.action_space.n), env, render, total_step_limit, total_run_limit, clip)
2021

21-
def _main_loop(self, game_model, env, render, total_step_limit, total_run_limit):
22+
def _main_loop(self, game_model, env, render, total_step_limit, total_run_limit, clip):
2223
run = 0
2324
total_step = 0
2425
while True:
@@ -42,6 +43,8 @@ def _main_loop(self, game_model, env, render, total_step_limit, total_run_limit)
4243

4344
action = game_model.move(current_state)
4445
next_state, reward, terminal, info = env.step(action)
46+
if clip:
47+
np.sign(reward)
4548
score += reward
4649
game_model.remember(current_state, action, reward, next_state, terminal)
4750
current_state = next_state
@@ -60,18 +63,21 @@ def _args(self):
6063
parser.add_argument("-r", "--render", help="Choose if the game should be rendered. Default is 'False'.", default=False)
6164
parser.add_argument("-tsl", "--total_step_limit", help="Choose how many total steps (frames visible by agent) should be performed. Default is '10000000'.", default=10000000)
6265
parser.add_argument("-trl", "--total_run_limit", help="Choose after how many runs should we stop. Default is None (no limit).", default=None)
66+
parser.add_argument("-c", "--clip", help="Choose whether we should clip rewards to (0, 1) range. Default is 'True'", default=True)
6367
args = parser.parse_args()
6468
game_mode = args.mode
6569
game_name = args.game
6670
render = args.render
6771
total_step_limit = args.total_step_limit
68-
total_run_limit = args.run_limit
72+
total_run_limit = args.total_run_limit
73+
clip = args.clip
6974
print "Selected game: " + str(game_name)
7075
print "Selected mode: " + str(game_mode)
7176
print "Should render: " + str(render)
77+
print "Should clip: " + str(clip)
7278
print "Total step limit: " + str(total_step_limit)
7379
print "Total run limit: " + str(total_run_limit)
74-
return game_name, game_mode, render, total_step_limit, total_run_limit
80+
return game_name, game_mode, render, total_step_limit, total_run_limit, clip
7581

7682
def _game_model(self, game_mode,game_name, action_space):
7783
if game_mode == "ddqn_training":

game_models/ddqn_game_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
EXPLORATION_MAX = 1.0
1818
EXPLORATION_MIN = 0.1
19-
EXPLORATION_TEST = 0.02
19+
EXPLORATION_TEST = 0.01
2020
EXPLORATION_STEPS = 850000
2121
EXPLORATION_DECAY = (EXPLORATION_MAX-EXPLORATION_MIN)/EXPLORATION_STEPS
2222

gym_wrappers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -191,5 +191,5 @@ def wrap(env):
191191
env = ProcessFrame84(env)
192192
env = ChannelsFirstImageShape(env)
193193
env = FrameStack(env, 4)
194-
env = ClippedRewardsWrapper(env)
194+
# env = ClippedRewardsWrapper(env)
195195
return env

logger.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def _save_png(self, input_path, output_path, small_batch_length, big_batch_lengt
9595
batch_averages_y.append(mean(temp_values_in_batch))
9696
batch_averages_x.append(len(batch_averages_y)*big_batch_length)
9797
temp_values_in_batch = []
98-
if batch_averages_x and batch_averages_y:
98+
if len(batch_averages_x) > 1:
9999
plt.plot(batch_averages_x, batch_averages_y, linestyle="--", label="last " + str(big_batch_length) + " average")
100100

101101
if len(x) > 1:

0 commit comments

Comments
 (0)