From 3e4db23256c8d8a219d792c41b1b0900c628a854 Mon Sep 17 00:00:00 2001 From: cmwslw Date: Wed, 11 Nov 2015 23:51:57 -0500 Subject: [PATCH 1/5] Double DQN support. --- deep_q_rl/launcher.py | 5 +++ deep_q_rl/q_network.py | 21 +++++++--- deep_q_rl/run_double.py | 66 ++++++++++++++++++++++++++++++++ deep_q_rl/run_nature.py | 1 + deep_q_rl/run_nips.py | 1 + deep_q_rl/test/test_q_network.py | 12 +++--- 6 files changed, 95 insertions(+), 11 deletions(-) create mode 100755 deep_q_rl/run_double.py diff --git a/deep_q_rl/launcher.py b/deep_q_rl/launcher.py index c136f01..e3df92a 100755 --- a/deep_q_rl/launcher.py +++ b/deep_q_rl/launcher.py @@ -138,6 +138,10 @@ def process_args(args, defaults, description): type=bool, default=defaults.CUDNN_DETERMINISTIC, help=('Whether to use deterministic backprop. ' + '(default: %(default)s)')) + parser.add_argument('--use_double', dest="use_double", + type=bool, default=defaults.USE_DOUBLE, + help=('Whether to use Double DQN. ' + + '(default: %(default)s)')) parameters = parser.parse_args(args) if parameters.experiment_prefix is None: @@ -216,6 +220,7 @@ def launch(args, defaults, description): parameters.momentum, parameters.clip_delta, parameters.freeze_interval, + parameters.use_double, parameters.batch_size, parameters.network_type, parameters.update_rule, diff --git a/deep_q_rl/q_network.py b/deep_q_rl/q_network.py index 0fa360b..0c5b425 100644 --- a/deep_q_rl/q_network.py +++ b/deep_q_rl/q_network.py @@ -28,7 +28,7 @@ class DeepQLearner: def __init__(self, input_width, input_height, num_actions, num_frames, discount, learning_rate, rho, rms_epsilon, momentum, clip_delta, freeze_interval, - batch_size, network_type, update_rule, + use_double, batch_size, network_type, update_rule, batch_accumulator, rng, input_scale=255.0): self.input_width = input_width @@ -43,8 +43,13 @@ def __init__(self, input_width, input_height, num_actions, self.momentum = momentum self.clip_delta = clip_delta self.freeze_interval = freeze_interval + self.use_double = use_double self.rng = rng + # Using Double DQN is pointless without periodic freezing + if self.use_double: + assert self.freeze_interval > 0 + lasagne.random.set_rng(self.rng) self.update_counter = 0 @@ -93,9 +98,15 @@ def __init__(self, input_width, input_height, num_actions, next_states / input_scale) next_q_vals = theano.gradient.disconnected_grad(next_q_vals) - target = (rewards + - (T.ones_like(terminals) - terminals) * - self.discount * T.max(next_q_vals, axis=1, keepdims=True)) + if self.use_double: + maxaction = T.argmax(q_vals, axis=1, keepdims=True) + target = (rewards + + (T.ones_like(terminals) - terminals) * + self.discount * next_q_vals[maxaction]) + else: + target = (rewards + + (T.ones_like(terminals) - terminals) * + self.discount * T.max(next_q_vals, axis=1, keepdims=True)) diff = target - q_vals[T.arange(batch_size), actions.reshape((-1,))].reshape((-1, 1)) @@ -476,7 +487,7 @@ def build_linear_network(self, input_width, input_height, output_dim, return l_out def main(): - net = DeepQLearner(84, 84, 16, 4, .99, .00025, .95, .95, 10000, + net = DeepQLearner(84, 84, 16, 4, .99, .00025, .95, .95, 10000, False, 32, 'nature_cuda') diff --git a/deep_q_rl/run_double.py b/deep_q_rl/run_double.py new file mode 100755 index 0000000..e1e9ac1 --- /dev/null +++ b/deep_q_rl/run_double.py @@ -0,0 +1,66 @@ +#! /usr/bin/env python +""" +Execute a training run of deep-Q-Leaning with parameters that +are consistent with: + +Human-level control through deep reinforcement learning. +Nature, 518(7540):529-533, February 2015 + +""" + +import launcher +import sys + +class Defaults: + # ---------------------- + # Experiment Parameters + # ---------------------- + STEPS_PER_EPOCH = 250000 + EPOCHS = 200 + STEPS_PER_TEST = 125000 + + # ---------------------- + # ALE Parameters + # ---------------------- + BASE_ROM_PATH = "../roms/" + ROM = 'breakout.bin' + FRAME_SKIP = 4 + REPEAT_ACTION_PROBABILITY = 0 + + # ---------------------- + # Agent/Network parameters: + # ---------------------- + UPDATE_RULE = 'deepmind_rmsprop' + BATCH_ACCUMULATOR = 'sum' + LEARNING_RATE = .00025 + DISCOUNT = .99 + RMS_DECAY = .95 # (Rho) + RMS_EPSILON = .01 + MOMENTUM = 0 # Note that the "momentum" value mentioned in the Nature + # paper is not used in the same way as a traditional momentum + # term. It is used to track gradient for the purpose of + # estimating the standard deviation. This package uses + # rho/RMS_DECAY to track both the history of the gradient + # and the squared gradient. + CLIP_DELTA = 1.0 + EPSILON_START = 1.0 + EPSILON_MIN = .1 + EPSILON_DECAY = 1000000 + PHI_LENGTH = 4 + UPDATE_FREQUENCY = 4 + REPLAY_MEMORY_SIZE = 1000000 + BATCH_SIZE = 32 + NETWORK_TYPE = "nature_dnn" + FREEZE_INTERVAL = 10000 + REPLAY_START_SIZE = 50000 + RESIZE_METHOD = 'scale' + RESIZED_WIDTH = 84 + RESIZED_HEIGHT = 84 + DEATH_ENDS_EPISODE = 'true' + MAX_START_NULLOPS = 30 + DETERMINISTIC = True + CUDNN_DETERMINISTIC = False + USE_DOUBLE = True + +if __name__ == "__main__": + launcher.launch(sys.argv[1:], Defaults, __doc__) diff --git a/deep_q_rl/run_nature.py b/deep_q_rl/run_nature.py index 2da46bc..8199546 100755 --- a/deep_q_rl/run_nature.py +++ b/deep_q_rl/run_nature.py @@ -60,6 +60,7 @@ class Defaults: MAX_START_NULLOPS = 30 DETERMINISTIC = True CUDNN_DETERMINISTIC = False + USE_DOUBLE = False if __name__ == "__main__": launcher.launch(sys.argv[1:], Defaults, __doc__) diff --git a/deep_q_rl/run_nips.py b/deep_q_rl/run_nips.py index 8a6ddfc..1585f2c 100755 --- a/deep_q_rl/run_nips.py +++ b/deep_q_rl/run_nips.py @@ -55,6 +55,7 @@ class Defaults: MAX_START_NULLOPS = 0 DETERMINISTIC = True CUDNN_DETERMINISTIC = False + USE_DOUBLE = False if __name__ == "__main__": launcher.launch(sys.argv[1:], Defaults, __doc__) diff --git a/deep_q_rl/test/test_q_network.py b/deep_q_rl/test/test_q_network.py index 82cd142..87bc10d 100644 --- a/deep_q_rl/test/test_q_network.py +++ b/deep_q_rl/test/test_q_network.py @@ -114,7 +114,7 @@ def test_updates_sgd_no_freeze(self): self.mdp.num_actions, 1, self.discount, self.learning_rate, 0, 0, 0, 0, - freeze_interval, 1, 'linear', + freeze_interval, False, 1, 'linear', 'sgd', 'sum', 1.0) mdp = self.mdp @@ -157,7 +157,7 @@ def test_convergence_sgd_no_freeze(self): self.mdp.num_actions, 1, self.discount, self.learning_rate, 0, 0, 0, 0, - freeze_interval, 1, 'linear', + freeze_interval, False, 1, 'linear', 'sgd', 'sum', 1.0) @@ -178,7 +178,7 @@ def test_convergence_random_initialization(self): self.mdp.num_actions, 1, self.discount, self.learning_rate, 0, 0, 0, 0, - freeze_interval, 1, 'linear', + freeze_interval, False, 1, 'linear', 'sgd', 'sum', 1.0) # Randomize initial q-values: @@ -203,7 +203,7 @@ def test_convergence_sgd_permanent_freeze(self): self.mdp.num_actions, 1, self.discount, self.learning_rate, 0, 0, 0, 0, - freeze_interval, 1, 'linear', + freeze_interval, False, 1, 'linear', 'sgd', 'sum', 1.0) self.train(net, 1000) @@ -218,7 +218,7 @@ def test_convergence_sgd_frequent_freeze(self): self.mdp.num_actions, 1, self.discount, self.learning_rate, 0, 0, 0, 0, - freeze_interval, 1, 'linear', + freeze_interval, False, 1, 'linear', 'sgd', 'sum', 1.0) self.train(net, 1000) @@ -233,7 +233,7 @@ def test_convergence_sgd_one_freeze(self): self.mdp.num_actions, 1, self.discount, self.learning_rate, 0, 0, 0, 0, - freeze_interval, 1, 'linear', + freeze_interval, False, 1, 'linear', 'sgd', 'sum', 1.0) self.train(net, freeze_interval * 2) From e2aa85e111bb0c087d21cb16ca906e9949789629 Mon Sep 17 00:00:00 2001 From: Cory Walker Date: Thu, 12 Nov 2015 21:51:56 +0000 Subject: [PATCH 2/5] Bug fix, some testing code. --- deep_q_rl/q_network.py | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/deep_q_rl/q_network.py b/deep_q_rl/q_network.py index 0c5b425..28d8592 100644 --- a/deep_q_rl/q_network.py +++ b/deep_q_rl/q_network.py @@ -49,6 +49,7 @@ def __init__(self, input_width, input_height, num_actions, # Using Double DQN is pointless without periodic freezing if self.use_double: assert self.freeze_interval > 0 + # pass lasagne.random.set_rng(self.rng) @@ -91,18 +92,21 @@ def __init__(self, input_width, input_height, num_actions, q_vals = lasagne.layers.get_output(self.l_out, states / input_scale) if self.freeze_interval > 0: + # Nature. If using periodic freezing next_q_vals = lasagne.layers.get_output(self.next_l_out, next_states / input_scale) else: + # NIPS next_q_vals = lasagne.layers.get_output(self.l_out, next_states / input_scale) next_q_vals = theano.gradient.disconnected_grad(next_q_vals) if self.use_double: - maxaction = T.argmax(q_vals, axis=1, keepdims=True) + maxaction = T.argmax(q_vals, axis=1, keepdims=False) + temptargets = next_q_vals[T.arange(batch_size),maxaction].reshape((-1, 1)) target = (rewards + (T.ones_like(terminals) - terminals) * - self.discount * next_q_vals[maxaction]) + self.discount * temptargets) else: target = (rewards + (T.ones_like(terminals) - terminals) * @@ -156,8 +160,30 @@ def __init__(self, input_width, input_height, num_actions, updates = lasagne.updates.apply_momentum(updates, None, self.momentum) - self._train = theano.function([], [loss, q_vals], updates=updates, - givens=givens) + if False: + def inspect_inputs(i, node, fn): + if ('maxand' not in str(node).lower() and '12345' not in str(node)): + return + print i, node, "input(s) value(s):", [input[0] for input in fn.inputs], + raw_input('press enter') + + def inspect_outputs(i, node, fn): + if ('maxand' not in str(node).lower() and '12345' not in str(node)): + return + if '12345' in str(node): + print "output(s) value(s):", [np.asarray(output[0]) for output in fn.outputs] + else: + print "output(s) value(s):", [output[0] for output in fn.outputs] + raw_input('press enter') + + self._train = theano.function([], [loss, q_vals], updates=updates, + givens=givens, mode=theano.compile.MonitorMode( + pre_func=inspect_inputs, + post_func=inspect_outputs)) + theano.printing.debugprint(target) + else: + self._train = theano.function([], [loss, q_vals], updates=updates, + givens=givens) self._q_vals = theano.function([], q_vals, givens={states: self.states_shared}) From 91236a8e081d641610e2b59d587a56a88fc9444c Mon Sep 17 00:00:00 2001 From: Cory Walker Date: Fri, 13 Nov 2015 16:11:46 +0000 Subject: [PATCH 3/5] Checkpoint before instance shutdown. --- deep_q_rl/ale_run_watch.py | 2 +- deep_q_rl/q_network.py | 52 +++++++++++++++++++++++++------------- 2 files changed, 36 insertions(+), 18 deletions(-) diff --git a/deep_q_rl/ale_run_watch.py b/deep_q_rl/ale_run_watch.py index 67a0bd5..2f12b2a 100644 --- a/deep_q_rl/ale_run_watch.py +++ b/deep_q_rl/ale_run_watch.py @@ -9,7 +9,7 @@ import sys def run_watch(): - command = ['./run_nature.py', '--steps-per-epoch', '0', + command = ['./run_double.py', '--steps-per-epoch', '0', '--test-length', '10000', '--nn-file', sys.argv[1], '--display-screen'] diff --git a/deep_q_rl/q_network.py b/deep_q_rl/q_network.py index 28d8592..11f41dc 100644 --- a/deep_q_rl/q_network.py +++ b/deep_q_rl/q_network.py @@ -160,22 +160,22 @@ def __init__(self, input_width, input_height, num_actions, updates = lasagne.updates.apply_momentum(updates, None, self.momentum) - if False: - def inspect_inputs(i, node, fn): - if ('maxand' not in str(node).lower() and '12345' not in str(node)): - return - print i, node, "input(s) value(s):", [input[0] for input in fn.inputs], - raw_input('press enter') - - def inspect_outputs(i, node, fn): - if ('maxand' not in str(node).lower() and '12345' not in str(node)): - return - if '12345' in str(node): - print "output(s) value(s):", [np.asarray(output[0]) for output in fn.outputs] - else: - print "output(s) value(s):", [output[0] for output in fn.outputs] - raw_input('press enter') + def inspect_inputs(i, node, fn): + if ('maxand' not in str(node).lower() and '12345' not in str(node)): + return + print i, node, "input(s) value(s):", [input[0] for input in fn.inputs], + raw_input('press enter') + + def inspect_outputs(i, node, fn): + if ('maxand' not in str(node).lower() and '12345' not in str(node)): + return + if '12345' in str(node): + print "output(s) value(s):", [np.asarray(output[0]) for output in fn.outputs] + else: + print "output(s) value(s):", [output[0] for output in fn.outputs] + raw_input('press enter') + if False: self._train = theano.function([], [loss, q_vals], updates=updates, givens=givens, mode=theano.compile.MonitorMode( pre_func=inspect_inputs, @@ -184,8 +184,14 @@ def inspect_outputs(i, node, fn): else: self._train = theano.function([], [loss, q_vals], updates=updates, givens=givens) - self._q_vals = theano.function([], q_vals, - givens={states: self.states_shared}) + if False: + self._q_vals = theano.function([], q_vals, + givens={states: self.states_shared}, mode=theano.compile.MonitorMode( + pre_func=inspect_inputs, + post_func=inspect_outputs)) + else: + self._q_vals = theano.function([], q_vals, + givens={states: self.states_shared}) def build_network(self, network_type, input_width, input_height, output_dim, num_frames, batch_size): @@ -250,6 +256,18 @@ def choose_action(self, state, epsilon): if self.rng.rand() < epsilon: return self.rng.randint(0, self.num_actions) q_vals = self.q_vals(state) + scaled_q = q_vals - np.min(q_vals) + scaled_q = scaled_q / np.max(scaled_q) + for cmpi in range(1*15, -1, -1): + cmpval = float(cmpi) / 15 + line = '' + for i in range(4): + if scaled_q[i] > cmpval: + line += '####' + line += '\t' + print line + print 'noop\tfire\tright\tleft' + print q_vals return np.argmax(q_vals) def reset_q_hat(self): From 94f7d724f69291ecfe23ac656aef38c63a776b8d Mon Sep 17 00:00:00 2001 From: cmwslw Date: Sat, 21 Nov 2015 22:18:52 -0500 Subject: [PATCH 4/5] Prepare for pull request. --- deep_q_rl/ale_run_watch.py | 2 +- deep_q_rl/q_network.py | 51 +++----------------------------------- 2 files changed, 5 insertions(+), 48 deletions(-) diff --git a/deep_q_rl/ale_run_watch.py b/deep_q_rl/ale_run_watch.py index 2f12b2a..67a0bd5 100644 --- a/deep_q_rl/ale_run_watch.py +++ b/deep_q_rl/ale_run_watch.py @@ -9,7 +9,7 @@ import sys def run_watch(): - command = ['./run_double.py', '--steps-per-epoch', '0', + command = ['./run_nature.py', '--steps-per-epoch', '0', '--test-length', '10000', '--nn-file', sys.argv[1], '--display-screen'] diff --git a/deep_q_rl/q_network.py b/deep_q_rl/q_network.py index 11f41dc..95fef35 100644 --- a/deep_q_rl/q_network.py +++ b/deep_q_rl/q_network.py @@ -49,7 +49,6 @@ def __init__(self, input_width, input_height, num_actions, # Using Double DQN is pointless without periodic freezing if self.use_double: assert self.freeze_interval > 0 - # pass lasagne.random.set_rng(self.rng) @@ -92,11 +91,9 @@ def __init__(self, input_width, input_height, num_actions, q_vals = lasagne.layers.get_output(self.l_out, states / input_scale) if self.freeze_interval > 0: - # Nature. If using periodic freezing next_q_vals = lasagne.layers.get_output(self.next_l_out, next_states / input_scale) else: - # NIPS next_q_vals = lasagne.layers.get_output(self.l_out, next_states / input_scale) next_q_vals = theano.gradient.disconnected_grad(next_q_vals) @@ -160,38 +157,10 @@ def __init__(self, input_width, input_height, num_actions, updates = lasagne.updates.apply_momentum(updates, None, self.momentum) - def inspect_inputs(i, node, fn): - if ('maxand' not in str(node).lower() and '12345' not in str(node)): - return - print i, node, "input(s) value(s):", [input[0] for input in fn.inputs], - raw_input('press enter') - - def inspect_outputs(i, node, fn): - if ('maxand' not in str(node).lower() and '12345' not in str(node)): - return - if '12345' in str(node): - print "output(s) value(s):", [np.asarray(output[0]) for output in fn.outputs] - else: - print "output(s) value(s):", [output[0] for output in fn.outputs] - raw_input('press enter') - - if False: - self._train = theano.function([], [loss, q_vals], updates=updates, - givens=givens, mode=theano.compile.MonitorMode( - pre_func=inspect_inputs, - post_func=inspect_outputs)) - theano.printing.debugprint(target) - else: - self._train = theano.function([], [loss, q_vals], updates=updates, - givens=givens) - if False: - self._q_vals = theano.function([], q_vals, - givens={states: self.states_shared}, mode=theano.compile.MonitorMode( - pre_func=inspect_inputs, - post_func=inspect_outputs)) - else: - self._q_vals = theano.function([], q_vals, - givens={states: self.states_shared}) + self._train = theano.function([], [loss, q_vals], updates=updates, + givens=givens) + self._q_vals = theano.function([], q_vals, + givens={states: self.states_shared}) def build_network(self, network_type, input_width, input_height, output_dim, num_frames, batch_size): @@ -256,18 +225,6 @@ def choose_action(self, state, epsilon): if self.rng.rand() < epsilon: return self.rng.randint(0, self.num_actions) q_vals = self.q_vals(state) - scaled_q = q_vals - np.min(q_vals) - scaled_q = scaled_q / np.max(scaled_q) - for cmpi in range(1*15, -1, -1): - cmpval = float(cmpi) / 15 - line = '' - for i in range(4): - if scaled_q[i] > cmpval: - line += '####' - line += '\t' - print line - print 'noop\tfire\tright\tleft' - print q_vals return np.argmax(q_vals) def reset_q_hat(self): From f6668014d45f35278079c7fd023bc6a1126c4945 Mon Sep 17 00:00:00 2001 From: Cory Walker Date: Tue, 1 Dec 2015 11:20:32 -0500 Subject: [PATCH 5/5] Update citation Minor change to update the citation for Double DQN. --- deep_q_rl/run_double.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deep_q_rl/run_double.py b/deep_q_rl/run_double.py index e1e9ac1..2bce656 100755 --- a/deep_q_rl/run_double.py +++ b/deep_q_rl/run_double.py @@ -3,8 +3,8 @@ Execute a training run of deep-Q-Leaning with parameters that are consistent with: -Human-level control through deep reinforcement learning. -Nature, 518(7540):529-533, February 2015 +Deep Reinforcement Learning with Double Q-learning. +arXiv preprint arXiv:1509.06461. """