Skip to content
This repository has been archived by the owner on Nov 15, 2021. It is now read-only.

Commit

Permalink
Add entropy coeff schedule
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanpantic committed Jun 26, 2019
1 parent b1827d5 commit d639737
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 12 deletions.
1 change: 1 addition & 0 deletions python/ray/rllib/agents/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
# balancing the three losses
"vf_loss_coeff": 0.5,
"entropy_coeff": 0.01,
"entropy_schedule": None,

# use fake (infinite speed) sampler for testing
"_fake_sampler": False,
Expand Down
10 changes: 7 additions & 3 deletions python/ray/rllib/agents/impala/vtrace_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy import TFPolicy, \
LearningRateSchedule
LearningRateSchedule, EntropyCoeffSchedule
from ray.rllib.models.action_dist import Categorical
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils.annotations import override
Expand Down Expand Up @@ -126,7 +126,7 @@ def postprocess_trajectory(self,
return sample_batch


class VTraceTFPolicy(LearningRateSchedule, VTracePostprocessing, TFPolicy):
class VTraceTFPolicy(LearningRateSchedule, EntropyCoeffSchedule, VTracePostprocessing, TFPolicy):
def __init__(self,
observation_space,
action_space,
Expand Down Expand Up @@ -241,6 +241,9 @@ def make_time_major(tensor, drop_last=False):
loss_actions = actions if is_multidiscrete else tf.expand_dims(
actions, axis=1)

EntropyCoeffSchedule.__init__(self, self.config["entropy_coeff"],
self.config["entropy_schedule"])

# Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
self.loss = VTraceLoss(
actions=make_time_major(loss_actions, drop_last=True),
Expand All @@ -259,7 +262,7 @@ def make_time_major(tensor, drop_last=False):
dist_class=Categorical if is_multidiscrete else dist_class,
valid_mask=make_time_major(mask, drop_last=True),
vf_loss_coeff=self.config["vf_loss_coeff"],
entropy_coeff=self.config["entropy_coeff"],
entropy_coeff=self.entropy_coeff,
clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"])

Expand Down Expand Up @@ -299,6 +302,7 @@ def make_time_major(tensor, drop_last=False):
self.stats_fetches = {
LEARNER_STATS_KEY: {
"cur_lr": tf.cast(self.cur_lr, tf.float64),
"entropy_coeff": tf.cast(self.entropy_coeff, tf.float64),
"policy_loss": self.loss.pi_loss,
"entropy": self.loss.entropy,
"grad_gnorm": tf.global_norm(self._grads),
Expand Down
38 changes: 29 additions & 9 deletions python/ray/rllib/policy/tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@
from __future__ import division
from __future__ import print_function

import os
import errno
import logging
import numpy as np
import os

import numpy as np
import ray
import ray.experimental.tf_utils
from ray.rllib.models.lstm import chop_into_sequences
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.models.lstm import chop_into_sequences
from ray.rllib.utils import try_import_tf
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.debug import log_once, summarize
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
from ray.rllib.utils.tf_run_builder import TFRunBuilder
from ray.rllib.utils import try_import_tf

tf = try_import_tf()
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -416,7 +416,7 @@ def _build_compute_actions(self,
if len(self._state_inputs) != len(state_batches):
raise ValueError(
"Must pass in RNN state batches for placeholders {}, got {}".
format(self._state_inputs, state_batches))
format(self._state_inputs, state_batches))
builder.add_feed_dict(self.extra_compute_action_feed_dict())
builder.add_feed_dict({self._obs_input: obs_batch})
if state_batches:
Expand All @@ -443,7 +443,7 @@ def _build_apply_gradients(self, builder, gradients):
if len(gradients) != len(self._grads):
raise ValueError(
"Unexpected number of gradients to apply, got {} for {}".
format(gradients, self._grads))
format(gradients, self._grads))
builder.add_feed_dict({self._is_training: True})
builder.add_feed_dict(dict(zip(self._grads, gradients)))
fetches = builder.add_fetches([self._apply_op])
Expand Down Expand Up @@ -473,9 +473,9 @@ def _get_loss_inputs_dict(self, batch):
feed_dict = {}
if self._batch_divisibility_req > 1:
meets_divisibility_reqs = (
len(batch[SampleBatch.CUR_OBS]) %
self._batch_divisibility_req == 0
and max(batch[SampleBatch.AGENT_INDEX]) == 0) # not multiagent
len(batch[SampleBatch.CUR_OBS]) %
self._batch_divisibility_req == 0
and max(batch[SampleBatch.AGENT_INDEX]) == 0) # not multiagent
else:
meets_divisibility_reqs = True

Expand Down Expand Up @@ -544,3 +544,23 @@ def on_global_var_update(self, global_vars):
@override(TFPolicy)
def optimizer(self):
return tf.train.AdamOptimizer(self.cur_lr)


@DeveloperAPI
class EntropyCoeffSchedule(object):
"""Mixin for TFPolicy that adds entropy coeff decay."""

@DeveloperAPI
def __init__(self, entropy_coeff, entropy_schedule):
self.entropy_coeff = tf.get_variable("entropy_coeff", initializer=entropy_coeff)
self._entropy_schedule = entropy_schedule

@override(Policy)
def on_global_var_update(self, global_vars):
super(EntropyCoeffSchedule, self).on_global_var_update(global_vars)
if self._entropy_schedule is not None:
self.entropy_coeff.load(
self.config['entropy_coeff'] *
(1 - global_vars['timestep'] /
self.config['entropy_schedule']),
session=self._sess)

0 comments on commit d639737

Please sign in to comment.