Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Remove vtrace_drop_last_ts option and add proper vf bootstrapping to IMPALA and APPO. #36013

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d53832a
wip
sven1977 Jun 2, 2023
faab512
wip
sven1977 Jun 2, 2023
935b7f7
wip
sven1977 Jun 2, 2023
4a44f70
wip
sven1977 Jun 2, 2023
7b3d7b5
wip
sven1977 Jun 2, 2023
c2451bb
fix test
sven1977 Jun 3, 2023
5f0b92c
fix test
sven1977 Jun 3, 2023
605cf6e
fix test
sven1977 Jun 3, 2023
8a0e5c9
wip
sven1977 Jun 5, 2023
95bd73f
fix
sven1977 Jun 5, 2023
957e185
wip
sven1977 Jun 9, 2023
20de58c
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Jun 9, 2023
a2730a4
wip
sven1977 Jun 9, 2023
1752f21
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Jun 14, 2023
72f4375
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Jun 20, 2023
4305ffa
wip
sven1977 Jun 20, 2023
644098e
wip
sven1977 Jun 20, 2023
c8cc8ae
wip
sven1977 Jun 20, 2023
65bb4be
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Jun 21, 2023
56a3dea
wip
sven1977 Jun 21, 2023
bf0bb36
wip
sven1977 Jun 21, 2023
9c8c7e2
wip
sven1977 Jun 21, 2023
d2d0eb3
wip
sven1977 Jun 21, 2023
3718254
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Jun 21, 2023
5f43198
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Jun 21, 2023
0ab5703
wip
sven1977 Jun 21, 2023
17be7cc
wip
sven1977 Jun 21, 2023
8e46813
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Jun 22, 2023
f55532c
wip
sven1977 Jun 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 36 additions & 38 deletions rllib/algorithms/appo/appo_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def loss(
is_multidiscrete = False
output_hidden_shape = 1

# TODO: (sven) deprecate this when trajectory view API gets activated.
def make_time_major(*args, **kw):
return _make_time_major(
self, train_batch.get(SampleBatch.SEQ_LENS), *args, **kw
Expand All @@ -159,12 +158,28 @@ def make_time_major(*args, **kw):
prev_action_dist = dist_class(behaviour_logits, self.model)
values = self.model.value_function()
values_time_major = make_time_major(values)
bootstrap_values_time_major = make_time_major(
train_batch[SampleBatch.VALUES_BOOTSTRAPPED]
)
# Add values to bootstrap values to yield correct t=1 to T+1 trajectories,
# with T being the rollout length (max trajectory len).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm finding this comment confusing. Can you annotate with what the shapes of values_time_major and bootstrap_values_time_major are supposed to be.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expanded the docstring of ray.rllib.evaluation.postprocessing.compute_bootstrap_value() and added the computation example there (how to add the two columns together to get the final vtrace-usable value estimates).
Then I linked from each of the 4 loss functions (IMPALA+APPO vs torch+tf) to this new docstring.

# Note that the `SampleBatch.VALUES_BOOTSTRAPPED` values are always recorded
# ONLY at the last ts of a trajectory (for the following timestep,
# which is one past(!) the last ts). All other values in that tensor are
# zero.
shape = tf.shape(values_time_major)
B = shape[1]
values_time_major = tf.concat([values_time_major, tf.zeros((1, B))], axis=0)
bootstrap_values_time_major = tf.concat(
[tf.zeros((1, B)), bootstrap_values_time_major], axis=0
)
values_time_major += bootstrap_values_time_major

if self.is_recurrent():
max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS])
mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
mask = tf.reshape(mask, [-1])
mask = make_time_major(mask, drop_last=self.config["vtrace"])
mask = make_time_major(mask)

def reduce_mean_valid(t):
return tf.reduce_mean(tf.boolean_mask(t, mask))
Expand All @@ -173,11 +188,7 @@ def reduce_mean_valid(t):
reduce_mean_valid = tf.reduce_mean

if self.config["vtrace"]:
drop_last = self.config["vtrace_drop_last_ts"]
logger.debug(
"Using V-Trace surrogate loss (vtrace=True; "
f"drop_last={drop_last})"
)
logger.debug("Using V-Trace surrogate loss (vtrace=True)")

# Prepare actions for loss.
loss_actions = (
Expand All @@ -188,9 +199,7 @@ def reduce_mean_valid(t):
old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)

# Prepare KL for Loss
mean_kl = make_time_major(
old_policy_action_dist.multi_kl(action_dist), drop_last=drop_last
)
mean_kl = make_time_major(old_policy_action_dist.multi_kl(action_dist))

