Skip to content

Commit

Permalink
loss initialized function
Browse files Browse the repository at this point in the history
  • Loading branch information
ericl committed May 19, 2019
1 parent d9ee507 commit 17364d8
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 21 deletions.
22 changes: 7 additions & 15 deletions python/ray/rllib/agents/dqn/dqn_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import print_function

from gym.spaces import Discrete
import logging
import numpy as np
from scipy.stats import entropy

Expand All @@ -20,8 +19,6 @@

tf = try_import_tf()

logger = logging.getLogger(__name__)

Q_SCOPE = "q_func"
Q_TARGET_SCOPE = "target_q_func"

Expand Down Expand Up @@ -343,21 +340,16 @@ def __init__(self, obs_space, action_space, config):
self.update_target_expr = tf.group(*update_target_expr)

def update_target(self):
return self._sess.run(self.update_target_expr)
return self.get_session().run(self.update_target_expr)


class ComputeTDErrorMixin(object):
def compute_td_error(self, obs_t, act_t, rew_t, obs_tp1, done_mask,
importance_weights):
if not hasattr(self, "loss"):
# this is a hack so that the td error placeholder can be
# generated by build_tf_policy()
logger.warn(
"compute_td_error() called before loss has been initialized, "
"returning zeros")
if not self.loss_initialized():
return np.zeros_like(rew_t)

td_err = self._sess.run(
td_err = self.get_session().run(
self.loss.td_error,
feed_dict={
self.get_placeholder(SampleBatch.CUR_OBS): [
Expand All @@ -382,10 +374,10 @@ def postprocess_trajectory(policy,
# adjust the sigma of parameter space noise
states = [list(x) for x in sample_batch.columns(["obs"])][0]

noisy_action_distribution = policy._sess.run(
noisy_action_distribution = policy.get_session().run(
policy.action_probs, feed_dict={policy.cur_observations: states})
policy._sess.run(policy.remove_noise_op)
clean_action_distribution = policy._sess.run(
policy.get_session().run(policy.remove_noise_op)
clean_action_distribution = policy.get_session().run(
policy.action_probs, feed_dict={policy.cur_observations: states})
distance_in_action_space = np.mean(
entropy(clean_action_distribution.T, noisy_action_distribution.T))
Expand All @@ -397,7 +389,7 @@ def postprocess_trajectory(policy,
else:
policy.parameter_noise_sigma_val /= 1.01
policy.parameter_noise_sigma.load(
policy.parameter_noise_sigma_val, session=policy._sess)
policy.parameter_noise_sigma_val, session=policy.get_session())

return _postprocess_dqn(policy, sample_batch)

Expand Down
4 changes: 2 additions & 2 deletions python/ray/rllib/agents/ppo/ppo_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def update_kl(self, sampled_kl):
self.kl_coeff_val *= 1.5
elif sampled_kl < 0.5 * self.kl_target:
self.kl_coeff_val *= 0.5
self.kl_coeff.load(self.kl_coeff_val, session=self._sess)
self.kl_coeff.load(self.kl_coeff_val, session=self.get_session())
return self.kl_coeff_val


Expand Down Expand Up @@ -264,7 +264,7 @@ def _value(self, ob, prev_action, prev_reward, *args):
(args, self.model.state_in)
for k, v in zip(self.model.state_in, args):
feed_dict[k] = v
vf = self._sess.run(self.value_function, feed_dict)
vf = self.get_session().run(self.value_function, feed_dict)
return vf[0]


Expand Down
23 changes: 19 additions & 4 deletions python/ray/rllib/evaluation/tf_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,26 @@ def __init__(self,
def get_placeholder(self, name):
"""Returns the given loss input placeholder by name.
These are the same placeholders passed in as the loss_inputs arg.
These are the same placeholders passed in as the loss_inputs arg. If
the loss has not been initialized, an error is raised.
"""

if not self.loss_initialized():
raise RuntimeError(
"You cannot call policy.get_placeholder() before the loss "
"has been initialized. To avoid this, use "
"policy.loss_initialized() to check whether this is the case.")

return self._loss_input_dict[name]

def get_session(self):
"""Returns a reference to the TF session for this policy."""
return self._sess

def loss_initialized(self):
"""Returns whether the loss function has been initialized."""
return self._loss is not None

def _initialize_loss(self, loss, loss_inputs):
self._loss_inputs = loss_inputs
self._loss_input_dict = dict(self._loss_inputs)
Expand Down Expand Up @@ -204,21 +219,21 @@ def compute_actions(self,

@override(PolicyGraph)
def compute_gradients(self, postprocessed_batch):
assert self._loss is not None, "Loss not initialized"
assert self.loss_initialized()
builder = TFRunBuilder(self._sess, "compute_gradients")
fetches = self._build_compute_gradients(builder, postprocessed_batch)
return builder.get(fetches)

@override(PolicyGraph)
def apply_gradients(self, gradients):
assert self._loss is not None, "Loss not initialized"
assert self.loss_initialized()
builder = TFRunBuilder(self._sess, "apply_gradients")
fetches = self._build_apply_gradients(builder, gradients)
builder.get(fetches)

@override(PolicyGraph)
def learn_on_batch(self, postprocessed_batch):
assert self._loss is not None, "Loss not initialized"
assert self.loss_initialized()
builder = TFRunBuilder(self._sess, "learn_on_batch")
fetches = self._build_learn_on_batch(builder, postprocessed_batch)
return builder.get(fetches)
Expand Down

0 comments on commit 17364d8

Please sign in to comment.