Skip to content

Commit

Permalink
make prev action reward optional
Browse files Browse the repository at this point in the history
  • Loading branch information
ericl committed May 19, 2019
1 parent b4bb8ce commit 04d0157
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 41 deletions.
1 change: 1 addition & 0 deletions python/ray/rllib/agents/dqn/dqn_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,7 @@ def _scope_vars(scope, trainable_only=False):
update_ops_fn=lambda policy: policy.q_batchnorm_update_ops,
before_init=setup_early_mixins,
after_init=setup_late_mixins,
obs_include_prev_action_reward=False,
mixins=[
ExplorationStateMixin,
TargetNetworkMixin,
Expand Down
11 changes: 3 additions & 8 deletions python/ray/rllib/agents/ppo/ppo_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,14 +240,9 @@ def __init__(self, obs_space, action_space, config):
"a custom LSTM model that overrides the "
"value_function() method.")
with tf.variable_scope("value_function"):
self.value_function = ModelCatalog.get_model({
"obs": self.get_placeholder(SampleBatch.CUR_OBS),
"prev_actions": self.get_placeholder(
SampleBatch.PREV_ACTIONS),
"prev_rewards": self.get_placeholder(
SampleBatch.PREV_REWARDS),
"is_training": self._get_is_training_placeholder(),
}, obs_space, action_space, 1, vf_config).outputs
self.value_function = ModelCatalog.get_model(
self.get_obs_input_dict(), obs_space, action_space, 1,
vf_config).outputs
self.value_function = tf.reshape(self.value_function, [-1])
else:
self.value_function = tf.zeros(
Expand Down
75 changes: 49 additions & 26 deletions python/ray/rllib/evaluation/dynamic_tf_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def __init__(self,
before_loss_init=None,
make_action_sampler=None,
existing_inputs=None,
get_batch_divisibility_req=None):
get_batch_divisibility_req=None,
obs_include_prev_action_reward=True):
"""Initialize a dynamic TF policy graph.
Arguments:
Expand All @@ -68,33 +69,41 @@ def __init__(self,
defining new ones
get_batch_divisibility_req (func): optional function that returns
the divisibility requirement for sample batches
obs_include_prev_action_reward (bool): whether to include the
previous action and reward in the model input
"""
self.config = config
self._loss_fn = loss_fn
self._stats_fn = stats_fn
self._grad_stats_fn = grad_stats_fn
self._update_ops_fn = update_ops_fn
self._obs_include_prev_action_reward = obs_include_prev_action_reward

# Setup standard placeholders
prev_actions = None
prev_rewards = None
if existing_inputs is not None:
obs = existing_inputs[SampleBatch.CUR_OBS]
prev_actions = existing_inputs[SampleBatch.PREV_ACTIONS]
prev_rewards = existing_inputs[SampleBatch.PREV_REWARDS]
if self._obs_include_prev_action_reward:
prev_actions = existing_inputs[SampleBatch.PREV_ACTIONS]
prev_rewards = existing_inputs[SampleBatch.PREV_REWARDS]
else:
obs = tf.placeholder(
tf.float32,
shape=[None] + list(obs_space.shape),
name="observation")
prev_actions = ModelCatalog.get_action_placeholder(action_space)
prev_rewards = tf.placeholder(
tf.float32, [None], name="prev_reward")
if self._obs_include_prev_action_reward:
prev_actions = ModelCatalog.get_action_placeholder(
action_space)
prev_rewards = tf.placeholder(
tf.float32, [None], name="prev_reward")

self.input_dict = UsageTrackingDict({
self.input_dict = {
SampleBatch.CUR_OBS: obs,
SampleBatch.PREV_ACTIONS: prev_actions,
SampleBatch.PREV_REWARDS: prev_rewards,
"is_training": self._get_is_training_placeholder(),
})
}

# Create the model network and action outputs
if make_action_sampler:
Expand Down Expand Up @@ -162,6 +171,13 @@ def __init__(self,
if not existing_inputs:
self._initialize_loss()

def get_obs_input_dict(self):
"""Returns the obs input dict used to build policy models.
This dict includes the obs, prev actions, prev rewards, etc. tensors.
"""
return self.input_dict

@override(TFPolicyGraph)
def copy(self, existing_inputs):
"""Creates a copy of self using existing input placeholders."""
Expand Down Expand Up @@ -218,14 +234,17 @@ def fake_array(tensor):
return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype)