unpacked_behaviour_logits = tf.split(
behaviour_logits, output_hidden_shape, axis=1
Expand All @@ -203,25 +212,19 @@ def reduce_mean_valid(t):
with tf.device("/cpu:0"):
vtrace_returns = vtrace.multi_from_logits(
behaviour_policy_logits=make_time_major(
unpacked_behaviour_logits, drop_last=drop_last
unpacked_behaviour_logits
),
target_policy_logits=make_time_major(
unpacked_old_policy_behaviour_logits, drop_last=drop_last
),
actions=tf.unstack(
make_time_major(loss_actions, drop_last=drop_last), axis=2
unpacked_old_policy_behaviour_logits
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you not need to keep the tf unstack logic? I guess if vtrace tests are passing, then no...

),
actions=tf.unstack(make_time_major(loss_actions), axis=2),
discounts=tf.cast(
~make_time_major(
tf.cast(dones, tf.bool), drop_last=drop_last
),
~make_time_major(tf.cast(dones, tf.bool)),
tf.float32,
)
* self.config["gamma"],
rewards=make_time_major(rewards, drop_last=drop_last),
values=values_time_major[:-1]
if drop_last
else values_time_major,
rewards=make_time_major(rewards),
values=values_time_major[:-1],
bootstrap_value=values_time_major[-1],
dist_class=Categorical if is_multidiscrete else dist_class,
model=model,
Expand All @@ -233,14 +236,10 @@ def reduce_mean_valid(t):
),
)

actions_logp = make_time_major(
action_dist.logp(actions), drop_last=drop_last
)
prev_actions_logp = make_time_major(
prev_action_dist.logp(actions), drop_last=drop_last
)
actions_logp = make_time_major(action_dist.logp(actions))
prev_actions_logp = make_time_major(prev_action_dist.logp(actions))
old_policy_actions_logp = make_time_major(
old_policy_action_dist.logp(actions), drop_last=drop_last
old_policy_action_dist.logp(actions)
)

is_ratio = tf.clip_by_value(
Expand All @@ -267,17 +266,12 @@ def reduce_mean_valid(t):
mean_policy_loss = -reduce_mean_valid(surrogate_loss)

# The value function loss.
if drop_last:
delta = values_time_major[:-1] - vtrace_returns.vs
else:
delta = values_time_major - vtrace_returns.vs
value_targets = vtrace_returns.vs
delta = values_time_major[:-1] - value_targets
mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta))

# The entropy loss.
actions_entropy = make_time_major(
action_dist.multi_entropy(), drop_last=True
)
actions_entropy = make_time_major(action_dist.multi_entropy())
mean_entropy = reduce_mean_valid(actions_entropy)

else:
Expand Down Expand Up @@ -312,7 +306,7 @@ def reduce_mean_valid(t):
value_targets = make_time_major(
train_batch[Postprocessing.VALUE_TARGETS]
)
delta = values_time_major - value_targets
delta = values_time_major[:-1] - value_targets
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we rename values_time_major to something like values_time_major_w_bootstrap_value?

mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta))

# The entropy loss.
Expand Down Expand Up @@ -353,7 +347,6 @@ def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
self,
train_batch.get(SampleBatch.SEQ_LENS),
self.model.value_function(),
drop_last=self.config["vtrace"] and self.config["vtrace_drop_last_ts"],
)

stats_dict = {
Expand Down Expand Up @@ -388,6 +381,11 @@ def postprocess_trajectory(
other_agent_batches: Optional[SampleBatch] = None,
episode: Optional["Episode"] = None,
):
# Call super's postprocess_trajectory first.
sample_batch = super().postprocess_trajectory(
sample_batch, other_agent_batches, episode
)

if not self.config["vtrace"]:
sample_batch = compute_gae_for_sample_batch(
self, sample_batch, other_agent_batches, episode
Expand Down
68 changes: 32 additions & 36 deletions rllib/algorithms/appo/appo_torch_policy.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comments in appo_tf_policy here.

Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,28 @@ def _make_time_major(*args, **kwargs):
prev_action_dist = dist_class(behaviour_logits, model)
values = model.value_function()
values_time_major = _make_time_major(values)
bootstrap_values_time_major = _make_time_major(
train_batch[SampleBatch.VALUES_BOOTSTRAPPED]
)

drop_last = self.config["vtrace"] and self.config["vtrace_drop_last_ts"]
# Add values to bootstrap values to yield correct t=1 to T+1 trajectories,
# with T being the rollout length (max trajectory len).
# Note that the `SampleBatch.VALUES_BOOTSTRAPPED` values are always recorded
# ONLY at the last ts of a trajectory (for the following timestep,
# which is one past(!) the last ts). All other values in that tensor are
# zero.
_, B = values_time_major.shape
values_time_major = torch.cat([values_time_major, torch.zeros((1, B))], dim=0)
bootstrap_values_time_major = torch.cat(
[torch.zeros((1, B)), bootstrap_values_time_major], dim=0
)
values_time_major += bootstrap_values_time_major

if self.is_recurrent():
max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS])
mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
mask = torch.reshape(mask, [-1])
mask = _make_time_major(mask, drop_last=drop_last)
mask = _make_time_major(mask)
num_valid = torch.sum(mask)

