diff --git a/python/ray/rllib/agents/dqn/dqn_policy_graph.py b/python/ray/rllib/agents/dqn/dqn_policy_graph.py index bb9d54babbf5..2f768683f9e8 100644 --- a/python/ray/rllib/agents/dqn/dqn_policy_graph.py +++ b/python/ray/rllib/agents/dqn/dqn_policy_graph.py @@ -700,6 +700,7 @@ def _scope_vars(scope, trainable_only=False): update_ops_fn=lambda policy: policy.q_batchnorm_update_ops, before_init=setup_early_mixins, after_init=setup_late_mixins, + obs_include_prev_action_reward=False, mixins=[ ExplorationStateMixin, TargetNetworkMixin, diff --git a/python/ray/rllib/agents/ppo/ppo_policy_graph.py b/python/ray/rllib/agents/ppo/ppo_policy_graph.py index cc1bbf642c7c..869ce5c7ad48 100644 --- a/python/ray/rllib/agents/ppo/ppo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/ppo_policy_graph.py @@ -240,14 +240,9 @@ def __init__(self, obs_space, action_space, config): "a custom LSTM model that overrides the " "value_function() method.") with tf.variable_scope("value_function"): - self.value_function = ModelCatalog.get_model({ - "obs": self.get_placeholder(SampleBatch.CUR_OBS), - "prev_actions": self.get_placeholder( - SampleBatch.PREV_ACTIONS), - "prev_rewards": self.get_placeholder( - SampleBatch.PREV_REWARDS), - "is_training": self._get_is_training_placeholder(), - }, obs_space, action_space, 1, vf_config).outputs + self.value_function = ModelCatalog.get_model( + self.get_obs_input_dict(), obs_space, action_space, 1, + vf_config).outputs self.value_function = tf.reshape(self.value_function, [-1]) else: self.value_function = tf.zeros( diff --git a/python/ray/rllib/evaluation/dynamic_tf_policy_graph.py b/python/ray/rllib/evaluation/dynamic_tf_policy_graph.py index 5b051693d4fa..a25c2bdc945e 100644 --- a/python/ray/rllib/evaluation/dynamic_tf_policy_graph.py +++ b/python/ray/rllib/evaluation/dynamic_tf_policy_graph.py @@ -42,7 +42,8 @@ def __init__(self, before_loss_init=None, make_action_sampler=None, existing_inputs=None, - get_batch_divisibility_req=None): + get_batch_divisibility_req=None, + obs_include_prev_action_reward=True): """Initialize a dynamic TF policy graph. Arguments: @@ -68,33 +69,41 @@ def __init__(self, defining new ones get_batch_divisibility_req (func): optional function that returns the divisibility requirement for sample batches + obs_include_prev_action_reward (bool): whether to include the + previous action and reward in the model input """ self.config = config self._loss_fn = loss_fn self._stats_fn = stats_fn self._grad_stats_fn = grad_stats_fn self._update_ops_fn = update_ops_fn + self._obs_include_prev_action_reward = obs_include_prev_action_reward # Setup standard placeholders + prev_actions = None + prev_rewards = None if existing_inputs is not None: obs = existing_inputs[SampleBatch.CUR_OBS] - prev_actions = existing_inputs[SampleBatch.PREV_ACTIONS] - prev_rewards = existing_inputs[SampleBatch.PREV_REWARDS] + if self._obs_include_prev_action_reward: + prev_actions = existing_inputs[SampleBatch.PREV_ACTIONS] + prev_rewards = existing_inputs[SampleBatch.PREV_REWARDS] else: obs = tf.placeholder( tf.float32, shape=[None] + list(obs_space.shape), name="observation") - prev_actions = ModelCatalog.get_action_placeholder(action_space) - prev_rewards = tf.placeholder( - tf.float32, [None], name="prev_reward") + if self._obs_include_prev_action_reward: + prev_actions = ModelCatalog.get_action_placeholder( + action_space) + prev_rewards = tf.placeholder( + tf.float32, [None], name="prev_reward") - self.input_dict = UsageTrackingDict({ + self.input_dict = { SampleBatch.CUR_OBS: obs, SampleBatch.PREV_ACTIONS: prev_actions, SampleBatch.PREV_REWARDS: prev_rewards, "is_training": self._get_is_training_placeholder(), - }) + } # Create the model network and action outputs if make_action_sampler: @@ -162,6 +171,13 @@ def __init__(self, if not existing_inputs: self._initialize_loss() + def get_obs_input_dict(self): + """Returns the obs input dict used to build policy models. + + This dict includes the obs, prev actions, prev rewards, etc. tensors. + """ + return self.input_dict + @override(TFPolicyGraph) def copy(self, existing_inputs): """Creates a copy of self using existing input placeholders.""" @@ -218,14 +234,17 @@ def fake_array(tensor): return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype) dummy_batch = { - SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input), - SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input), SampleBatch.CUR_OBS: fake_array(self._obs_input), SampleBatch.NEXT_OBS: fake_array(self._obs_input), - SampleBatch.ACTIONS: fake_array(self._prev_action_input), - SampleBatch.REWARDS: np.array([0], dtype=np.float32), SampleBatch.DONES: np.array([False], dtype=np.bool), + SampleBatch.ACTIONS: np.array([self.action_space.sample()]), + SampleBatch.REWARDS: np.array([0], dtype=np.float32), } + if self._obs_include_prev_action_reward: + dummy_batch.update({ + SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input), + SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input), + }) state_init = self.get_initial_state() for i, h in enumerate(state_init): dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0) @@ -240,20 +259,24 @@ def fake_array(tensor): postprocessed_batch = self.postprocess_trajectory( SampleBatch(dummy_batch)) - batch_tensors = UsageTrackingDict({ - SampleBatch.PREV_ACTIONS: self._prev_action_input, - SampleBatch.PREV_REWARDS: self._prev_reward_input, - SampleBatch.CUR_OBS: self._obs_input, - }) - loss_inputs = [ - (SampleBatch.CUR_OBS, self._obs_input), - ] - if SampleBatch.PREV_ACTIONS in self.input_dict.accessed_keys: - loss_inputs.append((SampleBatch.PREV_ACTIONS, - self._prev_action_input)) - if SampleBatch.PREV_REWARDS in self.input_dict.accessed_keys: - loss_inputs.append((SampleBatch.PREV_REWARDS, - self._prev_reward_input)) + if self._obs_include_prev_action_reward: + batch_tensors = UsageTrackingDict({ + SampleBatch.PREV_ACTIONS: self._prev_action_input, + SampleBatch.PREV_REWARDS: self._prev_reward_input, + SampleBatch.CUR_OBS: self._obs_input, + }) + loss_inputs = [ + (SampleBatch.PREV_ACTIONS, self._prev_action_input), + (SampleBatch.PREV_REWARDS, self._prev_reward_input), + (SampleBatch.CUR_OBS, self._obs_input), + ] + else: + batch_tensors = UsageTrackingDict({ + SampleBatch.CUR_OBS: self._obs_input, + }) + loss_inputs = [ + (SampleBatch.CUR_OBS, self._obs_input), + ] for k, v in postprocessed_batch.items(): if k in batch_tensors: diff --git a/python/ray/rllib/evaluation/tf_policy_graph.py b/python/ray/rllib/evaluation/tf_policy_graph.py index fcc896603ab6..69b8f1619002 100644 --- a/python/ray/rllib/evaluation/tf_policy_graph.py +++ b/python/ray/rllib/evaluation/tf_policy_graph.py @@ -141,16 +141,24 @@ def __init__(self, "seq_lens tensor must be given if state inputs are defined") def get_placeholder(self, name): - """Returns the given loss input placeholder by name. + """Returns the given action or loss input placeholder by name. - These are the same placeholders passed in as the loss_inputs arg. If - the loss has not been initialized, an error is raised. + If the loss has not been initialized and a loss input placeholder is + requested, an error is raised. """ + obs_inputs = { + SampleBatch.CUR_OBS: self._obs_input, + SampleBatch.PREV_ACTIONS: self._prev_action_input, + SampleBatch.PREV_REWARDS: self._prev_reward_input, + } + if name in obs_inputs: + return obs_inputs[name] + if not self.loss_initialized(): raise RuntimeError( - "You cannot call policy.get_placeholder() before the loss " - "has been initialized. To avoid this, use " + "You cannot call policy.get_placeholder() for non-obs inputs " + "before the loss has been initialized. To avoid this, use " "policy.loss_initialized() to check whether this is the case.") return self._loss_input_dict[name] diff --git a/python/ray/rllib/evaluation/tf_policy_template.py b/python/ray/rllib/evaluation/tf_policy_template.py index f22b2e3af0bc..f16f422e39b8 100644 --- a/python/ray/rllib/evaluation/tf_policy_template.py +++ b/python/ray/rllib/evaluation/tf_policy_template.py @@ -27,7 +27,8 @@ def build_tf_policy(name, make_action_sampler=None, mixins=None, get_batch_divisibility_req=None, - update_ops_fn=None): + update_ops_fn=None, + obs_include_prev_action_reward=True): """Helper function for creating a dynamic tf policy at runtime. Arguments: @@ -70,6 +71,8 @@ def build_tf_policy(name, the divisibility requirement for sample batches update_ops_fn (func): optional function that returns a list overriding the update ops to run when applying gradients + obs_include_prev_action_reward (bool): whether to include the + previous action and reward in the model input Returns: a DynamicTFPolicyGraph instance that uses the specified args @@ -118,7 +121,8 @@ def before_loss_init_wrapper(policy, obs_space, action_space, update_ops_fn=update_ops_fn, before_loss_init=before_loss_init_wrapper, make_action_sampler=make_action_sampler, - existing_inputs=existing_inputs) + existing_inputs=existing_inputs, + obs_include_prev_action_reward=obs_include_prev_action_reward) if after_init: after_init(self, obs_space, action_space, config)