Skip to content

Commit

Permalink
[rllib] Rough port of DQN to build_tf_policy() pattern (#4823)
Browse files Browse the repository at this point in the history
  • Loading branch information
ericl authored Jun 2, 2019
1 parent c2ade07 commit 665d081
Show file tree
Hide file tree
Showing 8 changed files with 547 additions and 626 deletions.
224 changes: 85 additions & 139 deletions python/ray/rllib/agents/a3c/a3c_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,13 @@
from __future__ import division
from __future__ import print_function

import gym

import ray
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.explained_variance import explained_variance
from ray.rllib.policy.policy import Policy
from ray.rllib.evaluation.postprocessing import compute_advantages, \
Postprocessing
from ray.rllib.policy.tf_policy import TFPolicy, \
LearningRateSchedule
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils.annotations import override
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.policy.tf_policy import LearningRateSchedule
from ray.rllib.utils import try_import_tf

tf = try_import_tf()
Expand All @@ -44,144 +37,97 @@ def __init__(self,
self.entropy * entropy_coeff)


class A3CPostprocessing(object):
"""Adds the VF preds and advantages fields to the trajectory."""

@override(TFPolicy)
def extra_compute_action_fetches(self):
return dict(
TFPolicy.extra_compute_action_fetches(self),
**{SampleBatch.VF_PREDS: self.vf})

@override(Policy)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
completed = sample_batch[SampleBatch.DONES][-1]
if completed:
last_r = 0.0
else:
next_state = []
for i in range(len(self.model.state_in)):
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
last_r = self._value(sample_batch[SampleBatch.NEXT_OBS][-1],
sample_batch[SampleBatch.ACTIONS][-1],
sample_batch[SampleBatch.REWARDS][-1],
*next_state)
return compute_advantages(sample_batch, last_r, self.config["gamma"],
self.config["lambda"])


class A3CTFPolicy(LearningRateSchedule, A3CPostprocessing, TFPolicy):
def __init__(self, observation_space, action_space, config):
config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
self.config = config
self.sess = tf.get_default_session()

# Setup the policy
self.observations = tf.placeholder(
tf.float32, [None] + list(observation_space.shape))
dist_class, logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
self.prev_actions = ModelCatalog.get_action_placeholder(action_space)
self.prev_rewards = tf.placeholder(
tf.float32, [None], name="prev_reward")
self.model = ModelCatalog.get_model({
"obs": self.observations,
"prev_actions": self.prev_actions,
"prev_rewards": self.prev_rewards,
"is_training": self._get_is_training_placeholder(),
}, observation_space, action_space, logit_dim, self.config["model"])
action_dist = dist_class(self.model.outputs)
def actor_critic_loss(policy, batch_tensors):
policy.loss = A3CLoss(
policy.action_dist, batch_tensors[SampleBatch.ACTIONS],
batch_tensors[Postprocessing.ADVANTAGES],
batch_tensors[Postprocessing.VALUE_TARGETS], policy.vf,
policy.config["vf_loss_coeff"], policy.config["entropy_coeff"])
return policy.loss.total_loss


def postprocess_advantages(policy,
sample_batch,
other_agent_batches=None,
episode=None):
completed = sample_batch[SampleBatch.DONES][-1]
if completed:
last_r = 0.0
else:
next_state = []
for i in range(len(policy.model.state_in)):
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
sample_batch[SampleBatch.ACTIONS][-1],
sample_batch[SampleBatch.REWARDS][-1],
*next_state)
return compute_advantages(sample_batch, last_r, policy.config["gamma"],
policy.config["lambda"])


def add_value_function_fetch(policy):
return {SampleBatch.VF_PREDS: policy.vf}


class ValueNetworkMixin(object):
def __init__(self):
self.vf = self.model.value_function()
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
tf.get_variable_scope().name)

# Setup the policy loss
if isinstance(action_space, gym.spaces.Box):
ac_size = action_space.shape[0]
actions = tf.placeholder(tf.float32, [None, ac_size], name="ac")
elif isinstance(action_space, gym.spaces.Discrete):
actions = tf.placeholder(tf.int64, [None], name="ac")
else:
raise UnsupportedSpaceException(
"Action space {} is not supported for A3C.".format(
action_space))
advantages = tf.placeholder(tf.float32, [None], name="advantages")
self.v_target = tf.placeholder(tf.float32, [None], name="v_target")
self.loss = A3CLoss(action_dist, actions, advantages, self.v_target,
self.vf, self.config["vf_loss_coeff"],
self.config["entropy_coeff"])

