diff --git a/rllib/BUILD b/rllib/BUILD index 32f0259bad2c4..feab865a38dcc 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -301,6 +301,16 @@ py_test( args = ["--dir=tuned_examples/appo"] ) +py_test( + name = "learning_tests_stateless_cartpole_appo_vtrace", + main = "tests/run_regression_tests.py", + tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"], + size = "large", + srcs = ["tests/run_regression_tests.py"], + data = ["tuned_examples/appo/stateless-cartpole-appo-vtrace.py"], + args = ["--dir=tuned_examples/appo"] +) + # ARS py_test( name = "learning_tests_cartpole_ars", @@ -3710,7 +3720,7 @@ py_test( ) # Taking out this test for now: Mixed torch- and tf- policies within the same -# Algorothm never really worked. +# Algorithm never really worked. # py_test( # name = "examples/multi_agent_two_trainers_mixed_torch_tf", # main = "examples/multi_agent_two_trainers.py", diff --git a/rllib/algorithms/appo/appo_tf_policy.py b/rllib/algorithms/appo/appo_tf_policy.py index 8441f8032ede8..ec4dd78c295e3 100644 --- a/rllib/algorithms/appo/appo_tf_policy.py +++ b/rllib/algorithms/appo/appo_tf_policy.py @@ -19,6 +19,7 @@ ) from ray.rllib.evaluation.episode import Episode from ray.rllib.evaluation.postprocessing import ( + compute_bootstrap_value, compute_gae_for_sample_batch, Postprocessing, ) @@ -144,7 +145,6 @@ def loss( is_multidiscrete = False output_hidden_shape = 1 - # TODO: (sven) deprecate this when trajectory view API gets activated. def make_time_major(*args, **kw): return _make_time_major( self, train_batch.get(SampleBatch.SEQ_LENS), *args, **kw @@ -159,12 +159,16 @@ def make_time_major(*args, **kw): prev_action_dist = dist_class(behaviour_logits, self.model) values = self.model.value_function() values_time_major = make_time_major(values) + bootstrap_values_time_major = make_time_major( + train_batch[SampleBatch.VALUES_BOOTSTRAPPED] + ) + bootstrap_value = bootstrap_values_time_major[-1] if self.is_recurrent(): max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS]) mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) mask = tf.reshape(mask, [-1]) - mask = make_time_major(mask, drop_last=self.config["vtrace"]) + mask = make_time_major(mask) def reduce_mean_valid(t): return tf.reduce_mean(tf.boolean_mask(t, mask)) @@ -173,11 +177,7 @@ def reduce_mean_valid(t): reduce_mean_valid = tf.reduce_mean if self.config["vtrace"]: - drop_last = self.config["vtrace_drop_last_ts"] - logger.debug( - "Using V-Trace surrogate loss (vtrace=True; " - f"drop_last={drop_last})" - ) + logger.debug("Using V-Trace surrogate loss (vtrace=True)") # Prepare actions for loss. loss_actions = ( @@ -188,9 +188,7 @@ def reduce_mean_valid(t): old_policy_action_dist = dist_class(old_policy_behaviour_logits, model) # Prepare KL for Loss - mean_kl = make_time_major( - old_policy_action_dist.multi_kl(action_dist), drop_last=drop_last - ) + mean_kl = make_time_major(old_policy_action_dist.multi_kl(action_dist)) unpacked_behaviour_logits = tf.split( behaviour_logits, output_hidden_shape, axis=1 @@ -203,26 +201,20 @@ def reduce_mean_valid(t): with tf.device("/cpu:0"): vtrace_returns = vtrace.multi_from_logits( behaviour_policy_logits=make_time_major( - unpacked_behaviour_logits, drop_last=drop_last + unpacked_behaviour_logits ), target_policy_logits=make_time_major( - unpacked_old_policy_behaviour_logits, drop_last=drop_last - ), - actions=tf.unstack( - make_time_major(loss_actions, drop_last=drop_last), axis=2 + unpacked_old_policy_behaviour_logits ), + actions=tf.unstack(make_time_major(loss_actions), axis=2), discounts=tf.cast( - ~make_time_major( - tf.cast(dones, tf.bool), drop_last=drop_last - ), + ~make_time_major(tf.cast(dones, tf.bool)), tf.float32, ) * self.config["gamma"], - rewards=make_time_major(rewards, drop_last=drop_last), - values=values_time_major[:-1] - if drop_last - else values_time_major, - bootstrap_value=values_time_major[-1], + rewards=make_time_major(rewards), + values=values_time_major, + bootstrap_value=bootstrap_value, dist_class=Categorical if is_multidiscrete else dist_class, model=model, clip_rho_threshold=tf.cast( @@ -233,14 +225,10 @@ def reduce_mean_valid(t): ), ) - actions_logp = make_time_major( - action_dist.logp(actions), drop_last=drop_last - ) - prev_actions_logp = make_time_major( - prev_action_dist.logp(actions), drop_last=drop_last - ) + actions_logp = make_time_major(action_dist.logp(actions)) + prev_actions_logp = make_time_major(prev_action_dist.logp(actions)) old_policy_actions_logp = make_time_major( - old_policy_action_dist.logp(actions), drop_last=drop_last + old_policy_action_dist.logp(actions) ) is_ratio = tf.clip_by_value( @@ -267,17 +255,12 @@ def reduce_mean_valid(t): mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. - if drop_last: - delta = values_time_major[:-1] - vtrace_returns.vs - else: - delta = values_time_major - vtrace_returns.vs value_targets = vtrace_returns.vs + delta = values_time_major - value_targets mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta)) # The entropy loss. - actions_entropy = make_time_major( - action_dist.multi_entropy(), drop_last=True - ) + actions_entropy = make_time_major(action_dist.multi_entropy()) mean_entropy = reduce_mean_valid(actions_entropy) else: @@ -353,7 +336,6 @@ def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: self, train_batch.get(SampleBatch.SEQ_LENS), self.model.value_function(), - drop_last=self.config["vtrace"] and self.config["vtrace_drop_last_ts"], ) stats_dict = { @@ -388,20 +370,22 @@ def postprocess_trajectory( other_agent_batches: Optional[SampleBatch] = None, episode: Optional["Episode"] = None, ): + # Call super's postprocess_trajectory first. + # sample_batch = super().postprocess_trajectory( + # sample_batch, other_agent_batches, episode + # ) + if not self.config["vtrace"]: sample_batch = compute_gae_for_sample_batch( self, sample_batch, other_agent_batches, episode ) + else: + # Add the SampleBatch.VALUES_BOOTSTRAPPED column, which we'll need + # inside the loss for vtrace calculations. + sample_batch = compute_bootstrap_value(sample_batch, self) return sample_batch - @override(base) - def extra_action_out_fn(self) -> Dict[str, TensorType]: - extra_action_fetches = super().extra_action_out_fn() - if not self.config["vtrace"]: - extra_action_fetches[SampleBatch.VF_PREDS] = self.model.value_function() - return extra_action_fetches - @override(base) def get_batch_divisibility_req(self) -> int: return self.config["rollout_fragment_length"] diff --git a/rllib/algorithms/appo/appo_torch_policy.py b/rllib/algorithms/appo/appo_torch_policy.py index 4a7754830f321..a97446c44281c 100644 --- a/rllib/algorithms/appo/appo_torch_policy.py +++ b/rllib/algorithms/appo/appo_torch_policy.py @@ -19,6 +19,7 @@ ) from ray.rllib.evaluation.episode import Episode from ray.rllib.evaluation.postprocessing import ( + compute_bootstrap_value, compute_gae_for_sample_batch, Postprocessing, ) @@ -157,14 +158,16 @@ def _make_time_major(*args, **kwargs): prev_action_dist = dist_class(behaviour_logits, model) values = model.value_function() values_time_major = _make_time_major(values) - - drop_last = self.config["vtrace"] and self.config["vtrace_drop_last_ts"] + bootstrap_values_time_major = _make_time_major( + train_batch[SampleBatch.VALUES_BOOTSTRAPPED] + ) + bootstrap_value = bootstrap_values_time_major[-1] if self.is_recurrent(): max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS]) mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) mask = torch.reshape(mask, [-1]) - mask = _make_time_major(mask, drop_last=drop_last) + mask = _make_time_major(mask) num_valid = torch.sum(mask) def reduce_mean_valid(t): @@ -174,9 +177,7 @@ def reduce_mean_valid(t): reduce_mean_valid = torch.mean if self.config["vtrace"]: - logger.debug( - "Using V-Trace surrogate loss (vtrace=True; " f"drop_last={drop_last})" - ) + logger.debug("Using V-Trace surrogate loss (vtrace=True)") old_policy_behaviour_logits = target_model_out.detach() old_policy_action_dist = dist_class(old_policy_behaviour_logits, model) @@ -202,40 +203,30 @@ def reduce_mean_valid(t): ) # Prepare KL for loss. - action_kl = _make_time_major( - old_policy_action_dist.kl(action_dist), drop_last=drop_last - ) + action_kl = _make_time_major(old_policy_action_dist.kl(action_dist)) # Compute vtrace on the CPU for better perf. vtrace_returns = vtrace.multi_from_logits( - behaviour_policy_logits=_make_time_major( - unpacked_behaviour_logits, drop_last=drop_last - ), + behaviour_policy_logits=_make_time_major(unpacked_behaviour_logits), target_policy_logits=_make_time_major( - unpacked_old_policy_behaviour_logits, drop_last=drop_last + unpacked_old_policy_behaviour_logits ), - actions=torch.unbind( - _make_time_major(loss_actions, drop_last=drop_last), dim=2 - ), - discounts=(1.0 - _make_time_major(dones, drop_last=drop_last).float()) + actions=torch.unbind(_make_time_major(loss_actions), dim=2), + discounts=(1.0 - _make_time_major(dones).float()) * self.config["gamma"], - rewards=_make_time_major(rewards, drop_last=drop_last), - values=values_time_major[:-1] if drop_last else values_time_major, - bootstrap_value=values_time_major[-1], + rewards=_make_time_major(rewards), + values=values_time_major, + bootstrap_value=bootstrap_value, dist_class=TorchCategorical if is_multidiscrete else dist_class, model=model, clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"], ) - actions_logp = _make_time_major( - action_dist.logp(actions), drop_last=drop_last - ) - prev_actions_logp = _make_time_major( - prev_action_dist.logp(actions), drop_last=drop_last - ) + actions_logp = _make_time_major(action_dist.logp(actions)) + prev_actions_logp = _make_time_major(prev_action_dist.logp(actions)) old_policy_actions_logp = _make_time_major( - old_policy_action_dist.logp(actions), drop_last=drop_last + old_policy_action_dist.logp(actions) ) is_ratio = torch.clamp( torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0 @@ -259,16 +250,11 @@ def reduce_mean_valid(t): # The value function loss. value_targets = vtrace_returns.vs.to(values_time_major.device) - if drop_last: - delta = values_time_major[:-1] - value_targets - else: - delta = values_time_major - value_targets + delta = values_time_major - value_targets mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0)) # The entropy loss. - mean_entropy = reduce_mean_valid( - _make_time_major(action_dist.entropy(), drop_last=drop_last) - ) + mean_entropy = reduce_mean_valid(_make_time_major(action_dist.entropy())) else: logger.debug("Using PPO surrogate loss (vtrace=False)") @@ -323,9 +309,7 @@ def reduce_mean_valid(t): model.tower_stats["value_targets"] = value_targets model.tower_stats["vf_explained_var"] = explained_variance( torch.reshape(value_targets, [-1]), - torch.reshape( - values_time_major[:-1] if drop_last else values_time_major, [-1] - ), + torch.reshape(values_time_major, [-1]), ) return total_loss @@ -378,10 +362,7 @@ def extra_action_out( model: TorchModelV2, action_dist: TorchDistributionWrapper, ) -> Dict[str, TensorType]: - out = {} - if not self.config["vtrace"]: - out[SampleBatch.VF_PREDS] = model.value_function() - return out + return {SampleBatch.VF_PREDS: model.value_function()} @override(TorchPolicyV2) def postprocess_trajectory( @@ -391,17 +372,23 @@ def postprocess_trajectory( episode: Optional["Episode"] = None, ): # Call super's postprocess_trajectory first. - sample_batch = super().postprocess_trajectory( - sample_batch, other_agent_batches, episode - ) - if not self.config["vtrace"]: - # Do all post-processing always with no_grad(). - # Not using this here will introduce a memory leak - # in torch (issue #6962). - with torch.no_grad(): + # sample_batch = super().postprocess_trajectory( + # sample_batch, other_agent_batches, episode + # ) + + # Do all post-processing always with no_grad(). + # Not using this here will introduce a memory leak + # in torch (issue #6962). + with torch.no_grad(): + if not self.config["vtrace"]: sample_batch = compute_gae_for_sample_batch( self, sample_batch, other_agent_batches, episode ) + else: + # Add the SampleBatch.VALUES_BOOTSTRAPPED column, which we'll need + # inside the loss for vtrace calculations. + sample_batch = compute_bootstrap_value(sample_batch, self) + return sample_batch @override(TorchPolicyV2) diff --git a/rllib/algorithms/appo/tests/test_appo_learner.py b/rllib/algorithms/appo/tests/test_appo_learner.py index 19bcacdc89bec..6a018ca403447 100644 --- a/rllib/algorithms/appo/tests/test_appo_learner.py +++ b/rllib/algorithms/appo/tests/test_appo_learner.py @@ -36,6 +36,9 @@ SampleBatch.VF_PREDS: np.array( list(reversed(range(frag_length))), dtype=np.float32 ), + SampleBatch.VALUES_BOOTSTRAPPED: np.array( + list(reversed(range(frag_length))), dtype=np.float32 + ), SampleBatch.ACTION_LOGP: np.log( np.random.uniform(low=0, high=1, size=(frag_length,)) ).astype(np.float32), diff --git a/rllib/algorithms/appo/tf/appo_tf_learner.py b/rllib/algorithms/appo/tf/appo_tf_learner.py index 84fc8f4714cc5..3c949d1a090fa 100644 --- a/rllib/algorithms/appo/tf/appo_tf_learner.py +++ b/rllib/algorithms/appo/tf/appo_tf_learner.py @@ -61,19 +61,24 @@ def compute_loss_for_module( trajectory_len=hps.rollout_frag_or_episode_len, recurrent_seq_len=hps.recurrent_seq_len, ) + rewards_time_major = make_time_major( + batch[SampleBatch.REWARDS], + trajectory_len=hps.rollout_frag_or_episode_len, + recurrent_seq_len=hps.recurrent_seq_len, + ) values_time_major = make_time_major( values, trajectory_len=hps.rollout_frag_or_episode_len, recurrent_seq_len=hps.recurrent_seq_len, ) - bootstrap_value = values_time_major[-1] - rewards_time_major = make_time_major( - batch[SampleBatch.REWARDS], + bootstrap_values_time_major = make_time_major( + batch[SampleBatch.VALUES_BOOTSTRAPPED], trajectory_len=hps.rollout_frag_or_episode_len, recurrent_seq_len=hps.recurrent_seq_len, ) + bootstrap_value = bootstrap_values_time_major[-1] - # the discount factor that is used should be gamma except for timesteps where + # The discount factor that is used should be gamma except for timesteps where # the episode is terminated. In that case, the discount factor should be 0. discounts_time_major = ( 1.0 diff --git a/rllib/algorithms/appo/torch/appo_torch_learner.py b/rllib/algorithms/appo/torch/appo_torch_learner.py index e1f170a092dcc..91b86a9892648 100644 --- a/rllib/algorithms/appo/torch/appo_torch_learner.py +++ b/rllib/algorithms/appo/torch/appo_torch_learner.py @@ -74,19 +74,24 @@ def compute_loss_for_module( trajectory_len=hps.rollout_frag_or_episode_len, recurrent_seq_len=hps.recurrent_seq_len, ) + rewards_time_major = make_time_major( + batch[SampleBatch.REWARDS], + trajectory_len=hps.rollout_frag_or_episode_len, + recurrent_seq_len=hps.recurrent_seq_len, + ) values_time_major = make_time_major( values, trajectory_len=hps.rollout_frag_or_episode_len, recurrent_seq_len=hps.recurrent_seq_len, ) - bootstrap_value = values_time_major[-1] - rewards_time_major = make_time_major( - batch[SampleBatch.REWARDS], + bootstrap_values_time_major = make_time_major( + batch[SampleBatch.VALUES_BOOTSTRAPPED], trajectory_len=hps.rollout_frag_or_episode_len, recurrent_seq_len=hps.recurrent_seq_len, ) + bootstrap_value = bootstrap_values_time_major[-1] - # the discount factor that is used should be gamma except for timesteps where + # The discount factor that is used should be gamma except for timesteps where # the episode is terminated. In that case, the discount factor should be 0. discounts_time_major = ( 1.0 diff --git a/rllib/algorithms/impala/impala.py b/rllib/algorithms/impala/impala.py index 9bd59b19dbd92..db4ed7e0d5690 100644 --- a/rllib/algorithms/impala/impala.py +++ b/rllib/algorithms/impala/impala.py @@ -112,11 +112,6 @@ def __init__(self, algo_class=None): self.vtrace = True self.vtrace_clip_rho_threshold = 1.0 self.vtrace_clip_pg_rho_threshold = 1.0 - # TODO (sven): Deprecate this setting. It makes no sense to drop the last ts. - # It's actually dangerous if there are important rewards "hiding" in that ts. - # This setting is already ignored (always False) on the new Learner API - # (if _enable_learner_api=True). - self.vtrace_drop_last_ts = False self.num_multi_gpu_tower_stacks = 1 self.minibatch_buffer_size = 1 self.num_sgd_iter = 1 @@ -171,6 +166,7 @@ def __init__(self, algo_class=None): # Deprecated value. self.num_data_loader_buffers = DEPRECATED_VALUE + self.vtrace_drop_last_ts = DEPRECATED_VALUE @override(AlgorithmConfig) def training( @@ -179,7 +175,6 @@ def training( vtrace: Optional[bool] = NotProvided, vtrace_clip_rho_threshold: Optional[float] = NotProvided, vtrace_clip_pg_rho_threshold: Optional[float] = NotProvided, - vtrace_drop_last_ts: Optional[bool] = NotProvided, gamma: Optional[float] = NotProvided, num_multi_gpu_tower_stacks: Optional[int] = NotProvided, minibatch_buffer_size: Optional[int] = NotProvided, @@ -206,6 +201,8 @@ def training( _separate_vf_optimizer: Optional[bool] = NotProvided, _lr_vf: Optional[float] = NotProvided, after_train_step: Optional[Callable[[dict], None]] = NotProvided, + # deprecated. + vtrace_drop_last_ts=None, **kwargs, ) -> "ImpalaConfig": """Sets the training related configuration. @@ -214,13 +211,6 @@ def training( vtrace: V-trace params (see vtrace_tf/torch.py). vtrace_clip_rho_threshold: vtrace_clip_pg_rho_threshold: - vtrace_drop_last_ts: If True, drop the last timestep for the vtrace - calculations, such that all data goes into the calculations as [B x T-1] - (+ the bootstrap value). This is the default and legacy RLlib behavior, - however, could potentially have a destabilizing effect on learning, - especially in sparse reward or reward-at-goal environments. - False for not dropping the last timestep. - System params. gamma: Float specifying the discount factor of the Markov Decision process. num_multi_gpu_tower_stacks: For each stack of multi-GPU towers, how many slots should we reserve for parallel data loading? Set this to >1 to @@ -293,6 +283,16 @@ def training( Returns: This updated AlgorithmConfig object. """ + if vtrace_drop_last_ts is not None: + deprecation_warning( + old="vtrace_drop_last_ts", + help="The v-trace operations in RLlib have been enhanced and we are " + "now using proper value bootstrapping at the end of each " + "trajectory, such that no timesteps in our loss functions have to " + "be dropped anymore.", + error=True, + ) + # Pass kwargs onto super's `training()` method. super().training(**kwargs) @@ -302,8 +302,6 @@ def training( self.vtrace_clip_rho_threshold = vtrace_clip_rho_threshold if vtrace_clip_pg_rho_threshold is not NotProvided: self.vtrace_clip_pg_rho_threshold = vtrace_clip_pg_rho_threshold - if vtrace_drop_last_ts is not NotProvided: - self.vtrace_drop_last_ts = vtrace_drop_last_ts if num_multi_gpu_tower_stacks is not NotProvided: self.num_multi_gpu_tower_stacks = num_multi_gpu_tower_stacks if minibatch_buffer_size is not NotProvided: diff --git a/rllib/algorithms/impala/impala_tf_policy.py b/rllib/algorithms/impala/impala_tf_policy.py index d8b830ef7653b..5e2700cb3639a 100644 --- a/rllib/algorithms/impala/impala_tf_policy.py +++ b/rllib/algorithms/impala/impala_tf_policy.py @@ -5,9 +5,11 @@ import numpy as np import logging import gymnasium as gym -from typing import Dict, List, Type, Union +from typing import Dict, List, Optional, Type, Union from ray.rllib.algorithms.impala import vtrace_tf as vtrace +from ray.rllib.evaluation.episode import Episode +from ray.rllib.evaluation.postprocessing import compute_bootstrap_value from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_action_dist import Categorical, TFActionDistribution from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2 @@ -18,7 +20,7 @@ from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.tf_utils import explained_variance -from ray.rllib.policy.tf_mixins import GradStatsMixin +from ray.rllib.policy.tf_mixins import GradStatsMixin, ValueNetworkMixin from ray.rllib.utils.typing import ( LocalOptimizer, ModelGradients, @@ -131,14 +133,13 @@ def __init__( self.total_loss += self.vf_loss * vf_loss_coeff -def _make_time_major(policy, seq_lens, tensor, drop_last=False): +def _make_time_major(policy, seq_lens, tensor): """Swaps batch and trajectory axis. Args: policy: Policy reference seq_lens: Sequence lengths if recurrent or None tensor: A tensor or list of tensors to reshape. - drop_last: A bool indicating whether to drop the last trajectory item. Returns: @@ -146,7 +147,7 @@ def _make_time_major(policy, seq_lens, tensor, drop_last=False): swapped axes. """ if isinstance(tensor, list): - return [_make_time_major(policy, seq_lens, t, drop_last) for t in tensor] + return [_make_time_major(policy, seq_lens, t) for t in tensor] if policy.is_recurrent(): B = tf.shape(seq_lens)[0] @@ -163,8 +164,6 @@ def _make_time_major(policy, seq_lens, tensor, drop_last=False): # swap B and T axes res = tf.transpose(rs, [1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0])))) - if drop_last: - return res[:-1] return res @@ -274,6 +273,7 @@ class ImpalaTFPolicy( LearningRateSchedule, EntropyCoeffSchedule, GradStatsMixin, + ValueNetworkMixin, base, ): def __init__( @@ -296,6 +296,7 @@ def __init__( existing_inputs=existing_inputs, existing_model=existing_model, ) + ValueNetworkMixin.__init__(self, config) # If Learner API is used, we don't need any loss-specific mixins. # However, we also would like to avoid creating special Policy-subclasses @@ -349,6 +350,11 @@ def make_time_major(*args, **kw): ) unpacked_outputs = tf.split(model_out, output_hidden_shape, axis=1) values = model.value_function() + values_time_major = make_time_major(values) + bootstrap_values_time_major = make_time_major( + train_batch[SampleBatch.VALUES_BOOTSTRAPPED] + ) + bootstrap_value = bootstrap_values_time_major[-1] if self.is_recurrent(): max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS]) @@ -363,30 +369,21 @@ def make_time_major(*args, **kw): ) # Inputs are reshaped from [B * T] => [(T|T-1), B] for V-trace calc. - drop_last = self.config["vtrace_drop_last_ts"] self.vtrace_loss = VTraceLoss( - actions=make_time_major(loss_actions, drop_last=drop_last), - actions_logp=make_time_major( - action_dist.logp(actions), drop_last=drop_last - ), - actions_entropy=make_time_major( - action_dist.multi_entropy(), drop_last=drop_last - ), - dones=make_time_major(dones, drop_last=drop_last), - behaviour_action_logp=make_time_major( - behaviour_action_logp, drop_last=drop_last - ), - behaviour_logits=make_time_major( - unpacked_behaviour_logits, drop_last=drop_last - ), - target_logits=make_time_major(unpacked_outputs, drop_last=drop_last), + actions=make_time_major(loss_actions), + actions_logp=make_time_major(action_dist.logp(actions)), + actions_entropy=make_time_major(action_dist.multi_entropy()), + dones=make_time_major(dones), + behaviour_action_logp=make_time_major(behaviour_action_logp), + behaviour_logits=make_time_major(unpacked_behaviour_logits), + target_logits=make_time_major(unpacked_outputs), discount=self.config["gamma"], - rewards=make_time_major(rewards, drop_last=drop_last), - values=make_time_major(values, drop_last=drop_last), - bootstrap_value=make_time_major(values)[-1], + rewards=make_time_major(rewards), + values=values_time_major, + bootstrap_value=bootstrap_value, dist_class=Categorical if is_multidiscrete else dist_class, model=model, - valid_mask=make_time_major(mask, drop_last=drop_last), + valid_mask=make_time_major(mask), config=self.config, vf_loss_coeff=self.config["vf_loss_coeff"], entropy_coeff=self.entropy_coeff, @@ -401,12 +398,10 @@ def make_time_major(*args, **kw): @override(base) def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: - drop_last = self.config["vtrace"] and self.config["vtrace_drop_last_ts"] values_batched = _make_time_major( self, train_batch.get(SampleBatch.SEQ_LENS), self.model.value_function(), - drop_last=drop_last, ) return { @@ -422,6 +417,25 @@ def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: ), } + @override(base) + def postprocess_trajectory( + self, + sample_batch: SampleBatch, + other_agent_batches: Optional[SampleBatch] = None, + episode: Optional["Episode"] = None, + ): + # Call super's postprocess_trajectory first. + # sample_batch = super().postprocess_trajectory( + # sample_batch, other_agent_batches, episode + # ) + + if self.config["vtrace"]: + # Add the SampleBatch.VALUES_BOOTSTRAPPED column, which we'll need + # inside the loss for vtrace calculations. + sample_batch = compute_bootstrap_value(sample_batch, self) + + return sample_batch + @override(base) def get_batch_divisibility_req(self) -> int: return self.config["rollout_fragment_length"] diff --git a/rllib/algorithms/impala/impala_torch_policy.py b/rllib/algorithms/impala/impala_torch_policy.py index 71aed03206015..80b39f64f065c 100644 --- a/rllib/algorithms/impala/impala_torch_policy.py +++ b/rllib/algorithms/impala/impala_torch_policy.py @@ -1,9 +1,11 @@ import gymnasium as gym import logging import numpy as np -from typing import Dict, List, Type, Union +from typing import Dict, List, Optional, Type, Union import ray +from ray.rllib.evaluation.episode import Episode +from ray.rllib.evaluation.postprocessing import compute_bootstrap_value from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.torch.torch_action_dist import TorchCategorical @@ -11,6 +13,7 @@ from ray.rllib.policy.torch_mixins import ( EntropyCoeffSchedule, LearningRateSchedule, + ValueNetworkMixin, ) from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 from ray.rllib.utils.annotations import override @@ -126,22 +129,20 @@ def __init__( ) -def make_time_major(policy, seq_lens, tensor, drop_last=False): +def make_time_major(policy, seq_lens, tensor): """Swaps batch and trajectory axis. Args: policy: Policy reference seq_lens: Sequence lengths if recurrent or None tensor: A tensor or list of tensors to reshape. - drop_last: A bool indicating whether to drop the last - trajectory item. Returns: res: A tensor with swapped axes or a list of tensors with swapped axes. """ if isinstance(tensor, (list, tuple)): - return [make_time_major(policy, seq_lens, t, drop_last) for t in tensor] + return [make_time_major(policy, seq_lens, t) for t in tensor] if policy.is_recurrent(): B = seq_lens.shape[0] @@ -158,8 +159,6 @@ def make_time_major(policy, seq_lens, tensor, drop_last=False): # Swap B and T axes. res = torch.transpose(rs, 1, 0) - if drop_last: - return res[:-1] return res @@ -192,6 +191,7 @@ class ImpalaTorchPolicy( VTraceOptimizer, LearningRateSchedule, EntropyCoeffSchedule, + ValueNetworkMixin, TorchPolicyV2, ): """PyTorch policy class used with Impala.""" @@ -222,6 +222,8 @@ def __init__(self, observation_space, action_space, config): max_seq_len=config["model"]["max_seq_len"], ) + ValueNetworkMixin.__init__(self, config) + self._initialize_loss_from_dummy_batch() @override(TorchPolicyV2) @@ -265,6 +267,11 @@ def _make_time_major(*args, **kw): ) unpacked_outputs = torch.chunk(model_out, output_hidden_shape, dim=1) values = model.value_function() + values_time_major = _make_time_major(values) + bootstrap_values_time_major = _make_time_major( + train_batch[SampleBatch.VALUES_BOOTSTRAPPED] + ) + bootstrap_value = bootstrap_values_time_major[-1] if self.is_recurrent(): max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS]) @@ -277,30 +284,21 @@ def _make_time_major(*args, **kw): loss_actions = actions if is_multidiscrete else torch.unsqueeze(actions, dim=1) # Inputs are reshaped from [B * T] => [(T|T-1), B] for V-trace calc. - drop_last = self.config["vtrace_drop_last_ts"] loss = VTraceLoss( - actions=_make_time_major(loss_actions, drop_last=drop_last), - actions_logp=_make_time_major( - action_dist.logp(actions), drop_last=drop_last - ), - actions_entropy=_make_time_major( - action_dist.entropy(), drop_last=drop_last - ), - dones=_make_time_major(dones, drop_last=drop_last), - behaviour_action_logp=_make_time_major( - behaviour_action_logp, drop_last=drop_last - ), - behaviour_logits=_make_time_major( - unpacked_behaviour_logits, drop_last=drop_last - ), - target_logits=_make_time_major(unpacked_outputs, drop_last=drop_last), + actions=_make_time_major(loss_actions), + actions_logp=_make_time_major(action_dist.logp(actions)), + actions_entropy=_make_time_major(action_dist.entropy()), + dones=_make_time_major(dones), + behaviour_action_logp=_make_time_major(behaviour_action_logp), + behaviour_logits=_make_time_major(unpacked_behaviour_logits), + target_logits=_make_time_major(unpacked_outputs), discount=self.config["gamma"], - rewards=_make_time_major(rewards, drop_last=drop_last), - values=_make_time_major(values, drop_last=drop_last), - bootstrap_value=_make_time_major(values)[-1], + rewards=_make_time_major(rewards), + values=values_time_major, + bootstrap_value=bootstrap_value, dist_class=TorchCategorical if is_multidiscrete else dist_class, model=model, - valid_mask=_make_time_major(mask, drop_last=drop_last), + valid_mask=_make_time_major(mask), config=self.config, vf_loss_coeff=self.config["vf_loss_coeff"], entropy_coeff=self.entropy_coeff, @@ -320,7 +318,6 @@ def _make_time_major(*args, **kw): self, train_batch.get(SampleBatch.SEQ_LENS), values, - drop_last=self.config["vtrace"] and drop_last, ) model.tower_stats["vf_explained_var"] = explained_variance( torch.reshape(loss.value_targets, [-1]), torch.reshape(values_batched, [-1]) @@ -349,6 +346,25 @@ def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: } ) + @override(TorchPolicyV2) + def postprocess_trajectory( + self, + sample_batch: SampleBatch, + other_agent_batches: Optional[SampleBatch] = None, + episode: Optional["Episode"] = None, + ): + # Call super's postprocess_trajectory first. + # sample_batch = super().postprocess_trajectory( + # sample_batch, other_agent_batches, episode + # ) + + if self.config["vtrace"]: + # Add the SampleBatch.VALUES_BOOTSTRAPPED column, which we'll need + # inside the loss for vtrace calculations. + sample_batch = compute_bootstrap_value(sample_batch, self) + + return sample_batch + @override(TorchPolicyV2) def extra_grad_process( self, optimizer: "torch.optim.Optimizer", loss: TensorType diff --git a/rllib/algorithms/impala/tests/test_impala_learner.py b/rllib/algorithms/impala/tests/test_impala_learner.py index 7f130ca939976..1e3672074c37f 100644 --- a/rllib/algorithms/impala/tests/test_impala_learner.py +++ b/rllib/algorithms/impala/tests/test_impala_learner.py @@ -31,6 +31,9 @@ SampleBatch.VF_PREDS: np.array( list(reversed(range(frag_length))), dtype=np.float32 ), + SampleBatch.VALUES_BOOTSTRAPPED: np.array( + list(reversed(range(frag_length))), dtype=np.float32 + ), SampleBatch.ACTION_LOGP: np.log( np.random.uniform(low=0, high=1, size=(frag_length,)) ).astype(np.float32), diff --git a/rllib/algorithms/impala/tf/impala_tf_learner.py b/rllib/algorithms/impala/tf/impala_tf_learner.py index 87ddc2cd938fc..03a960c691e3a 100644 --- a/rllib/algorithms/impala/tf/impala_tf_learner.py +++ b/rllib/algorithms/impala/tf/impala_tf_learner.py @@ -48,17 +48,22 @@ def compute_loss_for_module( trajectory_len=hps.rollout_frag_or_episode_len, recurrent_seq_len=hps.recurrent_seq_len, ) + rewards_time_major = make_time_major( + batch[SampleBatch.REWARDS], + trajectory_len=hps.rollout_frag_or_episode_len, + recurrent_seq_len=hps.recurrent_seq_len, + ) values_time_major = make_time_major( values, trajectory_len=hps.rollout_frag_or_episode_len, recurrent_seq_len=hps.recurrent_seq_len, ) - bootstrap_value = values_time_major[-1] - rewards_time_major = make_time_major( - batch[SampleBatch.REWARDS], + bootstrap_values_time_major = make_time_major( + batch[SampleBatch.VALUES_BOOTSTRAPPED], trajectory_len=hps.rollout_frag_or_episode_len, recurrent_seq_len=hps.recurrent_seq_len, ) + bootstrap_value = bootstrap_values_time_major[-1] # the discount factor that is used should be gamma except for timesteps where # the episode is terminated. In that case, the discount factor should be 0. diff --git a/rllib/algorithms/impala/tf/vtrace_tf_v2.py b/rllib/algorithms/impala/tf/vtrace_tf_v2.py index 5f878ddbc1c19..4aecc6c655490 100644 --- a/rllib/algorithms/impala/tf/vtrace_tf_v2.py +++ b/rllib/algorithms/impala/tf/vtrace_tf_v2.py @@ -9,7 +9,6 @@ def make_time_major( *, trajectory_len: int = None, recurrent_seq_len: int = None, - drop_last: bool = False, ): """Swaps batch and trajectory axis. @@ -21,8 +20,6 @@ def make_time_major( If None then `recurrent_seq_len` must be set. recurrent_seq_len: Sequence lengths if recurrent. If None then `trajectory_len` must be set. - drop_last: A bool indicating whether to drop the last - trajectory item. Note: Either `trajectory_len` or `recurrent_seq_len` must be set. `trajectory_len` should be used in cases where tensor is not produced from a @@ -33,7 +30,7 @@ def make_time_major( """ if isinstance(tensor, list): return [ - make_time_major(_tensor, trajectory_len, recurrent_seq_len, drop_last) + make_time_major(_tensor, trajectory_len, recurrent_seq_len) for _tensor in tensor ] @@ -52,8 +49,6 @@ def make_time_major( # swap B and T axes res = tf.transpose(rs, [1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0])))) - if drop_last: - return res[:-1] return res diff --git a/rllib/algorithms/impala/torch/impala_torch_learner.py b/rllib/algorithms/impala/torch/impala_torch_learner.py index 9a5d25f32ed67..8aa7b8b3d07d3 100644 --- a/rllib/algorithms/impala/torch/impala_torch_learner.py +++ b/rllib/algorithms/impala/torch/impala_torch_learner.py @@ -58,17 +58,22 @@ def compute_loss_for_module( trajectory_len=hps.rollout_frag_or_episode_len, recurrent_seq_len=hps.recurrent_seq_len, ) + rewards_time_major = make_time_major( + batch[SampleBatch.REWARDS], + trajectory_len=hps.rollout_frag_or_episode_len, + recurrent_seq_len=hps.recurrent_seq_len, + ) values_time_major = make_time_major( values, trajectory_len=hps.rollout_frag_or_episode_len, recurrent_seq_len=hps.recurrent_seq_len, ) - bootstrap_value = values_time_major[-1] - rewards_time_major = make_time_major( - batch[SampleBatch.REWARDS], + bootstrap_values_time_major = make_time_major( + batch[SampleBatch.VALUES_BOOTSTRAPPED], trajectory_len=hps.rollout_frag_or_episode_len, recurrent_seq_len=hps.recurrent_seq_len, ) + bootstrap_value = bootstrap_values_time_major[-1] # the discount factor that is used should be gamma except for timesteps where # the episode is terminated. In that case, the discount factor should be 0. diff --git a/rllib/algorithms/impala/torch/vtrace_torch_v2.py b/rllib/algorithms/impala/torch/vtrace_torch_v2.py index 1f4b6f9411faf..83ba8879d9558 100644 --- a/rllib/algorithms/impala/torch/vtrace_torch_v2.py +++ b/rllib/algorithms/impala/torch/vtrace_torch_v2.py @@ -9,7 +9,6 @@ def make_time_major( *, trajectory_len: int = None, recurrent_seq_len: int = None, - drop_last: bool = False, ): """Swaps batch and trajectory axis. @@ -21,8 +20,6 @@ def make_time_major( If None then `recurrent_seq_len` must be set. recurrent_seq_len: Sequence lengths if recurrent. If None then `trajectory_len` must be set. - drop_last: A bool indicating whether to drop the last - trajectory item. Returns: res: A tensor with swapped axes or a list of tensors with @@ -30,7 +27,7 @@ def make_time_major( """ if isinstance(tensor, (list, tuple)): return [ - make_time_major(_tensor, trajectory_len, recurrent_seq_len, drop_last) + make_time_major(_tensor, trajectory_len, recurrent_seq_len) for _tensor in tensor ] @@ -53,8 +50,6 @@ def make_time_major( # Swap B and T axes. res = torch.transpose(rs, 1, 0) - if drop_last: - return res[:-1] return res diff --git a/rllib/evaluation/postprocessing.py b/rllib/evaluation/postprocessing.py index 21355cefd4550..525d375a67afc 100644 --- a/rllib/evaluation/postprocessing.py +++ b/rllib/evaluation/postprocessing.py @@ -179,10 +179,64 @@ def compute_gae_for_sample_batch( Returns: The postprocessed, modified SampleBatch (or a new one). """ + # Compute the SampleBatch.VALUES_BOOTSTRAPPED column, which we'll need for the + # following `last_r` arg in `compute_advantages()`. + sample_batch = compute_bootstrap_value(sample_batch, policy) + # Adds the policy logits, VF preds, and advantages to the batch, + # using GAE ("generalized advantage estimation") or not. + batch = compute_advantages( + rollout=sample_batch, + last_r=sample_batch[SampleBatch.VALUES_BOOTSTRAPPED][-1], + gamma=policy.config["gamma"], + lambda_=policy.config["lambda"], + use_gae=policy.config["use_gae"], + use_critic=policy.config.get("use_critic", True), + ) + + return batch + + +@DeveloperAPI +def compute_bootstrap_value(sample_batch: SampleBatch, policy: Policy) -> SampleBatch: + """Performs a value function computation at the end of a trajectory. + + If the trajectory is terminated (not truncated), will not use the value function, + but assume that the value of the last timestep is 0.0. + In all other cases, will use the given policy's value function to compute the + "bootstrapped" value estimate at the end of the given trajectory. To do so, the + very last observation (sample_batch[NEXT_OBS][-1]) and - if applicable - + the very last state output (sample_batch[STATE_OUT][-1]) wil be used as inputs to + the value function. + + The thus computed value estimate will be stored in a new column of the + `sample_batch`: SampleBatch.VALUES_BOOTSTRAPPED. Thereby, values at all timesteps + in this column are set to 0.0, except or the last timestep, which receives the + computed bootstrapped value. + This is done, such that in any loss function (which processes raw, intact + trajectories, such as those of IMPALA and APPO) can use this new column as follows: + + Example: numbers=ts in episode, '|'=episode boundary (terminal), + X=bootstrapped value (!= 0.0 b/c ts=12 is not a terminal). + ts=5 is NOT a terminal. + T: 8 9 10 11 12 <- no terminal + VF_PREDS: . . . . . + VALUES_BOOTSTRAPPED: 0 0 0 0 X + + Args: + sample_batch: The SampleBatch (single trajectory) for which to compute the + bootstrap value at the end. This SampleBatch will be altered in place + (by adding a new column: SampleBatch.VALUES_BOOTSTRAPPED). + policy: The Policy object, whose value function to use. + + Returns: + The altered SampleBatch (with the extra SampleBatch.VALUES_BOOTSTRAPPED + column). + """ # Trajectory is actually complete -> last r=0.0. if sample_batch[SampleBatch.TERMINATEDS][-1]: last_r = 0.0 + # Trajectory has been truncated -> last r=VF estimate of last obs. else: # Input dict is provided to us automatically via the Model's @@ -192,7 +246,6 @@ def compute_gae_for_sample_batch( input_dict = sample_batch.get_single_step_input_dict( policy.model.view_requirements, index="last" ) - if policy.config.get("_enable_rl_module_api"): # Note: During sampling you are using the parameters at the beginning of # the sampling process. If I'll be using this advantages during training @@ -217,18 +270,18 @@ def compute_gae_for_sample_batch( else: last_r = policy._value(**input_dict) - # Adds the policy logits, VF preds, and advantages to the batch, - # using GAE ("generalized advantage estimation") or not. - batch = compute_advantages( - sample_batch, - last_r, - policy.config["gamma"], - policy.config["lambda"], - use_gae=policy.config["use_gae"], - use_critic=policy.config.get("use_critic", True), + # Set the SampleBatch.VALUES_BOOTSTRAPPED field to VF_PREDS[1:] + the + # very last timestep (where this bootstrapping value is actually needed), which + # we set to the computed `last_r`. + sample_batch[SampleBatch.VALUES_BOOTSTRAPPED] = np.concatenate( + [ + convert_to_numpy(sample_batch[SampleBatch.VF_PREDS][1:]), + np.array([convert_to_numpy(last_r)], dtype=np.float32), + ], + axis=0, ) - return batch + return sample_batch @DeveloperAPI diff --git a/rllib/evaluation/tests/test_trajectory_view_api.py b/rllib/evaluation/tests/test_trajectory_view_api.py index 6bcf6225444a4..0a7d76e6959f8 100644 --- a/rllib/evaluation/tests/test_trajectory_view_api.py +++ b/rllib/evaluation/tests/test_trajectory_view_api.py @@ -129,7 +129,7 @@ def test_traj_view_lstm_prev_actions_and_rewards(self): view_req_policy = policy.view_requirements # 7=obs, prev-a + r, 2x state-in, 2x state-out. assert len(view_req_model) == 7, view_req_model - assert len(view_req_policy) == 22, (len(view_req_policy), view_req_policy) + assert len(view_req_policy) == 23, (len(view_req_policy), view_req_policy) for key in [ SampleBatch.OBS, SampleBatch.ACTIONS, diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index ded4f4bb741a6..34dfb1c017e6c 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -151,6 +151,10 @@ class SampleBatch(dict): # Value function predictions emitted by the behaviour policy. VF_PREDS = "vf_preds" + # Values one ts beyond the last ts taken. These are usually calculated via the value + # function network using the final observation (and in case of an RNN: the last + # returned internal state). + VALUES_BOOTSTRAPPED = "values_bootstrapped" # RE 3 # This is only computed and used when RE3 exploration strategy is enabled. @@ -162,9 +166,8 @@ class SampleBatch(dict): # Deprecated keys: - # SampleBatches must already not be constructed anymore by setting this key - # directly. Instead, the values under this key are auto-computed via the values of - # the new TERMINATEDS and TRUNCATEDS keys. + # Do not set this key directly. Instead, the values under this key are + # auto-computed via the values of the TERMINATEDS and TRUNCATEDS keys. DONES = "dones" # Use SampleBatch.OBS instead. CUR_OBS = "obs" diff --git a/rllib/policy/tf_mixins.py b/rllib/policy/tf_mixins.py index fe5e23a330e84..49ba23a0f0395 100644 --- a/rllib/policy/tf_mixins.py +++ b/rllib/policy/tf_mixins.py @@ -293,9 +293,9 @@ class ValueNetworkMixin: """ def __init__(self, config): - # When doing GAE, we need the value function estimate on the + # When doing GAE or vtrace, we need the value function estimate on the # observation. - if config["use_gae"]: + if config.get("use_gae") or config.get("vtrace"): # Input dict is provided to us automatically via the Model's # requirements. It's a single-timestep (last one in trajectory) diff --git a/rllib/policy/torch_mixins.py b/rllib/policy/torch_mixins.py index b258c1d74560f..359eb31089a59 100644 --- a/rllib/policy/torch_mixins.py +++ b/rllib/policy/torch_mixins.py @@ -126,7 +126,7 @@ class ValueNetworkMixin: def __init__(self, config): # When doing GAE, we need the value function estimate on the # observation. - if config["use_gae"]: + if config.get("use_gae") or config.get("vtrace"): # Input dict is provided to us automatically via the Model's # requirements. It's a single-timestep (last one in trajectory) # input_dict. diff --git a/rllib/tuned_examples/appo/stateless-cartpole-appo-vtrace.py b/rllib/tuned_examples/appo/stateless-cartpole-appo-vtrace.py new file mode 100644 index 0000000000000..e127da1e28c31 --- /dev/null +++ b/rllib/tuned_examples/appo/stateless-cartpole-appo-vtrace.py @@ -0,0 +1,29 @@ +from ray.rllib.algorithms.appo.appo import APPOConfig +from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole + + +config = ( + APPOConfig() + .environment(StatelessCartPole) + .resources(num_gpus=0) + .rollouts(num_rollout_workers=1, observation_filter="MeanStdFilter") + .training( + lr=0.0003, + num_sgd_iter=6, + vf_loss_coeff=0.01, + model={ + "fcnet_hiddens": [32], + "fcnet_activation": "linear", + "vf_share_layers": True, + "use_lstm": True, + }, + # TODO: Switch over to new stack once it supports LSTMs. + # _enable_learner_api=True, + ) + # .rl_module(_enable_rl_module_api=True) +) + +stop = { + "timesteps_total": 500000, + "sampler_results/episode_reward_mean": 150.0, +}