Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] New API stack: (Multi)RLModule overhaul vol 03 (Introduce generic _forward to further simplify the user experience). #47889

16 changes: 8 additions & 8 deletions rllib/algorithms/ppo/torch/ppo_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def compute_loss_for_module(
batch: Dict[str, Any],
fwd_out: Dict[str, TensorType],
) -> TensorType:
module = self.module[module_id].unwrapped()

# Possibly apply masking to some sub loss terms and to the total loss term
# at the end. Masking could be used for RNN-based model (zero padded `batch`)
# and for PPO's batched value function (and bootstrap value) computations,
Expand All @@ -55,12 +57,8 @@ def possibly_masked_mean(data_):
else:
possibly_masked_mean = torch.mean

action_dist_class_train = (
self.module[module_id].unwrapped().get_train_action_dist_cls()
)
action_dist_class_exploration = (
self.module[module_id].unwrapped().get_exploration_action_dist_cls()
)
action_dist_class_train = module.get_train_action_dist_cls()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to use unwrapped in case DDP is used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great question! DDP already wraps this method to use the unwrapped underlying RLModule, so this is ok here.

action_dist_class_exploration = module.get_exploration_action_dist_cls()
Copy link
Contributor

@smanolloff smanolloff Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, I have a question here: shouldn't exploration or inference dist be used? In a similar fashion to GetActions connector's logic?

This affects KL loss calculation which might end up using a different distribution class (exploration_dist) than the one used for the surrogate loss (inference_dist). It is somewhat an edge case since the two are actually the same as per TorchRLModule, but users sub-classing it would be unaware.


curr_action_dist = action_dist_class_train.from_logits(
fwd_out[Columns.ACTION_DIST_INPUTS]
Expand Down Expand Up @@ -91,12 +89,14 @@ def possibly_masked_mean(data_):

# Compute a value function loss.
if config.use_critic:
value_fn_out = fwd_out[Columns.VF_PREDS]
value_fn_out = module.compute_values(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this gives again problems in the DDP case. I remember similar problems with CQL and SAC when not running everything in forward_train, but I guess the problem was that forward_train was run multiple times. So, my guess: works here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good point, I think you are right. Let's see what the tests say ...

batch, features=fwd_out.get(Columns.FEATURES)
)
vf_loss = torch.pow(value_fn_out - batch[Postprocessing.VALUE_TARGETS], 2.0)
vf_loss_clipped = torch.clamp(vf_loss, 0, config.vf_clip_param)
mean_vf_loss = possibly_masked_mean(vf_loss_clipped)
mean_vf_unclipped_loss = possibly_masked_mean(vf_loss)
# Ignore the value function.
# Ignore the value function -> Set all to 0.0.
else:
z = torch.tensor(0.0, device=surrogate_loss.device)
value_fn_out = mean_vf_unclipped_loss = vf_loss_clipped = mean_vf_loss = z
Expand Down
69 changes: 28 additions & 41 deletions rllib/algorithms/ppo/torch/ppo_torch_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Optional

from ray.rllib.algorithms.ppo.ppo_rl_module import PPORLModule
from ray.rllib.core.columns import Columns
Expand All @@ -17,63 +17,50 @@ class PPOTorchRLModule(TorchRLModule, PPORLModule):
framework: str = "torch"

@override(RLModule)
def _forward_inference(self, batch: Dict[str, Any]) -> Dict[str, Any]:
def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""Default forward pass (used for inference and exploration)."""
output = {}

# Encoder forward pass.
encoder_outs = self.encoder(batch)
# Stateful encoder?
if Columns.STATE_OUT in encoder_outs:
output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT]

# Pi head.
output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT][ACTOR])

return output

@override(RLModule)
def _forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
return self._forward_inference(batch)

@override(RLModule)
def _forward_train(self, batch: Dict[str, Any]) -> Dict[str, Any]:
if self.config.inference_only:
raise RuntimeError(
"Trying to train a module that is not a learner module. Set the "
"flag `inference_only=False` when building the module."
)
def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""Train forward pass (keep features for possible shared value func. call)."""
output = {}

# Shared encoder.
encoder_outs = self.encoder(batch)
output[Columns.FEATURES] = encoder_outs[ENCODER_OUT][CRITIC]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo features is a misleading term here as features are usually the inputs to a neural network or model in general. embeddings might fit better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right!
Changed everywhere to Columns.EMBEDDINGS and argument name: compute_values(self, batch, embedding=None).

if Columns.STATE_OUT in encoder_outs:
output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT]

# Value head.
vf_out = self.vf(encoder_outs[ENCODER_OUT][CRITIC])
# Squeeze out last dim (value function node).
output[Columns.VF_PREDS] = vf_out.squeeze(-1)