# Initialize TFPolicy
loss_in = [
(SampleBatch.CUR_OBS, self.observations),
(SampleBatch.ACTIONS, actions),
(SampleBatch.PREV_ACTIONS, self.prev_actions),
(SampleBatch.PREV_REWARDS, self.prev_rewards),
(Postprocessing.ADVANTAGES, advantages),
(Postprocessing.VALUE_TARGETS, self.v_target),
]
LearningRateSchedule.__init__(self, self.config["lr"],
self.config["lr_schedule"])
TFPolicy.__init__(
self,
observation_space,
action_space,
self.sess,
obs_input=self.observations,
action_sampler=action_dist.sample(),
action_prob=action_dist.sampled_action_prob(),
loss=self.loss.total_loss,
model=self.model,
loss_inputs=loss_in,
state_inputs=self.model.state_in,
state_outputs=self.model.state_out,
prev_action_input=self.prev_actions,
prev_reward_input=self.prev_rewards,
seq_lens=self.model.seq_lens,
max_seq_len=self.config["model"]["max_seq_len"])

self.stats_fetches = {
LEARNER_STATS_KEY: {
"cur_lr": tf.cast(self.cur_lr, tf.float64),
"policy_loss": self.loss.pi_loss,
"policy_entropy": self.loss.entropy,
"grad_gnorm": tf.global_norm(self._grads),
"var_gnorm": tf.global_norm(self.var_list),
"vf_loss": self.loss.vf_loss,
"vf_explained_var": explained_variance(self.v_target, self.vf),
},
}

self.sess.run(tf.global_variables_initializer())

@override(Policy)
def get_initial_state(self):
return self.model.state_init

@override(TFPolicy)
def gradients(self, optimizer, loss):
grads = tf.gradients(loss, self.var_list)
self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"])
clipped_grads = list(zip(self.grads, self.var_list))
return clipped_grads

@override(TFPolicy)
def extra_compute_grad_fetches(self):
return self.stats_fetches

def _value(self, ob, prev_action, prev_reward, *args):
feed_dict = {
self.observations: [ob],
self.prev_actions: [prev_action],
self.prev_rewards: [prev_reward],
self.get_placeholder(SampleBatch.CUR_OBS): [ob],
self.get_placeholder(SampleBatch.PREV_ACTIONS): [prev_action],
self.get_placeholder(SampleBatch.PREV_REWARDS): [prev_reward],
self.model.seq_lens: [1]
}
assert len(args) == len(self.model.state_in), \
(args, self.model.state_in)
for k, v in zip(self.model.state_in, args):
feed_dict[k] = v
vf = self.sess.run(self.vf, feed_dict)
vf = self.get_session().run(self.vf, feed_dict)
return vf[0]


def stats(policy, batch_tensors):
return {
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
"policy_loss": policy.loss.pi_loss,
"policy_entropy": policy.loss.entropy,
"var_gnorm": tf.global_norm(policy.var_list),
"vf_loss": policy.loss.vf_loss,
}


def grad_stats(policy, grads):
return {
"grad_gnorm": tf.global_norm(grads),
"vf_explained_var": explained_variance(
policy.get_placeholder(Postprocessing.VALUE_TARGETS), policy.vf),
}


def clip_gradients(policy, optimizer, loss):
grads = tf.gradients(loss, policy.var_list)
grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"])
clipped_grads = list(zip(grads, policy.var_list))
return clipped_grads


def setup_mixins(policy, obs_space, action_space, config):
ValueNetworkMixin.__init__(policy)
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
policy.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
tf.get_variable_scope().name)


A3CTFPolicy = build_tf_policy(
name="A3CTFPolicy",
get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG,
loss_fn=actor_critic_loss,
stats_fn=stats,
grad_stats_fn=grad_stats,
gradients_fn=clip_gradients,
postprocess_fn=postprocess_advantages,
extra_action_fetches_fn=add_value_function_fetch,
before_loss_init=setup_mixins,
mixins=[ValueNetworkMixin, LearningRateSchedule])
Loading

0 comments on commit 665d081

Please sign in to comment.