dummy_batch = {
SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input),
SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input),
SampleBatch.CUR_OBS: fake_array(self._obs_input),
SampleBatch.NEXT_OBS: fake_array(self._obs_input),
SampleBatch.ACTIONS: fake_array(self._prev_action_input),
SampleBatch.REWARDS: np.array([0], dtype=np.float32),
SampleBatch.DONES: np.array([False], dtype=np.bool),
SampleBatch.ACTIONS: np.array([self.action_space.sample()]),
SampleBatch.REWARDS: np.array([0], dtype=np.float32),
}
if self._obs_include_prev_action_reward:
dummy_batch.update({
SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input),
SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input),
})
state_init = self.get_initial_state()
for i, h in enumerate(state_init):
dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0)
Expand All @@ -240,20 +259,24 @@ def fake_array(tensor):
postprocessed_batch = self.postprocess_trajectory(
SampleBatch(dummy_batch))

batch_tensors = UsageTrackingDict({
SampleBatch.PREV_ACTIONS: self._prev_action_input,
SampleBatch.PREV_REWARDS: self._prev_reward_input,
SampleBatch.CUR_OBS: self._obs_input,
})
loss_inputs = [
(SampleBatch.CUR_OBS, self._obs_input),
]
if SampleBatch.PREV_ACTIONS in self.input_dict.accessed_keys:
loss_inputs.append((SampleBatch.PREV_ACTIONS,
self._prev_action_input))
if SampleBatch.PREV_REWARDS in self.input_dict.accessed_keys:
loss_inputs.append((SampleBatch.PREV_REWARDS,
self._prev_reward_input))
if self._obs_include_prev_action_reward:
batch_tensors = UsageTrackingDict({
SampleBatch.PREV_ACTIONS: self._prev_action_input,
SampleBatch.PREV_REWARDS: self._prev_reward_input,
SampleBatch.CUR_OBS: self._obs_input,
})
loss_inputs = [
(SampleBatch.PREV_ACTIONS, self._prev_action_input),
(SampleBatch.PREV_REWARDS, self._prev_reward_input),
(SampleBatch.CUR_OBS, self._obs_input),
]
else:
batch_tensors = UsageTrackingDict({
SampleBatch.CUR_OBS: self._obs_input,
})
loss_inputs = [
(SampleBatch.CUR_OBS, self._obs_input),
]

for k, v in postprocessed_batch.items():
if k in batch_tensors:
Expand Down
18 changes: 13 additions & 5 deletions python/ray/rllib/evaluation/tf_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,24 @@ def __init__(self,
"seq_lens tensor must be given if state inputs are defined")

def get_placeholder(self, name):
"""Returns the given loss input placeholder by name.
"""Returns the given action or loss input placeholder by name.
These are the same placeholders passed in as the loss_inputs arg. If
the loss has not been initialized, an error is raised.
If the loss has not been initialized and a loss input placeholder is
requested, an error is raised.
"""

obs_inputs = {
SampleBatch.CUR_OBS: self._obs_input,
SampleBatch.PREV_ACTIONS: self._prev_action_input,
SampleBatch.PREV_REWARDS: self._prev_reward_input,
}
if name in obs_inputs:
return obs_inputs[name]

if not self.loss_initialized():
raise RuntimeError(
"You cannot call policy.get_placeholder() before the loss "
"has been initialized. To avoid this, use "
"You cannot call policy.get_placeholder() for non-obs inputs "
"before the loss has been initialized. To avoid this, use "
"policy.loss_initialized() to check whether this is the case.")

return self._loss_input_dict[name]
Expand Down
8 changes: 6 additions & 2 deletions python/ray/rllib/evaluation/tf_policy_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def build_tf_policy(name,
make_action_sampler=None,
mixins=None,
get_batch_divisibility_req=None,
update_ops_fn=None):
update_ops_fn=None,
obs_include_prev_action_reward=True):
"""Helper function for creating a dynamic tf policy at runtime.
Arguments:
Expand Down Expand Up @@ -70,6 +71,8 @@ def build_tf_policy(name,
the divisibility requirement for sample batches
update_ops_fn (func): optional function that returns a list overriding
the update ops to run when applying gradients
obs_include_prev_action_reward (bool): whether to include the
previous action and reward in the model input
Returns:
a DynamicTFPolicyGraph instance that uses the specified args
Expand Down Expand Up @@ -118,7 +121,8 @@ def before_loss_init_wrapper(policy, obs_space, action_space,
update_ops_fn=update_ops_fn,
before_loss_init=before_loss_init_wrapper,
make_action_sampler=make_action_sampler,
existing_inputs=existing_inputs)
existing_inputs=existing_inputs,
obs_include_prev_action_reward=obs_include_prev_action_reward)

if after_init:
after_init(self, obs_space, action_space, config)
Expand Down

0 comments on commit 04d0157

Please sign in to comment.