def reduce_mean_valid(t):
Expand All @@ -174,9 +188,7 @@ def reduce_mean_valid(t):
reduce_mean_valid = torch.mean

if self.config["vtrace"]:
logger.debug(
"Using V-Trace surrogate loss (vtrace=True; " f"drop_last={drop_last})"
)
logger.debug("Using V-Trace surrogate loss (vtrace=True)")

old_policy_behaviour_logits = target_model_out.detach()
old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)
Expand All @@ -202,40 +214,30 @@ def reduce_mean_valid(t):
)

# Prepare KL for loss.
action_kl = _make_time_major(
old_policy_action_dist.kl(action_dist), drop_last=drop_last
)
action_kl = _make_time_major(old_policy_action_dist.kl(action_dist))

# Compute vtrace on the CPU for better perf.
vtrace_returns = vtrace.multi_from_logits(
behaviour_policy_logits=_make_time_major(
unpacked_behaviour_logits, drop_last=drop_last
),
behaviour_policy_logits=_make_time_major(unpacked_behaviour_logits),
target_policy_logits=_make_time_major(
unpacked_old_policy_behaviour_logits, drop_last=drop_last
),
actions=torch.unbind(
_make_time_major(loss_actions, drop_last=drop_last), dim=2
unpacked_old_policy_behaviour_logits
),
discounts=(1.0 - _make_time_major(dones, drop_last=drop_last).float())
actions=torch.unbind(_make_time_major(loss_actions), dim=2),
discounts=(1.0 - _make_time_major(dones).float())
* self.config["gamma"],
rewards=_make_time_major(rewards, drop_last=drop_last),
values=values_time_major[:-1] if drop_last else values_time_major,
rewards=_make_time_major(rewards),
values=values_time_major[:-1],
bootstrap_value=values_time_major[-1],
dist_class=TorchCategorical if is_multidiscrete else dist_class,
model=model,
clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"],
)

actions_logp = _make_time_major(
action_dist.logp(actions), drop_last=drop_last
)
prev_actions_logp = _make_time_major(
prev_action_dist.logp(actions), drop_last=drop_last
)
actions_logp = _make_time_major(action_dist.logp(actions))
prev_actions_logp = _make_time_major(prev_action_dist.logp(actions))
old_policy_actions_logp = _make_time_major(
old_policy_action_dist.logp(actions), drop_last=drop_last
old_policy_action_dist.logp(actions)
)
is_ratio = torch.clamp(
torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0
Expand All @@ -259,16 +261,11 @@ def reduce_mean_valid(t):

# The value function loss.
value_targets = vtrace_returns.vs.to(values_time_major.device)
if drop_last:
delta = values_time_major[:-1] - value_targets
else:
delta = values_time_major - value_targets
delta = values_time_major[:-1] - value_targets
mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))

# The entropy loss.
mean_entropy = reduce_mean_valid(
_make_time_major(action_dist.entropy(), drop_last=drop_last)
)
mean_entropy = reduce_mean_valid(_make_time_major(action_dist.entropy()))

else:
logger.debug("Using PPO surrogate loss (vtrace=False)")
Expand Down Expand Up @@ -296,7 +293,7 @@ def reduce_mean_valid(t):

# The value function loss.
value_targets = _make_time_major(train_batch[Postprocessing.VALUE_TARGETS])
delta = values_time_major - value_targets
delta = values_time_major[:-1] - value_targets
mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))

# The entropy loss.
Expand All @@ -323,9 +320,7 @@ def reduce_mean_valid(t):
model.tower_stats["value_targets"] = value_targets
model.tower_stats["vf_explained_var"] = explained_variance(
torch.reshape(value_targets, [-1]),
torch.reshape(
values_time_major[:-1] if drop_last else values_time_major, [-1]
),
torch.reshape(values_time_major[:-1], [-1]),
)

