-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Remove vtrace_drop_last_ts
option and add proper vf bootstrapping to IMPALA and APPO.
#36013
[RLlib] Remove vtrace_drop_last_ts
option and add proper vf bootstrapping to IMPALA and APPO.
#36013
Changes from 5 commits
d53832a
faab512
935b7f7
4a44f70
7b3d7b5
c2451bb
5f0b92c
605cf6e
8a0e5c9
95bd73f
957e185
20de58c
a2730a4
1752f21
72f4375
4305ffa
644098e
c8cc8ae
65bb4be
56a3dea
bf0bb36
9c8c7e2
d2d0eb3
3718254
5f43198
0ab5703
17be7cc
8e46813
f55532c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -144,7 +144,6 @@ def loss( | |
is_multidiscrete = False | ||
output_hidden_shape = 1 | ||
|
||
# TODO: (sven) deprecate this when trajectory view API gets activated. | ||
def make_time_major(*args, **kw): | ||
return _make_time_major( | ||
self, train_batch.get(SampleBatch.SEQ_LENS), *args, **kw | ||
|
@@ -159,12 +158,28 @@ def make_time_major(*args, **kw): | |
prev_action_dist = dist_class(behaviour_logits, self.model) | ||
values = self.model.value_function() | ||
values_time_major = make_time_major(values) | ||
bootstrap_values_time_major = make_time_major( | ||
train_batch[SampleBatch.VALUES_BOOTSTRAPPED] | ||
) | ||
# Add values to bootstrap values to yield correct t=1 to T+1 trajectories, | ||
# with T being the rollout length (max trajectory len). | ||
# Note that the `SampleBatch.VALUES_BOOTSTRAPPED` values are always recorded | ||
# ONLY at the last ts of a trajectory (for the following timestep, | ||
# which is one past(!) the last ts). All other values in that tensor are | ||
# zero. | ||
shape = tf.shape(values_time_major) | ||
B = shape[1] | ||
values_time_major = tf.concat([values_time_major, tf.zeros((1, B))], axis=0) | ||
bootstrap_values_time_major = tf.concat( | ||
[tf.zeros((1, B)), bootstrap_values_time_major], axis=0 | ||
) | ||
values_time_major += bootstrap_values_time_major | ||
|
||
if self.is_recurrent(): | ||
max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS]) | ||
mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) | ||
mask = tf.reshape(mask, [-1]) | ||
mask = make_time_major(mask, drop_last=self.config["vtrace"]) | ||
mask = make_time_major(mask) | ||
|
||
def reduce_mean_valid(t): | ||
return tf.reduce_mean(tf.boolean_mask(t, mask)) | ||
|
@@ -173,11 +188,7 @@ def reduce_mean_valid(t): | |
reduce_mean_valid = tf.reduce_mean | ||
|
||
if self.config["vtrace"]: | ||
drop_last = self.config["vtrace_drop_last_ts"] | ||
logger.debug( | ||
"Using V-Trace surrogate loss (vtrace=True; " | ||
f"drop_last={drop_last})" | ||
) | ||
logger.debug("Using V-Trace surrogate loss (vtrace=True)") | ||
|
||
# Prepare actions for loss. | ||
loss_actions = ( | ||
|
@@ -188,9 +199,7 @@ def reduce_mean_valid(t): | |
old_policy_action_dist = dist_class(old_policy_behaviour_logits, model) | ||
|
||
# Prepare KL for Loss | ||
mean_kl = make_time_major( | ||
old_policy_action_dist.multi_kl(action_dist), drop_last=drop_last | ||
) | ||
mean_kl = make_time_major(old_policy_action_dist.multi_kl(action_dist)) | ||
|
||
unpacked_behaviour_logits = tf.split( | ||
behaviour_logits, output_hidden_shape, axis=1 | ||
|
@@ -203,25 +212,19 @@ def reduce_mean_valid(t): | |
with tf.device("/cpu:0"): | ||
vtrace_returns = vtrace.multi_from_logits( | ||
behaviour_policy_logits=make_time_major( | ||
unpacked_behaviour_logits, drop_last=drop_last | ||
unpacked_behaviour_logits | ||
), | ||
target_policy_logits=make_time_major( | ||
unpacked_old_policy_behaviour_logits, drop_last=drop_last | ||
), | ||
actions=tf.unstack( | ||
make_time_major(loss_actions, drop_last=drop_last), axis=2 | ||
unpacked_old_policy_behaviour_logits | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you not need to keep the tf unstack logic? I guess if vtrace tests are passing, then no... |
||
), | ||
actions=tf.unstack(make_time_major(loss_actions), axis=2), | ||
discounts=tf.cast( | ||
~make_time_major( | ||
tf.cast(dones, tf.bool), drop_last=drop_last | ||
), | ||
~make_time_major(tf.cast(dones, tf.bool)), | ||
tf.float32, | ||
) | ||
* self.config["gamma"], | ||
rewards=make_time_major(rewards, drop_last=drop_last), | ||
values=values_time_major[:-1] | ||
if drop_last | ||
else values_time_major, | ||
rewards=make_time_major(rewards), | ||
values=values_time_major[:-1], | ||
bootstrap_value=values_time_major[-1], | ||
dist_class=Categorical if is_multidiscrete else dist_class, | ||
model=model, | ||
|
@@ -233,14 +236,10 @@ def reduce_mean_valid(t): | |
), | ||
) | ||
|
||
actions_logp = make_time_major( | ||
action_dist.logp(actions), drop_last=drop_last | ||
) | ||
prev_actions_logp = make_time_major( | ||
prev_action_dist.logp(actions), drop_last=drop_last | ||
) | ||
actions_logp = make_time_major(action_dist.logp(actions)) | ||
prev_actions_logp = make_time_major(prev_action_dist.logp(actions)) | ||
old_policy_actions_logp = make_time_major( | ||
old_policy_action_dist.logp(actions), drop_last=drop_last | ||
old_policy_action_dist.logp(actions) | ||
) | ||
|
||
is_ratio = tf.clip_by_value( | ||
|
@@ -267,17 +266,12 @@ def reduce_mean_valid(t): | |
mean_policy_loss = -reduce_mean_valid(surrogate_loss) | ||
|
||
# The value function loss. | ||
if drop_last: | ||
delta = values_time_major[:-1] - vtrace_returns.vs | ||
else: | ||
delta = values_time_major - vtrace_returns.vs | ||
value_targets = vtrace_returns.vs | ||
delta = values_time_major[:-1] - value_targets | ||
mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta)) | ||
|
||
# The entropy loss. | ||
actions_entropy = make_time_major( | ||
action_dist.multi_entropy(), drop_last=True | ||
) | ||
actions_entropy = make_time_major(action_dist.multi_entropy()) | ||
mean_entropy = reduce_mean_valid(actions_entropy) | ||
|
||
else: | ||
|
@@ -312,7 +306,7 @@ def reduce_mean_valid(t): | |
value_targets = make_time_major( | ||
train_batch[Postprocessing.VALUE_TARGETS] | ||
) | ||
delta = values_time_major - value_targets | ||
delta = values_time_major[:-1] - value_targets | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we rename |
||
mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta)) | ||
|
||
# The entropy loss. | ||
|
@@ -353,7 +347,6 @@ def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: | |
self, | ||
train_batch.get(SampleBatch.SEQ_LENS), | ||
self.model.value_function(), | ||
drop_last=self.config["vtrace"] and self.config["vtrace_drop_last_ts"], | ||
) | ||
|
||
stats_dict = { | ||
|
@@ -388,6 +381,11 @@ def postprocess_trajectory( | |
other_agent_batches: Optional[SampleBatch] = None, | ||
episode: Optional["Episode"] = None, | ||
): | ||
# Call super's postprocess_trajectory first. | ||
sample_batch = super().postprocess_trajectory( | ||
sample_batch, other_agent_batches, episode | ||
) | ||
|
||
if not self.config["vtrace"]: | ||
sample_batch = compute_gae_for_sample_batch( | ||
self, sample_batch, other_agent_batches, episode | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comments in appo_tf_policy here. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,17 +61,39 @@ def compute_loss_for_module( | |
trajectory_len=hps.rollout_frag_or_episode_len, | ||
recurrent_seq_len=hps.recurrent_seq_len, | ||
) | ||
values_time_major = make_time_major( | ||
values, | ||
rewards_time_major = make_time_major( | ||
batch[SampleBatch.REWARDS], | ||
trajectory_len=hps.rollout_frag_or_episode_len, | ||
recurrent_seq_len=hps.recurrent_seq_len, | ||
) | ||
bootstrap_value = values_time_major[-1] | ||
rewards_time_major = make_time_major( | ||
batch[SampleBatch.REWARDS], | ||
values_time_major = make_time_major( | ||
values, | ||
trajectory_len=hps.rollout_frag_or_episode_len, | ||
recurrent_seq_len=hps.recurrent_seq_len, | ||
) | ||
bootstrap_values_time_major = make_time_major( | ||
batch[SampleBatch.VALUES_BOOTSTRAPPED] | ||
) | ||
# Then add the shifted-by-one bootstrapped values to that to yield the final | ||
# value tensor. Use the last ts in that resulting tensor as the | ||
# "bootstrapped" values for vtrace. | ||
shape = tf.shape(values_time_major) | ||
B = shape[1] | ||
# Augment `values_time_major` by one timestep at the end (all zeros). | ||
values_time_major = tf.concat([values_time_major, tf.zeros((1, B))], axis=0) | ||
# Augment `bootstrap_values_time_major` by one timestep at the beginning | ||
# (all zeros). | ||
bootstrap_values_time_major = tf.concat( | ||
[tf.zeros((1, B)), bootstrap_values_time_major], axis=0 | ||
) | ||
# Note that the `SampleBatch.VALUES_BOOTSTRAPPED` values are always recorded | ||
# ONLY at the last ts of a trajectory (for the following timestep, | ||
# which is one past(!) the last ts). All other values in that tensor are | ||
# zero. | ||
# Adding values and bootstrap_values yields the correct values+bootstrap | ||
# configuration, from which we can then take t=-1 (last timestep) to get | ||
# the bootstrap_value arg for the vtrace function below. | ||
values_time_major += bootstrap_values_time_major | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need to do this. below we end up passing in values_time_major[:-1] and values_time_major[-1] separately, which means that we never needed to expand values_time_major, and could have instead directly passed in the boostrap_value There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do think we still need this. For the following reason: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm struggling to understand how There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that |
||
|
||
# the discount factor that is used should be gamma except for timesteps where | ||
# the episode is terminated. In that case, the discount factor should be 0. | ||
|
@@ -93,8 +115,8 @@ def compute_loss_for_module( | |
behaviour_action_log_probs=behaviour_actions_logp_time_major, | ||
discounts=discounts_time_major, | ||
rewards=rewards_time_major, | ||
values=values_time_major, | ||
bootstrap_value=bootstrap_value, | ||
values=values_time_major[:-1], | ||
bootstrap_value=values_time_major[-1], | ||
clip_pg_rho_threshold=hps.vtrace_clip_pg_rho_threshold, | ||
clip_rho_threshold=hps.vtrace_clip_rho_threshold, | ||
) | ||
|
@@ -127,7 +149,7 @@ def compute_loss_for_module( | |
mean_pi_loss = -tf.math.reduce_mean(surrogate_loss) | ||
|
||
# The baseline loss. | ||
delta = values_time_major - vtrace_adjusted_target_values | ||
delta = values_time_major[:-1] - vtrace_adjusted_target_values | ||
mean_vf_loss = 0.5 * tf.math.reduce_mean(delta**2) | ||
|
||
# The entropy loss. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm finding this comment confusing. Can you annotate with what the shapes of values_time_major and bootstrap_values_time_major are supposed to be.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I expanded the docstring of
ray.rllib.evaluation.postprocessing.compute_bootstrap_value()
and added the computation example there (how to add the two columns together to get the final vtrace-usable value estimates).Then I linked from each of the 4 loss functions (IMPALA+APPO vs torch+tf) to this new docstring.