Skip to content

Commit

Permalink
[RLlib] APPO enhancements (new API stack) vol 02: Cleanup loss functi…
Browse files Browse the repository at this point in the history
…on, add GAE-lambda to vtrace, make rho-clip configurable. (ray-project#48800)

Signed-off-by: hjiang <dentinyhao@gmail.com>
  • Loading branch information
sven1977 authored and dentiny committed Dec 7, 2024
1 parent 95e5c4d commit 91f5f14
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 47 deletions.
10 changes: 8 additions & 2 deletions rllib/algorithms/appo/appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@

LEARNER_RESULTS_KL_KEY = "mean_kl_loss"
LEARNER_RESULTS_CURR_KL_COEFF_KEY = "curr_kl_coeff"
OLD_ACTION_DIST_KEY = "old_action_dist"
OLD_ACTION_DIST_LOGITS_KEY = "old_action_dist_logits"
TARGET_ACTION_DIST_LOGITS_KEY = "target_action_dist_logits"


class APPOConfig(IMPALAConfig):
Expand Down Expand Up @@ -108,6 +107,7 @@ def __init__(self, algo_class=None):
self.use_kl_loss = False
self.kl_coeff = 1.0
self.kl_target = 0.01
self.target_worker_clipping = 2.0
# TODO (sven): Activate once v-trace sequences in non-RNN batch are solved.
# If we switch this on right now, the shuffling would destroy the rollout
# sequences (non-zero-padded!) needed in the batch for v-trace.
Expand Down Expand Up @@ -163,6 +163,7 @@ def training(
kl_target: Optional[float] = NotProvided,
tau: Optional[float] = NotProvided,
target_network_update_freq: Optional[int] = NotProvided,
target_worker_clipping: Optional[float] = NotProvided,
# Deprecated keys.
target_update_frequency=DEPRECATED_VALUE,
**kwargs,
Expand Down Expand Up @@ -193,6 +194,9 @@ def training(
on before updating the target networks and tune the kl loss
coefficients. NOTE: This parameter is only applicable when using the
Learner API (enable_rl_module_and_learner=True).
target_worker_clipping: The maximum value for the target-worker-clipping
used for computing the IS ratio, described in [1]
IS = min(π(i) / π(target), ρ) * (π / π(i))
Returns:
This updated AlgorithmConfig object.
Expand Down Expand Up @@ -227,6 +231,8 @@ def training(
self.tau = tau
if target_network_update_freq is not NotProvided:
self.target_network_update_freq = target_network_update_freq
if target_worker_clipping is not NotProvided:
self.target_worker_clipping = target_worker_clipping

return self

Expand Down
4 changes: 2 additions & 2 deletions rllib/algorithms/appo/appo_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Dict, List, Tuple

from ray.rllib.algorithms.ppo.ppo_rl_module import PPORLModule
from ray.rllib.algorithms.appo.appo import OLD_ACTION_DIST_LOGITS_KEY
from ray.rllib.algorithms.appo.appo import TARGET_ACTION_DIST_LOGITS_KEY
from ray.rllib.core.learner.utils import make_target_network
from ray.rllib.core.models.base import ACTOR
from ray.rllib.core.models.tf.encoder import ENCODER_OUT
Expand Down Expand Up @@ -32,7 +32,7 @@ def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]:
def forward_target(self, batch: Dict[str, Any]) -> Dict[str, Any]:
old_pi_inputs_encoded = self._old_encoder(batch)[ENCODER_OUT][ACTOR]
old_action_dist_logits = self._old_pi(old_pi_inputs_encoded)
return {OLD_ACTION_DIST_LOGITS_KEY: old_action_dist_logits}
return {TARGET_ACTION_DIST_LOGITS_KEY: old_action_dist_logits}

@OverrideToImplementCustomLogic_CallToSuperRecommended
@override(PPORLModule)
Expand Down
126 changes: 83 additions & 43 deletions rllib/algorithms/appo/torch/appo_torch_learner.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
"""Asynchronous Proximal Policy Optimization (APPO)
The algorithm is described in [1] (under the name of "IMPACT"):
Detailed documentation:
https://docs.ray.io/en/master/rllib-algorithms.html#appo
[1] IMPACT: Importance Weighted Asynchronous Architectures with Clipped Target Networks.
Luo et al. 2020
https://arxiv.org/pdf/1912.00167
"""
from typing import Dict

from ray.rllib.algorithms.appo.appo import (
APPOConfig,
LEARNER_RESULTS_CURR_KL_COEFF_KEY,
LEARNER_RESULTS_KL_KEY,
OLD_ACTION_DIST_LOGITS_KEY,
TARGET_ACTION_DIST_LOGITS_KEY,
)
from ray.rllib.algorithms.appo.appo_learner import APPOLearner
from ray.rllib.algorithms.impala.torch.impala_torch_learner import IMPALATorchLearner
Expand Down Expand Up @@ -60,45 +71,49 @@ def compute_loss_for_module(
)

action_dist_cls_train = module.get_train_action_dist_cls()
target_policy_dist = action_dist_cls_train.from_logits(
fwd_out[Columns.ACTION_DIST_INPUTS]
)

old_target_policy_dist = action_dist_cls_train.from_logits(
module.forward_target(batch)[OLD_ACTION_DIST_LOGITS_KEY]
)
old_target_policy_actions_logp = old_target_policy_dist.logp(
batch[Columns.ACTIONS]
# Policy being trained (current).
current_action_dist = action_dist_cls_train.from_logits(
fwd_out[Columns.ACTION_DIST_INPUTS]
)
behaviour_actions_logp = batch[Columns.ACTION_LOGP]
target_actions_logp = target_policy_dist.logp(batch[Columns.ACTIONS])

behaviour_actions_logp_time_major = make_time_major(
behaviour_actions_logp,
current_actions_logp = current_action_dist.logp(batch[Columns.ACTIONS])
current_actions_logp_time_major = make_time_major(
current_actions_logp,
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
)

# Target policy.
target_action_dist = action_dist_cls_train.from_logits(
module.forward_target(batch)[TARGET_ACTION_DIST_LOGITS_KEY]
)
target_actions_logp = target_action_dist.logp(batch[Columns.ACTIONS])
target_actions_logp_time_major = make_time_major(
target_actions_logp,
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
)
old_actions_logp_time_major = make_time_major(
old_target_policy_actions_logp,

# EnvRunner's policy (behavior).
behavior_actions_logp = batch[Columns.ACTION_LOGP]
behavior_actions_logp_time_major = make_time_major(
behavior_actions_logp,
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
)

rewards_time_major = make_time_major(
batch[Columns.REWARDS],
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
)

assert Columns.VALUES_BOOTSTRAPPED not in batch
values_time_major = make_time_major(
values,
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
)
assert Columns.VALUES_BOOTSTRAPPED not in batch
# Use as bootstrap values the vf-preds in the next "batch row", except
# for the very last row (which doesn't have a next row), for which the
# bootstrap value does not matter b/c it has a +1ts value at its end
Expand All @@ -112,61 +127,86 @@ def compute_loss_for_module(
dim=0,
)

# The discount factor that is used should be gamma except for timesteps where
# the episode is terminated. In that case, the discount factor should be 0.
# The discount factor that is used should be `gamma * lambda_`, except for
# termination timesteps, in which case the discount factor should be 0.
discounts_time_major = (
1.0
- make_time_major(
batch[Columns.TERMINATEDS],
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
).float()
) * config.gamma
(
1.0
- make_time_major(
batch[Columns.TERMINATEDS],
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
).float()
# See [1] 3.1: Discounts must contain the GAE lambda_ parameter as well.
)
* config.gamma
* config.lambda_
)

# Note that vtrace will compute the main loop on the CPU for better performance.
vtrace_adjusted_target_values, pg_advantages = vtrace_torch(
target_action_log_probs=old_actions_logp_time_major,
behaviour_action_log_probs=behaviour_actions_logp_time_major,
# See [1] 3.1: For AˆV-GAE, the ratios used are: min(c¯, π(target)/π(i))
# π(target)
target_action_log_probs=target_actions_logp_time_major,
# π(i)
behaviour_action_log_probs=behavior_actions_logp_time_major,
# See [1] 3.1: Discounts must contain the GAE lambda_ parameter as well.
discounts=discounts_time_major,
rewards=rewards_time_major,
values=values_time_major,
bootstrap_values=bootstrap_values,
clip_pg_rho_threshold=config.vtrace_clip_pg_rho_threshold,
# c¯
clip_rho_threshold=config.vtrace_clip_rho_threshold,
# c¯ (but we allow users to distinguish between c¯ used for
# value estimates and c¯ used for the advantages.
clip_pg_rho_threshold=config.vtrace_clip_pg_rho_threshold,
)
pg_advantages = pg_advantages * loss_mask_time_major

# The policy gradients loss.
is_ratio = torch.clip(
torch.exp(behaviour_actions_logp_time_major - old_actions_logp_time_major),
# The policy gradient loss.
# As described in [1], use a logp-ratio of:
# min(π(i) / π(target), ρ) * (π / π(i)), where ..
# - π are the action probs from the current (learner) policy
# - π(i) are the action probs from the ith EnvRunner
# - π(target) are the action probs from the target network
# - ρ is the "target-worker clipping" (2.0 in the paper)
target_worker_is_ratio = torch.clip(
torch.exp(
behavior_actions_logp_time_major - target_actions_logp_time_major
),
0.0,
2.0,
config.target_worker_clipping,
)
logp_ratio = is_ratio * torch.exp(
target_actions_logp_time_major - behaviour_actions_logp_time_major
target_worker_logp_ratio = target_worker_is_ratio * torch.exp(
current_actions_logp_time_major - behavior_actions_logp_time_major
)

surrogate_loss = torch.minimum(
pg_advantages * logp_ratio,
pg_advantages * target_worker_logp_ratio,
pg_advantages
* torch.clip(logp_ratio, 1 - config.clip_param, 1 + config.clip_param),
* torch.clip(
target_worker_logp_ratio,
1 - config.clip_param,
1 + config.clip_param,
),
)
mean_pi_loss = -(torch.sum(surrogate_loss) / size_loss_mask)

# Compute KL-loss (if required): KL divergence between current action dist.
# and target action dict.
if config.use_kl_loss:
action_kl = old_target_policy_dist.kl(target_policy_dist) * loss_mask
action_kl = target_action_dist.kl(current_action_dist) * loss_mask
mean_kl_loss = torch.sum(action_kl) / size_loss_mask
else:
mean_kl_loss = 0.0
mean_pi_loss = -(torch.sum(surrogate_loss) / size_loss_mask)

# The baseline loss.
# Compute value function loss.
delta = values_time_major - vtrace_adjusted_target_values
vf_loss = 0.5 * torch.sum(torch.pow(delta, 2.0) * loss_mask_time_major)
mean_vf_loss = vf_loss / size_loss_mask

# The entropy loss.
# Compute entropy loss.
mean_entropy_loss = (
-torch.sum(target_policy_dist.entropy() * loss_mask) / size_loss_mask
-torch.sum(current_action_dist.entropy() * loss_mask) / size_loss_mask
)

# The summed weighted loss.
Expand Down
1 change: 1 addition & 0 deletions rllib/algorithms/impala/vtrace_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def multi_from_logits(
behaviour_action_log_probs, device="cpu"
)
behaviour_action_log_probs = force_list(behaviour_action_log_probs)
# log_rhos = target_logp - behavior_logp
log_rhos = get_log_rhos(target_action_log_probs, behaviour_action_log_probs)

vtrace_returns = from_importance_weights(
Expand Down

0 comments on commit 91f5f14

Please sign in to comment.