return total_loss
Expand Down Expand Up @@ -402,6 +397,7 @@ def postprocess_trajectory(
sample_batch = compute_gae_for_sample_batch(
self, sample_batch, other_agent_batches, episode
)

return sample_batch

@override(TorchPolicyV2)
Expand Down
38 changes: 30 additions & 8 deletions rllib/algorithms/appo/tf/appo_tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,39 @@ def compute_loss_for_module(
trajectory_len=hps.rollout_frag_or_episode_len,
recurrent_seq_len=hps.recurrent_seq_len,
)
values_time_major = make_time_major(
values,
rewards_time_major = make_time_major(
batch[SampleBatch.REWARDS],
trajectory_len=hps.rollout_frag_or_episode_len,
recurrent_seq_len=hps.recurrent_seq_len,
)
bootstrap_value = values_time_major[-1]
rewards_time_major = make_time_major(
batch[SampleBatch.REWARDS],
values_time_major = make_time_major(
values,
trajectory_len=hps.rollout_frag_or_episode_len,
recurrent_seq_len=hps.recurrent_seq_len,
)
bootstrap_values_time_major = make_time_major(
batch[SampleBatch.VALUES_BOOTSTRAPPED]
)
# Then add the shifted-by-one bootstrapped values to that to yield the final
# value tensor. Use the last ts in that resulting tensor as the
# "bootstrapped" values for vtrace.
shape = tf.shape(values_time_major)
B = shape[1]
# Augment `values_time_major` by one timestep at the end (all zeros).
values_time_major = tf.concat([values_time_major, tf.zeros((1, B))], axis=0)
# Augment `bootstrap_values_time_major` by one timestep at the beginning
# (all zeros).
bootstrap_values_time_major = tf.concat(
[tf.zeros((1, B)), bootstrap_values_time_major], axis=0
)
# Note that the `SampleBatch.VALUES_BOOTSTRAPPED` values are always recorded
# ONLY at the last ts of a trajectory (for the following timestep,
# which is one past(!) the last ts). All other values in that tensor are
# zero.
# Adding values and bootstrap_values yields the correct values+bootstrap
# configuration, from which we can then take t=-1 (last timestep) to get
# the bootstrap_value arg for the vtrace function below.
values_time_major += bootstrap_values_time_major
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to do this. below we end up passing in values_time_major[:-1] and values_time_major[-1] separately, which means that we never needed to expand values_time_major, and could have instead directly passed in the boostrap_value

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think we still need this. For the following reason:
In APPO and IMPALA, we build/preprocess the train batch along the rollout_fragment_len (or max_seq_len if LSTM) boundaries. This means that in case a rollout ends within an episode, this rollout's last trajectory will end up in the train batch with a zero-padded right side, thus, the bootstrapped value for this fragment is in the middle of the train batch, NOT at time-axis index -1!
So adding these two together here covers that particular case as well. It'll lead to most bootstrapped value to be at the end of the train batch rows, but some (in those cases where the rollout was done within an episode) will be located in the middle of the train batch's rows. Here: "row" means a trajectory (along the T-axis) within the (B, T, ...) train batch.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm struggling to understand how compute_bootstrap_value handles this situation where the rollout ends within an episode. I'm reading the code for it, and on the surface I don't see anything that searches in the middle of an episode for the reward at the terminated timestep. It looks like we're only checking the last timestep in the sample batch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that compute_bootstrap_value is only ever called by a Policy's postprocess_trajectory method, which - by design - is only called when a rollout has ended (either within or at the terminal of an episode).
If at a terminal: Assume the value to be 0.0 (no value computation necessary)
If NOT at a terminal: Use the Policy's vf to compute the value at the last timestep of the trajectory. This is the "bootstrap" value to be used in the losses (instead of dropping the last ts and using that ts as a "bootstrapped" value).


# the discount factor that is used should be gamma except for timesteps where
# the episode is terminated. In that case, the discount factor should be 0.
Expand All @@ -93,8 +115,8 @@ def compute_loss_for_module(
behaviour_action_log_probs=behaviour_actions_logp_time_major,
discounts=discounts_time_major,
rewards=rewards_time_major,
values=values_time_major,
bootstrap_value=bootstrap_value,
values=values_time_major[:-1],
bootstrap_value=values_time_major[-1],
clip_pg_rho_threshold=hps.vtrace_clip_pg_rho_threshold,
clip_rho_threshold=hps.vtrace_clip_rho_threshold,
)
Expand Down Expand Up @@ -127,7 +149,7 @@ def compute_loss_for_module(
mean_pi_loss = -tf.math.reduce_mean(surrogate_loss)

# The baseline loss.
delta = values_time_major - vtrace_adjusted_target_values
delta = values_time_major[:-1] - vtrace_adjusted_target_values
mean_vf_loss = 0.5 * tf.math.reduce_mean(delta**2)

# The entropy loss.
Expand Down
Loading