Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rllib] Rough port of DQN to build_tf_policy() pattern #4823

Merged
merged 12 commits into from
Jun 2, 2019
224 changes: 85 additions & 139 deletions python/ray/rllib/agents/a3c/a3c_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,13 @@
from __future__ import division
from __future__ import print_function

import gym

import ray
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.explained_variance import explained_variance
from ray.rllib.policy.policy import Policy
from ray.rllib.evaluation.postprocessing import compute_advantages, \
Postprocessing
from ray.rllib.policy.tf_policy import TFPolicy, \
LearningRateSchedule
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils.annotations import override
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.policy.tf_policy import LearningRateSchedule
from ray.rllib.utils import try_import_tf

tf = try_import_tf()
Expand All @@ -44,144 +37,97 @@ def __init__(self,
self.entropy * entropy_coeff)


class A3CPostprocessing(object):
"""Adds the VF preds and advantages fields to the trajectory."""

@override(TFPolicy)
def extra_compute_action_fetches(self):
return dict(
TFPolicy.extra_compute_action_fetches(self),
**{SampleBatch.VF_PREDS: self.vf})

@override(Policy)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
completed = sample_batch[SampleBatch.DONES][-1]
if completed:
last_r = 0.0
else:
next_state = []
for i in range(len(self.model.state_in)):
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
last_r = self._value(sample_batch[SampleBatch.NEXT_OBS][-1],
sample_batch[SampleBatch.ACTIONS][-1],
sample_batch[SampleBatch.REWARDS][-1],
*next_state)
return compute_advantages(sample_batch, last_r, self.config["gamma"],
self.config["lambda"])


class A3CTFPolicy(LearningRateSchedule, A3CPostprocessing, TFPolicy):
def __init__(self, observation_space, action_space, config):
config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
self.config = config
self.sess = tf.get_default_session()

# Setup the policy
self.observations = tf.placeholder(
tf.float32, [None] + list(observation_space.shape))
dist_class, logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
self.prev_actions = ModelCatalog.get_action_placeholder(action_space)
self.prev_rewards = tf.placeholder(
tf.float32, [None], name="prev_reward")
self.model = ModelCatalog.get_model({
"obs": self.observations,
"prev_actions": self.prev_actions,
"prev_rewards": self.prev_rewards,
"is_training": self._get_is_training_placeholder(),
}, observation_space, action_space, logit_dim, self.config["model"])
action_dist = dist_class(self.model.outputs)
def actor_critic_loss(policy, batch_tensors):
policy.loss = A3CLoss(
policy.action_dist, batch_tensors[SampleBatch.ACTIONS],
batch_tensors[Postprocessing.ADVANTAGES],
batch_tensors[Postprocessing.VALUE_TARGETS], policy.vf,
policy.config["vf_loss_coeff"], policy.config["entropy_coeff"])
return policy.loss.total_loss


def postprocess_advantages(policy,
sample_batch,
other_agent_batches=None,
episode=None):
completed = sample_batch[SampleBatch.DONES][-1]
if completed:
last_r = 0.0
else:
next_state = []
for i in range(len(policy.model.state_in)):
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
sample_batch[SampleBatch.ACTIONS][-1],
sample_batch[SampleBatch.REWARDS][-1],
*next_state)
return compute_advantages(sample_batch, last_r, policy.config["gamma"],
policy.config["lambda"])


def add_value_function_fetch(policy):
return {SampleBatch.VF_PREDS: policy.vf}


class ValueNetworkMixin(object):
def __init__(self):
self.vf = self.model.value_function()
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
tf.get_variable_scope().name)

# Setup the policy loss
if isinstance(action_space, gym.spaces.Box):
ac_size = action_space.shape[0]
actions = tf.placeholder(tf.float32, [None, ac_size], name="ac")
elif isinstance(action_space, gym.spaces.Discrete):
actions = tf.placeholder(tf.int64, [None], name="ac")
else:
raise UnsupportedSpaceException(
"Action space {} is not supported for A3C.".format(
action_space))
advantages = tf.placeholder(tf.float32, [None], name="advantages")
self.v_target = tf.placeholder(tf.float32, [None], name="v_target")
self.loss = A3CLoss(action_dist, actions, advantages, self.v_target,
self.vf, self.config["vf_loss_coeff"],
self.config["entropy_coeff"])

