Skip to content

Commit

Permalink
support eager with ppo
Browse files Browse the repository at this point in the history
  • Loading branch information
ericl committed Jun 4, 2019
1 parent 77daae0 commit 84b0553
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 17 deletions.
14 changes: 7 additions & 7 deletions ci/jenkins_tests/run_rllib_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,6 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
--stop '{"training_iteration": 1}' \
--config '{"num_workers": 2}'

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output /ray/python/ray/rllib/train.py \
--env CartPole-v0 \
--run PG \
--stop '{"training_iteration": 1}' \
--config '{"use_eager": true}'

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output /ray/python/ray/rllib/train.py \
--env CartPole-v0 \
Expand Down Expand Up @@ -402,6 +395,13 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/eager_execution.py --iters=2

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output /ray/python/ray/rllib/train.py \
--env CartPole-v0 \
--run PPO \
--stop '{"training_iteration": 1}' \
--config '{"use_eager": true, "simple_optimizer": true}'

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_tf_policy.py --iters=2

Expand Down
2 changes: 1 addition & 1 deletion doc/source/rllib-concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ While RLlib runs all TF operations in graph mode, you can still leverage TensorF
You can find a runnable file for the above eager execution example `here <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/eager_execution.py>`__.

There is also experimental support for automatically running the entire loss function in eager mode. This can be enabled with ``use_eager: True``, e.g., ``rllib train --env=CartPole-v0 --run=PG --config='{"use_eager": true}'``. However this currently only works for basic algorithms such as PG.
There is also experimental support for running the entire loss function in eager mode. This can be enabled with ``use_eager: True``, e.g., ``rllib train --env=CartPole-v0 --run=PPO --config='{"use_eager": true, "simple_optimizer": true}'``. However this currently only works for a couple algorithms.

Building Policies in PyTorch
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
2 changes: 2 additions & 0 deletions doc/source/rllib-training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,8 @@ TensorFlow Eager

While RLlib uses TF graph mode for all computations, you can still leverage TF eager to inspect the intermediate state of computations using `tf.py_function <https://www.tensorflow.org/api_docs/python/tf/py_function>`__. Here's an example of using eager mode in `a custom RLlib model and loss <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/eager_execution.py>`__.

There is also experimental support for running the entire loss function in eager mode. This can be enabled with ``use_eager: True``, e.g., ``rllib train --env=CartPole-v0 --run=PPO --config='{"use_eager": true, "simple_optimizer": true}'``. However this currently only works for a couple algorithms.

Episode Traces
~~~~~~~~~~~~~~

Expand Down
10 changes: 6 additions & 4 deletions python/ray/rllib/agents/ppo/ppo_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,10 @@ def reduce_mean_valid(t):

def ppo_surrogate_loss(policy, batch_tensors):
if policy.model.state_in:
max_seq_len = tf.reduce_max(policy.model.seq_lens)
mask = tf.sequence_mask(policy.model.seq_lens, max_seq_len)
max_seq_len = tf.reduce_max(
policy.convert_to_eager(policy.model.seq_lens))
mask = tf.sequence_mask(
policy.convert_to_eager(policy.model.seq_lens), max_seq_len)
mask = tf.reshape(mask, [-1])
else:
mask = tf.ones_like(
Expand All @@ -121,8 +123,8 @@ def ppo_surrogate_loss(policy, batch_tensors):
batch_tensors[BEHAVIOUR_LOGITS],
batch_tensors[SampleBatch.VF_PREDS],
policy.action_dist,
policy.value_function,
policy.kl_coeff,
policy.convert_to_eager(policy.value_function),
policy.convert_to_eager(policy.kl_coeff),
mask,
entropy_coeff=policy.config["entropy_coeff"],
clip_param=policy.config["clip_param"],
Expand Down
33 changes: 28 additions & 5 deletions python/ray/rllib/policy/dynamic_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ def __init__(self,
batch_divisibility_req=batch_divisibility_req)

# Phase 2 init
self._needs_eager_conversion = set()
self._eager_tensors = {}
before_loss_init(self, obs_space, action_space, config)
if not existing_inputs:
self._initialize_loss()
Expand All @@ -178,12 +180,25 @@ def get_obs_input_dict(self):
"""
return self.input_dict

def convert_to_eager(self, tensor):
"""Convert a graph tensor accessed in the loss to an eager tensor.
Experimental.
"""
if tf.executing_eagerly():
return self._eager_tensors[tensor]
else:
self._needs_eager_conversion.add(tensor)
return tensor

@override(TFPolicy)
def copy(self, existing_inputs):
"""Creates a copy of self using existing input placeholders."""

if self.config["use_eager"]:
raise ValueError("eager not implemented in this case")
raise ValueError(
"eager not implemented for multi-GPU, try setting "
"`simple_optimizer: true`")

# Note that there might be RNN state inputs at the end of the list
if self._state_inputs:
Expand Down Expand Up @@ -304,23 +319,31 @@ def fake_array(tensor):
# XXX experimental support for automatically eagerifying the loss.
# The main limitation right now is that TF doesn't support mixing eager
# and non-eager tensors, so losses that read non-eager tensors through
# the `policy` reference will crash.
# `policy` need to use `policy.convert_to_eager(tensor)`.
if self.config["use_eager"]:
if not self.model:
raise ValueError("eager not implemented in this case")
graph_tensors = list(self._needs_eager_conversion)

def gen_loss(model_outputs, *args):
eager_inputs = dict(zip([k for (k, v) in loss_inputs], args))
# fill in the batch tensor dict with eager ensors
eager_inputs = dict(
zip([k for (k, v) in loss_inputs],
args[:len(loss_inputs)]))
# fill in the eager versions of all accessed graph tensors
self._eager_tensors = dict(
zip(graph_tensors, args[len(loss_inputs):]))
# patch the action dist to use eager mode tensors
self.action_dist.inputs = model_outputs
return self._loss_fn(self, eager_inputs)

loss = tf.py_function(
gen_loss,
[self.model.outputs] +
# cast works around TypeError: Cannot convert provided value
# to EagerTensor. Provided value: 0.0 Requested dtype: int64
[tf.cast(v, tf.float32) for (k, v) in loss_inputs],
[self.model.outputs] + [
tf.cast(v, tf.float32) for (k, v) in loss_inputs
] + [tf.cast(t, tf.float32) for t in graph_tensors],
tf.float32)

TFPolicy._initialize_loss(self, loss, loss_inputs)
Expand Down

0 comments on commit 84b0553

Please sign in to comment.