# Policy head.
action_logits = self.pi(encoder_outs[ENCODER_OUT][ACTOR])
output[Columns.ACTION_DIST_INPUTS] = action_logits

output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT][ACTOR])
return output

@override(ValueFunctionAPI)
def compute_values(self, batch: Dict[str, Any]) -> TensorType:
# Separate vf-encoder.
if hasattr(self.encoder, "critic_encoder"):
batch_ = batch
if self.is_stateful():
# The recurrent encoders expect a `(state_in, h)` key in the
# input dict while the key returned is `(state_in, critic, h)`.
batch_ = batch.copy()
batch_[Columns.STATE_IN] = batch[Columns.STATE_IN][CRITIC]
encoder_outs = self.encoder.critic_encoder(batch_)[ENCODER_OUT]
# Shared encoder.
else:
encoder_outs = self.encoder(batch)[ENCODER_OUT][CRITIC]
def compute_values(
self,
batch: Dict[str, Any],
features: Optional[Any] = None,
) -> TensorType:
if features is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not using features in batch and instead passing it in as an extra argument?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. This would mean that we would have to change the batch (add a new key to it) during the update procedure, which might clash when we have to (torch) compile this operation. We had the same problem with tf-static graph.
Also, design-wise, I think it's cleaner not to change the batch after it comes out of a connector pipeline. Separation of concerns: Only connector pipelines are ever allowed to write to a batch:

connector -> train_batch  # <- read-only from here on
fwd_out = rl_module.forward_train(train_batch)
losses = rl_module.compute_losses(train_batch, fwd_out)

# Separate vf-encoder.
if hasattr(self.encoder, "critic_encoder"):
batch_ = batch
if self.is_stateful():
# The recurrent encoders expect a `(state_in, h)` key in the
# input dict while the key returned is `(state_in, critic, h)`.
batch_ = batch.copy()
batch_[Columns.STATE_IN] = batch[Columns.STATE_IN][CRITIC]
features = self.encoder.critic_encoder(batch_)[ENCODER_OUT]
# Shared encoder.
else:
features = self.encoder(batch)[ENCODER_OUT][CRITIC]

# Value head.
vf_out = self.vf(encoder_outs)
vf_out = self.vf(features)
# Squeeze out last dimension (single node value head).
return vf_out.squeeze(-1)
1 change: 1 addition & 0 deletions rllib/core/columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class Columns:
# Common extra RLModule output keys.
STATE_IN = "state_in"
STATE_OUT = "state_out"
FEATURES = "features"
ACTION_DIST_INPUTS = "action_dist_inputs"
ACTION_PROB = "action_prob"
ACTION_LOGP = "action_logp"
Expand Down
4 changes: 2 additions & 2 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ def filter_param_dict_for_optimizer(
def get_param_ref(self, param: Param) -> Hashable:
"""Returns a hashable reference to a trainable parameter.

This should be overriden in framework specific specialization. For example in
This should be overridden in framework specific specialization. For example in
torch it will return the parameter itself, while in tf it returns the .ref() of
the variable. The purpose is to retrieve a unique reference to the parameters.

Expand All @@ -706,7 +706,7 @@ def get_param_ref(self, param: Param) -> Hashable:
def get_parameters(self, module: RLModule) -> Sequence[Param]:
"""Returns the list of parameters of a module.

This should be overriden in framework specific learner. For example in torch it
This should be overridden in framework specific learner. For example in torch it
will return .parameters(), while in tf it returns .trainable_variables.

Args:
Expand Down
16 changes: 13 additions & 3 deletions rllib/core/rl_module/apis/value_function_api.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,30 @@
import abc
from typing import Any, Dict
from typing import Any, Dict, Optional

from ray.rllib.utils.typing import TensorType


class ValueFunctionAPI(abc.ABC):
"""An API to be implemented by RLModules for handling value function-based learning.

RLModules implementing this API must override the `compute_values` method."""
RLModules implementing this API must override the `compute_values` method.
"""

@abc.abstractmethod
def compute_values(self, batch: Dict[str, Any]) -> TensorType:
def compute_values(
self,
batch: Dict[str, Any],
features: Optional[Any] = None,
) -> TensorType:
"""Computes the value estimates given `batch`.

Args:
batch: The batch to compute value function estimates for.
features: Optional features already computed from the `batch` (by another
forward pass through the model's encoder (or other feature computing
subcomponent). For example, the caller of thie method should provide
`fetuares` - if available - to avoid duplicate passes through a shared
encoder.

Returns:
A tensor of shape (B,) or (B, T) (in case the input `batch` has a
Expand Down
Loading