From 6dc2b92c2e540b1ed2edaa7194cb5933c8d188a8 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sun, 19 May 2019 16:17:30 -0700 Subject: [PATCH 01/10] works now --- .../ray/rllib/agents/dqn/dqn_policy_graph.py | 560 +++++++++--------- .../ray/rllib/agents/ppo/ppo_policy_graph.py | 6 +- .../evaluation/dynamic_tf_policy_graph.py | 44 +- .../ray/rllib/evaluation/tf_policy_graph.py | 10 +- .../rllib/evaluation/tf_policy_template.py | 31 +- 5 files changed, 354 insertions(+), 297 deletions(-) diff --git a/python/ray/rllib/agents/dqn/dqn_policy_graph.py b/python/ray/rllib/agents/dqn/dqn_policy_graph.py index 1e682ce80cfa..e29f36cb46d1 100644 --- a/python/ray/rllib/agents/dqn/dqn_policy_graph.py +++ b/python/ray/rllib/agents/dqn/dqn_policy_graph.py @@ -3,22 +3,25 @@ from __future__ import print_function from gym.spaces import Discrete +import logging import numpy as np from scipy.stats import entropy import ray from ray.rllib.evaluation.sample_batch import SampleBatch -from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.models import ModelCatalog, Categorical from ray.rllib.utils.annotations import override from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \ LearningRateSchedule +from ray.rllib.evaluation.tf_policy_template import build_tf_policy from ray.rllib.utils import try_import_tf tf = try_import_tf() +logger = logging.getLogger(__name__) + Q_SCOPE = "q_func" Q_TARGET_SCOPE = "target_q_func" @@ -102,46 +105,6 @@ def __init__(self, } -class DQNPostprocessing(object): - """Implements n-step learning and param noise adjustments.""" - - @override(TFPolicyGraph) - def extra_compute_action_fetches(self): - return dict( - TFPolicyGraph.extra_compute_action_fetches(self), **{ - "q_values": self.q_values, - }) - - @override(PolicyGraph) - def postprocess_trajectory(self, - sample_batch, - other_agent_batches=None, - episode=None): - if self.config["parameter_noise"]: - # adjust the sigma of parameter space noise - states = [list(x) for x in sample_batch.columns(["obs"])][0] - - noisy_action_distribution = self.sess.run( - self.action_probs, feed_dict={self.cur_observations: states}) - self.sess.run(self.remove_noise_op) - clean_action_distribution = self.sess.run( - self.action_probs, feed_dict={self.cur_observations: states}) - distance_in_action_space = np.mean( - entropy(clean_action_distribution.T, - noisy_action_distribution.T)) - self.pi_distance = distance_in_action_space - if (distance_in_action_space < - -np.log(1 - self.cur_epsilon + - self.cur_epsilon / self.num_actions)): - self.parameter_noise_sigma_val *= 1.01 - else: - self.parameter_noise_sigma_val /= 1.01 - self.parameter_noise_sigma.load( - self.parameter_noise_sigma_val, session=self.sess) - - return _postprocess_dqn(self, sample_batch) - - class QNetwork(object): def __init__(self, model, @@ -345,98 +308,31 @@ def __init__(self, q_values, observations, num_actions, stochastic, eps, self.action_prob = None -class DQNPolicyGraph(LearningRateSchedule, DQNPostprocessing, TFPolicyGraph): - def __init__(self, observation_space, action_space, config): - config = dict(ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, **config) - if not isinstance(action_space, Discrete): - raise UnsupportedSpaceException( - "Action space {} is not supported for DQN.".format( - action_space)) - - self.config = config +class ExplorationStateMixin(object): + def __init__(self, obs_space, action_space, config): self.cur_epsilon = 1.0 - self.num_actions = action_space.n - - # Action inputs self.stochastic = tf.placeholder(tf.bool, (), name="stochastic") self.eps = tf.placeholder(tf.float32, (), name="eps") - self.cur_observations = tf.placeholder( - tf.float32, shape=(None, ) + observation_space.shape) - - # Action Q network - with tf.variable_scope(Q_SCOPE) as scope: - q_values, q_logits, q_dist, _ = self._build_q_network( - self.cur_observations, observation_space, action_space) - self.q_values = q_values - self.q_func_vars = _scope_vars(scope.name) - # Noise vars for Q network except for layer normalization vars + def add_parameter_noise(self): if self.config["parameter_noise"]: - self._build_parameter_noise([ - var for var in self.q_func_vars if "LayerNorm" not in var.name - ]) - self.action_probs = tf.nn.softmax(self.q_values) - - # Action outputs - self.output_actions, self.action_prob = self._build_q_value_policy( - q_values) - - # Replay inputs - self.obs_t = tf.placeholder( - tf.float32, shape=(None, ) + observation_space.shape) - self.act_t = tf.placeholder(tf.int32, [None], name="action") - self.rew_t = tf.placeholder(tf.float32, [None], name="reward") - self.obs_tp1 = tf.placeholder( - tf.float32, shape=(None, ) + observation_space.shape) - self.done_mask = tf.placeholder(tf.float32, [None], name="done") - self.importance_weights = tf.placeholder( - tf.float32, [None], name="weight") - - # q network evaluation - with tf.variable_scope(Q_SCOPE, reuse=True): - prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - q_t, q_logits_t, q_dist_t, model = self._build_q_network( - self.obs_t, observation_space, action_space) - q_batchnorm_update_ops = list( - set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - - prev_update_ops) - - # target q network evalution - with tf.variable_scope(Q_TARGET_SCOPE) as scope: - q_tp1, q_logits_tp1, q_dist_tp1, _ = self._build_q_network( - self.obs_tp1, observation_space, action_space) - self.target_q_func_vars = _scope_vars(scope.name) - - # q scores for actions which we know were selected in the given state. - one_hot_selection = tf.one_hot(self.act_t, self.num_actions) - q_t_selected = tf.reduce_sum(q_t * one_hot_selection, 1) - q_logits_t_selected = tf.reduce_sum( - q_logits_t * tf.expand_dims(one_hot_selection, -1), 1) - - # compute estimate of best possible value starting from state at t + 1 - if config["double_q"]: - with tf.variable_scope(Q_SCOPE, reuse=True): - q_tp1_using_online_net, q_logits_tp1_using_online_net, \ - q_dist_tp1_using_online_net, _ = self._build_q_network( - self.obs_tp1, observation_space, action_space) - q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1) - q_tp1_best_one_hot_selection = tf.one_hot( - q_tp1_best_using_online_net, self.num_actions) - q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1) - q_dist_tp1_best = tf.reduce_sum( - q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), - 1) - else: - q_tp1_best_one_hot_selection = tf.one_hot( - tf.argmax(q_tp1, 1), self.num_actions) - q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1) - q_dist_tp1_best = tf.reduce_sum( - q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), - 1) + self.sess.run(self.add_noise_op) - self.loss = self._build_q_loss(q_t_selected, q_logits_t_selected, - q_tp1_best, q_dist_tp1_best) + def set_epsilon(self, epsilon): + self.cur_epsilon = epsilon + @override(PolicyGraph) + def get_state(self): + return [TFPolicyGraph.get_state(self), self.cur_epsilon] + + @override(PolicyGraph) + def set_state(self, state): + TFPolicyGraph.set_state(self, state[0]) + self.set_epsilon(state[1]) + + +class TargetNetworkMixin(object): + def __init__(self, obs_space, action_space, config): # update_target_fn will be called periodically to copy Q network to # target Q network update_target_expr = [] @@ -446,166 +342,255 @@ def __init__(self, observation_space, action_space, config): update_target_expr.append(var_target.assign(var)) self.update_target_expr = tf.group(*update_target_expr) - # initialize TFPolicyGraph - self.sess = tf.get_default_session() - self.loss_inputs = [ - (SampleBatch.CUR_OBS, self.obs_t), - (SampleBatch.ACTIONS, self.act_t), - (SampleBatch.REWARDS, self.rew_t), - (SampleBatch.NEXT_OBS, self.obs_tp1), - (SampleBatch.DONES, self.done_mask), - (PRIO_WEIGHTS, self.importance_weights), - ] - - LearningRateSchedule.__init__(self, self.config["lr"], - self.config["lr_schedule"]) - TFPolicyGraph.__init__( - self, - observation_space, - action_space, - self.sess, - obs_input=self.cur_observations, - action_sampler=self.output_actions, - action_prob=self.action_prob, - loss=self.loss.loss, - model=model, - loss_inputs=self.loss_inputs, - update_ops=q_batchnorm_update_ops) - self.sess.run(tf.global_variables_initializer()) - - self.stats_fetches = dict({ - "cur_lr": tf.cast(self.cur_lr, tf.float64), - }, **self.loss.stats) - - @override(TFPolicyGraph) - def optimizer(self): - return tf.train.AdamOptimizer( - learning_rate=self.cur_lr, epsilon=self.config["adam_epsilon"]) - - @override(TFPolicyGraph) - def gradients(self, optimizer, loss): - if self.config["grad_norm_clipping"] is not None: - grads_and_vars = _minimize_and_clip( - optimizer, - loss, - var_list=self.q_func_vars, - clip_val=self.config["grad_norm_clipping"]) - else: - grads_and_vars = optimizer.compute_gradients( - loss, var_list=self.q_func_vars) - grads_and_vars = [(g, v) for (g, v) in grads_and_vars if g is not None] - return grads_and_vars - - @override(TFPolicyGraph) - def extra_compute_action_feed_dict(self): - return { - self.stochastic: True, - self.eps: self.cur_epsilon, - } - - @override(TFPolicyGraph) - def extra_compute_grad_fetches(self): - return { - "td_error": self.loss.td_error, - LEARNER_STATS_KEY: self.stats_fetches, - } - - @override(PolicyGraph) - def get_state(self): - return [TFPolicyGraph.get_state(self), self.cur_epsilon] - - @override(PolicyGraph) - def set_state(self, state): - TFPolicyGraph.set_state(self, state[0]) - self.set_epsilon(state[1]) + def update_target(self): + return self._sess.run(self.update_target_expr) - def _build_parameter_noise(self, pnet_params): - self.parameter_noise_sigma_val = 1.0 - self.parameter_noise_sigma = tf.get_variable( - initializer=tf.constant_initializer( - self.parameter_noise_sigma_val), - name="parameter_noise_sigma", - shape=(), - trainable=False, - dtype=tf.float32) - self.parameter_noise = list() - # No need to add any noise on LayerNorm parameters - for var in pnet_params: - noise_var = tf.get_variable( - name=var.name.split(":")[0] + "_noise", - shape=var.shape, - initializer=tf.constant_initializer(.0), - trainable=False) - self.parameter_noise.append(noise_var) - remove_noise_ops = list() - for var, var_noise in zip(pnet_params, self.parameter_noise): - remove_noise_ops.append(tf.assign_add(var, -var_noise)) - self.remove_noise_op = tf.group(*tuple(remove_noise_ops)) - generate_noise_ops = list() - for var_noise in self.parameter_noise: - generate_noise_ops.append( - tf.assign( - var_noise, - tf.random_normal( - shape=var_noise.shape, - stddev=self.parameter_noise_sigma))) - with tf.control_dependencies(generate_noise_ops): - add_noise_ops = list() - for var, var_noise in zip(pnet_params, self.parameter_noise): - add_noise_ops.append(tf.assign_add(var, var_noise)) - self.add_noise_op = tf.group(*tuple(add_noise_ops)) - self.pi_distance = None +class ComputeTDErrorMixin(object): def compute_td_error(self, obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights): - td_err = self.sess.run( + if not hasattr(self, "loss"): + # this is a hack so that the td error placeholder can be + # generated by build_tf_policy() + logger.warn( + "compute_td_error() called before loss has been initialized, " + "returning zeros") + return np.zeros_like(rew_t) + + td_err = self._sess.run( self.loss.td_error, feed_dict={ - self.obs_t: [np.array(ob) for ob in obs_t], - self.act_t: act_t, - self.rew_t: rew_t, - self.obs_tp1: [np.array(ob) for ob in obs_tp1], - self.done_mask: done_mask, - self.importance_weights: importance_weights + self.get_placeholder(SampleBatch.CUR_OBS): [ + np.array(ob) for ob in obs_t + ], + self.get_placeholder(SampleBatch.ACTIONS): act_t, + self.get_placeholder(SampleBatch.REWARDS): rew_t, + self.get_placeholder(SampleBatch.NEXT_OBS): [ + np.array(ob) for ob in obs_tp1 + ], + self.get_placeholder(SampleBatch.DONES): done_mask, + self.get_placeholder(PRIO_WEIGHTS): importance_weights, }) return td_err - def add_parameter_noise(self): - if self.config["parameter_noise"]: - self.sess.run(self.add_noise_op) - def update_target(self): - return self.sess.run(self.update_target_expr) - - def set_epsilon(self, epsilon): - self.cur_epsilon = epsilon - - def _build_q_network(self, obs, obs_space, action_space): - qnet = QNetwork( - ModelCatalog.get_model({ - "obs": obs, - "is_training": self._get_is_training_placeholder(), - }, obs_space, action_space, self.num_actions, - self.config["model"]), self.num_actions, - self.config["dueling"], self.config["hiddens"], - self.config["noisy"], self.config["num_atoms"], - self.config["v_min"], self.config["v_max"], self.config["sigma0"], - self.config["parameter_noise"]) - return qnet.value, qnet.logits, qnet.dist, qnet.model - - def _build_q_value_policy(self, q_values): - policy = QValuePolicy( - q_values, self.cur_observations, self.num_actions, self.stochastic, - self.eps, self.config["soft_q"], self.config["softmax_temp"]) - return policy.action, policy.action_prob - - def _build_q_loss(self, q_t_selected, q_logits_t_selected, q_tp1_best, - q_dist_tp1_best): - return QLoss(q_t_selected, q_logits_t_selected, q_tp1_best, - q_dist_tp1_best, self.importance_weights, self.rew_t, - self.done_mask, self.config["gamma"], - self.config["n_step"], self.config["num_atoms"], - self.config["v_min"], self.config["v_max"]) +def postprocess_trajectory(policy, + sample_batch, + other_agent_batches=None, + episode=None): + if policy.config["parameter_noise"]: + # adjust the sigma of parameter space noise + states = [list(x) for x in sample_batch.columns(["obs"])][0] + + noisy_action_distribution = policy._sess.run( + policy.action_probs, feed_dict={policy.cur_observations: states}) + policy._sess.run(policy.remove_noise_op) + clean_action_distribution = policy._sess.run( + policy.action_probs, feed_dict={policy.cur_observations: states}) + distance_in_action_space = np.mean( + entropy(clean_action_distribution.T, noisy_action_distribution.T)) + policy.pi_distance = distance_in_action_space + if (distance_in_action_space < + -np.log(1 - policy.cur_epsilon + + policy.cur_epsilon / policy.num_actions)): + policy.parameter_noise_sigma_val *= 1.01 + else: + policy.parameter_noise_sigma_val /= 1.01 + policy.parameter_noise_sigma.load( + policy.parameter_noise_sigma_val, session=policy._sess) + + return _postprocess_dqn(policy, sample_batch) + + +def build_q_networks(policy, input_dict, observation_space, action_space, + config): + + if not isinstance(action_space, Discrete): + raise UnsupportedSpaceException( + "Action space {} is not supported for DQN.".format(action_space)) + + # Action Q network + with tf.variable_scope(Q_SCOPE) as scope: + q_values, q_logits, q_dist, _ = _build_q_network( + policy, input_dict[SampleBatch.CUR_OBS], observation_space, + action_space) + policy.q_values = q_values + policy.q_func_vars = _scope_vars(scope.name) + + # Noise vars for Q network except for layer normalization vars + if config["parameter_noise"]: + _build_parameter_noise( + policy, + [var for var in policy.q_func_vars if "LayerNorm" not in var.name]) + policy.action_probs = tf.nn.softmax(policy.q_values) + + # Action outputs + qvp = QValuePolicy(q_values, input_dict[SampleBatch.CUR_OBS], + action_space.n, policy.stochastic, policy.eps, + policy.config["soft_q"], policy.config["softmax_temp"]) + policy.output_actions, policy.action_prob = qvp.action, qvp.action_prob + + return policy.output_actions, policy.action_prob + + +def _build_parameter_noise(policy, pnet_params): + policy.parameter_noise_sigma_val = 1.0 + policy.parameter_noise_sigma = tf.get_variable( + initializer=tf.constant_initializer(policy.parameter_noise_sigma_val), + name="parameter_noise_sigma", + shape=(), + trainable=False, + dtype=tf.float32) + policy.parameter_noise = list() + # No need to add any noise on LayerNorm parameters + for var in pnet_params: + noise_var = tf.get_variable( + name=var.name.split(":")[0] + "_noise", + shape=var.shape, + initializer=tf.constant_initializer(.0), + trainable=False) + policy.parameter_noise.append(noise_var) + remove_noise_ops = list() + for var, var_noise in zip(pnet_params, policy.parameter_noise): + remove_noise_ops.append(tf.assign_add(var, -var_noise)) + policy.remove_noise_op = tf.group(*tuple(remove_noise_ops)) + generate_noise_ops = list() + for var_noise in policy.parameter_noise: + generate_noise_ops.append( + tf.assign( + var_noise, + tf.random_normal( + shape=var_noise.shape, + stddev=policy.parameter_noise_sigma))) + with tf.control_dependencies(generate_noise_ops): + add_noise_ops = list() + for var, var_noise in zip(pnet_params, policy.parameter_noise): + add_noise_ops.append(tf.assign_add(var, var_noise)) + policy.add_noise_op = tf.group(*tuple(add_noise_ops)) + policy.pi_distance = None + + +def build_q_losses(policy, batch_tensors): + # q network evaluation + with tf.variable_scope(Q_SCOPE, reuse=True): + prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) + q_t, q_logits_t, q_dist_t, model = _build_q_network( + policy, batch_tensors[SampleBatch.CUR_OBS], + policy.observation_space, policy.action_space) + policy.q_batchnorm_update_ops = list( + set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops) + + # target q network evalution + with tf.variable_scope(Q_TARGET_SCOPE) as scope: + q_tp1, q_logits_tp1, q_dist_tp1, _ = _build_q_network( + policy, batch_tensors[SampleBatch.NEXT_OBS], + policy.observation_space, policy.action_space) + policy.target_q_func_vars = _scope_vars(scope.name) + + # q scores for actions which we know were selected in the given state. + one_hot_selection = tf.one_hot(batch_tensors[SampleBatch.ACTIONS], + policy.action_space.n) + q_t_selected = tf.reduce_sum(q_t * one_hot_selection, 1) + q_logits_t_selected = tf.reduce_sum( + q_logits_t * tf.expand_dims(one_hot_selection, -1), 1) + + # compute estimate of best possible value starting from state at t + 1 + if policy.config["double_q"]: + with tf.variable_scope(Q_SCOPE, reuse=True): + q_tp1_using_online_net, q_logits_tp1_using_online_net, \ + q_dist_tp1_using_online_net, _ = _build_q_network( + policy, + batch_tensors[SampleBatch.NEXT_OBS], + policy.observation_space, policy.action_space) + q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1) + q_tp1_best_one_hot_selection = tf.one_hot(q_tp1_best_using_online_net, + policy.action_space.n) + q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1) + q_dist_tp1_best = tf.reduce_sum( + q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), 1) + else: + q_tp1_best_one_hot_selection = tf.one_hot( + tf.argmax(q_tp1, 1), policy.action_space.n) + q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1) + q_dist_tp1_best = tf.reduce_sum( + q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), 1) + + policy.loss = _build_q_loss( + q_t_selected, q_logits_t_selected, q_tp1_best, q_dist_tp1_best, + batch_tensors[SampleBatch.REWARDS], batch_tensors[SampleBatch.DONES], + batch_tensors[PRIO_WEIGHTS], policy.config) + + return policy.loss.loss + + +def adam_optimizer(policy, config): + return tf.train.AdamOptimizer( + learning_rate=policy.cur_lr, epsilon=config["adam_epsilon"]) + + +def clip_gradients(policy, optimizer, loss): + if policy.config["grad_norm_clipping"] is not None: + grads_and_vars = _minimize_and_clip( + optimizer, + loss, + var_list=policy.q_func_vars, + clip_val=policy.config["grad_norm_clipping"]) + else: + grads_and_vars = optimizer.compute_gradients( + loss, var_list=policy.q_func_vars) + grads_and_vars = [(g, v) for (g, v) in grads_and_vars if g is not None] + return grads_and_vars + + +def exploration_setting_inputs(policy): + return { + policy.stochastic: True, + policy.eps: policy.cur_epsilon, + } + + +def build_q_stats(policy, batch_tensors): + return dict({ + "cur_lr": tf.cast(policy.cur_lr, tf.float64), + }, **policy.loss.stats) + + +def setup_early_mixins(policy, obs_space, action_space, config): + LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) + ExplorationStateMixin.__init__(policy, obs_space, action_space, config) + + +def setup_late_mixins(policy, obs_space, action_space, config): + TargetNetworkMixin.__init__(policy, obs_space, action_space, config) + + +def _build_q_network(policy, obs, obs_space, action_space): + config = policy.config + qnet = QNetwork( + ModelCatalog.get_model({ + "obs": obs, + "is_training": policy._get_is_training_placeholder(), + }, obs_space, action_space, action_space.n, config["model"]), + action_space.n, config["dueling"], config["hiddens"], config["noisy"], + config["num_atoms"], config["v_min"], config["v_max"], + config["sigma0"], config["parameter_noise"]) + return qnet.value, qnet.logits, qnet.dist, qnet.model + + +def _build_q_value_policy(policy, q_values): + policy = QValuePolicy(q_values, policy.cur_observations, + policy.num_actions, policy.stochastic, policy.eps, + policy.config["soft_q"], + policy.config["softmax_temp"]) + return policy.action, policy.action_prob + + +def _build_q_loss(q_t_selected, q_logits_t_selected, q_tp1_best, + q_dist_tp1_best, rewards, dones, importance_weights, config): + return QLoss(q_t_selected, q_logits_t_selected, q_tp1_best, + q_dist_tp1_best, importance_weights, rewards, + tf.cast(dones, tf.float32), config["gamma"], config["n_step"], + config["num_atoms"], config["v_min"], config["v_max"]) def _adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones): @@ -706,3 +691,26 @@ def _scope_vars(scope, trainable_only=False): tf.GraphKeys.TRAINABLE_VARIABLES if trainable_only else tf.GraphKeys.VARIABLES, scope=scope if isinstance(scope, str) else scope.name) + + +DQNPolicyGraph = build_tf_policy( + name="DQNTFPolicy", + get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, + make_action_sampler=build_q_networks, + loss_fn=build_q_losses, + stats_fn=build_q_stats, + postprocess_fn=postprocess_trajectory, + optimizer_fn=adam_optimizer, + gradients_fn=clip_gradients, + extra_action_feed_fn=exploration_setting_inputs, + extra_action_fetches_fn=lambda policy: {"q_values": policy.q_values}, + extra_learn_fetches_fn=lambda policy: {"td_error": policy.loss.td_error}, + update_ops_fn=lambda policy: policy.q_batchnorm_update_ops, + before_init=setup_early_mixins, + after_init=setup_late_mixins, + mixins=[ + ExplorationStateMixin, + TargetNetworkMixin, + ComputeTDErrorMixin, + LearningRateSchedule, + ]) diff --git a/python/ray/rllib/agents/ppo/ppo_policy_graph.py b/python/ray/rllib/agents/ppo/ppo_policy_graph.py index 334ca788c936..d5f41887331c 100644 --- a/python/ray/rllib/agents/ppo/ppo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/ppo_policy_graph.py @@ -252,9 +252,9 @@ def __init__(self, obs_space, action_space, config): def _value(self, ob, prev_action, prev_reward, *args): feed_dict = { - self._obs_input: [ob], - self._prev_action_input: [prev_action], - self._prev_reward_input: [prev_reward], + self.get_placeholder(SampleBatch.CUR_OBS): [ob], + self.get_placeholder(SampleBatch.PREV_ACTIONS): [prev_action], + self.get_placeholder(SampleBatch.PREV_REWARDS): [prev_reward], self.model.seq_lens: [1] } assert len(args) == len(self.model.state_in), \ diff --git a/python/ray/rllib/evaluation/dynamic_tf_policy_graph.py b/python/ray/rllib/evaluation/dynamic_tf_policy_graph.py index 73e08fcf9093..5b051693d4fa 100644 --- a/python/ray/rllib/evaluation/dynamic_tf_policy_graph.py +++ b/python/ray/rllib/evaluation/dynamic_tf_policy_graph.py @@ -37,6 +37,7 @@ def __init__(self, config, loss_fn, stats_fn=None, + update_ops_fn=None, grad_stats_fn=None, before_loss_init=None, make_action_sampler=None, @@ -54,6 +55,8 @@ def __init__(self, TF fetches given the policy graph and batch input tensors grad_stats_fn (func): optional function that returns a dict of TF fetches given the policy graph and loss gradient tensors + update_ops_fn (func): optional function that returns a list + overriding the update ops to run when applying gradients before_loss_init (func): optional function to run prior to loss init that takes the same arguments as __init__ make_action_sampler (func): optional function that returns a @@ -70,6 +73,7 @@ def __init__(self, self._loss_fn = loss_fn self._stats_fn = stats_fn self._grad_stats_fn = grad_stats_fn + self._update_ops_fn = update_ops_fn # Setup standard placeholders if existing_inputs is not None: @@ -85,12 +89,12 @@ def __init__(self, prev_rewards = tf.placeholder( tf.float32, [None], name="prev_reward") - input_dict = { - "obs": obs, - "prev_actions": prev_actions, - "prev_rewards": prev_rewards, + self.input_dict = UsageTrackingDict({ + SampleBatch.CUR_OBS: obs, + SampleBatch.PREV_ACTIONS: prev_actions, + SampleBatch.PREV_REWARDS: prev_rewards, "is_training": self._get_is_training_placeholder(), - } + }) # Create the model network and action outputs if make_action_sampler: @@ -100,7 +104,7 @@ def __init__(self, self.dist_class = None self.action_dist = None action_sampler, action_prob = make_action_sampler( - self, input_dict, obs_space, action_space, config) + self, self.input_dict, obs_space, action_space, config) else: self.dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"]) @@ -117,7 +121,7 @@ def __init__(self, existing_state_in = [] existing_seq_lens = None self.model = ModelCatalog.get_model( - input_dict, + self.input_dict, obs_space, action_space, logit_dim, @@ -190,10 +194,8 @@ def copy(self, existing_inputs): self.action_space, self.config, existing_inputs=input_dict) - loss = instance._loss_fn(instance, input_dict) - if instance._stats_fn: - instance._stats_fetches.update( - instance._stats_fn(instance, input_dict)) + + loss = instance._do_loss_init(input_dict) TFPolicyGraph._initialize_loss( instance, loss, [(k, existing_inputs[i]) for i, (k, _) in enumerate(self._loss_inputs)]) @@ -244,10 +246,14 @@ def fake_array(tensor): SampleBatch.CUR_OBS: self._obs_input, }) loss_inputs = [ - (SampleBatch.PREV_ACTIONS, self._prev_action_input), - (SampleBatch.PREV_REWARDS, self._prev_reward_input), (SampleBatch.CUR_OBS, self._obs_input), ] + if SampleBatch.PREV_ACTIONS in self.input_dict.accessed_keys: + loss_inputs.append((SampleBatch.PREV_ACTIONS, + self._prev_action_input)) + if SampleBatch.PREV_REWARDS in self.input_dict.accessed_keys: + loss_inputs.append((SampleBatch.PREV_REWARDS, + self._prev_reward_input)) for k, v in postprocessed_batch.items(): if k in batch_tensors: @@ -264,12 +270,18 @@ def fake_array(tensor): "Initializing loss function with dummy input:\n\n{}\n".format( summarize(batch_tensors))) - loss = self._loss_fn(self, batch_tensors) - if self._stats_fn: - self._stats_fetches.update(self._stats_fn(self, batch_tensors)) + loss = self._do_loss_init(batch_tensors) for k in sorted(batch_tensors.accessed_keys): loss_inputs.append((k, batch_tensors[k])) TFPolicyGraph._initialize_loss(self, loss, loss_inputs) if self._grad_stats_fn: self._stats_fetches.update(self._grad_stats_fn(self, self._grads)) self._sess.run(tf.global_variables_initializer()) + + def _do_loss_init(self, batch_tensors): + loss = self._loss_fn(self, batch_tensors) + if self._stats_fn: + self._stats_fetches.update(self._stats_fn(self, batch_tensors)) + if self._update_ops_fn: + self._update_ops = self._update_ops_fn(self) + return loss diff --git a/python/ray/rllib/evaluation/tf_policy_graph.py b/python/ray/rllib/evaluation/tf_policy_graph.py index b921e6cfb0d1..e71ede8cb08d 100644 --- a/python/ray/rllib/evaluation/tf_policy_graph.py +++ b/python/ray/rllib/evaluation/tf_policy_graph.py @@ -140,6 +140,14 @@ def __init__(self, raise ValueError( "seq_lens tensor must be given if state inputs are defined") + def get_placeholder(self, name): + """Returns the given loss input placeholder by name. + + These are the same placeholders passed in as the loss_inputs arg. + """ + + return self._loss_input_dict[name] + def _initialize_loss(self, loss, loss_inputs): self._loss_inputs = loss_inputs self._loss_input_dict = dict(self._loss_inputs) @@ -173,7 +181,7 @@ def _initialize_loss(self, loss, loss_inputs): self._grads_and_vars) if log_once("loss_used"): - logger.debug( + logger.info( "These tensors were used in the loss_fn:\n\n{}\n".format( summarize(self._loss_input_dict))) diff --git a/python/ray/rllib/evaluation/tf_policy_template.py b/python/ray/rllib/evaluation/tf_policy_template.py index b2549e973a65..f22b2e3af0bc 100644 --- a/python/ray/rllib/evaluation/tf_policy_template.py +++ b/python/ray/rllib/evaluation/tf_policy_template.py @@ -3,6 +3,7 @@ from __future__ import print_function from ray.rllib.evaluation.dynamic_tf_policy_graph import DynamicTFPolicyGraph +from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph from ray.rllib.utils.annotations import override, DeveloperAPI @@ -18,12 +19,15 @@ def build_tf_policy(name, postprocess_fn=None, optimizer_fn=None, gradients_fn=None, + extra_action_feed_fn=None, + extra_learn_fetches_fn=None, before_init=None, before_loss_init=None, after_init=None, make_action_sampler=None, mixins=None, - get_batch_divisibility_req=None): + get_batch_divisibility_req=None, + update_ops_fn=None): """Helper function for creating a dynamic tf policy at runtime. Arguments: @@ -45,6 +49,10 @@ def build_tf_policy(name, gradients_fn (func): optional function that returns a list of gradients given a tf optimizer and loss tensor. If not specified, this defaults to optimizer.compute_gradients(loss) + extra_action_feed_fn (func): optional function that returns a feed dict + to also feed to TF when computing actions + extra_learn_fetches_fn (func): optional function that returns a dict of + extra values to fetch and return when learning on a batch before_init (func): optional function to run at the beginning of policy init that takes the same arguments as the policy constructor before_loss_init (func): optional function to run prior to loss @@ -60,6 +68,8 @@ def build_tf_policy(name, precedence than the DynamicTFPolicyGraph class get_batch_divisibility_req (func): optional function that returns the divisibility requirement for sample batches + update_ops_fn (func): optional function that returns a list overriding + the update ops to run when applying gradients Returns: a DynamicTFPolicyGraph instance that uses the specified args @@ -105,7 +115,9 @@ def before_loss_init_wrapper(policy, obs_space, action_space, loss_fn, stats_fn=stats_fn, grad_stats_fn=grad_stats_fn, + update_ops_fn=update_ops_fn, before_loss_init=before_loss_init_wrapper, + make_action_sampler=make_action_sampler, existing_inputs=existing_inputs) if after_init: @@ -135,6 +147,23 @@ def gradients(self, optimizer, loss): else: return TFPolicyGraph.gradients(self, optimizer, loss) + @override(TFPolicyGraph) + def extra_compute_grad_fetches(self): + if extra_learn_fetches_fn: + # auto-add empty learner stats dict if needed + return dict({ + LEARNER_STATS_KEY: {} + }, **extra_learn_fetches_fn(self)) + else: + return TFPolicyGraph.extra_compute_grad_fetches(self) + + @override(TFPolicyGraph) + def extra_compute_action_feed_dict(self): + if extra_action_feed_fn: + return extra_action_feed_fn(self) + else: + return TFPolicyGraph.extra_compute_action_feed_dict(self) + @override(TFPolicyGraph) def extra_compute_action_fetches(self): return dict( From d9ee507fe27b5cb3c0661fcb00455d4f70e10d19 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sun, 19 May 2019 16:20:36 -0700 Subject: [PATCH 02/10] use get ph --- python/ray/rllib/agents/ppo/ppo_policy_graph.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/ray/rllib/agents/ppo/ppo_policy_graph.py b/python/ray/rllib/agents/ppo/ppo_policy_graph.py index d5f41887331c..e853ac95380c 100644 --- a/python/ray/rllib/agents/ppo/ppo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/ppo_policy_graph.py @@ -241,14 +241,17 @@ def __init__(self, obs_space, action_space, config): "value_function() method.") with tf.variable_scope("value_function"): self.value_function = ModelCatalog.get_model({ - "obs": self._obs_input, - "prev_actions": self._prev_action_input, - "prev_rewards": self._prev_reward_input, + "obs": self.get_placeholder(SampleBatch.CUR_OBS), + "prev_actions": self.get_placeholder( + SampleBatch.PREV_ACTIONS), + "prev_rewards": self.get_placeholder( + SampleBatch.PREV_REWARDS), "is_training": self._get_is_training_placeholder(), }, obs_space, action_space, 1, vf_config).outputs self.value_function = tf.reshape(self.value_function, [-1]) else: - self.value_function = tf.zeros(shape=tf.shape(self._obs_input)[:1]) + self.value_function = tf.zeros( + shape=tf.shape(self.get_placeholder(SampleBatch.CUR_OBS))[:1]) def _value(self, ob, prev_action, prev_reward, *args): feed_dict = { From 17364d8bacb4db859f076ac9d5a6a248ea342565 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sun, 19 May 2019 16:28:51 -0700 Subject: [PATCH 03/10] loss initialized function --- .../ray/rllib/agents/dqn/dqn_policy_graph.py | 22 ++++++------------ .../ray/rllib/agents/ppo/ppo_policy_graph.py | 4 ++-- .../ray/rllib/evaluation/tf_policy_graph.py | 23 +++++++++++++++---- 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/python/ray/rllib/agents/dqn/dqn_policy_graph.py b/python/ray/rllib/agents/dqn/dqn_policy_graph.py index e29f36cb46d1..bb9d54babbf5 100644 --- a/python/ray/rllib/agents/dqn/dqn_policy_graph.py +++ b/python/ray/rllib/agents/dqn/dqn_policy_graph.py @@ -3,7 +3,6 @@ from __future__ import print_function from gym.spaces import Discrete -import logging import numpy as np from scipy.stats import entropy @@ -20,8 +19,6 @@ tf = try_import_tf() -logger = logging.getLogger(__name__) - Q_SCOPE = "q_func" Q_TARGET_SCOPE = "target_q_func" @@ -343,21 +340,16 @@ def __init__(self, obs_space, action_space, config): self.update_target_expr = tf.group(*update_target_expr) def update_target(self): - return self._sess.run(self.update_target_expr) + return self.get_session().run(self.update_target_expr) class ComputeTDErrorMixin(object): def compute_td_error(self, obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights): - if not hasattr(self, "loss"): - # this is a hack so that the td error placeholder can be - # generated by build_tf_policy() - logger.warn( - "compute_td_error() called before loss has been initialized, " - "returning zeros") + if not self.loss_initialized(): return np.zeros_like(rew_t) - td_err = self._sess.run( + td_err = self.get_session().run( self.loss.td_error, feed_dict={ self.get_placeholder(SampleBatch.CUR_OBS): [ @@ -382,10 +374,10 @@ def postprocess_trajectory(policy, # adjust the sigma of parameter space noise states = [list(x) for x in sample_batch.columns(["obs"])][0] - noisy_action_distribution = policy._sess.run( + noisy_action_distribution = policy.get_session().run( policy.action_probs, feed_dict={policy.cur_observations: states}) - policy._sess.run(policy.remove_noise_op) - clean_action_distribution = policy._sess.run( + policy.get_session().run(policy.remove_noise_op) + clean_action_distribution = policy.get_session().run( policy.action_probs, feed_dict={policy.cur_observations: states}) distance_in_action_space = np.mean( entropy(clean_action_distribution.T, noisy_action_distribution.T)) @@ -397,7 +389,7 @@ def postprocess_trajectory(policy, else: policy.parameter_noise_sigma_val /= 1.01 policy.parameter_noise_sigma.load( - policy.parameter_noise_sigma_val, session=policy._sess) + policy.parameter_noise_sigma_val, session=policy.get_session()) return _postprocess_dqn(policy, sample_batch) diff --git a/python/ray/rllib/agents/ppo/ppo_policy_graph.py b/python/ray/rllib/agents/ppo/ppo_policy_graph.py index e853ac95380c..cc1bbf642c7c 100644 --- a/python/ray/rllib/agents/ppo/ppo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/ppo_policy_graph.py @@ -216,7 +216,7 @@ def update_kl(self, sampled_kl): self.kl_coeff_val *= 1.5 elif sampled_kl < 0.5 * self.kl_target: self.kl_coeff_val *= 0.5 - self.kl_coeff.load(self.kl_coeff_val, session=self._sess) + self.kl_coeff.load(self.kl_coeff_val, session=self.get_session()) return self.kl_coeff_val @@ -264,7 +264,7 @@ def _value(self, ob, prev_action, prev_reward, *args): (args, self.model.state_in) for k, v in zip(self.model.state_in, args): feed_dict[k] = v - vf = self._sess.run(self.value_function, feed_dict) + vf = self.get_session().run(self.value_function, feed_dict) return vf[0] diff --git a/python/ray/rllib/evaluation/tf_policy_graph.py b/python/ray/rllib/evaluation/tf_policy_graph.py index e71ede8cb08d..fcc896603ab6 100644 --- a/python/ray/rllib/evaluation/tf_policy_graph.py +++ b/python/ray/rllib/evaluation/tf_policy_graph.py @@ -143,11 +143,26 @@ def __init__(self, def get_placeholder(self, name): """Returns the given loss input placeholder by name. - These are the same placeholders passed in as the loss_inputs arg. + These are the same placeholders passed in as the loss_inputs arg. If + the loss has not been initialized, an error is raised. """ + if not self.loss_initialized(): + raise RuntimeError( + "You cannot call policy.get_placeholder() before the loss " + "has been initialized. To avoid this, use " + "policy.loss_initialized() to check whether this is the case.") + return self._loss_input_dict[name] + def get_session(self): + """Returns a reference to the TF session for this policy.""" + return self._sess + + def loss_initialized(self): + """Returns whether the loss function has been initialized.""" + return self._loss is not None + def _initialize_loss(self, loss, loss_inputs): self._loss_inputs = loss_inputs self._loss_input_dict = dict(self._loss_inputs) @@ -204,21 +219,21 @@ def compute_actions(self, @override(PolicyGraph) def compute_gradients(self, postprocessed_batch): - assert self._loss is not None, "Loss not initialized" + assert self.loss_initialized() builder = TFRunBuilder(self._sess, "compute_gradients") fetches = self._build_compute_gradients(builder, postprocessed_batch) return builder.get(fetches) @override(PolicyGraph) def apply_gradients(self, gradients): - assert self._loss is not None, "Loss not initialized" + assert self.loss_initialized() builder = TFRunBuilder(self._sess, "apply_gradients") fetches = self._build_apply_gradients(builder, gradients) builder.get(fetches) @override(PolicyGraph) def learn_on_batch(self, postprocessed_batch): - assert self._loss is not None, "Loss not initialized" + assert self.loss_initialized() builder = TFRunBuilder(self._sess, "learn_on_batch") fetches = self._build_learn_on_batch(builder, postprocessed_batch) return builder.get(fetches) From b4bb8ced118ad70f654797ac9138419b35637f07 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sun, 19 May 2019 16:32:37 -0700 Subject: [PATCH 04/10] fix appo --- python/ray/rllib/agents/ppo/appo_policy_graph.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/ray/rllib/agents/ppo/appo_policy_graph.py b/python/ray/rllib/agents/ppo/appo_policy_graph.py index 5aa76913194f..bf2b3318521c 100644 --- a/python/ray/rllib/agents/ppo/appo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/appo_policy_graph.py @@ -366,12 +366,15 @@ def __init__(self): tf.get_variable_scope().name) def value(self, ob, *args): - feed_dict = {self._obs_input: [ob], self.model.seq_lens: [1]} + feed_dict = { + self.get_placeholder(SampleBatch.CUR_OBS): [ob], + self.model.seq_lens: [1] + } assert len(args) == len(self.model.state_in), \ (args, self.model.state_in) for k, v in zip(self.model.state_in, args): feed_dict[k] = v - vf = self._sess.run(self.value_function, feed_dict) + vf = self.get_session().run(self.value_function, feed_dict) return vf[0] From 04d0157a7cb71b448985395695928f4d511eb49b Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sun, 19 May 2019 16:53:44 -0700 Subject: [PATCH 05/10] make prev action reward optional --- .../ray/rllib/agents/dqn/dqn_policy_graph.py | 1 + .../ray/rllib/agents/ppo/ppo_policy_graph.py | 11 +-- .../evaluation/dynamic_tf_policy_graph.py | 75 ++++++++++++------- .../ray/rllib/evaluation/tf_policy_graph.py | 18 +++-- .../rllib/evaluation/tf_policy_template.py | 8 +- 5 files changed, 72 insertions(+), 41 deletions(-) diff --git a/python/ray/rllib/agents/dqn/dqn_policy_graph.py b/python/ray/rllib/agents/dqn/dqn_policy_graph.py index bb9d54babbf5..2f768683f9e8 100644 --- a/python/ray/rllib/agents/dqn/dqn_policy_graph.py +++ b/python/ray/rllib/agents/dqn/dqn_policy_graph.py @@ -700,6 +700,7 @@ def _scope_vars(scope, trainable_only=False): update_ops_fn=lambda policy: policy.q_batchnorm_update_ops, before_init=setup_early_mixins, after_init=setup_late_mixins, + obs_include_prev_action_reward=False, mixins=[ ExplorationStateMixin, TargetNetworkMixin, diff --git a/python/ray/rllib/agents/ppo/ppo_policy_graph.py b/python/ray/rllib/agents/ppo/ppo_policy_graph.py index cc1bbf642c7c..869ce5c7ad48 100644 --- a/python/ray/rllib/agents/ppo/ppo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/ppo_policy_graph.py @@ -240,14 +240,9 @@ def __init__(self, obs_space, action_space, config): "a custom LSTM model that overrides the " "value_function() method.") with tf.variable_scope("value_function"): - self.value_function = ModelCatalog.get_model({ - "obs": self.get_placeholder(SampleBatch.CUR_OBS), - "prev_actions": self.get_placeholder( - SampleBatch.PREV_ACTIONS), - "prev_rewards": self.get_placeholder( - SampleBatch.PREV_REWARDS), - "is_training": self._get_is_training_placeholder(), - }, obs_space, action_space, 1, vf_config).outputs + self.value_function = ModelCatalog.get_model( + self.get_obs_input_dict(), obs_space, action_space, 1, + vf_config).outputs self.value_function = tf.reshape(self.value_function, [-1]) else: self.value_function = tf.zeros( diff --git a/python/ray/rllib/evaluation/dynamic_tf_policy_graph.py b/python/ray/rllib/evaluation/dynamic_tf_policy_graph.py index 5b051693d4fa..a25c2bdc945e 100644 --- a/python/ray/rllib/evaluation/dynamic_tf_policy_graph.py +++ b/python/ray/rllib/evaluation/dynamic_tf_policy_graph.py @@ -42,7 +42,8 @@ def __init__(self, before_loss_init=None, make_action_sampler=None, existing_inputs=None, - get_batch_divisibility_req=None): + get_batch_divisibility_req=None, + obs_include_prev_action_reward=True): """Initialize a dynamic TF policy graph. Arguments: @@ -68,33 +69,41 @@ def __init__(self, defining new ones get_batch_divisibility_req (func): optional function that returns the divisibility requirement for sample batches + obs_include_prev_action_reward (bool): whether to include the + previous action and reward in the model input """ self.config = config self._loss_fn = loss_fn self._stats_fn = stats_fn self._grad_stats_fn = grad_stats_fn self._update_ops_fn = update_ops_fn + self._obs_include_prev_action_reward = obs_include_prev_action_reward # Setup standard placeholders + prev_actions = None + prev_rewards = None if existing_inputs is not None: obs = existing_inputs[SampleBatch.CUR_OBS] - prev_actions = existing_inputs[SampleBatch.PREV_ACTIONS] - prev_rewards = existing_inputs[SampleBatch.PREV_REWARDS] + if self._obs_include_prev_action_reward: + prev_actions = existing_inputs[SampleBatch.PREV_ACTIONS] + prev_rewards = existing_inputs[SampleBatch.PREV_REWARDS] else: obs = tf.placeholder( tf.float32, shape=[None] + list(obs_space.shape), name="observation") - prev_actions = ModelCatalog.get_action_placeholder(action_space) - prev_rewards = tf.placeholder( - tf.float32, [None], name="prev_reward") + if self._obs_include_prev_action_reward: + prev_actions = ModelCatalog.get_action_placeholder( + action_space) + prev_rewards = tf.placeholder( + tf.float32, [None], name="prev_reward") - self.input_dict = UsageTrackingDict({ + self.input_dict = { SampleBatch.CUR_OBS: obs, SampleBatch.PREV_ACTIONS: prev_actions, SampleBatch.PREV_REWARDS: prev_rewards, "is_training": self._get_is_training_placeholder(), - }) + } # Create the model network and action outputs if make_action_sampler: @@ -162,6 +171,13 @@ def __init__(self, if not existing_inputs: self._initialize_loss() + def get_obs_input_dict(self): + """Returns the obs input dict used to build policy models. + + This dict includes the obs, prev actions, prev rewards, etc. tensors. + """ + return self.input_dict + @override(TFPolicyGraph) def copy(self, existing_inputs): """Creates a copy of self using existing input placeholders.""" @@ -218,14 +234,17 @@ def fake_array(tensor): return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype) dummy_batch = { - SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input), - SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input), SampleBatch.CUR_OBS: fake_array(self._obs_input), SampleBatch.NEXT_OBS: fake_array(self._obs_input), - SampleBatch.ACTIONS: fake_array(self._prev_action_input), - SampleBatch.REWARDS: np.array([0], dtype=np.float32), SampleBatch.DONES: np.array([False], dtype=np.bool), + SampleBatch.ACTIONS: np.array([self.action_space.sample()]), + SampleBatch.REWARDS: np.array([0], dtype=np.float32), } + if self._obs_include_prev_action_reward: + dummy_batch.update({ + SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input), + SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input), + }) state_init = self.get_initial_state() for i, h in enumerate(state_init): dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0) @@ -240,20 +259,24 @@ def fake_array(tensor): postprocessed_batch = self.postprocess_trajectory( SampleBatch(dummy_batch)) - batch_tensors = UsageTrackingDict({ - SampleBatch.PREV_ACTIONS: self._prev_action_input, - SampleBatch.PREV_REWARDS: self._prev_reward_input, - SampleBatch.CUR_OBS: self._obs_input, - }) - loss_inputs = [ - (SampleBatch.CUR_OBS, self._obs_input), - ] - if SampleBatch.PREV_ACTIONS in self.input_dict.accessed_keys: - loss_inputs.append((SampleBatch.PREV_ACTIONS, - self._prev_action_input)) - if SampleBatch.PREV_REWARDS in self.input_dict.accessed_keys: - loss_inputs.append((SampleBatch.PREV_REWARDS, - self._prev_reward_input)) + if self._obs_include_prev_action_reward: + batch_tensors = UsageTrackingDict({ + SampleBatch.PREV_ACTIONS: self._prev_action_input, + SampleBatch.PREV_REWARDS: self._prev_reward_input, + SampleBatch.CUR_OBS: self._obs_input, + }) + loss_inputs = [ + (SampleBatch.PREV_ACTIONS, self._prev_action_input), + (SampleBatch.PREV_REWARDS, self._prev_reward_input), + (SampleBatch.CUR_OBS, self._obs_input), + ] + else: + batch_tensors = UsageTrackingDict({ + SampleBatch.CUR_OBS: self._obs_input, + }) + loss_inputs = [ + (SampleBatch.CUR_OBS, self._obs_input), + ] for k, v in postprocessed_batch.items(): if k in batch_tensors: diff --git a/python/ray/rllib/evaluation/tf_policy_graph.py b/python/ray/rllib/evaluation/tf_policy_graph.py index fcc896603ab6..69b8f1619002 100644 --- a/python/ray/rllib/evaluation/tf_policy_graph.py +++ b/python/ray/rllib/evaluation/tf_policy_graph.py @@ -141,16 +141,24 @@ def __init__(self, "seq_lens tensor must be given if state inputs are defined") def get_placeholder(self, name): - """Returns the given loss input placeholder by name. + """Returns the given action or loss input placeholder by name. - These are the same placeholders passed in as the loss_inputs arg. If - the loss has not been initialized, an error is raised. + If the loss has not been initialized and a loss input placeholder is + requested, an error is raised. """ + obs_inputs = { + SampleBatch.CUR_OBS: self._obs_input, + SampleBatch.PREV_ACTIONS: self._prev_action_input, + SampleBatch.PREV_REWARDS: self._prev_reward_input, + } + if name in obs_inputs: + return obs_inputs[name] + if not self.loss_initialized(): raise RuntimeError( - "You cannot call policy.get_placeholder() before the loss " - "has been initialized. To avoid this, use " + "You cannot call policy.get_placeholder() for non-obs inputs " + "before the loss has been initialized. To avoid this, use " "policy.loss_initialized() to check whether this is the case.") return self._loss_input_dict[name] diff --git a/python/ray/rllib/evaluation/tf_policy_template.py b/python/ray/rllib/evaluation/tf_policy_template.py index f22b2e3af0bc..f16f422e39b8 100644 --- a/python/ray/rllib/evaluation/tf_policy_template.py +++ b/python/ray/rllib/evaluation/tf_policy_template.py @@ -27,7 +27,8 @@ def build_tf_policy(name, make_action_sampler=None, mixins=None, get_batch_divisibility_req=None, - update_ops_fn=None): + update_ops_fn=None, + obs_include_prev_action_reward=True): """Helper function for creating a dynamic tf policy at runtime. Arguments: @@ -70,6 +71,8 @@ def build_tf_policy(name, the divisibility requirement for sample batches update_ops_fn (func): optional function that returns a list overriding the update ops to run when applying gradients + obs_include_prev_action_reward (bool): whether to include the + previous action and reward in the model input Returns: a DynamicTFPolicyGraph instance that uses the specified args @@ -118,7 +121,8 @@ def before_loss_init_wrapper(policy, obs_space, action_space, update_ops_fn=update_ops_fn, before_loss_init=before_loss_init_wrapper, make_action_sampler=make_action_sampler, - existing_inputs=existing_inputs) + existing_inputs=existing_inputs, + obs_include_prev_action_reward=obs_include_prev_action_reward) if after_init: after_init(self, obs_space, action_space, config) From 9dabc72af4889996838ea31d45c5947aa8415198 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sun, 19 May 2019 17:00:37 -0700 Subject: [PATCH 06/10] cleanup template args --- .../rllib/evaluation/tf_policy_template.py | 38 ++++++++++++------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/python/ray/rllib/evaluation/tf_policy_template.py b/python/ray/rllib/evaluation/tf_policy_template.py index f16f422e39b8..f4a939745dca 100644 --- a/python/ray/rllib/evaluation/tf_policy_template.py +++ b/python/ray/rllib/evaluation/tf_policy_template.py @@ -15,12 +15,13 @@ def build_tf_policy(name, get_default_config=None, stats_fn=None, grad_stats_fn=None, - extra_action_fetches_fn=None, postprocess_fn=None, optimizer_fn=None, gradients_fn=None, + extra_action_fetches_fn=None, extra_action_feed_fn=None, extra_learn_fetches_fn=None, + extra_learn_feed_fn=None, before_init=None, before_loss_init=None, after_init=None, @@ -34,15 +35,13 @@ def build_tf_policy(name, Arguments: name (str): name of the graph (e.g., "PPOPolicy") loss_fn (func): function that returns a loss tensor the policy, - and dict of experience tensor placeholders + and dict of experience tensor placeholdes get_default_config (func): optional function that returns the default config to merge with any overrides stats_fn (func): optional function that returns a dict of TF fetches given the policy and batch input tensors grad_stats_fn (func): optional function that returns a dict of TF fetches given the policy and loss gradient tensors - extra_action_fetches_fn (func): optional function that returns - a dict of TF fetches given the policy object postprocess_fn (func): optional experience postprocessing function that takes the same args as PolicyGraph.postprocess_trajectory() optimizer_fn (func): optional function that returns a tf.Optimizer @@ -50,10 +49,14 @@ def build_tf_policy(name, gradients_fn (func): optional function that returns a list of gradients given a tf optimizer and loss tensor. If not specified, this defaults to optimizer.compute_gradients(loss) + extra_action_fetches_fn (func): optional function that returns + a dict of TF fetches given the policy object extra_action_feed_fn (func): optional function that returns a feed dict to also feed to TF when computing actions extra_learn_fetches_fn (func): optional function that returns a dict of extra values to fetch and return when learning on a batch + extra_learn_feed_fn (func): optional function that returns a feed dict + to also feed to TF when learning on a batch before_init (func): optional function to run at the beginning of policy init that takes the same arguments as the policy constructor before_loss_init (func): optional function to run prior to loss @@ -151,6 +154,19 @@ def gradients(self, optimizer, loss): else: return TFPolicyGraph.gradients(self, optimizer, loss) + @override(TFPolicyGraph) + def extra_compute_action_fetches(self): + return dict( + TFPolicyGraph.extra_compute_action_fetches(self), + **self._extra_action_fetches) + + @override(TFPolicyGraph) + def extra_compute_action_feed_dict(self): + if extra_action_feed_fn: + return extra_action_feed_fn(self) + else: + return TFPolicyGraph.extra_compute_action_feed_dict(self) + @override(TFPolicyGraph) def extra_compute_grad_fetches(self): if extra_learn_fetches_fn: @@ -162,17 +178,11 @@ def extra_compute_grad_fetches(self): return TFPolicyGraph.extra_compute_grad_fetches(self) @override(TFPolicyGraph) - def extra_compute_action_feed_dict(self): - if extra_action_feed_fn: - return extra_action_feed_fn(self) + def extra_compute_grad_feed_dict(self): + if extra_learn_feed_fn: + return extra_learn_feed_fn(self) else: - return TFPolicyGraph.extra_compute_action_feed_dict(self) - - @override(TFPolicyGraph) - def extra_compute_action_fetches(self): - return dict( - TFPolicyGraph.extra_compute_action_fetches(self), - **self._extra_action_fetches) + return TFPolicyGraph.extra_compute_grad_feed_dict(self) graph_cls.__name__ = name graph_cls.__qualname__ = name From d7c661d3748a7f32ab905cb4291b37f8acb76420 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sun, 19 May 2019 17:16:11 -0700 Subject: [PATCH 07/10] doc init --- .../rllib/evaluation/tf_policy_template.py | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/python/ray/rllib/evaluation/tf_policy_template.py b/python/ray/rllib/evaluation/tf_policy_template.py index f4a939745dca..bb1e277eff69 100644 --- a/python/ray/rllib/evaluation/tf_policy_template.py +++ b/python/ray/rllib/evaluation/tf_policy_template.py @@ -13,11 +13,12 @@ def build_tf_policy(name, loss_fn, get_default_config=None, - stats_fn=None, - grad_stats_fn=None, postprocess_fn=None, + stats_fn=None, + update_ops_fn=None, optimizer_fn=None, gradients_fn=None, + grad_stats_fn=None, extra_action_fetches_fn=None, extra_action_feed_fn=None, extra_learn_fetches_fn=None, @@ -28,27 +29,36 @@ def build_tf_policy(name, make_action_sampler=None, mixins=None, get_batch_divisibility_req=None, - update_ops_fn=None, obs_include_prev_action_reward=True): """Helper function for creating a dynamic tf policy at runtime. + Functions will be run in this order to initialize the policy: + 1. Placeholder setup: postprocess_fn + 2. Loss init: loss_fn, stats_fn, update_ops_fn + 3. Optimizer init: optimizer_fn, gradients_fn, grad_stats_fn + + This means that you can e.g., depend on any policy attributes created in + the running of `loss_fn` in later functions such as `stats_fn`. + Arguments: name (str): name of the graph (e.g., "PPOPolicy") loss_fn (func): function that returns a loss tensor the policy, and dict of experience tensor placeholdes get_default_config (func): optional function that returns the default config to merge with any overrides - stats_fn (func): optional function that returns a dict of - TF fetches given the policy and batch input tensors - grad_stats_fn (func): optional function that returns a dict of - TF fetches given the policy and loss gradient tensors postprocess_fn (func): optional experience postprocessing function that takes the same args as PolicyGraph.postprocess_trajectory() + stats_fn (func): optional function that returns a dict of + TF fetches given the policy and batch input tensors + update_ops_fn (func): optional function that returns a list overriding + the update ops to run when applying gradients optimizer_fn (func): optional function that returns a tf.Optimizer given the policy and config gradients_fn (func): optional function that returns a list of gradients given a tf optimizer and loss tensor. If not specified, this defaults to optimizer.compute_gradients(loss) + grad_stats_fn (func): optional function that returns a dict of + TF fetches given the policy and loss gradient tensors extra_action_fetches_fn (func): optional function that returns a dict of TF fetches given the policy object extra_action_feed_fn (func): optional function that returns a feed dict @@ -72,8 +82,6 @@ def build_tf_policy(name, precedence than the DynamicTFPolicyGraph class get_batch_divisibility_req (func): optional function that returns the divisibility requirement for sample batches - update_ops_fn (func): optional function that returns a list overriding - the update ops to run when applying gradients obs_include_prev_action_reward (bool): whether to include the previous action and reward in the model input From 0e347af2937c1403126d4f3e8b83914b16984ac1 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 23 May 2019 17:57:54 -0700 Subject: [PATCH 08/10] fix action placeholder --- python/ray/rllib/policy/dynamic_tf_policy.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/ray/rllib/policy/dynamic_tf_policy.py b/python/ray/rllib/policy/dynamic_tf_policy.py index 886215cf351f..afa72a0af709 100644 --- a/python/ray/rllib/policy/dynamic_tf_policy.py +++ b/python/ray/rllib/policy/dynamic_tf_policy.py @@ -237,7 +237,8 @@ def fake_array(tensor): SampleBatch.CUR_OBS: fake_array(self._obs_input), SampleBatch.NEXT_OBS: fake_array(self._obs_input), SampleBatch.DONES: np.array([False], dtype=np.bool), - SampleBatch.ACTIONS: np.array([self.action_space.sample()]), + SampleBatch.ACTIONS: fake_array( + ModelCatalog.get_action_placeholder(self.action_space)), SampleBatch.REWARDS: np.array([0], dtype=np.float32), } if self._obs_include_prev_action_reward: From be1daa00ddfe667650ea73b91ca08149399d5f3b Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 29 May 2019 21:30:59 -0700 Subject: [PATCH 09/10] port a3c policy to builder pattern --- python/ray/rllib/agents/a3c/a3c_tf_policy.py | 224 +++++++------------ 1 file changed, 85 insertions(+), 139 deletions(-) diff --git a/python/ray/rllib/agents/a3c/a3c_tf_policy.py b/python/ray/rllib/agents/a3c/a3c_tf_policy.py index eb5becceaa71..ed3676472850 100644 --- a/python/ray/rllib/agents/a3c/a3c_tf_policy.py +++ b/python/ray/rllib/agents/a3c/a3c_tf_policy.py @@ -4,20 +4,13 @@ from __future__ import division from __future__ import print_function -import gym - import ray -from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.explained_variance import explained_variance -from ray.rllib.policy.policy import Policy from ray.rllib.evaluation.postprocessing import compute_advantages, \ Postprocessing -from ray.rllib.policy.tf_policy import TFPolicy, \ - LearningRateSchedule -from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.utils.annotations import override +from ray.rllib.policy.tf_policy_template import build_tf_policy +from ray.rllib.policy.tf_policy import LearningRateSchedule from ray.rllib.utils import try_import_tf tf = try_import_tf() @@ -44,144 +37,97 @@ def __init__(self, self.entropy * entropy_coeff) -class A3CPostprocessing(object): - """Adds the VF preds and advantages fields to the trajectory.""" - - @override(TFPolicy) - def extra_compute_action_fetches(self): - return dict( - TFPolicy.extra_compute_action_fetches(self), - **{SampleBatch.VF_PREDS: self.vf}) - - @override(Policy) - def postprocess_trajectory(self, - sample_batch, - other_agent_batches=None, - episode=None): - completed = sample_batch[SampleBatch.DONES][-1] - if completed: - last_r = 0.0 - else: - next_state = [] - for i in range(len(self.model.state_in)): - next_state.append([sample_batch["state_out_{}".format(i)][-1]]) - last_r = self._value(sample_batch[SampleBatch.NEXT_OBS][-1], - sample_batch[SampleBatch.ACTIONS][-1], - sample_batch[SampleBatch.REWARDS][-1], - *next_state) - return compute_advantages(sample_batch, last_r, self.config["gamma"], - self.config["lambda"]) - - -class A3CTFPolicy(LearningRateSchedule, A3CPostprocessing, TFPolicy): - def __init__(self, observation_space, action_space, config): - config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config) - self.config = config - self.sess = tf.get_default_session() - - # Setup the policy - self.observations = tf.placeholder( - tf.float32, [None] + list(observation_space.shape)) - dist_class, logit_dim = ModelCatalog.get_action_dist( - action_space, self.config["model"]) - self.prev_actions = ModelCatalog.get_action_placeholder(action_space) - self.prev_rewards = tf.placeholder( - tf.float32, [None], name="prev_reward") - self.model = ModelCatalog.get_model({ - "obs": self.observations, - "prev_actions": self.prev_actions, - "prev_rewards": self.prev_rewards, - "is_training": self._get_is_training_placeholder(), - }, observation_space, action_space, logit_dim, self.config["model"]) - action_dist = dist_class(self.model.outputs) +def actor_critic_loss(policy, batch_tensors): + policy.loss = A3CLoss( + policy.action_dist, batch_tensors[SampleBatch.ACTIONS], + batch_tensors[Postprocessing.ADVANTAGES], + batch_tensors[Postprocessing.VALUE_TARGETS], policy.vf, + policy.config["vf_loss_coeff"], policy.config["entropy_coeff"]) + return policy.loss.total_loss + + +def postprocess_advantages(policy, + sample_batch, + other_agent_batches=None, + episode=None): + completed = sample_batch[SampleBatch.DONES][-1] + if completed: + last_r = 0.0 + else: + next_state = [] + for i in range(len(policy.model.state_in)): + next_state.append([sample_batch["state_out_{}".format(i)][-1]]) + last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1], + sample_batch[SampleBatch.ACTIONS][-1], + sample_batch[SampleBatch.REWARDS][-1], + *next_state) + return compute_advantages(sample_batch, last_r, policy.config["gamma"], + policy.config["lambda"]) + + +def add_value_function_fetch(policy): + return {SampleBatch.VF_PREDS: policy.vf} + + +class ValueNetworkMixin(object): + def __init__(self): self.vf = self.model.value_function() - self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, - tf.get_variable_scope().name) - - # Setup the policy loss - if isinstance(action_space, gym.spaces.Box): - ac_size = action_space.shape[0] - actions = tf.placeholder(tf.float32, [None, ac_size], name="ac") - elif isinstance(action_space, gym.spaces.Discrete): - actions = tf.placeholder(tf.int64, [None], name="ac") - else: - raise UnsupportedSpaceException( - "Action space {} is not supported for A3C.".format( - action_space)) - advantages = tf.placeholder(tf.float32, [None], name="advantages") - self.v_target = tf.placeholder(tf.float32, [None], name="v_target") - self.loss = A3CLoss(action_dist, actions, advantages, self.v_target, - self.vf, self.config["vf_loss_coeff"], - self.config["entropy_coeff"]) - - # Initialize TFPolicy - loss_in = [ - (SampleBatch.CUR_OBS, self.observations), - (SampleBatch.ACTIONS, actions), - (SampleBatch.PREV_ACTIONS, self.prev_actions), - (SampleBatch.PREV_REWARDS, self.prev_rewards), - (Postprocessing.ADVANTAGES, advantages), - (Postprocessing.VALUE_TARGETS, self.v_target), - ] - LearningRateSchedule.__init__(self, self.config["lr"], - self.config["lr_schedule"]) - TFPolicy.__init__( - self, - observation_space, - action_space, - self.sess, - obs_input=self.observations, - action_sampler=action_dist.sample(), - action_prob=action_dist.sampled_action_prob(), - loss=self.loss.total_loss, - model=self.model, - loss_inputs=loss_in, - state_inputs=self.model.state_in, - state_outputs=self.model.state_out, - prev_action_input=self.prev_actions, - prev_reward_input=self.prev_rewards, - seq_lens=self.model.seq_lens, - max_seq_len=self.config["model"]["max_seq_len"]) - - self.stats_fetches = { - LEARNER_STATS_KEY: { - "cur_lr": tf.cast(self.cur_lr, tf.float64), - "policy_loss": self.loss.pi_loss, - "policy_entropy": self.loss.entropy, - "grad_gnorm": tf.global_norm(self._grads), - "var_gnorm": tf.global_norm(self.var_list), - "vf_loss": self.loss.vf_loss, - "vf_explained_var": explained_variance(self.v_target, self.vf), - }, - } - - self.sess.run(tf.global_variables_initializer()) - - @override(Policy) - def get_initial_state(self): - return self.model.state_init - - @override(TFPolicy) - def gradients(self, optimizer, loss): - grads = tf.gradients(loss, self.var_list) - self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"]) - clipped_grads = list(zip(self.grads, self.var_list)) - return clipped_grads - - @override(TFPolicy) - def extra_compute_grad_fetches(self): - return self.stats_fetches def _value(self, ob, prev_action, prev_reward, *args): feed_dict = { - self.observations: [ob], - self.prev_actions: [prev_action], - self.prev_rewards: [prev_reward], + self.get_placeholder(SampleBatch.CUR_OBS): [ob], + self.get_placeholder(SampleBatch.PREV_ACTIONS): [prev_action], + self.get_placeholder(SampleBatch.PREV_REWARDS): [prev_reward], self.model.seq_lens: [1] } assert len(args) == len(self.model.state_in), \ (args, self.model.state_in) for k, v in zip(self.model.state_in, args): feed_dict[k] = v - vf = self.sess.run(self.vf, feed_dict) + vf = self.get_session().run(self.vf, feed_dict) return vf[0] + + +def stats(policy, batch_tensors): + return { + "cur_lr": tf.cast(policy.cur_lr, tf.float64), + "policy_loss": policy.loss.pi_loss, + "policy_entropy": policy.loss.entropy, + "var_gnorm": tf.global_norm(policy.var_list), + "vf_loss": policy.loss.vf_loss, + } + + +def grad_stats(policy, grads): + return { + "grad_gnorm": tf.global_norm(grads), + "vf_explained_var": explained_variance( + policy.get_placeholder(Postprocessing.VALUE_TARGETS), policy.vf), + } + + +def clip_gradients(policy, optimizer, loss): + grads = tf.gradients(loss, policy.var_list) + grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"]) + clipped_grads = list(zip(grads, policy.var_list)) + return clipped_grads + + +def setup_mixins(policy, obs_space, action_space, config): + ValueNetworkMixin.__init__(policy) + LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) + policy.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, + tf.get_variable_scope().name) + + +A3CTFPolicy = build_tf_policy( + name="A3CTFPolicy", + get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, + loss_fn=actor_critic_loss, + stats_fn=stats, + grad_stats_fn=grad_stats, + gradients_fn=clip_gradients, + postprocess_fn=postprocess_advantages, + extra_action_fetches_fn=add_value_function_fetch, + before_loss_init=setup_mixins, + mixins=[ValueNetworkMixin, LearningRateSchedule]) From c7ab0a9827a64da58b6b99cea5e8805f8cedeb8c Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 29 May 2019 21:34:54 -0700 Subject: [PATCH 10/10] add doc --- python/ray/rllib/policy/tf_policy.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/ray/rllib/policy/tf_policy.py b/python/ray/rllib/policy/tf_policy.py index bc0f711038eb..ed234f809512 100644 --- a/python/ray/rllib/policy/tf_policy.py +++ b/python/ray/rllib/policy/tf_policy.py @@ -158,7 +158,9 @@ def get_placeholder(self, name): raise RuntimeError( "You cannot call policy.get_placeholder() for non-obs inputs " "before the loss has been initialized. To avoid this, use " - "policy.loss_initialized() to check whether this is the case.") + "policy.loss_initialized() to check whether this is the " + "case, or move the call to later (e.g., from stats_fn to " + "grad_stats_fn).") return self._loss_input_dict[name]