# Initialize TFPolicy
loss_in = [
(SampleBatch.CUR_OBS, self.observations),
(SampleBatch.ACTIONS, actions),
(SampleBatch.PREV_ACTIONS, self.prev_actions),
(SampleBatch.PREV_REWARDS, self.prev_rewards),
(Postprocessing.ADVANTAGES, advantages),
(Postprocessing.VALUE_TARGETS, self.v_target),
]
LearningRateSchedule.__init__(self, self.config["lr"],
self.config["lr_schedule"])
TFPolicy.__init__(
self,
observation_space,
action_space,
self.sess,
obs_input=self.observations,
action_sampler=action_dist.sample(),
action_prob=action_dist.sampled_action_prob(),
loss=self.loss.total_loss,
model=self.model,
loss_inputs=loss_in,
state_inputs=self.model.state_in,
state_outputs=self.model.state_out,
prev_action_input=self.prev_actions,
prev_reward_input=self.prev_rewards,
seq_lens=self.model.seq_lens,
max_seq_len=self.config["model"]["max_seq_len"])

self.stats_fetches = {
LEARNER_STATS_KEY: {
"cur_lr": tf.cast(self.cur_lr, tf.float64),
"policy_loss": self.loss.pi_loss,
"policy_entropy": self.loss.entropy,
"grad_gnorm": tf.global_norm(self._grads),
"var_gnorm": tf.global_norm(self.var_list),
"vf_loss": self.loss.vf_loss,
"vf_explained_var": explained_variance(self.v_target, self.vf),
},
}

self.sess.run(tf.global_variables_initializer())

@override(Policy)
def get_initial_state(self):
return self.model.state_init

@override(TFPolicy)
def gradients(self, optimizer, loss):
grads = tf.gradients(loss, self.var_list)
self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"])
clipped_grads = list(zip(self.grads, self.var_list))
return clipped_grads

@override(TFPolicy)
def extra_compute_grad_fetches(self):
return self.stats_fetches

def _value(self, ob, prev_action, prev_reward, *args):
feed_dict = {
self.observations: [ob],
self.prev_actions: [prev_action],
self.prev_rewards: [prev_reward],
self.get_placeholder(SampleBatch.CUR_OBS): [ob],
self.get_placeholder(SampleBatch.PREV_ACTIONS): [prev_action],
self.get_placeholder(SampleBatch.PREV_REWARDS): [prev_reward],
self.model.seq_lens: [1]
}
assert len(args) == len(self.model.state_in), \
(args, self.model.state_in)
for k, v in zip(self.model.state_in, args):
feed_dict[k] = v
vf = self.sess.run(self.vf, feed_dict)
vf = self.get_session().run(self.vf, feed_dict)
return vf[0]


def stats(policy, batch_tensors):
return {
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
"policy_loss": policy.loss.pi_loss,
"policy_entropy": policy.loss.entropy,
"var_gnorm": tf.global_norm(policy.var_list),
"vf_loss": policy.loss.vf_loss,
}


def grad_stats(policy, grads):
return {
"grad_gnorm": tf.global_norm(grads),
"vf_explained_var": explained_variance(
policy.get_placeholder(Postprocessing.VALUE_TARGETS), policy.vf),
}


def clip_gradients(policy, optimizer, loss):
grads = tf.gradients(loss, policy.var_list)
grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"])
clipped_grads = list(zip(grads, policy.var_list))
return clipped_grads


def setup_mixins(policy, obs_space, action_space, config):
ValueNetworkMixin.__init__(policy)
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
policy.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
tf.get_variable_scope().name)


A3CTFPolicy = build_tf_policy(
name="A3CTFPolicy",
get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG,
loss_fn=actor_critic_loss,
stats_fn=stats,
grad_stats_fn=grad_stats,
gradients_fn=clip_gradients,
postprocess_fn=postprocess_advantages,
extra_action_fetches_fn=add_value_function_fetch,
before_loss_init=setup_mixins,
mixins=[ValueNetworkMixin, LearningRateSchedule])
Loading